Moshe Ofer
GPT2
8f0265c
import eventlet
eventlet.monkey_patch(socket=True, select=True, thread=True)
import eventlet.wsgi
from flask import Flask, render_template, request
from flask_socketio import SocketIO
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
import torch
app = Flask(__name__)
socketio = SocketIO(
app,
async_mode='eventlet',
message_queue=None,
ping_timeout=60,
ping_interval=25,
cors_allowed_origins="*",
logger=True,
engineio_logger=True,
async_handlers=True
)
# Initialize models and tokenizers
MODELS = {
"qwen": {
"name": "Qwen/Qwen2.5-0.5B-Instruct",
"tokenizer": None,
"model": None,
"uses_chat_template": True # Qwen uses chat template
},
"gpt2": {
"name": "gpt2",
"tokenizer": None,
"model": None,
"uses_chat_template": False # GPT2 doesn't use chat template
}
}
# Load models and tokenizers
for model_key, model_info in MODELS.items():
model_info["tokenizer"] = AutoTokenizer.from_pretrained(model_info["name"])
model_info["model"] = AutoModelForCausalLM.from_pretrained(
model_info["name"],
torch_dtype="auto",
device_map="auto"
)
# Add pad token for GPT2 if it doesn't have one
if model_key == "gpt2" and model_info["tokenizer"].pad_token is None:
model_info["tokenizer"].pad_token = model_info["tokenizer"].eos_token
model_info["model"].config.pad_token_id = model_info["model"].config.eos_token_id
class WebSocketBeamStreamer(MultiBeamTextStreamer):
"""Custom streamer that sends updates through websockets with adjustable speed"""
def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True):
super().__init__(
tokenizer,
num_beams=num_beams,
skip_prompt=skip_prompt,
on_beam_update=self.on_beam_update,
on_beam_finished=self.on_beam_finished
)
self.beam_texts = {i: "" for i in range(num_beams)}
self.sleep_time = sleep_time
def on_beam_update(self, beam_idx: int, new_text: str):
self.beam_texts[beam_idx] = new_text
if self.sleep_time > 0:
eventlet.sleep(self.sleep_time / 1000)
socketio.emit('beam_update', {
'beam_idx': beam_idx,
'text': new_text
}, namespace='/', callback=lambda: eventlet.sleep(0))
socketio.sleep(0)
def on_beam_finished(self, final_text: str):
socketio.emit('beam_finished', {
'text': final_text
})
@app.route('/')
def index():
return render_template('index.html')
@socketio.on('generate')
def handle_generation(data):
socketio.emit('generation_started')
prompt = data['prompt']
model_name = data.get('model', 'qwen') # Default to qwen if not specified
num_beams = data.get('num_beams', 5)
max_new_tokens = data.get('max_tokens', 512)
sleep_time = data.get('sleep_time', 0)
# Get the selected model info
model_info = MODELS[model_name]
model = model_info["model"]
tokenizer = model_info["tokenizer"]
# Prepare input text based on model type
if model_info["uses_chat_template"]:
# For Qwen, use chat template
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# For GPT2, use direct prompt
text = prompt
# Prepare inputs
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Initialize streamer
streamer = WebSocketBeamStreamer(
tokenizer=tokenizer,
num_beams=num_beams,
sleep_time=sleep_time,
skip_prompt=True
)
try:
# Generate with beam search
with torch.no_grad():
model.generate(
**model_inputs,
num_beams=num_beams,
num_return_sequences=num_beams,
max_new_tokens=max_new_tokens,
output_scores=True,
return_dict_in_generate=True,
early_stopping=True,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
except Exception as e:
socketio.emit('generation_error', {'error': str(e)})
finally:
socketio.emit('generation_completed')