optimum
Advanced tools
| Metadata-Version: 2.1 | ||
| Name: optimum | ||
| Version: 1.25.3 | ||
| Version: 1.26.0 | ||
| Summary: Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality. | ||
@@ -34,3 +34,3 @@ Home-page: https://github.com/huggingface/optimum | ||
| Requires-Dist: onnxruntime>=1.11.0; extra == "onnxruntime" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime" | ||
| Provides-Extra: onnxruntime-gpu | ||
@@ -41,3 +41,3 @@ Requires-Dist: onnx; extra == "onnxruntime-gpu" | ||
| Requires-Dist: onnxruntime-gpu>=1.11.0; extra == "onnxruntime-gpu" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime-gpu" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime-gpu" | ||
| Provides-Extra: onnxruntime-training | ||
@@ -49,3 +49,3 @@ Requires-Dist: evaluate; extra == "onnxruntime-training" | ||
| Requires-Dist: protobuf>=3.20.1; extra == "onnxruntime-training" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime-training" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime-training" | ||
| Requires-Dist: onnxruntime-training>=1.11.0; extra == "onnxruntime-training" | ||
@@ -57,3 +57,3 @@ Provides-Extra: exporters | ||
| Requires-Dist: protobuf>=3.20.1; extra == "exporters" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "exporters" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "exporters" | ||
| Provides-Extra: exporters-gpu | ||
@@ -64,3 +64,3 @@ Requires-Dist: onnx; extra == "exporters-gpu" | ||
| Requires-Dist: protobuf>=3.20.1; extra == "exporters-gpu" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "exporters-gpu" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "exporters-gpu" | ||
| Provides-Extra: exporters-tf | ||
@@ -114,2 +114,3 @@ Requires-Dist: onnx; extra == "exporters-tf" | ||
| Requires-Dist: hf_xet; extra == "dev" | ||
| Requires-Dist: onnxslim>=0.1.53; extra == "dev" | ||
| Requires-Dist: black~=23.1; extra == "dev" | ||
@@ -133,2 +134,3 @@ Requires-Dist: ruff==0.1.5; extra == "dev" | ||
| Requires-Dist: hf_xet; extra == "tests" | ||
| Requires-Dist: onnxslim>=0.1.53; extra == "tests" | ||
| Provides-Extra: quality | ||
@@ -135,0 +137,0 @@ Requires-Dist: black~=23.1; extra == "quality" |
@@ -34,2 +34,3 @@ transformers>=4.29 | ||
| hf_xet | ||
| onnxslim>=0.1.53 | ||
| black~=23.1 | ||
@@ -46,3 +47,3 @@ ruff==0.1.5 | ||
| protobuf>=3.20.1 | ||
| transformers<4.52.0,>=4.36 | ||
| transformers<4.53.0,>=4.36 | ||
@@ -54,3 +55,3 @@ [exporters-gpu] | ||
| protobuf>=3.20.1 | ||
| transformers<4.52.0,>=4.36 | ||
| transformers<4.53.0,>=4.36 | ||
@@ -97,3 +98,3 @@ [exporters-tf] | ||
| onnxruntime>=1.11.0 | ||
| transformers<4.52.0,>=4.36 | ||
| transformers<4.53.0,>=4.36 | ||
@@ -105,3 +106,3 @@ [onnxruntime-gpu] | ||
| onnxruntime-gpu>=1.11.0 | ||
| transformers<4.52.0,>=4.36 | ||
| transformers<4.53.0,>=4.36 | ||
@@ -114,3 +115,3 @@ [onnxruntime-training] | ||
| protobuf>=3.20.1 | ||
| transformers<4.52.0,>=4.36 | ||
| transformers<4.53.0,>=4.36 | ||
| onnxruntime-training>=1.11.0 | ||
@@ -144,1 +145,2 @@ | ||
| hf_xet | ||
| onnxslim>=0.1.53 |
@@ -92,3 +92,2 @@ LICENSE | ||
| optimum/onnxruntime/constants.py | ||
| optimum/onnxruntime/graph.py | ||
| optimum/onnxruntime/modeling_decoder.py | ||
@@ -105,6 +104,2 @@ optimum/onnxruntime/modeling_diffusion.py | ||
| optimum/onnxruntime/utils.py | ||
| optimum/onnxruntime/io_binding/__init__.py | ||
| optimum/onnxruntime/io_binding/io_binding_helper.py | ||
| optimum/onnxruntime/models/__init__.py | ||
| optimum/onnxruntime/models/bloom.py | ||
| optimum/onnxruntime/preprocessors/__init__.py | ||
@@ -111,0 +106,0 @@ optimum/onnxruntime/preprocessors/quantization.py |
@@ -172,2 +172,7 @@ # Copyright 2022 The HuggingFace Team. All rights reserved. | ||
| ) | ||
| optional_group.add_argument( | ||
| "--slim", | ||
| action="store_true", | ||
| help="Enables onnxslim optimization.", | ||
| ) | ||
@@ -290,3 +295,4 @@ input_group = parser.add_argument_group( | ||
| do_constant_folding=not self.args.no_constant_folding, | ||
| slim=self.args.slim, | ||
| **input_shapes, | ||
| ) |
@@ -18,3 +18,2 @@ # coding=utf-8 | ||
| import argparse | ||
| import warnings | ||
| from pathlib import Path | ||
@@ -29,3 +28,8 @@ | ||
| from ...utils import DEFAULT_DUMMY_SHAPES, logging | ||
| from ...utils.import_utils import is_transformers_version | ||
| from ...utils.import_utils import ( | ||
| is_diffusers_available, | ||
| is_sentence_transformers_available, | ||
| is_timm_available, | ||
| is_transformers_version, | ||
| ) | ||
| from ...utils.save_utils import maybe_load_preprocessors | ||
@@ -63,5 +67,4 @@ from ..tasks import TasksManager | ||
| atol: Optional[float] = None, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| trust_remote_code: bool = False, | ||
| pad_token_id: Optional[int] = None, | ||
| # hub options | ||
| subfolder: str = "", | ||
@@ -71,4 +74,6 @@ revision: str = "main", | ||
| local_files_only: bool = False, | ||
| use_auth_token: Optional[Union[bool, str]] = None, | ||
| trust_remote_code: bool = False, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| token: Optional[Union[bool, str]] = None, | ||
| ######################################## | ||
| for_ort: bool = False, | ||
@@ -85,2 +90,3 @@ do_validation: bool = True, | ||
| do_constant_folding: bool = True, | ||
| slim: bool = False, | ||
| **kwargs_shapes, | ||
@@ -174,2 +180,4 @@ ): | ||
| PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible. | ||
| slim (bool, defaults to `False`): | ||
| PyTorch-specific argument. If `True`, use onnxslim to optimize the ONNX model. | ||
| **kwargs_shapes (`Dict`): | ||
@@ -186,11 +194,2 @@ Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. | ||
| if use_auth_token is not None: | ||
| warnings.warn( | ||
| "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", | ||
| FutureWarning, | ||
| ) | ||
| if token is not None: | ||
| raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") | ||
| token = use_auth_token | ||
| if fp16: | ||
@@ -230,3 +229,3 @@ if dtype is not None: | ||
| if task in ["stable-diffusion", "stable-diffusion-xl"]: | ||
| if task in ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"]: | ||
| logger.warning( | ||
@@ -237,2 +236,19 @@ f"The task `{task}` is deprecated and will be removed in a future release of Optimum. " | ||
| if library_name == "sentence_transformers" and not is_sentence_transformers_available(): | ||
| raise ImportError( | ||
| "The library `sentence_transformers` was specified, but it is not installed. " | ||
| "Please install it with `pip install sentence-transformers`." | ||
| ) | ||
| if library_name == "diffusers" and not is_diffusers_available(): | ||
| raise ImportError( | ||
| "The library `diffusers` was specified, but it is not installed. " | ||
| "Please install it with `pip install diffusers`." | ||
| ) | ||
| if library_name == "timm" and not is_timm_available(): | ||
| raise ImportError( | ||
| "The library `timm` was specified, but it is not installed. Please install it with `pip install timm`." | ||
| ) | ||
| original_task = task | ||
@@ -250,2 +266,18 @@ task = TasksManager.map_from_synonym(task) | ||
| ) | ||
| if library_name == "sentence_transformers" and not is_sentence_transformers_available(): | ||
| logger.warning( | ||
| "The library name was inferred as `sentence_transformers`, which is not installed. " | ||
| "Falling back to `transformers` to avoid breaking the export." | ||
| ) | ||
| library_name = "transformers" | ||
| elif library_name == "timm" and not is_timm_available(): | ||
| raise ImportError( | ||
| "The library name was inferred as `timm`, which is not installed. " | ||
| "Please install it with `pip install timm`." | ||
| ) | ||
| elif library_name == "diffusers" and not is_diffusers_available(): | ||
| raise ImportError( | ||
| "The library name was inferred as `diffusers`, which is not installed. " | ||
| "Please install it with `pip install diffusers`." | ||
| ) | ||
@@ -268,3 +300,10 @@ torch_dtype = None | ||
| try: | ||
| task = TasksManager.infer_task_from_model(model_name_or_path, library_name=library_name) | ||
| task = TasksManager.infer_task_from_model( | ||
| model_name_or_path, | ||
| subfolder=subfolder, | ||
| revision=revision, | ||
| cache_dir=cache_dir, | ||
| token=token, | ||
| library_name=library_name, | ||
| ) | ||
| except KeyError as e: | ||
@@ -310,4 +349,5 @@ raise KeyError( | ||
| # TODO: Fix in Transformers so that SdpaAttention class can be exported to ONNX. `attn_implementation` is introduced in Transformers 4.36. | ||
| if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and is_transformers_version(">=", "4.35.99"): | ||
| # TODO: Fix in Transformers so that SdpaAttention class can be exported to ONNX. | ||
| # This was fixed in transformers 4.42.0, we can remve it when minimum transformers version is updated to 4.42 | ||
| if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and is_transformers_version("<", "4.42"): | ||
| loading_kwargs["attn_implementation"] = "eager" | ||
@@ -406,2 +446,3 @@ | ||
| do_constant_folding=do_constant_folding, | ||
| slim=slim, | ||
| **kwargs_shapes, | ||
@@ -408,0 +449,0 @@ ) |
@@ -941,15 +941,15 @@ # coding=utf-8 | ||
| # Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task | ||
| if len(onnx_files_subpaths) >= 3 and self.use_past is True: | ||
| decoder_path = Path(path, onnx_files_subpaths[1]) | ||
| decoder_with_past_path = Path(path, onnx_files_subpaths[2]) | ||
| decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") | ||
| # Attempt to merge only if the decoder was exported without/with past | ||
| onnx_decoder_path = Path(path, ONNX_DECODER_NAME + ".onnx") | ||
| onnx_decoder_with_past_path = Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx") | ||
| decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") | ||
| if onnx_decoder_path.is_file() and onnx_decoder_with_past_path.is_file() and self.use_past is True: | ||
| try: | ||
| # The decoder with past does not output the cross attention past key values as they are constant, | ||
| # hence the need for strict=False | ||
| from ...onnx import merge_decoders | ||
| # The decoder with past does not output the cross attention past key values as they are constant, | ||
| # hence the need for strict=False | ||
| merge_decoders( | ||
| decoder=decoder_path, | ||
| decoder_with_past=decoder_with_past_path, | ||
| decoder=onnx_decoder_path, | ||
| decoder_with_past=onnx_decoder_with_past_path, | ||
| save_path=decoder_merged_path, | ||
@@ -956,0 +956,0 @@ strict=False, |
@@ -33,2 +33,3 @@ # coding=utf-8 | ||
| from ...onnx.graph_transformations import check_and_save_model | ||
| from ...onnx.utils import _get_onnx_external_constants, _get_onnx_external_data_tensors, check_model_uses_external_data | ||
@@ -40,2 +41,3 @@ from ...utils import ( | ||
| is_diffusers_available, | ||
| is_onnxslim_available, | ||
| is_torch_onnx_support_available, | ||
@@ -922,2 +924,3 @@ is_transformers_version, | ||
| do_constant_folding: bool = True, | ||
| slim: bool = False, | ||
| **kwargs_shapes, | ||
@@ -978,2 +981,4 @@ ): | ||
| PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible. | ||
| slim (bool, defaults to `False`): | ||
| Use onnxslim to optimize the ONNX model. | ||
| **kwargs_shapes (`Dict`): | ||
@@ -1203,2 +1208,13 @@ Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. | ||
| if slim: | ||
| if not is_onnxslim_available(): | ||
| raise ImportError("The pip package `onnxslim` is required to optimize onnx models.") | ||
| from onnxslim import slim | ||
| for subpath in onnx_files_subpaths: | ||
| file_path = os.path.join(output, subpath) | ||
| slimmed_model = slim(file_path) | ||
| check_and_save_model(slimmed_model, file_path) | ||
| # Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any | ||
@@ -1205,0 +1221,0 @@ # TODO: treating diffusion separately is quite ugly |
@@ -83,2 +83,3 @@ # coding=utf-8 | ||
| "imagegpt", | ||
| "internlm2", | ||
| "llama", | ||
@@ -89,2 +90,4 @@ "mistral", | ||
| "qwen2", | ||
| "qwen3", | ||
| "qwen3-moe", | ||
| "granite", | ||
@@ -94,3 +97,3 @@ } | ||
| if is_transformers_version(">=", "4.45.99"): | ||
| if is_transformers_version(">=", "4.46.0"): | ||
| MODEL_TYPES_REQUIRING_POSITION_IDS.add("opt") | ||
@@ -97,0 +100,0 @@ |
+42
-74
@@ -23,3 +23,3 @@ # coding=utf-8 | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Optional, Union | ||
| from typing import TYPE_CHECKING, List, Optional, Union | ||
@@ -36,5 +36,15 @@ from huggingface_hub import HfApi | ||
| if TYPE_CHECKING: | ||
| from transformers import PreTrainedModel, TFPreTrainedModel | ||
| from transformers import ( | ||
| FeatureExtractionMixin, | ||
| ImageProcessingMixin, | ||
| PreTrainedModel, | ||
| ProcessorMixin, | ||
| SpecialTokensMixin, | ||
| TFPreTrainedModel, | ||
| ) | ||
| PreprocessorT = Union[SpecialTokensMixin, FeatureExtractionMixin, ImageProcessingMixin, ProcessorMixin] | ||
| ModelT = Union["PreTrainedModel", "TFPreTrainedModel"] | ||
| logger = logging.getLogger(__name__) | ||
@@ -84,2 +94,3 @@ | ||
| # TODO: Should be removed when we no longer use OptimizedModel for everything | ||
| # workaround to enable compatibility between optimum models and transformers pipelines | ||
@@ -95,7 +106,8 @@ class PreTrainedModel(ABC): # noqa: F811 | ||
| def __init__(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], config: PretrainedConfig): | ||
| super().__init__() | ||
| def __init__( | ||
| self, model: Union["ModelT"], config: "PretrainedConfig", preprocessors: Optional[List["PreprocessorT"]] = None | ||
| ): | ||
| self.model = model | ||
| self.config = config | ||
| self.preprocessors = [] | ||
| self.preprocessors = preprocessors or [] | ||
@@ -236,19 +248,11 @@ def __call__(self, *args, **kwargs): | ||
| config_name_or_path: Union[str, os.PathLike], | ||
| revision: Optional[str] = None, | ||
| # hub options | ||
| subfolder: str = "", | ||
| revision: str = "main", | ||
| force_download: bool = False, | ||
| local_files_only: bool = False, | ||
| trust_remote_code: bool = False, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| use_auth_token: Optional[Union[bool, str]] = None, | ||
| token: Optional[Union[bool, str]] = None, | ||
| force_download: bool = False, | ||
| subfolder: str = "", | ||
| trust_remote_code: bool = False, | ||
| ) -> PretrainedConfig: | ||
| if use_auth_token is not None: | ||
| warnings.warn( | ||
| "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", | ||
| FutureWarning, | ||
| ) | ||
| if token is not None: | ||
| raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") | ||
| token = use_auth_token | ||
| try: | ||
@@ -286,10 +290,10 @@ config = AutoConfig.from_pretrained( | ||
| model_id: Union[str, Path], | ||
| config: PretrainedConfig, | ||
| use_auth_token: Optional[Union[bool, str]] = None, | ||
| token: Optional[Union[bool, str]] = None, | ||
| revision: Optional[str] = None, | ||
| # hub options | ||
| subfolder: str = "", | ||
| revision: str = "main", | ||
| force_download: bool = False, | ||
| local_files_only: bool = False, | ||
| trust_remote_code: bool = False, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| subfolder: str = "", | ||
| local_files_only: bool = False, | ||
| token: Optional[Union[bool, str]] = None, | ||
| **kwargs, | ||
@@ -301,35 +305,14 @@ ) -> "OptimizedModel": | ||
| @classmethod | ||
| def _from_transformers( | ||
| def _export( | ||
| cls, | ||
| model_id: Union[str, Path], | ||
| config: PretrainedConfig, | ||
| use_auth_token: Optional[Union[bool, str]] = None, | ||
| token: Optional[Union[bool, str]] = None, | ||
| revision: Optional[str] = None, | ||
| # hub options | ||
| subfolder: str = "", | ||
| revision: str = "main", | ||
| force_download: bool = False, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| subfolder: str = "", | ||
| local_files_only: bool = False, | ||
| trust_remote_code: bool = False, | ||
| **kwargs, | ||
| ) -> "OptimizedModel": | ||
| """Overwrite this method in subclass to define how to load your model from vanilla transformers model""" | ||
| raise NotImplementedError( | ||
| "`_from_transformers` method will be deprecated in a future release. Please override `_export` instead" | ||
| "to define how to load your model from vanilla transformers model" | ||
| ) | ||
| @classmethod | ||
| def _export( | ||
| cls, | ||
| model_id: Union[str, Path], | ||
| config: PretrainedConfig, | ||
| use_auth_token: Optional[Union[bool, str]] = None, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| token: Optional[Union[bool, str]] = None, | ||
| revision: Optional[str] = None, | ||
| force_download: bool = False, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| subfolder: str = "", | ||
| local_files_only: bool = False, | ||
| trust_remote_code: bool = False, | ||
| **kwargs, | ||
@@ -347,12 +330,12 @@ ) -> "OptimizedModel": | ||
| model_id: Union[str, Path], | ||
| config: Optional[PretrainedConfig] = None, | ||
| export: bool = False, | ||
| # hub options | ||
| subfolder: str = "", | ||
| revision: str = "main", | ||
| force_download: bool = False, | ||
| use_auth_token: Optional[Union[bool, str]] = None, | ||
| token: Optional[Union[bool, str]] = None, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| subfolder: str = "", | ||
| config: Optional[PretrainedConfig] = None, | ||
| local_files_only: bool = False, | ||
| trust_remote_code: bool = False, | ||
| revision: Optional[str] = None, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| token: Optional[Union[bool, str]] = None, | ||
| **kwargs, | ||
@@ -365,21 +348,5 @@ ) -> "OptimizedModel": | ||
| if use_auth_token is not None: | ||
| warnings.warn( | ||
| "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", | ||
| FutureWarning, | ||
| ) | ||
| if token is not None: | ||
| raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") | ||
| token = use_auth_token | ||
| if isinstance(model_id, Path): | ||
| model_id = model_id.as_posix() | ||
| from_transformers = kwargs.pop("from_transformers", None) | ||
| if from_transformers is not None: | ||
| logger.warning( | ||
| "The argument `from_transformers` is deprecated, and will be removed in optimum 2.0. Use `export` instead" | ||
| ) | ||
| export = from_transformers | ||
| if len(model_id.split("@")) == 2: | ||
@@ -448,3 +415,3 @@ logger.warning( | ||
| from_pretrained_method = cls._from_transformers if export else cls._from_pretrained | ||
| from_pretrained_method = cls._export if export else cls._from_pretrained | ||
@@ -454,2 +421,3 @@ return from_pretrained_method( | ||
| config=config, | ||
| # hub options | ||
| revision=revision, | ||
@@ -456,0 +424,0 @@ cache_dir=cache_dir, |
+543
-356
@@ -14,451 +14,638 @@ # Copyright 2022 The HuggingFace Team. All rights reserved. | ||
| # limitations under the License. | ||
| """Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models.""" | ||
| """Defines the base classes that are used to perform inference with ONNX Runtime sessions.""" | ||
| from abc import abstractmethod | ||
| from typing import Dict, Optional, Set, Tuple, Union | ||
| import os | ||
| import shutil | ||
| from pathlib import Path | ||
| from typing import Any, Dict, List, Optional, Set, Tuple, Union | ||
| import numpy as np | ||
| import torch | ||
| from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput | ||
| from onnxruntime import InferenceSession | ||
| from onnxruntime import InferenceSession, IOBinding | ||
| from onnxruntime.transformers.io_binding_helper import TypeHelper | ||
| from ..utils import NormalizedConfigManager | ||
| from ..utils.logging import warn_once | ||
| from .io_binding import TypeHelper | ||
| from .modeling_ort import ORTModel | ||
| from .utils import logging | ||
| from ..onnx.utils import _get_model_external_data_paths | ||
| from ..utils.logging import get_logger | ||
| from .utils import ( | ||
| get_device_for_provider, | ||
| get_dtype_from_session, | ||
| get_provider_for_device, | ||
| parse_device, | ||
| validate_provider_availability, | ||
| ) | ||
| logger = logging.get_logger(__name__) | ||
| logger = get_logger(__name__) | ||
| NON_EMPTY_TENSOR = torch.tensor(0) | ||
| class ORTModelPart: | ||
| class ORTSessionMixin: | ||
| """ | ||
| For multi-file ONNX models, such as encoder-decoder models, represents a part of the model. | ||
| It has its own `onnxruntime.InferenceSession`, and can perform a forward pass. | ||
| Mixin class that provides common functionalities for an ONNX Runtime session. | ||
| This class is used to manage the session, the execution provider, and the IO binding. | ||
| It also provides methods to prepare the inputs and outputs for ONNX Runtime. | ||
| """ | ||
| # should be in an ORTMixin | ||
| _prepare_io_binding = ORTModel._prepare_io_binding | ||
| _prepare_output_buffer = ORTModel._prepare_output_buffer | ||
| _output_shape_inference = ORTModel._output_shape_inference | ||
| def initialize_ort_attributes(self, session: InferenceSession, use_io_binding: Optional[bool] = None): | ||
| """ | ||
| Initializes the ORTSessionMixin class. | ||
| Args: | ||
| session (`onnxruntime.InferenceSession`): | ||
| The ONNX Runtime session to use for inference. | ||
| use_io_binding (`Optional[bool]`, defaults to `None`): | ||
| Whether to use IO Binding or not. If `None`, it will be set to `True` for CUDAExecutionProvider and `False` | ||
| for other providers. | ||
| """ | ||
| _prepare_onnx_inputs = ORTModel._prepare_onnx_inputs | ||
| _prepare_onnx_outputs = ORTModel._prepare_onnx_outputs | ||
| def __init__(self, session: InferenceSession, parent_model: "ORTModel"): | ||
| self.session = session | ||
| self.parent_model = parent_model | ||
| self.main_input_name = self.parent_model.main_input_name | ||
| self.path = Path(session._model_path) | ||
| self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} | ||
| self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} | ||
| if use_io_binding is None: | ||
| if self.provider == "CUDAExecutionProvider": | ||
| logger.info( | ||
| "`use_io_binding` was not set, but CUDAExecutionProvider supports IO Binding. " | ||
| "Setting `use_io_binding=True` to leverage IO Binding and improve performance. " | ||
| "You can disable it by setting `model.use_io_binding=False`." | ||
| ) | ||
| use_io_binding = True | ||
| else: | ||
| use_io_binding = False | ||
| self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()} | ||
| self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()} | ||
| self._use_io_binding = use_io_binding | ||
| self._io_binding = IOBinding(session) | ||
| self._dtype = get_dtype_from_session(session) | ||
| self._device = get_device_for_provider(self.provider, self.provider_option) | ||
| self.input_shapes = {input_key.name: input_key.shape for input_key in session.get_inputs()} | ||
| self.output_shapes = {output_key.name: output_key.shape for output_key in session.get_outputs()} | ||
| self.input_names = {input.name: idx for idx, input in enumerate(session.get_inputs())} | ||
| self.output_names = {output.name: idx for idx, output in enumerate(session.get_outputs())} | ||
| self.input_shapes = {input.name: input.shape for input in session.get_inputs()} | ||
| self.output_shapes = {output.name: output.shape for output in session.get_outputs()} | ||
| self.input_dtypes = {input.name: input.type for input in session.get_inputs()} | ||
| self.output_dtypes = {output.name: output.type for output in session.get_outputs()} | ||
| @property | ||
| def device(self): | ||
| return self.parent_model.device | ||
| def model_path(self) -> str: | ||
| """ | ||
| Returns the path of the onnx file from which the session was created. | ||
| """ | ||
| logger.warning( | ||
| "The `ORTSessionMixin.model_path` property is deprecated and will be removed in a future version. " | ||
| "Please use `ORTSessionMixin.path` instead (`ORTSessionMixin.path` is a proper Path object)." | ||
| ) | ||
| return self.path | ||
| @property | ||
| def dtype(self): | ||
| for dtype in self.input_dtypes.values(): | ||
| torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) | ||
| if torch_dtype.is_floating_point: | ||
| return torch_dtype | ||
| def model_name(self) -> str: | ||
| """ | ||
| Returns the name of the onnx file from which the session was created. | ||
| """ | ||
| logger.warning( | ||
| "The `ORTSessionMixin.model_name` property is deprecated and will be removed in a future version. " | ||
| "Please use `ORTSessionMixin.path.name` instead (`ORTSessionMixin.path` is a proper Path object)." | ||
| ) | ||
| return self.path.name | ||
| for dtype in self.output_dtypes.values(): | ||
| torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) | ||
| if torch_dtype.is_floating_point: | ||
| return torch_dtype | ||
| @property | ||
| def providers(self) -> List[str]: | ||
| """ | ||
| Returns a list of Execution Providers registered with the session. | ||
| """ | ||
| return self.session.get_providers() | ||
| return None | ||
| @property | ||
| def provider(self) -> str: | ||
| """ | ||
| Returns the main Execution Provider registered with the session. | ||
| """ | ||
| return self.providers[0] | ||
| def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None): | ||
| @property | ||
| def provider_options(self) -> Dict[str, Any]: | ||
| """ | ||
| Returns a dictionary of Execution Providers configurations/options. | ||
| """ | ||
| return self.session.get_provider_options() | ||
| @property | ||
| def provider_option(self) -> Dict[str, Any]: | ||
| """ | ||
| Returns the configuration/options of the main Execution Provider. | ||
| """ | ||
| return self.provider_options[self.provider] | ||
| @property | ||
| def device(self) -> torch.device: | ||
| """ | ||
| Returns the `torch.device` associated with the ONNX Runtime session. | ||
| This device is inferred from the provider and provider options. | ||
| """ | ||
| return self._device | ||
| @device.setter | ||
| def device(self, *args, **kwargs): | ||
| raise AttributeError( | ||
| "The device attribute is read-only, please use the `.to(device)` " | ||
| "method to change both the device and the execution provider accordingly." | ||
| ) | ||
| @property | ||
| def dtype(self) -> torch.dtype: | ||
| """ | ||
| Returns the `torch.dtype` associated with the ONNX Runtime session. | ||
| This dtype is inferred from the input/output dtypes of the session. | ||
| If no floating point type is found, it defaults to `torch.float32`. | ||
| """ | ||
| return self._dtype | ||
| @property | ||
| def use_io_binding(self) -> Optional[bool]: | ||
| """ | ||
| Returns whether IO Binding is used or not. | ||
| """ | ||
| return self._use_io_binding | ||
| @use_io_binding.setter | ||
| def use_io_binding(self, value: bool): | ||
| """ | ||
| Sets the IO Binding usage. | ||
| """ | ||
| if not isinstance(value, bool): | ||
| raise ValueError("`use_io_binding` should be a boolean value.") | ||
| self._use_io_binding = value | ||
| def to(self, *args, **kwargs): | ||
| """ | ||
| Moves the session to the specified device by updating the execution provider and its options. | ||
| Args: | ||
| device (`str`, `int`, `torch.device`): | ||
| The device to move the session to. It can be a string (e.g., "cuda", "cpu"), an integer (e.g., 0 for GPU 0), | ||
| or a `torch.device` object. | ||
| Returns: | ||
| `ORTSessionMixin`: The updated session. | ||
| Raises: | ||
| ValueError: If the device is not supported or if the provider is not available. | ||
| """ | ||
| dtype = None | ||
| device = None | ||
| for arg in args: | ||
| if isinstance(arg, torch.device): | ||
| if isinstance(arg, (str, torch.device)): | ||
| device = arg | ||
| elif isinstance(arg, int): | ||
| device = torch.device(arg) | ||
| elif isinstance(arg, torch.device): | ||
| device = arg | ||
| elif isinstance(arg, torch.dtype): | ||
| dtype = arg | ||
| if device is not None and device != self.device: | ||
| raise ValueError( | ||
| "Cannot change the device of a model part without changing the device of the parent model. " | ||
| "Please use the `to` method of the parent model to change the device." | ||
| ) | ||
| for key, value in kwargs.items(): | ||
| if key == "device": | ||
| device = value | ||
| elif key == "dtype": | ||
| dtype = value | ||
| if dtype is not None and dtype != self.dtype: | ||
| raise NotImplementedError( | ||
| f"Cannot change the dtype of the model from {self.dtype} to {dtype}. " | ||
| f"Please export the model with the desired dtype." | ||
| ) | ||
| if dtype is not None: | ||
| # we don't support changing the dtype of the model | ||
| return self | ||
| @abstractmethod | ||
| def forward(self, *args, **kwargs): | ||
| pass | ||
| if device is None: | ||
| # no device was provided, we don't change the device | ||
| return self | ||
| def __call__(self, *args, **kwargs): | ||
| return self.forward(*args, **kwargs) | ||
| device, provider_option = parse_device(device) | ||
| provider = get_provider_for_device(device) | ||
| validate_provider_availability(provider) | ||
| if device == self.device: | ||
| return self | ||
| class ORTEncoder(ORTModelPart): | ||
| """ | ||
| Encoder part of the encoder-decoder model for ONNX Runtime inference. | ||
| """ | ||
| self.session.set_providers([provider], provider_options=[provider_option]) | ||
| def __init__(self, session: InferenceSession, parent_model: "ORTModel"): | ||
| super().__init__(session, parent_model) | ||
| if self.use_io_binding is None: | ||
| if self.provider == "CUDAExecutionProvider": | ||
| logger.info( | ||
| "`use_io_binding` was set to `None` before the provider was changed to CUDAExecutionProvider. " | ||
| "Setting `use_io_binding=True` to leverage IO Binding and improve performance. " | ||
| "You can disable it by setting `model.use_io_binding=False`." | ||
| ) | ||
| self.use_io_binding = True | ||
| config = ( | ||
| self.parent_model.config.encoder | ||
| if hasattr(self.parent_model.config, "encoder") | ||
| else self.parent_model.config | ||
| ) | ||
| self._device = device | ||
| self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) | ||
| return self | ||
| def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, **kwargs) -> BaseModelOutput: | ||
| use_torch = isinstance(input_ids, torch.Tensor) | ||
| self.parent_model.raise_on_numpy_input_io_binding(use_torch) | ||
| def raise_on_numpy_input_io_binding(self, use_torch: bool): | ||
| """ | ||
| Raises an error if IO Binding is requested although the tensor used are numpy arrays. | ||
| model_inputs = { | ||
| "input_ids": input_ids, | ||
| "attention_mask": attention_mask, | ||
| } | ||
| Args: | ||
| use_torch (`bool`): | ||
| Whether the tensor used during inference are of type torch.Tensor or not. | ||
| """ | ||
| if use_torch is False and self.use_io_binding is True: | ||
| raise ValueError( | ||
| "IO Binding can not be used when passing numpy inputs. Please disable IO Binding" | ||
| " with `model.use_io_binding=False`, or pass `torch.Tensor` inputs instead." | ||
| ) | ||
| if self.parent_model.use_io_binding: | ||
| io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs) | ||
| def _prepare_onnx_inputs( | ||
| self, use_torch: bool, model_inputs: Dict[str, Union[torch.Tensor, np.ndarray]] | ||
| ) -> Dict[str, np.ndarray]: | ||
| """ | ||
| Prepares the inputs for ONNX Runtime by converting them to numpy arrays with the expected dtype. | ||
| if self.device.type == "cpu": | ||
| self.session.run_with_iobinding(io_binding) | ||
| Args: | ||
| use_torch (`bool`): | ||
| Whether the inputs are torch.Tensor or not. | ||
| inputs (`Dict[str, Union[torch.Tensor, np.ndarray]]`): | ||
| The inputs to prepare for ONNX Runtime. | ||
| Returns: | ||
| `Dict[str, np.ndarray]`: The inputs prepared for ONNX Runtime. | ||
| """ | ||
| onnx_inputs = {} | ||
| for input_name in self.input_names.keys(): | ||
| if model_inputs.get(input_name, None) is None: | ||
| raise ValueError(f"Input {input_name} is required by model but not provided.") | ||
| if use_torch: | ||
| onnx_inputs[input_name] = model_inputs[input_name].numpy(force=True) | ||
| else: | ||
| io_binding.synchronize_inputs() | ||
| self.session.run_with_iobinding(io_binding) | ||
| io_binding.synchronize_outputs() | ||
| onnx_inputs[input_name] = model_inputs[input_name] | ||
| last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) | ||
| else: | ||
| onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) | ||
| expected_dtype = TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]) | ||
| last_hidden_state = model_outputs["last_hidden_state"] | ||
| if onnx_inputs[input_name].dtype != expected_dtype: | ||
| onnx_inputs[input_name] = onnx_inputs[input_name].astype(expected_dtype) | ||
| return BaseModelOutput(last_hidden_state=last_hidden_state) | ||
| return onnx_inputs | ||
| def _prepare_onnx_outputs( | ||
| self, use_torch: bool, onnx_outputs: List[np.ndarray] | ||
| ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: | ||
| """ | ||
| Prepares the outputs from ONNX Runtime by converting them to torch.Tensor if requested. | ||
| class ORTDecoderForSeq2Seq(ORTModelPart): | ||
| """ | ||
| Decoder model with a language modeling head on top for ONNX Runtime inference. | ||
| """ | ||
| Args: | ||
| use_torch (`bool`): | ||
| Whether the outputs should be torch.Tensor or not. | ||
| onnx_outputs (`List[np.ndarray]`): | ||
| The outputs from ONNX Runtime. | ||
| def __init__( | ||
| self, | ||
| session: InferenceSession, | ||
| parent_model: "ORTModel", | ||
| ): | ||
| super().__init__(session, parent_model) | ||
| Returns: | ||
| `Dict[str, Union[torch.Tensor, np.ndarray]]`: The outputs prepared for the user. | ||
| """ | ||
| config = ( | ||
| self.parent_model.config.decoder | ||
| if hasattr(self.parent_model.config, "decoder") | ||
| else self.parent_model.config | ||
| ) | ||
| model_outputs = {} | ||
| self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) | ||
| for output_name, idx in self.output_names.items(): | ||
| model_outputs[output_name] = onnx_outputs[idx] | ||
| # TODO: make this less hacky. | ||
| self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)] | ||
| self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)] | ||
| if use_torch: | ||
| model_outputs[output_name] = torch.from_numpy(model_outputs[output_name]).to(self.device) | ||
| # To handle the old case when past_key_values were following the format: past_key_values_{idx} | ||
| if len(self.key_value_input_names) == 0: | ||
| self.key_value_input_names = [key for key in self.input_names if "key_values" in key] | ||
| if len(self.key_value_output_names) == 0: | ||
| self.key_value_output_names = [key for key in self.output_names if "key_values" in key] | ||
| return model_outputs | ||
| if self.parent_model.use_cache is True and len(self.key_value_output_names) == 0: | ||
| raise RuntimeError("Could not find the past key values in the provided model.") | ||
| def _prepare_output_buffer(self, output_name: str, output_shape: Tuple[int]) -> torch.Tensor: | ||
| """ | ||
| Prepares an output buffer for ONNX Runtime IO Binding. | ||
| self.use_past_in_outputs = len(self.key_value_output_names) > 0 | ||
| self.use_past_in_inputs = len(self.key_value_input_names) > 0 | ||
| self.use_fp16 = self.dtype == torch.float16 | ||
| Args: | ||
| output_name (`str`): | ||
| The name of the output for which to prepare the buffer. | ||
| output_shape (`Tuple[int]`): | ||
| The shape of the output buffer. | ||
| # We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2 | ||
| # can be used but do not support KV caching for the cross-attention key/values, see: | ||
| # https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L302-L311 | ||
| # This attribute is used to avoid returning cross-attention KV-cache in this case. | ||
| self.no_cross_attention_cache = getattr(self.parent_model, "no_cross_attention_cache", False) | ||
| Returns: | ||
| `torch.Tensor`: The output buffer. | ||
| if (not self.parent_model.use_merged and self.use_past_in_inputs) or self.no_cross_attention_cache: | ||
| self.num_pkv = 2 | ||
| """ | ||
| if len(output_shape) == 0: | ||
| raise ValueError("`output_shape` should not be empty") | ||
| elif not all(isinstance(dim, int) for dim in output_shape): | ||
| raise ValueError(f"`output_shape` should only contain integers but got {output_shape}.") | ||
| elif not all(dim > 0 for dim in output_shape): | ||
| raise ValueError(f"`output_shape` should only contain positive integers but got {output_shape}.") | ||
| output_dtype = TypeHelper.ort_type_to_torch_type(self.output_dtypes[output_name]) | ||
| if len(output_shape) > 0: | ||
| output_buffer = torch.empty(np.prod(output_shape), dtype=output_dtype, device=self.device) | ||
| else: | ||
| # When using a merged model, we always have the same number of output whether we use past key values or not, | ||
| # and in the case past key values are used, empty tensors are given as cross-attention past key values as they | ||
| # are constants | ||
| self.num_pkv = 4 | ||
| output_buffer = torch.tensor(0, dtype=output_dtype, device=self.device) | ||
| self.past_key_values_cross_attention_output_names = set() | ||
| for output_name in self.output_names: | ||
| if output_name.startswith("present") and "encoder" in output_name: | ||
| self.past_key_values_cross_attention_output_names.add(output_name) | ||
| return output_buffer | ||
| self.use_legacy_outputs = ( | ||
| self.parent_model.use_merged is False and len(self.past_key_values_cross_attention_output_names) > 0 | ||
| ) | ||
| def _output_shape_inference(self, output_name: str, known_axes_values: Dict[str, int]) -> List[int]: | ||
| """ | ||
| Infers the shape of a given output by using the `known_axes_values` mapping. | ||
| def compute_past_key_values_output_shapes( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| encoder_hidden_states: torch.Tensor, | ||
| use_cache_branch: Optional[bool], | ||
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||
| ) -> Dict[str, int]: | ||
| batch_size = input_ids.size(0) | ||
| Args: | ||
| output_name (`str`): | ||
| The name of the output for which to infer the shape. | ||
| known_axes_values (`Dict[str, int]`): | ||
| A mapping of the axis names to their values. | ||
| num_attention_heads = self.normalized_config.num_attention_heads | ||
| embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads | ||
| Returns: | ||
| `List[int]`: The inferred shape of the output. | ||
| """ | ||
| sequence_length = input_ids.size(1) | ||
| encoder_sequence_length = encoder_hidden_states.size(1) | ||
| if past_key_values is not None and use_cache_branch is not False: | ||
| # Here, use_cache_branch may be None in the case of separate decoder without/with past, or True if the with past branch | ||
| # of a merged decoder is used | ||
| sequence_length += past_key_values[0].size(2) | ||
| output_shape = list(self.output_shapes[output_name]) | ||
| self_attn_shape = (batch_size, num_attention_heads, sequence_length, embed_size_per_head) | ||
| for idx, axis_name in enumerate(output_shape): | ||
| if isinstance(axis_name, str): | ||
| output_shape[idx] = self._dynamic_axis_inference(axis_name, known_axes_values) | ||
| if past_key_values is not None and use_cache_branch is True: | ||
| cross_attn_shape = (0, num_attention_heads, 1, embed_size_per_head) | ||
| else: | ||
| cross_attn_shape = (batch_size, num_attention_heads, encoder_sequence_length, embed_size_per_head) | ||
| return output_shape | ||
| past_key_values_shapes = {} | ||
| for idx, name in enumerate(self.key_value_output_names): | ||
| is_self_attn = idx % 4 < 2 | ||
| # decoder with past does not ouput cross attention key/values as they are constants | ||
| past_key_values_shapes[name] = self_attn_shape if (is_self_attn or self.num_pkv == 2) else cross_attn_shape | ||
| return past_key_values_shapes | ||
| def _dynamic_axis_inference(self, axis_name: Union[str], known_axes_values: Dict[str, int]) -> int: | ||
| """ | ||
| Infers the value of a given dynamic axis by using the `known_axes_values` mapping. | ||
| def get_outputs_not_to_bind(self, use_merged_cache: bool) -> Set[str]: | ||
| result = { | ||
| output_name | ||
| for output_name in self.output_names | ||
| if (not output_name.startswith("present") and output_name not in {"loss", "logits"}) | ||
| } | ||
| if use_merged_cache is True: | ||
| # When using a merged decoder and the use cache branch, we output 0-dim tensors that IO Binding do not support. | ||
| # Therefore, we do not bind them. | ||
| result = result.union(self.past_key_values_cross_attention_output_names) | ||
| return result | ||
| For instance, for the following inputs: | ||
| axis_name = "sequence_length + past_sequence_length" | ||
| known_axes_values = {"batch_size": 2, "sequence_length": 3, "past_sequence_length": 7} | ||
| def forward( | ||
| The inferred value will be: | ||
| 3 + 7 = 10 | ||
| """ | ||
| if axis_name in known_axes_values: | ||
| # simple case, the axis value is known | ||
| return known_axes_values[axis_name] | ||
| tokens = axis_name.split(" ") | ||
| for idx, token in enumerate(tokens): | ||
| if token in known_axes_values: | ||
| tokens[idx] = str(known_axes_values[token]) | ||
| return int(eval(" ".join(tokens))) | ||
| def _prepare_io_binding( | ||
| self, | ||
| input_ids: torch.LongTensor, | ||
| encoder_hidden_states: torch.FloatTensor, | ||
| decoder_attention_mask: Optional[torch.LongTensor] = None, | ||
| encoder_attention_mask: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||
| cache_position: Optional[torch.Tensor] = None, | ||
| ) -> Seq2SeqLMOutput: | ||
| # Adding use_cache_branch in the signature here is just a hack for IO Binding | ||
| model_inputs: Dict[str, torch.Tensor], | ||
| outputs_to_not_bind: Optional[Set[str]] = None, | ||
| known_output_buffers: Optional[Dict[str, str]] = None, | ||
| known_output_shapes: Optional[Dict[str, Tuple[int]]] = None, | ||
| ) -> Tuple[Dict[str, Tuple[int]], Dict[str, torch.Tensor]]: | ||
| """ | ||
| Prepares IO binding for ONNX Runtime. | ||
| use_torch = isinstance(input_ids, torch.Tensor) | ||
| self.parent_model.raise_on_numpy_input_io_binding(use_torch) | ||
| Args: | ||
| model_inputs (`Dict[str, torch.Tensor]`): | ||
| The inputs to bind to the model. | ||
| outputs_to_not_bind (`Optional[Set[str]]`, defaults to `None`): | ||
| The names of the outputs that should not be bound. | ||
| known_output_buffers (`Optional[Dict[str, str]]`, defaults to `None`): | ||
| Sometimes we can reuse the same input buffer for the output. This is the case for the output sample | ||
| in a diffusion pipeline. It is possible to explicitely pass the buffer via this argument. | ||
| known_output_shapes (`Optional[Dict[str, Tuple[int]]]`, defaults to `None`): | ||
| It can be hard to infer all the output shapes from the inputs only. For instance for the past key / | ||
| values. It is possible to explicitely pass the shape via this argument. | ||
| # Flatten the past_key_values | ||
| if past_key_values is not None: | ||
| past_key_values = tuple( | ||
| past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer | ||
| ) | ||
| Returns: | ||
| `TupleDict[str, Tuple[int]], Dict[str, torch.Tensor]`: A dictionary of the output shapes and a dictionary of | ||
| the output buffers. | ||
| """ | ||
| # no-ops if merged decoder is not used | ||
| use_merged_no_cache = past_key_values is None and self.parent_model.use_merged | ||
| use_merged_cache = past_key_values is not None and self.parent_model.use_merged | ||
| use_cache_branch_tensor, past_key_values, cache_position = self.prepare_inputs_for_merged( | ||
| input_ids, past_key_values, cache_position, use_torch=use_torch | ||
| ) | ||
| known_axes_values = {} | ||
| model_inputs = { | ||
| "input_ids": input_ids, | ||
| "encoder_hidden_states": encoder_hidden_states, | ||
| "decoder_attention_mask": decoder_attention_mask, | ||
| "encoder_attention_mask": encoder_attention_mask, | ||
| "use_cache_branch": use_cache_branch_tensor, | ||
| "cache_position": cache_position, | ||
| } | ||
| if past_key_values is not None: | ||
| model_inputs.update(zip(self.key_value_input_names, past_key_values)) | ||
| for input_name in self.input_names.keys(): | ||
| input_shape = model_inputs[input_name].shape | ||
| if self.parent_model.use_io_binding: | ||
| known_output_shapes = self.compute_past_key_values_output_shapes( | ||
| input_ids, | ||
| encoder_hidden_states, | ||
| use_cache_branch=use_cache_branch_tensor.item() if use_cache_branch_tensor is not None else None, | ||
| past_key_values=past_key_values, | ||
| ) | ||
| outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache) | ||
| if not model_inputs[input_name].is_contiguous(): | ||
| model_inputs[input_name] = model_inputs[input_name].contiguous() | ||
| io_binding, output_shapes, output_buffers = self._prepare_io_binding( | ||
| self.session, | ||
| model_inputs, | ||
| known_output_shapes=known_output_shapes, | ||
| outputs_to_not_bind=outputs_to_not_bind, | ||
| tensor_dtype = model_inputs[input_name].dtype | ||
| expected_dtype = TypeHelper.ort_type_to_torch_type(self.input_dtypes[input_name]) | ||
| if tensor_dtype != expected_dtype: | ||
| model_inputs[input_name] = model_inputs[input_name].to(expected_dtype) | ||
| data_ptr = model_inputs[input_name].data_ptr() | ||
| if data_ptr == 0: | ||
| # During first generation, sequence_length can be 0 when use_cache=True, which results in data_ptr to also be 0. | ||
| # To keep compatibility with IO binding, we pass the data pointer of a non-empty tensor. | ||
| # No impact because past_key_values will not be used during the first generation. | ||
| data_ptr = NON_EMPTY_TENSOR.data_ptr() | ||
| self._io_binding.bind_input( | ||
| input_name, | ||
| self.device.type, | ||
| self.device.index or 0, | ||
| TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]), | ||
| input_shape, | ||
| data_ptr, | ||
| ) | ||
| if self.device.type == "cpu": | ||
| self.session.run_with_iobinding(io_binding) | ||
| for idx, axis_name in enumerate(self.input_shapes[input_name]): | ||
| if isinstance(axis_name, str): | ||
| known_axes_values[axis_name] = input_shape[idx] | ||
| output_shapes = {} | ||
| output_buffers = {} | ||
| known_output_shapes = known_output_shapes or {} | ||
| known_output_buffers = known_output_buffers or {} | ||
| outputs_to_not_bind = outputs_to_not_bind or set() | ||
| for output_name in self.output_names.keys(): | ||
| if output_name in outputs_to_not_bind: | ||
| continue | ||
| if output_name in known_output_shapes: | ||
| output_shape = known_output_shapes[output_name] | ||
| else: | ||
| io_binding.synchronize_inputs() | ||
| self.session.run_with_iobinding(io_binding) | ||
| io_binding.synchronize_outputs() | ||
| output_shape = self._output_shape_inference(output_name, known_axes_values) | ||
| # Set -1 for sequence_length as it could be larger than the real sequence_length | ||
| for name, shape in output_shapes.items(): | ||
| if name in self.key_value_output_names: | ||
| output_shapes[name] = shape[:2] + (-1,) + shape[3:] | ||
| if output_name in known_output_buffers: | ||
| output_buffer = known_output_buffers[output_name] | ||
| else: | ||
| output_buffer = self._prepare_output_buffer(output_name, output_shape) | ||
| # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the | ||
| # self-attention layer and 2 to the cross-attention layer) | ||
| out_past_key_values = () | ||
| for name in self.key_value_output_names: | ||
| # TODO: this should be improved | ||
| if name in self.past_key_values_cross_attention_output_names and use_merged_cache: | ||
| continue | ||
| out_past_key_values += (output_buffers[name].view(output_shapes[name]),) | ||
| data_ptr = output_buffer.data_ptr() | ||
| logits = output_buffers["logits"].view(output_shapes["logits"]) | ||
| self._io_binding.bind_output( | ||
| output_name, | ||
| self.device.type, | ||
| self.device.index or 0, | ||
| TypeHelper.ort_type_to_numpy_type(self.output_dtypes[output_name]), | ||
| output_shape, | ||
| data_ptr, | ||
| ) | ||
| loss = None | ||
| if "loss" in self.output_names: | ||
| loss = output_buffers["loss"].view(output_shapes["loss"]) | ||
| output_buffers[output_name] = output_buffer | ||
| output_shapes[output_name] = output_shape | ||
| if not self.use_past_in_outputs: | ||
| out_past_key_values = None | ||
| elif not self.use_past_in_inputs or use_merged_no_cache or self.no_cross_attention_cache: | ||
| out_past_key_values = tuple( | ||
| out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv) | ||
| ) | ||
| else: | ||
| if self.use_legacy_outputs is True: | ||
| msg = ( | ||
| "For the decoder with past, using ONNX models outputting cross attention past key values" | ||
| " is deprecated and the support will be removed in optimum 2.0. We recommend exporting again the model" | ||
| " with optimum>=1.7.3." | ||
| ) | ||
| warn_once(logger, msg=msg) | ||
| out_past_key_values = tuple( | ||
| out_past_key_values[i : i + self.num_pkv] | ||
| for i in range(0, len(out_past_key_values), self.num_pkv) | ||
| ) | ||
| # grab the cross attention key/values from the inputs | ||
| elif self.num_pkv == 2: | ||
| out_past_key_values = tuple( | ||
| out_past_key_values[i : i + self.num_pkv] | ||
| + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv] | ||
| for i in range(0, len(out_past_key_values), self.num_pkv) | ||
| ) | ||
| elif self.num_pkv == 4: | ||
| # despite num_pkv being 4, we did not bind the cross-attention output | ||
| out_past_key_values = tuple( | ||
| out_past_key_values[i : i + 2] + past_key_values[2 * i + 2 : 2 * i + 4] | ||
| for i in range(0, len(out_past_key_values), 2) | ||
| ) | ||
| else: | ||
| raise ValueError("Unsupported num_pkv") | ||
| else: | ||
| onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) | ||
| return output_shapes, output_buffers | ||
| # TODO: using a new variable out_past_key_values is memory inefficient, | ||
| # past_key_values is not used anymore at this point | ||
| # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the | ||
| # self-attention layer and 2 to the cross-attention layer) | ||
| out_past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names) | ||
| def forward(self, *args, **kwargs): | ||
| raise NotImplementedError( | ||
| "The `forward` method should be implemented in the derived class. " | ||
| "Please refer to the documentation for more details." | ||
| ) | ||
| loss = model_outputs.get("loss", None) | ||
| logits = model_outputs["logits"] | ||
| def __call__(self, *args, **kwargs): | ||
| return self.forward(*args, **kwargs) | ||
| # TODO: this is extremely ugly and unreadable. What if cross-attention k/v change? | ||
| # Tuple of tuple of length `n_layers`, with each tuple of length equal to: | ||
| # * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention) | ||
| # * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant) | ||
| if not self.use_past_in_outputs: | ||
| out_past_key_values = None | ||
| elif not self.use_past_in_inputs or use_merged_no_cache or self.no_cross_attention_cache: | ||
| out_past_key_values = tuple( | ||
| out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv) | ||
| ) | ||
| else: | ||
| if self.use_legacy_outputs is True: | ||
| msg = ( | ||
| "For the decoder with past, using ONNX models outputting cross attention past key values" | ||
| " is deprecated and the support will be removed in optimum 2.0. We recommend exporting again the model" | ||
| " with optimum>=1.7.3." | ||
| ) | ||
| warn_once(logger, msg=msg) | ||
| out_past_key_values = tuple( | ||
| out_past_key_values[i : i + self.num_pkv] | ||
| for i in range(0, len(out_past_key_values), self.num_pkv) | ||
| ) | ||
| # grab the cross attention key/values from the inputs | ||
| elif self.num_pkv == 2: | ||
| out_past_key_values = tuple( | ||
| out_past_key_values[i : i + self.num_pkv] | ||
| + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv] | ||
| for i in range(0, len(out_past_key_values), self.num_pkv) | ||
| ) | ||
| elif self.num_pkv == 4: | ||
| out_past_key_values = tuple( | ||
| out_past_key_values[i : i + 2] + past_key_values[i + 2 : i + 4] | ||
| for i in range(0, len(out_past_key_values), self.num_pkv) | ||
| ) | ||
| else: | ||
| raise ValueError("Unsupported num_pkv") | ||
| def save_session(self, save_directory: Union[str, Path]): | ||
| """ | ||
| Saves the ONNX Runtime session to the specified directory. | ||
| return Seq2SeqLMOutput(loss=loss, logits=logits, past_key_values=out_past_key_values) | ||
| Args: | ||
| save_directory (`Union[str, Path]`): | ||
| The directory where to save the ONNX Runtime session. | ||
| """ | ||
| def prepare_inputs_for_merged( | ||
| self, | ||
| input_ids: Optional[Union[torch.LongTensor, np.ndarray]], | ||
| past_key_values: Optional[Tuple[Union[torch.FloatTensor, np.ndarray]]], | ||
| cache_position: Optional[Union[torch.Tensor, np.ndarray]], | ||
| use_torch: bool, | ||
| ): | ||
| constructor = torch if use_torch is True else np | ||
| os.makedirs(save_directory, exist_ok=True) | ||
| if self.parent_model.use_merged: | ||
| # Uses without/with branch of a merged decoder depending on whether real past key values are passed | ||
| use_cache_branch_tensor = constructor.full((1,), past_key_values is not None) | ||
| if use_torch and use_cache_branch_tensor is not None: | ||
| use_cache_branch_tensor = use_cache_branch_tensor.to(self.device) | ||
| else: | ||
| use_cache_branch_tensor = None | ||
| model_path = Path(self.session._model_path) | ||
| model_save_path = Path(save_directory) / model_path.name | ||
| external_data_paths = _get_model_external_data_paths(model_path) | ||
| external_data_save_paths = [ | ||
| Path(save_directory) / external_data_path.name for external_data_path in external_data_paths | ||
| ] | ||
| # Generate dummy past for the first forward if uses a merged decoder | ||
| if self.parent_model.use_merged and past_key_values is None: | ||
| batch_size = input_ids.shape[0] | ||
| num_attention_heads = self.normalized_config.num_attention_heads | ||
| embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads | ||
| dtype = constructor.float16 if self.use_fp16 else constructor.float32 | ||
| shape = (batch_size, num_attention_heads, 1, embed_size_per_head) | ||
| key_or_value = constructor.zeros(shape, dtype=dtype) | ||
| shutil.copy(model_path, model_save_path) | ||
| for src_path, dst_path in zip(external_data_paths, external_data_save_paths): | ||
| shutil.copy(src_path, dst_path) | ||
| if use_torch is True: | ||
| key_or_value = key_or_value.to(self.device) | ||
| past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) | ||
| class ORTParentMixin: | ||
| """ | ||
| Wrapper class for multiple ORTSessionMixin instances. This class allows to combine multiple parts into | ||
| a single wrapper. It is useful for pipelines/models that require multiple parts to work together, such | ||
| as diffusion pipelines or encoder-decoder models, as it provides a unified interface for inference. | ||
| """ | ||
| # Generate dummy position cache for the first forward if uses a merged decoder | ||
| if self.parent_model.use_merged and cache_position is None: | ||
| cache_position = constructor.zeros((1,), dtype=constructor.int64) | ||
| if use_torch is True: | ||
| cache_position = cache_position.to(self.device) | ||
| def initialize_ort_attributes(self, parts: List[ORTSessionMixin]): | ||
| """ | ||
| Initializes the ORTParentMixin class. | ||
| Args: | ||
| parts (`List[ORTSessionMixin]`): | ||
| List of ORTSessionMixin instances to wrap. | ||
| """ | ||
| return use_cache_branch_tensor, past_key_values, cache_position | ||
| if len(parts) < 1: | ||
| raise ValueError("ORTParentMixin should be initialized with at least one part.") | ||
| if any(not isinstance(model, ORTSessionMixin) for model in parts): | ||
| raise ValueError("All parts passed to ORTParentMixin should be ORTSessionMixin instances.") | ||
| self.parts = parts | ||
| @property | ||
| def providers(self): | ||
| """ | ||
| Returns a list of Execution Providers registered with the session. | ||
| """ | ||
| if not all(model.providers == self.parts[0].providers for model in self.parts): | ||
| logger.warning( | ||
| "Calling `ORTParentMixin.providers` when the underlying parts have different values " | ||
| "for `providers` is not recommended. The value of the first session will be returned. " | ||
| ) | ||
| return self.parts[0].providers | ||
| @property | ||
| def provider(self): | ||
| """ | ||
| Returns the main Execution Provider registered with the session. | ||
| """ | ||
| if not all(model.provider == self.parts[0].provider for model in self.parts): | ||
| logger.warning( | ||
| "Calling `ORTParentMixin.provider` when the underlying parts have different values " | ||
| "for `provider` is not recommended. The value of the first session will be returned. " | ||
| ) | ||
| return self.parts[0].provider | ||
| @property | ||
| def provider_options(self): | ||
| """ | ||
| Returns a dictionary of Execution Providers configurations/options. | ||
| """ | ||
| if not all(model.provider_options == self.parts[0].provider_options for model in self.parts): | ||
| logger.warning( | ||
| "Calling `ORTParentMixin.provider_options` when the underlying parts have different values " | ||
| "for `provider_options` is not recommended. The value of the first session will be returned. " | ||
| ) | ||
| return self.parts[0].provider_options | ||
| @property | ||
| def provider_option(self): | ||
| """ | ||
| Returns the configuration/options of the main Execution Provider. | ||
| """ | ||
| if not all(model.provider_option == self.parts[0].provider_option for model in self.parts): | ||
| logger.warning( | ||
| "Calling `ORTParentMixin.provider_option` when the underlying parts have different values " | ||
| "for `provider_option` is not recommended. The value of the first session will be returned. " | ||
| ) | ||
| return self.parts[0].provider_option | ||
| @property | ||
| def device(self): | ||
| """ | ||
| Returns the `torch.device` associated with the ONNX Runtime session. | ||
| This device is inferred from the provider and provider options. | ||
| """ | ||
| if not all(model.device == self.parts[0].device for model in self.parts): | ||
| logger.warning( | ||
| "Calling `ORTParentMixin.device` when the underlying parts have different values " | ||
| "for `device` is not recommended. The value of the first session will be returned. " | ||
| ) | ||
| return self.parts[0].device | ||
| @property | ||
| def dtype(self): | ||
| """ | ||
| Returns the `torch.dtype` associated with the ONNX Runtime session. | ||
| This dtype is inferred from the input/output dtypes of the session. | ||
| If no floating point type is found, it defaults to `torch.float32`. | ||
| """ | ||
| if not all(model.dtype == self.parts[0].dtype for model in self.parts): | ||
| logger.warning( | ||
| "Calling `ORTParentMixin.dtype` when the underlying parts have different values " | ||
| "for `dtype` is not recommended. The value of the first session will be returned. " | ||
| ) | ||
| return self.parts[0].dtype | ||
| @property | ||
| def use_io_binding(self): | ||
| """ | ||
| Returns whether IO Binding is used or not. | ||
| """ | ||
| if not all(model.use_io_binding == self.parts[0].use_io_binding for model in self.parts): | ||
| logger.warning( | ||
| "Calling `ORTParentMixin.use_io_binding` when the underlying parts have different values " | ||
| "for `use_io_binding` is not recommended. The value of the first session will be returned. " | ||
| ) | ||
| return self.parts[0].use_io_binding | ||
| @use_io_binding.setter | ||
| def use_io_binding(self, value: bool): | ||
| """ | ||
| Setter for the use_io_binding property. | ||
| """ | ||
| for model in self.parts: | ||
| model.use_io_binding = value | ||
| def to(self, *args, **kwargs): | ||
| """ | ||
| Moves all parts to the specified device by updating the execution provider and its options. | ||
| Args: | ||
| device (`str`, `int`, `torch.device`): | ||
| The device to move the session to. It can be a string (e.g., "cuda", "cpu"), an integer (e.g., 0 for GPU 0), | ||
| or a `torch.device` object. | ||
| Returns: | ||
| `ORTParentMixin`: The updated session. | ||
| Raises: | ||
| ValueError: If the device is not supported or if the provider is not available. | ||
| """ | ||
| for model in self.parts: | ||
| model.to(*args, **kwargs) | ||
| return self |
@@ -21,5 +21,4 @@ # Copyright 2022 The HuggingFace Team. All rights reserved. | ||
| from tempfile import TemporaryDirectory | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | ||
| from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union | ||
| import numpy as np | ||
| import onnx | ||
@@ -34,5 +33,6 @@ import torch | ||
| import onnxruntime | ||
| from onnxruntime import InferenceSession, SessionOptions | ||
| from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export | ||
| from ..exporters.tasks import TasksManager | ||
| from ..onnx.utils import check_model_uses_external_data | ||
@@ -49,4 +49,3 @@ from ..utils import NormalizedConfigManager, is_transformers_version | ||
| from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel | ||
| from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache | ||
| from .utils import ONNX_WEIGHTS_NAME | ||
| from .utils import prepare_providers_and_provider_options | ||
@@ -138,73 +137,109 @@ | ||
| self, | ||
| model: onnxruntime.InferenceSession, | ||
| config: "PretrainedConfig", | ||
| *args, | ||
| config: "PretrainedConfig" = None, | ||
| session: "InferenceSession" = None, | ||
| use_io_binding: Optional[bool] = None, | ||
| generation_config: Optional["GenerationConfig"] = None, | ||
| model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, | ||
| preprocessors: Optional[List] = None, | ||
| generation_config: Optional[GenerationConfig] = None, | ||
| use_cache: Optional[bool] = None, | ||
| **kwargs, | ||
| ): | ||
| if use_io_binding is None: | ||
| use_io_binding = model.get_providers()[0] in ["CPUExecutionProvider", "CUDAExecutionProvider"] | ||
| # DEPRECATED BEHAVIOR | ||
| if args: | ||
| logger.warning( | ||
| "Instantiating an ORTModelForCausalLM with positional arguments is deprecated and will be removed in the next version. " | ||
| "Please use the keywords arguments {config, session, use_io_binding, generation_config, model_save_dir, use_cache} instead." | ||
| ) | ||
| # the old signature is ORTModelForCausalLM(model, config, use_io_binding, model_save_dir, preprocessors, generation_config, use_cache) | ||
| session = args[0] | ||
| if len(args) > 1: | ||
| config = args[1] | ||
| if len(args) > 2: | ||
| use_io_binding = args[2] | ||
| if len(args) > 3: | ||
| model_save_dir = args[3] | ||
| if len(args) > 4: | ||
| _ = args[4] | ||
| if len(args) > 5: | ||
| generation_config = args[5] | ||
| if len(args) > 6: | ||
| _ = args[6] | ||
| super().__init__(model, config, use_io_binding, model_save_dir, preprocessors, **kwargs) | ||
| if kwargs.get("model", None) is not None: | ||
| logger.warning( | ||
| "Passing the inference session as `model` argument to an ORTModelForCausalLM is deprecated. Please use `session` instead." | ||
| ) | ||
| session = kwargs.pop("model") | ||
| if kwargs: | ||
| logger.warning( | ||
| f"Some keyword arguments were passed to the ORTModelForCausalLM constructor that are not part of its signature: {', '.join(kwargs.keys())}. " | ||
| "These arguments will be ignored in the current version and will raise an error in the next version." | ||
| ) | ||
| self.num_pkv = 2 | ||
| if config is None: | ||
| raise ValueError( | ||
| "The parameter config is required. Please pass a config or use the from_pretrained method." | ||
| ) | ||
| if session is None: | ||
| raise ValueError( | ||
| "The parameter session is required. Please pass a session or use the from_pretrained method." | ||
| ) | ||
| ## END OF DEPRECATED BEHAVIOR | ||
| super().__init__(config=config, session=session, use_io_binding=use_io_binding, model_save_dir=model_save_dir) | ||
| self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) | ||
| self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)] | ||
| self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)] | ||
| self.use_cache = len(self.key_value_input_names) > 0 | ||
| if generation_config is None: | ||
| generation_config = GenerationConfig.from_model_config(config) | ||
| self.can_use_cache = len(self.key_value_input_names) > 0 and len(self.key_value_output_names) > 0 | ||
| self.is_merged = "use_cache_branch" in self.input_names | ||
| self.generation_config = generation_config | ||
| if is_transformers_version(">=", "4.44.99"): | ||
| misplaced_generation_parameters = self.config._get_non_default_generation_parameters() | ||
| if len(misplaced_generation_parameters) > 0: | ||
| logger.warning( | ||
| "Moving the following attributes in the config to the generation config: " | ||
| f"{misplaced_generation_parameters}. You are seeing this warning because you've set " | ||
| "generation parameters in the model config, as opposed to in the generation config.", | ||
| ) | ||
| for param_name, param_value in misplaced_generation_parameters.items(): | ||
| setattr(self.generation_config, param_name, param_value) | ||
| setattr(self.config, param_name, None) | ||
| self.onnx_paths = [self.model_path] | ||
| self.use_merged = "use_cache_branch" in self.input_names | ||
| self.model_type = self.config.model_type | ||
| self.use_fp16 = False | ||
| for inp in model.get_inputs(): | ||
| if ( | ||
| inp.name == "past_key_values" or inp.name in self.key_value_input_names | ||
| ) and inp.type == "tensor(float16)": | ||
| self.use_fp16 = True | ||
| break | ||
| # Reference: https://github.com/huggingface/optimum/pull/1381 | ||
| model_type = config.model_type.replace("_", "-") | ||
| model_type = self.config.model_type.replace("_", "-") | ||
| if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.input_names: | ||
| logger.warning( | ||
| f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. " | ||
| "We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support." | ||
| f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although the model type {model_type} " | ||
| "requires it. for correct batched generation. We strongly encourage to re-export the model with " | ||
| "a newer version of Optimum for better performance and more reliable generation. " | ||
| ) | ||
| if use_cache ^ self.use_cache: | ||
| raise ValueError( | ||
| f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. " | ||
| f"Please load your current model with `use_cache={self.use_cache}` or export the original model " | ||
| f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " | ||
| "To export your model, simply set `export=True`." | ||
| if not self.can_use_cache and self.generation_config.use_cache: | ||
| logger.warning( | ||
| "`model.generation_config.use_cache=True` but the loaded model does not support using the past key values cache." | ||
| "Please re-export the original model once again with `use_cache=True` to be able to use it during generation. " | ||
| "Or set `model.generation_config.use_cache=False` to avoid errors from attempting to use the cache. " | ||
| "To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`." | ||
| ) | ||
| if use_io_binding and not use_cache: | ||
| raise ValueError( | ||
| "The parameters combination use_cache=False, use_io_binding=True is not supported. " | ||
| "Please either pass use_cache=True, use_io_binding=True (default), or use_cache=False, use_io_binding=False." | ||
| if self.config.model_type == "gemma": | ||
| self.embed_size_per_head = self.normalized_config.head_dim | ||
| else: | ||
| self.embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads | ||
| if self.config.model_type in {"gemma", "mistral", "llama", "qwen2", "qwen3", "qwen3_moe", "granite"}: | ||
| self.num_key_value_heads = self.normalized_config.num_key_value_heads | ||
| elif self.config.model_type == "falcon": | ||
| self.num_key_value_heads = ( | ||
| self.config.num_kv_heads | ||
| if (self.config.new_decoder_architecture or not self.config.multi_query) | ||
| else 1 | ||
| ) | ||
| else: | ||
| self.num_key_value_heads = self.normalized_config.num_attention_heads | ||
| @property | ||
| def use_cache(self): | ||
| logger.warning( | ||
| "The `ORTModelForCausalLM.use_cache` property is deprecated and will be removed in a future version. " | ||
| "Please rather use `ORTModelForCausalLM.can_use_cache` to check if a model supports using cache during generation. " | ||
| "And use `ORTModelForCausalLM.generation_config.use_cache` to check if the model is configured to use cache during generation." | ||
| ) | ||
| return self.can_use_cache | ||
| @property | ||
| def use_merged(self): | ||
| logger.warning( | ||
| "The `ORTModelForCausalLM.use_merged` property is deprecated and will be removed in a future version. " | ||
| "Please rather use `ORTModelForCausalLM.is_merged` to check if the underlying model is merged or not." | ||
| ) | ||
| return self.is_merged | ||
| @add_start_docstrings_to_model_forward( | ||
@@ -221,34 +256,52 @@ CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") | ||
| input_ids: torch.LongTensor, | ||
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| attention_mask: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | ||
| use_cache_branch: bool = None, | ||
| use_cache: Optional[bool] = None, | ||
| **kwargs, | ||
| ) -> CausalLMOutputWithPast: | ||
| # adding use_cache_branch in the signature here is just a hack for IO Binding | ||
| use_torch = isinstance(input_ids, torch.Tensor) | ||
| self.raise_on_numpy_input_io_binding(use_torch) | ||
| use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
| known_output_shapes = {} | ||
| if use_cache and not self.can_use_cache: | ||
| raise ValueError( | ||
| f"`use_cache={use_cache}` was passed to the model but the loaded model only supports `use_cache={self.can_use_cache}`. " | ||
| f"Please load your current model with `use_cache={self.can_use_cache}` or export the original model " | ||
| f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " | ||
| "To re-export your model, simply set `export=True` in the `from_pretrained` method." | ||
| ) | ||
| if self.use_cache: | ||
| if past_key_values is not None: | ||
| # Flatten the past_key_values (gpt_bigcode has fused key/value cache, so no need to flatten it) | ||
| if self.model_type != "gpt_bigcode": | ||
| past_key_values = tuple( | ||
| past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer | ||
| ) | ||
| if past_key_values is not None and isinstance(past_key_values[0], tuple): | ||
| # Flattens the past_key_values to a single tuple | ||
| past_key_values = sum(past_key_values, ()) | ||
| # Create dummy past_key_values for decoder first generation step if none given | ||
| use_cache_branch, past_key_values, known_output_shapes = self.prepare_past_key_values( | ||
| input_ids, past_key_values, use_torch | ||
| ) | ||
| if "position_ids" in self.input_names and position_ids is None: | ||
| if attention_mask is not None: | ||
| # Create position_ids from attention_mask | ||
| position_ids = attention_mask.cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| if past_key_values is not None: | ||
| position_ids = position_ids[:, -1].unsqueeze(-1) | ||
| else: | ||
| raise ValueError( | ||
| "The model requires position_ids for batched generation but none were provided. " | ||
| "Please provide position_ids or attention_mask (from which position_ids can be inferred)." | ||
| ) | ||
| # Create position_ids on the fly for batch generation | ||
| if "position_ids" in self.input_names and position_ids is None and attention_mask is not None: | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| if past_key_values: | ||
| position_ids = position_ids[:, -1].unsqueeze(-1) | ||
| use_cache_branch = None | ||
| if self.is_merged: | ||
| # Uses cache branch of merged decoders depending on whether real past key values are passed | ||
| use_cache_branch = torch.full((1,), past_key_values is not None, dtype=torch.bool, device=self.device) | ||
| if past_key_values is None and len(self.key_value_input_names) > 0: | ||
| # Generates the input pkv for the first forward of the model (merged or with past) | ||
| batch_size, seq_len = input_ids.shape | ||
| if self.config.model_type == "gpt_bigcode": | ||
| shape = (batch_size, 0, self.embed_size_per_head * 2) | ||
| else: | ||
| shape = (batch_size, self.num_key_value_heads, 0, self.embed_size_per_head) | ||
| tensor = torch.empty(shape, dtype=self.dtype, device=self.device) | ||
| past_key_values = tuple(tensor for _ in range(len(self.key_value_input_names))) | ||
| model_inputs = { | ||
@@ -260,19 +313,34 @@ "input_ids": input_ids, | ||
| } | ||
| if len(self.key_value_input_names) > 0: | ||
| model_inputs.update(zip(self.key_value_input_names, past_key_values)) | ||
| if past_key_values is not None: | ||
| model_inputs.update( | ||
| zip(self.key_value_input_names, past_key_values), | ||
| ) | ||
| known_output_shapes = None | ||
| outputs_to_not_bind = None | ||
| if use_cache: | ||
| # Infers the shape of the output pkv | ||
| batch_size, seq_len = input_ids.shape | ||
| if self.config.model_type == "gpt_bigcode": | ||
| pkv_seq_len, embed_size_per_head_2 = past_key_values[0].shape[1:] | ||
| pkv_output_shape = (batch_size, pkv_seq_len + seq_len, embed_size_per_head_2) | ||
| else: | ||
| num_key_value_heads, pkv_seq_len, embed_size_per_head = past_key_values[0].shape[1:] | ||
| pkv_output_shape = (batch_size, num_key_value_heads, pkv_seq_len + seq_len, embed_size_per_head) | ||
| known_output_shapes = dict.fromkeys(self.key_value_output_names, pkv_output_shape) | ||
| else: | ||
| # Don't bind the output pkv if not used/returned | ||
| outputs_to_not_bind = self.key_value_output_names | ||
| if self.use_io_binding: | ||
| io_binding, output_shapes, output_buffers = self._prepare_io_binding( | ||
| self.model, model_inputs, known_output_shapes=known_output_shapes | ||
| output_shapes, output_buffers = self._prepare_io_binding( | ||
| model_inputs, | ||
| outputs_to_not_bind=outputs_to_not_bind, | ||
| known_output_shapes=known_output_shapes, | ||
| ) | ||
| if self.device.type == "cpu": | ||
| self.model.run_with_iobinding(io_binding) | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| else: | ||
| io_binding.synchronize_inputs() | ||
| self.model.run_with_iobinding(io_binding) | ||
| io_binding.synchronize_outputs() | ||
| self._io_binding.synchronize_inputs() | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| self._io_binding.synchronize_outputs() | ||
@@ -282,115 +350,83 @@ loss = output_buffers.get("loss", None) | ||
| if self.use_cache: | ||
| # Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2 for the self-attention) | ||
| if use_cache: | ||
| past_key_values = tuple( | ||
| output_buffers[name].view(output_shapes[name]) for name in self.key_value_output_names | ||
| output_buffers.pop(name).view(output_shapes[name]) for name in self.key_value_output_names | ||
| ) | ||
| else: | ||
| onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) | ||
| onnx_outputs = self.model.run(None, onnx_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) | ||
| loss = model_outputs.get("loss", None) | ||
| logits = model_outputs["logits"] | ||
| loss = model_outputs.pop("loss", None) | ||
| logits = model_outputs.pop("logits") | ||
| if self.use_cache: | ||
| # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention) | ||
| past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names) | ||
| if use_cache: | ||
| past_key_values = tuple(model_outputs.pop(name) for name in self.key_value_output_names) | ||
| if self.use_cache and self.model_type != "gpt_bigcode": | ||
| if use_cache and self.config.model_type != "gpt_bigcode": | ||
| # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and per decoder layer | ||
| past_key_values = tuple( | ||
| past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) | ||
| ) | ||
| past_key_values = tuple(past_key_values[i : i + 2] for i in range(0, len(past_key_values), 2)) | ||
| return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values) | ||
| def prepare_past_key_values( | ||
| def prepare_inputs_for_generation(self, *args, **kwargs): | ||
| if is_transformers_version("<", "4.46.0"): | ||
| return self._prepare_inputs_for_generation_legacy(*args, **kwargs) | ||
| else: | ||
| return super().prepare_inputs_for_generation(*args, **kwargs) | ||
| # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation | ||
| def _prepare_inputs_for_generation_legacy( | ||
| self, | ||
| input_ids: Union[None, torch.LongTensor, np.ndarray], | ||
| past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]], | ||
| use_torch: bool, | ||
| input_ids, | ||
| attention_mask=None, | ||
| past_key_values=None, | ||
| token_type_ids=None, | ||
| position_ids=None, | ||
| use_cache=None, | ||
| **kwargs, | ||
| ): | ||
| sequence_length = input_ids.shape[1] | ||
| constructor = torch if use_torch else np | ||
| if self.use_merged: | ||
| # Uses without/with branch of a merged decoder depending on whether real past key values are passed | ||
| use_cache_branch = constructor.full((1,), past_key_values is not None) | ||
| else: | ||
| # Uses separate decoders | ||
| use_cache_branch = None | ||
| if use_torch and use_cache_branch is not None: | ||
| use_cache_branch = use_cache_branch.to(self.device) | ||
| pkv_output_shape = {} | ||
| # Generate dummy past for the first forward if uses a merged decoder | ||
| if past_key_values is None: | ||
| batch_size = input_ids.shape[0] | ||
| embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads | ||
| if self.model_type == "gemma": | ||
| num_attention_heads = self.normalized_config.num_key_value_heads | ||
| embed_size_per_head = self.normalized_config.head_dim | ||
| elif self.model_type in {"mistral", "llama", "qwen2", "granite"}: | ||
| num_attention_heads = self.normalized_config.num_key_value_heads | ||
| if past_key_values is not None: | ||
| if self.config.model_type == "gpt_bigcode": | ||
| if self.config.multi_query: | ||
| past_length = past_key_values[0].shape[1] | ||
| else: | ||
| past_length = past_key_values[0].shape[2] | ||
| else: | ||
| num_attention_heads = self.normalized_config.num_attention_heads | ||
| past_length = past_key_values[0][0].shape[2] | ||
| dtype = constructor.float16 if self.use_fp16 else constructor.float32 | ||
| # TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY. | ||
| if self.__class__.__name__ == "ORTBloomForCausalLM": | ||
| shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head) | ||
| shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0) | ||
| key = constructor.zeros(shape_key, dtype=dtype) | ||
| value = constructor.zeros(shape_value, dtype=dtype) | ||
| if use_torch: | ||
| key = key.to(self.device) | ||
| value = value.to(self.device) | ||
| past_key_values = tuple( | ||
| key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value] | ||
| ) | ||
| for name, value in zip(self.key_value_output_names, past_key_values): | ||
| shape = [*value.shape] | ||
| index = 1 if "value" in name else 2 | ||
| shape[index] += sequence_length | ||
| pkv_output_shape[name] = shape | ||
| elif self.model_type == "gpt_bigcode": | ||
| # GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor. | ||
| shape_key_and_value = (batch_size, 0, embed_size_per_head * 2) | ||
| key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype) | ||
| if use_torch: | ||
| key_and_value = key_and_value.to(self.device) | ||
| past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names))) | ||
| for name, value in zip(self.key_value_output_names, past_key_values): | ||
| shape = [*value.shape] | ||
| shape[1] += sequence_length | ||
| pkv_output_shape[name] = shape | ||
| if input_ids.shape[1] > past_length: | ||
| remove_prefix_length = past_length | ||
| else: | ||
| num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads | ||
| shape = (batch_size, num_key_value_heads, 0, embed_size_per_head) | ||
| key_or_value = constructor.zeros(shape, dtype=dtype) | ||
| remove_prefix_length = input_ids.shape[1] - 1 | ||
| input_ids = input_ids[:, remove_prefix_length:] | ||
| if use_torch: | ||
| key_or_value = key_or_value.to(self.device) | ||
| return { | ||
| "input_ids": input_ids, | ||
| "attention_mask": attention_mask, | ||
| "past_key_values": past_key_values, | ||
| "token_type_ids": token_type_ids, | ||
| "position_ids": position_ids, | ||
| "use_cache": use_cache, | ||
| } | ||
| past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) | ||
| @staticmethod | ||
| def _reorder_cache( | ||
| past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor | ||
| ) -> Tuple[Tuple[torch.Tensor]]: | ||
| if isinstance(past_key_values, tuple) and isinstance(past_key_values[0], tuple): | ||
| # GPT2 style | ||
| return tuple( | ||
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) | ||
| for layer_past in past_key_values | ||
| ) | ||
| elif isinstance(past_key_values, tuple) and isinstance(past_key_values[0], torch.Tensor): | ||
| # GPT BigCode style | ||
| return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) | ||
| else: | ||
| raise ValueError( | ||
| f"Unexpected past_key_values: {past_key_values}. " | ||
| "Expected tuple of tuples (GPT2 style) or tuple of tensors (GPT BigCode style)." | ||
| ) | ||
| for name, value in zip(self.key_value_output_names, past_key_values): | ||
| shape = [*value.shape] | ||
| shape[2] += sequence_length | ||
| pkv_output_shape[name] = shape | ||
| return use_cache_branch, past_key_values, pkv_output_shape | ||
| @classmethod | ||
@@ -401,29 +437,25 @@ def _from_pretrained( | ||
| config: "PretrainedConfig", | ||
| token: Optional[Union[bool, str]] = None, | ||
| revision: Optional[str] = None, | ||
| # hub options | ||
| subfolder: str = "", | ||
| revision: str = "main", | ||
| force_download: bool = False, | ||
| local_files_only: bool = False, | ||
| trust_remote_code: bool = False, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| token: Optional[Union[bool, str]] = None, | ||
| # file options | ||
| file_name: Optional[str] = None, | ||
| subfolder: str = "", | ||
| # session options | ||
| provider: str = "CPUExecutionProvider", | ||
| providers: Optional[Sequence[str]] = None, | ||
| provider_options: Optional[Union[Sequence[Dict[str, Any]], Dict[str, Any]]] = None, | ||
| session_options: Optional[SessionOptions] = None, | ||
| # inference options | ||
| use_cache: bool = True, | ||
| local_files_only: bool = False, | ||
| use_merged: Optional[bool] = None, | ||
| provider: str = "CPUExecutionProvider", | ||
| session_options: Optional[onnxruntime.SessionOptions] = None, | ||
| provider_options: Optional[Dict[str, Any]] = None, | ||
| use_io_binding: Optional[bool] = None, | ||
| generation_config: Optional[GenerationConfig] = None, | ||
| # other arguments | ||
| model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, | ||
| **kwargs, | ||
| ) -> "ORTModelForCausalLM": | ||
| generation_config = kwargs.pop("generation_config", None) | ||
| # We do not implement the logic for use_cache=False, use_merged=True | ||
| if use_cache is False: | ||
| if use_merged is True: | ||
| raise ValueError( | ||
| "The parameters combination use_cache=False, use_merged=True is not supported." | ||
| " To use a merged decoder, past key values must be used." | ||
| ) | ||
| use_merged = False | ||
| onnx_files = find_files_matching_pattern( | ||
@@ -504,4 +536,8 @@ model_id, | ||
| ) | ||
| new_model_save_dir = Path(model_cache_path).parent | ||
| # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it | ||
| # instead of the path only. | ||
| if model_save_dir is None: | ||
| model_save_dir = Path(model_cache_path).parent | ||
| try: | ||
@@ -523,13 +559,7 @@ cached_file( | ||
| # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it | ||
| # instead of the path only. | ||
| if model_save_dir is None: | ||
| model_save_dir = new_model_save_dir | ||
| # This should be removed at some point | ||
| onnx_model = onnx.load(str(model_cache_path), load_external_data=False) | ||
| model_uses_external_data = check_model_uses_external_data(onnx_model) | ||
| if model_uses_external_data: | ||
| onnx_model = onnx.load(str(model_cache_path), load_external_data=True) | ||
| input_dims = { | ||
@@ -543,5 +573,3 @@ node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] | ||
| } | ||
| override_dims = False | ||
| # Since v1.7.0 decoder with past models have fixed sequence length of 1 | ||
@@ -553,3 +581,2 @@ # To keep these models compatible we set this dimension to dynamic | ||
| override_dims = True | ||
| # Since https://github.com/huggingface/optimum/pull/871/ | ||
@@ -561,3 +588,2 @@ # changed axis notation/naming during export, we need to update the dims | ||
| override_dims = True | ||
| if override_dims: | ||
@@ -582,32 +608,11 @@ # this is kinda dangerous, warning the user is the least we can do | ||
| ) | ||
| # Since transformers 4.44, the bloom model has been updated to use the standard cache format | ||
| use_old_bloom_modeling = not is_transformers_version(">=", "4.44") | ||
| for input_name in input_dims.keys(): | ||
| if input_dims[input_name][0] == "batch_size x num_heads": | ||
| use_old_bloom_modeling = True | ||
| del onnx_model | ||
| model = ORTModel.load_model( | ||
| model_cache_path, | ||
| provider=provider, | ||
| session_options=session_options, | ||
| provider_options=provider_options, | ||
| ) | ||
| # Important: for encoder-decoder models used with CausalLM, we need to set the is_decoder flag to True | ||
| # and the is_encoder_decoder flag to False. This is needed for the model to work correctly with generation logic. | ||
| if hasattr(config, "is_decoder"): | ||
| config.is_decoder = True | ||
| if hasattr(config, "is_encoder_decoder"): | ||
| config.is_encoder_decoder = False | ||
| if config.model_type == "bloom" and use_old_bloom_modeling: | ||
| init_cls = ORTBloomForCausalLM | ||
| elif config.model_type == "falcon": | ||
| init_cls = ORTFalconForCausalLM | ||
| elif config.model_type == "mpt": | ||
| init_cls = ORTMPTForCausalLM | ||
| # if model was exported with position_ids it means the model was exported with transformers >= v4.46 | ||
| elif config.model_type == "opt" and "position_ids" not in input_dims: | ||
| init_cls = ORTOPTForCausalLM | ||
| elif config.model_type == "gpt_bigcode": | ||
| init_cls = ORTGPTBigCodeForCausalLM | ||
| else: | ||
| init_cls = ORTModelForCausalLM | ||
| if generation_config is None: | ||
@@ -617,55 +622,76 @@ try: | ||
| model_id, | ||
| token=token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| local_files_only=local_files_only, | ||
| token=token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| ) | ||
| except OSError: | ||
| logger.info( | ||
| "Generation config file not found, using a generation config created from the model config." | ||
| logger.info("Generation config file not found, creating a new one from model config.") | ||
| generation_config = GenerationConfig.from_model_config(config) | ||
| # TODO: not sure if setting config.use_cache is needed for older versions of transformers | ||
| generation_config.use_cache = use_cache | ||
| config.use_cache = use_cache | ||
| if is_transformers_version(">=", "4.45.0"): | ||
| misplaced_generation_parameters = config._get_non_default_generation_parameters() | ||
| if len(misplaced_generation_parameters) > 0: | ||
| logger.warning( | ||
| "Moving the following attributes in the config to the generation config: " | ||
| f"{misplaced_generation_parameters}. You are seeing this warning because you've set " | ||
| "generation parameters in the model config, as opposed to in the generation config.", | ||
| ) | ||
| for param_name, param_value in misplaced_generation_parameters.items(): | ||
| setattr(generation_config, param_name, param_value) | ||
| setattr(config, param_name, None) | ||
| return init_cls( | ||
| model=model, | ||
| providers, provider_options = prepare_providers_and_provider_options( | ||
| provider=provider, providers=providers, provider_options=provider_options | ||
| ) | ||
| session = InferenceSession( | ||
| model_cache_path, | ||
| providers=providers, | ||
| provider_options=provider_options, | ||
| sess_options=session_options, | ||
| ) | ||
| return cls( | ||
| config=config, | ||
| session=session, | ||
| use_io_binding=use_io_binding, | ||
| generation_config=generation_config, | ||
| model_save_dir=model_save_dir, | ||
| use_cache=use_cache, | ||
| generation_config=generation_config, | ||
| ) | ||
| @classmethod | ||
| def _from_transformers( | ||
| def _export( | ||
| cls, | ||
| model_id: str, | ||
| model_id: Union[str, Path], | ||
| config: "PretrainedConfig", | ||
| token: Optional[Union[bool, str]] = None, | ||
| # hub options | ||
| subfolder: str = "", | ||
| revision: str = "main", | ||
| force_download: bool = True, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| subfolder: str = "", | ||
| force_download: bool = False, | ||
| local_files_only: bool = False, | ||
| trust_remote_code: bool = False, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| token: Optional[Union[bool, str]] = None, | ||
| # inference options | ||
| use_cache: bool = True, | ||
| use_merged: bool = False, | ||
| provider: str = "CPUExecutionProvider", | ||
| session_options: Optional[onnxruntime.SessionOptions] = None, | ||
| provider_options: Optional[Dict[str, Any]] = None, | ||
| use_io_binding: Optional[bool] = None, | ||
| task: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> "ORTModelForCausalLM": | ||
| file_name = ONNX_WEIGHTS_NAME | ||
| # this is garanteed to work since we it uses a mapping from model classes to task names | ||
| # instead of relying on the hub metadata or the model configuration | ||
| task = TasksManager._infer_task_from_model_or_model_class(model_class=cls.auto_model_class) | ||
| if use_cache: | ||
| task += "-with-past" | ||
| if use_merged: | ||
| logger.warning("The `use_merged` argument is deprecated when the model is exported, and not used anymore.") | ||
| use_merged = False | ||
| if kwargs.get("task", None) is not None: | ||
| raise ValueError( | ||
| f"The `task` argument is not needed when exporting a model with `{cls.__name__}`. " | ||
| f"The `task` is automatically inferred from the class as `{task}`." | ||
| ) | ||
| if task is None: | ||
| task = cls._auto_model_to_task(cls.auto_model_class) | ||
| if use_cache: | ||
| task += "-with-past" | ||
| save_dir = TemporaryDirectory() | ||
@@ -695,291 +721,15 @@ save_dir_path = Path(save_dir.name) | ||
| use_cache=use_cache, | ||
| use_merged=use_merged, | ||
| provider=provider, | ||
| session_options=session_options, | ||
| provider_options=provider_options, | ||
| use_io_binding=use_io_binding, | ||
| model_save_dir=save_dir, | ||
| file_name=file_name, | ||
| ) | ||
| # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation | ||
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | ||
| if past_key_values is not None: | ||
| past_length = past_key_values[0][0].shape[2] | ||
| # Some generation methods already pass only the last input ID | ||
| if input_ids.shape[1] > past_length: | ||
| remove_prefix_length = past_length | ||
| else: | ||
| # Default to old behavior: keep only final ID | ||
| remove_prefix_length = input_ids.shape[1] - 1 | ||
| input_ids = input_ids[:, remove_prefix_length:] | ||
| # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly | ||
| attention_mask = kwargs.get("attention_mask", None) | ||
| use_cache = kwargs.get("use_cache", None) | ||
| position_ids = kwargs.get("position_ids", None) | ||
| if attention_mask is not None and position_ids is None: | ||
| # create position_ids on the fly for batch generation | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| if past_key_values: | ||
| position_ids = position_ids[:, -1].unsqueeze(-1) | ||
| return { | ||
| "input_ids": input_ids, | ||
| "past_key_values": past_key_values, | ||
| "use_cache": use_cache, | ||
| "position_ids": position_ids, | ||
| "attention_mask": attention_mask, | ||
| } | ||
| # Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache | ||
| @staticmethod | ||
| def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: | ||
| return tuple( | ||
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) | ||
| for layer_past in past | ||
| ) | ||
| def _save_pretrained(self, save_directory: Union[str, Path]): | ||
| super()._save_pretrained(save_directory) | ||
| self.generation_config.save_pretrained(save_directory) | ||
| class ORTGPTBigCodeForCausalLM(ORTModelForCausalLM): | ||
| # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation | ||
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): | ||
| # Omit tokens covered by past_key_values | ||
| if past_key_values: | ||
| if self.config.multi_query: | ||
| past_length = past_key_values[0].shape[1] | ||
| else: | ||
| past_length = past_key_values[0].shape[2] | ||
| # Some generation methods already pass only the last input ID | ||
| if input_ids.shape[1] > past_length: | ||
| remove_prefix_length = past_length | ||
| else: | ||
| # Default to old behavior: keep only final ID | ||
| remove_prefix_length = input_ids.shape[1] - 1 | ||
| input_ids = input_ids[:, remove_prefix_length:] | ||
| attention_mask = kwargs.get("attention_mask", None) | ||
| position_ids = kwargs.get("position_ids", None) | ||
| if attention_mask is not None and position_ids is None: | ||
| # create position_ids on the fly for batch generation | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| if past_key_values: | ||
| position_ids = position_ids[:, -input_ids.shape[1] :] | ||
| else: | ||
| position_ids = None | ||
| model_inputs = {"input_ids": input_ids} | ||
| model_inputs.update( | ||
| { | ||
| "past_key_values": past_key_values, | ||
| "use_cache": kwargs.get("use_cache"), | ||
| "position_ids": position_ids, | ||
| "attention_mask": attention_mask, | ||
| } | ||
| ) | ||
| return model_inputs | ||
| # Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache | ||
| @staticmethod | ||
| def _reorder_cache( | ||
| past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor | ||
| ) -> Tuple[Tuple[torch.Tensor]]: | ||
| return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) | ||
| class ORTBloomForCausalLM(ORTModelForCausalLM): | ||
| # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation | ||
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | ||
| if past_key_values is not None: | ||
| past_length = past_key_values[0][0].shape[2] | ||
| # Some generation methods already pass only the last input ID | ||
| if input_ids.shape[1] > past_length: | ||
| remove_prefix_length = past_length | ||
| else: | ||
| # Default to old behavior: keep only final ID | ||
| remove_prefix_length = input_ids.shape[1] - 1 | ||
| input_ids = input_ids[:, remove_prefix_length:] | ||
| attention_mask = kwargs.get("attention_mask", None) | ||
| use_cache = kwargs.get("use_cache", None) | ||
| # only last token for input_ids if past is not None | ||
| if past_key_values: | ||
| # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed | ||
| if past_key_values[0][0].shape[0] == input_ids.shape[0]: | ||
| past_key_values = bloom_convert_to_bloom_cache(past_key_values) | ||
| return { | ||
| "input_ids": input_ids, | ||
| "past_key_values": past_key_values, | ||
| "use_cache": use_cache, | ||
| "position_ids": None, | ||
| "attention_mask": attention_mask, | ||
| } | ||
| # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache | ||
| @staticmethod | ||
| def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: | ||
| standardized_past = bloom_convert_to_standard_cache(past, batch_size=len(beam_idx)) | ||
| # Get a copy of `beam_idx` on all the devices where we need those indices. | ||
| device_to_beam_idx = { | ||
| past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past | ||
| } | ||
| reordered_past = tuple( | ||
| ( | ||
| layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), | ||
| layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), | ||
| ) | ||
| for layer_past in standardized_past | ||
| ) | ||
| return bloom_convert_to_bloom_cache(reordered_past) | ||
| class ORTOPTForCausalLM(ORTModelForCausalLM): | ||
| # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation | ||
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | ||
| if past_key_values is not None: | ||
| past_length = past_key_values[0][0].shape[2] | ||
| # Some generation methods already pass only the last input ID | ||
| if input_ids.shape[1] > past_length: | ||
| remove_prefix_length = past_length | ||
| else: | ||
| # Default to old behavior: keep only final ID | ||
| remove_prefix_length = input_ids.shape[1] - 1 | ||
| input_ids = input_ids[:, remove_prefix_length:] | ||
| attention_mask = kwargs.get("attention_mask", None) | ||
| use_cache = kwargs.get("use_cache", None) | ||
| return { | ||
| "input_ids": input_ids, | ||
| "past_key_values": past_key_values, | ||
| "use_cache": use_cache, | ||
| "position_ids": None, | ||
| "attention_mask": attention_mask, | ||
| } | ||
| class ORTMPTForCausalLM(ORTModelForCausalLM): | ||
| # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation | ||
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | ||
| if past_key_values is not None: | ||
| past_length = past_key_values[0][0].shape[2] | ||
| # Some generation methods already pass only the last input ID | ||
| if input_ids.shape[1] > past_length: | ||
| remove_prefix_length = past_length | ||
| else: | ||
| # Default to old behavior: keep only final ID | ||
| remove_prefix_length = input_ids.shape[1] - 1 | ||
| input_ids = input_ids[:, remove_prefix_length:] | ||
| attention_mask = kwargs.get("attention_mask", None) | ||
| use_cache = kwargs.get("use_cache", None) | ||
| return { | ||
| "input_ids": input_ids, | ||
| "past_key_values": past_key_values, | ||
| "use_cache": use_cache, | ||
| "position_ids": None, | ||
| "attention_mask": attention_mask, | ||
| } | ||
| class ORTFalconForCausalLM(ORTModelForCausalLM): | ||
| def __init__( | ||
| self, | ||
| model: onnxruntime.InferenceSession, | ||
| config: "PretrainedConfig", | ||
| use_io_binding: Optional[bool] = None, | ||
| model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, | ||
| preprocessors: Optional[List] = None, | ||
| generation_config: Optional[GenerationConfig] = None, | ||
| use_cache: Optional[bool] = None, | ||
| **kwargs, | ||
| ): | ||
| super().__init__( | ||
| model=model, | ||
| config=config, | ||
| use_io_binding=use_io_binding, | ||
| model_save_dir=model_save_dir, | ||
| preprocessors=preprocessors, | ||
| generation_config=generation_config, | ||
| use_cache=use_cache, | ||
| **kwargs, | ||
| ) | ||
| self.num_key_value_heads = ( | ||
| config.num_kv_heads if (config.new_decoder_architecture or not config.multi_query) else 1 | ||
| ) | ||
| self.use_alibi = config.alibi | ||
| # Copied from transformers.models.falcon.modeling_falcon.FalconForCausalLM._reorder_cache | ||
| def _reorder_cache( | ||
| self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor | ||
| ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: | ||
| def _save_config(self, save_directory): | ||
| """ | ||
| This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or | ||
| [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct | ||
| beam_idx at every generation step. | ||
| Save the model and generation configs to the specified directory. | ||
| Output shares the same memory storage as `past`. | ||
| Args: | ||
| save_directory (`str` or `os.PathLike`): | ||
| Directory where the model and generation configs will be saved. | ||
| """ | ||
| # Get a copy of `beam_idx` on all the devices where we need those indices. | ||
| device_to_beam_idx = { | ||
| past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past | ||
| } | ||
| reordered_past = tuple( | ||
| ( | ||
| layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), | ||
| layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), | ||
| ) | ||
| for layer_past in past | ||
| ) | ||
| return reordered_past | ||
| # Adapted from transformers.models.falcon.modeling_falcon.FalconForCausalLM.prepare_inputs_for_generation | ||
| def prepare_inputs_for_generation( | ||
| self, | ||
| input_ids: torch.LongTensor, | ||
| past_key_values: Optional[torch.Tensor] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| position_ids: Optional[torch.Tensor] = None, | ||
| **kwargs, | ||
| ) -> dict: | ||
| if past_key_values is not None: | ||
| past_length = past_key_values[0][0].shape[2] | ||
| # Some generation methods already pass only the last input ID | ||
| if input_ids.shape[1] > past_length: | ||
| remove_prefix_length = past_length | ||
| else: | ||
| # Default to old behavior: keep only final ID | ||
| remove_prefix_length = input_ids.shape[1] - 1 | ||
| input_ids = input_ids[:, remove_prefix_length:] | ||
| # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. | ||
| if not self.use_alibi and attention_mask is not None and position_ids is None: | ||
| # create position_ids on the fly for batch generation | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| if past_key_values: | ||
| position_ids = position_ids[:, -input_ids.shape[1] :] | ||
| return { | ||
| "input_ids": input_ids, | ||
| "position_ids": position_ids, | ||
| "past_key_values": past_key_values, | ||
| "use_cache": kwargs.get("use_cache"), | ||
| "attention_mask": attention_mask, | ||
| } | ||
| self.config.save_pretrained(save_directory) | ||
| self.generation_config.save_pretrained(save_directory) |
@@ -19,8 +19,6 @@ # Copyright 2023 The HuggingFace Team. All rights reserved. | ||
| import os | ||
| import shutil | ||
| from abc import abstractmethod | ||
| from collections import OrderedDict | ||
| from pathlib import Path | ||
| from tempfile import TemporaryDirectory | ||
| from typing import Any, Dict, Optional, Union | ||
| from typing import Any, Dict, Optional, Sequence, Union | ||
@@ -47,4 +45,4 @@ import numpy as np | ||
| from diffusers.utils.constants import CONFIG_NAME | ||
| from huggingface_hub import HfApi | ||
| from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE | ||
| from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card | ||
| from huggingface_hub import HfApi, create_repo | ||
| from huggingface_hub.utils import validate_hf_hub_args | ||
@@ -56,7 +54,5 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer | ||
| import onnxruntime as ort | ||
| from optimum.utils import is_diffusers_version | ||
| from onnxruntime import InferenceSession, SessionOptions | ||
| from ..exporters.onnx import main_export | ||
| from ..onnx.utils import _get_model_external_data_paths | ||
| from ..utils import ( | ||
@@ -70,12 +66,8 @@ DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, | ||
| DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, | ||
| ) | ||
| from .io_binding import TypeHelper | ||
| from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel | ||
| from .utils import ( | ||
| DIFFUSION_PIPELINE_CONFIG_FILE_NAME, | ||
| ONNX_WEIGHTS_NAME, | ||
| get_provider_for_device, | ||
| np_to_pt_generators, | ||
| parse_device, | ||
| validate_provider_availability, | ||
| is_diffusers_version, | ||
| ) | ||
| from .base import ORTParentMixin, ORTSessionMixin | ||
| from .utils import get_device_for_provider, np_to_pt_generators, prepare_providers_and_provider_options | ||
@@ -93,6 +85,7 @@ | ||
| # TODO: support from_pipe() | ||
| # TODO: Instead of ORTModel, it makes sense to have a compositional ORTMixin | ||
| # TODO: instead of one bloated __init__, we should consider an __init__ per pipeline | ||
| class ORTDiffusionPipeline(ORTModel, DiffusionPipeline): | ||
| config_name = "model_index.json" | ||
| class ORTDiffusionPipeline(ORTParentMixin, DiffusionPipeline): | ||
| config_name = DIFFUSION_PIPELINE_CONFIG_FILE_NAME | ||
| task = "auto" | ||
| library = "diffusers" | ||
| auto_model_class = DiffusionPipeline | ||
@@ -102,12 +95,13 @@ | ||
| self, | ||
| scheduler: "SchedulerMixin", | ||
| vae_decoder_session: ort.InferenceSession, | ||
| # optional pipeline models | ||
| unet_session: Optional[ort.InferenceSession] = None, | ||
| transformer_session: Optional[ort.InferenceSession] = None, | ||
| vae_encoder_session: Optional[ort.InferenceSession] = None, | ||
| text_encoder_session: Optional[ort.InferenceSession] = None, | ||
| text_encoder_2_session: Optional[ort.InferenceSession] = None, | ||
| text_encoder_3_session: Optional[ort.InferenceSession] = None, | ||
| # optional pipeline submodels | ||
| *, | ||
| # pipeline models | ||
| unet_session: Optional["InferenceSession"] = None, | ||
| transformer_session: Optional["InferenceSession"] = None, | ||
| vae_decoder_session: Optional["InferenceSession"] = None, | ||
| vae_encoder_session: Optional["InferenceSession"] = None, | ||
| text_encoder_session: Optional["InferenceSession"] = None, | ||
| text_encoder_2_session: Optional["InferenceSession"] = None, | ||
| text_encoder_3_session: Optional["InferenceSession"] = None, | ||
| # pipeline submodels | ||
| scheduler: Optional["SchedulerMixin"] = None, | ||
| tokenizer: Optional["CLIPTokenizer"] = None, | ||
@@ -126,22 +120,57 @@ tokenizer_2: Optional["CLIPTokenizer"] = None, | ||
| ): | ||
| self.unet = ORTModelUnet(unet_session, self) if unet_session is not None else None | ||
| self.transformer = ORTModelTransformer(transformer_session, self) if transformer_session is not None else None | ||
| # We initialize all ort session mixins first | ||
| self.unet = ORTUnet(unet_session, self, use_io_binding) if unet_session is not None else None | ||
| self.transformer = ( | ||
| ORTTransformer(transformer_session, self, use_io_binding) if transformer_session is not None else None | ||
| ) | ||
| self.text_encoder = ( | ||
| ORTModelTextEncoder(text_encoder_session, self) if text_encoder_session is not None else None | ||
| ORTTextEncoder(text_encoder_session, self, use_io_binding) if text_encoder_session is not None else None | ||
| ) | ||
| self.text_encoder_2 = ( | ||
| ORTModelTextEncoder(text_encoder_2_session, self) if text_encoder_2_session is not None else None | ||
| ORTTextEncoder(text_encoder_2_session, self, use_io_binding) | ||
| if text_encoder_2_session is not None | ||
| else None | ||
| ) | ||
| self.text_encoder_3 = ( | ||
| ORTModelTextEncoder(text_encoder_3_session, self) if text_encoder_3_session is not None else None | ||
| ORTTextEncoder(text_encoder_3_session, self, use_io_binding) | ||
| if text_encoder_3_session is not None | ||
| else None | ||
| ) | ||
| # We wrap the VAE Decoder & Encoder in a single object to simulate diffusers API | ||
| self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) if vae_encoder_session is not None else None | ||
| self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) if vae_decoder_session is not None else None | ||
| self.vae = ORTWrapperVae(self.vae_encoder, self.vae_decoder) | ||
| self.vae_encoder = ( | ||
| ORTVaeEncoder(vae_encoder_session, self, use_io_binding) if vae_encoder_session is not None else None | ||
| ) | ||
| self.vae_decoder = ( | ||
| ORTVaeDecoder(vae_decoder_session, self, use_io_binding) if vae_decoder_session is not None else None | ||
| ) | ||
| # We register ort session mixins to the wrapper | ||
| super().initialize_ort_attributes( | ||
| parts=list( | ||
| filter( | ||
| None, | ||
| { | ||
| self.unet, | ||
| self.transformer, | ||
| self.vae_encoder, | ||
| self.vae_decoder, | ||
| self.text_encoder, | ||
| self.text_encoder_2, | ||
| self.text_encoder_3, | ||
| }, | ||
| ) | ||
| ) | ||
| ) | ||
| # We wrap the VAE Encoder & Decoder in a single object for convenience | ||
| self.vae = ( | ||
| ORTVae(self.vae_encoder, self.vae_decoder) | ||
| if self.vae_encoder is not None or self.vae_decoder is not None | ||
| else None | ||
| ) | ||
| # we allow passing these as torch models for now | ||
| self.image_encoder = kwargs.pop("image_encoder", None) # TODO: maybe implement ORTModelImageEncoder | ||
| self.safety_checker = kwargs.pop("safety_checker", None) # TODO: maybe implement ORTModelSafetyChecker | ||
| self.image_encoder = kwargs.pop("image_encoder", None) # TODO: maybe implement ORTImageEncoder | ||
| self.safety_checker = kwargs.pop("safety_checker", None) # TODO: maybe implement ORTSafetyChecker | ||
| # We register the submodels to the pipeline | ||
| self.scheduler = scheduler | ||
@@ -153,2 +182,3 @@ self.tokenizer = tokenizer | ||
| # We initialize diffusers pipeline specific attributes (registers modules and config) | ||
| all_pipeline_init_args = { | ||
@@ -172,3 +202,2 @@ "vae": self.vae, | ||
| } | ||
| diffusers_pipeline_args = {} | ||
@@ -178,128 +207,170 @@ for key in inspect.signature(self.auto_model_class).parameters.keys(): | ||
| diffusers_pipeline_args[key] = all_pipeline_init_args[key] | ||
| # inits diffusers pipeline specific attributes (registers modules and config) | ||
| self.auto_model_class.__init__(self, **diffusers_pipeline_args) | ||
| # inits ort specific attributes | ||
| self.shared_attributes_init( | ||
| model=unet_session if unet_session is not None else transformer_session, | ||
| use_io_binding=use_io_binding, | ||
| model_save_dir=model_save_dir, | ||
| **kwargs, | ||
| ) | ||
| # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it | ||
| # would end-up removing the directory containing the underlying ONNX model (and thus failing inference). | ||
| self.model_save_dir = model_save_dir | ||
| def _save_pretrained(self, save_directory: Union[str, Path]): | ||
| save_directory = Path(save_directory) | ||
| models_to_save_paths = { | ||
| (self.unet, save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER), | ||
| (self.transformer, save_directory / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER), | ||
| (self.vae_decoder, save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER), | ||
| (self.vae_encoder, save_directory / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER), | ||
| (self.text_encoder, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER), | ||
| (self.text_encoder_2, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER), | ||
| (self.text_encoder_3, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER), | ||
| @property | ||
| def components(self) -> Dict[str, Optional[Union[ORTSessionMixin, torch.nn.Module]]]: | ||
| # TODO: all components should be ORTSessionMixin's at some point | ||
| components = { | ||
| "vae": self.vae, | ||
| "unet": self.unet, | ||
| "transformer": self.transformer, | ||
| "text_encoder": self.text_encoder, | ||
| "text_encoder_2": self.text_encoder_2, | ||
| "text_encoder_3": self.text_encoder_3, | ||
| "safety_checker": self.safety_checker, | ||
| "image_encoder": self.image_encoder, | ||
| } | ||
| for model, save_path in models_to_save_paths: | ||
| if model is not None: | ||
| model_path = Path(model.session._model_path) | ||
| save_path.mkdir(parents=True, exist_ok=True) | ||
| # copy onnx model | ||
| shutil.copyfile(model_path, save_path / ONNX_WEIGHTS_NAME) | ||
| # copy external onnx data | ||
| external_data_paths = _get_model_external_data_paths(model_path) | ||
| for external_data_path in external_data_paths: | ||
| shutil.copyfile(external_data_path, save_path / external_data_path.name) | ||
| # copy model config | ||
| config_path = model_path.parent / CONFIG_NAME | ||
| if config_path.is_file(): | ||
| config_save_path = save_path / CONFIG_NAME | ||
| shutil.copyfile(config_path, config_save_path) | ||
| components = {k: v for k, v in components.items() if v is not None} | ||
| return components | ||
| self.scheduler.save_pretrained(save_directory / "scheduler") | ||
| def to(self, device: Union[torch.device, str, int]): | ||
| """ | ||
| Changes the device of the pipeline components to the specified device. | ||
| if self.tokenizer is not None: | ||
| self.tokenizer.save_pretrained(save_directory / "tokenizer") | ||
| if self.tokenizer_2 is not None: | ||
| self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2") | ||
| if self.tokenizer_3 is not None: | ||
| self.tokenizer_3.save_pretrained(save_directory / "tokenizer_3") | ||
| if self.feature_extractor is not None: | ||
| self.feature_extractor.save_pretrained(save_directory / "feature_extractor") | ||
| Args: | ||
| device (`torch.device` or `str` or `int`): | ||
| Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run | ||
| the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too. | ||
| Returns: | ||
| `ORTDiffusionPipeline`: The pipeline with the updated device. | ||
| """ | ||
| for component in self.components.values(): | ||
| if isinstance(component, (ORTSessionMixin, ORTParentMixin)): | ||
| component.to(device) | ||
| return self | ||
| @classmethod | ||
| def _from_pretrained( | ||
| def from_pretrained( | ||
| cls, | ||
| model_id: Union[str, Path], | ||
| config: Dict[str, Any], | ||
| subfolder: str = "", | ||
| force_download: bool = False, | ||
| local_files_only: bool = False, | ||
| revision: Optional[str] = None, | ||
| trust_remote_code: bool = False, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| token: Optional[Union[bool, str]] = None, | ||
| unet_file_name: str = ONNX_WEIGHTS_NAME, | ||
| transformer_file_name: str = ONNX_WEIGHTS_NAME, | ||
| vae_decoder_file_name: str = ONNX_WEIGHTS_NAME, | ||
| vae_encoder_file_name: str = ONNX_WEIGHTS_NAME, | ||
| text_encoder_file_name: str = ONNX_WEIGHTS_NAME, | ||
| text_encoder_2_file_name: str = ONNX_WEIGHTS_NAME, | ||
| text_encoder_3_file_name: str = ONNX_WEIGHTS_NAME, | ||
| model_name_or_path: Union[str, Path], | ||
| # export options | ||
| export: bool = False, | ||
| # session options | ||
| provider: str = "CPUExecutionProvider", | ||
| providers: Optional[Sequence[str]] = None, | ||
| provider_options: Optional[Union[Sequence[Dict[str, Any]], Dict[str, Any]]] = None, | ||
| session_options: Optional[SessionOptions] = None, | ||
| # inference options | ||
| use_io_binding: Optional[bool] = None, | ||
| provider: str = "CPUExecutionProvider", | ||
| provider_options: Optional[Dict[str, Any]] = None, | ||
| session_options: Optional[ort.SessionOptions] = None, | ||
| model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, | ||
| # hub options and preloaded models | ||
| **kwargs, | ||
| ): | ||
| if use_io_binding: | ||
| raise ValueError( | ||
| "IOBinding is not yet available for diffusion pipelines, please set `use_io_binding` to False." | ||
| """ | ||
| Instantiates a [`ORTDiffusionPipeline`] with ONNX Runtime sessions from a pretrained model. | ||
| This method can be used to load a model from the Hugging Face Hub or from a local directory. | ||
| Args: | ||
| model_name_or_path (`str` or `os.PathLike`): | ||
| Path to a folder containing the model files or a hub repository id. | ||
| export (`bool`, *optional*, defaults to `False`): | ||
| Whether to export the model to ONNX format. If set to `True`, the model will be exported and saved | ||
| in the specified directory. | ||
| provider (`str`, *optional*, defaults to `"CPUExecutionProvider"`): | ||
| The execution provider for ONNX Runtime. Can be `"CUDAExecutionProvider"`, `"DmlExecutionProvider"`, | ||
| etc. | ||
| providers (`Sequence[str]`, *optional*): | ||
| A list of execution providers for ONNX Runtime. Overrides `provider`. | ||
| provider_options (`Union[Sequence[Dict[str, Any]], Dict[str, Any]]`, *optional*): | ||
| Options for each execution provider. Can be a single dictionary for the first provider or a list of | ||
| dictionaries for each provider. The order of the dictionaries should match the order of the providers. | ||
| session_options (`SessionOptions`, *optional*): | ||
| Options for the ONNX Runtime session. Can be used to set optimization levels, graph optimization, | ||
| etc. | ||
| use_io_binding (`bool`, *optional*): | ||
| Whether to use IOBinding for the ONNX Runtime session. If set to `True`, it will use IOBinding for | ||
| input and output tensors. | ||
| **kwargs: | ||
| Can include the following: | ||
| - Export arguments (e.g., `slim`, `dtype`, `device`, `no_dynamic_axes`, etc.). | ||
| - Hugging Face Hub arguments (e.g., `revision`, `cache_dir`, `force_download`, etc.). | ||
| - Preloaded models or sessions for the different components of the pipeline (e.g., `vae_encoder_session`, | ||
| `vae_decoder_session`, `unet_session`, `transformer_session`, `image_encoder`, `safety_checker`, etc.). | ||
| Returns: | ||
| [`ORTDiffusionPipeline`]: The loaded pipeline with ONNX Runtime sessions. | ||
| """ | ||
| providers, provider_options = prepare_providers_and_provider_options( | ||
| provider=provider, providers=providers, provider_options=provider_options | ||
| ) | ||
| hub_kwargs = { | ||
| "force_download": kwargs.get("force_download", False), | ||
| "resume_download": kwargs.get("resume_download", None), | ||
| "local_files_only": kwargs.get("local_files_only", False), | ||
| "cache_dir": kwargs.get("cache_dir", None), | ||
| "revision": kwargs.get("revision", None), | ||
| "proxies": kwargs.get("proxies", None), | ||
| "token": kwargs.get("token", None), | ||
| } | ||
| # get the pipeline config | ||
| config = cls.load_config(model_name_or_path, **hub_kwargs) | ||
| config = config[0] if isinstance(config, tuple) else config | ||
| model_save_tmpdir = None | ||
| model_save_path = Path(model_name_or_path) | ||
| # export the model if requested | ||
| if export: | ||
| model_save_tmpdir = TemporaryDirectory() | ||
| model_save_path = Path(model_save_tmpdir.name) | ||
| export_kwargs = { | ||
| "slim": kwargs.pop("slim", False), | ||
| "dtype": kwargs.pop("dtype", None), | ||
| "device": get_device_for_provider(provider, {}).type, | ||
| "no_dynamic_axes": kwargs.pop("no_dynamic_axes", False), | ||
| } | ||
| main_export( | ||
| model_name_or_path=model_name_or_path, | ||
| # export related arguments | ||
| output=model_save_path, | ||
| no_post_process=True, | ||
| do_validation=False, | ||
| task=cls.task, | ||
| # export related arguments | ||
| **export_kwargs, | ||
| # hub related arguments | ||
| **hub_kwargs, | ||
| ) | ||
| if not os.path.isdir(str(model_id)): | ||
| # download the model if needed | ||
| if not model_save_path.is_dir(): | ||
| # everything in components subfolders | ||
| all_components = {key for key in config.keys() if not key.startswith("_")} | {"vae_encoder", "vae_decoder"} | ||
| allow_patterns = {os.path.join(component, "*") for component in all_components} | ||
| # plus custom file names | ||
| allow_patterns.update( | ||
| { | ||
| unet_file_name, | ||
| transformer_file_name, | ||
| vae_decoder_file_name, | ||
| vae_encoder_file_name, | ||
| text_encoder_file_name, | ||
| text_encoder_2_file_name, | ||
| text_encoder_3_file_name, | ||
| ONNX_WEIGHTS_NAME, | ||
| DIFFUSION_PIPELINE_CONFIG_FILE_NAME, | ||
| SCHEDULER_CONFIG_NAME, | ||
| cls.config_name, | ||
| CONFIG_NAME, | ||
| } | ||
| ) | ||
| model_save_folder = HfApi(user_agent=http_user_agent(), token=token).snapshot_download( | ||
| model_id, | ||
| token=token, | ||
| revision=revision, | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| local_files_only=local_files_only, | ||
| model_save_folder = HfApi(user_agent=http_user_agent()).snapshot_download( | ||
| repo_id=str(model_name_or_path), | ||
| allow_patterns=allow_patterns, | ||
| ignore_patterns=["*.msgpack", "*.safetensors", "*.bin", "*.xml"], | ||
| allow_patterns=allow_patterns, | ||
| **hub_kwargs, | ||
| ) | ||
| else: | ||
| model_save_folder = str(model_id) | ||
| model_save_path = Path(model_save_folder) | ||
| model_save_path = Path(model_save_folder) | ||
| if model_save_dir is None: | ||
| model_save_dir = model_save_path | ||
| model_paths = { | ||
| "unet": model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, | ||
| "transformer": model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER / transformer_file_name, | ||
| "vae_decoder": model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, | ||
| "vae_encoder": model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, | ||
| "text_encoder": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, | ||
| "text_encoder_2": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name, | ||
| "text_encoder_3": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER / text_encoder_3_file_name, | ||
| "unet": model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME, | ||
| "transformer": model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER / ONNX_WEIGHTS_NAME, | ||
| "vae_encoder": model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME, | ||
| "vae_decoder": model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME, | ||
| "text_encoder": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME, | ||
| "text_encoder_2": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / ONNX_WEIGHTS_NAME, | ||
| "text_encoder_3": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER / ONNX_WEIGHTS_NAME, | ||
| } | ||
| models = {} | ||
| sessions = {} | ||
@@ -309,6 +380,12 @@ for model, path in model_paths.items(): | ||
| # this allows passing a model directly to from_pretrained | ||
| sessions[f"{model}_session"] = kwargs.pop(model) | ||
| else: | ||
| sessions[f"{model}_session"] = ( | ||
| ORTModel.load_model(path, provider, session_options, provider_options) if path.is_file() else None | ||
| models[model] = kwargs.pop(model) | ||
| elif kwargs.get(f"{model}_session", None) is not None: | ||
| # this allows passing a session directly to from_pretrained | ||
| sessions[f"{model}_session"] = kwargs.pop(f"{model}_session") | ||
| elif path.is_file(): | ||
| sessions[f"{model}_session"] = InferenceSession( | ||
| path, | ||
| providers=providers, | ||
| provider_options=provider_options, | ||
| sess_options=session_options, | ||
| ) | ||
@@ -331,6 +408,6 @@ | ||
| # same as DiffusionPipeline.from_pretraoned, if called directly, it loads the class in the config | ||
| # Same as DiffusionPipeline.from_pretrained | ||
| if cls.__name__ == "ORTDiffusionPipeline": | ||
| class_name = config["_class_name"] | ||
| ort_pipeline_class = _get_ort_class(class_name) | ||
| pipeline_class_name = config["_class_name"] | ||
| ort_pipeline_class = _get_ort_class(pipeline_class_name) | ||
| else: | ||
@@ -343,153 +420,108 @@ ort_pipeline_class = cls | ||
| use_io_binding=use_io_binding, | ||
| model_save_dir=model_save_dir, | ||
| model_save_dir=model_save_tmpdir, | ||
| **models, | ||
| **kwargs, | ||
| ) | ||
| # same as in DiffusionPipeline.from_pretrained, we save where the model was instantiated from | ||
| ort_pipeline.register_to_config(_name_or_path=config.get("_name_or_path", str(model_id))) | ||
| ort_pipeline.register_to_config(**config) | ||
| ort_pipeline.register_to_config(_name_or_path=config.get("_name_or_path", model_name_or_path)) | ||
| return ort_pipeline | ||
| @classmethod | ||
| def _export( | ||
| cls, | ||
| model_id: str, | ||
| config: Dict[str, Any], | ||
| subfolder: str = "", | ||
| force_download: bool = False, | ||
| local_files_only: bool = False, | ||
| revision: Optional[str] = None, | ||
| trust_remote_code: bool = False, | ||
| cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
| token: Optional[Union[bool, str]] = None, | ||
| use_io_binding: Optional[bool] = None, | ||
| provider: str = "CPUExecutionProvider", | ||
| session_options: Optional[ort.SessionOptions] = None, | ||
| provider_options: Optional[Dict[str, Any]] = None, | ||
| task: Optional[str] = None, | ||
| def save_pretrained( | ||
| self, | ||
| save_directory: Union[str, Path], | ||
| push_to_hub: Optional[bool] = False, | ||
| **kwargs, | ||
| ) -> "ORTDiffusionPipeline": | ||
| if task is None: | ||
| task = cls._auto_model_to_task(cls.auto_model_class) | ||
| # we continue passing the model_save_dir from here on to avoid it being cleaned up | ||
| # might be better to use a persistent temporary directory such as the one implemented in | ||
| # https://gist.github.com/twolfson/2929dc1163b0a76d2c2b66d51f9bc808 | ||
| model_save_dir = TemporaryDirectory() | ||
| model_save_path = Path(model_save_dir.name) | ||
| main_export( | ||
| model_id, | ||
| output=model_save_path, | ||
| do_validation=False, | ||
| no_post_process=True, | ||
| token=token, | ||
| revision=revision, | ||
| cache_dir=cache_dir, | ||
| subfolder=subfolder, | ||
| force_download=force_download, | ||
| local_files_only=local_files_only, | ||
| trust_remote_code=trust_remote_code, | ||
| library_name="diffusers", | ||
| task=task, | ||
| ) | ||
| return cls._from_pretrained( | ||
| model_save_path, | ||
| config=config, | ||
| provider=provider, | ||
| provider_options=provider_options, | ||
| session_options=session_options, | ||
| use_io_binding=use_io_binding, | ||
| model_save_dir=model_save_dir, | ||
| **kwargs, | ||
| ) | ||
| def to(self, device: Union[torch.device, str, int]): | ||
| ): | ||
| """ | ||
| Changes the ONNX Runtime provider according to the device. | ||
| Saves a model and its configuration file to a directory, so that it can be re-loaded using the | ||
| [`from_pretrained`] class method. | ||
| Args: | ||
| device (`torch.device` or `str` or `int`): | ||
| Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run | ||
| the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too. | ||
| Returns: | ||
| `ORTModel`: the model placed on the requested device. | ||
| save_directory (`Union[str, os.PathLike]`): | ||
| Directory to which to save. Will be created if it doesn't exist. | ||
| push_to_hub (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to push your model to the Hugging Face model hub after saving it. | ||
| **kwargs: | ||
| Additional keyword arguments passed along to [`~huggingface_hub.create_repo`] and | ||
| [`~huggingface_hub.HfApi.upload_folder`] if `push_to_hub` is set to `True`. | ||
| """ | ||
| device, provider_options = parse_device(device) | ||
| provider = get_provider_for_device(device) | ||
| validate_provider_availability(provider) | ||
| model_save_path = Path(save_directory) | ||
| model_save_path.mkdir(parents=True, exist_ok=True) | ||
| if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": | ||
| return self | ||
| if push_to_hub: | ||
| token = kwargs.pop("token", None) | ||
| private = kwargs.pop("private", False) | ||
| create_pr = kwargs.pop("create_pr", False) | ||
| commit_message = kwargs.pop("commit_message", None) | ||
| repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) | ||
| repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id | ||
| self.vae_decoder.session.set_providers([provider], provider_options=[provider_options]) | ||
| self.save_config(model_save_path) | ||
| self.scheduler.save_pretrained(model_save_path / "scheduler") | ||
| if self.unet is not None: | ||
| self.unet.session.set_providers([provider], provider_options=[provider_options]) | ||
| self.unet.save_pretrained(model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER) | ||
| if self.transformer is not None: | ||
| self.transformer.session.set_providers([provider], provider_options=[provider_options]) | ||
| self.transformer.save_pretrained(model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER) | ||
| if self.vae_encoder is not None: | ||
| self.vae_encoder.session.set_providers([provider], provider_options=[provider_options]) | ||
| self.vae_encoder.save_pretrained(model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER) | ||
| if self.vae_decoder is not None: | ||
| self.vae_decoder.save_pretrained(model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER) | ||
| if self.text_encoder is not None: | ||
| self.text_encoder.session.set_providers([provider], provider_options=[provider_options]) | ||
| self.text_encoder.save_pretrained(model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER) | ||
| if self.text_encoder_2 is not None: | ||
| self.text_encoder_2.session.set_providers([provider], provider_options=[provider_options]) | ||
| self.text_encoder_2.save_pretrained(model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER) | ||
| if self.text_encoder_3 is not None: | ||
| self.text_encoder_3.session.set_providers([provider], provider_options=[provider_options]) | ||
| self.text_encoder_3.save_pretrained(model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER) | ||
| self.providers = ( | ||
| self.unet.session.get_providers() if self.unet is not None else self.transformer.session.get_providers() | ||
| ) | ||
| self._device = device | ||
| if self.image_encoder is not None: | ||
| self.image_encoder.save_pretrained(model_save_path / "image_encoder") | ||
| if self.safety_checker is not None: | ||
| self.safety_checker.save_pretrained(model_save_path / "safety_checker") | ||
| return self | ||
| if self.tokenizer is not None: | ||
| self.tokenizer.save_pretrained(model_save_path / "tokenizer") | ||
| if self.tokenizer_2 is not None: | ||
| self.tokenizer_2.save_pretrained(model_save_path / "tokenizer_2") | ||
| if self.tokenizer_3 is not None: | ||
| self.tokenizer_3.save_pretrained(model_save_path / "tokenizer_3") | ||
| if self.feature_extractor is not None: | ||
| self.feature_extractor.save_pretrained(model_save_path / "feature_extractor") | ||
| @classmethod | ||
| def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs): | ||
| return cls.load_config(config_name_or_path, **kwargs) | ||
| if push_to_hub: | ||
| # Create a new empty model card and eventually tag it | ||
| model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True) | ||
| model_card = populate_model_card(model_card) | ||
| model_card.save(os.path.join(save_directory, "README.md")) | ||
| self._upload_folder( | ||
| save_directory, | ||
| repo_id, | ||
| token=token, | ||
| create_pr=create_pr, | ||
| commit_message=commit_message, | ||
| ) | ||
| def _save_config(self, save_directory: Union[str, Path]): | ||
| model_dir = ( | ||
| self.model_save_dir | ||
| if not isinstance(self.model_save_dir, TemporaryDirectory) | ||
| else self.model_save_dir.name | ||
| ) | ||
| save_dir = Path(save_directory) | ||
| original_config = Path(model_dir) / self.config_name | ||
| if original_config.exists(): | ||
| if not save_dir.exists(): | ||
| save_dir.mkdir(parents=True) | ||
| shutil.copy(original_config, save_dir) | ||
| else: | ||
| self.save_config(save_directory) | ||
| @property | ||
| def components(self) -> Dict[str, Any]: | ||
| components = { | ||
| "vae": self.vae, | ||
| "unet": self.unet, | ||
| "transformer": self.transformer, | ||
| "text_encoder": self.text_encoder, | ||
| "text_encoder_2": self.text_encoder_2, | ||
| "text_encoder_3": self.text_encoder_3, | ||
| "safety_checker": self.safety_checker, | ||
| "image_encoder": self.image_encoder, | ||
| } | ||
| components = {k: v for k, v in components.items() if v is not None} | ||
| return components | ||
| def __call__(self, *args, **kwargs): | ||
| # we do this to keep numpy random states support for now | ||
| # TODO: deprecate and add warnings when a random state is passed | ||
| args = list(args) | ||
| for i in range(len(args)): | ||
| args[i] = np_to_pt_generators(args[i], self.device) | ||
| new_args = np_to_pt_generators(args[i], self.device) | ||
| if args[i] is not new_args: | ||
| logger.warning( | ||
| "Converting numpy random state to torch generator is deprecated. " | ||
| "Please pass a torch generator directly to the pipeline." | ||
| ) | ||
| for k, v in kwargs.items(): | ||
| kwargs[k] = np_to_pt_generators(v, self.device) | ||
| for key, value in kwargs.items(): | ||
| new_value = np_to_pt_generators(value, self.device) | ||
| if value is not new_value: | ||
| logger.warning( | ||
| "Converting numpy random state to torch generator is deprecated. " | ||
| "Please pass a torch generator directly to the pipeline." | ||
| ) | ||
| kwargs[key] = new_value | ||
@@ -499,18 +531,14 @@ return self.auto_model_class.__call__(self, *args, **kwargs) | ||
| class ORTPipelinePart(ConfigMixin): | ||
| class ORTModelMixin(ORTSessionMixin, ConfigMixin): | ||
| config_name: str = CONFIG_NAME | ||
| def __init__(self, session: ort.InferenceSession, parent_pipeline: ORTDiffusionPipeline): | ||
| self.session = session | ||
| self.parent_pipeline = parent_pipeline | ||
| def __init__( | ||
| self, | ||
| session: "InferenceSession", | ||
| parent: "ORTDiffusionPipeline", | ||
| use_io_binding: Optional[bool] = None, | ||
| ): | ||
| self.initialize_ort_attributes(session, use_io_binding=use_io_binding) | ||
| self.parent = parent | ||
| self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} | ||
| self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} | ||
| self.input_dtypes = {input_key.name: input_key.type for input_key in self.session.get_inputs()} | ||
| self.output_dtypes = {output_key.name: output_key.type for output_key in self.session.get_outputs()} | ||
| self.input_shapes = {input_key.name: input_key.shape for input_key in self.session.get_inputs()} | ||
| self.output_shapes = {output_key.name: output_key.shape for output_key in self.session.get_outputs()} | ||
| config_file_path = Path(session._model_path).parent / self.config_name | ||
@@ -523,81 +551,18 @@ if not config_file_path.is_file(): | ||
| @property | ||
| def device(self): | ||
| return self.parent_pipeline.device | ||
| def save_pretrained(self, save_directory: Union[str, Path]): | ||
| """ | ||
| Saves the ONNX model and its configuration file to a directory, so that it can be re-loaded using the | ||
| [`from_pretrained`] class method. | ||
| @property | ||
| def dtype(self): | ||
| for dtype in self.input_dtypes.values(): | ||
| torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) | ||
| if torch_dtype.is_floating_point: | ||
| return torch_dtype | ||
| Args: | ||
| save_directory (`Union[str, os.PathLike]`): | ||
| Directory to which to save. Will be created if it doesn't exist. | ||
| """ | ||
| # save onnx model and external data | ||
| self.save_session(save_directory) | ||
| # save model configuration | ||
| self.save_config(save_directory) | ||
| for dtype in self.output_dtypes.values(): | ||
| torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) | ||
| if torch_dtype.is_floating_point: | ||
| return torch_dtype | ||
| return None | ||
| def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None): | ||
| for arg in args: | ||
| if isinstance(arg, torch.device): | ||
| device = arg | ||
| elif isinstance(arg, (int, str)): | ||
| device = torch.device(arg) | ||
| elif isinstance(arg, torch.dtype): | ||
| dtype = arg | ||
| if device is not None and device != self.device: | ||
| raise ValueError( | ||
| "Cannot change the device of a pipeline part without changing the device of the parent pipeline. " | ||
| "Please use the `to` method of the parent pipeline to change the device." | ||
| ) | ||
| if dtype is not None and dtype != self.dtype: | ||
| raise NotImplementedError( | ||
| f"Cannot change the dtype of the pipeline from {self.dtype} to {dtype}. " | ||
| f"Please export the pipeline with the desired dtype." | ||
| ) | ||
| def prepare_onnx_inputs(self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray]) -> Dict[str, np.ndarray]: | ||
| onnx_inputs = {} | ||
| # converts pytorch inputs into numpy inputs for onnx | ||
| for input_name in self.input_names.keys(): | ||
| onnx_inputs[input_name] = inputs.pop(input_name) | ||
| if use_torch: | ||
| onnx_inputs[input_name] = onnx_inputs[input_name].numpy(force=True) | ||
| if onnx_inputs[input_name].dtype != self.input_dtypes[input_name]: | ||
| onnx_inputs[input_name] = onnx_inputs[input_name].astype( | ||
| TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]) | ||
| ) | ||
| return onnx_inputs | ||
| def prepare_onnx_outputs( | ||
| self, use_torch: bool, *onnx_outputs: np.ndarray | ||
| ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: | ||
| model_outputs = {} | ||
| # converts onnxruntime outputs into tensor for standard outputs | ||
| for output_name, idx in self.output_names.items(): | ||
| model_outputs[output_name] = onnx_outputs[idx] | ||
| if use_torch: | ||
| model_outputs[output_name] = torch.from_numpy(model_outputs[output_name]).to(self.device) | ||
| return model_outputs | ||
| @abstractmethod | ||
| def forward(self, *args, **kwargs): | ||
| raise NotImplementedError | ||
| def __call__(self, *args, **kwargs): | ||
| return self.forward(*args, **kwargs) | ||
| class ORTModelUnet(ORTPipelinePart): | ||
| class ORTUnet(ORTModelMixin): | ||
| def __init__(self, *args, **kwargs): | ||
@@ -645,6 +610,30 @@ super().__init__(*args, **kwargs) | ||
| onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) | ||
| if self.use_io_binding: | ||
| known_output_shapes = {"out_sample": sample.shape} | ||
| known_output_buffers = None | ||
| if "LatentConsistencyModel" not in self.parent.__class__.__name__: | ||
| known_output_buffers = {"out_sample": sample} | ||
| output_shapes, output_buffers = self._prepare_io_binding( | ||
| model_inputs, | ||
| known_output_shapes=known_output_shapes, | ||
| known_output_buffers=known_output_buffers, | ||
| ) | ||
| if self.device.type == "cpu": | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| else: | ||
| self._io_binding.synchronize_inputs() | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| self._io_binding.synchronize_outputs() | ||
| model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names} | ||
| else: | ||
| onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) | ||
| model_outputs["sample"] = model_outputs.pop("out_sample") | ||
| if not return_dict: | ||
@@ -656,3 +645,3 @@ return tuple(model_outputs.values()) | ||
| class ORTModelTransformer(ORTPipelinePart): | ||
| class ORTTransformer(ORTModelMixin): | ||
| def forward( | ||
@@ -683,6 +672,30 @@ self, | ||
| onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) | ||
| if self.use_io_binding: | ||
| known_output_shapes = {"out_hidden_states": hidden_states.shape} | ||
| known_output_buffers = None | ||
| if "Flux" not in self.parent.__class__.__name__: | ||
| known_output_buffers = {"out_hidden_states": hidden_states} | ||
| output_shapes, output_buffers = self._prepare_io_binding( | ||
| model_inputs, | ||
| known_output_shapes=known_output_shapes, | ||
| known_output_buffers=known_output_buffers, | ||
| ) | ||
| if self.device.type == "cpu": | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| else: | ||
| self._io_binding.synchronize_inputs() | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| self._io_binding.synchronize_outputs() | ||
| model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names} | ||
| else: | ||
| onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) | ||
| model_outputs["hidden_states"] = model_outputs.pop("out_hidden_states") | ||
| if not return_dict: | ||
@@ -694,3 +707,3 @@ return tuple(model_outputs.values()) | ||
| class ORTModelTextEncoder(ORTPipelinePart): | ||
| class ORTTextEncoder(ORTModelMixin): | ||
| def forward( | ||
@@ -705,8 +718,22 @@ self, | ||
| model_inputs = {"input_ids": input_ids} | ||
| model_inputs = { | ||
| "input_ids": input_ids, | ||
| } | ||
| onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) | ||
| if self.use_io_binding: | ||
| output_shapes, output_buffers = self._prepare_io_binding(model_inputs) | ||
| if self.device.type == "cpu": | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| else: | ||
| self._io_binding.synchronize_inputs() | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| self._io_binding.synchronize_outputs() | ||
| model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names} | ||
| else: | ||
| onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) | ||
| if output_hidden_states: | ||
@@ -729,3 +756,3 @@ model_outputs["hidden_states"] = [] | ||
| class ORTModelVaeEncoder(ORTPipelinePart): | ||
| class ORTVaeEncoder(ORTModelMixin): | ||
| def __init__(self, *args, **kwargs): | ||
@@ -750,8 +777,22 @@ super().__init__(*args, **kwargs) | ||
| model_inputs = {"sample": sample} | ||
| model_inputs = { | ||
| "sample": sample, | ||
| } | ||
| onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) | ||
| if self.use_io_binding: | ||
| output_shapes, output_buffers = self._prepare_io_binding(model_inputs) | ||
| if self.device.type == "cpu": | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| else: | ||
| self._io_binding.synchronize_inputs() | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| self._io_binding.synchronize_outputs() | ||
| model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names} | ||
| else: | ||
| onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) | ||
| if "latent_sample" in model_outputs: | ||
@@ -771,3 +812,3 @@ model_outputs["latents"] = model_outputs.pop("latent_sample") | ||
| class ORTModelVaeDecoder(ORTPipelinePart): | ||
| class ORTVaeDecoder(ORTModelMixin): | ||
| def __init__(self, *args, **kwargs): | ||
@@ -792,8 +833,22 @@ super().__init__(*args, **kwargs) | ||
| model_inputs = {"latent_sample": latent_sample} | ||
| model_inputs = { | ||
| "latent_sample": latent_sample, | ||
| } | ||
| onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) | ||
| if self.use_io_binding: | ||
| output_shapes, output_buffers = self._prepare_io_binding(model_inputs) | ||
| if self.device.type == "cpu": | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| else: | ||
| self._io_binding.synchronize_inputs() | ||
| self.session.run_with_iobinding(self._io_binding) | ||
| self._io_binding.synchronize_outputs() | ||
| model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names} | ||
| else: | ||
| onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) | ||
| onnx_outputs = self.session.run(None, onnx_inputs) | ||
| model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) | ||
| if "latent_sample" in model_outputs: | ||
@@ -808,19 +863,9 @@ model_outputs["latents"] = model_outputs.pop("latent_sample") | ||
| class ORTWrapperVae(ORTPipelinePart): | ||
| def __init__(self, encoder: ORTModelVaeEncoder, decoder: ORTModelVaeDecoder): | ||
| class ORTVae(ORTParentMixin): | ||
| def __init__(self, encoder: Optional[ORTVaeEncoder] = None, decoder: Optional[ORTVaeDecoder] = None): | ||
| self.encoder = encoder | ||
| self.decoder = decoder | ||
| self.encoder = encoder | ||
| @property | ||
| def config(self): | ||
| return self.decoder.config | ||
| self.initialize_ort_attributes(parts=list(filter(None, {self.encoder, self.decoder}))) | ||
| @property | ||
| def dtype(self): | ||
| return self.decoder.dtype | ||
| @property | ||
| def device(self): | ||
| return self.decoder.device | ||
| def decode(self, *args, **kwargs): | ||
@@ -832,9 +877,14 @@ return self.decoder(*args, **kwargs) | ||
| def to(self, *args, **kwargs): | ||
| self.decoder.to(*args, **kwargs) | ||
| if self.encoder is not None: | ||
| self.encoder.to(*args, **kwargs) | ||
| @property | ||
| def config(self): | ||
| return self.decoder.config | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| ORT_PIPELINE_DOCSTRING = r""" | ||
| This Pipeline inherits from [`ORTDiffusionPipeline`] and is used to run inference with the ONNX Runtime. | ||
| The pipeline can be loaded from a pretrained pipeline using the [`ORTDiffusionPipeline.from_pretrained`] method. | ||
| """ | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTStableDiffusionPipeline(ORTDiffusionPipeline, StableDiffusionPipeline): | ||
@@ -845,8 +895,8 @@ """ | ||
| task = "text-to-image" | ||
| main_input_name = "prompt" | ||
| export_feature = "text-to-image" | ||
| auto_model_class = StableDiffusionPipeline | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTStableDiffusionImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionImg2ImgPipeline): | ||
@@ -857,8 +907,8 @@ """ | ||
| task = "image-to-image" | ||
| main_input_name = "image" | ||
| export_feature = "image-to-image" | ||
| auto_model_class = StableDiffusionImg2ImgPipeline | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTStableDiffusionInpaintPipeline(ORTDiffusionPipeline, StableDiffusionInpaintPipeline): | ||
@@ -869,8 +919,8 @@ """ | ||
| task = "inpainting" | ||
| main_input_name = "prompt" | ||
| export_feature = "inpainting" | ||
| auto_model_class = StableDiffusionInpaintPipeline | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTStableDiffusionXLPipeline(ORTDiffusionPipeline, StableDiffusionXLPipeline): | ||
@@ -881,4 +931,4 @@ """ | ||
| task = "text-to-image" | ||
| main_input_name = "prompt" | ||
| export_feature = "text-to-image" | ||
| auto_model_class = StableDiffusionXLPipeline | ||
@@ -900,3 +950,3 @@ | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTStableDiffusionXLImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionXLImg2ImgPipeline): | ||
@@ -907,4 +957,4 @@ """ | ||
| task = "image-to-image" | ||
| main_input_name = "prompt" | ||
| export_feature = "image-to-image" | ||
| auto_model_class = StableDiffusionXLImg2ImgPipeline | ||
@@ -940,3 +990,3 @@ | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTStableDiffusionXLInpaintPipeline(ORTDiffusionPipeline, StableDiffusionXLInpaintPipeline): | ||
@@ -948,3 +998,3 @@ """ | ||
| main_input_name = "image" | ||
| export_feature = "inpainting" | ||
| task = "inpainting" | ||
| auto_model_class = StableDiffusionXLInpaintPipeline | ||
@@ -980,3 +1030,3 @@ | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTLatentConsistencyModelPipeline(ORTDiffusionPipeline, LatentConsistencyModelPipeline): | ||
@@ -987,8 +1037,8 @@ """ | ||
| task = "text-to-image" | ||
| main_input_name = "prompt" | ||
| export_feature = "text-to-image" | ||
| auto_model_class = LatentConsistencyModelPipeline | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsistencyModelImg2ImgPipeline): | ||
@@ -999,4 +1049,4 @@ """ | ||
| task = "image-to-image" | ||
| main_input_name = "image" | ||
| export_feature = "image-to-image" | ||
| auto_model_class = LatentConsistencyModelImg2ImgPipeline | ||
@@ -1018,3 +1068,3 @@ | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTStableDiffusion3Pipeline(ORTDiffusionPipeline, StableDiffusion3Pipeline): | ||
@@ -1025,7 +1075,7 @@ """ | ||
| task = "text-to-image" | ||
| main_input_name = "prompt" | ||
| export_feature = "text-to-image" | ||
| auto_model_class = StableDiffusion3Pipeline | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTStableDiffusion3Img2ImgPipeline(ORTDiffusionPipeline, StableDiffusion3Img2ImgPipeline): | ||
@@ -1036,4 +1086,4 @@ """ | ||
| task = "image-to-image" | ||
| main_input_name = "image" | ||
| export_feature = "image-to-image" | ||
| auto_model_class = StableDiffusion3Img2ImgPipeline | ||
@@ -1053,3 +1103,3 @@ | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3InpaintPipeline): | ||
@@ -1060,7 +1110,7 @@ """ | ||
| task = "inpainting" | ||
| main_input_name = "prompt" | ||
| export_feature = "inpainting" | ||
| auto_model_class = StableDiffusion3InpaintPipeline | ||
| @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) | ||
| @add_end_docstrings(ORT_PIPELINE_DOCSTRING) | ||
| class ORTFluxPipeline(ORTDiffusionPipeline, FluxPipeline): | ||
@@ -1071,4 +1121,4 @@ """ | ||
| task = "text-to-image" | ||
| main_input_name = "prompt" | ||
| export_feature = "text-to-image" | ||
| auto_model_class = FluxPipeline | ||
@@ -1075,0 +1125,0 @@ |
@@ -95,8 +95,8 @@ # Copyright 2021 The HuggingFace Team. All rights reserved. | ||
| onnx_model_path += [ | ||
| model_or_path.encoder_model_path, | ||
| model_or_path.decoder_model_path, | ||
| model_or_path.encoder.path, | ||
| model_or_path.decoder.path, | ||
| ] | ||
| # Add the decoder with past key/values if present | ||
| if model_or_path.use_cache: | ||
| onnx_model_path.append(model_or_path.decoder_with_past_model_path) | ||
| if model_or_path.decoder_with_past is not None: | ||
| onnx_model_path.append(model_or_path.decoder_with_past.path) | ||
| elif isinstance(model_or_path, ORTModelForCausalLM) and model_or_path.use_merged: | ||
@@ -108,3 +108,3 @@ raise NotImplementedError( | ||
| else: | ||
| onnx_model_path.append(model_or_path.model_path) | ||
| onnx_model_path.append(model_or_path.path) | ||
| config = model_or_path.config | ||
@@ -111,0 +111,0 @@ elif os.path.isdir(model_or_path): |
@@ -21,3 +21,3 @@ # Copyright 2021 The HuggingFace Team. All rights reserved. | ||
| from inspect import signature | ||
| from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | ||
@@ -34,2 +34,3 @@ import numpy as np | ||
| import onnxruntime as ort | ||
| from onnxruntime.transformers.io_binding_helper import TypeHelper | ||
@@ -54,31 +55,3 @@ from ..exporters.onnx import OnnxConfig, OnnxConfigWithLoss | ||
| _ORT_TO_NP_TYPE = { | ||
| "tensor(bool)": np.bool_, | ||
| "tensor(int8)": np.int8, | ||
| "tensor(uint8)": np.uint8, | ||
| "tensor(int16)": np.int16, | ||
| "tensor(uint16)": np.uint16, | ||
| "tensor(int32)": np.int32, | ||
| "tensor(uint32)": np.uint32, | ||
| "tensor(int64)": np.int64, | ||
| "tensor(uint64)": np.uint64, | ||
| "tensor(float16)": np.float16, | ||
| "tensor(float)": np.float32, | ||
| "tensor(double)": np.float64, | ||
| } | ||
| def _is_gpu_available(): | ||
| """ | ||
| Checks if a gpu is available. | ||
| """ | ||
| available_providers = ort.get_available_providers() | ||
| if ( | ||
| "CUDAExecutionProvider" in available_providers or "ROCMExecutionProvider" in available_providers | ||
| ) and torch.cuda.is_available(): | ||
| return True | ||
| else: | ||
| return False | ||
| def is_onnxruntime_training_available(): | ||
@@ -141,2 +114,3 @@ """ | ||
| "mistral": "gpt2", | ||
| "modernbert": "bert", | ||
| "mpnet": "bert", | ||
@@ -208,3 +182,3 @@ "mt5": "bart", | ||
| if provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider", "ROCMExecutionProvider"]: | ||
| return torch.device(f"cuda:{provider_options['device_id']}") | ||
| return torch.device(f"cuda:{provider_options.get('device_id', 0)}") | ||
| else: | ||
@@ -298,2 +272,36 @@ return torch.device("cpu") | ||
| def prepare_providers_and_provider_options( | ||
| provider: str = "CPUExecutionProvider", | ||
| providers: Optional[Sequence[str]] = None, | ||
| provider_options: Optional[Union[Sequence[Dict[str, Any]], Dict[str, Any]]] = None, | ||
| ): | ||
| """ | ||
| Prepare the providers and provider options for ONNX Runtime. | ||
| Args: | ||
| provider (`str`): | ||
| The provider to use. If `None`, the default provider will be used. | ||
| providers (`Sequence[str]`, `optional`): | ||
| The list of providers to use. If `None`, the default provider will be used. | ||
| provider_options (`Union[Sequence[Dict[str, Any]], Dict[str, Any]]`, `optional`): | ||
| The options to use for the providers. If `None`, the default options will be used. | ||
| """ | ||
| if providers is None: | ||
| providers = [provider] | ||
| for provider in providers: | ||
| validate_provider_availability(provider) | ||
| if provider_options is None: | ||
| provider_options = [{}] * len(providers) | ||
| elif isinstance(provider_options, dict): | ||
| provider_options = [provider_options] + [{}] * (len(providers) - 1) | ||
| elif len(provider_options) != len(providers): | ||
| raise ValueError( | ||
| f"When passing a list of provider options, it should be the same length as the list of providers. " | ||
| f"Got {len(provider_options)} provider options for {len(providers)} providers." | ||
| ) | ||
| return providers, provider_options | ||
| def check_io_binding(providers: List[str], use_io_binding: Optional[bool] = None) -> bool: | ||
@@ -443,1 +451,21 @@ """ | ||
| self.stride = stride | ||
| def get_dtype_from_session(session: ort.InferenceSession) -> torch.dtype: | ||
| """ | ||
| Returns the `torch.dtype` associated with the ONNX Runtime session. | ||
| This dtype is inferred from the input/output dtypes of the session. | ||
| If no floating point type is found, it defaults to `torch.float32`. | ||
| """ | ||
| for input in session.get_inputs(): | ||
| torch_dtype = TypeHelper.ort_type_to_torch_type(input.type) | ||
| if torch_dtype.is_floating_point: | ||
| return torch_dtype | ||
| for output in session.get_outputs(): | ||
| torch_dtype = TypeHelper.ort_type_to_torch_type(output.type) | ||
| if torch_dtype.is_floating_point: | ||
| return torch_dtype | ||
| return torch.float32 |
@@ -22,3 +22,6 @@ # coding=utf-8 | ||
| AutoConfig, | ||
| AutoFeatureExtractor, | ||
| AutoImageProcessor, | ||
| AutomaticSpeechRecognitionPipeline, | ||
| AutoTokenizer, | ||
| FeatureExtractionPipeline, | ||
@@ -45,5 +48,12 @@ FillMaskPipeline, | ||
| from transformers.feature_extraction_utils import PreTrainedFeatureExtractor | ||
| from transformers.onnx.utils import get_preprocessor | ||
| from transformers.image_processing_utils import BaseImageProcessor | ||
| from transformers.pipelines import ( | ||
| FEATURE_EXTRACTOR_MAPPING, | ||
| IMAGE_PROCESSOR_MAPPING, | ||
| TOKENIZER_MAPPING, | ||
| check_task, | ||
| get_default_model_and_revision, | ||
| infer_framework_load_model, | ||
| ) | ||
| from transformers.pipelines import SUPPORTED_TASKS as TRANSFORMERS_SUPPORTED_TASKS | ||
| from transformers.pipelines import infer_framework_load_model | ||
@@ -92,3 +102,3 @@ from ..utils import is_onnxruntime_available, is_transformers_version | ||
| "impl": ImageSegmentationPipeline, | ||
| "class": (ORTModelForSemanticSegmentation,) if is_onnxruntime_available() else (), | ||
| "class": (ORTModelForSemanticSegmentation,), | ||
| "default": "nvidia/segformer-b0-finetuned-ade-512-512", | ||
@@ -181,2 +191,4 @@ "type": "image", | ||
| load_feature_extractor=None, | ||
| image_processor=None, | ||
| load_image_processor=None, | ||
| SUPPORTED_TASKS=None, | ||
@@ -201,5 +213,3 @@ subfolder: str = "", | ||
| if model is None: | ||
| model_id = SUPPORTED_TASKS[targeted_task]["default"] | ||
| elif isinstance(model, str): | ||
| if isinstance(model, str): | ||
| model_id = model | ||
@@ -227,3 +237,3 @@ else: | ||
| return model, model_id, tokenizer, feature_extractor | ||
| return model, model_id, tokenizer, feature_extractor, image_processor | ||
@@ -238,2 +248,4 @@ | ||
| load_feature_extractor, | ||
| image_processor, | ||
| load_image_processor, | ||
| SUPPORTED_TASKS, | ||
@@ -250,6 +262,3 @@ subfolder: str = "", | ||
| if model is None: | ||
| model_id = SUPPORTED_TASKS[targeted_task]["default"] | ||
| model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(model_id, export=True) | ||
| elif isinstance(model, str): | ||
| if isinstance(model, str): | ||
| model_id = model | ||
@@ -279,2 +288,13 @@ model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained( | ||
| ) | ||
| if image_processor is None and load_image_processor: | ||
| for preprocessor in model.preprocessors: | ||
| if isinstance(preprocessor, BaseImageProcessor): | ||
| image_processor = preprocessor | ||
| break | ||
| if image_processor is None: | ||
| raise ValueError( | ||
| "Could not automatically find a feature extractor for the ORTModel, you must pass a " | ||
| "image_processor explictly" | ||
| ) | ||
| model_id = None | ||
@@ -286,3 +306,3 @@ else: | ||
| ) | ||
| return model, model_id, tokenizer, feature_extractor | ||
| return model, model_id, tokenizer, feature_extractor, image_processor | ||
@@ -301,2 +321,3 @@ | ||
| feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, | ||
| image_processor: Optional[Union[str, BaseImageProcessor]] = None, | ||
| use_fast: bool = True, | ||
@@ -323,3 +344,12 @@ token: Optional[Union[str, bool]] = None, | ||
| # copied from transformers.pipelines.__init__.py | ||
| supported_tasks = ORT_SUPPORTED_TASKS if accelerator == "ort" else TRANSFORMERS_SUPPORTED_TASKS | ||
| if model is None: | ||
| if accelerator != "ort": | ||
| _, target_task, task_options = check_task(task) | ||
| model, default_revision = get_default_model_and_revision(target_task, "pt", task_options) | ||
| revision = revision or default_revision | ||
| else: | ||
| model = supported_tasks[targeted_task]["default"] | ||
| hub_kwargs = { | ||
@@ -337,9 +367,9 @@ "revision": revision, | ||
| supported_tasks = ORT_SUPPORTED_TASKS if accelerator == "ort" else TRANSFORMERS_SUPPORTED_TASKS | ||
| no_feature_extractor_tasks = set() | ||
| no_tokenizer_tasks = set() | ||
| no_image_processor_tasks = set() | ||
| for _task, values in supported_tasks.items(): | ||
| if values["type"] == "text": | ||
| no_feature_extractor_tasks.add(_task) | ||
| no_image_processor_tasks.add(_task) | ||
| elif values["type"] in {"image", "video"}: | ||
@@ -349,5 +379,11 @@ no_tokenizer_tasks.add(_task) | ||
| no_tokenizer_tasks.add(_task) | ||
| no_image_processor_tasks.add(_task) | ||
| elif values["type"] not in ["multimodal", "audio", "video"]: | ||
| raise ValueError(f"SUPPORTED_TASK {_task} contains invalid type {values['type']}") | ||
| model_config = config or model.config | ||
| load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None | ||
| load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None | ||
| load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None | ||
| # copied from transformers.pipelines.__init__.py l.609 | ||
@@ -360,11 +396,13 @@ if targeted_task in no_tokenizer_tasks: | ||
| load_tokenizer = False | ||
| else: | ||
| load_tokenizer = True | ||
| if targeted_task in no_feature_extractor_tasks: | ||
| load_feature_extractor = False | ||
| else: | ||
| load_feature_extractor = True | ||
| model, model_id, tokenizer, feature_extractor = MAPPING_LOADING_FUNC[accelerator]( | ||
| if targeted_task in no_image_processor_tasks: | ||
| load_image_processor = False | ||
| if load_image_processor and load_feature_extractor: | ||
| load_feature_extractor = False | ||
| model, model_id, tokenizer, feature_extractor, image_processor = MAPPING_LOADING_FUNC[accelerator]( | ||
| model, | ||
@@ -376,2 +414,4 @@ targeted_task, | ||
| load_feature_extractor, | ||
| image_processor, | ||
| load_image_processor, | ||
| SUPPORTED_TASKS=supported_tasks, | ||
@@ -385,6 +425,9 @@ config=config, | ||
| use_fast = kwargs.get(use_fast, "True") | ||
| if tokenizer is None and load_tokenizer: | ||
| tokenizer = get_preprocessor(model_id) | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, **hub_kwargs) | ||
| if feature_extractor is None and load_feature_extractor: | ||
| feature_extractor = get_preprocessor(model_id) | ||
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, use_fast=use_fast, **hub_kwargs) | ||
| if image_processor is None and load_image_processor: | ||
| image_processor = AutoImageProcessor.from_pretrained(model_id, **hub_kwargs) | ||
@@ -396,4 +439,5 @@ return transformers_pipeline( | ||
| feature_extractor=feature_extractor, | ||
| image_processor=image_processor, | ||
| use_fast=use_fast, | ||
| **kwargs, | ||
| ) |
@@ -18,2 +18,4 @@ # Copyright 2021 The HuggingFace Team. All rights reserved. | ||
| CONFIG_NAME, | ||
| DIFFUSION_MODEL_CONFIG_FILE_NAME, | ||
| DIFFUSION_MODEL_ONNX_FILE_NAME, | ||
| DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, | ||
@@ -26,2 +28,3 @@ DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER, | ||
| DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, | ||
| DIFFUSION_PIPELINE_CONFIG_FILE_NAME, | ||
| ONNX_WEIGHTS_NAME, | ||
@@ -46,4 +49,6 @@ ) | ||
| is_onnxruntime_available, | ||
| is_onnxslim_available, | ||
| is_pydantic_available, | ||
| is_sentence_transformers_available, | ||
| is_tensorrt_available, | ||
| is_tf_available, | ||
@@ -50,0 +55,0 @@ is_timm_available, |
@@ -17,2 +17,4 @@ # Copyright 2023 The HuggingFace Inc. team. All rights reserved. | ||
| CONFIG_NAME = "config.json" | ||
| ONNX_WEIGHTS_NAME = "model.onnx" | ||
| DIFFUSION_MODEL_UNET_SUBFOLDER = "unet" | ||
@@ -25,2 +27,4 @@ DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER = "transformer" | ||
| DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER = "text_encoder_3" | ||
| ONNX_WEIGHTS_NAME = "model.onnx" | ||
| DIFFUSION_PIPELINE_CONFIG_FILE_NAME = "model_index.json" | ||
| DIFFUSION_MODEL_CONFIG_FILE_NAME = "config.json" | ||
| DIFFUSION_MODEL_ONNX_FILE_NAME = "model.onnx" |
@@ -92,2 +92,3 @@ # Copyright 2022 The HuggingFace Team. All rights reserved. | ||
| _datasets_available = _is_package_available("datasets") | ||
| _tensorrt_available = _is_package_available("tensorrt") | ||
| _diffusers_available, _diffusers_version = _is_package_available("diffusers", return_version=True) | ||
@@ -138,2 +139,3 @@ _transformers_available, _transformers_version = _is_package_available("transformers", return_version=True) | ||
| ) | ||
| _onnxslim_available = _is_package_available("onnxslim") | ||
@@ -242,2 +244,6 @@ if _tf_available and version.parse(_tf_version) < version.parse("2"): | ||
| def is_tensorrt_available(): | ||
| return _tensorrt_available | ||
| def is_torch_available(): | ||
@@ -273,2 +279,6 @@ return _torch_available | ||
| def is_onnxslim_available(): | ||
| return _onnxslim_available | ||
| @contextmanager | ||
@@ -275,0 +285,0 @@ def check_if_pytorch_greater(target_version: str, message: str): |
@@ -240,3 +240,3 @@ # coding=utf-8 | ||
| "blenderbot-small": BartLikeNormalizedTextConfig, | ||
| "bloom": NormalizedTextConfig.with_args(num_layers="n_layer"), | ||
| "bloom": NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head"), | ||
| "falcon": NormalizedTextConfig, | ||
@@ -261,2 +261,3 @@ "camembert": NormalizedTextConfig, | ||
| "imagegpt": GPT2LikeNormalizedTextConfig, | ||
| "internlm2": NormalizedTextConfigWithGQA, | ||
| "llama": NormalizedTextConfigWithGQA, | ||
@@ -269,2 +270,3 @@ "longt5": T5LikeNormalizedTextConfig, | ||
| "mixtral": NormalizedTextConfigWithGQA, | ||
| "modernbert": NormalizedTextConfig, | ||
| "mpnet": NormalizedTextConfig, | ||
@@ -275,2 +277,4 @@ "mpt": MPTNormalizedTextConfig, | ||
| "nystromformer": NormalizedTextConfig, | ||
| "olmo": NormalizedTextConfig, | ||
| "olmo2": NormalizedTextConfig, | ||
| "opt": NormalizedTextConfig, | ||
@@ -281,3 +285,2 @@ "pegasus": BartLikeNormalizedTextConfig, | ||
| "phi3": NormalizedTextConfigWithGQA, | ||
| "phi3small": NormalizedTextConfigWithGQA, | ||
| "poolformer": NormalizedVisionConfig, | ||
@@ -298,2 +301,4 @@ "regnet": NormalizedVisionConfig, | ||
| "qwen2": NormalizedTextConfig, | ||
| "qwen3": NormalizedTextConfig, | ||
| "qwen3-moe": NormalizedTextConfig, | ||
| "granite": NormalizedTextConfigWithGQA, | ||
@@ -300,0 +305,0 @@ } |
@@ -15,2 +15,2 @@ # Copyright 2021 The HuggingFace Team. All rights reserved. | ||
| __version__ = "1.25.3" | ||
| __version__ = "1.26.0" |
+8
-6
| Metadata-Version: 2.1 | ||
| Name: optimum | ||
| Version: 1.25.3 | ||
| Version: 1.26.0 | ||
| Summary: Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality. | ||
@@ -34,3 +34,3 @@ Home-page: https://github.com/huggingface/optimum | ||
| Requires-Dist: onnxruntime>=1.11.0; extra == "onnxruntime" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime" | ||
| Provides-Extra: onnxruntime-gpu | ||
@@ -41,3 +41,3 @@ Requires-Dist: onnx; extra == "onnxruntime-gpu" | ||
| Requires-Dist: onnxruntime-gpu>=1.11.0; extra == "onnxruntime-gpu" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime-gpu" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime-gpu" | ||
| Provides-Extra: onnxruntime-training | ||
@@ -49,3 +49,3 @@ Requires-Dist: evaluate; extra == "onnxruntime-training" | ||
| Requires-Dist: protobuf>=3.20.1; extra == "onnxruntime-training" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime-training" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime-training" | ||
| Requires-Dist: onnxruntime-training>=1.11.0; extra == "onnxruntime-training" | ||
@@ -57,3 +57,3 @@ Provides-Extra: exporters | ||
| Requires-Dist: protobuf>=3.20.1; extra == "exporters" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "exporters" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "exporters" | ||
| Provides-Extra: exporters-gpu | ||
@@ -64,3 +64,3 @@ Requires-Dist: onnx; extra == "exporters-gpu" | ||
| Requires-Dist: protobuf>=3.20.1; extra == "exporters-gpu" | ||
| Requires-Dist: transformers<4.52.0,>=4.36; extra == "exporters-gpu" | ||
| Requires-Dist: transformers<4.53.0,>=4.36; extra == "exporters-gpu" | ||
| Provides-Extra: exporters-tf | ||
@@ -114,2 +114,3 @@ Requires-Dist: onnx; extra == "exporters-tf" | ||
| Requires-Dist: hf_xet; extra == "dev" | ||
| Requires-Dist: onnxslim>=0.1.53; extra == "dev" | ||
| Requires-Dist: black~=23.1; extra == "dev" | ||
@@ -133,2 +134,3 @@ Requires-Dist: ruff==0.1.5; extra == "dev" | ||
| Requires-Dist: hf_xet; extra == "tests" | ||
| Requires-Dist: onnxslim>=0.1.53; extra == "tests" | ||
| Provides-Extra: quality | ||
@@ -135,0 +137,0 @@ Requires-Dist: black~=23.1; extra == "quality" |
+6
-5
@@ -41,2 +41,3 @@ import re | ||
| "hf_xet", | ||
| "onnxslim>=0.1.53", | ||
| ] | ||
@@ -54,3 +55,3 @@ | ||
| "onnxruntime>=1.11.0", | ||
| "transformers>=4.36,<4.52.0", | ||
| "transformers>=4.36,<4.53.0", | ||
| ], | ||
@@ -62,3 +63,3 @@ "onnxruntime-gpu": [ | ||
| "onnxruntime-gpu>=1.11.0", | ||
| "transformers>=4.36,<4.52.0", | ||
| "transformers>=4.36,<4.53.0", | ||
| ], | ||
@@ -71,3 +72,3 @@ "onnxruntime-training": [ | ||
| "protobuf>=3.20.1", | ||
| "transformers>=4.36,<4.52.0", | ||
| "transformers>=4.36,<4.53.0", | ||
| "onnxruntime-training>=1.11.0", | ||
@@ -80,3 +81,3 @@ ], | ||
| "protobuf>=3.20.1", | ||
| "transformers>=4.36,<4.52.0", | ||
| "transformers>=4.36,<4.53.0", | ||
| ], | ||
@@ -88,3 +89,3 @@ "exporters-gpu": [ | ||
| "protobuf>=3.20.1", | ||
| "transformers>=4.36,<4.52.0", | ||
| "transformers>=4.36,<4.53.0", | ||
| ], | ||
@@ -91,0 +92,0 @@ "exporters-tf": [ |
| # Copyright 2022 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import List | ||
| from onnxruntime.transformers.onnx_model import OnnxModel | ||
| def find_fully_connected_layers_nodes(model: OnnxModel) -> List[List[str]]: | ||
| adds = model.get_nodes_by_op_type("Add") | ||
| fc = list(filter(lambda graph: graph[1] is not None, ((add, model.match_parent(add, "MatMul")) for add in adds))) | ||
| return fc |
| # Copyright 2022 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from .io_binding_helper import IOBindingHelper, TypeHelper |
| # Copyright 2022 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import logging | ||
| import traceback | ||
| from typing import TYPE_CHECKING | ||
| import numpy as np | ||
| import torch | ||
| import onnxruntime as ort | ||
| from onnxruntime.capi.onnxruntime_inference_collection import OrtValue | ||
| from onnxruntime.transformers.io_binding_helper import TypeHelper as ORTTypeHelper | ||
| from ..utils import is_cupy_available, is_onnxruntime_training_available | ||
| if TYPE_CHECKING: | ||
| from ..modeling_ort import ORTModel | ||
| if is_cupy_available(): | ||
| import cupy as cp | ||
| # Adapted from https://github.com/microsoft/onnxruntime/blob/93e0a151177ad8222c2c95f814342bfa27f0a64d/onnxruntime/python/tools/transformers/io_binding_helper.py#L12 | ||
| class TypeHelper(ORTTypeHelper): | ||
| """ | ||
| Gets data type information of the ONNX Runtime inference session and provides the mapping from | ||
| `OrtValue` data types to the data types of other frameworks (NumPy, PyTorch, etc). | ||
| """ | ||
| @staticmethod | ||
| def ort_type_to_numpy_type(ort_type: str): | ||
| ort_type_to_numpy_type_map = { | ||
| "tensor(int64)": np.int64, | ||
| "tensor(int32)": np.int32, | ||
| "tensor(int8)": np.int8, | ||
| "tensor(float)": np.float32, | ||
| "tensor(float16)": np.float16, | ||
| "tensor(bool)": bool, | ||
| } | ||
| if ort_type in ort_type_to_numpy_type_map: | ||
| return ort_type_to_numpy_type_map[ort_type] | ||
| else: | ||
| raise ValueError( | ||
| f"{ort_type} is not supported. Here is a list of supported data type: {ort_type_to_numpy_type_map.keys()}" | ||
| ) | ||
| @staticmethod | ||
| def ort_type_to_torch_type(ort_type: str): | ||
| ort_type_to_torch_type_map = { | ||
| "tensor(int64)": torch.int64, | ||
| "tensor(int32)": torch.int32, | ||
| "tensor(int8)": torch.int8, | ||
| "tensor(float)": torch.float32, | ||
| "tensor(float16)": torch.float16, | ||
| "tensor(bool)": torch.bool, | ||
| } | ||
| if ort_type in ort_type_to_torch_type_map: | ||
| return ort_type_to_torch_type_map[ort_type] | ||
| else: | ||
| raise ValueError( | ||
| f"{ort_type} is not supported. Here is a list of supported data type: {ort_type_to_torch_type_map.keys()}" | ||
| ) | ||
| # Adapted from https://github.com/microsoft/onnxruntime/blob/1ab11a111ce0717bfbfaca964d04a017cb9b1752/onnxruntime/python/tools/transformers/io_binding_helper.py#L97 | ||
| class IOBindingHelper: | ||
| """ | ||
| A helper class to enable `ORTModel` instances to prepare IO binding with dynamic shaped outputs for an inference session and transfer | ||
| tensors from ONNX Runtime to other frameworks on device. It helps reduce memory copy between the host and device. | ||
| """ | ||
| def __init__(self, model: ort.InferenceSession, device, **kwargs): | ||
| self.model = model | ||
| self.device = device | ||
| # Create {name:idx} dict for model inputs and outputs | ||
| self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(model.get_inputs())} | ||
| self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(model.get_outputs())} | ||
| self.model_input_names = list(self.model_inputs.keys()) | ||
| self.model_output_names = list(self.model_outputs.keys()) | ||
| @staticmethod | ||
| def to_pytorch(ort_value: OrtValue) -> torch.Tensor: | ||
| """ | ||
| Converts tensors held by OrtValues in ONNX runtime memory buffer to torch tensor. | ||
| """ | ||
| if is_onnxruntime_training_available(): | ||
| return IOBindingHelper.to_pytorch_via_dlpack(ort_value) | ||
| else: | ||
| try: | ||
| return IOBindingHelper.to_pytorch_via_cupy(ort_value) | ||
| except Exception: | ||
| logging.error(traceback.format_exc()) | ||
| logging.info("Unable to access output memory in CUDA, will offload to CPU") | ||
| return IOBindingHelper.to_pytorch_via_numpy(ort_value) | ||
| @staticmethod | ||
| def to_pytorch_via_numpy(ort_value: OrtValue) -> torch.Tensor: | ||
| ort_device = ort_value.device_name().lower() | ||
| return torch.tensor(ort_value.numpy()).to(ort_device) | ||
| @staticmethod | ||
| def to_pytorch_via_cupy(ort_value: OrtValue) -> torch.Tensor: | ||
| ort_device = ort_value.device_name().lower() | ||
| if ort_device != "cuda": | ||
| raise RuntimeError(f"Exchange tensors to PyTorch via CuPy only when device is CUDA, got: {ort_device}") | ||
| ort_type = ort_value.data_type() | ||
| numpy_type = TypeHelper.ort_type_to_numpy_type(ort_type) | ||
| # Access CUDA memory via CuPy | ||
| memory = cp.cuda.UnownedMemory(ort_value.data_ptr(), 0, None) | ||
| memory_ptr = cp.cuda.MemoryPointer(memory, 0) | ||
| cp_array = cp.ndarray(shape=ort_value.shape(), memptr=memory_ptr, dtype=numpy_type) | ||
| torch_tensor = torch.from_dlpack(cp_array.toDlpack()) | ||
| # If is boolean, the dtype will be uint8 and need to be convert back to bool. | ||
| if "bool" in ort_type: | ||
| torch_tensor = torch_tensor.to(torch.bool) | ||
| torch_tensor = torch_tensor.clone() | ||
| return torch_tensor | ||
| @staticmethod | ||
| # dlpack support is available for OrtValue only when `onnxruntime-training` is installed | ||
| def to_pytorch_via_dlpack(ort_value: OrtValue) -> torch.Tensor: | ||
| from torch._C import _from_dlpack | ||
| torch_tensor = _from_dlpack(ort_value.to_dlpack()) | ||
| return torch_tensor | ||
| @staticmethod | ||
| def get_device_index(device): | ||
| if isinstance(device, str): | ||
| # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 | ||
| device = torch.device(device) | ||
| elif isinstance(device, int): | ||
| return device | ||
| return 0 if device.index is None else device.index | ||
| @staticmethod | ||
| def prepare_io_binding(ort_model: "ORTModel", **inputs) -> ort.IOBinding: | ||
| """ | ||
| Returns an IOBinding object for an inference session. This method is for general purpose, if the inputs and outputs | ||
| are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks. | ||
| """ | ||
| if not all(input_name in inputs.keys() for input_name in ort_model.input_names): | ||
| raise ValueError( | ||
| f"The ONNX model takes {ort_model.input_names.keys()} as inputs, but only {inputs.keys()} are given." | ||
| ) | ||
| name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_model.model) | ||
| # Bind inputs and outputs to onnxruntime session | ||
| io_binding = ort_model.model.io_binding() | ||
| # Bind inputs | ||
| for input_name in ort_model.input_names: | ||
| onnx_input = inputs.pop(input_name) | ||
| onnx_input = onnx_input.contiguous() | ||
| io_binding.bind_input( | ||
| input_name, | ||
| onnx_input.device.type, | ||
| ort_model.device.index, | ||
| name_to_np_type[input_name], | ||
| list(onnx_input.size()), | ||
| onnx_input.data_ptr(), | ||
| ) | ||
| # Bind outputs | ||
| for name in ort_model.output_names: | ||
| io_binding.bind_output(name, ort_model.device.type, device_id=ort_model.device.index) | ||
| return io_binding |
| from typing import TYPE_CHECKING, Tuple | ||
| if TYPE_CHECKING: | ||
| import torch | ||
| def bloom_convert_to_standard_cache( | ||
| past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]], batch_size: int | ||
| ) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: | ||
| """ | ||
| Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, | ||
| num_heads, ...])) | ||
| """ | ||
| batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape | ||
| num_heads = batch_size_times_num_heads // batch_size | ||
| # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] | ||
| # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] | ||
| return tuple( | ||
| ( | ||
| layer_past[0].view(batch_size, num_heads, head_dim, seq_length), | ||
| layer_past[1].view(batch_size, num_heads, seq_length, head_dim), | ||
| ) | ||
| for layer_past in past_key_value | ||
| ) | ||
| def bloom_convert_to_bloom_cache( | ||
| past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]] | ||
| ) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: | ||
| """ | ||
| Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) | ||
| """ | ||
| batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape | ||
| batch_size_times_num_heads = batch_size * num_heads | ||
| # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] | ||
| # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] | ||
| return tuple( | ||
| ( | ||
| layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), | ||
| layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), | ||
| ) | ||
| for layer_past in past_key_value | ||
| ) |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
1804497
-0.79%141
-3.42%35267
-0.73%