{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "toc_visible": true, "machine_shape": "hm", "gpuType": "A100", "authorship_tag": "ABX9TyNMzCSw8XLVSOI/aj2QMEti", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# Medical AI Chatbot\n", "## [ruslanmv/Medical-Llama3-v2](https://github.com/ruslanmv/ai-medical-chatbot/blob/master/Chatbot-Medical-Llama3-v2.ipynb)" ], "metadata": { "id": "D2JxjUcy8nZg" } }, { "cell_type": "code", "source": [ "from IPython.display import clear_output\n", "!pip install bitsandbytes accelerate gradio\n", "clear_output()" ], "metadata": { "id": "eS2NsgQgvhZQ" }, "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n", "import torch\n", "\n", "# Define BitsAndBytesConfig\n", "bnb_config = BitsAndBytesConfig(load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.float16)\n", "\n", "# Model name\n", "model_name = \"ruslanmv/Medical-Llama3-v2\"\n", "\n", "# Load tokenizer and model with BitsAndBytesConfig\n", "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, bnb_config=bnb_config)\n", "model = AutoModelForCausalLM.from_pretrained(model_name, config=bnb_config)\n", "\n", "# Ensure model is on the correct device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device)" ], "metadata": { "id": "teoE-Zmv4LlP" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Define the respond function\n", "def respond(\n", " message,\n", " history: list[tuple[str, str]],\n", " system_message,\n", " max_tokens,\n", " temperature,\n", " top_p,\n", "):\n", " messages = [{\"role\": \"system\", \"content\": system_message}]\n", "\n", " for val in history:\n", " if val[0]:\n", " messages.append({\"role\": \"user\", \"content\": val[0]})\n", " if val[1]:\n", " messages.append({\"role\": \"assistant\", \"content\": val[1]})\n", "\n", " messages.append({\"role\": \"user\", \"content\": message})\n", "\n", " # Format the conversation as a single string for the model\n", " prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, padding=True, max_length=1000)\n", "\n", " # Move inputs to device\n", " input_ids = inputs['input_ids'].to(device)\n", " attention_mask = inputs['attention_mask'].to(device)\n", "\n", " # Generate the response\n", " with torch.no_grad():\n", " outputs = model.generate(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " max_length=max_tokens,\n", " temperature=temperature,\n", " top_p=top_p,\n", " use_cache=True\n", " )\n", "\n", " # Extract the response\n", " response_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n", "\n", " # Remove the prompt and system message from the response\n", " response_text = response_text.replace(system_message, '').strip()\n", " response_text = response_text.replace(f\"Human: {message}\\n\\nAssistant: \", '').strip()\n", "\n", " return response_text\n", "\n", "# Create the Gradio interface\n", "demo = gr.ChatInterface(\n", " respond,\n", " additional_inputs=[\n", " gr.Textbox(value=\"You are a Medical AI Assistant. Please be thorough and provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.\", label=\"System message\"),\n", " gr.Slider(minimum=1, maximum=2048, value=512, step=1, label=\"Max new tokens\"),\n", " gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label=\"Temperature\"),\n", " gr.Slider(\n", " minimum=0.1,\n", " maximum=1.0,\n", " value=0.95,\n", " step=0.05,\n", " label=\"Top-p (nucleus sampling)\",\n", " ),\n", " ],\n", " title=\"Medical AI Assistant\",\n", " description=\"Ask any medical-related questions and get informative answers. If the AI doesn't know the answer, it will advise seeking professional help.\",\n", " examples=[[\"I have a headache and a fever. What should I do?\"], [\"What are the symptoms of diabetes?\"], [\"How can I improve my sleep?\"]],\n", "\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 625 }, "id": "7PPuaI3C-FUg", "outputId": "b5722b5f-f2f2-4e23-fca5-d801378efa82" }, "execution_count": 42, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n", "\n", "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n", "Running on public URL: https://12a24debf148400150.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co./spaces)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
" ] }, "metadata": {} } ] } ] }