|
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
import torch |
|
import gradio as gr |
|
import spaces |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
import requests |
|
import hashlib |
|
from pathlib import Path |
|
import re |
|
import random |
|
|
|
|
|
DEFAULT_LORA = "OedoSoldier/detail-tweaker-lora" |
|
LORA_CACHE_DIR = "lora_cache" |
|
|
|
def download_lora(url): |
|
"""Download LoRA file from Civitai URL and cache it locally""" |
|
|
|
os.makedirs(LORA_CACHE_DIR, exist_ok=True) |
|
|
|
|
|
url_hash = hashlib.md5(url.encode()).hexdigest() |
|
local_path = os.path.join(LORA_CACHE_DIR, f"{url_hash}.safetensors") |
|
|
|
|
|
if os.path.exists(local_path): |
|
print() |
|
print("********** Lora Already Exists **********") |
|
print() |
|
return local_path |
|
|
|
|
|
try: |
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
|
|
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
with open(local_path, 'wb') as f: |
|
if total_size == 0: |
|
f.write(response.content) |
|
else: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
print() |
|
print("********** Lora Downloading Successfull **********") |
|
print() |
|
return local_path |
|
except Exception as e: |
|
print() |
|
print(f"Error downloading LoRA: {str(e)}") |
|
print() |
|
return None |
|
|
|
def is_civitai_url(url): |
|
"""Check if the URL is a valid Civitai download URL""" |
|
return bool(re.match(r'https?://civitai\.com/api/download/models/\d+', url)) |
|
|
|
@spaces.GPU |
|
def generate_image(prompt, negative_prompt, lora_url, num_inference_steps=30, guidance_scale=7.0, |
|
model="Real6.0", num_images=1, width=512, height=512,seed=None): |
|
|
|
if model == "Real5.0": |
|
model_id = "SG161222/Realistic_Vision_V5.0_noVAE" |
|
elif model == "Real5.1": |
|
model_id = "SG161222/Realistic_Vision_V5.1_noVAE" |
|
elif model == "majicv7": |
|
model_id = "digiplay/majicMIX_realistic_v7" |
|
else: |
|
model_id = "SG161222/Realistic_Vision_V6.0_B1_noVAE" |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained( |
|
model_id, |
|
subfolder="vae" |
|
).to("cuda") |
|
|
|
text_encoder = CLIPTextModel.from_pretrained( |
|
model_id, |
|
subfolder="text_encoder" |
|
).to("cuda") |
|
|
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
model_id, |
|
subfolder="tokenizer" |
|
) |
|
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
model_id, |
|
subfolder="unet" |
|
).to("cuda") |
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
model_id, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
vae=vae |
|
).to("cuda") |
|
|
|
|
|
try: |
|
if lora_url and lora_url.strip(): |
|
if is_civitai_url(lora_url): |
|
|
|
lora_path = download_lora(lora_url) |
|
if lora_path: |
|
pipe.load_lora_weights(lora_path) |
|
print() |
|
print("********** URL Lora Loaded **********") |
|
print() |
|
|
|
else: |
|
pipe.load_lora_weights(DEFAULT_LORA) |
|
print() |
|
print("********** Default Lora Loaded **********") |
|
print() |
|
|
|
|
|
elif '/' in lora_url and not lora_url.startswith('http'): |
|
pipe.load_lora_weights(lora_url) |
|
print() |
|
print("********** URL Lora Loaded **********") |
|
print() |
|
else: |
|
pipe.load_lora_weights(DEFAULT_LORA) |
|
print() |
|
print("********** Default Lora Loaded **********") |
|
print() |
|
else: |
|
pipe.load_lora_weights(DEFAULT_LORA) |
|
except Exception as e: |
|
print() |
|
print(f"Error loading LoRA weights: {str(e)}") |
|
print() |
|
pipe.load_lora_weights(DEFAULT_LORA) |
|
|
|
if model == "Real6.0": |
|
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images)) |
|
|
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config( |
|
pipe.scheduler.config, |
|
algorithm_type="dpmsolver++", |
|
use_karras_sigmas=True |
|
) |
|
|
|
if seed is None: |
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
generator = torch.manual_seed(seed) |
|
|
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt" |
|
).to("cuda") |
|
|
|
negative_text_inputs = tokenizer( |
|
negative_prompt, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt" |
|
).to("cuda") |
|
|
|
prompt_embeds = text_encoder(text_inputs.input_ids)[0] |
|
negative_prompt_embeds = text_encoder(negative_text_inputs.input_ids)[0] |
|
|
|
|
|
result = pipe( |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
cross_attention_kwargs={"scale": 1}, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
width=width, |
|
height=height, |
|
num_images_per_prompt=num_images, |
|
generator=generator |
|
) |
|
torch.cuda.empty_cache() |
|
return result.images,seed |
|
|
|
def clean_lora_cache(): |
|
"""Clean the LoRA cache directory""" |
|
if os.path.exists(LORA_CACHE_DIR): |
|
for file in os.listdir(LORA_CACHE_DIR): |
|
file_path = os.path.join(LORA_CACHE_DIR, file) |
|
try: |
|
if os.path.isfile(file_path): |
|
os.unlink(file_path) |
|
except Exception as e: |
|
print(f"Error deleting {file_path}: {str(e)}") |
|
|
|
title = """<h1 align="center">ProFaker</h1>""" |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(title) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
info="Enter your image description here...", |
|
lines=3 |
|
) |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
info="Enter what you don't want in Image...", |
|
lines=3 |
|
) |
|
lora_input = gr.Textbox( |
|
label="LoRA URL/Path", |
|
info="Enter Civitai download URL or HuggingFace path (e.g., 'username/model-name')", |
|
value=DEFAULT_LORA |
|
) |
|
clear_cache = gr.Button("Clear LoRA Cache") |
|
generate_button = gr.Button("Generate Image") |
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
model = gr.Dropdown( |
|
choices=["Real6.0","Real5.1","Real5.0","majicv7"], |
|
value="Real6.0", |
|
label="Model", |
|
) |
|
|
|
num_images = gr.Slider( |
|
minimum=1, |
|
maximum=4, |
|
value=1, |
|
step=1, |
|
label="Number of Images to Generate" |
|
) |
|
width = gr.Slider( |
|
minimum=256, |
|
maximum=1024, |
|
value=512, |
|
step=64, |
|
label="Image Width" |
|
) |
|
height = gr.Slider( |
|
minimum=256, |
|
maximum=1024, |
|
value=512, |
|
step=64, |
|
label="Image Height" |
|
) |
|
steps_slider = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=30, |
|
step=1, |
|
label="Number of Steps" |
|
) |
|
guidance_slider = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=7.0, |
|
step=0.5, |
|
label="Guidance Scale" |
|
) |
|
seed_input = gr.Number(value=random.randint(0, 2**32 - 1), label="Seed (optional)") |
|
|
|
with gr.Column(): |
|
|
|
gallery = gr.Gallery( |
|
label="Generated Images", |
|
show_label=True, |
|
elem_id="gallery", |
|
columns=2, |
|
rows=2 |
|
) |
|
seed_display = gr.Textbox(label="Seed Used", interactive=False) |
|
|
|
|
|
generate_button.click( |
|
fn=generate_image, |
|
inputs=[prompt, negative_prompt, lora_input, steps_slider, guidance_slider, |
|
model, num_images, width, height,seed_input], |
|
outputs=[gallery,seed_display] |
|
) |
|
|
|
|
|
clear_cache.click(fn=clean_lora_cache) |
|
|
|
demo.queue(max_size=10).launch(share=False) |