Skip to content

feat: Add simpler HuggingFace repository support for lora_url #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from cog_model_helpers import seed as seed_helper
from replicate_weights import download_replicate_weights
from dataclasses import dataclass
import requests
from huggingface_hub import HfApi
import tempfile

OUTPUT_DIR = "/tmp/outputs"
INPUT_DIR = "/tmp/inputs"
Expand Down Expand Up @@ -42,7 +45,10 @@ class Inputs:
choices=["1.3b", "14b"],
default="14b",
)
lora_url = Input(description="Optional: The URL of a LORA to use", default=None)
lora_url = Input(
description="Optional: The URL or HuggingFace repo ID (username/repo) of a LORA to use",
default=None
)
lora_strength_model = Input(
description="Strength of the LORA applied to the model. 0.0 is no LORA.",
default=1.0,
Expand Down Expand Up @@ -83,6 +89,64 @@ class Inputs:
)


def is_huggingface_repo_id(text: str) -> bool:
"""Check if the text looks like a HuggingFace repo ID (username/repo)"""
if not text or "//" in text: # Skip empty strings and URLs
return False
return bool(re.match(r"^[^/]+/[^/]+$", text))

def download_from_huggingface(repo_id: str, lora_dir: str):
"""Downloads a .safetensors file from a HuggingFace repository"""
# Initialize HuggingFace API
api = HfApi()
try:
# List files in the repository
files = api.list_repo_files(repo_id)
except Exception as e:
raise ValueError(f"Failed to access HuggingFace repo '{repo_id}': {e}")

# Find safetensors files
safetensors_files = [f for f in files if f.endswith(".safetensors")]
if not safetensors_files:
raise ValueError(f"No .safetensors files found in HuggingFace repo: {repo_id}")

# Log all available .safetensors files for transparency
if len(safetensors_files) > 1:
print(f"Found multiple .safetensors files in {repo_id}:")
for i, file in enumerate(safetensors_files):
print(f" {i+1}. {file}")
print(f"Using the first file: {safetensors_files[0]}")

# Use the first safetensors file
hf_filename = safetensors_files[0]

# Determine model type from filename - use cautious approach
model_type = None
if "1.3b" in hf_filename.lower():
model_type = "1.3b"
elif "14b" in hf_filename.lower():
model_type = "14b"
else:
print(f"Warning: Could not determine model type from filename '{hf_filename}'. Using default model.")

# Download URL
download_url = f"https://huggingface.co/{repo_id}/resolve/main/{hf_filename}"

# Local path
local_filename = os.path.basename(hf_filename)
target_path = os.path.join(lora_dir, local_filename)

# Download file
print(f"Downloading {download_url}")
response = requests.get(download_url)
response.raise_for_status()

with open(target_path, 'wb') as f:
f.write(response.content)

return local_filename, model_type


class Predictor(BasePredictor):
def setup(self):
self.comfyUI = ComfyUI("127.0.0.1:8188")
Expand Down Expand Up @@ -223,24 +287,40 @@ def generate(
)
model = inferred_model_type
elif lora_url:
# Handle existing Replicate URL cases first
if m := re.match(
r"^(?:https?://replicate.com/)?([^/]+)/([^/]+)/?$", lora_url
):
owner, model_name = m.groups()
lora_filename, inferred_model_type = download_replicate_weights(
f"https://replicate.com/{owner}/{model_name}/_weights",
COMFYUI_LORAS_DIR,
)
try:
lora_filename, inferred_model_type = download_replicate_weights(
f"https://replicate.com/{owner}/{model_name}/_weights",
COMFYUI_LORAS_DIR,
)
except Exception as e:
print(f"Not a valid Replicate model, trying as HuggingFace repo: {str(e)}")
if is_huggingface_repo_id(lora_url):
lora_filename, inferred_model_type = download_from_huggingface(
lora_url, COMFYUI_LORAS_DIR
)
elif lora_url.startswith("https://replicate.delivery"):
lora_filename, inferred_model_type = download_replicate_weights(
lora_url, COMFYUI_LORAS_DIR
)

if inferred_model_type and inferred_model_type != model:
print(
f"Warning: Model type mismatch between requested model ({model}) and inferred model type ({inferred_model_type}). Using {inferred_model_type}."
# Try HuggingFace repo ID as a last resort
elif is_huggingface_repo_id(lora_url):
print(f"Processing HuggingFace repo ID: {lora_url}")
lora_filename, inferred_model_type = download_from_huggingface(
lora_url, COMFYUI_LORAS_DIR
)
model = inferred_model_type
if inferred_model_type:
model = model or inferred_model_type

if inferred_model_type and inferred_model_type != model:
print(
f"Warning: Model type mismatch between requested model ({model}) and inferred model type ({inferred_model_type}). Using {inferred_model_type}."
)
model = inferred_model_type

if resolution == "720p" and model == "1.3b":
print("Warning: 720p is not supported for 1.3b, using 480p instead")
Expand Down