openai api capability server
#6
by
devops724
- opened
here is code you can run this model os openai api capability server
import torch
import base64
import uvicorn
import json
import argparse
import os
from io import BytesIO
from typing import List, Dict, Any, Optional, Union
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
import time
import uuid
# Parse command line arguments
parser = argparse.ArgumentParser(description="OLMoCR API Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)")
parser.add_argument("--model-id", type=str, default="allenai/olmOCR-7B-0225-preview", help="Model ID to load")
args = parser.parse_args()
# Create FastAPI app
app = FastAPI(
title="OLMoCR API",
description="OpenAPI-compatible REST API for OLMoCR OCR and document understanding model",
version="0.1.0",
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define models for API
class ImageUrl(BaseModel):
url: str
class ContentItem(BaseModel):
type: str
text: Optional[str] = None
image_url: Optional[ImageUrl] = None
class Message(BaseModel):
role: str
content: Union[str, List[ContentItem]]
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
max_tokens: int = 300
temperature: float = 0.8
stream: bool = False
class CompletionChoice(BaseModel):
index: int
message: Message
finish_reason: str = "stop"
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str = Field(..., example="chatcmpl-123456789")
object: str = "chat.completion"
created: int
model: str
choices: List[CompletionChoice]
usage: CompletionUsage
# Load the model and processor
print(f"Loading model: {args.model_id}")
device = torch.device(args.device)
model = Qwen2VLForConditionalGeneration.from_pretrained(args.model_id, torch_dtype=torch.bfloat16).eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model.to(device)
print(f"Model loaded on {device}")
# Define routes
@app
.get("/")
async def root():
return {"message": "OLMoCR OpenAPI Server", "status": "running"}
@app
.get("/v1/models")
async def get_models():
"""List available models"""
return {
"object": "list",
"data": [
{
"id": args.model_id,
"object": "model",
"created": int(time.time()),
"owned_by": "allenai"
}
]
}
@app
.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
"""Process a chat completion request with image inputs"""
try:
# Only support the last message for now
last_message = request.messages[-1]
# Process the content
if isinstance(last_message.content, str):
# If content is just a string, treat it as a text prompt
prompt = last_message.content
image_data = None
else:
# Process the content items
prompt = None
image_data = None
for item in last_message.content:
if item.type == "text":
prompt = item.text
elif item.type == "image_url":
image_url = item.image_url.url
# Handle base64 encoded images
if image_url.startswith("data:image"):
# Extract the base64 part
image_data = image_url.split(",")[1]
# Handle PDF URLs (very basic detection)
elif image_url.endswith(".pdf"):
# Download the PDF and convert to image
import urllib.request
import tempfile
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file:
urllib.request.urlretrieve(image_url, tmp_file.name)
# Convert PDF to image
image_data = render_pdf_to_base64png(tmp_file.name, 1, target_longest_image_dim=1024)
# Build anchor text if no prompt given
if not prompt:
anchor_text = get_anchor_text(tmp_file.name, 1, pdf_engine="pdfreport", target_length=4000)
prompt = build_finetuning_prompt(anchor_text)
# Clean up
os.unlink(tmp_file.name)
else:
# Download the image
import urllib.request
with urllib.request.urlopen(image_url) as response:
image_data = base64.b64encode(response.read()).decode("utf-8")
if not image_data:
raise HTTPException(status_code=400, detail="No image data provided")
if not prompt:
prompt = "Extract and describe the text content from this image."
# Process the input for the model
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}},
],
}
]
# Apply chat template
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Process the image
main_image = Image.open(BytesIO(base64.b64decode(image_data)))
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
inputs = {key: value.to(device) for (key, value) in inputs.items()}
# Count input tokens
input_token_count = inputs["input_ids"].shape[1]
# Generate output
start_time = time.time()
output = model.generate(
**inputs,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
num_return_sequences=1,
do_sample=True,
)
end_time = time.time()
# Decode output
prompt_length = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_length:]
text_output = processor.tokenizer.batch_decode(
new_tokens, skip_special_tokens=True
)[0]
# Count output tokens
output_token_count = new_tokens.shape[1]
# Create response
completion_id = f"chatcmpl-{str(uuid.uuid4())[:8]}"
response = ChatCompletionResponse(
id=completion_id,
object="chat.completion",
created=int(time.time()),
model=request.model,
choices=[
CompletionChoice(
index=0,
message=Message(
role="assistant",
content=text_output
),
finish_reason="stop"
)
],
usage=CompletionUsage(
prompt_tokens=input_token_count,
completion_tokens=output_token_count,
total_tokens=input_token_count + output_token_count
)
)
print(f"Request processed in {end_time - start_time:.2f}s, generated {output_token_count} tokens")
return response
except Exception as e:
print(f"Error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
print(f"Starting OLMoCR API server on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port)