Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions comfy_api_nodes/apis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,9 +782,11 @@ def __init__(
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
*,
status_extractor: Callable[[R], Optional[str]],
progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], Optional[str]] | None = None,
price_extractor: Callable[[R], Optional[float]] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
Expand Down Expand Up @@ -815,10 +817,12 @@ def __init__(
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.price_extractor = price_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None
self.extracted_price: Optional[float] = None

async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
Expand All @@ -840,6 +844,8 @@ async def execute(self, client: Optional[ApiClient] = None) -> R:
def _display_text_on_node(self, text: str):
if not self.node_id:
return
if self.extracted_price is not None:
text = f"Price: {self.extracted_price}$\n{text}"
PromptServer.instance.send_progress_text(text, self.node_id)

def _display_time_progress_on_node(self, time_completed: int | float):
Expand Down Expand Up @@ -877,9 +883,7 @@ async def _poll_until_complete(self, client: ApiClient) -> R:
try:
logging.debug("[DEBUG] Polling attempt #%s", poll_count)

request_dict = (
None if self.request is None else self.request.model_dump(exclude_none=True)
)
request_dict = None if self.request is None else self.request.model_dump(exclude_none=True)

if poll_count == 1:
logging.debug(
Expand Down Expand Up @@ -912,6 +916,11 @@ async def _poll_until_complete(self, client: ApiClient) -> R:
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)

if self.price_extractor:
price = self.price_extractor(response_obj)
if price is not None:
self.extracted_price = price

if status == TaskStatus.COMPLETED:
message = "Task completed successfully"
if self.result_url_extractor:
Expand Down
17 changes: 10 additions & 7 deletions comfy_api_nodes/apis/gemini_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from __future__ import annotations

from typing import List, Optional
from typing import Optional

from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel


class GeminiImageConfig(BaseModel):
aspectRatio: Optional[str] = None


class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[List[str]] = None
responseModalities: Optional[list[str]] = None
imageConfig: Optional[GeminiImageConfig] = None


class GeminiImageGenerateContentRequest(BaseModel):
contents: List[GeminiContent]
contents: list[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[List[GeminiSafetySetting]] = None
safetySettings: Optional[list[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[List[GeminiTool]] = None
tools: Optional[list[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None
24 changes: 18 additions & 6 deletions comfy_api_nodes/nodes_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
GeminiPart,
GeminiMimeType,
)
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
Expand Down Expand Up @@ -63,6 +63,7 @@ class GeminiImageModel(str, Enum):
"""

gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
gemini_2_5_flash_image = "gemini-2.5-flash-image"


def get_gemini_endpoint(
Expand Down Expand Up @@ -538,7 +539,7 @@ def INPUT_TYPES(cls) -> InputTypeDict:
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiImageModel],
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
"default": GeminiImageModel.gemini_2_5_flash_image.value,
},
),
"seed": (
Expand Down Expand Up @@ -579,6 +580,14 @@ def INPUT_TYPES(cls) -> InputTypeDict:
# "tooltip": "How many images to generate",
# },
# ),
"aspect_ratio": (
IO.COMBO,
{
"tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.",
"options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
"default": "auto",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
Expand All @@ -600,15 +609,17 @@ async def api_call(
images: Optional[IO.IMAGE] = None,
files: Optional[list[GeminiPart]] = None,
n=1,
aspect_ratio: str = "auto",
unique_id: Optional[str] = None,
**kwargs,
):
# Validate inputs
validate_string(prompt, strip_whitespace=True, min_length=1)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)]

# Add other modal parts
if not aspect_ratio:
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)

if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
Expand All @@ -625,7 +636,8 @@ async def api_call(
),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT","IMAGE"]
responseModalities=["TEXT","IMAGE"],
imageConfig=None if aspect_ratio == "auto" else image_config,
)
),
auth_kwargs=kwargs,
Expand Down
33 changes: 19 additions & 14 deletions comfy_api_nodes/nodes_kling.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
validate_video_dimensions,
validate_video_duration,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.input.basic_types import AudioInput
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import ComfyExtension, io as comfy_io
Expand Down Expand Up @@ -511,7 +512,7 @@ async def execute_video_effect(
image_1: torch.Tensor,
image_2: Optional[torch.Tensor] = None,
model_mode: Optional[KlingVideoGenMode] = None,
) -> comfy_io.NodeOutput:
) -> tuple[VideoFromFile, str, str]:
if dual_character:
request_input_field = KlingDualCharacterEffectInput(
model_name=model_name,
Expand Down Expand Up @@ -562,7 +563,7 @@ async def execute_video_effect(
validate_video_result_response(final_response)

video = get_video_from_response(final_response)
return comfy_io.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration))
return await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)


async def execute_lipsync(
Expand Down Expand Up @@ -1271,7 +1272,7 @@ async def execute(
image_1=image_left,
image_2=image_right,
)
return video, duration
return comfy_io.NodeOutput(video, duration)


class KlingSingleImageVideoEffectNode(comfy_io.ComfyNode):
Expand Down Expand Up @@ -1320,17 +1321,21 @@ async def execute(
model_name: KlingSingleImageEffectModelName,
duration: KlingVideoGenDuration,
) -> comfy_io.NodeOutput:
return await execute_video_effect(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
dual_character=False,
effect_scene=effect_scene,
model_name=model_name,
duration=duration,
image_1=image,
return comfy_io.NodeOutput(
*(
await execute_video_effect(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
dual_character=False,
effect_scene=effect_scene,
model_name=model_name,
duration=duration,
image_1=image,
)
)
)


Expand Down
2 changes: 2 additions & 0 deletions comfy_api_nodes/nodes_pika.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
tensor_to_bytesio,
validate_string,
)
from comfy_api_nodes.apis import pika_defs
from comfy_api_nodes.apis.client import (
Expand Down Expand Up @@ -590,6 +591,7 @@ async def execute(
resolution: str,
duration: int,
) -> comfy_io.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
Expand Down
9 changes: 5 additions & 4 deletions comfy_extras/nodes_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
for key, value in metadata.items():
output_container.metadata[key] = value

layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
Expand All @@ -156,7 +157,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
Expand All @@ -165,9 +166,9 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)

frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
Expand Down
Loading