diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 1a41020a..1fec56f5 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -168,6 +168,11 @@ async def schedule_rollouts( retry_max_delay_s = float(getattr(cfg.actor, "rollout_retry_max_delay_s", 30.0)) def is_trainer_finished() -> bool: + # Fast-LLM ignores `gradient_accumulation_passes` and overshoots `docs_per_step` + # by a few docs per step, so the sample-counting formula below fires several + # optimizer steps early. Use the explicit `training_finished` event instead. + if cfg.use_fast_llm: + return trainer_state.training_done return ( trainer_state.samples_processed is not None and trainer_state.samples_processed >= samples_target @@ -609,9 +614,18 @@ def _run(self, dataset: list[tuple[str, dict]]): # the user function must do next(...) to run each iteration yield - final_steps = calculate_train_steps(self.cfg.finetune, self.cfg.finetune.interrupt_train_steps) - samples_target = final_steps * self.cfg.finetune.train_batch_size * self.cfg.finetune.gradient_accumulation_passes - if self.trainer_state.samples_processed is not None and self.trainer_state.samples_processed >= samples_target: + # Mirror `is_trainer_finished` (above): use the explicit training_done + # event under Fast-LLM; fall back to sample counting for HF/DeepSpeed. + if self.cfg.use_fast_llm: + trainer_finished = self.trainer_state.training_done + else: + final_steps = calculate_train_steps(self.cfg.finetune, self.cfg.finetune.interrupt_train_steps) + samples_target = final_steps * self.cfg.finetune.train_batch_size * self.cfg.finetune.gradient_accumulation_passes + trainer_finished = ( + self.trainer_state.samples_processed is not None + and self.trainer_state.samples_processed >= samples_target + ) + if trainer_finished: logger.info("Trainer signalled completion; stopping actor loop") break