Spaces:
Running
Running
import os | |
import numpy as np | |
import random | |
from pathlib import Path | |
from PIL import Image | |
from insightface.app import FaceAnalysis | |
import streamlit as st | |
from huggingface_hub import InferenceClient, AsyncInferenceClient | |
from gradio_client import Client, handle_file | |
import asyncio | |
import insightface | |
from concurrent.futures import ThreadPoolExecutor | |
import yaml | |
try: | |
with open("config.yaml", "r") as file: | |
credentials = yaml.safe_load(file) | |
except Exception as e: | |
st.error(f"Error al cargar el archivo de configuraci贸n: {e}") | |
credentials = {"username": "", "password": ""} | |
MAX_SEED = np.iinfo(np.int32).max | |
HF_TOKEN_UPSCALER = os.environ.get("HF_TOKEN_UPSCALER") | |
client = AsyncInferenceClient() | |
llm_client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
DATA_PATH = Path("./data") | |
DATA_PATH.mkdir(exist_ok=True) | |
def prepare_face_app(): | |
app = FaceAnalysis(name='buffalo_l') | |
app.prepare(ctx_id=0, det_size=(640, 640)) | |
swapper = insightface.model_zoo.get_model('onix.onnx') | |
return app, swapper | |
app, swapper = prepare_face_app() | |
def run_async(func): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
executor = ThreadPoolExecutor(max_workers=1) | |
result = loop.run_in_executor(executor, func) | |
return loop.run_until_complete(result) | |
async def generate_image(combined_prompt, model, width, height, scales, steps, seed): | |
try: | |
if seed == -1: | |
seed = random.randint(0, MAX_SEED) | |
seed = int(seed) | |
image = await client.text_to_image( | |
prompt=combined_prompt, height=height, width=width, guidance_scale=scales, | |
num_inference_steps=steps, model=model | |
) | |
return image, seed | |
except Exception as e: | |
return f"Error al generar imagen: {e}", None | |
def get_upscale_finegrain(prompt, img_path, upscale_factor): | |
try: | |
client = Client("finegrain/finegrain-image-enhancer", hf_token=HF_TOKEN_UPSCALER) | |
result = client.predict( | |
input_image=handle_file(img_path), prompt=prompt, upscale_factor=upscale_factor | |
) | |
return result[1] if isinstance(result, list) and len(result) > 1 else None | |
except Exception as e: | |
return None | |
def save_prompt(prompt_text, seed): | |
try: | |
prompt_file_path = DATA_PATH / f"prompt_{seed}.txt" | |
with open(prompt_file_path, "w") as prompt_file: | |
prompt_file.write(prompt_text) | |
return prompt_file_path | |
except Exception as e: | |
st.error(f"Error al guardar el prompt: {e}") | |
return None | |
async def gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, process_enhancer, language): | |
combined_prompt = prompt | |
if process_enhancer: | |
improved_prompt = await improve_prompt(prompt, language) | |
combined_prompt = f"{prompt} {improved_prompt}" | |
if seed == -1: | |
seed = random.randint(0, MAX_SEED) | |
seed = int(seed) | |
progress_bar = st.progress(0) | |
image, seed = await generate_image(combined_prompt, basemodel, width, height, scales, steps, seed) | |
progress_bar.progress(50) | |
if isinstance(image, str) and image.startswith("Error"): | |
progress_bar.empty() | |
return [image, None, combined_prompt] | |
image_path = save_image(image, seed) | |
prompt_file_path = save_prompt(combined_prompt, seed) | |
if process_upscale: | |
upscale_image_path = get_upscale_finegrain(combined_prompt, image_path, upscale_factor) | |
if upscale_image_path: | |
upscale_image = Image.open(upscale_image_path) | |
upscale_image.save(DATA_PATH / f"upscale_image_{seed}.jpg", format="JPEG") | |
progress_bar.progress(100) | |
image_path.unlink() | |
return [str(DATA_PATH / f"upscale_image_{seed}.jpg"), str(prompt_file_path)] | |
else: | |
progress_bar.empty() | |
return [str(image_path), str(prompt_file_path)] | |
else: | |
progress_bar.progress(100) | |
return [str(image_path), str(prompt_file_path)] | |
async def improve_prompt(prompt, language): | |
try: | |
instruction_en = "With this idea, describe in English a detailed txt2img prompt in 500 characters at most, add illumination, atmosphere, cinematic elements, and characters if need it..." | |
instruction_es = "Con esta idea, describe en espa帽ol un prompt detallado de txt2img en un m谩ximo de 500 caracteres, con iluminaci贸n, atm贸sfera, elementos cinematogr谩ficos y en su caso personajes..." | |
instruction = instruction_en if language == "en" else instruction_es | |
formatted_prompt = f"{prompt}: {instruction}" | |
response = llm_client.text_generation(formatted_prompt, max_new_tokens=500) | |
improved_text = response['generated_text'].strip() if 'generated_text' in response else response.strip() | |
return improved_text[:500] if len(improved_text) > 500 else improved_text | |
except Exception as e: | |
return f"Error mejorando el prompt: {e}" | |
def save_image(image, seed): | |
try: | |
image_path = DATA_PATH / f"image_{seed}.jpg" | |
image.save(image_path, format="JPEG") | |
return image_path | |
except Exception as e: | |
st.error(f"Error al guardar la imagen: {e}") | |
return None | |
def get_storage(): | |
files = [file for file in DATA_PATH.glob("*.jpg") if file.is_file()] | |
files.sort(key=lambda x: x.stat().st_mtime, reverse=True) | |
usage = sum([file.stat().st_size for file in files]) | |
return [str(file.resolve()) for file in files], f"Uso total: {usage/(1024.0 ** 3):.3f}GB" | |
def get_prompts(): | |
prompt_files = [file for file in DATA_PATH.glob("*.txt") if file.is_file()] | |
return {file.stem.replace("prompt_", ""): file for file in prompt_files} | |
def delete_image(image_path): | |
try: | |
if Path(image_path).exists(): | |
Path(image_path).unlink() | |
st.success(f"Imagen {image_path} borrada.") | |
else: | |
st.error("El archivo de imagen no existe.") | |
except Exception as e: | |
st.error(f"Error al borrar la imagen: {e}") | |
def authenticate_user(username, password, credentials): | |
return username == credentials["username"] and password == credentials["password"] | |
def login_form(credentials): | |
st.title("Iniciar Sesi贸n") | |
username, password = st.text_input("Usuario"), st.text_input("Contrase帽a", type="password") | |
if st.button("Iniciar Sesi贸n") and authenticate_user(username, password, credentials): | |
st.session_state['authenticated'] = True | |
def sort_faces(faces): | |
return sorted(faces, key=lambda x: x.bbox[0]) | |
def get_face(faces, face_id): | |
if not faces: raise ValueError("No se encontraron rostros.") | |
if len(faces) < face_id or face_id < 1: | |
raise ValueError(f"Solo hay {len(faces)} rostros, pediste el {face_id}.") | |
return faces[face_id - 1] | |
def swap_faces(source_image, source_face_index, destination_image): | |
faces = sort_faces(app.get(source_image)) | |
source_face = get_face(faces, source_face_index) | |
res_faces = sort_faces(app.get(destination_image)) | |
res_face = get_face(res_faces, 1) | |
result = swapper.get(destination_image, res_face, source_face, paste_back=True) | |
return result | |
def main(): | |
st.set_page_config(layout="wide") | |
login_form(credentials) | |
if 'authenticated' not in st.session_state or not st.session_state['authenticated']: | |
st.warning("Por favor, inicia sesi贸n para acceder a la aplicaci贸n.") | |
return | |
prompt = st.sidebar.text_input("Descripci贸n de la imagen", max_chars=900) | |
process_enhancer = st.sidebar.checkbox("Mejorar Prompt", value=False) | |
language = st.sidebar.selectbox("Idioma", ["en", "es"]) | |
basemodel = st.sidebar.selectbox("Modelo Base", ["black-forest-labs/FLUX.1-DEV", "black-forest-labs/FLUX.1-schnell"]) | |
format_option = st.sidebar.selectbox("Formato", ["9:16", "16:9"]) | |
process_upscale = st.sidebar.checkbox("Procesar Escalador", value=False) | |
upscale_factor = st.sidebar.selectbox("Factor de Escala", [2, 4, 8], index=0) | |
scales = st.sidebar.slider("Escalado", 1, 20, 10) | |
steps = st.sidebar.slider("Pasos", 1, 100, 20) | |
seed = st.sidebar.number_input("Semilla", value=-1) | |
width, height = (1080, 1920) if format_option == "9:16" else (1920, 1080) | |
if st.sidebar.button("Generar Imagen"): | |
with st.spinner("Generando..."): | |
# Llamada a la funci贸n asincr贸nica desde un evento | |
image_path, prompt_file_path = asyncio.run(gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, process_enhancer, language)) | |
if image_path: | |
st.image(image_path, caption="Imagen Generada", use_column_width=True) | |
st.download_button("Descargar Imagen", image_path) | |
if st.sidebar.button("Ver Almacenamiento"): | |
files, usage = get_storage() | |
st.write(usage) | |
for file in files: | |
st.write(file) | |
if st.sidebar.button("Ver Prompts"): | |
prompts = get_prompts() | |
for key, path in prompts.items(): | |
st.write(f"{key}: {path}") | |
if st.sidebar.button("Borrar Imagen"): | |
image_to_delete = st.sidebar.text_input("Ruta de la imagen a borrar") | |
if image_to_delete: | |
delete_image(image_to_delete) | |
if __name__ == "__main__": | |
main() |