Skip to content

[Inference Providers] isolate image-to-image payload build for HF Inference API #1439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
*
* Thanks!
*/
import { base64FromBytes } from "../utils/base64FromBytes";

import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../lib/InferenceOutputError";
import { isUrl } from "../lib/isUrl";
import type { BodyParams, HeaderParams, ModelId, UrlParams } from "../types";
import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types";
import { delay } from "../utils/delay";
import { omit } from "../utils/omit";
import {
Expand All @@ -27,6 +29,7 @@ import {
type TextToVideoTaskHelper,
} from "./providerHelper";
import { HF_HUB_URL } from "../config";
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition";

export interface FalAiQueueOutput {
request_id: string;
Expand Down Expand Up @@ -224,6 +227,28 @@ export class FalAIAutomaticSpeechRecognitionTask extends FalAITask implements Au
}
return { text: res.text };
}

async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
const contentType = blob?.type;
if (!contentType) {
throw new Error(
`Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.`
);
}
if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) {
throw new Error(
`Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join(
", "
)}`
);
}
const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
return {
...("data" in args ? omit(args, "data") : omit(args, "inputs")),
audio_url: `data:${contentType};base64,${base64audio}`,
};
}
}

export class FalAITextToSpeechTask extends FalAITask {
Expand Down
33 changes: 31 additions & 2 deletions packages/inference/src/providers/hf-inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import type {
import { HF_ROUTER_URL } from "../config";
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { TabularClassificationOutput } from "../tasks/tabular/tabularClassification";
import type { BodyParams, UrlParams } from "../types";
import type { BodyParams, RequestArgs, UrlParams } from "../types";
import { toArray } from "../utils/toArray";
import type {
AudioClassificationTaskHelper,
Expand Down Expand Up @@ -70,7 +70,10 @@ import type {
} from "./providerHelper";

import { TaskProviderHelper } from "./providerHelper";

import { base64FromBytes } from "../utils/base64FromBytes";
import type { ImageToImageArgs } from "../tasks/cv/imageToImage";
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition";
import { omit } from "../utils/omit";
interface Base64ImageGeneration {
data: Array<{
b64_json: string;
Expand Down Expand Up @@ -221,6 +224,15 @@ export class HFInferenceAutomaticSpeechRecognitionTask
override async getResponse(response: AutomaticSpeechRecognitionOutput): Promise<AutomaticSpeechRecognitionOutput> {
return response;
}

async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
return "data" in args
? args
: {
...omit(args, "inputs"),
data: args.inputs,
};
}
}

export class HFInferenceAudioToAudioTask extends HFInferenceTask implements AudioToAudioTaskHelper {
Expand Down Expand Up @@ -326,6 +338,23 @@ export class HFInferenceImageToTextTask extends HFInferenceTask implements Image
}

export class HFInferenceImageToImageTask extends HFInferenceTask implements ImageToImageTaskHelper {
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
if (!args.parameters) {
return {
...args,
model: args.model,
data: args.inputs,
};
} else {
return {
...args,
inputs: base64FromBytes(
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
),
};
}
}

override async getResponse(response: Blob): Promise<Blob> {
if (response instanceof Blob) {
return response;
Expand Down
6 changes: 5 additions & 1 deletion packages/inference/src/providers/providerHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ import type {
import { HF_ROUTER_URL } from "../config";
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio";
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, UrlParams } from "../types";
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, RequestArgs, UrlParams } from "../types";
import { toArray } from "../utils/toArray";
import type { ImageToImageArgs } from "../tasks/cv/imageToImage";
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition";

/**
* Base class for task-specific provider helpers
Expand Down Expand Up @@ -142,6 +144,7 @@ export interface TextToVideoTaskHelper {
export interface ImageToImageTaskHelper {
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>;
preparePayload(params: BodyParams<ImageToImageInput & BaseArgs>): Record<string, unknown>;
preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs>;
}

export interface ImageSegmentationTaskHelper {
Expand Down Expand Up @@ -245,6 +248,7 @@ export interface AudioToAudioTaskHelper {
export interface AutomaticSpeechRecognitionTaskHelper {
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AutomaticSpeechRecognitionOutput>;
preparePayload(params: BodyParams<AutomaticSpeechRecognitionInput & BaseArgs>): Record<string, unknown> | BodyInit;
preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs>;
}

export interface AudioClassificationTaskHelper {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@ import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
import { omit } from "../../utils/omit";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";

export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput);
/**
Expand All @@ -21,7 +17,7 @@ export async function automaticSpeechRecognition(
): Promise<AutomaticSpeechRecognitionOutput> {
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
const payload = await buildPayload(args);
const payload = await providerHelper.preparePayloadAsync(args);
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
...options,
task: "automatic-speech-recognition",
Expand All @@ -32,29 +28,3 @@ export async function automaticSpeechRecognition(
}
return providerHelper.getResponse(res);
}

async function buildPayload(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
if (args.provider === "fal-ai") {
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
const contentType = blob?.type;
if (!contentType) {
throw new Error(
`Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.`
);
}
if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) {
throw new Error(
`Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join(
", "
)}`
);
}
const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
return {
...("data" in args ? omit(args, "data") : omit(args, "inputs")),
audio_url: `data:${contentType};base64,${base64audio}`,
};
} else {
return preparePayload(args);
}
}
21 changes: 3 additions & 18 deletions packages/inference/src/tasks/cv/imageToImage.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import type { ImageToImageInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";

export type ImageToImageArgs = BaseArgs & ImageToImageInput;
Expand All @@ -14,22 +13,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "image-to-image");
let reqArgs: RequestArgs;
if (!args.parameters) {
reqArgs = {
accessToken: args.accessToken,
model: args.model,
data: args.inputs,
};
} else {
reqArgs = {
...args,
inputs: base64FromBytes(
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
),
};
}
const { data: res } = await innerRequest<Blob>(reqArgs, providerHelper, {
const payload = await providerHelper.preparePayloadAsync(args);
const { data: res } = await innerRequest<Blob>(payload, providerHelper, {
...options,
task: "image-to-image",
});
Expand Down