Skip to content

Commit f1d7ed0

Browse files
committed
refactor data.py into module instead of class thereby fixing multiprocessing on windows
1 parent 7f57e5b commit f1d7ed0

6 files changed

Lines changed: 492 additions & 489 deletions

File tree

src/activations

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ import socket
6262
import numpy as np
6363
import tensorflow as tf
6464

65-
import data
65+
import data as D
6666

6767
import datetime as dt
6868

@@ -120,7 +120,7 @@ def main():
120120
'parallelize': 1,
121121
'context': FLAGS.context}
122122

123-
audio_processor = data.AudioProcessor(
123+
D.init(
124124
FLAGS.data_dir,
125125
FLAGS.shiftby,
126126
FLAGS.labels_touse.split(','), FLAGS.kinds_touse.split(','),
@@ -146,14 +146,14 @@ def main():
146146

147147
time_shift_tics = int(FLAGS.shiftby * FLAGS.audio_tic_rate * FLAGS.time_scale)
148148

149-
testing_set_size = audio_processor.set_size('testing')
149+
testing_set_size = D.set_size('testing')
150150
if testing_set_size==0:
151151
print('ERROR: no annotations to process')
152152
exit()
153153

154154
def infer_step(isound):
155155
# HACK: get_data not guaranteed to return isounds in order
156-
fingerprints, _, sounds = audio_processor.get_data(
156+
fingerprints, _, sounds = D.get_data(
157157
FLAGS.batch_size, isound, model_settings,
158158
FLAGS.loss, FLAGS.overlapped_prefix,
159159
time_shift_tics, 'testing',

src/classify

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ import importlib
6060
from lib import log_nvidia_smi_output, select_GPUs
6161

6262
import tifffile
63-
from lib import compute_background, load_audio_read_plugin
63+
from lib import compute_background, load_audio_read_plugin, load_video_read_plugin
6464

6565
FLAGS = None
6666

@@ -80,13 +80,8 @@ def main():
8080
video_findfile = importlib.import_module(os.path.basename(FLAGS.video_findfile)).video_findfile
8181

8282
load_audio_read_plugin(FLAGS.audio_read_plugin, FLAGS.audio_read_plugin_kwargs)
83-
from lib import audio_read, trim_ext
84-
85-
sys.path.append(os.path.dirname(FLAGS.video_read_plugin))
86-
video_read_module = importlib.import_module(os.path.basename(FLAGS.video_read_plugin))
87-
def video_read(fullpath, start_frame=None, stop_frame=None):
88-
return video_read_module.video_read(fullpath, start_frame, stop_frame,
89-
**FLAGS.video_read_plugin_kwargs)
83+
load_video_read_plugin(FLAGS.video_read_plugin, FLAGS.video_read_plugin_kwargs)
84+
from lib import audio_read, video_read, trim_ext
9085

9186
with open(FLAGS.model_labels, 'r') as fid:
9287
model_labels = fid.read().splitlines()

0 commit comments

Comments
 (0)