Spaces:
Running
Running
from fastapi import FastAPI, Query | |
from transformers import ( | |
AutoProcessor, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
) | |
from transformers import Qwen2_5_VLForConditionalGeneration | |
from qwen_vl_utils import process_vision_info | |
import torch | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
app = FastAPI() | |
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct" | |
min_pixels = 256*28*28 | |
max_pixels = 1280*28*28 | |
processor = AutoProcessor.from_pretrained( | |
checkpoint, | |
min_pixels=min_pixels, | |
max_pixels=max_pixels | |
) | |
qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
checkpoint, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
# attn_implementation="flash_attention_2", | |
) | |
# LLaMA Model Setup | |
llama_model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2" | |
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_name) | |
llama_model = AutoModelForCausalLM.from_pretrained( | |
llama_model_name, torch_dtype=torch.float16, device_map="auto" | |
) | |
def read_root(): | |
return {"message": "API is live. Use the /predict, /chat, or /llama_chat endpoints."} | |
def predict(image_url: str = Query(...), prompt: str = Query(...)): | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant with vision abilities."}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image_url}, | |
{"type": "text", "text": prompt}, | |
], | |
}, | |
] | |
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
).to(qwen_model.device) | |
with torch.no_grad(): | |
generated_ids = qwen_model.generate(**inputs, max_new_tokens=128) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) | |
] | |
output_texts = processor.batch_decode( | |
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
) | |
return {"response": output_texts[0]} | |
def chat(prompt: str = Query(...)): | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": [{"type": "text", "text": prompt}]}, | |
] | |
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor( | |
text=[text], | |
padding=True, | |
return_tensors="pt", | |
).to(qwen_model.device) | |
with torch.no_grad(): | |
generated_ids = qwen_model.generate(**inputs, max_new_tokens=128) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) | |
] | |
output_texts = processor.batch_decode( | |
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
) | |
return {"response": output_texts[0]} | |
def llama_chat(prompt: str = Query(...)): | |
inputs = llama_tokenizer(prompt, return_tensors="pt").to(llama_model.device) | |
with torch.no_grad(): | |
outputs = llama_model.generate(**inputs, max_new_tokens=128) | |
response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return {"response": response} | |