{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7862\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers import AutoProcessor, AutoModelForCausalLM\n", "import gradio as gr\n", "import torch\n", "\n", "# Load the processor and model\n", "processor = AutoProcessor.from_pretrained(\"microsoft/git-base\")\n", "model = AutoModelForCausalLM.from_pretrained(\"./\")\n", "\n", "def predict(image):\n", " try:\n", " # Prepare the image using the processor\n", " inputs = processor(images=image, return_tensors=\"pt\")\n", "\n", " # Move inputs to the appropriate device\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " inputs = {key: value.to(device) for key, value in inputs.items()}\n", " model.to(device)\n", "\n", " # Generate the caption\n", " outputs = model.generate(**inputs)\n", "\n", " # Decode the generated caption\n", " caption = processor.batch_decode(outputs, skip_special_tokens=True)[0]\n", "\n", " return caption\n", "\n", " except Exception as e:\n", " print(\"Error during prediction:\", str(e))\n", " return \"Error: \" + str(e)\n", "\n", "# https://www.gradio.app/guides\n", "with gr.Blocks() as demo:\n", " image = gr.Image(type=\"pil\")\n", " predict_btn = gr.Button(\"Predict\", variant=\"primary\")\n", " output = gr.Label(label=\"Generated Caption\")\n", "\n", " inputs = [image]\n", " outputs = [output]\n", "\n", " predict_btn.click(predict, inputs=inputs, outputs=outputs)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch() # Local machine only\n", " # demo.launch(server_name=\"0.0.0.0\") # LAN access to local machine\n", " # demo.launch(share=True) # Public access to local machine\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 2 }