diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index 87ef3c52db..dd3c96e358 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -443,6 +443,8 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing + if max_steps <= warmup_steps: + raise ValueError(f"max_steps ({max_steps}) must be greater than warmup_steps ({warmup_steps})") scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index be7f72e376..7193034c59 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -466,6 +466,8 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing + if max_steps <= warmup_steps: + raise ValueError(f"max_steps ({max_steps}) must be greater than warmup_steps ({warmup_steps})") scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index e13fa3f90a..fd0e61f603 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -414,6 +414,8 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing + if max_steps <= warmup_steps: + raise ValueError(f"max_steps ({max_steps}) must be greater than warmup_steps ({warmup_steps})") scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index fbecf5a815..2180322085 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -493,6 +493,8 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing + if max_steps <= warmup_steps: + raise ValueError(f"max_steps ({max_steps}) must be greater than warmup_steps ({warmup_steps})") scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) diff --git a/litgpt/finetune/lora_legacy.py b/litgpt/finetune/lora_legacy.py index 96e9c56eb0..da2eb69597 100644 --- a/litgpt/finetune/lora_legacy.py +++ b/litgpt/finetune/lora_legacy.py @@ -474,6 +474,8 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing + if max_steps <= warmup_steps: + raise ValueError(f"max_steps ({max_steps}) must be greater than warmup_steps ({warmup_steps})") scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])