diff --git "a/train.ipynb" "b/train.ipynb" new file mode 100644--- /dev/null +++ "b/train.ipynb" @@ -0,0 +1,2447 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Install" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install uv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!uv pip install dagshub setuptools accelerate toml torch torchvision transformers mlflow datasets ipywidgets python-dotenv evaluate" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Initialized MLflow to track repo \"amaye15/CanineNet\"\n",
+       "
\n" + ], + "text/plain": [ + "Initialized MLflow to track repo \u001b[32m\"amaye15/CanineNet\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Repository amaye15/CanineNet initialized!\n",
+       "
\n" + ], + "text/plain": [ + "Repository amaye15/CanineNet initialized!\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import os\n", + "import toml\n", + "import torch\n", + "import mlflow\n", + "import dagshub\n", + "import datasets\n", + "import evaluate\n", + "from dotenv import load_dotenv\n", + "from torchvision.transforms import v2\n", + "from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer\n", + "\n", + "ENV_PATH = \"/Users/andrewmayes/Openclassroom/CanineNet/.env\"\n", + "CONFIG_PATH = \"/Users/andrewmayes/Openclassroom/CanineNet/code/config.toml\"\n", + "CONFIG = toml.load(CONFIG_PATH)\n", + "\n", + "load_dotenv(ENV_PATH)\n", + "\n", + "dagshub.init(repo_name=os.environ['MLFLOW_TRACKING_PROJECTNAME'], repo_owner=os.environ['MLFLOW_TRACKING_USERNAME'], mlflow=True, dvc=True)\n", + "\n", + "os.environ['MLFLOW_TRACKING_USERNAME'] = \"amaye15\"\n", + "\n", + "mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME']\n", + " + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')\n", + "\n", + "CREATE_DATASET = True\n", + "ORIGINAL_DATASET = \"Alanox/stanford-dogs\"\n", + "MODIFIED_DATASET = \"amaye15/stanford-dogs\"\n", + "REMOVE_COLUMNS = [\"name\", \"annotations\"]\n", + "RENAME_COLUMNS = {\"image\":\"pixel_values\", \"target\":\"label\"}\n", + "SPLIT = 0.2\n", + "\n", + "METRICS = [\"accuracy\", \"f1\", \"precision\", \"recall\"]\n", + "# MODELS = 'google/vit-base-patch16-224'\n", + "# MODELS = \"google/siglip-base-patch16-224\"\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Affenpinscher: 0\n", + "Afghan Hound: 1\n", + "African Hunting Dog: 2\n", + "Airedale: 3\n", + "American Staffordshire Terrier: 4\n", + "Appenzeller: 5\n", + "Australian Terrier: 6\n", + "Basenji: 7\n", + "Basset: 8\n", + "Beagle: 9\n", + "Bedlington Terrier: 10\n", + "Bernese Mountain Dog: 11\n", + "Black And Tan Coonhound: 12\n", + "Blenheim Spaniel: 13\n", + "Bloodhound: 14\n", + "Bluetick: 15\n", + "Border Collie: 16\n", + "Border Terrier: 17\n", + "Borzoi: 18\n", + "Boston Bull: 19\n", + "Bouvier Des Flandres: 20\n", + "Boxer: 21\n", + "Brabancon Griffon: 22\n", + "Briard: 23\n", + "Brittany Spaniel: 24\n", + "Bull Mastiff: 25\n", + "Cairn: 26\n", + "Cardigan: 27\n", + "Chesapeake Bay Retriever: 28\n", + "Chihuahua: 29\n", + "Chow: 30\n", + "Clumber: 31\n", + "Cocker Spaniel: 32\n", + "Collie: 33\n", + "Curly Coated Retriever: 34\n", + "Dandie Dinmont: 35\n", + "Dhole: 36\n", + "Dingo: 37\n", + "Doberman: 38\n", + "English Foxhound: 39\n", + "English Setter: 40\n", + "English Springer: 41\n", + "Entlebucher: 42\n", + "Eskimo Dog: 43\n", + "Flat Coated Retriever: 44\n", + "French Bulldog: 45\n", + "German Shepherd: 46\n", + "German Short Haired Pointer: 47\n", + "Giant Schnauzer: 48\n", + "Golden Retriever: 49\n", + "Gordon Setter: 50\n", + "Great Dane: 51\n", + "Great Pyrenees: 52\n", + "Greater Swiss Mountain Dog: 53\n", + "Groenendael: 54\n", + "Ibizan Hound: 55\n", + "Irish Setter: 56\n", + "Irish Terrier: 57\n", + "Irish Water Spaniel: 58\n", + "Irish Wolfhound: 59\n", + "Italian Greyhound: 60\n", + "Japanese Spaniel: 61\n", + "Keeshond: 62\n", + "Kelpie: 63\n", + "Kerry Blue Terrier: 64\n", + "Komondor: 65\n", + "Kuvasz: 66\n", + "Labrador Retriever: 67\n", + "Lakeland Terrier: 68\n", + "Leonberg: 69\n", + "Lhasa: 70\n", + "Malamute: 71\n", + "Malinois: 72\n", + "Maltese Dog: 73\n", + "Mexican Hairless: 74\n", + "Miniature Pinscher: 75\n", + "Miniature Poodle: 76\n", + "Miniature Schnauzer: 77\n", + "Newfoundland: 78\n", + "Norfolk Terrier: 79\n", + "Norwegian Elkhound: 80\n", + "Norwich Terrier: 81\n", + "Old English Sheepdog: 82\n", + "Otterhound: 83\n", + "Papillon: 84\n", + "Pekinese: 85\n", + "Pembroke: 86\n", + "Pomeranian: 87\n", + "Pug: 88\n", + "Redbone: 89\n", + "Rhodesian Ridgeback: 90\n", + "Rottweiler: 91\n", + "Saint Bernard: 92\n", + "Saluki: 93\n", + "Samoyed: 94\n", + "Schipperke: 95\n", + "Scotch Terrier: 96\n", + "Scottish Deerhound: 97\n", + "Sealyham Terrier: 98\n", + "Shetland Sheepdog: 99\n", + "Shih Tzu: 100\n", + "Siberian Husky: 101\n", + "Silky Terrier: 102\n", + "Soft Coated Wheaten Terrier: 103\n", + "Staffordshire Bullterrier: 104\n", + "Standard Poodle: 105\n", + "Standard Schnauzer: 106\n", + "Sussex Spaniel: 107\n", + "Tibetan Mastiff: 108\n", + "Tibetan Terrier: 109\n", + "Toy Poodle: 110\n", + "Toy Terrier: 111\n", + "Vizsla: 112\n", + "Walker Hound: 113\n", + "Weimaraner: 114\n", + "Welsh Springer Spaniel: 115\n", + "West Highland White Terrier: 116\n", + "Whippet: 117\n", + "Wire Haired Fox Terrier: 118\n", + "Yorkshire Terrier: 119\n" + ] + } + ], + "source": [ + "if CREATE_DATASET:\n", + " ds = datasets.load_dataset(ORIGINAL_DATASET, token=os.getenv(\"HF_TOKEN\"), split=\"full\", trust_remote_code=True)\n", + " ds = ds.remove_columns(REMOVE_COLUMNS).rename_columns(RENAME_COLUMNS)\n", + "\n", + " labels = ds.select_columns(\"label\").to_pandas().sort_values(\"label\").get(\"label\").unique().tolist()\n", + " numbers = range(len(labels))\n", + " label2int = dict(zip(labels, numbers))\n", + " int2label = dict(zip(numbers, labels))\n", + "\n", + " for key, val in label2int.items():\n", + " print(f\"{key}: {val}\")\n", + "\n", + " ds = ds.class_encode_column(\"label\")\n", + " ds = ds.align_labels_with_mapping(label2int, \"label\")\n", + "\n", + " ds = ds.train_test_split(test_size=SPLIT, stratify_by_column = \"label\")\n", + " #ds.push_to_hub(MODIFIED_DATASET, token=os.getenv(\"HF_TOKEN\"))\n", + "\n", + " CONFIG[\"label2int\"] = str(label2int)\n", + " CONFIG[\"int2label\"] = str(int2label)\n", + "\n", + " # with open(\"output.toml\", \"w\") as toml_file:\n", + " # toml.dump(toml.dumps(CONFIG), toml_file)\n", + "\n", + " #ds = datasets.load_dataset(MODIFIED_DATASET, token=os.getenv(\"HF_TOKEN\"), trust_remote_code=True, streaming=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:\n", + "- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([120]) in the model instantiated\n", + "- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([120, 768]) in the model instantiated\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8956098b0d16497b98dd963b5ee39e30", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00