summaryrefslogtreecommitdiff
path: root/learn_torch/basics/crossentropy_loss.py
blob: 67ccfa54db5a70d104cba6ae97ed9d87c00dd46d (plain)
1
2
3
4
5
6
7
8
9
10

from torch import nn
import torch

if __name__ == '__main__':
    input = torch.randn(3, 5)
    target = torch.empty(3, dtype=torch.long).random_(5)
    loss = nn.CrossEntropyLoss()
    output = loss(input, target)
    print(output)