diff --git a/docs/release_notes/pending_release.rst b/docs/release_notes/pending_release.rst index dfba8bc35..72be2e66d 100644 --- a/docs/release_notes/pending_release.rst +++ b/docs/release_notes/pending_release.rst @@ -5,6 +5,11 @@ SMQTK Pending Release Notes Updates / New Features ---------------------- +Scripts + - `train_itq` + - Added an optional configuration property + ``max_descriptors``. The descriptors used to train the ITQ + model are a random sample of the available descriptors. Fixes ----- diff --git a/python/smqtk/bin/train_itq.py b/python/smqtk/bin/train_itq.py index cfe40bf1c..2c4a937ac 100644 --- a/python/smqtk/bin/train_itq.py +++ b/python/smqtk/bin/train_itq.py @@ -9,11 +9,19 @@ be used to specify a sub-set of descriptors in the configured index to train on. This only works if the stored descriptors' UUID is a type of string. + +The ``max_descriptors'' configuration property is optional and can be +used to cap the number of descriptors used to train the model. If +more descriptors are available than requested, they are randomly +subsampled. """ import logging import os.path +import numpy +from six.moves import zip + from smqtk.algorithms.nn_index.lsh.functors.itq import ItqFunctor from smqtk.representation import ( get_descriptor_index_impls, @@ -32,6 +40,7 @@ def default_config(): "itq_config": ItqFunctor.get_default_config(), "uuids_list_filepath": None, "descriptor_index": plugin.make_config(get_descriptor_index_impls()), + "max_descriptors": None, } @@ -45,6 +54,7 @@ def main(): log = logging.getLogger(__name__) uuids_list_filepath = config['uuids_list_filepath'] + max_descriptors = config['max_descriptors'] log.info("Initializing ITQ functor") #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor @@ -63,12 +73,26 @@ def uuids_iter(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() + uuids = uuids_iter() log.info("Loading UUIDs list from file: %s", uuids_list_filepath) - d_iter = descriptor_index.get_many_descriptors(uuids_iter()) + if max_descriptors: + uuids = list(uuids) + if max_descriptors < len(uuids): + log.info("Subsampling UUIDs (old count=%d, new count=%d)", + len(uuids), max_descriptors) + uuids = numpy.random.choice(uuids, max_descriptors, replace=False) + d_iter = descriptor_index.get_many_descriptors(uuids) else: + d_length = len(descriptor_index) log.info("Using UUIDs from loaded DescriptorIndex (count=%d)", - len(descriptor_index)) - d_iter = descriptor_index + d_length) + if max_descriptors and max_descriptors < d_length: + log.info("Subsampling loaded DescriptorIndex (new count=%d)", + max_descriptors) + selected = numpy.random.permutation(numpy.arange(d_length) < max_descriptors) + d_iter = (d for d, s in zip(descriptor_index, selected) if s) + else: + d_iter = descriptor_index log.info("Fitting ITQ model") functor.fit(d_iter)