Skip to content

Commit c1a131f

Browse files
committed
handling dependency
1 parent 89e2752 commit c1a131f

5 files changed

+90076
-1
lines changed

__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from folder_paths import add_model_folder_path, get_filename_list, get_folder_paths
1818
from tqdm import tqdm
1919

20-
# from . import custom_routes
20+
from . import custom_routes
2121
# import routes
2222

2323
ag_path = os.path.join(os.path.dirname(__file__))

custom_routes.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from io import BytesIO
2+
from pprint import pprint
3+
from aiohttp import web
4+
import os
5+
import requests
6+
import folder_paths
7+
import json
8+
import server
9+
from PIL import Image
10+
import time
11+
import execution
12+
import random
13+
import traceback
14+
import uuid
15+
import asyncio
16+
import logging
17+
from urllib.parse import quote
18+
import threading
19+
import hashlib
20+
import aiohttp
21+
from aiohttp import ClientSession, web
22+
import aiofiles
23+
from typing import Dict, List, Union, Any, Optional
24+
from PIL import Image
25+
import copy
26+
import struct
27+
from aiohttp import web, ClientSession, ClientError, ClientTimeout, ClientResponseError
28+
import atexit
29+
from datetime import datetime
30+
import nodes
31+
from .dependency_checker import resolve_dependencies
32+
33+
34+
35+
@server.PromptServer.instance.routes.post("/shellagent/export") # data same as queue prompt, plus workflow_name
36+
async def shellagent_export(request):
37+
data = await request.json()
38+
client_id = data["client_id"]
39+
prompt = data["prompt"]
40+
extra_data = data["extra_data"]
41+
workflow_name = data["workflow_name"]
42+
workflow_id = str(uuid.uuid4())
43+
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
44+
45+
# metadata.json
46+
metadata = {
47+
"name": data["workflow_name"],
48+
"workflow_id": workflow_id,
49+
"create_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
50+
}
51+
52+
# custom_node.json
53+
resolve_dependencies(prompt)
54+
55+
workflow = prompt # used during running
56+
57+
#
58+
import pdb; pdb.set_trace()

dependency_checker.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import os
2+
import subprocess
3+
import json
4+
import logging
5+
from .utils import compute_sha256, windows_to_linux_path
6+
7+
ComfyUIModelLoaders = {
8+
'VAELoader': (["vae_name"], "vae"),
9+
'CheckpointLoader': (["ckpt_name"], "checkpoints"),
10+
'CheckpointLoaderSimple': (["ckpt_name"], "checkpoints"),
11+
'DiffusersLoader': (["model_path"], "diffusers"),
12+
'unCLIPCheckpointLoader': (["ckpt_name"], "checkpoints"),
13+
'LoraLoader': (["lora_name"], "loras"),
14+
'LoraLoaderModelOnly': (["lora_name"], "loras"),
15+
'ControlNetLoader': (["control_net_name"], "controlnet"),
16+
'DiffControlNetLoader': (["control_net_name"], "controlnet"),
17+
'UNETLoader': (["unet_name"], "unet"),
18+
'CLIPLoader': (["clip_name"], "clip"),
19+
'DualCLIPLoader': (["clip_name1", "clip_name2"], "clip"),
20+
'CLIPVisionLoader': (["clip_name"], "clip_vision"),
21+
'StyleModelLoader': (["style_model_name"], "style_models"),
22+
'GLIGENLoader': (["gligen_name"], "gligen"),
23+
}
24+
25+
26+
ComfyUIFileLoaders = {
27+
'VAELoader': (["vae_name"], "vae"),
28+
'CheckpointLoader': (["ckpt_name"], "checkpoints"),
29+
'CheckpointLoaderSimple': (["ckpt_name"], "checkpoints"),
30+
'DiffusersLoader': (["model_path"], "diffusers"),
31+
'unCLIPCheckpointLoader': (["ckpt_name"], "checkpoints"),
32+
'LoraLoader': (["lora_name"], "loras"),
33+
'LoraLoaderModelOnly': (["lora_name"], "loras"),
34+
'ControlNetLoader': (["control_net_name"], "controlnet"),
35+
'DiffControlNetLoader': (["control_net_name"], "controlnet"),
36+
'UNETLoader': (["unet_name"], "unet"),
37+
'CLIPLoader': (["clip_name"], "clip"),
38+
'DualCLIPLoader': (["clip_name1", "clip_name2"], "clip"),
39+
'CLIPVisionLoader': (["clip_name"], "clip_vision"),
40+
'StyleModelLoader': (["style_model_name"], "style_models"),
41+
'GLIGENLoader': (["gligen_name"], "gligen"),
42+
}
43+
44+
45+
model_list_json = json.load(open(os.path.join(os.path.dirname(__file__), "model_info.json")))
46+
def handle_model_info(ckpt_path):
47+
ckpt_path = windows_to_linux_path(ckpt_path)
48+
filename = os.path.basename(ckpt_path)
49+
dirname = os.path.dirname(ckpt_path)
50+
save_path = dirname.split('/', 1)[1]
51+
metadata_path = ckpt_path + ".json"
52+
if os.path.isfile(metadata_path):
53+
metadata = json.load(open(metadata_path))
54+
model_id = metadata["id"]
55+
else:
56+
logging.info(f"computing sha256 of {ckpt_path}")
57+
model_id = compute_sha256(ckpt_path)
58+
data = {
59+
"id": model_id,
60+
"save_path": save_path,
61+
"filename": filename,
62+
}
63+
json.dump(data, open(metadata_path, "w"))
64+
if model_id in model_list_json:
65+
urls = [item["url"] for item in model_list_json[model_id]["links"]][:10] # use the top 10
66+
else:
67+
urls = []
68+
69+
item = {
70+
"filename": filename,
71+
"save_path": save_path,
72+
"urls": urls,
73+
}
74+
return model_id, item
75+
76+
77+
def inspect_repo_version(module_path):
78+
# Get the remote repository URL
79+
try:
80+
remote_url = subprocess.check_output(
81+
['git', 'config', '--get', 'remote.origin.url'],
82+
cwd=module_path
83+
).strip().decode()
84+
except subprocess.CalledProcessError:
85+
return {"error": "Failed to get remote repository URL"}
86+
87+
# Get the latest commit hash
88+
try:
89+
commit_hash = subprocess.check_output(
90+
['git', 'rev-parse', 'HEAD'],
91+
cwd=module_path
92+
).strip().decode()
93+
except subprocess.CalledProcessError:
94+
return {"error": "Failed to get commit hash"}
95+
96+
# Create and return the JSON result
97+
result = {
98+
"repo": remote_url,
99+
"commit": commit_hash
100+
}
101+
return result
102+
103+
104+
def resolve_dependencies(prompt): # resolve custom nodes and models at the same time
105+
from nodes import NODE_CLASS_MAPPINGS
106+
custom_nodes = []
107+
ckpt_paths = []
108+
for node_id, node_info in prompt.items():
109+
node_class_type = node_info["class_type"]
110+
node_cls = NODE_CLASS_MAPPINGS[node_class_type]
111+
if hasattr(node_cls, "RELATIVE_PYTHON_MODULE"):
112+
custom_nodes.append(node_cls.RELATIVE_PYTHON_MODULE)
113+
if node_class_type in ComfyUIModelLoaders:
114+
input_names, save_path = ComfyUIModelLoaders[node_class_type]
115+
for input_name in input_names:
116+
ckpt_path = os.path.join("models", save_path, node_info["inputs"][input_name])
117+
ckpt_paths.append(ckpt_path)
118+
119+
ckpt_paths = list(set(ckpt_paths))
120+
custom_nodes = list(set(custom_nodes))
121+
# step 1: custom nodes
122+
custom_nodes_list = [inspect_repo_version(custom_node.replace(".", "/")) for custom_node in custom_nodes]
123+
124+
# step 2: models
125+
models_dict = {}
126+
for ckpt_path in ckpt_paths:
127+
model_id, item = handle_model_info(ckpt_path)
128+
models_dict[model_id] = item
129+
130+
# step 1: handle the custom nodes version
131+
import pdb; pdb.set_trace()
132+
# return ckpt_pat
133+
# # step 1:
134+
# for class_type in

0 commit comments

Comments
 (0)