Upload Finetune_flan_t5_large_bnb_peft (1).ipynb
Browse files
Finetune_flan_t5_large_bnb_peft (1).ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"cells":[{"cell_type":"markdown","metadata":{"id":"lw1cWgq-DI5k"},"source":["# Fine-tune FLAN-T5 using `bitsandbytes`, `peft` \u0026 `transformers` π€"]},{"cell_type":"markdown","metadata":{"id":"kBFPA3-aDT7H"},"source":["In this notebook we will see how to properly use `peft` , `transformers` \u0026 `bitsandbytes` to fine-tune `flan-t5-large` in a google colab!\n","\n","We will finetune the model on [`financial_phrasebank`](https://huggingface.co/datasets/financial_phrasebank) dataset, that consists of pairs of text-labels to classify financial-related sentences, if they are either `positive`, `neutral` or `negative`.\n","\n","Note that you could use the same notebook to fine-tune `flan-t5-xl` as well, but you would need to shard the models first to avoid CPU RAM issues on Google Colab, check [these weights](https://huggingface.co/ybelkada/flan-t5-xl-sharded-bf16)."]},{"cell_type":"markdown","metadata":{"id":"5TXx1vj8kJSu"},"source":["## TODO #1\n","\n","`google/flan-t5-large` λͺ¨λΈμ 무μμ λͺ©νλ‘ λ§λ€μ΄μ‘κ³ κΈ°λν μ μλ κΈ°λ₯μ 무μμΈμ§ μ‘°μ¬νμμ€\n","\n","- λκ·λͺ¨ μΈμ΄ λͺ¨λΈμ νκ³ κ·Ήλ³΅(μΌλ° GPUμμ μ΄κ±°λ LLMμ λ‘λ©, λ―ΈμΈνλ νλ κ²μ λΆκ°λ₯)\n","- λ²μ, μμ½, CoLA, STSB μμ
μ μν."]},{"cell_type":"markdown","metadata":{"id":"ShAuuHCDDkvk"},"source":["## Install requirements"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true,"base_uri":"https://localhost:8080/"},"id":"DRQ4ZrJTDkSy"},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m519.6/519.6 kB\u001b[0m \u001b[31m43.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m258.1/258.1 kB\u001b[0m \u001b[31m26.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m295.0/295.0 kB\u001b[0m \u001b[31m25.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n"," Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n"," Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"," Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n"," Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n"," Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m33.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m42.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"," Building wheel for peft (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"]}],"source":["!pip install -q bitsandbytes datasets accelerate\n","!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git@main"]},{"cell_type":"markdown","metadata":{"id":"QBdCIrizDxFw"},"source":["## Import model and tokenizer"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dd3c5acc"},"outputs":[],"source":["# Select CUDA device index\n","import os\n","import torch\n","\n","os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n","\n","from datasets import load_dataset\n","from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n","\n","model_name = \"google/flan-t5-large\"\n","\n","model = AutoModelForSeq2SeqLM.from_pretrained(model_name, load_in_8bit=True)\n","tokenizer = AutoTokenizer.from_pretrained(model_name)"]},{"cell_type":"markdown","metadata":{"id":"VwcHieQzD_dl"},"source":["## Prepare model for training"]},{"cell_type":"markdown","metadata":{"id":"4o3ePxrjEDzv"},"source":["Some pre-processing needs to be done before training such an int8 model using `peft`, therefore let's import an utiliy function `prepare_model_for_int8_training` that will:\n","- Casts all the non `int8` modules to full precision (`fp32`) for stability\n","- Add a `forward_hook` to the input embedding layer to enable gradient computation of the input hidden states\n","- Enable gradient checkpointing for more memory-efficient training"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"1629ebcb"},"outputs":[],"source":["from peft import prepare_model_for_int8_training\n","\n","model = prepare_model_for_int8_training(model)"]},{"cell_type":"markdown","metadata":{"id":"iCpAgawAEieu"},"source":["## Load your `PeftModel`\n","\n","Here we will use LoRA (Low-Rank Adaptators) to train our model"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"17566ae3"},"outputs":[],"source":["from peft import LoraConfig, get_peft_model, TaskType\n","\n","\n","def print_trainable_parameters(model):\n"," \"\"\"\n"," Prints the number of trainable parameters in the model.\n"," \"\"\"\n"," trainable_params = 0\n"," all_param = 0\n"," for _, param in model.named_parameters():\n"," all_param += param.numel()\n"," if param.requires_grad:\n"," trainable_params += param.numel()\n"," print(\n"," f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n"," )\n","\n","\n","lora_config = LoraConfig(\n"," r=16, lora_alpha=32, target_modules=[\"q\", \"v\"], lora_dropout=0.05, bias=\"none\", task_type=\"SEQ_2_SEQ_LM\"\n",")\n","\n","\n","model = get_peft_model(model, lora_config)\n","print_trainable_parameters(model)"]},{"cell_type":"markdown","metadata":{"id":"mGkwIgNXyS7U"},"source":["As you can see, here we are only training 0.6% of the parameters of the model! This is a huge memory gain that will enable us to fine-tune the model without any memory issue."]},{"cell_type":"markdown","metadata":{"id":"9kkyrzsakn2b"},"source":["## TODO #2\n","\n","μμ κ°μ΄ 0.6%λ‘ νμ΅ νλΌλ―Έν°μ κ°―μκ° λν μΆμλ μ리μ λν΄ κ°λ΅μ μΌλ‘ μ‘°μ¬νμμ€.\n","- ν¨νΉμ μ¬μ©νμ¬ μ¬λ¬ νλ ¨ μμ λ₯Ό λ¨μΌ μνμ€λ‘ κ²°ν©νκ³ μνμ€ μ’
λ£ ν ν°μ μ¬μ©νμ¬ μ
λ ₯μ λμμμ λΆλ¦¬ν¨.\n","- λ§μ€νΉμ ν ν°μ΄ ν¨νΉλ μμ κ²½κ³λ₯Ό λμ΄ λ€λ₯Έ ν ν°μκ² μ λ¬λλ κ²μ λ°©μ§νκΈ° μν΄ μ μ©λ¨."]},{"cell_type":"markdown","metadata":{"id":"wgvqtHnFlNAl"},"source":["## TODO #3\n","\n","λͺ¨λΈ λ‘λμ `load_in_8bit=True` μ΅μ
μ μ¬μ©νμ§ μμΌλ©΄ μλ³Έμ λ‘λ©νλ€.\n","\n","μ΄ λμ λͺ¨λΈ ꡬ쑰μ, `load_in_8bit=True` μ μ¬μ©νμ λμ λͺ¨λΈ ꡬ쑰λ₯Ό λΉκ΅νμ¬ μ΄λ€ μ°¨μ΄μ μ΄ μλμ§λ₯Ό μ‘°μ¬νμμ€.\n","- κΈ°λ³Έμ μΌλ‘ λͺ¨λΈμ 32λΉνΈ μ λ°λλ‘ λ‘λν¨.\n","- `load_in_8bit=True`: λͺ¨λΈμ 8λΉνΈ μ λ°λλ‘ λ‘λν¨. -\u003e λΉ λ¦, λ©λͺ¨λ¦¬ μ¬μ©λ κ°, μ νλ μμ€ λ°μν μ μμ."]},{"cell_type":"markdown","metadata":{"id":"HsG0x6Z7FwjZ"},"source":["## Load and process data\n","\n","Here we will use [`financial_phrasebank`](https://huggingface.co/datasets/financial_phrasebank) dataset to fine-tune our model on sentiment classification on financial sentences. We will load the split `sentences_allagree`, which corresponds according to the model card to the split where there is a 100% annotator agreement."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"242cdfae"},"outputs":[],"source":["# loading dataset\n","dataset = load_dataset(\"financial_phrasebank\", \"sentences_allagree\")\n","dataset = dataset[\"train\"].train_test_split(test_size=0.1)\n","dataset[\"validation\"] = dataset[\"test\"]\n","del dataset[\"test\"]\n","\n","classes = dataset[\"train\"].features[\"label\"].names\n","dataset = dataset.map(\n"," lambda x: {\"text_label\": [classes[label] for label in x[\"label\"]]},\n"," batched=True,\n"," num_proc=1,\n",")"]},{"cell_type":"markdown","metadata":{"id":"qzwyi-Z9yzRF"},"source":["Let's also apply some pre-processing of the input data, the labels needs to be pre-processed, the tokens corresponding to `pad_token_id` needs to be set to `-100` so that the `CrossEntropy` loss associated with the model will correctly ignore these tokens."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6b7ea44c"},"outputs":[],"source":["# data preprocessing\n","text_column = \"sentence\"\n","label_column = \"text_label\"\n","max_length = 128\n","\n","\n","def preprocess_function(examples):\n"," inputs = examples[text_column]\n"," targets = examples[label_column]\n"," model_inputs = tokenizer(inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n"," labels = tokenizer(targets, max_length=3, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n"," labels = labels[\"input_ids\"]\n"," labels[labels == tokenizer.pad_token_id] = -100\n"," model_inputs[\"labels\"] = labels\n"," return model_inputs\n","\n","\n","processed_datasets = dataset.map(\n"," preprocess_function,\n"," batched=True,\n"," num_proc=1,\n"," remove_columns=dataset[\"train\"].column_names,\n"," load_from_cache_file=False,\n"," desc=\"Running tokenizer on dataset\",\n",")\n","\n","train_dataset = processed_datasets[\"train\"]\n","eval_dataset = processed_datasets[\"validation\"]"]},{"cell_type":"markdown","metadata":{"id":"zmh21tjCm01z"},"source":["## TODO #4\n","\n","μ λ°μ΄ν°μ
λ‘λ©/κ°κ³΅μμ μ¬μ©ν νλΈμ λ°μ΄ν°μ
`financial_phrasebank` ꡬ쑰μ μ΄ μ
μ΄ μ΄λ»κ² λ―ΈμΈνλμ νμ©λμλμ§ κ°λ΅μ μΌλ‘ μ‘°μ¬νμμ€.\n","- financial_phrasebank: κΈμ΅ λ΄μ€ λ¬Έμ₯μ κ°μ λ°μ΄ν°μ
-\u003e κΈμ , λΆμ , μ€λ¦½\n","- 90% νλ ¨ λ°μ΄ν°, 10% κ²μ¦ λ°μ΄ν°\n","- λ°μ΄ν°μ
μ λ μ΄λΈμ λ³ν: \"text_label\"μ΄λΌλ μλ‘μ΄ νλλ₯Ό μΆκ°νκ³ , κ° λ°μ΄ν° ν¬μΈνΈμ \"label\" κ°μ ν΄λΉ ν΄λμ€ μ΄λ¦μΌλ‘ λ³ν"]},{"cell_type":"markdown","metadata":{"id":"bcNTdVypGEPb"},"source":["## Train our model!\n","\n","Let's now train our model, run the cells below.\n","Note that for T5 since some layers are kept in `float32` for stability purposes there is no need to call autocast on the trainer."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"69c756ac"},"outputs":[],"source":["from transformers import TrainingArguments, Trainer\n","\n","training_args = TrainingArguments(\n"," \"temp\",\n"," evaluation_strategy=\"epoch\",\n"," learning_rate=1e-3,\n"," gradient_accumulation_steps=1,\n"," auto_find_batch_size=True,\n"," num_train_epochs=1,\n"," save_steps=100,\n"," save_total_limit=8,\n",")\n","trainer = Trainer(\n"," model=model,\n"," args=training_args,\n"," train_dataset=train_dataset,\n"," eval_dataset=eval_dataset,\n",")\n","model.config.use_cache = False # silence the warnings. Please re-enable for inference!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ab52b651"},"outputs":[],"source":["trainer.train()"]},{"cell_type":"markdown","metadata":{"id":"r98VtofiGXtO"},"source":["## Qualitatively test our model"]},{"cell_type":"markdown","metadata":{"id":"NIm7z3UNzGPP"},"source":["Let's have a quick qualitative evaluation of the model, by taking a sample from the dataset that corresponds to a positive label. Run your generation similarly as you were running your model from `transformers`:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"c95d6173"},"outputs":[],"source":["model.eval()\n","input_text = \"In January-September 2009 , the Group 's net interest income increased to EUR 112.4 mn from EUR 74.3 mn in January-September 2008 .\"\n","inputs = tokenizer(input_text, return_tensors=\"pt\")\n","\n","outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n","\n","print(\"input sentence: \", input_text)\n","print(\" output prediction: \", tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"]},{"cell_type":"markdown","metadata":{"id":"ubwn2Qdbl3Fb"},"source":["## TODO #5\n","\n","λ³ΈμΈμ νκΉ
νμ΄μ€ κ³μ μ λ§λ€κ³ μλ νλΈμ μ
λ‘λ/λ€μ΄λ‘λ/νμΈ κ³Όμ μ λ³ΈμΈ κ³μ κΈ°μ€μΌλ‘ μ§ννμμ€.\n","\n","μ§ν ν μ
λ₯΄λν νκΉ
νμ΄μ€ νλΈμ λͺ¨λΈ idλ₯Ό μ μΌμμ€.\n","- λ§ν¬λ€μ΄ μ€νμΌλ‘ μμ±νμμ€."]},{"cell_type":"markdown","metadata":{"id":"9QqBlwzoGZ3f"},"source":["## Share your adapters on π€ Hub"]},{"cell_type":"markdown","metadata":{"id":"NT-C8SjcKqUx"},"source":["Once you have trained your adapter, you can easily share it on the Hub using the method `push_to_hub` . Note that only the adapter weights and config will be pushed"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bcbfa1f9"},"outputs":[],"source":["from huggingface_hub import notebook_login\n","\n","notebook_login(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)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"rFKJ4vHNGkJw"},"outputs":[],"source":["model.push_to_hub(\"ybelkada/flan-t5-large-financial-phrasebank-lora\", use_auth_token=True)"]},{"cell_type":"markdown","metadata":{"id":"xHuDmbCYJ89f"},"source":["## Load your adapter from the Hub"]},{"cell_type":"markdown","metadata":{"id":"ANFo6DdfKlU3"},"source":["You can load the model together with the adapter with few lines of code! Check the snippet below to load the adapter from the Hub and run the example evaluation!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"j097aaPWJ-9u"},"outputs":[],"source":["import torch\n","from peft import PeftModel, PeftConfig\n","from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n","\n","peft_model_id = \"ybelkada/flan-t5-large-financial-phrasebank-lora\"\n","config = PeftConfig.from_pretrained(peft_model_id)\n","\n","model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, torch_dtype=\"auto\", device_map=\"auto\")\n","tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n","\n","# Load the Lora model\n","model = PeftModel.from_pretrained(model, peft_model_id)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"jmjwWYt0KI_I"},"outputs":[],"source":["model.eval()\n","input_text = \"In January-September 2009 , the Group 's net interest income increased to EUR 112.4 mn from EUR 74.3 mn in January-September 2008 .\"\n","inputs = tokenizer(input_text, return_tensors=\"pt\")\n","\n","outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n","\n","print(\"input sentence: \", input_text)\n","print(\" output prediction: \", tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"]}],"metadata":{"accelerator":"GPU","colab":{"name":"","provenance":[{"file_id":"1-1-LIlaEF8ENrJfcID6S1p7ZJy1Ur1LY","timestamp":1695805141441}],"version":""},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"},"vscode":{"interpreter":{"hash":"1219a10c7def3e2ad4f431cfa6f49d569fcc5949850132f23800e792129eefbb"}}},"nbformat":4,"nbformat_minor":5}
|