File size: 1,294 Bytes
0d08077
 
 
 
 
 
 
bc65b96
 
 
0d08077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e7d5a4
 
0d08077
 
 
 
 
2e7d5a4
0d08077
bc65b96
 
 
 
 
 
 
 
0d08077
 
2e7d5a4
 
 
 
 
 
8ca734f
2e7d5a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from PIL import Image

from model import GitBaseCocoModel


def generate_captions(
	image,
	max_len,
	num_captions,
	):
	"""
	Generates captions for the given image.
	
	-----
	Parameters:
	image: PIL.Image
		The image to generate captions for.
	max_len: int
		The maximum length of the caption.
	num_captions: int
		The number of captions to generate.

	-----
	Returns:
	list[str]
	"""

	device = "cuda" if gradio.use_gpu else "cpu"
	checkpoint = "microsoft/git-base-coco"
	
	model = GitBaseCocoModel(device, checkpoint)

	caption = model.generate(image, max_len, num_captions)
	# Convert list to a single string separated by newlines.
	caption = "\n".join(caption)
	return caption

title = "Git-Base-COCO Image Captioning"
description = "A model for generating captions for images."

interface = gr.Interface(
	fn=generate_captions,
	inputs=[
		gr.inputs.Image(type="pil", label="Image"),
		gr.inputs.Slider(minimum=20, maximum=100, step=5, default=50, label="Maximum Caption Length"),
		gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Captions to Generate"),
	],
	outputs=[
		gr.outputs.Textbox(label="Caption"),
	],
	title=title,
	description=description,
	)


if __name__ == "__main__":
	interface.launch(
		enable_queue=True,
		debug=True
	)