import eventlet eventlet.monkey_patch(socket=True, select=True, thread=True) import eventlet.wsgi from flask import Flask, render_template, request from flask_socketio import SocketIO from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM import torch app = Flask(__name__) socketio = SocketIO( app, async_mode='eventlet', message_queue=None, ping_timeout=60, ping_interval=25, cors_allowed_origins="*", logger=True, engineio_logger=True, async_handlers=True ) # Initialize models and tokenizers MODELS = { "qwen": { "name": "Qwen/Qwen2.5-0.5B-Instruct", "tokenizer": None, "model": None, "uses_chat_template": True # Qwen uses chat template }, "gpt2": { "name": "gpt2", "tokenizer": None, "model": None, "uses_chat_template": False # GPT2 doesn't use chat template } } # Load models and tokenizers for model_key, model_info in MODELS.items(): model_info["tokenizer"] = AutoTokenizer.from_pretrained(model_info["name"]) model_info["model"] = AutoModelForCausalLM.from_pretrained( model_info["name"], torch_dtype="auto", device_map="auto" ) # Add pad token for GPT2 if it doesn't have one if model_key == "gpt2" and model_info["tokenizer"].pad_token is None: model_info["tokenizer"].pad_token = model_info["tokenizer"].eos_token model_info["model"].config.pad_token_id = model_info["model"].config.eos_token_id class WebSocketBeamStreamer(MultiBeamTextStreamer): """Custom streamer that sends updates through websockets with adjustable speed""" def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True): super().__init__( tokenizer, num_beams=num_beams, skip_prompt=skip_prompt, on_beam_update=self.on_beam_update, on_beam_finished=self.on_beam_finished ) self.beam_texts = {i: "" for i in range(num_beams)} self.sleep_time = sleep_time def on_beam_update(self, beam_idx: int, new_text: str): self.beam_texts[beam_idx] = new_text if self.sleep_time > 0: eventlet.sleep(self.sleep_time / 1000) socketio.emit('beam_update', { 'beam_idx': beam_idx, 'text': new_text }, namespace='/', callback=lambda: eventlet.sleep(0)) socketio.sleep(0) def on_beam_finished(self, final_text: str): socketio.emit('beam_finished', { 'text': final_text }) @app.route('/') def index(): return render_template('index.html') @socketio.on('generate') def handle_generation(data): socketio.emit('generation_started') prompt = data['prompt'] model_name = data.get('model', 'qwen') # Default to qwen if not specified num_beams = data.get('num_beams', 5) max_new_tokens = data.get('max_tokens', 512) sleep_time = data.get('sleep_time', 0) # Get the selected model info model_info = MODELS[model_name] model = model_info["model"] tokenizer = model_info["tokenizer"] # Prepare input text based on model type if model_info["uses_chat_template"]: # For Qwen, use chat template messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) else: # For GPT2, use direct prompt text = prompt # Prepare inputs model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # Initialize streamer streamer = WebSocketBeamStreamer( tokenizer=tokenizer, num_beams=num_beams, sleep_time=sleep_time, skip_prompt=True ) try: # Generate with beam search with torch.no_grad(): model.generate( **model_inputs, num_beams=num_beams, num_return_sequences=num_beams, max_new_tokens=max_new_tokens, output_scores=True, return_dict_in_generate=True, early_stopping=True, streamer=streamer, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) except Exception as e: socketio.emit('generation_error', {'error': str(e)}) finally: socketio.emit('generation_completed')