@@ -56,10 +56,8 @@ def tokenize_sft(
5656 template .messages = []
5757 for idx , mess in enumerate (messages ):
5858 if mess ["from" ] != template .roles [idx % 2 ]:
59- raise ValueError (
60- f"Message should iterate between user and assistant and starts with a \
61- line from the user. Got the following data:\n { messages } "
62- )
59+ raise ValueError (f"Message should iterate between user and assistant and starts with a \
60+ line from the user. Got the following data:\n { messages } " )
6361 template .append_message (mess ["from" ], mess ["content" ])
6462
6563 if len (template .messages ) % 2 != 0 :
@@ -245,10 +243,8 @@ def tokenize_rlhf(
245243
246244 for idx , mess in enumerate (context ):
247245 if mess ["from" ] != template .roles [idx % 2 ]:
248- raise ValueError (
249- f"Message should iterate between user and assistant and starts with a \
250- line from the user. Got the following data:\n { context } "
251- )
246+ raise ValueError (f"Message should iterate between user and assistant and starts with a \
247+ line from the user. Got the following data:\n { context } " )
252248 template .append_message (mess ["from" ], mess ["content" ])
253249
254250 if len (template .messages ) % 2 != 1 :
@@ -272,18 +268,14 @@ def tokenize_rlhf(
272268 rejected_continuation = data_point ["rejected" ]
273269 for round in range (len (chosen_continuation )):
274270 if chosen_continuation [round ]["from" ] != template .roles [(round + 1 ) % 2 ]:
275- raise ValueError (
276- f"Message should iterate between user and assistant and starts with a \
277- line from the user. Got the following data:\n { chosen_continuation } "
278- )
271+ raise ValueError (f"Message should iterate between user and assistant and starts with a \
272+ line from the user. Got the following data:\n { chosen_continuation } " )
279273 chosen .append_message (chosen_continuation [round ]["from" ], chosen_continuation [round ]["content" ])
280274
281275 for round in range (len (rejected_continuation )):
282276 if rejected_continuation [round ]["from" ] != template .roles [(round + 1 ) % 2 ]:
283- raise ValueError (
284- f"Message should iterate between user and assistant and starts with a \
285- line from the user. Got the following data:\n { rejected_continuation } "
286- )
277+ raise ValueError (f"Message should iterate between user and assistant and starts with a \
278+ line from the user. Got the following data:\n { rejected_continuation } " )
287279 rejected .append_message (rejected_continuation [round ]["from" ], rejected_continuation [round ]["content" ])
288280
289281 (
@@ -296,14 +288,14 @@ def tokenize_rlhf(
296288 ) = (None , None , None , None , None , None )
297289
298290 chosen_data_packed = apply_rlhf_data_format (chosen , tokenizer )
299- ( chosen_input_ids , chosen_loss_mask , chosen_label_decode ) = (
291+ chosen_input_ids , chosen_loss_mask , chosen_label_decode = (
300292 chosen_data_packed ["input_ids" ],
301293 chosen_data_packed ["loss_mask" ],
302294 chosen_data_packed ["label_decode" ],
303295 )
304296
305297 rejected_data_packed = apply_rlhf_data_format (rejected , tokenizer )
306- ( rejected_input_ids , rejected_loss_mask , rejected_label_decode ) = (
298+ rejected_input_ids , rejected_loss_mask , rejected_label_decode = (
307299 rejected_data_packed ["input_ids" ],
308300 rejected_data_packed ["loss_mask" ],
309301 rejected_data_packed ["label_decode" ],
0 commit comments