Skip to content

Commit 22b9b5c

Browse files
committed
fix tic-based congruence measurement
1 parent c8cd4e5 commit 22b9b5c

2 files changed

Lines changed: 75 additions & 81 deletions

File tree

src/congruence

Lines changed: 73 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -41,97 +41,88 @@ srcdir, repodir, _ = get_srcrepobindirs()
4141

4242
def doit(intervals, do_tic, do_label):
4343

44-
#to calculate the intervals everyone agrees upon (i.e. "everyone"), choose
45-
#one of the sets at random. iterate through each interval therein, testing
46-
#whether it overlaps with any of the intervals in all of the other sets.
47-
#if it does, delete the matching intervals in the other sets, and add to
48-
#the "everyone" set just intersection of the matching intervals.
49-
50-
intervals_copy = intervals.copy()
51-
key0 = next(iter(intervals_copy.keys()))
52-
everyone = P.empty()
53-
for interval0 in intervals_copy[key0]:
54-
ivalues = {}
55-
for keyN in set(intervals_copy.keys()) - set([key0]):
56-
for i,intervalN in enumerate(intervals_copy[keyN]):
57-
if len(interval0 & intervalN)>0:
58-
ivalues[keyN] = i
59-
break
60-
if keyN not in ivalues:
61-
break
62-
if len(ivalues)==len(intervals_copy)-1:
63-
for keyN in ivalues.keys():
64-
tmp = intervals_copy[keyN][ivalues[keyN]]
65-
interval0 &= tmp
66-
tmp |= P.open(tmp.lower-1, tmp.upper+1)
67-
intervals_copy[keyN] -= tmp
68-
everyone |= interval0
44+
key0 = next(iter(intervals.keys()))
45+
everyone = intervals[key0]
46+
for keyN in set(intervals.keys()) - set([key0]):
47+
everyone &= intervals[keyN]
6948

7049
#to calculate the intervals which only one set contains (e.g. "only
7150
#songexplorer"), iteratively test if each interval therein overlaps
7251
#with any of the other sets. if it does, delete the matching intervals
7352
#in the other sets; otherwise add this interval to the "only label" set.
74-
#for tics, delete from the interval the points in each matching interval
75-
#and add what remains to the "only tic" set.
7653

77-
onlyone_tic = {}
7854
onlyone_label = {}
7955
for key0 in intervals.keys():
80-
intervals_copy = intervals.copy()
81-
onlyone_tic[key0] = P.empty() if do_tic else None
82-
onlyone_label[key0] = P.empty() if do_label else None
83-
for interval0 in intervals_copy[key0]:
84-
ivalues = {}
85-
for keyN in set(intervals_copy.keys()) - set([key0]):
86-
for i,intervalN in enumerate(intervals_copy[keyN]):
87-
if len(interval0 & intervalN)>0:
88-
ivalues[keyN] = i
89-
break
90-
if do_label and len(ivalues)==0:
91-
onlyone_label[key0] |= interval0
92-
for keyN in ivalues.keys():
93-
tmp = intervals_copy[keyN][ivalues[keyN]]
94-
tmp |= P.open(tmp.lower-1, tmp.upper+1)
95-
if do_tic:
96-
interval0 -= tmp
97-
intervals_copy[keyN] -= tmp
56+
if do_label:
57+
onlyone_label[key0] = P.empty()
58+
intervals_copy = intervals.copy()
59+
for interval0 in intervals_copy[key0]:
60+
ivalues = {}
61+
for keyN in set(intervals_copy.keys()) - set([key0]):
62+
for i,intervalN in enumerate(intervals_copy[keyN]):
63+
if len(interval0 & intervalN)>0:
64+
ivalues[keyN] = i
65+
break
66+
if len(ivalues)==0:
67+
onlyone_label[key0] |= interval0
68+
for keyN in ivalues.keys():
69+
tmp = intervals_copy[keyN][ivalues[keyN]]
70+
tmp |= P.open(tmp.lower-1, tmp.upper+1)
71+
if do_tic:
72+
interval0 -= tmp
73+
intervals_copy[keyN] -= tmp
74+
else:
75+
onlyone_label[key0] = None
76+
77+
onlyone_tic = {}
78+
for key0 in intervals.keys():
9879
if do_tic:
99-
onlyone_tic[key0] |= interval0
80+
onlyone_tic[key0] = intervals[key0]
81+
for keyN in set(intervals.keys()) - set([key0]):
82+
onlyone_tic[key0] -= intervals[keyN]
83+
else:
84+
onlyone_tic[key0] = None
10085

10186
#to calculate the intervals which only one set does not contain (e.g. "not
10287
#david"), choose one of the other sets at random. iteratively test whether
10388
#its intervals overlap with an interval in the rest of the other sets
10489
#but not with the set of interest. for those intervals which meet this
10590
#criteria, delete the matching intervals in the rest of the other sets,
106-
#and add this interval to the "not" set. for tics, add to the "not tic"
107-
#set the intersection of all the matching intervals.
91+
#and add this interval to the "not" set.
10892

109-
notone_tic = {}
11093
notone_label = {}
11194
for key0 in intervals.keys():
112-
intervals_copy = intervals.copy()
113-
notone_tic[key0] = P.empty() if do_tic else None
114-
notone_label[key0] = P.empty() if do_label else None
115-
key1 = next(iter(set(intervals_copy.keys()) - set([key0])))
116-
for interval1 in intervals_copy[key1]:
117-
ivalues = {}
118-
for keyN in set(intervals_copy.keys()) - set([key1]):
119-
for i,intervalN in enumerate(intervals_copy[keyN]):
120-
if len(interval1 & intervalN)>0:
121-
ivalues[keyN] = i
122-
break
123-
if len(ivalues)==len(intervals_copy)-2 and key0 not in ivalues.keys():
124-
if do_label:
125-
notone_label[key0] |= interval1
126-
for keyN in ivalues.keys():
127-
tmp = intervals_copy[keyN][ivalues[keyN]]
128-
if do_tic:
129-
interval1 &= tmp
130-
tmp |= P.open(tmp.lower-1, tmp.upper+1)
131-
intervals_copy[keyN] -= tmp
132-
if do_tic:
133-
notone_tic[key0] |= interval1
134-
95+
if do_label:
96+
notone_label[key0] = P.empty()
97+
intervals_copy = intervals.copy()
98+
key1 = next(iter(set(intervals_copy.keys()) - set([key0])))
99+
for interval1 in intervals_copy[key1]:
100+
ivalues = {}
101+
for keyN in set(intervals_copy.keys()) - set([key1]):
102+
for i,intervalN in enumerate(intervals_copy[keyN]):
103+
if len(interval1 & intervalN)>0:
104+
ivalues[keyN] = i
105+
break
106+
if len(ivalues)==len(intervals_copy)-2 and key0 not in ivalues.keys():
107+
notone_label[key0] |= interval1
108+
for keyN in ivalues.keys():
109+
tmp = intervals_copy[keyN][ivalues[keyN]]
110+
tmp |= P.open(tmp.lower-1, tmp.upper+1)
111+
intervals_copy[keyN] -= tmp
112+
else:
113+
notone_label[key0] = None
114+
115+
notone_tic = {}
116+
for key0 in intervals.keys():
117+
if do_tic:
118+
key1 = next(iter(set(intervals.keys()) - set([key0])))
119+
notone_tic[key0] = intervals[key1]
120+
for keyN in set(intervals.keys()) - set([key0,key1]):
121+
notone_tic[key0] &= intervals[keyN]
122+
notone_tic[key0] -= intervals[key0]
123+
else:
124+
notone_tic[key0] = None
125+
135126
return everyone, onlyone_tic, notone_tic, onlyone_label, notone_label
136127

137128
FLAGS = None
@@ -666,13 +657,18 @@ def main():
666657
return thresholds_touse
667658

668659
if do_tic:
669-
plot_versus_thresholds(roc_table_tic, measure='tic')
660+
thresholds_touse = plot_versus_thresholds(roc_table_tic, measure='tic')
661+
if len(thresholds_touse)>0:
662+
save_thresholds(logdir, model, ckpt, thresholds_touse, precision_recalls_sparse,
663+
list(thresholds_touse.keys()),
664+
'-dense-tic-'+datetime.strftime(datetime.now(),'%Y%m%dT%H%M%S'))
670665
if do_label:
671666
thresholds_touse = plot_versus_thresholds(roc_table_label, measure='label')
667+
if len(thresholds_touse)>0:
668+
save_thresholds(logdir, model, ckpt, thresholds_touse, precision_recalls_sparse,
669+
list(thresholds_touse.keys()),
670+
'-dense-label-'+datetime.strftime(datetime.now(),'%Y%m%dT%H%M%S'))
672671

673-
if len(thresholds_touse)>0:
674-
save_thresholds(logdir, model, ckpt, thresholds_touse, precision_recalls_sparse,
675-
list(thresholds_touse.keys()), True)
676672

677673
if __name__ == "__main__":
678674
parser = argparse.ArgumentParser()

src/lib.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,8 @@ def read_thresholds(logdir, model, thresholds_file):
673673
thresholds.append(row)
674674
return precision_recall_ratios, thresholds
675675

676-
def save_thresholds(logdir, model, ckpt, thresholds, ratios, labels, dense=False):
677-
filename = 'thresholds'+\
678-
('-dense-'+datetime.strftime(datetime.now(),'%Y%m%dT%H%M%S') if dense else '')+\
679-
'.ckpt-'+str(ckpt)+'.csv'
676+
def save_thresholds(logdir, model, ckpt, thresholds, ratios, labels, dense=''):
677+
filename = 'thresholds'+dense+'.ckpt-'+str(ckpt)+'.csv'
680678
fid = open(os.path.join(logdir,model,filename),"w")
681679
fidcsv = csv.writer(fid, lineterminator='\n')
682680
fidcsv.writerow(['precision/recall'] + ratios)

0 commit comments

Comments
 (0)