Skip to content

Commit

Permalink
InferenceClient latent-to-image
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Feb 10, 2025
1 parent 85a5c8e commit 646010c
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 4 deletions.
57 changes: 56 additions & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
_set_unsupported_text_generation_kwargs,
_stream_chat_completion_response,
_stream_text_generation_response,
_tensor_shape_dtype,
_tensor_to_bytes,
raise_text_generation_error,
)
from huggingface_hub.inference._generated.types import (
Expand Down Expand Up @@ -102,14 +104,18 @@
ZeroShotImageClassificationOutputElement,
)
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status, is_torch_available
from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method


if TYPE_CHECKING:
import numpy as np
from PIL.Image import Image

if is_torch_available():
import torch


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -1434,6 +1440,55 @@ def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> Imag
output = ImageToTextOutput.parse_obj(response)
return output[0] if isinstance(output, list) else output

def latent_to_image(self, latents: "torch.Tensor", *, model: Optional[str] = None) -> "Image":
"""
Takes an input latents and return image.
<Tip warning={true}>
You must have `torch` installed if you want to work with tensors (`pip install torch`).
</Tip>
<Tip warning={true}>
You must have `safetensors` installed if you want to work with tensors (`pip install safetensors`).
</Tip>
VAE model should match the model that generated the latents.
Args:
latents (`torch.Tensor`):
The input latent to decode.
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
Subfolder `vae` is automatically detected.
Returns:
`Image`: The generated image.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.
"""
provider_helper = get_provider_helper(self.provider, task="latent-to-image")
headers = self.headers
headers.update(_tensor_shape_dtype(tensor=latents))
request_parameters = provider_helper.prepare_request(
data=_tensor_to_bytes(tensor=latents),
headers=headers,
model=model or self.model,
api_key=self.token,
)
response = self._inner_post(request_parameters)
response = provider_helper.get_response(response)
return _bytes_to_image(response)

def object_detection(
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
) -> List[ObjectDetectionOutputElement]:
Expand Down
37 changes: 37 additions & 0 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
is_aiohttp_available,
is_numpy_available,
is_pillow_available,
is_safetensors_available,
is_torch_available,
)
from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput

Expand All @@ -63,6 +65,9 @@
from aiohttp import ClientResponse, ClientSession
from PIL.Image import Image

if is_torch_available():
import torch

# TYPES
UrlT = str
PathT = Union[str, Path]
Expand Down Expand Up @@ -167,6 +172,24 @@ def _import_pil_image():
return Image


def _import_torch():
"""Make sure `torch` is installed on the machine."""
if not is_torch_available():
raise ImportError("Please install torch to use deal with tensors (`pip install torch`).")
import torch

return torch


def _import_safetensors():
"""Make sure `safetensors` is installed on the machine."""
if not is_safetensors_available():
raise ImportError("Please install safetensors to use deal with tensors (`pip install safetensors`).")
import safetensors.torch

return safetensors.torch


## ENCODING / DECODING UTILS


Expand Down Expand Up @@ -262,6 +285,20 @@ def _as_dict(response: Union[bytes, Dict]) -> Dict:
return json.loads(response) if isinstance(response, bytes) else response


def _tensor_to_bytes(tensor: "torch.Tensor") -> bytes:
safetensors = _import_safetensors()
data = safetensors._tobytes(tensor=tensor, name="tensor")
if not isinstance(data, bytes):
data = data.tobytes()
return data


def _tensor_shape_dtype(tensor: "torch.Tensor") -> Dict[str, str]:
shape = json.dumps(list(tensor.shape))
dtype = str(tensor.dtype).split(".")[-1]
return {"shape": shape, "dtype": dtype}


## PAYLOAD UTILS


Expand Down
57 changes: 56 additions & 1 deletion src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
_import_numpy,
_open_as_binary,
_set_unsupported_text_generation_kwargs,
_tensor_shape_dtype,
_tensor_to_bytes,
raise_text_generation_error,
)
from huggingface_hub.inference._generated.types import (
Expand Down Expand Up @@ -87,7 +89,7 @@
ZeroShotImageClassificationOutputElement,
)
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status, is_torch_available
from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method

from .._common import _async_yield_from, _import_aiohttp
Expand All @@ -98,6 +100,10 @@
from aiohttp import ClientResponse, ClientSession
from PIL.Image import Image

if is_torch_available():
import torch


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -1481,6 +1487,55 @@ async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -
output = ImageToTextOutput.parse_obj(response)
return output[0] if isinstance(output, list) else output

async def latent_to_image(self, latents: "torch.Tensor", *, model: Optional[str] = None) -> "Image":
"""
Takes an input latents and return image.
<Tip warning={true}>
You must have `torch` installed if you want to work with tensors (`pip install torch`).
</Tip>
<Tip warning={true}>
You must have `safetensors` installed if you want to work with tensors (`pip install safetensors`).
</Tip>
VAE model should match the model that generated the latents.
Args:
latents (`torch.Tensor`):
The input latent to decode.
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
Subfolder `vae` is automatically detected.
Returns:
`Image`: The generated image.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`aiohttp.ClientResponseError`:
If the request fails with an HTTP error status code other than HTTP 503.
"""
provider_helper = get_provider_helper(self.provider, task="latent-to-image")
headers = self.headers
headers.update(_tensor_shape_dtype(tensor=latents))
request_parameters = provider_helper.prepare_request(
data=_tensor_to_bytes(tensor=latents),
headers=headers,
model=model or self.model,
api_key=self.token,
)
response = await self._inner_post(request_parameters)
response = provider_helper.get_response(response)
return _bytes_to_image(response)

async def object_detection(
self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None
) -> List[ObjectDetectionOutputElement]:
Expand Down
9 changes: 7 additions & 2 deletions src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,20 @@ def prepare_headers(self, headers: Dict, *, api_key: Optional[Union[bool, str]]
def _prepare_payload(
self, inputs: Any, parameters: Dict[str, Any], model: Optional[str], extra_payload: Dict[str, Any]
) -> Tuple[Any, Any]:
if isinstance(inputs, bytes):
if isinstance(inputs, bytes) and self.task not in ("latent-to-image"):
raise ValueError(f"Unexpected binary input for task {self.task}.")
if isinstance(inputs, Path):
raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})")
return None, {
_data = None
_json = {
"inputs": inputs,
"parameters": {k: v for k, v in parameters.items() if v is not None},
**extra_payload,
}
if self.task in ("latent-to-image"):
_data = inputs
_json = None
return _data, _json

def get_response(self, response: Union[bytes, Dict]) -> Any:
return response
Expand Down

0 comments on commit 646010c

Please sign in to comment.