temporal shift module(TSM)
创始人
2024-04-11 07:53:04
0

【官方】Paddle2.1实现视频理解经典模型 — TSM - 飞桨AI Studio本项目将带大家深入理解视频理解领域经典模型TSM。从模型理论讲解入手,深入到代码实践。实践部分基于TSM模型在UCF101数据集上从训练到推理全流程实现行为识别任务。 - 飞桨AI Studiohttps://aistudio.baidu.com/aistudio/projectdetail/2310889?channelType=0&channel=0视频理解:基于TSM实现UCF101视频理解 - 飞桨AI Studio基于飞桨开源框架构建TSM,并实现对数据集UCF101的视频理解。 - 飞桨AI Studiohttps://aistudio.baidu.com/aistudio/projectdetail/4114499?channelType=0&channel=0

最近一直在做视频相关的项目,后续会陆续出一些视频理解和视频场景运动的案例,视频这块主推paddlevideo,里面应用层面的东西很丰富,paddle在应用侧一直做的比较好,模型训练这块可以结合mmaction2来,其实从实际应用角度来说,我觉得用paddle和pytorch训练都无所谓,部署的话可能以往我的经验更多是onnx,tensort服务侧的,目前来看,主要也就是服务器,端侧和页面侧的部署这三块,我看paddle分别有paddle inference、lite、js,国产框架中确实是首屈一指的,但是我自己的感觉是从我以前训练gan的结果看,paddle貌似要比pytorch的结果,一样的数据,一样的参数配置,好像要差一点。本文主要介绍一下tsm模块,利用2dcnn来模拟时序信息。视频中核心是视频动作识别,本质就是视频分类,可以用作特征提取,视频时序提取是输入一段长视频获取其中的时序片段,时空定位是同时获取视频中的人物物体的空间位置,核心三大任务,除此之外视频特征提取embedding,这块主要是结合多模态去做,视频,音频和文本侧特征的综合利用和提取。

1.时序信息维度 

上述这个视频序列从左向右播放和从右向左播放表达的意思是不同的,视频理解对视频顺序是强依赖的。

2.temporal shift module

这个模块是核心,其实tsm是可插拔模块,是可以很好的嵌入到resnet等模型中,上述图中,一种颜色是一帧,按照时序T上,一共是四帧,同一帧横向是一个channel,在cnn中channel是统一做cnn的,在a图中是没有shift的,在b中是离线shift操作,可见将channel中第一个向下移动,第二个向上移动,其实至于上下移动几个channel并没有很严的的限制,通常是分成几等分去移动,这样上下移动之后,则第一个channel会向下突出一帧,第二个channel会向上突出一帧,突出帧直接截断,空缺帧直接补0,这样在横向做cnn时,统一channel维度变引入不同色的帧,tsm正是通过这种平移的方式,TSM在特征图中引入 temporal 维度上的上下文交互,通过通道移动操作可以使得在当前帧中包含了前后两帧的通道信息,这样再进2D卷积操作就能像3D卷积一样直接提取视频的时空信息,提高了模型在时间维度上的建模能力。而online模式用于对视频类型的实时预测,在这种情况下,无法预知下一秒的图像,因此只能将channel维度由过去向现在移动,而不能从未来向现在移动。

3.缺点和改进

虽然时间位移的原理很简单,但作者发现直接将空间位移策略应用于时间维度并不能提供高性能和效率。具体来说,如果简单的转移所有通道,则会带来两个问题:

  1. 由于大量数据移动而导致的效率下降问题。位移操作不需要计算但是会涉及数据移动,数据移动增加了硬件上的内存占用和推理延迟,作者观察到在视频理解网络中,当使用naive shift策略时,CPU延迟增加13.7%,GPU延迟增加12.4%,使整体推理变慢。
  2. 空间建模能力变差导致性能下降,由于部分通道被转移到相邻帧,当前帧不能再访问通道中包含的信息,这可能会损失2D CNN主干的空间建模能力。与TSN基线相比,使用naive shift会降低2.6%的准确率。

为了解决naive shift的两个问题,TSM给出了相应的解决方法。

  1. 减少数据移动。 为了研究数据移动的影响,作者测量了TSM模型在不同硬件设备上的推理延迟,作者移动了不同比例的通道数并测量了延迟,位移方式分为无位移、部分位移(位移1/8、1/4、1/2的通道)和全部位移,使用ResNet-50主干和8帧输入测量模型。作者观察到,如果移动所有的通道,那么延迟开销将占CPU推理时间的13.7%,如果只移动一小部分通道,如1/8,则可将开销限制在3%左右。       
  2. 保持空间特征学习能力。 一种简单的TSM使用方法是将其直接插入到每个卷基层或残差模块前,如 所示,这种方法被称为 in-place shift,但是它会损失主干模型的空间特征学习能力,尤其当我们移动大量通道时,存储在通道中的当前帧信息会随着通道移动而丢失。为解决这个问题,作者提出了另一种方法,即将TSM放在残差模块的残差分支中,这种方法被称为 residual TSM,如所示,它可以解决退化的空间特征学习问题,因为原始的激活信息在时间转移后仍可通过identity映射访问。

 4.mmaction2中的代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import NonLocal3d
from torch.nn.modules.utils import _ntuplefrom ..builder import BACKBONES
from .resnet import ResNetclass NL3DWrapper(nn.Module):"""3D Non-local wrapper for ResNet50.Wrap ResNet layers with 3D NonLocal modules.Args:block (nn.Module): Residual blocks to be built.num_segments (int): Number of frame segments.non_local_cfg (dict): Config for non-local layers. Default: ``dict()``."""def __init__(self, block, num_segments, non_local_cfg=dict()):super(NL3DWrapper, self).__init__()self.block = blockself.non_local_cfg = non_local_cfgself.non_local_block = NonLocal3d(self.block.conv3.norm.num_features,**self.non_local_cfg)self.num_segments = num_segmentsdef forward(self, x):x = self.block(x)n, c, h, w = x.size()x = x.view(n // self.num_segments, self.num_segments, c, h,w).transpose(1, 2).contiguous()x = self.non_local_block(x)x = x.transpose(1, 2).contiguous().view(n, c, h, w)return xclass TemporalShift(nn.Module):"""Temporal shift module.This module is proposed in`TSM: Temporal Shift Module for Efficient Video Understanding`_Args:net (nn.module): Module to make temporal shift.num_segments (int): Number of frame segments. Default: 3.shift_div (int): Number of divisions for shift. Default: 8."""def __init__(self, net, num_segments=3, shift_div=8):super().__init__()self.net = netself.num_segments = num_segmentsself.shift_div = shift_divdef forward(self, x):"""Defines the computation performed at every call.Args:x (torch.Tensor): The input data.Returns:torch.Tensor: The output of the module."""x = self.shift(x, self.num_segments, shift_div=self.shift_div)return self.net(x)@staticmethoddef shift(x, num_segments, shift_div=3):"""Perform temporal shift operation on the feature.Args:x (torch.Tensor): The input feature to be shifted.num_segments (int): Number of frame segments.shift_div (int): Number of divisions for shift. Default: 3.Returns:torch.Tensor: The shifted feature."""# 假设当前feature map的通道是256,shift_div=3,# 那么就有256/3的特征进行shift left,256/3的特征进行shift right,其他一部分特征不动# num_segments每个视频采样的帧数# 每帧有c个通道,# [# [0_1,0_2,0_3,1_1,1_2,3_5,3_6,3_7]  第一帧,8个通道,但是shift_div表示这个通道维度被切分成3个等分# []  第二帧# []  第三帧# ]# [N, C, H, W]n, c, h, w = x.size()# [N // num_segments, num_segments, C, H*W]# can't use 5 dimensional array on PPL2D backend for caffex = x.view(-1, num_segments, c, h * w)# get shift foldfold = c // shift_div# split c channel into three parts:# left_split, mid_split, right_splitleft_split = x[:, :, :fold, :]mid_split = x[:, :, fold:2 * fold, :]right_split = x[:, :, 2 * fold:, :]# can't use torch.zeros(*A.shape) or torch.zeros_like(A)# because array on caffe inference must be got by computing# shift left on num_segments channel in `left_split`zeros = left_split - left_splitblank = zeros[:, :1, :, :]left_split = left_split[:, 1:, :, :]left_split = torch.cat((left_split, blank), 1)# shift right on num_segments channel in `mid_split`zeros = mid_split - mid_splitblank = zeros[:, :1, :, :]mid_split = mid_split[:, :-1, :, :]mid_split = torch.cat((blank, mid_split), 1)# right_split: no shift# concatenateout = torch.cat((left_split, mid_split, right_split), 2)# [N, C, H, W]# restore the original dimensionreturn out.view(n, c, h, w)@BACKBONES.register_module()
class ResNetTSM(ResNet):"""ResNet backbone for TSM.Args:num_segments (int): Number of frame segments. Default: 8.is_shift (bool): Whether to make temporal shift in reset layers.Default: True.non_local (Sequence[int]): Determine whether to apply non-local modulein the corresponding block of each stages. Default: (0, 0, 0, 0).non_local_cfg (dict): Config for non-local module. Default: ``dict()``.shift_div (int): Number of div for shift. Default: 8.shift_place (str): Places in resnet layers for shift, which is chosenfrom ['block', 'blockres'].If set to 'block', it will apply temporal shift to all child blocksin each resnet layer.If set to 'blockres', it will apply temporal shift to each `conv1`layer of all child blocks in each resnet layer.Default: 'blockres'.temporal_pool (bool): Whether to add temporal pooling. Default: False.**kwargs (keyword arguments, optional): Arguments for ResNet."""def __init__(self,depth,num_segments=8,is_shift=True,non_local=(0, 0, 0, 0),non_local_cfg=dict(),shift_div=8,shift_place='blockres',temporal_pool=False,**kwargs):super().__init__(depth, **kwargs)self.num_segments = num_segmentsself.is_shift = is_shiftself.shift_div = shift_divself.shift_place = shift_placeself.temporal_pool = temporal_poolself.non_local = non_localself.non_local_stages = _ntuple(self.num_stages)(non_local)self.non_local_cfg = non_local_cfgdef make_temporal_shift(self):"""Make temporal shift for some layers."""if self.temporal_pool:num_segment_list = [self.num_segments, self.num_segments // 2,self.num_segments // 2, self.num_segments // 2]else:num_segment_list = [self.num_segments] * 4if num_segment_list[-1] <= 0:raise ValueError('num_segment_list[-1] must be positive')if self.shift_place == 'block':def make_block_temporal(stage, num_segments):"""Make temporal shift on some blocks.Args:stage (nn.Module): Model layers to be shifted.num_segments (int): Number of frame segments.Returns:nn.Module: The shifted blocks."""blocks = list(stage.children())for i, b in enumerate(blocks):blocks[i] = TemporalShift(b, num_segments=num_segments, shift_div=self.shift_div)return nn.Sequential(*blocks)self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])elif 'blockres' in self.shift_place:n_round = 1if len(list(self.layer3.children())) >= 23:n_round = 2def make_block_temporal(stage, num_segments):"""Make temporal shift on some blocks.Args:stage (nn.Module): Model layers to be shifted.num_segments (int): Number of frame segments.Returns:nn.Module: The shifted blocks."""blocks = list(stage.children())for i, b in enumerate(blocks):if i % n_round == 0:blocks[i].conv1.conv = TemporalShift(b.conv1.conv,num_segments=num_segments,shift_div=self.shift_div)return nn.Sequential(*blocks)self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])else:raise NotImplementedErrordef make_temporal_pool(self):"""Make temporal pooling between layer1 and layer2, using a 3D maxpooling layer."""class TemporalPool(nn.Module):"""Temporal pool module.Wrap layer2 in ResNet50 with a 3D max pooling layer.Args:net (nn.Module): Module to make temporal pool.num_segments (int): Number of frame segments."""def __init__(self, net, num_segments):super().__init__()self.net = netself.num_segments = num_segmentsself.max_pool3d = nn.MaxPool3d(kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))def forward(self, x):# [N, C, H, W]n, c, h, w = x.size()# [N // num_segments, C, num_segments, H, W]x = x.view(n // self.num_segments, self.num_segments, c, h,w).transpose(1, 2)# [N // num_segmnets, C, num_segments // 2, H, W]x = self.max_pool3d(x)# [N // 2, C, H, W]x = x.transpose(1, 2).contiguous().view(n // 2, c, h, w)return self.net(x)self.layer2 = TemporalPool(self.layer2, self.num_segments)def make_non_local(self):# This part is for ResNet50for i in range(self.num_stages):non_local_stage = self.non_local_stages[i]if sum(non_local_stage) == 0:continuelayer_name = f'layer{i + 1}'res_layer = getattr(self, layer_name)for idx, non_local in enumerate(non_local_stage):if non_local:res_layer[idx] = NL3DWrapper(res_layer[idx],self.num_segments,self.non_local_cfg)def init_weights(self):"""Initiate the parameters either from existing checkpoint or fromscratch."""super().init_weights()if self.is_shift:self.make_temporal_shift()if len(self.non_local_cfg) != 0:self.make_non_local()if self.temporal_pool:self.make_temporal_pool()

 

相关内容

热门资讯

喜欢穿一身黑的男生性格(喜欢穿... 今天百科达人给各位分享喜欢穿一身黑的男生性格的知识,其中也会对喜欢穿一身黑衣服的男人人好相处吗进行解...
发春是什么意思(思春和发春是什... 本篇文章极速百科给大家谈谈发春是什么意思,以及思春和发春是什么意思对应的知识点,希望对各位有所帮助,...
网络用语zl是什么意思(zl是... 今天给各位分享网络用语zl是什么意思的知识,其中也会对zl是啥意思是什么网络用语进行解释,如果能碰巧...
为什么酷狗音乐自己唱的歌不能下... 本篇文章极速百科小编给大家谈谈为什么酷狗音乐自己唱的歌不能下载到本地?,以及为什么酷狗下载的歌曲不是...
家里可以做假山养金鱼吗(假山能... 今天百科达人给各位分享家里可以做假山养金鱼吗的知识,其中也会对假山能放鱼缸里吗进行解释,如果能碰巧解...
华为下载未安装的文件去哪找(华... 今天百科达人给各位分享华为下载未安装的文件去哪找的知识,其中也会对华为下载未安装的文件去哪找到进行解...
四分五裂是什么生肖什么动物(四... 本篇文章极速百科小编给大家谈谈四分五裂是什么生肖什么动物,以及四分五裂打一生肖是什么对应的知识点,希...
怎么往应用助手里添加应用(应用... 今天百科达人给各位分享怎么往应用助手里添加应用的知识,其中也会对应用助手怎么添加微信进行解释,如果能...
客厅放八骏马摆件可以吗(家里摆... 今天给各位分享客厅放八骏马摆件可以吗的知识,其中也会对家里摆八骏马摆件好吗进行解释,如果能碰巧解决你...
苏州离哪个飞机场近(苏州离哪个... 本篇文章极速百科小编给大家谈谈苏州离哪个飞机场近,以及苏州离哪个飞机场近点对应的知识点,希望对各位有...