Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 137 additions & 8 deletions Runware/Runware-base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import {
IRequestVideo,
IAsyncResults,
IVideoToImage,
TAudioInference,
IAudioResult,
} from "./types";
import {
BASE_RUNWARE_URLS,
Expand Down Expand Up @@ -266,7 +268,7 @@ export class RunwareBase {
}
};

private listenToImages({
private listenToResponse({
onPartialImages,
taskUUID,
groupKey,
Expand Down Expand Up @@ -574,7 +576,7 @@ export class RunwareBase {

// const generationTime = endTime - startTime;

lis = this.listenToImages({
lis = this.listenToResponse({
onPartialImages,
taskUUID: taskUUID,
groupKey: LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES,
Expand Down Expand Up @@ -723,6 +725,8 @@ export class RunwareBase {
retry,
includePayload,
includeGenerationTime,
inputImages,
...rest
}: IRequestImageToText): Promise<IImageToText> => {
const totalRetry = retry || this._globalMaxRetries;
let lis: any = undefined;
Expand All @@ -737,13 +741,33 @@ export class RunwareBase {
? await this.uploadImage(inputImage as File | string)
: null;

const imagesUploaded = inputImages?.length
? await Promise.all(
inputImages.map((image) =>
this.uploadImage(image as File | string)
)
)
: null;

const taskUUID = customTaskUUID || getUUID();

const payload = {
taskUUID,
taskType: ETaskType.IMAGE_CAPTION,
inputImage: imageUploaded?.imageUUID,

...(imageUploaded?.imageUUID
? { inputImage: imageUploaded.imageUUID }
: {}),

...(imagesUploaded?.length
? {
inputImages: imagesUploaded
.map((img) => img?.imageUUID)
.filter(Boolean),
}
: {}),
...evaluateNonTrue({ key: "includeCost", value: includeCost }),
...rest,
};

this.send(payload);
Expand Down Expand Up @@ -850,7 +874,7 @@ export class RunwareBase {
await getIntervalAsyncWithPromise(
async ({ resolve, reject }) => {
try {
const videos = await this.getResponse({ taskUUID });
const videos = await this.getResponse<IVideoToImage>({ taskUUID });

// Add videos to the collection
for (const video of videos || []) {
Expand Down Expand Up @@ -885,10 +909,10 @@ export class RunwareBase {
}
};

getResponse = async (payload: IAsyncResults): Promise<IVideoToImage[]> => {
getResponse = async <T>(payload: IAsyncResults): Promise<T[]> => {
const taskUUID = payload.taskUUID;
// const mock = getRandomTaskResponses({ count: 2, taskUUID });
return this.baseSingleRequest({
return this.baseSingleRequest<T[]>({
payload: {
...payload,
customTaskUUID: taskUUID,
Expand All @@ -910,6 +934,7 @@ export class RunwareBase {
retry,
includeGenerationTime,
includePayload,
...rest
}: IUpscaleGan): Promise<IImage> => {
const totalRetry = retry || this._globalMaxRetries;
let lis: any = undefined;
Expand All @@ -933,6 +958,7 @@ export class RunwareBase {
...(outputType ? { outputType } : {}),
...(outputQuality ? { outputQuality } : {}),
...(outputFormat ? { outputFormat } : {}),
...rest,
};

this.send(payload);
Expand Down Expand Up @@ -1178,7 +1204,7 @@ export class RunwareBase {
numberResults: imageRemaining,
});

lis = this.listenToImages({
lis = this.listenToResponse({
onPartialImages,
taskUUID: taskUUID,
groupKey: LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES,
Expand Down Expand Up @@ -1252,6 +1278,32 @@ export class RunwareBase {
});
};

audioInference = async (
payload: TAudioInference
): Promise<IAudioResult | IAudioResult[]> => {
const { skipResponse, deliveryMethod = "sync", ...rest } = payload;
try {
const requestMethod =
deliveryMethod === "sync"
? this.baseSyncRequest
: this.baseSingleRequest;

const request = await requestMethod<IAudioResult>({
payload: {
...rest,
numberResults: rest.numberResults || 1,
taskType: ETaskType.AUDIO_INFERENCE,
deliveryMethod: deliveryMethod,
},
debugKey: "audio-inference",
});

return request;
} catch (e) {
throw e;
}
};

protected baseSingleRequest = async <T>({
payload,
debugKey,
Expand Down Expand Up @@ -1335,6 +1387,81 @@ export class RunwareBase {
throw e;
}
};
protected baseSyncRequest = async <T>({
payload,
debugKey,
}: {
payload: Record<string, any>;
debugKey: string;
}): Promise<T> => {
const {
retry,
customTaskUUID,
includePayload,
numberResults = 1,
onPartialResponse,
includeGenerationTime,
...restPayload
} = payload;

const totalRetry = retry || this._globalMaxRetries;
let lis: any = undefined;
let taskUUIDs: string[] = [];
let retryCount = 0;

const startTime = Date.now();

try {
return await asyncRetry(
async () => {
await this.ensureConnection();
retryCount++;

const taskWithSimilarTaskUUID = this._globalImages.filter((audio) =>
taskUUIDs.includes(audio.taskUUID)
);

const taskUUID = customTaskUUID || getUUID();
taskUUIDs.push(taskUUID);
const taskRemaining = numberResults - taskWithSimilarTaskUUID.length;

const payload = {
...restPayload,
taskUUID,
numberResults: taskRemaining,
};

this.send(payload);

lis = this.listenToResponse({
onPartialImages: onPartialResponse,
taskUUID: taskUUID,
groupKey: LISTEN_TO_IMAGES_KEY.REQUEST_AUDIO,
requestPayload: includePayload ? payload : undefined,
startTime: includeGenerationTime ? startTime : undefined,
});

const promise = await this.getSimilarImages({
taskUUID: taskUUIDs,
numberResults,
lis,
debugKey,
});

lis.destroy();
return promise as T;
},
{
maxRetries: totalRetry,
callback: () => {
lis?.destroy();
},
}
);
} catch (e) {
throw e;
}
};

async ensureConnection() {
let isConnected = this.connected();
Expand Down Expand Up @@ -1437,11 +1564,13 @@ export class RunwareBase {
numberResults,
shouldThrowError,
lis,
debugKey = "getting-images",
}: {
taskUUID: string | string[];
numberResults: number;
shouldThrowError?: boolean;
lis: any;
debugKey?: string;
}): Promise<IImage[] | IError> {
return (await getIntervalWithPromise(
({ resolve, reject, intervalId }) => {
Expand Down Expand Up @@ -1471,7 +1600,7 @@ export class RunwareBase {
}
},
{
debugKey: "getting images",
debugKey,
shouldThrowError,
timeoutDuration: this._timeoutDuration,
}
Expand Down
78 changes: 78 additions & 0 deletions Runware/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export enum ETaskType {
AUTHENTICATION = "authentication",
MODEL_UPLOAD = "modelUpload",
MODEL_SEARCH = "modelSearch",
AUDIO_INFERENCE = "audioInference",
}

export type RunwareBaseType = {
Expand All @@ -34,6 +35,7 @@ export type RunwareBaseType = {
};

export type IOutputType = "base64Data" | "dataURI" | "URL";
export type IDeliveryType = "sync" | "async";
export type IOutputFormat = "JPG" | "PNG" | "WEBP";
export type IVideoOutputFormat = "MP4" | "WEBM" | "MOV";

Expand Down Expand Up @@ -69,6 +71,7 @@ export interface IVideoToImage {
seed?: number;
videoURL?: string;
}

export interface IControlNetImage {
taskUUID: string;
inputImageUUID: string;
Expand Down Expand Up @@ -249,6 +252,12 @@ export interface IRequestImageToText extends IAdditionalResponsePayload {
includeCost?: boolean;
customTaskUUID?: string;
retry?: number;

model?: string;
prompts?: string[];
inputImages?: string[];

[key: string]: any;
}
export interface IImageToText {
taskType: ETaskType;
Expand Down Expand Up @@ -335,9 +344,29 @@ export interface IUpscaleGan extends IAdditionalResponsePayload {
outputFormat?: IOutputFormat;
includeCost?: boolean;
outputQuality?: number;
revertExtra?: boolean;
model?: string;

customTaskUUID?: string;
retry?: number;

settings?: {
seed?: number;
controlNetWeight?: number;
CFGScale?: number;
positivePrompt?: string;
negativePrompt?: string;
scheduler?: string;
colorFix?: boolean;
tileDiffusion?: boolean;
clipSkip?: number;
steps?: number;
strength?: number;
checkNSFW?: boolean;
[key: string]: any;
};

[key: string]: any;
}

export type ReconnectingWebsocketProps = {
Expand Down Expand Up @@ -644,6 +673,39 @@ export type TModelSearch = {
retry?: number;
} & { [key: string]: any };

export type TAudioInference = {
model: string;
positivePrompt: string;
negativePrompt?: string;
duration: number;
numberResults?: number;
outputFormat?: "MP3" | "WAV" | "FLAC" | "AAC" | "OGG";
outputType?: IOutputType;
webhookURL?: string;
deliveryMethod?: IDeliveryType;
uploadEndpoint?: string;
includeCost?: boolean;
onPartialResponse?: (images: IImage[], error?: IError) => void;

audioSettings?: {
sampleRate?: number;
bitrate?: number;
[key: string]: any;
};

providerSettings?: {
elevenlabs?: {
music?: string;
[key: string]: any;
};
[key: string]: any;
};

// other options
customTaskUUID?: string;
retry?: number;
} & { [key: string]: any };

export type TModel = {
air: string;
name: string;
Expand Down Expand Up @@ -705,6 +767,22 @@ export type TImageUploadResponse = {
imageURL: string;
};

export type IAudioSyncResult = {
taskType: string;
taskUUID: string;
audioUUID: string;
audioURL?: string;
audioBase64Data?: string;
audioDataURI?: string;
cost: number;
};
export type IAuidoAsyncResult = {
taskType: string;
taskUUID: string;
status: string;
};
export type IAudioResult = IAudioSyncResult | IAuidoAsyncResult;

export type TImageMaskingResponse = {
taskType: string;
taskUUID: string;
Expand Down
Loading
Loading