openai api capability server

#6
by devops724 - opened

here is code you can run this model os openai api capability server

import torch
import base64
import uvicorn
import json
import argparse
import os
from io import BytesIO
from typing import List, Dict, Any, Optional, Union
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
import time
import uuid

# Parse command line arguments
parser = argparse.ArgumentParser(description="OLMoCR API Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)")
parser.add_argument("--model-id", type=str, default="allenai/olmOCR-7B-0225-preview", help="Model ID to load")
args = parser.parse_args()

# Create FastAPI app
app = FastAPI(
    title="OLMoCR API",
    description="OpenAPI-compatible REST API for OLMoCR OCR and document understanding model",
    version="0.1.0",
)

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Define models for API
class ImageUrl(BaseModel):
    url: str

class ContentItem(BaseModel):
    type: str
    text: Optional[str] = None
    image_url: Optional[ImageUrl] = None

class Message(BaseModel):
    role: str
    content: Union[str, List[ContentItem]]

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[Message]
    max_tokens: int = 300
    temperature: float = 0.8
    stream: bool = False
    
class CompletionChoice(BaseModel):
    index: int
    message: Message
    finish_reason: str = "stop"

class CompletionUsage(BaseModel):
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int

class ChatCompletionResponse(BaseModel):
    id: str = Field(..., example="chatcmpl-123456789")
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[CompletionChoice]
    usage: CompletionUsage

# Load the model and processor
print(f"Loading model: {args.model_id}")
device = torch.device(args.device)
model = Qwen2VLForConditionalGeneration.from_pretrained(args.model_id, torch_dtype=torch.bfloat16).eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model.to(device)
print(f"Model loaded on {device}")

# Define routes


@app
	.get("/")
async def root():
    return {"message": "OLMoCR OpenAPI Server", "status": "running"}



@app
	.get("/v1/models")
async def get_models():
    """List available models"""
    return {
        "object": "list",
        "data": [
            {
                "id": args.model_id,
                "object": "model",
                "created": int(time.time()),
                "owned_by": "allenai"
            }
        ]
    }



@app
	.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
    """Process a chat completion request with image inputs"""
    try:
        # Only support the last message for now
        last_message = request.messages[-1]
        
        # Process the content
        if isinstance(last_message.content, str):
            # If content is just a string, treat it as a text prompt
            prompt = last_message.content
            image_data = None
        else:
            # Process the content items
            prompt = None
            image_data = None
            
            for item in last_message.content:
                if item.type == "text":
                    prompt = item.text
                elif item.type == "image_url":
                    image_url = item.image_url.url
                    
                    # Handle base64 encoded images
                    if image_url.startswith("data:image"):
                        # Extract the base64 part
                        image_data = image_url.split(",")[1]
                    # Handle PDF URLs (very basic detection)
                    elif image_url.endswith(".pdf"):
                        # Download the PDF and convert to image
                        import urllib.request
                        import tempfile
                        
                        with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file:
                            urllib.request.urlretrieve(image_url, tmp_file.name)
                            # Convert PDF to image
                            image_data = render_pdf_to_base64png(tmp_file.name, 1, target_longest_image_dim=1024)
                            
                            # Build anchor text if no prompt given
                            if not prompt:
                                anchor_text = get_anchor_text(tmp_file.name, 1, pdf_engine="pdfreport", target_length=4000)
                                prompt = build_finetuning_prompt(anchor_text)
                            
                            # Clean up
                            os.unlink(tmp_file.name)
                    else:
                        # Download the image
                        import urllib.request
                        
                        with urllib.request.urlopen(image_url) as response:
                            image_data = base64.b64encode(response.read()).decode("utf-8")
        
        if not image_data:
            raise HTTPException(status_code=400, detail="No image data provided")
        
        if not prompt:
            prompt = "Extract and describe the text content from this image."
        
        # Process the input for the model
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}},
                ],
            }
        ]
        
        # Apply chat template
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        
        # Process the image
        main_image = Image.open(BytesIO(base64.b64decode(image_data)))
        inputs = processor(
            text=[text],
            images=[main_image],
            padding=True,
            return_tensors="pt",
        )
        inputs = {key: value.to(device) for (key, value) in inputs.items()}
        
        # Count input tokens
        input_token_count = inputs["input_ids"].shape[1]
        
        # Generate output
        start_time = time.time()
        output = model.generate(
            **inputs,
            temperature=request.temperature,
            max_new_tokens=request.max_tokens,
            num_return_sequences=1,
            do_sample=True,
        )
        end_time = time.time()
        
        # Decode output
        prompt_length = inputs["input_ids"].shape[1]
        new_tokens = output[:, prompt_length:]
        text_output = processor.tokenizer.batch_decode(
            new_tokens, skip_special_tokens=True
        )[0]
        
        # Count output tokens
        output_token_count = new_tokens.shape[1]
        
        # Create response
        completion_id = f"chatcmpl-{str(uuid.uuid4())[:8]}"
        response = ChatCompletionResponse(
            id=completion_id,
            object="chat.completion",
            created=int(time.time()),
            model=request.model,
            choices=[
                CompletionChoice(
                    index=0,
                    message=Message(
                        role="assistant",
                        content=text_output
                    ),
                    finish_reason="stop"
                )
            ],
            usage=CompletionUsage(
                prompt_tokens=input_token_count,
                completion_tokens=output_token_count,
                total_tokens=input_token_count + output_token_count
            )
        )
        
        print(f"Request processed in {end_time - start_time:.2f}s, generated {output_token_count} tokens")
        return response
    
    except Exception as e:
        print(f"Error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    print(f"Starting OLMoCR API server on {args.host}:{args.port}")
    uvicorn.run(app, host=args.host, port=args.port)

Sign up or log in to comment