blob: 67ccfa54db5a70d104cba6ae97ed9d87c00dd46d (
plain)
1
2
3
4
5
6
7
8
9
10
|
from torch import nn
import torch
if __name__ == '__main__':
input = torch.randn(3, 5)
target = torch.empty(3, dtype=torch.long).random_(5)
loss = nn.CrossEntropyLoss()
output = loss(input, target)
print(output)
|