刷脸支付 LVS 跨域 自定义指令 winforms wcf magento parameters casting scope scrapy vue前端 十大erp系统 bootstrap日历控件 新手学c还是java mysql新增用户和权限 java获取字符串 python代码示例 python环境设置 python简易教程 python读文件 javapackage java实例变量 java如何使用 java获取数据类型 java泛型方法 java抛出自定义异常 java字符串操作 java文件输入输出 php项目实例 php实例教程 战地女记者 vbscript程序员参考手册 vbs表白代码 地球末日攻略 网络适配器下载 自动回复机器人 陌陌电脑直播设置教程 spss22安装教程 lol卡米尔
当前位置: 首页 > 学习教程  > 编程语言

pytorch(3)torch.cat()和torch.stack()

2020/12/5 10:34:06 文章标签:

在进行cv相关实验中我们用的比较多的都是torch.cat()和torch.stack()函数。其中cat()函数的功能是在当前维度进行数据的拼接。stack()函数首先将当前维度及其以后维度的数据向后移动一位,将该位置的大小修改为1然后在进行拼接。这两个函数的区别是cat()函数不会进行…

       在进行cv相关实验中我们用的比较多的都是torch.cat()和torch.stack()函数。其中cat()函数的功能是在当前维度进行数据的拼接。stack()函数首先将当前维度及其以后维度的数据向后移动一位,将该位置的大小修改为1然后在进行拼接。这两个函数的区别是cat()函数不会进行维度扩展,stack()函数会进行维度扩展。以下从代码角度去理解这两个函数。

import numpy as np
import torch


# 创建两个张量并将其cat到一起。
def create_cat():
    t1 = torch.zeros((3, 3))
    t2 = torch.ones((3, 3))
    t_cat = torch.cat([t1, t2], dim=1)
    print(t_cat,t_cat.shape)

# 创建两个张量并将其stack到一起。
def create_stack():
    t1 = torch.zeros((3,3))
    t2 = torch.ones((3,3))
    t_stack = torch.stack([t1,t2],dim=1)
    print(t_stack,t_stack.shape)


if __name__ == '__main__':
    create_cat()
    create_stack()

拼接
冲上述图可以看到cat()和stack()都是在1维进行拼接的,cat()函数是直接在1维拼接原数据由两个[3,3]大小的数据变成了[3,6]的。stack()函数将数据[3,3]扩充成两个[3,1,3]然后在1维进行拼接变成了[3,2,3].
以下是画图表示,大家凑合着看哈。
由[3*3*1]拓展到[3*3*2]

由[1*3*3]拓展到[2*3*3],由[3*1*3]拓展到[3*2*3]


本文链接: http://www.dtmao.cc/news_show_450256.shtml

附件下载

相关教程

    暂无相关的数据...

共有条评论 网友评论

验证码: 看不清楚?