Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import threading | |
import spaces | |
from transformers import TextIteratorStreamer | |
print("Is CUDA available?", torch.cuda.is_available()) | |
class ModelWrapper: | |
def __init__(self): | |
self.model = None # Model will be loaded when GPU is allocated | |
def generate(self, prompt): | |
if self.model is None: | |
# Explicitly set device_map to 'cuda' | |
self.model = AutoGPTQForCausalLM.from_quantized( | |
model_id, | |
device_map={'': 'cuda:0'}, | |
trust_remote_code=True, | |
) | |
print("Model is on device:", next(self.model.parameters()).device) | |
# Tokenize the input prompt | |
inputs = tokenizer(prompt, return_tensors='pt').to('cuda') | |
print("Inputs are on device:", inputs['input_ids'].device) | |
# Set up the streamer | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
# Prepare generation arguments | |
generation_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
do_sample=True, | |
max_new_tokens=512, | |
) | |
# Start generation in a separate thread to enable streaming | |
thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield generated text in real-time | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
yield generated_text | |
model_wrapper = ModelWrapper() | |
interface = gr.Interface( | |
fn=model_wrapper.generate, | |
inputs=gr.Textbox(lines=5, label="Input Prompt"), | |
outputs=gr.Textbox(label="Generated Text"), | |
title="Mistral-Large-Instruct-2407 Text Completion", | |
description="Enter a prompt and receive a text completion using the Mistral-Large-Instruct-2407 INT4 model.", | |
allow_flagging='never', | |
live=False, | |
cache_examples=False, | |
streaming=True | |
) | |
if __name__ == "__main__": | |
interface.launch() |