summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/train.py b/train.py
index 9506d76..6bad204 100644
--- a/train.py
+++ b/train.py
@@ -137,7 +137,8 @@ def main():
with torch.no_grad():
gen_ids = accelerator.unwrap_model(model).generate(**enc,
- max_new_tokens=512, do_sample=True,
+ max_new_tokens=512,
+ do_sample=True,
top_p=0.95,
temperature=args.sample_temp,
synced_gpus=True,