-
Notifications
You must be signed in to change notification settings - Fork 28
[feat] Resume from ckpt #135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 61 commits
Commits
Show all changes
66 commits
Select commit
Hold shift + click to select a range
5cd3c0f
docs: add transformers resume design spec
kevssim 91eeaeb
docs: refine transformers resume design spec
kevssim 6eebda8
docs: trim resume state fields
kevssim cdd9c1b
docs: add npu resume compatibility requirements
kevssim 1542492
chore: ignore local worktrees
kevssim 9883118
wip
kevssim d41a634
wip
kevssim 21f9918
wip
kevssim 1e59531
fix
kevssim 9bb3f39
wip
kevssim fdf1f71
fix
kevssim 6cf5160
wip
kevssim 144ffe6
Merge branch 'modelscope:main' into resume_from_ckpt
kevssim e21f870
lint
kevssim 3359209
Merge branch 'resume_from_ckpt' of https://github.com/kevssim/twinkle…
kevssim 70ebe50
wip
kevssim 483778d
wip
kevssim 039789b
wip
kevssim 54de1a4
wip
kevssim 920ab86
wip
kevssim ffd6304
lint
kevssim 582bd41
wip
kevssim 9cb6106
wip
kevssim c0cf72e
wip
kevssim 505a75c
wip
kevssim a222b5b
fix
kevssim 7499e00
wip
kevssim cd0b094
doc
kevssim abf2c2f
wip
kevssim 8bf7a6a
lint
kevssim 27e76c6
Merge remote-tracking branch 'origin/main' into resume_from_ckpt
kevssim 5d68910
Merge remote-tracking branch 'origin' into resume_from_ckpt
kevssim 9326e64
wip
kevssim 670f0c1
feat: add resume_from_checkpoint abstract method to TwinkleModel base
kevssim 784730c
feat(dataloader): add resume_from_checkpoint wrapping skip_consumed_s…
kevssim 3db38e9
feat(transformers): replace load_training_state/read_training_progres…
kevssim 94679d5
feat(megatron): add resume_from_checkpoint and save trainer_state.json
kevssim 832ce87
refactor(cookbook): use model.resume_from_checkpoint API
kevssim e3a3cd6
feat(types): replace training state request types with ResumeFromChec…
kevssim a3effab
feat(server): replace training state endpoints with /resume_from_chec…
kevssim 383336d
feat(client): replace training state methods with resume_from_checkpoint
kevssim 54a1db6
docs: update checkpoint/resume documentation for unified API
kevssim 597cbd9
fix: remove stale load_training_state references from __init__.py, mu…
kevssim c55ab9f
fix(transformers): pass correct file paths to _load_scaler_state and …
kevssim 8f76b7b
fix: guard rng_state.pt existence check, add Config extra=allow to Re…
kevssim 4ffa5c7
wip
kevssim 0b43055
wip
kevssim c8bc9ab
wip
kevssim 8c0399e
wip
kevssim 94af275
Merge remote-tracking branch 'origin/main' into resume_from_ckpt
kevssim 10b4a20
refactor: delete resume_utils.py, inline logic in fsdp2.py, update docs
kevssim 3df191a
wip
kevssim deeb648
Merge remote-tracking branch 'origin/main' into resume_from_ckpt
kevssim 7a657e8
wip
kevssim ae67122
fix
kevssim 5b15d67
lint
kevssim f0d36e2
remove
kevssim d0219df
wip
kevssim 9d5327d
update
kevssim 85b7cf8
doc
kevssim 9af73bc
fix
kevssim 482a451
fix doc
kevssim 2396419
fix
kevssim 9a6fbb9
lint
kevssim daa9202
update cookbook
kevssim a75f8b1
fix
kevssim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| from pathlib import Path | ||
|
|
||
| from peft import LoraConfig | ||
| from tqdm import tqdm | ||
|
|
||
| import twinkle | ||
| from twinkle import DeviceMesh, get_device_placement, get_logger | ||
| from twinkle.dataloader import DataLoader | ||
| from twinkle.dataset import Dataset, DatasetMeta | ||
| from twinkle.model import MegatronModel | ||
| from twinkle.preprocessor import SelfCognitionProcessor | ||
|
|
||
| logger = get_logger() | ||
|
|
||
| MODEL_ID = 'ms://Qwen/Qwen3.5-4B' | ||
| DATASET_ID = 'ms://swift/self-cognition' | ||
| TEMPLATE_NAME = 'Qwen3_5Template' | ||
| MODEL_NAME = 'twinkle大模型' | ||
| MODEL_AUTHOR = 'ModelScope社区' | ||
| DP_SIZE = 2 | ||
| TP_SIZE = 2 | ||
| PP_SIZE = 2 | ||
| BATCH_SIZE = 16 | ||
| LEARNING_RATE = 1e-4 | ||
| LOG_INTERVAL = 5 | ||
| EVAL_INTERVAL = 20 | ||
| EVAL_SAMPLES = 100 | ||
| TRAIN_SAMPLES = 1000 | ||
|
|
||
| OUTPUT_DIR = './output/megatron_tp' | ||
| RESUME_FROM_CHECKPOINT = None | ||
| RESUME_ONLY_MODEL = False | ||
| IGNORE_DATA_SKIP = False | ||
| ADAPTER_NAME = 'default' | ||
|
|
||
| device_mesh = DeviceMesh.from_sizes(dp_size=DP_SIZE, tp_size=TP_SIZE, pp_size=PP_SIZE) | ||
| twinkle.initialize(mode='local', global_device_mesh=device_mesh) | ||
|
|
||
|
|
||
| def build_dataset(num_samples: int) -> Dataset: | ||
| dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) | ||
| dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) | ||
| dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) | ||
| dataset.encode() | ||
| return dataset | ||
|
|
||
|
|
||
| def save_checkpoint(model: MegatronModel, checkpoint_name: str, dataloader: DataLoader): | ||
| model.save( | ||
| checkpoint_name, | ||
| output_dir=OUTPUT_DIR, | ||
| adapter_name=ADAPTER_NAME, | ||
| save_optimizer=True, | ||
| consumed_train_samples=dataloader.get_state()['consumed_train_samples'], | ||
| ) | ||
|
|
||
|
|
||
| def evaluate(model): | ||
| dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) | ||
| for batch in tqdm(dataloader): | ||
| model.forward_only(inputs=batch) | ||
| return model.calculate_metric(is_training=False) | ||
|
|
||
|
|
||
| def train(): | ||
| dataset = build_dataset(TRAIN_SAMPLES) | ||
| dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) | ||
|
|
||
| model = MegatronModel(model_id=MODEL_ID) | ||
|
|
||
| lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') | ||
|
|
||
| # Add a lora to model, with name `default` | ||
| # Comment this to use full-parameter training | ||
| model.add_adapter_to_model(ADAPTER_NAME, lora_config) | ||
| model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE) | ||
| model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) | ||
|
|
||
| if RESUME_FROM_CHECKPOINT: | ||
| checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() | ||
| kwargs = {} | ||
| if ADAPTER_NAME: | ||
| kwargs['adapter_name'] = ADAPTER_NAME | ||
| progress = model.resume_from_checkpoint( | ||
| str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) | ||
| if not IGNORE_DATA_SKIP: | ||
| dataloader.resume_from_checkpoint(progress['consumed_train_samples']) | ||
|
|
||
| logger.info(get_device_placement()) | ||
| logger.info(model.get_train_configs()) | ||
| logger.info(f'Total steps: {len(dataloader)}') | ||
|
|
||
| best_loss = float('inf') | ||
|
|
||
| for step, batch in enumerate(dataloader): | ||
| model.forward_backward(inputs=batch) | ||
| model.clip_grad_and_step() | ||
| if step % LOG_INTERVAL == 0: | ||
| metric = model.calculate_metric(is_training=True) | ||
| logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') | ||
| if step > 0 and step % EVAL_INTERVAL == 0: | ||
| metrics = evaluate(model) | ||
| logger.info(f'Eval metric: {metrics}') | ||
| metrics['step'] = step | ||
| current_loss = float(metrics['loss']) | ||
| if current_loss < best_loss: | ||
| save_checkpoint(model, f'checkpoint-{step}', dataloader) | ||
| best_loss = current_loss | ||
| save_checkpoint(model, 'last-checkpoint', dataloader) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| train() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.