@@ -1075,7 +1075,7 @@ def _train_succeeded(logdir, kind, model, reftime):
10751075 if "labels.txt" not in train_files :
10761076 bokehlog .info ("ERROR: " + train_dir + os .sep + "labels.txt does not exist." )
10771077 return False
1078- validate_step_period = save_step_period = how_many_training_steps = None
1078+ validate_step_period = save_step_period = how_many_training_steps = start_checkpoint = None
10791079 with open (train_dir + ".log" ) as fid :
10801080 for line in fid :
10811081 if "validate_step_period = " in line :
@@ -1087,22 +1087,35 @@ def _train_succeeded(logdir, kind, model, reftime):
10871087 if "how_many_training_steps = " in line :
10881088 m = re .search ('how_many_training_steps = (\d+)' ,line )
10891089 how_many_training_steps = int (m .group (1 ))
1090+ if "start_checkpoint = " in line :
1091+ m = re .search ('start_checkpoint = .*ckpt-(\d+)' ,line )
1092+ if m : start_checkpoint = int (m .group (1 ))
10901093 if validate_step_period is None or save_step_period is None or how_many_training_steps is None :
10911094 bokehlog .info ("ERROR: " + train_dir + ".log should contain `validate_step_period`, `save_step_period`, and `how_many_training_steps`" )
10921095 return False
1096+ if save_step_period != validate_step_period :
1097+ bokehlog .info ("ERROR: `save_step_period` is not the same as `validate_step_period`" )
1098+ return False
10931099 if save_step_period > 0 :
1094- nckpts = how_many_training_steps // save_step_period + 1
1095- if len (list (filter (lambda x : x .startswith ("ckpt-" ), \
1096- train_files ))) != 2 * nckpts :
1097- bokehlog .info ("ERROR: " + train_dir + os .sep + " should contain " + \
1098- str (2 * nckpts )+ " ckpt-* files." )
1100+ start = save_step_period
1101+ if start_checkpoint : start += start_checkpoint
1102+ ckpts_config = set (range (start , how_many_training_steps + 1 , save_step_period ))
1103+ ckpts_index = set ()
1104+ ckpts_data = set ()
1105+ logits_saved = set ()
1106+ for train_file in train_files :
1107+ index = re .fullmatch ("ckpt-([0-9]+)\.index" , train_file )
1108+ if index : ckpts_index |= set ([int (index [1 ])])
1109+ data = re .match ("ckpt-([0-9]+)\.data" , train_file )
1110+ if data : ckpts_data |= set ([int (data [1 ])])
1111+ logits = re .fullmatch ("logits.validation.ckpt-([0-9]+)\.npz" , train_file )
1112+ if logits : logits_saved |= set ([int (logits [1 ])])
1113+ ckpts_saved = ckpts_index & ckpts_data
1114+ if ckpts_config - ckpts_saved :
1115+ bokehlog .info ("ERROR: " + train_dir + os .sep + " is missing ckpt files for steps " + str (ckpts_config - ckpts_saved ))
10991116 return False
1100- if validate_step_period > 0 :
1101- nevals = how_many_training_steps // validate_step_period
1102- if len (list (filter (lambda x : x .startswith ("logits.validation.ckpt-" ), \
1103- train_files ))) != nevals :
1104- bokehlog .info ("ERROR: " + train_dir + os .sep + " should contain " + str (nevals )+ \
1105- " logits.validation.ckpt-* files." )
1117+ if ckpts_config - logits_saved :
1118+ bokehlog .info ("ERROR: " + train_dir + os .sep + " is missing logit.validation files for steps " + str (ckpts_config - logits_saved ))
11061119 return False
11071120 return True
11081121
0 commit comments