{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"toc_visible": true,
"machine_shape": "hm",
"gpuType": "A100",
"authorship_tag": "ABX9TyNMzCSw8XLVSOI/aj2QMEti",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"source": [
"# Medical AI Chatbot\n",
"## [ruslanmv/Medical-Llama3-v2](https://github.com/ruslanmv/ai-medical-chatbot/blob/master/Chatbot-Medical-Llama3-v2.ipynb)"
],
"metadata": {
"id": "D2JxjUcy8nZg"
}
},
{
"cell_type": "code",
"source": [
"from IPython.display import clear_output\n",
"!pip install bitsandbytes accelerate gradio\n",
"clear_output()"
],
"metadata": {
"id": "eS2NsgQgvhZQ"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
"import torch\n",
"\n",
"# Define BitsAndBytesConfig\n",
"bnb_config = BitsAndBytesConfig(load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch.float16)\n",
"\n",
"# Model name\n",
"model_name = \"ruslanmv/Medical-Llama3-v2\"\n",
"\n",
"# Load tokenizer and model with BitsAndBytesConfig\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, bnb_config=bnb_config)\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, config=bnb_config)\n",
"\n",
"# Ensure model is on the correct device\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device)"
],
"metadata": {
"id": "teoE-Zmv4LlP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Define the respond function\n",
"def respond(\n",
" message,\n",
" history: list[tuple[str, str]],\n",
" system_message,\n",
" max_tokens,\n",
" temperature,\n",
" top_p,\n",
"):\n",
" messages = [{\"role\": \"system\", \"content\": system_message}]\n",
"\n",
" for val in history:\n",
" if val[0]:\n",
" messages.append({\"role\": \"user\", \"content\": val[0]})\n",
" if val[1]:\n",
" messages.append({\"role\": \"assistant\", \"content\": val[1]})\n",
"\n",
" messages.append({\"role\": \"user\", \"content\": message})\n",
"\n",
" # Format the conversation as a single string for the model\n",
" prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
" inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, padding=True, max_length=1000)\n",
"\n",
" # Move inputs to device\n",
" input_ids = inputs['input_ids'].to(device)\n",
" attention_mask = inputs['attention_mask'].to(device)\n",
"\n",
" # Generate the response\n",
" with torch.no_grad():\n",
" outputs = model.generate(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask,\n",
" max_length=max_tokens,\n",
" temperature=temperature,\n",
" top_p=top_p,\n",
" use_cache=True\n",
" )\n",
"\n",
" # Extract the response\n",
" response_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n",
"\n",
" # Remove the prompt and system message from the response\n",
" response_text = response_text.replace(system_message, '').strip()\n",
" response_text = response_text.replace(f\"Human: {message}\\n\\nAssistant: \", '').strip()\n",
"\n",
" return response_text\n",
"\n",
"# Create the Gradio interface\n",
"demo = gr.ChatInterface(\n",
" respond,\n",
" additional_inputs=[\n",
" gr.Textbox(value=\"You are a Medical AI Assistant. Please be thorough and provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.\", label=\"System message\"),\n",
" gr.Slider(minimum=1, maximum=2048, value=512, step=1, label=\"Max new tokens\"),\n",
" gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label=\"Temperature\"),\n",
" gr.Slider(\n",
" minimum=0.1,\n",
" maximum=1.0,\n",
" value=0.95,\n",
" step=0.05,\n",
" label=\"Top-p (nucleus sampling)\",\n",
" ),\n",
" ],\n",
" title=\"Medical AI Assistant\",\n",
" description=\"Ask any medical-related questions and get informative answers. If the AI doesn't know the answer, it will advise seeking professional help.\",\n",
" examples=[[\"I have a headache and a fever. What should I do?\"], [\"What are the symptoms of diabetes?\"], [\"How can I improve my sleep?\"]],\n",
"\n",
")\n",
"\n",
"if __name__ == \"__main__\":\n",
" demo.launch()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 625
},
"id": "7PPuaI3C-FUg",
"outputId": "b5722b5f-f2f2-4e23-fca5-d801378efa82"
},
"execution_count": 42,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
"\n",
"Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
"Running on public URL: https://12a24debf148400150.gradio.live\n",
"\n",
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co./spaces)\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
""
]
},
"metadata": {}
}
]
}
]
}