端口映射 Zookeeper使用 Pytorch algorithm image scripting rss Animsition vue源码下载 click事件 js回调函数写法 matlab不等于 linux查询文件内容 python搭建环境 python的安装 python条件判断 python字典添加 java类与对象 java学习基础 java接口的使用 java求阶乘 java字符 tmac修改器 销售单软件 快点蛆虫成就单刷 su版本转换器 选择模拟位置信息应用 gg修改器下载 java字符串截取 流程图工具 comsol下载 工程html加密 微信小程序开发实例 蜘蛛皮肤 五笔字型86版 超级网游助手 砸金蛋抽奖活动 骰子牛牛 lol游戏环境异常 js组合
当前位置: 首页 > 学习教程  > 编程语言

机器学习笔记(4)Pytorch学习

2020/10/8 20:31:20 文章标签:

机器学习笔记(4)Pytorch学习 笔记(3)代码的遇到的一些问题和以及相关知识的进一步学习 读取数据 1.Dataset类 用images和labels来定义自己的数据集类,之后送入dataloader中 可以看这个博客,讲的很详细 h…

机器学习笔记(4)Pytorch学习

笔记(3)代码的遇到的一些问题和以及相关知识的进一步学习

读取数据

1.Dataset类
用images和labels来定义自己的数据集类,之后送入dataloader中
可以看这个博客,讲的很详细
https://blog.csdn.net/VictoriaW/article/details/72356453?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param

2.torchvision 包
PyTorch中专门用来处理图像的库,包含四个大类datasets,models,transforms,utils
2.1datasets
用来加载一些经典的数据集,即这些数据集可以不用1中的方法来自己定义,可以直接利用datasets来加载。

dataset_train = datasets.MNIST(
	root='./MNIST', 	#加载数据集的目录
    train=True,			#是否为训练集
    transform=None,		#是否进行图像预处理
    download=True)		#是否下载数据

2.2transforms
用来图像预处理
transforms.ToTensor():用来将像素值范围归到[0,1]
transforms.Normalize(mean,std):用来将像素值归到[-1,1]
transforms.Resize():图像缩放
transforms.Compose([]):组合多个预处理方法
2.3models
用来加载一些经典的模型

import torchvision.models as models
VGG1 = models.VGG()			       #q权重随机模型
VGG2 = models.VGG(pretrained=True) #预训练模型

3.torch.utils.data.DataLoader 生成训练所需batch数据的类。
batch的数量是指计算一次loss,输入的样本个数

train_loader = torch.utils.data.DataLoader(
    dataset_train,               #你的数据集
    batch_size=batch_size,  	 #训练时每批的大小
    shuffle=True,)				 #是否随机打乱

损失函数

1.遇到的问题:
笔记3中使用 nn.CrossEntropyLoss()作为损失函数,使用它时网络结构最后一层输出不需要softmax了,因为他整合了nn.logSoftmax()和nn.NLLLoss()。
并且它在计算每一batch时会自动求平均。所以计算这一epoch的损失函数需要乘上lable.size(0)或者batch_size,最后的总误差要除以len(datatrain)。

2.网络参数更新

optimizer.zero_grad()
loss.backward()
optimizer.step()

在每一个batch中,需要根据LOSS反向传播的梯度信息来更新网络参数,经常要用到以上三个函数。
首先清零梯度信息,这一batch参数的更新只跟这一batch的梯度有关,所以要先清除上一batch的梯度信息。
之后损失函数反向传播来计算梯度。
最后利用优化器(常用的有ADAM,SGD等)结合learning rate来更新网络参数


本文链接: http://www.dtmao.cc/news_show_250323.shtml

附件下载

相关教程

    暂无相关的数据...

共有条评论 网友评论

验证码: 看不清楚?