如何优雅地读取网络的中间特征?

0.前言

在调试深度神经网络工程时,常会在前向计算过程中将网络的中间层信息返回,便于打印或者可视化网络中间结果。实现该功能的一个常用方法是在构建model类时,在forward返回要保留的中间信息。

这里跟大家分享一个更优雅、便捷的方法,利用torchvision提供的IntermediateLayerGetter类,在网络前向计算时返回指定的特征。

1.使用方法

IntermediateLayerGetter类在torchvision/models/_utils.py中实现。

一个简单的使用案例如下:

import torch
import torchvision.models as models

original_model = models.resnet18(pretrained=True)
wrapped_model = models._utils.IntermediateLayerGetter(original_model, {'layer1': 'feat1', 'layer3': 'feat2'})
out = wrapped_model(torch.rand(1, 3, 224, 224))
print(out['feat1'].shape)
print(out['feat2'].shape)

IntermediateLayerGetter类在实例化时,对原来的模型类进行了一层封装,且需要传入字典来指示想返回的中间特征名和访问特征时使用的name。

构造IntermediateLayerGetter时需要传入字典,字典的key来源于dict(original_model.named_children()).keys(),对于上例,key来源于:

dict_keys(['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', 'fc'])

传入字典的值是自己定义的字符串,在前向推理结束后的返回结果中,将自定义的字符串作为key来访问对应的中间变量。比如上例传入的字典是{'layer1': 'feat1', 'layer3': 'feat2'},则得到前向推理输出结果out后,通过out['feat1']访问layer1的输出,通过out['feat2']访问layer3的输出。

2.原理

IntermediateLayerGetter类的源码比较简单,如下:

class IntermediateLayerGetter(nn.ModuleDict):
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
        orig_return_layers = return_layers
        return_layers = {str(k): str(v) for k, v in return_layers.items()}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super().__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

本质上来讲,IntermediateLayerGetter的实例在初始化时,使用model.named_children()构造一个OrderDict,再用得到的OrderDict去初始化容器nn.ModuleDict()

在前向计算时按照nn.ModuleDict()容器的内容,顺序执行里面的模块;只是在执行时,会判断容器中模块的名字【即model.named_children()的key】是否在指定的返回值名字列表中,若在列表中,则保存该中间结果到返回值字典中。

这就是在实例化IntermediateLayerGetter时传入字典的key来源于dict(original_model.named_children()).keys()的原因。

3.局限性

根据前文IntermediateLayerGetter的实现方法以及原理,可以很容易发现使用IntermediateLayerGetter获取网络推理中间结果的局限性:

(1)只能获取model.named_children()级别的模块的输出特征,对于更细分模块的输出特征则无法获取;

(2)模型的顶层必须是可以顺序执行的,因为只有这样才能将model.named_children()获取的模块存到OrderDict中并封装为nn.ModuleDict()

4.开源工程使用案例

在DETR官方的开源代码中(链接:https://github.com/facebookresearch/detr),在文件models/backbone.pyBackboneBase类中使用了该方法获取其中model的中间结果。

推荐阅读

港科大提出适用于夜间场景语义分割的无监督域自适应新方法

EViT:借鉴鹰眼视觉结构,南开大学等提出ViT新骨干架构,在多个任务上涨点

HSN:微调预训练ViT用于目标检测和语义分割,华南理工和阿里巴巴联合提出

CV计算机视觉每日开源代码Paper with code速览-2023.10.13

CV计算机视觉每日开源代码Paper with code速览-2023.10.12

CV计算机视觉每日开源代码Paper with code速览-2023.10.10


http://www.niftyadmin.cn/n/5094091.html

相关文章

python的搜索引擎系统设计与实现 计算机竞赛

0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 python的搜索引擎系统设计与实现 🥇学长这里给一个题目综合评分(每项满分5分) 难度系数:3分工作量:5分创新点:3分 该项目较为新颖&#xff…

好未来sre面经

好未来sre CDN DNS(域名系统)底层使用的是UDP(用户数据报协议)。 服务器响应慢怎么排查 检查网络连接:确保服务器与网络连接稳定,没有网络故障或带宽限制。可以尝试使用其他设备或工具测试网络连接。 检…

助力森林火情预警检测,基于YOLOv7-tiny、YOLOv7和YOLOv7x开发构建无人机航拍场景下的森林火情检测是别预警系统

火情的预警与检测识别对于保障林业安全,减少人员伤亡有着重要的意义,科学有效地早发现早扑灭是最有效的干预手段,本文的主要是想就是想要建立基于无人机航拍场景下的森林火情检测预警系统,整体效果如下所示: 这里文中选…

【django2.0之Rest_Framework框架一】rest_framework序列器介绍

Django RestFramework(简称DRF) 提供了序列化器Serialzier的定义,可以帮助我们简化序列化与反序列化的过程,不仅如此,还提供丰富的类视图、扩展类、视图集来简化视图的编写工作。REST framework还提供了认证、权限、限流、过滤、分页、接口文…

【C++ Primer Plus学习记录】指针——指针和字符串

数组和指针的特殊关系可以扩展到C-风格字符串。请看下面的代码&#xff1a; char flower[10] "rose"; cout << flower << "s are red\n"; 数组名是第一个元素的地址&#xff0c;因此cout语句中的flower是包含字符r的char元素的地址。cout对…

CSS 复杂卡片/导航栏特效运用目录

主要是记录复杂卡片/导航栏相关的特效实践案例和实现思路。 章节名称完成度难度文章地址完整代码下载地址多曲面卡片实现完成复杂文章链接代码下载倒置边框半径卡片完成一般文章链接代码下载

COMMUTING CONDITIONAL GANS FOR MULTI-MODAL FUSION

方法 C f ^f f是分类器&#xff0c;P f ^f f(o i _i i​)是第i个物体出现的融合概率 作者未公布代码

【计算机毕设选题推荐】蛋糕甜品店管理系统SpringBoot+SSM+Vue

前言&#xff1a;我是IT源码社&#xff0c;从事计算机开发行业数年&#xff0c;专注Java领域&#xff0c;专业提供程序设计开发、源码分享、技术指导讲解、定制和毕业设计服务 项目名 基于SpringBoot的蛋糕甜品店管理系统 技术栈 SpringBootSSMVueMySQLMaven 文章目录 一、蛋糕…