Paul DAMPFHOEFFER commited on
Commit
cbe05b5
·
1 Parent(s): cd782ec

fix: small fix

Browse files
Files changed (1) hide show
  1. app.py +30 -10
app.py CHANGED
@@ -13,17 +13,23 @@ def greet_json():
13
 
14
  @app.post("/")
15
  async def aria_image_to_text(request: Request):
 
16
  data = await request.json()
 
17
  image_url = data.get("image_url")
 
 
 
18
  image = Image.open(requests.get(image_url, stream=True).raw)
19
-
20
  model_id_or_path = "rhymes-ai/Aria"
 
21
  model = AriaForConditionalGeneration.from_pretrained(
22
  model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
23
  )
24
-
25
  processor = AriaProcessor.from_pretrained(model_id_or_path)
26
-
27
  messages = [
28
  {
29
  "role": "user",
@@ -34,11 +40,15 @@ async def aria_image_to_text(request: Request):
34
  }
35
  ]
36
 
 
37
  text = processor.apply_chat_template(messages, add_generation_prompt=True)
 
38
  inputs = processor(text=text, images=image, return_tensors="pt")
 
39
  inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
 
40
  inputs.to(model.device)
41
-
42
  output = model.generate(
43
  **inputs,
44
  max_new_tokens=15,
@@ -47,21 +57,26 @@ async def aria_image_to_text(request: Request):
47
  do_sample=True,
48
  temperature=0.9,
49
  )
 
50
  output_ids = output[0][inputs["input_ids"].shape[1]:]
 
51
  response = processor.decode(output_ids, skip_special_tokens=True)
 
52
  return {"response": response}
53
 
54
  @app.get("/aria-test")
55
  def aria_test():
 
56
  model_id_or_path = "rhymes-ai/Aria"
 
57
  model = AriaForConditionalGeneration.from_pretrained(
58
  model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
59
  )
60
-
61
  processor = AriaProcessor.from_pretrained(model_id_or_path)
62
-
63
  image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
64
-
65
  messages = [
66
  {
67
  "role": "user",
@@ -71,12 +86,15 @@ def aria_test():
71
  ],
72
  }
73
  ]
74
-
75
  text = processor.apply_chat_template(messages, add_generation_prompt=True)
 
76
  inputs = processor(text=text, images=image, return_tensors="pt")
 
77
  inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
 
78
  inputs.to(model.device)
79
-
80
  output = model.generate(
81
  **inputs,
82
  max_new_tokens=15,
@@ -86,5 +104,7 @@ def aria_test():
86
  temperature=0.9,
87
  )
88
  output_ids = output[0][inputs["input_ids"].shape[1]:]
 
89
  response = processor.decode(output_ids, skip_special_tokens=True)
90
- return {"response": response}
 
 
13
 
14
  @app.post("/")
15
  async def aria_image_to_text(request: Request):
16
+ print(1)
17
  data = await request.json()
18
+ print(2)
19
  image_url = data.get("image_url")
20
+ print(3)
21
+ print('image_url')
22
+ print(image_url)
23
  image = Image.open(requests.get(image_url, stream=True).raw)
24
+ print(4)
25
  model_id_or_path = "rhymes-ai/Aria"
26
+ print(5)
27
  model = AriaForConditionalGeneration.from_pretrained(
28
  model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
29
  )
30
+ print(6)
31
  processor = AriaProcessor.from_pretrained(model_id_or_path)
32
+ print(7)
33
  messages = [
34
  {
35
  "role": "user",
 
40
  }
41
  ]
42
 
43
+ print(8)
44
  text = processor.apply_chat_template(messages, add_generation_prompt=True)
45
+ print(9)
46
  inputs = processor(text=text, images=image, return_tensors="pt")
47
+ print(10)
48
  inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
49
+ print(11)
50
  inputs.to(model.device)
51
+ print(12)
52
  output = model.generate(
53
  **inputs,
54
  max_new_tokens=15,
 
57
  do_sample=True,
58
  temperature=0.9,
59
  )
60
+ print(13)
61
  output_ids = output[0][inputs["input_ids"].shape[1]:]
62
+ print(14)
63
  response = processor.decode(output_ids, skip_special_tokens=True)
64
+ print(15)
65
  return {"response": response}
66
 
67
  @app.get("/aria-test")
68
  def aria_test():
69
+ print(1)
70
  model_id_or_path = "rhymes-ai/Aria"
71
+ print(2)
72
  model = AriaForConditionalGeneration.from_pretrained(
73
  model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
74
  )
75
+ print(3)
76
  processor = AriaProcessor.from_pretrained(model_id_or_path)
77
+ print(4)
78
  image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
79
+ print(5)
80
  messages = [
81
  {
82
  "role": "user",
 
86
  ],
87
  }
88
  ]
89
+ print(6)
90
  text = processor.apply_chat_template(messages, add_generation_prompt=True)
91
+ print(7)
92
  inputs = processor(text=text, images=image, return_tensors="pt")
93
+ print(8)
94
  inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
95
+ print(9)
96
  inputs.to(model.device)
97
+ print(10)
98
  output = model.generate(
99
  **inputs,
100
  max_new_tokens=15,
 
104
  temperature=0.9,
105
  )
106
  output_ids = output[0][inputs["input_ids"].shape[1]:]
107
+ print(11)
108
  response = processor.decode(output_ids, skip_special_tokens=True)
109
+ print(12)
110
+ return {"response": response}