summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/train.py b/train.py
index 6bad204..54ad33c 100644
--- a/train.py
+++ b/train.py
@@ -33,14 +33,12 @@ def parse_args():
parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
- parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps')
+ parser.add_argument('--max_steps', type=int, default=50, help='Maximum training steps')
parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter')
parser.add_argument('--run_name', type=str, default=None, help='Experiment run name')
- parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name')
+ parser.add_argument('--wandb_project', type=str, default='one-shot-em', help='W&B project name')
parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
- parser.add_argument('--enable_self_distill', action='store_true', help='Enable self-distillation MSE constraint')
- parser.add_argument('--self_distill_gamma', type=float, default=0.0, help='Self-distillation MSE weight γ')
- parser.add_argument('--seed', type=int, default=42, help='Random seed')
+ parser.add_argument('--seed', type=int, default=15, help='Random seed')
return parser.parse_args()
class FTDataset(Dataset):