summaryrefslogtreecommitdiff
path: root/learn_torch/basics/crossentropy_loss.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-07-24 20:25:48 +0800
committerzhang <zch921005@126.com>2022-07-24 20:25:48 +0800
commit92d3bc06bad13095df6515111bba45e73f701018 (patch)
tree5730478fe92b39f3b909843546291d0eced774d0 /learn_torch/basics/crossentropy_loss.py
parente9945ee44d8c46d93d50f023f49e79f3ba532583 (diff)
wordpiece
Diffstat (limited to 'learn_torch/basics/crossentropy_loss.py')
-rw-r--r--learn_torch/basics/crossentropy_loss.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/learn_torch/basics/crossentropy_loss.py b/learn_torch/basics/crossentropy_loss.py
new file mode 100644
index 0000000..67ccfa5
--- /dev/null
+++ b/learn_torch/basics/crossentropy_loss.py
@@ -0,0 +1,10 @@
+
+from torch import nn
+import torch
+
+if __name__ == '__main__':
+ input = torch.randn(3, 5)
+ target = torch.empty(3, dtype=torch.long).random_(5)
+ loss = nn.CrossEntropyLoss()
+ output = loss(input, target)
+ print(output) \ No newline at end of file