33# plot accuracy across hyperparameter values
44
55# e.g. compare <logdirs-prefix>
6- # --logdirs_prefix =trained- \
6+ # --logdirs_filter =trained- \
77# --loss=exclusive \
88# --overlapped_prefix=not_
99
@@ -33,8 +33,9 @@ def main():
3333 for key in sorted (flags .keys ()):
3434 print ('%s = %s' % (key , flags [key ]))
3535
36- logdirs_prefix = FLAGS .logdirs_prefix
37- basename , dirname = os .path .split (logdirs_prefix )
36+ logdirs_filter = FLAGS .logdirs_filter
37+ logdirs_dirname , logdirs_basename = os .path .split (logdirs_filter )
38+ indepvar , * filters = logdirs_basename .split ('_' )
3839
3940 same_time = False
4041 outlier_criteria = 50
@@ -51,20 +52,24 @@ def main():
5152 nlayers = {}
5253 hyperparameters = {}
5354
54- logdirs = list (filter (lambda x : x .startswith (dirname + '-' ) and \
55- os .path .isdir (os .path .join (basename ,x )), os .listdir (basename )))
55+ def filter_logdirs (logdir ):
56+ params = logdir .split ('_' )
57+ return all ([f in params for f in filters ]) and \
58+ any ([p .startswith (indepvar ) for p in params ])
59+
60+ logdirs = list (filter (filter_logdirs , os .listdir (logdirs_dirname )))
5661
5762 for logdir in logdirs :
5863 print (logdir )
59- hyperparameters [logdir ] = set (logdir .split ('-' )[ - 1 ]. split ( ' _' ))
64+ hyperparameters [logdir ] = set (logdir .split ('_' ))
6065 _ , _ , train_time [logdir ], _ , \
6166 _ , _ , validation_precision [logdir ], validation_recall [logdir ], \
6267 validation_time [logdir ], validation_step [logdir ], \
6368 _ , _ , _ , _ , \
6469 labels_touse [logdir ], _ , \
6570 nparameters_total [logdir ], nparameters_finallayer [logdir ], \
6671 batch_size [logdir ], nlayers [logdir ] = \
67- read_logs (os .path .join (basename ,logdir ))
72+ read_logs (os .path .join (logdirs_dirname ,logdir ))
6873 if len (set ([tuple (x ) for x in labels_touse [logdir ].values ()]))> 1 :
6974 print ('WARNING: not all labels_touse are the same' )
7075 if len (set (nparameters_total [logdir ].values ()))> 1 :
@@ -89,15 +94,20 @@ def main():
8994
9095 commonparameters = reduce (lambda x ,y : x & y , hyperparameters .values ())
9196 differentparameters = {x :',' .join (natsorted (list (hyperparameters [x ]- commonparameters ))) \
92- for x in natsorted (logdirs )}
97+ for x in logdirs }
98+
9399
100+ def sortby_indepvar (logdir ):
101+ params = logdir .split ('_' )
102+ iindepvar = next (i for i ,x in enumerate (params ) if x .startswith (indepvar ))
103+ return str (params [iindepvar ]) + str (params [:iindepvar ]) + str (params [iindepvar + 1 :])
94104
95105 fig = plt .figure (figsize = (8 ,10 * 2 / 3 ))
96106
97107 ax = fig .add_subplot (2 ,2 ,1 )
98108
99109 precisions_mean , recalls_mean = [], []
100- for (ilogdir ,logdir ) in enumerate (natsorted (logdirs )):
110+ for (ilogdir ,logdir ) in enumerate (natsorted (logdirs , key = sortby_indepvar )):
101111 color = cm .viridis (ilogdir / max (1 ,len (validation_recall )- 1 ))
102112 precisions_all , recalls_all = [], []
103113 for model in validation_recall [logdir ].keys ():
@@ -111,7 +121,7 @@ def main():
111121
112122 ax = fig .add_subplot (2 ,2 ,2 )
113123 bottom = 100
114- for (iexpt ,expt ) in enumerate (natsorted (validation_recall .keys ())):
124+ for (iexpt ,expt ) in enumerate (natsorted (validation_recall .keys (), key = sortby_indepvar )):
115125 color = cm .viridis (iexpt / max (1 ,len (validation_recall )- 1 ))
116126 validation_recall_average = np .zeros (len (next (iter (validation_recall [expt ].values ()))))
117127 for model in validation_time [expt ].keys ():
@@ -127,45 +137,45 @@ def main():
127137 ax .set_ylim (bottom = bottom - 5 , top = 100 )
128138 ax .set_xlabel ('Training time (min)' )
129139 ax .set_ylabel ('Overall validation recall' )
130- ax .legend (loc = 'lower right' , title = dirname , ncol = 2 if "Annotations" in dirname else 1 )
140+ ax .legend (loc = 'lower right' , ncol = 2 if "Annotations" in logdirs_basename else 1 )
131141
132142 ax = fig .add_subplot (2 ,2 ,3 )
133- ldata = natsorted (nparameters_total .keys ())
143+ ldata = natsorted (nparameters_total .keys (), key = sortby_indepvar )
134144 xdata = range (len (ldata ))
135145 ydata = [next (iter (nparameters_total [x ].values ())) - \
136146 next (iter (nparameters_finallayer [x ].values ())) for x in ldata ]
137147 ydata2 = [next (iter (nparameters_finallayer [x ].values ())) for x in ldata ]
138148 bar1 = ax .bar (xdata ,ydata ,color = 'k' )
139149 bar2 = ax .bar (xdata ,ydata2 ,bottom = ydata ,color = 'gray' )
140150 ax .legend ((bar2 ,bar1 ), ('last' ,'rest' ))
141- ax .set_xlabel (dirname )
151+ ax .set_xlabel (logdirs_basename )
142152 ax .set_ylabel ('Trainable parameters' )
143153 ax .set_xticks (xdata )
144154 ax .set_xticklabels ([differentparameters [x ] for x in ldata ], rotation = 40 , ha = 'right' )
145155
146156 ax = fig .add_subplot (2 ,2 ,4 )
147- data = {k :list ([np .median (np .diff (x )) for x in train_time [k ].values ()]) for k in train_time }
157+ data = {k :list ([np .median (np .diff (x )) for x in train_time [k ].values ()])
158+ for k in sorted (train_time .keys (), key = sortby_indepvar )}
148159 ldata = jitter_plot (ax , data )
149160 ax .set_ylabel ('time / step (ms)' )
150- ax .set_xlabel (dirname )
161+ ax .set_xlabel (logdirs_basename )
151162 ax .set_xticks (range (len (ldata )))
152163 ax .set_xticklabels ([differentparameters [x ] for x in ldata ], rotation = 40 , ha = 'right' )
153164
154- fig .suptitle (',' .join (list (commonparameters )))
155-
156- fig .tight_layout (rect = [0 , 0.03 , 1 , 0.95 ])
157- plt .savefig (logdirs_prefix + '-compare-overall-params-speed.pdf' )
165+ fig .suptitle (',' .join (list (commonparameters )), fontsize = 'xx-large' )
166+ fig .tight_layout (rect = [0 , 0.03 , 1 , 0.97 ])
167+ plt .savefig (logdirs_filter + '-compare-overall-params-speed.pdf' )
158168 plt .close ()
159169
160170
161171 recall_confusion_matrices = {}
162172 precision_confusion_matrices = {}
163173 labels = None
164174
165- for ilogdir ,logdir in enumerate (natsorted (logdirs )):
175+ for ilogdir ,logdir in enumerate (natsorted (logdirs , key = sortby_indepvar )):
166176 kind = next (iter (validation_time [logdir ].keys ())).split ('_' )[0 ]
167177 confusion_matrices , theselabels = \
168- parse_confusion_matrices (os .path .join (basename ,logdir ), kind , \
178+ parse_confusion_matrices (os .path .join (logdirs_dirname ,logdir ), kind , \
169179 idx_time = idx_time [logdir ] if same_time else None )
170180
171181 recall_confusion_matrices [logdir ]= {}
@@ -215,7 +225,7 @@ def main():
215225 summed2_confusion_matrix ,
216226 precision_summed_matrix , recall_summed_matrix ,
217227 len (labels )< 10 ,
218- logdir + "\n " ,
228+ differentparameters [ logdir ] + "\n " ,
219229 labels if FLAGS .loss == 'exclusive' else
220230 ["song" , FLAGS .overlapped_prefix + "song" ],
221231 precision_summed , recall_summed )
@@ -229,12 +239,13 @@ def main():
229239 summed_confusion_matrix [ilabel ], \
230240 precision_summed_matrix , recall_summed_matrix , \
231241 len (labels )< 10 ,
232- logdir + "\n " ,
242+ differentparameters [ logdir ] + "\n " ,
233243 [labels [ilabel ], FLAGS .overlapped_prefix + labels [ilabel ]],
234244 precision_summed , recall_summed )
235245
236- fig .tight_layout ()
237- plt .savefig (logdirs_prefix + '-compare-confusion-matrices.pdf' )
246+ fig .suptitle (',' .join (list (commonparameters )), fontsize = 'xx-large' )
247+ fig .tight_layout (rect = [0 , 0.03 , 1 , 0.97 ])
248+ plt .savefig (logdirs_filter + '-compare-confusion-matrices.pdf' )
238249 plt .close ()
239250
240251
@@ -245,7 +256,7 @@ def main():
245256 for (ilabel ,label ) in enumerate (labels ):
246257 ax = fig .add_subplot (nrows , ncols , ilabel + 1 )
247258 precisions_mean , recalls_mean = [], []
248- for (ilogdir ,logdir ) in enumerate (natsorted (logdirs )):
259+ for (ilogdir ,logdir ) in enumerate (natsorted (logdirs , key = sortby_indepvar )):
249260 color = cm .viridis (ilogdir / max (1 ,len (validation_recall )- 1 ))
250261 precisions_all , recalls_all = [], []
251262 for (imodel ,model ) in enumerate (recall_confusion_matrices [logdir ].keys ()):
@@ -261,14 +272,15 @@ def main():
261272 'o' , markeredgecolor = 'k' , color = color )
262273 label_precisions_recall (ax , recalls_mean , precisions_mean , label + "\n " )
263274
264- fig .tight_layout ()
265- plt .savefig (logdirs_prefix + '-compare-PR-classes.pdf' )
275+ fig .suptitle (',' .join (list (commonparameters )), fontsize = 'xx-large' )
276+ fig .tight_layout (rect = [0 , 0.03 , 1 , 0.97 ])
277+ plt .savefig (logdirs_filter + '-compare-PR-classes.pdf' )
266278 plt .close ()
267279
268280if __name__ == "__main__" :
269281 parser = argparse .ArgumentParser ()
270282 parser .add_argument (
271- '--logdirs_prefix ' ,
283+ '--logdirs_filter ' ,
272284 type = str ,
273285 default = '/tmp/speech_commands_train' ,
274286 help = 'Common prefix of the directories of logs and checkpoints' )
0 commit comments