Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions pipelinerl/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ async def schedule_rollouts(
retry_max_delay_s = float(getattr(cfg.actor, "rollout_retry_max_delay_s", 30.0))

def is_trainer_finished() -> bool:
# Fast-LLM ignores `gradient_accumulation_passes` and overshoots `docs_per_step`
# by a few docs per step, so the sample-counting formula below fires several
# optimizer steps early. Use the explicit `training_finished` event instead.
if cfg.use_fast_llm:
return trainer_state.training_done
return (
trainer_state.samples_processed is not None
and trainer_state.samples_processed >= samples_target
Expand Down Expand Up @@ -609,9 +614,18 @@ def _run(self, dataset: list[tuple[str, dict]]):
# the user function must do next(...) to run each iteration
yield

final_steps = calculate_train_steps(self.cfg.finetune, self.cfg.finetune.interrupt_train_steps)
samples_target = final_steps * self.cfg.finetune.train_batch_size * self.cfg.finetune.gradient_accumulation_passes
if self.trainer_state.samples_processed is not None and self.trainer_state.samples_processed >= samples_target:
# Mirror `is_trainer_finished` (above): use the explicit training_done
# event under Fast-LLM; fall back to sample counting for HF/DeepSpeed.
if self.cfg.use_fast_llm:
trainer_finished = self.trainer_state.training_done
else:
final_steps = calculate_train_steps(self.cfg.finetune, self.cfg.finetune.interrupt_train_steps)
samples_target = final_steps * self.cfg.finetune.train_batch_size * self.cfg.finetune.gradient_accumulation_passes
trainer_finished = (
self.trainer_state.samples_processed is not None
and self.trainer_state.samples_processed >= samples_target
)
if trainer_finished:
logger.info("Trainer signalled completion; stopping actor loop")
break

Expand Down