Skip to content

Commit 1a790b8

Browse files
committed
move augmentation from model architecture to data loader
1 parent 2427161 commit 1a790b8

17 files changed

Lines changed: 257 additions & 128 deletions

src/convolutional.py

Lines changed: 3 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -208,59 +208,8 @@ def model_parameters(time_units, freq_units, time_scale, freq_scale):
208208
["max",
209209
"average"]], None, True],
210210
["denselayers", "dense layers", '', '', 1, [], None, False],
211-
["augment_volume", "augment volume", '', '1,1', 1, [], None, True],
212-
["augment_noise", "augment noise", '', '0,0', 1, [], None, True],
213-
["augment_dc", "augment DC", '', '0,0', 1, [], None, True],
214-
["augment_reverse", "augment reverse", ["yes", "no"], 'no', 1, [], None, True],
215-
["augment_invert", "augment invert", ["yes", "no"], 'no', 1, [], None, True],
216211
]
217212

218-
class Augment(tf.keras.layers.Layer):
219-
def __init__(self, volume_range, noise_range, baseline_range, reverse_bool, invert_bool, **kwargs):
220-
super(Augment, self).__init__(**kwargs)
221-
self.volume_range = volume_range
222-
self.noise_range = noise_range
223-
self.baseline_range = baseline_range
224-
self.reverse_bool = reverse_bool
225-
self.invert_bool = invert_bool
226-
def get_config(self):
227-
config = super().get_config().copy()
228-
config.update({
229-
'volume_range': self.volume_range,
230-
'noise_range': self.noise_range,
231-
'baseline_range': self.baseline_range,
232-
'reverse_bool': self.reverse_bool,
233-
'invert_bool': self.invert_bool,
234-
})
235-
return config
236-
def call(self, inputs, training=None):
237-
if not training:
238-
return inputs
239-
if self.volume_range != [1,1] or self.noise_range != [0,0] or self.baseline_range != [0,0]:
240-
nbatch_1_nchannel = tf.stack((tf.shape(inputs)[0], 1, tf.shape(inputs)[2]), axis=0)
241-
if self.volume_range != [1,1]:
242-
volume_ranges = tf.random.uniform(nbatch_1_nchannel, *self.volume_range)
243-
inputs = tf.math.multiply(volume_ranges, inputs)
244-
if self.noise_range != [0,0]:
245-
noise_ranges = tf.random.uniform(nbatch_1_nchannel, *self.noise_range)
246-
noises = tf.random.normal(tf.shape(inputs), 0, noise_ranges)
247-
inputs = tf.math.add(noises, inputs)
248-
if self.baseline_range != [0,0]:
249-
baseline_ranges = tf.random.uniform(nbatch_1_nchannel, *self.baseline_range)
250-
inputs = tf.math.add(baseline_ranges, inputs)
251-
if self.reverse_bool:
252-
ireverse = tf.squeeze(tf.random.categorical(tf.math.log([[0.5, 0.5]]),
253-
tf.shape(inputs)[0], dtype=tf.int32))
254-
ireverse *= tf.shape(inputs)[1]
255-
inputs = tf.reverse_sequence(inputs, ireverse, seq_axis=1, batch_axis=0)
256-
if self.invert_bool:
257-
iinvert = tf.squeeze(tf.random.categorical(tf.math.log([[0.5, 0.5]]),
258-
tf.shape(inputs)[0], dtype=tf.int32))
259-
iinvert = tf.cast(iinvert, tf.float32)*2-1
260-
iinvert = tf.expand_dims(tf.expand_dims(iinvert, axis=1), axis=1)
261-
inputs *= iinvert
262-
return inputs
263-
264213
class Spectrogram(tf.keras.layers.Layer):
265214
def __init__(self, window_tics, stride_tics, **kwargs):
266215
super(Spectrogram, self).__init__(**kwargs)
@@ -452,20 +401,10 @@ def create_model(model_settings, model_parameters, io=sys.stdout):
452401
inputs = Input(shape=(ninput_tics, model_settings['audio_nchannels']))
453402
hidden_layers.append(inputs)
454403

455-
volume_range = [float(x) for x in model_parameters['augment_volume'].split(',')]
456-
noise_range = [float(x) for x in model_parameters['augment_noise'].split(',')]
457-
dc_range = [float(x) for x in model_parameters['augment_dc'].split(',')]
458-
reverse_bool = model_parameters['augment_reverse'] == 'yes'
459-
invert_bool = model_parameters['augment_invert'] == 'yes'
460-
if volume_range != [1,1] or noise_range != [0,0] or dc_range != [0,0]:
461-
x = Augment(volume_range, noise_range, dc_range, reverse_bool, invert_bool)(inputs)
462-
else:
463-
x = inputs
464-
465404
if representation == "waveform":
466-
x = Reshape((ninput_tics,1,model_settings['audio_nchannels']))(x)
405+
x = Reshape((ninput_tics,1,model_settings['audio_nchannels']))(inputs)
467406
elif representation == "spectrogram":
468-
x = Spectrogram(window_tics, stride_tics)(x)
407+
x = Spectrogram(window_tics, stride_tics)(inputs)
469408
if model_parameters['range'] != "":
470409
lo, hi = model_parameters['range'].split('-')
471410
lo = float(lo) * freq_scale
@@ -478,7 +417,7 @@ def create_model(model_settings, model_parameters, io=sys.stdout):
478417
elif representation == "mel-cepstrum":
479418
filterbank_nchannels, dct_ncoefficients = model_parameters['mel_dct'].split(',')
480419
x = MelCepstrum(window_tics, stride_tics, audio_tic_rate,
481-
int(filterbank_nchannels), int(dct_ncoefficients))(x)
420+
int(filterbank_nchannels), int(dct_ncoefficients))(inputs)
482421
hidden_layers.append(x)
483422
x_shape = x.shape
484423

src/convolutional1.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -216,36 +216,8 @@ def model_parameters(time_units, freq_units, time_scale, freq_scale):
216216
["max",
217217
"average"]], None, True],
218218
["denselayers", "dense layers", '', '', 1, [], None, False],
219-
["augment_volume", "augment volume", '', '1,1', 1, [], None, True],
220-
["augment_noise", "augment noise", '', '0,0', 1, [], None, True],
221219
]
222220

223-
class Augment(tf.keras.layers.Layer):
224-
def __init__(self, volume_range, noise_range, **kwargs):
225-
super(Augment, self).__init__(**kwargs)
226-
self.volume_range = volume_range
227-
self.noise_range = noise_range
228-
def get_config(self):
229-
config = super().get_config().copy()
230-
config.update({
231-
'volume_range': self.volume_range,
232-
'noise_range': self.noise_range,
233-
})
234-
return config
235-
def call(self, inputs, training=None):
236-
if not training:
237-
return inputs
238-
if self.volume_range != [1,1] or self.noise_range != [0,0]:
239-
nbatch_1_nchannel = tf.stack((tf.shape(inputs)[0], 1, tf.shape(inputs)[2]), axis=0)
240-
if self.volume_range != [1,1]:
241-
volume_ranges = tf.random.uniform(nbatch_1_nchannel, *self.volume_range)
242-
inputs = tf.math.multiply(volume_ranges, inputs)
243-
if self.noise_range != [0,0]:
244-
noise_ranges = tf.random.uniform(nbatch_1_nchannel, *self.noise_range)
245-
noises = tf.random.normal(tf.shape(inputs), 0, noise_ranges)
246-
inputs = tf.math.add(noises, inputs)
247-
return inputs
248-
249221
class Spectrogram(tf.keras.layers.Layer):
250222
def __init__(self, window_tics, stride_tics, **kwargs):
251223
super(Spectrogram, self).__init__(**kwargs)
@@ -444,17 +416,10 @@ def Identity(x): return lambda x: x
444416
inputs = Input(shape=(ninput_tics, model_settings['audio_nchannels']))
445417
hidden_layers.append(inputs)
446418

447-
volume_range = [float(x) for x in model_parameters['augment_volume'].split(',')]
448-
noise_range = [float(x) for x in model_parameters['augment_noise'].split(',')]
449-
if volume_range != [1,1] or noise_range != [0,0]:
450-
x = Augment(volume_range, noise_range)(inputs)
451-
else:
452-
x = inputs
453-
454419
if representation == "waveform":
455-
x = Reshape((ninput_tics,1,model_settings['audio_nchannels']))(x)
420+
x = Reshape((ninput_tics,1,model_settings['audio_nchannels']))(inputs)
456421
elif representation == "spectrogram":
457-
x = Spectrogram(window_tics, stride_tics)(x)
422+
x = Spectrogram(window_tics, stride_tics)(inputs)
458423
if model_parameters['range'] != "":
459424
lo, hi = model_parameters['range'].split('-')
460425
lo = float(lo) * freq_scale
@@ -467,7 +432,7 @@ def Identity(x): return lambda x: x
467432
elif representation == "mel-cepstrum":
468433
filterbank_nchannels, dct_ncoefficients = model_parameters['mel_dct'].split(',')
469434
x = MelCepstrum(window_tics, stride_tics, audio_tic_rate,
470-
int(filterbank_nchannels), int(dct_ncoefficients))(x)
435+
int(filterbank_nchannels), int(dct_ncoefficients))(inputs)
471436
hidden_layers.append(x)
472437
x_shape = x.shape
473438

src/data.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,31 @@ def _get_data(self, q, o, how_many, offset, model_settings, loss, overlapped_pre
494494
root = overlapped_sound['label'].removeprefix(overlapped_prefix)
495495
labels[i - offset, self.labels_list.index(root)] = target
496496
sounds[-1].append({k: v for k,v in overlapped_sound.items() if k!='overlaps'})
497+
498+
# augmentation
499+
if use_audio and mode=='training':
500+
volume_range = [float(x) for x in model_settings['augment_volume'].split(',')]
501+
noise_range = [float(x) for x in model_settings['augment_noise'].split(',')]
502+
dc_range = [float(x) for x in model_settings['augment_dc'].split(',')]
503+
reverse_bool = model_settings['augment_reverse'] == 'yes'
504+
invert_bool = model_settings['augment_invert'] == 'yes'
505+
if volume_range != [1,1]:
506+
volume_ranges = np.random.uniform(*volume_range, (nsounds,1,audio_nchannels))
507+
audio_slice *= volume_ranges
508+
if noise_range != [0,0]:
509+
noise_ranges = np.random.uniform(*noise_range, (nsounds,1,audio_nchannels))
510+
noises = np.random.normal(0, noise_ranges, audio_slice.shape)
511+
audio_slice += noises
512+
if dc_range != [0,0]:
513+
dc_ranges = np.random.uniform(*dc_range, (nsounds,1,audio_nchannels))
514+
audio_slice += dc_ranges
515+
if reverse_bool:
516+
ireverse = np.random.choice([False,True], nsounds)
517+
audio_slice[ireverse] = np.flip(audio_slice[ireverse], axis=1)
518+
if invert_bool:
519+
iinvert = np.random.choice([-1,1], (nsounds,1,1))
520+
audio_slice *= iinvert
521+
497522
if use_audio and use_video:
498523
q.put([[audio_slice, video_slice], labels, sounds])
499524
elif use_audio:

src/generalize

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
# --video_channels=0 \
4141
# --batch_seed=_1 \
4242
# --weights_seed=_1 \
43+
# --augment_volume=1,1 \
44+
# --augment_noise=0,0 \
45+
# --augment_dc=0,0 \
46+
# --augment_reverse=no \
47+
# --augment_invert=no \
4348
# --deterministic=0 \
4449
# --igpu=0 \
4550
# --ioffset=3 \
@@ -133,6 +138,11 @@ def main():
133138
"--video_channels="+FLAGS.video_channels,
134139
"--random_seed_batch="+str(FLAGS.batch_seed),
135140
"--random_seed_weights="+str(FLAGS.weights_seed),
141+
"--augment_volume="+str(FLAGS.augment_volume),
142+
"--augment_noise="+str(FLAGS.augment_noise),
143+
"--augment_dc="+str(FLAGS.augment_dc),
144+
"--augment_reverse="+str(FLAGS.augment_reverse),
145+
"--augment_invert="+str(FLAGS.augment_invert),
136146
"--deterministic="+FLAGS.deterministic,
137147
"--train_dir="+os.path.join(FLAGS.logdir,"generalize_"+model),
138148
"--summaries_dir="+os.path.join(FLAGS.logdir,"summaries_"+model),
@@ -294,6 +304,31 @@ if __name__ == '__main__':
294304
type=int,
295305
default=59185,
296306
help='Randomize weight initialization if -1; otherwise use supplied number as seed.')
307+
parser.add_argument(
308+
'--augment_volume',
309+
type=str,
310+
default='1,1',
311+
help='Multiply each annotation by a uniform random number in this interval when training')
312+
parser.add_argument(
313+
'--augment_noise',
314+
type=str,
315+
default='0,0',
316+
help='Add noise to each annotation with a uniform random std dev in this interval when training')
317+
parser.add_argument(
318+
'--augment_dc',
319+
type=str,
320+
default='0,0',
321+
help='Add to each annotation a uniform random number in this interval when training')
322+
parser.add_argument(
323+
'--augment_reverse',
324+
type=str,
325+
default='no',
326+
help='Flip in time with a probability of half each annotation when training')
327+
parser.add_argument(
328+
'--augment_invert',
329+
type=str,
330+
default='no',
331+
help='Negate with a probability of half each annotation when training')
297332
parser.add_argument(
298333
'--model_architecture',
299334
type=str,

src/gui/controller.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,11 @@ async def train_actuate():
12221222
"--video_channels="+str(M.video_channels), \
12231223
"--batch_seed="+V.batch_seed.value, \
12241224
"--weights_seed="+V.weights_seed.value, \
1225+
"--augment_volume="+V.augment_volume.value, \
1226+
"--augment_noise="+V.augment_noise.value, \
1227+
"--augment_dc="+V.augment_dc.value, \
1228+
"--augment_reverse="+V.augment_reverse.value, \
1229+
"--augment_invert="+V.augment_invert.value, \
12251230
"--deterministic="+M.deterministic, \
12261231
"--igpu=QUEUE1", \
12271232
"--ireplicates="+','.join([str(x) for x in range(ireplicate, min(1+nreplicates, \
@@ -1313,6 +1318,11 @@ async def leaveout_actuate(comma):
13131318
"--video_channels="+str(M.video_channels), \
13141319
"--batch_seed="+V.batch_seed.value, \
13151320
"--weights_seed="+V.weights_seed.value, \
1321+
"--augment_volume="+V.augment_volume.value, \
1322+
"--augment_noise="+V.augment_noise.value, \
1323+
"--augment_dc="+V.augment_dc.value, \
1324+
"--augment_reverse="+V.augment_reverse.value, \
1325+
"--augment_invert="+V.augment_invert.value, \
13161326
"--deterministic="+M.deterministic, \
13171327
"--ioffset="+str(ivalidation_file),
13181328
"--igpu=QUEUE1", \
@@ -1384,6 +1394,11 @@ async def xvalidate_actuate():
13841394
"--video_channels="+str(M.video_channels), \
13851395
"--batch_seed="+V.batch_seed.value, \
13861396
"--weights_seed="+V.weights_seed.value, \
1397+
"--augment_volume="+V.augment_volume.value, \
1398+
"--augment_noise="+V.augment_noise.value, \
1399+
"--augment_dc="+V.augment_dc.value, \
1400+
"--augment_reverse="+V.augment_reverse.value, \
1401+
"--augment_invert="+V.augment_invert.value, \
13871402
"--deterministic="+M.deterministic, \
13881403
"--igpu=QUEUE1", \
13891404
"--kfold="+V.kfold.value, \
@@ -2218,6 +2233,21 @@ def _copy_callback():
22182233
elif "random_seed_weights = " in line:
22192234
m=re.search('random_seed_weights = (.*)', line)
22202235
V.weights_seed.value = m.group(1)
2236+
elif "augment_volume = " in line:
2237+
m=re.search('augment_volume = (.*)', line)
2238+
V.augment_volume.value = m.group(1)
2239+
elif "augment_noise = " in line:
2240+
m=re.search('augment_noise = (.*)', line)
2241+
V.augment_noise.value = m.group(1)
2242+
elif "augment_dc = " in line:
2243+
m=re.search('augment_dc = (.*)', line)
2244+
V.augment_dc.value = m.group(1)
2245+
elif "augment_reverse = " in line:
2246+
m=re.search('augment_reverse = (.*)', line)
2247+
V.augment_reverse.value = m.group(1)
2248+
elif "augment_invert = " in line:
2249+
m=re.search('augment_invert = (.*)', line)
2250+
V.augment_invert.value = m.group(1)
22212251
elif "validate_step_period = " in line:
22222252
m=re.search('validate_step_period = (\d+)', line)
22232253
V.save_and_validate_period.value = m.group(1)

src/gui/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,13 @@
177177
width=105)),
178178
row(column(V.file_dialog_string,
179179
V.file_dialog_table),
180-
column(*[row([model_parameters[x] for x in p])
180+
column(row(V.augment_volume,
181+
V.augment_noise,
182+
V.augment_dc,
183+
V.augment_reverse,
184+
V.augment_invert,
185+
width=M.gui_width_pix//2),
186+
*[row([model_parameters[x] for x in p])
181187
for p in V.model_parameters_partitioned],
182188
V.model_summary,
183189
width=M.gui_width_pix//2))),

src/gui/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def save_state_callback():
5757
'nreplicates': V.nreplicates.value,
5858
'batch_seed': V.batch_seed.value,
5959
'weights_seed': V.weights_seed.value,
60+
'augment_volume': V.augment_volume.value,
61+
'augment_noise': V.augment_noise.value,
62+
'augment_dc': V.augment_dc.value,
63+
'augment_reverse': V.augment_reverse.value,
64+
'augment_invert': V.augment_invert.value,
6065
'labels': str.join(',',[x.value for x in V.label_texts]),
6166
'file_dialog_string': V.file_dialog_string.value,
6267
'context': V.context.value,
@@ -510,6 +515,11 @@ def is_local_server_or_cluster(varname, varvalue):
510515
'nreplicates':'1', \
511516
'batch_seed':'-1', \
512517
'weights_seed':'-1', \
518+
'augment_volume':'1,1', \
519+
'augment_noise':'0,0', \
520+
'augment_dc':'0,0', \
521+
'augment_reverse':'no', \
522+
'augment_invert':'no', \
513523
'labels':','*(nlabels-1), \
514524
'file_dialog_string':os.getcwd(), \
515525
'context':str(0.2048 / time_scale), \

0 commit comments

Comments
 (0)