一帧数据 tws mirror 远程桌面登陆 Morecoin idea 常用快捷键 Netty k8s serialization tfs timer process 建筑资质 八大员 找公司做网站 android项目实例 axure组件库下载 oracle行转列函数 java不定长数组 python循环 河南普通话报名入口 python生成多个随机数 python多线程编程 input函数python python支持中文 java语言基础教程 java判断语句 java写入txt java删除文件 java获取时间 网站数据分析工具 键盘模拟器 2k14生涯模式修改器 凯恩与林奇2下载 显示器面板类型 掌门一对一下载 c4dr19 mix2s拆机 pro换肤 c语言从入门到精通 无主之地2联机超时
当前位置: 首页 > 学习教程  > 编程语言

二分类和多分类交叉熵函数区别详解

2020/12/28 18:47:10 文章标签:

二分类和多分类交叉熵函数区别详解 写在前面 查了下百度,交叉熵,是度量两个分布间差异的概念。而在我们神经网络中,两个分布也就是y的真实值分布和预测值分布。当两个分布越接近时,其交叉熵值也就越小。 根据上面知识&#xff…

二分类和多分类交叉熵函数区别详解

写在前面

查了下百度,交叉熵,是度量两个分布间差异的概念。而在我们神经网络中,两个分布也就是y的真实值分布和预测值分布。当两个分布越接近时,其交叉熵值也就越小。

根据上面知识,也就转化为我们需要解决让预测值和真实值尽可能接近的问题,而这正与概率论数理统计中的最大似然分布一脉相承,进而目标转化为确定值的分布和求解最大似然估计问题。

二分类问题

表示分类任务中有两个类别,比如我们想判断一张图片是不是猫。也就是说,训练一个分类器,输入一张图片,用特征向量x表示,输出是不是猫用y=0或1表示,其中1表示是,0表示不是。

这样的问题,我们完全可以用0-1分布来进行表示:

y i y_i yi 1 − y i 1-y_i 1yi
y i ^ \hat{y_i} yi^ 1 − y i ^ 1-\hat{y_i} 1yi^

注:其中yi为真实值, y i ^ \hat{y_i} yi^为预测值,且 y i y_i yi的值为0或1

此时求解最大似然估计过程如下:
L ( y i ^ ) = Π i = 1 n y i ^ y i ( 1 − y i ^ ) 1 − y i L(\hat{y_i})=\Pi_{i=1}^{n}\hat{y_i}^{y_i}(1-\hat{y_i})^{1-y_i} L(yi^)=Πi=1nyi^yi(1yi^)1yi
两边同时取对数
l o g ( L ( y i ^ ) ) = ∑ i = 1 n ( y i l o g ( y i ^ ) + ( 1 − y i ) l o g ( 1 − y i ^ ) ) log(L(\hat{y_i}))=\sum_{i=1}^{n}(y_ilog(\hat{y_i})+(1-y_i)log(1-\hat{y_i})) log(L(yi^))=i=1n(yilog(yi^)+(1yi)log(1yi^))
最大似然估计要求数越大越好,而损失函数要求越小越好,因而损失函数在前面加上负号,因而也得到了二分类问题使用的交叉熵损失函数
L o s s = − ∑ i = 1 n ( y i l o g ( y i ^ ) + ( 1 − y i ) l o g ( 1 − y i ^ ) ) Loss=-\sum_{i=1}^{n}(y_ilog(\hat{y_i})+(1-y_i)log(1-\hat{y_i})) Loss=i=1n(yilog(yi^)+(1yi)log(1yi^))

多分类问题

表示分类任务有多个类别,如对一堆水果分类,它们可能是橘子、苹果、梨等,每个样本有且只有一个标签。

这种情况与二分类类似,只是可能的情况增多了,可以描述为一个离散分布

y 1 y_{1} y1 y 2 y_2 y2 y k y_k yk
y 1 ^ \hat{y_1} y1^ y 2 ^ \hat{y_2} y2^ y k ^ \hat{y_k} yk^

注: y 1 、 y 2 . . . y k y_1、y_2...y_k y1y2...yk为真实值,其中有且只有一个为1,其余为0。(采用one-hot编码)

此时求解最大似然函数过程如下:
L ( y i ^ ) = Π i = 1 n ( y ( i , 1 ) ^ y ( i , 1 ) y ( i , 2 ) ^ y ( i , 2 ) . . . y ( i , n ) ^ y ( i , n ) ) L(\hat{y_i})=\Pi_{i=1}^{n}(\hat{y_{(i,1)}}^{y_{(i,1)}}\hat{y_{(i,2)}}^{y_{(i,2)}}...\hat{y_{(i,n)}}^{y_{(i,n)}}) L(yi^)=Πi=1n(y(i,1)^y(i,1)y(i,2)^y(i,2)...y(i,n)^y(i,n))
因为真实值只有一个为1,其余为0,因而只有1项值非零,可化简为:
L ( y i ^ ) = Π i = 1 n y ( i , m ) ^ y ( i , m ) L(\hat{y_i})=\Pi_{i=1}^{n}\hat{y_{(i,m)}}^{y_{(i,m)}} L(yi^)=Πi=1ny(i,m)^y(i,m)
注: y ( i , m ) ^ \hat{y_{(i,m)}} y(i,m)^表示含义为第i个样本,属于第m个类别(m值会随样本的变化动态改变)

两边同时取对数:
l o g ( L ( y i ^ ) ) = ∑ i = 1 n y ( i , m ) l o g ( y i , m ^ ) log(L(\hat{y_i}))=\sum_{i=1}^{n}y_{(i,m)}log(\hat{y_{i,m}}) log(L(yi^))=i=1ny(i,m)log(yi,m^)
与二元分类同理,此时多分类的交叉熵损失函数即为:
L o s s = − ∑ i = 1 n y ( i , m ) l o g ( y i , m ^ ) Loss=-\sum_{i=1}^{n}y_{(i,m)}log(\hat{y_{i,m}}) Loss=i=1ny(i,m)log(yi,m^)

参考文献

[1] https://www.bilibili.com/video/BV1a5411W7Dn?t=47
[2] https://juejin.cn/post/6844903630479294477


本文链接: http://www.dtmao.cc/news_show_550113.shtml

附件下载

相关教程

    暂无相关的数据...

共有条评论 网友评论

验证码: 看不清楚?