pytorch底层, 本文想要解决的是pytorch中间层的输出问题,有时我们训练神经网络时会设定回归或者分类作为目标,但在测试阶段实际需要的只是用神经网络提取输入的表征,因此需要获得网络的中间层输出。总结起来有两种方法:
这种方式是在前向传播进行中还没得到最终输出时,将所需要的中间层输出从前向数据流中提取出来,利用到了pytorch中的register_hook()函数。这一函数可以为模型中的某个module设置一个回调函数,形如:
hook(module, input, output) -> None or modified output
函数的输入值为module的名字、module的输入和输出。通过前置定义一个数组,在hook()函数中将对应module的输入或输出加入该数组以实现中间层提取。给出代码如下:
import torch
from torch import nnclass test_model(nn.Module):def __init__(self):super(test_model, self).__init__()self.conv_16 = nn.Sequential(nn.Conv2d(1,16,(3,3),(1,1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2,2)))self.conv_32 = nn.Sequential(nn.Conv2d(16,32,(3,3),(1,1)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))self.linear_1 = nn.Sequential(nn.Linear(32,64),nn.ReLU())self.linear_class = nn.Sequential(nn.Linear(64,5),nn.ReLU())def forward(self, x):x = self.conv_16(x)x = self.conv_32(x)x = x.view(x.shape[0],x.shape[1])x = self.linear_1(x)return self.linear_class(x)features = []def hook(module, input, output):features.append(input)return Nonenet = test_model()# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():print(name)
# 设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))
net(a)
print(features)
实际过程中建议先打印所有层的名字以做到精确提取。
pytorch embedding? 值得注意的是这个函数的用途并不止于提取中间层的输出,它也可以用于对module的输出值进行修改,查看该函数的源码注释
r"""Registers a forward hook on the module.
The hook will be called every time after :func:
forward
has
computed an output. It should have the following signature::hook(module, input, output) -> None or modified output
pytorch batchsize,The input contains only the positional arguments given to the
module. Keyword arguments won’t be passed to the hooks and only to
theforward
. The hook can modify the output. It can modify the
input inplace but it will not have effect on forward since this is
called after :func:forward
is called.
我们可以分析得到:
register_hook()函数在对应module前向传播产生输出后自动执行,回调函数的输入只包括了module的位置参数不包括关键字参数。回调函数可以通过return 修改过的输出值来对module的最终输出进行修改,同样在回调函数内部我们也可以对输入进行inplace修改,但并不会对module的输出值造成影响因为register_hook()是在对应module前向传播产生输出后之执行,输入值已经被计算过了。(这里本人对最后一句话的理解与参考中的不一样,但未经验证过。)
通过返回值提取中间层输出比较简单,同样有两种方法来实现:
pytorch parameter。代码如下:
import torch
from torch import nnclass test_model(nn.Module):def __init__(self):super(test_model, self).__init__()self.conv_16 = nn.Sequential(nn.Conv2d(1,16,(3,3),(1,1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2,2)))self.conv_32 = nn.Sequential(nn.Conv2d(16,32,(3,3),(1,1)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))self.linear_1 = nn.Sequential(nn.Linear(32,64),nn.ReLU())self.linear_class = nn.Sequential(nn.Linear(64,5),nn.ReLU())self.feature=[]def forward(self, x):x = self.conv_16(x)x = self.conv_32(x)x = x.view(x.shape[0],x.shape[1])x = self.linear_1(x)self.feature.append(x.detach())feature = x.detach()return self.linear_class(x),featurefeatures = []def hook(module, input, output):features.append(input)return Nonenet = test_model()# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():print(name)
# 设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))_,final_out=net(a)
hook_out=features
att_out=net.feature
对比可以发现三者输出是一致的。
Pytorch获取中间层输出的几种方法
pytorch的hook机制之register_forward_hook
版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。
工作时间:8:00-18:00
客服电话
电子邮件
admin@qq.com
扫码二维码
获取最新动态