qwen25-api / main.py
dragonjump
update'
9086f73
from fastapi import FastAPI, Query
from transformers import (
AutoProcessor,
AutoModelForCausalLM,
AutoTokenizer,
)
from transformers import Qwen2_5_VLForConditionalGeneration
from qwen_vl_utils import process_vision_info
import torch
import logging
logging.basicConfig(level=logging.INFO)
app = FastAPI()
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
min_pixels = 256*28*28
max_pixels = 1280*28*28
processor = AutoProcessor.from_pretrained(
checkpoint,
min_pixels=min_pixels,
max_pixels=max_pixels
)
qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
# attn_implementation="flash_attention_2",
)
# LLaMA Model Setup
llama_model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2"
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_name)
llama_model = AutoModelForCausalLM.from_pretrained(
llama_model_name, torch_dtype=torch.float16, device_map="auto"
)
@app.get("/")
def read_root():
return {"message": "API is live. Use the /predict, /chat, or /llama_chat endpoints."}
@app.get("/predict")
def predict(image_url: str = Query(...), prompt: str = Query(...)):
messages = [
{"role": "system", "content": "You are a helpful assistant with vision abilities."},
{
"role": "user",
"content": [
{"type": "image", "image": image_url},
{"type": "text", "text": prompt},
],
},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(qwen_model.device)
with torch.no_grad():
generated_ids = qwen_model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_texts = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return {"response": output_texts[0]}
@app.get("/chat")
def chat(prompt: str = Query(...)):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
padding=True,
return_tensors="pt",
).to(qwen_model.device)
with torch.no_grad():
generated_ids = qwen_model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_texts = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return {"response": output_texts[0]}
@app.get("/llama_chat")
def llama_chat(prompt: str = Query(...)):
inputs = llama_tokenizer(prompt, return_tensors="pt").to(llama_model.device)
with torch.no_grad():
outputs = llama_model.generate(**inputs, max_new_tokens=128)
response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": response}