import gradio as gr import torch from PIL import Image from model import GitBaseCocoModel, BlipBaseModel MODELS = { "Git-Base-COCO": GitBaseCocoModel, "Blip Base": BlipBaseModel, } def generate_captions( image, num_captions, max_length, temperature, top_k, top_p, repetition_penalty, diversity_penalty, model_name, ): """ Generates captions for the given image. ----- Parameters: image: PIL.Image The image to generate captions for. max_len: int The maximum length of the caption. num_captions: int The number of captions to generate. ----- Returns: list[str] """ device = "cuda" if torch.cuda.is_available() else "cpu" model = MODELS[model_name](device) captions = model.generate( image=image, max_length=max_length, num_captions=num_captions, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, diversity_penalty=diversity_penalty, ) # Convert list to a single string separated by newlines. captions = "\n".join(captions) return captions title = "Git-Base-COCO Image Captioning" description = "A model for generating captions for images." interface = gr.Interface( fn=generate_captions, inputs=[ gr.inputs.Image(type="pil", label="Image"), gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Captions to Generate"), gr.inputs.Slider(minimum=20, maximum=100, step=5, default=50, label="Maximum Caption Length"), gr.inputs.Slider(minimum=0.1, maximum=10.0, step=0.1, default=1.0, label="Temperature"), gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top K"), gr.inputs.Slider(minimum=-5.0, maximum=5.0, step=0.1, default=1.0, label="Top P"), gr.inputs.Slider(minimum=1.0, maximum=10.0, step=0.1, default=1.0, label="Repetition Penalty"), gr.inputs.Slider(minimum=0.0, maximum=10.0, step=0.1, default=0.0, label="Diversity Penalty"), gr.inputs.Dropdown(MODELS.keys(), label="Model"), ], outputs=[ gr.outputs.Textbox(label="Caption"), ], title=title, description=description, ) if __name__ == "__main__": interface.launch( enable_queue=True, debug=True )