--- library_name: peft base_model: stabilityai/stablelm-3b-4e1t license: mit language: - en metrics: - bleu - bertscore - accuracy tags: - medical --- # Model Card for Model ID Welcome to StableMed , it's a stable 3b llm - alpha fine tuned model for Medical Question and Answering. ## Model Details ### Model Description This is a stable 3b finetune for medical QnA using MedQuad. It's intended for education in public health and sanitation, specifically to improve our understanding of outreach and communication. - **Developed by:** [Tonic](https://huggingface.co./Tonic) - **Shared by [optional]:** [Tonic](https://huggingface.co./Tonic) - **Model type:** stable LM 3b - Alpha - **Language(s) (NLP):** English - **License:** MIT - **Finetuned from model [optional]:** [stabilityai/stablelm-3b-4e1t](https://huggingface.co./stabilityai/stablelm-3b-4e1t) ### Model Sources [optional] - **Repository:** [Tonic/stablemed](https://huggingface.co./Tonic/stablemed) - **Demo :** [Tonic/StableMed_Chat](https://huggingface.co./Tonic/StableMed_Chat) ## Uses Use this model for educational purposes only , do not use for decision support in the wild. Use this model for Medical Q n A. Use this model as a educational tool for "miniature" models. ### Direct Use Medical Question and Answering ### Downstream Use [optional] Finetune this model to work in a network or swarm of medical finetunes. ### Out-of-Scope Use do not use this model in the wild. do not use this model directly. do not use this model for real world decision support. ## Bias, Risks, and Limitations [We use Giskard for evaluation - Coming Soon!] ### Recommendations Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. DO NOT USE THIS MODEL WITHOUT EVALUATION DO NOT USE THIS MODEL WITHOUT BENCHMARKING DO NOT USE THIS MODEL WITHOUT FURTHER FINETUNING ## How to Get Started with the Model Use the code below to get started with the model. ```Python from transformers import AutoTokenizer, MistralForCausalLM import torch import gradio as gr import random from textwrap import wrap from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM from peft import PeftModel, PeftConfig import torch import gradio as gr import os hf_token = os.environ.get('HUGGINGFACE_TOKEN') # Functions to Wrap the Prompt Correctly def wrap_text(text, width=90): lines = text.split('\n') wrapped_lines = [textwrap.fill(line, width=width) for line in lines] wrapped_text = '\n'.join(wrapped_lines) return wrapped_text def multimodal_prompt(user_input, system_prompt="You are an expert medical analyst:"): # Combine user input and system prompt formatted_input = f"[INSTRUCTION]{system_prompt}[QUESTION]{user_input}" # Encode the input text encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False) model_inputs = encodeds.to(device) # Generate a response using the model output = model.generate( **model_inputs, max_length=max_length, use_cache=True, early_stopping=True, bos_token_id=model.config.bos_token_id, eos_token_id=model.config.eos_token_id, pad_token_id=model.config.eos_token_id, temperature=0.1, do_sample=True ) # Decode the response response_text = tokenizer.decode(output[0], skip_special_tokens=True) return response_text # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # Use the base model's ID base_model_id = "stabilityai/stablelm-3b-4e1t" model_directory = "Tonic/stablemed" # Instantiate the Tokenizer tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t", trust_remote_code=True, padding_side="left") # tokenizer = AutoTokenizer.from_pretrained("Tonic/stablemed", trust_remote_code=True, padding_side="left") tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'left' # Load the PEFT model peft_config = PeftConfig.from_pretrained("Tonic/stablemed", token=hf_token) peft_model = MistralForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t", trust_remote_code=True) peft_model = PeftModel.from_pretrained(peft_model, "Tonic/stablemed", token=hf_token) class ChatBot: def __init__(self): self.history = [] def predict(self, user_input, system_prompt="You are an expert medical analyst:"): # Combine user input and system prompt formatted_input = f"[INSTRUCTION:]{system_prompt}[QUESTION:] {user_input}" # Encode user input user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt") # Concatenate the user input with chat history if len(self.history) > 0: chat_history_ids = torch.cat([self.history, user_input_ids], dim=-1) else: chat_history_ids = user_input_ids # Generate a response using the PEFT model response = peft_model.generate(input_ids=chat_history_ids, max_length=400, pad_token_id=tokenizer.eos_token_id) # Update chat history self.history = chat_history_ids # Decode and return the response response_text = tokenizer.decode(response[0], skip_special_tokens=True) return response_text bot = ChatBot() title = "👋🏻Welcome to Tonic's StableMed Chat🚀" description = """ You can use this Space to test out the current model [StableMed](https://huggingface.co./Tonic/stablemed) or You can also use 😷StableMed⚕️ on your own data & in your own way by cloning this space. 🧬🔬🔍 Simply click here: Duplicate Space # Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community on 👻Discord: [Discord](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co./TeamTonic) & [MultiTransformer](https://huggingface.co./MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [PolyGPT](https://github.com/tonic-ai/polygpt-alpha) """ examples = [["What is the proper treatment for buccal herpes?", "Please provide information on the most effective antiviral medications and home remedies for treating buccal herpes."]] iface = gr.Interface( fn=bot.predict, title=title, description=description, examples=examples, inputs=["text", "text"], # Take user input and system prompt separately outputs="text", theme="ParityError/Anime" ) iface.launch() ``` ## Training Details ### Training Data [Dataset](https://huggingface.co./datasets/keivalya/MedQuad-MedicalQnADataset) ```json output Dataset({ features: ['qtype', 'Question', 'Answer'], num_rows: 16407 }) ``` ### Training Procedure ```json trainable params: 12940288 || all params: 1539606528 || trainable%: 0.8404931886596937 ``` Using Lora #### Preprocessing [optional] Original Model Configuration: ```json StableLMEpochForCausalLM( (model): StableLMEpochModel( (embed_tokens): Embedding(50304, 2560) (layers): ModuleList( (0-31): 32 x DecoderLayer( (self_attn): Attention( (q_proj): Linear4bit(in_features=2560, out_features=2560, bias=False) (k_proj): Linear4bit(in_features=2560, out_features=2560, bias=False) (v_proj): Linear4bit(in_features=2560, out_features=2560, bias=False) (o_proj): Linear4bit(in_features=2560, out_features=2560, bias=False) (rotary_emb): RotaryEmbedding() ) (mlp): MLP( (gate_proj): Linear4bit(in_features=2560, out_features=6912, bias=False) (up_proj): Linear4bit(in_features=2560, out_features=6912, bias=False) (down_proj): Linear4bit(in_features=6912, out_features=2560, bias=False) (act_fn): SiLU() ) (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) (post_attention_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) ) ) (norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) ) (lm_head): Linear(in_features=2560, out_features=50304, bias=False) ) ``` Data Formatting : ```json Given a target sentence construct the underlying meaning representation of the input sentence as a single function with attributes and attribute values. This function should describe the target string accurately and the function must be one of the following ['inform', 'request', 'give_opinion', 'confirm', 'verify_attribute', 'suggest', 'request_explanation', 'recommend', 'request_attribute']. The attributes must be one of the following: ['name', 'pathology', 'therapeutic', 'dosage', 'side_effects', 'contraindications', 'manufacturer', 'price', 'availability', 'administration', 'warnings', 'interactions', 'storage', 'expiration_date', 'formulation', 'strength', 'route_of_administration', 'class', 'prescription_required', 'generic_name', 'brand_name', 'patient_instructions'] ``` #### Training Hyperparameters - **Training regime:** #### Speeds, Sizes, Times [optional] ```json TrainOutput(global_step=2051, training_loss=0.6156479549198718, metrics={'train_runtime': 22971.4974, 'train_samples_per_second': 0.357, 'train_steps_per_second': 0.089, 'total_flos': 6.5950444363776e+16, 'train_loss': 0.6156479549198718, 'epoch': 0.5}) ``` ## Results | Value | Measurement | |-------|-------------| | 50 | 1.427000 | | 100 | 0.763200 | | 150 | 0.708200 | | 200 | 0.662300 | | 250 | 0.650900 | | 300 | 0.617400 | | 350 | 0.602900 | | 400 | 0.608900 | | 450 | 0.596100 | | 500 | 0.602000 | | 550 | 0.594700 | | 600 | 0.584700 | | 650 | 0.611000 | | 700 | 0.558700 | | 750 | 0.616300 | | 800 | 0.568700 | | 850 | 0.597300 | | 900 | 0.607400 | | 950 | 0.563200 | | 1000 | 0.602900 | | 1050 | 0.594900 | | 1100 | 0.583000 | | 1150 | 0.604500 | | 1200 | 0.547400 | | 1250 | 0.586600 | | 1300 | 0.554300 | | 1350 | 0.581000 | | 1400 | 0.578900 | | 1450 | 0.563200 | | 1500 | 0.556800 | | 1550 | 0.570300 | | 1600 | 0.599800 | | 1650 | 0.556000 | | 1700 | 0.592500 | | 1750 | 0.597200 | | 1800 | 0.559100 | | 1850 | 0.586100 | | 1900 | 0.581100 | | 1950 | 0.589400 | | 2000 | 0.581100 | | 2050 | 0.533100 | ## Environmental Impact Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). - **Hardware Type:** [More Information Needed] - **Hours used:** [More Information Needed] - **Cloud Provider:** [More Information Needed] - **Compute Region:** [More Information Needed] - **Carbon Emitted:** [More Information Needed] ## Technical Specifications [optional] ### Model Architecture and Objective with LORA : ```json PeftModelForCausalLM( (base_model): LoraModel( (model): StableLMEpochForCausalLM( (model): StableLMEpochModel( (embed_tokens): Embedding(50304, 2560) (layers): ModuleList( (0-31): 32 x DecoderLayer( (self_attn): Attention( (q_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=2560, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=2560, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=False) ) (k_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=2560, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=2560, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=False) ) (v_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=2560, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=2560, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=False) ) (o_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=2560, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=2560, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=False) ) (rotary_emb): RotaryEmbedding() ) (mlp): MLP( (gate_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=2560, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=6912, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=2560, out_features=6912, bias=False) ) (up_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=2560, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=6912, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=2560, out_features=6912, bias=False) ) (down_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=6912, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=2560, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=6912, out_features=2560, bias=False) ) (act_fn): SiLU() ) (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) (post_attention_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) ) ) (norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) ) (lm_head): Linear( in_features=2560, out_features=50304, bias=False (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=2560, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=50304, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) ) ) ) ``` ### Compute Infrastructure GCS #### Hardware T4 #### Software transformers peft torch datasets ## Model Card Authors [optional] [Tonic](https://huggingface.co./Tonic) ## Model Card Contact [Tonic](https://huggingface.co./Tonic) ## Training procedure The following `bitsandbytes` quantization config was used during training: - quant_method: bitsandbytes - load_in_8bit: False - load_in_4bit: True - llm_int8_threshold: 6.0 - llm_int8_skip_modules: None - llm_int8_enable_fp32_cpu_offload: False - llm_int8_has_fp16_weight: False - bnb_4bit_quant_type: nf4 - bnb_4bit_use_double_quant: True - bnb_4bit_compute_dtype: bfloat16 ### Framework versions - PEFT 0.6.2.dev0