omar0scarf's picture
ضافة الملفات والتحديثات الخاصة بتطبيق UI-TARS-7B-DPO
e655ddc
from typing import List, Optional, Union, Literal
from fastapi import FastAPI, Body
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image as PILImage
import torch
import base64
import io
import os
from starlette.responses import FileResponse
app = FastAPI(docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json")
# Initialize model and processor
MODEL_NAME = "bytedance-research/UI-TARS-7B-DPO"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device) # Use float16 with low CPU memory usage
except RuntimeError as e:
if "CUDA out of memory" in str(e):
print("Warning: Loading model in float16 failed due to insufficient memory. Falling back to CPU and float32.")
device = "cpu" # Switch to CPU
model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True).to(device) # Load in float32 on CPU with low CPU mem usage
import gc
gc.collect()
torch.cuda.empty_cache()
else:
raise e
processor = AutoProcessor.from_pretrained(MODEL_NAME)
# Pydantic models
class ImageUrl(BaseModel):
url: str
class Image(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: ImageUrl
class Content(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[ImageUrl] = None
class Message(BaseModel):
role: Literal["user", "system", "assistant"]
content: Union[str, List[Content]]
class ChatCompletionRequest(BaseModel):
messages: List[Message]
max_tokens: Optional[int] = 128
@app.post("/chat/completions")
async def chat_completion(request: ChatCompletionRequest = Body(...)):
# Extract first message content
messages = request.messages
max_tokens = request.max_tokens
first_message = messages[0]
image_url = None
text_content = None
if isinstance(first_message.content, str):
text_content = first_message.content
else:
for content_item in first_message.content:
if content_item.type == "image_url":
image_url = content_item.image_url.url
elif content_item.type == "text":
text_content = content_item.text
# Process image if provided
pil_image = None
if image_url:
try:
if image_url.startswith("data:image"):
header, encoded = image_url.split(",", 1)
image_data = base64.b64decode(encoded)
pil_image = PILImage.open(io.BytesIO(image_data)).convert("RGB")
else:
print("Image URL provided, but base64 expected.")
except Exception as e:
print(f"Error processing image: {e}")
raise e
# Generate response
try:
inputs = processor(text=text_content, images=pil_image, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=max_tokens)
response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
except Exception as e:
print(f"Error during model inference: {e}")
raise e
return {
"choices": [{
"message": {
"role": "assistant",
"content": response
}
}]
}
@app.get("/")
def index():
return FileResponse("static/index.html")
@app.on_event("startup")
def startup_event():
# In Hugging Face Spaces, the application is usually accessible at https://<space_name>.hf.space
# Here we assume the space name is 'api-UI-TARS-7B-DPO'
public_url = "https://api-UI-TARS-7B-DPO.hf.space"
print(f"Public URL: {public_url}")