{"cells":[{"cell_type":"markdown","id":"lw1cWgq-DI5k","metadata":{"id":"lw1cWgq-DI5k"},"source":["# Fine-tune FLAN-T5 using `bitsandbytes`, `peft` & `transformers` ๐Ÿค—"]},{"cell_type":"markdown","id":"kBFPA3-aDT7H","metadata":{"id":"kBFPA3-aDT7H"},"source":["In this notebook we will see how to properly use `peft` , `transformers` & `bitsandbytes` to fine-tune `flan-t5-large` in a google colab!\n","\n","We will finetune the model on [`financial_phrasebank`](https://huggingface.co./datasets/financial_phrasebank) dataset, that consists of pairs of text-labels to classify financial-related sentences, if they are either `positive`, `neutral` or `negative`.\n","\n","Note that you could use the same notebook to fine-tune `flan-t5-xl` as well, but you would need to shard the models first to avoid CPU RAM issues on Google Colab, check [these weights](https://huggingface.co./ybelkada/flan-t5-xl-sharded-bf16)."]},{"cell_type":"markdown","source":["## TODO #1\n","\n","`google/flan-t5-large` ๋ชจ๋ธ์€ ๋ฌด์—‡์„ ๋ชฉํ‘œ๋กœ ๋งŒ๋“ค์–ด์กŒ๊ณ  ๊ธฐ๋Œ€ํ•  ์ˆ˜ ์žˆ๋Š” ๊ธฐ๋Šฅ์€ ๋ฌด์—‡์ธ์ง€ ์กฐ์‚ฌํ•˜์‹œ์˜ค\n","- ๋งˆํฌ๋‹ค์šด ์Šคํƒ€์ผ๋กœ ์ž‘์„ฑํ•˜์‹œ์˜ค"],"metadata":{"id":"5TXx1vj8kJSu"},"id":"5TXx1vj8kJSu"},{"cell_type":"markdown","source":["## 'google/flan-t5-large' ๋ชจ๋ธ ๊ฐœ์š”\n","\n","- 'google/flan-t5-large' ๋ชจ๋ธ์€ T5 ์•„ํ‚คํ…์ฒ˜๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•ฉ๋‹ˆ๋‹ค.\n","- T5 ๋ชจ๋ธ์€ ํ…์ŠคํŠธ ์ž…๋ ฅ์„ ๋ฐ›์•„ ์ถœ๋ ฅ์„ ์ƒ์„ฑํ•˜๋Š” ์‹œํ€€์Šค ํˆฌ ์‹œํ€€์Šค ๋ชจ๋ธ๋กœ, NLP ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.\n","- ์ด ๋ชจ๋ธ์€ \"๋ชจ๋“  ๊ฒƒ์€ ํ…์ŠคํŠธ\"๋ผ๋Š” ์ ‘๊ทผ์„ ๋”ฐ๋ฅด๋ฉฐ ์ž…๋ ฅ ํ…์ŠคํŠธ์™€ ์ถœ๋ ฅ ํ…์ŠคํŠธ๋ฅผ ๋™์ผํ•œ ํ˜•์‹์œผ๋กœ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.\n","\n","## ๊ธฐ๋Œ€ ๊ธฐ๋Šฅ๊ณผ ํ™œ์šฉ\n","\n","- 'google/flan-t5-large' ๋ชจ๋ธ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋‹ค์–‘ํ•œ NLP ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:\n"," - ํ…์ŠคํŠธ ์ƒ์„ฑ: ์ž…๋ ฅ ํ…์ŠคํŠธ๋กœ๋ถ€ํ„ฐ ๋‹ค์–‘ํ•œ ์ข…๋ฅ˜์˜ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n"," - ์š”์•ฝ: ๊ธด ๋ฌธ์„œ๋‚˜ ํ…์ŠคํŠธ๋ฅผ ๊ฐ„๊ฒฐํ•œ ์š”์•ฝ์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.\n"," - ๋ฒˆ์—ญ: ๋‹ค๊ตญ์–ด ๋ฒˆ์—ญ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๋ฉฐ ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ๋‹ค๋ฅธ ์–ธ์–ด๋กœ ๋ฒˆ์—ญํ•ฉ๋‹ˆ๋‹ค.\n"," - ์งˆ๋ฌธ ์‘๋‹ต: ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•˜๊ณ , ์ง€๋ฌธ๊ณผ ์งˆ๋ฌธ์„ ์ดํ•ดํ•˜์—ฌ ๋‹ต๋ณ€ํ•ฉ๋‹ˆ๋‹ค.\n"," - ๋ฌธ์žฅ ๋ถ„๋ฅ˜: ์ฃผ์–ด์ง„ ๋ฌธ์žฅ์„ ์นดํ…Œ๊ณ ๋ฆฌ ๋˜๋Š” ํด๋ž˜์Šค๋กœ ๋ถ„๋ฅ˜ํ•ฉ๋‹ˆ๋‹ค.\n","\n","'google/flan-t5-large' ๋ชจ๋ธ์„ ํ†ตํ•ด ๋‹ค์–‘ํ•œ NLP ์ž‘์—…์„ ์ž๋™ํ™”ํ•˜๊ณ  ํ–ฅ์ƒ์‹œํ‚ค๊ธฐ ์œ„ํ•ด, ๋ชจ๋ธ์˜ ํŠน์ • ๊ธฐ๋Šฅ๊ณผ ์ž‘์—…์— ๋”ฐ๋ฅธ ์„ค์ • ๋ฐ ๋ฐ์ดํ„ฐ๊ฐ€ ํ•„์š”ํ•˜๋ฉฐ, ์ด๋ฅผ ํ†ตํ•ด ์ •ํ™•ํ•˜๊ณ  ํšจ์œจ์ ์ธ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."],"metadata":{"id":"gNdrvxdIM83V"},"id":"gNdrvxdIM83V"},{"cell_type":"markdown","id":"ShAuuHCDDkvk","metadata":{"id":"ShAuuHCDDkvk"},"source":["## Install requirements"]},{"cell_type":"code","execution_count":null,"id":"DRQ4ZrJTDkSy","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"DRQ4ZrJTDkSy","outputId":"3b98c09a-6889-4cdc-dddf-a7bb231b1f1d"},"outputs":[{"output_type":"stream","name":"stdout","text":[" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n"," Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n"," Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"]}],"source":["!pip install -q bitsandbytes datasets accelerate\n","!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git@main"]},{"cell_type":"markdown","id":"QBdCIrizDxFw","metadata":{"id":"QBdCIrizDxFw"},"source":["## Import model and tokenizer"]},{"cell_type":"code","execution_count":null,"id":"dd3c5acc","metadata":{"id":"dd3c5acc"},"outputs":[],"source":["# Select CUDA device index\n","import os\n","import torch\n","\n","os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n","\n","from datasets import load_dataset\n","from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n","\n","model_name = \"google/flan-t5-large\"\n","\n","model = AutoModelForSeq2SeqLM.from_pretrained(model_name, load_in_8bit=True)\n","tokenizer = AutoTokenizer.from_pretrained(model_name)"]},{"cell_type":"markdown","id":"VwcHieQzD_dl","metadata":{"id":"VwcHieQzD_dl"},"source":["## Prepare model for training"]},{"cell_type":"markdown","id":"4o3ePxrjEDzv","metadata":{"id":"4o3ePxrjEDzv"},"source":["Some pre-processing needs to be done before training such an int8 model using `peft`, therefore let's import an utiliy function `prepare_model_for_int8_training` that will:\n","- Casts all the non `int8` modules to full precision (`fp32`) for stability\n","- Add a `forward_hook` to the input embedding layer to enable gradient computation of the input hidden states\n","- Enable gradient checkpointing for more memory-efficient training"]},{"cell_type":"code","execution_count":null,"id":"1629ebcb","metadata":{"id":"1629ebcb"},"outputs":[],"source":["from peft import prepare_model_for_int8_training\n","\n","model = prepare_model_for_int8_training(model)"]},{"cell_type":"markdown","id":"iCpAgawAEieu","metadata":{"id":"iCpAgawAEieu"},"source":["## Load your `PeftModel`\n","\n","Here we will use LoRA (Low-Rank Adaptators) to train our model"]},{"cell_type":"code","execution_count":null,"id":"17566ae3","metadata":{"id":"17566ae3"},"outputs":[],"source":["from peft import LoraConfig, get_peft_model, TaskType\n","\n","\n","def print_trainable_parameters(model):\n"," \"\"\"\n"," Prints the number of trainable parameters in the model.\n"," \"\"\"\n"," trainable_params = 0\n"," all_param = 0\n"," for _, param in model.named_parameters():\n"," all_param += param.numel()\n"," if param.requires_grad:\n"," trainable_params += param.numel()\n"," print(\n"," f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n"," )\n","\n","\n","lora_config = LoraConfig(\n"," r=16, lora_alpha=32, target_modules=[\"q\", \"v\"], lora_dropout=0.05, bias=\"none\", task_type=\"SEQ_2_SEQ_LM\"\n",")\n","\n","\n","model = get_peft_model(model, lora_config)\n","print_trainable_parameters(model)"]},{"cell_type":"markdown","id":"mGkwIgNXyS7U","metadata":{"id":"mGkwIgNXyS7U"},"source":["As you can see, here we are only training 0.6% of the parameters of the model! This is a huge memory gain that will enable us to fine-tune the model without any memory issue."]},{"cell_type":"markdown","source":["## TODO #2\n","\n","์œ„์™€ ๊ฐ™์ด 0.6%๋กœ ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ๊ฐฏ์ˆ˜๊ฐ€ ๋Œ€ํญ ์ถ•์†Œ๋œ ์›๋ฆฌ์— ๋Œ€ํ•ด ๊ฐœ๋žต์ ์œผ๋กœ ์กฐ์‚ฌํ•˜์‹œ์˜ค.\n","- ๋งˆํฌ๋‹ค์šด ์Šคํƒ€์ผ๋กœ ์ž‘์„ฑํ•˜์‹œ์˜ค"],"metadata":{"id":"9kkyrzsakn2b"},"id":"9kkyrzsakn2b"},{"cell_type":"markdown","source":["## ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ํฌ๊ธฐ ์ถ•์†Œ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ ๊ฐœ์„ \n","\n","์ œ๊ณต๋œ ์ฝ”๋“œ์—์„œ ์‚ฌ์šฉ๋œ ๊ธฐ์ˆ ์€ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ํฌ๊ธฐ๋ฅผ ์ถ•์†Œํ•˜๋ฉด์„œ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ๊ฐœ์„ ํ•˜๊ณ  ๋ชจ๋ธ์„ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ์ƒํƒœ๋กœ ์œ ์ง€ํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ์ด ๊ธฐ์ˆ ์€ ๋ฉ”๋ชจ๋ฆฌ ์ œ์•ฝ์ด ์žˆ๋Š” ํ™˜๊ฒฝ์—์„œ ํšจ๊ณผ์ ์œผ๋กœ ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜๋Š” ๋ฐ ๋„์›€์„ ์ค๋‹ˆ๋‹ค.\n","\n","- **๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ถœ๋ ฅ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ์ด๋“**:\n"," - `print_trainable_parameters` ํ•จ์ˆ˜๋Š” ๋ชจ๋ธ์˜ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.\n"," - ๊ฒฐ๊ณผ์—์„œ \"trainable params\"๋Š” ์‹ค์ œ๋กœ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.\n"," - \"all params\"๋Š” ๋ชจ๋ธ์˜ ์ด ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.\n"," - \"trainable%\"์€ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ๋ฐฑ๋ถ„์œจ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.\n"," - ๊ฒฐ๊ณผ์—์„œ \"trainable%\"๊ฐ€ ๋งค์šฐ ๋‚ฎ๊ฒŒ ๋‚˜ํƒ€๋‚˜๋ฉด, ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ค‘ ์ผ๋ถ€๋งŒ์ด ํ•™์Šต ๊ฐ€๋Šฅํ•œ ์ƒํƒœ๋กœ ์œ ์ง€๋˜๊ณ , ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ์ด ํฌ๊ฒŒ ์ค„์–ด๋“ค๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.\n","\n","์ด ์ ‘๊ทผ ๋ฐฉ์‹์€ ๋ฉ”๋ชจ๋ฆฌ ์ œ์•ฝ์ด ์žˆ๋Š” ํ™˜๊ฒฝ์—์„œ ํฐ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ณ ์ž ํ•  ๋•Œ ํšจ๊ณผ์ ์ด๋ฉฐ, ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ด์ค๋‹ˆ๋‹ค. ๋˜ํ•œ ๋ชจ๋ธ์„ ํšจ์œจ์ ์œผ๋กœ ํ•™์Šตํ•˜๊ณ  ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ๋„์™€์ค๋‹ˆ๋‹ค."],"metadata":{"id":"Yd8VN8RGNCmH"},"id":"Yd8VN8RGNCmH"},{"cell_type":"markdown","source":["## TODO #3\n","\n","๋ชจ๋ธ ๋กœ๋“œ์‹œ `load_in_8bit=True` ์˜ต์…˜์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์œผ๋ฉด ์›๋ณธ์„ ๋กœ๋”ฉํ•œ๋‹ค.\n","\n","์ด ๋•Œ์˜ ๋ชจ๋ธ ๊ตฌ์กฐ์™€, `load_in_8bit=True` ์„ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ์˜ ๋ฌด๋ธ ๊ตฌ์กฐ๋ฅผ ๋น„๊ตํ•˜์—ฌ ์–ด๋–ค ์ฐจ์ด์ ์ด ์žˆ๋Š”์ง€๋ฅผ ์กฐ์‚ฌํ•˜์‹œ์˜ค.\n","- ๋งˆํฌ๋‹ค์šด ์Šคํƒ€์ผ๋กœ ์ž‘์„ฑํ•˜์‹œ์˜ค"],"metadata":{"id":"wgvqtHnFlNAl"},"id":"wgvqtHnFlNAl"},{"cell_type":"markdown","source":["## `load_in_8bit=True`์™€ `load_in_8bit=False` ๋ชจ๋ธ ๋กœ๋“œ ์˜ต์…˜ ๋น„๊ต\n","\n","`load_in_8bit=True` ์˜ต์…˜์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๋Š” ๊ฒฝ์šฐ์™€ ๊ทธ๋ ‡์ง€ ์•Š์€ ๊ฒฝ์šฐ ๋ชจ๋ธ ๊ตฌ์กฐ์— ์ฐจ์ด๊ฐ€ ์žˆ์„ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ฃผ์š” ์ฐจ์ด์ ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:\n","\n","1. **๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ๋ฐ์ดํ„ฐ ์œ ํ˜•**:\n"," - `load_in_8bit=True`๋ฅผ ์‚ฌ์šฉํ•œ ๊ฒฝ์šฐ: ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” 8๋น„ํŠธ ์ •๋ฐ€๋„๋กœ ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. ์ด๋Š” ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ์„ ํ‘œํ˜„ํ•˜๋Š”๋ฐ ์‚ฌ์šฉ๋˜๋Š” ์ˆซ์ž๊ฐ€ ์ƒ๋Œ€์ ์œผ๋กœ ์ž‘์Œ์„ ์˜๋ฏธํ•˜๋ฉฐ, ์ด๋Š” ๋ชจ๋ธ์ด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ ๊ฒŒ ์‚ฌ์šฉํ•˜๋Š” ์žฅ์ ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.\n"," - `load_in_8bit=True`๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ: ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ 32๋น„ํŠธ ๋˜๋Š” 16๋น„ํŠธ๋กœ ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. ์ด๋Š” ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ์ด ์ƒ๋Œ€์ ์œผ๋กœ ํฐ ์ˆซ์ž๋ฅผ ๊ฐ€์งˆ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋กœ ์ธํ•ด ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ์ด ์ฆ๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.\n","\n","2. **๋ชจ๋ธ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ**:\n"," - `load_in_8bit=True`๋ฅผ ์‚ฌ์šฉํ•œ ๊ฒฝ์šฐ: ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•˜๋Š” ๋ฉ”๋ชจ๋ฆฌ ์–‘์ด ๊ฐ์†Œํ•˜๋ฏ€๋กœ ๋” ํšจ์œจ์ ์œผ๋กœ ์ž‘๋™ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.\n"," - `load_in_8bit=True`๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ: ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•˜๋Š” ๋ฉ”๋ชจ๋ฆฌ ์–‘์ด ์ฆ๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.\n","\n","3. **์„ฑ๋Šฅ ๋ฐ ์ •ํ™•๋„**:\n"," - `load_in_8bit=True`๋ฅผ ์‚ฌ์šฉํ•œ ๊ฒฝ์šฐ: ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ์˜ 8๋น„ํŠธ ์ •๋ฐ€๋„๋กœ ์ธํ•ด ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ๊ณผ ์ •ํ™•๋„๊ฐ€ ๊ฐ์†Œํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋กœ ์ธํ•ด ์˜ˆ์ธก์˜ ์ •ํ™•๋„๊ฐ€ ์ €ํ•˜๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.\n"," - `load_in_8bit=True`๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ: ์›๋ณธ ์ •๋ฐ€๋„๋กœ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋กœ๋“œ๋˜๋ฏ€๋กœ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ๋” ๋†’์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.\n","\n","๋”ฐ๋ผ์„œ `load_in_8bit=True`๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์ด ๊ฐœ์„ ๋˜์ง€๋งŒ, ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ๊ฐ์†Œํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ์ ๋‹นํ•˜๊ฒŒ ๊ณ ๋ คํ•ด์„œ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค."],"metadata":{"id":"m08rbbKxPAby"},"id":"m08rbbKxPAby"},{"cell_type":"markdown","id":"HsG0x6Z7FwjZ","metadata":{"id":"HsG0x6Z7FwjZ"},"source":["## Load and process data\n","\n","Here we will use [`financial_phrasebank`](https://huggingface.co./datasets/financial_phrasebank) dataset to fine-tune our model on sentiment classification on financial sentences. We will load the split `sentences_allagree`, which corresponds according to the model card to the split where there is a 100% annotator agreement."]},{"cell_type":"code","execution_count":null,"id":"242cdfae","metadata":{"id":"242cdfae"},"outputs":[],"source":["# loading dataset\n","dataset = load_dataset(\"financial_phrasebank\", \"sentences_allagree\")\n","dataset = dataset[\"train\"].train_test_split(test_size=0.1)\n","dataset[\"validation\"] = dataset[\"test\"]\n","del dataset[\"test\"]\n","\n","classes = dataset[\"train\"].features[\"label\"].names\n","dataset = dataset.map(\n"," lambda x: {\"text_label\": [classes[label] for label in x[\"label\"]]},\n"," batched=True,\n"," num_proc=1,\n",")"]},{"cell_type":"markdown","id":"qzwyi-Z9yzRF","metadata":{"id":"qzwyi-Z9yzRF"},"source":["Let's also apply some pre-processing of the input data, the labels needs to be pre-processed, the tokens corresponding to `pad_token_id` needs to be set to `-100` so that the `CrossEntropy` loss associated with the model will correctly ignore these tokens."]},{"cell_type":"code","execution_count":null,"id":"6b7ea44c","metadata":{"id":"6b7ea44c"},"outputs":[],"source":["# data preprocessing\n","text_column = \"sentence\"\n","label_column = \"text_label\"\n","max_length = 128\n","\n","\n","def preprocess_function(examples):\n"," inputs = examples[text_column]\n"," targets = examples[label_column]\n"," model_inputs = tokenizer(inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n"," labels = tokenizer(targets, max_length=3, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n"," labels = labels[\"input_ids\"]\n"," labels[labels == tokenizer.pad_token_id] = -100\n"," model_inputs[\"labels\"] = labels\n"," return model_inputs\n","\n","\n","processed_datasets = dataset.map(\n"," preprocess_function,\n"," batched=True,\n"," num_proc=1,\n"," remove_columns=dataset[\"train\"].column_names,\n"," load_from_cache_file=False,\n"," desc=\"Running tokenizer on dataset\",\n",")\n","\n","train_dataset = processed_datasets[\"train\"]\n","eval_dataset = processed_datasets[\"validation\"]"]},{"cell_type":"markdown","source":["## TODO #4\n","\n","์œ„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋”ฉ/๊ฐ€๊ณต์—์„œ ์‚ฌ์šฉํ•œ ํ—ˆ๋ธŒ์˜ ๋ฐ์ดํ„ฐ์…‹ `financial_phrasebank` ๊ตฌ์กฐ์™€ ์ด ์…‹์ด ์–ด๋–ป๊ฒŒ ๋ฏธ์„ธํŠœ๋‹์— ํ™œ์šฉ๋˜์—ˆ๋Š”์ง€ ๊ฐœ๋žต์ ์œผ๋กœ ์กฐ์‚ฌํ•˜์‹œ์˜ค.\n","- ๋งˆํฌ๋‹ค์šด ์Šคํƒ€์ผ๋กœ ์ž‘์„ฑํ•˜์‹œ์˜ค"],"metadata":{"id":"zmh21tjCm01z"},"id":"zmh21tjCm01z"},{"cell_type":"markdown","source":["## 'financial_phrasebank' ๋ฐ์ดํ„ฐ์…‹์„ ํ™œ์šฉํ•œ NLP ๋ชจ๋ธ ๋ฏธ์„ธ ํŠœ๋‹ ์˜ˆ์ œ\n","\n","์•„๋ž˜ ์ฝ”๋“œ์˜ ๊ฐœ๋žต์ ์ธ ๊ฐœ์š”๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:\n","\n","1. **๋ฐ์ดํ„ฐ์…‹ ๋กœ๋”ฉ ๋ฐ ๊ฐ€๊ณต**:\n"," - `financial_phrasebank` ๋ฐ์ดํ„ฐ์…‹์€ ๊ธˆ์œต ๊ด€๋ จ ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ์ด๋ฉฐ Hugging Face Datasets ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค.\n"," - ๋ฐ์ดํ„ฐ๋Š” ํ•™์Šต ๋ฐ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋ถ„ํ• ๋˜๊ณ , ๋ ˆ์ด๋ธ”์ด ์ฒ˜๋ฆฌ๋˜์–ด ๋ชจ๋ธ ํ•™์Šต์— ๋งž๊ฒŒ ์ค€๋น„๋ฉ๋‹ˆ๋‹ค.\n","\n","2. **๋ชจ๋ธ ๋ฏธ์„ธ ํŠœ๋‹**:\n"," - `TrainingArguments`๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต ์„ค์ •์ด ์ •์˜๋ฉ๋‹ˆ๋‹ค. ์ด ์„ค์ •์€ ํ•™์Šต๋ฅ , ๋ฐฐ์น˜ ํฌ๊ธฐ, ํ•™์Šต ์—ํฌํฌ, ์ €์žฅ ๋ฐ ํ‰๊ฐ€ ์ฃผ๊ธฐ ๋“ฑ์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.\n"," - `Trainer` ํด๋ž˜์Šค๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ํŠœ๋‹ํ•ฉ๋‹ˆ๋‹ค. ์ด๋•Œ ๋ชจ๋ธ, ํ•™์Šต ์„ค์ •, ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹ ๋ฐ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์…‹์ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.\n"," - `trainer.train()`์„ ํ˜ธ์ถœํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ต๋‹ˆ๋‹ค.\n","\n","3. **๋ชจ๋ธ ์ถ”๋ก **:\n"," - ํ•™์Šต์ด ์™„๋ฃŒ๋œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜๊ณ  ์ถ”๋ก ํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.\n"," - `model.eval()`์„ ํ˜ธ์ถœํ•˜์—ฌ ๋ชจ๋ธ์„ ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •ํ•˜๊ณ , ์ž…๋ ฅ ๋ฌธ์žฅ์ด ์ •์˜๋ฉ๋‹ˆ๋‹ค.\n"," - ์ž…๋ ฅ ๋ฌธ์žฅ์„ ํ† ํฐํ™”ํ•˜๊ณ  ๋ชจ๋ธ์— ์ „๋‹ฌํ•˜์—ฌ ๋ชจ๋ธ์˜ ์ถœ๋ ฅ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n"," - ๋ชจ๋ธ์˜ ์ถœ๋ ฅ์„ ํ•ด๋…ํ•˜์—ฌ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.\n","\n","4. **๊ฒฐ๊ณผ ์ถœ๋ ฅ**:\n"," - ์ž…๋ ฅ ๋ฌธ์žฅ๊ณผ ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ๊ฐ€ ์ถœ๋ ฅ๋ฉ๋‹ˆ๋‹ค.\n","\n","์ด ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด ๊ธˆ์œต ๊ด€๋ จ ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ NLP ๋ชจ๋ธ์„ ๋ฏธ์„ธ ํŠœ๋‹ํ•˜๊ณ  ์ด๋ฅผ ํ†ตํ•ด ์ •ํ™•ํ•œ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ์˜ˆ์ œ๊ฐ€ ์ œ์‹œ๋ฉ๋‹ˆ๋‹ค. ๋ฏธ์„ธ ํŠœ๋‹์„ ํ†ตํ•ด ๋ชจ๋ธ์€ ํŠน์ • ๋ฐ์ดํ„ฐ์…‹๊ณผ ์ž‘์—…์— ๋” ์ž˜ ์ ์‘ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋Š” ๋” ๋†’์€ ์„ฑ๋Šฅ๊ณผ ์ •ํ™•๋„๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค."],"metadata":{"id":"PzXUprxPPbI9"},"id":"PzXUprxPPbI9"},{"cell_type":"markdown","id":"bcNTdVypGEPb","metadata":{"id":"bcNTdVypGEPb"},"source":["## Train our model!\n","\n","Let's now train our model, run the cells below.\n","Note that for T5 since some layers are kept in `float32` for stability purposes there is no need to call autocast on the trainer."]},{"cell_type":"code","execution_count":null,"id":"69c756ac","metadata":{"id":"69c756ac"},"outputs":[],"source":["from transformers import TrainingArguments, Trainer\n","\n","training_args = TrainingArguments(\n"," \"temp\",\n"," evaluation_strategy=\"epoch\",\n"," learning_rate=1e-3,\n"," gradient_accumulation_steps=1,\n"," auto_find_batch_size=True,\n"," num_train_epochs=1,\n"," save_steps=100,\n"," save_total_limit=8,\n",")\n","trainer = Trainer(\n"," model=model,\n"," args=training_args,\n"," train_dataset=train_dataset,\n"," eval_dataset=eval_dataset,\n",")\n","model.config.use_cache = False # silence the warnings. Please re-enable for inference!"]},{"cell_type":"code","execution_count":null,"id":"ab52b651","metadata":{"id":"ab52b651"},"outputs":[],"source":["trainer.train()"]},{"cell_type":"markdown","id":"r98VtofiGXtO","metadata":{"id":"r98VtofiGXtO"},"source":["## Qualitatively test our model"]},{"cell_type":"markdown","id":"NIm7z3UNzGPP","metadata":{"id":"NIm7z3UNzGPP"},"source":["Let's have a quick qualitative evaluation of the model, by taking a sample from the dataset that corresponds to a positive label. Run your generation similarly as you were running your model from `transformers`:"]},{"cell_type":"code","execution_count":null,"id":"c95d6173","metadata":{"id":"c95d6173"},"outputs":[],"source":["model.eval()\n","input_text = \"In January-September 2009 , the Group 's net interest income increased to EUR 112.4 mn from EUR 74.3 mn in January-September 2008 .\"\n","inputs = tokenizer(input_text, return_tensors=\"pt\")\n","\n","outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n","\n","print(\"input sentence: \", input_text)\n","print(\" output prediction: \", tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"]},{"cell_type":"markdown","source":["## TODO #5\n","\n","๋ณธ์ธ์˜ ํ—ˆ๊น…ํŽ˜์ด์Šค ๊ณ„์ •์„ ๋งŒ๋“ค๊ณ  ์•„๋ž˜ ํ—ˆ๋ธŒ์— ์—…๋กœ๋“œ/๋‹ค์šด๋กœ๋“œ/ํ™•์ธ ๊ณผ์ •์„ ๋ณธ์ธ ๊ณ„์ • ๊ธฐ์ค€์œผ๋กœ ์ง„ํ–‰ํ•˜์‹œ์˜ค.\n","\n","์ง„ํ–‰ ํ›„ ์—…๋ฅด๋“œํ•œ ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ์˜ ๋ชจ๋ธ id๋ฅผ ์ ์œผ์‹œ์˜ค.\n","- ๋งˆํฌ๋‹ค์šด ์Šคํƒ€์ผ๋กœ ์ž‘์„ฑํ•˜์‹œ์˜ค."],"metadata":{"id":"ubwn2Qdbl3Fb"},"id":"ubwn2Qdbl3Fb"},{"cell_type":"markdown","source":[],"metadata":{"id":"hK-Mdl4VgKcN"},"id":"hK-Mdl4VgKcN"},{"cell_type":"markdown","id":"9QqBlwzoGZ3f","metadata":{"id":"9QqBlwzoGZ3f"},"source":["## Share your adapters on ๐Ÿค— Hub"]},{"cell_type":"markdown","id":"NT-C8SjcKqUx","metadata":{"id":"NT-C8SjcKqUx"},"source":["Once you have trained your adapter, you can easily share it on the Hub using the method `push_to_hub` . Note that only the adapter weights and config will be pushed"]},{"cell_type":"code","execution_count":null,"id":"bcbfa1f9","metadata":{"id":"bcbfa1f9"},"outputs":[],"source":["from huggingface_hub import notebook_login\n","\n","notebook_login()"]},{"cell_type":"code","execution_count":null,"id":"rFKJ4vHNGkJw","metadata":{"id":"rFKJ4vHNGkJw"},"outputs":[],"source":["model.push_to_hub(\"yysspp/flan-t5-large-financial-phrasebank-lora\", use_auth_token=True)"]},{"cell_type":"markdown","id":"xHuDmbCYJ89f","metadata":{"id":"xHuDmbCYJ89f"},"source":["## Load your adapter from the Hub"]},{"cell_type":"markdown","id":"ANFo6DdfKlU3","metadata":{"id":"ANFo6DdfKlU3"},"source":["You can load the model together with the adapter with few lines of code! Check the snippet below to load the adapter from the Hub and run the example evaluation!"]},{"cell_type":"code","execution_count":null,"id":"j097aaPWJ-9u","metadata":{"id":"j097aaPWJ-9u"},"outputs":[],"source":["import torch\n","from peft import PeftModel, PeftConfig\n","from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n","\n","peft_model_id = \"yysspp/flan-t5-large-financial-phrasebank-lora\"\n","config = PeftConfig.from_pretrained(peft_model_id)\n","\n","model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, torch_dtype=\"auto\", device_map=\"auto\")\n","tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n","\n","# Load the Lora model\n","model = PeftModel.from_pretrained(model, peft_model_id)"]},{"cell_type":"code","execution_count":null,"id":"jmjwWYt0KI_I","metadata":{"id":"jmjwWYt0KI_I"},"outputs":[],"source":["model.eval()\n","input_text = \"In January-September 2009 , the Group 's net interest income increased to EUR 112.4 mn from EUR 74.3 mn in January-September 2008 .\"\n","inputs = tokenizer(input_text, return_tensors=\"pt\")\n","\n","outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n","\n","print(\"input sentence: \", input_text)\n","print(\" output prediction: \", tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"]}],"metadata":{"accelerator":"GPU","colab":{"provenance":[],"gpuType":"T4","toc_visible":true},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"},"vscode":{"interpreter":{"hash":"1219a10c7def3e2ad4f431cfa6f49d569fcc5949850132f23800e792129eefbb"}}},"nbformat":4,"nbformat_minor":5}