{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "429b26f3-8c61-46cc-b5fc-284add4d018f", "metadata": {}, "outputs": [], "source": [ "import json\n", "from tqdm.auto import tqdm\n", "from datasets import load_dataset\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import os\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "2a927511-78a0-42d5-861d-9e7af50ff000", "metadata": {}, "outputs": [], "source": [ "import requests\n", "from bs4 import BeautifulSoup\n", "\n", "page = requests.get('https://arxiv.org/category_taxonomy')\n", "soup = BeautifulSoup(page.content)\n", "tag_to_name = {}\n", "for tag_html in soup.find_all('h4')[1:]:\n", " tag, name = tag_html.text.split(maxsplit=1)\n", " tag_to_name[tag] = name[1:-1]\n", "with open('tag_to_name.json', 'w') as fout:\n", " json.dump(tag_to_name, fout)" ] }, { "cell_type": "code", "execution_count": 3, "id": "19b75e52-15c0-472e-b737-72c5eea896ec", "metadata": {}, "outputs": [], "source": [ "tag_to_label = dict(zip(tag_to_name, range(len(tag_to_name))))" ] }, { "cell_type": "code", "execution_count": 4, "id": "fec2865f-2992-4b3e-9202-8e9b8c5a7da1", "metadata": {}, "outputs": [], "source": [ "def add_labels(row):\n", " tag_list = eval(row['tag'])\n", " label_ids, label_tags = [], []\n", " for tag_dict in tag_list:\n", " if tag_dict['term'] in tag_to_label:\n", " label_tags.append(tag_dict['term'])\n", " label_ids.append(tag_to_label[tag_dict['term']])\n", " return {'label_ids': label_ids, 'label_tags': label_tags}" ] }, { "cell_type": "code", "execution_count": 5, "id": "81dff335-093f-4a59-93b5-27d7c57aac9a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default-60d1f0f90275ae1e\n", "Found cached dataset json (/root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-66945521f8e38136.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-5298549794823409.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-6c93a706327f5678.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-ff58b61d0d461ac4.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-259b966b550351dc.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8f0ed2baf297a3db.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-845944d2885d6a34.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8ec43ba6cf3d3eba.arrow\n" ] } ], "source": [ "dataset = load_dataset(\"json\", data_files=\"arxivData.json\", split=\"train\")\n", "dataset = dataset.map(add_labels, num_proc=8)\n", "dataset = dataset.remove_columns(['author', 'day', 'id', 'link', 'month', 'tag', 'year'])" ] }, { "cell_type": "code", "execution_count": 6, "id": "c9a6ab6a-6a47-4377-a9d9-044c3a395ef3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
summarytitlelabel_idslabel_tags
0We propose an architecture for VQA which utili...Dual Recurrent Attention Units for Visual Ques...[0, 5, 7, 28, 152][cs.AI, cs.CL, cs.CV, cs.NE, stat.ML]
1In a physical neural system, where storage and...A Theory of Local Learning, the Learning Chann...[22, 28, 152][cs.LG, cs.NE, stat.ML]
2One way to approach end-to-end autonomous driv...Query-Efficient Imitation Learning for End-to-...[22, 0, 34][cs.LG, cs.AI, cs.RO]
\n", "
" ], "text/plain": [ " summary \\\n", "0 We propose an architecture for VQA which utili... \n", "1 In a physical neural system, where storage and... \n", "2 One way to approach end-to-end autonomous driv... \n", "\n", " title label_ids \\\n", "0 Dual Recurrent Attention Units for Visual Ques... [0, 5, 7, 28, 152] \n", "1 A Theory of Local Learning, the Learning Chann... [22, 28, 152] \n", "2 Query-Efficient Imitation Learning for End-to-... [22, 0, 34] \n", "\n", " label_tags \n", "0 [cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] \n", "1 [cs.LG, cs.NE, stat.ML] \n", "2 [cs.LG, cs.AI, cs.RO] " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(dataset.select([0, 1000, 10000]))" ] }, { "cell_type": "code", "execution_count": 7, "id": "c193d04b-5def-443f-b723-1e3cf9df4d9e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached split indices for dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-7ce5346705e1f437.arrow and /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-981e0a6e9da25ee7.arrow\n", "Loading cached split indices for dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-1ab388509804381c.arrow and /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-eac731b57f161563.arrow\n" ] }, { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['summary', 'title', 'label_ids', 'label_tags'],\n", " num_rows: 38952\n", " })\n", " val: Dataset({\n", " features: ['summary', 'title', 'label_ids', 'label_tags'],\n", " num_rows: 1024\n", " })\n", " test: Dataset({\n", " features: ['summary', 'title', 'label_ids', 'label_tags'],\n", " num_rows: 1024\n", " })\n", "})" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import DatasetDict\n", "\n", "dataset = dataset.train_test_split(test_size=2048, seed=0)\n", "dataset_val = dataset['test'].train_test_split(test_size=1024, seed=0)\n", "\n", "dataset = DatasetDict({\n", " 'train': dataset['train'],\n", " 'val': dataset_val['train'],\n", " 'test': dataset_val['test'],\n", "})\n", "\n", "dataset" ] }, { "cell_type": "code", "execution_count": 17, "id": "2544c24b-d2ed-4fba-bb86-75469053db8c", "metadata": {}, "outputs": [], "source": [ "def get_collator(tokenizer, abstract_proba=0.5, num_labels=len(tag_to_label)):\n", " def collate_fn(rows):\n", " texts = []\n", " take_abstracts = np.random.rand(len(rows)) < abstract_proba\n", " for row, take_abstract in zip(rows, take_abstracts):\n", " if take_abstract:\n", " texts.append(row['title'] + '[SEP]' + row['summary'])\n", " else:\n", " texts.append(row['title'])\n", " processed = tokenizer(texts, truncation=True, return_tensors='pt', padding=True, max_length=512)\n", " labels = torch.zeros(size=(len(rows), num_labels), dtype=torch.float)\n", " for i, row in enumerate(rows):\n", " labels[i, row['label_ids']] = 1 / len(row['label_ids'])\n", " processed['labels'] = labels\n", " return processed\n", " return collate_fn" ] }, { "cell_type": "code", "execution_count": 9, "id": "33934717-57ca-49e8-8354-3eafe503bcf0", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "/usr/local/lib/python3.8/dist-packages/transformers/convert_slow_tokenizer.py:446: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n", " warnings.warn(\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "Some weights of the model checkpoint at microsoft/deberta-v3-base were not used when initializing DebertaV2ForSequenceClassification: ['lm_predictions.lm_head.LayerNorm.bias', 'lm_predictions.lm_head.bias', 'mask_predictions.classifier.weight', 'mask_predictions.LayerNorm.bias', 'mask_predictions.classifier.bias', 'mask_predictions.dense.weight', 'lm_predictions.lm_head.dense.bias', 'mask_predictions.LayerNorm.weight', 'lm_predictions.lm_head.LayerNorm.weight', 'mask_predictions.dense.bias', 'lm_predictions.lm_head.dense.weight']\n", "- This IS expected if you are initializing DebertaV2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DebertaV2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.weight', 'pooler.dense.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", "\n", "tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-base')\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(\n", " 'microsoft/deberta-v3-base',\n", " problem_type=None, # https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L1349\n", " num_labels=len(tag_to_label), id2label={v: k for k, v in tag_to_label.items()}, label2id=tag_to_label)\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "588d769e-d44b-4367-a22b-7b9b87cb5319", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "max_steps is given, it will override any value given in num_train_epochs\n" ] } ], "source": [ "from transformers import TrainingArguments, Trainer\n", "\n", "training_args = TrainingArguments(\n", " output_dir='checkpoints',\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=24,\n", " per_device_eval_batch_size=24,\n", " weight_decay=0.01,\n", " warmup_ratio=0.02,\n", " logging_steps=100,\n", " overwrite_output_dir=True,\n", " seed=0,\n", " dataloader_num_workers=8,\n", " do_train=True,\n", " do_eval=True,\n", " max_steps=5000,\n", " save_strategy=\"steps\",\n", " evaluation_strategy=\"steps\",\n", " eval_steps=100,\n", " save_steps=100,\n", " save_total_limit=2,\n", " lr_scheduler_type=\"linear\",\n", " load_best_model_at_end=True,\n", " report_to=\"tensorboard\",\n", " remove_unused_columns=False,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=dataset['train'],\n", " eval_dataset=dataset['val'],\n", " tokenizer=tokenizer,\n", " data_collator=get_collator(tokenizer),\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "id": "04d3ccf6-193b-4ee2-ad4d-2cdbb3c7b737", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.8/dist-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 38952\n", " Num Epochs = 4\n", " Instantaneous batch size per device = 24\n", " Total train batch size (w. parallel, distributed & accumulation) = 24\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 5000\n", " Number of trainable parameters = 184541339\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [5000/5000 22:52, Epoch 3/4]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation Loss
1004.2861002.809958
2002.3657002.110714
3002.0236002.046348
4002.0204001.982979
5001.9273001.915667
6001.9195001.927610
7001.8346001.929402
8001.8408001.861055
9001.8239001.819358
10001.7571001.798097
11001.7465001.779167
12001.7750001.774340
13001.6985001.764457
14001.6842001.741629
15001.7630001.680664
16001.6784001.712918
17001.6698001.710484
18001.6650001.698851
19001.6452001.663767
20001.6676001.674545
21001.6023001.680639
22001.6518001.667343
23001.6226001.659117
24001.6169001.645381
25001.6009001.642603
26001.5902001.657698
27001.6463001.644075
28001.6026001.626339
29001.5968001.646950
30001.5472001.622913
31001.5635001.611651
32001.5835001.608005
33001.5658001.626086
34001.5310001.626902
35001.5661001.607745
36001.5551001.594658
37001.5976001.597994
38001.4976001.590335
39001.5223001.588875
40001.5066001.572686
41001.4979001.602122
42001.5341001.576102
43001.5174001.578320
44001.5185001.588920
45001.5102001.596100
46001.4411001.576099
47001.5110001.575001
48001.4877001.579319
49001.4913001.591276
50001.4747001.572709

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-100\n", "Configuration saved in checkpoints/checkpoint-100/config.json\n", "Model weights saved in checkpoints/checkpoint-100/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-100/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-100/special_tokens_map.json\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-200\n", "Configuration saved in checkpoints/checkpoint-200/config.json\n", "Model weights saved in checkpoints/checkpoint-200/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-200/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-200/special_tokens_map.json\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-300\n", "Configuration saved in checkpoints/checkpoint-300/config.json\n", "Model weights saved in checkpoints/checkpoint-300/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-300/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-300/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-100] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-400\n", "Configuration saved in checkpoints/checkpoint-400/config.json\n", "Model weights saved in checkpoints/checkpoint-400/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-400/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-400/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-200] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-500\n", "Configuration saved in checkpoints/checkpoint-500/config.json\n", "Model weights saved in checkpoints/checkpoint-500/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-500/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-500/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-300] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-600\n", "Configuration saved in checkpoints/checkpoint-600/config.json\n", "Model weights saved in checkpoints/checkpoint-600/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-600/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-600/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-400] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-700\n", "Configuration saved in checkpoints/checkpoint-700/config.json\n", "Model weights saved in checkpoints/checkpoint-700/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-700/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-700/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-600] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-800\n", "Configuration saved in checkpoints/checkpoint-800/config.json\n", "Model weights saved in checkpoints/checkpoint-800/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-800/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-800/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-500] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-900\n", "Configuration saved in checkpoints/checkpoint-900/config.json\n", "Model weights saved in checkpoints/checkpoint-900/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-900/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-900/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-700] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-1000\n", "Configuration saved in checkpoints/checkpoint-1000/config.json\n", "Model weights saved in checkpoints/checkpoint-1000/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-1000/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-1000/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-800] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-1100\n", "Configuration saved in checkpoints/checkpoint-1100/config.json\n", "Model weights saved in checkpoints/checkpoint-1100/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-1100/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-1100/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-900] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-1200\n", "Configuration saved in checkpoints/checkpoint-1200/config.json\n", "Model weights saved in checkpoints/checkpoint-1200/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-1200/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-1200/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-1000] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-1300\n", "Configuration saved in checkpoints/checkpoint-1300/config.json\n", "Model weights saved in checkpoints/checkpoint-1300/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-1300/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-1300/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-1100] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-1400\n", "Configuration saved in checkpoints/checkpoint-1400/config.json\n", "Model weights saved in checkpoints/checkpoint-1400/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-1400/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-1400/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-1200] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-1500\n", "Configuration saved in checkpoints/checkpoint-1500/config.json\n", "Model weights saved in checkpoints/checkpoint-1500/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-1500/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-1500/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-1300] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-1600\n", "Configuration saved in checkpoints/checkpoint-1600/config.json\n", "Model weights saved in checkpoints/checkpoint-1600/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-1600/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-1600/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-1400] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-1700\n", "Configuration saved in checkpoints/checkpoint-1700/config.json\n", "Model weights saved in checkpoints/checkpoint-1700/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-1700/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-1700/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-1600] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-1800\n", "Configuration saved in checkpoints/checkpoint-1800/config.json\n", "Model weights saved in checkpoints/checkpoint-1800/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-1800/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-1800/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-1700] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-2200\n", "Configuration saved in checkpoints/checkpoint-2200/config.json\n", "Model weights saved in checkpoints/checkpoint-2200/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-2200/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-2200/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-2100] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-2300\n", "Configuration saved in checkpoints/checkpoint-2300/config.json\n", "Model weights saved in checkpoints/checkpoint-2300/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-2300/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-2300/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-1900] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-2400\n", "Configuration saved in checkpoints/checkpoint-2400/config.json\n", "Model weights saved in checkpoints/checkpoint-2400/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-2400/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-2400/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-2200] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-2500\n", "Configuration saved in checkpoints/checkpoint-2500/config.json\n", "Model weights saved in checkpoints/checkpoint-2500/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-2500/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-2500/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-2300] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-2600\n", "Configuration saved in checkpoints/checkpoint-2600/config.json\n", "Model weights saved in checkpoints/checkpoint-2600/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-2600/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-2600/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-2400] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-2700\n", "Configuration saved in checkpoints/checkpoint-2700/config.json\n", "Model weights saved in checkpoints/checkpoint-2700/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-2700/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-2700/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-2600] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-2800\n", "Configuration saved in checkpoints/checkpoint-2800/config.json\n", "Model weights saved in checkpoints/checkpoint-2800/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-2800/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-2800/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-2500] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-2900\n", "Configuration saved in checkpoints/checkpoint-2900/config.json\n", "Model weights saved in checkpoints/checkpoint-2900/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-2900/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-2900/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-2700] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3000\n", "Configuration saved in checkpoints/checkpoint-3000/config.json\n", "Model weights saved in checkpoints/checkpoint-3000/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3000/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3000/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-2800] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3100\n", "Configuration saved in checkpoints/checkpoint-3100/config.json\n", "Model weights saved in checkpoints/checkpoint-3100/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3100/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3100/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-2900] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3200\n", "Configuration saved in checkpoints/checkpoint-3200/config.json\n", "Model weights saved in checkpoints/checkpoint-3200/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3200/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3200/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3000] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3300\n", "Configuration saved in checkpoints/checkpoint-3300/config.json\n", "Model weights saved in checkpoints/checkpoint-3300/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3300/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3300/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3100] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3400\n", "Configuration saved in checkpoints/checkpoint-3400/config.json\n", "Model weights saved in checkpoints/checkpoint-3400/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3400/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3400/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3300] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3500\n", "Configuration saved in checkpoints/checkpoint-3500/config.json\n", "Model weights saved in checkpoints/checkpoint-3500/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3500/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3500/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3200] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3600\n", "Configuration saved in checkpoints/checkpoint-3600/config.json\n", "Model weights saved in checkpoints/checkpoint-3600/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3600/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3600/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3400] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3700\n", "Configuration saved in checkpoints/checkpoint-3700/config.json\n", "Model weights saved in checkpoints/checkpoint-3700/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3700/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3700/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3500] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3800\n", "Configuration saved in checkpoints/checkpoint-3800/config.json\n", "Model weights saved in checkpoints/checkpoint-3800/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3800/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3800/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3600] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-3900\n", "Configuration saved in checkpoints/checkpoint-3900/config.json\n", "Model weights saved in checkpoints/checkpoint-3900/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-3900/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-3900/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3700] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4000\n", "Configuration saved in checkpoints/checkpoint-4000/config.json\n", "Model weights saved in checkpoints/checkpoint-4000/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4000/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4000/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3800] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4100\n", "Configuration saved in checkpoints/checkpoint-4100/config.json\n", "Model weights saved in checkpoints/checkpoint-4100/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4100/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4100/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-3900] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4200\n", "Configuration saved in checkpoints/checkpoint-4200/config.json\n", "Model weights saved in checkpoints/checkpoint-4200/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4200/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4200/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-4100] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4300\n", "Configuration saved in checkpoints/checkpoint-4300/config.json\n", "Model weights saved in checkpoints/checkpoint-4300/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4300/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4300/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-4200] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4400\n", "Configuration saved in checkpoints/checkpoint-4400/config.json\n", "Model weights saved in checkpoints/checkpoint-4400/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4400/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4400/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-4300] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4500\n", "Configuration saved in checkpoints/checkpoint-4500/config.json\n", "Model weights saved in checkpoints/checkpoint-4500/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4500/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4500/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-4400] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4600\n", "Configuration saved in checkpoints/checkpoint-4600/config.json\n", "Model weights saved in checkpoints/checkpoint-4600/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4600/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4600/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-4500] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4700\n", "Configuration saved in checkpoints/checkpoint-4700/config.json\n", "Model weights saved in checkpoints/checkpoint-4700/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4700/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4700/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-4600] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4800\n", "Configuration saved in checkpoints/checkpoint-4800/config.json\n", "Model weights saved in checkpoints/checkpoint-4800/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4800/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4800/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-4700] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-4900\n", "Configuration saved in checkpoints/checkpoint-4900/config.json\n", "Model weights saved in checkpoints/checkpoint-4900/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-4900/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-4900/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-4800] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 1024\n", " Batch size = 24\n", "Saving model checkpoint to checkpoints/checkpoint-5000\n", "Configuration saved in checkpoints/checkpoint-5000/config.json\n", "Model weights saved in checkpoints/checkpoint-5000/pytorch_model.bin\n", "tokenizer config file saved in checkpoints/checkpoint-5000/tokenizer_config.json\n", "Special tokens file saved in checkpoints/checkpoint-5000/special_tokens_map.json\n", "Deleting older checkpoint [checkpoints/checkpoint-4900] due to args.save_total_limit\n", "\n", "\n", "Training completed. Do not forget to share your model on huggingface.co/models =)\n", "\n", "\n", "Loading best model from checkpoints/checkpoint-4000 (score: 1.5726864337921143).\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=5000, training_loss=1.7068539916992187, metrics={'train_runtime': 1373.8884, 'train_samples_per_second': 87.343, 'train_steps_per_second': 3.639, 'total_flos': 1.9803672136145664e+16, 'train_loss': 1.7068539916992187, 'epoch': 3.08})" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 12, "id": "86cbf6fb-8e38-4a54-bf1a-987e308ef97d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "checkpoint-4000 checkpoint-5000 runs\n" ] } ], "source": [ "!ls checkpoints/" ] }, { "cell_type": "code", "execution_count": 94, "id": "4a0a8e00-6b91-4c8e-af5c-8ab66c3e5648", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "def calc_metrics(model, dataset, abstract_proba):\n", " dataloader = DataLoader(\n", " dataset, batch_size=16, shuffle=False,\n", " collate_fn=get_collator(tokenizer, abstract_proba=abstract_proba)\n", " )\n", " precisions, recalls, top1_accs = [], [], []\n", " with torch.no_grad():\n", " for batch in tqdm(dataloader):\n", " outputs = model(**batch.to('cuda'))\n", " for labels, preds in zip(batch['labels'], outputs.logits.softmax(-1)):\n", " top_probs, top_inds = preds.sort(descending=True)\n", " mask = top_probs.cumsum(0) <= 0.95\n", " mask[0] = True\n", " a = set(top_inds[mask].tolist())\n", " y = set(labels.nonzero().flatten().tolist())\n", " top1_accs.append(int(top_inds[0]) in y)\n", " recalls.append(len(y & a) / len(y))\n", " precisions.append(len(y & a) / len(a))\n", " return {'Recall@0.95': np.mean(recalls),\n", " 'Precision@0.95': np.mean(precisions),\n", " 'Top-1 Accuracy': np.mean(top1_accs)}" ] }, { "cell_type": "code", "execution_count": 97, "id": "a374a862-3d6c-4d58-8871-a1195dd75e1c", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", "\n", "tokenizer = AutoTokenizer.from_pretrained('checkpoints/checkpoint-4000/')\n", "model = AutoModelForSequenceClassification.from_pretrained('checkpoints/checkpoint-4000/')\n", "model.to('cuda')\n", "model.eval();" ] }, { "cell_type": "code", "execution_count": 98, "id": "1320d0fd-5302-4677-b9ae-4c9d5652db73", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7da9cb0fa4154813a7db717c9556e4f4", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/64 [00:00