Zookeeper 微信商家收款 哨兵模式 multithreading datetime delphi tcp colors uiviewcontroller menu pygame swift2 vue前端 后台管理网页模板 十大erp系统 找公司做网站 java并发编程视频 android实战项目 electron安装 jq获取最后一个子元素 mysql在线测试 mac安装hadoop db2从入门到精通 grep不是内部命令 linux获取当前时间 重置hosts python中文 python基本语法 配置python环境 java框架 java时间函数 java当前时间 java平台 java入门代码 java文件读取 linux系统教程 microkms 东方头条邀请码 din字体下载 python游戏代码
当前位置: 首页 > 学习教程  > 编程语言

nll loss 和CrossEntropyLoss 的一些区别

2020/12/28 19:06:23 文章标签:

NLL Loss 在传入这个loss前,需要先对输入进行一次 log_softmax 的变换, 例子如下: import torch import torch.nn as nn import torch.nn.functional as Fseed1 m nn.LogSoftmax(dim1) loss nn.NLLLoss() # input is of size N x C 3 x 5 input to…

NLL Loss

在传入这个loss前,需要先对输入进行一次 log_softmax 的变换, 例子如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

seed=1
m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
print(input)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)
print(output)
output.backward()

得到

tensor([[ 0.7865, -0.3559, -1.0439,  1.1853,  0.9505],
        [-0.7899,  0.4999, -1.6801, -0.4715, -0.4272],
        [ 1.4955, -2.1098, -2.5503, -0.1414, -0.1313]], requires_grad=True)
tensor(2.2048, grad_fn=<NllLossBackward>)

计算细节参考豪哥的博客

CrossEntropy Loss

和NLL Loss相比,省去了log_softmax 的变换, 例子如下:

target2=torch.tensor([1,0,4], dtype=torch.long)
print(target2)
CEloss=nn.CrossEntropyLoss()
output2 = CEloss(input,target2)
print(output2)

输出:

tensor([1, 0, 4])
tensor(2.2048, grad_fn=<NllLossBackward>)

值得注意的地方

  1. CE loss 的target需要指定为long
  2. CE loss 计算导数的时候,本质上和NLLLoss是一样的,看输出里的grad_fn就知道了

本文链接: http://www.dtmao.cc/news_show_550159.shtml

附件下载

相关教程

    暂无相关的数据...

共有条评论 网友评论

验证码: 看不清楚?