gfgf / app.py
Ffftdtd5dtft's picture
Update app.py
1654627 verified
import os
import uuid
import redis
import torch
import scipy
from transformers import (
pipeline, AutoTokenizer, AutoModelForCausalLM, AutoProcessor,
MusicgenForConditionalGeneration, WhisperProcessor, WhisperForConditionalGeneration,
MarianMTModel, MarianTokenizer, BartTokenizer, BartForConditionalGeneration
)
from diffusers import (
FluxPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler,
StableDiffusionImg2ImgPipeline, DiffusionPipeline
)
from diffusers.utils import export_to_video
from datasets import load_dataset
from PIL import Image
import gradio as gr
from dotenv import load_dotenv
import multiprocessing
load_dotenv()
redis_client = redis.Redis(
host=os.getenv('REDIS_HOST'),
port=os.getenv('REDIS_PORT'),
password=os.getenv("REDIS_PASSWORD")
)
huggingface_token = os.getenv('HF_TOKEN')
def generate_unique_id():
return str(uuid.uuid4())
def store_special_tokens(tokenizer, model_name):
special_tokens = {
'pad_token': tokenizer.pad_token,
'pad_token_id': tokenizer.pad_token_id,
'eos_token': tokenizer.eos_token,
'eos_token_id': tokenizer.eos_token_id,
'unk_token': tokenizer.unk_token,
'unk_token_id': tokenizer.unk_token_id,
'bos_token': tokenizer.bos_token,
'bos_token_id': tokenizer.bos_token_id
}
redis_client.hmset(f"tokenizer_special_tokens:{model_name}", special_tokens)
def load_special_tokens(tokenizer, model_name):
special_tokens = redis_client.hgetall(f"tokenizer_special_tokens:{model_name}")
if special_tokens:
tokenizer.pad_token = special_tokens.get('pad_token', '').decode("utf-8")
tokenizer.pad_token_id = int(special_tokens.get('pad_token_id', -1))
tokenizer.eos_token = special_tokens.get('eos_token', '').decode("utf-8")
tokenizer.eos_token_id = int(special_tokens.get('eos_token_id', -1))
tokenizer.unk_token = special_tokens.get('unk_token', '').decode("utf-8")
tokenizer.unk_token_id = int(special_tokens.get('unk_token_id', -1))
tokenizer.bos_token = special_tokens.get('bos_token', '').decode("utf-8")
tokenizer.bos_token_id = int(special_tokens.get('bos_token_id', -1))
def train_and_store_transformers_model(model_name, data):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.train()
store_special_tokens(tokenizer, model_name)
torch.save(model.state_dict(), "transformers_model.pt")
with open("transformers_model.pt", "rb") as f:
model_data = f.read()
redis_client.set(f"transformers_model:{model_name}:state_dict", model_data)
tokenizer_data = tokenizer.save_pretrained("transformers_tokenizer")
redis_client.set(f"transformers_tokenizer:{model_name}", tokenizer_data)
def generate_transformers_response_from_redis(model_name, prompt):
unique_id = generate_unique_id()
model_data = redis_client.get(f"transformers_model:{model_name}:state_dict")
with open("transformers_model.pt", "wb") as f:
f.write(model_data)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.load_state_dict(torch.load("transformers_model.pt"))
tokenizer_data = redis_client.get(f"transformers_tokenizer:{model_name}")
tokenizer = AutoTokenizer.from_pretrained("transformers_tokenizer")
load_special_tokens(tokenizer, model_name)
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=50)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
redis_client.set(f"transformers_response:{unique_id}", response)
return response
def train_and_store_diffusers_model(model_name, data):
pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
pipe.train()
pipe.save_pretrained("diffusers_model")
with open("diffusers_model/flux_pipeline.pt", "rb") as f:
model_data = f.read()
redis_client.set(f"diffusers_model:{model_name}", model_data)
def generate_diffusers_image_from_redis(model_name, prompt):
unique_id = generate_unique_id()
model_data = redis_client.get(f"diffusers_model:{model_name}")
with open("diffusers_model/flux_pipeline.pt", "wb") as f:
f.write(model_data)
pipe = FluxPipeline.from_pretrained("diffusers_model", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
image = pipe(prompt, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)).images[0]
image_path = f"images/diffusers_{unique_id}.png"
image.save(image_path)
redis_client.set(f"diffusers_image:{unique_id}", image_path)
return image
def train_and_store_musicgen_model(model_name, data):
processor = AutoProcessor.from_pretrained(model_name)
model = MusicgenForConditionalGeneration.from_pretrained(model_name)
model.train()
torch.save(model.state_dict(), "musicgen_model.pt")
with open("musicgen_model.pt", "rb") as f:
model_data = f.read()
redis_client.set(f"musicgen_model:{model_name}:state_dict", model_data)
processor_data = processor.save_pretrained("musicgen_processor")
redis_client.set(f"musicgen_processor:{model_name}", processor_data)
def generate_musicgen_audio_from_redis(model_name, text_prompts):
unique_id = generate_unique_id()
model_data = redis_client.get(f"musicgen_model:{model_name}:state_dict")
with open("musicgen_model.pt", "wb") as f:
f.write(model_data)
model = MusicgenForConditionalGeneration.from_pretrained(model_name)
model.load_state_dict(torch.load("musicgen_model.pt"))
processor_data = redis_client.get(f"musicgen_processor:{model_name}")
processor = AutoProcessor.from_pretrained("musicgen_processor")
inputs = processor(text=text_prompts, padding=True, return_tensors="pt")
audio_values = model.generate(**inputs, max_new_tokens=256)
audio_path = f"audio/musicgen_{unique_id}.wav"
scipy.io.wavfile.write(audio_path, rate=audio_values["sampling_rate"], data=audio_values["audio"])
redis_client.set(f"musicgen_audio:{unique_id}", audio_path)
return audio_path
def train_and_store_stable_diffusion_model(model_name, data):
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
pipe.train()
pipe.save_pretrained("stable_diffusion_model")
with open("stable_diffusion_model/stable_diffusion_pipeline.pt", "rb") as f:
model_data = f.read()
redis_client.set(f"stable_diffusion_model:{model_name}", model_data)
def generate_stable_diffusion_image_from_redis(model_name, prompt):
unique_id = generate_unique_id()
model_data = redis_client.get(f"stable_diffusion_model:{model_name}")
with open("stable_diffusion_model/stable_diffusion_pipeline.pt", "wb") as f:
f.write(model_data)
pipe = StableDiffusionPipeline.from_pretrained("stable_diffusion_model", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
image = pipe(prompt).images[0]
image_path = f"images/stable_diffusion_{unique_id}.png"
image.save(image_path)
redis_client.set(f"stable_diffusion_image:{unique_id}", image_path)
return image
def train_and_store_img2img_model(model_name, data):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.train()
pipe.save_pretrained("img2img_model")
with open("img2img_model/img2img_pipeline.pt", "rb") as f:
model_data = f.read()
redis_client.set(f"img2img_model:{model_name}", model_data)
def generate_img2img_from_redis(model_name, init_image, prompt, strength=0.75):
unique_id = generate_unique_id()
model_data = redis_client.get(f"img2img_model:{model_name}")
with open("img2img_model/img2img_pipeline.pt", "wb") as f:
f.write(model_data)
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("img2img_model", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
init_image = Image.open(init_image).convert("RGB")
image = pipe(prompt=prompt, init_image=init_image, strength=strength).images[0]
image_path = f"images/img2img_{unique_id}.png"
image.save(image_path)
redis_client.set(f"img2img_image:{unique_id}", image_path)
return image
def train_and_store_marianmt_model(model_name, data):
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
model.train()
torch.save(model.state_dict(), "marianmt_model.pt")
with open("marianmt_model.pt", "rb") as f:
model_data = f.read()
redis_client.set(f"marianmt_model:{model_name}:state_dict", model_data)
tokenizer_data = tokenizer.save_pretrained("marianmt_tokenizer")
redis_client.set(f"marianmt_tokenizer:{model_name}", tokenizer_data)
def translate_text_from_redis(model_name, text, src_lang, tgt_lang):
unique_id = generate_unique_id()
model_data = redis_client.get(f"marianmt_model:{model_name}:state_dict")
with open("marianmt_model.pt", "wb") as f:
f.write(model_data)
model = MarianMTModel.from_pretrained(model_name)
model.load_state_dict(torch.load("marianmt_model.pt"))
tokenizer_data = redis_client.get(f"marianmt_tokenizer:{model_name}")
tokenizer = MarianTokenizer.from_pretrained("marianmt_tokenizer")
inputs = tokenizer(text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang)
translated_tokens = model.generate(**inputs)
translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
redis_client.set(f"marianmt_translation:{unique_id}", translation)
return translation
def train_and_store_bart_model(model_name, data):
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
model.train()
torch.save(model.state_dict(), "bart_model.pt")
with open("bart_model.pt", "rb") as f:
model_data = f.read()
redis_client.set(f"bart_model:{model_name}:state_dict", model_data)
tokenizer_data = tokenizer.save_pretrained("bart_tokenizer")
redis_client.set(f"bart_tokenizer:{model_name}", tokenizer_data)
def summarize_text_from_redis(model_name, text):
unique_id = generate_unique_id()
model_data = redis_client.get(f"bart_model:{model_name}:state_dict")
with open("bart_model.pt", "wb") as f:
f.write(model_data)
model = BartForConditionalGeneration.from_pretrained(model_name)
model.load_state_dict(torch.load("bart_model.pt"))
tokenizer_data = redis_client.get(f"bart_tokenizer:{model_name}")
tokenizer = BartTokenizer.from_pretrained("bart_tokenizer")
load_special_tokens(tokenizer, model_name)
inputs = tokenizer(text, return_tensors="pt", truncation=True)
summary_ids = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
redis_client.set(f"bart_summary:{unique_id}", summary)
return summary
def auto_train_and_store(model_name, task, data):
if task == "text-generation":
train_and_store_transformers_model(model_name, data)
elif task == "diffusers":
train_and_store_diffusers_model(model_name, data)
elif task == "musicgen":
train_and_store_musicgen_model(model_name, data)
elif task == "stable-diffusion":
train_and_store_stable_diffusion_model(model_name, data)
elif task == "img2img":
train_and_store_img2img_model(model_name, data)
elif task == "translation":
train_and_store_marianmt_model(model_name, data)
elif task == "summarization":
train_and_store_bart_model(model_name, data)
def transcribe_audio_from_redis(audio_file):
audio_file_path = "audio_file.wav"
with open(audio_file_path, "wb") as f:
f.write(audio_file)
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.forced_decoder_ids = None
input_features = processor(audio_file, sampling_rate=16000, return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription[0]
def generate_image_from_redis(model_name, prompt, model_type):
if model_type == "diffusers":
image = generate_diffusers_image_from_redis(model_name, prompt)
elif model_type == "stable-diffusion":
image = generate_stable_diffusion_image_from_redis(model_name, prompt)
elif model_type == "img2img":
image = generate_img2img_from_redis(model_name, "init_image.png", prompt)
return image
def generate_video_from_redis(prompt):
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16,
variant="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
video_frames = pipe(prompt, num_inference_steps=25).frames
video_path = export_to_video(video_frames)
unique_id = generate_unique_id()
redis_client.set(f"video_{unique_id}", video_path)
return video_path
def generate_random_response(prompts, generator):
responses = []
for prompt in prompts:
response = generator(prompt, max_length=50)[0]['generated_text']
responses.append(response)
return responses
def process_parallel(tasks):
with multiprocessing.Pool() as pool:
results = pool.map(lambda task: task(), tasks)
return results
def generate_response_from_prompt(prompt, model_name="google/flan-t5-xl"):
generator = pipeline('text-generation', model=model_name, tokenizer=model_name)
responses = generate_random_response([prompt], generator)
return responses[0]
def generate_image_from_prompt(prompt, image_type, model_name="CompVis/stable-diffusion-v1-4"):
if image_type == "diffusers":
image = generate_diffusers_image_from_redis(model_name, prompt)
elif image_type == "stable-diffusion":
image = generate_stable_diffusion_image_from_redis(model_name, prompt)
elif image_type == "img2img":
image = generate_img2img_from_redis(model_name, "init_image.png", prompt)
return image
def gradio_app():
with gr.Blocks() as app:
gr.Markdown(
"""
# IA Generativa con Transformers y Diffusers
Explora diferentes modelos de IA para generar texto, im谩genes, audio, video y m谩s.
"""
)
with gr.Tab("Texto"):
with gr.Row():
with gr.Column():
prompt_text = gr.Textbox(label="Texto de Entrada", placeholder="Ingresa tu prompt de texto aqu铆...")
text_button = gr.Button("Generar Texto", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Respuesta")
text_button.click(generate_response_from_prompt, inputs=prompt_text, outputs=text_output)
with gr.Tab("Imagen"):
with gr.Row():
with gr.Column():
prompt_image = gr.Textbox(label="Prompt de Imagen",
placeholder="Ingresa tu prompt de imagen aqu铆...")
image_type = gr.Dropdown(["diffusers", "stable-diffusion", "img2img"], label="Tipo de Modelo",
value="stable-diffusion")
model_name_image = gr.Textbox(label="Nombre del Modelo",
value="CompVis/stable-diffusion-v1-4")
image_button = gr.Button("Generar Imagen", variant="primary")
with gr.Column():
image_output = gr.Image(label="Imagen Generada")
image_button.click(generate_image_from_prompt, inputs=[prompt_image, image_type, model_name_image],
outputs=image_output)
with gr.Tab("Video"):
with gr.Row():
with gr.Column():
prompt_video = gr.Textbox(label="Prompt de Video", placeholder="Ingresa tu prompt de video aqu铆...")
video_button = gr.Button("Generar Video", variant="primary")
with gr.Column():
video_output = gr.Video(label="Video Generado")
video_button.click(generate_video_from_redis, inputs=prompt_video, outputs=video_output)
with gr.Tab("Audio"):
with gr.Row():
with gr.Column():
model_name_audio = gr.Textbox(label="Nombre del Modelo", value="facebook/musicgen-small")
text_prompts_audio = gr.Textbox(label="Prompts de Audio",
placeholder="Ingresa tus prompts de audio aqu铆...")
audio_button = gr.Button("Generar Audio", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Audio Generado")
audio_button.click(generate_musicgen_audio_from_redis, inputs=[model_name_audio, text_prompts_audio],
outputs=audio_output)
with gr.Tab("Transcripci贸n"):
with gr.Row():
with gr.Column():
audio_file = gr.Audio(type="filepath", label="Archivo de Audio")
audio_button = gr.Button("Transcribir Audio", variant="primary")
with gr.Column():
transcription_output = gr.Textbox(label="Transcripci贸n")
audio_button.click(transcribe_audio_from_redis, inputs=audio_file, outputs=transcription_output)
with gr.Tab("Traducci贸n"):
with gr.Row():
with gr.Column():
model_name_translate = gr.Textbox(label="Nombre del Modelo", value="Helsinki-NLP/opus-mt-en-es")
text_input = gr.Textbox(label="Texto a Traducir", placeholder="Ingresa el texto a traducir...")
src_lang_input = gr.Textbox(label="Idioma de Origen", value="en")
tgt_lang_input = gr.Textbox(label="Idioma de Destino", value="es")
translate_button = gr.Button("Traducir Texto", variant="primary")
with gr.Column():
translation_output = gr.Textbox(label="Traducci贸n")
translate_button.click(translate_text_from_redis,
inputs=[model_name_translate, text_input, src_lang_input, tgt_lang_input],
outputs=translation_output)
with gr.Tab("Resumen"):
with gr.Row():
with gr.Column():
model_name_summarize = gr.Textbox(label="Nombre del Modelo", value="facebook/bart-large-cnn")
text_to_summarize = gr.Textbox(label="Texto para Resumir",
placeholder="Ingresa el texto a resumir...")
summarize_button = gr.Button("Generar Resumen", variant="primary")
with gr.Column():
summary_output = gr.Textbox(label="Resumen")
summarize_button.click(summarize_text_from_redis, inputs=[model_name_summarize, text_to_summarize],
outputs=summary_output)
app.launch()
if __name__ == "__main__":
gradio_app()