@@ -284,7 +284,6 @@ def sync_data(self, data: Dict[str, torch.Tensor]) -> None:
284284 ray_broadcast_tensor_dict (data , src = 0 , device = self .device , group_name = f"sync_data_{ self .producer_idx } " )
285285
286286 def loop (self ) -> None :
287- # breakpoint()
288287 self .sync_model (0 , 0 )
289288 num_update_per_episode = len (self .train_dataloader ) // self .num_microbatches
290289 num_valid_microbatches = num_update_per_episode * self .num_microbatches
@@ -620,10 +619,10 @@ async def generate(self, input_ids, attention_mask, **kwargs):
620619 rollouts = await asyncio .gather (* tasks )
621620 rollouts = {
622621 k : (
623- torch .cat ([r [k ] for r in rollouts ], dim = 0 )
622+ torch .cat ([r [k ] for r in rollouts ], dim = 0 ). cpu ()
624623 if k not in ["gt_answer" , "test_cases" ]
625624 else [r [k ] for r in rollouts ]
626- ). cpu () # CUDA tensor is not serializable by ray
625+ ) # CUDA tensor is not serializable by ray
627626 for k in rollouts [0 ].keys ()
628627 }
629628 rollouts ["consumer_global_step" ] = self .consumer_global_step
@@ -758,8 +757,8 @@ async def loop(self) -> None:
758757 self .eval_mode = False
759758 self .latest_eval_step = self .consumer_global_step
760759 self .profiler .enter ("rollout" )
761- # breakpoint()
762760 outputs = await self .rollout (** batch )
761+ outputs = {k : v .to (self .device ) if isinstance (v , torch .Tensor ) else v for k , v in outputs .items ()}
763762 self .profiler .exit ("rollout" )
764763 outputs ["temperature" ] = torch .tensor (
765764 [self .model .generate_config ["temperature" ]] * outputs ["input_ids" ].size (0 )
@@ -803,6 +802,8 @@ async def loop(self) -> None:
803802 outputs .pop ("gt_answer" )
804803 if "test_cases" in outputs :
805804 outputs .pop ("test_cases" )
805+ if "consumer_global_step" in outputs :
806+ outputs .pop ("consumer_global_step" )
806807 self .profiler .exit ("calculate_reward" )
807808
808809 print (f"[P{ self .producer_idx } ] Send data { [(k , v .shape ) for k , v in outputs .items ()]} " )
0 commit comments