|
import requests |
|
import torch |
|
from transformers import AutoModelForCausalLM |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"q-future/one-align", |
|
trust_remote_code=True, |
|
attn_implementation="eager", |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
|
|
def score_image(image, task): |
|
""" |
|
对输入的图像进行评分 |
|
:param image: 输入的图像 |
|
:param task: 任务类型,可以是 "quality" 或 "aesthetics" |
|
:return: 评分结果 |
|
""" |
|
if task not in ["quality", "aesthetics"]: |
|
return "任务类型必须是 'quality' 或 'aesthetics'" |
|
|
|
|
|
if isinstance(image, str): |
|
image = Image.open(requests.get(image, stream=True).raw) |
|
elif isinstance(image, Image.Image): |
|
pass |
|
else: |
|
return "输入必须是图像 URL 或 PIL 图像" |
|
|
|
|
|
result = model.score([image], task_=task, input_="image") |
|
return result |
|
|
|
|
|
iface = gr.Interface( |
|
fn=score_image, |
|
inputs=[ |
|
gr.Image(label="输入图像", type="pil"), |
|
gr.Dropdown(choices=["quality", "aesthetics"], label="任务类型") |
|
], |
|
outputs="text", |
|
title="图像评分模型", |
|
description="上传图像并选择任务类型(quality 或 aesthetics)来获取评分。" |
|
) |
|
|
|
|
|
iface.launch() |