New Research: Supply Chain Attack on Axios Pulls Malicious Dependency from npm.Details
Socket
Book a DemoSign in
Socket

optimum

Package Overview
Dependencies
Maintainers
7
Versions
82
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

optimum - pypi Package Compare versions

Comparing version
1.25.3
to
1.26.0
+8
-6
optimum.egg-info/PKG-INFO
Metadata-Version: 2.1
Name: optimum
Version: 1.25.3
Version: 1.26.0
Summary: Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality.

@@ -34,3 +34,3 @@ Home-page: https://github.com/huggingface/optimum

Requires-Dist: onnxruntime>=1.11.0; extra == "onnxruntime"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime"
Provides-Extra: onnxruntime-gpu

@@ -41,3 +41,3 @@ Requires-Dist: onnx; extra == "onnxruntime-gpu"

Requires-Dist: onnxruntime-gpu>=1.11.0; extra == "onnxruntime-gpu"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime-gpu"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime-gpu"
Provides-Extra: onnxruntime-training

@@ -49,3 +49,3 @@ Requires-Dist: evaluate; extra == "onnxruntime-training"

Requires-Dist: protobuf>=3.20.1; extra == "onnxruntime-training"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime-training"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime-training"
Requires-Dist: onnxruntime-training>=1.11.0; extra == "onnxruntime-training"

@@ -57,3 +57,3 @@ Provides-Extra: exporters

Requires-Dist: protobuf>=3.20.1; extra == "exporters"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "exporters"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "exporters"
Provides-Extra: exporters-gpu

@@ -64,3 +64,3 @@ Requires-Dist: onnx; extra == "exporters-gpu"

Requires-Dist: protobuf>=3.20.1; extra == "exporters-gpu"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "exporters-gpu"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "exporters-gpu"
Provides-Extra: exporters-tf

@@ -114,2 +114,3 @@ Requires-Dist: onnx; extra == "exporters-tf"

Requires-Dist: hf_xet; extra == "dev"
Requires-Dist: onnxslim>=0.1.53; extra == "dev"
Requires-Dist: black~=23.1; extra == "dev"

@@ -133,2 +134,3 @@ Requires-Dist: ruff==0.1.5; extra == "dev"

Requires-Dist: hf_xet; extra == "tests"
Requires-Dist: onnxslim>=0.1.53; extra == "tests"
Provides-Extra: quality

@@ -135,0 +137,0 @@ Requires-Dist: black~=23.1; extra == "quality"

@@ -34,2 +34,3 @@ transformers>=4.29

hf_xet
onnxslim>=0.1.53
black~=23.1

@@ -46,3 +47,3 @@ ruff==0.1.5

protobuf>=3.20.1
transformers<4.52.0,>=4.36
transformers<4.53.0,>=4.36

@@ -54,3 +55,3 @@ [exporters-gpu]

protobuf>=3.20.1
transformers<4.52.0,>=4.36
transformers<4.53.0,>=4.36

@@ -97,3 +98,3 @@ [exporters-tf]

onnxruntime>=1.11.0
transformers<4.52.0,>=4.36
transformers<4.53.0,>=4.36

@@ -105,3 +106,3 @@ [onnxruntime-gpu]

onnxruntime-gpu>=1.11.0
transformers<4.52.0,>=4.36
transformers<4.53.0,>=4.36

@@ -114,3 +115,3 @@ [onnxruntime-training]

protobuf>=3.20.1
transformers<4.52.0,>=4.36
transformers<4.53.0,>=4.36
onnxruntime-training>=1.11.0

@@ -144,1 +145,2 @@

hf_xet
onnxslim>=0.1.53

@@ -92,3 +92,2 @@ LICENSE

optimum/onnxruntime/constants.py
optimum/onnxruntime/graph.py
optimum/onnxruntime/modeling_decoder.py

@@ -105,6 +104,2 @@ optimum/onnxruntime/modeling_diffusion.py

optimum/onnxruntime/utils.py
optimum/onnxruntime/io_binding/__init__.py
optimum/onnxruntime/io_binding/io_binding_helper.py
optimum/onnxruntime/models/__init__.py
optimum/onnxruntime/models/bloom.py
optimum/onnxruntime/preprocessors/__init__.py

@@ -111,0 +106,0 @@ optimum/onnxruntime/preprocessors/quantization.py

@@ -172,2 +172,7 @@ # Copyright 2022 The HuggingFace Team. All rights reserved.

)
optional_group.add_argument(
"--slim",
action="store_true",
help="Enables onnxslim optimization.",
)

@@ -290,3 +295,4 @@ input_group = parser.add_argument_group(

do_constant_folding=not self.args.no_constant_folding,
slim=self.args.slim,
**input_shapes,
)

@@ -18,3 +18,2 @@ # coding=utf-8

import argparse
import warnings
from pathlib import Path

@@ -29,3 +28,8 @@

from ...utils import DEFAULT_DUMMY_SHAPES, logging
from ...utils.import_utils import is_transformers_version
from ...utils.import_utils import (
is_diffusers_available,
is_sentence_transformers_available,
is_timm_available,
is_transformers_version,
)
from ...utils.save_utils import maybe_load_preprocessors

@@ -63,5 +67,4 @@ from ..tasks import TasksManager

atol: Optional[float] = None,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
trust_remote_code: bool = False,
pad_token_id: Optional[int] = None,
# hub options
subfolder: str = "",

@@ -71,4 +74,6 @@ revision: str = "main",

local_files_only: bool = False,
use_auth_token: Optional[Union[bool, str]] = None,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
########################################
for_ort: bool = False,

@@ -85,2 +90,3 @@ do_validation: bool = True,

do_constant_folding: bool = True,
slim: bool = False,
**kwargs_shapes,

@@ -174,2 +180,4 @@ ):

PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
slim (bool, defaults to `False`):
PyTorch-specific argument. If `True`, use onnxslim to optimize the ONNX model.
**kwargs_shapes (`Dict`):

@@ -186,11 +194,2 @@ Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token
if fp16:

@@ -230,3 +229,3 @@ if dtype is not None:

if task in ["stable-diffusion", "stable-diffusion-xl"]:
if task in ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"]:
logger.warning(

@@ -237,2 +236,19 @@ f"The task `{task}` is deprecated and will be removed in a future release of Optimum. "

if library_name == "sentence_transformers" and not is_sentence_transformers_available():
raise ImportError(
"The library `sentence_transformers` was specified, but it is not installed. "
"Please install it with `pip install sentence-transformers`."
)
if library_name == "diffusers" and not is_diffusers_available():
raise ImportError(
"The library `diffusers` was specified, but it is not installed. "
"Please install it with `pip install diffusers`."
)
if library_name == "timm" and not is_timm_available():
raise ImportError(
"The library `timm` was specified, but it is not installed. Please install it with `pip install timm`."
)
original_task = task

@@ -250,2 +266,18 @@ task = TasksManager.map_from_synonym(task)

)
if library_name == "sentence_transformers" and not is_sentence_transformers_available():
logger.warning(
"The library name was inferred as `sentence_transformers`, which is not installed. "
"Falling back to `transformers` to avoid breaking the export."
)
library_name = "transformers"
elif library_name == "timm" and not is_timm_available():
raise ImportError(
"The library name was inferred as `timm`, which is not installed. "
"Please install it with `pip install timm`."
)
elif library_name == "diffusers" and not is_diffusers_available():
raise ImportError(
"The library name was inferred as `diffusers`, which is not installed. "
"Please install it with `pip install diffusers`."
)

@@ -268,3 +300,10 @@ torch_dtype = None

try:
task = TasksManager.infer_task_from_model(model_name_or_path, library_name=library_name)
task = TasksManager.infer_task_from_model(
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
library_name=library_name,
)
except KeyError as e:

@@ -310,4 +349,5 @@ raise KeyError(

# TODO: Fix in Transformers so that SdpaAttention class can be exported to ONNX. `attn_implementation` is introduced in Transformers 4.36.
if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and is_transformers_version(">=", "4.35.99"):
# TODO: Fix in Transformers so that SdpaAttention class can be exported to ONNX.
# This was fixed in transformers 4.42.0, we can remve it when minimum transformers version is updated to 4.42
if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and is_transformers_version("<", "4.42"):
loading_kwargs["attn_implementation"] = "eager"

@@ -406,2 +446,3 @@

do_constant_folding=do_constant_folding,
slim=slim,
**kwargs_shapes,

@@ -408,0 +449,0 @@ )

@@ -941,15 +941,15 @@ # coding=utf-8

# Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task
if len(onnx_files_subpaths) >= 3 and self.use_past is True:
decoder_path = Path(path, onnx_files_subpaths[1])
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
# Attempt to merge only if the decoder was exported without/with past
onnx_decoder_path = Path(path, ONNX_DECODER_NAME + ".onnx")
onnx_decoder_with_past_path = Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx")
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
if onnx_decoder_path.is_file() and onnx_decoder_with_past_path.is_file() and self.use_past is True:
try:
# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
from ...onnx import merge_decoders
# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
merge_decoders(
decoder=decoder_path,
decoder_with_past=decoder_with_past_path,
decoder=onnx_decoder_path,
decoder_with_past=onnx_decoder_with_past_path,
save_path=decoder_merged_path,

@@ -956,0 +956,0 @@ strict=False,

@@ -33,2 +33,3 @@ # coding=utf-8

from ...onnx.graph_transformations import check_and_save_model
from ...onnx.utils import _get_onnx_external_constants, _get_onnx_external_data_tensors, check_model_uses_external_data

@@ -40,2 +41,3 @@ from ...utils import (

is_diffusers_available,
is_onnxslim_available,
is_torch_onnx_support_available,

@@ -922,2 +924,3 @@ is_transformers_version,

do_constant_folding: bool = True,
slim: bool = False,
**kwargs_shapes,

@@ -978,2 +981,4 @@ ):

PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
slim (bool, defaults to `False`):
Use onnxslim to optimize the ONNX model.
**kwargs_shapes (`Dict`):

@@ -1203,2 +1208,13 @@ Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

if slim:
if not is_onnxslim_available():
raise ImportError("The pip package `onnxslim` is required to optimize onnx models.")
from onnxslim import slim
for subpath in onnx_files_subpaths:
file_path = os.path.join(output, subpath)
slimmed_model = slim(file_path)
check_and_save_model(slimmed_model, file_path)
# Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any

@@ -1205,0 +1221,0 @@ # TODO: treating diffusion separately is quite ugly

@@ -83,2 +83,3 @@ # coding=utf-8

"imagegpt",
"internlm2",
"llama",

@@ -89,2 +90,4 @@ "mistral",

"qwen2",
"qwen3",
"qwen3-moe",
"granite",

@@ -94,3 +97,3 @@ }

if is_transformers_version(">=", "4.45.99"):
if is_transformers_version(">=", "4.46.0"):
MODEL_TYPES_REQUIRING_POSITION_IDS.add("opt")

@@ -97,0 +100,0 @@

@@ -23,3 +23,3 @@ # coding=utf-8

from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

@@ -36,5 +36,15 @@ from huggingface_hub import HfApi

if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers import (
FeatureExtractionMixin,
ImageProcessingMixin,
PreTrainedModel,
ProcessorMixin,
SpecialTokensMixin,
TFPreTrainedModel,
)
PreprocessorT = Union[SpecialTokensMixin, FeatureExtractionMixin, ImageProcessingMixin, ProcessorMixin]
ModelT = Union["PreTrainedModel", "TFPreTrainedModel"]
logger = logging.getLogger(__name__)

@@ -84,2 +94,3 @@

# TODO: Should be removed when we no longer use OptimizedModel for everything
# workaround to enable compatibility between optimum models and transformers pipelines

@@ -95,7 +106,8 @@ class PreTrainedModel(ABC): # noqa: F811

def __init__(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], config: PretrainedConfig):
super().__init__()
def __init__(
self, model: Union["ModelT"], config: "PretrainedConfig", preprocessors: Optional[List["PreprocessorT"]] = None
):
self.model = model
self.config = config
self.preprocessors = []
self.preprocessors = preprocessors or []

@@ -236,19 +248,11 @@ def __call__(self, *args, **kwargs):

config_name_or_path: Union[str, os.PathLike],
revision: Optional[str] = None,
# hub options
subfolder: str = "",
revision: str = "main",
force_download: bool = False,
local_files_only: bool = False,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
force_download: bool = False,
subfolder: str = "",
trust_remote_code: bool = False,
) -> PretrainedConfig:
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token
try:

@@ -286,10 +290,10 @@ config = AutoConfig.from_pretrained(

model_id: Union[str, Path],
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
# hub options
subfolder: str = "",
revision: str = "main",
force_download: bool = False,
local_files_only: bool = False,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
subfolder: str = "",
local_files_only: bool = False,
token: Optional[Union[bool, str]] = None,
**kwargs,

@@ -301,35 +305,14 @@ ) -> "OptimizedModel":

@classmethod
def _from_transformers(
def _export(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
# hub options
subfolder: str = "",
revision: str = "main",
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
subfolder: str = "",
local_files_only: bool = False,
trust_remote_code: bool = False,
**kwargs,
) -> "OptimizedModel":
"""Overwrite this method in subclass to define how to load your model from vanilla transformers model"""
raise NotImplementedError(
"`_from_transformers` method will be deprecated in a future release. Please override `_export` instead"
"to define how to load your model from vanilla transformers model"
)
@classmethod
def _export(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
subfolder: str = "",
local_files_only: bool = False,
trust_remote_code: bool = False,
**kwargs,

@@ -347,12 +330,12 @@ ) -> "OptimizedModel":

model_id: Union[str, Path],
config: Optional[PretrainedConfig] = None,
export: bool = False,
# hub options
subfolder: str = "",
revision: str = "main",
force_download: bool = False,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
subfolder: str = "",
config: Optional[PretrainedConfig] = None,
local_files_only: bool = False,
trust_remote_code: bool = False,
revision: Optional[str] = None,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
**kwargs,

@@ -365,21 +348,5 @@ ) -> "OptimizedModel":

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token
if isinstance(model_id, Path):
model_id = model_id.as_posix()
from_transformers = kwargs.pop("from_transformers", None)
if from_transformers is not None:
logger.warning(
"The argument `from_transformers` is deprecated, and will be removed in optimum 2.0. Use `export` instead"
)
export = from_transformers
if len(model_id.split("@")) == 2:

@@ -448,3 +415,3 @@ logger.warning(

from_pretrained_method = cls._from_transformers if export else cls._from_pretrained
from_pretrained_method = cls._export if export else cls._from_pretrained

@@ -454,2 +421,3 @@ return from_pretrained_method(

config=config,
# hub options
revision=revision,

@@ -456,0 +424,0 @@ cache_dir=cache_dir,

@@ -14,451 +14,638 @@ # Copyright 2022 The HuggingFace Team. All rights reserved.

# limitations under the License.
"""Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models."""
"""Defines the base classes that are used to perform inference with ONNX Runtime sessions."""
from abc import abstractmethod
from typing import Dict, Optional, Set, Tuple, Union
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import torch
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from onnxruntime import InferenceSession
from onnxruntime import InferenceSession, IOBinding
from onnxruntime.transformers.io_binding_helper import TypeHelper
from ..utils import NormalizedConfigManager
from ..utils.logging import warn_once
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
from .utils import logging
from ..onnx.utils import _get_model_external_data_paths
from ..utils.logging import get_logger
from .utils import (
get_device_for_provider,
get_dtype_from_session,
get_provider_for_device,
parse_device,
validate_provider_availability,
)
logger = logging.get_logger(__name__)
logger = get_logger(__name__)
NON_EMPTY_TENSOR = torch.tensor(0)
class ORTModelPart:
class ORTSessionMixin:
"""
For multi-file ONNX models, such as encoder-decoder models, represents a part of the model.
It has its own `onnxruntime.InferenceSession`, and can perform a forward pass.
Mixin class that provides common functionalities for an ONNX Runtime session.
This class is used to manage the session, the execution provider, and the IO binding.
It also provides methods to prepare the inputs and outputs for ONNX Runtime.
"""
# should be in an ORTMixin
_prepare_io_binding = ORTModel._prepare_io_binding
_prepare_output_buffer = ORTModel._prepare_output_buffer
_output_shape_inference = ORTModel._output_shape_inference
def initialize_ort_attributes(self, session: InferenceSession, use_io_binding: Optional[bool] = None):
"""
Initializes the ORTSessionMixin class.
Args:
session (`onnxruntime.InferenceSession`):
The ONNX Runtime session to use for inference.
use_io_binding (`Optional[bool]`, defaults to `None`):
Whether to use IO Binding or not. If `None`, it will be set to `True` for CUDAExecutionProvider and `False`
for other providers.
"""
_prepare_onnx_inputs = ORTModel._prepare_onnx_inputs
_prepare_onnx_outputs = ORTModel._prepare_onnx_outputs
def __init__(self, session: InferenceSession, parent_model: "ORTModel"):
self.session = session
self.parent_model = parent_model
self.main_input_name = self.parent_model.main_input_name
self.path = Path(session._model_path)
self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
if use_io_binding is None:
if self.provider == "CUDAExecutionProvider":
logger.info(
"`use_io_binding` was not set, but CUDAExecutionProvider supports IO Binding. "
"Setting `use_io_binding=True` to leverage IO Binding and improve performance. "
"You can disable it by setting `model.use_io_binding=False`."
)
use_io_binding = True
else:
use_io_binding = False
self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()}
self._use_io_binding = use_io_binding
self._io_binding = IOBinding(session)
self._dtype = get_dtype_from_session(session)
self._device = get_device_for_provider(self.provider, self.provider_option)
self.input_shapes = {input_key.name: input_key.shape for input_key in session.get_inputs()}
self.output_shapes = {output_key.name: output_key.shape for output_key in session.get_outputs()}
self.input_names = {input.name: idx for idx, input in enumerate(session.get_inputs())}
self.output_names = {output.name: idx for idx, output in enumerate(session.get_outputs())}
self.input_shapes = {input.name: input.shape for input in session.get_inputs()}
self.output_shapes = {output.name: output.shape for output in session.get_outputs()}
self.input_dtypes = {input.name: input.type for input in session.get_inputs()}
self.output_dtypes = {output.name: output.type for output in session.get_outputs()}
@property
def device(self):
return self.parent_model.device
def model_path(self) -> str:
"""
Returns the path of the onnx file from which the session was created.
"""
logger.warning(
"The `ORTSessionMixin.model_path` property is deprecated and will be removed in a future version. "
"Please use `ORTSessionMixin.path` instead (`ORTSessionMixin.path` is a proper Path object)."
)
return self.path
@property
def dtype(self):
for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype
def model_name(self) -> str:
"""
Returns the name of the onnx file from which the session was created.
"""
logger.warning(
"The `ORTSessionMixin.model_name` property is deprecated and will be removed in a future version. "
"Please use `ORTSessionMixin.path.name` instead (`ORTSessionMixin.path` is a proper Path object)."
)
return self.path.name
for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype
@property
def providers(self) -> List[str]:
"""
Returns a list of Execution Providers registered with the session.
"""
return self.session.get_providers()
return None
@property
def provider(self) -> str:
"""
Returns the main Execution Provider registered with the session.
"""
return self.providers[0]
def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None):
@property
def provider_options(self) -> Dict[str, Any]:
"""
Returns a dictionary of Execution Providers configurations/options.
"""
return self.session.get_provider_options()
@property
def provider_option(self) -> Dict[str, Any]:
"""
Returns the configuration/options of the main Execution Provider.
"""
return self.provider_options[self.provider]
@property
def device(self) -> torch.device:
"""
Returns the `torch.device` associated with the ONNX Runtime session.
This device is inferred from the provider and provider options.
"""
return self._device
@device.setter
def device(self, *args, **kwargs):
raise AttributeError(
"The device attribute is read-only, please use the `.to(device)` "
"method to change both the device and the execution provider accordingly."
)
@property
def dtype(self) -> torch.dtype:
"""
Returns the `torch.dtype` associated with the ONNX Runtime session.
This dtype is inferred from the input/output dtypes of the session.
If no floating point type is found, it defaults to `torch.float32`.
"""
return self._dtype
@property
def use_io_binding(self) -> Optional[bool]:
"""
Returns whether IO Binding is used or not.
"""
return self._use_io_binding
@use_io_binding.setter
def use_io_binding(self, value: bool):
"""
Sets the IO Binding usage.
"""
if not isinstance(value, bool):
raise ValueError("`use_io_binding` should be a boolean value.")
self._use_io_binding = value
def to(self, *args, **kwargs):
"""
Moves the session to the specified device by updating the execution provider and its options.
Args:
device (`str`, `int`, `torch.device`):
The device to move the session to. It can be a string (e.g., "cuda", "cpu"), an integer (e.g., 0 for GPU 0),
or a `torch.device` object.
Returns:
`ORTSessionMixin`: The updated session.
Raises:
ValueError: If the device is not supported or if the provider is not available.
"""
dtype = None
device = None
for arg in args:
if isinstance(arg, torch.device):
if isinstance(arg, (str, torch.device)):
device = arg
elif isinstance(arg, int):
device = torch.device(arg)
elif isinstance(arg, torch.device):
device = arg
elif isinstance(arg, torch.dtype):
dtype = arg
if device is not None and device != self.device:
raise ValueError(
"Cannot change the device of a model part without changing the device of the parent model. "
"Please use the `to` method of the parent model to change the device."
)
for key, value in kwargs.items():
if key == "device":
device = value
elif key == "dtype":
dtype = value
if dtype is not None and dtype != self.dtype:
raise NotImplementedError(
f"Cannot change the dtype of the model from {self.dtype} to {dtype}. "
f"Please export the model with the desired dtype."
)
if dtype is not None:
# we don't support changing the dtype of the model
return self
@abstractmethod
def forward(self, *args, **kwargs):
pass
if device is None:
# no device was provided, we don't change the device
return self
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
device, provider_option = parse_device(device)
provider = get_provider_for_device(device)
validate_provider_availability(provider)
if device == self.device:
return self
class ORTEncoder(ORTModelPart):
"""
Encoder part of the encoder-decoder model for ONNX Runtime inference.
"""
self.session.set_providers([provider], provider_options=[provider_option])
def __init__(self, session: InferenceSession, parent_model: "ORTModel"):
super().__init__(session, parent_model)
if self.use_io_binding is None:
if self.provider == "CUDAExecutionProvider":
logger.info(
"`use_io_binding` was set to `None` before the provider was changed to CUDAExecutionProvider. "
"Setting `use_io_binding=True` to leverage IO Binding and improve performance. "
"You can disable it by setting `model.use_io_binding=False`."
)
self.use_io_binding = True
config = (
self.parent_model.config.encoder
if hasattr(self.parent_model.config, "encoder")
else self.parent_model.config
)
self._device = device
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
return self
def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, **kwargs) -> BaseModelOutput:
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)
def raise_on_numpy_input_io_binding(self, use_torch: bool):
"""
Raises an error if IO Binding is requested although the tensor used are numpy arrays.
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
Args:
use_torch (`bool`):
Whether the tensor used during inference are of type torch.Tensor or not.
"""
if use_torch is False and self.use_io_binding is True:
raise ValueError(
"IO Binding can not be used when passing numpy inputs. Please disable IO Binding"
" with `model.use_io_binding=False`, or pass `torch.Tensor` inputs instead."
)
if self.parent_model.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs)
def _prepare_onnx_inputs(
self, use_torch: bool, model_inputs: Dict[str, Union[torch.Tensor, np.ndarray]]
) -> Dict[str, np.ndarray]:
"""
Prepares the inputs for ONNX Runtime by converting them to numpy arrays with the expected dtype.
if self.device.type == "cpu":
self.session.run_with_iobinding(io_binding)
Args:
use_torch (`bool`):
Whether the inputs are torch.Tensor or not.
inputs (`Dict[str, Union[torch.Tensor, np.ndarray]]`):
The inputs to prepare for ONNX Runtime.
Returns:
`Dict[str, np.ndarray]`: The inputs prepared for ONNX Runtime.
"""
onnx_inputs = {}
for input_name in self.input_names.keys():
if model_inputs.get(input_name, None) is None:
raise ValueError(f"Input {input_name} is required by model but not provided.")
if use_torch:
onnx_inputs[input_name] = model_inputs[input_name].numpy(force=True)
else:
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
onnx_inputs[input_name] = model_inputs[input_name]
last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
expected_dtype = TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name])
last_hidden_state = model_outputs["last_hidden_state"]
if onnx_inputs[input_name].dtype != expected_dtype:
onnx_inputs[input_name] = onnx_inputs[input_name].astype(expected_dtype)
return BaseModelOutput(last_hidden_state=last_hidden_state)
return onnx_inputs
def _prepare_onnx_outputs(
self, use_torch: bool, onnx_outputs: List[np.ndarray]
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
"""
Prepares the outputs from ONNX Runtime by converting them to torch.Tensor if requested.
class ORTDecoderForSeq2Seq(ORTModelPart):
"""
Decoder model with a language modeling head on top for ONNX Runtime inference.
"""
Args:
use_torch (`bool`):
Whether the outputs should be torch.Tensor or not.
onnx_outputs (`List[np.ndarray]`):
The outputs from ONNX Runtime.
def __init__(
self,
session: InferenceSession,
parent_model: "ORTModel",
):
super().__init__(session, parent_model)
Returns:
`Dict[str, Union[torch.Tensor, np.ndarray]]`: The outputs prepared for the user.
"""
config = (
self.parent_model.config.decoder
if hasattr(self.parent_model.config, "decoder")
else self.parent_model.config
)
model_outputs = {}
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
for output_name, idx in self.output_names.items():
model_outputs[output_name] = onnx_outputs[idx]
# TODO: make this less hacky.
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
if use_torch:
model_outputs[output_name] = torch.from_numpy(model_outputs[output_name]).to(self.device)
# To handle the old case when past_key_values were following the format: past_key_values_{idx}
if len(self.key_value_input_names) == 0:
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
if len(self.key_value_output_names) == 0:
self.key_value_output_names = [key for key in self.output_names if "key_values" in key]
return model_outputs
if self.parent_model.use_cache is True and len(self.key_value_output_names) == 0:
raise RuntimeError("Could not find the past key values in the provided model.")
def _prepare_output_buffer(self, output_name: str, output_shape: Tuple[int]) -> torch.Tensor:
"""
Prepares an output buffer for ONNX Runtime IO Binding.
self.use_past_in_outputs = len(self.key_value_output_names) > 0
self.use_past_in_inputs = len(self.key_value_input_names) > 0
self.use_fp16 = self.dtype == torch.float16
Args:
output_name (`str`):
The name of the output for which to prepare the buffer.
output_shape (`Tuple[int]`):
The shape of the output buffer.
# We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2
# can be used but do not support KV caching for the cross-attention key/values, see:
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L302-L311
# This attribute is used to avoid returning cross-attention KV-cache in this case.
self.no_cross_attention_cache = getattr(self.parent_model, "no_cross_attention_cache", False)
Returns:
`torch.Tensor`: The output buffer.
if (not self.parent_model.use_merged and self.use_past_in_inputs) or self.no_cross_attention_cache:
self.num_pkv = 2
"""
if len(output_shape) == 0:
raise ValueError("`output_shape` should not be empty")
elif not all(isinstance(dim, int) for dim in output_shape):
raise ValueError(f"`output_shape` should only contain integers but got {output_shape}.")
elif not all(dim > 0 for dim in output_shape):
raise ValueError(f"`output_shape` should only contain positive integers but got {output_shape}.")
output_dtype = TypeHelper.ort_type_to_torch_type(self.output_dtypes[output_name])
if len(output_shape) > 0:
output_buffer = torch.empty(np.prod(output_shape), dtype=output_dtype, device=self.device)
else:
# When using a merged model, we always have the same number of output whether we use past key values or not,
# and in the case past key values are used, empty tensors are given as cross-attention past key values as they
# are constants
self.num_pkv = 4
output_buffer = torch.tensor(0, dtype=output_dtype, device=self.device)
self.past_key_values_cross_attention_output_names = set()
for output_name in self.output_names:
if output_name.startswith("present") and "encoder" in output_name:
self.past_key_values_cross_attention_output_names.add(output_name)
return output_buffer
self.use_legacy_outputs = (
self.parent_model.use_merged is False and len(self.past_key_values_cross_attention_output_names) > 0
)
def _output_shape_inference(self, output_name: str, known_axes_values: Dict[str, int]) -> List[int]:
"""
Infers the shape of a given output by using the `known_axes_values` mapping.
def compute_past_key_values_output_shapes(
self,
input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor,
use_cache_branch: Optional[bool],
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> Dict[str, int]:
batch_size = input_ids.size(0)
Args:
output_name (`str`):
The name of the output for which to infer the shape.
known_axes_values (`Dict[str, int]`):
A mapping of the axis names to their values.
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
Returns:
`List[int]`: The inferred shape of the output.
"""
sequence_length = input_ids.size(1)
encoder_sequence_length = encoder_hidden_states.size(1)
if past_key_values is not None and use_cache_branch is not False:
# Here, use_cache_branch may be None in the case of separate decoder without/with past, or True if the with past branch
# of a merged decoder is used
sequence_length += past_key_values[0].size(2)
output_shape = list(self.output_shapes[output_name])
self_attn_shape = (batch_size, num_attention_heads, sequence_length, embed_size_per_head)
for idx, axis_name in enumerate(output_shape):
if isinstance(axis_name, str):
output_shape[idx] = self._dynamic_axis_inference(axis_name, known_axes_values)
if past_key_values is not None and use_cache_branch is True:
cross_attn_shape = (0, num_attention_heads, 1, embed_size_per_head)
else:
cross_attn_shape = (batch_size, num_attention_heads, encoder_sequence_length, embed_size_per_head)
return output_shape
past_key_values_shapes = {}
for idx, name in enumerate(self.key_value_output_names):
is_self_attn = idx % 4 < 2
# decoder with past does not ouput cross attention key/values as they are constants
past_key_values_shapes[name] = self_attn_shape if (is_self_attn or self.num_pkv == 2) else cross_attn_shape
return past_key_values_shapes
def _dynamic_axis_inference(self, axis_name: Union[str], known_axes_values: Dict[str, int]) -> int:
"""
Infers the value of a given dynamic axis by using the `known_axes_values` mapping.
def get_outputs_not_to_bind(self, use_merged_cache: bool) -> Set[str]:
result = {
output_name
for output_name in self.output_names
if (not output_name.startswith("present") and output_name not in {"loss", "logits"})
}
if use_merged_cache is True:
# When using a merged decoder and the use cache branch, we output 0-dim tensors that IO Binding do not support.
# Therefore, we do not bind them.
result = result.union(self.past_key_values_cross_attention_output_names)
return result
For instance, for the following inputs:
axis_name = "sequence_length + past_sequence_length"
known_axes_values = {"batch_size": 2, "sequence_length": 3, "past_sequence_length": 7}
def forward(
The inferred value will be:
3 + 7 = 10
"""
if axis_name in known_axes_values:
# simple case, the axis value is known
return known_axes_values[axis_name]
tokens = axis_name.split(" ")
for idx, token in enumerate(tokens):
if token in known_axes_values:
tokens[idx] = str(known_axes_values[token])
return int(eval(" ".join(tokens)))
def _prepare_io_binding(
self,
input_ids: torch.LongTensor,
encoder_hidden_states: torch.FloatTensor,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Seq2SeqLMOutput:
# Adding use_cache_branch in the signature here is just a hack for IO Binding
model_inputs: Dict[str, torch.Tensor],
outputs_to_not_bind: Optional[Set[str]] = None,
known_output_buffers: Optional[Dict[str, str]] = None,
known_output_shapes: Optional[Dict[str, Tuple[int]]] = None,
) -> Tuple[Dict[str, Tuple[int]], Dict[str, torch.Tensor]]:
"""
Prepares IO binding for ONNX Runtime.
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)
Args:
model_inputs (`Dict[str, torch.Tensor]`):
The inputs to bind to the model.
outputs_to_not_bind (`Optional[Set[str]]`, defaults to `None`):
The names of the outputs that should not be bound.
known_output_buffers (`Optional[Dict[str, str]]`, defaults to `None`):
Sometimes we can reuse the same input buffer for the output. This is the case for the output sample
in a diffusion pipeline. It is possible to explicitely pass the buffer via this argument.
known_output_shapes (`Optional[Dict[str, Tuple[int]]]`, defaults to `None`):
It can be hard to infer all the output shapes from the inputs only. For instance for the past key /
values. It is possible to explicitely pass the shape via this argument.
# Flatten the past_key_values
if past_key_values is not None:
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
Returns:
`TupleDict[str, Tuple[int]], Dict[str, torch.Tensor]`: A dictionary of the output shapes and a dictionary of
the output buffers.
"""
# no-ops if merged decoder is not used
use_merged_no_cache = past_key_values is None and self.parent_model.use_merged
use_merged_cache = past_key_values is not None and self.parent_model.use_merged
use_cache_branch_tensor, past_key_values, cache_position = self.prepare_inputs_for_merged(
input_ids, past_key_values, cache_position, use_torch=use_torch
)
known_axes_values = {}
model_inputs = {
"input_ids": input_ids,
"encoder_hidden_states": encoder_hidden_states,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"cache_position": cache_position,
}
if past_key_values is not None:
model_inputs.update(zip(self.key_value_input_names, past_key_values))
for input_name in self.input_names.keys():
input_shape = model_inputs[input_name].shape
if self.parent_model.use_io_binding:
known_output_shapes = self.compute_past_key_values_output_shapes(
input_ids,
encoder_hidden_states,
use_cache_branch=use_cache_branch_tensor.item() if use_cache_branch_tensor is not None else None,
past_key_values=past_key_values,
)
outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache)
if not model_inputs[input_name].is_contiguous():
model_inputs[input_name] = model_inputs[input_name].contiguous()
io_binding, output_shapes, output_buffers = self._prepare_io_binding(
self.session,
model_inputs,
known_output_shapes=known_output_shapes,
outputs_to_not_bind=outputs_to_not_bind,
tensor_dtype = model_inputs[input_name].dtype
expected_dtype = TypeHelper.ort_type_to_torch_type(self.input_dtypes[input_name])
if tensor_dtype != expected_dtype:
model_inputs[input_name] = model_inputs[input_name].to(expected_dtype)
data_ptr = model_inputs[input_name].data_ptr()
if data_ptr == 0:
# During first generation, sequence_length can be 0 when use_cache=True, which results in data_ptr to also be 0.
# To keep compatibility with IO binding, we pass the data pointer of a non-empty tensor.
# No impact because past_key_values will not be used during the first generation.
data_ptr = NON_EMPTY_TENSOR.data_ptr()
self._io_binding.bind_input(
input_name,
self.device.type,
self.device.index or 0,
TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]),
input_shape,
data_ptr,
)
if self.device.type == "cpu":
self.session.run_with_iobinding(io_binding)
for idx, axis_name in enumerate(self.input_shapes[input_name]):
if isinstance(axis_name, str):
known_axes_values[axis_name] = input_shape[idx]
output_shapes = {}
output_buffers = {}
known_output_shapes = known_output_shapes or {}
known_output_buffers = known_output_buffers or {}
outputs_to_not_bind = outputs_to_not_bind or set()
for output_name in self.output_names.keys():
if output_name in outputs_to_not_bind:
continue
if output_name in known_output_shapes:
output_shape = known_output_shapes[output_name]
else:
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
output_shape = self._output_shape_inference(output_name, known_axes_values)
# Set -1 for sequence_length as it could be larger than the real sequence_length
for name, shape in output_shapes.items():
if name in self.key_value_output_names:
output_shapes[name] = shape[:2] + (-1,) + shape[3:]
if output_name in known_output_buffers:
output_buffer = known_output_buffers[output_name]
else:
output_buffer = self._prepare_output_buffer(output_name, output_shape)
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = ()
for name in self.key_value_output_names:
# TODO: this should be improved
if name in self.past_key_values_cross_attention_output_names and use_merged_cache:
continue
out_past_key_values += (output_buffers[name].view(output_shapes[name]),)
data_ptr = output_buffer.data_ptr()
logits = output_buffers["logits"].view(output_shapes["logits"])
self._io_binding.bind_output(
output_name,
self.device.type,
self.device.index or 0,
TypeHelper.ort_type_to_numpy_type(self.output_dtypes[output_name]),
output_shape,
data_ptr,
)
loss = None
if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
output_buffers[output_name] = output_buffer
output_shapes[output_name] = output_shape
if not self.use_past_in_outputs:
out_past_key_values = None
elif not self.use_past_in_inputs or use_merged_no_cache or self.no_cross_attention_cache:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
else:
if self.use_legacy_outputs is True:
msg = (
"For the decoder with past, using ONNX models outputting cross attention past key values"
" is deprecated and the support will be removed in optimum 2.0. We recommend exporting again the model"
" with optimum>=1.7.3."
)
warn_once(logger, msg=msg)
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
# grab the cross attention key/values from the inputs
elif self.num_pkv == 2:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv]
+ past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
elif self.num_pkv == 4:
# despite num_pkv being 4, we did not bind the cross-attention output
out_past_key_values = tuple(
out_past_key_values[i : i + 2] + past_key_values[2 * i + 2 : 2 * i + 4]
for i in range(0, len(out_past_key_values), 2)
)
else:
raise ValueError("Unsupported num_pkv")
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
return output_shapes, output_buffers
# TODO: using a new variable out_past_key_values is memory inefficient,
# past_key_values is not used anymore at this point
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names)
def forward(self, *args, **kwargs):
raise NotImplementedError(
"The `forward` method should be implemented in the derived class. "
"Please refer to the documentation for more details."
)
loss = model_outputs.get("loss", None)
logits = model_outputs["logits"]
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
# TODO: this is extremely ugly and unreadable. What if cross-attention k/v change?
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)
# * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant)
if not self.use_past_in_outputs:
out_past_key_values = None
elif not self.use_past_in_inputs or use_merged_no_cache or self.no_cross_attention_cache:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
else:
if self.use_legacy_outputs is True:
msg = (
"For the decoder with past, using ONNX models outputting cross attention past key values"
" is deprecated and the support will be removed in optimum 2.0. We recommend exporting again the model"
" with optimum>=1.7.3."
)
warn_once(logger, msg=msg)
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
# grab the cross attention key/values from the inputs
elif self.num_pkv == 2:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv]
+ past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
elif self.num_pkv == 4:
out_past_key_values = tuple(
out_past_key_values[i : i + 2] + past_key_values[i + 2 : i + 4]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
else:
raise ValueError("Unsupported num_pkv")
def save_session(self, save_directory: Union[str, Path]):
"""
Saves the ONNX Runtime session to the specified directory.
return Seq2SeqLMOutput(loss=loss, logits=logits, past_key_values=out_past_key_values)
Args:
save_directory (`Union[str, Path]`):
The directory where to save the ONNX Runtime session.
"""
def prepare_inputs_for_merged(
self,
input_ids: Optional[Union[torch.LongTensor, np.ndarray]],
past_key_values: Optional[Tuple[Union[torch.FloatTensor, np.ndarray]]],
cache_position: Optional[Union[torch.Tensor, np.ndarray]],
use_torch: bool,
):
constructor = torch if use_torch is True else np
os.makedirs(save_directory, exist_ok=True)
if self.parent_model.use_merged:
# Uses without/with branch of a merged decoder depending on whether real past key values are passed
use_cache_branch_tensor = constructor.full((1,), past_key_values is not None)
if use_torch and use_cache_branch_tensor is not None:
use_cache_branch_tensor = use_cache_branch_tensor.to(self.device)
else:
use_cache_branch_tensor = None
model_path = Path(self.session._model_path)
model_save_path = Path(save_directory) / model_path.name
external_data_paths = _get_model_external_data_paths(model_path)
external_data_save_paths = [
Path(save_directory) / external_data_path.name for external_data_path in external_data_paths
]
# Generate dummy past for the first forward if uses a merged decoder
if self.parent_model.use_merged and past_key_values is None:
batch_size = input_ids.shape[0]
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
dtype = constructor.float16 if self.use_fp16 else constructor.float32
shape = (batch_size, num_attention_heads, 1, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)
shutil.copy(model_path, model_save_path)
for src_path, dst_path in zip(external_data_paths, external_data_save_paths):
shutil.copy(src_path, dst_path)
if use_torch is True:
key_or_value = key_or_value.to(self.device)
past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))
class ORTParentMixin:
"""
Wrapper class for multiple ORTSessionMixin instances. This class allows to combine multiple parts into
a single wrapper. It is useful for pipelines/models that require multiple parts to work together, such
as diffusion pipelines or encoder-decoder models, as it provides a unified interface for inference.
"""
# Generate dummy position cache for the first forward if uses a merged decoder
if self.parent_model.use_merged and cache_position is None:
cache_position = constructor.zeros((1,), dtype=constructor.int64)
if use_torch is True:
cache_position = cache_position.to(self.device)
def initialize_ort_attributes(self, parts: List[ORTSessionMixin]):
"""
Initializes the ORTParentMixin class.
Args:
parts (`List[ORTSessionMixin]`):
List of ORTSessionMixin instances to wrap.
"""
return use_cache_branch_tensor, past_key_values, cache_position
if len(parts) < 1:
raise ValueError("ORTParentMixin should be initialized with at least one part.")
if any(not isinstance(model, ORTSessionMixin) for model in parts):
raise ValueError("All parts passed to ORTParentMixin should be ORTSessionMixin instances.")
self.parts = parts
@property
def providers(self):
"""
Returns a list of Execution Providers registered with the session.
"""
if not all(model.providers == self.parts[0].providers for model in self.parts):
logger.warning(
"Calling `ORTParentMixin.providers` when the underlying parts have different values "
"for `providers` is not recommended. The value of the first session will be returned. "
)
return self.parts[0].providers
@property
def provider(self):
"""
Returns the main Execution Provider registered with the session.
"""
if not all(model.provider == self.parts[0].provider for model in self.parts):
logger.warning(
"Calling `ORTParentMixin.provider` when the underlying parts have different values "
"for `provider` is not recommended. The value of the first session will be returned. "
)
return self.parts[0].provider
@property
def provider_options(self):
"""
Returns a dictionary of Execution Providers configurations/options.
"""
if not all(model.provider_options == self.parts[0].provider_options for model in self.parts):
logger.warning(
"Calling `ORTParentMixin.provider_options` when the underlying parts have different values "
"for `provider_options` is not recommended. The value of the first session will be returned. "
)
return self.parts[0].provider_options
@property
def provider_option(self):
"""
Returns the configuration/options of the main Execution Provider.
"""
if not all(model.provider_option == self.parts[0].provider_option for model in self.parts):
logger.warning(
"Calling `ORTParentMixin.provider_option` when the underlying parts have different values "
"for `provider_option` is not recommended. The value of the first session will be returned. "
)
return self.parts[0].provider_option
@property
def device(self):
"""
Returns the `torch.device` associated with the ONNX Runtime session.
This device is inferred from the provider and provider options.
"""
if not all(model.device == self.parts[0].device for model in self.parts):
logger.warning(
"Calling `ORTParentMixin.device` when the underlying parts have different values "
"for `device` is not recommended. The value of the first session will be returned. "
)
return self.parts[0].device
@property
def dtype(self):
"""
Returns the `torch.dtype` associated with the ONNX Runtime session.
This dtype is inferred from the input/output dtypes of the session.
If no floating point type is found, it defaults to `torch.float32`.
"""
if not all(model.dtype == self.parts[0].dtype for model in self.parts):
logger.warning(
"Calling `ORTParentMixin.dtype` when the underlying parts have different values "
"for `dtype` is not recommended. The value of the first session will be returned. "
)
return self.parts[0].dtype
@property
def use_io_binding(self):
"""
Returns whether IO Binding is used or not.
"""
if not all(model.use_io_binding == self.parts[0].use_io_binding for model in self.parts):
logger.warning(
"Calling `ORTParentMixin.use_io_binding` when the underlying parts have different values "
"for `use_io_binding` is not recommended. The value of the first session will be returned. "
)
return self.parts[0].use_io_binding
@use_io_binding.setter
def use_io_binding(self, value: bool):
"""
Setter for the use_io_binding property.
"""
for model in self.parts:
model.use_io_binding = value
def to(self, *args, **kwargs):
"""
Moves all parts to the specified device by updating the execution provider and its options.
Args:
device (`str`, `int`, `torch.device`):
The device to move the session to. It can be a string (e.g., "cuda", "cpu"), an integer (e.g., 0 for GPU 0),
or a `torch.device` object.
Returns:
`ORTParentMixin`: The updated session.
Raises:
ValueError: If the device is not supported or if the provider is not available.
"""
for model in self.parts:
model.to(*args, **kwargs)
return self

@@ -21,5 +21,4 @@ # Copyright 2022 The HuggingFace Team. All rights reserved.

from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union
import numpy as np
import onnx

@@ -34,5 +33,6 @@ import torch

import onnxruntime
from onnxruntime import InferenceSession, SessionOptions
from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export
from ..exporters.tasks import TasksManager
from ..onnx.utils import check_model_uses_external_data

@@ -49,4 +49,3 @@ from ..utils import NormalizedConfigManager, is_transformers_version

from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache
from .utils import ONNX_WEIGHTS_NAME
from .utils import prepare_providers_and_provider_options

@@ -138,73 +137,109 @@

self,
model: onnxruntime.InferenceSession,
config: "PretrainedConfig",
*args,
config: "PretrainedConfig" = None,
session: "InferenceSession" = None,
use_io_binding: Optional[bool] = None,
generation_config: Optional["GenerationConfig"] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
use_cache: Optional[bool] = None,
**kwargs,
):
if use_io_binding is None:
use_io_binding = model.get_providers()[0] in ["CPUExecutionProvider", "CUDAExecutionProvider"]
# DEPRECATED BEHAVIOR
if args:
logger.warning(
"Instantiating an ORTModelForCausalLM with positional arguments is deprecated and will be removed in the next version. "
"Please use the keywords arguments {config, session, use_io_binding, generation_config, model_save_dir, use_cache} instead."
)
# the old signature is ORTModelForCausalLM(model, config, use_io_binding, model_save_dir, preprocessors, generation_config, use_cache)
session = args[0]
if len(args) > 1:
config = args[1]
if len(args) > 2:
use_io_binding = args[2]
if len(args) > 3:
model_save_dir = args[3]
if len(args) > 4:
_ = args[4]
if len(args) > 5:
generation_config = args[5]
if len(args) > 6:
_ = args[6]
super().__init__(model, config, use_io_binding, model_save_dir, preprocessors, **kwargs)
if kwargs.get("model", None) is not None:
logger.warning(
"Passing the inference session as `model` argument to an ORTModelForCausalLM is deprecated. Please use `session` instead."
)
session = kwargs.pop("model")
if kwargs:
logger.warning(
f"Some keyword arguments were passed to the ORTModelForCausalLM constructor that are not part of its signature: {', '.join(kwargs.keys())}. "
"These arguments will be ignored in the current version and will raise an error in the next version."
)
self.num_pkv = 2
if config is None:
raise ValueError(
"The parameter config is required. Please pass a config or use the from_pretrained method."
)
if session is None:
raise ValueError(
"The parameter session is required. Please pass a session or use the from_pretrained method."
)
## END OF DEPRECATED BEHAVIOR
super().__init__(config=config, session=session, use_io_binding=use_io_binding, model_save_dir=model_save_dir)
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
self.use_cache = len(self.key_value_input_names) > 0
if generation_config is None:
generation_config = GenerationConfig.from_model_config(config)
self.can_use_cache = len(self.key_value_input_names) > 0 and len(self.key_value_output_names) > 0
self.is_merged = "use_cache_branch" in self.input_names
self.generation_config = generation_config
if is_transformers_version(">=", "4.44.99"):
misplaced_generation_parameters = self.config._get_non_default_generation_parameters()
if len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
"generation parameters in the model config, as opposed to in the generation config.",
)
for param_name, param_value in misplaced_generation_parameters.items():
setattr(self.generation_config, param_name, param_value)
setattr(self.config, param_name, None)
self.onnx_paths = [self.model_path]
self.use_merged = "use_cache_branch" in self.input_names
self.model_type = self.config.model_type
self.use_fp16 = False
for inp in model.get_inputs():
if (
inp.name == "past_key_values" or inp.name in self.key_value_input_names
) and inp.type == "tensor(float16)":
self.use_fp16 = True
break
# Reference: https://github.com/huggingface/optimum/pull/1381
model_type = config.model_type.replace("_", "-")
model_type = self.config.model_type.replace("_", "-")
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.input_names:
logger.warning(
f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. "
"We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support."
f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although the model type {model_type} "
"requires it. for correct batched generation. We strongly encourage to re-export the model with "
"a newer version of Optimum for better performance and more reliable generation. "
)
if use_cache ^ self.use_cache:
raise ValueError(
f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. "
f"Please load your current model with `use_cache={self.use_cache}` or export the original model "
f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. "
"To export your model, simply set `export=True`."
if not self.can_use_cache and self.generation_config.use_cache:
logger.warning(
"`model.generation_config.use_cache=True` but the loaded model does not support using the past key values cache."
"Please re-export the original model once again with `use_cache=True` to be able to use it during generation. "
"Or set `model.generation_config.use_cache=False` to avoid errors from attempting to use the cache. "
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
)
if use_io_binding and not use_cache:
raise ValueError(
"The parameters combination use_cache=False, use_io_binding=True is not supported. "
"Please either pass use_cache=True, use_io_binding=True (default), or use_cache=False, use_io_binding=False."
if self.config.model_type == "gemma":
self.embed_size_per_head = self.normalized_config.head_dim
else:
self.embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
if self.config.model_type in {"gemma", "mistral", "llama", "qwen2", "qwen3", "qwen3_moe", "granite"}:
self.num_key_value_heads = self.normalized_config.num_key_value_heads
elif self.config.model_type == "falcon":
self.num_key_value_heads = (
self.config.num_kv_heads
if (self.config.new_decoder_architecture or not self.config.multi_query)
else 1
)
else:
self.num_key_value_heads = self.normalized_config.num_attention_heads
@property
def use_cache(self):
logger.warning(
"The `ORTModelForCausalLM.use_cache` property is deprecated and will be removed in a future version. "
"Please rather use `ORTModelForCausalLM.can_use_cache` to check if a model supports using cache during generation. "
"And use `ORTModelForCausalLM.generation_config.use_cache` to check if the model is configured to use cache during generation."
)
return self.can_use_cache
@property
def use_merged(self):
logger.warning(
"The `ORTModelForCausalLM.use_merged` property is deprecated and will be removed in a future version. "
"Please rather use `ORTModelForCausalLM.is_merged` to check if the underlying model is merged or not."
)
return self.is_merged
@add_start_docstrings_to_model_forward(

@@ -221,34 +256,52 @@ CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length")

input_ids: torch.LongTensor,
attention_mask: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
use_cache_branch: bool = None,
use_cache: Optional[bool] = None,
**kwargs,
) -> CausalLMOutputWithPast:
# adding use_cache_branch in the signature here is just a hack for IO Binding
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)
use_cache = use_cache if use_cache is not None else self.config.use_cache
known_output_shapes = {}
if use_cache and not self.can_use_cache:
raise ValueError(
f"`use_cache={use_cache}` was passed to the model but the loaded model only supports `use_cache={self.can_use_cache}`. "
f"Please load your current model with `use_cache={self.can_use_cache}` or export the original model "
f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. "
"To re-export your model, simply set `export=True` in the `from_pretrained` method."
)
if self.use_cache:
if past_key_values is not None:
# Flatten the past_key_values (gpt_bigcode has fused key/value cache, so no need to flatten it)
if self.model_type != "gpt_bigcode":
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
if past_key_values is not None and isinstance(past_key_values[0], tuple):
# Flattens the past_key_values to a single tuple
past_key_values = sum(past_key_values, ())
# Create dummy past_key_values for decoder first generation step if none given
use_cache_branch, past_key_values, known_output_shapes = self.prepare_past_key_values(
input_ids, past_key_values, use_torch
)
if "position_ids" in self.input_names and position_ids is None:
if attention_mask is not None:
# Create position_ids from attention_mask
position_ids = attention_mask.cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values is not None:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
raise ValueError(
"The model requires position_ids for batched generation but none were provided. "
"Please provide position_ids or attention_mask (from which position_ids can be inferred)."
)
# Create position_ids on the fly for batch generation
if "position_ids" in self.input_names and position_ids is None and attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
use_cache_branch = None
if self.is_merged:
# Uses cache branch of merged decoders depending on whether real past key values are passed
use_cache_branch = torch.full((1,), past_key_values is not None, dtype=torch.bool, device=self.device)
if past_key_values is None and len(self.key_value_input_names) > 0:
# Generates the input pkv for the first forward of the model (merged or with past)
batch_size, seq_len = input_ids.shape
if self.config.model_type == "gpt_bigcode":
shape = (batch_size, 0, self.embed_size_per_head * 2)
else:
shape = (batch_size, self.num_key_value_heads, 0, self.embed_size_per_head)
tensor = torch.empty(shape, dtype=self.dtype, device=self.device)
past_key_values = tuple(tensor for _ in range(len(self.key_value_input_names)))
model_inputs = {

@@ -260,19 +313,34 @@ "input_ids": input_ids,

}
if len(self.key_value_input_names) > 0:
model_inputs.update(zip(self.key_value_input_names, past_key_values))
if past_key_values is not None:
model_inputs.update(
zip(self.key_value_input_names, past_key_values),
)
known_output_shapes = None
outputs_to_not_bind = None
if use_cache:
# Infers the shape of the output pkv
batch_size, seq_len = input_ids.shape
if self.config.model_type == "gpt_bigcode":
pkv_seq_len, embed_size_per_head_2 = past_key_values[0].shape[1:]
pkv_output_shape = (batch_size, pkv_seq_len + seq_len, embed_size_per_head_2)
else:
num_key_value_heads, pkv_seq_len, embed_size_per_head = past_key_values[0].shape[1:]
pkv_output_shape = (batch_size, num_key_value_heads, pkv_seq_len + seq_len, embed_size_per_head)
known_output_shapes = dict.fromkeys(self.key_value_output_names, pkv_output_shape)
else:
# Don't bind the output pkv if not used/returned
outputs_to_not_bind = self.key_value_output_names
if self.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(
self.model, model_inputs, known_output_shapes=known_output_shapes
output_shapes, output_buffers = self._prepare_io_binding(
model_inputs,
outputs_to_not_bind=outputs_to_not_bind,
known_output_shapes=known_output_shapes,
)
if self.device.type == "cpu":
self.model.run_with_iobinding(io_binding)
self.session.run_with_iobinding(self._io_binding)
else:
io_binding.synchronize_inputs()
self.model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()

@@ -282,115 +350,83 @@ loss = output_buffers.get("loss", None)

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2 for the self-attention)
if use_cache:
past_key_values = tuple(
output_buffers[name].view(output_shapes[name]) for name in self.key_value_output_names
output_buffers.pop(name).view(output_shapes[name]) for name in self.key_value_output_names
)
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
loss = model_outputs.get("loss", None)
logits = model_outputs["logits"]
loss = model_outputs.pop("loss", None)
logits = model_outputs.pop("logits")
if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention)
past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names)
if use_cache:
past_key_values = tuple(model_outputs.pop(name) for name in self.key_value_output_names)
if self.use_cache and self.model_type != "gpt_bigcode":
if use_cache and self.config.model_type != "gpt_bigcode":
# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and per decoder layer
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
past_key_values = tuple(past_key_values[i : i + 2] for i in range(0, len(past_key_values), 2))
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values)
def prepare_past_key_values(
def prepare_inputs_for_generation(self, *args, **kwargs):
if is_transformers_version("<", "4.46.0"):
return self._prepare_inputs_for_generation_legacy(*args, **kwargs)
else:
return super().prepare_inputs_for_generation(*args, **kwargs)
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def _prepare_inputs_for_generation_legacy(
self,
input_ids: Union[None, torch.LongTensor, np.ndarray],
past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]],
use_torch: bool,
input_ids,
attention_mask=None,
past_key_values=None,
token_type_ids=None,
position_ids=None,
use_cache=None,
**kwargs,
):
sequence_length = input_ids.shape[1]
constructor = torch if use_torch else np
if self.use_merged:
# Uses without/with branch of a merged decoder depending on whether real past key values are passed
use_cache_branch = constructor.full((1,), past_key_values is not None)
else:
# Uses separate decoders
use_cache_branch = None
if use_torch and use_cache_branch is not None:
use_cache_branch = use_cache_branch.to(self.device)
pkv_output_shape = {}
# Generate dummy past for the first forward if uses a merged decoder
if past_key_values is None:
batch_size = input_ids.shape[0]
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
if self.model_type == "gemma":
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.head_dim
elif self.model_type in {"mistral", "llama", "qwen2", "granite"}:
num_attention_heads = self.normalized_config.num_key_value_heads
if past_key_values is not None:
if self.config.model_type == "gpt_bigcode":
if self.config.multi_query:
past_length = past_key_values[0].shape[1]
else:
past_length = past_key_values[0].shape[2]
else:
num_attention_heads = self.normalized_config.num_attention_heads
past_length = past_key_values[0][0].shape[2]
dtype = constructor.float16 if self.use_fp16 else constructor.float32
# TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY.
if self.__class__.__name__ == "ORTBloomForCausalLM":
shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head)
shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0)
key = constructor.zeros(shape_key, dtype=dtype)
value = constructor.zeros(shape_value, dtype=dtype)
if use_torch:
key = key.to(self.device)
value = value.to(self.device)
past_key_values = tuple(
key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value]
)
for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
index = 1 if "value" in name else 2
shape[index] += sequence_length
pkv_output_shape[name] = shape
elif self.model_type == "gpt_bigcode":
# GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor.
shape_key_and_value = (batch_size, 0, embed_size_per_head * 2)
key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype)
if use_torch:
key_and_value = key_and_value.to(self.device)
past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names)))
for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
shape[1] += sequence_length
pkv_output_shape[name] = shape
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads
shape = (batch_size, num_key_value_heads, 0, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if use_torch:
key_or_value = key_or_value.to(self.device)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
"use_cache": use_cache,
}
past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
if isinstance(past_key_values, tuple) and isinstance(past_key_values[0], tuple):
# GPT2 style
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)
elif isinstance(past_key_values, tuple) and isinstance(past_key_values[0], torch.Tensor):
# GPT BigCode style
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
else:
raise ValueError(
f"Unexpected past_key_values: {past_key_values}. "
"Expected tuple of tuples (GPT2 style) or tuple of tensors (GPT BigCode style)."
)
for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
shape[2] += sequence_length
pkv_output_shape[name] = shape
return use_cache_branch, past_key_values, pkv_output_shape
@classmethod

@@ -401,29 +437,25 @@ def _from_pretrained(

config: "PretrainedConfig",
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
# hub options
subfolder: str = "",
revision: str = "main",
force_download: bool = False,
local_files_only: bool = False,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
# file options
file_name: Optional[str] = None,
subfolder: str = "",
# session options
provider: str = "CPUExecutionProvider",
providers: Optional[Sequence[str]] = None,
provider_options: Optional[Union[Sequence[Dict[str, Any]], Dict[str, Any]]] = None,
session_options: Optional[SessionOptions] = None,
# inference options
use_cache: bool = True,
local_files_only: bool = False,
use_merged: Optional[bool] = None,
provider: str = "CPUExecutionProvider",
session_options: Optional[onnxruntime.SessionOptions] = None,
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
generation_config: Optional[GenerationConfig] = None,
# other arguments
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
**kwargs,
) -> "ORTModelForCausalLM":
generation_config = kwargs.pop("generation_config", None)
# We do not implement the logic for use_cache=False, use_merged=True
if use_cache is False:
if use_merged is True:
raise ValueError(
"The parameters combination use_cache=False, use_merged=True is not supported."
" To use a merged decoder, past key values must be used."
)
use_merged = False
onnx_files = find_files_matching_pattern(

@@ -504,4 +536,8 @@ model_id,

)
new_model_save_dir = Path(model_cache_path).parent
# model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it
# instead of the path only.
if model_save_dir is None:
model_save_dir = Path(model_cache_path).parent
try:

@@ -523,13 +559,7 @@ cached_file(

# model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it
# instead of the path only.
if model_save_dir is None:
model_save_dir = new_model_save_dir
# This should be removed at some point
onnx_model = onnx.load(str(model_cache_path), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)
if model_uses_external_data:
onnx_model = onnx.load(str(model_cache_path), load_external_data=True)
input_dims = {

@@ -543,5 +573,3 @@ node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]

}
override_dims = False
# Since v1.7.0 decoder with past models have fixed sequence length of 1

@@ -553,3 +581,2 @@ # To keep these models compatible we set this dimension to dynamic

override_dims = True
# Since https://github.com/huggingface/optimum/pull/871/

@@ -561,3 +588,2 @@ # changed axis notation/naming during export, we need to update the dims

override_dims = True
if override_dims:

@@ -582,32 +608,11 @@ # this is kinda dangerous, warning the user is the least we can do

)
# Since transformers 4.44, the bloom model has been updated to use the standard cache format
use_old_bloom_modeling = not is_transformers_version(">=", "4.44")
for input_name in input_dims.keys():
if input_dims[input_name][0] == "batch_size x num_heads":
use_old_bloom_modeling = True
del onnx_model
model = ORTModel.load_model(
model_cache_path,
provider=provider,
session_options=session_options,
provider_options=provider_options,
)
# Important: for encoder-decoder models used with CausalLM, we need to set the is_decoder flag to True
# and the is_encoder_decoder flag to False. This is needed for the model to work correctly with generation logic.
if hasattr(config, "is_decoder"):
config.is_decoder = True
if hasattr(config, "is_encoder_decoder"):
config.is_encoder_decoder = False
if config.model_type == "bloom" and use_old_bloom_modeling:
init_cls = ORTBloomForCausalLM
elif config.model_type == "falcon":
init_cls = ORTFalconForCausalLM
elif config.model_type == "mpt":
init_cls = ORTMPTForCausalLM
# if model was exported with position_ids it means the model was exported with transformers >= v4.46
elif config.model_type == "opt" and "position_ids" not in input_dims:
init_cls = ORTOPTForCausalLM
elif config.model_type == "gpt_bigcode":
init_cls = ORTGPTBigCodeForCausalLM
else:
init_cls = ORTModelForCausalLM
if generation_config is None:

@@ -617,55 +622,76 @@ try:

model_id,
token=token,
revision=revision,
subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
logger.info("Generation config file not found, creating a new one from model config.")
generation_config = GenerationConfig.from_model_config(config)
# TODO: not sure if setting config.use_cache is needed for older versions of transformers
generation_config.use_cache = use_cache
config.use_cache = use_cache
if is_transformers_version(">=", "4.45.0"):
misplaced_generation_parameters = config._get_non_default_generation_parameters()
if len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
"generation parameters in the model config, as opposed to in the generation config.",
)
for param_name, param_value in misplaced_generation_parameters.items():
setattr(generation_config, param_name, param_value)
setattr(config, param_name, None)
return init_cls(
model=model,
providers, provider_options = prepare_providers_and_provider_options(
provider=provider, providers=providers, provider_options=provider_options
)
session = InferenceSession(
model_cache_path,
providers=providers,
provider_options=provider_options,
sess_options=session_options,
)
return cls(
config=config,
session=session,
use_io_binding=use_io_binding,
generation_config=generation_config,
model_save_dir=model_save_dir,
use_cache=use_cache,
generation_config=generation_config,
)
@classmethod
def _from_transformers(
def _export(
cls,
model_id: str,
model_id: Union[str, Path],
config: "PretrainedConfig",
token: Optional[Union[bool, str]] = None,
# hub options
subfolder: str = "",
revision: str = "main",
force_download: bool = True,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
subfolder: str = "",
force_download: bool = False,
local_files_only: bool = False,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
# inference options
use_cache: bool = True,
use_merged: bool = False,
provider: str = "CPUExecutionProvider",
session_options: Optional[onnxruntime.SessionOptions] = None,
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
task: Optional[str] = None,
**kwargs,
) -> "ORTModelForCausalLM":
file_name = ONNX_WEIGHTS_NAME
# this is garanteed to work since we it uses a mapping from model classes to task names
# instead of relying on the hub metadata or the model configuration
task = TasksManager._infer_task_from_model_or_model_class(model_class=cls.auto_model_class)
if use_cache:
task += "-with-past"
if use_merged:
logger.warning("The `use_merged` argument is deprecated when the model is exported, and not used anymore.")
use_merged = False
if kwargs.get("task", None) is not None:
raise ValueError(
f"The `task` argument is not needed when exporting a model with `{cls.__name__}`. "
f"The `task` is automatically inferred from the class as `{task}`."
)
if task is None:
task = cls._auto_model_to_task(cls.auto_model_class)
if use_cache:
task += "-with-past"
save_dir = TemporaryDirectory()

@@ -695,291 +721,15 @@ save_dir_path = Path(save_dir.name)

use_cache=use_cache,
use_merged=use_merged,
provider=provider,
session_options=session_options,
provider_options=provider_options,
use_io_binding=use_io_binding,
model_save_dir=save_dir,
file_name=file_name,
)
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
def _save_pretrained(self, save_directory: Union[str, Path]):
super()._save_pretrained(save_directory)
self.generation_config.save_pretrained(save_directory)
class ORTGPTBigCodeForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
# Omit tokens covered by past_key_values
if past_key_values:
if self.config.multi_query:
past_length = past_key_values[0].shape[1]
else:
past_length = past_key_values[0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
)
return model_inputs
# Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
class ORTBloomForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)
# only last token for input_ids if past is not None
if past_key_values:
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = bloom_convert_to_bloom_cache(past_key_values)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
standardized_past = bloom_convert_to_standard_cache(past, batch_size=len(beam_idx))
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
)
return bloom_convert_to_bloom_cache(reordered_past)
class ORTOPTForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}
class ORTMPTForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}
class ORTFalconForCausalLM(ORTModelForCausalLM):
def __init__(
self,
model: onnxruntime.InferenceSession,
config: "PretrainedConfig",
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
use_cache: Optional[bool] = None,
**kwargs,
):
super().__init__(
model=model,
config=config,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
generation_config=generation_config,
use_cache=use_cache,
**kwargs,
)
self.num_key_value_heads = (
config.num_kv_heads if (config.new_decoder_architecture or not config.multi_query) else 1
)
self.use_alibi = config.alibi
# Copied from transformers.models.falcon.modeling_falcon.FalconForCausalLM._reorder_cache
def _reorder_cache(
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
def _save_config(self, save_directory):
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Save the model and generation configs to the specified directory.
Output shares the same memory storage as `past`.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the model and generation configs will be saved.
"""
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in past
)
return reordered_past
# Adapted from transformers.models.falcon.modeling_falcon.FalconForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.use_alibi and attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
return {
"input_ids": input_ids,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
self.config.save_pretrained(save_directory)
self.generation_config.save_pretrained(save_directory)

@@ -19,8 +19,6 @@ # Copyright 2023 The HuggingFace Team. All rights reserved.

import os
import shutil
from abc import abstractmethod
from collections import OrderedDict
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union

@@ -47,4 +45,4 @@ import numpy as np

from diffusers.utils.constants import CONFIG_NAME
from huggingface_hub import HfApi
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from huggingface_hub import HfApi, create_repo
from huggingface_hub.utils import validate_hf_hub_args

@@ -56,7 +54,5 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer

import onnxruntime as ort
from optimum.utils import is_diffusers_version
from onnxruntime import InferenceSession, SessionOptions
from ..exporters.onnx import main_export
from ..onnx.utils import _get_model_external_data_paths
from ..utils import (

@@ -70,12 +66,8 @@ DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,

DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
)
from .io_binding import TypeHelper
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .utils import (
DIFFUSION_PIPELINE_CONFIG_FILE_NAME,
ONNX_WEIGHTS_NAME,
get_provider_for_device,
np_to_pt_generators,
parse_device,
validate_provider_availability,
is_diffusers_version,
)
from .base import ORTParentMixin, ORTSessionMixin
from .utils import get_device_for_provider, np_to_pt_generators, prepare_providers_and_provider_options

@@ -93,6 +85,7 @@

# TODO: support from_pipe()
# TODO: Instead of ORTModel, it makes sense to have a compositional ORTMixin
# TODO: instead of one bloated __init__, we should consider an __init__ per pipeline
class ORTDiffusionPipeline(ORTModel, DiffusionPipeline):
config_name = "model_index.json"
class ORTDiffusionPipeline(ORTParentMixin, DiffusionPipeline):
config_name = DIFFUSION_PIPELINE_CONFIG_FILE_NAME
task = "auto"
library = "diffusers"
auto_model_class = DiffusionPipeline

@@ -102,12 +95,13 @@

self,
scheduler: "SchedulerMixin",
vae_decoder_session: ort.InferenceSession,
# optional pipeline models
unet_session: Optional[ort.InferenceSession] = None,
transformer_session: Optional[ort.InferenceSession] = None,
vae_encoder_session: Optional[ort.InferenceSession] = None,
text_encoder_session: Optional[ort.InferenceSession] = None,
text_encoder_2_session: Optional[ort.InferenceSession] = None,
text_encoder_3_session: Optional[ort.InferenceSession] = None,
# optional pipeline submodels
*,
# pipeline models
unet_session: Optional["InferenceSession"] = None,
transformer_session: Optional["InferenceSession"] = None,
vae_decoder_session: Optional["InferenceSession"] = None,
vae_encoder_session: Optional["InferenceSession"] = None,
text_encoder_session: Optional["InferenceSession"] = None,
text_encoder_2_session: Optional["InferenceSession"] = None,
text_encoder_3_session: Optional["InferenceSession"] = None,
# pipeline submodels
scheduler: Optional["SchedulerMixin"] = None,
tokenizer: Optional["CLIPTokenizer"] = None,

@@ -126,22 +120,57 @@ tokenizer_2: Optional["CLIPTokenizer"] = None,

):
self.unet = ORTModelUnet(unet_session, self) if unet_session is not None else None
self.transformer = ORTModelTransformer(transformer_session, self) if transformer_session is not None else None
# We initialize all ort session mixins first
self.unet = ORTUnet(unet_session, self, use_io_binding) if unet_session is not None else None
self.transformer = (
ORTTransformer(transformer_session, self, use_io_binding) if transformer_session is not None else None
)
self.text_encoder = (
ORTModelTextEncoder(text_encoder_session, self) if text_encoder_session is not None else None
ORTTextEncoder(text_encoder_session, self, use_io_binding) if text_encoder_session is not None else None
)
self.text_encoder_2 = (
ORTModelTextEncoder(text_encoder_2_session, self) if text_encoder_2_session is not None else None
ORTTextEncoder(text_encoder_2_session, self, use_io_binding)
if text_encoder_2_session is not None
else None
)
self.text_encoder_3 = (
ORTModelTextEncoder(text_encoder_3_session, self) if text_encoder_3_session is not None else None
ORTTextEncoder(text_encoder_3_session, self, use_io_binding)
if text_encoder_3_session is not None
else None
)
# We wrap the VAE Decoder & Encoder in a single object to simulate diffusers API
self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) if vae_encoder_session is not None else None
self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) if vae_decoder_session is not None else None
self.vae = ORTWrapperVae(self.vae_encoder, self.vae_decoder)
self.vae_encoder = (
ORTVaeEncoder(vae_encoder_session, self, use_io_binding) if vae_encoder_session is not None else None
)
self.vae_decoder = (
ORTVaeDecoder(vae_decoder_session, self, use_io_binding) if vae_decoder_session is not None else None
)
# We register ort session mixins to the wrapper
super().initialize_ort_attributes(
parts=list(
filter(
None,
{
self.unet,
self.transformer,
self.vae_encoder,
self.vae_decoder,
self.text_encoder,
self.text_encoder_2,
self.text_encoder_3,
},
)
)
)
# We wrap the VAE Encoder & Decoder in a single object for convenience
self.vae = (
ORTVae(self.vae_encoder, self.vae_decoder)
if self.vae_encoder is not None or self.vae_decoder is not None
else None
)
# we allow passing these as torch models for now
self.image_encoder = kwargs.pop("image_encoder", None) # TODO: maybe implement ORTModelImageEncoder
self.safety_checker = kwargs.pop("safety_checker", None) # TODO: maybe implement ORTModelSafetyChecker
self.image_encoder = kwargs.pop("image_encoder", None) # TODO: maybe implement ORTImageEncoder
self.safety_checker = kwargs.pop("safety_checker", None) # TODO: maybe implement ORTSafetyChecker
# We register the submodels to the pipeline
self.scheduler = scheduler

@@ -153,2 +182,3 @@ self.tokenizer = tokenizer

# We initialize diffusers pipeline specific attributes (registers modules and config)
all_pipeline_init_args = {

@@ -172,3 +202,2 @@ "vae": self.vae,

}
diffusers_pipeline_args = {}

@@ -178,128 +207,170 @@ for key in inspect.signature(self.auto_model_class).parameters.keys():

diffusers_pipeline_args[key] = all_pipeline_init_args[key]
# inits diffusers pipeline specific attributes (registers modules and config)
self.auto_model_class.__init__(self, **diffusers_pipeline_args)
# inits ort specific attributes
self.shared_attributes_init(
model=unet_session if unet_session is not None else transformer_session,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
**kwargs,
)
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
# would end-up removing the directory containing the underlying ONNX model (and thus failing inference).
self.model_save_dir = model_save_dir
def _save_pretrained(self, save_directory: Union[str, Path]):
save_directory = Path(save_directory)
models_to_save_paths = {
(self.unet, save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER),
(self.transformer, save_directory / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER),
(self.vae_decoder, save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER),
(self.vae_encoder, save_directory / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER),
(self.text_encoder, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER),
(self.text_encoder_2, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER),
(self.text_encoder_3, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER),
@property
def components(self) -> Dict[str, Optional[Union[ORTSessionMixin, torch.nn.Module]]]:
# TODO: all components should be ORTSessionMixin's at some point
components = {
"vae": self.vae,
"unet": self.unet,
"transformer": self.transformer,
"text_encoder": self.text_encoder,
"text_encoder_2": self.text_encoder_2,
"text_encoder_3": self.text_encoder_3,
"safety_checker": self.safety_checker,
"image_encoder": self.image_encoder,
}
for model, save_path in models_to_save_paths:
if model is not None:
model_path = Path(model.session._model_path)
save_path.mkdir(parents=True, exist_ok=True)
# copy onnx model
shutil.copyfile(model_path, save_path / ONNX_WEIGHTS_NAME)
# copy external onnx data
external_data_paths = _get_model_external_data_paths(model_path)
for external_data_path in external_data_paths:
shutil.copyfile(external_data_path, save_path / external_data_path.name)
# copy model config
config_path = model_path.parent / CONFIG_NAME
if config_path.is_file():
config_save_path = save_path / CONFIG_NAME
shutil.copyfile(config_path, config_save_path)
components = {k: v for k, v in components.items() if v is not None}
return components
self.scheduler.save_pretrained(save_directory / "scheduler")
def to(self, device: Union[torch.device, str, int]):
"""
Changes the device of the pipeline components to the specified device.
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory / "tokenizer")
if self.tokenizer_2 is not None:
self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2")
if self.tokenizer_3 is not None:
self.tokenizer_3.save_pretrained(save_directory / "tokenizer_3")
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory / "feature_extractor")
Args:
device (`torch.device` or `str` or `int`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run
the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too.
Returns:
`ORTDiffusionPipeline`: The pipeline with the updated device.
"""
for component in self.components.values():
if isinstance(component, (ORTSessionMixin, ORTParentMixin)):
component.to(device)
return self
@classmethod
def _from_pretrained(
def from_pretrained(
cls,
model_id: Union[str, Path],
config: Dict[str, Any],
subfolder: str = "",
force_download: bool = False,
local_files_only: bool = False,
revision: Optional[str] = None,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
unet_file_name: str = ONNX_WEIGHTS_NAME,
transformer_file_name: str = ONNX_WEIGHTS_NAME,
vae_decoder_file_name: str = ONNX_WEIGHTS_NAME,
vae_encoder_file_name: str = ONNX_WEIGHTS_NAME,
text_encoder_file_name: str = ONNX_WEIGHTS_NAME,
text_encoder_2_file_name: str = ONNX_WEIGHTS_NAME,
text_encoder_3_file_name: str = ONNX_WEIGHTS_NAME,
model_name_or_path: Union[str, Path],
# export options
export: bool = False,
# session options
provider: str = "CPUExecutionProvider",
providers: Optional[Sequence[str]] = None,
provider_options: Optional[Union[Sequence[Dict[str, Any]], Dict[str, Any]]] = None,
session_options: Optional[SessionOptions] = None,
# inference options
use_io_binding: Optional[bool] = None,
provider: str = "CPUExecutionProvider",
provider_options: Optional[Dict[str, Any]] = None,
session_options: Optional[ort.SessionOptions] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
# hub options and preloaded models
**kwargs,
):
if use_io_binding:
raise ValueError(
"IOBinding is not yet available for diffusion pipelines, please set `use_io_binding` to False."
"""
Instantiates a [`ORTDiffusionPipeline`] with ONNX Runtime sessions from a pretrained model.
This method can be used to load a model from the Hugging Face Hub or from a local directory.
Args:
model_name_or_path (`str` or `os.PathLike`):
Path to a folder containing the model files or a hub repository id.
export (`bool`, *optional*, defaults to `False`):
Whether to export the model to ONNX format. If set to `True`, the model will be exported and saved
in the specified directory.
provider (`str`, *optional*, defaults to `"CPUExecutionProvider"`):
The execution provider for ONNX Runtime. Can be `"CUDAExecutionProvider"`, `"DmlExecutionProvider"`,
etc.
providers (`Sequence[str]`, *optional*):
A list of execution providers for ONNX Runtime. Overrides `provider`.
provider_options (`Union[Sequence[Dict[str, Any]], Dict[str, Any]]`, *optional*):
Options for each execution provider. Can be a single dictionary for the first provider or a list of
dictionaries for each provider. The order of the dictionaries should match the order of the providers.
session_options (`SessionOptions`, *optional*):
Options for the ONNX Runtime session. Can be used to set optimization levels, graph optimization,
etc.
use_io_binding (`bool`, *optional*):
Whether to use IOBinding for the ONNX Runtime session. If set to `True`, it will use IOBinding for
input and output tensors.
**kwargs:
Can include the following:
- Export arguments (e.g., `slim`, `dtype`, `device`, `no_dynamic_axes`, etc.).
- Hugging Face Hub arguments (e.g., `revision`, `cache_dir`, `force_download`, etc.).
- Preloaded models or sessions for the different components of the pipeline (e.g., `vae_encoder_session`,
`vae_decoder_session`, `unet_session`, `transformer_session`, `image_encoder`, `safety_checker`, etc.).
Returns:
[`ORTDiffusionPipeline`]: The loaded pipeline with ONNX Runtime sessions.
"""
providers, provider_options = prepare_providers_and_provider_options(
provider=provider, providers=providers, provider_options=provider_options
)
hub_kwargs = {
"force_download": kwargs.get("force_download", False),
"resume_download": kwargs.get("resume_download", None),
"local_files_only": kwargs.get("local_files_only", False),
"cache_dir": kwargs.get("cache_dir", None),
"revision": kwargs.get("revision", None),
"proxies": kwargs.get("proxies", None),
"token": kwargs.get("token", None),
}
# get the pipeline config
config = cls.load_config(model_name_or_path, **hub_kwargs)
config = config[0] if isinstance(config, tuple) else config
model_save_tmpdir = None
model_save_path = Path(model_name_or_path)
# export the model if requested
if export:
model_save_tmpdir = TemporaryDirectory()
model_save_path = Path(model_save_tmpdir.name)
export_kwargs = {
"slim": kwargs.pop("slim", False),
"dtype": kwargs.pop("dtype", None),
"device": get_device_for_provider(provider, {}).type,
"no_dynamic_axes": kwargs.pop("no_dynamic_axes", False),
}
main_export(
model_name_or_path=model_name_or_path,
# export related arguments
output=model_save_path,
no_post_process=True,
do_validation=False,
task=cls.task,
# export related arguments
**export_kwargs,
# hub related arguments
**hub_kwargs,
)
if not os.path.isdir(str(model_id)):
# download the model if needed
if not model_save_path.is_dir():
# everything in components subfolders
all_components = {key for key in config.keys() if not key.startswith("_")} | {"vae_encoder", "vae_decoder"}
allow_patterns = {os.path.join(component, "*") for component in all_components}
# plus custom file names
allow_patterns.update(
{
unet_file_name,
transformer_file_name,
vae_decoder_file_name,
vae_encoder_file_name,
text_encoder_file_name,
text_encoder_2_file_name,
text_encoder_3_file_name,
ONNX_WEIGHTS_NAME,
DIFFUSION_PIPELINE_CONFIG_FILE_NAME,
SCHEDULER_CONFIG_NAME,
cls.config_name,
CONFIG_NAME,
}
)
model_save_folder = HfApi(user_agent=http_user_agent(), token=token).snapshot_download(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
model_save_folder = HfApi(user_agent=http_user_agent()).snapshot_download(
repo_id=str(model_name_or_path),
allow_patterns=allow_patterns,
ignore_patterns=["*.msgpack", "*.safetensors", "*.bin", "*.xml"],
allow_patterns=allow_patterns,
**hub_kwargs,
)
else:
model_save_folder = str(model_id)
model_save_path = Path(model_save_folder)
model_save_path = Path(model_save_folder)
if model_save_dir is None:
model_save_dir = model_save_path
model_paths = {
"unet": model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name,
"transformer": model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER / transformer_file_name,
"vae_decoder": model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name,
"vae_encoder": model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
"text_encoder": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name,
"text_encoder_2": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name,
"text_encoder_3": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER / text_encoder_3_file_name,
"unet": model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME,
"transformer": model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER / ONNX_WEIGHTS_NAME,
"vae_encoder": model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
"vae_decoder": model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
"text_encoder": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
"text_encoder_2": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / ONNX_WEIGHTS_NAME,
"text_encoder_3": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER / ONNX_WEIGHTS_NAME,
}
models = {}
sessions = {}

@@ -309,6 +380,12 @@ for model, path in model_paths.items():

# this allows passing a model directly to from_pretrained
sessions[f"{model}_session"] = kwargs.pop(model)
else:
sessions[f"{model}_session"] = (
ORTModel.load_model(path, provider, session_options, provider_options) if path.is_file() else None
models[model] = kwargs.pop(model)
elif kwargs.get(f"{model}_session", None) is not None:
# this allows passing a session directly to from_pretrained
sessions[f"{model}_session"] = kwargs.pop(f"{model}_session")
elif path.is_file():
sessions[f"{model}_session"] = InferenceSession(
path,
providers=providers,
provider_options=provider_options,
sess_options=session_options,
)

@@ -331,6 +408,6 @@

# same as DiffusionPipeline.from_pretraoned, if called directly, it loads the class in the config
# Same as DiffusionPipeline.from_pretrained
if cls.__name__ == "ORTDiffusionPipeline":
class_name = config["_class_name"]
ort_pipeline_class = _get_ort_class(class_name)
pipeline_class_name = config["_class_name"]
ort_pipeline_class = _get_ort_class(pipeline_class_name)
else:

@@ -343,153 +420,108 @@ ort_pipeline_class = cls

use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
model_save_dir=model_save_tmpdir,
**models,
**kwargs,
)
# same as in DiffusionPipeline.from_pretrained, we save where the model was instantiated from
ort_pipeline.register_to_config(_name_or_path=config.get("_name_or_path", str(model_id)))
ort_pipeline.register_to_config(**config)
ort_pipeline.register_to_config(_name_or_path=config.get("_name_or_path", model_name_or_path))
return ort_pipeline
@classmethod
def _export(
cls,
model_id: str,
config: Dict[str, Any],
subfolder: str = "",
force_download: bool = False,
local_files_only: bool = False,
revision: Optional[str] = None,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
use_io_binding: Optional[bool] = None,
provider: str = "CPUExecutionProvider",
session_options: Optional[ort.SessionOptions] = None,
provider_options: Optional[Dict[str, Any]] = None,
task: Optional[str] = None,
def save_pretrained(
self,
save_directory: Union[str, Path],
push_to_hub: Optional[bool] = False,
**kwargs,
) -> "ORTDiffusionPipeline":
if task is None:
task = cls._auto_model_to_task(cls.auto_model_class)
# we continue passing the model_save_dir from here on to avoid it being cleaned up
# might be better to use a persistent temporary directory such as the one implemented in
# https://gist.github.com/twolfson/2929dc1163b0a76d2c2b66d51f9bc808
model_save_dir = TemporaryDirectory()
model_save_path = Path(model_save_dir.name)
main_export(
model_id,
output=model_save_path,
do_validation=False,
no_post_process=True,
token=token,
revision=revision,
cache_dir=cache_dir,
subfolder=subfolder,
force_download=force_download,
local_files_only=local_files_only,
trust_remote_code=trust_remote_code,
library_name="diffusers",
task=task,
)
return cls._from_pretrained(
model_save_path,
config=config,
provider=provider,
provider_options=provider_options,
session_options=session_options,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
**kwargs,
)
def to(self, device: Union[torch.device, str, int]):
):
"""
Changes the ONNX Runtime provider according to the device.
Saves a model and its configuration file to a directory, so that it can be re-loaded using the
[`from_pretrained`] class method.
Args:
device (`torch.device` or `str` or `int`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run
the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too.
Returns:
`ORTModel`: the model placed on the requested device.
save_directory (`Union[str, os.PathLike]`):
Directory to which to save. Will be created if it doesn't exist.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
**kwargs:
Additional keyword arguments passed along to [`~huggingface_hub.create_repo`] and
[`~huggingface_hub.HfApi.upload_folder`] if `push_to_hub` is set to `True`.
"""
device, provider_options = parse_device(device)
provider = get_provider_for_device(device)
validate_provider_availability(provider)
model_save_path = Path(save_directory)
model_save_path.mkdir(parents=True, exist_ok=True)
if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self
if push_to_hub:
token = kwargs.pop("token", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
self.vae_decoder.session.set_providers([provider], provider_options=[provider_options])
self.save_config(model_save_path)
self.scheduler.save_pretrained(model_save_path / "scheduler")
if self.unet is not None:
self.unet.session.set_providers([provider], provider_options=[provider_options])
self.unet.save_pretrained(model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER)
if self.transformer is not None:
self.transformer.session.set_providers([provider], provider_options=[provider_options])
self.transformer.save_pretrained(model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER)
if self.vae_encoder is not None:
self.vae_encoder.session.set_providers([provider], provider_options=[provider_options])
self.vae_encoder.save_pretrained(model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER)
if self.vae_decoder is not None:
self.vae_decoder.save_pretrained(model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER)
if self.text_encoder is not None:
self.text_encoder.session.set_providers([provider], provider_options=[provider_options])
self.text_encoder.save_pretrained(model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER)
if self.text_encoder_2 is not None:
self.text_encoder_2.session.set_providers([provider], provider_options=[provider_options])
self.text_encoder_2.save_pretrained(model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER)
if self.text_encoder_3 is not None:
self.text_encoder_3.session.set_providers([provider], provider_options=[provider_options])
self.text_encoder_3.save_pretrained(model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER)
self.providers = (
self.unet.session.get_providers() if self.unet is not None else self.transformer.session.get_providers()
)
self._device = device
if self.image_encoder is not None:
self.image_encoder.save_pretrained(model_save_path / "image_encoder")
if self.safety_checker is not None:
self.safety_checker.save_pretrained(model_save_path / "safety_checker")
return self
if self.tokenizer is not None:
self.tokenizer.save_pretrained(model_save_path / "tokenizer")
if self.tokenizer_2 is not None:
self.tokenizer_2.save_pretrained(model_save_path / "tokenizer_2")
if self.tokenizer_3 is not None:
self.tokenizer_3.save_pretrained(model_save_path / "tokenizer_3")
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(model_save_path / "feature_extractor")
@classmethod
def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs):
return cls.load_config(config_name_or_path, **kwargs)
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
model_card = populate_model_card(model_card)
model_card.save(os.path.join(save_directory, "README.md"))
self._upload_folder(
save_directory,
repo_id,
token=token,
create_pr=create_pr,
commit_message=commit_message,
)
def _save_config(self, save_directory: Union[str, Path]):
model_dir = (
self.model_save_dir
if not isinstance(self.model_save_dir, TemporaryDirectory)
else self.model_save_dir.name
)
save_dir = Path(save_directory)
original_config = Path(model_dir) / self.config_name
if original_config.exists():
if not save_dir.exists():
save_dir.mkdir(parents=True)
shutil.copy(original_config, save_dir)
else:
self.save_config(save_directory)
@property
def components(self) -> Dict[str, Any]:
components = {
"vae": self.vae,
"unet": self.unet,
"transformer": self.transformer,
"text_encoder": self.text_encoder,
"text_encoder_2": self.text_encoder_2,
"text_encoder_3": self.text_encoder_3,
"safety_checker": self.safety_checker,
"image_encoder": self.image_encoder,
}
components = {k: v for k, v in components.items() if v is not None}
return components
def __call__(self, *args, **kwargs):
# we do this to keep numpy random states support for now
# TODO: deprecate and add warnings when a random state is passed
args = list(args)
for i in range(len(args)):
args[i] = np_to_pt_generators(args[i], self.device)
new_args = np_to_pt_generators(args[i], self.device)
if args[i] is not new_args:
logger.warning(
"Converting numpy random state to torch generator is deprecated. "
"Please pass a torch generator directly to the pipeline."
)
for k, v in kwargs.items():
kwargs[k] = np_to_pt_generators(v, self.device)
for key, value in kwargs.items():
new_value = np_to_pt_generators(value, self.device)
if value is not new_value:
logger.warning(
"Converting numpy random state to torch generator is deprecated. "
"Please pass a torch generator directly to the pipeline."
)
kwargs[key] = new_value

@@ -499,18 +531,14 @@ return self.auto_model_class.__call__(self, *args, **kwargs)

class ORTPipelinePart(ConfigMixin):
class ORTModelMixin(ORTSessionMixin, ConfigMixin):
config_name: str = CONFIG_NAME
def __init__(self, session: ort.InferenceSession, parent_pipeline: ORTDiffusionPipeline):
self.session = session
self.parent_pipeline = parent_pipeline
def __init__(
self,
session: "InferenceSession",
parent: "ORTDiffusionPipeline",
use_io_binding: Optional[bool] = None,
):
self.initialize_ort_attributes(session, use_io_binding=use_io_binding)
self.parent = parent
self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
self.input_dtypes = {input_key.name: input_key.type for input_key in self.session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in self.session.get_outputs()}
self.input_shapes = {input_key.name: input_key.shape for input_key in self.session.get_inputs()}
self.output_shapes = {output_key.name: output_key.shape for output_key in self.session.get_outputs()}
config_file_path = Path(session._model_path).parent / self.config_name

@@ -523,81 +551,18 @@ if not config_file_path.is_file():

@property
def device(self):
return self.parent_pipeline.device
def save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the ONNX model and its configuration file to a directory, so that it can be re-loaded using the
[`from_pretrained`] class method.
@property
def dtype(self):
for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype
Args:
save_directory (`Union[str, os.PathLike]`):
Directory to which to save. Will be created if it doesn't exist.
"""
# save onnx model and external data
self.save_session(save_directory)
# save model configuration
self.save_config(save_directory)
for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype
return None
def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None):
for arg in args:
if isinstance(arg, torch.device):
device = arg
elif isinstance(arg, (int, str)):
device = torch.device(arg)
elif isinstance(arg, torch.dtype):
dtype = arg
if device is not None and device != self.device:
raise ValueError(
"Cannot change the device of a pipeline part without changing the device of the parent pipeline. "
"Please use the `to` method of the parent pipeline to change the device."
)
if dtype is not None and dtype != self.dtype:
raise NotImplementedError(
f"Cannot change the dtype of the pipeline from {self.dtype} to {dtype}. "
f"Please export the pipeline with the desired dtype."
)
def prepare_onnx_inputs(self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray]) -> Dict[str, np.ndarray]:
onnx_inputs = {}
# converts pytorch inputs into numpy inputs for onnx
for input_name in self.input_names.keys():
onnx_inputs[input_name] = inputs.pop(input_name)
if use_torch:
onnx_inputs[input_name] = onnx_inputs[input_name].numpy(force=True)
if onnx_inputs[input_name].dtype != self.input_dtypes[input_name]:
onnx_inputs[input_name] = onnx_inputs[input_name].astype(
TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name])
)
return onnx_inputs
def prepare_onnx_outputs(
self, use_torch: bool, *onnx_outputs: np.ndarray
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
model_outputs = {}
# converts onnxruntime outputs into tensor for standard outputs
for output_name, idx in self.output_names.items():
model_outputs[output_name] = onnx_outputs[idx]
if use_torch:
model_outputs[output_name] = torch.from_numpy(model_outputs[output_name]).to(self.device)
return model_outputs
@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class ORTModelUnet(ORTPipelinePart):
class ORTUnet(ORTModelMixin):
def __init__(self, *args, **kwargs):

@@ -645,6 +610,30 @@ super().__init__(*args, **kwargs)

onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs)
if self.use_io_binding:
known_output_shapes = {"out_sample": sample.shape}
known_output_buffers = None
if "LatentConsistencyModel" not in self.parent.__class__.__name__:
known_output_buffers = {"out_sample": sample}
output_shapes, output_buffers = self._prepare_io_binding(
model_inputs,
known_output_shapes=known_output_shapes,
known_output_buffers=known_output_buffers,
)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
model_outputs["sample"] = model_outputs.pop("out_sample")
if not return_dict:

@@ -656,3 +645,3 @@ return tuple(model_outputs.values())

class ORTModelTransformer(ORTPipelinePart):
class ORTTransformer(ORTModelMixin):
def forward(

@@ -683,6 +672,30 @@ self,

onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs)
if self.use_io_binding:
known_output_shapes = {"out_hidden_states": hidden_states.shape}
known_output_buffers = None
if "Flux" not in self.parent.__class__.__name__:
known_output_buffers = {"out_hidden_states": hidden_states}
output_shapes, output_buffers = self._prepare_io_binding(
model_inputs,
known_output_shapes=known_output_shapes,
known_output_buffers=known_output_buffers,
)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
model_outputs["hidden_states"] = model_outputs.pop("out_hidden_states")
if not return_dict:

@@ -694,3 +707,3 @@ return tuple(model_outputs.values())

class ORTModelTextEncoder(ORTPipelinePart):
class ORTTextEncoder(ORTModelMixin):
def forward(

@@ -705,8 +718,22 @@ self,

model_inputs = {"input_ids": input_ids}
model_inputs = {
"input_ids": input_ids,
}
onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs)
if self.use_io_binding:
output_shapes, output_buffers = self._prepare_io_binding(model_inputs)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
if output_hidden_states:

@@ -729,3 +756,3 @@ model_outputs["hidden_states"] = []

class ORTModelVaeEncoder(ORTPipelinePart):
class ORTVaeEncoder(ORTModelMixin):
def __init__(self, *args, **kwargs):

@@ -750,8 +777,22 @@ super().__init__(*args, **kwargs)

model_inputs = {"sample": sample}
model_inputs = {
"sample": sample,
}
onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs)
if self.use_io_binding:
output_shapes, output_buffers = self._prepare_io_binding(model_inputs)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
if "latent_sample" in model_outputs:

@@ -771,3 +812,3 @@ model_outputs["latents"] = model_outputs.pop("latent_sample")

class ORTModelVaeDecoder(ORTPipelinePart):
class ORTVaeDecoder(ORTModelMixin):
def __init__(self, *args, **kwargs):

@@ -792,8 +833,22 @@ super().__init__(*args, **kwargs)

model_inputs = {"latent_sample": latent_sample}
model_inputs = {
"latent_sample": latent_sample,
}
onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs)
if self.use_io_binding:
output_shapes, output_buffers = self._prepare_io_binding(model_inputs)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
if "latent_sample" in model_outputs:

@@ -808,19 +863,9 @@ model_outputs["latents"] = model_outputs.pop("latent_sample")

class ORTWrapperVae(ORTPipelinePart):
def __init__(self, encoder: ORTModelVaeEncoder, decoder: ORTModelVaeDecoder):
class ORTVae(ORTParentMixin):
def __init__(self, encoder: Optional[ORTVaeEncoder] = None, decoder: Optional[ORTVaeDecoder] = None):
self.encoder = encoder
self.decoder = decoder
self.encoder = encoder
@property
def config(self):
return self.decoder.config
self.initialize_ort_attributes(parts=list(filter(None, {self.encoder, self.decoder})))
@property
def dtype(self):
return self.decoder.dtype
@property
def device(self):
return self.decoder.device
def decode(self, *args, **kwargs):

@@ -832,9 +877,14 @@ return self.decoder(*args, **kwargs)

def to(self, *args, **kwargs):
self.decoder.to(*args, **kwargs)
if self.encoder is not None:
self.encoder.to(*args, **kwargs)
@property
def config(self):
return self.decoder.config
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
ORT_PIPELINE_DOCSTRING = r"""
This Pipeline inherits from [`ORTDiffusionPipeline`] and is used to run inference with the ONNX Runtime.
The pipeline can be loaded from a pretrained pipeline using the [`ORTDiffusionPipeline.from_pretrained`] method.
"""
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionPipeline(ORTDiffusionPipeline, StableDiffusionPipeline):

@@ -845,8 +895,8 @@ """

task = "text-to-image"
main_input_name = "prompt"
export_feature = "text-to-image"
auto_model_class = StableDiffusionPipeline
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionImg2ImgPipeline):

@@ -857,8 +907,8 @@ """

task = "image-to-image"
main_input_name = "image"
export_feature = "image-to-image"
auto_model_class = StableDiffusionImg2ImgPipeline
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionInpaintPipeline(ORTDiffusionPipeline, StableDiffusionInpaintPipeline):

@@ -869,8 +919,8 @@ """

task = "inpainting"
main_input_name = "prompt"
export_feature = "inpainting"
auto_model_class = StableDiffusionInpaintPipeline
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionXLPipeline(ORTDiffusionPipeline, StableDiffusionXLPipeline):

@@ -881,4 +931,4 @@ """

task = "text-to-image"
main_input_name = "prompt"
export_feature = "text-to-image"
auto_model_class = StableDiffusionXLPipeline

@@ -900,3 +950,3 @@

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionXLImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionXLImg2ImgPipeline):

@@ -907,4 +957,4 @@ """

task = "image-to-image"
main_input_name = "prompt"
export_feature = "image-to-image"
auto_model_class = StableDiffusionXLImg2ImgPipeline

@@ -940,3 +990,3 @@

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionXLInpaintPipeline(ORTDiffusionPipeline, StableDiffusionXLInpaintPipeline):

@@ -948,3 +998,3 @@ """

main_input_name = "image"
export_feature = "inpainting"
task = "inpainting"
auto_model_class = StableDiffusionXLInpaintPipeline

@@ -980,3 +1030,3 @@

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTLatentConsistencyModelPipeline(ORTDiffusionPipeline, LatentConsistencyModelPipeline):

@@ -987,8 +1037,8 @@ """

task = "text-to-image"
main_input_name = "prompt"
export_feature = "text-to-image"
auto_model_class = LatentConsistencyModelPipeline
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsistencyModelImg2ImgPipeline):

@@ -999,4 +1049,4 @@ """

task = "image-to-image"
main_input_name = "image"
export_feature = "image-to-image"
auto_model_class = LatentConsistencyModelImg2ImgPipeline

@@ -1018,3 +1068,3 @@

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusion3Pipeline(ORTDiffusionPipeline, StableDiffusion3Pipeline):

@@ -1025,7 +1075,7 @@ """

task = "text-to-image"
main_input_name = "prompt"
export_feature = "text-to-image"
auto_model_class = StableDiffusion3Pipeline
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusion3Img2ImgPipeline(ORTDiffusionPipeline, StableDiffusion3Img2ImgPipeline):

@@ -1036,4 +1086,4 @@ """

task = "image-to-image"
main_input_name = "image"
export_feature = "image-to-image"
auto_model_class = StableDiffusion3Img2ImgPipeline

@@ -1053,3 +1103,3 @@

@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3InpaintPipeline):

@@ -1060,7 +1110,7 @@ """

task = "inpainting"
main_input_name = "prompt"
export_feature = "inpainting"
auto_model_class = StableDiffusion3InpaintPipeline
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTFluxPipeline(ORTDiffusionPipeline, FluxPipeline):

@@ -1071,4 +1121,4 @@ """

task = "text-to-image"
main_input_name = "prompt"
export_feature = "text-to-image"
auto_model_class = FluxPipeline

@@ -1075,0 +1125,0 @@

@@ -95,8 +95,8 @@ # Copyright 2021 The HuggingFace Team. All rights reserved.

onnx_model_path += [
model_or_path.encoder_model_path,
model_or_path.decoder_model_path,
model_or_path.encoder.path,
model_or_path.decoder.path,
]
# Add the decoder with past key/values if present
if model_or_path.use_cache:
onnx_model_path.append(model_or_path.decoder_with_past_model_path)
if model_or_path.decoder_with_past is not None:
onnx_model_path.append(model_or_path.decoder_with_past.path)
elif isinstance(model_or_path, ORTModelForCausalLM) and model_or_path.use_merged:

@@ -108,3 +108,3 @@ raise NotImplementedError(

else:
onnx_model_path.append(model_or_path.model_path)
onnx_model_path.append(model_or_path.path)
config = model_or_path.config

@@ -111,0 +111,0 @@ elif os.path.isdir(model_or_path):

@@ -21,3 +21,3 @@ # Copyright 2021 The HuggingFace Team. All rights reserved.

from inspect import signature
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

@@ -34,2 +34,3 @@ import numpy as np

import onnxruntime as ort
from onnxruntime.transformers.io_binding_helper import TypeHelper

@@ -54,31 +55,3 @@ from ..exporters.onnx import OnnxConfig, OnnxConfigWithLoss

_ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(int16)": np.int16,
"tensor(uint16)": np.uint16,
"tensor(int32)": np.int32,
"tensor(uint32)": np.uint32,
"tensor(int64)": np.int64,
"tensor(uint64)": np.uint64,
"tensor(float16)": np.float16,
"tensor(float)": np.float32,
"tensor(double)": np.float64,
}
def _is_gpu_available():
"""
Checks if a gpu is available.
"""
available_providers = ort.get_available_providers()
if (
"CUDAExecutionProvider" in available_providers or "ROCMExecutionProvider" in available_providers
) and torch.cuda.is_available():
return True
else:
return False
def is_onnxruntime_training_available():

@@ -141,2 +114,3 @@ """

"mistral": "gpt2",
"modernbert": "bert",
"mpnet": "bert",

@@ -208,3 +182,3 @@ "mt5": "bart",

if provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider", "ROCMExecutionProvider"]:
return torch.device(f"cuda:{provider_options['device_id']}")
return torch.device(f"cuda:{provider_options.get('device_id', 0)}")
else:

@@ -298,2 +272,36 @@ return torch.device("cpu")

def prepare_providers_and_provider_options(
provider: str = "CPUExecutionProvider",
providers: Optional[Sequence[str]] = None,
provider_options: Optional[Union[Sequence[Dict[str, Any]], Dict[str, Any]]] = None,
):
"""
Prepare the providers and provider options for ONNX Runtime.
Args:
provider (`str`):
The provider to use. If `None`, the default provider will be used.
providers (`Sequence[str]`, `optional`):
The list of providers to use. If `None`, the default provider will be used.
provider_options (`Union[Sequence[Dict[str, Any]], Dict[str, Any]]`, `optional`):
The options to use for the providers. If `None`, the default options will be used.
"""
if providers is None:
providers = [provider]
for provider in providers:
validate_provider_availability(provider)
if provider_options is None:
provider_options = [{}] * len(providers)
elif isinstance(provider_options, dict):
provider_options = [provider_options] + [{}] * (len(providers) - 1)
elif len(provider_options) != len(providers):
raise ValueError(
f"When passing a list of provider options, it should be the same length as the list of providers. "
f"Got {len(provider_options)} provider options for {len(providers)} providers."
)
return providers, provider_options
def check_io_binding(providers: List[str], use_io_binding: Optional[bool] = None) -> bool:

@@ -443,1 +451,21 @@ """

self.stride = stride
def get_dtype_from_session(session: ort.InferenceSession) -> torch.dtype:
"""
Returns the `torch.dtype` associated with the ONNX Runtime session.
This dtype is inferred from the input/output dtypes of the session.
If no floating point type is found, it defaults to `torch.float32`.
"""
for input in session.get_inputs():
torch_dtype = TypeHelper.ort_type_to_torch_type(input.type)
if torch_dtype.is_floating_point:
return torch_dtype
for output in session.get_outputs():
torch_dtype = TypeHelper.ort_type_to_torch_type(output.type)
if torch_dtype.is_floating_point:
return torch_dtype
return torch.float32

@@ -22,3 +22,6 @@ # coding=utf-8

AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutomaticSpeechRecognitionPipeline,
AutoTokenizer,
FeatureExtractionPipeline,

@@ -45,5 +48,12 @@ FillMaskPipeline,

from transformers.feature_extraction_utils import PreTrainedFeatureExtractor
from transformers.onnx.utils import get_preprocessor
from transformers.image_processing_utils import BaseImageProcessor
from transformers.pipelines import (
FEATURE_EXTRACTOR_MAPPING,
IMAGE_PROCESSOR_MAPPING,
TOKENIZER_MAPPING,
check_task,
get_default_model_and_revision,
infer_framework_load_model,
)
from transformers.pipelines import SUPPORTED_TASKS as TRANSFORMERS_SUPPORTED_TASKS
from transformers.pipelines import infer_framework_load_model

@@ -92,3 +102,3 @@ from ..utils import is_onnxruntime_available, is_transformers_version

"impl": ImageSegmentationPipeline,
"class": (ORTModelForSemanticSegmentation,) if is_onnxruntime_available() else (),
"class": (ORTModelForSemanticSegmentation,),
"default": "nvidia/segformer-b0-finetuned-ade-512-512",

@@ -181,2 +191,4 @@ "type": "image",

load_feature_extractor=None,
image_processor=None,
load_image_processor=None,
SUPPORTED_TASKS=None,

@@ -201,5 +213,3 @@ subfolder: str = "",

if model is None:
model_id = SUPPORTED_TASKS[targeted_task]["default"]
elif isinstance(model, str):
if isinstance(model, str):
model_id = model

@@ -227,3 +237,3 @@ else:

return model, model_id, tokenizer, feature_extractor
return model, model_id, tokenizer, feature_extractor, image_processor

@@ -238,2 +248,4 @@

load_feature_extractor,
image_processor,
load_image_processor,
SUPPORTED_TASKS,

@@ -250,6 +262,3 @@ subfolder: str = "",

if model is None:
model_id = SUPPORTED_TASKS[targeted_task]["default"]
model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(model_id, export=True)
elif isinstance(model, str):
if isinstance(model, str):
model_id = model

@@ -279,2 +288,13 @@ model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(

)
if image_processor is None and load_image_processor:
for preprocessor in model.preprocessors:
if isinstance(preprocessor, BaseImageProcessor):
image_processor = preprocessor
break
if image_processor is None:
raise ValueError(
"Could not automatically find a feature extractor for the ORTModel, you must pass a "
"image_processor explictly"
)
model_id = None

@@ -286,3 +306,3 @@ else:

)
return model, model_id, tokenizer, feature_extractor
return model, model_id, tokenizer, feature_extractor, image_processor

@@ -301,2 +321,3 @@

feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
image_processor: Optional[Union[str, BaseImageProcessor]] = None,
use_fast: bool = True,

@@ -323,3 +344,12 @@ token: Optional[Union[str, bool]] = None,

# copied from transformers.pipelines.__init__.py
supported_tasks = ORT_SUPPORTED_TASKS if accelerator == "ort" else TRANSFORMERS_SUPPORTED_TASKS
if model is None:
if accelerator != "ort":
_, target_task, task_options = check_task(task)
model, default_revision = get_default_model_and_revision(target_task, "pt", task_options)
revision = revision or default_revision
else:
model = supported_tasks[targeted_task]["default"]
hub_kwargs = {

@@ -337,9 +367,9 @@ "revision": revision,

supported_tasks = ORT_SUPPORTED_TASKS if accelerator == "ort" else TRANSFORMERS_SUPPORTED_TASKS
no_feature_extractor_tasks = set()
no_tokenizer_tasks = set()
no_image_processor_tasks = set()
for _task, values in supported_tasks.items():
if values["type"] == "text":
no_feature_extractor_tasks.add(_task)
no_image_processor_tasks.add(_task)
elif values["type"] in {"image", "video"}:

@@ -349,5 +379,11 @@ no_tokenizer_tasks.add(_task)

no_tokenizer_tasks.add(_task)
no_image_processor_tasks.add(_task)
elif values["type"] not in ["multimodal", "audio", "video"]:
raise ValueError(f"SUPPORTED_TASK {_task} contains invalid type {values['type']}")
model_config = config or model.config
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
# copied from transformers.pipelines.__init__.py l.609

@@ -360,11 +396,13 @@ if targeted_task in no_tokenizer_tasks:

load_tokenizer = False
else:
load_tokenizer = True
if targeted_task in no_feature_extractor_tasks:
load_feature_extractor = False
else:
load_feature_extractor = True
model, model_id, tokenizer, feature_extractor = MAPPING_LOADING_FUNC[accelerator](
if targeted_task in no_image_processor_tasks:
load_image_processor = False
if load_image_processor and load_feature_extractor:
load_feature_extractor = False
model, model_id, tokenizer, feature_extractor, image_processor = MAPPING_LOADING_FUNC[accelerator](
model,

@@ -376,2 +414,4 @@ targeted_task,

load_feature_extractor,
image_processor,
load_image_processor,
SUPPORTED_TASKS=supported_tasks,

@@ -385,6 +425,9 @@ config=config,

use_fast = kwargs.get(use_fast, "True")
if tokenizer is None and load_tokenizer:
tokenizer = get_preprocessor(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, **hub_kwargs)
if feature_extractor is None and load_feature_extractor:
feature_extractor = get_preprocessor(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, use_fast=use_fast, **hub_kwargs)
if image_processor is None and load_image_processor:
image_processor = AutoImageProcessor.from_pretrained(model_id, **hub_kwargs)

@@ -396,4 +439,5 @@ return transformers_pipeline(

feature_extractor=feature_extractor,
image_processor=image_processor,
use_fast=use_fast,
**kwargs,
)

@@ -18,2 +18,4 @@ # Copyright 2021 The HuggingFace Team. All rights reserved.

CONFIG_NAME,
DIFFUSION_MODEL_CONFIG_FILE_NAME,
DIFFUSION_MODEL_ONNX_FILE_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,

@@ -26,2 +28,3 @@ DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER,

DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
DIFFUSION_PIPELINE_CONFIG_FILE_NAME,
ONNX_WEIGHTS_NAME,

@@ -46,4 +49,6 @@ )

is_onnxruntime_available,
is_onnxslim_available,
is_pydantic_available,
is_sentence_transformers_available,
is_tensorrt_available,
is_tf_available,

@@ -50,0 +55,0 @@ is_timm_available,

@@ -17,2 +17,4 @@ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.

CONFIG_NAME = "config.json"
ONNX_WEIGHTS_NAME = "model.onnx"
DIFFUSION_MODEL_UNET_SUBFOLDER = "unet"

@@ -25,2 +27,4 @@ DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER = "transformer"

DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER = "text_encoder_3"
ONNX_WEIGHTS_NAME = "model.onnx"
DIFFUSION_PIPELINE_CONFIG_FILE_NAME = "model_index.json"
DIFFUSION_MODEL_CONFIG_FILE_NAME = "config.json"
DIFFUSION_MODEL_ONNX_FILE_NAME = "model.onnx"

@@ -92,2 +92,3 @@ # Copyright 2022 The HuggingFace Team. All rights reserved.

_datasets_available = _is_package_available("datasets")
_tensorrt_available = _is_package_available("tensorrt")
_diffusers_available, _diffusers_version = _is_package_available("diffusers", return_version=True)

@@ -138,2 +139,3 @@ _transformers_available, _transformers_version = _is_package_available("transformers", return_version=True)

)
_onnxslim_available = _is_package_available("onnxslim")

@@ -242,2 +244,6 @@ if _tf_available and version.parse(_tf_version) < version.parse("2"):

def is_tensorrt_available():
return _tensorrt_available
def is_torch_available():

@@ -273,2 +279,6 @@ return _torch_available

def is_onnxslim_available():
return _onnxslim_available
@contextmanager

@@ -275,0 +285,0 @@ def check_if_pytorch_greater(target_version: str, message: str):

@@ -240,3 +240,3 @@ # coding=utf-8

"blenderbot-small": BartLikeNormalizedTextConfig,
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer"),
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head"),
"falcon": NormalizedTextConfig,

@@ -261,2 +261,3 @@ "camembert": NormalizedTextConfig,

"imagegpt": GPT2LikeNormalizedTextConfig,
"internlm2": NormalizedTextConfigWithGQA,
"llama": NormalizedTextConfigWithGQA,

@@ -269,2 +270,3 @@ "longt5": T5LikeNormalizedTextConfig,

"mixtral": NormalizedTextConfigWithGQA,
"modernbert": NormalizedTextConfig,
"mpnet": NormalizedTextConfig,

@@ -275,2 +277,4 @@ "mpt": MPTNormalizedTextConfig,

"nystromformer": NormalizedTextConfig,
"olmo": NormalizedTextConfig,
"olmo2": NormalizedTextConfig,
"opt": NormalizedTextConfig,

@@ -281,3 +285,2 @@ "pegasus": BartLikeNormalizedTextConfig,

"phi3": NormalizedTextConfigWithGQA,
"phi3small": NormalizedTextConfigWithGQA,
"poolformer": NormalizedVisionConfig,

@@ -298,2 +301,4 @@ "regnet": NormalizedVisionConfig,

"qwen2": NormalizedTextConfig,
"qwen3": NormalizedTextConfig,
"qwen3-moe": NormalizedTextConfig,
"granite": NormalizedTextConfigWithGQA,

@@ -300,0 +305,0 @@ }

@@ -15,2 +15,2 @@ # Copyright 2021 The HuggingFace Team. All rights reserved.

__version__ = "1.25.3"
__version__ = "1.26.0"
Metadata-Version: 2.1
Name: optimum
Version: 1.25.3
Version: 1.26.0
Summary: Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality.

@@ -34,3 +34,3 @@ Home-page: https://github.com/huggingface/optimum

Requires-Dist: onnxruntime>=1.11.0; extra == "onnxruntime"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime"
Provides-Extra: onnxruntime-gpu

@@ -41,3 +41,3 @@ Requires-Dist: onnx; extra == "onnxruntime-gpu"

Requires-Dist: onnxruntime-gpu>=1.11.0; extra == "onnxruntime-gpu"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime-gpu"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime-gpu"
Provides-Extra: onnxruntime-training

@@ -49,3 +49,3 @@ Requires-Dist: evaluate; extra == "onnxruntime-training"

Requires-Dist: protobuf>=3.20.1; extra == "onnxruntime-training"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "onnxruntime-training"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "onnxruntime-training"
Requires-Dist: onnxruntime-training>=1.11.0; extra == "onnxruntime-training"

@@ -57,3 +57,3 @@ Provides-Extra: exporters

Requires-Dist: protobuf>=3.20.1; extra == "exporters"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "exporters"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "exporters"
Provides-Extra: exporters-gpu

@@ -64,3 +64,3 @@ Requires-Dist: onnx; extra == "exporters-gpu"

Requires-Dist: protobuf>=3.20.1; extra == "exporters-gpu"
Requires-Dist: transformers<4.52.0,>=4.36; extra == "exporters-gpu"
Requires-Dist: transformers<4.53.0,>=4.36; extra == "exporters-gpu"
Provides-Extra: exporters-tf

@@ -114,2 +114,3 @@ Requires-Dist: onnx; extra == "exporters-tf"

Requires-Dist: hf_xet; extra == "dev"
Requires-Dist: onnxslim>=0.1.53; extra == "dev"
Requires-Dist: black~=23.1; extra == "dev"

@@ -133,2 +134,3 @@ Requires-Dist: ruff==0.1.5; extra == "dev"

Requires-Dist: hf_xet; extra == "tests"
Requires-Dist: onnxslim>=0.1.53; extra == "tests"
Provides-Extra: quality

@@ -135,0 +137,0 @@ Requires-Dist: black~=23.1; extra == "quality"

@@ -41,2 +41,3 @@ import re

"hf_xet",
"onnxslim>=0.1.53",
]

@@ -54,3 +55,3 @@

"onnxruntime>=1.11.0",
"transformers>=4.36,<4.52.0",
"transformers>=4.36,<4.53.0",
],

@@ -62,3 +63,3 @@ "onnxruntime-gpu": [

"onnxruntime-gpu>=1.11.0",
"transformers>=4.36,<4.52.0",
"transformers>=4.36,<4.53.0",
],

@@ -71,3 +72,3 @@ "onnxruntime-training": [

"protobuf>=3.20.1",
"transformers>=4.36,<4.52.0",
"transformers>=4.36,<4.53.0",
"onnxruntime-training>=1.11.0",

@@ -80,3 +81,3 @@ ],

"protobuf>=3.20.1",
"transformers>=4.36,<4.52.0",
"transformers>=4.36,<4.53.0",
],

@@ -88,3 +89,3 @@ "exporters-gpu": [

"protobuf>=3.20.1",
"transformers>=4.36,<4.52.0",
"transformers>=4.36,<4.53.0",
],

@@ -91,0 +92,0 @@ "exporters-tf": [

# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from onnxruntime.transformers.onnx_model import OnnxModel
def find_fully_connected_layers_nodes(model: OnnxModel) -> List[List[str]]:
adds = model.get_nodes_by_op_type("Add")
fc = list(filter(lambda graph: graph[1] is not None, ((add, model.match_parent(add, "MatMul")) for add in adds)))
return fc
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .io_binding_helper import IOBindingHelper, TypeHelper
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import traceback
from typing import TYPE_CHECKING
import numpy as np
import torch
import onnxruntime as ort
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
from onnxruntime.transformers.io_binding_helper import TypeHelper as ORTTypeHelper
from ..utils import is_cupy_available, is_onnxruntime_training_available
if TYPE_CHECKING:
from ..modeling_ort import ORTModel
if is_cupy_available():
import cupy as cp
# Adapted from https://github.com/microsoft/onnxruntime/blob/93e0a151177ad8222c2c95f814342bfa27f0a64d/onnxruntime/python/tools/transformers/io_binding_helper.py#L12
class TypeHelper(ORTTypeHelper):
"""
Gets data type information of the ONNX Runtime inference session and provides the mapping from
`OrtValue` data types to the data types of other frameworks (NumPy, PyTorch, etc).
"""
@staticmethod
def ort_type_to_numpy_type(ort_type: str):
ort_type_to_numpy_type_map = {
"tensor(int64)": np.int64,
"tensor(int32)": np.int32,
"tensor(int8)": np.int8,
"tensor(float)": np.float32,
"tensor(float16)": np.float16,
"tensor(bool)": bool,
}
if ort_type in ort_type_to_numpy_type_map:
return ort_type_to_numpy_type_map[ort_type]
else:
raise ValueError(
f"{ort_type} is not supported. Here is a list of supported data type: {ort_type_to_numpy_type_map.keys()}"
)
@staticmethod
def ort_type_to_torch_type(ort_type: str):
ort_type_to_torch_type_map = {
"tensor(int64)": torch.int64,
"tensor(int32)": torch.int32,
"tensor(int8)": torch.int8,
"tensor(float)": torch.float32,
"tensor(float16)": torch.float16,
"tensor(bool)": torch.bool,
}
if ort_type in ort_type_to_torch_type_map:
return ort_type_to_torch_type_map[ort_type]
else:
raise ValueError(
f"{ort_type} is not supported. Here is a list of supported data type: {ort_type_to_torch_type_map.keys()}"
)
# Adapted from https://github.com/microsoft/onnxruntime/blob/1ab11a111ce0717bfbfaca964d04a017cb9b1752/onnxruntime/python/tools/transformers/io_binding_helper.py#L97
class IOBindingHelper:
"""
A helper class to enable `ORTModel` instances to prepare IO binding with dynamic shaped outputs for an inference session and transfer
tensors from ONNX Runtime to other frameworks on device. It helps reduce memory copy between the host and device.
"""
def __init__(self, model: ort.InferenceSession, device, **kwargs):
self.model = model
self.device = device
# Create {name:idx} dict for model inputs and outputs
self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(model.get_inputs())}
self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(model.get_outputs())}
self.model_input_names = list(self.model_inputs.keys())
self.model_output_names = list(self.model_outputs.keys())
@staticmethod
def to_pytorch(ort_value: OrtValue) -> torch.Tensor:
"""
Converts tensors held by OrtValues in ONNX runtime memory buffer to torch tensor.
"""
if is_onnxruntime_training_available():
return IOBindingHelper.to_pytorch_via_dlpack(ort_value)
else:
try:
return IOBindingHelper.to_pytorch_via_cupy(ort_value)
except Exception:
logging.error(traceback.format_exc())
logging.info("Unable to access output memory in CUDA, will offload to CPU")
return IOBindingHelper.to_pytorch_via_numpy(ort_value)
@staticmethod
def to_pytorch_via_numpy(ort_value: OrtValue) -> torch.Tensor:
ort_device = ort_value.device_name().lower()
return torch.tensor(ort_value.numpy()).to(ort_device)
@staticmethod
def to_pytorch_via_cupy(ort_value: OrtValue) -> torch.Tensor:
ort_device = ort_value.device_name().lower()
if ort_device != "cuda":
raise RuntimeError(f"Exchange tensors to PyTorch via CuPy only when device is CUDA, got: {ort_device}")
ort_type = ort_value.data_type()
numpy_type = TypeHelper.ort_type_to_numpy_type(ort_type)
# Access CUDA memory via CuPy
memory = cp.cuda.UnownedMemory(ort_value.data_ptr(), 0, None)
memory_ptr = cp.cuda.MemoryPointer(memory, 0)
cp_array = cp.ndarray(shape=ort_value.shape(), memptr=memory_ptr, dtype=numpy_type)
torch_tensor = torch.from_dlpack(cp_array.toDlpack())
# If is boolean, the dtype will be uint8 and need to be convert back to bool.
if "bool" in ort_type:
torch_tensor = torch_tensor.to(torch.bool)
torch_tensor = torch_tensor.clone()
return torch_tensor
@staticmethod
# dlpack support is available for OrtValue only when `onnxruntime-training` is installed
def to_pytorch_via_dlpack(ort_value: OrtValue) -> torch.Tensor:
from torch._C import _from_dlpack
torch_tensor = _from_dlpack(ort_value.to_dlpack())
return torch_tensor
@staticmethod
def get_device_index(device):
if isinstance(device, str):
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
device = torch.device(device)
elif isinstance(device, int):
return device
return 0 if device.index is None else device.index
@staticmethod
def prepare_io_binding(ort_model: "ORTModel", **inputs) -> ort.IOBinding:
"""
Returns an IOBinding object for an inference session. This method is for general purpose, if the inputs and outputs
are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks.
"""
if not all(input_name in inputs.keys() for input_name in ort_model.input_names):
raise ValueError(
f"The ONNX model takes {ort_model.input_names.keys()} as inputs, but only {inputs.keys()} are given."
)
name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_model.model)
# Bind inputs and outputs to onnxruntime session
io_binding = ort_model.model.io_binding()
# Bind inputs
for input_name in ort_model.input_names:
onnx_input = inputs.pop(input_name)
onnx_input = onnx_input.contiguous()
io_binding.bind_input(
input_name,
onnx_input.device.type,
ort_model.device.index,
name_to_np_type[input_name],
list(onnx_input.size()),
onnx_input.data_ptr(),
)
# Bind outputs
for name in ort_model.output_names:
io_binding.bind_output(name, ort_model.device.type, device_id=ort_model.device.index)
return io_binding
from typing import TYPE_CHECKING, Tuple
if TYPE_CHECKING:
import torch
def bloom_convert_to_standard_cache(
past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]], batch_size: int
) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
def bloom_convert_to_bloom_cache(
past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]]
) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display