Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 67 additions & 89 deletions runware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,8 +1157,7 @@ async def _removeImageBackground(

# Add provider settings if provided
if removeImageBackgroundPayload.providerSettings:
self._addImageProviderSettings(task_params, removeImageBackgroundPayload)

self._addProviderSettings(task_params, removeImageBackgroundPayload)
# Add safety settings if provided
if removeImageBackgroundPayload.safety:
self._addSafetySettings(task_params, removeImageBackgroundPayload.safety)
Expand Down Expand Up @@ -1259,8 +1258,7 @@ async def _upscaleGan(self, upscaleGanPayload: "IImageUpscale") -> "Union[List[I

# Add provider settings if provided
if upscaleGanPayload.providerSettings:
self._addImageProviderSettings(task_params, upscaleGanPayload)

self._addProviderSettings(task_params, upscaleGanPayload)
# Add safety settings if provided
if upscaleGanPayload.safety:
self._addSafetySettings(task_params, upscaleGanPayload.safety)
Expand Down Expand Up @@ -1303,81 +1301,71 @@ async def imageVectorize(self, vectorizePayload: "IVectorize") -> "Union[List[II
async with self._request_semaphore:
return await self._retry_with_reconnect(self._vectorize, vectorizePayload)

async def _vectorize(self, vectorizePayload: "IVectorize") -> Union[List["IImage"], "IAsyncTaskResponse"]:
await self.ensureConnection()
# Process the image from inputs
input_image = vectorizePayload.inputs.image

if not input_image:
raise ValueError("Image is required in inputs for vectorize task")

# Upload the image if it's a local file
image_uploaded = await self.uploadImage(input_image)

if not image_uploaded or not image_uploaded.imageUUID:
return []

taskUUID = getUUID()
async def _processVectorizeInputs(self, vectorizePayload: IVectorize) -> None:
if not vectorizePayload.inputs or not vectorizePayload.inputs.image:
return
vectorizePayload.inputs.image = await process_image(vectorizePayload.inputs.image)

# Create a dictionary with mandatory parameters
task_params = {
def _buildVectorizeRequest(self, vectorizePayload: IVectorize) -> Dict[str, Any]:
request_object = {
"taskType": ETaskType.IMAGE_VECTORIZE.value,
"taskUUID": taskUUID,
"inputs": {
"image": image_uploaded.imageUUID
}
"taskUUID": vectorizePayload.taskUUID,
}

# Add optional parameters if they are provided
if vectorizePayload.model is not None:
task_params["model"] = vectorizePayload.model
request_object["model"] = vectorizePayload.model
if vectorizePayload.outputType is not None:
task_params["outputType"] = vectorizePayload.outputType
request_object["outputType"] = vectorizePayload.outputType
if vectorizePayload.outputFormat is not None:
task_params["outputFormat"] = vectorizePayload.outputFormat
request_object["outputFormat"] = vectorizePayload.outputFormat
if vectorizePayload.includeCost:
task_params["includeCost"] = vectorizePayload.includeCost
request_object["includeCost"] = vectorizePayload.includeCost
if vectorizePayload.webhookURL:
request_object["webhookURL"] = vectorizePayload.webhookURL
if vectorizePayload.width is not None:
request_object["width"] = vectorizePayload.width
if vectorizePayload.height is not None:
request_object["height"] = vectorizePayload.height
if vectorizePayload.positivePrompt is not None:
request_object["positivePrompt"] = vectorizePayload.positivePrompt.strip()
self._addOptionalField(request_object, vectorizePayload.inputs)
self._addProviderSettings(request_object, vectorizePayload)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use smth like:

def build_request(mapping: dict[str, Any]) -> dict[str, Any]:
    """Build request dict, excluding None values."""
    return {k: v for k, v in mapping.items() if v is not None}

return request_object

async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]:
await self.ensureConnection()
await self._processVectorizeInputs(vectorizePayload)
vectorizePayload.taskUUID = vectorizePayload.taskUUID or getUUID()
task_params = self._buildVectorizeRequest(vectorizePayload)

await self.send([task_params])

if vectorizePayload.webhookURL:
task_params["webhookURL"] = vectorizePayload.webhookURL
return await self._handleWebhookRequest(
request_object=task_params,
task_uuid=taskUUID,
task_uuid=vectorizePayload.taskUUID,
task_type="vectorize",
debug_key="image-vectorize-webhook"
)

future, should_send = await self._register_pending_operation(
taskUUID,
expected_results=1,
complete_predicate=None,
result_filter=lambda r: r.get("imageUUID") is not None
let_lis = await self.listenToImages(
onPartialImages=None,
taskUUID=vectorizePayload.taskUUID,
groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES,
)

try:

if should_send:
await self.send([task_params])
await self._mark_operation_sent(taskUUID)
results = await asyncio.wait_for(future, timeout=IMAGE_OPERATION_TIMEOUT / 1000)
images = await self.getSimililarImage(
taskUUID=vectorizePayload.taskUUID,
numberOfImages=1,
shouldThrowError=True,
lis=let_lis,
)

if not results:
raise Exception(f"No results received | TaskUUID: {taskUUID}")
let_lis["destroy"]()

for result in results:
if "code" in result or "errors" in result:
raise RunwareAPIError(result)
if "code" in images or "errors" in images:
raise RunwareAPIError(images)

return instantiateDataclassList(IImage, results)

except asyncio.TimeoutError:
raise Exception(
f"Timeout waiting for vectorize | TaskUUID: {taskUUID} | "
f"Timeout: {IMAGE_OPERATION_TIMEOUT}ms"
)
except RunwareAPIError:
raise
finally:
await self._unregister_pending_operation(taskUUID)
return instantiateDataclassList(IImage, images)

async def promptEnhance(
self, promptEnhancer: "IPromptEnhance"
Expand Down Expand Up @@ -2248,9 +2236,7 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]:
request_object["stopSequences"] = requestText.stopSequences
if requestText.includeCost is not None:
request_object["includeCost"] = requestText.includeCost
if requestText.numberResults is not None:
request_object["numberResults"] = requestText.numberResults
self._addTextProviderSettings(request_object, requestText)
self._addProviderSettings(request_object, requestText)
return request_object

async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]:
Expand Down Expand Up @@ -2370,7 +2356,7 @@ def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str
self._addOptionalImageFields(request_object, requestImage)
self._addImageSpecialFields(request_object, requestImage, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data)
self._addOptionalField(request_object, requestImage.inputs)
self._addImageProviderSettings(request_object, requestImage)
self._addProviderSettings(request_object, requestImage)
self._addOptionalField(request_object, requestImage.ultralytics)
self._addOptionalField(request_object, requestImage.safety)
self._addOptionalField(request_object, requestImage.settings)
Expand Down Expand Up @@ -2482,17 +2468,23 @@ def _addSafetySettings(self, request_object: Dict[str, Any], safety: ISafety) ->
if safety_dict:
request_object["safety"] = safety_dict

def _addImageProviderSettings(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None:
if not requestImage.providerSettings:
return
provider_dict = requestImage.providerSettings.to_request_dict()
if provider_dict:
request_object["providerSettings"] = provider_dict

def _addProviderSettings(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
if not requestVideo.providerSettings:
def _addProviderSettings(
self,
request_object: Dict[str, Any],
payload: Union[
IImageInference,
IImageBackgroundRemoval,
IImageUpscale,
IVectorize,
IVideoInference,
IAudioInference,
ITextInference,
],
) -> None:
providerSettings = getattr(payload, "providerSettings", None)
if not providerSettings:
return
provider_dict = requestVideo.providerSettings.to_request_dict()
provider_dict = providerSettings.to_request_dict()
if provider_dict:
request_object["providerSettings"] = provider_dict

Expand Down Expand Up @@ -2865,7 +2857,7 @@ def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]:
self._addOptionalField(request_object, requestAudio.speech)
self._addOptionalField(request_object, requestAudio.audioSettings)
self._addOptionalField(request_object, requestAudio.settings)
self._addAudioProviderSettings(request_object, requestAudio)
self._addProviderSettings(request_object, requestAudio)
self._addOptionalField(request_object, requestAudio.inputs)
self._addOptionalField(request_object, requestAudio.settings)

Expand All @@ -2883,20 +2875,6 @@ def _addOptionalAudioFields(self, request_object: Dict[str, Any], requestAudio:
request_object[field] = value


def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None:
if not requestAudio.providerSettings:
return
provider_dict = requestAudio.providerSettings.to_request_dict()
if provider_dict:
request_object["providerSettings"] = provider_dict

def _addTextProviderSettings(self, request_object: Dict[str, Any], requestText: ITextInference) -> None:
if not requestText.providerSettings:
return
provider_dict = requestText.providerSettings.to_request_dict()
if provider_dict:
request_object["providerSettings"] = provider_dict

async def _handleInitialAudioResponse(
self,
request_object: "Dict[str, Any]",
Expand Down
15 changes: 10 additions & 5 deletions runware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,8 @@ def request_key(self) -> str:
| IRecraftProviderSettings
)

VectorizeProviderSettings = IRecraftProviderSettings

@dataclass
class ISafety(SerializableMixin):
tolerance: Optional[bool] = None
Expand Down Expand Up @@ -1051,14 +1053,17 @@ def __post_init__(self):

@dataclass
class IVectorize:

inputs: IInputs = None
inputs: Optional[IInputs] = None
includeCost: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only one field is not optional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

taskUUID: Optional[str] = None
model: Optional[str] = None
outputType: Optional[IOutputType] = "URL"
outputFormat: Optional[IOutputFormat] = "SVG"
model: Optional[str] = None
outputType: Optional[IOutputType] = "URL"
outputFormat: Optional[IOutputFormat] = "SVG"
webhookURL: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
positivePrompt: Optional[str] = None
providerSettings: Optional[VectorizeProviderSettings] = None


@dataclass
Expand Down