Voice-CPU / app.py
Staticaliza's picture
Update app.py
a6370a9 verified
raw
history blame
2.01 kB
import gradio as gr
import torch
import threading
import spaces
from transformers import TextIteratorStreamer
print("Is CUDA available?", torch.cuda.is_available())
class ModelWrapper:
def __init__(self):
self.model = None # Model will be loaded when GPU is allocated
@spaces.GPU
def generate(self, prompt):
if self.model is None:
# Explicitly set device_map to 'cuda'
self.model = AutoGPTQForCausalLM.from_quantized(
model_id,
device_map={'': 'cuda:0'},
trust_remote_code=True,
)
print("Model is on device:", next(self.model.parameters()).device)
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors='pt').to('cuda')
print("Inputs are on device:", inputs['input_ids'].device)
# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Prepare generation arguments
generation_kwargs = dict(
**inputs,
streamer=streamer,
do_sample=True,
max_new_tokens=512,
)
# Start generation in a separate thread to enable streaming
thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Yield generated text in real-time
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
model_wrapper = ModelWrapper()
interface = gr.Interface(
fn=model_wrapper.generate,
inputs=gr.Textbox(lines=5, label="Input Prompt"),
outputs=gr.Textbox(label="Generated Text"),
title="Mistral-Large-Instruct-2407 Text Completion",
description="Enter a prompt and receive a text completion using the Mistral-Large-Instruct-2407 INT4 model.",
allow_flagging='never',
live=False,
cache_examples=False,
streaming=True
)
if __name__ == "__main__":
interface.launch()