fffiloni commited on
Commit
a6ba22e
·
1 Parent(s): fa91213

Create img_cap.py

Browse files
Files changed (1) hide show
  1. tasks/img_cap.py +54 -0
tasks/img_cap.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import cv2
9
+ import torch
10
+ import numpy as np
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+
14
+
15
+ t = []
16
+ t.append(transforms.Resize(224, interpolation=Image.BICUBIC))
17
+ transform = transforms.Compose(t)
18
+
19
+ t = []
20
+ t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
21
+ transform_v = transforms.Compose(t)
22
+
23
+ def image_captioning(model, image, texts, inpainting_text, *args, **kwargs):
24
+ with torch.no_grad():
25
+ image_ori = transform_v(image)
26
+ width = image_ori.size[0]
27
+ height = image_ori.size[1]
28
+ image_ori = np.asarray(image_ori)
29
+
30
+ image = transform(image)
31
+ image = np.asarray(image)
32
+ images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
33
+
34
+ batch_inputs = [{'image': images, 'height': height, 'width': width, 'image_id': 0}]
35
+ outputs = model.model.evaluate_captioning(batch_inputs)
36
+ text = outputs[-1]['captioning_text']
37
+
38
+ image_ori = image_ori.copy()
39
+ cv2.rectangle(image_ori, (0, height-60), (width, height), (0,0,0), -1)
40
+ font = cv2.FONT_HERSHEY_DUPLEX
41
+ fontScale = 1.2
42
+ thickness = 2
43
+ lineType = 2
44
+ bottomLeftCornerOfText = (10, height-20)
45
+ fontColor = [255,255,255]
46
+ cv2.putText(image_ori, text,
47
+ bottomLeftCornerOfText,
48
+ font,
49
+ fontScale,
50
+ fontColor,
51
+ thickness,
52
+ lineType)
53
+ torch.cuda.empty_cache()
54
+ return Image.fromarray(image_ori), text, None