Image-Gen / app.py
Profakerr's picture
Update app.py
840cd7f verified
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
import torch
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
import os
import requests
import hashlib
from pathlib import Path
import re
import random
# Default LoRA for fallback
DEFAULT_LORA = "OedoSoldier/detail-tweaker-lora"
LORA_CACHE_DIR = "lora_cache"
def download_lora(url):
"""Download LoRA file from Civitai URL and cache it locally"""
# Create cache directory if it doesn't exist
os.makedirs(LORA_CACHE_DIR, exist_ok=True)
# Generate a filename from the URL
url_hash = hashlib.md5(url.encode()).hexdigest()
local_path = os.path.join(LORA_CACHE_DIR, f"{url_hash}.safetensors")
# If file already exists in cache, return the path
if os.path.exists(local_path):
print()
print("********** Lora Already Exists **********")
print()
return local_path
# Download the file
try:
response = requests.get(url, stream=True)
response.raise_for_status()
# Get the total file size
total_size = int(response.headers.get('content-length', 0))
# Download and save the file
with open(local_path, 'wb') as f:
if total_size == 0:
f.write(response.content)
else:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print()
print("********** Lora Downloading Successfull **********")
print()
return local_path
except Exception as e:
print()
print(f"Error downloading LoRA: {str(e)}")
print()
return None
def is_civitai_url(url):
"""Check if the URL is a valid Civitai download URL"""
return bool(re.match(r'https?://civitai\.com/api/download/models/\d+', url))
@spaces.GPU
def generate_image(prompt, negative_prompt, lora_url, num_inference_steps=30, guidance_scale=7.0,
model="Real6.0", num_images=1, width=512, height=512,seed=None):
if model == "Real5.0":
model_id = "SG161222/Realistic_Vision_V5.0_noVAE"
elif model == "Real5.1":
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
elif model == "majicv7":
model_id = "digiplay/majicMIX_realistic_v7"
else:
model_id = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
# Initialize models
vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae"
).to("cuda")
text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder"
).to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(
model_id,
subfolder="tokenizer"
)
unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet"
).to("cuda")
pipe = DiffusionPipeline.from_pretrained(
model_id,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae
).to("cuda")
# Load LoRA weights
try:
if lora_url and lora_url.strip():
if is_civitai_url(lora_url):
# Download and load Civitai LoRA
lora_path = download_lora(lora_url)
if lora_path:
pipe.load_lora_weights(lora_path)
print()
print("********** URL Lora Loaded **********")
print()
else:
pipe.load_lora_weights(DEFAULT_LORA)
print()
print("********** Default Lora Loaded **********")
print()
# If it's a HuggingFace repo path
elif '/' in lora_url and not lora_url.startswith('http'):
pipe.load_lora_weights(lora_url)
print()
print("********** URL Lora Loaded **********")
print()
else:
pipe.load_lora_weights(DEFAULT_LORA)
print()
print("********** Default Lora Loaded **********")
print()
else:
pipe.load_lora_weights(DEFAULT_LORA)
except Exception as e:
print()
print(f"Error loading LoRA weights: {str(e)}")
print()
pipe.load_lora_weights(DEFAULT_LORA)
if model == "Real6.0":
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config,
algorithm_type="dpmsolver++",
use_karras_sigmas=True
)
if seed is None:
seed = random.randint(0, 2**32 - 1)
generator = torch.manual_seed(seed)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).to("cuda")
negative_text_inputs = tokenizer(
negative_prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).to("cuda")
prompt_embeds = text_encoder(text_inputs.input_ids)[0]
negative_prompt_embeds = text_encoder(negative_text_inputs.input_ids)[0]
# Generate the image
result = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
cross_attention_kwargs={"scale": 1},
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
num_images_per_prompt=num_images,
generator=generator
)
torch.cuda.empty_cache()
return result.images,seed
def clean_lora_cache():
"""Clean the LoRA cache directory"""
if os.path.exists(LORA_CACHE_DIR):
for file in os.listdir(LORA_CACHE_DIR):
file_path = os.path.join(LORA_CACHE_DIR, file)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
except Exception as e:
print(f"Error deleting {file_path}: {str(e)}")
title = """<h1 align="center">ProFaker</h1>"""
# Create the Gradio interface
with gr.Blocks() as demo:
gr.HTML(title)
with gr.Row():
with gr.Column():
# Input components
prompt = gr.Textbox(
label="Prompt",
info="Enter your image description here...",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
info="Enter what you don't want in Image...",
lines=3
)
lora_input = gr.Textbox(
label="LoRA URL/Path",
info="Enter Civitai download URL or HuggingFace path (e.g., 'username/model-name')",
value=DEFAULT_LORA
)
clear_cache = gr.Button("Clear LoRA Cache")
generate_button = gr.Button("Generate Image")
with gr.Accordion("Advanced Options", open=False):
model = gr.Dropdown(
choices=["Real6.0","Real5.1","Real5.0","majicv7"],
value="Real6.0",
label="Model",
)
num_images = gr.Slider(
minimum=1,
maximum=4,
value=1,
step=1,
label="Number of Images to Generate"
)
width = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Image Width"
)
height = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Image Height"
)
steps_slider = gr.Slider(
minimum=1,
maximum=100,
value=30,
step=1,
label="Number of Steps"
)
guidance_slider = gr.Slider(
minimum=1,
maximum=10,
value=7.0,
step=0.5,
label="Guidance Scale"
)
seed_input = gr.Number(value=random.randint(0, 2**32 - 1), label="Seed (optional)")
with gr.Column():
# Output component
gallery = gr.Gallery(
label="Generated Images",
show_label=True,
elem_id="gallery",
columns=2,
rows=2
)
seed_display = gr.Textbox(label="Seed Used", interactive=False)
# Connect the interface to the generation function
generate_button.click(
fn=generate_image,
inputs=[prompt, negative_prompt, lora_input, steps_slider, guidance_slider,
model, num_images, width, height,seed_input],
outputs=[gallery,seed_display]
)
# Connect clear cache button
clear_cache.click(fn=clean_lora_cache)
demo.queue(max_size=10).launch(share=False)