当前位置:网站首页>cross entropy loss = log softmax + nll loss
cross entropy loss = log softmax + nll loss
2022-06-26 05:32:00 【wujpbb7】
代码如下:
import torch
logits = torch.randn(3,4,requires_grad=True)
labels = torch.LongTensor([1,0,2])
print('logits={}, labels={}'.format(logits,labels))
# 直接计算交叉熵(cross entropy loss)
def calc_ce_loss1(logits, labels):
ce_loss = torch.nn.CrossEntropyLoss()
loss = ce_loss(logits, labels)
return loss
# 分解计算交叉熵(cross entropy loss = log softmax + nll loss)
def calc_ce_loss2(logits, labels):
log_softmax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()
logits_ls = log_softmax(logits)
loss = nll_loss(logits_ls, labels)
return loss
loss1 = calc_ce_loss1(logits, labels)
print('loss1={}'.format(loss1))
loss2 = calc_ce_loss2(logits, labels)
print('loss2={}'.format(loss2))
# 增加 temperature
temperature = 0.05
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))
temperature = 2
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))输出如下:
logits=tensor([[-0.7441, -2.3802, -0.1708, 0.5020],
[ 0.3381, -0.3981, 2.2979, 0.6773],
[-0.5372, -0.4489, -0.0680, 0.4889]], requires_grad=True), labels=tensor([1, 0, 2])
loss1=2.399930000305176
loss2=2.399930000305176
t=0.05, loss1=35.99229431152344
t=0.05, loss2=35.99229431152344
t=2, loss1=1.8117588758468628
t=2, loss2=1.8117588758468628边栏推荐
- Experience of reading the road to wealth and freedom
- cartographer_fast_correlative_scan_matcher_2d分支定界粗匹配
- Apktool tool usage document
- uni-app吸顶固定样式
- bingc(继承)
- Mongodb image configuration method
- Baidu API map is not displayed in the middle, but in the upper left corner. What's the matter? Resolved!
- 睛天霹雳的消息
- vscode config
- Mysql 源码阅读(二)登录连接调试
猜你喜欢

LeetCode_ Binary search tree_ Simple_ 108. convert an ordered array to a binary search tree

pytorch(网络模型训练)

cartographer_pose_graph_2d
![[arm] build boa based embedded web server on nuc977](/img/fb/7dc1898e35ed78b417770216b05286.png)
[arm] build boa based embedded web server on nuc977

Introduction to GUI programming to game practice (I)

Fedora alicloud source

Consul服务注册与发现

Command line interface of alluxio

Baidu API map is not displayed in the middle, but in the upper left corner. What's the matter? Resolved!

10 set
随机推荐
数据存储:MySQL之InnoDB与MyISAM的区别
About XXX management system (version C)
Internship May 29, 2019
Leetcode513. Find the value in the lower left corner of the tree
C# 39. Conversion between string type and byte[] type (actual measurement)
pytorch(环境、tensorboard、transforms、torchvision、dataloader)
Experience of reading the road to wealth and freedom
About abstact and virtual
[arm] build boa based embedded web server on nuc977
Command line interface of alluxio
Official image acceleration
定位设置水平,垂直居中(多种方法)
Positioning setting horizontal and vertical center (multiple methods)
C# 40. Byte[] to hexadecimal string
data = self._data_queue.get(timeout=timeout)
虚拟项目失败感想
How to ensure the efficiency and real-time of pushing large-scale group messages in mobile IM?
bingc(继承)
Use jedis to monitor redis stream to realize message queue function
【活动推荐】云原生、产业互联网、低代码、Web3、元宇宙……哪个是 2022 年架构热点?...