Spaces:
Runtime error
Runtime error
File size: 3,051 Bytes
0d08077 7d06c4c 0d08077 a95ba86 0d08077 df766f8 0d08077 43f7561 76537bc 0d08077 bc65b96 a95ba86 df766f8 0d08077 a95ba86 0d08077 a95ba86 0d08077 a95ba86 0d08077 7d06c4c a95ba86 df766f8 bba74e9 2bcaca6 df766f8 0d08077 2e7d5a4 df766f8 0d08077 a95ba86 0d08077 2e7d5a4 0d08077 bc65b96 a95ba86 bc65b96 a95ba86 bc65b96 43f7561 8467d12 43f7561 0d08077 a95ba86 2e7d5a4 a95ba86 2e7d5a4 a95ba86 2e7d5a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
import torch
from PIL import Image
from model import BlipBaseModel, GitBaseCocoModel
MODELS = {
"Git-Base-COCO": GitBaseCocoModel,
"Blip Base": BlipBaseModel,
}
# examples = [["Image1.png"], ["Image2.png"], ["Image3.png"]]
def generate_captions(
image,
num_captions,
model_name,
max_length,
temperature,
top_k,
top_p,
repetition_penalty,
diversity_penalty,
):
"""
Generates captions for the given image.
-----
Parameters:
image: PIL.Image
The image to generate captions for.
num_captions: int
The number of captions to generate.
** Rest of the parameters are the same as in the model.generate method. **
-----
Returns:
list[str]
"""
# Convert the numerical values to their corresponding types.
# Gradio Slider returns values as floats: except when the value is a whole number, in which case it returns an int.
# Only float values suffer from this issue.
temperature = float(temperature)
top_p = float(top_p)
repetition_penalty = float(repetition_penalty)
diversity_penalty = float(diversity_penalty)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MODELS[model_name](device)
captions = model.generate(
image=image,
max_length=max_length,
num_captions=num_captions,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
diversity_penalty=diversity_penalty,
)
# Convert list to a single string separated by newlines.
captions = "\n".join(captions)
return captions
title = "AI tool for generating captions for images"
description = "This tool uses pretrained models to generate captions for images."
interface = gr.Interface(
fn=generate_captions,
inputs=[
gr.components.Image(type="pil", label="Image"),
gr.components.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Captions to Generate"),
gr.components.Dropdown(MODELS.keys(), label="Model", value=list(MODELS.keys())[1]), # Default to Blip Base
gr.components.Slider(minimum=20, maximum=100, step=5, value=50, label="Maximum Caption Length"),
gr.components.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label="Temperature"),
gr.components.Slider(minimum=1, maximum=100, step=1, value=50, label="Top K"),
gr.components.Slider(minimum=0.1, maximum=5.0, step=0.1, value=1.0, label="Top P"),
gr.components.Slider(minimum=1.0, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"),
gr.components.Slider(minimum=0.0, maximum=10.0, step=0.1, value=2.0, label="Diversity Penalty"),
],
outputs=[
gr.components.Textbox(label="Caption"),
],
# Set image examples to be displayed in the interface.
examples = [
["Image1.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
["Image2.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
["Image3.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
],
title=title,
description=description,
allow_flagging="never",
)
if __name__ == "__main__":
# Launch the interface.
interface.launch(
enable_queue=True,
debug=True,
) |