File size: 1,969 Bytes
61d1cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# https://huggingface.co./nlpconnect/vit-gpt2-image-captioning

import urllib.request
import modal

stub = modal.Stub("vit-gpt2-image-captioning")
volume = modal.SharedVolume().persist("shared_vol")

@stub.function(
    gpu="any",
    image=modal.Image.debian_slim().pip_install("Pillow", "transformers", "torch"),
    shared_volumes={"/root/model_cache": volume},
    retries=3,
)
def predict(image):
    import io
    from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
    import torch
    from PIL import Image

    model = VisionEncoderDecoderModel.from_pretrained(
        "nlpconnect/vit-gpt2-image-captioning"
    )
    feature_extractor = ViTImageProcessor.from_pretrained(
        "nlpconnect/vit-gpt2-image-captioning"
    )
    tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    max_length = 16
    num_beams = 4
    gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
    input_img = Image.open(io.BytesIO(image))
    pixel_values = feature_extractor(
        images=[input_img], return_tensors="pt"
    ).pixel_values
    pixel_values = pixel_values.to(device)

    output_ids = model.generate(pixel_values, **gen_kwargs)

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    return preds


@stub.local_entrypoint()
def main():
    from pathlib import Path

    image_filepath = Path(__file__).parent / "sample.png"
    if image_filepath.exists():
        with open(image_filepath, "rb") as f:
            image = f.read()
    else:
        try:
            image = urllib.request.urlopen(
                "https://drive.google.com/uc?id=0B0TjveMhQDhgLTlpOENiOTZ6Y00&export=download"
            ).read()
        except urllib.error.URLError as e:
            print(e.reason)
    print(predict.call(image)[0])