pytorch中间层输出方法

 2023-09-10 阅读 20 评论 0

摘要:大纲引言一、钩子截流附:钩子函数二、视作输出参考 引言 pytorch底层, 本文想要解决的是pytorch中间层的输出问题,有时我们训练神经网络时会设定回归或者分类作为目标,但在测试阶段实际需要的只是用神经网络提取输入的表征,因此需要获得网络的

大纲

  • 引言
  • 一、钩子截流
    • 附:钩子函数
  • 二、视作输出
  • 参考

引言

pytorch底层, 本文想要解决的是pytorch中间层的输出问题,有时我们训练神经网络时会设定回归或者分类作为目标,但在测试阶段实际需要的只是用神经网络提取输入的表征,因此需要获得网络的中间层输出。总结起来有两种方法:

  • 一是在前向传播过程中通过"钩子"截取;
  • 二是将中间层输出也视作输出,在最后获取

一、钩子截流

 这种方式是在前向传播进行中还没得到最终输出时,将所需要的中间层输出从前向数据流中提取出来,利用到了pytorch中的register_hook()函数。这一函数可以为模型中的某个module设置一个回调函数,形如:
hook(module, input, output) -> None or modified output
函数的输入值为module的名字、module的输入和输出。通过前置定义一个数组,在hook()函数中将对应module的输入或输出加入该数组以实现中间层提取。给出代码如下:

import torch
from torch import  nnclass test_model(nn.Module):def __init__(self):super(test_model, self).__init__()self.conv_16 = nn.Sequential(nn.Conv2d(1,16,(3,3),(1,1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2,2)))self.conv_32 = nn.Sequential(nn.Conv2d(16,32,(3,3),(1,1)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))self.linear_1 = nn.Sequential(nn.Linear(32,64),nn.ReLU())self.linear_class = nn.Sequential(nn.Linear(64,5),nn.ReLU())def forward(self, x):x = self.conv_16(x)x = self.conv_32(x)x = x.view(x.shape[0],x.shape[1])x = self.linear_1(x)return self.linear_class(x)features = []def hook(module, input, output):features.append(input)return Nonenet = test_model()# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():print(name)
#  设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))
net(a)
print(features)

实际过程中建议先打印所有层的名字以做到精确提取。

附:钩子函数

pytorch embedding? 值得注意的是这个函数的用途并不止于提取中间层的输出,它也可以用于对module的输出值进行修改,查看该函数的源码注释

r"""Registers a forward hook on the module.

The hook will be called every time after :func:forward has
computed an output. It should have the following signature::

   hook(module, input, output) -> None or modified output

pytorch batchsize,The input contains only the positional arguments given to the
module. Keyword arguments won’t be passed to the hooks and only to
the forward. The hook can modify the output. It can modify the
input inplace but it will not have effect on forward since this is
called after :func:forward is called.

我们可以分析得到:
 register_hook()函数在对应module前向传播产生输出后自动执行,回调函数的输入只包括了module的位置参数不包括关键字参数。回调函数可以通过return 修改过的输出值来对module的最终输出进行修改,同样在回调函数内部我们也可以对输入进行inplace修改,但并不会对module的输出值造成影响因为register_hook()是在对应module前向传播产生输出后之执行,输入值已经被计算过了。(这里本人对最后一句话的理解与参考中的不一样,但未经验证过。)

二、视作输出

 通过返回值提取中间层输出比较简单,同样有两种方法来实现:

  • 一是将中间层的返回值作为模型的属性,在初始化时定义好;
  • 二是在forward函数将中间层返回值一并输出;

pytorch parameter。代码如下:

import torch
from torch import  nnclass test_model(nn.Module):def __init__(self):super(test_model, self).__init__()self.conv_16 = nn.Sequential(nn.Conv2d(1,16,(3,3),(1,1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2,2)))self.conv_32 = nn.Sequential(nn.Conv2d(16,32,(3,3),(1,1)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))self.linear_1 = nn.Sequential(nn.Linear(32,64),nn.ReLU())self.linear_class = nn.Sequential(nn.Linear(64,5),nn.ReLU())self.feature=[]def forward(self, x):x = self.conv_16(x)x = self.conv_32(x)x = x.view(x.shape[0],x.shape[1])x = self.linear_1(x)self.feature.append(x.detach())feature = x.detach()return self.linear_class(x),featurefeatures = []def hook(module, input, output):features.append(input)return Nonenet = test_model()# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():print(name)
#  设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))_,final_out=net(a)
hook_out=features
att_out=net.feature

对比可以发现三者输出是一致的。

参考

Pytorch获取中间层输出的几种方法
pytorch的hook机制之register_forward_hook

版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。

原文链接:https://808629.com/39836.html

发表评论:

本站为非赢利网站,部分文章来源或改编自互联网及其他公众平台,主要目的在于分享信息,版权归原作者所有,内容仅供读者参考,如有侵权请联系我们删除!

Copyright © 2022 86后生记录生活 Inc. 保留所有权利。

底部版权信息