Skip to content

Commit 8cf22c4

Browse files
committed
fix status bar when restoring from checkpoint
1 parent 8364819 commit 8cf22c4

1 file changed

Lines changed: 25 additions & 12 deletions

File tree

src/gui/controller.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ def _train_succeeded(logdir, kind, model, reftime):
10751075
if "labels.txt" not in train_files:
10761076
bokehlog.info("ERROR: "+train_dir+os.sep+"labels.txt does not exist.")
10771077
return False
1078-
validate_step_period = save_step_period = how_many_training_steps = None
1078+
validate_step_period = save_step_period = how_many_training_steps = start_checkpoint = None
10791079
with open(train_dir+".log") as fid:
10801080
for line in fid:
10811081
if "validate_step_period = " in line:
@@ -1087,22 +1087,35 @@ def _train_succeeded(logdir, kind, model, reftime):
10871087
if "how_many_training_steps = " in line:
10881088
m=re.search('how_many_training_steps = (\d+)',line)
10891089
how_many_training_steps = int(m.group(1))
1090+
if "start_checkpoint = " in line:
1091+
m=re.search('start_checkpoint = .*ckpt-(\d+)',line)
1092+
if m: start_checkpoint = int(m.group(1))
10901093
if validate_step_period is None or save_step_period is None or how_many_training_steps is None:
10911094
bokehlog.info("ERROR: "+train_dir+".log should contain `validate_step_period`, `save_step_period`, and `how_many_training_steps`")
10921095
return False
1096+
if save_step_period != validate_step_period:
1097+
bokehlog.info("ERROR: `save_step_period` is not the same as `validate_step_period`")
1098+
return False
10931099
if save_step_period>0:
1094-
nckpts = how_many_training_steps // save_step_period + 1
1095-
if len(list(filter(lambda x: x.startswith("ckpt-"), \
1096-
train_files))) != 2*nckpts:
1097-
bokehlog.info("ERROR: "+train_dir+os.sep+" should contain "+ \
1098-
str(2*nckpts)+" ckpt-* files.")
1100+
start = save_step_period
1101+
if start_checkpoint: start += start_checkpoint
1102+
ckpts_config = set(range(start, how_many_training_steps+1, save_step_period))
1103+
ckpts_index = set()
1104+
ckpts_data = set()
1105+
logits_saved = set()
1106+
for train_file in train_files:
1107+
index = re.fullmatch("ckpt-([0-9]+)\.index", train_file)
1108+
if index: ckpts_index |= set([int(index[1])])
1109+
data = re.match("ckpt-([0-9]+)\.data", train_file)
1110+
if data: ckpts_data |= set([int(data[1])])
1111+
logits = re.fullmatch("logits.validation.ckpt-([0-9]+)\.npz", train_file)
1112+
if logits: logits_saved |= set([int(logits[1])])
1113+
ckpts_saved = ckpts_index & ckpts_data
1114+
if ckpts_config - ckpts_saved:
1115+
bokehlog.info("ERROR: "+train_dir+os.sep+" is missing ckpt files for steps "+str(ckpts_config-ckpts_saved))
10991116
return False
1100-
if validate_step_period>0:
1101-
nevals = how_many_training_steps // validate_step_period
1102-
if len(list(filter(lambda x: x.startswith("logits.validation.ckpt-"), \
1103-
train_files))) != nevals:
1104-
bokehlog.info("ERROR: "+train_dir+os.sep+" should contain "+str(nevals)+\
1105-
" logits.validation.ckpt-* files.")
1117+
if ckpts_config - logits_saved:
1118+
bokehlog.info("ERROR: "+train_dir+os.sep+" is missing logit.validation files for steps "+str(ckpts_config-logits_saved))
11061119
return False
11071120
return True
11081121

0 commit comments

Comments
 (0)