EV5V's picture
Update app.py
88bc7dd verified
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
# Função para carregar o modelo base com LoRa
def load_model_with_lora(base_model_path, lora_model_path):
# Carrega o modelo base
pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float16).to("cuda")
# Carrega a LoRa e aplica ao modelo base
pipeline.load_lora_weights(lora_model_path)
return pipeline
# Função de inferência de imagem com parâmetros ajustáveis
def infer_image(prompt, steps, cfg_scale, seed, width, height):
pipeline = load_model_with_lora("black-forest-labs/FLUX.1-dev", "rorito/testSCG-Anatomy-Flux1")
# Configurações adicionais
generator = torch.manual_seed(seed)
result = pipeline(prompt, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator)
return result.images[0]
# Interface do Gradio com parâmetros ajustáveis
interface = gr.Interface(
fn=infer_image,
inputs=[
gr.Textbox(label="Prompt"),
gr.Slider(minimum=10, maximum=150, label="Número de Passos", value=50),
gr.Slider(minimum=1.0, maximum=20.0, label="Escala CFG", value=7.5),
gr.Number(label="Seed (Semente)", value=42),
gr.Slider(minimum=256, maximum=1024, label="Largura da Imagem", value=512),
gr.Slider(minimum=256, maximum=1024, label="Altura da Imagem", value=512)
],
outputs="image"
)
# Lançar o aplicativo Gradio
interface.launch()