Skip to content

Commit d45b4ba

Browse files
committed
fix tic-based congruence measurement
1 parent 14fd746 commit d45b4ba

4 files changed

Lines changed: 204 additions & 206 deletions

File tree

src/congruence

Lines changed: 87 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -41,97 +41,86 @@ srcdir, repodir, _ = get_srcrepobindirs()
4141

4242
def doit(intervals, do_tic, do_label):
4343

44-
#to calculate the intervals everyone agrees upon (i.e. "everyone"), choose
45-
#one of the sets at random. iterate through each interval therein, testing
46-
#whether it overlaps with any of the intervals in all of the other sets.
47-
#if it does, delete the matching intervals in the other sets, and add to
48-
#the "everyone" set just intersection of the matching intervals.
49-
50-
intervals_copy = intervals.copy()
51-
key0 = next(iter(intervals_copy.keys()))
52-
everyone = P.empty()
53-
for interval0 in intervals_copy[key0]:
54-
ivalues = {}
55-
for keyN in set(intervals_copy.keys()) - set([key0]):
56-
for i,intervalN in enumerate(intervals_copy[keyN]):
57-
if len(interval0 & intervalN)>0:
58-
ivalues[keyN] = i
59-
break
60-
if keyN not in ivalues:
61-
break
62-
if len(ivalues)==len(intervals_copy)-1:
63-
for keyN in ivalues.keys():
64-
tmp = intervals_copy[keyN][ivalues[keyN]]
65-
interval0 &= tmp
66-
tmp |= P.open(tmp.lower-1, tmp.upper+1)
67-
intervals_copy[keyN] -= tmp
68-
everyone |= interval0
44+
key0 = next(iter(intervals.keys()))
45+
everyone = intervals[key0]
46+
for keyN in set(intervals.keys()) - set([key0]):
47+
everyone &= intervals[keyN]
6948

7049
#to calculate the intervals which only one set contains (e.g. "only
7150
#songexplorer"), iteratively test if each interval therein overlaps
7251
#with any of the other sets. if it does, delete the matching intervals
7352
#in the other sets; otherwise add this interval to the "only label" set.
74-
#for tics, delete from the interval the points in each matching interval
75-
#and add what remains to the "only tic" set.
7653

77-
onlyone_tic = {}
7854
onlyone_label = {}
7955
for key0 in intervals.keys():
80-
intervals_copy = intervals.copy()
81-
onlyone_tic[key0] = P.empty() if do_tic else None
82-
onlyone_label[key0] = P.empty() if do_label else None
83-
for interval0 in intervals_copy[key0]:
84-
ivalues = {}
85-
for keyN in set(intervals_copy.keys()) - set([key0]):
86-
for i,intervalN in enumerate(intervals_copy[keyN]):
87-
if len(interval0 & intervalN)>0:
88-
ivalues[keyN] = i
89-
break
90-
if do_label and len(ivalues)==0:
91-
onlyone_label[key0] |= interval0
92-
for keyN in ivalues.keys():
93-
tmp = intervals_copy[keyN][ivalues[keyN]]
94-
tmp |= P.open(tmp.lower-1, tmp.upper+1)
95-
if do_tic:
96-
interval0 -= tmp
97-
intervals_copy[keyN] -= tmp
56+
if do_label:
57+
onlyone_label[key0] = P.empty()
58+
intervals_copy = intervals.copy()
59+
for interval0 in intervals_copy[key0]:
60+
ivalues = {}
61+
for keyN in set(intervals_copy.keys()) - set([key0]):
62+
for i,intervalN in enumerate(intervals_copy[keyN]):
63+
if len(interval0 & intervalN)>0:
64+
ivalues[keyN] = i
65+
break
66+
if len(ivalues)==0:
67+
onlyone_label[key0] |= interval0
68+
for keyN in ivalues.keys():
69+
tmp = intervals_copy[keyN][ivalues[keyN]]
70+
tmp |= P.open(tmp.lower-1, tmp.upper+1)
71+
intervals_copy[keyN] -= tmp
72+
else:
73+
onlyone_label[key0] = None
74+
75+
onlyone_tic = {}
76+
for key0 in intervals.keys():
9877
if do_tic:
99-
onlyone_tic[key0] |= interval0
78+
onlyone_tic[key0] = intervals[key0]
79+
for keyN in set(intervals.keys()) - set([key0]):
80+
onlyone_tic[key0] -= intervals[keyN]
81+
else:
82+
onlyone_tic[key0] = None
10083

10184
#to calculate the intervals which only one set does not contain (e.g. "not
10285
#david"), choose one of the other sets at random. iteratively test whether
10386
#its intervals overlap with an interval in the rest of the other sets
10487
#but not with the set of interest. for those intervals which meet this
10588
#criteria, delete the matching intervals in the rest of the other sets,
106-
#and add this interval to the "not" set. for tics, add to the "not tic"
107-
#set the intersection of all the matching intervals.
89+
#and add this interval to the "not" set.
10890

109-
notone_tic = {}
11091
notone_label = {}
11192
for key0 in intervals.keys():
112-
intervals_copy = intervals.copy()
113-
notone_tic[key0] = P.empty() if do_tic else None
114-
notone_label[key0] = P.empty() if do_label else None
115-
key1 = next(iter(set(intervals_copy.keys()) - set([key0])))
116-
for interval1 in intervals_copy[key1]:
117-
ivalues = {}
118-
for keyN in set(intervals_copy.keys()) - set([key1]):
119-
for i,intervalN in enumerate(intervals_copy[keyN]):
120-
if len(interval1 & intervalN)>0:
121-
ivalues[keyN] = i
122-
break
123-
if len(ivalues)==len(intervals_copy)-2 and key0 not in ivalues.keys():
124-
if do_label:
125-
notone_label[key0] |= interval1
126-
for keyN in ivalues.keys():
127-
tmp = intervals_copy[keyN][ivalues[keyN]]
128-
if do_tic:
129-
interval1 &= tmp
130-
tmp |= P.open(tmp.lower-1, tmp.upper+1)
131-
intervals_copy[keyN] -= tmp
132-
if do_tic:
133-
notone_tic[key0] |= interval1
134-
93+
if do_label:
94+
notone_label[key0] = P.empty()
95+
intervals_copy = intervals.copy()
96+
key1 = next(iter(set(intervals_copy.keys()) - set([key0])))
97+
for interval1 in intervals_copy[key1]:
98+
ivalues = {}
99+
for keyN in set(intervals_copy.keys()) - set([key1]):
100+
for i,intervalN in enumerate(intervals_copy[keyN]):
101+
if len(interval1 & intervalN)>0:
102+
ivalues[keyN] = i
103+
break
104+
if len(ivalues)==len(intervals_copy)-2 and key0 not in ivalues.keys():
105+
notone_label[key0] |= interval1
106+
for keyN in ivalues.keys():
107+
tmp = intervals_copy[keyN][ivalues[keyN]]
108+
tmp |= P.open(tmp.lower-1, tmp.upper+1)
109+
intervals_copy[keyN] -= tmp
110+
else:
111+
notone_label[key0] = None
112+
113+
notone_tic = {}
114+
for key0 in intervals.keys():
115+
if do_tic:
116+
key1 = next(iter(set(intervals.keys()) - set([key0])))
117+
notone_tic[key0] = intervals[key1]
118+
for keyN in set(intervals.keys()) - set([key0,key1]):
119+
notone_tic[key0] &= intervals[keyN]
120+
notone_tic[key0] -= intervals[key0]
121+
else:
122+
notone_tic[key0] = None
123+
135124
return everyone, onlyone_tic, notone_tic, onlyone_label, notone_label
136125

137126
FLAGS = None
@@ -396,6 +385,9 @@ def main():
396385
ax.set_xticklabels(xdata, rotation=40, ha='right')
397386
ax.set_title('all files', fontsize=8)
398387

388+
def tic_dist(x):
389+
return x.upper - x.lower - 1 + (x.left==P.CLOSED) + (x.right==P.CLOSED)
390+
399391
for pr in precision_recalls:
400392
print('P/R = '+pr)
401393
for label in timestamps.keys():
@@ -436,10 +428,10 @@ def main():
436428

437429
if do_tic:
438430
plot_file(fig_tic, fig_tic_venn,
439-
[sum([x.upper-x.lower+1 for x in everyone[pr][label][csvbase]]), \
440-
*[sum([y.upper-y.lower+1 for y in onlyone_tic[pr][label][csvbase][x]]) \
431+
[sum([tic_dist(x) for x in everyone[pr][label][csvbase]]), \
432+
*[sum([tic_dist(y) for y in onlyone_tic[pr][label][csvbase][x]]) \
441433
for x in sorted_hm]],
442-
[sum([y.upper-y.lower+1 for y in notone_tic[pr][label][csvbase][x]]) \
434+
[sum([tic_dist(y) for y in notone_tic[pr][label][csvbase][x]]) \
443435
for x in sorted_hm] if len(sorted_hm)>2 else None)
444436
if do_label:
445437
plot_file(fig_label, fig_label_venn,
@@ -456,11 +448,11 @@ def main():
456448
csvbase0 = list(onlyone_tic[pr][label].keys())[0]
457449
if do_tic:
458450
plot_sumfiles(fig_tic, fig_tic_venn,
459-
[sum([x.upper-x.lower+1 for f in everyone[pr][label].values() for x in f]),
460-
*[sum([sum([y.upper-y.lower+1 for y in f[hm]])
451+
[sum([tic_dist(x) for f in everyone[pr][label].values() for x in f]),
452+
*[sum([sum([tic_dist(y) for y in f[hm]])
461453
for f in onlyone_tic[pr][label].values()])
462454
for hm in sorted_hm]],
463-
[sum([sum([y.upper-y.lower+1 for y in f[hm]]) \
455+
[sum([sum([tic_dist(y) for y in f[hm]]) \
464456
for f in notone_tic[pr][label].values()])
465457
for hm in onlyone_tic[pr][label][csvbase0].keys()]
466458
if len(sorted_hm)>2 else None)
@@ -504,8 +496,11 @@ def main():
504496
csvwriter = csv.writer(fid, lineterminator='\n')
505497
for ilabel,label in enumerate(timestamps.keys()):
506498
for i in intervals[ilabel]:
499+
if tic_dist(i)==0: continue
507500
csvwriter.writerow([os.path.basename(csvbase),
508-
int(i.lower), int(i.upper), whichset, label])
501+
int(i.lower)+(i.left==P.OPEN),
502+
int(i.upper)-(i.right==P.OPEN),
503+
whichset, label])
509504

510505
for pr in filter(lambda x: x not in thresholds, precision_recalls):
511506
for csvbase in csvbases:
@@ -552,19 +547,19 @@ def main():
552547
for hm in sorted_hm:
553548
key = 'only '+hm
554549
if do_tic:
555-
roc_table_tic[label][pr][key] = int(sum([sum([y.upper-y.lower+1 for y in f[hm]]) \
550+
roc_table_tic[label][pr][key] = int(sum([sum([tic_dist(y) for y in f[hm]]) \
556551
for f in onlyone_tic[pr][label].values()]))
557552
if do_label:
558553
roc_table_label[label][pr][key] = sum([len(f[hm]) for f in onlyone_label[pr][label].values()])
559554
if len(sorted_hm)>2:
560555
key = 'not '+hm
561556
if do_tic:
562-
roc_table_tic[label][pr][key] = int(sum([sum([y.upper-y.lower+1 for y in f[hm]]) \
557+
roc_table_tic[label][pr][key] = int(sum([sum([tic_dist(y) for y in f[hm]]) \
563558
for f in notone_tic[pr][label].values()]))
564559
if do_label:
565560
roc_table_label[label][pr][key] = sum([len(f[hm])
566561
for f in notone_label[pr][label].values()])
567-
roc_table_tic[label][pr]['Everyone'] = int(sum([x.upper-x.lower+1
562+
roc_table_tic[label][pr]['Everyone'] = int(sum([tic_dist(x)
568563
for f in everyone[pr][label].values()
569564
for x in f]))
570565
roc_table_label[label][pr]['Everyone'] = sum([len(f) for f in everyone[pr][label].values()])
@@ -666,13 +661,18 @@ def main():
666661
return thresholds_touse
667662

668663
if do_tic:
669-
plot_versus_thresholds(roc_table_tic, measure='tic')
664+
thresholds_touse = plot_versus_thresholds(roc_table_tic, measure='tic')
665+
if len(thresholds_touse)>0:
666+
save_thresholds(logdir, model, ckpt, thresholds_touse, precision_recalls_sparse,
667+
list(thresholds_touse.keys()),
668+
'-dense-tic-'+datetime.strftime(datetime.now(),'%Y%m%dT%H%M%S'))
670669
if do_label:
671670
thresholds_touse = plot_versus_thresholds(roc_table_label, measure='label')
671+
if len(thresholds_touse)>0:
672+
save_thresholds(logdir, model, ckpt, thresholds_touse, precision_recalls_sparse,
673+
list(thresholds_touse.keys()),
674+
'-dense-label-'+datetime.strftime(datetime.now(),'%Y%m%dT%H%M%S'))
672675

673-
if len(thresholds_touse)>0:
674-
save_thresholds(logdir, model, ckpt, thresholds_touse, precision_recalls_sparse,
675-
list(thresholds_touse.keys()), True)
676676

677677
if __name__ == "__main__":
678678
parser = argparse.ArgumentParser()

src/lib.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,8 @@ def read_thresholds(logdir, model, thresholds_file):
673673
thresholds.append(row)
674674
return precision_recall_ratios, thresholds
675675

676-
def save_thresholds(logdir, model, ckpt, thresholds, ratios, labels, dense=False):
677-
filename = 'thresholds'+\
678-
('-dense-'+datetime.strftime(datetime.now(),'%Y%m%dT%H%M%S') if dense else '')+\
679-
'.ckpt-'+str(ckpt)+'.csv'
676+
def save_thresholds(logdir, model, ckpt, thresholds, ratios, labels, dense=''):
677+
filename = 'thresholds'+dense+'.ckpt-'+str(ckpt)+'.csv'
680678
fid = open(os.path.join(logdir,model,filename),"w")
681679
fidcsv = csv.writer(fid, lineterminator='\n')
682680
fidcsv.writerow(['precision/recall'] + ratios)

0 commit comments

Comments
 (0)