Spaces:
Running
Running
from typing import List, Optional, Union, Literal | |
from fastapi import FastAPI, Body | |
from pydantic import BaseModel | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from PIL import Image as PILImage | |
import torch | |
import base64 | |
import io | |
import os | |
from starlette.responses import FileResponse | |
app = FastAPI(docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json") | |
# Initialize model and processor | |
MODEL_NAME = "bytedance-research/UI-TARS-7B-DPO" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
try: | |
model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device) # Use float16 with low CPU memory usage | |
except RuntimeError as e: | |
if "CUDA out of memory" in str(e): | |
print("Warning: Loading model in float16 failed due to insufficient memory. Falling back to CPU and float32.") | |
device = "cpu" # Switch to CPU | |
model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True).to(device) # Load in float32 on CPU with low CPU mem usage | |
import gc | |
gc.collect() | |
torch.cuda.empty_cache() | |
else: | |
raise e | |
processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
# Pydantic models | |
class ImageUrl(BaseModel): | |
url: str | |
class Image(BaseModel): | |
type: Literal["image_url"] = "image_url" | |
image_url: ImageUrl | |
class Content(BaseModel): | |
type: Literal["text", "image_url"] | |
text: Optional[str] = None | |
image_url: Optional[ImageUrl] = None | |
class Message(BaseModel): | |
role: Literal["user", "system", "assistant"] | |
content: Union[str, List[Content]] | |
class ChatCompletionRequest(BaseModel): | |
messages: List[Message] | |
max_tokens: Optional[int] = 128 | |
async def chat_completion(request: ChatCompletionRequest = Body(...)): | |
# Extract first message content | |
messages = request.messages | |
max_tokens = request.max_tokens | |
first_message = messages[0] | |
image_url = None | |
text_content = None | |
if isinstance(first_message.content, str): | |
text_content = first_message.content | |
else: | |
for content_item in first_message.content: | |
if content_item.type == "image_url": | |
image_url = content_item.image_url.url | |
elif content_item.type == "text": | |
text_content = content_item.text | |
# Process image if provided | |
pil_image = None | |
if image_url: | |
try: | |
if image_url.startswith("data:image"): | |
header, encoded = image_url.split(",", 1) | |
image_data = base64.b64decode(encoded) | |
pil_image = PILImage.open(io.BytesIO(image_data)).convert("RGB") | |
else: | |
print("Image URL provided, but base64 expected.") | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
raise e | |
# Generate response | |
try: | |
inputs = processor(text=text_content, images=pil_image, return_tensors="pt").to(device) | |
outputs = model.generate(**inputs, max_new_tokens=max_tokens) | |
response = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
except Exception as e: | |
print(f"Error during model inference: {e}") | |
raise e | |
return { | |
"choices": [{ | |
"message": { | |
"role": "assistant", | |
"content": response | |
} | |
}] | |
} | |
def index(): | |
return FileResponse("static/index.html") | |
def startup_event(): | |
# In Hugging Face Spaces, the application is usually accessible at https://<space_name>.hf.space | |
# Here we assume the space name is 'api-UI-TARS-7B-DPO' | |
public_url = "https://api-UI-TARS-7B-DPO.hf.space" | |
print(f"Public URL: {public_url}") |