{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "429b26f3-8c61-46cc-b5fc-284add4d018f", "metadata": {}, "outputs": [], "source": [ "import json\n", "from tqdm.auto import tqdm\n", "from datasets import load_dataset\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import os\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "2a927511-78a0-42d5-861d-9e7af50ff000", "metadata": {}, "outputs": [], "source": [ "import requests\n", "from bs4 import BeautifulSoup\n", "\n", "page = requests.get('https://arxiv.org/category_taxonomy')\n", "soup = BeautifulSoup(page.content)\n", "tag_to_name = {}\n", "for tag_html in soup.find_all('h4')[1:]:\n", " tag, name = tag_html.text.split(maxsplit=1)\n", " tag_to_name[tag] = name[1:-1]\n", "with open('tag_to_name.json', 'w') as fout:\n", " json.dump(tag_to_name, fout)" ] }, { "cell_type": "code", "execution_count": 3, "id": "19b75e52-15c0-472e-b737-72c5eea896ec", "metadata": {}, "outputs": [], "source": [ "tag_to_label = dict(zip(tag_to_name, range(len(tag_to_name))))" ] }, { "cell_type": "code", "execution_count": 4, "id": "fec2865f-2992-4b3e-9202-8e9b8c5a7da1", "metadata": {}, "outputs": [], "source": [ "def add_labels(row):\n", " tag_list = eval(row['tag'])\n", " label_ids, label_tags = [], []\n", " for tag_dict in tag_list:\n", " if tag_dict['term'] in tag_to_label:\n", " label_tags.append(tag_dict['term'])\n", " label_ids.append(tag_to_label[tag_dict['term']])\n", " return {'label_ids': label_ids, 'label_tags': label_tags}" ] }, { "cell_type": "code", "execution_count": 5, "id": "81dff335-093f-4a59-93b5-27d7c57aac9a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default-60d1f0f90275ae1e\n", "Found cached dataset json (/root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-66945521f8e38136.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-5298549794823409.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-6c93a706327f5678.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-ff58b61d0d461ac4.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-259b966b550351dc.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8f0ed2baf297a3db.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-845944d2885d6a34.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8ec43ba6cf3d3eba.arrow\n" ] } ], "source": [ "dataset = load_dataset(\"json\", data_files=\"arxivData.json\", split=\"train\")\n", "dataset = dataset.map(add_labels, num_proc=8)\n", "dataset = dataset.remove_columns(['author', 'day', 'id', 'link', 'month', 'tag', 'year'])" ] }, { "cell_type": "code", "execution_count": 6, "id": "c9a6ab6a-6a47-4377-a9d9-044c3a395ef3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | summary | \n", "title | \n", "label_ids | \n", "label_tags | \n", "
---|---|---|---|---|
0 | \n", "We propose an architecture for VQA which utili... | \n", "Dual Recurrent Attention Units for Visual Ques... | \n", "[0, 5, 7, 28, 152] | \n", "[cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] | \n", "
1 | \n", "In a physical neural system, where storage and... | \n", "A Theory of Local Learning, the Learning Chann... | \n", "[22, 28, 152] | \n", "[cs.LG, cs.NE, stat.ML] | \n", "
2 | \n", "One way to approach end-to-end autonomous driv... | \n", "Query-Efficient Imitation Learning for End-to-... | \n", "[22, 0, 34] | \n", "[cs.LG, cs.AI, cs.RO] | \n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "
---|---|---|
100 | \n", "4.286100 | \n", "2.809958 | \n", "
200 | \n", "2.365700 | \n", "2.110714 | \n", "
300 | \n", "2.023600 | \n", "2.046348 | \n", "
400 | \n", "2.020400 | \n", "1.982979 | \n", "
500 | \n", "1.927300 | \n", "1.915667 | \n", "
600 | \n", "1.919500 | \n", "1.927610 | \n", "
700 | \n", "1.834600 | \n", "1.929402 | \n", "
800 | \n", "1.840800 | \n", "1.861055 | \n", "
900 | \n", "1.823900 | \n", "1.819358 | \n", "
1000 | \n", "1.757100 | \n", "1.798097 | \n", "
1100 | \n", "1.746500 | \n", "1.779167 | \n", "
1200 | \n", "1.775000 | \n", "1.774340 | \n", "
1300 | \n", "1.698500 | \n", "1.764457 | \n", "
1400 | \n", "1.684200 | \n", "1.741629 | \n", "
1500 | \n", "1.763000 | \n", "1.680664 | \n", "
1600 | \n", "1.678400 | \n", "1.712918 | \n", "
1700 | \n", "1.669800 | \n", "1.710484 | \n", "
1800 | \n", "1.665000 | \n", "1.698851 | \n", "
1900 | \n", "1.645200 | \n", "1.663767 | \n", "
2000 | \n", "1.667600 | \n", "1.674545 | \n", "
2100 | \n", "1.602300 | \n", "1.680639 | \n", "
2200 | \n", "1.651800 | \n", "1.667343 | \n", "
2300 | \n", "1.622600 | \n", "1.659117 | \n", "
2400 | \n", "1.616900 | \n", "1.645381 | \n", "
2500 | \n", "1.600900 | \n", "1.642603 | \n", "
2600 | \n", "1.590200 | \n", "1.657698 | \n", "
2700 | \n", "1.646300 | \n", "1.644075 | \n", "
2800 | \n", "1.602600 | \n", "1.626339 | \n", "
2900 | \n", "1.596800 | \n", "1.646950 | \n", "
3000 | \n", "1.547200 | \n", "1.622913 | \n", "
3100 | \n", "1.563500 | \n", "1.611651 | \n", "
3200 | \n", "1.583500 | \n", "1.608005 | \n", "
3300 | \n", "1.565800 | \n", "1.626086 | \n", "
3400 | \n", "1.531000 | \n", "1.626902 | \n", "
3500 | \n", "1.566100 | \n", "1.607745 | \n", "
3600 | \n", "1.555100 | \n", "1.594658 | \n", "
3700 | \n", "1.597600 | \n", "1.597994 | \n", "
3800 | \n", "1.497600 | \n", "1.590335 | \n", "
3900 | \n", "1.522300 | \n", "1.588875 | \n", "
4000 | \n", "1.506600 | \n", "1.572686 | \n", "
4100 | \n", "1.497900 | \n", "1.602122 | \n", "
4200 | \n", "1.534100 | \n", "1.576102 | \n", "
4300 | \n", "1.517400 | \n", "1.578320 | \n", "
4400 | \n", "1.518500 | \n", "1.588920 | \n", "
4500 | \n", "1.510200 | \n", "1.596100 | \n", "
4600 | \n", "1.441100 | \n", "1.576099 | \n", "
4700 | \n", "1.511000 | \n", "1.575001 | \n", "
4800 | \n", "1.487700 | \n", "1.579319 | \n", "
4900 | \n", "1.491300 | \n", "1.591276 | \n", "
5000 | \n", "1.474700 | \n", "1.572709 | \n", "
"
],
"text/plain": [
"