Commit
•
8576dce
1
Parent(s):
7edd2b8
Missing self
Browse files- pipeline.py +3 -3
pipeline.py
CHANGED
@@ -17,7 +17,7 @@ class PreTrainedPipeline():
|
|
17 |
self.model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
|
18 |
self.model = blip_decoder(pretrained=self.model_url, image_size=384, vit='large')
|
19 |
self.model.eval()
|
20 |
-
self.model = model.to(device)
|
21 |
|
22 |
image_size = 384
|
23 |
self.transform = transforms.Compose([
|
@@ -43,8 +43,8 @@ class PreTrainedPipeline():
|
|
43 |
|
44 |
# decode base64 image to PIL
|
45 |
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
|
46 |
-
image = transform(image).unsqueeze(0).to(device)
|
47 |
with torch.no_grad():
|
48 |
caption = self.model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
|
49 |
# postprocess the prediction
|
50 |
-
return caption
|
|
|
17 |
self.model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
|
18 |
self.model = blip_decoder(pretrained=self.model_url, image_size=384, vit='large')
|
19 |
self.model.eval()
|
20 |
+
self.model = self.model.to(device)
|
21 |
|
22 |
image_size = 384
|
23 |
self.transform = transforms.Compose([
|
|
|
43 |
|
44 |
# decode base64 image to PIL
|
45 |
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
|
46 |
+
image = self.transform(image).unsqueeze(0).to(device)
|
47 |
with torch.no_grad():
|
48 |
caption = self.model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
|
49 |
# postprocess the prediction
|
50 |
+
return caption
|