Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sending directly torch/tf/np dtype for generating dummy inputs #2117

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ...onnx import merge_decoders
from ...utils import (
DEFAULT_DUMMY_SHAPES,
DTYPE_MAPPER,
DummyInputGenerator,
DummyLabelsGenerator,
DummySeq2SeqPastKeyValuesGenerator,
Expand Down Expand Up @@ -467,8 +468,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
int_dtype = getattr(DTYPE_MAPPER, framework)(self.int_dtype)
float_dtype = getattr(DTYPE_MAPPER, framework)(self.float_dtype)
dummy_inputs[input_name] = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
input_name, int_dtype=int_dtype, float_dtype=float_dtype
)
input_was_inserted = True
break
Expand Down Expand Up @@ -679,6 +682,8 @@ def overwrite_shape_and_generate_input(

# TODO: The check `self.task != "text-generation" and self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs.
# This overwrite may be moved to OnnxSeq2SeqConfigWithPast, but I am afraid it would break encoder-decoder models.
int_dtype = getattr(DTYPE_MAPPER, framework)(self.int_dtype)
float_dtype = getattr(DTYPE_MAPPER, framework)(self.float_dtype)
if (
self.use_past
and self.use_past_in_inputs
Expand All @@ -689,14 +694,10 @@ def overwrite_shape_and_generate_input(
sequence_length = dummy_input_gen.sequence_length
# Use a sequence length of 1 when the KV cache is already populated.
dummy_input_gen.sequence_length = 1
dummy_input = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
)
dummy_input = dummy_input_gen.generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype)
dummy_input_gen.sequence_length = sequence_length
else:
dummy_input = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
)
dummy_input = dummy_input_gen.generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype)

return dummy_input

Expand Down Expand Up @@ -740,8 +741,12 @@ def flatten_output_collection_property(self, name: str, field: Iterable[Any]) ->
return flattened_output

def generate_dummy_inputs_for_validation(
self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None
self,
reference_model_inputs: Dict[str, Any],
onnx_input_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
int_dtype = DTYPE_MAPPER.pt(self.int_dtype)
float_dtype = DTYPE_MAPPER.pt(self.float_dtype)
if self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self.is_merged is True and self.use_cache_branch is False:
Expand All @@ -754,7 +759,7 @@ def generate_dummy_inputs_for_validation(
task=self.task, normalized_config=self._normalized_config, sequence_length=1, batch_size=batch_size
)
reference_model_inputs["past_key_values"] = pkv_generator.generate(
"past_key_values", framework="pt", int_dtype=self.int_dtype, float_dtype=self.float_dtype
"past_key_values", int_dtype=int_dtype, float_dtype=float_dtype
)

return reference_model_inputs
Expand Down Expand Up @@ -1081,12 +1086,14 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
for cls_ in self.DUMMY_EXTRA_INPUT_GENERATOR_CLASSES
]

int_dtype = getattr(DTYPE_MAPPER, framework)(self.int_dtype)
float_dtype = getattr(DTYPE_MAPPER, framework)(self.float_dtype)
for input_name in self._tasks_to_extra_inputs[self.task]:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
input_name, int_dtype=int_dtype, float_dtype=float_dtype
)
input_was_inserted = True
break
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,7 +1790,7 @@ def infer_task_from_model(
cache_dir=cache_dir,
token=token,
)
elif type(model) == type:
elif isinstance(model, type):
inferred_task_name = cls._infer_task_from_model_or_model_class(model_class=model)
else:
inferred_task_name = cls._infer_task_from_model_or_model_class(model=model)
Expand Down Expand Up @@ -1944,7 +1944,7 @@ def infer_library_from_model(
cache_dir=cache_dir,
token=token,
)
elif type(model) == type:
elif isinstance(model, type):
library_name = cls._infer_library_from_model_or_model_class(model_class=model)
else:
library_name = cls._infer_library_from_model_or_model_class(model=model)
Expand Down
11 changes: 10 additions & 1 deletion optimum/exporters/tflite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
if is_tf_available():
import tensorflow as tf

from ...utils import DTYPE_MAPPER
from ..base import ExportConfig


Expand Down Expand Up @@ -191,12 +192,16 @@ def __init__(
audio_sequence_length: Optional[int] = None,
point_batch_size: Optional[int] = None,
nb_points_per_image: Optional[int] = None,
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
self._config = config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.mandatory_axes = ()
self.task = task
self._axes: Dict[str, int] = {}
self.int_dtype = int_dtype
self.float_dtype = float_dtype

# To avoid using **kwargs.
axes_values = {
Expand Down Expand Up @@ -310,12 +315,16 @@ def generate_dummy_inputs(self) -> Dict[str, "tf.Tensor"]:
"""
dummy_inputs_generators = self._create_dummy_input_generator_classes()
dummy_inputs = {}
int_dtype = DTYPE_MAPPER.tf(self.int_dtype)
float_dtype = DTYPE_MAPPER.tf(self.float_dtype)

for input_name in self.inputs:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework="tf")
dummy_inputs[input_name] = dummy_input_gen.generate(
input_name, int_dtype=int_dtype, float_dtype=float_dtype
)
input_was_inserted = True
break
if not input_was_inserted:
Expand Down
Loading
Loading