@@ -108,7 +108,7 @@ def which_set(filename, validation_percentage, validation_offset_percentage, tes
108108 return result
109109
110110def init (_data_dir ,
111- shiftby ,
111+ shiftby , _clip_window_samples ,
112112 labels_touse , kinds_touse ,
113113 validation_percentage , validation_offset_percentage , validation_files ,
114114 testing_percentage , testing_files , subsample_skip , subsample_label ,
@@ -121,9 +121,10 @@ def init(_data_dir,
121121 _audio_read_plugin , _video_read_plugin ,
122122 _audio_read_plugin_kwargs , _video_read_plugin_kwargs ):
123123
124- global data_dir , np_rng , audio_read_plugin_kwargs , audio_read_plugin , video_read_plugin_kwargs , video_read_plugin , queue_size , max_procs
124+ global data_dir , clip_window_samples , np_rng , audio_read_plugin_kwargs , audio_read_plugin , video_read_plugin_kwargs , video_read_plugin , queue_size , max_procs
125125
126126 data_dir = _data_dir
127+ clip_window_samples = _clip_window_samples
127128 np_rng = np .random .default_rng (None if random_seed_batch == - 1 else random_seed_batch )
128129
129130 audio_read_plugin = _audio_read_plugin
@@ -198,6 +199,8 @@ def prepare_data_index(shiftby,
198199 audio_tic_rate = model_settings ['audio_tic_rate' ]
199200 time_scale = model_settings ['time_scale' ]
200201 context_tics = int (audio_tic_rate * model_settings ['context' ] * time_scale )
202+ parallelize = int (model_settings ['parallelize' ])
203+ stride_x_downsample = (clip_window_samples - context_tics ) // (parallelize - 1 )
201204 shiftby_tics = int (shiftby * audio_tic_rate * time_scale )
202205 audio_ntics = {}
203206 video_nframes = {}
@@ -266,19 +269,21 @@ def prepare_data_index(shiftby,
266269 f'in configuration.py but is actually { audio_data_shape [1 ]} in { wav_path } ' )
267270 audio_ntics [wav_path ] = audio_data_shape [0 ]
268271 if use_audio :
269- if ticks [1 ] < context_tics // 2 + shiftby_tics or \
270- ticks [0 ] > (audio_ntics [wav_path ] - context_tics // 2 + shiftby_tics ):
272+ left_room = context_tics // 2 + (parallelize // 2 + 1 )* stride_x_downsample + shiftby_tics
273+ right_room = context_tics // 2 + (parallelize // 2 )* stride_x_downsample - shiftby_tics
274+ if ticks [1 ] < left_room or \
275+ ticks [0 ] > (audio_ntics [wav_path ] - right_room ):
271276 print (f"WARNING: { str (annotation )} is too close to an edge of the recording. "
272277 f"not using at all" )
273278 continue
274- if ticks [0 ] < context_tics // 2 + shiftby_tics :
279+ if ticks [0 ] < left_room :
275280 print (f"WARNING: { str (annotation )} is close to beginning of recording. "
276281 f"shortening interval to usable range" )
277- ticks [0 ] = context_tics // 2 + shiftby_tics
278- if ticks [1 ] > audio_ntics [wav_path ] - context_tics // 2 + shiftby_tics :
282+ ticks [0 ] = left_room
283+ if ticks [1 ] > audio_ntics [wav_path ] - right_room :
279284 print (f"WARNING: { str (annotation )} is close to end of recording. "
280285 f"shortening interval to usable range" )
281- ticks [1 ] = audio_ntics [wav_path ] - context_tics // 2 + shiftby_tics
286+ ticks [1 ] = audio_ntics [wav_path ] - right_room
282287 if use_video and wav_path not in video_nframes :
283288 sound_dirname = os .path .join (data_dir , wav_base2 [0 ])
284289 vidfile = video_findfile (sound_dirname , wavfile )
@@ -438,11 +443,14 @@ def augment(audio_slice, augmentation_parameters):
438443 video_channels = model_settings ['video_channels' ]
439444 time_scale = model_settings ['time_scale' ]
440445 context_tics = int (audio_tic_rate * model_settings ['context' ] * time_scale )
446+ parallelize = int (model_settings ['parallelize' ])
447+ ninput_tics = clip_window_samples
448+ stride_x_downsample = (clip_window_samples - context_tics ) // (parallelize - 1 )
441449 shiftby_tics = int (shiftby * audio_tic_rate * time_scale )
442450 if use_audio :
443- audio_slice = np .zeros ((nsounds , context_tics , audio_nchannels ), dtype = np .float32 )
451+ audio_slice = np .zeros ((nsounds , ninput_tics , audio_nchannels ), dtype = np .float32 )
444452 if use_video :
445- nframes = round (model_settings [ 'context' ] * time_scale * video_frame_rate )
453+ nframes = round (ninput_tics / audio_tic_rate * video_frame_rate )
446454 video_slice = np .zeros ((nsounds ,
447455 nframes ,
448456 model_settings ['video_frame_height' ],
@@ -451,9 +459,9 @@ def augment(audio_slice, augmentation_parameters):
451459 dtype = np .float32 )
452460 bkg = {}
453461 if loss == 'exclusive' :
454- labels = np .zeros ( nsounds , dtype = np .int32 )
462+ labels = - 1 * np .ones (( nsounds , parallelize ) , dtype = np .int32 )
455463 elif loss == 'overlapped' :
456- labels = 2 * np .ones ((nsounds , len (labels_list )), dtype = np .float32 )
464+ labels = 2 * np .ones ((nsounds , parallelize , len (labels_list )), dtype = np .float32 )
457465 # repeatedly to generate the final output sound data we'll use in training.
458466 for i in range (offset , offset + nsounds ):
459467 # Pick which sound to use.
@@ -466,8 +474,8 @@ def augment(audio_slice, augmentation_parameters):
466474 offset_tic = (np_rng .integers (sound ['ticks' ][0 ], high = 1 + sound ['ticks' ][1 ]) \
467475 if sound ['ticks' ][0 ] < sound ['ticks' ][1 ] \
468476 else sound ['ticks' ][0 ])
469- start_tic = offset_tic - math .floor ( context_tics / 2 ) - shiftby_tics
470- stop_tic = offset_tic + math .ceil ( context_tics / 2 ) - shiftby_tics
477+ start_tic = offset_tic - math .ceil ( ninput_tics / 2 ) - shiftby_tics
478+ stop_tic = offset_tic + math .floor ( ninput_tics / 2 ) - shiftby_tics
471479 if use_audio :
472480 wavpath = os .path .join (data_dir , * sound ['file' ])
473481 _ , _ , audio_data = audio_read (wavpath , start_tic , stop_tic )
@@ -486,9 +494,18 @@ def augment(audio_slice, augmentation_parameters):
486494 video_slice [i - offset ,iframe ,:,:,:] = \
487495 frame [:,:,video_channels ] - bkg [vidfile ][:,:,video_channels ]
488496 if loss == 'exclusive' :
489- labels [i - offset ] = labels_list .index (sound ['label' ])
497+ start_in_tic = max (sound ['ticks' ][0 ] - offset_tic ,
498+ (1 - parallelize / 2 ) * stride_x_downsample )
499+ start_out_tic = math .ceil (start_in_tic / stride_x_downsample )
500+ start_out_tic += parallelize // 2 - 1
501+ stop_in_tic = min (sound ['ticks' ][1 ] - offset_tic + stride_x_downsample ,
502+ parallelize / 2 * stride_x_downsample )
503+ stop_out_tic = math .floor (stop_in_tic / stride_x_downsample )
504+ stop_out_tic += parallelize // 2 - 1
505+ labels [i - offset , start_out_tic : 1 + stop_out_tic ] = labels_list .index (sound ['label' ])
506+ sound ['offset_tic' ] = offset_tic
490507 sounds .append ({k : v for k ,v in sound .items () if k != 'overlaps' })
491- elif loss == 'overlapped' :
508+ elif loss == 'overlapped' : ### !!!
492509 target = 0 if sound ['label' ].startswith (overlapped_prefix ) else 1
493510 root = sound ['label' ].removeprefix (overlapped_prefix )
494511 labels [i - offset , labels_list .index (root )] = target
0 commit comments