{
"cells": [
{
"cell_type": "markdown",
"id": "1c71aba7-c0f3-4378-9b63-55529e0994b4",
"metadata": {},
"source": [
"# Data\n",
"\n",
"Мы используем следующий датасет для файнтюнинга:\n",
"\n",
"- [arXiv papers](https://www.kaggle.com/datasets/neelshah18/arxivdataset)\n",
"\n",
"Среди статей на arXiv есть также статьи по вычислительной биологии, геномике, etc.\n",
"\n",
"Среди альтернатив — [датасет](https://zenodo.org/record/7695390) из [недавнего исследования](https://www.biorxiv.org/content/10.1101/2023.04.10.536208v1.full.pdf) с названиями и лейблами статей из PubMed. В нём 20 миллионов статей, но приведены только заголовки (без абстрактов).\n",
"\n",
"В данном ноутбуке мы используем данные и теги с arXiv."
]
},
{
"cell_type": "markdown",
"id": "e9874f4a-3898-4c89-a0f7-04eeabf2b389",
"metadata": {
"tags": []
},
"source": [
"# Models\n",
"\n",
"В качестве базовой модели мы используем BERT, натренированный на биомедицинских данных (из PubMed). \n",
"\n",
"- [BiomedNLP-PubMedBERT](https://huggingface.co./microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract)"
]
},
{
"cell_type": "markdown",
"id": "991e48e7-897f-45a3-8a0b-539ea67b4eb5",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"id": "2f130f05-21ee-46f9-889f-488e8c676aba",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "757a0582-1b8c-4f1c-b26f-544688e391f4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"import transformers\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"\n",
"import torch\n",
"from datasets import Dataset, ClassLabel\n",
"from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification\n",
"from transformers import TrainingArguments, Trainer\n",
"from transformers import pipeline\n",
"import evaluate"
]
},
{
"cell_type": "markdown",
"id": "03847b87-d096-49a5-b6e2-023fa08b94c2",
"metadata": {},
"source": [
"# Load data"
]
},
{
"cell_type": "markdown",
"id": "b3e902ea-4e0f-4d76-b27b-59e472b2b556",
"metadata": {},
"source": [
"Загрузим данные для файнтюнинга — в частности, нам понадобятся названия статей, их абстракты и теги."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1be8f69e-bd7d-4ca9-ba9f-044b8e7bc497",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"df = pd.read_json(\"arxivData.json\")"
]
},
{
"cell_type": "markdown",
"id": "791edb3c-a96d-4042-b35d-c8097bbbef79",
"metadata": {},
"source": [
" "
]
},
{
"cell_type": "markdown",
"id": "d5b6158a-728e-4ada-bcdc-a4a49328f002",
"metadata": {},
"source": [
"Совместим заголовки и абстракты и сохраним текст в соответствующей колонке:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c8709a7b-becf-4f19-8b4f-8773cd5c60f1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"df['text'] = df['title'] + \"\\n\" + df['summary']"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ed0ed687-6439-494a-a5a8-c572bc2e4059",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" author | \n",
" day | \n",
" id | \n",
" link | \n",
" month | \n",
" summary | \n",
" tag | \n",
" title | \n",
" year | \n",
" text | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" [{'name': 'Ahmed Osman'}, {'name': 'Wojciech S... | \n",
" 1 | \n",
" 1802.00209v1 | \n",
" [{'rel': 'alternate', 'href': 'http://arxiv.or... | \n",
" 2 | \n",
" We propose an architecture for VQA which utili... | \n",
" [{'term': 'cs.AI', 'scheme': 'http://arxiv.org... | \n",
" Dual Recurrent Attention Units for Visual Ques... | \n",
" 2018 | \n",
" Dual Recurrent Attention Units for Visual Ques... | \n",
"
\n",
" \n",
" 1 | \n",
" [{'name': 'Ji Young Lee'}, {'name': 'Franck De... | \n",
" 12 | \n",
" 1603.03827v1 | \n",
" [{'rel': 'alternate', 'href': 'http://arxiv.or... | \n",
" 3 | \n",
" Recent approaches based on artificial neural n... | \n",
" [{'term': 'cs.CL', 'scheme': 'http://arxiv.org... | \n",
" Sequential Short-Text Classification with Recu... | \n",
" 2016 | \n",
" Sequential Short-Text Classification with Recu... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" author day id \\\n",
"0 [{'name': 'Ahmed Osman'}, {'name': 'Wojciech S... 1 1802.00209v1 \n",
"1 [{'name': 'Ji Young Lee'}, {'name': 'Franck De... 12 1603.03827v1 \n",
"\n",
" link month \\\n",
"0 [{'rel': 'alternate', 'href': 'http://arxiv.or... 2 \n",
"1 [{'rel': 'alternate', 'href': 'http://arxiv.or... 3 \n",
"\n",
" summary \\\n",
"0 We propose an architecture for VQA which utili... \n",
"1 Recent approaches based on artificial neural n... \n",
"\n",
" tag \\\n",
"0 [{'term': 'cs.AI', 'scheme': 'http://arxiv.org... \n",
"1 [{'term': 'cs.CL', 'scheme': 'http://arxiv.org... \n",
"\n",
" title year \\\n",
"0 Dual Recurrent Attention Units for Visual Ques... 2018 \n",
"1 Sequential Short-Text Classification with Recu... 2016 \n",
"\n",
" text \n",
"0 Dual Recurrent Attention Units for Visual Ques... \n",
"1 Sequential Short-Text Classification with Recu... "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head(2)"
]
},
{
"cell_type": "markdown",
"id": "ce1de806-a4d2-4e58-a3a8-f3542392f22e",
"metadata": {},
"source": [
"## Labels"
]
},
{
"cell_type": "markdown",
"id": "b5183517-8b02-47bc-812a-415b5651e07d",
"metadata": {},
"source": [
"Будем использовать категории из arXiv'а, такие как `astro-ph` для статей по астрофизике или `cs.CV` для computer vision (computer science)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ba4e7197-23b6-4cb4-9b44-620c6b730eb7",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total: 126 labels such as adap-org, astro-ph, ..., stat.OT\n"
]
}
],
"source": [
"df['category'] = [eval(i)[0]['term'].strip() for i in df['tag']]\n",
"categories = np.unique(df['category'])\n",
"num_labels = len(categories)\n",
"print(f\"Total: {num_labels} labels such as {categories[0]}, {categories[1]}, ..., {categories[-1]}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1508a6d9-856d-4ecf-a0f3-895d3ffbe99b",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" category | \n",
" category_index | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" adap-org | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" astro-ph | \n",
" 1 | \n",
"
\n",
" \n",
" 2 | \n",
" astro-ph.CO | \n",
" 2 | \n",
"
\n",
" \n",
" 3 | \n",
" astro-ph.EP | \n",
" 3 | \n",
"
\n",
" \n",
" 4 | \n",
" astro-ph.GA | \n",
" 4 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" category category_index\n",
"0 adap-org 0\n",
"1 astro-ph 1\n",
"2 astro-ph.CO 2\n",
"3 astro-ph.EP 3\n",
"4 astro-ph.GA 4"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame({\n",
" \"category\": categories,\n",
" \"category_index\": np.arange(num_labels),\n",
"}).head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5c082c3a-7b0e-4320-b62d-f75a6c9f2398",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"df = pd.DataFrame({\n",
" \"category\": categories,\n",
" \"category_index\": np.arange(num_labels),\n",
"}).set_index(\"category\").join(df.set_index(\"category\"), how=\"right\", sort=False).reset_index()"
]
},
{
"cell_type": "markdown",
"id": "76d8ccb9-a993-4d82-9dd3-689380e92e55",
"metadata": {},
"source": [
"# Model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a0c154f7-d2fa-46a1-8b69-57174bf00632",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"id": "2bf6513d-664d-4b94-8b05-7e8df205e3ec",
"metadata": {},
"source": [
"Токенайзер (название + абстракт -> токены):"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "12fa49a7-2ac5-4f78-84fe-93305926692e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")"
]
},
{
"cell_type": "markdown",
"id": "0ea1b4e5-9067-4292-ba12-8f560bbf26fd",
"metadata": {},
"source": [
"Сама модель, в которой `AutoModelForSequenceClassification` заменит голову для задачи классификации:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d6eb92bc-c293-47ad-b9cc-2a63e8f1de69",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model = AutoModelForSequenceClassification.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\", num_labels=num_labels).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f5c79846-e6fc-42c0-bb8d-949678f5e60a",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BertForSequenceClassification(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0-11): 12 x BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (classifier): Linear(in_features=768, out_features=126, bias=True)\n",
")\n"
]
}
],
"source": [
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "5ce6eefc-91ce-4486-9568-b686d04adcc7",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "markdown",
"id": "71add72c-eafb-491a-8820-31ce7336524f",
"metadata": {},
"source": [
"## Data Loaders"
]
},
{
"cell_type": "markdown",
"id": "2a0b579c-998a-4d2e-bf0e-d4c7406d22da",
"metadata": {},
"source": [
"Для работы с `transformers`, возможно, будет удобнее использовать библиотеку `datasets` для работы с данными."
]
},
{
"cell_type": "markdown",
"id": "47b0e14a-866b-49ac-8b95-49a91a0bcc22",
"metadata": {},
"source": [
"Создадим (hugging face) [датасет](https://huggingface.co./docs/datasets/tabular_load#pandas-dataframes):"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "dc1a3f33-0ef9-43c9-ab5f-eb9ae304b897",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"train_indices = np.sort(np.random.choice(np.arange(len(df)), size=37_000, replace=False))\n",
"test_indices = np.array([i for i in np.arange(len(df)) if i not in train_indices])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "d948f8a6-1a7a-4baa-88a0-418596a1f275",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"train_df = df.loc[:,[\"text\", \"category\"]].iloc[train_indices]\n",
"test_df = df.loc[:,[\"text\", \"category\"]].iloc[test_indices]\n",
"\n",
"train_ds = Dataset.from_pandas(train_df, split=\"train\")\n",
"test_ds = Dataset.from_pandas(test_df, split=\"test\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "50242a35-3067-41e5-8de8-f7e6a4fb6e9c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/37000 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/4000 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def tokenize_text(row):\n",
" return tokenizer(\n",
" row[\"text\"],\n",
" max_length=512,\n",
" truncation=True,\n",
" padding='max_length',\n",
" )\n",
"\n",
"train_ds = train_ds.map(tokenize_text, batched=True)\n",
"test_ds = test_ds.map(tokenize_text, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "35d454d1-fbdc-4847-8b60-4c6c442364b1",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/37000 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/4000 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Casting the dataset: 0%| | 0/37000 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Casting the dataset: 0%| | 0/4000 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"labels_map = ClassLabel(num_classes=num_labels, names=list(categories))\n",
"\n",
"def transform_labels(row):\n",
" # default name for a label (label or label_ids)\n",
" return {\"label\": labels_map.str2int(row[\"category\"])}\n",
"\n",
"# OR: \n",
"# \n",
"# labels_map = pd.Series(\n",
"# np.arange(num_labels),\n",
"# index=categories,\n",
"# )\n",
"# \n",
"# def transform_labels(row):\n",
"# return {\"label\": labels_map[row[\"category\"]]}\n",
"\n",
"train_ds = train_ds.map(transform_labels, batched=True)\n",
"test_ds = test_ds.map(transform_labels, batched=True)\n",
"\n",
"train_ds = train_ds.cast_column('label', labels_map)\n",
"test_ds = test_ds.cast_column('label', labels_map)"
]
},
{
"cell_type": "markdown",
"id": "6f3862ef-ed78-461f-ba68-8f059f01d355",
"metadata": {},
"source": [
" "
]
},
{
"cell_type": "markdown",
"id": "811c5fe3-218e-4187-878d-65abc157f802",
"metadata": {},
"source": [
"## Prepare training"
]
},
{
"cell_type": "code",
"execution_count": 110,
"id": "d2160c7d-4130-47ae-9d6d-6684e4ba7e9b",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" \"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\", \n",
" num_labels=num_labels,\n",
" id2label={i:labels_map.names[i] for i in range(len(categories))},\n",
" label2id={labels_map.names[i]:i for i in range(len(categories))},\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": 111,
"id": "72e74c2b-89d7-4c17-8df1-dcfd40ead01e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")"
]
},
{
"cell_type": "markdown",
"id": "ebb91037-fbdf-4453-87de-6da5eec3304f",
"metadata": {},
"source": [
"Будем вычислять accuracy:"
]
},
{
"cell_type": "code",
"execution_count": 112,
"id": "630f6fa5-4c53-4962-b36d-5ee9aad6e29d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"metric = evaluate.load(\"accuracy\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "f64425b7-72b7-466a-8e3e-cd7624893139",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"training_args = TrainingArguments(\n",
" output_dir=\"bert-paper-classifier-arxiv\", \n",
" evaluation_strategy=\"epoch\",\n",
" per_device_train_batch_size=64,\n",
" num_train_epochs=10,\n",
" logging_steps=10,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 114,
"id": "b850cd9b-eb36-40ec-8cf2-26206fedcf27",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_ds,\n",
" eval_dataset=test_ds,\n",
" compute_metrics=compute_metrics,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6b88166-d82e-4502-acef-494fbb206d30",
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ed8c94a-e3ef-47f9-96a8-c112eb7f11bc",
"metadata": {},
"outputs": [],
"source": [
"# Convert to a python file and run training:\n",
"#! jupyter nbconvert finetuning-arxiv.ipynb --to python"
]
},
{
"cell_type": "markdown",
"id": "cc8dad7d-8105-4f37-9087-615314c35afb",
"metadata": {},
"source": [
"# Save and share"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "38d24722-d5c6-40ac-b568-3cd7fd9f225e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"trainer.args.hub_model_id = \"bert-paper-classifier-arxiv\""
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "9530790c-bc63-48f4-9a01-8c534fa90e00",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"('bert-paper-classifier/tokenizer_config.json',\n",
" 'bert-paper-classifier/special_tokens_map.json',\n",
" 'bert-paper-classifier/vocab.txt',\n",
" 'bert-paper-classifier/added_tokens.json',\n",
" 'bert-paper-classifier/tokenizer.json')"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.save_pretrained(\"bert-paper-classifier-arxiv\")"
]
},
{
"cell_type": "code",
"execution_count": 116,
"id": "0498df97-cd2c-4732-9d07-ee2013f8bd55",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"trainer.save_model(\"bert-paper-classifier-arxiv\")"
]
},
{
"cell_type": "markdown",
"id": "7af12b9e-0d77-48ec-af6f-38556e13b067",
"metadata": {
"tags": []
},
"source": [
"Запушим модель на HF Hub:"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "5de0e91f-bc23-4413-b22e-5aa32b09ef12",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"To https://huggingface.co./oracat/bert-paper-classifier\n",
" 915ccf0..862abb7 main -> main\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
}
],
"source": [
"trainer.push_to_hub()"
]
},
{
"cell_type": "markdown",
"id": "5093aee3-106e-43e9-a9c7-413d059ebb27",
"metadata": {},
"source": [
" "
]
},
{
"cell_type": "markdown",
"id": "b1a1029f-543c-409e-9aaf-35bcefe49988",
"metadata": {},
"source": [
"# Inference"
]
},
{
"cell_type": "markdown",
"id": "e7b0cd5a-2e17-49f3-b2a9-5ae4e8511969",
"metadata": {},
"source": [
"Теперь попробуем загрузить модель с HF Hub:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b7fe37b9-61a9-4796-af24-092f6722cd61",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "36afc9d465f54c80ab01698f5a687388",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)okenizer_config.json: 0%| | 0.00/394 [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "df18b9d22fc14a0c81e8cb557f88a848",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)solve/main/vocab.txt: 0%| | 0.00/225k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4ba2236cf89d4159bcc9740d4654b16d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)/main/tokenizer.json: 0%| | 0.00/679k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cae249ea1c2946a89fffdb80ff1d7b7b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)cial_tokens_map.json: 0%| | 0.00/125 [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b860284eb1ff4cb08b5c8d54ab1a33b9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)lve/main/config.json: 0%| | 0.00/6.04k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3607b2b6f85b49b0a03844df69077d7e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading pytorch_model.bin: 0%| | 0.00/438M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"inference_tokenizer = AutoTokenizer.from_pretrained(\"oracat/bert-paper-classifier-arxiv\")\n",
"inference_model = AutoModelForSequenceClassification.from_pretrained(\"oracat/bert-paper-classifier-arxiv\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "34495235-4dca-4635-b468-5b15647a6682",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"pipe = pipeline(\"text-classification\", model=inference_model, tokenizer=inference_tokenizer, top_k=None)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "052b5070-c1ee-4419-8a6d-127925c95cce",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def top_pct(preds, threshold=.95):\n",
" preds = sorted(preds, key=lambda x: -x[\"score\"])\n",
" \n",
" cum_score = 0\n",
" for i, item in enumerate(preds):\n",
" cum_score += item[\"score\"]\n",
" if cum_score >= threshold:\n",
" break\n",
"\n",
" preds = preds[:(i+1)]\n",
" \n",
" return preds"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ed3545b6-e043-4dfb-aeb2-7559eac37f7c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def format_predictions(preds) -> str:\n",
" \"\"\"\n",
" Prepare predictions and their scores for printing to the user\n",
" \"\"\"\n",
" out = \"\"\n",
" for i, item in enumerate(preds):\n",
" out += f\"{i+1}. {item['label']} (score {item['score']:.2f})\\n\"\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "870d593a-a298-4d55-87b0-cb2813cc1fad",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1. cs.LG (score 0.88)\n",
"2. cs.AI (score 0.07)\n",
"3. cs.NE (score 0.03)\n",
"\n"
]
}
],
"source": [
"print(\n",
" format_predictions(\n",
" top_pct(\n",
" pipe(\"Attention Is All You Need\\nThe dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration.\")[0]\n",
" )\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "408f015e-be23-46a6-9e91-503fdccecf11",
"metadata": {},
"source": [
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}