当前位置:网站首页>cross entropy loss = log softmax + nll loss
cross entropy loss = log softmax + nll loss
2022-06-26 05:39:00 【wujpbb7】
The code is as follows :
import torch
logits = torch.randn(3,4,requires_grad=True)
labels = torch.LongTensor([1,0,2])
print('logits={}, labels={}'.format(logits,labels))
# Calculate the cross entropy directly (cross entropy loss)
def calc_ce_loss1(logits, labels):
ce_loss = torch.nn.CrossEntropyLoss()
loss = ce_loss(logits, labels)
return loss
# Decomposition calculation cross entropy (cross entropy loss = log softmax + nll loss)
def calc_ce_loss2(logits, labels):
log_softmax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()
logits_ls = log_softmax(logits)
loss = nll_loss(logits_ls, labels)
return loss
loss1 = calc_ce_loss1(logits, labels)
print('loss1={}'.format(loss1))
loss2 = calc_ce_loss2(logits, labels)
print('loss2={}'.format(loss2))
# increase temperature
temperature = 0.05
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))
temperature = 2
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))
Output is as follows :
logits=tensor([[-0.7441, -2.3802, -0.1708, 0.5020],
[ 0.3381, -0.3981, 2.2979, 0.6773],
[-0.5372, -0.4489, -0.0680, 0.4889]], requires_grad=True), labels=tensor([1, 0, 2])
loss1=2.399930000305176
loss2=2.399930000305176
t=0.05, loss1=35.99229431152344
t=0.05, loss2=35.99229431152344
t=2, loss1=1.8117588758468628
t=2, loss2=1.8117588758468628
边栏推荐
- cross entropy loss = log softmax + nll loss
- Recursively traverse directory structure and tree presentation
- Could not get unknown property ‘*‘ for SigningConfig container of type org. gradle. api. internal
- Serious hazard warning! Log4j execution vulnerability is exposed!
- Thinking about bad money expelling good money
- Using Jenkins to perform testng+selenium+jsup automated tests and generate extendreport test reports
- Consul service registration and discovery
- uniCloud云开发获取小程序用户openid
- 数据存储:MySQL之InnoDB与MyISAM的区别
- Ad tutorial series | 4 - creating an integration library file
猜你喜欢
Replacing domestic image sources in openwrt for soft routing (take Alibaba cloud as an example)
cartographer_ fast_ correlative_ scan_ matcher_ 2D branch and bound rough matching
慢慢学JVM之缓存行和伪共享
Gram 矩阵
Yunqi lab recommends experience scenarios this week, free cloud learning
Leetcode513.找出树的左下角的值
[red team] what preparations should be made to join the red team?
How does P2P technology reduce the bandwidth of live video by 75%?
uniCloud云开发获取小程序用户openid
Redis usage and memory optimization
随机推荐
How to ensure the efficiency and real-time of pushing large-scale group messages in mobile IM?
RIA ideas
cartographer_ local_ trajectory_ builder_ 2d
LeetCode_ Binary search tree_ Simple_ 108. convert an ordered array to a binary search tree
Sofa weekly | open source person - Yu Yu, QA this week, contributor this week
【ARM】在NUC977上搭建基于boa的嵌入式web服务器
data = self._ data_ queue. get(timeout=timeout)
cartographer_ pose_ graph_ 2d
LeetCode_二叉搜索树_简单_108.将有序数组转换为二叉搜索树
What management systems (Updates) for things like this
Setting pseudo static under fastadmin Apache
pytorch(环境、tensorboard、transforms、torchvision、dataloader)
新的征程
Posting - don't get lost in the ocean of Technology
生命原来如此脆弱
A new journey
cartographer_ fast_ correlative_ scan_ matcher_ 2D branch and bound rough matching
12 multithreading
Internship May 29, 2019
How Navicat reuses the current connection information to another computer