stzhao commited on
Commit
4193b64
·
verified ·
1 Parent(s): 10b53ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py CHANGED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import torch
3
+ from transformers import AutoModelForCausalLM
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ # 加载模型
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ "q-future/one-align",
10
+ trust_remote_code=True,
11
+ attn_implementation="eager",
12
+ torch_dtype=torch.float16,
13
+ device_map="auto"
14
+ )
15
+
16
+ def score_image(image, task):
17
+ """
18
+ 对输入的图像进行评分
19
+ :param image: 输入的图像
20
+ :param task: 任务类型,可以是 "quality" 或 "aesthetics"
21
+ :return: 评分结果
22
+ """
23
+ if task not in ["quality", "aesthetics"]:
24
+ return "任务类型必须是 'quality' 或 'aesthetics'"
25
+
26
+ # 将图像转换为模型所需的格式
27
+ if isinstance(image, str):
28
+ image = Image.open(requests.get(image, stream=True).raw)
29
+ elif isinstance(image, Image.Image):
30
+ pass
31
+ else:
32
+ return "输入必须是图像 URL 或 PIL 图像"
33
+
34
+ # 调用模型进行评分
35
+ result = model.score([image], task_=task, input_="image")
36
+ return result
37
+
38
+ # 创建 Gradio 界面
39
+ iface = gr.Interface(
40
+ fn=score_image,
41
+ inputs=[
42
+ gr.Image(label="输入图像", type="pil"),
43
+ gr.Dropdown(choices=["quality", "aesthetics"], label="任务类型")
44
+ ],
45
+ outputs="text",
46
+ title="图像评分模型",
47
+ description="上传图像并选择任务类型(quality 或 aesthetics)来获取评分。"
48
+ )
49
+
50
+ # 启动 Gradio 应用
51
+ iface.launch()