|
--- |
|
tags: |
|
- generated_from_trainer |
|
datasets: |
|
- coco |
|
metrics: |
|
- rouge |
|
- bleu |
|
model-index: |
|
- name: vit-swin-base-224-gpt2-image-captioning |
|
results: [] |
|
license: mit |
|
language: |
|
- en |
|
pipeline_tag: image-to-text |
|
--- |
|
|
|
# vit-swin-base-224-gpt2-image-captioning |
|
|
|
This model is a fine-tuned [VisionEncoderDecoder](https://huggingface.co./docs/transformers/model_doc/vision-encoder-decoder) model on 60% of the [COCO2014](https://huggingface.co./datasets/HuggingFaceM4/COCO) dataset. |
|
It achieves the following results on the testing set: |
|
- Loss: 0.7989 |
|
- Rouge1: 53.1153 |
|
- Rouge2: 24.2307 |
|
- Rougel: 51.5002 |
|
- Rougelsum: 51.4983 |
|
- Bleu: 17.7765 |
|
|
|
## Model description |
|
|
|
The model was initialized on [microsoft/swin-base-patch4-window7-224-in22k](https://huggingface.co./microsoft/swin-base-patch4-window7-224-in22k) as the vision encoder, the [gpt2](https://huggingface.co./gpt2) as the decoder. |
|
|
|
## Intended uses & limitations |
|
|
|
You can use this model for image captioning only. |
|
|
|
## How to use |
|
|
|
You can either use the simple pipeline API: |
|
|
|
```python |
|
from transformers import pipeline |
|
|
|
image_captioner = pipeline("image-to-text", model="Abdou/vit-swin-base-224-gpt2-image-captioning") |
|
# infer the caption |
|
caption = image_captioner("http://images.cocodataset.org/test-stuff2017/000000000019.jpg")[0]['generated_text'] |
|
print(f"caption: {caption}") |
|
|
|
``` |
|
|
|
Or initialize everything for more flexibility: |
|
|
|
```python |
|
from transformers import VisionEncoderDecoderModel, GPT2TokenizerFast, ViTImageProcessor |
|
import torch |
|
import os |
|
import urllib.parse as parse |
|
from PIL import Image |
|
import requests |
|
|
|
# a function to determine whether a string is a URL or not |
|
def is_url(string): |
|
try: |
|
result = parse.urlparse(string) |
|
return all([result.scheme, result.netloc, result.path]) |
|
except: |
|
return False |
|
|
|
# a function to load an image |
|
def load_image(image_path): |
|
if is_url(image_path): |
|
return Image.open(requests.get(image_path, stream=True).raw) |
|
elif os.path.exists(image_path): |
|
return Image.open(image_path) |
|
|
|
# a function to perform inference |
|
def get_caption(model, image_processor, tokenizer, image_path): |
|
image = load_image(image_path) |
|
# preprocess the image |
|
img = image_processor(image, return_tensors="pt").to(device) |
|
# generate the caption (using greedy decoding by default) |
|
output = model.generate(**img) |
|
# decode the output |
|
caption = tokenizer.batch_decode(output, skip_special_tokens=True)[0] |
|
return caption |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
# load the fine-tuned image captioning model and corresponding tokenizer and image processor |
|
model = VisionEncoderDecoderModel.from_pretrained("Abdou/vit-swin-base-224-gpt2-image-captioning").to(device) |
|
tokenizer = GPT2TokenizerFast.from_pretrained("Abdou/vit-swin-base-224-gpt2-image-captioning") |
|
image_processor = ViTImageProcessor.from_pretrained("Abdou/vit-swin-base-224-gpt2-image-captioning") |
|
|
|
# target image |
|
url = "http://images.cocodataset.org/test-stuff2017/000000000019.jpg" |
|
# get the caption |
|
caption = get_caption(model, image_processor, tokenizer, url) |
|
print(f"caption: {caption}") |
|
|
|
|
|
``` |
|
Output: |
|
``` |
|
Two cows laying in a field with a sky background. |
|
``` |
|
|
|
## Training procedure |
|
|
|
You can check [this guide](https://www.thepythoncode.com/article/image-captioning-with-pytorch-and-transformers-in-python) to learn how this model was fine-tuned. |
|
|
|
### Training hyperparameters |
|
|
|
The following hyperparameters were used during training: |
|
- learning_rate: 5e-05 |
|
- train_batch_size: 64 |
|
- eval_batch_size: 64 |
|
- seed: 42 |
|
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 |
|
- lr_scheduler_type: linear |
|
- num_epochs: 2 |
|
|
|
### Training results |
|
|
|
| Training Loss | Epoch | Step | Validation Loss | Rouge1 | Rouge2 | Rougel | Rougelsum | Bleu | Gen Len | |
|
|:-------------:|:-----:|:-----:|:---------------:|:-------:|:-------:|:-------:|:---------:|:-------:|:-------:| |
|
| 1.0018 | 0.38 | 2000 | 0.8860 | 38.6537 | 13.8145 | 35.3932 | 35.3935 | 8.2448 | 11.2946 | |
|
| 0.8827 | 0.75 | 4000 | 0.8395 | 40.0458 | 14.8829 | 36.5321 | 36.5366 | 9.1169 | 11.2946 | |
|
| 0.8378 | 1.13 | 6000 | 0.8140 | 41.2736 | 15.9576 | 37.5504 | 37.5512 | 9.871 | 11.2946 | |
|
| 0.7913 | 1.51 | 8000 | 0.8012 | 41.6642 | 16.1987 | 37.8786 | 37.8891 | 10.0786 | 11.2946 | |
|
| 0.7794 | 1.89 | 10000 | 0.7933 | 41.9119 | 16.3738 | 38.1062 | 38.1292 | 10.288 | 11.2946 | |
|
|
|
Total training time: ~5 hours on NVIDIA A100 GPU. |
|
|
|
### Framework versions |
|
|
|
- Transformers 4.26.0 |
|
- Pytorch 1.13.1+cu116 |
|
- Datasets 2.9.0 |
|
- Tokenizers 0.13.2 |