Rojban's picture
Update app.py
0c2f8e1
import gradio as gr
import os
from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline
import torch
import random
import uuid
#token = os.getenv("token")
#model = gr.load("models/Rojban/dreambooth4", hf_token=token)
prj_path = "dreambooth4"
model = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(
model,
torch_dtype=torch.float16,
)
pipe.to("cuda")
pipe.load_lora_weights(prj_path, weight_name="pytorch_lora_weights.safetensors")
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
torch_dtype=torch.float16,
)
refiner.to("cuda")
def generate_image(prompt, seed=None):
if seed is None:
seed = 253
seed = int(seed)
generator = torch.Generator("cuda").manual_seed(seed)
image = pipe(prompt=prompt, generator=generator, num_inference_steps=25).images[0]
image = refiner(prompt=prompt, generator=generator, image=image).images[0]
name = f"{seed}_{str(uuid.uuid4())}.png"
save_path = f"images/{name}"
image.save(save_path)
return save_path
# Create the Gradio interface
interface = gr.Interface(
fn=generate_image,
inputs=[gr.Textbox(label="Prompt"), gr.Number(label="Seed")],
outputs=gr.Image(type="filepath"),
title="Custom Stable Diffusion Model",
description="Generate images using a custom Stable Diffusion model.",
)
# Launch the app
if __name__ == "__main__":
interface.launch()