From 646010caaf079d985dc90b89b65eb5225e4227b1 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Feb 2025 11:15:53 +0000 Subject: [PATCH] InferenceClient latent-to-image --- src/huggingface_hub/inference/_client.py | 57 ++++++++++++++++++- src/huggingface_hub/inference/_common.py | 37 ++++++++++++ .../inference/_generated/_async_client.py | 57 ++++++++++++++++++- .../inference/_providers/hf_inference.py | 9 ++- 4 files changed, 156 insertions(+), 4 deletions(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 98b951bd40..6f68d3ba9e 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -59,6 +59,8 @@ _set_unsupported_text_generation_kwargs, _stream_chat_completion_response, _stream_text_generation_response, + _tensor_shape_dtype, + _tensor_to_bytes, raise_text_generation_error, ) from huggingface_hub.inference._generated.types import ( @@ -102,7 +104,7 @@ ZeroShotImageClassificationOutputElement, ) from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper -from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status +from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status, is_torch_available from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method @@ -110,6 +112,10 @@ import numpy as np from PIL.Image import Image + if is_torch_available(): + import torch + + logger = logging.getLogger(__name__) @@ -1434,6 +1440,55 @@ def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> Imag output = ImageToTextOutput.parse_obj(response) return output[0] if isinstance(output, list) else output + def latent_to_image(self, latents: "torch.Tensor", *, model: Optional[str] = None) -> "Image": + """ + Takes an input latents and return image. + + + + You must have `torch` installed if you want to work with tensors (`pip install torch`). + + + + + + You must have `safetensors` installed if you want to work with tensors (`pip install safetensors`). + + + + VAE model should match the model that generated the latents. + + Args: + latents (`torch.Tensor`): + The input latent to decode. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + Subfolder `vae` is automatically detected. + + Returns: + `Image`: The generated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + """ + provider_helper = get_provider_helper(self.provider, task="latent-to-image") + headers = self.headers + headers.update(_tensor_shape_dtype(tensor=latents)) + request_parameters = provider_helper.prepare_request( + data=_tensor_to_bytes(tensor=latents), + headers=headers, + model=model or self.model, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + response = provider_helper.get_response(response) + return _bytes_to_image(response) + def object_detection( self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None ) -> List[ObjectDetectionOutputElement]: diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index 7a7e5d43db..057eb7abcb 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -55,6 +55,8 @@ is_aiohttp_available, is_numpy_available, is_pillow_available, + is_safetensors_available, + is_torch_available, ) from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput @@ -63,6 +65,9 @@ from aiohttp import ClientResponse, ClientSession from PIL.Image import Image + if is_torch_available(): + import torch + # TYPES UrlT = str PathT = Union[str, Path] @@ -167,6 +172,24 @@ def _import_pil_image(): return Image +def _import_torch(): + """Make sure `torch` is installed on the machine.""" + if not is_torch_available(): + raise ImportError("Please install torch to use deal with tensors (`pip install torch`).") + import torch + + return torch + + +def _import_safetensors(): + """Make sure `safetensors` is installed on the machine.""" + if not is_safetensors_available(): + raise ImportError("Please install safetensors to use deal with tensors (`pip install safetensors`).") + import safetensors.torch + + return safetensors.torch + + ## ENCODING / DECODING UTILS @@ -262,6 +285,20 @@ def _as_dict(response: Union[bytes, Dict]) -> Dict: return json.loads(response) if isinstance(response, bytes) else response +def _tensor_to_bytes(tensor: "torch.Tensor") -> bytes: + safetensors = _import_safetensors() + data = safetensors._tobytes(tensor=tensor, name="tensor") + if not isinstance(data, bytes): + data = data.tobytes() + return data + + +def _tensor_shape_dtype(tensor: "torch.Tensor") -> Dict[str, str]: + shape = json.dumps(list(tensor.shape)) + dtype = str(tensor.dtype).split(".")[-1] + return {"shape": shape, "dtype": dtype} + + ## PAYLOAD UTILS diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 3c814aca91..d9011f59ff 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -44,6 +44,8 @@ _import_numpy, _open_as_binary, _set_unsupported_text_generation_kwargs, + _tensor_shape_dtype, + _tensor_to_bytes, raise_text_generation_error, ) from huggingface_hub.inference._generated.types import ( @@ -87,7 +89,7 @@ ZeroShotImageClassificationOutputElement, ) from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper -from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status +from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status, is_torch_available from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method from .._common import _async_yield_from, _import_aiohttp @@ -98,6 +100,10 @@ from aiohttp import ClientResponse, ClientSession from PIL.Image import Image + if is_torch_available(): + import torch + + logger = logging.getLogger(__name__) @@ -1481,6 +1487,55 @@ async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) - output = ImageToTextOutput.parse_obj(response) return output[0] if isinstance(output, list) else output + async def latent_to_image(self, latents: "torch.Tensor", *, model: Optional[str] = None) -> "Image": + """ + Takes an input latents and return image. + + + + You must have `torch` installed if you want to work with tensors (`pip install torch`). + + + + + + You must have `safetensors` installed if you want to work with tensors (`pip install safetensors`). + + + + VAE model should match the model that generated the latents. + + Args: + latents (`torch.Tensor`): + The input latent to decode. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + Subfolder `vae` is automatically detected. + + Returns: + `Image`: The generated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + """ + provider_helper = get_provider_helper(self.provider, task="latent-to-image") + headers = self.headers + headers.update(_tensor_shape_dtype(tensor=latents)) + request_parameters = provider_helper.prepare_request( + data=_tensor_to_bytes(tensor=latents), + headers=headers, + model=model or self.model, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + response = provider_helper.get_response(response) + return _bytes_to_image(response) + async def object_detection( self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None ) -> List[ObjectDetectionOutputElement]: diff --git a/src/huggingface_hub/inference/_providers/hf_inference.py b/src/huggingface_hub/inference/_providers/hf_inference.py index 3066448b72..a3278376ab 100644 --- a/src/huggingface_hub/inference/_providers/hf_inference.py +++ b/src/huggingface_hub/inference/_providers/hf_inference.py @@ -107,15 +107,20 @@ def prepare_headers(self, headers: Dict, *, api_key: Optional[Union[bool, str]] def _prepare_payload( self, inputs: Any, parameters: Dict[str, Any], model: Optional[str], extra_payload: Dict[str, Any] ) -> Tuple[Any, Any]: - if isinstance(inputs, bytes): + if isinstance(inputs, bytes) and self.task not in ("latent-to-image"): raise ValueError(f"Unexpected binary input for task {self.task}.") if isinstance(inputs, Path): raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})") - return None, { + _data = None + _json = { "inputs": inputs, "parameters": {k: v for k, v in parameters.items() if v is not None}, **extra_payload, } + if self.task in ("latent-to-image"): + _data = inputs + _json = None + return _data, _json def get_response(self, response: Union[bytes, Dict]) -> Any: return response