Spaces:
Sleeping
Sleeping
import eventlet | |
eventlet.monkey_patch(socket=True, select=True, thread=True) | |
import eventlet.wsgi | |
from flask import Flask, render_template, request | |
from flask_socketio import SocketIO | |
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
app = Flask(__name__) | |
socketio = SocketIO( | |
app, | |
async_mode='eventlet', | |
message_queue=None, | |
ping_timeout=60, | |
ping_interval=25, | |
cors_allowed_origins="*", | |
logger=True, | |
engineio_logger=True, | |
async_handlers=True | |
) | |
# Initialize models and tokenizers | |
MODELS = { | |
"qwen": { | |
"name": "Qwen/Qwen2.5-0.5B-Instruct", | |
"tokenizer": None, | |
"model": None, | |
"uses_chat_template": True # Qwen uses chat template | |
}, | |
"gpt2": { | |
"name": "gpt2", | |
"tokenizer": None, | |
"model": None, | |
"uses_chat_template": False # GPT2 doesn't use chat template | |
} | |
} | |
# Load models and tokenizers | |
for model_key, model_info in MODELS.items(): | |
model_info["tokenizer"] = AutoTokenizer.from_pretrained(model_info["name"]) | |
model_info["model"] = AutoModelForCausalLM.from_pretrained( | |
model_info["name"], | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
# Add pad token for GPT2 if it doesn't have one | |
if model_key == "gpt2" and model_info["tokenizer"].pad_token is None: | |
model_info["tokenizer"].pad_token = model_info["tokenizer"].eos_token | |
model_info["model"].config.pad_token_id = model_info["model"].config.eos_token_id | |
class WebSocketBeamStreamer(MultiBeamTextStreamer): | |
"""Custom streamer that sends updates through websockets with adjustable speed""" | |
def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True): | |
super().__init__( | |
tokenizer, | |
num_beams=num_beams, | |
skip_prompt=skip_prompt, | |
on_beam_update=self.on_beam_update, | |
on_beam_finished=self.on_beam_finished | |
) | |
self.beam_texts = {i: "" for i in range(num_beams)} | |
self.sleep_time = sleep_time | |
def on_beam_update(self, beam_idx: int, new_text: str): | |
self.beam_texts[beam_idx] = new_text | |
if self.sleep_time > 0: | |
eventlet.sleep(self.sleep_time / 1000) | |
socketio.emit('beam_update', { | |
'beam_idx': beam_idx, | |
'text': new_text | |
}, namespace='/', callback=lambda: eventlet.sleep(0)) | |
socketio.sleep(0) | |
def on_beam_finished(self, final_text: str): | |
socketio.emit('beam_finished', { | |
'text': final_text | |
}) | |
def index(): | |
return render_template('index.html') | |
def handle_generation(data): | |
socketio.emit('generation_started') | |
prompt = data['prompt'] | |
model_name = data.get('model', 'qwen') # Default to qwen if not specified | |
num_beams = data.get('num_beams', 5) | |
max_new_tokens = data.get('max_tokens', 512) | |
sleep_time = data.get('sleep_time', 0) | |
# Get the selected model info | |
model_info = MODELS[model_name] | |
model = model_info["model"] | |
tokenizer = model_info["tokenizer"] | |
# Prepare input text based on model type | |
if model_info["uses_chat_template"]: | |
# For Qwen, use chat template | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
else: | |
# For GPT2, use direct prompt | |
text = prompt | |
# Prepare inputs | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
# Initialize streamer | |
streamer = WebSocketBeamStreamer( | |
tokenizer=tokenizer, | |
num_beams=num_beams, | |
sleep_time=sleep_time, | |
skip_prompt=True | |
) | |
try: | |
# Generate with beam search | |
with torch.no_grad(): | |
model.generate( | |
**model_inputs, | |
num_beams=num_beams, | |
num_return_sequences=num_beams, | |
max_new_tokens=max_new_tokens, | |
output_scores=True, | |
return_dict_in_generate=True, | |
early_stopping=True, | |
streamer=streamer, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
except Exception as e: | |
socketio.emit('generation_error', {'error': str(e)}) | |
finally: | |
socketio.emit('generation_completed') |