-
Notifications
You must be signed in to change notification settings - Fork 2.1k
feat: add nebullvm as backend #697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
diegofiori
wants to merge
15
commits into
jina-ai:main
Choose a base branch
from
diegofiori:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
ad3737d
doc: add nebullvm to docs
jina-bot a3c9bdc
style: adapt code to jina's style
ba1bc5f
docs: add python versions supported by nebullvm
8e3e8bb
feat: allow user to select the runtime device
f93bee2
test: add tests with nebullvm
5930aa3
fix: fix bug with cuda devices in clip-nebullvm
5544a28
docs: add nebullvm in all the docs
6089275
build: add nebullvm to the list of requirements
085f865
refactor: adapt nebullvm executor to the new interface
4ca4300
ci: add nebullvm installation to tests in ci/cd
f110021
style: fix format error
44634cf
fix: fix nebullvm api
638e538
fix: removed unnecessary param
fbbd827
fix: added support for dynamic shape
dce8c09
fix: adapt code to the new structure
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| repos: | ||
| - repo: https://github.com/ambv/black | ||
| rev: 22.3.0 | ||
| rev: 22.10.0 | ||
| hooks: | ||
| - id: black | ||
| types: [python] | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| import os | ||
| import warnings | ||
| from functools import partial | ||
| from multiprocessing.pool import ThreadPool | ||
| from typing import Optional, Dict | ||
|
|
||
| import torch | ||
| from clip_server.executors.helper import ( | ||
| split_img_txt_da, | ||
| preproc_image, | ||
| preproc_text, | ||
| set_rank, | ||
| ) | ||
| from clip_server.model import clip | ||
| from clip_server.model.clip_nebullvm import CLIPNebullvmModel, EnvRunner | ||
| from jina import Executor, requests, DocumentArray | ||
|
|
||
|
|
||
| class CLIPEncoder(Executor): | ||
| def __init__( | ||
| self, | ||
| name: str = 'ViT-B/32', | ||
| device: Optional[str] = None, | ||
| num_worker_preprocess: int = 4, | ||
| minibatch_size: int = 64, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
|
|
||
| self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name]) | ||
| self._pool = ThreadPool(processes=num_worker_preprocess) | ||
|
|
||
| self._minibatch_size = minibatch_size | ||
| if not device: | ||
| self._device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
| else: | ||
| self._device = device | ||
| if not self._device.startswith('cuda') and ( | ||
| 'NEBULLVM_THREADS_PER_MODEL' not in os.environ | ||
| and hasattr(self.runtime_args, 'replicas') | ||
| ): | ||
| replicas = getattr(self.runtime_args, 'replicas', 1) | ||
| num_threads = max(1, torch.get_num_threads() // replicas) | ||
| if num_threads < 2: | ||
| warnings.warn( | ||
| f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in ' | ||
| f'sub-optimal performance.' | ||
| ) | ||
| else: | ||
| num_threads = None | ||
| self._model = CLIPNebullvmModel(name, clip.MODEL_SIZE[name]) | ||
| with EnvRunner(self._device, num_threads): | ||
| self._model.optimize_models(batch_size=minibatch_size) | ||
|
|
||
| @requests(on='/rank') | ||
| async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): | ||
| await self.encode(docs['@r,m']) | ||
| set_rank(docs) | ||
|
|
||
| @requests | ||
| async def encode(self, docs: 'DocumentArray', **kwargs): | ||
| _img_da = DocumentArray() | ||
| _txt_da = DocumentArray() | ||
| for d in docs: | ||
| split_img_txt_da(d, _img_da, _txt_da) | ||
|
|
||
| # for image | ||
| if _img_da: | ||
| for minibatch in _img_da.map_batch( | ||
| partial( | ||
| preproc_image, preprocess_fn=self._preprocess_tensor, return_np=True | ||
| ), | ||
| batch_size=self._minibatch_size, | ||
| pool=self._pool, | ||
| ): | ||
| minibatch.embeddings = self._model.encode_image(minibatch.tensors) | ||
|
|
||
| # for text | ||
| if _txt_da: | ||
| for minibatch, _texts in _txt_da.map_batch( | ||
| partial(preproc_text, return_np=True), | ||
| batch_size=self._minibatch_size, | ||
| pool=self._pool, | ||
| ): | ||
| minibatch.embeddings = self._model.encode_text(minibatch.tensors) | ||
| minibatch.texts = _texts | ||
|
|
||
| # drop tensors | ||
| docs.tensors = None | ||
|
|
||
| return docs | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| import os | ||
|
|
||
| import numpy as np | ||
| import torch.cuda | ||
|
|
||
| from clip_server.model.pretrained_models import ( | ||
| download_model, | ||
| _OPENCLIP_MODELS, | ||
| _MULTILINGUALCLIP_MODELS, | ||
| ) | ||
| from clip_server.model.clip_model import BaseCLIPModel | ||
| from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET, _S3_BUCKET_V2 | ||
|
|
||
|
|
||
| class CLIPNebullvmModel(BaseCLIPModel): | ||
| def __init__(self, name: str, model_path: str = None): | ||
| super().__init__(name) | ||
| if name in _MODELS: | ||
| if not model_path: | ||
| cache_dir = os.path.expanduser( | ||
| f'~/.cache/clip/{name.replace("/", "-").replace("::", "-")}' | ||
| ) | ||
| textual_model_name, textual_model_md5 = _MODELS[name][0] | ||
| self._textual_path = download_model( | ||
| url=_S3_BUCKET_V2 + textual_model_name, | ||
| target_folder=cache_dir, | ||
| md5sum=textual_model_md5, | ||
| with_resume=True, | ||
| ) | ||
| visual_model_name, visual_model_md5 = _MODELS[name][1] | ||
| self._visual_path = download_model( | ||
| url=_S3_BUCKET_V2 + visual_model_name, | ||
| target_folder=cache_dir, | ||
| md5sum=visual_model_md5, | ||
| with_resume=True, | ||
| ) | ||
| else: | ||
| if os.path.isdir(model_path): | ||
| self._textual_path = os.path.join(model_path, | ||
| 'textual.onnx') | ||
| self._visual_path = os.path.join(model_path, 'visual.onnx') | ||
| if not os.path.isfile( | ||
| self._textual_path) or not os.path.isfile( | ||
| self._visual_path | ||
| ): | ||
| raise RuntimeError( | ||
| f'The given model path {model_path} does not contain `textual.onnx` and `visual.onnx`' | ||
| ) | ||
| else: | ||
| raise RuntimeError( | ||
| f'The given model path {model_path} should be a folder containing both ' | ||
| f'`textual.onnx` and `visual.onnx`.' | ||
| ) | ||
| else: | ||
| raise RuntimeError( | ||
| 'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format( | ||
| name, | ||
| ''.join( | ||
| ['\t- {}\n'.format(i) for i in list(_MODELS.keys())]), | ||
| ) | ||
| ) | ||
|
|
||
| def optimize_models( | ||
| self, | ||
| **kwargs, | ||
| ): | ||
| from nebullvm.api.functions import optimize_model | ||
|
|
||
| general_kwargs = {} | ||
| general_kwargs.update(kwargs) | ||
|
|
||
| dynamic_info = { | ||
| "inputs": [ | ||
| {0: 'batch', 1: 'num_channels', 2: 'pixel_size', 3: 'pixel_size'} | ||
| ], | ||
| "outputs": [{0: 'batch'}], | ||
| } | ||
|
|
||
| self._visual_model = optimize_model( | ||
| self._visual_path, | ||
| input_data=[ | ||
| ( | ||
| ( | ||
| np.random.randn(1, 3, self.pixel_size, self.pixel_size).astype( | ||
| np.float32 | ||
| ), | ||
| ), | ||
| 0, | ||
| ) | ||
| ], | ||
| dynamic_info=dynamic_info, | ||
| **general_kwargs, | ||
| ) | ||
|
|
||
| dynamic_info = { | ||
| "inputs": [ | ||
| {0: 'batch', 1: 'num_tokens'}, | ||
| ], | ||
| "outputs": [ | ||
| {0: 'batch'}, | ||
| ], | ||
| } | ||
|
|
||
| self._textual_model = optimize_model( | ||
| self._textual_path, | ||
| input_data=[((np.random.randint(0, 100, (1, 77)),), 0)], | ||
| dynamic_info=dynamic_info, | ||
| **general_kwargs, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def get_model_name(name: str): | ||
| if name in _OPENCLIP_MODELS: | ||
| from clip_server.model.openclip_model import OpenCLIPModel | ||
|
|
||
| return OpenCLIPModel.get_model_name(name) | ||
| elif name in _MULTILINGUALCLIP_MODELS: | ||
| from clip_server.model.mclip_model import MultilingualCLIPModel | ||
|
|
||
| return MultilingualCLIPModel.get_model_name(name) | ||
|
|
||
| return name | ||
|
|
||
| def encode_image(self, onnx_image): | ||
| (visual_output,) = self._visual_model(onnx_image) | ||
| return visual_output | ||
|
|
||
| def encode_text(self, onnx_text): | ||
| (textual_output,) = self._textual_model(onnx_text) | ||
| return textual_output | ||
|
|
||
|
|
||
| class EnvRunner: | ||
| def __init__(self, device: str, num_threads: int = None): | ||
| self.device = device | ||
| self.cuda_str = None | ||
| self.rm_cuda_flag = False | ||
| self.num_threads = num_threads | ||
|
|
||
| def __enter__(self): | ||
| if self.device == "cpu" and torch.cuda.is_available(): | ||
| self.cuda_str = os.environ.get("CUDA_VISIBLE_DEVICES") | ||
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
| self.rm_cuda_flag = self.cuda_str is None | ||
| if self.num_threads is not None: | ||
| os.environ["NEBULLVM_THREADS_PER_MODEL"] = f"{self.num_threads}" | ||
|
|
||
| def __exit__(self, exc_type, exc_val, exc_tb): | ||
| if self.cuda_str is not None: | ||
| os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_str | ||
| elif self.rm_cuda_flag: | ||
| os.environ.pop("CUDA_VISIBLE_DEVICES") | ||
| if self.num_threads is not None: | ||
| os.environ.pop("NEBULLVM_THREADS_PER_MODEL") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.