-
Notifications
You must be signed in to change notification settings - Fork 58
Loading pytorch model from torchvision #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vbhavank
wants to merge
15
commits into
master
Choose a base branch
from
dev/pytorchdescrgen
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
615a1cc
Loading pytorch model from torchvision
vbhavank 502f3a0
Removed imports and code refactoring
vbhavank c7ffbc8
Added overwrite features and safety nets
vbhavank f189e49
Removed comments
vbhavank 3cc0970
Removed one unwanted import
vbhavank e6a39cf
Fixed is_usable functional bug
vbhavank 8a46c96
Doc fix and truncate model change
vbhavank 6f741d3
Fixed multi-dimensional feature to 2-D vector
vbhavank 8cdc652
Minor refactoring
vbhavank 6c2fd63
truncate code refactoring to reduce operations
vbhavank c629a35
Test for PytorchModelDescriptor
vbhavank 3453ce5
Removed unused import Dataset
vbhavank 29a8623
Added pytorch test for lenna
vbhavank a84303c
Input image dimensions as argument in pytorch descriptor
vbhavank ae97c85
Fix for passing 10 tests
vbhavank File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
2 changes: 2 additions & 0 deletions
2
python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from .pytorch_model_descriptors import PytorchModelDescriptor | ||
| DESCRIPTOR_GENERATOR_CLASS = PytorchModelDescriptor |
345 changes: 345 additions & 0 deletions
345
python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/pytorch_model_descriptors.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,345 @@ | ||
| from smqtk.algorithms.descriptor_generator import DescriptorGenerator, \ | ||
| DFLT_DESCRIPTOR_FACTORY | ||
|
|
||
| import torch | ||
| from torch.utils.data import DataLoader | ||
| from torch.autograd import Variable | ||
| import torchvision | ||
|
|
||
| from .utils import PytorchImagedataset | ||
|
|
||
| from collections import deque | ||
| import multiprocessing | ||
| import six | ||
|
|
||
| try: | ||
| import torch | ||
| import torchvision | ||
| except ImportError as ex: | ||
| logging.getLogger(__name__).warning("Failed to import torch/torchvision \ | ||
| module: %s", str(ex)) | ||
| torch = None | ||
| torchvision = None | ||
|
|
||
| __all__ = [ | ||
| "PytorchModelDescriptor", | ||
| ] | ||
|
|
||
|
|
||
| class PytorchModelDescriptor (DescriptorGenerator): | ||
| """ | ||
| Compute images against a PyTorch model, extracting a layer as the content | ||
| descriptor. | ||
| """ | ||
|
|
||
|
|
||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
| @classmethod | ||
| def is_usable(cls): | ||
| valid = torch is not None or torchvision is not None | ||
| if not valid: | ||
| cls.get_logger().debug("PyTorch and torchvision cannot be imported") | ||
| return valid | ||
|
|
||
|
|
||
| def truncate_pytorch_model(self, model, return_layer_list): | ||
| """ | ||
| Given a pytorch model and label of layer, the function returns a | ||
| model truncated at return layer. | ||
| :param model: The pytorch model that needs to be truncated at | ||
| a certain return layer in network. | ||
| :type model: torch.nn | ||
| :param return_layer_list: List of return layers in hierarchical order. | ||
| :type return_layer_list: List of string [str, str, ...] | ||
|
|
||
| :return seq_mod,t1_model: Last sequential block of network | ||
| in present state. | ||
| :return model,trunc_model: Model truncated until last sequential block | ||
| :rtype: torch.nn | ||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
| """ | ||
| if len(return_layer_list) == 2: | ||
| t1_model, _ = self.truncate_pytorch_model(model, | ||
| [return_layer_list[0]]) | ||
| sub_module_list = [_ for _ in t1_model.named_children()] | ||
| for inx, lay in enumerate(sub_module_list): | ||
| if return_layer_list[1] == lay[0]: | ||
| sub_pos = inx | ||
| trunc_pos = len(sub_module_list) - (sub_pos+1) | ||
| model_sub_ = torch.nn.Sequential(*(list(t1_model.children())) | ||
| [:-trunc_pos]) | ||
| setattr(locals().get("model"), 'classifier', model_sub_) | ||
| return t1_model, model | ||
| else: | ||
| module_list = list(model.__dict__['_modules']) | ||
| layer_position = (module_list.index(return_layer_list[0])) | ||
| if len(module_list) == layer_position: | ||
| return model, model | ||
| else: | ||
| trunc_model = torch.nn.Sequential(*(list(model.children()) | ||
| [:layer_position+1])) | ||
| try: | ||
| seq_mod = torch.nn.Sequential((list(model.children()) | ||
| [layer_position])) | ||
| except IndexError: | ||
| seq_mod = None | ||
| return seq_mod, trunc_model | ||
|
|
||
| def check_model_dict(self, model, return_key): | ||
| """ | ||
| Checks model dictionary to see if the top return layer is present. | ||
| :param model: Base model to be checked for presense of layer | ||
| :type model: torch.nn | ||
| :param return_keys: Label of top return layer for feature | ||
| collection. | ||
| :type return_keys: str | ||
| """ | ||
| try: | ||
| if return_key is not '': | ||
| assert (getattr(model,"__dict__")).get("_modules")[return_key] | ||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
| except KeyError: | ||
| raise KeyError("KeyError: Given return layer is \ | ||
| not present in model") | ||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
|
|
||
| def __init__(self, | ||
| model_name = 'resnet18', return_layer = 'avgpool', | ||
| custom_model_arch = None, weights_filepath = None, | ||
| norm_mean = None, norm_std = None, use_gpu = True, | ||
| batch_size = 32, pretrained = True): | ||
| """ | ||
| Create a PyTorch CNN descriptor generator | ||
| :param model_name: Name of model on PyTorch library, | ||
| for example: 'resnet50', 'vgg16'. | ||
| :type model_name: str | ||
| :param return_layer: The label of the layer we take data from | ||
| to compose output descriptor vector. | ||
| :type return_layer: str | ||
| :param custom_model_arch: Method that implements a custom Pytorch | ||
| model. | ||
| :type custom_model_arch: torch.nn | ||
| :param weights_filepath: Absolute file path to weights of a custom | ||
| model custom_model_arch. | ||
| :type weights_filepath: str | ||
| :param norm_mean: Mean for normalizing images across three channels. | ||
| :type norm_mean: List [float, float, float]. | ||
| :param norm_std: Standard deviation for normalizing images across | ||
| three channels. | ||
| :type norm_std: List [float, float, float]. | ||
| :param use_gpu: If Caffe should try to use the GPU | ||
| :type use_gpu: bool | ||
| :param batch_size: The maximum number of images to process in one feed | ||
| forward of the network. This is especially important for GPUs since | ||
| they can only process a batch that will fit in the GPU memory | ||
| space. | ||
| :type batch_size: int | ||
| :param pretrained: The network is loaded with pretrained weights | ||
| available on torchvision instead of custom weights. | ||
| :type pretrained: bool | ||
| """ | ||
| self.model_name = model_name | ||
| self.transforms = torchvision.transforms.Compose([ | ||
| torchvision.transforms.Resize((224,224)), | ||
| torchvision.transforms.ToTensor(), | ||
| torchvision.transforms.Normalize(norm_mean, norm_std)]) | ||
| self.batch_size = batch_size | ||
| self.return_layer = return_layer | ||
| self.norm_mean = norm_mean | ||
| self.norm_std = norm_std | ||
| self.use_gpu = use_gpu | ||
| self.pretrained = pretrained | ||
| self.weights_filepath = weights_filepath | ||
| self.custom_model_arch = custom_model_arch | ||
| if not custom_model_arch: | ||
| try: | ||
| assert model_name in torchvision.models.__dict__.keys() | ||
| except AssertionError: | ||
| self._log.info("Invalid model name, model not present \ | ||
| in torchvision. Please load network architecture") | ||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
| self._log.info("Available models include:{}" | ||
| .format([s for s in torchvision.models.__dict__.keys() \ | ||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
| if not "__" in s])) | ||
| model = getattr(torchvision.models, self.model_name)(self.pretrained) | ||
| ret_para = [k for k in self.return_layer.split('.')] | ||
| self.check_model_dict(model, ret_para[0]) | ||
| try: | ||
| _, new_model = self.truncate_pytorch_model(model, ret_para) | ||
| assert new_model | ||
| model = new_model | ||
| except AssertionError: | ||
| self._log.info("Selected model{}".format(model)) | ||
| raise AssertionError("Invalid return layer label selected \ | ||
| model") | ||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
| else: | ||
| model = custom_model_arch | ||
| if (not self.pretrained) and (self.weights_filepath): | ||
| checkpoint = torch.load(self.weights_filepath) | ||
| if 'state_dict' in checkpoint: | ||
| checkpoint = checkpoint['state_dict'] | ||
| model.load_state_dict(checkpoint) | ||
| model.eval() | ||
| if self.use_gpu: | ||
| try: | ||
| model = model.cuda() | ||
| self.model = torch.nn.DataParallel(model) | ||
| except AssertionError: | ||
| self.model = model | ||
| self._log.info("Cannot load PyTorch model to GPU, running on CPU") | ||
|
|
||
| def __getstate__(self): | ||
| return self.get_config() | ||
|
|
||
| def __setstate__(self, state): | ||
| # This works because configuration parameters exactly match up with | ||
| # instance attributes | ||
| self.__dict__.update(state) | ||
| self._setup_network() | ||
|
|
||
| def get_config(self): | ||
| """ | ||
| Return a JSON-compliant dictionary that could be passed to this class's | ||
| ``from_config`` method to produce an instance with identical | ||
| configuration. | ||
| In the common case, this involves naming the keys of the dictionary | ||
| based on the initialization argument names as if it were to be passed | ||
| to the constructor via dictionary expansion. | ||
| :return: JSON type compliant configuration dictionary. | ||
| :rtype: dict | ||
| """ | ||
| return { | ||
| 'model_name': self.model_name, | ||
| 'return_layer': self.return_layer, | ||
| 'custom_model_arch': self.custom_model_arch, | ||
| 'weights_filepath': self.weights_filepath, | ||
| 'norm_mean': self.norm_mean, | ||
| 'norm_std': self.norm_std, | ||
| 'use_gpu': self.use_gpu, | ||
| 'batch_size': self.batch_size, | ||
| 'pretrained': self.pretrained, | ||
| } | ||
|
|
||
| def valid_content_types(self): | ||
| """ | ||
| :return: A set valid MIME type content types that this descriptor can | ||
| handle. | ||
| :rtype: set[str] | ||
| """ | ||
| return { | ||
| 'image/bmp', | ||
| 'image/tiff', | ||
| 'image/png', | ||
| 'image/jpeg', | ||
| } | ||
|
|
||
| def compute_descriptor(self, data, descr_factory=DFLT_DESCRIPTOR_FACTORY, | ||
| overwrite=False): | ||
| """ | ||
| Given some data, return a descriptor element containing a descriptor | ||
| vector. | ||
| :raises RuntimeError: Descriptor extraction failure of some kind. | ||
| :raises ValueError: Given data element content was not of a valid type | ||
| with respect to this descriptor. | ||
| :param data: Some kind of input data for the feature descriptor. | ||
| :type data: smqtk.representation.DataElement | ||
| :param descr_factory: Factory instance to produce the wrapping | ||
| descriptor element instance. The default factory produces | ||
| ``DescriptorMemoryElement`` instances by default. | ||
| :type descr_factory: smqtk.representation.DescriptorElementFactory | ||
| :param overwrite: Whether or not to force re-computation of a descriptor | ||
| vector for the given data even when there exists a precomputed | ||
| vector in the generated DescriptorElement as generated from the | ||
| provided factory. This will overwrite the persistently stored vector | ||
| if the provided factory produces a DescriptorElement implementation | ||
| with such storage. | ||
| :type overwrite: bool | ||
| :return: Result descriptor element. UUID of this output descriptor is | ||
| the same as the UUID of the input data element. | ||
| :rtype: smqtk.representation.DescriptorElement | ||
| """ | ||
| m = self.compute_descriptor_async([data], descr_factory) | ||
| return m[data.uuid()] | ||
|
|
||
| def _compute_descriptor(self, data): | ||
| raise NotImplementedError("Shouldn't get here as " | ||
| "compute_descriptor[_async] is being " | ||
| "overridden") | ||
|
|
||
| def compute_descriptor_async(self, data_set, descriptor_elem_factory= | ||
| DFLT_DESCRIPTOR_FACTORY, overwrite=False): | ||
| """ | ||
| Asynchronously compute feature data for multiple data items. | ||
| :param data_iter: Iterable of data elements to compute features for. | ||
| These must have UIDs assigned for feature association in return | ||
| value. | ||
| :type data_iter: collections.Iterable[smqtk.representation.DataElement] | ||
| :param descr_factory: Factory instance to produce the wrapping | ||
| descriptor element instance. The default factory produces | ||
| ``DescriptorMemoryElement`` instances by default. | ||
| :type descr_factory: smqtk.representation.DescriptorElementFactory | ||
| :param overwrite: Whether or not to force re-computation of a descriptor | ||
| vectors for the given data even when there exists precomputed | ||
| vectors in the generated DescriptorElements as generated from the | ||
| provided factory. This will overwrite the persistently stored | ||
| vectors if the provided factory produces a DescriptorElement | ||
| implementation such storage. | ||
| :type overwrite: bool | ||
| :raises ValueError: An input DataElement was of a content type that we | ||
| cannot handle. | ||
| :return: Mapping of input DataElement UUIDs to the computed descriptor | ||
| element for that data. DescriptorElement UUID's are congruent with | ||
| the UUID of the data element it is the descriptor of. | ||
| :rtype: dict[collections.Hashable, | ||
| smqtk.representation.DescriptorElement] | ||
| """ | ||
| self.data_elements = {} | ||
| self.descr_elements = {} | ||
| self.uuid4proc = deque() | ||
| for d in data_set: | ||
| ct = d.content_type() | ||
| if ct not in self.valid_content_types(): | ||
| self._log.error("Cannot compute descriptor from content type " | ||
| "'%s' data: %s)" % (ct, d)) | ||
| raise ValueError("Cannot compute descriptor from content type " | ||
| "'%s' data: %s)" % (ct, d)) | ||
| self.data_elements[d.uuid()] = d | ||
| self.descr_elements[d.uuid()] = descriptor_elem_factory \ | ||
| .new_descriptor(self.name, d.uuid()) | ||
| def check_get_uuid(descriptor_elem): | ||
| if overwrite or not descriptor_elem.has_vector(): | ||
| self.uuid4proc.append(descriptor_elem.uuid()) | ||
| procs = multiprocessing.cpu_count() | ||
| if len(self.data_elements) < procs: | ||
| procs = len(self.data_elements) | ||
| # Using thread-pool due to in-line function + updating local deque | ||
| p = multiprocessing.pool.ThreadPool(procs) | ||
| try: | ||
| p.map(check_get_uuid, six.itervalues(self.descr_elements)) | ||
| finally: | ||
| p.close() | ||
| p.join() | ||
| del p | ||
| self._log.debug("%d descriptors already computed", | ||
| len(self.data_elements) - len(self.uuid4proc)) | ||
| self._log.debug("Given %d unique data elements", \ | ||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
| len(self.data_elements)) | ||
| if len(self.data_elements) == 0: | ||
| raise ValueError("No data elements provided") | ||
| if self.uuid4proc: | ||
| kwargs = {'num_workers': procs, 'pin_memory': True} | ||
| data_loader_cls = PytorchImagedataset(self.data_elements, | ||
| self.uuid4proc, self.transforms) | ||
| data_loader = DataLoader(data_loader_cls, | ||
| batch_size=self.batch_size, shuffle=False, **kwargs) | ||
| self._log.debug("Extracting PyTorch features") | ||
| for (d, uuids) in data_loader: | ||
| if self.use_gpu: | ||
| d = d.cuda() | ||
| pytorch_f = self.model(Variable(d)).squeeze() | ||
| if len(pytorch_f.shape) < 2: | ||
| pytorch_f = pytorch_f.unsqueeze(0) | ||
| if len(pytorch_f.shape) > 2: | ||
| pytorch_f = pytorch_f.view(pytorch_f.shape(0), (pytorch_f.shape(1)*pytorch_f.shape(2))) | ||
| [self.descr_elements[uuid].set_vector( | ||
| pytorch_f.data.cpu().numpy()[idx]) | ||
| for idx, uuid in enumerate(uuids)] | ||
| self._log.debug("forming output dict") | ||
| return dict((self.data_elements[k].uuid(), self.descr_elements[k]) | ||
| for k in self.data_elements) | ||
|
|
||
23 changes: 23 additions & 0 deletions
23
python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| from torch.utils.data import Dataset | ||
| from PIL import Image | ||
| import io | ||
|
|
||
|
|
||
| class PytorchImagedataset(Dataset): | ||
|
vbhavank marked this conversation as resolved.
|
||
| def __init__(self, img_paths, uuid4proc, transforms): | ||
| self.transform = transforms | ||
| self._uuid4proc = uuid4proc | ||
| self.image_path_list = img_paths | ||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
| if not self.image_path_list: | ||
| self._log.info("Given file path contains no images of specified format {}".format(img_paths[0].split('.')[-1])) | ||
|
vbhavank marked this conversation as resolved.
Outdated
|
||
|
|
||
| def __len__(self): | ||
| return len(self.image_path_list) | ||
|
|
||
| def __getitem__(self, idx): | ||
| img = Image.open(io.BytesIO(self.image_path_list[self._uuid4proc[idx]].get_bytes())) | ||
| img = img.convert('RGB') | ||
| if self.transform: | ||
| img = self.transform(img) | ||
| res = (img, self._uuid4proc[idx]) | ||
| return res | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.