MindSpore和Python中nn.Unfold的区别
创始人
2024-02-07 12:23:59
0

在往MindSpore迁移项目中遇到了这个转换,以至于不得不去仔细研究一下。

Unfold是卷积操作中的一部分,我们来看一下描述。

Unfold()函数是从一个batch图片中,提取出滑动的局部区域块,也就是卷积操作中的提取kernel filter对应的滑动窗口。Mindspore和pytorch的功能比较可以参考官网的链接。
比较与torch.nn.Unfold的功能差异 — MindSpore master documentation

首先强调一下几个我认为重点的东西

1、MindSpore的Unfold只能GPU使用无法在CPU上使用,测试也是不行的。

2、MindSpore这个方法输出的是四维的,Pytorch的输出是三维的。而只要把MindSpore 的最后两个维度合并就是Pytorch的结果了,这个后续讲详细介绍。

3、官网举的例子特别容易让人误解,Pytorch的例子直接用的Pytorch用的他们官方的,但是自己的例子是自己编的,并且和pytorch的例子还不是对应的关系。这让我们刚刚开始使用MindSpore的特别容易误解。

下面详细说一下这两个框架中Unfold的输入输出,以至于可以快速进行迁移。

在Pytorch中需要的参数是

  • kernel_size (int or tuple) – 滑动窗口的大小

  • stride (int or tupleoptional) – 滑动步长 Default: 1

  • padding (int or tupleoptional) – padding Default: 0

  • dilation (int or tupleoptional) – 空洞大小,这里默认1就是没有空洞,和conv中的有所区别. Default: 1 (这里就可以看出pytorch的文档写的确实好,他怕解释不清楚,给了可视化的链接conv_arithmetic/README.md at master · vdumoulin/conv_arithmetic · GitHub)

那么同理MindSpore这边的参数也是差不多的,但是格式真的是差距很大。也不知道为什么要求必须两边加个1,而padding的数量也是帮你定好了的。

  • ksizes (Union[tuple[int], list[int]]) - 滑窗大小,其格式为[1, ksize_row, ksize_col, 1]的int组成的tuple或list。

  • strides (Union[tuple[int], list[int]]) - 滑窗步长,其格式为[1, stride_row, stride_col, 1]的int组成的tuple或list。

  • rates (Union[tuple[int], list[int]]) - 滑窗元素之间的空洞个数,其格式为[1, rate_row, rate_col, 1] 的int组成的tuple或list。

  • padding (str) - 填充模式,可选值有:”same”或”valid”的字符串,不区分大小写。默认值:”valid”。

    • same - 指所提取的区域块的部分区域可以在原始图像之外,此部分填充为0。

    • valid - 表示所取的区域快必须被原始图像所覆盖。

好的,假如目前我们都设置了同样参数的Unfold,因为mindspore没有默认的参数,所以需要我们输入。那么同样参数的Unfold,得到的结果有什么差别呢?

py_unfold = torch.nn.Unfold(kernel_size=(2, 2)) # pytorch
ms_unfold = mindspore.nn.Unfold(ksizes=[1, 2, 2, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1])

假设input的维度为(N,C,W,H),kernel_size=(k1, k2)

那么Pytorch的输出维度为(N, C\times k1\times k2, L)

MindSpore的输出维度为(N, C\times k1\times k2, L_r\times L_w)

在这里L的求法都给了在数学上很复杂的公式,如果只是为了理解,那么完全没必要去看那个公式,在这里L = L_r\times L_w,这里第一个维度N就是batch_size,一直没有变,第二个维度是kernel乘以通道数C,最后一个维度是每一层会产生多少个小的窗口。用一张图就可以很容易解释了。加入input为(N,C,3,3),Unfold的核大小为(2,2),stride = 1 ,paddind = 0, dilation = 1。则最后的L = 4,如果使用MS,得到的为L_r\times L_w = 2\times 2

相关内容

热门资讯

喜欢穿一身黑的男生性格(喜欢穿... 今天百科达人给各位分享喜欢穿一身黑的男生性格的知识,其中也会对喜欢穿一身黑衣服的男人人好相处吗进行解...
发春是什么意思(思春和发春是什... 本篇文章极速百科给大家谈谈发春是什么意思,以及思春和发春是什么意思对应的知识点,希望对各位有所帮助,...
网络用语zl是什么意思(zl是... 今天给各位分享网络用语zl是什么意思的知识,其中也会对zl是啥意思是什么网络用语进行解释,如果能碰巧...
为什么酷狗音乐自己唱的歌不能下... 本篇文章极速百科小编给大家谈谈为什么酷狗音乐自己唱的歌不能下载到本地?,以及为什么酷狗下载的歌曲不是...
华为下载未安装的文件去哪找(华... 今天百科达人给各位分享华为下载未安装的文件去哪找的知识,其中也会对华为下载未安装的文件去哪找到进行解...
家里可以做假山养金鱼吗(假山能... 今天百科达人给各位分享家里可以做假山养金鱼吗的知识,其中也会对假山能放鱼缸里吗进行解释,如果能碰巧解...
四分五裂是什么生肖什么动物(四... 本篇文章极速百科小编给大家谈谈四分五裂是什么生肖什么动物,以及四分五裂打一生肖是什么对应的知识点,希...
怎么往应用助手里添加应用(应用... 今天百科达人给各位分享怎么往应用助手里添加应用的知识,其中也会对应用助手怎么添加微信进行解释,如果能...
客厅放八骏马摆件可以吗(家里摆... 今天给各位分享客厅放八骏马摆件可以吗的知识,其中也会对家里摆八骏马摆件好吗进行解释,如果能碰巧解决你...
美团联名卡审核成功待激活(美团... 今天百科达人给各位分享美团联名卡审核成功待激活的知识,其中也会对美团联名卡审核未通过进行解释,如果能...