1
+ import folder_paths
2
+ from PIL import Image , ImageOps
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+ import uuid
7
+ import tqdm
8
+
9
+
10
+ class ShellAgentPluginInputImage :
11
+ @classmethod
12
+ def INPUT_TYPES (s ):
13
+ input_dir = folder_paths .get_input_directory ()
14
+ files = [f for f in os .listdir (input_dir ) if os .path .isfile (os .path .join (input_dir , f ))]
15
+ files = sorted (files )
16
+ return {
17
+ "required" : {
18
+ "input_name" : (
19
+ "STRING" ,
20
+ {"multiline" : False , "default" : "input_image" },
21
+ ),
22
+ "default_value" : (
23
+ "STRING" , {"image_upload" : True , "default" : files [0 ] if len (files ) else "" },
24
+ ),
25
+ },
26
+ "optional" : {
27
+ "description" : (
28
+ "STRING" ,
29
+ {"multiline" : True , "default" : "" },
30
+ ),
31
+ }
32
+ }
33
+
34
+ RETURN_TYPES = ("IMAGE" ,)
35
+ RETURN_NAMES = ("image" ,)
36
+
37
+ FUNCTION = "run"
38
+
39
+ CATEGORY = "shellagent"
40
+
41
+ def run (self , input_name , default_value = None , display_name = None , description = None ):
42
+ input_dir = folder_paths .get_input_directory ()
43
+ image_path = default_value
44
+ try :
45
+ if image_path .startswith ('http' ):
46
+ import requests
47
+ from io import BytesIO
48
+ print ("Fetching image from url: " , image )
49
+ response = requests .get (image )
50
+ image = Image .open (BytesIO (response .content ))
51
+ elif image_path .startswith ('data:image/png;base64,' ) or image_path .startswith ('data:image/jpeg;base64,' ) or image_path .startswith ('data:image/jpg;base64,' ):
52
+ import base64
53
+ from io import BytesIO
54
+ print ("Decoding base64 image" )
55
+ base64_image = image_path [image_path .find ("," )+ 1 :]
56
+ decoded_image = base64 .b64decode (base64_image )
57
+ image = Image .open (BytesIO (decoded_image ))
58
+ else :
59
+ # local path
60
+ image_path = os .path .join (input_dir , image_path )
61
+ image = Image .open (image_path ).convert ("RGB" )
62
+
63
+ image = ImageOps .exif_transpose (image )
64
+ image = image .convert ("RGB" )
65
+ image = np .array (image ).astype (np .float32 ) / 255.0
66
+ image = torch .from_numpy (image )[None ,]
67
+ return [image ]
68
+ except Exception as e :
69
+ raise e
70
+
71
+ video_extensions = ["webm" , "mp4" , "mkv" , "gif" ]
72
+
73
+ class ShellAgentPluginInputVideo :
74
+ @classmethod
75
+ def INPUT_TYPES (s ):
76
+ input_dir = folder_paths .get_input_directory ()
77
+ files = []
78
+ for f in os .listdir (input_dir ):
79
+ if os .path .isfile (os .path .join (input_dir , f )):
80
+ file_parts = f .split ("." )
81
+ if len (file_parts ) > 1 and (file_parts [- 1 ] in video_extensions ):
82
+ files .append (f )
83
+
84
+ return {
85
+ "required" : {
86
+ "input_name" : (
87
+ "STRING" ,
88
+ {"multiline" : False , "default" : "input_video" },
89
+ ),
90
+ "default_value" : (
91
+ "STRING" , {"video_upload" : True , "default" : files [0 ] if len (files ) else "" },
92
+ ),
93
+ },
94
+ "optional" : {
95
+ "description" : (
96
+ "STRING" ,
97
+ {"multiline" : True , "default" : "" },
98
+ ),
99
+ }
100
+ }
101
+
102
+ RETURN_TYPES = ("STRING" ,)
103
+ RETURN_NAMES = ("video" ,)
104
+
105
+ FUNCTION = "run"
106
+
107
+ CATEGORY = "shellagent"
108
+
109
+ def run (self , input_name , default_value = None , description = None ):
110
+ input_dir = folder_paths .get_input_directory ()
111
+ if default_value .startswith ("http" ):
112
+ import requests
113
+
114
+ print ("Fetching video from URL: " , default_value )
115
+ response = requests .get (default_value , stream = True )
116
+ file_size = int (response .headers .get ("Content-Length" , 0 ))
117
+ file_extension = default_value .split ("." )[- 1 ].split ("?" )[
118
+ 0
119
+ ] # Extract extension and handle URLs with parameters
120
+ if file_extension not in video_extensions :
121
+ file_extension = ".mp4"
122
+
123
+ unique_filename = str (uuid .uuid4 ()) + "." + file_extension
124
+ video_path = os .path .join (input_dir , unique_filename )
125
+ chunk_size = 1024 # 1 Kibibyte
126
+
127
+ num_bars = int (file_size / chunk_size )
128
+
129
+ with open (video_path , "wb" ) as out_file :
130
+ for chunk in tqdm (
131
+ response .iter_content (chunk_size = chunk_size ),
132
+ total = num_bars ,
133
+ unit = "KB" ,
134
+ desc = "Downloading" ,
135
+ leave = True ,
136
+ ):
137
+ out_file .write (chunk )
138
+ else :
139
+ video_path = os .path .abspath (os .path .join (input_dir , default_value ))
140
+
141
+ return (video_path ,)
142
+
143
+
144
+ NODE_CLASS_MAPPINGS = {
145
+ "ShellAgentPluginInputImage" : ShellAgentPluginInputImage ,
146
+ # "ShellAgentPluginInputVideo": ShellAgentPluginInputVideo,
147
+ }
148
+ NODE_DISPLAY_NAME_MAPPINGS = {
149
+ "ShellAgentPluginInputImage" : "Input Image (ShellAgent Plugin)" ,
150
+ # "ShellAgentPluginInputVideo": "Input Video (ShellAgent Plugin)"
151
+ }
0 commit comments