sooh-j commited on
Commit
419b4c1
1 Parent(s): b912c2a

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +66 -0
handler.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipForQuestionAnswering, BitsAndBytesConfig
3
+ from transformers import AutoProcessor, AutoModelForCausalLM
4
+ from typing import Dict, List, Any
5
+ from PIL import Image
6
+ from transformers import pipeline
7
+ import requests
8
+ import torch
9
+ from io import BytesIO
10
+ import base64
11
+
12
+ class EndpointHandler():
13
+ def __init__(self, path=""):
14
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
+ print("device:",self.device)
16
+ self.model_base = "Salesforce/blip2-opt-2.7b"
17
+ self.model_name = "sooh-j/blip2-vizwizqa"
18
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
19
+ self.model = Blip2ForConditionalGeneration.from_pretrained(self.model_name,
20
+ device_map="auto",
21
+ ).to(self.device)
22
+
23
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
24
+ """
25
+ data args:
26
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
27
+ kwargs
28
+ Return:
29
+ A :obj:`list` | `dict`: will be serialized and returned
30
+ """
31
+ # await hf.visualQuestionAnswering({
32
+ # model: 'dandelin/vilt-b32-finetuned-vqa',
33
+ # inputs: {
34
+ # question: 'How many cats are lying down?',
35
+ # image: await (await fetch('https://placekitten.com/300/300')).blob()
36
+ # }
37
+ # })
38
+
39
+ inputs = data.get("inputs")
40
+ imageBase64 = inputs.get("image")
41
+ question = inputs.get("question")
42
+
43
+ if ('http:' in imageBase64) or ('https:' in imageBase64):
44
+ image = Image.open(requests.get(imageBase64, stream=True).raw)
45
+ else:
46
+ image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[0].encode())))
47
+
48
+ prompt = f"Question: {question}, Answer:"
49
+ processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
50
+
51
+ with torch.no_grad():
52
+ out = self.model.generate(**processed,
53
+ max_new_tokens=50,
54
+ temperature = 0.5,
55
+ do_sample=True,
56
+ top_k=50,
57
+ top_p=0.9,
58
+ repetition_penalty=1.2
59
+ ).to(self.device)
60
+
61
+ result = {}
62
+ text_output = self.processor.decode(out[0], skip_special_tokens=True)
63
+ result["text_output"] = text_output
64
+ score = 0
65
+
66
+ return [{"answer":text_output,"score":score}]