Skip to content

Commit

Permalink
fix: missing git urls when freezing requirements (#22)
Browse files Browse the repository at this point in the history
Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored Jan 3, 2025
1 parent 8891b27 commit 6dffa3d
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions nodes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import time
import uuid
import zipfile
from importlib.metadata import Distribution, distributions
from pathlib import Path
from typing import Any, Union

Expand All @@ -25,25 +26,33 @@
ZPath = Union[Path, zipfile.Path]
TEMP_FOLDER = Path(__file__).parent.parent / "temp"
COMFY_PACK_DIR = Path(__file__).parent.parent / "src" / "comfy_pack"
EXCLUDE_PACKAGES = ["bentoml", "onnxruntime"] # TODO: standardize this
EXCLUDE_PACKAGES = ["bentoml", "onnxruntime", "conda"] # TODO: standardize this


def _get_requirement_string(dist: Distribution) -> str:
direct_url_text = dist.read_text("direct_url.json")
pinned_str = f'{dist.metadata["Name"]}=={dist.version}'
if not direct_url_text:
return pinned_str
direct_url = json.loads(direct_url_text)
if url := direct_url.get("url"):
if url.startswith("file://"):
# we are not able to share local files
return pinned_str
if vcs_info := direct_url.get("vcs_info"):
url = f"{vcs_info['vcs']}+{url}@{vcs_info['commit_id']}"
if subdirectory := direct_url.get("subdirectory"):
url += f"#subdirectory={subdirectory}"
return f"{dist.metadata['Name']} @ {url}"
else:
return pinned_str


async def _write_requirements(path: ZPath, extras: list[str] | None = None) -> None:
print("Package => Writing requirements.txt")
with path.joinpath("requirements.txt").open("w") as f:
proc = await asyncio.subprocess.create_subprocess_exec(
sys.executable,
"-m",
"pip",
"list",
"--format",
"freeze",
"--exclude-editable",
*[f"--exclude={p}" for p in EXCLUDE_PACKAGES],
stdout=subprocess.PIPE,
)
stdout, _ = await proc.communicate()
f.write(stdout.decode().rstrip("\n") + "\n")
for dist in distributions():
f.write(_get_requirement_string(dist) + "\n")
if extras:
f.write("\n".join(extras) + "\n")

Expand Down

0 comments on commit 6dffa3d

Please sign in to comment.