Skip to content

Commit

Permalink
Add memory limit to distillation cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
armbues committed Jan 8, 2025
1 parent 7cf9dc2 commit aa7def0
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions sillm/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
parser.add_argument("--lora_dropout", default=0.0, type=int, help="Dropout to use for LoRA (default: 0.0)")
parser.add_argument("--lora_scale", default=10.0, type=float, help="Scale to use for LoRA (default: 10.0)")
parser.add_argument("--optimizer", type=str, default="adam", help="Optimizer type (default: adam)")
parser.add_argument("--loss_alpha", default=0.5, type=float, help="Distillation loss alpha (default: 0.5)")
parser.add_argument("--grad_checkpoint", default=False, action="store_true", help="Use gradient checkpointing")
parser.add_argument("--grad_accu_steps", type=int, default=1, help="Gradient accumulation steps (default: 1)")
parser.add_argument("--learning_rate", default=1e-5, type=float, help="Learning rate (default: 1e-5)")
Expand All @@ -32,7 +33,8 @@
parser.add_argument("--report_steps", default=10, type=int, help="Number of batch iterations per training report (default: 10)")
parser.add_argument("--eval_steps", default=100, type=int, help="Number of batch iterations per evaluation (default: 100)")
parser.add_argument("--validation_samples", default=40, type=int, help="Number of validation_samples (default: 40)")
parser.add_argument("--loss_alpha", default=0.5, type=float, help="Distillation loss alpha (default: 0.5)")
parser.add_argument("--memory_limit", default=None, type=float, help="Memory limit for training (default: None)")
parser.add_argument("--relax_memory_limit", default=False, action="store_true", help="Relax memory limit for training")
parser.add_argument("-v", "--verbose", default=1, action="count", help="Increase output verbosity")
args = parser.parse_args()

Expand All @@ -48,6 +50,10 @@
if log_level <= 10:
utils.log_arguments(args.__dict__)

# Set memory limit
if args.memory_limit is not None:
utils.set_memory_limit(args.memory_limit, relaxed=args.relax_memory_limit)

# Set random seed
if args.seed >= 0:
utils.seed(args.seed)
Expand Down Expand Up @@ -103,6 +109,6 @@
"validation_samples": args.validation_samples,
}
target_model.train(dataset_training,
dataset_validation,
dataset_test,
**training_config)
dataset_validation,
dataset_test,
**training_config)

0 comments on commit aa7def0

Please sign in to comment.