Skip to content

Commit 89e2752

Browse files
committed
input output nodes v1
1 parent 6627185 commit 89e2752

7 files changed

+548
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__

__init__.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""
2+
@author: MyShell
3+
@title: comfyui-shellagent-plugin
4+
@description:
5+
"""
6+
import os
7+
import sys
8+
9+
sys.path.append(os.path.join(os.path.dirname(__file__)))
10+
11+
import inspect
12+
import sys
13+
import importlib
14+
import subprocess
15+
import requests
16+
import folder_paths
17+
from folder_paths import add_model_folder_path, get_filename_list, get_folder_paths
18+
from tqdm import tqdm
19+
20+
# from . import custom_routes
21+
# import routes
22+
23+
ag_path = os.path.join(os.path.dirname(__file__))
24+
25+
def get_python_files(path):
26+
return [f[:-3] for f in os.listdir(path) if f.endswith(".py")]
27+
28+
def append_to_sys_path(path):
29+
if path not in sys.path:
30+
sys.path.append(path)
31+
32+
paths = ["comfy-nodes"]
33+
files = []
34+
35+
for path in paths:
36+
full_path = os.path.join(ag_path, path)
37+
append_to_sys_path(full_path)
38+
files.extend(get_python_files(full_path))
39+
40+
NODE_CLASS_MAPPINGS = {}
41+
NODE_DISPLAY_NAME_MAPPINGS = {}
42+
43+
# Import all the modules and append their mappings
44+
for file in files:
45+
module = importlib.import_module(file)
46+
47+
if hasattr(module, "NODE_CLASS_MAPPINGS"):
48+
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
49+
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS"):
50+
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
51+
52+
WEB_DIRECTORY = "web-plugin"
53+
print(NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS)
54+
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]

comfy-nodes/input_image.py

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import folder_paths
2+
from PIL import Image, ImageOps
3+
import numpy as np
4+
import torch
5+
import os
6+
import uuid
7+
import tqdm
8+
9+
10+
class ShellAgentPluginInputImage:
11+
@classmethod
12+
def INPUT_TYPES(s):
13+
input_dir = folder_paths.get_input_directory()
14+
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
15+
files = sorted(files)
16+
return {
17+
"required": {
18+
"input_name": (
19+
"STRING",
20+
{"multiline": False, "default": "input_image"},
21+
),
22+
"default_value": (
23+
"STRING", {"image_upload": True, "default": files[0] if len(files) else ""},
24+
),
25+
},
26+
"optional": {
27+
"description": (
28+
"STRING",
29+
{"multiline": True, "default": ""},
30+
),
31+
}
32+
}
33+
34+
RETURN_TYPES = ("IMAGE",)
35+
RETURN_NAMES = ("image",)
36+
37+
FUNCTION = "run"
38+
39+
CATEGORY = "shellagent"
40+
41+
def run(self, input_name, default_value=None, display_name=None, description=None):
42+
input_dir = folder_paths.get_input_directory()
43+
image_path = default_value
44+
try:
45+
if image_path.startswith('http'):
46+
import requests
47+
from io import BytesIO
48+
print("Fetching image from url: ", image)
49+
response = requests.get(image)
50+
image = Image.open(BytesIO(response.content))
51+
elif image_path.startswith('data:image/png;base64,') or image_path.startswith('data:image/jpeg;base64,') or image_path.startswith('data:image/jpg;base64,'):
52+
import base64
53+
from io import BytesIO
54+
print("Decoding base64 image")
55+
base64_image = image_path[image_path.find(",")+1:]
56+
decoded_image = base64.b64decode(base64_image)
57+
image = Image.open(BytesIO(decoded_image))
58+
else:
59+
# local path
60+
image_path = os.path.join(input_dir, image_path)
61+
image = Image.open(image_path).convert("RGB")
62+
63+
image = ImageOps.exif_transpose(image)
64+
image = image.convert("RGB")
65+
image = np.array(image).astype(np.float32) / 255.0
66+
image = torch.from_numpy(image)[None,]
67+
return [image]
68+
except Exception as e:
69+
raise e
70+
71+
video_extensions = ["webm", "mp4", "mkv", "gif"]
72+
73+
class ShellAgentPluginInputVideo:
74+
@classmethod
75+
def INPUT_TYPES(s):
76+
input_dir = folder_paths.get_input_directory()
77+
files = []
78+
for f in os.listdir(input_dir):
79+
if os.path.isfile(os.path.join(input_dir, f)):
80+
file_parts = f.split(".")
81+
if len(file_parts) > 1 and (file_parts[-1] in video_extensions):
82+
files.append(f)
83+
84+
return {
85+
"required": {
86+
"input_name": (
87+
"STRING",
88+
{"multiline": False, "default": "input_video"},
89+
),
90+
"default_value": (
91+
"STRING", {"video_upload": True, "default": files[0] if len(files) else ""},
92+
),
93+
},
94+
"optional": {
95+
"description": (
96+
"STRING",
97+
{"multiline": True, "default": ""},
98+
),
99+
}
100+
}
101+
102+
RETURN_TYPES = ("STRING",)
103+
RETURN_NAMES = ("video",)
104+
105+
FUNCTION = "run"
106+
107+
CATEGORY = "shellagent"
108+
109+
def run(self, input_name, default_value=None, description=None):
110+
input_dir = folder_paths.get_input_directory()
111+
if default_value.startswith("http"):
112+
import requests
113+
114+
print("Fetching video from URL: ", default_value)
115+
response = requests.get(default_value, stream=True)
116+
file_size = int(response.headers.get("Content-Length", 0))
117+
file_extension = default_value.split(".")[-1].split("?")[
118+
0
119+
] # Extract extension and handle URLs with parameters
120+
if file_extension not in video_extensions:
121+
file_extension = ".mp4"
122+
123+
unique_filename = str(uuid.uuid4()) + "." + file_extension
124+
video_path = os.path.join(input_dir, unique_filename)
125+
chunk_size = 1024 # 1 Kibibyte
126+
127+
num_bars = int(file_size / chunk_size)
128+
129+
with open(video_path, "wb") as out_file:
130+
for chunk in tqdm(
131+
response.iter_content(chunk_size=chunk_size),
132+
total=num_bars,
133+
unit="KB",
134+
desc="Downloading",
135+
leave=True,
136+
):
137+
out_file.write(chunk)
138+
else:
139+
video_path = os.path.abspath(os.path.join(input_dir, default_value))
140+
141+
return (video_path,)
142+
143+
144+
NODE_CLASS_MAPPINGS = {
145+
"ShellAgentPluginInputImage": ShellAgentPluginInputImage,
146+
# "ShellAgentPluginInputVideo": ShellAgentPluginInputVideo,
147+
}
148+
NODE_DISPLAY_NAME_MAPPINGS = {
149+
"ShellAgentPluginInputImage": "Input Image (ShellAgent Plugin)",
150+
# "ShellAgentPluginInputVideo": "Input Video (ShellAgent Plugin)"
151+
}

comfy-nodes/input_text.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import folder_paths
2+
from PIL import Image, ImageOps
3+
import numpy as np
4+
import torch
5+
6+
class ShellAgentPluginInputText:
7+
@classmethod
8+
def INPUT_TYPES(s):
9+
return {
10+
"required": {
11+
"input_name": (
12+
"STRING",
13+
{"multiline": False, "default": "input_text"},
14+
),
15+
},
16+
"optional": {
17+
"default_value": (
18+
"STRING",
19+
{"multiline": True, "default": ""},
20+
),
21+
"description": (
22+
"STRING",
23+
{"multiline": True, "default": ""},
24+
),
25+
"choices": (
26+
"STRING",
27+
{"multiline": False, "default": ""},
28+
),
29+
}
30+
}
31+
32+
RETURN_TYPES = ("STRING",)
33+
RETURN_NAMES = ("text",)
34+
35+
FUNCTION = "run"
36+
37+
CATEGORY = "shellagent"
38+
39+
def run(self, input_name, default_value=None, display_name=None, description=None, choices=None):
40+
return [default_value]
41+
42+
43+
NODE_CLASS_MAPPINGS = {"ShellAgentPluginInputText": ShellAgentPluginInputText}
44+
NODE_DISPLAY_NAME_MAPPINGS = {"ShellAgentPluginInputText": "Input Text (ShellAgent Plugin)"}

0 commit comments

Comments
 (0)