搜索
您的当前位置:首页深入理解torch.functional.cross_entropy或F.cross_entropy的原理

深入理解torch.functional.cross_entropy或F.cross_entropy的原理

来源:世旅网

参考的回答:

可知以下两个调用等价:

import torch
import torch.nn.functional as F
x = torch.FloatTensor([[1.,0.,0.],
                       [0.,1.,0.],
                       [0.,0.,1.]])
y = torch.LongTensor([0,1,2])

print(torch.nn.functional.cross_entropy(x, y))  # tensor(0.5514)
print(F.nll_loss(F.log_softmax(x, 1), y))  # tensor(0.5514)

并且,以下两个调用是等价的:

print(F.softmax(x, 1).log())
print(F.log_softmax(x, 1))

由此可知,torch的CE loss会先沿着prediciton score矩阵的每一行计算softmax操作,再全部计算log。最后再基于negative likelyhood loss去计算最终的loss。。

因篇幅问题不能全部显示,请点此查看更多更全内容

Top