Laurent1 commited on
Commit
7771883
1 Parent(s): 4261dd0

Upload laurent-restaurant-adaptation-mistral-7b-tuned.ipynb

Browse files
laurent-restaurant-adaptation-mistral-7b-tuned.ipynb ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Libraries"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "! pip install bitsandbytes\n",
17
+ "! pip install einops\n",
18
+ "! pip install peft\n",
19
+ "! pip install datasets==2.14.6"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 1,
25
+ "metadata": {
26
+ "tags": []
27
+ },
28
+ "outputs": [
29
+ {
30
+ "name": "stdout",
31
+ "output_type": "stream",
32
+ "text": [
33
+ "2.14.6\n",
34
+ "4.35.0\n"
35
+ ]
36
+ }
37
+ ],
38
+ "source": [
39
+ "# Check the versions\n",
40
+ "import datasets\n",
41
+ "print(datasets.__version__)\n",
42
+ "import transformers\n",
43
+ "print(transformers.__version__)"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {},
49
+ "source": [
50
+ "# Restaurant dataset"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 2,
56
+ "metadata": {},
57
+ "outputs": [
58
+ {
59
+ "name": "stdout",
60
+ "output_type": "stream",
61
+ "text": [
62
+ "bin C:\\Users\\Utilisateur\\anaconda3\\lib\\site-packages\\bitsandbytes\\libbitsandbytes_cuda117.dll\n"
63
+ ]
64
+ }
65
+ ],
66
+ "source": [
67
+ "import einops\n",
68
+ "import torch\n",
69
+ "import pandas as pd\n",
70
+ "import numpy as np\n",
71
+ "from datasets import load_dataset,Dataset\n",
72
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, Trainer, DataCollatorForLanguageModeling\n",
73
+ "from peft import LoraConfig,get_peft_model,AutoPeftModelForCausalLM"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "metadata": {},
79
+ "source": [
80
+ "# Load of the dataset for domain adaptation"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 3,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "dataset0 = load_dataset(\"Argen7um/restrant-qa\")#.select(range(877))"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {},
95
+ "source": [
96
+ "## Adaptation of the data for training"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 4,
102
+ "metadata": {},
103
+ "outputs": [
104
+ {
105
+ "data": {
106
+ "text/plain": [
107
+ "Dataset({\n",
108
+ " features: ['text'],\n",
109
+ " num_rows: 877\n",
110
+ "})"
111
+ ]
112
+ },
113
+ "execution_count": 4,
114
+ "metadata": {},
115
+ "output_type": "execute_result"
116
+ }
117
+ ],
118
+ "source": [
119
+ "text = []\n",
120
+ "for i in range(877): \n",
121
+ " text.append('At Laurent restaurant : '+ dataset0['train'][i]['Prompt'].split('[question]:')[1].replace(' [/INST]\\n',''))\n",
122
+ "\n",
123
+ "data_text = pd.DataFrame(columns = ['text'])\n",
124
+ "data_text['text'] = text\n",
125
+ "\n",
126
+ "dataset_text = Dataset.from_pandas(data_text)\n",
127
+ "dataset_text"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "metadata": {},
133
+ "source": [
134
+ "## Check the distribution of the length of the rows (truncation impact ?)"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 5,
140
+ "metadata": {
141
+ "tags": []
142
+ },
143
+ "outputs": [
144
+ {
145
+ "data": {
146
+ "text/plain": [
147
+ "array([[<Axes: title={'center': '0'}>]], dtype=object)"
148
+ ]
149
+ },
150
+ "execution_count": 5,
151
+ "metadata": {},
152
+ "output_type": "execute_result"
153
+ },
154
+ {
155
+ "data": {
156
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGzCAYAAAAFROyYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApJklEQVR4nO3df3RU5Z3H8c+ETAZSM4QQQpI1gfgTLRL5UdKsroWSAMGDRakrPzyLlAPFBVtJu0W6ogm2G1Zb1uqycNhVdA+ktPYIKioQQEDWEOVHFnFdJBTESgIFNhmSyDiQZ//gMHVMIJlwJ/ME3q9z5pD7PM+989zvueN8vHPnjssYYwQAAGCRmGhPAAAA4OsIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAFbw+/2aO3eu0tPT1a1bN+Xk5KisrCza0wIQJQQUAFZ46KGHtGjRIk2ePFm/+c1v1KVLF40ZM0bbt2+P9tQARIGLHwsEEG3vv/++cnJy9Mwzz+inP/2pJOnMmTPq37+/UlJS9N5770V5hgA6GmdQAETdH/7wB3Xp0kUzZswItnXt2lXTpk1TeXm5PvvssyjODkA0EFAARN2ePXt00003yev1hrQPHTpUklRZWRmFWQGIJgIKgKirrq5WWlpas/YLbUePHu3oKQGIMgIKgKj74osv5PF4mrV37do12A/g6kJAARB13bp1k9/vb9Z+5syZYD+AqwsBBUDUpaWlqbq6uln7hbb09PSOnhKAKCOgAIi622+/XZ988ol8Pl9Ie0VFRbAfwNWFgAIg6r7//e/r3LlzWrZsWbDN7/dr+fLlysnJUUZGRhRnByAaYqM9AQDIycnR/fffr3nz5un48eO64YYb9PLLL+vw4cN64YUXoj09AFHAnWQBWOHMmTOaP3++VqxYof/7v//TgAED9NRTT2nUqFHRnhqAKCCgAAAA63ANCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdTrljdqampp09OhRJSQkyOVyRXs6AACgDYwxOn36tNLT0xUTc+lzJJ0yoBw9epRbXwMA0El99tlnuvbaay85plMGlISEBEnnd9Dr9QbbA4GANmzYoJEjR8rtdkdrelcEaukcaukcaukcaukcatl2Pp9PGRkZwffxS+mUAeXCxzper7dZQImPj5fX6+UguUzU0jnU0jnU0jnU0jnUMnxtuTyDi2QBAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOuEHVC2bdumsWPHKj09XS6XS2vWrAnpd7lcLT6eeeaZ4Ji+ffs261+4cOFl7wwAALgyhB1QGhoalJ2drcWLF7fYX11dHfJ48cUX5XK5NH78+JBxCxYsCBn3yCOPtG8PAADAFSfsW90XFBSooKDgov2pqakhy6+99pqGDx+u6667LqQ9ISGh2VgAAAApwr/Fc+zYMb355pt6+eWXm/UtXLhQTz31lDIzMzVp0iTNmTNHsbEtT8fv98vv9weXfT6fpPO/fxAIBILtF/7+ahvah1o6h1o6h1o6h1o6h1q2XTg1imhAefnll5WQkKD77rsvpP1HP/qRBg0apKSkJL333nuaN2+eqqurtWjRoha3U1JSouLi4mbtGzZsUHx8fLP2srIyZ3YA1NJB1NI51NI51NI51LJ1jY2NbR7rMsaY9j6Ry+XS6tWrNW7cuBb7+/Xrp/z8fD3//POX3M6LL76oH/7wh6qvr5fH42nW39IZlIyMDJ04caLZrxmXlZUpPz+fX5S8TNTSOdTSOdTSOdTSOdSy7Xw+n5KTk1VXVxfy/t2SiJ1Beffdd7V//3797ne/a3VsTk6Ozp49q8OHD+vmm29u1u/xeFoMLm63u8WD4WLtCF8ka9n3sTfbve7hhXc7OJOOwXHpHGrpHGrpHGrZunDqE7H7oLzwwgsaPHiwsrOzWx1bWVmpmJgYpaSkRGo6AACgEwn7DEp9fb2qqqqCy4cOHVJlZaWSkpKUmZkp6fwpnFdeeUW//vWvm61fXl6uiooKDR8+XAkJCSovL9ecOXP04IMPqkePHpexKwAA4EoRdkDZuXOnhg8fHlwuLCyUJE2ZMkUvvfSSJGnVqlUyxmjixInN1vd4PFq1apWKiork9/uVlZWlOXPmBLcDAAAQdkAZNmyYWruudsaMGZoxY0aLfYMGDdKOHTvCfVoAAHAV4bd4AACAdQgoAADAOgQUAABgnYjeSRaIlKvtHioAcLXhDAoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1gk7oGzbtk1jx45Venq6XC6X1qxZE9L/0EMPyeVyhTxGjx4dMubUqVOaPHmyvF6vEhMTNW3aNNXX11/WjgAAgCtH2AGloaFB2dnZWrx48UXHjB49WtXV1cHHb3/725D+yZMn66OPPlJZWZnWrl2rbdu2acaMGeHPHgAAXJFiw12hoKBABQUFlxzj8XiUmpraYt/HH3+sdevW6YMPPtCQIUMkSc8//7zGjBmjX/3qV0pPTw93SgAA4AoTdkBpiy1btiglJUU9evTQd7/7Xf3iF79Qz549JUnl5eVKTEwMhhNJysvLU0xMjCoqKnTvvfc2257f75ff7w8u+3w+SVIgEFAgEAi2X/j7q21on46opaeLidi2L6Wjjw+OS+dQS+dQS+dQy7YLp0aOB5TRo0frvvvuU1ZWlg4ePKif//znKigoUHl5ubp06aKamhqlpKSETiI2VklJSaqpqWlxmyUlJSouLm7WvmHDBsXHxzdrLysrc2ZnENFaPj00Ypu+pLfeeisqz8tx6Rxq6Rxq6Rxq2brGxsY2j3U8oEyYMCH492233aYBAwbo+uuv15YtWzRixIh2bXPevHkqLCwMLvt8PmVkZGjkyJHyer3B9kAgoLKyMuXn58vtdrd/J9AhtexftD4i223NvqJRHfp8HJfOoZbOoZbOoZZtd+ETkLaIyEc8X3XdddcpOTlZVVVVGjFihFJTU3X8+PGQMWfPntWpU6cuet2Kx+ORx+Np1u52u1s8GC7WjvBFspb+c66IbLc1N87f0O51Dy+8u93rclw6h1o6h1o6h1q2Lpz6RPw+KH/605908uRJpaWlSZJyc3NVW1urXbt2Bcds3rxZTU1NysnJifR0AABAJxD2GZT6+npVVVUFlw8dOqTKykolJSUpKSlJxcXFGj9+vFJTU3Xw4EH97Gc/0w033KBRo86fVr/llls0evRoTZ8+XUuXLlUgENDs2bM1YcIEvsEDAAAkteMMys6dOzVw4EANHDhQklRYWKiBAwfqiSeeUJcuXbR3717dc889uummmzRt2jQNHjxY7777bshHNCtXrlS/fv00YsQIjRkzRnfeeaeWLVvm3F4BAIBOLewzKMOGDZMxF/966Pr1rV/4mJSUpNLS0nCfGgAAXCX4LR4AAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOmEHlG3btmns2LFKT0+Xy+XSmjVrgn2BQEBz587Vbbfdpm984xtKT0/X3/3d3+no0aMh2+jbt69cLlfIY+HChZe9MwAA4MoQdkBpaGhQdna2Fi9e3KyvsbFRu3fv1vz587V79269+uqr2r9/v+65555mYxcsWKDq6urg45FHHmnfHgAAgCtObLgrFBQUqKCgoMW+7t27q6ysLKTtX//1XzV06FAdOXJEmZmZwfaEhASlpqa26Tn9fr/8fn9w2efzSTp/xiYQCATbL/z91Ta0T0fU0tPFRGzbkdKeenBcOodaOodaOodatl04NXIZY9r9LuFyubR69WqNGzfuomM2btyokSNHqra2Vl6vV9L5j3jOnDmjQCCgzMxMTZo0SXPmzFFsbMt5qaioSMXFxc3aS0tLFR8f397pAwCADtTY2KhJkyaprq4umAkuJqIB5cyZM7rjjjvUr18/rVy5Mti+aNEiDRo0SElJSXrvvfc0b948TZ06VYsWLWpxOy2dQcnIyNCJEydCdjAQCKisrEz5+flyu93t3S2oY2rZv2h9RLYbSfuKRoW9Dselc6ilc6ilc6hl2/l8PiUnJ7cpoIT9EU9bBQIB/e3f/q2MMVqyZElIX2FhYfDvAQMGKC4uTj/84Q9VUlIij8fTbFsej6fFdrfb3eLBcLF2hK+1WvZ97M3L2LrrMtaNjss5rjgunUMtnUMtnUMtWxdOfSLyNeML4eTTTz9VWVlZqykpJydHZ8+e1eHDhyMxHQAA0Mk4fgblQjg5cOCA3nnnHfXs2bPVdSorKxUTE6OUlBSnpwMAADqhsANKfX29qqqqgsuHDh1SZWWlkpKSlJaWpu9///vavXu31q5dq3PnzqmmpkaSlJSUpLi4OJWXl6uiokLDhw9XQkKCysvLNWfOHD344IPq0aOHc3sGAAA6rbADys6dOzV8+PDg8oXrSaZMmaKioiK9/vrrkqTbb789ZL133nlHw4YNk8fj0apVq1RUVCS/36+srCzNmTMn5LoUAABwdQs7oAwbNkyX+uJPa18KGjRokHbs2BHu0wIAgKsIv8UDAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsE7YAWXbtm0aO3as0tPT5XK5tGbNmpB+Y4yeeOIJpaWlqVu3bsrLy9OBAwdCxpw6dUqTJ0+W1+tVYmKipk2bpvr6+svaEQAAcOUIO6A0NDQoOztbixcvbrH/6aef1nPPPaelS5eqoqJC3/jGNzRq1CidOXMmOGby5Mn66KOPVFZWprVr12rbtm2aMWNG+/cCAABcUWLDXaGgoEAFBQUt9hlj9Oyzz+rxxx/X9773PUnSf/7nf6p3795as2aNJkyYoI8//ljr1q3TBx98oCFDhkiSnn/+eY0ZM0a/+tWvlJ6efhm7AwAArgRhB5RLOXTokGpqapSXlxds6969u3JyclReXq4JEyaovLxciYmJwXAiSXl5eYqJiVFFRYXuvffeZtv1+/3y+/3BZZ/PJ0kKBAIKBALB9gt/f7UN7dPWWnq6mI6YjjXac2xxXDqHWjqHWjqHWrZdODVyNKDU1NRIknr37h3S3rt372BfTU2NUlJSQicRG6ukpKTgmK8rKSlRcXFxs/YNGzYoPj6+WXtZWVm75o/mWqvl00M7aCKWeOutt9q9Lselc6ilc6ilc6hl6xobG9s81tGAEinz5s1TYWFhcNnn8ykjI0MjR46U1+sNtgcCAZWVlSk/P19utzsaU71itLWW/YvWd+Csom9f0aiw1+G4dA61dA61dA61bLsLn4C0haMBJTU1VZJ07NgxpaWlBduPHTum22+/PTjm+PHjIeudPXtWp06dCq7/dR6PRx6Pp1m72+1u8WC4WDvC11ot/edcHTib6Luc44rj0jnU0jnU0jnUsnXh1MfR+6BkZWUpNTVVmzZtCrb5fD5VVFQoNzdXkpSbm6va2lrt2rUrOGbz5s1qampSTk6Ok9MBAACdVNhnUOrr61VVVRVcPnTokCorK5WUlKTMzEw9+uij+sUvfqEbb7xRWVlZmj9/vtLT0zVu3DhJ0i233KLRo0dr+vTpWrp0qQKBgGbPnq0JEybwDR4AACCpHQFl586dGj58eHD5wrUhU6ZM0UsvvaSf/exnamho0IwZM1RbW6s777xT69atU9euXYPrrFy5UrNnz9aIESMUExOj8ePH67nnnnNgdwAAwJUg7IAybNgwGXPxr5a6XC4tWLBACxYsuOiYpKQklZaWhvvUAADgKsFv8QAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOmH/Fg9wNev72Jthr+PpYvT00AhMBgCuYJxBAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOo4HlL59+8rlcjV7zJo1S5I0bNiwZn0zZ850ehoAAKATi3V6gx988IHOnTsXXN63b5/y8/N1//33B9umT5+uBQsWBJfj4+OdngYAAOjEHA8ovXr1ClleuHChrr/+en3nO98JtsXHxys1NdXppwYAAFcIxwPKV3355ZdasWKFCgsL5XK5gu0rV67UihUrlJqaqrFjx2r+/PmXPIvi9/vl9/uDyz6fT5IUCAQUCASC7Rf+/mob2qettfR0MR0xnU7NE3O+RhyXl4/XuHOopXOoZduFUyOXMSZi7zC///3vNWnSJB05ckTp6emSpGXLlqlPnz5KT0/X3r17NXfuXA0dOlSvvvrqRbdTVFSk4uLiZu2lpaV8PAQAQCfR2NioSZMmqa6uTl6v95JjIxpQRo0apbi4OL3xxhsXHbN582aNGDFCVVVVuv7661sc09IZlIyMDJ04cSJkBwOBgMrKypSfny+32+3cjlyF2lrL/kXrO3BWnZMnxuipIU0clw7gNe4caukcatl2Pp9PycnJbQooEfuI59NPP9XGjRsveWZEknJyciTpkgHF4/HI4/E0a3e73S0eDBdrR/haq6X/nOuifQjFcekcaukcaukcatm6cOoTsfugLF++XCkpKbr77rsvOa6yslKSlJaWFqmpAACATiYiZ1Campq0fPlyTZkyRbGxf3mKgwcPqrS0VGPGjFHPnj21d+9ezZkzR3fddZcGDBgQiakAAIBOKCIBZePGjTpy5Ih+8IMfhLTHxcVp48aNevbZZ9XQ0KCMjAyNHz9ejz/+eCSmAQAAOqmIBJSRI0eqpWtvMzIytHXr1kg8JQAAuILwWzwAAMA6BBQAAGAdAgoAALBORG91j86h72NvNmvzdDF6euj5G7FxrxMAQEfjDAoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArON4QCkqKpLL5Qp59OvXL9h/5swZzZo1Sz179tQ111yj8ePH69ixY05PAwAAdGIROYPyzW9+U9XV1cHH9u3bg31z5szRG2+8oVdeeUVbt27V0aNHdd9990ViGgAAoJOKjchGY2OVmprarL2urk4vvPCCSktL9d3vfleStHz5ct1yyy3asWOHvv3tb0diOgAAoJOJSEA5cOCA0tPT1bVrV+Xm5qqkpESZmZnatWuXAoGA8vLygmP79eunzMxMlZeXXzSg+P1++f3+4LLP55MkBQIBBQKBYPuFv7/ahtZ5upjmbTEm5F+034UaclxePl7jzqGWzqGWbRdOjVzGGEffgd5++23V19fr5ptvVnV1tYqLi/X5559r3759euONNzR16tSQsCFJQ4cO1fDhw/XP//zPLW6zqKhIxcXFzdpLS0sVHx/v5PQBAECENDY2atKkSaqrq5PX673kWMcDytfV1taqT58+WrRokbp169augNLSGZSMjAydOHEiZAcDgYDKysqUn58vt9sdmR26AvUvWt+szRNj9NSQJs3fGSN/kysKs7pyXKglx+Xl4zXuHGrpHGrZdj6fT8nJyW0KKBH5iOerEhMTddNNN6mqqkr5+fn68ssvVVtbq8TExOCYY8eOtXjNygUej0cej6dZu9vtbvFguFg7WuY/d/EA4m9yXbIfbcdx6Rxq6Rxq6Rxq2bpw6hPx+6DU19fr4MGDSktL0+DBg+V2u7Vp06Zg//79+3XkyBHl5uZGeioAAKCTcPwMyk9/+lONHTtWffr00dGjR/Xkk0+qS5cumjhxorp3765p06apsLBQSUlJ8nq9euSRR5Sbm8s3eAAAQJDjAeVPf/qTJk6cqJMnT6pXr1668847tWPHDvXq1UuS9C//8i+KiYnR+PHj5ff7NWrUKP3bv/2b09MAAACdmOMBZdWqVZfs79q1qxYvXqzFixc7/dQAAOAKwW/xAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANZxPKCUlJToW9/6lhISEpSSkqJx48Zp//79IWOGDRsml8sV8pg5c6bTUwEAAJ2U4wFl69atmjVrlnbs2KGysjIFAgGNHDlSDQ0NIeOmT5+u6urq4OPpp592eioAAKCTinV6g+vWrQtZfumll5SSkqJdu3bprrvuCrbHx8crNTXV6acHrNW/aL3851ztWvfwwrsdng0A2M3xgPJ1dXV1kqSkpKSQ9pUrV2rFihVKTU3V2LFjNX/+fMXHx7e4Db/fL7/fH1z2+XySpEAgoEAgEGy/8PdX29A6TxfTvC3GhPyL9nOilhzT5/Eadw61dA61bLtwauQyxkTsHaipqUn33HOPamtrtX379mD7smXL1KdPH6Wnp2vv3r2aO3euhg4dqldffbXF7RQVFam4uLhZe2lp6UVDDQAAsEtjY6MmTZqkuro6eb3eS46NaEB5+OGH9fbbb2v79u269tprLzpu8+bNGjFihKqqqnT99dc362/pDEpGRoZOnDgRsoOBQEBlZWXKz8+X2+12dmeuYP2L1jdr88QYPTWkSfN3xsjf1L6PJXCeE7XcVzTK4Vl1TrzGnUMtnUMt287n8yk5OblNASViH/HMnj1ba9eu1bZt2y4ZTiQpJydHki4aUDwejzweT7N2t9vd4sFwsXa07FLXRfibXO2+bgKhLqeWHM+heI07h1o6h1q2Lpz6OB5QjDF65JFHtHr1am3ZskVZWVmtrlNZWSlJSktLc3o6AACgE3I8oMyaNUulpaV67bXXlJCQoJqaGklS9+7d1a1bNx08eFClpaUaM2aMevbsqb1792rOnDm66667NGDAAKenc9Xo+9ib0Z4CAACOcTygLFmyRNL5m7F91fLly/XQQw8pLi5OGzdu1LPPPquGhgZlZGRo/Pjxevzxx52eCgAA6KQi8hHPpWRkZGjr1q1OPy0AALiC8Fs8AADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDqx0Z4A/qLvY29GewoAAFiBMygAAMA6BBQAAGAdAgoAALAO16A4jOtIAAC4fJxBAQAA1uEMCtAJXM6ZucML73ZwJgDQMTiDAgAArENAAQAA1iGgAAAA63ANCnCF4/oVAJ0RZ1AAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANbhPigt4BeJAQCILs6gAAAA6xBQAACAdQgoAADAOlG9BmXx4sV65plnVFNTo+zsbD3//PMaOnRoNKcEoJPjt4eAv+jMr4eonUH53e9+p8LCQj355JPavXu3srOzNWrUKB0/fjxaUwIAAJaIWkBZtGiRpk+frqlTp+rWW2/V0qVLFR8frxdffDFaUwIAAJaIykc8X375pXbt2qV58+YF22JiYpSXl6fy8vJm4/1+v/x+f3C5rq5OknTq1CkFAoFgeyAQUGNjo06ePCm3293u+cWebWj3uleK2CajxsYmxQZidK7JFe3pdGqduZYnT56M9hRCtOU1fjmvX9v2N5Kc+u8l7K6lba+H06dPS5KMMa2OjUpAOXHihM6dO6fevXuHtPfu3Vv/+7//22x8SUmJiouLm7VnZWVFbI6QJkV7AleQzlrL5F9HewYd62rbX+BSIvl6OH36tLp3737JMZ3iRm3z5s1TYWFhcLmpqUmnTp1Sz5495XL95f9IfT6fMjIy9Nlnn8nr9UZjqlcMaukcaukcaukcaukcatl2xhidPn1a6enprY6NSkBJTk5Wly5ddOzYsZD2Y8eOKTU1tdl4j8cjj8cT0paYmHjR7Xu9Xg4Sh1BL51BL51BL51BL51DLtmntzMkFUblINi4uToMHD9amTZuCbU1NTdq0aZNyc3OjMSUAAGCRqH3EU1hYqClTpmjIkCEaOnSonn32WTU0NGjq1KnRmhIAALBE1ALKAw88oD//+c964oknVFNTo9tvv13r1q1rduFsODwej5588slmHwchfNTSOdTSOdTSOdTSOdQyMlymLd/1AQAA6ED8Fg8AALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOtYH1C2bdumsWPHKj09XS6XS2vWrAnpN8boiSeeUFpamrp166a8vDwdOHAgZMypU6c0efJkeb1eJSYmatq0aaqvr+/AvbBDSUmJvvWtbykhIUEpKSkaN26c9u/fHzLmzJkzmjVrlnr27KlrrrlG48ePb3bH3yNHjujuu+9WfHy8UlJS9A//8A86e/ZsR+5K1C1ZskQDBgwI3jkyNzdXb7/9drCfOrbPwoUL5XK59OijjwbbqGXbFRUVyeVyhTz69esX7KeWbff555/rwQcfVM+ePdWtWzfddttt2rlzZ7Cf954OYCz31ltvmX/8x380r776qpFkVq9eHdK/cOFC0717d7NmzRrz3//93+aee+4xWVlZ5osvvgiOGT16tMnOzjY7duww7777rrnhhhvMxIkTO3hPom/UqFFm+fLlZt++faaystKMGTPGZGZmmvr6+uCYmTNnmoyMDLNp0yazc+dO8+1vf9v89V//dbD/7Nmzpn///iYvL8/s2bPHvPXWWyY5OdnMmzcvGrsUNa+//rp58803zSeffGL2799vfv7znxu322327dtnjKGO7fH++++bvn37mgEDBpgf//jHwXZq2XZPPvmk+eY3v2mqq6uDjz//+c/BfmrZNqdOnTJ9+vQxDz30kKmoqDB//OMfzfr1601VVVVwDO89kWd9QPmqrweUpqYmk5qaap555plgW21trfF4POa3v/2tMcaY//mf/zGSzAcffBAc8/bbbxuXy2U+//zzDpu7jY4fP24kma1btxpjztfO7XabV155JTjm448/NpJMeXm5MeZ8YIyJiTE1NTXBMUuWLDFer9f4/f6O3QHL9OjRw/zHf/wHdWyH06dPmxtvvNGUlZWZ73znO8GAQi3D8+STT5rs7OwW+6hl282dO9fceeedF+3nvadjWP8Rz6UcOnRINTU1ysvLC7Z1795dOTk5Ki8vlySVl5crMTFRQ4YMCY7Jy8tTTEyMKioqOnzONqmrq5MkJSUlSZJ27dqlQCAQUs9+/fopMzMzpJ633XZbyB1/R40aJZ/Pp48++qgDZ2+Pc+fOadWqVWpoaFBubi51bIdZs2bp7rvvDqmZxDHZHgcOHFB6erquu+46TZ48WUeOHJFELcPx+uuva8iQIbr//vuVkpKigQMH6t///d+D/bz3dIxOHVBqamokqdnt8Xv37h3sq6mpUUpKSkh/bGyskpKSgmOuRk1NTXr00Ud1xx13qH///pLO1youLq7ZL0V/vZ4t1ftC39Xkww8/1DXXXCOPx6OZM2dq9erVuvXWW6ljmFatWqXdu3erpKSkWR+1DE9OTo5eeuklrVu3TkuWLNGhQ4f0N3/zNzp9+jS1DMMf//hHLVmyRDfeeKPWr1+vhx9+WD/60Y/08ssvS+K9p6NE7bd4EF2zZs3Svn37tH379mhPpdO6+eabVVlZqbq6Ov3hD3/QlClTtHXr1mhPq1P57LPP9OMf/1hlZWXq2rVrtKfT6RUUFAT/HjBggHJyctSnTx/9/ve/V7du3aI4s86lqalJQ4YM0T/90z9JkgYOHKh9+/Zp6dKlmjJlSpRnd/Xo1GdQUlNTJanZVejHjh0L9qWmpur48eMh/WfPntWpU6eCY642s2fP1tq1a/XOO+/o2muvDbanpqbqyy+/VG1tbcj4r9ezpXpf6LuaxMXF6YYbbtDgwYNVUlKi7Oxs/eY3v6GOYdi1a5eOHz+uQYMGKTY2VrGxsdq6dauee+45xcbGqnfv3tTyMiQmJuqmm25SVVUVx2UY0tLSdOutt4a03XLLLcGPy3jv6RidOqBkZWUpNTVVmzZtCrb5fD5VVFQoNzdXkpSbm6va2lrt2rUrOGbz5s1qampSTk5Oh885mowxmj17tlavXq3NmzcrKysrpH/w4MFyu90h9dy/f7+OHDkSUs8PP/ww5IVXVlYmr9fb7AV9tWlqapLf76eOYRgxYoQ+/PBDVVZWBh9DhgzR5MmTg39Ty/arr6/XwYMHlZaWxnEZhjvuuKPZLRg++eQT9enTRxLvPR0m2lfptub06dNmz549Zs+ePUaSWbRokdmzZ4/59NNPjTHnv+qVmJhoXnvtNbN3717zve99r8Wveg0cONBUVFSY7du3mxtvvPGq/KrXww8/bLp37262bNkS8jXExsbG4JiZM2eazMxMs3nzZrNz506Tm5trcnNzg/0XvoY4cuRIU1lZadatW2d69ep11X0N8bHHHjNbt241hw4dMnv37jWPPfaYcblcZsOGDcYY6ng5vvotHmOoZTh+8pOfmC1btphDhw6Z//qv/zJ5eXkmOTnZHD9+3BhDLdvq/fffN7GxseaXv/ylOXDggFm5cqWJj483K1asCI7hvSfyrA8o77zzjpHU7DFlyhRjzPmve82fP9/07t3beDweM2LECLN///6QbZw8edJMnDjRXHPNNcbr9ZqpU6ea06dPR2FvoqulOkoyy5cvD4754osvzN///d+bHj16mPj4eHPvvfea6urqkO0cPnzYFBQUmG7dupnk5GTzk5/8xAQCgQ7em+j6wQ9+YPr06WPi4uJMr169zIgRI4LhxBjqeDm+HlCoZds98MADJi0tzcTFxZm/+qu/Mg888EDIvTuoZdu98cYbpn///sbj8Zh+/fqZZcuWhfTz3hN5LmOMic65GwAAgJZ16mtQAADAlYmAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADW+X93U5a4KAP3LQAAAABJRU5ErkJggg==\n",
157
+ "text/plain": [
158
+ "<Figure size 640x480 with 1 Axes>"
159
+ ]
160
+ },
161
+ "metadata": {},
162
+ "output_type": "display_data"
163
+ }
164
+ ],
165
+ "source": [
166
+ "LEN = []\n",
167
+ "for i in range(877):\n",
168
+ " LEN.append(len(dataset_text['text'][i]))\n",
169
+ "import numpy as np\n",
170
+ "import pandas as pd\n",
171
+ "\n",
172
+ "pd.DataFrame(np.array(LEN)).hist(bins = 30)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "metadata": {},
178
+ "source": [
179
+ "# Tokenization of the dataset"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": 6,
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "data": {
189
+ "application/vnd.jupyter.widget-view+json": {
190
+ "model_id": "b9919f91c04f49428be62ac33921ec7d",
191
+ "version_major": 2,
192
+ "version_minor": 0
193
+ },
194
+ "text/plain": [
195
+ "Map: 0%| | 0/877 [00:00<?, ? examples/s]"
196
+ ]
197
+ },
198
+ "metadata": {},
199
+ "output_type": "display_data"
200
+ },
201
+ {
202
+ "data": {
203
+ "text/plain": [
204
+ "Dataset({\n",
205
+ " features: ['input_ids', 'attention_mask'],\n",
206
+ " num_rows: 877\n",
207
+ "})"
208
+ ]
209
+ },
210
+ "execution_count": 6,
211
+ "metadata": {},
212
+ "output_type": "execute_result"
213
+ }
214
+ ],
215
+ "source": [
216
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
217
+ "\n",
218
+ "tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.1\")\n",
219
+ "tokenizer.pad_token = tokenizer.eos_token\n",
220
+ "tokenizer.padding_side = \"right\" \n",
221
+ "\n",
222
+ "def tokenize_function(examples):\n",
223
+ " result = tokenizer(examples[\"text\"])\n",
224
+ " return result\n",
225
+ "\n",
226
+ "tokenized_datasets = dataset_text.map(\n",
227
+ " tokenize_function, batched=True, remove_columns=[\"text\"]\n",
228
+ ")\n",
229
+ "tokenized_datasets"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 7,
235
+ "metadata": {
236
+ "tags": []
237
+ },
238
+ "outputs": [],
239
+ "source": [
240
+ "tokenizer.mask_token = '<MASK>'\n",
241
+ "collator = DataCollatorForLanguageModeling(mlm = True,mlm_probability=0.15,tokenizer = tokenizer)"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "markdown",
246
+ "metadata": {
247
+ "id": "rjOMoSbGSxx9"
248
+ },
249
+ "source": [
250
+ "# Foundation model"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 8,
256
+ "metadata": {
257
+ "id": "ZwXZbQ2dSwzI",
258
+ "outputId": "a57e521a-a8a3-48e9-a478-63334083f94a"
259
+ },
260
+ "outputs": [
261
+ {
262
+ "name": "stderr",
263
+ "output_type": "stream",
264
+ "text": [
265
+ "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
266
+ ]
267
+ },
268
+ {
269
+ "data": {
270
+ "application/vnd.jupyter.widget-view+json": {
271
+ "model_id": "87d196af4c864c2f9381a18ceb5720e5",
272
+ "version_major": 2,
273
+ "version_minor": 0
274
+ },
275
+ "text/plain": [
276
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
277
+ ]
278
+ },
279
+ "metadata": {},
280
+ "output_type": "display_data"
281
+ }
282
+ ],
283
+ "source": [
284
+ "bnb_config = BitsAndBytesConfig(\n",
285
+ " load_in_4bit=True,\n",
286
+ " bnb_4bit_quant_type=\"nf4\",\n",
287
+ " bnb_4bit_compute_dtype=torch.float16,\n",
288
+ ")\n",
289
+ "\n",
290
+ "model = AutoModelForCausalLM.from_pretrained(\n",
291
+ " \"mistralai/Mistral-7B-Instruct-v0.1\",\n",
292
+ " device_map=\"auto\",\n",
293
+ " torch_dtype=torch.float16, #torch.bfloat16,\n",
294
+ " trust_remote_code=True\n",
295
+ " )"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "markdown",
300
+ "metadata": {
301
+ "id": "NuAx3zBeUL1q"
302
+ },
303
+ "source": [
304
+ "## LoRa configuration"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": 9,
310
+ "metadata": {
311
+ "id": "dQdvjTYTT1vQ",
312
+ "tags": []
313
+ },
314
+ "outputs": [],
315
+ "source": [
316
+ "lora_alpha = 16\n",
317
+ "lora_dropout = 0.1\n",
318
+ "lora_r = 64\n",
319
+ "\n",
320
+ "peft_config = LoraConfig(\n",
321
+ " lora_alpha=lora_alpha,\n",
322
+ " lora_dropout=lora_dropout,\n",
323
+ " r=lora_r,\n",
324
+ " bias=\"none\",\n",
325
+ " task_type=\"CAUSAL_LM\",\n",
326
+ " target_modules=[\n",
327
+ " \"Wqkv\",\n",
328
+ " \"out_proj\",\n",
329
+ " \"up_proj\",\n",
330
+ " \"down_proj\",\n",
331
+ " ])\n"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "markdown",
336
+ "metadata": {},
337
+ "source": [
338
+ "# Training parameters"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 10,
344
+ "metadata": {
345
+ "tags": []
346
+ },
347
+ "outputs": [],
348
+ "source": [
349
+ "output_dir = \"/MY_DIRECTORY\"\n",
350
+ "per_device_train_batch_size = 1\n",
351
+ "gradient_accumulation_steps = 16 \n",
352
+ "optim = \"paged_adamw_32bit\"\n",
353
+ "save_steps = 55 \n",
354
+ "logging_steps = 55\n",
355
+ "learning_rate = 1e-4\n",
356
+ "max_grad_norm = 0.3\n",
357
+ "max_steps = 55 * 15 \n",
358
+ "warmup_ratio = 0.03\n",
359
+ "lr_scheduler_type = \"linear\"\n",
360
+ "\n",
361
+ "training_arguments = TrainingArguments(\n",
362
+ " output_dir=output_dir,\n",
363
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
364
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
365
+ " optim=optim,\n",
366
+ " logging_steps=logging_steps,\n",
367
+ " save_strategy= 'no', #''epoch',\n",
368
+ " #save_steps=save_steps,\n",
369
+ " #evaluation_strategy = \"steps\",#\"epoch\",\n",
370
+ " learning_rate=learning_rate,\n",
371
+ " fp16=True,\n",
372
+ " max_grad_norm=max_grad_norm,\n",
373
+ " max_steps=max_steps,\n",
374
+ " warmup_ratio=warmup_ratio,\n",
375
+ " group_by_length=True,\n",
376
+ " lr_scheduler_type=lr_scheduler_type,\n",
377
+ " report_to = 'none',\n",
378
+ " save_total_limit = 1\n",
379
+ ")"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "markdown",
384
+ "metadata": {},
385
+ "source": [
386
+ "# Training"
387
+ ]
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "execution_count": 11,
392
+ "metadata": {},
393
+ "outputs": [],
394
+ "source": [
395
+ "model = get_peft_model(model, peft_config)"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "code",
400
+ "execution_count": 12,
401
+ "metadata": {
402
+ "tags": []
403
+ },
404
+ "outputs": [],
405
+ "source": [
406
+ "trainer = Trainer(\n",
407
+ " model=model,\n",
408
+ " tokenizer=tokenizer,\n",
409
+ " data_collator=collator,\n",
410
+ " train_dataset=tokenized_datasets,\n",
411
+ " #eval_dataset=\n",
412
+ " args=training_arguments,\n",
413
+ ")"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": null,
419
+ "metadata": {},
420
+ "outputs": [],
421
+ "source": [
422
+ "trainer.train()"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "markdown",
427
+ "metadata": {},
428
+ "source": [
429
+ "# Save model"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": 28,
435
+ "metadata": {
436
+ "execution": {
437
+ "iopub.execute_input": "2023-11-12T18:43:12.964677Z",
438
+ "iopub.status.busy": "2023-11-12T18:43:12.964270Z",
439
+ "iopub.status.idle": "2023-11-12T18:43:13.685390Z",
440
+ "shell.execute_reply": "2023-11-12T18:43:13.684268Z",
441
+ "shell.execute_reply.started": "2023-11-12T18:43:12.964645Z"
442
+ }
443
+ },
444
+ "outputs": [],
445
+ "source": [
446
+ "trainer.save_model(output_dir)"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "markdown",
451
+ "metadata": {},
452
+ "source": [
453
+ "# Reload the model"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": null,
459
+ "metadata": {},
460
+ "outputs": [],
461
+ "source": [
462
+ "model1 = AutoPeftModelForCausalLM.from_pretrained(output_dir, load_in_4bit=True)\n",
463
+ "tokenizer1 = AutoTokenizer.from_pretrained(output_dir)"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "markdown",
468
+ "metadata": {},
469
+ "source": [
470
+ "# Prompt preparation"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "markdown",
475
+ "metadata": {},
476
+ "source": [
477
+ "## Criteria for early stopping during generation"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": null,
483
+ "metadata": {},
484
+ "outputs": [],
485
+ "source": [
486
+ "from transformers import StoppingCriteria,StoppingCriteriaList"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "metadata": {},
493
+ "outputs": [],
494
+ "source": [
495
+ "class StopOnTokens(StoppingCriteria):\n",
496
+ " def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n",
497
+ " stop_ids = [28723] # corresponding to '.'\n",
498
+ " for stop_id in stop_ids:\n",
499
+ " if input_ids[0][-1] == stop_id:\n",
500
+ " return True\n",
501
+ " return False"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "markdown",
506
+ "metadata": {},
507
+ "source": [
508
+ "## Prompt answers"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "execution_count": null,
514
+ "metadata": {},
515
+ "outputs": [],
516
+ "source": [
517
+ "\n",
518
+ "text = \"At Laurent restaurant : do you have any vegetarian options?\"\n",
519
+ "#text = \"At Laurent restaurant: do you have Apple pie?\"\n",
520
+ "#text = \"At Laurent restaurant: what is included in the Premium Sweetheart Set for Two?\"\n",
521
+ "#text = \"At Laurent restaurant: do you have Seafood Paella?\"\n",
522
+ "#text = \"At Laurent restaurant: what is the best menu?\"\n",
523
+ "\n",
524
+ "inputs = tokenizer1(text, return_tensors=\"pt\").to('cuda')\n",
525
+ "out = model1.generate(**inputs, \n",
526
+ " pad_token_id=tokenizer.eos_token_id,\n",
527
+ " stopping_criteria = StoppingCriteriaList([StopOnTokens()]),\n",
528
+ " max_new_tokens=100\n",
529
+ " )\n",
530
+ "\n",
531
+ "tokenizer1.decode(out[0],skip_special_tokens=True).split(\"[answer]:\")[1]\n"
532
+ ]
533
+ }
534
+ ],
535
+ "metadata": {
536
+ "kernelspec": {
537
+ "display_name": "Python 3 (ipykernel)",
538
+ "language": "python",
539
+ "name": "python3"
540
+ },
541
+ "language_info": {
542
+ "codemirror_mode": {
543
+ "name": "ipython",
544
+ "version": 3
545
+ },
546
+ "file_extension": ".py",
547
+ "mimetype": "text/x-python",
548
+ "name": "python",
549
+ "nbconvert_exporter": "python",
550
+ "pygments_lexer": "ipython3",
551
+ "version": "3.9.13"
552
+ }
553
+ },
554
+ "nbformat": 4,
555
+ "nbformat_minor": 4
556
+ }