Skip to content

Commit a4a3ebd

Browse files
committed
recordings can now be arbitrarily arranged in the groundtruth folder
1 parent b317926 commit a4a3ebd

8 files changed

Lines changed: 177 additions & 155 deletions

File tree

src/accuracy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def doit(logdir, key_to_plot, ckpt, labels, nprobabilities, error_ratios, loss,
163163
if loss=='exclusive':
164164
for subdir in set([x['file'][0] for x in validation_sounds]):
165165
with open(os.path.join(logdir, key_to_plot, 'predictions.ckpt-'+str(ckpt), \
166-
subdir+'-mistakes.csv'), \
166+
subdir.replace(os.path.sep,'-')+'-mistakes.csv'), \
167167
'w', newline='') as csvfile:
168168
csvwriter = csv.writer(csvfile, lineterminator='\n')
169169
for i in range(len(validation_sounds)):

src/congruence

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# e.g. congruence \
66
# --basepath=/groups/stern/sternlab/behavior/arthurb/groundtruth/kyriacou2017 \
7+
# --topath=/groups/stern/sternlab/behavior/arthurb/groundtruth/kyriacou2017/congruence-20240718T091400 \
78
# --wavfiles=PS_20130625111709_ch3.wav,PS_20130625111709_ch7.wav \
89
# --portion=union \
910
# --convolve_ms=0 \
@@ -146,11 +147,17 @@ def main():
146147
convolve_tic = int(FLAGS.convolve_ms/2/1000*FLAGS.audio_tic_rate)
147148

148149
wavdirs = {}
149-
for subdir in filter(lambda x: os.path.isdir(os.path.join(FLAGS.basepath,x)), \
150-
os.listdir(FLAGS.basepath)):
151-
commonsubfiles = wavfiles & set(os.listdir(os.path.join(FLAGS.basepath, subdir)))
152-
if len(commonsubfiles) > 0:
153-
wavdirs[subdir] = commonsubfiles
150+
def traverse(curdir):
151+
entries = set()
152+
for entry in os.listdir(os.path.join(FLAGS.basepath, curdir)):
153+
if os.path.isdir(os.path.join(FLAGS.basepath, curdir, entry)):
154+
traverse(os.path.join(curdir, entry))
155+
else:
156+
entries |= set([entry])
157+
commonsubfiles = wavfiles & entries
158+
if len(commonsubfiles) > 0:
159+
wavdirs[curdir] = commonsubfiles
160+
traverse("")
154161

155162
labels=None
156163
temp_files=[]
@@ -209,6 +216,7 @@ def main():
209216
annotator_keys = set()
210217

211218
for wavdir in wavdirs:
219+
os.makedirs(os.path.join(FLAGS.topath, wavdir), exist_ok=True)
212220
for csvfile in filter(lambda x: ("-annotated-" in x or "-predicted-" in x) and
213221
x.endswith('.csv'), \
214222
os.listdir(os.path.join(FLAGS.basepath,wavdir))):
@@ -468,31 +476,31 @@ def main():
468476
if do_tic:
469477
fig_tic.tight_layout()
470478
plt.figure(fig_tic.number)
471-
plt.savefig(os.path.join(FLAGS.basepath, 'congruence.tic.'+label+'.'+pr+'.pdf'))
479+
plt.savefig(os.path.join(FLAGS.topath, 'congruence.tic.'+label+'.'+pr+'.pdf'))
472480
plt.close()
473481
if do_label:
474482
fig_label.tight_layout()
475483
plt.figure(fig_label.number)
476-
plt.savefig(os.path.join(FLAGS.basepath, 'congruence.label.'+label+'.'+pr+'.pdf'))
484+
plt.savefig(os.path.join(FLAGS.topath, 'congruence.label.'+label+'.'+pr+'.pdf'))
477485
plt.close()
478486
if len(sorted_hm)<4:
479487
if do_tic:
480488
fig_tic_venn.tight_layout()
481489
plt.figure(fig_tic_venn.number)
482-
plt.savefig(os.path.join(FLAGS.basepath, 'congruence.tic.'+label+'.'+pr+'-venn.pdf'))
490+
plt.savefig(os.path.join(FLAGS.topath, 'congruence.tic.'+label+'.'+pr+'-venn.pdf'))
483491
plt.close()
484492
if do_label:
485493
fig_label_venn.tight_layout()
486494
plt.figure(fig_label_venn.number)
487-
plt.savefig(os.path.join(FLAGS.basepath, 'congruence.label.'+label+'.'+pr+'-venn.pdf'))
495+
plt.savefig(os.path.join(FLAGS.topath, 'congruence.label.'+label+'.'+pr+'-venn.pdf'))
488496
plt.close()
489497

490498
if FLAGS.parallelize!=0:
491499
pool.close()
492500

493501
def to_csv(intervals, csvbase, whichset):
494502
filename = os.path.splitext(csvbase)[0]+'-disjoint-'+whichset+'.csv'
495-
with open(os.path.join(FLAGS.basepath,filename), 'w') as fid:
503+
with open(os.path.join(FLAGS.topath,filename), 'w') as fid:
496504
csvwriter = csv.writer(fid, lineterminator='\n')
497505
for ilabel,label in enumerate(timestamps.keys()):
498506
for i in intervals[ilabel]:
@@ -627,7 +635,7 @@ def main():
627635
ax2.legend(loc=(1.05, 0.0))
628636
ax1.legend(loc=(1.2, 0.1))
629637
fig.tight_layout()
630-
plt.savefig(os.path.join(FLAGS.basepath,'congruence.'+measure+'.'+label+'.pdf'))
638+
plt.savefig(os.path.join(FLAGS.topath,'congruence.'+measure+'.'+label+'.pdf'))
631639
plt.close()
632640

633641
inotnan = (~np.isnan(P) & ~np.isnan(R)).nonzero()[0]
@@ -637,7 +645,7 @@ def main():
637645
else:
638646
print(measure+' '+label+' area cannot be computed because recall is not monotonic')
639647

640-
with open(os.path.join(FLAGS.basepath,'congruence.'+measure+'.'+label+'.csv'), 'w') as fid:
648+
with open(os.path.join(FLAGS.topath,'congruence.'+measure+'.'+label+'.csv'), 'w') as fid:
641649
csvwriter = csv.writer(fid, lineterminator='\n')
642650
rows = roc_table[label].keys()
643651
cols = roc_table[label][next(iter(rows-thresholds))].keys()
@@ -671,6 +679,9 @@ if __name__ == "__main__":
671679
parser.add_argument(
672680
'--basepath',
673681
type=str)
682+
parser.add_argument(
683+
'--topath',
684+
type=str)
674685
parser.add_argument(
675686
'--wavfiles',
676687
type=str)

src/data.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,19 @@ def prepare_data_index(self,
199199
video_frame_height = model_settings['video_frame_height']
200200
video_channels = model_settings['video_channels']
201201
shiftby_tics = int(shiftby_ms * model_settings["audio_tic_rate"] / 1000)
202-
search_path = os.path.join(self.data_dir, '*', '*.csv')
203202
audio_ntics = {}
204203
video_nframes = {}
205204
subsample = {x:int(y) for x,y in zip(subsample_label.split(','),subsample_skip.split(','))
206205
if x != ''}
207206
partition_labels = partition_label.split(',')
208207
if '' in partition_labels:
209208
partition_labels.remove('')
210-
for csv_path in glob(search_path):
211-
with (open(csv_path, 'r')) as csv_file:
209+
for csv_path in glob("**/*.csv", root_dir=self.data_dir, recursive=True):
210+
csv_dir = os.path.dirname(csv_path)
211+
if re.fullmatch('congruence-[0-9]{8}T[0-9]{6}', csv_dir) or \
212+
re.fullmatch('oldfiles-[0-9]{8}T[0-9]{6}', csv_dir):
213+
continue
214+
with (open(os.path.join(self.data_dir, csv_path), 'r')) as csv_file:
212215
annotation_reader = csv.reader(csv_file)
213216
annotation_list = list(annotation_reader)
214217
if len(partition_labels)>0:
@@ -226,8 +229,8 @@ def prepare_data_index(self,
226229
if (label if loss=='exclusive' else
227230
label.removeprefix(overlapped_prefix)) not in labels_touse:
228231
continue
229-
wav_path=os.path.join(os.path.dirname(csv_path),wavfile)
230-
wav_base2=[os.path.basename(os.path.dirname(csv_path)), wavfile]
232+
wav_path = os.path.join(self.data_dir, os.path.dirname(csv_path), wavfile)
233+
wav_base2 = [os.path.dirname(csv_path), wavfile]
231234
if wavfile in validation_files:
232235
set_index = 'validation'
233236
elif wavfile in testing_files:

src/gui/controller.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,16 +1027,6 @@ async def misses_actuate():
10271027
misses_succeeded(w, t)))
10281028
asyncio.create_task(actuate_finalize(threads, results, V.groundtruth_update))
10291029

1030-
def isoldfile(x,subdir,basewavs):
1031-
return \
1032-
np.any([x.startswith(b+'-') and x.endswith('.wav') for b in basewavs]) or \
1033-
x.endswith('-classify.log') or \
1034-
'-predicted' in x or \
1035-
x.endswith('-ethogram.log') or \
1036-
'-missed' in x or \
1037-
x.endswith('-misses.log') or \
1038-
x == subdir+'.csv'
1039-
10401030
def _validation_test_files(files_string, comma=True):
10411031
if files_string.rstrip(os.sep) == V.groundtruth_folder.value.rstrip(os.sep):
10421032
dfs = []
@@ -1137,31 +1127,37 @@ def sequester_stalefiles():
11371127
M.annotated_csvfiles_all=set([])
11381128
for button in V.nsounds_per_label_buttons:
11391129
button.label = str(0)
1140-
for subdir in filter(lambda x: os.path.isdir(os.path.join(V.groundtruth_folder.value,x)), \
1141-
os.listdir(V.groundtruth_folder.value)):
1130+
1131+
def isoldfile(x,curdir,basewavs):
1132+
return \
1133+
np.any([x.startswith(b+'-') and x.endswith('.wav') for b in basewavs]) or \
1134+
x.endswith('-classify.log') or \
1135+
'-predicted' in x or \
1136+
x.endswith('-ethogram.log') or \
1137+
'-missed' in x or \
1138+
x.endswith('-misses.log') or \
1139+
x == curdir+'.csv'
1140+
1141+
def _sequester(curdir):
11421142
dfs = []
1143-
for csvfile in filter(lambda x: '-annotated-' in x and x.endswith('.csv'), \
1144-
os.listdir(os.path.join(V.groundtruth_folder.value, \
1145-
subdir))):
1146-
filepath = os.path.join(V.groundtruth_folder.value, subdir, csvfile)
1147-
if os.path.getsize(filepath) > 0:
1148-
dfs.append(pd.read_csv(filepath, header=None, index_col=False))
1143+
for entry in os.listdir(curdir):
1144+
if os.path.isdir(os.path.join(curdir, entry)):
1145+
_sequester(os.path.join(curdir, entry))
1146+
elif '-annotated-' in entry and entry.endswith('.csv'):
1147+
filepath = os.path.join(curdir, entry)
1148+
if os.path.getsize(filepath) > 0:
1149+
dfs.append(pd.read_csv(filepath, header=None, index_col=False))
11491150
if dfs:
11501151
df = pd.concat(dfs)
11511152
basewavs = set([os.path.splitext(x)[0] for x in df[0]])
1152-
oldfiles = []
1153-
for oldfile in filter(lambda x: isoldfile(x,subdir,basewavs), \
1154-
os.listdir(os.path.join(V.groundtruth_folder.value, \
1155-
subdir))):
1156-
oldfiles.append(oldfile)
1153+
oldfiles = [x for x in os.listdir(curdir) if isoldfile(x, curdir, basewavs)]
11571154
if len(oldfiles)>0:
1158-
topath = os.path.join(V.groundtruth_folder.value, \
1159-
subdir, \
1160-
'oldfiles-'+M.songexplorer_starttime)
1155+
topath = os.path.join(curdir, 'oldfiles-'+M.songexplorer_starttime)
11611156
os.mkdir(topath)
11621157
for oldfile in oldfiles:
1163-
os.rename(os.path.join(V.groundtruth_folder.value, subdir, oldfile), \
1164-
os.path.join(topath, oldfile))
1158+
os.rename(os.path.join(curdir, oldfile), os.path.join(topath, oldfile))
1159+
1160+
_sequester(V.groundtruth_folder.value)
11651161
V.groundtruth_update()
11661162

11671163
async def train_actuate():
@@ -2006,14 +2002,19 @@ async def congruence_actuate():
20062002
all_files = validation_files + test_files
20072003
if '' in all_files:
20082004
all_files.remove('')
2009-
logfile = os.path.join(V.groundtruth_folder.value,'congruence.log')
2005+
timestamp = datetime.strftime(datetime.now(),'%Y%m%dT%H%M%S')
2006+
congruence_folder = os.path.join(V.groundtruth_folder.value, 'congruence-'+timestamp)
2007+
os.mkdir(congruence_folder)
2008+
logfile = os.path.join(congruence_folder, 'congruence.log')
20102009
jobid = generic_actuate("congruence", logfile,
20112010
M.congruence_where,
20122011
M.congruence_ncpu_cores,
20132012
M.congruence_ngpu_cards,
20142013
M.congruence_ngigabytes_memory,
20152014
M.congruence_cluster_flags,
20162015
"--basepath="+V.groundtruth_folder.value,
2016+
"--topath="+os.path.join(V.groundtruth_folder.value,
2017+
'congruence-'+timestamp),
20172018
"--wavfiles="+','.join(all_files),
20182019
"--portion="+V.congruence_portion.value,
20192020
"--convolve_ms="+V.congruence_convolve.value,
@@ -2031,7 +2032,7 @@ async def congruence_actuate():
20312032
threads[0] = asyncio.create_task(actuate_monitor(displaystring, results, 0, \
20322033
lambda l=logfile, t=currtime: recent_file_exists(l, t, False), \
20332034
lambda l=logfile: contains_two_timestamps(l), \
2034-
lambda l=V.groundtruth_folder.value, t=currtime, r=regex_files,
2035+
lambda l=congruence_folder, t=currtime, r=regex_files,
20352036
m=V.congruence_measure.value: congruence_succeeded(l, t, r, m)))
20362037
asyncio.create_task(actuate_finalize(threads, results, V.groundtruth_update))
20372038

src/gui/view.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,21 +1398,27 @@ def labelcounts_update():
13981398
if not os.path.isdir(groundtruth_folder.value):
13991399
labelcounts.text = ""
14001400
return dfs, subdirs
1401-
for subdir in filter(lambda x: os.path.isdir(os.path.join(groundtruth_folder.value,x)), \
1402-
os.listdir(groundtruth_folder.value)):
1403-
for csvfile in filter(lambda x: x.endswith('.csv'), \
1404-
os.listdir(os.path.join(groundtruth_folder.value, subdir))):
1405-
filepath = os.path.join(groundtruth_folder.value, subdir, csvfile)
1406-
if os.path.getsize(filepath) > 0:
1407-
try:
1408-
df = pd.read_csv(filepath, header=None, index_col=False)
1409-
except:
1410-
bokehlog.info("WARNING: "+csvfile+" is not in the correct format")
1411-
if 5<=len(df.columns)<=6:
1412-
dfs.append(df)
1413-
subdirs.append(subdir)
1414-
else:
1415-
bokehlog.info("WARNING: "+csvfile+" is not in the correct format")
1401+
1402+
def _labelcounts_update(curdir):
1403+
for entry in os.listdir(curdir):
1404+
if os.path.isdir(os.path.join(curdir, entry)):
1405+
timestamp = datetime.strftime(datetime.now(),'%Y')
1406+
if "congruence-"+timestamp not in entry and "oldfiles-"+timestamp not in entry:
1407+
_labelcounts_update(os.path.join(curdir, entry))
1408+
elif entry.endswith('.csv'):
1409+
filepath = os.path.join(curdir, entry)
1410+
if os.path.getsize(filepath) > 0:
1411+
try:
1412+
df = pd.read_csv(filepath, header=None, index_col=False)
1413+
except:
1414+
bokehlog.info("WARNING: "+entry+" is not in the correct format")
1415+
if 5<=len(df.columns)<=6:
1416+
dfs.append(df)
1417+
subdirs.append(curdir[len(groundtruth_folder.value):])
1418+
else:
1419+
bokehlog.info("WARNING: "+entry+" is not in the correct format")
1420+
_labelcounts_update(groundtruth_folder.value)
1421+
14161422
if dfs:
14171423
df = pd.concat(dfs)
14181424
M.kinds = sorted(set(df[3]))

test/runtests

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ if os.name == "posix":
5151
os.path.join("nfeaturesexclusive-32", "xvalidate_1k", "thresholds.ckpt-300.csv"),
5252
os.path.join("nfeaturesexclusive-64", "xvalidate_1k", "thresholds.ckpt-30.csv"),
5353
os.path.join("nfeaturesexclusive-64", "xvalidate_1k", "thresholds.ckpt-300.csv"),
54-
os.path.join("groundtruth-data", "congruence.tic.ambient.csv"),
55-
os.path.join("groundtruth-data", "congruence.tic.mel-pulse.csv"),
56-
os.path.join("groundtruth-data", "congruence.tic.mel-sine.csv"),
57-
os.path.join("groundtruth-data", "congruence.label.ambient.csv"),
58-
os.path.join("groundtruth-data", "congruence.label.mel-pulse.csv"),
59-
os.path.join("groundtruth-data", "congruence.label.mel-sine.csv")
54+
os.path.join("groundtruth-data", "congruence-11112233T445566", "congruence.tic.ambient.csv"),
55+
os.path.join("groundtruth-data", "congruence-11112233T445566", "congruence.tic.mel-pulse.csv"),
56+
os.path.join("groundtruth-data", "congruence-11112233T445566", "congruence.tic.mel-sine.csv"),
57+
os.path.join("groundtruth-data", "congruence-11112233T445566", "congruence.label.ambient.csv"),
58+
os.path.join("groundtruth-data", "congruence-11112233T445566", "congruence.label.mel-pulse.csv"),
59+
os.path.join("groundtruth-data", "congruence-11112233T445566", "congruence.label.mel-sine.csv")
6060
]
6161
for file in files:
6262
if not cmp(os.path.join(repo_path, "test", "scratch", "tutorial-sh", file),

0 commit comments

Comments
 (0)