-
Notifications
You must be signed in to change notification settings - Fork 58
Loading pytorch model from torchvision #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vbhavank
wants to merge
15
commits into
master
Choose a base branch
from
dev/pytorchdescrgen
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
615a1cc
Loading pytorch model from torchvision
vbhavank 502f3a0
Removed imports and code refactoring
vbhavank c7ffbc8
Added overwrite features and safety nets
vbhavank f189e49
Removed comments
vbhavank 3cc0970
Removed one unwanted import
vbhavank e6a39cf
Fixed is_usable functional bug
vbhavank 8a46c96
Doc fix and truncate model change
vbhavank 6f741d3
Fixed multi-dimensional feature to 2-D vector
vbhavank 8cdc652
Minor refactoring
vbhavank 6c2fd63
truncate code refactoring to reduce operations
vbhavank c629a35
Test for PytorchModelDescriptor
vbhavank 3453ce5
Removed unused import Dataset
vbhavank 29a8623
Added pytorch test for lenna
vbhavank a84303c
Input image dimensions as argument in pytorch descriptor
vbhavank ae97c85
Fix for passing 10 tests
vbhavank File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
2 changes: 2 additions & 0 deletions
2
python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from .pytorch_model_descriptors import PytorchModelDescriptor | ||
| DESCRIPTOR_GENERATOR_CLASS = PytorchModelDescriptor |
387 changes: 387 additions & 0 deletions
387
python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/pytorch_model_descriptors.py
Large diffs are not rendered by default.
Oops, something went wrong.
55 changes: 55 additions & 0 deletions
55
python/smqtk/algorithms/descriptor_generator/pytorchdescriptor/utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| from torch.utils.data import Dataset | ||
| from PIL import Image | ||
| import io | ||
|
|
||
|
|
||
| class PytorchImagedataset(Dataset): | ||
| """ | ||
| A Pytorch dataset class that loads images for feature extraction, | ||
| while maintaining a corresponde between their feature vectors | ||
| and uuids. | ||
| """ | ||
|
|
||
| def __init__(self, data_elements, uuid4proc, transforms): | ||
| """ | ||
| Create a Pytorch dataset for feature extraction using CNN. | ||
| :param data_elements: A dictionary of uuids to corresponding | ||
| smqtk.representation.DataElement | ||
| :type data_elements: dict[uuid, smqtk.representation.DataElement] | ||
| :param uuid4proc: A queue of descriptor element uuids. | ||
| :type uuid4proc: list[uuid] | ||
| :param transforms: Augmentations and transforms applied to each | ||
| image. | ||
| :type tranforms: torchvision.transforms | ||
|
|
||
| :return: A tuple containing the transformed image and corresponding | ||
| uuid. | ||
| :rtype: tuple(torch.tensor, str) | ||
| """ | ||
| self.transform = transforms | ||
| self._uuid4proc = uuid4proc | ||
| self.data_ele = data_elements | ||
|
|
||
| def __len__(self): | ||
| """ | ||
| Returns the length of dataset | ||
| """ | ||
| return len(self.data_ele) | ||
|
|
||
| def __getitem__(self, idx): | ||
| """ | ||
| Returns both the transformed image tensor and its corresponding uuids | ||
| at a random position inside the dataset. | ||
| :param idx: id of a dataset elements to be fetched in current batch | ||
| of feature extraction. | ||
| :type idx: int or [int] | ||
|
|
||
| :return res: A tuple of the image tensor and its uuid. | ||
| :rtype res: tuple(torch.tensor, str) | ||
| """ | ||
| img = Image.open(io.BytesIO(self.data_ele[self._uuid4proc[idx]].get_bytes())) | ||
| img = img.convert('RGB') | ||
| if self.transform: | ||
| img = self.transform(img) | ||
| res = (img, self._uuid4proc[idx]) | ||
| return res | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,3 +19,5 @@ scikit-learn==0.20.0 | |
| scipy==1.1.0 | ||
| six==1.11.0 | ||
| stevedore==1.29.0 | ||
| torch==1.4.0 | ||
| torchvision==0.2.2 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| from __future__ import division, print_function | ||
| import inspect | ||
| import os | ||
| import unittest | ||
|
|
||
| import six | ||
| import PIL.Image | ||
| import numpy | ||
|
|
||
| from smqtk.algorithms.descriptor_generator import DescriptorGenerator | ||
| from smqtk.algorithms.descriptor_generator.pytorchdescriptor.pytorch_model_descriptors import \ | ||
| torch, PytorchModelDescriptor | ||
| from smqtk.representation.data_element.memory_element import DataMemoryElement | ||
|
|
||
| from tests import TEST_DATA_DIR | ||
| import pytest | ||
|
|
||
| if PytorchModelDescriptor.is_usable(): | ||
|
|
||
| class TestPytorchModelDescriptor (unittest.TestCase): | ||
|
|
||
| lenna_image_fp = os.path.join(TEST_DATA_DIR,'Lenna.png') | ||
| lenna_torch_res18_avgpool_descr_fp = os.path.join( | ||
| TEST_DATA_DIR, 'Lenna.resnet18_avgpool_output.npy' | ||
| ) | ||
|
|
||
| model_name_elem = 'resnet18' | ||
| return_layer_elem = 'avgpool' | ||
| norm_mean_elem = [0.485, 0.456, 0.406] | ||
| norm_std_elem = [0.229, 0.224, 0.225] | ||
| pretrained = True | ||
| resnet18_avgpool_weights = os.path.join( | ||
| TEST_DATA_DIR,'resnet18_avgpool_weights_torch.pth') | ||
|
|
||
| # Dummy pytorch configuration files + weights | ||
| dummy_model_name = 'dummy_model' | ||
| dummy_return_layer = 'junk_layer' | ||
|
|
||
| @classmethod | ||
| def setup_class(cls): | ||
| cls.model_name = 'resnet18' | ||
| cls.return_layer = 'avgpool' | ||
| cls.input_dim = (224, 224) | ||
| cls.norm_mean = [0.485, 0.456, 0.406] | ||
| cls.norm_std = [0.229, 0.224, 0.225] | ||
| if not torch.cuda.is_available(): | ||
| cls.use_gpu = False | ||
|
|
||
| def test_impl_findable(self): | ||
| self.assertIn(PytorchModelDescriptor, | ||
| DescriptorGenerator.get_impls()) | ||
|
|
||
| def test_get_config(self): | ||
| # Mocking set_network | ||
| expected_params = { | ||
| 'model_name': 'resnet18', | ||
| 'return_layer': 'avgpool', | ||
| 'custom_model_arch': False, | ||
| 'weights_filepath': None, | ||
| 'input_dim': (24, 996), | ||
| 'norm_mean': [0, 0, -0.5], | ||
| 'norm_std': [0.2, 0.3, 1], | ||
| 'use_gpu': True, | ||
| 'batch_size': 777, | ||
| 'pretrained': True, | ||
| } | ||
| g = PytorchModelDescriptor(**expected_params) | ||
| self.assertEqual(g.get_config(), expected_params) | ||
|
|
||
| def test_no_internal_compute_descriptor(self): | ||
| # This implementation's descriptor computation logic sits in async | ||
| # method override due to Pytorch's natural multi-element computation | ||
| # interface. Thus, ``_compute_descriptor`` should not be | ||
| # implemented. | ||
|
|
||
| # Passing purposefully bag constructor parameters and ignoring | ||
| # noinspection PyTypeChecker | ||
| g = PytorchModelDescriptor() | ||
| self.assertRaises( | ||
| NotImplementedError, | ||
| g._compute_descriptor, None | ||
| ) | ||
|
|
||
| def test_compute_descriptor_dummy_model(self): | ||
| # Pytorch dummy network interaction test Lenna image) | ||
|
|
||
| # Construct network with an dummy model. | ||
| # We expect an AsserterionError | ||
| self.assertRaises( | ||
| AssertionError, | ||
| PytorchModelDescriptor, model_name = self.dummy_model_name) | ||
|
|
||
| @unittest.skipUnless(DataMemoryElement.is_usable(), | ||
| "Memory element not functional") | ||
| def test_compute_descriptor_lenna_description(self): | ||
| # Pytorch ResNet interaction test (Lenna image) | ||
| # This is a long test since it has to compute descriptors. | ||
| expected_descr = numpy.load(self.lenna_torch_res18_avgpool_descr_fp) | ||
| d = PytorchModelDescriptor( | ||
| self.model_name_elem, | ||
| self.return_layer_elem, | ||
| None, None, self.input_dim, | ||
| self.norm_mean_elem, | ||
| self.norm_std_elem, | ||
| True, 1, self.pretrained) | ||
| im = PIL.Image.open(self.lenna_image_fp) | ||
| buff = six.BytesIO() | ||
| (im).save(buff, format="bmp") | ||
| de = DataMemoryElement(buff.getvalue(), | ||
| content_type='image/bmp') | ||
| descr = (d.compute_descriptor(de)).vector() | ||
| numpy.testing.assert_allclose(expected_descr, descr, atol=1e-4) | ||
|
|
||
| @unittest.skipUnless(DataMemoryElement.is_usable(), | ||
| "Memory element not functional") | ||
| def test_load_image_data(self): | ||
| # Testing if image can be loaded and throw an error if uuid is | ||
| # not automatically generated. | ||
| buff = six.BytesIO() | ||
| im = PIL.Image.open(self.lenna_image_fp) | ||
| (im).save(buff, format="bmp") | ||
| de = DataMemoryElement(buff.getvalue(), | ||
| content_type='image/bmp') | ||
| with pytest.raises(AssertionError): | ||
| assert not (de.uuid()) | ||
|
|
||
| def test_compute_descriptor_async_no_data(self): | ||
| # Should get a ValueError when given no descriptors to async method | ||
| g = PytorchModelDescriptor( | ||
| self.model_name_elem, | ||
| self.return_layer_elem, | ||
| None, None, self.input_dim, | ||
| self.norm_mean_elem, | ||
| self.norm_std_elem, | ||
| True, 32, self.pretrained) | ||
| self.assertRaises( | ||
| ValueError, | ||
| g.compute_descriptor_async, [] | ||
| ) | ||
|
|
||
| def test_loading_custom_weights_model(self): | ||
| # Should get a ValueError when the network weights are not | ||
| # loaded to the network or junk weights loaded. | ||
| with pytest.raises(ValueError): | ||
| g = PytorchModelDescriptor(custom_model_arch=None, \ | ||
| weights_filepath=None, pretrained=False) | ||
|
|
||
| def test_weights_loaded_to_model(self): | ||
| # Should fail when the network weights with pretrained flag | ||
| # loaded are not the imagenet pretrained weights. | ||
| d = PytorchModelDescriptor( | ||
| self.model_name_elem, | ||
| self.return_layer_elem, | ||
| None, None, self.input_dim, | ||
| self.norm_mean_elem, | ||
| self.norm_std_elem, | ||
| True, 1, self.pretrained) | ||
| imagenet_weights = torch.load(self.resnet18_avgpool_weights) | ||
| d.model.state_dict() == pytest.approx(imagenet_weights, rel=1e-6, abs=1e-12) | ||
|
|
||
| def test_return_layer_from_network(self): | ||
| # Should get a KeyError when the network does not contain | ||
| # the given return layer | ||
| with pytest.raises(KeyError): | ||
| g = PytorchModelDescriptor( | ||
| self.model_name_elem, | ||
| self.dummy_return_layer, | ||
| None, None, self.input_dim, | ||
| self.norm_mean_elem, | ||
| self.norm_std_elem, | ||
| True, 32, True) | ||
|
|
Binary file not shown.
Binary file not shown.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.