flux3 / app.py
salomonsky's picture
Update app.py
99a5876 verified
raw
history blame
7.04 kB
import os
import numpy as np
import random
from pathlib import Path
from PIL import Image
import streamlit as st
from huggingface_hub import InferenceClient, AsyncInferenceClient
from gradio_client import Client, handle_file
import asyncio
from concurrent.futures import ThreadPoolExecutor
MAX_SEED = np.iinfo(np.int32).max
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_TOKEN_UPSCALER = os.environ.get("HF_TOKEN_UPSCALER")
client = AsyncInferenceClient()
llm_client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
DATA_PATH = Path("./data")
DATA_PATH.mkdir(exist_ok=True)
def run_async(func):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
executor = ThreadPoolExecutor(max_workers=1)
result = loop.run_in_executor(executor, func)
return loop.run_until_complete(result)
def enable_lora(lora_add, basemodel):
return lora_add if lora_add else basemodel
async def generate_image(combined_prompt, model, width, height, scales, steps, seed):
try:
if seed == -1:
seed = random.randint(0, MAX_SEED)
seed = int(seed)
image = await client.text_to_image(
prompt=combined_prompt, height=height, width=width, guidance_scale=scales,
num_inference_steps=steps, model=model
)
return image, seed
except Exception as e:
return f"Error al generar imagen: {e}", None
def get_upscale_finegrain(prompt, img_path, upscale_factor):
try:
client = Client("finegrain/finegrain-image-enhancer", hf_token=HF_TOKEN_UPSCALER)
result = client.predict(
input_image=handle_file(img_path), prompt=prompt, upscale_factor=upscale_factor
)
return result[1] if isinstance(result, list) and len(result) > 1 else None
except Exception as e:
return None
def save_prompt(prompt_text, seed):
try:
prompt_file_path = DATA_PATH / f"prompt_{seed}.txt"
with open(prompt_file_path, "w") as prompt_file:
prompt_file.write(prompt_text)
return prompt_file_path
except Exception as e:
st.error(f"Error al guardar el prompt: {e}")
return None
def save_image(image, seed):
image_path = DATA_PATH / f"generated_image_{seed}.jpg"
image.save(image_path)
return image_path
async def improve_prompt(prompt, language):
try:
instruction = (
"With this idea, describe in English a detailed txt2img prompt in 500 characters at most, add illumination, atmosphere, cinematic elements, and characters..."
if language == "English"
else "Con esta idea, describe en espa帽ol un prompt detallado de txt2img en un m谩ximo de 500 caracteres, a帽adiendo iluminaci贸n, atm贸sfera, elementos cinematogr谩ficos y personajes..."
)
formatted_prompt = f"{prompt}: {instruction}"
response = llm_client.text_generation(formatted_prompt, max_new_tokens=300)
improved_text = response['generated_text'].strip() if 'generated_text' in response else response.strip()
return improved_text[:300] if len(improved_text) > 300 else improved_text
except Exception as e:
return f"Error mejorando el prompt: {e}"
async def gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora, process_enhancer, prompt_language):
model = enable_lora(lora_model, basemodel) if process_lora else basemodel
combined_prompt = prompt # Usar el prompt original por defecto
if process_enhancer:
improved_prompt = await improve_prompt(prompt, prompt_language)
combined_prompt = f"{prompt} {improved_prompt}"
if seed == -1:
seed = random.randint(0, MAX_SEED)
seed = int(seed)
progress_bar = st.progress(0)
image, seed = await generate_image(combined_prompt, model, width, height, scales, steps, seed)
progress_bar.progress(50)
if isinstance(image, str) and image.startswith("Error"):
progress_bar.empty()
return [image, None, combined_prompt]
image_path = save_image(image, seed)
prompt_file_path = save_prompt(combined_prompt, seed)
if process_upscale:
upscale_image_path = get_upscale_finegrain(combined_prompt, image_path, upscale_factor)
if upscale_image_path:
upscale_image = Image.open(upscale_image_path)
upscale_image.save(DATA_PATH / f"upscale_image_{seed}.jpg", format="JPEG")
progress_bar.progress(100)
image_path.unlink()
return [str(DATA_PATH / f"upscale_image_{seed}.jpg"), str(prompt_file_path)]
else:
progress_bar.empty()
return [str(image_path), str(prompt_file_path)]
else:
progress_bar.progress(100)
return [str(image_path), str(prompt_file_path)]
def main():
st.set_page_config(layout="wide")
st.title("FLUX with enhancer and upscaler with LORA model training")
prompt = st.sidebar.text_input("Descripci贸n de la imagen", max_chars=200)
process_enhancer = st.sidebar.checkbox("Mejorar Prompt", value=True)
prompt_language = st.sidebar.selectbox("Idioma para mejorar el prompt", ["English", "Spanish"])
basemodel = st.sidebar.selectbox("Modelo Base", ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-DEV"])
lora_model = st.sidebar.selectbox("LORA Realismo", ["Shakker-Labs/FLUX.1-dev-LoRA-add-details", "XLabs-AI/flux-RealismLora"])
format_option = st.sidebar.selectbox("Formato", ["9:16", "16:9"])
process_lora = st.sidebar.checkbox("Procesar LORA", value=True)
process_upscale = st.sidebar.checkbox("Procesar Escalador", value=True)
upscale_factor = st.sidebar.selectbox("Factor de Escala", [2, 4, 8], index=0)
scales = st.sidebar.slider("Escalado", 1, 20, 10)
steps = st.sidebar.slider("Pasos", 1, 100, 20)
seed = st.sidebar.number_input("Semilla", value=-1)
if format_option == "9:16":
width = 720
height = 1280
else:
width = 1280
height = 720
if st.sidebar.button("Generar Imagen"):
with st.spinner("Mejorando y generando imagen..."):
result = asyncio.run(gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora, process_enhancer, prompt_language))
image_paths = result[0]
prompt_file = result[1]
st.write(f"Image paths: {image_paths}")
if image_paths:
if Path(image_paths).exists():
st.image(image_paths, caption="Imagen Generada")
else:
st.error("El archivo de imagen no existe.")
if prompt_file and Path(prompt_file).exists():
prompt_text = Path(prompt_file).read_text()
st.write(f"Prompt utilizado: {prompt_text}")
else:
st.write("El archivo del prompt no est谩 disponible.")
else:
st.error("No se pudo generar la imagen.")
if __name__ == "__main__":
main()