diff options
| author | zhang <zch921005@126.com> | 2022-07-24 20:25:48 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-07-24 20:25:48 +0800 |
| commit | 92d3bc06bad13095df6515111bba45e73f701018 (patch) | |
| tree | 5730478fe92b39f3b909843546291d0eced774d0 /learn_torch/basics | |
| parent | e9945ee44d8c46d93d50f023f49e79f3ba532583 (diff) | |
wordpiece
Diffstat (limited to 'learn_torch/basics')
| -rw-r--r-- | learn_torch/basics/crossentropy_loss.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/learn_torch/basics/crossentropy_loss.py b/learn_torch/basics/crossentropy_loss.py new file mode 100644 index 0000000..67ccfa5 --- /dev/null +++ b/learn_torch/basics/crossentropy_loss.py @@ -0,0 +1,10 @@ + +from torch import nn +import torch + +if __name__ == '__main__': + input = torch.randn(3, 5) + target = torch.empty(3, dtype=torch.long).random_(5) + loss = nn.CrossEntropyLoss() + output = loss(input, target) + print(output)
\ No newline at end of file |
