当前位置:网站首页>cross entropy loss = log softmax + nll loss

cross entropy loss = log softmax + nll loss

2022-06-26 05:39:00 wujpbb7

The code is as follows :

import torch

logits = torch.randn(3,4,requires_grad=True)
labels = torch.LongTensor([1,0,2])
print('logits={}, labels={}'.format(logits,labels))

#  Calculate the cross entropy directly (cross entropy loss)
def calc_ce_loss1(logits, labels):
    ce_loss = torch.nn.CrossEntropyLoss()
    loss = ce_loss(logits, labels)
    return loss

#  Decomposition calculation cross entropy (cross entropy loss = log softmax + nll loss)
def calc_ce_loss2(logits, labels):
    log_softmax = torch.nn.LogSoftmax(dim=1)
    nll_loss = torch.nn.NLLLoss()
    logits_ls = log_softmax(logits)
    loss = nll_loss(logits_ls, labels)
    return loss

loss1 = calc_ce_loss1(logits, labels)
print('loss1={}'.format(loss1))
loss2 = calc_ce_loss2(logits, labels)
print('loss2={}'.format(loss2))

#  increase  temperature 
temperature = 0.05
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))

temperature = 2
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))

Output is as follows :

logits=tensor([[-0.7441, -2.3802, -0.1708,  0.5020],
        [ 0.3381, -0.3981,  2.2979,  0.6773],
        [-0.5372, -0.4489, -0.0680,  0.4889]], requires_grad=True), labels=tensor([1, 0, 2])
loss1=2.399930000305176
loss2=2.399930000305176
t=0.05, loss1=35.99229431152344
t=0.05, loss2=35.99229431152344
t=2, loss1=1.8117588758468628
t=2, loss2=1.8117588758468628

原网站

版权声明
本文为[wujpbb7]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/177/202206260532169998.html