|
--- |
|
license: mit |
|
--- |
|
|
|
# SQLMaster |
|
A minimum of 10 GB VRAM is required. |
|
|
|
## Colab Example |
|
https://colab.research.google.com/drive/1Nvwie-klMNPPWI4o7Nae4l5spxEX1PaD?usp=sharing |
|
|
|
## Install Prerequisite |
|
```bash |
|
!pip install peft |
|
!pip install transformers |
|
!pip install bitsandbytes |
|
!pip install accelerate |
|
``` |
|
|
|
## Login Using Huggingface Token |
|
```bash |
|
# You need a huggingface token that can access llama2 |
|
from huggingface_hub import notebook_login |
|
notebook_login() |
|
``` |
|
|
|
## Download Model |
|
```python |
|
import torch |
|
from peft import PeftModel, PeftConfig |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
peft_model_id = "Danjie/SQLMaster_13b" |
|
config = PeftConfig.from_pretrained(peft_model_id) |
|
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map='auto', quantization_config=bnb_config) |
|
model.resize_token_embeddings(len(tokenizer) + 1) |
|
|
|
# Load the Lora model |
|
model = PeftModel.from_pretrained(model, peft_model_id) |
|
``` |
|
|
|
## Inference |
|
```python |
|
def create_sql_query(question: str, context: str) -> str: |
|
input = "Question: " + question + "\nContext:" + context + "\nAnswer" |
|
|
|
# Encode and move tensor into cuda if applicable. |
|
encoded_input = tokenizer(input, return_tensors='pt') |
|
encoded_input = {k: v.to(device) for k, v in encoded_input.items()} |
|
|
|
output = model.generate(**encoded_input, max_new_tokens=256) |
|
response = tokenizer.decode(output[0], skip_special_tokens=True) |
|
response = response[len(input):] |
|
return response |
|
``` |
|
|
|
## Example |
|
```python |
|
create_sql_query("What is the highest age of users with name Danjie", "CREATE TABLE user (age INTEGER, name STRING)") |
|
``` |