Skip to content

Commit ca3985b

Browse files
committed
parallelize training
1 parent 72bee60 commit ca3985b

15 files changed

Lines changed: 155 additions & 99 deletions

File tree

configuration.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,6 @@
194194
# how many points to use for the precision-recall, sensitivity-specificity, and congruence curves
195195
nprobabilities=20
196196

197-
# used by freeze and classify to specify how many output time tics to process in parallel.
198-
# must be greater than one.
199-
classify_parallelize=64
200-
201197
# used by train, generalize, xvalidate, and activations
202198
data_loader_maxprocs=0 # 0 = num CPU cores
203199
data_loader_queuesize=1 # 0 = infinite

src/accuracy

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def plot_probability_density(test_ground_truth, test_logits, ratios, thresholds,
114114
igt = test_ground_truth==gt
115115
else:
116116
igt = test_ground_truth[:,ilabel]==gt
117-
if sum(igt)==0:
117+
if np.sum(igt)==0:
118118
continue
119119
xdata = test_logits[igt,ilabel]
120120
xdata = np.minimum(np.finfo(float).max, np.exp(xdata))
@@ -156,7 +156,6 @@ def doit(logdir, key_to_plot, ckpt, labels, nprobabilities, error_ratios, loss,
156156
'specificity-sensitivity.ckpt-'+str(ckpt)+'.pdf'))
157157
plt.close()
158158

159-
already_written = set()
160159
predictions_path = os.path.join(logdir,key_to_plot,'predictions.ckpt-'+str(ckpt))
161160
if not os.path.isdir(predictions_path):
162161
os.mkdir(os.path.join(logdir,key_to_plot,'predictions.ckpt-'+str(ckpt)))
@@ -169,20 +168,17 @@ def doit(logdir, key_to_plot, ckpt, labels, nprobabilities, error_ratios, loss,
169168
for i in range(len(validation_sounds)):
170169
if validation_sounds[i]['file'][0] != subdir:
171170
continue
172-
classified_as = np.argmax(validation_logits[i,:])
173-
id = os.path.join(*validation_sounds[i]['file']) + \
174-
str(validation_sounds[i]['ticks']) + \
175-
validation_sounds[i]['label'] + \
176-
labels[classified_as]
177-
if id in already_written:
178-
continue
179-
already_written |= set([id])
180-
csvwriter.writerow([validation_sounds[i]['file'][1],
181-
validation_sounds[i]['ticks'][0],
182-
validation_sounds[i]['ticks'][1],
183-
'correct' if classified_as == validation_ground_truth[i] else 'mistaken',
184-
labels[classified_as],
185-
validation_sounds[i]['label']])
171+
classified_as = np.argmax(validation_logits[i,:], axis=1)
172+
scores = 1+np.where(np.diff(classified_as == validation_ground_truth[i]))[0]
173+
scores = [0, *scores, len(classified_as)]
174+
for iscore in range(len(scores)-1):
175+
if validation_ground_truth[i][scores[iscore]] == -1: continue
176+
csvwriter.writerow([validation_sounds[i]['file'][1],
177+
validation_sounds[i]['offset_tic']-len(classified_as)//2+1+scores[iscore],
178+
validation_sounds[i]['offset_tic']-len(classified_as)//2+scores[iscore+1],
179+
'correct' if classified_as[scores[iscore]] == validation_ground_truth[i][scores[iscore]] else 'mistaken',
180+
labels[classified_as[scores[iscore]]],
181+
validation_sounds[i]['label']])
186182
else:
187183
for subdir in set([x[0]['file'][0] for x in validation_sounds]):
188184
with open(os.path.join(logdir, key_to_plot, 'predictions.ckpt-'+str(ckpt), \
@@ -195,13 +191,6 @@ def doit(logdir, key_to_plot, ckpt, labels, nprobabilities, error_ratios, loss,
195191
for j in range(len(validation_sounds[i])):
196192
k = labels.index(validation_sounds[i][j]['label'].removeprefix(overlapped_prefix))
197193
classified_as = validation_logits[i,k]>0.5
198-
id = os.path.join(*validation_sounds[i][j]['file']) + \
199-
str(validation_sounds[i][j]['ticks']) + \
200-
validation_sounds[i][j]['label'] + \
201-
str(classified_as)
202-
if id in already_written:
203-
continue
204-
already_written |= set([id])
205194
csvwriter.writerow([validation_sounds[i][j]['file'][1],
206195
validation_sounds[i][j]['ticks'][0],
207196
validation_sounds[i][j]['ticks'][1],

src/activations

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,19 @@ def main():
120120
'video_frame_width': FLAGS.video_frame_width,
121121
'video_frame_height': FLAGS.video_frame_height,
122122
'video_channels': [int(x)-1 for x in FLAGS.video_channels.split(',')],
123-
'parallelize': 1,
123+
'parallelize': FLAGS.parallelize,
124124
'context': FLAGS.context}
125125

126+
thismodel = model.create_model(model_settings, FLAGS.model_parameters)
127+
thismodel.summary(line_length=120, positions=[0.4,0.6,0.7,1])
128+
129+
input_shape = thismodel.input_shape
130+
clip_window_samples = input_shape[0][1] if model.use_video else input_shape[1]
131+
126132
D.init(
127133
FLAGS.data_dir,
128134
FLAGS.shiftby,
135+
clip_window_samples,
129136
FLAGS.labels_touse.split(','), FLAGS.kinds_touse.split(','),
130137
FLAGS.validation_percentage, FLAGS.validation_offset_percentage,
131138
FLAGS.validation_files.split(','),
@@ -141,9 +148,6 @@ def main():
141148
FLAGS.audio_read_plugin, FLAGS.video_read_plugin,
142149
audio_read_plugin_kwargs, video_read_plugin_kwargs)
143150

144-
thismodel = model.create_model(model_settings, FLAGS.model_parameters)
145-
thismodel.summary(line_length=120, positions=[0.4,0.6,0.7,1])
146-
147151
checkpoint = tf.train.Checkpoint(thismodel=thismodel)
148152
checkpoint.read(FLAGS.start_checkpoint).expect_partial()
149153

@@ -326,6 +330,11 @@ if __name__ == '__main__':
326330
type=float,
327331
default=1000,
328332
help='Expected duration in milliseconds of the wavs',)
333+
parser.add_argument(
334+
'--parallelize',
335+
type=int,
336+
default=64,
337+
help='how many output time tics to simultaneously process',)
329338
parser.add_argument(
330339
'--loss',
331340
type=str,

src/classify

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def main():
163163
exit()
164164

165165
if FLAGS.parallelize==1:
166-
print("WARNING: classify_parallelize in configuration.py is set to 1. making predictions is faster if it is > 1")
166+
print("WARNING: parallelize in configuration.py is set to 1. making predictions is faster if it is > 1")
167167

168168
input_shape = thismodel.get_input_shape()
169169
if use_audio:

src/data.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def which_set(filename, validation_percentage, validation_offset_percentage, tes
108108
return result
109109

110110
def init(_data_dir,
111-
shiftby,
111+
shiftby, _clip_window_samples,
112112
labels_touse, kinds_touse,
113113
validation_percentage, validation_offset_percentage, validation_files,
114114
testing_percentage, testing_files, subsample_skip, subsample_label,
@@ -121,9 +121,10 @@ def init(_data_dir,
121121
_audio_read_plugin, _video_read_plugin,
122122
_audio_read_plugin_kwargs, _video_read_plugin_kwargs):
123123

124-
global data_dir, np_rng, audio_read_plugin_kwargs, audio_read_plugin, video_read_plugin_kwargs, video_read_plugin, queue_size, max_procs
124+
global data_dir, clip_window_samples, np_rng, audio_read_plugin_kwargs, audio_read_plugin, video_read_plugin_kwargs, video_read_plugin, queue_size, max_procs
125125

126126
data_dir = _data_dir
127+
clip_window_samples = _clip_window_samples
127128
np_rng = np.random.default_rng(None if random_seed_batch==-1 else random_seed_batch)
128129

129130
audio_read_plugin = _audio_read_plugin
@@ -198,6 +199,8 @@ def prepare_data_index(shiftby,
198199
audio_tic_rate = model_settings['audio_tic_rate']
199200
time_scale = model_settings['time_scale']
200201
context_tics = int(audio_tic_rate * model_settings['context'] * time_scale)
202+
parallelize = int(model_settings['parallelize'])
203+
stride_x_downsample = (clip_window_samples - context_tics) // (parallelize-1)
201204
shiftby_tics = int(shiftby * audio_tic_rate * time_scale)
202205
audio_ntics = {}
203206
video_nframes = {}
@@ -266,19 +269,21 @@ def prepare_data_index(shiftby,
266269
f'in configuration.py but is actually {audio_data_shape[1]} in {wav_path}')
267270
audio_ntics[wav_path] = audio_data_shape[0]
268271
if use_audio:
269-
if ticks[1] < context_tics//2 + shiftby_tics or \
270-
ticks[0] > (audio_ntics[wav_path] - context_tics//2 + shiftby_tics):
272+
left_room = context_tics//2 + (parallelize//2+1)*stride_x_downsample + shiftby_tics
273+
right_room = context_tics//2 + (parallelize//2)*stride_x_downsample - shiftby_tics
274+
if ticks[1] < left_room or \
275+
ticks[0] > (audio_ntics[wav_path] - right_room):
271276
print(f"WARNING: {str(annotation)} is too close to an edge of the recording. "
272277
f"not using at all")
273278
continue
274-
if ticks[0] < context_tics//2 + shiftby_tics:
279+
if ticks[0] < left_room:
275280
print(f"WARNING: {str(annotation)} is close to beginning of recording. "
276281
f"shortening interval to usable range")
277-
ticks[0] = context_tics//2 + shiftby_tics
278-
if ticks[1] > audio_ntics[wav_path] - context_tics//2 + shiftby_tics:
282+
ticks[0] = left_room
283+
if ticks[1] > audio_ntics[wav_path] - right_room:
279284
print(f"WARNING: {str(annotation)} is close to end of recording. "
280285
f"shortening interval to usable range")
281-
ticks[1] = audio_ntics[wav_path] - context_tics//2 + shiftby_tics
286+
ticks[1] = audio_ntics[wav_path] - right_room
282287
if use_video and wav_path not in video_nframes:
283288
sound_dirname = os.path.join(data_dir, wav_base2[0])
284289
vidfile = video_findfile(sound_dirname, wavfile)
@@ -438,11 +443,14 @@ def augment(audio_slice, augmentation_parameters):
438443
video_channels = model_settings['video_channels']
439444
time_scale = model_settings['time_scale']
440445
context_tics = int(audio_tic_rate * model_settings['context'] * time_scale)
446+
parallelize = int(model_settings['parallelize'])
447+
ninput_tics = clip_window_samples
448+
stride_x_downsample = (clip_window_samples - context_tics) // (parallelize-1)
441449
shiftby_tics = int(shiftby * audio_tic_rate * time_scale)
442450
if use_audio:
443-
audio_slice = np.zeros((nsounds, context_tics, audio_nchannels), dtype=np.float32)
451+
audio_slice = np.zeros((nsounds, ninput_tics, audio_nchannels), dtype=np.float32)
444452
if use_video:
445-
nframes = round(model_settings['context'] * time_scale * video_frame_rate)
453+
nframes = round(ninput_tics / audio_tic_rate * video_frame_rate)
446454
video_slice = np.zeros((nsounds,
447455
nframes,
448456
model_settings['video_frame_height'],
@@ -451,9 +459,9 @@ def augment(audio_slice, augmentation_parameters):
451459
dtype=np.float32)
452460
bkg = {}
453461
if loss=='exclusive':
454-
labels = np.zeros(nsounds, dtype=np.int32)
462+
labels = -1*np.ones((nsounds, parallelize), dtype=np.int32)
455463
elif loss=='overlapped':
456-
labels = 2*np.ones((nsounds, len(labels_list)), dtype=np.float32)
464+
labels = 2*np.ones((nsounds, parallelize, len(labels_list)), dtype=np.float32)
457465
# repeatedly to generate the final output sound data we'll use in training.
458466
for i in range(offset, offset + nsounds):
459467
# Pick which sound to use.
@@ -466,8 +474,8 @@ def augment(audio_slice, augmentation_parameters):
466474
offset_tic = (np_rng.integers(sound['ticks'][0], high=1+sound['ticks'][1]) \
467475
if sound['ticks'][0] < sound['ticks'][1] \
468476
else sound['ticks'][0])
469-
start_tic = offset_tic - math.floor(context_tics/2) - shiftby_tics
470-
stop_tic = offset_tic + math.ceil(context_tics/2) - shiftby_tics
477+
start_tic = offset_tic - math.ceil(ninput_tics/2) - shiftby_tics
478+
stop_tic = offset_tic + math.floor(ninput_tics/2) - shiftby_tics
471479
if use_audio:
472480
wavpath = os.path.join(data_dir, *sound['file'])
473481
_, _, audio_data = audio_read(wavpath, start_tic, stop_tic)
@@ -486,9 +494,18 @@ def augment(audio_slice, augmentation_parameters):
486494
video_slice[i-offset,iframe,:,:,:] = \
487495
frame[:,:,video_channels] - bkg[vidfile][:,:,video_channels]
488496
if loss=='exclusive':
489-
labels[i - offset] = labels_list.index(sound['label'])
497+
start_in_tic = max(sound['ticks'][0] - offset_tic,
498+
(1 - parallelize / 2) * stride_x_downsample)
499+
start_out_tic = math.ceil(start_in_tic / stride_x_downsample)
500+
start_out_tic += parallelize // 2 - 1
501+
stop_in_tic = min(sound['ticks'][1] - offset_tic + stride_x_downsample,
502+
parallelize / 2 * stride_x_downsample)
503+
stop_out_tic = math.floor(stop_in_tic / stride_x_downsample)
504+
stop_out_tic += parallelize // 2 - 1
505+
labels[i - offset, start_out_tic : 1+stop_out_tic] = labels_list.index(sound['label'])
506+
sound['offset_tic'] = offset_tic
490507
sounds.append({k: v for k,v in sound.items() if k!='overlaps'})
491-
elif loss=='overlapped':
508+
elif loss=='overlapped': ### !!!
492509
target = 0 if sound['label'].startswith(overlapped_prefix) else 1
493510
root = sound['label'].removeprefix(overlapped_prefix)
494511
labels[i - offset, labels_list.index(root)] = target

src/generalize

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# --shiftby=0.0 \
88
# --optimizer=Adam \
99
# --loss=exclusive \
10+
# --parallelize=64 \
1011
# --overlapped_prefix=not_ \
1112
# --learning_rate=0.0002 \
1213
# --audio_read_plugin=load_wav \
@@ -102,6 +103,7 @@ def main():
102103
"--shiftby="+str(FLAGS.shiftby),
103104
"--optimizer="+FLAGS.optimizer,
104105
"--loss="+FLAGS.loss,
106+
"--parallelize="+str(FLAGS.parallelize),
105107
"--overlapped_prefix="+FLAGS.overlapped_prefix,
106108
"--learning_rate="+str(FLAGS.learning_rate),
107109
"--audio_read_plugin="+FLAGS.audio_read_plugin,
@@ -232,6 +234,11 @@ if __name__ == '__main__':
232234
type=float,
233235
default=1000,
234236
help='Expected duration in milliseconds of the wavs',)
237+
parser.add_argument(
238+
'--parallelize',
239+
type=int,
240+
default=64,
241+
help='how many output time tics to simultaneously process',)
235242
parser.add_argument(
236243
'--learning_rate',
237244
type=float,

src/gui/controller.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def generic_parameters_callback(n):
113113
M.save_state_callback()
114114
V.buttons_update()
115115

116-
def context_callback(n):
116+
def context_parallelize_callback(n):
117117
V.model_summary_update()
118118
generic_parameters_callback(n)
119119

@@ -1190,6 +1190,7 @@ async def train_actuate():
11901190
"--shiftby="+V.shiftby.value, \
11911191
"--optimizer="+V.optimizer.value, \
11921192
"--loss="+V.loss.value, \
1193+
"--parallelize="+str(V.parallelize.value),
11931194
"--overlapped_prefix="+M.overlapped_prefix, \
11941195
"--learning_rate="+V.learning_rate.value, \
11951196
"--audio_read_plugin="+str(M.audio_read_plugin), \
@@ -1289,6 +1290,7 @@ async def leaveout_actuate(kind):
12891290
"--shiftby="+V.shiftby.value, \
12901291
"--optimizer="+V.optimizer.value, \
12911292
"--loss="+V.loss.value, \
1293+
"--parallelize="+str(V.parallelize.value),
12921294
"--overlapped_prefix="+M.overlapped_prefix, \
12931295
"--learning_rate="+V.learning_rate.value, \
12941296
"--audio_read_plugin="+str(M.audio_read_plugin), \
@@ -1362,6 +1364,7 @@ async def xvalidate_actuate():
13621364
"--shiftby="+V.shiftby.value, \
13631365
"--optimizer="+V.optimizer.value, \
13641366
"--loss="+V.loss.value, \
1367+
"--parallelize="+str(V.parallelize.value),
13651368
"--overlapped_prefix="+M.overlapped_prefix, \
13661369
"--learning_rate="+V.learning_rate.value, \
13671370
"--audio_read_plugin="+str(M.audio_read_plugin), \
@@ -1735,7 +1738,7 @@ async def _freeze_actuate(ckpts):
17351738
"--model_architecture="+M.architecture_plugin,
17361739
"--model_parameters="+json.dumps({k:v.value for k,v in V.model_parameters.items()}),
17371740
"--loss="+V.loss.value, \
1738-
"--parallelize="+str(M.classify_parallelize),
1741+
"--parallelize="+str(V.parallelize.value),
17391742
"--time_units="+str(M.time_units),
17401743
"--freq_units="+str(M.freq_units),
17411744
"--time_scale="+str(M.time_scale),
@@ -1821,7 +1824,7 @@ async def ensemble_actuate():
18211824
"--context="+V.context.value,
18221825
"--model_architecture="+M.architecture_plugin,
18231826
"--model_parameters="+json.dumps({k:v.value for k,v in V.model_parameters.items()}),
1824-
"--parallelize="+str(M.classify_parallelize),
1827+
"--parallelize="+str(V.parallelize.value),
18251828
"--time_units="+str(M.time_units),
18261829
"--freq_units="+str(M.freq_units),
18271830
"--time_scale="+str(M.time_scale),
@@ -1871,7 +1874,7 @@ async def _classify_actuate(wavfiles):
18711874
"--model="+os.path.join(logdir,model,"frozen-graph.ckpt-"+check_point+".pb"),
18721875
"--model_labels="+os.path.join(logdir,model,"labels.txt"),
18731876
"--wav="+wavfile,
1874-
"--parallelize="+str(M.classify_parallelize),
1877+
"--parallelize="+str(V.parallelize.value),
18751878
"--time_scale="+str(M.time_scale),
18761879
"--audio_tic_rate="+str(M.audio_tic_rate),
18771880
"--audio_nchannels="+str(M.audio_nchannels),
@@ -2230,6 +2233,9 @@ def _copy_callback():
22302233
elif "context = " in line:
22312234
m=re.search('context = (.*)', line)
22322235
V.context.value = m.group(1)
2236+
elif "parallelize = " in line:
2237+
m=re.search('parallelize = (.*)', line)
2238+
V.parallelize.value = m.group(1)
22332239
elif "time_shift_sec = " in line:
22342240
m=re.search('time_shift_sec = (.*)', line)
22352241
V.shiftby.value = m.group(1)

src/gui/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
V.prevalences_button, V.prevalences,
110110
V.delete_ckpts, V.copy, width=M.gui_width_pix),
111111
row(V.nsteps, V.restore_from, V.weights_seed, V.optimizer, V.context,
112-
V.mini_batch, V.nreplicates, V.activations_equalize_ratio,
112+
V.parallelize, V.mini_batch, V.nreplicates, V.activations_equalize_ratio,
113113
V.precision_recall_ratios, V.congruence_portion,
114114
width=M.gui_width_pix),
115115
row(V.save_and_validate_period, V.validate_percentage, V.batch_seed,

src/gui/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def save_state_callback():
6060
'labels': str.join(',',[x.value for x in V.label_texts]),
6161
'file_dialog_string': V.file_dialog_string.value,
6262
'context': V.context.value,
63+
'parallelize': V.parallelize.value,
6364
'shiftby': V.shiftby.value,
6465
'optimizer': V.optimizer.value,
6566
'loss': V.loss.value,
@@ -502,6 +503,7 @@ def is_local_server_or_cluster(varname, varvalue):
502503
'labels':','*(nlabels-1), \
503504
'file_dialog_string':os.getcwd(), \
504505
'context':str(0.2048 / time_scale), \
506+
'parallelize':'64', \
505507
'shiftby':'0.0', \
506508
'optimizer':'Adam', \
507509
'loss':'exclusive', \

0 commit comments

Comments
 (0)