@@ -91,16 +91,19 @@ def main():
9191 with open (FLAGS .model_labels , 'r' ) as fid :
9292 model_labels = fid .read ().splitlines ()
9393
94- if FLAGS .labels :
95- labels = np .array (FLAGS .labels .split (',' ))
96- iimodel_labels = np .argsort (np .argsort (model_labels ))
97- ilabels = np .argsort (labels )
98- labels = labels [ilabels ][iimodel_labels ]
99- assert np .all (labels == model_labels )
94+ if FLAGS .loss != 'autoencoder' :
95+ if FLAGS .labels :
96+ labels = np .array (FLAGS .labels .split (',' ))
97+ iimodel_labels = np .argsort (np .argsort (model_labels ))
98+ ilabels = np .argsort (labels )
99+ labels = labels [ilabels ][iimodel_labels ]
100+ assert np .all (labels == model_labels )
101+ else :
102+ labels = model_labels
103+ ilabels = iimodel_labels = range (len (labels ))
104+ print ('labels: ' + str (labels ))
100105 else :
101- labels = model_labels
102- ilabels = iimodel_labels = range (len (labels ))
103- print ('labels: ' + str (labels ))
106+ labels = ilabels = None
104107
105108 if FLAGS .prevalences and FLAGS .loss == 'exclusive' :
106109 prevalences = np .array ([float (x ) for x in FLAGS .prevalences .split (',' )])
@@ -184,11 +187,17 @@ def main():
184187
185188 context_samples = int (FLAGS .context * FLAGS .time_scale * data_sample_rate )
186189 stride_x_downsample_samples = (clip_window_samples - context_samples ) // (FLAGS .parallelize - 1 )
187- clip_stride_samples = stride_x_downsample_samples * FLAGS .parallelize
190+ if FLAGS .loss == 'autoencoder' :
191+ clip_stride_samples = clip_window_samples
192+ else :
193+ clip_stride_samples = stride_x_downsample_samples * FLAGS .parallelize
188194
189195 stride_x_downsample_sec = stride_x_downsample_samples / data_sample_rate
190196 npadding = round ((FLAGS .context / 2 + FLAGS .shiftby ) * FLAGS .time_scale / stride_x_downsample_sec )
191- probability_list = [np .zeros ((npadding , len (labels )), dtype = np .float32 )]
197+ if FLAGS .loss == 'autoencoder' :
198+ probability_list = [np .zeros ((npadding , ), dtype = np .float32 )]
199+ else :
200+ probability_list = [np .zeros ((npadding , len (labels )), dtype = np .float32 )]
192201
193202 # Inference along audio stream.
194203 for data_offset_samples in range (0 , 1 + data_len_samples , clip_stride_samples ):
@@ -220,31 +229,42 @@ def main():
220229 inputs = tf .expand_dims (video_slice , 0 )
221230 _ ,outputs = recognize_graph (inputs )
222231
223- current_time_sec = np .round (data_offset_samples / data_sample_rate ).astype (int )
224232 if pad_len > 0 :
225233 discard_len = np .ceil (pad_len / stride_x_downsample_samples ).astype (int )
226- probability_list .append (np .array (outputs .numpy ()[0 ,:- discard_len ,:]))
234+ if FLAGS .loss == 'autoencoder' :
235+ probability_list .append (np .array (outputs .numpy ()[0 ,:- discard_len ,0 ]))
236+ else :
237+ probability_list .append (np .array (outputs .numpy ()[0 ,:- discard_len ,:]))
227238 break
228239 else :
229- probability_list .append (np .array (outputs .numpy ()[0 ,:,:]))
240+ if FLAGS .loss == 'autoencoder' :
241+ probability_list .append (np .array (outputs .numpy ()[0 ,:,0 ]))
242+ else :
243+ probability_list .append (np .array (outputs .numpy ()[0 ,:,:]))
230244
231245 sample_rate = round (1 / stride_x_downsample_sec )
232246 if sample_rate != 1 / stride_x_downsample_sec :
233247 print ('WARNING: .wav files do not support fractional sampling rates!' )
234248
235249 probability_matrix = np .concatenate (probability_list )
236- if prevalences :
237- denominator = np .sum (probability_matrix * prevalences , axis = 1 )
238- for ch in range (len (labels )):
239- if prevalences :
240- adjusted_probability = probability_matrix [:,ch ] * prevalences [ch ]
241- adjusted_probability [npadding :] /= denominator [npadding :]
242- else :
243- adjusted_probability = probability_matrix [:,ch ]
244- waveform = adjusted_probability * np .iinfo (np .int16 ).max
245- withoutext = trim_ext (FLAGS .wav )
246- filename = withoutext + '-' + labels [ch ]+ '.wav'
247- wavfile .write (filename , int (sample_rate ), waveform .astype ('int16' ))
250+ if FLAGS .loss != 'autoencoder' :
251+ if prevalences :
252+ denominator = np .sum (probability_matrix * prevalences , axis = 1 )
253+ for ch in range (len (labels )):
254+ if prevalences :
255+ adjusted_probability = probability_matrix [:,ch ] * prevalences [ch ]
256+ adjusted_probability [npadding :] /= denominator [npadding :]
257+ else :
258+ adjusted_probability = probability_matrix [:,ch ]
259+ waveform = adjusted_probability * np .iinfo (np .int16 ).max
260+ withoutext = trim_ext (FLAGS .wav )
261+ filename = withoutext + '-' + labels [ch ]+ '.wav'
262+ wavfile .write (filename , int (sample_rate ), waveform .astype ('int16' ))
263+ else :
264+ waveform = probability_matrix * np .iinfo (np .int16 ).max
265+ withoutext = trim_ext (FLAGS .wav )
266+ filename = withoutext + '-.wav'
267+ wavfile .write (filename , int (sample_rate ), waveform .astype ('int16' ))
248268
249269if __name__ == '__main__' :
250270 parser = argparse .ArgumentParser (description = 'test_streaming_accuracy' )
@@ -271,8 +291,8 @@ if __name__ == '__main__':
271291 '--loss' ,
272292 type = str ,
273293 default = 'exclusive' ,
274- choices = ['exclusive' , 'overlapped' ],
275- help = 'Sigmoid cross entropy is used for "overlapped" labels while softmax cross entropy is used for "exclusive" labels.' )
294+ choices = ['exclusive' , 'overlapped' , 'autoencoder' ],
295+ help = 'Sigmoid cross entropy is used for "overlapped" or "autoencoder" labels while softmax cross entropy is used for "exclusive" labels.' )
276296 parser .add_argument (
277297 '--context' ,
278298 type = float ,
0 commit comments