yuukicammy commited on
Commit
61d1cb6
·
1 Parent(s): 53eb83f

Simplify the names.

Browse files
vit_gpt2_image_caption.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
2
+
3
+ import urllib.request
4
+ import modal
5
+
6
+ stub = modal.Stub("vit-gpt2-image-captioning")
7
+ volume = modal.SharedVolume().persist("shared_vol")
8
+
9
+ @stub.function(
10
+ gpu="any",
11
+ image=modal.Image.debian_slim().pip_install("Pillow", "transformers", "torch"),
12
+ shared_volumes={"/root/model_cache": volume},
13
+ retries=3,
14
+ )
15
+ def predict(image):
16
+ import io
17
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
18
+ import torch
19
+ from PIL import Image
20
+
21
+ model = VisionEncoderDecoderModel.from_pretrained(
22
+ "nlpconnect/vit-gpt2-image-captioning"
23
+ )
24
+ feature_extractor = ViTImageProcessor.from_pretrained(
25
+ "nlpconnect/vit-gpt2-image-captioning"
26
+ )
27
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
28
+
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model.to(device)
31
+
32
+ max_length = 16
33
+ num_beams = 4
34
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
35
+ input_img = Image.open(io.BytesIO(image))
36
+ pixel_values = feature_extractor(
37
+ images=[input_img], return_tensors="pt"
38
+ ).pixel_values
39
+ pixel_values = pixel_values.to(device)
40
+
41
+ output_ids = model.generate(pixel_values, **gen_kwargs)
42
+
43
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
44
+ preds = [pred.strip() for pred in preds]
45
+ return preds
46
+
47
+
48
+ @stub.local_entrypoint()
49
+ def main():
50
+ from pathlib import Path
51
+
52
+ image_filepath = Path(__file__).parent / "sample.png"
53
+ if image_filepath.exists():
54
+ with open(image_filepath, "rb") as f:
55
+ image = f.read()
56
+ else:
57
+ try:
58
+ image = urllib.request.urlopen(
59
+ "https://drive.google.com/uc?id=0B0TjveMhQDhgLTlpOENiOTZ6Y00&export=download"
60
+ ).read()
61
+ except urllib.error.URLError as e:
62
+ print(e.reason)
63
+ print(predict.call(image)[0])
vit_gpt2_image_caption_webapp.py CHANGED
@@ -11,7 +11,7 @@ web_app = fastapi.FastAPI()
11
 
12
  @web_app.post("/parse")
13
  async def parse(request: fastapi.Request):
14
- predict_step = Function.lookup("vit-gpt2-image-captioning", "predict_step")
15
 
16
  form = await request.form()
17
  image = await form["image"].read() # type: ignore
 
11
 
12
  @web_app.post("/parse")
13
  async def parse(request: fastapi.Request):
14
+ predict_step = Function.lookup("vit-gpt2-image-caption", "predict")
15
 
16
  form = await request.form()
17
  image = await form["image"].read() # type: ignore