diff --git a/.github/workflows/push-huggingface.yml b/.github/workflows/push-huggingface.yml
new file mode 100644
index 0000000000000000000000000000000000000000..317156dfc9b2d7cbdf2c0ca5ac12d4e3b070108b
--- /dev/null
+++ b/.github/workflows/push-huggingface.yml
@@ -0,0 +1,22 @@
+name: Push to Hugging Face
+
+on:
+ push:
+ branches: [ "master" ]
+
+jobs:
+ push:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Push repository to Hugging Face
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ git config --global user.email "phuochungus@gmail.com"
+ git config --global user.name "HungNP"
+ git remote add space https://huggingface.co./spaces/phuochungus/PyCIL_Stanford_Car
+ git checkout -b main
+ git reset $(git commit-tree HEAD^{tree} -m "New single commit message")
+ git push --force https://phuochungus:$HF_TOKEN@huggingface.co/spaces/phuochungus/PyCIL_Stanford_Car main
+ git push --force https://phuochungus:$HF_TOKEN@huggingface.co/spaces/DevSecOpAI/PyCIL main
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e746b5623d4ac019731bcf2e3796edf99ea87b15
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+data/
+__pycache__/
+logs/
+.env
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..045b2a92de917713413b8c59bd165d7f10f8fb00
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,30 @@
+FROM python:3.8.5
+
+RUN useradd -m -u 1000 user
+ENV HOME=/home/user \
+ PATH=/home/user/.local/bin:$PATH
+WORKDIR $HOME
+
+RUN apt-get update && apt-get install -y unzip
+
+RUN pip install --no-cache-dir --upgrade pip
+RUN pip install Cython
+RUN pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+
+COPY --chown=user requirements.txt requirements.txt
+
+RUN pip install -r requirements.txt
+
+COPY --chown=user download_dataset.sh download_dataset.sh
+
+RUN chmod +x download_dataset.sh
+
+RUN ./download_dataset.sh
+
+COPY --chown=user . .
+
+RUN chmod +x install_awscli.sh && ./install_awscli.sh
+
+RUN chmod +x entrypoint.sh upload_s3.sh simple_train.sh train_from_working.sh
+
+ENTRYPOINT [ "./entrypoint.sh" ]
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..1946f39efccb4fceae1752928617ea8fd99552d2
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,43 @@
+MIT License
+
+Copyright (c) 2020 Changhong Zhong
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+MIT License
+
+Copyright (c) 2021 Fu-Yun Wang.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6ffe0fb3989e2c1f8c52b9c25c7860703743b373
--- /dev/null
+++ b/README.md
@@ -0,0 +1,248 @@
+---
+title: Pycil
+emoji: 🍳
+colorFrom: red
+colorTo: red
+sdk: docker
+pinned: false
+---
+# PyCIL: A Python Toolbox for Class-Incremental Learning
+
+---
+
+
+ Introduction •
+ Methods Reproduced •
+ Reproduced Results •
+ How To Use •
+ License •
+ Acknowledgments •
+ Contact
+
+
+
+
+
+
+---
+
+
+
+
+
+[![LICENSE](https://img.shields.io/badge/license-MIT-green?style=flat-square)](https://github.com/yaoyao-liu/class-incremental-learning/blob/master/LICENSE)[![Python](https://img.shields.io/badge/python-3.8-blue.svg?style=flat-square&logo=python&color=3776AB&logoColor=3776AB)](https://www.python.org/) [![PyTorch](https://img.shields.io/badge/pytorch-1.8-%237732a8?style=flat-square&logo=PyTorch&color=EE4C2C)](https://pytorch.org/) [![method](https://img.shields.io/badge/Reproduced-20-success)]() [![CIL](https://img.shields.io/badge/ClassIncrementalLearning-SOTA-success??style=for-the-badge&logo=appveyor)](https://paperswithcode.com/task/incremental-learning)
+![visitors](https://visitor-badge.laobi.icu/badge?page_id=LAMDA.PyCIL&left_color=green&right_color=red)
+
+
+
+Welcome to PyCIL, perhaps the toolbox for class-incremental learning with the **most** implemented methods. This is the code repository for "PyCIL: A Python Toolbox for Class-Incremental Learning" [[paper]](https://arxiv.org/abs/2112.12533) in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:
+
+ @article{zhou2023pycil,
+ author = {Da-Wei Zhou and Fu-Yun Wang and Han-Jia Ye and De-Chuan Zhan},
+ title = {PyCIL: a Python toolbox for class-incremental learning},
+ journal = {SCIENCE CHINA Information Sciences},
+ year = {2023},
+ volume = {66},
+ number = {9},
+ pages = {197101-},
+ doi = {https://doi.org/10.1007/s11432-022-3600-y}
+ }
+
+ @article{zhou2023class,
+ author = {Zhou, Da-Wei and Wang, Qi-Wei and Qi, Zhi-Hong and Ye, Han-Jia and Zhan, De-Chuan and Liu, Ziwei},
+ title = {Deep Class-Incremental Learning: A Survey},
+ journal = {arXiv preprint arXiv:2302.03648},
+ year = {2023}
+ }
+
+
+## What's New
+- [2024-03]🌟 Check out our [latest work](https://arxiv.org/abs/2403.12030) on pre-trained model-based class-incremental learning!
+- [2024-01]🌟 Check out our [latest survey](https://arxiv.org/abs/2401.16386) on pre-trained model-based continual learning!
+- [2023-09]🌟 We have released [PILOT](https://github.com/sun-hailong/LAMDA-PILOT) toolbox for class-incremental learning with pre-trained models. Have a try!
+- [2023-07]🌟 Add [MEMO](https://openreview.net/forum?id=S07feAlQHgM), [BEEF](https://openreview.net/forum?id=iP77_axu0h3), and [SimpleCIL](https://arxiv.org/abs/2303.07338). State-of-the-art methods of 2023!
+- [2023-05]🌟 Check out our recent work about [class-incremental learning with vision-language models](https://arxiv.org/abs/2305.19270)!
+- [2023-02]🌟 Check out our [rigorous and unified survey](https://arxiv.org/abs/2302.03648) about class-incremental learning, which introduces some memory-agnostic measures with holistic evaluations from multiple aspects!
+- [2022-12]🌟 Add FrTrIL, PASS, IL2A, and SSRE.
+- [2022-10]🌟 PyCIL has been published in [SCIENCE CHINA Information Sciences](https://link.springer.com/article/10.1007/s11432-022-3600-y). Check out the [official introduction](https://mp.weixin.qq.com/s/h1qu2LpdvjeHAPLOnG478A)!
+- [2022-08]🌟 Add RMM.
+- [2022-07]🌟 Add [FOSTER](https://arxiv.org/abs/2204.04662). State-of-the-art method with a single backbone!
+- [2021-12]🌟 **Call For Feedback**: We add a section to introduce awesome works using PyCIL. If you are using PyCIL to publish your work in top-tier conferences/journals, feel free to [contact us](mailto:zhoudw@lamda.nju.edu.cn) for details!
+
+## Introduction
+
+Traditional machine learning systems are deployed under the closed-world setting, which requires the entire training data before the offline training process. However, real-world applications often face the incoming new classes, and a model should incorporate them continually. The learning paradigm is called Class-Incremental Learning (CIL). We propose a Python toolbox that implements several key algorithms for class-incremental learning to ease the burden of researchers in the machine learning community. The toolbox contains implementations of a number of founding works of CIL, such as EWC and iCaRL, but also provides current state-of-the-art algorithms that can be used for conducting novel fundamental research. This toolbox, named PyCIL for Python Class-Incremental Learning, is open source with an MIT license.
+
+For more information about incremental learning, you can refer to these reading materials:
+- A brief introduction (in Chinese) about CIL is available [here](https://zhuanlan.zhihu.com/p/490308909).
+- A PyTorch Tutorial to Class-Incremental Learning (with explicit codes and detailed explanations) is available [here](https://github.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning).
+
+## Methods Reproduced
+
+- `FineTune`: Baseline method which simply updates parameters on new tasks.
+- `EWC`: Overcoming catastrophic forgetting in neural networks. PNAS2017 [[paper](https://arxiv.org/abs/1612.00796)]
+- `LwF`: Learning without Forgetting. ECCV2016 [[paper](https://arxiv.org/abs/1606.09282)]
+- `Replay`: Baseline method with exemplar replay.
+- `GEM`: Gradient Episodic Memory for Continual Learning. NIPS2017 [[paper](https://arxiv.org/abs/1706.08840)]
+- `iCaRL`: Incremental Classifier and Representation Learning. CVPR2017 [[paper](https://arxiv.org/abs/1611.07725)]
+- `BiC`: Large Scale Incremental Learning. CVPR2019 [[paper](https://arxiv.org/abs/1905.13260)]
+- `WA`: Maintaining Discrimination and Fairness in Class Incremental Learning. CVPR2020 [[paper](https://arxiv.org/abs/1911.07053)]
+- `PODNet`: PODNet: Pooled Outputs Distillation for Small-Tasks Incremental Learning. ECCV2020 [[paper](https://arxiv.org/abs/2004.13513)]
+- `DER`: DER: Dynamically Expandable Representation for Class Incremental Learning. CVPR2021 [[paper](https://arxiv.org/abs/2103.16788)]
+- `PASS`: Prototype Augmentation and Self-Supervision for Incremental Learning. CVPR2021 [[paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhu_Prototype_Augmentation_and_Self-Supervision_for_Incremental_Learning_CVPR_2021_paper.pdf)]
+- `RMM`: RMM: Reinforced Memory Management for Class-Incremental Learning. NeurIPS2021 [[paper](https://proceedings.neurips.cc/paper/2021/hash/1cbcaa5abbb6b70f378a3a03d0c26386-Abstract.html)]
+- `IL2A`: Class-Incremental Learning via Dual Augmentation. NeurIPS2021 [[paper](https://proceedings.neurips.cc/paper/2021/file/77ee3bc58ce560b86c2b59363281e914-Paper.pdf)]
+- `SSRE`: Self-Sustaining Representation Expansion for Non-Exemplar Class-Incremental Learning. CVPR2022 [[paper](https://arxiv.org/abs/2203.06359)]
+- `FeTrIL`: Feature Translation for Exemplar-Free Class-Incremental Learning. WACV2023 [[paper](https://arxiv.org/abs/2211.13131)]
+- `Coil`: Co-Transport for Class-Incremental Learning. ACM MM2021 [[paper](https://arxiv.org/abs/2107.12654)]
+- `FOSTER`: Feature Boosting and Compression for Class-incremental Learning. ECCV 2022 [[paper](https://arxiv.org/abs/2204.04662)]
+- `MEMO`: A Model or 603 Exemplars: Towards Memory-Efficient Class-Incremental Learning. ICLR 2023 Spotlight [[paper](https://openreview.net/forum?id=S07feAlQHgM)]
+- `BEEF`: BEEF: Bi-Compatible Class-Incremental Learning via Energy-Based Expansion and Fusion. ICLR 2023 [[paper](https://openreview.net/forum?id=iP77_axu0h3)]
+- `SimpleCIL`: Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need. arXiv 2023 [[paper](https://arxiv.org/abs/2303.07338)]
+
+> Intended authors are welcome to contact us to reproduce your methods in our repo. Feel free to merge your algorithm into PyCIL if you are using our codebase!
+
+## Reproduced Results
+
+#### CIFAR-100
+
+
+
+
+
+
+#### ImageNet-100
+
+
+
+
+
+#### ImageNet-100 (Top-5 Accuracy)
+
+
+
+
+
+> More experimental details and results can be found in our [survey](https://arxiv.org/abs/2302.03648).
+
+## How To Use
+
+### Clone
+
+Clone this GitHub repository:
+
+```
+git clone https://github.com/G-U-N/PyCIL.git
+cd PyCIL
+```
+
+### Dependencies
+
+1. [torch 1.81](https://github.com/pytorch/pytorch)
+2. [torchvision 0.6.0](https://github.com/pytorch/vision)
+3. [tqdm](https://github.com/tqdm/tqdm)
+4. [numpy](https://github.com/numpy/numpy)
+5. [scipy](https://github.com/scipy/scipy)
+6. [quadprog](https://github.com/quadprog/quadprog)
+7. [POT](https://github.com/PythonOT/POT)
+
+### Run experiment
+
+1. Edit the `[MODEL NAME].json` file for global settings.
+2. Edit the hyperparameters in the corresponding `[MODEL NAME].py` file (e.g., `models/icarl.py`).
+3. Run:
+
+```bash
+python main.py --config=./exps/[MODEL NAME].json
+```
+
+where [MODEL NAME] should be chosen from `finetune`, `ewc`, `lwf`, `replay`, `gem`, `icarl`, `bic`, `wa`, `podnet`, `der`, etc.
+
+4. `hyper-parameters`
+
+When using PyCIL, you can edit the global parameters and algorithm-specific hyper-parameter in the corresponding json file.
+
+These parameters include:
+
+- **memory-size**: The total exemplar number in the incremental learning process. Assuming there are $K$ classes at the current stage, the model will preserve $\left[\frac{memory-size}{K}\right]$ exemplar per class.
+- **init-cls**: The number of classes in the first incremental stage. Since there are different settings in CIL with a different number of classes in the first stage, our framework enables different choices to define the initial stage.
+- **increment**: The number of classes in each incremental stage $i$, $i$ > 1. By default, the number of classes per incremental stage is equivalent per stage.
+- **convnet-type**: The backbone network for the incremental model. According to the benchmark setting, `ResNet32` is utilized for `CIFAR100`, and `ResNet18` is used for `ImageNet`.
+- **seed**: The random seed adopted for shuffling the class order. According to the benchmark setting, it is set to 1993 by default.
+
+Other parameters in terms of model optimization, e.g., batch size, optimization epoch, learning rate, learning rate decay, weight decay, milestone, and temperature, can be modified in the corresponding Python file.
+
+### Datasets
+
+We have implemented the pre-processing of `CIFAR100`, `imagenet100,` and `imagenet1000`. When training on `CIFAR100`, this framework will automatically download it. When training on `imagenet100/1000`, you should specify the folder of your dataset in `utils/data.py`.
+
+```python
+ def download_data(self):
+ assert 0,"You should specify the folder of your dataset"
+ train_dir = '[DATA-PATH]/train/'
+ test_dir = '[DATA-PATH]/val/'
+```
+[Here](https://drive.google.com/drive/folders/1RBrPGrZzd1bHU5YG8PjdfwpHANZR_lhJ?usp=sharing) is the file list of ImageNet100 (or say ImageNet-Sub).
+
+## Awesome Papers using PyCIL
+
+### Our Papers
+- Expandable Subspace Ensemble for Pre-Trained Model-Based Class-Incremental Learning (**CVPR 2024**) [[paper](https://arxiv.org/abs/2403.12030 )] [[code](https://github.com/sun-hailong/CVPR24-Ease)]
+
+- Continual Learning with Pre-Trained Models: A Survey (**arXiv 2024**) [[paper](https://arxiv.org/abs/2401.16386)] [[code](https://github.com/sun-hailong/LAMDA-PILOT)]
+
+- Deep Class-Incremental Learning: A Survey (**arXiv 2023**) [[paper](https://arxiv.org/abs/2302.03648)] [[code](https://github.com/zhoudw-zdw/CIL_Survey/)]
+
+- Learning without Forgetting for Vision-Language Models (**arXiv 2023**) [[paper](https://arxiv.org/abs/2305.19270)]
+
+- Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need (**arXiv 2023**) [[paper](https://arxiv.org/abs/2303.07338)] [[code](https://github.com/zhoudw-zdw/RevisitingCIL)]
+
+- PILOT: A Pre-Trained Model-Based Continual Learning Toolbox (**arXiv 2023**) [[paper](https://arxiv.org/abs/2309.07117)] [[code](https://github.com/sun-hailong/LAMDA-PILOT)]
+
+- Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration (**NeurIPS 2023**)[[paper](https://arxiv.org/abs/2312.05229)] [[Code](https://github.com/wangkiw/TEEN)]
+
+- BEEF: Bi-Compatible Class-Incremental Learning via Energy-Based Expansion and Fusion (**ICLR 2023**) [[paper](https://openreview.net/forum?id=iP77_axu0h3)] [[code](https://github.com/G-U-N/ICLR23-BEEF/)]
+
+- A model or 603 exemplars: Towards memory-efficient class-incremental learning (**ICLR 2023**) [[paper](https://arxiv.org/abs/2205.13218)] [[code](https://github.com/wangkiw/ICLR23-MEMO/)]
+
+- Few-shot class-incremental learning by sampling multi-phase tasks (**TPAMI 2022**) [[paper](https://arxiv.org/pdf/2203.17030.pdf)] [[code](https://github.com/zhoudw-zdw/TPAMI-Limit)]
+
+- Foster: Feature Boosting and Compression for Class-incremental Learning (**ECCV 2022**) [[paper](https://arxiv.org/abs/2204.04662)] [[code](https://github.com/G-U-N/ECCV22-FOSTER/)]
+
+- Forward compatible few-shot class-incremental learning (**CVPR 2022**) [[paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Zhou_Forward_Compatible_Few-Shot_Class-Incremental_Learning_CVPR_2022_paper.pdf)] [[code](https://github.com/zhoudw-zdw/CVPR22-Fact)]
+
+- Co-Transport for Class-Incremental Learning (**ACM MM 2021**) [[paper](https://arxiv.org/abs/2107.12654)] [[code](https://github.com/zhoudw-zdw/MM21-Coil)]
+
+### Other Awesome Works
+
+- Towards Realistic Evaluation of Industrial Continual Learning Scenarios with an Emphasis on Energy Consumption and Computational Footprint (**ICCV 2023**) [[paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Chavan_Towards_Realistic_Evaluation_of_Industrial_Continual_Learning_Scenarios_with_an_ICCV_2023_paper.pdf)][[code](https://github.com/Vivek9Chavan/RECIL)]
+
+- Dynamic Residual Classifier for Class Incremental Learning (**ICCV 2023**) [[paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Chen_Dynamic_Residual_Classifier_for_Class_Incremental_Learning_ICCV_2023_paper.pdf)][[code](https://github.com/chen-xw/DRC-CIL)]
+
+- S-Prompts Learning with Pre-trained Transformers: An Occam's Razor for Domain Incremental Learning (**NeurIPS 2022**) [[paper](https://openreview.net/forum?id=ZVe_WeMold)] [[code](https://github.com/iamwangyabin/S-Prompts)]
+
+
+## License
+
+Please check the MIT [license](./LICENSE) that is listed in this repository.
+
+## Acknowledgments
+
+We thank the following repos providing helpful components/functions in our work.
+
+- [Continual-Learning-Reproduce](https://github.com/zhchuu/continual-learning-reproduce)
+- [GEM](https://github.com/hursung1/GradientEpisodicMemory)
+- [FACIL](https://github.com/mmasana/FACIL)
+
+The training flow and data configurations are based on Continual-Learning-Reproduce. The original information of the repo is available in the base branch.
+
+
+## Contact
+
+If there are any questions, please feel free to propose new features by opening an issue or contact with the author: **Da-Wei Zhou**([zhoudw@lamda.nju.edu.cn](mailto:zhoudw@lamda.nju.edu.cn)) and **Fu-Yun Wang**(wangfuyun@smail.nju.edu.cn). Enjoy the code.
+
+
+## Star History 🚀
+
+[![Star History Chart](https://api.star-history.com/svg?repos=G-U-N/PyCIL&type=Date)](https://star-history.com/#G-U-N/PyCIL&Date)
+
diff --git a/convs/__init__.py b/convs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/convs/cifar_resnet.py b/convs/cifar_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2668dd45ccd05b6a881b544b6008abd8d8b58af
--- /dev/null
+++ b/convs/cifar_resnet.py
@@ -0,0 +1,207 @@
+'''
+Reference:
+https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py
+'''
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DownsampleA(nn.Module):
+ def __init__(self, nIn, nOut, stride):
+ super(DownsampleA, self).__init__()
+ assert stride == 2
+ self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
+
+ def forward(self, x):
+ x = self.avg(x)
+ return torch.cat((x, x.mul(0)), 1)
+
+
+class DownsampleB(nn.Module):
+ def __init__(self, nIn, nOut, stride):
+ super(DownsampleB, self).__init__()
+ self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
+ self.bn = nn.BatchNorm2d(nOut)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class DownsampleC(nn.Module):
+ def __init__(self, nIn, nOut, stride):
+ super(DownsampleC, self).__init__()
+ assert stride != 1 or nIn != nOut
+ self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class DownsampleD(nn.Module):
+ def __init__(self, nIn, nOut, stride):
+ super(DownsampleD, self).__init__()
+ assert stride == 2
+ self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
+ self.bn = nn.BatchNorm2d(nOut)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class ResNetBasicblock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(ResNetBasicblock, self).__init__()
+
+ self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn_a = nn.BatchNorm2d(planes)
+
+ self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn_b = nn.BatchNorm2d(planes)
+
+ self.downsample = downsample
+
+ def forward(self, x):
+ residual = x
+
+ basicblock = self.conv_a(x)
+ basicblock = self.bn_a(basicblock)
+ basicblock = F.relu(basicblock, inplace=True)
+
+ basicblock = self.conv_b(basicblock)
+ basicblock = self.bn_b(basicblock)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ return F.relu(residual + basicblock, inplace=True)
+
+
+class CifarResNet(nn.Module):
+ """
+ ResNet optimized for the Cifar Dataset, as specified in
+ https://arxiv.org/abs/1512.03385.pdf
+ """
+
+ def __init__(self, block, depth, channels=3):
+ super(CifarResNet, self).__init__()
+
+ # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
+ assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
+ layer_blocks = (depth - 2) // 6
+
+ self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn_1 = nn.BatchNorm2d(16)
+
+ self.inplanes = 16
+ self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
+ self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
+ self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
+ self.avgpool = nn.AvgPool2d(8)
+ self.out_dim = 64 * block.expansion
+ self.fc = nn.Linear(64*block.expansion, 10)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ # m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv_1_3x3(x) # [bs, 16, 32, 32]
+ x = F.relu(self.bn_1(x), inplace=True)
+
+ x_1 = self.stage_1(x) # [bs, 16, 32, 32]
+ x_2 = self.stage_2(x_1) # [bs, 32, 16, 16]
+ x_3 = self.stage_3(x_2) # [bs, 64, 8, 8]
+
+ pooled = self.avgpool(x_3) # [bs, 64, 1, 1]
+ features = pooled.view(pooled.size(0), -1) # [bs, 64]
+
+ return {
+ 'fmaps': [x_1, x_2, x_3],
+ 'features': features
+ }
+
+ @property
+ def last_conv(self):
+ return self.stage_3[-1].conv_b
+
+
+def resnet20mnist():
+ """Constructs a ResNet-20 model for MNIST."""
+ model = CifarResNet(ResNetBasicblock, 20, 1)
+ return model
+
+
+def resnet32mnist():
+ """Constructs a ResNet-32 model for MNIST."""
+ model = CifarResNet(ResNetBasicblock, 32, 1)
+ return model
+
+
+def resnet20():
+ """Constructs a ResNet-20 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 20)
+ return model
+
+
+def resnet32():
+ """Constructs a ResNet-32 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 32)
+ return model
+
+
+def resnet44():
+ """Constructs a ResNet-44 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 44)
+ return model
+
+
+def resnet56():
+ """Constructs a ResNet-56 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 56)
+ return model
+
+
+def resnet110():
+ """Constructs a ResNet-110 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 110)
+ return model
+
+# for auc
+def resnet14():
+ model = CifarResNet(ResNetBasicblock, 14)
+ return model
+
+def resnet26():
+ model = CifarResNet(ResNetBasicblock, 26)
+ return model
\ No newline at end of file
diff --git a/convs/conv_cifar.py b/convs/conv_cifar.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c337270b92312b0dc7ed37b2f7c937345f40696
--- /dev/null
+++ b/convs/conv_cifar.py
@@ -0,0 +1,77 @@
+'''
+For MEMO implementations of CIFAR-ConvNet
+Reference:
+https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_cifar.py
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# for cifar
+def conv_block(in_channels, out_channels):
+ return nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+
+class ConvNet2(nn.Module):
+ def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
+ super().__init__()
+ self.out_dim = 64
+ self.avgpool = nn.AvgPool2d(8)
+ self.encoder = nn.Sequential(
+ conv_block(x_dim, hid_dim),
+ conv_block(hid_dim, z_dim),
+ )
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.avgpool(x)
+ features = x.view(x.shape[0], -1)
+ return {
+ "features":features
+ }
+
+class GeneralizedConvNet2(nn.Module):
+ def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
+ super().__init__()
+ self.encoder = nn.Sequential(
+ conv_block(x_dim, hid_dim),
+ )
+
+ def forward(self, x):
+ base_features = self.encoder(x)
+ return base_features
+
+class SpecializedConvNet2(nn.Module):
+ def __init__(self,hid_dim=64,z_dim=64):
+ super().__init__()
+ self.feature_dim = 64
+ self.avgpool = nn.AvgPool2d(8)
+ self.AdaptiveBlock = conv_block(hid_dim,z_dim)
+
+ def forward(self,x):
+ base_features = self.AdaptiveBlock(x)
+ pooled = self.avgpool(base_features)
+ features = pooled.view(pooled.size(0),-1)
+ return features
+
+def conv2():
+ return ConvNet2()
+
+def get_conv_a2fc():
+ basenet = GeneralizedConvNet2()
+ adaptivenet = SpecializedConvNet2()
+ return basenet,adaptivenet
+
+if __name__ == '__main__':
+ a, b = get_conv_a2fc()
+ _base = sum(p.numel() for p in a.parameters())
+ _adap = sum(p.numel() for p in b.parameters())
+ print(f"conv :{_base+_adap}")
+
+ conv2 = conv2()
+ conv2_sum = sum(p.numel() for p in conv2.parameters())
+ print(f"conv2 :{conv2_sum}")
\ No newline at end of file
diff --git a/convs/conv_imagenet.py b/convs/conv_imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..59793b2184b938250d0a55fff1786d1fc724ed73
--- /dev/null
+++ b/convs/conv_imagenet.py
@@ -0,0 +1,82 @@
+'''
+For MEMO implementations of ImageNet-ConvNet
+Reference:
+https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_imagenet.py
+'''
+import torch.nn as nn
+import torch
+
+# for imagenet
+def first_block(in_channels, out_channels):
+ return nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+
+def conv_block(in_channels, out_channels):
+ return nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+
+class ConvNet(nn.Module):
+ def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
+ super().__init__()
+ self.block1 = first_block(x_dim, hid_dim)
+ self.block2 = conv_block(hid_dim, hid_dim)
+ self.block3 = conv_block(hid_dim, hid_dim)
+ self.block4 = conv_block(hid_dim, z_dim)
+ self.avgpool = nn.AvgPool2d(7)
+ self.out_dim = 512
+
+ def forward(self, x):
+ x = self.block1(x)
+ x = self.block2(x)
+ x = self.block3(x)
+ x = self.block4(x)
+
+ x = self.avgpool(x)
+ features = x.view(x.shape[0], -1)
+
+ return {
+ "features": features
+ }
+
+class GeneralizedConvNet(nn.Module):
+ def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
+ super().__init__()
+ self.block1 = first_block(x_dim, hid_dim)
+ self.block2 = conv_block(hid_dim, hid_dim)
+ self.block3 = conv_block(hid_dim, hid_dim)
+
+ def forward(self, x):
+ x = self.block1(x)
+ x = self.block2(x)
+ x = self.block3(x)
+ return x
+
+class SpecializedConvNet(nn.Module):
+ def __init__(self, hid_dim=128,z_dim=512):
+ super().__init__()
+ self.block4 = conv_block(hid_dim, z_dim)
+ self.avgpool = nn.AvgPool2d(7)
+ self.feature_dim = 512
+
+ def forward(self, x):
+ x = self.block4(x)
+ x = self.avgpool(x)
+ features = x.view(x.shape[0], -1)
+ return features
+
+def conv4():
+ model = ConvNet()
+ return model
+
+def conv_a2fc_imagenet():
+ _base = GeneralizedConvNet()
+ _adaptive_net = SpecializedConvNet()
+ return _base, _adaptive_net
\ No newline at end of file
diff --git a/convs/linears.py b/convs/linears.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2eb0a316b68d7f520a6b1ff41613d3387fd49bc
--- /dev/null
+++ b/convs/linears.py
@@ -0,0 +1,167 @@
+'''
+Reference:
+https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py
+'''
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class SimpleLinear(nn.Module):
+ '''
+ Reference:
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
+ '''
+ def __init__(self, in_features, out_features, bias=True):
+ super(SimpleLinear, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_features))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.kaiming_uniform_(self.weight, nonlinearity='linear')
+ nn.init.constant_(self.bias, 0)
+
+ def forward(self, input):
+ return {'logits': F.linear(input, self.weight, self.bias)}
+
+
+class CosineLinear(nn.Module):
+ def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True):
+ super(CosineLinear, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features * nb_proxy
+ self.nb_proxy = nb_proxy
+ self.to_reduce = to_reduce
+ self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features))
+ if sigma:
+ self.sigma = nn.Parameter(torch.Tensor(1))
+ else:
+ self.register_parameter('sigma', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ stdv = 1. / math.sqrt(self.weight.size(1))
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.sigma is not None:
+ self.sigma.data.fill_(1)
+
+ def forward(self, input):
+ out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
+
+ if self.to_reduce:
+ # Reduce_proxy
+ out = reduce_proxies(out, self.nb_proxy)
+
+ if self.sigma is not None:
+ out = self.sigma * out
+
+ return {'logits': out}
+
+
+class SplitCosineLinear(nn.Module):
+ def __init__(self, in_features, out_features1, out_features2, nb_proxy=1, sigma=True):
+ super(SplitCosineLinear, self).__init__()
+ self.in_features = in_features
+ self.out_features = (out_features1 + out_features2) * nb_proxy
+ self.nb_proxy = nb_proxy
+ self.fc1 = CosineLinear(in_features, out_features1, nb_proxy, False, False)
+ self.fc2 = CosineLinear(in_features, out_features2, nb_proxy, False, False)
+ if sigma:
+ self.sigma = nn.Parameter(torch.Tensor(1))
+ self.sigma.data.fill_(1)
+ else:
+ self.register_parameter('sigma', None)
+
+ def forward(self, x):
+ out1 = self.fc1(x)
+ out2 = self.fc2(x)
+
+ out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel
+
+ # Reduce_proxy
+ out = reduce_proxies(out, self.nb_proxy)
+
+ if self.sigma is not None:
+ out = self.sigma * out
+
+ return {
+ 'old_scores': reduce_proxies(out1['logits'], self.nb_proxy),
+ 'new_scores': reduce_proxies(out2['logits'], self.nb_proxy),
+ 'logits': out
+ }
+
+
+def reduce_proxies(out, nb_proxy):
+ if nb_proxy == 1:
+ return out
+ bs = out.shape[0]
+ nb_classes = out.shape[1] / nb_proxy
+ assert nb_classes.is_integer(), 'Shape error'
+ nb_classes = int(nb_classes)
+
+ simi_per_class = out.view(bs, nb_classes, nb_proxy)
+ attentions = F.softmax(simi_per_class, dim=-1)
+
+ return (attentions * simi_per_class).sum(-1)
+
+
+'''
+class CosineLinear(nn.Module):
+ def __init__(self, in_features, out_features, sigma=True):
+ super(CosineLinear, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
+ if sigma:
+ self.sigma = nn.Parameter(torch.Tensor(1))
+ else:
+ self.register_parameter('sigma', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ stdv = 1. / math.sqrt(self.weight.size(1))
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.sigma is not None:
+ self.sigma.data.fill_(1)
+
+ def forward(self, input):
+ out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
+ if self.sigma is not None:
+ out = self.sigma * out
+ return {'logits': out}
+
+
+class SplitCosineLinear(nn.Module):
+ def __init__(self, in_features, out_features1, out_features2, sigma=True):
+ super(SplitCosineLinear, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features1 + out_features2
+ self.fc1 = CosineLinear(in_features, out_features1, False)
+ self.fc2 = CosineLinear(in_features, out_features2, False)
+ if sigma:
+ self.sigma = nn.Parameter(torch.Tensor(1))
+ self.sigma.data.fill_(1)
+ else:
+ self.register_parameter('sigma', None)
+
+ def forward(self, x):
+ out1 = self.fc1(x)
+ out2 = self.fc2(x)
+
+ out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel
+ if self.sigma is not None:
+ out = self.sigma * out
+
+ return {
+ 'old_scores': out1['logits'],
+ 'new_scores': out2['logits'],
+ 'logits': out
+ }
+'''
diff --git a/convs/memo_cifar_resnet.py b/convs/memo_cifar_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d585519f60bd2d421d277aed4affcb16e9d4af5
--- /dev/null
+++ b/convs/memo_cifar_resnet.py
@@ -0,0 +1,164 @@
+'''
+For MEMO implementations of CIFAR-ResNet
+Reference:
+https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py
+'''
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class DownsampleA(nn.Module):
+ def __init__(self, nIn, nOut, stride):
+ super(DownsampleA, self).__init__()
+ assert stride == 2
+ self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
+
+ def forward(self, x):
+ x = self.avg(x)
+ return torch.cat((x, x.mul(0)), 1)
+
+class ResNetBasicblock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(ResNetBasicblock, self).__init__()
+
+ self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn_a = nn.BatchNorm2d(planes)
+
+ self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn_b = nn.BatchNorm2d(planes)
+
+ self.downsample = downsample
+
+ def forward(self, x):
+ residual = x
+
+ basicblock = self.conv_a(x)
+ basicblock = self.bn_a(basicblock)
+ basicblock = F.relu(basicblock, inplace=True)
+
+ basicblock = self.conv_b(basicblock)
+ basicblock = self.bn_b(basicblock)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ return F.relu(residual + basicblock, inplace=True)
+
+
+
+class GeneralizedResNet_cifar(nn.Module):
+ def __init__(self, block, depth, channels=3):
+ super(GeneralizedResNet_cifar, self).__init__()
+ assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
+ layer_blocks = (depth - 2) // 6
+ self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn_1 = nn.BatchNorm2d(16)
+
+ self.inplanes = 16
+ self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
+ self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
+
+ self.out_dim = 64 * block.expansion
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ # m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv_1_3x3(x) # [bs, 16, 32, 32]
+ x = F.relu(self.bn_1(x), inplace=True)
+
+ x_1 = self.stage_1(x) # [bs, 16, 32, 32]
+ x_2 = self.stage_2(x_1) # [bs, 32, 16, 16]
+ return x_2
+
+class SpecializedResNet_cifar(nn.Module):
+ def __init__(self, block, depth, inplanes=32, feature_dim=64):
+ super(SpecializedResNet_cifar, self).__init__()
+ self.inplanes = inplanes
+ self.feature_dim = feature_dim
+ layer_blocks = (depth - 2) // 6
+ self.final_stage = self._make_layer(block, 64, layer_blocks, 2)
+ self.avgpool = nn.AvgPool2d(8)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ # m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=2):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+ return nn.Sequential(*layers)
+
+ def forward(self, base_feature_map):
+ final_feature_map = self.final_stage(base_feature_map)
+ pooled = self.avgpool(final_feature_map)
+ features = pooled.view(pooled.size(0), -1) #bs x 64
+ return features
+
+#For cifar & MEMO
+def get_resnet8_a2fc():
+ basenet = GeneralizedResNet_cifar(ResNetBasicblock,8)
+ adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,8)
+ return basenet,adaptivenet
+
+def get_resnet14_a2fc():
+ basenet = GeneralizedResNet_cifar(ResNetBasicblock,14)
+ adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,14)
+ return basenet,adaptivenet
+
+def get_resnet20_a2fc():
+ basenet = GeneralizedResNet_cifar(ResNetBasicblock,20)
+ adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,20)
+ return basenet,adaptivenet
+
+def get_resnet26_a2fc():
+ basenet = GeneralizedResNet_cifar(ResNetBasicblock,26)
+ adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,26)
+ return basenet,adaptivenet
+
+def get_resnet32_a2fc():
+ basenet = GeneralizedResNet_cifar(ResNetBasicblock,32)
+ adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,32)
+ return basenet,adaptivenet
+
+
diff --git a/convs/memo_resnet.py b/convs/memo_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..507b0bd60e35528b93a5820af373b1a1a06d2aff
--- /dev/null
+++ b/convs/memo_resnet.py
@@ -0,0 +1,322 @@
+'''
+For MEMO implementations of ImageNet-ResNet
+Reference:
+https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
+'''
+import torch
+import torch.nn as nn
+try:
+ from torchvision.models.utils import load_state_dict_from_url
+except:
+ from torch.hub import load_state_dict_from_url
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class GeneralizedResNet_imagenet(nn.Module):
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None):
+ super(GeneralizedResNet_imagenet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, # stride=2 -> stride=1 for cifar
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Removed in _forward_impl for cifar
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.out_dim = 512 * block.expansion
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+ return nn.Sequential(*layers)
+ def _forward_impl(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ x_1 = self.layer1(x)
+ x_2 = self.layer2(x_1)
+ x_3 = self.layer3(x_2)
+ return x_3
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+class SpecializedResNet_imagenet(nn.Module):
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None):
+ super(SpecializedResNet_imagenet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+ self.feature_dim = 512 * block.expansion
+ self.inplanes = 256 * block.expansion
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.out_dim = 512 * block.expansion
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def forward(self,x):
+ x_4 = self.layer4(x) # [bs, 512, 4, 4]
+ pooled = self.avgpool(x_4) # [bs, 512, 1, 1]
+ features = torch.flatten(pooled, 1) # [bs, 512]
+ return features
+
+def get_resnet10_imagenet():
+ basenet = GeneralizedResNet_imagenet(BasicBlock,[1, 1, 1, 1])
+ adaptivenet = SpecializedResNet_imagenet(BasicBlock, [1, 1, 1, 1])
+ return basenet,adaptivenet
+
+def get_resnet18_imagenet():
+ basenet = GeneralizedResNet_imagenet(BasicBlock,[2, 2, 2, 2])
+ adaptivenet = SpecializedResNet_imagenet(BasicBlock, [2, 2, 2, 2])
+ return basenet,adaptivenet
+
+def get_resnet26_imagenet():
+ basenet = GeneralizedResNet_imagenet(Bottleneck,[2, 2, 2, 2])
+ adaptivenet = SpecializedResNet_imagenet(Bottleneck, [2, 2, 2, 2])
+ return basenet,adaptivenet
+
+def get_resnet34_imagenet():
+ basenet = GeneralizedResNet_imagenet(BasicBlock,[3, 4, 6, 3])
+ adaptivenet = SpecializedResNet_imagenet(BasicBlock, [3, 4, 6, 3])
+ return basenet,adaptivenet
+
+def get_resnet50_imagenet():
+ basenet = GeneralizedResNet_imagenet(Bottleneck,[3, 4, 6, 3])
+ adaptivenet = SpecializedResNet_imagenet(Bottleneck, [3, 4, 6, 3])
+ return basenet,adaptivenet
+
+
+if __name__ == '__main__':
+ model2imagenet = 3*224*224
+
+ a, b = get_resnet10_imagenet()
+ _base = sum(p.numel() for p in a.parameters())
+ _adap = sum(p.numel() for p in b.parameters())
+ print(f"resnet10 #params:{_base+_adap}")
+
+ a, b = get_resnet18_imagenet()
+ _base = sum(p.numel() for p in a.parameters())
+ _adap = sum(p.numel() for p in b.parameters())
+ print(f"resnet18 #params:{_base+_adap}")
+
+ a, b = get_resnet26_imagenet()
+ _base = sum(p.numel() for p in a.parameters())
+ _adap = sum(p.numel() for p in b.parameters())
+ print(f"resnet26 #params:{_base+_adap}")
+
+ a, b = get_resnet34_imagenet()
+ _base = sum(p.numel() for p in a.parameters())
+ _adap = sum(p.numel() for p in b.parameters())
+ print(f"resnet34 #params:{_base+_adap}")
+
+ a, b = get_resnet50_imagenet()
+ _base = sum(p.numel() for p in a.parameters())
+ _adap = sum(p.numel() for p in b.parameters())
+ print(f"resnet50 #params:{_base+_adap}")
\ No newline at end of file
diff --git a/convs/modified_represnet.py b/convs/modified_represnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b451cadcd15b25420dabfc6d662d6f440879c533
--- /dev/null
+++ b/convs/modified_represnet.py
@@ -0,0 +1,177 @@
+import torch
+import torch.nn as nn
+import math
+import torch.utils.model_zoo as model_zoo
+import torch.nn.functional as F
+
+__all__ = ['ResNet', 'resnet18_rep', 'resnet34_rep' ]
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=True)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True)
+
+class conv_block(nn.Module):
+
+ def __init__(self, in_planes, planes, mode, stride=1):
+ super(conv_block, self).__init__()
+ self.conv = conv3x3(in_planes, planes, stride)
+ self.mode = mode
+ if mode == 'parallel_adapters':
+ self.adapter = conv1x1(in_planes, planes, stride)
+
+
+ def re_init_conv(self):
+ nn.init.kaiming_normal_(self.adapter.weight, mode='fan_out', nonlinearity='relu')
+ return
+ def forward(self, x):
+ y = self.conv(x)
+ if self.mode == 'parallel_adapters':
+ y = y + self.adapter(x)
+
+ return y
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, mode, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv_block(inplanes, planes, mode, stride)
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv_block(planes, planes, mode)
+ self.norm2 = nn.BatchNorm2d(planes)
+ self.mode = mode
+
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.norm2(out)
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=100, args = None):
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ assert args is not None
+ self.mode = args["mode"]
+
+ if 'cifar' in args["dataset"]:
+ self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True))
+ print("use cifar")
+ elif 'imagenet' in args["dataset"] or 'stanfordcar' in args["dataset"]:
+ if args["init_cls"] == args["increment"]:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+ else:
+ # Following PODNET implmentation
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+
+
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.feature = nn.AvgPool2d(4, stride=1)
+ self.out_dim = 512
+
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=True),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, self.mode, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, self.mode))
+
+ return nn.Sequential(*layers)
+
+ def switch(self, mode='normal'):
+ for name, module in self.named_modules():
+ if hasattr(module, 'mode'):
+ module.mode = mode
+ def re_init_params(self):
+ for name, module in self.named_modules():
+ if hasattr(module, 're_init_conv'):
+ module.re_init_conv()
+ def forward(self, x):
+ x = self.conv1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ dim = x.size()[-1]
+ pool = nn.AvgPool2d(dim, stride=1)
+ x = pool(x)
+ x = x.view(x.size(0), -1)
+ return {"features": x}
+
+
+def resnet18_rep(pretrained=False, **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ if pretrained:
+ pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
+ now_state_dict = model.state_dict()
+ now_state_dict.update(pretrained_state_dict)
+ model.load_state_dict(now_state_dict)
+ return model
+
+
+def resnet34_rep(pretrained=False, **kwargs):
+ """Constructs a ResNet-34 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
+ now_state_dict = model.state_dict()
+ now_state_dict.update(pretrained_state_dict)
+ model.load_state_dict(now_state_dict)
+ return model
\ No newline at end of file
diff --git a/convs/resnet.py b/convs/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d205be57c859f3db2843eebbbecb4d8328e7bd1
--- /dev/null
+++ b/convs/resnet.py
@@ -0,0 +1,395 @@
+'''
+Reference:
+https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
+'''
+import torch
+import torch.nn as nn
+try:
+ from torchvision.models.utils import load_state_dict_from_url
+except:
+ from torch.hub import load_state_dict_from_url
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None,args=None):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+
+ assert args is not None, "you should pass args to resnet"
+ if 'cifar' in args["dataset"]:
+ if args["model_name"] == "memo":
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+ else:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True))
+ elif 'imagenet' in args["dataset"] or 'stanfordcar' in args['dataset'] or 'general_dataset' in args['dataset']:
+ if args["init_cls"] == args["increment"]:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+ else:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+
+
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.out_dim = 512 * block.expansion
+ # self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x):
+ # See note [TorchScript super()]
+ x = self.conv1(x) # [bs, 64, 32, 32]
+
+ x_1 = self.layer1(x) # [bs, 128, 32, 32]
+ x_2 = self.layer2(x_1) # [bs, 256, 16, 16]
+ x_3 = self.layer3(x_2) # [bs, 512, 8, 8]
+ x_4 = self.layer4(x_3) # [bs, 512, 4, 4]
+
+ pooled = self.avgpool(x_4) # [bs, 512, 1, 1]
+ features = torch.flatten(pooled, 1) # [bs, 512]
+ # x = self.fc(x)
+
+ return {
+ 'fmaps': [x_1, x_2, x_3, x_4],
+ 'features': features
+ }
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ @property
+ def last_conv(self):
+ if hasattr(self.layer4[-1], 'conv3'):
+ return self.layer4[-1].conv3
+ else:
+ return self.layer4[-1].conv2
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+def resnet10(pretrained=False, progress=True, **kwargs):
+ """
+ For MEMO implementations of ResNet-10
+ """
+ return _resnet('resnet10', BasicBlock, [1, 1, 1, 1], pretrained, progress,
+ **kwargs)
+
+def resnet26(pretrained=False, progress=True, **kwargs):
+ """
+ For MEMO implementations of ResNet-26
+ """
+ return _resnet('resnet26', Bottleneck, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+def resnet18(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
diff --git a/convs/resnet_cbam.py b/convs/resnet_cbam.py
new file mode 100644
index 0000000000000000000000000000000000000000..240c430fb6b103cc9885f479bafaaad18f691cd1
--- /dev/null
+++ b/convs/resnet_cbam.py
@@ -0,0 +1,267 @@
+import torch
+import torch.nn as nn
+import math
+import torch.utils.model_zoo as model_zoo
+import torch.nn.functional as F
+
+__all__ = ['ResNet', 'resnet18_cbam', 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam',
+ 'resnet152_cbam']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class ChannelAttention(nn.Module):
+ def __init__(self, in_planes, ratio=16):
+ super(ChannelAttention, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
+
+ self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
+ self.relu1 = nn.ReLU()
+ self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
+ out = avg_out + max_out
+ return self.sigmoid(out)
+
+
+class SpatialAttention(nn.Module):
+ def __init__(self, kernel_size=7):
+ super(SpatialAttention, self).__init__()
+
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
+ padding = 3 if kernel_size == 7 else 1
+
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = torch.mean(x, dim=1, keepdim=True)
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
+ x = torch.cat([avg_out, max_out], dim=1)
+ x = self.conv1(x)
+ return self.sigmoid(x)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.ca = ChannelAttention(planes)
+ self.sa = SpatialAttention()
+
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.ca = ChannelAttention(planes * 4)
+ self.sa = SpatialAttention()
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+ out = self.bn3(out)
+ out = self.ca(out) * out
+ out = self.sa(out) * out
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=100, args=None):
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ assert args is not None, "you should pass args to resnet"
+ if 'cifar' in args["dataset"]:
+ self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True))
+ elif 'imagenet' in args["dataset"] or 'stanfordcar' in args['dataset']:
+ if args["init_cls"] == args["increment"]:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+ else:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.feature = nn.AvgPool2d(4, stride=1)
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
+ self.out_dim = 512 * block.expansion
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ dim = x.size()[-1]
+ pool = nn.AvgPool2d(dim, stride=1)
+ x = pool(x)
+ x = x.view(x.size(0), -1)
+ return {"features": x}
+
+def resnet18_cbam(pretrained=False, **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ if pretrained:
+ pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
+ now_state_dict = model.state_dict()
+ now_state_dict.update(pretrained_state_dict)
+ model.load_state_dict(now_state_dict)
+ return model
+
+
+def resnet34_cbam(pretrained=False, **kwargs):
+ """Constructs a ResNet-34 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
+ now_state_dict = model.state_dict()
+ now_state_dict.update(pretrained_state_dict)
+ model.load_state_dict(now_state_dict)
+ return model
+
+
+def resnet50_cbam(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
+ now_state_dict = model.state_dict()
+ now_state_dict.update(pretrained_state_dict)
+ model.load_state_dict(now_state_dict)
+ return model
+
+
+def resnet101_cbam(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+ if pretrained:
+ pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
+ now_state_dict = model.state_dict()
+ now_state_dict.update(pretrained_state_dict)
+ model.load_state_dict(now_state_dict)
+ return model
+
+
+def resnet152_cbam(pretrained=False, **kwargs):
+ """Constructs a ResNet-152 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+ if pretrained:
+ pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
+ now_state_dict = model.state_dict()
+ now_state_dict.update(pretrained_state_dict)
+ model.load_state_dict(now_state_dict)
+ return model
\ No newline at end of file
diff --git a/convs/ucir_cifar_resnet.py b/convs/ucir_cifar_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e71b742f56925d3a717228a8450c43e59dca39f
--- /dev/null
+++ b/convs/ucir_cifar_resnet.py
@@ -0,0 +1,204 @@
+'''
+Reference:
+https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py
+https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_resnet_cifar.py
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# from convs.modified_linear import CosineLinear
+
+
+class DownsampleA(nn.Module):
+ def __init__(self, nIn, nOut, stride):
+ super(DownsampleA, self).__init__()
+ assert stride == 2
+ self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
+
+ def forward(self, x):
+ x = self.avg(x)
+ return torch.cat((x, x.mul(0)), 1)
+
+
+class DownsampleB(nn.Module):
+ def __init__(self, nIn, nOut, stride):
+ super(DownsampleB, self).__init__()
+ self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
+ self.bn = nn.BatchNorm2d(nOut)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class DownsampleC(nn.Module):
+ def __init__(self, nIn, nOut, stride):
+ super(DownsampleC, self).__init__()
+ assert stride != 1 or nIn != nOut
+ self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class DownsampleD(nn.Module):
+ def __init__(self, nIn, nOut, stride):
+ super(DownsampleD, self).__init__()
+ assert stride == 2
+ self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
+ self.bn = nn.BatchNorm2d(nOut)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class ResNetBasicblock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, last=False):
+ super(ResNetBasicblock, self).__init__()
+
+ self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn_a = nn.BatchNorm2d(planes)
+
+ self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn_b = nn.BatchNorm2d(planes)
+
+ self.downsample = downsample
+ self.last = last
+
+ def forward(self, x):
+ residual = x
+
+ basicblock = self.conv_a(x)
+ basicblock = self.bn_a(basicblock)
+ basicblock = F.relu(basicblock, inplace=True)
+
+ basicblock = self.conv_b(basicblock)
+ basicblock = self.bn_b(basicblock)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out = residual + basicblock
+ if not self.last:
+ out = F.relu(out, inplace=True)
+
+ return out
+
+
+class CifarResNet(nn.Module):
+ """
+ ResNet optimized for the Cifar Dataset, as specified in
+ https://arxiv.org/abs/1512.03385.pdf
+ """
+
+ def __init__(self, block, depth, channels=3):
+ super(CifarResNet, self).__init__()
+
+ # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
+ assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
+ layer_blocks = (depth - 2) // 6
+
+ self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn_1 = nn.BatchNorm2d(16)
+
+ self.inplanes = 16
+ self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
+ self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
+ self.stage_3 = self._make_layer(block, 64, layer_blocks, 2, last_phase=True)
+ self.avgpool = nn.AvgPool2d(8)
+ self.out_dim = 64 * block.expansion
+ # self.fc = CosineLinear(64*block.expansion, 10)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, last_phase=False):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = DownsampleB(self.inplanes, planes * block.expansion, stride) # DownsampleA => DownsampleB
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ if last_phase:
+ for i in range(1, blocks-1):
+ layers.append(block(self.inplanes, planes))
+ layers.append(block(self.inplanes, planes, last=True))
+ else:
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv_1_3x3(x) # [bs, 16, 32, 32]
+ x = F.relu(self.bn_1(x), inplace=True)
+
+ x_1 = self.stage_1(x) # [bs, 16, 32, 32]
+ x_2 = self.stage_2(x_1) # [bs, 32, 16, 16]
+ x_3 = self.stage_3(x_2) # [bs, 64, 8, 8]
+
+ pooled = self.avgpool(x_3) # [bs, 64, 1, 1]
+ features = pooled.view(pooled.size(0), -1) # [bs, 64]
+ # out = self.fc(vector)
+
+ return {
+ 'fmaps': [x_1, x_2, x_3],
+ 'features': features
+ }
+
+ @property
+ def last_conv(self):
+ return self.stage_3[-1].conv_b
+
+
+def resnet20mnist():
+ """Constructs a ResNet-20 model for MNIST."""
+ model = CifarResNet(ResNetBasicblock, 20, 1)
+ return model
+
+
+def resnet32mnist():
+ """Constructs a ResNet-32 model for MNIST."""
+ model = CifarResNet(ResNetBasicblock, 32, 1)
+ return model
+
+
+def resnet20():
+ """Constructs a ResNet-20 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 20)
+ return model
+
+
+def resnet32():
+ """Constructs a ResNet-32 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 32)
+ return model
+
+
+def resnet44():
+ """Constructs a ResNet-44 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 44)
+ return model
+
+
+def resnet56():
+ """Constructs a ResNet-56 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 56)
+ return model
+
+
+def resnet110():
+ """Constructs a ResNet-110 model for CIFAR-10."""
+ model = CifarResNet(ResNetBasicblock, 110)
+ return model
diff --git a/convs/ucir_resnet.py b/convs/ucir_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..73b4dbb32e6ad4a5b7b8ff16c47efa9f7ea31740
--- /dev/null
+++ b/convs/ucir_resnet.py
@@ -0,0 +1,299 @@
+'''
+Reference:
+https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
+'''
+import torch
+import torch.nn as nn
+try:
+ from torchvision.models.utils import load_state_dict_from_url
+except:
+ from torch.hub import load_state_dict_from_url
+
+__all__ = ['resnet50']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None, last=False):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.last = last
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ if not self.last:
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None, last=False):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.last = last
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ if not self.last:
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None, args=None):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+
+ assert args is not None, "you should pass args to resnet"
+ if 'cifar' in args["dataset"]:
+ self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True))
+ elif 'imagenet' in args["dataset"] or 'stanfordcar' in args["dataset"] or 'general_dataset' in args['dataset']:
+ if args["init_cls"] == args["increment"]:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+ else:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ )
+
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2], last_phase=True)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.out_dim = 512 * block.expansion
+ self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False, last_phase=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ if last_phase:
+ for _ in range(1, blocks-1):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer, last=True))
+ else:
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x):
+ # See note [TorchScript super()]
+ x = self.conv1(x) # [bs, 64, 32, 32]
+
+ x_1 = self.layer1(x) # [bs, 128, 32, 32]
+ x_2 = self.layer2(x_1) # [bs, 256, 16, 16]
+ x_3 = self.layer3(x_2) # [bs, 512, 8, 8]
+ x_4 = self.layer4(x_3) # [bs, 512, 4, 4]
+
+ pooled = self.avgpool(x_4) # [bs, 512, 1, 1]
+ features = torch.flatten(pooled, 1) # [bs, 512]
+ # x = self.fc(x)
+
+ return {
+ 'fmaps': [x_1, x_2, x_3, x_4],
+ 'features': features
+ }
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ @property
+ def last_conv(self):
+ if hasattr(self.layer4[-1], 'conv3'):
+ return self.layer4[-1].conv3
+ else:
+ return self.layer4[-1].conv2
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+def resnet18(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
diff --git a/download_dataset.sh b/download_dataset.sh
new file mode 100644
index 0000000000000000000000000000000000000000..be8d19530c2f30e2a49872edc496e488c5c978a1
--- /dev/null
+++ b/download_dataset.sh
@@ -0,0 +1,8 @@
+#!/bin/sh
+kaggle datasets download -d senemanu/stanfordcarsfcs
+
+unzip -qq stanfordcarsfcs.zip
+
+rm -rf ./car_data/car_data/train/models
+
+mv ./car_data/car_data/test ./car_data/car_data/val
diff --git a/download_file_from_s3.py b/download_file_from_s3.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7c1ce923aa1d1d18f2ef4706ac7f92ed751c34e
--- /dev/null
+++ b/download_file_from_s3.py
@@ -0,0 +1,49 @@
+import os
+import boto3
+from botocore.exceptions import NoCredentialsError
+
+
+def download_from_s3(bucket_name, s3_key, local_path, is_directory=False):
+ """
+ Download a file or directory from S3 to a local path.
+
+ :param bucket_name: str. The name of the S3 bucket.
+ :param s3_key: str. The S3 key (path to the file or directory).
+ :param local_path: str. The local file path or directory to download to.
+ :param is_directory: bool. Set to True if s3_key is a directory.
+ """
+ s3 = boto3.client("s3")
+
+ if is_directory:
+ # Ensure the local directory exists
+ if not os.path.exists(local_path):
+ os.makedirs(local_path)
+
+ # List all objects in the specified S3 directory
+ result = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_key)
+ print(result)
+
+ if "Contents" in result:
+ for obj in result["Contents"]:
+ s3_object_key = obj["Key"]
+ # Remove the directory prefix to get the relative file path
+ relative_path = os.path.relpath(s3_object_key, s3_key)
+ local_file_path = os.path.join(local_path, relative_path)
+
+ # Ensure the local directory for the file exists
+ local_file_dir = os.path.dirname(local_file_path)
+ if not os.path.exists(local_file_dir):
+ os.makedirs(local_file_dir)
+
+ # Download the file
+ s3.download_file(bucket_name, s3_object_key, local_file_path)
+ print(f"Downloaded {s3_object_key} to {local_file_path}")
+ else:
+ # Download a single file
+ print(f"Downloaded {s3_key} to {local_path}")
+ s3.download_file(bucket_name, s3_key, local_path)
+
+
+# Example usage:
+# download_from_s3('my-bucket', 'path/to/myfile.txt', 'local/path/to/myfile.txt')
+# download_from_s3('my-bucket', 'path/to/mydirectory/', 'local/path/to/mydirectory', is_directory=True)
diff --git a/download_s3_path.py b/download_s3_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..e103bf7a1cb6a92172027e5135e79eb0b08b819a
--- /dev/null
+++ b/download_s3_path.py
@@ -0,0 +1,58 @@
+import os
+import boto3
+from botocore.exceptions import NoCredentialsError, PartialCredentialsError
+
+def download_s3_folder(bucket_name, s3_folder, local_dir):
+ # Convert local_dir to an absolute path
+ local_dir = os.path.abspath(local_dir)
+
+ # Ensure local directory exists
+ if not os.path.exists(local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+
+ s3 = boto3.client('s3')
+
+ try:
+ # List objects within the specified folder
+ objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
+ if 'Contents' not in objects:
+ print(f"The folder '{s3_folder}' does not contain any files.")
+ return
+
+ for obj in objects['Contents']:
+ # Formulate the local file path
+ s3_file_path = obj['Key']
+ if s3_file_path.endswith('/'):
+ # Skip directories
+ continue
+
+ local_file_path = os.path.join(local_dir, os.path.relpath(s3_file_path, s3_folder))
+
+ # Create local directories if they do not exist
+ os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
+
+ # Download the file
+ s3.download_file(bucket_name, s3_file_path, local_file_path)
+ print(f'Downloaded {s3_file_path} to {local_file_path}')
+
+ except KeyError:
+ print(f"The folder '{s3_folder}' does not contain any files.")
+ except NoCredentialsError:
+ print("Credentials not available.")
+ except PartialCredentialsError:
+ print("Incomplete credentials provided.")
+ except PermissionError as e:
+ print(f"Permission error: {e}. Please check your directory permissions.")
+ except Exception as e:
+ print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description='Download an S3 folder to a local directory.')
+ parser.add_argument('-bucket', type=str, required=True, help='The S3 bucket name.')
+ parser.add_argument('-s3_folder', type=str, required=True, help='The folder path within the S3 bucket.')
+ parser.add_argument('-local_dir', type=str, required=True, help='The local directory to download the files to.')
+ args = parser.parse_args()
+
+ download_s3_folder(args.bucket, args.s3_folder, args.local_dir)
diff --git a/entrypoint.sh b/entrypoint.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4fecceffa04b4d5f7024dd628499eac7fdc56875
--- /dev/null
+++ b/entrypoint.sh
@@ -0,0 +1,8 @@
+#!/bin/sh
+set -e
+
+chmod +x train.sh install_awscli.sh
+
+mkdir upload
+
+python server.py
diff --git a/eval.py b/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..a162c70e147cc4ae3d79dc579d69e46de14deb82
--- /dev/null
+++ b/eval.py
@@ -0,0 +1,133 @@
+import sys
+import logging
+import copy
+import torch
+from PIL import Image
+import torchvision.transforms as transforms
+from utils import factory
+from utils.data_manager import DataManager
+from torch.utils.data import DataLoader
+from utils.toolkit import count_parameters
+import os
+import numpy as np
+import json
+import argparse
+import torch.multiprocessing
+torch.multiprocessing.set_sharing_strategy('file_system')
+def _set_device(args):
+ device_type = args["device"]
+ gpus = []
+
+ for device in device_type:
+ if device == -1:
+ device = torch.device("cpu")
+ else:
+ device = torch.device("cuda:{}".format(device))
+
+ gpus.append(device)
+
+ args["device"] = gpus
+
+def get_methods(object, spacing=20):
+ methodList = []
+ for method_name in dir(object):
+ try:
+ if callable(getattr(object, method_name)):
+ methodList.append(str(method_name))
+ except Exception:
+ methodList.append(str(method_name))
+ processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s)
+ for method in methodList:
+ try:
+ print(str(method.ljust(spacing)) + ' ' +
+ processFunc(str(getattr(object, method).__doc__)[0:90]))
+ except Exception:
+ print(method.ljust(spacing) + ' ' + ' getattr() failed')
+
+def load_model(args):
+ _set_device(args)
+ model = factory.get_model(args["model_name"], args)
+ model.load_checkpoint(args["checkpoint"])
+ return model
+
+def evaluate(args):
+ logs_name = "logs/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], args['init_cls'], args['increment'])
+
+ if not os.path.exists(logs_name):
+ os.makedirs(logs_name)
+ logfilename = "logs/{}/{}_{}/{}/{}/{}_{}_{}".format(
+ args["model_name"],
+ args["dataset"],
+ args['data'],
+ args['init_cls'],
+ args["increment"],
+ args["prefix"],
+ args["seed"],
+ args["convnet_type"],
+ )
+ if not os.path.exists(logs_name):
+ os.makedirs(logs_name)
+ args['logfilename'] = logs_name
+ args['csv_name'] = "{}_{}_{}".format(
+ args["prefix"],
+ args["seed"],
+ args["convnet_type"],
+ )
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(filename)s] => %(message)s",
+ handlers=[
+ logging.FileHandler(filename=logfilename + ".log"),
+ logging.StreamHandler(sys.stdout),
+ ],
+ )
+ _set_random()
+ print_args(args)
+ model = load_model(args)
+
+ data_manager = DataManager(
+ args["dataset"],
+ False,
+ args["seed"],
+ args["init_cls"],
+ args["increment"],
+ path = args["data"]
+ )
+ loader = DataLoader(data_manager.get_dataset(model.class_list, source = "test", mode = "test"), batch_size=args['batch_size'], shuffle=True, num_workers=8)
+
+ cnn_acc, nme_acc = model.eval_task(loader, group = 1, mode = "test")
+ print(cnn_acc, nme_acc)
+def main():
+ args = setup_parser().parse_args()
+ param = load_json(args.config)
+ args = vars(args) # Converting argparse Namespace to a dict.
+ args.update(param) # Add parameters from json
+ evaluate(args)
+
+def load_json(settings_path):
+ with open(settings_path) as data_file:
+ param = json.load(data_file)
+
+ return param
+
+def _set_random():
+ torch.manual_seed(1)
+ torch.cuda.manual_seed(1)
+ torch.cuda.manual_seed_all(1)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+def setup_parser():
+ parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
+ parser.add_argument('--config', type=str, default='./exps/finetune.json',
+ help='Json file of settings.')
+ parser.add_argument('-d','--data', type=str, help='Path of the data folder')
+ parser.add_argument('-c','--checkpoint', type=str, help='Path of checkpoint file if resume training')
+ return parser
+
+def print_args(args):
+ for key, value in args.items():
+ logging.info("{}: {}".format(key, value))
+if __name__ == '__main__':
+ main()
+
diff --git a/exps/beef.json b/exps/beef.json
new file mode 100644
index 0000000000000000000000000000000000000000..c28195c56e11dc54cbb6a4832af30faeb738c5cf
--- /dev/null
+++ b/exps/beef.json
@@ -0,0 +1,28 @@
+{
+ "prefix": "fusion-energy-0.01-1.7-fixed",
+ "dataset": "stanfordcar",
+ "memory_size": 4000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "beefiso",
+ "convnet_type": "resnet18",
+ "device": ["0", "1"],
+ "seed": [2003],
+ "logits_alignment": 1.7,
+ "energy_weight": 0.01,
+ "is_compress":false,
+ "reduce_batch_size": false,
+ "init_epochs": 1,
+ "init_lr" : 0.1,
+ "init_weight_decay" : 5e-4,
+ "expansion_epochs" : 1,
+ "fusion_epochs" : 1,
+ "lr" : 0.1,
+ "batch_size" : 32,
+ "weight_decay" : 5e-4,
+ "num_workers" : 8,
+ "T" : 2
+}
\ No newline at end of file
diff --git a/exps/bic.json b/exps/bic.json
new file mode 100644
index 0000000000000000000000000000000000000000..6510ef908d06c6d55932ebb68e04da33296ff3d5
--- /dev/null
+++ b/exps/bic.json
@@ -0,0 +1,14 @@
+{
+ "prefix": "reproduce",
+ "dataset": "cifar100",
+ "memory_size": 2000,
+ "memory_per_class": 20,
+ "fixed_memory": false,
+ "shuffle": true,
+ "init_cls": 10,
+ "increment": 10,
+ "model_name": "bic",
+ "convnet_type": "resnet32",
+ "device": ["0","1","2","3"],
+ "seed": [1993]
+}
diff --git a/exps/coil.json b/exps/coil.json
new file mode 100644
index 0000000000000000000000000000000000000000..98a3d048bda89a74dcd09ddacbe6622f46c7b4fe
--- /dev/null
+++ b/exps/coil.json
@@ -0,0 +1,18 @@
+{
+ "prefix": "reproduce",
+ "dataset": "stanfordcar",
+ "memory_size": 2000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "sinkhorn":0.464,
+ "calibration_term":1.5,
+ "norm_term":3.0,
+ "reg_term":1e-3,
+ "model_name": "coil",
+ "convnet_type": "cosine_resnet18",
+ "device": ["0","1"],
+ "seed": [2003]
+}
diff --git a/exps/der.json b/exps/der.json
new file mode 100644
index 0000000000000000000000000000000000000000..a3ade0b06b91e7970c5e94200f1cbf8ad5685f7b
--- /dev/null
+++ b/exps/der.json
@@ -0,0 +1,14 @@
+{
+ "prefix": "reproduce",
+ "dataset": "stanfordcar",
+ "memory_size": 4000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "der",
+ "convnet_type": "resnet18",
+ "device": ["0","1"],
+ "seed": [1993]
+}
\ No newline at end of file
diff --git a/exps/ewc.json b/exps/ewc.json
new file mode 100644
index 0000000000000000000000000000000000000000..76d865417f23b15ac62b191c49ca5b6edd0a765f
--- /dev/null
+++ b/exps/ewc.json
@@ -0,0 +1,14 @@
+{
+ "prefix": "reproduce",
+ "dataset": "cifar100",
+ "memory_size": 2000,
+ "memory_per_class": 20,
+ "fixed_memory": false,
+ "shuffle": true,
+ "init_cls": 10,
+ "increment": 10,
+ "model_name": "ewc",
+ "convnet_type": "resnet32",
+ "device": ["0","1","2","3"],
+ "seed": [1993]
+}
\ No newline at end of file
diff --git a/exps/fetril.json b/exps/fetril.json
new file mode 100644
index 0000000000000000000000000000000000000000..b922528da2dd0afb85364acb036fcd9acc697b5d
--- /dev/null
+++ b/exps/fetril.json
@@ -0,0 +1,21 @@
+{
+ "prefix": "train",
+ "dataset": "stanfordcar",
+ "memory_size": 0,
+ "shuffle": true,
+ "init_cls": 40,
+ "increment": 1,
+ "model_name": "fetril",
+ "convnet_type": "resnet18",
+ "device": ["0"],
+ "seed": [2003],
+ "init_epochs": 100,
+ "init_lr" : 0.1,
+ "init_weight_decay" : 5e-4,
+ "epochs" : 80,
+ "lr" : 0.05,
+ "batch_size" : 32,
+ "weight_decay" : 5e-4,
+ "num_workers" : 8,
+ "T" : 2
+}
\ No newline at end of file
diff --git a/exps/finetune.json b/exps/finetune.json
new file mode 100644
index 0000000000000000000000000000000000000000..f0c5a5cdc225009dc5dbe818aa3952e9cd519dad
--- /dev/null
+++ b/exps/finetune.json
@@ -0,0 +1,14 @@
+{
+ "prefix": "reproduce",
+ "dataset": "stanfordcar",
+ "memory_size": 4000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "finetune",
+ "convnet_type": "resnet18",
+ "device": ["0"],
+ "seed": [2003]
+}
\ No newline at end of file
diff --git a/exps/foster.json b/exps/foster.json
new file mode 100644
index 0000000000000000000000000000000000000000..dfbcb38ce4cca206ee0fabe68068dd461d91ed83
--- /dev/null
+++ b/exps/foster.json
@@ -0,0 +1,31 @@
+{
+ "prefix": "cil",
+ "dataset": "stanfordcar",
+ "memory_size": 4000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "foster",
+ "convnet_type": "resnet18",
+ "device": ["0"],
+ "seed": [2003],
+ "beta1":0.96,
+ "beta2":0.97,
+ "oofc":"ft",
+ "is_teacher_wa":false,
+ "is_student_wa":false,
+ "lambda_okd":1,
+ "wa_value":1,
+ "init_epochs": 100,
+ "init_lr" : 0.1,
+ "init_weight_decay" : 5e-4,
+ "boosting_epochs" : 80,
+ "compression_epochs" : 50,
+ "lr" : 0.1,
+ "batch_size" : 32,
+ "weight_decay" : 5e-4,
+ "num_workers" : 8,
+ "T" : 2
+}
\ No newline at end of file
diff --git a/exps/foster_general.json b/exps/foster_general.json
new file mode 100644
index 0000000000000000000000000000000000000000..47eebb6f73bc1c62ccef639c7a43bf4da0828595
--- /dev/null
+++ b/exps/foster_general.json
@@ -0,0 +1,31 @@
+{
+ "prefix": "cil",
+ "dataset": "general_dataset",
+ "memory_size": 4000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "foster",
+ "convnet_type": "resnet18",
+ "device": ["0"],
+ "seed": [2003],
+ "beta1":0.96,
+ "beta2":0.97,
+ "oofc":"ft",
+ "is_teacher_wa":false,
+ "is_student_wa":false,
+ "lambda_okd":1,
+ "wa_value":1,
+ "init_epochs": 100,
+ "init_lr" : 0.1,
+ "init_weight_decay" : 5e-4,
+ "boosting_epochs" : 80,
+ "compression_epochs" : 50,
+ "lr" : 0.1,
+ "batch_size" : 32,
+ "weight_decay" : 5e-4,
+ "num_workers" : 8,
+ "T" : 2
+}
\ No newline at end of file
diff --git a/exps/gem.json b/exps/gem.json
new file mode 100644
index 0000000000000000000000000000000000000000..ddec0a3f41761ff94b17ed6b27711cc8ab7bce07
--- /dev/null
+++ b/exps/gem.json
@@ -0,0 +1,14 @@
+{
+ "prefix": "reproduce",
+ "dataset": "stanfordcar",
+ "memory_size": 4000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "gem",
+ "convnet_type": "resnet18",
+ "device": [ "0", "1"],
+ "seed": [2003]
+}
\ No newline at end of file
diff --git a/exps/icarl.json b/exps/icarl.json
new file mode 100644
index 0000000000000000000000000000000000000000..2129645841f9dc322646b4b3583475baf02b4853
--- /dev/null
+++ b/exps/icarl.json
@@ -0,0 +1,15 @@
+{
+ "prefix": "reproduce",
+ "dataset": "stanfordcar",
+ "memory_size": 4000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "icarl",
+ "convnet_type": "resnet18",
+ "device": ["0"],
+ "seed": [2003]
+}
+
diff --git a/exps/il2a.json b/exps/il2a.json
new file mode 100644
index 0000000000000000000000000000000000000000..c644999c7a42b31abf2215ee614398fa89ad3f2a
--- /dev/null
+++ b/exps/il2a.json
@@ -0,0 +1,24 @@
+{
+ "prefix": "cil",
+ "dataset": "stanfordcar",
+ "memory_size": 0,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "il2a",
+ "convnet_type": "resnet18_cbam",
+ "device": ["0", "1"],
+ "seed": [2003],
+ "lambda_fkd":10,
+ "lambda_proto":10,
+ "temp":0.1,
+ "epochs" : 1,
+ "lr" : 0.001,
+ "batch_size" : 32,
+ "weight_decay" : 2e-4,
+ "step_size":45,
+ "gamma":0.1,
+ "num_workers" : 8,
+ "ratio": 2.5,
+ "T" : 2
+}
\ No newline at end of file
diff --git a/exps/lwf.json b/exps/lwf.json
new file mode 100644
index 0000000000000000000000000000000000000000..0d55f4e0c0065b5e5345ccc98588211b65628b32
--- /dev/null
+++ b/exps/lwf.json
@@ -0,0 +1,14 @@
+{
+ "prefix": "reproduce",
+ "dataset": "stanfordcar",
+ "memory_size": 4000,
+ "memory_per_class": 10,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "lwf",
+ "convnet_type": "resnet18",
+ "device":["0", "1"],
+ "seed": [2003]
+}
diff --git a/exps/memo.json b/exps/memo.json
new file mode 100644
index 0000000000000000000000000000000000000000..c8eef1668374c0eb8470151beb5b979efc193d9a
--- /dev/null
+++ b/exps/memo.json
@@ -0,0 +1,33 @@
+{
+ "prefix": "benchmark",
+ "dataset": "stanfordcar",
+ "memory_size": 4000,
+ "memory_per_class":20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "memo",
+ "convnet_type": "memo_resnet18",
+ "train_base": true,
+ "train_adaptive": true,
+ "debug": false,
+ "skip": false,
+ "device": ["0", "1"],
+ "seed":[2003],
+ "scheduler": "steplr",
+ "init_epoch": 100,
+ "t_max": null,
+ "init_lr" : 0.1,
+ "init_weight_decay" : 5e-4,
+ "init_lr_decay" : 0.1,
+ "init_milestones" : [40,60,80],
+ "milestones" : [30,50,70],
+ "epochs": 80,
+ "lrate" : 0.1,
+ "batch_size" : 32,
+ "weight_decay" : 2e-4,
+ "lrate_decay" : 0.1,
+ "alpha_aux" : 1.0,
+ "backbone" : "models/finetune/reproduce_2003_resnet18_9.pkl"
+}
\ No newline at end of file
diff --git a/exps/pass.json b/exps/pass.json
new file mode 100644
index 0000000000000000000000000000000000000000..b82509b11635b23bf6aa86f5db0ee7c0c8ccec29
--- /dev/null
+++ b/exps/pass.json
@@ -0,0 +1,23 @@
+{
+ "prefix": "train",
+ "dataset": "stanfordcar",
+ "memory_size": 0,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "pass",
+ "convnet_type": "resnet18_cbam",
+ "device": ["0"],
+ "seed": [2003],
+ "lambda_fkd":10,
+ "lambda_proto":10,
+ "temp":0.1,
+ "epochs" : 100,
+ "lr" : 0.001,
+ "batch_size" : 16,
+ "weight_decay" : 2e-4,
+ "step_size":45,
+ "gamma":0.1,
+ "num_workers" : 8,
+ "T" : 2
+}
\ No newline at end of file
diff --git a/exps/podnet.json b/exps/podnet.json
new file mode 100644
index 0000000000000000000000000000000000000000..d33f8c3de36e474d1918b9bf5b0f7f6aefefe4ae
--- /dev/null
+++ b/exps/podnet.json
@@ -0,0 +1,14 @@
+{
+ "prefix": "increment",
+ "dataset": "stanfordcar",
+ "memory_size": 2000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "podnet",
+ "convnet_type": "cosine_resnet18",
+ "device": ["0","1"],
+ "seed": [2003]
+}
diff --git a/exps/replay.json b/exps/replay.json
new file mode 100644
index 0000000000000000000000000000000000000000..11c8ce967040d3d4072bb821028af16f0ee973ac
--- /dev/null
+++ b/exps/replay.json
@@ -0,0 +1,14 @@
+{
+ "prefix": "reproduce",
+ "dataset": "stanfordcar",
+ "memory_size": 4000,
+ "memory_per_class": 20,
+ "fixed_memory": true,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "replay",
+ "convnet_type": "resnet18",
+ "device": ["0"],
+ "seed": [2003]
+}
\ No newline at end of file
diff --git a/exps/rmm-foster.json b/exps/rmm-foster.json
new file mode 100644
index 0000000000000000000000000000000000000000..671d4d48ed32517ee743b578e462d53215c5e5bd
--- /dev/null
+++ b/exps/rmm-foster.json
@@ -0,0 +1,31 @@
+{
+ "prefix": "rmm-foster",
+ "dataset": "stanfordcar",
+ "memory_size": 2000,
+ "m_rate_list":[0.3, 0.3, 0.3, 0.4, 0.4, 0.4],
+ "c_rate_list":[0.0, 0.0, 0.1, 0.1, 0.1, 0.0],
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "rmm-foster",
+ "convnet_type": "resnet18",
+ "device": ["0", "1"],
+ "seed": [2003],
+ "beta1":0.97,
+ "beta2":0.97,
+ "oofc":"ft",
+ "is_teacher_wa":false,
+ "is_student_wa":false,
+ "lambda_okd":1,
+ "wa_value":1,
+ "init_epochs": 1,
+ "init_lr" : 0.1,
+ "init_weight_decay" : 5e-4,
+ "boosting_epochs" : 1,
+ "compression_epochs" : 1,
+ "lr" : 0.1,
+ "batch_size" : 32,
+ "weight_decay" : 5e-4,
+ "num_workers" : 8,
+ "T" : 2
+}
diff --git a/exps/rmm-icarl.json b/exps/rmm-icarl.json
new file mode 100644
index 0000000000000000000000000000000000000000..d117fcbfc4d9d8d9c49e3a448947c5e94a6673a5
--- /dev/null
+++ b/exps/rmm-icarl.json
@@ -0,0 +1,15 @@
+{
+ "prefix": "reproduce",
+ "dataset": "cifar100",
+ "m_rate_list":[0.8, 0.8, 0.6, 0.6, 0.6, 0.6],
+ "c_rate_list":[0.0, 0.0, 0.1, 0.1, 0.1, 0.0],
+ "memory_size": 2000,
+ "shuffle": true,
+ "init_cls": 50,
+ "increment": 10,
+ "model_name": "rmm-icarl",
+ "convnet_type": "resnet32",
+ "device": ["0"],
+ "seed": [1993]
+}
+
diff --git a/exps/rmm-pretrain.json b/exps/rmm-pretrain.json
new file mode 100644
index 0000000000000000000000000000000000000000..14b92ceb178ec413b1810dc60627100c7693ad6a
--- /dev/null
+++ b/exps/rmm-pretrain.json
@@ -0,0 +1,10 @@
+{
+ "prefix": "pretrain-rmm",
+ "dataset": "cifar100",
+ "memory_size": 2000,
+ "shuffle": true,
+ "model_name": "rmm-icarl",
+ "convnet_type": "resnet32",
+ "device": ["0"],
+ "seed": [1993]
+}
diff --git a/exps/simplecil.json b/exps/simplecil.json
new file mode 100644
index 0000000000000000000000000000000000000000..b037f432151deaf9e3eff0074d3546c3be80bb17
--- /dev/null
+++ b/exps/simplecil.json
@@ -0,0 +1,23 @@
+{
+ "prefix": "simplecil",
+ "dataset": "stanfordcar",
+ "memory_size": 0,
+ "memory_per_class": 0,
+ "fixed_memory": false,
+ "shuffle": true,
+ "init_cls": 50,
+ "increment": 20,
+ "model_name": "simplecil",
+ "convnet_type": "cosine_resnet18",
+ "device": ["0"],
+ "seed": [2003],
+ "checkpoint": "./models/simplecil/stanfordcar/0/20/simplecil_0.pkl",
+ "init_epoch": 1,
+ "init_lr": 0.01,
+ "batch_size": 32,
+ "weight_decay": 0.05,
+ "init_lr_decay": 0.1,
+ "init_weight_decay": 5e-4,
+ "min_lr": 0
+}
+
diff --git a/exps/simplecil_general.json b/exps/simplecil_general.json
new file mode 100644
index 0000000000000000000000000000000000000000..96acb01c8cd916e103623457deb112ef197ab4bc
--- /dev/null
+++ b/exps/simplecil_general.json
@@ -0,0 +1,22 @@
+{
+ "prefix": "simplecil",
+ "dataset": "general_dataset",
+ "memory_size": 0,
+ "memory_per_class": 0,
+ "fixed_memory": false,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "simplecil",
+ "convnet_type": "cosine_resnet18",
+ "device": [-1],
+ "seed": [2003],
+ "init_epoch": 1,
+ "init_lr": 0.01,
+ "batch_size": 32,
+ "weight_decay": 0.05,
+ "init_lr_decay": 0.1,
+ "init_weight_decay": 5e-4,
+ "min_lr": 0
+}
+
diff --git a/exps/simplecil_resume.json b/exps/simplecil_resume.json
new file mode 100644
index 0000000000000000000000000000000000000000..cd18739955d5b9366d860b10328e489417e4c6dd
--- /dev/null
+++ b/exps/simplecil_resume.json
@@ -0,0 +1,24 @@
+{
+ "prefix": "simplecil",
+ "dataset": "general_dataset",
+ "memory_size": 0,
+ "memory_per_class": 0,
+ "fixed_memory": false,
+ "shuffle": true,
+ "init_cls": 50,
+ "increment": 20,
+ "model_name": "simplecil",
+ "convnet_type": "cosine_resnet18",
+ "device": ["0"],
+ "seed": [2003],
+ "checkpoint": "./models/simplecil/stanfordcar/50/20/simplecil_0.pkl",
+ "data": "./car_data/car_data",
+ "init_epoch": 1,
+ "init_lr": 0.01,
+ "batch_size": 32,
+ "weight_decay": 0.05,
+ "init_lr_decay": 0.1,
+ "init_weight_decay": 5e-4,
+ "min_lr": 0
+}
+
diff --git a/exps/ssre.json b/exps/ssre.json
new file mode 100644
index 0000000000000000000000000000000000000000..d9f6935cf7d8f4dcc23effbaf20b7bfee15772da
--- /dev/null
+++ b/exps/ssre.json
@@ -0,0 +1,25 @@
+{
+ "prefix": "ssre",
+ "dataset": "stanfordcar",
+ "memory_size": 0,
+ "shuffle": true,
+ "init_cls": 20,
+ "increment": 20,
+ "model_name": "ssre",
+ "convnet_type": "resnet18_rep",
+ "device": ["0"],
+ "seed": [2003],
+ "lambda_fkd":1,
+ "lambda_proto":10,
+ "temp":0.1,
+ "mode": "parallel_adapters",
+ "epochs" : 1,
+ "lr" : 0.0001,
+ "batch_size" : 32,
+ "weight_decay" : 5e-4,
+ "step_size":45,
+ "gamma":0.1,
+ "threshold": 0.8,
+ "num_workers" : 8,
+ "T" : 2
+}
\ No newline at end of file
diff --git a/exps/wa.json b/exps/wa.json
new file mode 100644
index 0000000000000000000000000000000000000000..16e2ca86e9d7700a10a9738bd1dc7a1910fa0a3f
--- /dev/null
+++ b/exps/wa.json
@@ -0,0 +1,14 @@
+{
+ "prefix": "reproduce",
+ "dataset": "cifar100",
+ "memory_size": 2000,
+ "memory_per_class": 20,
+ "fixed_memory": false,
+ "shuffle": true,
+ "init_cls": 10,
+ "increment": 10,
+ "model_name": "wa",
+ "convnet_type": "resnet32",
+ "device": ["0","1","2","3"],
+ "seed": [1993]
+}
\ No newline at end of file
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..207c4ebf35e1c2670d8c548732c2ebc5a7c46eb2
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,115 @@
+import sys
+import logging
+import copy
+import torch
+from PIL import Image
+import torchvision.transforms as transforms
+from torchvision.transforms.functional import pil_to_tensor
+from utils import factory
+from utils.data_manager import DataManager
+from utils.toolkit import count_parameters
+from utils.data_manager import pil_loader
+import os
+import numpy as np
+import json
+import argparse
+import imghdr
+import time
+
+def is_image_imghdr(path):
+ """
+ Checks if a path points to a valid image using imghdr.
+
+ Args:
+ path: The path to the file.
+
+ Returns:
+ True if the path is a valid image, False otherwise.
+ """
+ if not os.path.isfile(path):
+ return False
+ return imghdr.what(path) in ['jpeg', 'png']
+
+def _set_device(args):
+ device_type = args["device"]
+ gpus = []
+
+ for device in device_type:
+ if device == -1:
+ device = torch.device("cpu")
+ else:
+ device = torch.device("cuda:{}".format(device))
+
+ gpus.append(device)
+
+ args["device"] = gpus
+
+def get_methods(object, spacing=20):
+ methodList = []
+ for method_name in dir(object):
+ try:
+ if callable(getattr(object, method_name)):
+ methodList.append(str(method_name))
+ except Exception:
+ methodList.append(str(method_name))
+ processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s)
+ for method in methodList:
+ try:
+ print(str(method.ljust(spacing)) + ' ' +
+ processFunc(str(getattr(object, method).__doc__)[0:90]))
+ except Exception:
+ print(method.ljust(spacing) + ' ' + ' getattr() failed')
+
+def load_model(args):
+ _set_device(args)
+ model = factory.get_model(args["model_name"], args)
+ model.load_checkpoint(args["checkpoint"])
+ return model
+def main():
+ args = setup_parser().parse_args()
+ param = load_json(args.config)
+ args = vars(args) # Converting argparse Namespace to a dict.
+ args.update(param) # Add parameters from json
+ assert args['output'].split(".")[-1] == "json" or os.path.isdir(args['output'])
+ model = load_model(args)
+ result = []
+ if is_image_imghdr(args['input']):
+ img = pil_to_tensor(pil_loader(args['input']))
+ img = img.unsqueeze(0)
+ predictions = model.inference(img)
+ out = {"img": args['input'].split("/")[-1]}
+ out.update({"predictions": [{"confident": confident, "index": pred, "label": label } for pred, label, confident in zip(predictions[0], predictions[1], predictions[2])]})
+ result.append(out)
+ else:
+ image_list = filter(lambda x: is_image_imghdr(os.path.join(args['input'], x)), os.listdir(args['input']))
+ for image in image_list:
+ print("Inference on image", image)
+ img = pil_to_tensor(pil_loader(os.path.join(args['input'], image)))
+ img = img.unsqueeze(0)
+ predictions = model.inference(img)
+ out = {"img": image.split("/")[-1]}
+ out.update({"predictions": [{"confident": confident, "index": pred, "label": label } for pred, label, confident in zip(predictions[0], predictions[1], predictions[2])]})
+ result.append(out)
+ if args['output'].split(".")[-1] == "json":
+ with open(args['output'], "w+") as f:
+ json.dump(result, f, indent=4)
+ else:
+ with open(os.path.join(args['output'], "output_model_{}.json".format(time.time())), "w+") as f:
+ json.dump(result, f, indent=4)
+def load_json(settings_path):
+ with open(settings_path) as data_file:
+ param = json.load(data_file)
+ return param
+
+
+def setup_parser():
+ parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
+ parser.add_argument('--config', type=str, help='Json file of settings.')
+ parser.add_argument('--checkpoint', type=str, help="path to checkpoint file. File must be a .pth format file")
+ parser.add_argument('--input', type=str, help="Path to input. This could be an folder or an image file")
+ parser.add_argument('--output', type=str, help = "Output path to save prediction")
+ return parser
+
+if __name__ == '__main__':
+ main()
+
diff --git a/install_awscli.sh b/install_awscli.sh
new file mode 100644
index 0000000000000000000000000000000000000000..dc0acc69c82403826c436e30b1caea9258e2322f
--- /dev/null
+++ b/install_awscli.sh
@@ -0,0 +1,7 @@
+#!/bin/sh
+
+curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
+
+unzip awscliv2.zip
+
+./aws/install
diff --git a/load.sh b/load.sh
new file mode 100644
index 0000000000000000000000000000000000000000..429b50227fb951e7be17231d8d687ecfc08935d2
--- /dev/null
+++ b/load.sh
@@ -0,0 +1,5 @@
+#! /bin/sh
+for arg in $@; do
+ python ./load_model.py --config=$arg;
+ # Your commands to process each argument here
+done
\ No newline at end of file
diff --git a/load_model.py b/load_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d2b4e02b20ec76c1d47a7dfce9fad32ef6ce2aa
--- /dev/null
+++ b/load_model.py
@@ -0,0 +1,73 @@
+import sys
+import logging
+import copy
+import torch
+from PIL import Image
+import torchvision.transforms as transforms
+from utils import factory
+from utils.data_manager import DataManager
+from utils.toolkit import count_parameters
+import os
+import numpy as np
+import json
+import argparse
+
+def _set_device(args):
+ device_type = args["device"]
+ gpus = []
+
+ for device in device_type:
+ if device == -1:
+ device = torch.device("cpu")
+ else:
+ device = torch.device("cuda:{}".format(device))
+
+ gpus.append(device)
+
+ args["device"] = gpus
+
+def get_methods(object, spacing=20):
+ methodList = []
+ for method_name in dir(object):
+ try:
+ if callable(getattr(object, method_name)):
+ methodList.append(str(method_name))
+ except Exception:
+ methodList.append(str(method_name))
+ processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s)
+ for method in methodList:
+ try:
+ print(str(method.ljust(spacing)) + ' ' +
+ processFunc(str(getattr(object, method).__doc__)[0:90]))
+ except Exception:
+ print(method.ljust(spacing) + ' ' + ' getattr() failed')
+
+def load_model(args):
+ _set_device(args)
+ model = factory.get_model(args["model_name"], args)
+ model.load_checkpoint(args["checkpoint"])
+ return model
+def main():
+ args = setup_parser().parse_args()
+ param = load_json(args.config)
+ args = vars(args) # Converting argparse Namespace to a dict.
+ args.update(param) # Add parameters from json
+
+ load_model(args)
+def load_json(settings_path):
+ with open(settings_path) as data_file:
+ param = json.load(data_file)
+
+ return param
+
+
+def setup_parser():
+ parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
+ parser.add_argument('--config', type=str, default='./exps/finetune.json',
+ help='Json file of settings.')
+
+ return parser
+
+if __name__ == '__main__':
+ main()
+
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..358266a524f04de2ab57e27ebf67f925d99a43b5
--- /dev/null
+++ b/main.py
@@ -0,0 +1,38 @@
+import json
+import argparse
+from trainer import train
+from train_more import train_more
+
+def main():
+ args = setup_parser().parse_args()
+ param = load_json(args.config)
+ args = vars(args) # Converting argparse Namespace to a dict.
+ args.update(param) # Add parameters from json
+ if not args['dataset'] == "general_dataset":
+ train(args)
+ else:
+ assert args['data'] != None
+ if not args['checkpoint']:
+ args.pop('checkpoint')
+ train(args)
+ else:
+ train_more(args)
+
+def load_json(settings_path):
+ with open(settings_path) as data_file:
+ param = json.load(data_file)
+
+ return param
+
+
+def setup_parser():
+ parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
+ parser.add_argument('--config', type=str, default='./exps/finetune.json',
+ help='Json file of settings.')
+ parser.add_argument('-d','--data', nargs ='?', type=str, help='Path of the data folder')
+ parser.add_argument('-c','--checkpoint',nargs = '?', type=str, help='Path of checkpoint file if resume training')
+ return parser
+
+
+if __name__ == "__main__":
+ main()
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/base.py b/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..349a06d335b1941ea4d7f9d6e9b9e88555c1d8e9
--- /dev/null
+++ b/models/base.py
@@ -0,0 +1,421 @@
+import copy
+import logging
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+from utils.toolkit import tensor2numpy, accuracy
+from scipy.spatial.distance import cdist
+import os
+
+EPSILON = 1e-8
+batch_size = 64
+
+
+class BaseLearner(object):
+ def __init__(self, args):
+ self.args = args
+ self._cur_task = -1
+ self._known_classes = 0
+ self._total_classes = 0
+ self.class_list = []
+ self._network = None
+ self._old_network = None
+ self._data_memory, self._targets_memory = np.array([]), np.array([])
+ self.topk = 5
+
+ self._memory_size = args["memory_size"]
+ self._memory_per_class = args.get("memory_per_class", None)
+ self._fixed_memory = args.get("fixed_memory", False)
+ self._device = args["device"][0]
+ self._multiple_gpus = args["device"]
+
+ @property
+ def exemplar_size(self):
+ assert len(self._data_memory) == len(
+ self._targets_memory
+ ), "Exemplar size error."
+ return len(self._targets_memory)
+
+ @property
+ def samples_per_class(self):
+ if self._fixed_memory:
+ return self._memory_per_class
+ else:
+ assert self._total_classes != 0, "Total classes is 0"
+ return self._memory_size // self._total_classes
+
+ @property
+ def feature_dim(self):
+ if isinstance(self._network, nn.DataParallel):
+ return self._network.module.feature_dim
+ else:
+ return self._network.feature_dim
+
+ def build_rehearsal_memory(self, data_manager, per_class, ):
+ if self._fixed_memory:
+ self._construct_exemplar_unified(data_manager, per_class)
+ else:
+ self._reduce_exemplar(data_manager, per_class)
+ self._construct_exemplar(data_manager, per_class)
+ def load_checkpoint(self, filename):
+ pass;
+
+ def save_checkpoint(self, filename):
+ self._network.cpu()
+ save_dict = {
+ "tasks": self._cur_task,
+ "model_state_dict": self._network.state_dict(),
+ }
+ torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task))
+
+ def after_task(self):
+ pass
+
+ def _evaluate(self, y_pred, y_true, group = 10):
+ ret = {}
+ grouped = accuracy(y_pred.T[0], y_true, self._known_classes, increment = group)
+ ret["grouped"] = grouped
+ ret["top1"] = grouped["total"]
+ ret["top{}".format(self.topk)] = np.around(
+ (y_pred.T == np.tile(y_true, (self.topk, 1))).sum() * 100 / len(y_true),
+ decimals=2,
+ )
+
+ return ret
+
+ def eval_task(self, data=None, save_conf=False, group = 10, mode = "train"):
+ if data is None:
+ data = self.test_loader
+ y_pred, y_true = self._eval_cnn(data, mode = mode)
+ cnn_accy = self._evaluate(y_pred, y_true, group = group)
+
+ if hasattr(self, "_class_means"):
+ y_pred, y_true = self._eval_nme(data, self._class_means)
+ nme_accy = self._evaluate(y_pred, y_true)
+ else:
+ nme_accy = None
+
+ if save_conf:
+ _pred = y_pred.T[0]
+ _pred_path = os.path.join(self.args['logfilename'], "pred.npy")
+ _target_path = os.path.join(self.args['logfilename'], "target.npy")
+ np.save(_pred_path, _pred)
+ np.save(_target_path, y_true)
+
+ _save_dir = os.path.join(f"./results/{self.args['model_name']}/conf_matrix/{self.args['prefix']}")
+ os.makedirs(_save_dir, exist_ok=True)
+ _save_path = os.path.join(_save_dir, f"{self.args['csv_name']}.csv")
+ with open(_save_path, "a+") as f:
+ f.write(f"{self.args['model_name']},{_pred_path},{_target_path} \n")
+
+ return cnn_accy, nme_accy
+
+ def incremental_train(self):
+ pass
+
+ def _train(self):
+ pass
+
+ def _get_memory(self):
+ if len(self._data_memory) == 0:
+ return None
+ else:
+ return (self._data_memory, self._targets_memory)
+
+ def _compute_accuracy(self, model, loader):
+ model.eval()
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(loader):
+ inputs = inputs.to(self._device)
+ with torch.no_grad():
+ outputs = model(inputs)["logits"]
+ predicts = torch.max(outputs, dim=1)[1]
+ correct += (predicts.cpu() == targets).sum()
+ total += len(targets)
+
+ return np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ def _eval_cnn(self, loader, mode = "train"):
+ self._network.eval()
+ y_pred, y_true = [], []
+ for _, (_, inputs, targets) in enumerate(loader):
+ inputs = inputs.to(self._device)
+ with torch.no_grad():
+ outputs = self._network(inputs)["logits"]
+ if self.topk > self._total_classes:
+ self.topk = self._total_classes
+ predicts = torch.topk(
+ outputs, k=self.topk, dim=1, largest=True, sorted=True
+ )[
+ 1
+ ] # [bs, topk]
+ refine_predicts = predicts.cpu().numpy()
+ if mode == "test":
+ refine_predicts = self.class_list[refine_predicts]
+ y_pred.append(refine_predicts)
+ y_true.append(targets.cpu().numpy())
+ return np.concatenate(y_pred), np.concatenate(y_true) # [N, topk]
+ def inference(self, image):
+ self._network.eval()
+ self._network.to(self._device)
+ image = image.to(self._device, dtype=torch.float32)
+ with torch.no_grad():
+ output = self._network(image)["logits"]
+ if self.topk > self._total_classes:
+ self.topk = self._total_classes
+ predict = torch.topk(
+ output, k=self.topk, dim=1, largest=True, sorted=True
+ )[1]
+ confidents = softmax(output.cpu().numpy())
+ if self.class_list is not None:
+ self.class_list = np.array(self.class_list)
+ predicts = predict.cpu().numpy()
+ result = self.class_list[predicts].tolist()
+ #result = predicts.tolist()
+ result.append([self.label_list[item] for item in result[0]])
+ result.append(confidents[0][predicts][0].tolist())
+ return result
+ elif self.data_manager is not None:
+ return self.data_manager.class_list[predict.cpu().numpy()]
+
+ predicts.append([self.label_list[index] for index in predicts[0]])
+ return predicts
+
+ def _eval_nme(self, loader, class_means):
+ self._network.eval()
+ vectors, y_true = self._extract_vectors(loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+
+ dists = cdist(class_means, vectors, "sqeuclidean") # [nb_classes, N]
+ scores = dists.T # [N, nb_classes], choose the one with the smallest distance
+
+ return np.argsort(scores, axis=1)[:, : self.topk], y_true # [N, topk]
+
+ def _extract_vectors(self, loader):
+ self._network.eval()
+ vectors, targets = [], []
+ for _, _inputs, _targets in loader:
+ _targets = _targets.numpy()
+ if isinstance(self._network, nn.DataParallel):
+ _vectors = tensor2numpy(
+ self._network.module.extract_vector(_inputs.to(self._device))
+ )
+ else:
+ _vectors = tensor2numpy(
+ self._network.extract_vector(_inputs.to(self._device))
+ )
+
+ vectors.append(_vectors)
+ targets.append(_targets)
+ return np.concatenate(vectors), np.concatenate(targets)
+
+ def _reduce_exemplar(self, data_manager, m):
+ logging.info("Reducing exemplars...({} per classes)".format(m))
+ dummy_data, dummy_targets = copy.deepcopy(self._data_memory), copy.deepcopy(
+ self._targets_memory
+ )
+ self._class_means = np.zeros((self._total_classes, self.feature_dim))
+ self._data_memory, self._targets_memory = np.array([]), np.array([])
+
+ for class_idx in range(self._known_classes):
+ mask = np.where(dummy_targets == class_idx)[0]
+ dd, dt = dummy_data[mask][:m], dummy_targets[mask][:m]
+ self._data_memory = (
+ np.concatenate((self._data_memory, dd))
+ if len(self._data_memory) != 0
+ else dd
+ )
+ self._targets_memory = (
+ np.concatenate((self._targets_memory, dt))
+ if len(self._targets_memory) != 0
+ else dt
+ )
+
+ # Exemplar mean
+ idx_dataset = data_manager.get_dataset(
+ [], source="train", mode="test", appendent=(dd, dt)
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ self._class_means[class_idx, :] = mean
+
+ def _construct_exemplar(self, data_manager, m):
+ logging.info("Constructing exemplars...({} per classes)".format(m))
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = data_manager.get_dataset(
+ np.arange(class_idx, class_idx + 1),
+ source="train",
+ mode="test",
+ ret_data=True,
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ class_mean = np.mean(vectors, axis=0)
+
+ # Select
+ selected_exemplars = []
+ exemplar_vectors = [] # [n, feature_dim]
+ for k in range(1, m + 1):
+ S = np.sum(
+ exemplar_vectors, axis=0
+ ) # [feature_dim] sum of selected exemplars vectors
+ mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors
+ i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
+ selected_exemplars.append(
+ np.array(data[i])
+ ) # New object to avoid passing by inference
+ exemplar_vectors.append(
+ np.array(vectors[i])
+ ) # New object to avoid passing by inference
+
+ vectors = np.delete(
+ vectors, i, axis=0
+ ) # Remove it to avoid duplicative selection
+ data = np.delete(
+ data, i, axis=0
+ ) # Remove it to avoid duplicative selection
+
+ # uniques = np.unique(selected_exemplars, axis=0)
+ # print('Unique elements: {}'.format(len(uniques)))
+ selected_exemplars = np.array(selected_exemplars)
+ exemplar_targets = np.full(m, class_idx)
+ self._data_memory = (
+ np.concatenate((self._data_memory, selected_exemplars))
+ if len(self._data_memory) != 0
+ else selected_exemplars
+ )
+ self._targets_memory = (
+ np.concatenate((self._targets_memory, exemplar_targets))
+ if len(self._targets_memory) != 0
+ else exemplar_targets
+ )
+
+ # Exemplar mean
+ idx_dataset = data_manager.get_dataset(
+ [],
+ source="train",
+ mode="test",
+ appendent=(selected_exemplars, exemplar_targets),
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ self._class_means[class_idx, :] = mean
+
+ def _construct_exemplar_unified(self, data_manager, m):
+ logging.info(
+ "Constructing exemplars for new classes...({} per classes)".format(m)
+ )
+ _class_means = np.zeros((self._total_classes, self.feature_dim))
+
+ # Calculate the means of old classes with newly trained network
+ for class_idx in range(self._known_classes):
+ mask = np.where(self._targets_memory == class_idx)[0]
+ class_data, class_targets = (
+ self._data_memory[mask],
+ self._targets_memory[mask],
+ )
+
+ class_dset = data_manager.get_dataset(
+ [], source="train", mode="test", appendent=(class_data, class_targets)
+ )
+ class_loader = DataLoader(
+ class_dset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(class_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ _class_means[class_idx, :] = mean
+
+ # Construct exemplars for new classes and calculate the means
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, class_dset = data_manager.get_dataset(
+ np.arange(class_idx, class_idx + 1),
+ source="train",
+ mode="test",
+ ret_data=True,
+ )
+ class_loader = DataLoader(
+ class_dset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+
+ vectors, _ = self._extract_vectors(class_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ class_mean = np.mean(vectors, axis=0)
+
+ # Select
+ selected_exemplars = []
+ exemplar_vectors = []
+ for k in range(1, m + 1):
+ S = np.sum(
+ exemplar_vectors, axis=0
+ ) # [feature_dim] sum of selected exemplars vectors
+ mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors
+ i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
+
+ selected_exemplars.append(
+ np.array(data[i])
+ ) # New object to avoid passing by inference
+ exemplar_vectors.append(
+ np.array(vectors[i])
+ ) # New object to avoid passing by inference
+
+ vectors = np.delete(
+ vectors, i, axis=0
+ ) # Remove it to avoid duplicative selection
+ data = np.delete(
+ data, i, axis=0
+ ) # Remove it to avoid duplicative selection
+
+ selected_exemplars = np.array(selected_exemplars)
+ exemplar_targets = np.full(m, class_idx)
+ self._data_memory = (
+ np.concatenate((self._data_memory, selected_exemplars))
+ if len(self._data_memory) != 0
+ else selected_exemplars
+ )
+ self._targets_memory = (
+ np.concatenate((self._targets_memory, exemplar_targets))
+ if len(self._targets_memory) != 0
+ else exemplar_targets
+ )
+
+ # Exemplar mean
+ exemplar_dset = data_manager.get_dataset(
+ [],
+ source="train",
+ mode="test",
+ appendent=(selected_exemplars, exemplar_targets),
+ )
+ exemplar_loader = DataLoader(
+ exemplar_dset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(exemplar_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ _class_means[class_idx, :] = mean
+
+ self._class_means = _class_means
+def softmax(x):
+ """Compute softmax values for each sets of scores in x."""
+ e_x = np.exp(x - np.max(x))
+ return e_x / (e_x.sum(axis=0) + 1e-7) # only difference
diff --git a/models/beef_iso.py b/models/beef_iso.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f72ef56e5525c1c6898441efd13626518856cea
--- /dev/null
+++ b/models/beef_iso.py
@@ -0,0 +1,684 @@
+import copy
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import BEEFISONet
+from utils.toolkit import count_parameters, target2onehot, tensor2numpy
+
+
+EPSILON = 1e-8
+
+
+class BEEFISO(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self.args = args
+ self._network = BEEFISONet(args, False)
+ self._snet = None
+ self.logits_alignment = args["logits_alignment"]
+ self.val_loader = None
+ self.reduce_batch_size = args["reduce_batch_size"]
+ self.random = args.get("random",None)
+ self.imbalance = args.get("imbalance",None)
+
+ def after_task(self):
+ self._network_module_ptr.update_fc_after()
+ self._known_classes = self._total_classes
+ if self.reduce_batch_size:
+ if self._cur_task == 0:
+ self.args["batch_size"] = self.args["batch_size"]
+ else:
+ self.args["batch_size"] = self.args["batch_size"] * (self._cur_task+1) // (self._cur_task+2)
+ logging.info("Exemplar size: {}".format(self.exemplar_size))
+
+ def incremental_train(self, data_manager):
+ self.data_manager = data_manager
+ self._cur_task += 1
+ if self._cur_task > 1 and self.args["is_compress"]:
+ self._network = self._snet
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc_before(self._total_classes)
+ self._network_module_ptr = self._network
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ if self._cur_task > 0:
+ for id in range(self._cur_task):
+ for p in self._network.convnets[id].parameters():
+ p.requires_grad = False
+ for p in self._network.old_fc.parameters():
+ p.requires_grad = False
+
+
+ logging.info("All params: {}".format(count_parameters(self._network)))
+ logging.info(
+ "Trainable params: {}".format(count_parameters(self._network, True))
+ )
+
+
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ )
+ self.train_loader = DataLoader(
+ train_dataset,
+ batch_size=self.args["batch_size"],
+ shuffle=True,
+ num_workers=self.args["num_workers"],
+ pin_memory=True,
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset,
+ batch_size=self.args["batch_size"],
+ shuffle=False,
+ num_workers=self.args["num_workers"],
+ pin_memory=True,
+ )
+ if self._cur_task > 0:
+ if self.random or self.imbalance:
+ val_dset = data_manager.get_finetune_dataset(known_classes=self._known_classes, total_classes=self._total_classes,
+ source="train", mode='train', appendent=self._get_memory(), type="ratio")
+ else:
+ _, val_dset = data_manager.get_dataset_with_split(np.arange(self._known_classes, self._total_classes),
+ source='train', mode='train',
+ appendent=self._get_memory(),
+ val_samples_per_class=int(
+ self.samples_old_class))
+ self.val_loader = DataLoader(
+ val_dset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True)
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader,self.val_loader)
+ if self.random or self.imbalance:
+ self.build_rehearsal_memory_imbalance(data_manager,self.samples_per_class)
+ else:
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def train(self):
+ self._network_module_ptr.train()
+ self._network_module_ptr.convnets[-1].train()
+ if self._cur_task >= 1:
+ self._network_module_ptr.convnets[0].eval()
+
+ def _train(self, train_loader, test_loader, val_loader=None):
+ self._network.to(self._device)
+ if hasattr(self._network, "module"):
+ self._network_module_ptr = self._network.module
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._network.parameters()),
+ momentum=0.9,
+ lr=self.args["init_lr"],
+ weight_decay=self.args["init_weight_decay"],
+ )
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=self.args["init_epochs"]
+ )
+ self.epochs = self.args["init_epochs"]
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._network.parameters()),
+ lr=self.args["lr"],
+ momentum=0.9,
+ weight_decay=self.args["weight_decay"],
+ )
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=self.args["expansion_epochs"]
+ )
+
+ self.epochs = self.args["expansion_epochs"]
+ self.state = "expansion"
+ if len(self._multiple_gpus) > 1:
+ network = self._network.module
+ else:
+ network = self._network
+ for p in network.biases.parameters():
+ p.requires_grad = False
+ self._expansion(train_loader, test_loader, optimizer, scheduler)
+
+
+
+ for p in self._network_module_ptr.forward_prototypes.parameters():
+ p.requires_grad = False
+ for p in self._network_module_ptr.backward_prototypes.parameters():
+ p.requires_grad = False
+ for p in self._network_module_ptr.new_fc.parameters():
+ p.requires_grad = False
+ for p in self._network_module_ptr.convnets[-1].parameters():
+ p.requires_grad = False
+ for p in self._network.biases.parameters():
+ p.requires_grad = True
+ self.state = "fusion"
+ self.epochs = self.args["fusion_epochs"]
+ self.per_cls_weights = torch.ones(self._total_classes).to(self._device)
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._network.parameters()),
+ lr=0.05,
+ momentum=0.9,
+ weight_decay=self.args["weight_decay"],
+ )
+ for n, p in self._network.named_parameters():
+ if p.requires_grad == True:
+ print(n)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=self.args["fusion_epochs"]
+ )
+ self._fusion(val_loader,test_loader,optimizer,scheduler)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self.epochs))
+ for _, epoch in enumerate(prog_bar):
+ self.train()
+ losses = 0.0
+ losses_en = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True
+ ), targets.to(self._device, non_blocking=True)
+ logits = self._network(inputs)["logits"]
+ loss_en = self.args["energy_weight"] * self.get_energy_loss(inputs,targets,targets)
+ loss = F.cross_entropy(logits, targets)
+ loss = loss + loss_en
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ losses_en += loss_en.item()
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args["init_epochs"],
+ losses / len(train_loader),
+ losses_en / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args["init_epochs"],
+ losses / len(train_loader),
+ losses_en / len(train_loader),
+ train_acc,
+ )
+
+ prog_bar.set_description(info)
+ logging.info(info)
+
+ def _expansion(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self.epochs))
+ for _, epoch in enumerate(prog_bar):
+ self.train()
+ losses = 0.0
+ losses_clf = 0.0
+ losses_fe = 0.0
+ losses_en = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True
+ ), targets.to(self._device, non_blocking=True)
+ targets = targets.float()
+ outputs = self._network(inputs)
+ logits,train_logits = (
+ outputs["logits"],
+ outputs["train_logits"]
+ )
+ pseudo_targets = targets.clone()
+ for task_id in range(self._cur_task+1):
+ if task_id == 0:
+ pseudo_targets = torch.where(targets0,targets-self._known_classes+task_id,pseudo_targets)
+ else:
+ pseudo_targets = torch.where((targetsself.data_manager.get_accumulate_tasksize(task_id-1)-1),task_id,pseudo_targets)
+
+ train_logits[:, list(range(self._cur_task))] /= self.logits_alignment
+ loss_clf = F.cross_entropy(train_logits.float(), pseudo_targets)
+ loss_fe = torch.tensor(0.).cuda()
+ loss_en = self.args["energy_weight"] * self.get_energy_loss(inputs,targets,pseudo_targets)
+ loss = loss_clf + loss_fe + loss_en
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ losses_fe += loss_fe.item()
+ losses_clf += loss_clf.item()
+ losses_en += loss_en.item()
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.epochs,
+ losses / len(train_loader),
+ losses_clf / len(train_loader),
+ losses_fe / len(train_loader),
+ losses_en / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.epochs,
+ losses / len(train_loader),
+ losses_clf / len(train_loader),
+ losses_fe / len(train_loader),
+ losses_en / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
+
+ def _fusion(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self.epochs))
+ for _, epoch in enumerate(prog_bar):
+ self.train()
+ # self.
+ losses = 0.0
+ losses_clf = 0.0
+ losses_fe = 0.0
+ losses_kd = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True
+ ), targets.to(self._device, non_blocking=True)
+ outputs = self._network(inputs)
+ logits,train_logits = (
+ outputs["logits"],
+ outputs["train_logits"]
+ )
+
+ loss_clf = F.cross_entropy(logits,targets)
+ loss_fe = torch.tensor(0.).cuda()
+ loss_kd = torch.tensor(0.).cuda()
+ loss = loss_clf + loss_fe + loss_kd
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ losses_fe += loss_fe.item()
+ losses_clf += loss_clf.item()
+ losses_kd += (
+ self._known_classes / self._total_classes
+ ) * loss_kd.item()
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.epochs,
+ losses / len(train_loader),
+ losses_clf / len(train_loader),
+ losses_fe / len(train_loader),
+ losses_kd / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.epochs,
+ losses / len(train_loader),
+ losses_clf / len(train_loader),
+ losses_fe / len(train_loader),
+ losses_kd / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
+
+
+ @property
+ def samples_old_class(self):
+ if self._fixed_memory:
+ return self._memory_per_class
+ else:
+ assert self._total_classes != 0, "Total classes is 0"
+ return self._memory_size // self._known_classes
+
+ def samples_new_class(self, index):
+ if self.args["dataset"] == "cifar100":
+ return 500
+ else:
+ return self.data_manager.getlen(index)
+
+ def BKD(self, pred, soft, T):
+ pred = torch.log_softmax(pred / T, dim=1)
+ soft = torch.softmax(soft / T, dim=1)
+ soft = soft * self.per_cls_weights
+ soft = soft / soft.sum(1)[:, None]
+ return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
+
+
+ def get_energy_loss(self,inputs,targets,pseudo_targets):
+ inputs = self.sample_q(inputs)
+
+ out = self._network(inputs)
+ if self._cur_task == 0:
+ targets = targets + self._total_classes
+ train_logits, energy_logits = out["logits"], out["energy_logits"]
+ else:
+ targets = targets + (self._total_classes - self._known_classes) + self._cur_task
+ train_logits, energy_logits = out["train_logits"], out["energy_logits"]
+
+ logits = torch.cat([train_logits,energy_logits],dim=1)
+
+ logits[:,pseudo_targets] = 1e-9
+ energy_loss = F.cross_entropy(logits,targets)
+ return energy_loss
+
+ def sample_q(self, replay_buffer, n_steps=3):
+ """this func takes in replay_buffer now so we have the option to sample from
+ scratch (i.e. replay_buffer==[]). See test_wrn_ebm.py for example.
+ """
+ self._network_copy = self._network_module_ptr.copy().freeze()
+ init_sample = replay_buffer
+ init_sample = torch.rot90(init_sample, 2, (2, 3))
+ embedding_k = init_sample.clone().detach().requires_grad_(True)
+ optimizer_gen = torch.optim.SGD(
+ [embedding_k], lr=1e-2)
+ for k in range(1, n_steps + 1):
+ out = self._network_copy(embedding_k)
+ if self._cur_task == 0:
+ energy_logits, train_logits = out["energy_logits"], out["logits"]
+ else:
+ energy_logits, train_logits = out["energy_logits"], out["train_logits"]
+ num_forwards = energy_logits.shape[1]
+ logits = torch.cat([train_logits,energy_logits],dim=1)
+ negative_energy = torch.log(torch.sum(torch.softmax(logits,dim=1)[:,-num_forwards:]))
+ optimizer_gen.zero_grad()
+ negative_energy.sum().backward()
+ optimizer_gen.step()
+ embedding_k.data += 1e-3 * \
+ torch.randn_like(embedding_k)
+ final_samples = embedding_k.detach()
+ return final_samples
+
+
+ def build_rehearsal_memory_imbalance(self, data_manager, per_class):
+ if self._fixed_memory:
+ self._construct_exemplar_unified_imbalance(data_manager, per_class,self.random,self.imbalance)
+ else:
+ self._reduce_exemplar_imbalance(data_manager, per_class,self.random,self.imbalance)
+ self._construct_exemplar_imbalance(data_manager, per_class,self.random,self.imbalance)
+
+
+ def _reduce_exemplar_imbalance(self, data_manager, m,random,imbalance):
+ logging.info('Reducing exemplars...({} per classes)'.format(m))
+ dummy_data, dummy_targets = copy.deepcopy(self._data_memory), copy.deepcopy(self._targets_memory)
+ self._class_means = np.zeros((self._total_classes, self.feature_dim))
+ self._data_memory, self._targets_memory = np.array([]), np.array([])
+
+ for class_idx in range(self._known_classes):
+ mask = np.where(dummy_targets == class_idx)[0]
+ l = sum(mask)
+ if l == 0:
+ continue
+ if random or imbalance is not None:
+ dd, dt = dummy_data[mask][:-1], dummy_targets[mask][:-1]
+ else:
+ dd, dt = dummy_data[mask][:m], dummy_targets[mask][:m]
+ self._data_memory = np.concatenate((self._data_memory, dd)) if len(self._data_memory) != 0 else dd
+ self._targets_memory = np.concatenate((self._targets_memory, dt)) if len(self._targets_memory) != 0 else dt
+
+ # Exemplar mean
+ idx_dataset = data_manager.get_dataset([], source='train', mode='test', appendent=(dd, dt))
+ idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ self._class_means[class_idx, :] = mean
+
+ def _construct_exemplar_imbalance(self, data_manager, m, random=False,imbalance=None):
+ increment = self._total_classes - self._known_classes
+
+ if random:
+ '''
+ uniform random type
+ '''
+ selected_exemplars = []
+ selected_targets = []
+ logging.info("Contructing exmplars, totally random...({} total instances {} classes)".format(increment*m, increment))
+ data, targets, idx_dataset = data_manager.get_dataset(np.arange(self._known_classes,self._total_classes),source="train",mode="test",ret_data=True)
+ selected_indices = np.random.choice(list(range(len(data))),m*increment,repladce=False)
+ for idx in selected_indices:
+ selected_exemplars.append(data[idx])
+ selected_targets.append(targets[idx])
+ selected_exemplars = np.array(selected_exemplars)[:m*increment]
+ selected_targets = np.array(selected_targets)[:m*increment]
+ self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \
+ else selected_exemplars
+ self._targets_memory = np.concatenate((self._targets_memory, selected_targets)) if \
+ len(self._targets_memory) != 0 else selected_targets
+ else:
+ if imbalance is None:
+ logging.info('Constructing exemplars...({} per classes)'.format(m))
+ ms = np.ones(increment,dtype=int)*m
+ elif imbalance>=1:
+ '''
+ half-half type
+ '''
+ ms=[m for _ in range(increment)]
+ for i in range(increment//2):
+ ms[i]-=m//imbalance
+ for i in range(increment//2,increment):
+ ms[i]+=m//imbalance
+ np.random.shuffle(ms)
+ ms = np.array(ms,dtype=int)
+ logging.info("Constructing exmplars, Imbalance...({} or {} per classes)".format(m-m//imbalance,(m+m//imbalance)))
+ elif imbalance<1:
+ '''
+ exp type
+ '''
+ ms = np.array([imbalance**i for i in range(increment)])
+ ms = ms/ms.sum()
+ tot = m*increment
+ ms = (tot*ms).astype(int)
+ np.random.shuffle(ms)
+
+ else:
+ assert 0, "not implemented yet"
+ logging.info("ms {}".format(ms))
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
+ mode='test', ret_data=True)
+ idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ class_mean = np.mean(vectors, axis=0)
+
+ # Select
+ selected_exemplars = []
+ exemplar_vectors = [] # [n, feature_dim]
+ for k in range(1, ms[class_idx-self._known_classes]+1):
+ S = np.sum(exemplar_vectors, axis=0) # [feature_dim] sum of selected exemplars vectors
+ mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors
+ i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
+ selected_exemplars.append(np.array(data[i])) # New object to avoid passing by inference
+ exemplar_vectors.append(np.array(vectors[i])) # New object to avoid passing by inference
+
+ vectors = np.delete(vectors, i, axis=0) # Remove it to avoid duplicative selection
+ data = np.delete(data, i, axis=0) # Remove it to avoid duplicative selection
+
+ # uniques = np.unique(selected_exemplars, axis=0)
+ selected_exemplars = np.array(selected_exemplars)
+ if len(selected_exemplars)==0:
+ continue
+ exemplar_targets = np.full(ms[class_idx-self._known_classes], class_idx)
+ self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \
+ else selected_exemplars
+ self._targets_memory = np.concatenate((self._targets_memory, exemplar_targets)) if \
+ len(self._targets_memory) != 0 else exemplar_targets
+
+ # Exemplar mean
+ idx_dataset = data_manager.get_dataset([], source='train', mode='test',
+ appendent=(selected_exemplars, exemplar_targets))
+ idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4,pin_memory=True)
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ self._class_means[class_idx, :] = mean
+ # self._class_means[class_idx, :] = class_mean
+
+ def _construct_exemplar_unified_imbalance(self, data_manager, m,random,imbalance):
+ logging.info('Constructing exemplars for new classes...({} per classes)'.format(m))
+ _class_means = np.zeros((self._total_classes, self.feature_dim))
+ increment = self._total_classes - self._known_classes
+
+ # Calculate the means of old classes with newly trained network
+ for class_idx in range(self._known_classes):
+ mask = np.where(self._targets_memory == class_idx)[0]
+ if sum(mask) == 0: continue
+ class_data, class_targets = self._data_memory[mask], self._targets_memory[mask]
+
+ class_dset = data_manager.get_dataset([], source='train', mode='test',
+ appendent=(class_data, class_targets))
+ class_loader = DataLoader(class_dset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(class_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ _class_means[class_idx, :] = mean
+
+ if random:
+ '''
+ uniform sample type
+ '''
+ selected_exemplars = []
+ selected_targets = []
+ logging.info("Contructing exmplars, totally random...({} total instances {} classes)".format(increment*m, increment))
+ data, targets, idx_dataset = data_manager.get_dataset(np.arange(self._known_classes,self._total_classes),source="train",mode="test",ret_data=True)
+ selected_indices = np.random.choice(list(range(len(data))),m*increment,replace=False)
+ for idx in selected_indices:
+ selected_exemplars.append(data[idx])
+ selected_targets.append(targets[idx])
+ selected_exemplars = np.array(selected_exemplars)
+ selected_targets = np.array(selected_targets)
+ self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \
+ else selected_exemplars
+ self._targets_memory = np.concatenate((self._targets_memory, selected_targets)) if \
+ len(self._targets_memory) != 0 else selected_targets
+ else:
+ if imbalance is None:
+ logging.info('Constructing exemplars...({} per classes)'.format(m))
+ ms = np.ones(increment,dtype=int)*m
+ elif imbalance>=1:
+ '''
+ half-half type
+ '''
+ ms=[m for _ in range(increment)]
+ for i in range(increment//2):
+ ms[i]-=m//imbalance
+ for i in range(increment//2,increment):
+ ms[i]+=m//imbalance
+ np.random.shuffle(ms)
+ ms = np.array(ms,dtype=int)
+ logging.info("Constructing exmplars, Imbalance...({} or {} per classes)".format(m-m//imbalance,(m+m//imbalance)))
+ elif imbalance<1:
+ '''
+ exp type
+ '''
+ ms = np.array([imbalance**i for i in range(increment)])
+ ms = ms/ms.sum()
+ tot = m*increment
+ ms = (tot*ms).astype(int)
+ np.random.shuffle(ms)
+
+ else:
+ assert 0, "not implemented yet"
+ logging.info("ms {}".format(ms))
+ # Construct exemplars for new classes and calculate the means
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, class_dset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
+ mode='test', ret_data=True)
+ class_loader = DataLoader(class_dset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4,pin_memory=True)
+
+ vectors, _ = self._extract_vectors(class_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ class_mean = np.mean(vectors, axis=0)
+
+ # Select
+ selected_exemplars = []
+ exemplar_vectors = []
+ for k in range(1, ms[class_idx-self._known_classes]+1):
+ S = np.sum(exemplar_vectors, axis=0) # [feature_dim] sum of selected exemplars vectors
+ mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors
+ i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
+
+ selected_exemplars.append(np.array(data[i])) # New object to avoid passing by inference
+ exemplar_vectors.append(np.array(vectors[i])) # New object to avoid passing by inference
+
+ vectors = np.delete(vectors, i, axis=0) # Remove it to avoid duplicative selection
+ data = np.delete(data, i, axis=0) # Remove it to avoid duplicative selection
+
+ selected_exemplars = np.array(selected_exemplars)
+ if len(selected_exemplars)==0:
+ continue
+ exemplar_targets = np.full(ms[class_idx-self._known_classes], class_idx)
+ self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \
+ else selected_exemplars
+ self._targets_memory = np.concatenate((self._targets_memory, exemplar_targets)) if \
+ len(self._targets_memory) != 0 else exemplar_targets
+
+ # Exemplar mean
+ exemplar_dset = data_manager.get_dataset([], source='train', mode='test',
+ appendent=(selected_exemplars, exemplar_targets))
+ exemplar_loader = DataLoader(exemplar_dset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(exemplar_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ _class_means[class_idx, :] = mean
+ # _class_means[class_idx,:] = class_mean
+
+ self._class_means = _class_means
+
diff --git a/models/bic.py b/models/bic.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57aba6c7a106e38aa6c3bd0aeeffee969f3a058
--- /dev/null
+++ b/models/bic.py
@@ -0,0 +1,206 @@
+import logging
+import numpy as np
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import IncrementalNetWithBias
+
+
+epochs = 170
+lrate = 0.1
+milestones = [60, 100, 140]
+lrate_decay = 0.1
+batch_size = 128
+split_ratio = 0.1
+T = 2
+weight_decay = 2e-4
+num_workers = 8
+
+
+class BiC(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = IncrementalNetWithBias(
+ args, False, bias_correction=True
+ )
+ self._class_means = None
+
+ def after_task(self):
+ self._old_network = self._network.copy().freeze()
+ self._known_classes = self._total_classes
+ logging.info("Exemplar size: {}".format(self.exemplar_size))
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ if self._cur_task >= 1:
+ train_dset, val_dset = data_manager.get_dataset_with_split(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ val_samples_per_class=int(
+ split_ratio * self._memory_size / self._known_classes
+ ),
+ )
+ self.val_loader = DataLoader(
+ val_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ logging.info(
+ "Stage1 dset: {}, Stage2 dset: {}".format(
+ len(train_dset), len(val_dset)
+ )
+ )
+ self.lamda = self._known_classes / self._total_classes
+ logging.info("Lambda: {:.3f}".format(self.lamda))
+ else:
+ train_dset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ )
+ test_dset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+
+ self.train_loader = DataLoader(
+ train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ self.test_loader = DataLoader(
+ test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ self._log_bias_params()
+ self._stage1_training(self.train_loader, self.test_loader)
+ if self._cur_task >= 1:
+ self._stage2_bias_correction(self.val_loader, self.test_loader)
+
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+ self._log_bias_params()
+
+ def _run(self, train_loader, test_loader, optimizer, scheduler, stage):
+ for epoch in range(1, epochs + 1):
+ self._network.train()
+ losses = 0.0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ if stage == "training":
+ clf_loss = F.cross_entropy(logits, targets)
+ if self._old_network is not None:
+ old_logits = self._old_network(inputs)["logits"].detach()
+ hat_pai_k = F.softmax(old_logits / T, dim=1)
+ log_pai_k = F.log_softmax(
+ logits[:, : self._known_classes] / T, dim=1
+ )
+ distill_loss = -torch.mean(
+ torch.sum(hat_pai_k * log_pai_k, dim=1)
+ )
+ loss = distill_loss * self.lamda + clf_loss * (1 - self.lamda)
+ else:
+ loss = clf_loss
+ elif stage == "bias_correction":
+ loss = F.cross_entropy(torch.softmax(logits, dim=1), targets)
+ else:
+ raise NotImplementedError()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ scheduler.step()
+ train_acc = self._compute_accuracy(self._network, train_loader)
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "{} => Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}".format(
+ stage,
+ self._cur_task,
+ epoch,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ logging.info(info)
+
+ def _stage1_training(self, train_loader, test_loader):
+ """
+ if self._cur_task == 0:
+ loaded_dict = torch.load('./dict_0.pkl')
+ self._network.load_state_dict(loaded_dict['model_state_dict'])
+ self._network.to(self._device)
+ return
+ """
+
+ ignored_params = list(map(id, self._network.bias_layers.parameters()))
+ base_params = filter(
+ lambda p: id(p) not in ignored_params, self._network.parameters()
+ )
+ network_params = [
+ {"params": base_params, "lr": lrate, "weight_decay": weight_decay},
+ {
+ "params": self._network.bias_layers.parameters(),
+ "lr": 0,
+ "weight_decay": 0,
+ },
+ ]
+ optimizer = optim.SGD(
+ network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._network.to(self._device)
+ if self._old_network is not None:
+ self._old_network.to(self._device)
+
+ self._run(train_loader, test_loader, optimizer, scheduler, stage="training")
+
+ def _stage2_bias_correction(self, val_loader, test_loader):
+ if isinstance(self._network, nn.DataParallel):
+ self._network = self._network.module
+ network_params = [
+ {
+ "params": self._network.bias_layers[-1].parameters(),
+ "lr": lrate,
+ "weight_decay": weight_decay,
+ }
+ ]
+ optimizer = optim.SGD(
+ network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._network.to(self._device)
+
+ self._run(
+ val_loader, test_loader, optimizer, scheduler, stage="bias_correction"
+ )
+
+ def _log_bias_params(self):
+ logging.info("Parameters of bias layer:")
+ params = self._network.get_bias_params()
+ for i, param in enumerate(params):
+ logging.info("{} => {:.3f}, {:.3f}".format(i, param[0], param[1]))
diff --git a/models/coil.py b/models/coil.py
new file mode 100644
index 0000000000000000000000000000000000000000..b000510f23ee033f60a2f9a58a73bae680ad8c63
--- /dev/null
+++ b/models/coil.py
@@ -0,0 +1,332 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import (
+ IncrementalNet,
+ CosineIncrementalNet,
+ SimpleCosineIncrementalNet,
+)
+from utils.toolkit import target2onehot, tensor2numpy
+import ot
+from torch import nn
+import copy
+
+EPSILON = 1e-8
+
+epochs = 100
+lrate = 0.1
+milestones = [40, 80]
+lrate_decay = 0.1
+batch_size = 32
+memory_size = 2000
+T = 2
+
+
+class COIL(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = SimpleCosineIncrementalNet(args, False)
+ self.data_manager = None
+ self.nextperiod_initialization = None
+ self.sinkhorn_reg = args["sinkhorn"]
+ self.calibration_term = args["calibration_term"]
+ self.args = args
+
+ def after_task(self):
+ self.nextperiod_initialization = self.solving_ot()
+ self._old_network = self._network.copy().freeze()
+ self._known_classes = self._total_classes
+
+ def solving_ot(self):
+ with torch.no_grad():
+ if self._total_classes == self.data_manager.get_total_classnum():
+ print("training over, no more ot solving")
+ return None
+ each_time_class_num = self.data_manager.get_task_size(1)
+ self._extract_class_means(
+ self.data_manager, 0, self._total_classes + each_time_class_num
+ )
+ former_class_means = torch.tensor(
+ self._ot_prototype_means[: self._total_classes]
+ )
+ next_period_class_means = torch.tensor(
+ self._ot_prototype_means[
+ self._total_classes : self._total_classes + each_time_class_num
+ ]
+ )
+ Q_cost_matrix = torch.cdist(
+ former_class_means, next_period_class_means, p=self.args["norm_term"]
+ )
+ # solving ot
+ _mu1_vec = (
+ torch.ones(len(former_class_means)) / len(former_class_means) * 1.0
+ )
+ _mu2_vec = (
+ torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0
+ )
+ T = ot.sinkhorn(_mu1_vec, _mu2_vec, Q_cost_matrix, self.sinkhorn_reg)
+ T = torch.tensor(T).float().cuda()
+ transformed_hat_W = torch.mm(
+ T.T, F.normalize(self._network.fc.weight, p=2, dim=1)
+ )
+ oldnorm = torch.norm(self._network.fc.weight, p=2, dim=1)
+ newnorm = torch.norm(
+ transformed_hat_W * len(former_class_means), p=2, dim=1
+ )
+ meannew = torch.mean(newnorm)
+ meanold = torch.mean(oldnorm)
+ gamma = meanold / meannew
+ self.calibration_term = gamma
+ self._ot_new_branch = (
+ transformed_hat_W * len(former_class_means) * self.calibration_term
+ )
+ return transformed_hat_W * len(former_class_means) * self.calibration_term
+
+ def solving_ot_to_old(self):
+ current_class_num = self.data_manager.get_task_size(self._cur_task)
+ self._extract_class_means_with_memory(
+ self.data_manager, self._known_classes, self._total_classes
+ )
+ former_class_means = torch.tensor(
+ self._ot_prototype_means[: self._known_classes]
+ )
+ next_period_class_means = torch.tensor(
+ self._ot_prototype_means[self._known_classes : self._total_classes]
+ )
+ Q_cost_matrix = (
+ torch.cdist(
+ next_period_class_means, former_class_means, p=self.args["norm_term"]
+ )
+ + EPSILON
+ ) # in case of numerical err
+ _mu1_vec = torch.ones(len(former_class_means)) / len(former_class_means) * 1.0
+ _mu2_vec = (
+ torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0
+ )
+ T = ot.sinkhorn(_mu2_vec, _mu1_vec, Q_cost_matrix, self.sinkhorn_reg)
+ T = torch.tensor(T).float().cuda()
+ transformed_hat_W = torch.mm(
+ T.T,
+ F.normalize(self._network.fc.weight[-current_class_num:, :], p=2, dim=1),
+ )
+ return transformed_hat_W * len(former_class_means) * self.calibration_term
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+
+ self._network.update_fc(self._total_classes, self.nextperiod_initialization)
+ self.data_manager = data_manager
+
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+ self.lamda = self._known_classes / self._total_classes
+ # Loader
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ )
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+
+ self._train(self.train_loader, self.test_loader)
+
+ if self.args['fixed_memory']:
+ examplar_size = self.args["memory_per_class"]
+ else:
+ examplar_size = memory_size // self._total_classes
+ self._reduce_exemplar(data_manager, examplar_size)
+ self._construct_exemplar(data_manager, examplar_size)
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._old_network is not None:
+ self._old_network.to(self._device)
+ optimizer = optim.SGD(
+ self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=5e-4
+ ) # 1e-5
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(epochs))
+ for _, epoch in enumerate(prog_bar):
+ weight_ot_init = max(1.0 - (epoch / 2) ** 2, 0)
+ weight_ot_co_tuning = (epoch / epochs) ** 2.0
+
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ output = self._network(inputs)
+ logits = output["logits"]
+ onehots = target2onehot(targets, self._total_classes)
+
+ clf_loss = F.cross_entropy(logits, targets)
+ if self._old_network is not None:
+
+ old_logits = self._old_network(inputs)["logits"].detach()
+ hat_pai_k = F.softmax(old_logits / T, dim=1)
+ log_pai_k = F.log_softmax(
+ logits[:, : self._known_classes] / T, dim=1
+ )
+ distill_loss = -torch.mean(torch.sum(hat_pai_k * log_pai_k, dim=1))
+
+ if epoch < 1:
+ features = F.normalize(output["features"], p=2, dim=1)
+ current_logit_new = F.log_softmax(
+ logits[:, self._known_classes :] / T, dim=1
+ )
+ new_logit_by_wnew_init_by_ot = F.linear(
+ features, F.normalize(self._ot_new_branch, p=2, dim=1)
+ )
+ new_logit_by_wnew_init_by_ot = F.softmax(
+ new_logit_by_wnew_init_by_ot / T, dim=1
+ )
+ new_branch_distill_loss = -torch.mean(
+ torch.sum(
+ current_logit_new * new_logit_by_wnew_init_by_ot, dim=1
+ )
+ )
+
+ loss = (
+ distill_loss * self.lamda
+ + clf_loss * (1 - self.lamda)
+ + 0.001 * (weight_ot_init * new_branch_distill_loss)
+ )
+ else:
+ features = F.normalize(output["features"], p=2, dim=1)
+ if i % 30 == 0:
+ with torch.no_grad():
+ self._ot_old_branch = self.solving_ot_to_old()
+ old_logit_by_wold_init_by_ot = F.linear(
+ features, F.normalize(self._ot_old_branch, p=2, dim=1)
+ )
+ old_logit_by_wold_init_by_ot = F.log_softmax(
+ old_logit_by_wold_init_by_ot / T, dim=1
+ )
+ old_branch_distill_loss = -torch.mean(
+ torch.sum(hat_pai_k * old_logit_by_wold_init_by_ot, dim=1)
+ )
+ loss = (
+ distill_loss * self.lamda
+ + clf_loss * (1 - self.lamda)
+ + self.args["reg_term"]
+ * (weight_ot_co_tuning * old_branch_distill_loss)
+ )
+ else:
+ loss = clf_loss
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ prog_bar.set_description(info)
+
+ logging.info(info)
+
+ def _extract_class_means(self, data_manager, low, high):
+ self._ot_prototype_means = np.zeros(
+ (data_manager.get_total_classnum(), self._network.feature_dim)
+ )
+ with torch.no_grad():
+ for class_idx in range(low, high):
+ data, targets, idx_dataset = data_manager.get_dataset(
+ np.arange(class_idx, class_idx + 1),
+ source="train",
+ mode="test",
+ ret_data=True,
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ class_mean = np.mean(vectors, axis=0)
+ class_mean = class_mean / (np.linalg.norm(class_mean))
+ self._ot_prototype_means[class_idx, :] = class_mean
+ self._network.train()
+
+ def _extract_class_means_with_memory(self, data_manager, low, high):
+
+ self._ot_prototype_means = np.zeros(
+ (data_manager.get_total_classnum(), self._network.feature_dim)
+ )
+ memoryx, memoryy = self._data_memory, self._targets_memory
+ with torch.no_grad():
+ for class_idx in range(0, low):
+ idxes = np.where(
+ np.logical_and(memoryy >= class_idx, memoryy < class_idx + 1)
+ )[0]
+ data, targets = memoryx[idxes], memoryy[idxes]
+ # idx_dataset=TensorDataset(data,targets)
+ # idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
+ _, _, idx_dataset = data_manager.get_dataset(
+ [],
+ source="train",
+ appendent=(data, targets),
+ mode="test",
+ ret_data=True,
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ class_mean = np.mean(vectors, axis=0)
+ class_mean = class_mean / np.linalg.norm(class_mean)
+ self._ot_prototype_means[class_idx, :] = class_mean
+
+ for class_idx in range(low, high):
+ data, targets, idx_dataset = data_manager.get_dataset(
+ np.arange(class_idx, class_idx + 1),
+ source="train",
+ mode="test",
+ ret_data=True,
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ class_mean = np.mean(vectors, axis=0)
+ class_mean = class_mean / np.linalg.norm(class_mean)
+ self._ot_prototype_means[class_idx, :] = class_mean
+ self._network.train()
diff --git a/models/der.py b/models/der.py
new file mode 100644
index 0000000000000000000000000000000000000000..3943b2e023c3660f3d67420a55f031157860705f
--- /dev/null
+++ b/models/der.py
@@ -0,0 +1,230 @@
+# Please note that the current implementation of DER only contains the dynamic expansion process, since masking and pruning are not implemented by the source repo.
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import DERNet, IncrementalNet
+from utils.toolkit import count_parameters, target2onehot, tensor2numpy
+
+EPSILON = 1e-8
+
+init_epoch = 100
+init_lr = 0.1
+init_milestones = [40, 60, 80]
+init_lr_decay = 0.1
+init_weight_decay = 0.0005
+
+
+epochs = 80
+lrate = 0.1
+milestones = [30, 50, 70]
+lrate_decay = 0.1
+batch_size = 32
+weight_decay = 2e-4
+num_workers = 8
+T = 2
+
+
+class DER(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = DERNet(args, False)
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+ logging.info("Exemplar size: {}".format(self.exemplar_size))
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ if self._cur_task > 0:
+ for i in range(self._cur_task):
+ for p in self._network.convnets[i].parameters():
+ p.requires_grad = False
+
+ logging.info("All params: {}".format(count_parameters(self._network)))
+ logging.info(
+ "Trainable params: {}".format(count_parameters(self._network, True))
+ )
+
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ )
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def train(self):
+ self._network.train()
+ if len(self._multiple_gpus) > 1 :
+ self._network_module_ptr = self._network.module
+ else:
+ self._network_module_ptr = self._network
+ self._network_module_ptr.convnets[-1].train()
+ if self._cur_task >= 1:
+ for i in range(self._cur_task):
+ self._network_module_ptr.convnets[i].eval()
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._network.parameters()),
+ momentum=0.9,
+ lr=init_lr,
+ weight_decay=init_weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._network.parameters()),
+ lr=lrate,
+ momentum=0.9,
+ weight_decay=weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+ if len(self._multiple_gpus) > 1:
+ self._network.module.weight_align(
+ self._total_classes - self._known_classes
+ )
+ else:
+ self._network.weight_align(self._total_classes - self._known_classes)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(init_epoch))
+ for _, epoch in enumerate(prog_bar):
+ self.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+
+ logging.info(info)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(epochs))
+ for _, epoch in enumerate(prog_bar):
+ self.train()
+ losses = 0.0
+ losses_clf = 0.0
+ losses_aux = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ outputs = self._network(inputs)
+ logits, aux_logits = outputs["logits"], outputs["aux_logits"]
+ loss_clf = F.cross_entropy(logits, targets)
+ aux_targets = targets.clone()
+ aux_targets = torch.where(
+ aux_targets - self._known_classes + 1 > 0,
+ aux_targets - self._known_classes + 1,
+ torch.tensor([0]).to(self._device),
+ )
+ loss_aux = F.cross_entropy(aux_logits, aux_targets)
+ loss = loss_clf + loss_aux
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ losses_aux += loss_aux.item()
+ losses_clf += loss_clf.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ losses_clf / len(train_loader),
+ losses_aux / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ losses_clf / len(train_loader),
+ losses_aux / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
diff --git a/models/ewc.py b/models/ewc.py
new file mode 100644
index 0000000000000000000000000000000000000000..5493e09eca5449ef299025b1518932285883a911
--- /dev/null
+++ b/models/ewc.py
@@ -0,0 +1,254 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from models.podnet import pod_spatial_loss
+from utils.inc_net import IncrementalNet
+from utils.toolkit import target2onehot, tensor2numpy
+
+EPSILON = 1e-8
+
+init_epoch = 200
+init_lr = 0.1
+init_milestones = [60, 120, 170]
+init_lr_decay = 0.1
+init_weight_decay = 0.0005
+
+
+epochs = 180
+lrate = 0.1
+milestones = [70, 120, 150]
+lrate_decay = 0.1
+batch_size = 128
+weight_decay = 2e-4
+num_workers = 4
+T = 2
+lamda = 1000
+fishermax = 0.0001
+
+
+class EWC(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self.fisher = None
+ self._network = IncrementalNet(args, False)
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ )
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ if self.fisher is None:
+ self.fisher = self.getFisherDiagonal(self.train_loader)
+ else:
+ alpha = self._known_classes / self._total_classes
+ new_finsher = self.getFisherDiagonal(self.train_loader)
+ for n, p in new_finsher.items():
+ new_finsher[n][: len(self.fisher[n])] = (
+ alpha * self.fisher[n]
+ + (1 - alpha) * new_finsher[n][: len(self.fisher[n])]
+ )
+ self.fisher = new_finsher
+ self.mean = {
+ n: p.clone().detach()
+ for n, p in self._network.named_parameters()
+ if p.requires_grad
+ }
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ momentum=0.9,
+ lr=init_lr,
+ weight_decay=init_weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ lr=lrate,
+ momentum=0.9,
+ weight_decay=weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(init_epoch))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ )
+
+ prog_bar.set_description(info)
+
+ logging.info(info)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(epochs))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss_clf = F.cross_entropy(
+ logits[:, self._known_classes :], targets - self._known_classes
+ )
+ loss_ewc = self.compute_ewc()
+ loss = loss_clf + lamda * loss_ewc
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
+
+ def compute_ewc(self):
+ loss = 0
+ if len(self._multiple_gpus) > 1:
+ for n, p in self._network.module.named_parameters():
+ if n in self.fisher.keys():
+ loss += (
+ torch.sum(
+ (self.fisher[n])
+ * (p[: len(self.mean[n])] - self.mean[n]).pow(2)
+ )
+ / 2
+ )
+ else:
+ for n, p in self._network.named_parameters():
+ if n in self.fisher.keys():
+ loss += (
+ torch.sum(
+ (self.fisher[n])
+ * (p[: len(self.mean[n])] - self.mean[n]).pow(2)
+ )
+ / 2
+ )
+ return loss
+
+ def getFisherDiagonal(self, train_loader):
+ fisher = {
+ n: torch.zeros(p.shape).to(self._device)
+ for n, p in self._network.named_parameters()
+ if p.requires_grad
+ }
+ self._network.train()
+ optimizer = optim.SGD(self._network.parameters(), lr=lrate)
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+ loss = torch.nn.functional.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ for n, p in self._network.named_parameters():
+ if p.grad is not None:
+ fisher[n] += p.grad.pow(2).clone()
+ for n, p in fisher.items():
+ fisher[n] = p / len(train_loader)
+ fisher[n] = torch.min(fisher[n], torch.tensor(fishermax))
+ return fisher
diff --git a/models/fetril.py b/models/fetril.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cdec8a43778a89ebdd29d34eb983dbbf9caeb0e
--- /dev/null
+++ b/models/fetril.py
@@ -0,0 +1,227 @@
+'''
+
+results on CIFAR-100:
+
+ | Reported Resnet18 | Reproduced Resnet32
+Protocols | Reported FC | Reported SVM | Reproduced FC | Reproduced SVM |
+
+T = 5 | 64.7 | 66.3 | 65.775 | 65.375 |
+
+T = 10 | 63.4 | 65.2 | 64.91 | 65.10 |
+
+T = 60 | 50.8 | 59.8 | 62.09 | 61.72 |
+
+'''
+
+
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader,Dataset
+from models.base import BaseLearner
+from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet
+from utils.toolkit import count_parameters, target2onehot, tensor2numpy
+from sklearn.svm import LinearSVC
+from torchvision import datasets, transforms
+from utils.autoaugment import CIFAR10Policy,ImageNetPolicy
+from utils.ops import Cutout
+
+EPSILON = 1e-8
+
+
+class FeTrIL(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self.args = args
+ self._network = IncrementalNet(args, False)
+ self._means = []
+ self._svm_accs = []
+
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+
+ def incremental_train(self, data_manager):
+ self.data_manager = data_manager
+ self.data_manager._train_trsf = [
+ transforms.RandomHorizontalFlip(),
+ transforms.ColorJitter(brightness=63/255),
+ ImageNetPolicy(),
+ Cutout(n_holes=1, length=16),
+ ]
+ self._cur_task += 1
+
+ self._total_classes = self._known_classes + \
+ data_manager.get_task_size(self._cur_task)
+ self._network.update_fc(self._total_classes)
+ self._network_module_ptr = self._network
+ logging.info(
+ 'Learning on {}-{}'.format(self._known_classes, self._total_classes))
+
+ if self._cur_task > 0:
+ for p in self._network.convnet.parameters():
+ p.requires_grad = False
+
+ logging.info('All params: {}'.format(count_parameters(self._network)))
+ logging.info('Trainable params: {}'.format(
+ count_parameters(self._network, True)))
+
+ train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train',
+ mode='train', appendent=self._get_memory())
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True)
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source='test', mode='test')
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"])
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if hasattr(self._network, "module"):
+ self._network_module_ptr = self._network.module
+ if self._cur_task == 0:
+ self._epoch_num = self.args["init_epochs"]
+ optimizer = optim.SGD(filter(lambda p: p.requires_grad, self._network.parameters(
+ )), momentum=0.9, lr=self.args["init_lr"], weight_decay=self.args["init_weight_decay"])
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=self.args["init_epochs"])
+ self._train_function(train_loader, test_loader, optimizer, scheduler)
+ self._compute_means()
+ self._build_feature_set()
+ else:
+ self._epoch_num = self.args["epochs"]
+ self._compute_means()
+ self._compute_relations()
+ self._build_feature_set()
+
+ train_loader = DataLoader(self._feature_trainset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True)
+ optimizer = optim.SGD(self._network_module_ptr.fc.parameters(),momentum=0.9,lr=self.args["lr"],weight_decay=self.args["weight_decay"])
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max = self.args["epochs"])
+
+ self._train_function(train_loader, test_loader, optimizer, scheduler)
+ self._train_svm(self._feature_trainset,self._feature_testset)
+
+
+ def _compute_means(self):
+ with torch.no_grad():
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
+ mode='test', ret_data=True)
+ idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(idx_loader)
+ class_mean = np.mean(vectors, axis=0)
+ self._means.append(class_mean)
+
+ def _compute_relations(self):
+ old_means = np.array(self._means[:self._known_classes])
+ new_means = np.array(self._means[self._known_classes:])
+ self._relations=np.argmax((old_means/np.linalg.norm(old_means,axis=1)[:,None])@(new_means/np.linalg.norm(new_means,axis=1)[:,None]).T,axis=1)+self._known_classes
+ def _build_feature_set(self):
+ self.vectors_train = []
+ self.labels_train = []
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
+ mode='test', ret_data=True)
+ idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(idx_loader)
+ self.vectors_train.append(vectors)
+ self.labels_train.append([class_idx]*len(vectors))
+ for class_idx in range(0,self._known_classes):
+ new_idx = self._relations[class_idx]
+ self.vectors_train.append(self.vectors_train[new_idx-self._known_classes]-self._means[new_idx]+self._means[class_idx])
+ self.labels_train.append([class_idx]*len(self.vectors_train[-1]))
+
+ self.vectors_train = np.concatenate(self.vectors_train)
+ self.labels_train = np.concatenate(self.labels_train)
+ self._feature_trainset = FeatureDataset(self.vectors_train,self.labels_train)
+
+ self.vectors_test = []
+ self.labels_test = []
+ for class_idx in range(0, self._total_classes):
+ data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='test',
+ mode='test', ret_data=True)
+ idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(idx_loader)
+ self.vectors_test.append(vectors)
+ self.labels_test.append([class_idx]*len(vectors))
+ self.vectors_test = np.concatenate(self.vectors_test)
+ self.labels_test = np.concatenate(self.labels_test)
+
+ self._feature_testset = FeatureDataset(self.vectors_test,self.labels_test)
+
+ def _train_function(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self._epoch_num))
+ for _, epoch in enumerate(prog_bar):
+ if self._cur_task == 0:
+ self._network.train()
+ else:
+ self._network.eval()
+ losses = 0.
+ correct, total = 0, 0
+ for i, _, inputs, targets in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True), targets.to(self._device, non_blocking=True)
+ if self._cur_task ==0:
+ logits = self._network(inputs)['logits']
+ else:
+ logits = self._network_module_ptr.fc(inputs)['logits']
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(
+ correct)*100 / total, decimals=2)
+ if epoch % 5 != 0:
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), train_acc)
+ else:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), train_acc, test_acc)
+ prog_bar.set_description(info)
+ logging.info(info)
+ def _train_svm(self,train_set,test_set):
+ train_features = train_set.features.numpy()
+ train_labels = train_set.labels.numpy()
+ test_features = test_set.features.numpy()
+ test_labels = test_set.labels.numpy()
+ train_features = train_features/np.linalg.norm(train_features,axis=1)[:,None]
+ test_features = test_features/np.linalg.norm(test_features,axis=1)[:,None]
+ svm_classifier = LinearSVC(random_state=42)
+ svm_classifier.fit(train_features,train_labels)
+ logging.info("svm train: acc: {}".format(np.around(svm_classifier.score(train_features,train_labels)*100,decimals=2)))
+ acc = svm_classifier.score(test_features,test_labels)
+ self._svm_accs.append(np.around(acc*100,decimals=2))
+ logging.info("svm evaluation: acc_list: {}".format(self._svm_accs))
+
+class FeatureDataset(Dataset):
+ def __init__(self, features, labels):
+ assert len(features) == len(labels), "Data size error!"
+ self.features = torch.from_numpy(features)
+ self.labels = torch.from_numpy(labels)
+
+ def __len__(self):
+ return len(self.features)
+
+ def __getitem__(self, idx):
+ feature = self.features[idx]
+ label = self.labels[idx]
+
+ return idx, feature, label
diff --git a/models/finetune.py b/models/finetune.py
new file mode 100644
index 0000000000000000000000000000000000000000..b85338f62e163b432f06dce8a16c5206330bf6af
--- /dev/null
+++ b/models/finetune.py
@@ -0,0 +1,206 @@
+import logging
+import numpy as np
+import torch
+import copy
+from torch import nn
+from torch.serialization import load
+from tqdm import tqdm
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from utils.inc_net import IncrementalNet
+from models.base import BaseLearner
+from utils.toolkit import target2onehot, tensor2numpy
+
+
+init_epoch = 100
+init_lr = 0.1
+init_milestones = [40, 60, 80]
+init_lr_decay = 0.1
+init_weight_decay = 0.0005
+
+
+epochs = 80
+lrate = 0.1
+milestones = [40, 70]
+lrate_decay = 0.1
+batch_size = 32
+weight_decay = 2e-4
+num_workers = 8
+
+
+class Finetune(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = IncrementalNet(args, False)
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+
+ def save_checkpoint(self, test_acc):
+ assert self.args['model_name'] == 'finetune'
+ checkpoint_name = f"models/finetune/{self.args['csv_name']}"
+ _checkpoint_cpu = copy.deepcopy(self._network)
+ if isinstance(_checkpoint_cpu, nn.DataParallel):
+ _checkpoint_cpu = _checkpoint_cpu.module
+ _checkpoint_cpu.cpu()
+ save_dict = {
+ "tasks": self._cur_task,
+ "convnet": _checkpoint_cpu.convnet.state_dict(),
+ "fc":_checkpoint_cpu.fc.state_dict(),
+ "test_acc": test_acc
+ }
+ torch.save(save_dict, "{}_{}.pkl".format(checkpoint_name, self._cur_task))
+
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ )
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ momentum=0.9,
+ lr=init_lr,
+ weight_decay=init_weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ lr=lrate,
+ momentum=0.9,
+ weight_decay=weight_decay,
+ ) # 1e-5
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(init_epoch))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ )
+
+ prog_bar.set_description(info)
+
+ logging.info(info)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+
+ prog_bar = tqdm(range(epochs))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ fake_targets = targets - self._known_classes
+ loss_clf = F.cross_entropy(
+ logits[:, self._known_classes :], fake_targets
+ )
+
+ loss = loss_clf
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
diff --git a/models/foster.py b/models/foster.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9c2ea78bcd4e3b47bcf8ec54d5ce84bef62b61a
--- /dev/null
+++ b/models/foster.py
@@ -0,0 +1,435 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import FOSTERNet
+from utils.toolkit import count_parameters, target2onehot, tensor2numpy
+
+# Please refer to https://github.com/G-U-N/ECCV22-FOSTER for the full source code to reproduce foster.
+
+EPSILON = 1e-8
+
+
+class FOSTER(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self.args = args
+ self._network = FOSTERNet(args, False)
+ self._snet = None
+ self.beta1 = args["beta1"]
+ self.beta2 = args["beta2"]
+ self.per_cls_weights = None
+ self.is_teacher_wa = args["is_teacher_wa"]
+ self.is_student_wa = args["is_student_wa"]
+ self.lambda_okd = args["lambda_okd"]
+ self.wa_value = args["wa_value"]
+ self.oofc = args["oofc"].lower()
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+ logging.info("Exemplar size: {}".format(self.exemplar_size))
+
+ def load_checkpoint(self, filename):
+ checkpoint = torch.load(filename)
+ self._known_classes = len(checkpoint["classes"])
+ self.class_list = np.array(checkpoint["classes"])
+ self.label_list = checkpoint["label_list"]
+ self._network.update_fc(self._known_classes)
+ self._network.load_checkpoint(checkpoint["network"])
+ self._network.to(self._device)
+ self._cur_task = 0
+
+ def save_checkpoint(self, filename):
+ self._network.cpu()
+ save_dict = {
+ "classes": self.data_manager.get_class_list(self._cur_task),
+ "network": {
+ "convnet": self._network.convnets[0].state_dict(),
+ "fc": self._network.fc.state_dict()
+ },
+ "label_list": self.data_manager.get_label_list(self._cur_task),
+ }
+ torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task))
+ def incremental_train(self, data_manager):
+ self.data_manager = data_manager
+ if hasattr(self.data_manager,'label_list') and hasattr(self,'label_list'):
+ self.data_manager.label_list = list(self.label_list.values()) + self.data_manager.label_list
+ self._cur_task += 1
+ if self._cur_task > 1:
+ self._network = self._snet
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ self._network_module_ptr = self._network
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ if self._cur_task > 0:
+ for p in self._network.convnets[0].parameters():
+ p.requires_grad = False
+ for p in self._network.oldfc.parameters():
+ p.requires_grad = False
+
+ logging.info("All params: {}".format(count_parameters(self._network)))
+ logging.info(
+ "Trainable params: {}".format(count_parameters(self._network, True))
+ )
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ )
+
+ self.train_loader = DataLoader(
+ train_dataset,
+ batch_size=self.args["batch_size"],
+ shuffle=True,
+ num_workers=self.args["num_workers"],
+ pin_memory=True,
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset,
+ batch_size=self.args["batch_size"],
+ shuffle=False,
+ num_workers=self.args["num_workers"],
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ #self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def train(self):
+ self._network_module_ptr.train()
+ self._network_module_ptr.convnets[-1].train()
+ if self._cur_task >= 1:
+ self._network_module_ptr.convnets[0].eval()
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if hasattr(self._network, "module"):
+ self._network_module_ptr = self._network.module
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._network.parameters()),
+ momentum=0.9,
+ lr=self.args["init_lr"],
+ weight_decay=self.args["init_weight_decay"],
+ )
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=self.args["init_epochs"]
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ cls_num_list = [self.samples_old_class] * self._known_classes + [
+ self.samples_new_class(i)
+ for i in range(self._known_classes, self._total_classes)
+ ]
+ effective_num = 1.0 - np.power(self.beta1, cls_num_list)
+ per_cls_weights = (1.0 - self.beta1) / np.array(effective_num)
+ per_cls_weights = (
+ per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
+ )
+
+ logging.info("per cls weights : {}".format(per_cls_weights))
+ self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device)
+
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._network.parameters()),
+ lr=self.args["lr"],
+ momentum=0.9,
+ weight_decay=self.args["weight_decay"],
+ )
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=self.args["boosting_epochs"]
+ )
+ if self.oofc == "az":
+ for i, p in enumerate(self._network_module_ptr.fc.parameters()):
+ if i == 0:
+ p.data[
+ self._known_classes :, : self._network_module_ptr.out_dim
+ ] = torch.tensor(0.0)
+ elif self.oofc != "ft":
+ assert 0, "not implemented"
+ self._feature_boosting(train_loader, test_loader, optimizer, scheduler)
+ if self.is_teacher_wa:
+ self._network_module_ptr.weight_align(
+ self._known_classes,
+ self._total_classes - self._known_classes,
+ self.wa_value,
+ )
+ else:
+ logging.info("do not weight align teacher!")
+
+ cls_num_list = [self.samples_old_class] * self._known_classes + [
+ self.samples_new_class(i)
+ for i in range(self._known_classes, self._total_classes)
+ ]
+ effective_num = 1.0 - np.power(self.beta2, cls_num_list)
+ per_cls_weights = (1.0 - self.beta2) / np.array(effective_num)
+ per_cls_weights = (
+ per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
+ )
+ logging.info("per cls weights : {}".format(per_cls_weights))
+ self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device)
+ self._feature_compression(train_loader, test_loader)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self.args["init_epochs"]))
+ for _, epoch in enumerate(prog_bar):
+ self.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True
+ ), targets.to(self._device, non_blocking=True)
+ logits = self._network(inputs)["logits"]
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args["init_epochs"],
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args["init_epochs"],
+ losses / len(train_loader),
+ train_acc,
+ )
+
+ prog_bar.set_description(info)
+ logging.info(info)
+
+ def _feature_boosting(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self.args["boosting_epochs"]))
+ for _, epoch in enumerate(prog_bar):
+ self.train()
+ losses = 0.0
+ losses_clf = 0.0
+ losses_fe = 0.0
+ losses_kd = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True
+ ), targets.to(self._device, non_blocking=True)
+ outputs = self._network(inputs)
+ logits, fe_logits, old_logits = (
+ outputs["logits"],
+ outputs["fe_logits"],
+ outputs["old_logits"].detach(),
+ )
+ loss_clf = F.cross_entropy(logits / self.per_cls_weights, targets)
+ loss_fe = F.cross_entropy(fe_logits, targets)
+ loss_kd = self.lambda_okd * _KD_loss(
+ logits[:, : self._known_classes], old_logits, self.args["T"]
+ )
+ loss = loss_clf + loss_fe + loss_kd
+ optimizer.zero_grad()
+ loss.backward()
+ if self.oofc == "az":
+ for i, p in enumerate(self._network_module_ptr.fc.parameters()):
+ if i == 0:
+ p.grad.data[
+ self._known_classes :,
+ : self._network_module_ptr.out_dim,
+ ] = torch.tensor(0.0)
+ elif self.oofc != "ft":
+ assert 0, "not implemented"
+ optimizer.step()
+ losses += loss.item()
+ losses_fe += loss_fe.item()
+ losses_clf += loss_clf.item()
+ losses_kd += (
+ self._known_classes / self._total_classes
+ ) * loss_kd.item()
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args["boosting_epochs"],
+ losses / len(train_loader),
+ losses_clf / len(train_loader),
+ losses_fe / len(train_loader),
+ losses_kd / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args["boosting_epochs"],
+ losses / len(train_loader),
+ losses_clf / len(train_loader),
+ losses_fe / len(train_loader),
+ losses_kd / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
+
+ def _feature_compression(self, train_loader, test_loader):
+ self._snet = FOSTERNet(self.args, False)
+ self._snet.update_fc(self._total_classes)
+ if len(self._multiple_gpus) > 1:
+ self._snet = nn.DataParallel(self._snet, self._multiple_gpus)
+ if hasattr(self._snet, "module"):
+ self._snet_module_ptr = self._snet.module
+ else:
+ self._snet_module_ptr = self._snet
+ self._snet.to(self._device)
+ self._snet_module_ptr.convnets[0].load_state_dict(
+ self._network_module_ptr.convnets[0].state_dict()
+ )
+ self._snet_module_ptr.copy_fc(self._network_module_ptr.oldfc)
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._snet.parameters()),
+ lr=self.args["lr"],
+ momentum=0.9,
+ )
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=self.args["compression_epochs"]
+ )
+ self._network.eval()
+ prog_bar = tqdm(range(self.args["compression_epochs"]))
+ for _, epoch in enumerate(prog_bar):
+ self._snet.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True
+ ), targets.to(self._device, non_blocking=True)
+ dark_logits = self._snet(inputs)["logits"]
+ with torch.no_grad():
+ outputs = self._network(inputs)
+ logits, old_logits, fe_logits = (
+ outputs["logits"],
+ outputs["old_logits"],
+ outputs["fe_logits"],
+ )
+ loss_dark = self.BKD(dark_logits, logits, self.args["T"])
+ loss = loss_dark
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ _, preds = torch.max(dark_logits[: targets.shape[0]], dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._snet, test_loader)
+ info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args["compression_epochs"],
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args["compression_epochs"],
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
+ if len(self._multiple_gpus) > 1:
+ self._snet = self._snet.module
+ if self.is_student_wa:
+ self._snet.weight_align(
+ self._known_classes,
+ self._total_classes - self._known_classes,
+ self.wa_value,
+ )
+ else:
+ logging.info("do not weight align student!")
+ if self._cur_task > 1:
+ self._network = self._snet
+ self._snet.eval()
+ y_pred, y_true = [], []
+ for _, (_, inputs, targets) in enumerate(test_loader):
+ inputs = inputs.to(self._device, non_blocking=True)
+ with torch.no_grad():
+ outputs = self._snet(inputs)["logits"]
+ predicts = torch.topk(
+ outputs, k=self.topk, dim=1, largest=True, sorted=True
+ )[1]
+ y_pred.append(predicts.cpu().numpy())
+ y_true.append(targets.cpu().numpy())
+ y_pred = np.concatenate(y_pred)
+ y_true = np.concatenate(y_true)
+ cnn_accy = self._evaluate(y_pred, y_true)
+ logging.info("darknet eval: ")
+ logging.info("CNN top1 curve: {}".format(cnn_accy["top1"]))
+ logging.info("CNN top5 curve: {}".format(cnn_accy["top5"]))
+
+ @property
+ def samples_old_class(self):
+ if self._fixed_memory:
+ return self._memory_per_class
+ else:
+ assert self._total_classes != 0, "Total classes is 0"
+ return self._memory_size // self._known_classes
+
+ def samples_new_class(self, index):
+ if self.args["dataset"] == "cifar100":
+ return 500
+ else:
+ return self.data_manager.getlen(index)
+
+ def BKD(self, pred, soft, T):
+ pred = torch.log_softmax(pred / T, dim=1)
+ soft = torch.softmax(soft / T, dim=1)
+ soft = soft * self.per_cls_weights
+ soft = soft / soft.sum(1)[:, None]
+ return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
+
+
+def _KD_loss(pred, soft, T):
+ pred = torch.log_softmax(pred / T, dim=1)
+ soft = torch.softmax(soft / T, dim=1)
+ return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
diff --git a/models/gem.py b/models/gem.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d42769ec12ac43a8a7d2fa1590eedbab89c6d39
--- /dev/null
+++ b/models/gem.py
@@ -0,0 +1,304 @@
+import logging
+import numpy as np
+from torch._C import device
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import IncrementalNet
+from utils.inc_net import CosineIncrementalNet
+from utils.toolkit import target2onehot, tensor2numpy
+try:
+ from quadprog import solve_qp
+except:
+ pass
+
+
+EPSILON = 1e-8
+
+
+init_epoch = 1
+init_lr = 0.1
+init_milestones = [40, 60, 80]
+init_lr_decay = 0.1
+init_weight_decay = 0.0005
+
+
+epochs = 1
+lrate = 0.1
+milestones = [20, 40, 60]
+lrate_decay = 0.1
+batch_size = 16
+weight_decay = 2e-4
+num_workers = 4
+
+
+class GEM(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = IncrementalNet(args, False)
+ self.previous_data = None
+ self.previous_label = None
+
+ def after_task(self):
+ self._old_network = self._network.copy().freeze()
+ self._known_classes = self._total_classes
+ logging.info("Exemplar size: {}".format(self.exemplar_size))
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ )
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ if self._cur_task > 0:
+ previous_dataset = data_manager.get_dataset(
+ [], source="train", mode="train", appendent=self._get_memory()
+ )
+
+ self.previous_data = []
+ self.previous_label = []
+ for i in previous_dataset:
+ _, data_, label_ = i
+ self.previous_data.append(data_)
+ self.previous_label.append(label_)
+ self.previous_data = torch.stack(self.previous_data)
+ self.previous_label = torch.tensor(self.previous_label)
+ # Procedure
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._old_network is not None:
+ self._old_network.to(self._device)
+
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ momentum=0.9,
+ lr=init_lr,
+ weight_decay=init_weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ lr=lrate,
+ momentum=0.9,
+ weight_decay=weight_decay,
+ ) # 1e-5
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(init_epoch))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ )
+
+ prog_bar.set_description(info)
+
+ logging.info(info)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(epochs))
+ grad_numels = []
+ for params in self._network.parameters():
+ grad_numels.append(params.data.numel())
+ G = torch.zeros((sum(grad_numels), self._cur_task + 1)).to(self._device)
+
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ incremental_step = self._total_classes - self._known_classes
+ for k in range(0, self._cur_task):
+ optimizer.zero_grad()
+ mask = torch.where(
+ (self.previous_label >= k * incremental_step)
+ & (self.previous_label < (k + 1) * incremental_step)
+ )[0]
+ data_ = self.previous_data[mask].to(self._device)
+ label_ = self.previous_label[mask].to(self._device)
+ pred_ = self._network(data_)["logits"]
+ pred_[:, : k * incremental_step].data.fill_(-10e10)
+ pred_[:, (k + 1) * incremental_step :].data.fill_(-10e10)
+ loss_ = F.cross_entropy(pred_, label_)
+ loss_.backward()
+
+ j = 0
+ for params in self._network.parameters():
+ if params is not None:
+ if j == 0:
+ stpt = 0
+ else:
+ stpt = sum(grad_numels[:j])
+
+ endpt = sum(grad_numels[: j + 1])
+ G[stpt:endpt, k].data.copy_(params.grad.data.view(-1))
+ j += 1
+
+ optimizer.zero_grad()
+
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+ logits[:, : self._known_classes].data.fill_(-10e10)
+ loss_clf = F.cross_entropy(logits, targets)
+
+ loss = loss_clf
+
+ optimizer.zero_grad()
+ loss.backward()
+
+ j = 0
+ for params in self._network.parameters():
+ if params is not None:
+ if j == 0:
+ stpt = 0
+ else:
+ stpt = sum(grad_numels[:j])
+
+ endpt = sum(grad_numels[: j + 1])
+ G[stpt:endpt, self._cur_task].data.copy_(
+ params.grad.data.view(-1)
+ )
+ j += 1
+
+ dotprod = torch.mm(
+ G[:, self._cur_task].unsqueeze(0), G[:, : self._cur_task]
+ )
+
+ if (dotprod < 0).sum() > 0:
+
+ old_grad = G[:, : self._cur_task].cpu().t().double().numpy()
+ cur_grad = G[:, self._cur_task].cpu().contiguous().double().numpy()
+
+ C = old_grad @ old_grad.T
+ p = old_grad @ cur_grad
+ A = np.eye(old_grad.shape[0])
+ b = np.zeros(old_grad.shape[0])
+
+ v = solve_qp(C, -p, A, b)[0]
+
+ new_grad = old_grad.T @ v + cur_grad
+ new_grad = torch.tensor(new_grad).float().to(self._device)
+
+ new_dotprod = torch.mm(
+ new_grad.unsqueeze(0), G[:, : self._cur_task]
+ )
+ if (new_dotprod < -0.01).sum() > 0:
+ assert 0
+ j = 0
+ for params in self._network.parameters():
+ if params is not None:
+ if j == 0:
+ stpt = 0
+ else:
+ stpt = sum(grad_numels[:j])
+
+ endpt = sum(grad_numels[: j + 1])
+ params.grad.data.copy_(
+ new_grad[stpt:endpt]
+ .contiguous()
+ .view(params.grad.data.size())
+ )
+ j += 1
+
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
diff --git a/models/icarl.py b/models/icarl.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd400b0b7d2a62ed58603eea5bf58f2a60c9545a
--- /dev/null
+++ b/models/icarl.py
@@ -0,0 +1,205 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import IncrementalNet
+from utils.inc_net import CosineIncrementalNet
+from utils.toolkit import target2onehot, tensor2numpy
+
+EPSILON = 1e-8
+
+init_epoch = 100
+init_lr = 0.1
+init_milestones = [40, 60, 80]
+init_lr_decay = 0.1
+init_weight_decay = 0.0005
+
+
+epochs = 80
+lrate = 0.1
+milestones = [40, 60]
+lrate_decay = 0.1
+batch_size = 32
+weight_decay = 2e-4
+num_workers = 8
+T = 2
+
+
+class iCaRL(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = IncrementalNet(args, False)
+
+ def after_task(self):
+ self._old_network = self._network.copy().freeze()
+ self._known_classes = self._total_classes
+ logging.info("Exemplar size: {}".format(self.exemplar_size))
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ )
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._old_network is not None:
+ self._old_network.to(self._device)
+
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ momentum=0.9,
+ lr=init_lr,
+ weight_decay=init_weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ lr=lrate,
+ momentum=0.9,
+ weight_decay=weight_decay,
+ ) # 1e-5
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(init_epoch))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ )
+
+ prog_bar.set_description(info)
+
+ logging.info(info)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(epochs))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss_clf = F.cross_entropy(logits, targets)
+ loss_kd = _KD_loss(
+ logits[:, : self._known_classes],
+ self._old_network(inputs)["logits"],
+ T,
+ )
+
+ loss = loss_clf + loss_kd
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
+
+
+def _KD_loss(pred, soft, T):
+ pred = torch.log_softmax(pred / T, dim=1)
+ soft = torch.softmax(soft / T, dim=1)
+ return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
diff --git a/models/il2a.py b/models/il2a.py
new file mode 100644
index 0000000000000000000000000000000000000000..45a9ded99ecaf32ad5f51cca7cd691a2f9774021
--- /dev/null
+++ b/models/il2a.py
@@ -0,0 +1,250 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader,Dataset
+from models.base import BaseLearner
+from utils.inc_net import CosineIncrementalNet, FOSTERNet, IL2ANet, IncrementalNet
+from utils.toolkit import count_parameters, target2onehot, tensor2numpy
+
+EPSILON = 1e-8
+
+
+class IL2A(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self.args = args
+ self._network = IL2ANet(args, False)
+ self._protos = []
+ self._covs = []
+
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+ self._old_network = self._network.copy().freeze()
+ if hasattr(self._old_network,"module"):
+ self.old_network_module_ptr = self._old_network.module
+ else:
+ self.old_network_module_ptr = self._old_network
+ #self.save_checkpoint("{}_{}_{}".format(self.args["model_name"],self.args["init_cls"],self.args["increment"]))
+ def incremental_train(self, data_manager):
+ self.data_manager = data_manager
+ self._cur_task += 1
+
+ task_size = self.data_manager.get_task_size(self._cur_task)
+ self._total_classes = self._known_classes + task_size
+ self._network.update_fc(self._known_classes,self._total_classes,int((task_size-1)*task_size/2))
+ self._network_module_ptr = self._network
+ logging.info(
+ 'Learning on {}-{}'.format(self._known_classes, self._total_classes))
+
+
+ logging.info('All params: {}'.format(count_parameters(self._network)))
+ logging.info('Trainable params: {}'.format(
+ count_parameters(self._network, True)))
+
+ train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train',
+ mode='train', appendent=self._get_memory())
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True)
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source='test', mode='test')
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"])
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+
+ def _train(self, train_loader, test_loader):
+
+ resume = False
+ if self._cur_task in []:
+ self._network.load_state_dict(torch.load("{}_{}_{}_{}.pkl".format(self.args["model_name"],self.args["init_cls"],self.args["increment"],self._cur_task))["model_state_dict"])
+ resume = True
+ self._network.to(self._device)
+ if hasattr(self._network, "module"):
+ self._network_module_ptr = self._network.module
+ if not resume:
+ self._epoch_num = self.args["epochs"]
+ optimizer = torch.optim.Adam(self._network.parameters(), lr=self.args["lr"], weight_decay=self.args["weight_decay"])
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args["step_size"], gamma=self.args["gamma"])
+ self._train_function(train_loader, test_loader, optimizer, scheduler)
+ self._build_protos()
+
+
+ def _build_protos(self):
+ with torch.no_grad():
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
+ mode='test', ret_data=True)
+ idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(idx_loader)
+ class_mean = np.mean(vectors, axis=0)
+ self._protos.append(class_mean)
+ cov = np.cov(vectors.T)
+ self._covs.append(cov)
+
+ def _train_function(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self._epoch_num))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.
+ losses_clf, losses_fkd, losses_proto = 0., 0., 0.
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True), targets.to(self._device, non_blocking=True)
+ inputs,targets = self._class_aug(inputs,targets)
+ logits, loss_clf, loss_fkd, loss_proto = self._compute_il2a_loss(inputs,targets)
+ loss = loss_clf + loss_fkd + loss_proto
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ losses_clf += loss_clf.item()
+ losses_fkd += loss_fkd.item()
+ losses_proto += loss_proto.item()
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(
+ correct)*100 / total, decimals=2)
+ if epoch % 5 != 0:
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc)
+ else:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc, test_acc)
+ prog_bar.set_description(info)
+ logging.info(info)
+
+ def _compute_il2a_loss(self,inputs, targets):
+ logits = self._network(inputs)["logits"]
+ loss_clf = F.cross_entropy(logits/self.args["temp"], targets)
+
+ if self._cur_task == 0:
+ return logits, loss_clf, torch.tensor(0.), torch.tensor(0.)
+
+ features = self._network_module_ptr.extract_vector(inputs)
+ features_old = self.old_network_module_ptr.extract_vector(inputs)
+ loss_fkd = self.args["lambda_fkd"] * torch.dist(features, features_old, 2)
+
+ index = np.random.choice(range(self._known_classes),size=self.args["batch_size"],replace=True)
+
+ proto_features = np.array(self._protos)[index]
+ proto_targets = index
+ proto_features = torch.from_numpy(proto_features).float().to(self._device,non_blocking=True)
+ proto_targets = torch.from_numpy(proto_targets).to(self._device,non_blocking=True)
+
+ proto_logits = self._network_module_ptr.fc(proto_features)["logits"][:,:self._total_classes]
+
+
+ proto_logits = self._semantic_aug(proto_logits,proto_targets,self.args["ratio"])
+
+ loss_proto = self.args["lambda_proto"] * F.cross_entropy(proto_logits/self.args["temp"], proto_targets)
+ return logits, loss_clf, loss_fkd, loss_proto
+
+
+ def _semantic_aug(self,proto_logits,proto_targets,ratio):
+ # weight_fc = self._network_module_ptr.fc.weight.data[:self._total_classes] # don't use it ! data is not involved in back propagation
+ weight_fc = self._network_module_ptr.fc.weight[:self._total_classes]
+ N,C,D = self.args["batch_size"], self._total_classes, weight_fc.shape[1]
+
+ N_weight = weight_fc.expand(N,C,D) # NCD
+ N_target_weight = torch.gather(N_weight, 1, proto_targets[:,None,None].expand(N,C,D)) # NCD
+ N_v = N_weight-N_target_weight
+ N_cov = torch.from_numpy(np.array(self._covs))[proto_targets].float().to(self._device) # NDD
+
+ proto_logits = proto_logits + ratio/2* torch.diagonal(N_v @ N_cov @ N_v.permute(0,2,1),dim1=1,dim2=2) # NC
+
+ return proto_logits
+
+
+
+
+ def _class_aug(self,inputs,targets,alpha=20., mix_time=4):
+
+ mixup_inputs = []
+ mixup_targets = []
+ for _ in range(mix_time):
+ index = torch.randperm(inputs.shape[0])
+ perm_inputs = inputs[index]
+ perm_targets = targets[index]
+ mask = perm_targets!= targets
+
+ select_inputs = inputs[mask]
+ select_targets = targets[mask]
+ perm_inputs = perm_inputs[mask]
+ perm_targets = perm_targets[mask]
+
+ lams = np.random.beta(alpha,alpha,sum(mask))
+ lams = np.where((lams<0.4)|(lams>0.6),0.5,lams)
+ lams = torch.from_numpy(lams).to(self._device)[:,None,None,None].float()
+
+
+ mixup_inputs.append(lams*select_inputs+(1-lams)*perm_inputs)
+ mixup_targets.append(self._map_targets(select_targets,perm_targets))
+ mixup_inputs = torch.cat(mixup_inputs,dim=0)
+ mixup_targets = torch.cat(mixup_targets,dim=0)
+
+ inputs = torch.cat([inputs,mixup_inputs],dim=0)
+ targets = torch.cat([targets,mixup_targets],dim=0)
+ return inputs,targets
+
+ def _map_targets(self,select_targets,perm_targets):
+ assert (select_targets != perm_targets).all()
+ large_targets = torch.max(select_targets,perm_targets)-self._known_classes
+ small_targets = torch.min(select_targets,perm_targets)-self._known_classes
+
+ mixup_targets = large_targets*(large_targets-1) // 2 + small_targets + self._total_classes
+ return mixup_targets
+ def _compute_accuracy(self, model, loader):
+ model.eval()
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(loader):
+ inputs = inputs.to(self._device)
+ with torch.no_grad():
+ outputs = model(inputs)["logits"][:,:self._total_classes]
+ predicts = torch.max(outputs, dim=1)[1]
+ correct += (predicts.cpu() == targets).sum()
+ total += len(targets)
+
+ return np.around(tensor2numpy(correct)*100 / total, decimals=2)
+
+ def _eval_cnn(self, loader):
+ self._network.eval()
+ y_pred, y_true = [], []
+ for _, (_, inputs, targets) in enumerate(loader):
+ inputs = inputs.to(self._device)
+ with torch.no_grad():
+ outputs = self._network(inputs)["logits"][:,:self._total_classes]
+ predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1]
+ y_pred.append(predicts.cpu().numpy())
+ y_true.append(targets.cpu().numpy())
+
+ return np.concatenate(y_pred), np.concatenate(y_true)
+
+ def eval_task(self, save_conf=False):
+ y_pred, y_true = self._eval_cnn(self.test_loader)
+ cnn_accy = self._evaluate(y_pred, y_true)
+
+ if hasattr(self, '_class_means'):
+ y_pred, y_true = self._eval_nme(self.test_loader, self._class_means)
+ nme_accy = self._evaluate(y_pred, y_true)
+ elif hasattr(self, '_protos'):
+ y_pred, y_true = self._eval_nme(self.test_loader, self._protos/np.linalg.norm(self._protos,axis=1)[:,None])
+ nme_accy = self._evaluate(y_pred, y_true)
+ else:
+ nme_accy = None
+
+ return cnn_accy, nme_accy
diff --git a/models/lwf.py b/models/lwf.py
new file mode 100644
index 0000000000000000000000000000000000000000..e618803cab955b43160ea91e7c631b602105cbb9
--- /dev/null
+++ b/models/lwf.py
@@ -0,0 +1,205 @@
+import logging
+import numpy as np
+import torch
+from torch import nn
+from torch.serialization import load
+from tqdm import tqdm
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from utils.inc_net import IncrementalNet
+from models.base import BaseLearner
+from utils.toolkit import target2onehot, tensor2numpy
+
+init_epoch = 200
+init_lr = 0.1
+init_milestones = [60, 120, 160]
+init_lr_decay = 0.1
+init_weight_decay = 0.0005
+
+
+epochs = 250
+lrate = 0.1
+milestones = [60, 120, 180, 220]
+lrate_decay = 0.1
+batch_size = 128
+weight_decay = 2e-4
+num_workers = 8
+T = 2
+lamda = 3
+
+
+class LwF(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = IncrementalNet(args, False)
+
+ def after_task(self):
+ self._old_network = self._network.copy().freeze()
+ self._known_classes = self._total_classes
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ )
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._old_network is not None:
+ self._old_network.to(self._device)
+
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ momentum=0.9,
+ lr=init_lr,
+ weight_decay=init_weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ lr=lrate,
+ momentum=0.9,
+ weight_decay=weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(init_epoch))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+
+ logging.info(info)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+
+ prog_bar = tqdm(range(epochs))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ fake_targets = targets - self._known_classes
+ loss_clf = F.cross_entropy(
+ logits[:, self._known_classes :], fake_targets
+ )
+ loss_kd = _KD_loss(
+ logits[:, : self._known_classes],
+ self._old_network(inputs)["logits"],
+ T,
+ )
+
+ loss = lamda * loss_kd + loss_clf
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ with torch.no_grad():
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
+
+
+def _KD_loss(pred, soft, T):
+ pred = torch.log_softmax(pred / T, dim=1)
+ soft = torch.softmax(soft / T, dim=1)
+ return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
diff --git a/models/memo.py b/models/memo.py
new file mode 100644
index 0000000000000000000000000000000000000000..a26e5ce29e79f6a57e4f4c2c737f1d262648a534
--- /dev/null
+++ b/models/memo.py
@@ -0,0 +1,337 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+import copy
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import AdaptiveNet
+from utils.toolkit import count_parameters, target2onehot, tensor2numpy
+
+num_workers=8
+EPSILON = 1e-8
+batch_size = 32
+
+class MEMO(BaseLearner):
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.args = args
+ self._old_base = None
+ self._network = AdaptiveNet(args, True)
+ logging.info(f'>>> train generalized blocks:{self.args["train_base"]} train_adaptive:{self.args["train_adaptive"]}')
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+ if self._cur_task == 0:
+ if self.args['train_base']:
+ logging.info("Train Generalized Blocks...")
+ self._network.TaskAgnosticExtractor.train()
+ for param in self._network.TaskAgnosticExtractor.parameters():
+ param.requires_grad = True
+ else:
+ logging.info("Fix Generalized Blocks...")
+ self._network.TaskAgnosticExtractor.eval()
+ for param in self._network.TaskAgnosticExtractor.parameters():
+ param.requires_grad = False
+
+ logging.info('Exemplar size: {}'.format(self.exemplar_size))
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
+ self._network.update_fc(self._total_classes)
+
+ logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes))
+
+ if self._cur_task>0:
+ for i in range(self._cur_task):
+ for p in self._network.AdaptiveExtractors[i].parameters():
+ if self.args['train_adaptive'] and i == self._cur_task:
+ p.requires_grad = True
+ else:
+ p.requires_grad = False
+
+ logging.info('All params: {}'.format(count_parameters(self._network)))
+ logging.info('Trainable params: {}'.format(count_parameters(self._network, True)))
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source='train',
+ mode='train',
+ appendent=self._get_memory()
+ )
+ self.train_loader = DataLoader(
+ train_dataset,
+ batch_size=self.args["batch_size"],
+ shuffle=True,
+ num_workers=num_workers
+ )
+
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes),
+ source='test',
+ mode='test'
+ )
+ self.test_loader = DataLoader(
+ test_dataset,
+ batch_size=self.args["batch_size"],
+ shuffle=False,
+ num_workers=num_workers
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def set_network(self):
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+ self._network.train() #All status from eval to train
+ if self.args['train_base']:
+ self._network.TaskAgnosticExtractor.train()
+ else:
+ self._network.TaskAgnosticExtractor.eval()
+
+ # set adaptive extractor's status
+ self._network.AdaptiveExtractors[-1].train()
+ if self._cur_task >= 1:
+ for i in range(self._cur_task):
+ if self.args['train_adaptive']:
+ self._network.AdaptiveExtractors[i].train()
+ else:
+ self._network.AdaptiveExtractors[i].eval()
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._cur_task==0:
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._network.parameters()),
+ momentum=0.9,
+ lr=self.args["init_lr"],
+ weight_decay=self.args["init_weight_decay"]
+ )
+ if self.args['scheduler'] == 'steplr':
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer,
+ milestones=self.args['init_milestones'],
+ gamma=self.args['init_lr_decay']
+ )
+ elif self.args['scheduler'] == 'cosine':
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer,
+ T_max=self.args['init_epoch']
+ )
+ else:
+ raise NotImplementedError
+
+ if not self.args['skip']:
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ if isinstance(self._network, nn.DataParallel):
+ self._network = self._network.module
+ load_acc = self._network.load_checkpoint(self.args)
+ self._network.to(self._device)
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+
+ cur_test_acc = self._compute_accuracy(self._network, self.test_loader)
+ logging.info(f"Loaded_Test_Acc:{load_acc} Cur_Test_Acc:{cur_test_acc}")
+ else:
+ optimizer = optim.SGD(
+ filter(lambda p: p.requires_grad, self._network.parameters()),
+ lr=self.args['lrate'],
+ momentum=0.9,
+ weight_decay=self.args['weight_decay']
+ )
+ if self.args['scheduler'] == 'steplr':
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer,
+ milestones=self.args['milestones'],
+ gamma=self.args['lrate_decay']
+ )
+ elif self.args['scheduler'] == 'cosine':
+ assert self.args['t_max'] is not None
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer,
+ T_max=self.args['t_max']
+ )
+ else:
+ raise NotImplementedError
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+ if len(self._multiple_gpus) > 1:
+ self._network.module.weight_align(self._total_classes-self._known_classes)
+ else:
+ self._network.weight_align(self._total_classes-self._known_classes)
+
+
+ def _init_train(self,train_loader,test_loader,optimizer,scheduler):
+ prog_bar = tqdm(range(self.args["init_epoch"]))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)['logits']
+
+ loss=F.cross_entropy(logits,targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct)*100 / total, decimals=2)
+ if epoch%5==0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self.args['init_epoch'], losses/len(train_loader), train_acc, test_acc)
+ else:
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self.args['init_epoch'], losses/len(train_loader), train_acc)
+ # prog_bar.set_description(info)
+ logging.info(info)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self.args["epochs"]))
+ for _, epoch in enumerate(prog_bar):
+ self.set_network()
+ losses = 0.
+ losses_clf=0.
+ losses_aux=0.
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+
+ outputs= self._network(inputs)
+ logits,aux_logits=outputs["logits"],outputs["aux_logits"]
+ loss_clf=F.cross_entropy(logits,targets)
+ aux_targets = targets.clone()
+ aux_targets=torch.where(aux_targets-self._known_classes+1.0>0, aux_targets-self._known_classes+1.0,torch.Tensor([.0]).to(self.args["device"][0]))
+ loss_aux=F.cross_entropy(aux_logits,aux_targets.long())
+ loss=loss_clf+self.args['alpha_aux']*loss_aux
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ losses_aux+=loss_aux.item()
+ losses_clf+=loss_clf.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct)*100 / total, decimals=2)
+ if epoch%5==0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self.args["epochs"], losses/len(train_loader),losses_clf/len(train_loader),losses_aux/len(train_loader),train_acc, test_acc)
+ else:
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self.args["epochs"], losses/len(train_loader), losses_clf/len(train_loader),losses_aux/len(train_loader),train_acc)
+ prog_bar.set_description(info)
+ logging.info(info)
+
+ def save_checkpoint(self, test_acc):
+ assert self.args['model_name'] == 'finetune'
+ checkpoint_name = f"checkpoints/finetune_{self.args['csv_name']}"
+ _checkpoint_cpu = copy.deepcopy(self._network)
+ if isinstance(_checkpoint_cpu, nn.DataParallel):
+ _checkpoint_cpu = _checkpoint_cpu.module
+ _checkpoint_cpu.cpu()
+ save_dict = {
+ "tasks": self._cur_task,
+ "convnet": _checkpoint_cpu.convnet.state_dict(),
+ "fc":_checkpoint_cpu.fc.state_dict(),
+ "test_acc": test_acc
+ }
+ torch.save(save_dict, "{}_{}.pkl".format(checkpoint_name, self._cur_task))
+
+ def _construct_exemplar(self, data_manager, m):
+ logging.info("Constructing exemplars...({} per classes)".format(m))
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = data_manager.get_dataset(
+ np.arange(class_idx, class_idx + 1),
+ source="train",
+ mode="test",
+ ret_data=True,
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ class_mean = np.mean(vectors, axis=0)
+
+ # Select
+ selected_exemplars = []
+ exemplar_vectors = [] # [n, feature_dim]
+ for k in range(1, m + 1):
+ S = np.sum(
+ exemplar_vectors, axis=0
+ ) # [feature_dim] sum of selected exemplars vectors
+ mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors
+ i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
+ selected_exemplars.append(
+ np.array(data[i])
+ ) # New object to avoid passing by inference
+ exemplar_vectors.append(
+ np.array(vectors[i])
+ ) # New object to avoid passing by inference
+
+ vectors = np.delete(
+ vectors, i, axis=0
+ ) # Remove it to avoid duplicative selection
+ data = np.delete(
+ data, i, axis=0
+ ) # Remove it to avoid duplicative selection
+
+ if len(vectors) == 0:
+ break
+ # uniques = np.unique(selected_exemplars, axis=0)
+ # print('Unique elements: {}'.format(len(uniques)))
+ selected_exemplars = np.array(selected_exemplars)
+ # exemplar_targets = np.full(m, class_idx)
+ exemplar_targets = np.full(selected_exemplars.shape[0], class_idx)
+ self._data_memory = (
+ np.concatenate((self._data_memory, selected_exemplars))
+ if len(self._data_memory) != 0
+ else selected_exemplars
+ )
+ self._targets_memory = (
+ np.concatenate((self._targets_memory, exemplar_targets))
+ if len(self._targets_memory) != 0
+ else exemplar_targets
+ )
+
+ # Exemplar mean
+ idx_dataset = data_manager.get_dataset(
+ [],
+ source="train",
+ mode="test",
+ appendent=(selected_exemplars, exemplar_targets),
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ self._class_means[class_idx, :] = mean
\ No newline at end of file
diff --git a/models/pa2s.py b/models/pa2s.py
new file mode 100644
index 0000000000000000000000000000000000000000..2caa6ca579d9cd537377cd660b8c18605db7bf60
--- /dev/null
+++ b/models/pa2s.py
@@ -0,0 +1,216 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import os
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader,Dataset
+from models.base import BaseLearner
+from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet
+from utils.toolkit import count_parameters, target2onehot, tensor2numpy
+
+EPSILON = 1e-8
+
+
+class PASS(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self.args = args
+ self._network = IncrementalNet(args, False)
+ self._protos = []
+ self._radius = 0
+ self._radiuses = []
+
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+ self._old_network = self._network.copy().freeze()
+ if hasattr(self._old_network,"module"):
+ self.old_network_module_ptr = self._old_network.module
+ else:
+ self.old_network_module_ptr = self._old_network
+ #self.save_checkpoint("{}_{}_{}".format(self.args["model_name"],self.args["init_cls"],self.args["increment"]))
+ def incremental_train(self, data_manager):
+ self.data_manager = data_manager
+ self._cur_task += 1
+
+ self._total_classes = self._known_classes + \
+ data_manager.get_task_size(self._cur_task)
+ self._network.update_fc(self._total_classes*4)
+ self._network_module_ptr = self._network
+ logging.info(
+ 'Learning on {}-{}'.format(self._known_classes, self._total_classes))
+
+
+ logging.info('All params: {}'.format(count_parameters(self._network)))
+ logging.info('Trainable params: {}'.format(
+ count_parameters(self._network, True)))
+
+ train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train',
+ mode='train', appendent=self._get_memory())
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True)
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source='test', mode='test')
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"])
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+
+ def _train(self, train_loader, test_loader):
+
+ resume = False
+ if self._cur_task in []:
+ self._network.load_state_dict(torch.load("{}_{}_{}_{}.pkl".format(self.args["model_name"],self.args["init_cls"],self.args["increment"],self._cur_task))["model_state_dict"])
+ resume = True
+ self._network.to(self._device)
+ if hasattr(self._network, "module"):
+ self._network_module_ptr = self._network.module
+ if not resume:
+ self._epoch_num = self.args["epochs"]
+ optimizer = torch.optim.Adam(self._network.parameters(), lr=self.args["lr"], weight_decay=self.args["weight_decay"])
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args["step_size"], gamma=self.args["gamma"])
+ self._train_function(train_loader, test_loader, optimizer, scheduler)
+ self._build_protos()
+
+
+ def _build_protos(self):
+ with torch.no_grad():
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
+ mode='test', ret_data=True)
+ idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(idx_loader)
+ class_mean = np.mean(vectors, axis=0)
+ self._protos.append(class_mean)
+ cov = np.cov(vectors.T)
+ self._radiuses.append(np.trace(cov)/vectors.shape[1])
+ self._radius = np.sqrt(np.mean(self._radiuses))
+
+ def _train_function(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self._epoch_num))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.
+ losses_clf, losses_fkd, losses_proto = 0., 0., 0.
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True), targets.to(self._device, non_blocking=True)
+ inputs = torch.stack([torch.rot90(inputs, k, (2, 3)) for k in range(4)], 1)
+ inputs = inputs.view(-1, 3, 320, 320)
+ targets = torch.stack([targets * 4 + k for k in range(4)], 1).view(-1)
+ logits, loss_clf, loss_fkd, loss_proto = self._compute_pass_loss(inputs,targets)
+ loss = loss_clf + loss_fkd + loss_proto
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ losses_clf += loss_clf.item()
+ losses_fkd += loss_fkd.item()
+ losses_proto += loss_proto.item()
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(
+ correct)*100 / total, decimals=2)
+ if epoch % 5 != 0:
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc)
+ else:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc, test_acc)
+ prog_bar.set_description(info)
+ logging.info(info)
+
+ def _compute_pass_loss(self,inputs, targets):
+ logits = self._network(inputs)["logits"]
+ loss_clf = F.cross_entropy(logits/self.args["temp"], targets)
+
+ if self._cur_task == 0:
+ return logits, loss_clf, torch.tensor(0.), torch.tensor(0.)
+
+ features = self._network_module_ptr.extract_vector(inputs)
+ features_old = self.old_network_module_ptr.extract_vector(inputs)
+ loss_fkd = self.args["lambda_fkd"] * torch.dist(features, features_old, 2)
+
+ # index = np.random.choice(range(self._known_classes),size=self.args["batch_size"],replace=True)
+
+ index = np.random.choice(range(self._known_classes),size=self.args["batch_size"]*int(self._known_classes/(self._total_classes-self._known_classes)),replace=True)
+ # print(index)
+ # print(np.concatenate(self._protos))
+ proto_features = np.array(self._protos)[index]
+ # print(proto_features)
+ proto_targets = 4*index
+ proto_features = proto_features + np.random.normal(0,1,proto_features.shape)*self._radius
+ proto_features = torch.from_numpy(proto_features).float().to(self._device,non_blocking=True)
+ proto_targets = torch.from_numpy(proto_targets).to(self._device,non_blocking=True)
+
+
+ proto_logits = self._network_module_ptr.fc(proto_features)["logits"]
+ loss_proto = self.args["lambda_proto"] * F.cross_entropy(proto_logits/self.args["temp"], proto_targets)
+ return logits, loss_clf, loss_fkd, loss_proto
+
+
+
+ def _compute_accuracy(self, model, loader):
+ model.eval()
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(loader):
+ inputs = inputs.to(self._device)
+ with torch.no_grad():
+ outputs = model(inputs)["logits"][:,::4]
+ predicts = torch.max(outputs, dim=1)[1]
+ correct += (predicts.cpu() == targets).sum()
+ total += len(targets)
+
+ return np.around(tensor2numpy(correct)*100 / total, decimals=2)
+
+ def _eval_cnn(self, loader):
+ self._network.eval()
+ y_pred, y_true = [], []
+ for _, (_, inputs, targets) in enumerate(loader):
+ inputs = inputs.to(self._device)
+ with torch.no_grad():
+ outputs = self._network(inputs)["logits"][:,::4]
+ predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1]
+ y_pred.append(predicts.cpu().numpy())
+ y_true.append(targets.cpu().numpy())
+
+ return np.concatenate(y_pred), np.concatenate(y_true)
+
+ def eval_task(self, save_conf=True):
+ y_pred, y_true = self._eval_cnn(self.test_loader)
+ cnn_accy = self._evaluate(y_pred, y_true)
+
+ if hasattr(self, '_class_means'):
+ y_pred, y_true = self._eval_nme(self.test_loader, self._class_means)
+ nme_accy = self._evaluate(y_pred, y_true)
+ elif hasattr(self, '_protos'):
+ y_pred, y_true = self._eval_nme(self.test_loader, self._protos/np.linalg.norm(self._protos,axis=1)[:,None])
+ nme_accy = self._evaluate(y_pred, y_true)
+ else:
+ nme_accy = None
+ if save_conf:
+ _pred = y_pred.T[0]
+ _pred_path = os.path.join(self.args['logfilename'], "pred.npy")
+ _target_path = os.path.join(self.args['logfilename'], "target.npy")
+ np.save(_pred_path, _pred)
+ np.save(_target_path, y_true)
+
+ _save_dir = os.path.join(f"./results/{self.args['model_name']}/conf_matrix/{self.args['prefix']}")
+ os.makedirs(_save_dir, exist_ok=True)
+ _save_path = os.path.join(_save_dir, f"{self.args['csv_name']}.csv")
+ with open(_save_path, "a+") as f:
+ f.write(f"{self.args['model_name']},{_pred_path},{_target_path} \n")
+ return cnn_accy, nme_accy
\ No newline at end of file
diff --git a/models/podnet.py b/models/podnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..847090b5bfca6acc144fb45492b9dc69050b7995
--- /dev/null
+++ b/models/podnet.py
@@ -0,0 +1,324 @@
+import math
+import logging
+import numpy as np
+import torch
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import CosineIncrementalNet
+from utils.toolkit import tensor2numpy
+
+epochs = 100
+lrate = 0.1
+ft_epochs = 20
+ft_lrate = 0.005
+batch_size = 32
+lambda_c_base = 5
+lambda_f_base = 1
+nb_proxy = 10
+weight_decay = 5e-4
+num_workers = 4
+
+"""
+Distillation losses: POD-flat (lambda_f=1) + POD-spatial (lambda_c=5)
+NME results are shown.
+The reproduced results are not in line with the reported results.
+Maybe I missed something...
++--------------------+--------------------+--------------------+--------------------+
+| Classifier | Steps | Reported (%) | Reproduced (%) |
++--------------------+--------------------+--------------------+--------------------+
+| Cosine (k=1) | 50 | 56.69 | 55.49 |
++--------------------+--------------------+--------------------+--------------------+
+| LSC-CE (k=10) | 50 | 59.86 | 55.69 |
++--------------------+--------------------+--------------------+--------------------+
+| LSC-NCA (k=10) | 50 | 61.40 | 56.50 |
++--------------------+--------------------+--------------------+--------------------+
+| LSC-CE (k=10) | 25 | ----- | 59.16 |
++--------------------+--------------------+--------------------+--------------------+
+| LSC-NCA (k=10) | 25 | 62.71 | 59.79 |
++--------------------+--------------------+--------------------+--------------------+
+| LSC-CE (k=10) | 10 | ----- | 62.59 |
++--------------------+--------------------+--------------------+--------------------+
+| LSC-NCA (k=10) | 10 | 64.03 | 62.81 |
++--------------------+--------------------+--------------------+--------------------+
+| LSC-CE (k=10) | 5 | ----- | 64.16 |
++--------------------+--------------------+--------------------+--------------------+
+| LSC-NCA (k=10) | 5 | 64.48 | 64.37 |
++--------------------+--------------------+--------------------+--------------------+
+"""
+
+
+class PODNet(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = CosineIncrementalNet(
+ args, pretrained=False, nb_proxy=nb_proxy
+ )
+ self._class_means = None
+
+ def after_task(self):
+ self._old_network = self._network.copy().freeze()
+ self._known_classes = self._total_classes
+ logging.info("Exemplar size: {}".format(self.exemplar_size))
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self.task_size = self._total_classes - self._known_classes
+ self._network.update_fc(self._total_classes, self._cur_task)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ train_dset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ )
+ test_dset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.train_loader = DataLoader(
+ train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ self.test_loader = DataLoader(
+ test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ self._train(data_manager, self.train_loader, self.test_loader)
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+
+ def _train(self, data_manager, train_loader, test_loader):
+ if self._cur_task == 0:
+ self.factor = 0
+ else:
+ self.factor = math.sqrt(
+ self._total_classes / (self._total_classes - self._known_classes)
+ )
+ logging.info("Adaptive factor: {}".format(self.factor))
+
+ self._network.to(self._device)
+ if self._old_network is not None:
+ self._old_network.to(self._device)
+
+ if self._cur_task == 0:
+ network_params = self._network.parameters()
+ else:
+ ignored_params = list(map(id, self._network.fc.fc1.parameters()))
+ base_params = filter(
+ lambda p: id(p) not in ignored_params, self._network.parameters()
+ )
+ network_params = [
+ {"params": base_params, "lr": lrate, "weight_decay": weight_decay},
+ {
+ "params": self._network.fc.fc1.parameters(),
+ "lr": 0,
+ "weight_decay": 0,
+ },
+ ]
+ optimizer = optim.SGD(
+ network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay
+ )
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=epochs
+ )
+ self._run(train_loader, test_loader, optimizer, scheduler, epochs)
+
+ if self._cur_task == 0:
+ return
+ logging.info(
+ "Finetune the network (classifier part) with the undersampled dataset!"
+ )
+ if self._fixed_memory:
+ finetune_samples_per_class = self._memory_per_class
+ self._construct_exemplar_unified(data_manager, finetune_samples_per_class)
+ else:
+ finetune_samples_per_class = self._memory_size // self._known_classes
+ self._reduce_exemplar(data_manager, finetune_samples_per_class)
+ self._construct_exemplar(data_manager, finetune_samples_per_class)
+
+ finetune_train_dataset = data_manager.get_dataset(
+ [], source="train", mode="train", appendent=self._get_memory()
+ )
+ finetune_train_loader = DataLoader(
+ finetune_train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers,
+ )
+ logging.info(
+ "The size of finetune dataset: {}".format(len(finetune_train_dataset))
+ )
+
+ ignored_params = list(map(id, self._network.fc.fc1.parameters()))
+ base_params = filter(
+ lambda p: id(p) not in ignored_params, self._network.parameters()
+ )
+ network_params = [
+ {"params": base_params, "lr": ft_lrate, "weight_decay": weight_decay},
+ {"params": self._network.fc.fc1.parameters(), "lr": 0, "weight_decay": 0},
+ ]
+ optimizer = optim.SGD(
+ network_params, lr=ft_lrate, momentum=0.9, weight_decay=weight_decay
+ )
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=ft_epochs
+ )
+ self._run(finetune_train_loader, test_loader, optimizer, scheduler, ft_epochs)
+
+ if self._fixed_memory:
+ self._data_memory = self._data_memory[
+ : -self._memory_per_class * self.task_size
+ ]
+ self._targets_memory = self._targets_memory[
+ : -self._memory_per_class * self.task_size
+ ]
+ assert (
+ len(
+ np.setdiff1d(
+ self._targets_memory, np.arange(0, self._known_classes)
+ )
+ )
+ == 0
+ ), "Exemplar error!"
+
+ def _run(self, train_loader, test_loader, optimizer, scheduler, epk):
+ for epoch in range(1, epk + 1):
+ self._network.train()
+ lsc_losses = 0.0
+ spatial_losses = 0.0
+ flat_losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ outputs = self._network(inputs)
+ logits = outputs["logits"]
+ features = outputs["features"]
+ fmaps = outputs["fmaps"]
+ lsc_loss = nca(logits, targets)
+
+ spatial_loss = 0.0
+ flat_loss = 0.0
+ if self._old_network is not None:
+ with torch.no_grad():
+ old_outputs = self._old_network(inputs)
+ old_features = old_outputs["features"]
+ old_fmaps = old_outputs["fmaps"]
+ flat_loss = (
+ F.cosine_embedding_loss(
+ features,
+ old_features.detach(),
+ torch.ones(inputs.shape[0]).to(self._device),
+ )
+ * self.factor
+ * lambda_f_base
+ )
+ spatial_loss = (
+ pod_spatial_loss(fmaps, old_fmaps) * self.factor * lambda_c_base
+ )
+
+ loss = lsc_loss + flat_loss + spatial_loss
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ lsc_losses += lsc_loss.item()
+ spatial_losses += (
+ spatial_loss.item() if self._cur_task != 0 else spatial_loss
+ )
+ flat_losses += flat_loss.item() if self._cur_task != 0 else flat_loss
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ if scheduler is not None:
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info1 = "Task {}, Epoch {}/{} (LR {:.5f}) => ".format(
+ self._cur_task, epoch, epk, optimizer.param_groups[0]["lr"]
+ )
+ info2 = "LSC_loss {:.2f}, Spatial_loss {:.2f}, Flat_loss {:.2f}, Train_acc {:.2f}, Test_acc {:.2f}".format(
+ lsc_losses / (i + 1),
+ spatial_losses / (i + 1),
+ flat_losses / (i + 1),
+ train_acc,
+ test_acc,
+ )
+ logging.info(info1 + info2)
+
+
+def pod_spatial_loss(old_fmaps, fmaps, normalize=True):
+ """
+ a, b: list of [bs, c, w, h]
+ """
+ loss = torch.tensor(0.0).to(fmaps[0].device)
+ for i, (a, b) in enumerate(zip(old_fmaps, fmaps)):
+ assert a.shape == b.shape, "Shape error"
+
+ a = torch.pow(a, 2)
+ b = torch.pow(b, 2)
+
+ a_h = a.sum(dim=3).view(a.shape[0], -1) # [bs, c*w]
+ b_h = b.sum(dim=3).view(b.shape[0], -1) # [bs, c*w]
+ a_w = a.sum(dim=2).view(a.shape[0], -1) # [bs, c*h]
+ b_w = b.sum(dim=2).view(b.shape[0], -1) # [bs, c*h]
+
+ a = torch.cat([a_h, a_w], dim=-1)
+ b = torch.cat([b_h, b_w], dim=-1)
+
+ if normalize:
+ a = F.normalize(a, dim=1, p=2)
+ b = F.normalize(b, dim=1, p=2)
+
+ layer_loss = torch.mean(torch.frobenius_norm(a - b, dim=-1))
+ loss += layer_loss
+
+ return loss / len(fmaps)
+
+
+def nca(
+ similarities,
+ targets,
+ class_weights=None,
+ focal_gamma=None,
+ scale=1.0,
+ margin=0.6,
+ exclude_pos_denominator=True,
+ hinge_proxynca=False,
+ memory_flags=None,
+):
+ margins = torch.zeros_like(similarities)
+ margins[torch.arange(margins.shape[0]), targets] = margin
+ similarities = scale * (similarities - margin)
+
+ if exclude_pos_denominator:
+ similarities = similarities - similarities.max(1)[0].view(-1, 1)
+
+ disable_pos = torch.zeros_like(similarities)
+ disable_pos[torch.arange(len(similarities)), targets] = similarities[
+ torch.arange(len(similarities)), targets
+ ]
+
+ numerator = similarities[torch.arange(similarities.shape[0]), targets]
+ denominator = similarities - disable_pos
+
+ losses = numerator - torch.log(torch.exp(denominator).sum(-1))
+ if class_weights is not None:
+ losses = class_weights[targets] * losses
+
+ losses = -losses
+ if hinge_proxynca:
+ losses = torch.clamp(losses, min=0.0)
+
+ loss = torch.mean(losses)
+ return loss
+
+ return F.cross_entropy(
+ similarities, targets, weight=class_weights, reduction="mean"
+ )
diff --git a/models/replay.py b/models/replay.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee3e9cf9f88bf21147e569b3a36764fbb8afb7ce
--- /dev/null
+++ b/models/replay.py
@@ -0,0 +1,193 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import IncrementalNet
+from utils.toolkit import target2onehot, tensor2numpy
+
+EPSILON = 1e-8
+
+
+init_epoch = 100
+init_lr = 0.1
+init_milestones = [40, 60, 80]
+init_lr_decay = 0.1
+init_weight_decay = 0.0005
+
+
+epochs = 70
+lrate = 0.1
+milestones = [30, 50]
+lrate_decay = 0.1
+batch_size = 32
+weight_decay = 2e-4
+num_workers = 8
+T = 2
+
+
+class Replay(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = IncrementalNet(args, False)
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+ logging.info("Exemplar size: {}".format(self.exemplar_size))
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ # Loader
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ )
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ # Procedure
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ momentum=0.9,
+ lr=init_lr,
+ weight_decay=init_weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ lr=lrate,
+ momentum=0.9,
+ weight_decay=weight_decay,
+ ) # 1e-5
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(init_epoch))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ )
+
+ prog_bar.set_description(info)
+
+ logging.info(info)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(epochs))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss_clf = F.cross_entropy(logits, targets)
+ loss = loss_clf
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ # acc
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
diff --git a/models/rmm.py b/models/rmm.py
new file mode 100644
index 0000000000000000000000000000000000000000..bacddb6628b2db13dedc853adf066b873bb0afed
--- /dev/null
+++ b/models/rmm.py
@@ -0,0 +1,285 @@
+import copy
+import logging
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from models.foster import FOSTER
+from utils.toolkit import count_parameters, tensor2numpy, accuracy
+from utils.inc_net import IncrementalNet
+from scipy.spatial.distance import cdist
+from models.base import BaseLearner
+from models.icarl import iCaRL
+from tqdm import tqdm
+import torch.optim as optim
+
+
+EPSILON = 1e-8
+batch_size = 32
+weight_decay = 2e-4
+num_workers = 8
+
+
+class RMMBase(BaseLearner):
+ def __init__(self, args):
+ self._args = args
+ self._m_rate_list = args.get("m_rate_list", [])
+ self._c_rate_list = args.get("c_rate_list", [])
+
+ @property
+ def samples_per_class(self):
+ return int(self.memory_size // self._total_classes)
+
+ @property
+ def memory_size(self):
+ if self._args["dataset"] == "cifar100":
+ img_per_cls = 500
+ else:
+ img_per_cls = 1300
+
+ if self._m_rate_list[self._cur_task] != 0:
+ print(self._total_classes)
+ self._memory_size = min(int(self._total_classes*img_per_cls-1),self._args["memory_size"] + int(
+ self._m_rate_list[self._cur_task]
+ * self._args["increment"]
+ * img_per_cls
+ ))
+ return self._memory_size
+
+ @property
+ def new_memory_size(self):
+ if self._args["dataset"] == "cifar100":
+ img_per_cls = 500
+ else:
+ img_per_cls = 1300
+ return int(
+ (1 - self._m_rate_list[self._cur_task])
+ * self._args["increment"]
+ * img_per_cls
+ )
+
+ def build_rehearsal_memory(self, data_manager, per_class):
+ self._reduce_exemplar(data_manager, per_class)
+ self._construct_exemplar(data_manager, per_class)
+
+ def _construct_exemplar(self, data_manager, m):
+ if self._args["dataset"] == "cifar100":
+ img_per_cls = 500
+ else:
+ img_per_cls = 1300
+ ns = [
+ min(img_per_cls,int(m * (1 - self._c_rate_list[self._cur_task]))),
+ min(img_per_cls,int(m * (1 + self._c_rate_list[self._cur_task]))),
+ ]
+ logging.info(
+ "Constructing exemplars...({} or {} per classes)".format(ns[0], ns[1])
+ )
+
+ all_cls_entropies = []
+ ms = []
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = data_manager.get_dataset(
+ np.arange(class_idx, class_idx + 1),
+ source="train",
+ mode="test",
+ ret_data=True,
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ with torch.no_grad():
+ cidx_cls_entropies = []
+ for idx, (_, inputs, targets) in enumerate(idx_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+ cross_entropy = (
+ F.cross_entropy(logits, targets, reduction="none")
+ .detach()
+ .cpu()
+ .numpy()
+ )
+ cidx_cls_entropies.append(cross_entropy)
+ # print(cidx_cls_entropies)
+ cidx_cls_entropies = np.mean(np.concatenate(cidx_cls_entropies))
+ all_cls_entropies.append(cidx_cls_entropies)
+ entropy_median = np.median(all_cls_entropies)
+ for the_entropy in all_cls_entropies:
+ if the_entropy > entropy_median:
+ ms.append(ns[0])
+ else:
+ ms.append(ns[1])
+
+ logging.info(f"ms: {ms}")
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = data_manager.get_dataset(
+ np.arange(class_idx, class_idx + 1),
+ source="train",
+ mode="test",
+ ret_data=True,
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ class_mean = np.mean(vectors, axis=0)
+ # Select
+ selected_exemplars = []
+ exemplar_vectors = [] # [n, feature_dim]
+ for k in range(1, ms[class_idx - self._known_classes] + 1):
+ S = np.sum(
+ exemplar_vectors, axis=0
+ ) # [feature_dim] sum of selected exemplars vectors
+ mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors
+ i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
+ selected_exemplars.append(
+ np.array(data[i])
+ ) # New object to avoid passing by inference
+ exemplar_vectors.append(
+ np.array(vectors[i])
+ ) # New object to avoid passing by inference
+
+ vectors = np.delete(
+ vectors, i, axis=0
+ ) # Remove it to avoid duplicative selection
+ data = np.delete(
+ data, i, axis=0
+ ) # Remove it to avoid duplicative selection
+
+ selected_exemplars = np.array(selected_exemplars)
+ exemplar_targets = np.full(ms[class_idx - self._known_classes], class_idx)
+ self._data_memory = (
+ np.concatenate((self._data_memory, selected_exemplars))
+ if len(self._data_memory) != 0
+ else selected_exemplars
+ )
+ self._targets_memory = (
+ np.concatenate((self._targets_memory, exemplar_targets))
+ if len(self._targets_memory) != 0
+ else exemplar_targets
+ )
+
+ # Exemplar mean
+ idx_dataset = data_manager.get_dataset(
+ [],
+ source="train",
+ mode="test",
+ appendent=(selected_exemplars, exemplar_targets),
+ )
+ idx_loader = DataLoader(
+ idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
+ )
+ vectors, _ = self._extract_vectors(idx_loader)
+ vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
+ mean = np.mean(vectors, axis=0)
+ mean = mean / np.linalg.norm(mean)
+
+ self._class_means[class_idx, :] = mean
+
+
+class RMM_iCaRL(
+ RMMBase, iCaRL
+): # RMM Base is supposed to be prior to the orginal method.
+ def __init__(self, args):
+ RMMBase.__init__(self, args)
+ iCaRL.__init__(self, args)
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ m_rate=self._m_rate_list[self._cur_task] if self._cur_task > 0 else None,
+ )
+ self.train_loader = DataLoader(
+ train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers,
+ pin_memory=True,
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+
+class RMM_FOSTER(RMMBase, FOSTER):
+ def __init__(self, args):
+ RMMBase.__init__(self, args)
+ FOSTER.__init__(self, args)
+
+ def incremental_train(self, data_manager):
+ self.data_manager = data_manager
+ self._cur_task += 1
+ if self._cur_task > 1:
+ self._network = self._snet
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ self._network_module_ptr = self._network
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ if self._cur_task > 0:
+ for p in self._network.convnets[0].parameters():
+ p.requires_grad = False
+ for p in self._network.oldfc.parameters():
+ p.requires_grad = False
+
+ logging.info("All params: {}".format(count_parameters(self._network)))
+ logging.info(
+ "Trainable params: {}".format(count_parameters(self._network, True))
+ )
+
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ m_rate=self._m_rate_list[self._cur_task] if self._cur_task > 0 else None,
+ )
+ self.train_loader = DataLoader(
+ train_dataset,
+ batch_size=self.args["batch_size"],
+ shuffle=True,
+ num_workers=self.args["num_workers"],
+ pin_memory=True,
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset,
+ batch_size=self.args["batch_size"],
+ shuffle=False,
+ num_workers=self.args["num_workers"],
+ )
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
diff --git a/models/simplecil.py b/models/simplecil.py
new file mode 100644
index 0000000000000000000000000000000000000000..f62cb40ef17fcda62952bc714f734b8c2cfed791
--- /dev/null
+++ b/models/simplecil.py
@@ -0,0 +1,175 @@
+'''
+Re-implementation of SimpleCIL (https://arxiv.org/abs/2303.07338) without pre-trained weights.
+The training process is as follows: train the model with cross-entropy in the first stage and replace the classifier with prototypes for all the classes in the subsequent stages.
+Please refer to the original implementation (https://github.com/zhoudw-zdw/RevisitingCIL) if you are using pre-trained weights.
+'''
+import logging
+import numpy as np
+import torch
+from torch import nn
+from torch.serialization import load
+from tqdm import tqdm
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from utils.inc_net import SimpleCosineIncrementalNet
+from models.base import BaseLearner
+from utils.toolkit import target2onehot, tensor2numpy
+
+
+num_workers = 8
+batch_size = 32
+milestones = [40, 80]
+
+class SimpleCIL(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = SimpleCosineIncrementalNet(args, False)
+ self.min_lr = args['min_lr'] if args['min_lr'] is not None else 1e-8
+ self.args = args
+
+ def load_checkpoint(self, filename):
+ checkpoint = torch.load(filename)
+ self._total_classes = len(checkpoint["classes"])
+ self.class_list = np.array(checkpoint["classes"])
+ self.label_list = checkpoint["label_list"]
+ print("Class list: ", self.class_list)
+ self._network.update_fc(self._total_classes)
+ self._network.load_checkpoint(checkpoint["network"])
+ self._network.to(self._device)
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+
+ def save_checkpoint(self, filename):
+ self._network.cpu()
+ save_dict = {
+ "classes": self.data_manager.get_class_list(self._cur_task),
+ "network": {
+ "convnet": self._network.convnet.state_dict(),
+ "fc": self._network.fc.state_dict()
+ },
+ "label_list": self.data_manager.get_label_list(self._cur_task),
+ }
+ torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task))
+
+ def replace_fc(self,trainloader, model, args):
+ model = model.eval()
+ embedding_list = []
+ label_list = []
+ with torch.no_grad():
+ for i, batch in enumerate(trainloader):
+ (_,data,label) = batch
+ data = data.cuda()
+ label = label.cuda()
+ embedding = model(data)["features"]
+ embedding_list.append(embedding.cpu())
+ label_list.append(label.cpu())
+ embedding_list = torch.cat(embedding_list, dim=0)
+ label_list = torch.cat(label_list, dim=0)
+
+ class_list = np.unique(self.train_dataset.labels)
+ proto_list = []
+ for class_index in class_list:
+ # print('Replacing...',class_index)
+ data_index = torch.nonzero(label_list == class_index).squeeze(-1)
+ embedding = embedding_list[data_index]
+ proto = embedding.mean(0)
+ if len(self._multiple_gpus) > 1:
+ self._network.module.fc.weight.data[class_index] = proto
+ else:
+ self._network.fc.weight.data[class_index] = proto
+ return model
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
+ self._network.update_fc(self._total_classes)
+ logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))
+ self.class_list = np.array(data_manager.get_class_list(self._cur_task))
+ train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train", )
+ self.train_dataset = train_dataset
+ self.data_manager = data_manager
+ self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
+ test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
+ self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+ train_dataset_for_protonet = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="test", )
+ self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=batch_size, shuffle=True, num_workers=num_workers)
+
+ if len(self._multiple_gpus) > 1:
+ print('Multiple GPUs')
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader, self.train_loader_for_protonet)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def _train(self, train_loader, test_loader, train_loader_for_protonet):
+ self._network.to(self._device)
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ momentum=0.9,
+ lr=self.args["init_lr"],
+ weight_decay=self.args["init_weight_decay"]
+ )
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer, T_max=self.args['init_epoch'], eta_min=self.min_lr
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ self.replace_fc(train_loader_for_protonet, self._network, None)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self.args["init_epoch"]))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args['init_epoch'],
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ self.args['init_epoch'],
+ losses / len(train_loader),
+ train_acc,
+ )
+ elapsed = prog_bar.format_dict["elapsed"]
+ rate = prog_bar.format_dict["rate"]
+ remaining = (prog_bar.total - prog_bar.n) / rate if rate and prog_bar.total else 0 # Seconds*
+ prog_bar.set_description(info)
+ logging.info("Working on task {}: {:.2f}:{:.2f}".format(
+ self._cur_task,
+ elapsed,
+ remaining))
+ logging.info(info)
+ logging.info("Finised on task {}: {:.2f}".format(
+ self._cur_task, elapsed))
+
+
diff --git a/models/ssre.py b/models/ssre.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cd93f752530f9c4581190b3d9df93b803629ca
--- /dev/null
+++ b/models/ssre.py
@@ -0,0 +1,253 @@
+import logging
+import numpy as np
+import os
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader,Dataset
+from models.base import BaseLearner
+from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet
+from utils.toolkit import count_parameters, target2onehot, tensor2numpy
+from utils.autoaugment import CIFAR10Policy,ImageNetPolicy
+from utils.ops import Cutout
+from torchvision import datasets, transforms
+
+EPSILON = 1e-8
+
+
+class SSRE(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self.args = args
+ self._network = IncrementalNet(args, False)
+ self._protos = []
+
+
+
+ def after_task(self):
+ self._known_classes = self._total_classes
+ self._old_network = self._network.copy().freeze()
+ if hasattr(self._old_network,"module"):
+ self.old_network_module_ptr = self._old_network.module
+ else:
+ self.old_network_module_ptr = self._old_network
+ #self.save_checkpoint("{}_{}_{}".format(self.args["model_name"],self.args["init_cls"],self.args["increment"]))
+ def incremental_train(self, data_manager):
+ self.data_manager = data_manager
+ if self._cur_task == 0:
+ self.data_manager._train_trsf = [
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ColorJitter(brightness=63/255),
+ CIFAR10Policy(),
+ Cutout(n_holes=1, length=16)
+ ]
+ else:
+ self.data_manager._train_trsf = [
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ColorJitter(brightness=63/255),
+ ]
+ self._cur_task += 1
+ self._total_classes = self._known_classes + \
+ data_manager.get_task_size(self._cur_task)
+ self._network.update_fc(self._total_classes)
+ self._network_module_ptr = self._network
+
+ logging.info("Model Expansion!")
+ self._network_expansion()
+
+ logging.info(
+ 'Learning on {}-{}'.format(self._known_classes, self._total_classes))
+
+
+ logging.info('All params: {}'.format(count_parameters(self._network)))
+ logging.info('Trainable params: {}'.format(
+ count_parameters(self._network, True)))
+
+ train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train',mode='train', appendent=self._get_memory())
+ if self._cur_task == 0:
+ batch_size = 64
+ else:
+ batch_size = self.args["batch_size"]
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=self.args["num_workers"], pin_memory=True)
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source='test', mode='test')
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"])
+
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+
+ self._train(self.train_loader, self.test_loader)
+
+
+
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ logging.info("Model Compression!")
+
+ self._network_compression()
+ def _train(self, train_loader, test_loader):
+
+ resume = False
+ if self._cur_task in []:
+ self._network.load_state_dict(torch.load("{}_{}_{}_{}.pkl".format(self.args["model_name"],self.args["init_cls"],self.args["increment"],self._cur_task))["model_state_dict"])
+ resume = True
+ self._network.to(self._device)
+ if hasattr(self._network, "module"):
+ self._network_module_ptr = self._network.module
+ if not resume:
+ self._epoch_num = self.args["epochs"]
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self._network.parameters(
+ )), lr=self.args["lr"], weight_decay=self.args["weight_decay"])
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args["step_size"], gamma=self.args["gamma"])
+ self._train_function(train_loader, test_loader, optimizer, scheduler)
+ self._build_protos()
+
+
+ def _build_protos(self):
+ with torch.no_grad():
+ for class_idx in range(self._known_classes, self._total_classes):
+ data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
+ mode='test', ret_data=True)
+ idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4)
+ vectors, _ = self._extract_vectors(idx_loader)
+ class_mean = np.mean(vectors, axis=0)
+ self._protos.append(class_mean)
+
+ def train(self):
+ if self._cur_task > 0:
+ self._network.eval()
+ return
+ self._network.train()
+ def _train_function(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(self._epoch_num))
+ for _, epoch in enumerate(prog_bar):
+ self.train()
+ losses = 0.
+ losses_clf, losses_fkd, losses_proto = 0., 0., 0.
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(
+ self._device, non_blocking=True), targets.to(self._device, non_blocking=True)
+ logits, loss_clf, loss_fkd, loss_proto = self._compute_ssre_loss(inputs,targets)
+ loss = loss_clf + loss_fkd + loss_proto
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+ losses_clf += loss_clf.item()
+ losses_fkd += loss_fkd.item()
+ losses_proto += loss_proto.item()
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(
+ correct)*100 / total, decimals=2)
+ if epoch % 5 != 0:
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc)
+ else:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
+ self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc, test_acc)
+ prog_bar.set_description(info)
+ logging.info(info)
+
+ def _compute_ssre_loss(self,inputs, targets):
+ if self._cur_task == 0:
+ logits = self._network(inputs)["logits"]
+ loss_clf = F.cross_entropy(logits/self.args["temp"], targets)
+ return logits, loss_clf, torch.tensor(0.), torch.tensor(0.)
+
+ features = self._network_module_ptr.extract_vector(inputs) # N D
+
+ with torch.no_grad():
+ features_old = self.old_network_module_ptr.extract_vector(inputs)
+
+ protos = torch.from_numpy(np.array(self._protos)).to(self._device) # C D
+ with torch.no_grad():
+ weights = F.normalize(features,p=2,dim=1,eps=1e-12) @ F.normalize(protos,p=2,dim=1,eps=1e-12).T
+ weights = torch.max(weights,dim=1)[0]
+ # mask = weights > self.args["threshold"]
+ mask = weights
+ logits = self._network(inputs)["logits"]
+ loss_clf = F.cross_entropy(logits/self.args["temp"],targets,reduction="none")
+ # loss_clf = torch.mean(loss_clf * ~mask)
+ loss_clf = torch.mean(loss_clf * (1-mask))
+
+ loss_fkd = torch.norm(features - features_old, p=2, dim=1)
+ loss_fkd = self.args["lambda_fkd"] * torch.sum(loss_fkd * mask)
+
+ index = np.random.choice(range(self._known_classes),size=self.args["batch_size"],replace=True)
+
+ proto_features = np.array(self._protos)[index]
+ proto_targets = index
+ proto_features = proto_features
+ proto_features = torch.from_numpy(proto_features).float().to(self._device,non_blocking=True)
+ proto_targets = torch.from_numpy(proto_targets).to(self._device,non_blocking=True)
+
+
+ proto_logits = self._network_module_ptr.fc(proto_features)["logits"]
+ loss_proto = self.args["lambda_proto"] * F.cross_entropy(proto_logits/self.args["temp"], proto_targets)
+ return logits, loss_clf, loss_fkd, loss_proto
+
+
+ def eval_task(self, save_conf=False):
+ y_pred, y_true = self._eval_cnn(self.test_loader)
+ cnn_accy = self._evaluate(y_pred, y_true)
+
+ if hasattr(self, '_class_means'):
+ y_pred, y_true = self._eval_nme(self.test_loader, self._class_means)
+ nme_accy = self._evaluate(y_pred, y_true)
+ elif hasattr(self, '_protos'):
+ y_pred, y_true = self._eval_nme(self.test_loader, self._protos/np.linalg.norm(self._protos,axis=1)[:,None])
+ nme_accy = self._evaluate(y_pred, y_true)
+ else:
+ nme_accy = None
+ if save_conf:
+ _pred = y_pred.T[0]
+ _pred_path = os.path.join(self.args['logfilename'], "pred.npy")
+ _target_path = os.path.join(self.args['logfilename'], "target.npy")
+ np.save(_pred_path, _pred)
+ np.save(_target_path, y_true)
+
+ _save_dir = os.path.join(f"./results/{self.args['model_name']}/conf_matrix/{self.args['prefix']}")
+ os.makedirs(_save_dir, exist_ok=True)
+ _save_path = os.path.join(_save_dir, f"{self.args['csv_name']}.csv")
+ with open(_save_path, "a+") as f:
+ f.write(f"{self.args['model_name']},{_pred_path},{_target_path} \n")
+ return cnn_accy, nme_accy
+
+ def _network_expansion(self):
+ if self._cur_task > 0:
+ for p in self._network.convnet.parameters():
+ p.requires_grad = True
+ for k, v in self._network.convnet.named_parameters():
+ if 'adapter' not in k:
+ v.requires_grad = False
+ # self._network.convnet.re_init_params() # do not use!
+ self._network.convnet.switch("parallel_adapters")
+
+ def _network_compression(self):
+
+ model_dict = self._network.state_dict()
+ for k, v in model_dict.items():
+ if 'adapter' in k:
+ k_conv3 = k.replace('adapter', 'conv')
+ if 'weight' in k:
+ model_dict[k_conv3] = model_dict[k_conv3] + F.pad(v, [1, 1, 1, 1], 'constant', 0)
+ model_dict[k] = torch.zeros_like(v)
+ elif 'bias' in k:
+ model_dict[k_conv3] = model_dict[k_conv3] + v
+ model_dict[k] = torch.zeros_like(v)
+ else:
+ assert 0
+ self._network.load_state_dict(model_dict)
+ self._network.convnet.switch("normal")
\ No newline at end of file
diff --git a/models/wa.py b/models/wa.py
new file mode 100644
index 0000000000000000000000000000000000000000..23de65687736c4882fc9d18d0fe7068e57bfd730
--- /dev/null
+++ b/models/wa.py
@@ -0,0 +1,217 @@
+import logging
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch import nn
+from torch import optim
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from models.base import BaseLearner
+from utils.inc_net import IncrementalNet
+from utils.toolkit import target2onehot, tensor2numpy
+
+EPSILON = 1e-8
+
+
+init_epoch = 200
+init_lr = 0.1
+init_milestones = [60, 120, 170]
+init_lr_decay = 0.1
+init_weight_decay = 0.0005
+
+
+epochs = 170
+lrate = 0.1
+milestones = [60, 100, 140]
+lrate_decay = 0.1
+batch_size = 128
+weight_decay = 2e-4
+num_workers = 8
+T = 2
+
+
+class WA(BaseLearner):
+ def __init__(self, args):
+ super().__init__(args)
+ self._network = IncrementalNet(args, False)
+
+ def after_task(self):
+ if self._cur_task > 0:
+ self._network.weight_align(self._total_classes - self._known_classes)
+ self._old_network = self._network.copy().freeze()
+ self._known_classes = self._total_classes
+ logging.info("Exemplar size: {}".format(self.exemplar_size))
+
+ def incremental_train(self, data_manager):
+ self._cur_task += 1
+ self._total_classes = self._known_classes + data_manager.get_task_size(
+ self._cur_task
+ )
+ self._network.update_fc(self._total_classes)
+ logging.info(
+ "Learning on {}-{}".format(self._known_classes, self._total_classes)
+ )
+
+ # Loader
+ train_dataset = data_manager.get_dataset(
+ np.arange(self._known_classes, self._total_classes),
+ source="train",
+ mode="train",
+ appendent=self._get_memory(),
+ )
+ self.train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_dataset = data_manager.get_dataset(
+ np.arange(0, self._total_classes), source="test", mode="test"
+ )
+ self.test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ # Procedure
+ if len(self._multiple_gpus) > 1:
+ self._network = nn.DataParallel(self._network, self._multiple_gpus)
+ self._train(self.train_loader, self.test_loader)
+ self.build_rehearsal_memory(data_manager, self.samples_per_class)
+ if len(self._multiple_gpus) > 1:
+ self._network = self._network.module
+
+ def _train(self, train_loader, test_loader):
+ self._network.to(self._device)
+ if self._old_network is not None:
+ self._old_network.to(self._device)
+
+ if self._cur_task == 0:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ momentum=0.9,
+ lr=init_lr,
+ weight_decay=init_weight_decay,
+ )
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay
+ )
+ self._init_train(train_loader, test_loader, optimizer, scheduler)
+ else:
+ optimizer = optim.SGD(
+ self._network.parameters(),
+ lr=lrate,
+ momentum=0.9,
+ weight_decay=weight_decay,
+ ) # 1e-5
+ scheduler = optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer, milestones=milestones, gamma=lrate_decay
+ )
+ self._update_representation(train_loader, test_loader, optimizer, scheduler)
+ if len(self._multiple_gpus) > 1:
+ self._network.module.weight_align(
+ self._total_classes - self._known_classes
+ )
+ else:
+ self._network.weight_align(self._total_classes - self._known_classes)
+
+ def _init_train(self, train_loader, test_loader, optimizer, scheduler):
+ prog_bar = tqdm(range(init_epoch))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss = F.cross_entropy(logits, targets)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ init_epoch,
+ losses / len(train_loader),
+ train_acc,
+ )
+
+ prog_bar.set_description(info)
+
+ logging.info(info)
+
+ def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
+ kd_lambda = self._known_classes / self._total_classes
+ prog_bar = tqdm(range(epochs))
+ for _, epoch in enumerate(prog_bar):
+ self._network.train()
+ losses = 0.0
+ correct, total = 0, 0
+ for i, (_, inputs, targets) in enumerate(train_loader):
+ inputs, targets = inputs.to(self._device), targets.to(self._device)
+ logits = self._network(inputs)["logits"]
+
+ loss_clf = F.cross_entropy(logits, targets)
+ loss_kd = _KD_loss(
+ logits[:, : self._known_classes],
+ self._old_network(inputs)["logits"],
+ T,
+ )
+
+ loss = (1-kd_lambda) * loss_clf + kd_lambda * loss_kd
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ losses += loss.item()
+
+ # acc
+ _, preds = torch.max(logits, dim=1)
+ correct += preds.eq(targets.expand_as(preds)).cpu().sum()
+ total += len(targets)
+
+ scheduler.step()
+ train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
+ if epoch % 5 == 0:
+ test_acc = self._compute_accuracy(self._network, test_loader)
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ test_acc,
+ )
+ else:
+ info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
+ self._cur_task,
+ epoch + 1,
+ epochs,
+ losses / len(train_loader),
+ train_acc,
+ )
+ prog_bar.set_description(info)
+ logging.info(info)
+
+
+def _KD_loss(pred, soft, T):
+ pred = torch.log_softmax(pred / T, dim=1)
+ soft = torch.softmax(soft / T, dim=1)
+ return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c08ad666431d8cc241f603b279817c23a6a0375e
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+kaggle
+numpy==1.21.0
+Pillow==10.3.0
+POT==0.4.0
+quadprog==0.1.12
+scikit_learn
+scipy==1.3.3
+tqdm==4.66.2
+Flask
+flask_autoindex
+boto3
+scikit-learn
+python-dotenv
\ No newline at end of file
diff --git a/resources/ImageNet100.png b/resources/ImageNet100.png
new file mode 100644
index 0000000000000000000000000000000000000000..ccde98b17c48ccc3cc45faf9fb468058537f4da3
Binary files /dev/null and b/resources/ImageNet100.png differ
diff --git a/resources/cifar100.png b/resources/cifar100.png
new file mode 100644
index 0000000000000000000000000000000000000000..2f4ce302a4eb092487aac62dada1056a022fd7da
Binary files /dev/null and b/resources/cifar100.png differ
diff --git a/resources/imagenet20st5.png b/resources/imagenet20st5.png
new file mode 100644
index 0000000000000000000000000000000000000000..e1206cbfe140694ba16747ca567648c07b9520c0
Binary files /dev/null and b/resources/imagenet20st5.png differ
diff --git a/resources/logo.png b/resources/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..66c6d5a887d4439158653a73bbe6c19ae312688e
Binary files /dev/null and b/resources/logo.png differ
diff --git a/rmm_train.py b/rmm_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9ecd6f1307471545cc91d31f7a60e03cfa6e92
--- /dev/null
+++ b/rmm_train.py
@@ -0,0 +1,232 @@
+'''
+We implemented `iCaRL+RMM`, `FOSTER+RMM` in [rmm.py](models/rmm.py). We implemented the `Pretraining Stage` of `RMM` in [rmm_train.py](rmm_train.py).
+Use the following training script to run it.
+```bash
+python rmm_train.py --config=./exps/rmm-pretrain.json
+```
+'''
+import json
+import argparse
+from trainer import train
+import sys
+import logging
+import copy
+import torch
+from utils import factory
+from utils.data_manager import DataManager
+from utils.rl_utils.ddpg import DDPG
+from utils.rl_utils.rl_utils import ReplayBuffer
+from utils.toolkit import count_parameters
+import os
+import numpy as np
+import random
+
+
+class CILEnv:
+ def __init__(self, args) -> None:
+ self._args = copy.deepcopy(args)
+ self.settings = [(50, 2), (50, 5), (50, 10), (50, 20), (10, 10), (20, 20), (5, 5)]
+ # self.settings = [(5,5)] # Debug
+ self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))]
+ self.data_manager = DataManager(
+ self._args["dataset"],
+ self._args["shuffle"],
+ self._args["seed"],
+ self._args["init_cls"],
+ self._args["increment"],
+ )
+ self.model = factory.get_model(self._args["model_name"], self._args)
+
+ @property
+ def nb_task(self):
+ return self.data_manager.nb_tasks
+
+ @property
+ def cur_task(self):
+ return self.model._cur_task
+
+ def get_task_size(self, task_id):
+ return self.data_manager.get_task_size(task_id)
+
+ def reset(self):
+ self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))]
+ self.data_manager = DataManager(
+ self._args["dataset"],
+ self._args["shuffle"],
+ self._args["seed"],
+ self._args["init_cls"],
+ self._args["increment"],
+ )
+ self.model = factory.get_model(self._args["model_name"], self._args)
+
+ info = "start new task: dataset: {}, init_cls: {}, increment: {}".format(
+ self._args["dataset"], self._args["init_cls"], self._args["increment"]
+ )
+ return np.array([self.get_task_size(0) / 100, 0]), None, False, info
+
+ def step(self, action):
+ self.model._m_rate_list.append(action[0])
+ self.model._c_rate_list.append(action[1])
+ self.model.incremental_train(self.data_manager)
+ cnn_accy, nme_accy = self.model.eval_task()
+ self.model.after_task()
+ done = self.cur_task == self.nb_task - 1
+ info = "running task [{}/{}]: dataset: {}, increment: {}, cnn_accy top1: {}, top5: {}".format(
+ self.model._known_classes,
+ 100,
+ self._args["dataset"],
+ self._args["increment"],
+ cnn_accy["top1"],
+ cnn_accy["top5"],
+ )
+ return (
+ np.array(
+ [
+ self.get_task_size(self.cur_task+1)/100 if not done else 0.,
+ self.model.memory_size
+ / (self.model.memory_size + self.model.new_memory_size),
+ ]
+ ),
+ cnn_accy["top1"]/100,
+ done,
+ info,
+ )
+
+
+def _train(args):
+
+ logs_name = "logs/RL-CIL/{}/".format(args["model_name"])
+ if not os.path.exists(logs_name):
+ os.makedirs(logs_name)
+
+ logfilename = "logs/RL-CIL/{}/{}_{}_{}_{}_{}".format(
+ args["model_name"],
+ args["prefix"],
+ args["seed"],
+ args["model_name"],
+ args["convnet_type"],
+ args["dataset"],
+ )
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(filename)s] => %(message)s",
+ handlers=[
+ logging.FileHandler(filename=logfilename + ".log"),
+ logging.StreamHandler(sys.stdout),
+ ],
+ )
+
+ _set_random()
+ _set_device(args)
+ print_args(args)
+
+ actor_lr = 5e-4
+ critic_lr = 5e-3
+ num_episodes = 200
+ hidden_dim = 32
+ gamma = 0.98
+ tau = 0.005
+ buffer_size = 1000
+ minimal_size = 50
+ batch_size = 32
+ sigma = 0.2 # action noise, encouraging the off-policy algo to explore.
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ env = CILEnv(args)
+ replay_buffer = ReplayBuffer(buffer_size)
+ agent = DDPG(
+ 2, 1, 4, hidden_dim, False, 1, sigma, actor_lr, critic_lr, tau, gamma, device
+ )
+ for iteration in range(num_episodes):
+ state, *_, info = env.reset()
+ logging.info(info)
+ done = False
+ while not done:
+ action = agent.take_action(state)
+ logging.info(f"take action: m_rate {action[0]}, c_rate {action[1]}")
+ next_state, reward, done, info = env.step(action)
+ logging.info(info)
+ replay_buffer.add(state, action, reward, next_state, done)
+ state = next_state
+ if replay_buffer.size() > minimal_size:
+ b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
+ transition_dict = {
+ "states": b_s,
+ "actions": b_a,
+ "next_states": b_ns,
+ "rewards": b_r,
+ "dones": b_d,
+ }
+ agent.update(transition_dict)
+
+
+def _set_device(args):
+ device_type = args["device"]
+ gpus = []
+
+ for device in device_type:
+ if device_type == -1:
+ device = torch.device("cpu")
+ else:
+ device = torch.device("cuda:{}".format(device))
+
+ gpus.append(device)
+
+ args["device"] = gpus
+
+
+def _set_random():
+ random.seed(1)
+ torch.manual_seed(1)
+ torch.cuda.manual_seed(1)
+ torch.cuda.manual_seed_all(1)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def print_args(args):
+ for key, value in args.items():
+ logging.info("{}: {}".format(key, value))
+
+
+def train(args):
+ seed_list = copy.deepcopy(args["seed"])
+ device = copy.deepcopy(args["device"])
+
+ for seed in seed_list:
+ args["seed"] = seed
+ args["device"] = device
+ _train(args)
+
+
+def main():
+ args = setup_parser().parse_args()
+ param = load_json(args.config)
+ args = vars(args) # Converting argparse Namespace to a dict.
+ args.update(param) # Add parameters from json
+
+ train(args)
+
+
+def load_json(settings_path):
+ with open(settings_path) as data_file:
+ param = json.load(data_file)
+
+ return param
+
+
+def setup_parser():
+ parser = argparse.ArgumentParser(
+ description="Reproduce of multiple continual learning algorthms."
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="./exps/finetune.json",
+ help="Json file of settings.",
+ )
+
+ return parser
+
+
+if __name__ == "__main__":
+ main()
diff --git a/server.py b/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fcc0f6bf4b339b73ec177dab05b3df3ae93941e
--- /dev/null
+++ b/server.py
@@ -0,0 +1,89 @@
+from flask import Flask, send_from_directory, request, send_file
+from flask_autoindex import AutoIndex
+import subprocess, os
+
+from download_s3_path import download_s3_folder
+from download_file_from_s3 import download_from_s3
+from split import split_data
+import os
+import shutil
+import json
+import time
+
+app = Flask(__name__)
+app.config["UPLOAD_FOLDER"] = "upload"
+AutoIndex(app, browse_root=os.path.curdir)
+
+
+@app.route("/train", methods=["GET"])
+def train():
+ try:
+ subprocess.Popen(["./simple_train.sh"])
+ return "Bash script triggered successfully!"
+ except subprocess.CalledProcessError as e:
+ return f"An error occurred: {str(e)}", 500
+
+
+@app.route("/train/workings/", methods=["GET"])
+def train_with_working_id(working_id):
+ path = f"working/{working_id}"
+ delete_folder(path)
+ download_s3_folder(os.getenv("S3_BUCKET_NAME", "pycil.com"), path, path)
+
+ data_path = path + "/data"
+ config_path = path + "/config.json"
+ output_path = f"s3://pycil.com/output/{working_id}"
+
+ split_data(data_path)
+
+ subprocess.Popen(
+ [
+ "./train_from_working.sh",
+ config_path,
+ data_path,
+ "models",
+ f"s3://pycil.com/output/{working_id}/{int(time.time())}",
+ ]
+ )
+
+ return f"Training started with working id {working_id}!"
+
+
+@app.route("/inference", methods=["POST"])
+def infernece():
+ file = request.files["image"]
+ file.save(os.path.join(app.config["UPLOAD_FOLDER"], file.filename))
+
+ input_path = os.path.join(app.config["UPLOAD_FOLDER"], file.filename)
+ config_path = request.form["config_path"]
+ checkpoint_path = request.form["checkpoint_path"]
+
+ download_from_s3("pycil.com", config_path, "config.json")
+ download_from_s3("pycil.com", checkpoint_path, "checkpoint.pkl")
+ subprocess.call(
+ [
+ "python",
+ "inference.py",
+ "--config",
+ "config.json",
+ "--checkpoint",
+ "checkpoint.pkl",
+ "--input",
+ input_path,
+ "--output",
+ "output.json",
+ ]
+ )
+ return send_file("output.json")
+
+
+def delete_folder(folder_path):
+ if os.path.exists(folder_path):
+ shutil.rmtree(folder_path)
+ print(f"Folder '{folder_path}' has been deleted.")
+ else:
+ print(f"Folder '{folder_path}' does not exist.")
+
+
+if __name__ == "__main__":
+ app.run(host="0.0.0.0", port=7860, debug=True)
diff --git a/simple_train.sh b/simple_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2bc939de85c343ea30af06ead4a428d9b6b7288a
--- /dev/null
+++ b/simple_train.sh
@@ -0,0 +1,5 @@
+#!/bin/sh
+
+python main.py --config ./exps/simplecil_general.json --data ./car_data/car_data
+
+./upload_s3.sh
diff --git a/split.py b/split.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e4360d9e773e3882729d2465258c02e9ccf5eb7
--- /dev/null
+++ b/split.py
@@ -0,0 +1,55 @@
+import os
+import shutil
+import sys
+from sklearn.model_selection import train_test_split
+
+
+def split_data(data_dir, train_ratio=0.8, seed=42):
+ train_dir = os.path.join(data_dir, "train")
+ val_dir = os.path.join(data_dir, "val")
+
+ # Ensure the train and val directories exist
+ os.makedirs(train_dir, exist_ok=True)
+ os.makedirs(val_dir, exist_ok=True)
+
+ # Iterate over each class folder
+ for class_name in os.listdir(data_dir):
+ class_path = os.path.join(data_dir, class_name)
+ if os.path.isdir(class_path) and class_name not in ["train", "val"]:
+ # Get a list of all files in the class directory
+ files = os.listdir(class_path)
+ files = [f for f in files if os.path.isfile(os.path.join(class_path, f))]
+
+ # Split the files into training and validation sets
+ train_files, val_files = train_test_split(
+ files, train_size=train_ratio, random_state=seed
+ )
+
+ # Create class directories in train and val directories
+ train_class_dir = os.path.join(train_dir, class_name)
+ val_class_dir = os.path.join(val_dir, class_name)
+ os.makedirs(train_class_dir, exist_ok=True)
+ os.makedirs(val_class_dir, exist_ok=True)
+
+ # Move training files
+ for file in train_files:
+ shutil.move(
+ os.path.join(class_path, file), os.path.join(train_class_dir, file)
+ )
+
+ # Move validation files
+ for file in val_files:
+ shutil.move(
+ os.path.join(class_path, file), os.path.join(val_class_dir, file)
+ )
+
+ print("Data split complete.")
+
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ print("Usage: python split_data.py ")
+ sys.exit(1)
+
+ data_dir = sys.argv[1]
+ split_data(data_dir)
diff --git a/static/test.log b/static/test.log
new file mode 100644
index 0000000000000000000000000000000000000000..c7b8f380dcfd08cdcdece7a95198f653ea0915ee
--- /dev/null
+++ b/static/test.log
@@ -0,0 +1 @@
+this is a test line
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/test.py
@@ -0,0 +1 @@
+
diff --git a/test_blur.py b/test_blur.py
new file mode 100644
index 0000000000000000000000000000000000000000..95eda4a6e1d0fd08f3625b9fcf27d1cca677ee57
--- /dev/null
+++ b/test_blur.py
@@ -0,0 +1,30 @@
+from torchvision import transforms
+from PIL import Image
+import argparse
+
+def main():
+ args = setup_parser().parse_args()
+ path = args.path
+ img = Image.open(path)
+ trf = transforms.Compose([
+ transforms.RandomAffine(25, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=8),
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ColorJitter(brightness = 0.3, saturation = 0.2),
+ transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.5, 2.0))], p=1), # Apply Gaussian blur with random probability
+ ])
+ img = trf(img)
+ img.save("blur.jpg")
+
+def setup_parser():
+ parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
+ parser.add_argument('--path', type=str,
+ help='Image file.')
+
+ return parser
+
+
+
+if __name__ == '__main__':
+ main()
+
\ No newline at end of file
diff --git a/test_upload/test.txt b/test_upload/test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4f872797f1bb936f015eaa34908d5744b420a7b8
--- /dev/null
+++ b/test_upload/test.txt
@@ -0,0 +1 @@
+this file will be upload to s3
\ No newline at end of file
diff --git a/train.sh b/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..67d0586705cbc27e4aa3a87af770a135ea399dce
--- /dev/null
+++ b/train.sh
@@ -0,0 +1,7 @@
+#!/bin/sh
+for arg in $@; do
+ python ./main.py --config=$arg
+ # Your commands to process each argument here
+done
+
+./upload_s3.sh
\ No newline at end of file
diff --git a/train_from_working.sh b/train_from_working.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c3159358f8c5a648a8688d08b986679cc8f3333c
--- /dev/null
+++ b/train_from_working.sh
@@ -0,0 +1,16 @@
+#!/bin/sh
+
+# Ensure the script exits on the first error and prints each command before executing it
+set -e
+set -x
+
+# Check if config, data, upload_s3_arg, and s3_path arguments were provided, if not, set default values
+config=${1:-./exps/simplecil_general.json}
+data=${2:-./car_data/car_data}
+upload_s3_arg=${3:-./models}
+s3_path=${4:-s3://pycil.com/"$(date -u +"%Y-%m-%dT%H:%M:%SZ")"}
+
+# Run the training script with the provided or default config and data arguments
+python main.py --config "$config" --data "$data"
+
+./upload_s3.sh "$upload_s3_arg" "$s3_path"
diff --git a/train_memo.py b/train_memo.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5d3f79e1853188e3aca0edbad7664a6e043aeff
--- /dev/null
+++ b/train_memo.py
@@ -0,0 +1,187 @@
+import sys
+import logging
+import copy
+import torch
+from utils import factory
+from utils.data_manager import DataManager
+from utils.toolkit import count_parameters
+import os
+import numpy as np
+
+
+def train(args):
+ seed_list = copy.deepcopy(args["seed"])
+ device = copy.deepcopy(args["device"])
+
+ for seed in seed_list:
+ args["seed"] = seed
+ args["device"] = device
+ _train(args)
+
+
+def _train(args):
+
+ init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
+ logs_name = "logs/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment'])
+
+ if not os.path.exists(logs_name):
+ os.makedirs(logs_name)
+
+ save_name = "models/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment'])
+
+ if not os.path.exists(save_name):
+ os.makedirs(save_name)
+ logfilename = "logs/{}/{}/{}/{}/{}_{}_{}".format(
+ args["model_name"],
+ args["dataset"],
+ init_cls,
+ args["increment"],
+ args["prefix"],
+ args["seed"],
+ args["convnet_type"],
+ )
+ if not os.path.exists(logs_name):
+ os.makedirs(logs_name)
+ args['logfilename'] = logs_name
+ args['csv_name'] = "{}_{}_{}".format(
+ args["prefix"],
+ args["seed"],
+ args["convnet_type"],
+ )
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(filename)s] => %(message)s",
+ handlers=[
+ logging.FileHandler(filename=logfilename + ".log"),
+ logging.StreamHandler(sys.stdout),
+ ],
+ )
+
+ _set_random()
+ _set_device(args)
+ print_args(args)
+ data_manager = DataManager(
+ args["dataset"],
+ args["shuffle"],
+ args["seed"],
+ args["init_cls"],
+ args["increment"],
+ )
+ model = factory.get_model(args["model_name"], args)
+
+ cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []}
+ cnn_matrix, nme_matrix = [], []
+
+ for task in range(data_manager.nb_tasks):
+ print(args["device"])
+ logging.info("All params: {}".format(count_parameters(model._network)))
+ logging.info(
+ "Trainable params: {}".format(count_parameters(model._network, True))
+ )
+ model.incremental_train(data_manager)
+ cnn_accy, nme_accy = model.eval_task(save_conf=True)
+ model.after_task()
+
+ if nme_accy is not None:
+ logging.info("CNN: {}".format(cnn_accy["grouped"]))
+ logging.info("NME: {}".format(nme_accy["grouped"]))
+
+ cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
+ cnn_keys_sorted = sorted(cnn_keys)
+ cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
+ cnn_matrix.append(cnn_values)
+
+ nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key]
+ nme_keys_sorted = sorted(nme_keys)
+ nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted]
+ nme_matrix.append(nme_values)
+
+
+ cnn_curve["top1"].append(cnn_accy["top1"])
+ cnn_curve["top5"].append(cnn_accy["top5"])
+
+ nme_curve["top1"].append(nme_accy["top1"])
+ nme_curve["top5"].append(nme_accy["top5"])
+
+ logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
+ logging.info("CNN top5 curve: {}".format(cnn_curve["top5"]))
+ logging.info("NME top1 curve: {}".format(nme_curve["top1"]))
+ logging.info("NME top5 curve: {}\n".format(nme_curve["top5"]))
+
+ print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
+ print('Average Accuracy (NME):', sum(nme_curve["top1"])/len(nme_curve["top1"]))
+
+ logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))
+ logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"])/len(nme_curve["top1"])))
+ else:
+ logging.info("No NME accuracy.")
+ logging.info("CNN: {}".format(cnn_accy["grouped"]))
+
+ cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
+ cnn_keys_sorted = sorted(cnn_keys)
+ cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
+ cnn_matrix.append(cnn_values)
+
+ cnn_curve["top1"].append(cnn_accy["top1"])
+ cnn_curve["top5"].append(cnn_accy["top5"])
+
+ logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
+ logging.info("CNN top5 curve: {}\n".format(cnn_curve["top5"]))
+
+ print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
+ logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))
+ model.save_checkpoint(save_name)
+
+ if len(cnn_matrix)>0:
+ np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))])
+ for idxx, line in enumerate(cnn_matrix):
+ idxy = len(line)
+ np_acctable[idxx, :idxy] = np.array(line)
+ np_acctable = np_acctable.T
+ forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1])
+ logging.info('Forgetting (CNN): {}'.format(forgetting))
+ logging.info('Accuracy Matrix (CNN): {}'.format(np_acctable))
+ print('Accuracy Matrix (CNN):')
+ print(np_acctable)
+ print('Forgetting (CNN):', forgetting)
+ if len(nme_matrix)>0:
+ np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))])
+ for idxx, line in enumerate(nme_matrix):
+ idxy = len(line)
+ np_acctable[idxx, :idxy] = np.array(line)
+ np_acctable = np_acctable.T
+ forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1])
+ logging.info('Forgetting (NME): {}'.format(forgetting))
+ logging.info('Accuracy Matrix (NME): {}'.format(np_acctable))
+ print('Accuracy Matrix (NME):')
+ print(np_acctable)
+ print('Forgetting (NME):', forgetting)
+
+
+def _set_device(args):
+ device_type = args["device"]
+ gpus = []
+
+ for device in device_type:
+ if device_type == -1:
+ device = torch.device("cpu")
+ else:
+ device = torch.device("cuda:{}".format(device))
+
+ gpus.append(device)
+
+ args["device"] = gpus
+
+
+def _set_random():
+ torch.manual_seed(1)
+ torch.cuda.manual_seed(1)
+ torch.cuda.manual_seed_all(1)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def print_args(args):
+ for key, value in args.items():
+ logging.info("{}: {}".format(key, value))
+
diff --git a/train_more.py b/train_more.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba2a0cd24bdbe11e7ee37c73663c57b5e6058712
--- /dev/null
+++ b/train_more.py
@@ -0,0 +1,186 @@
+import sys
+import logging
+import copy
+import torch
+from utils import factory
+from utils.data_manager import DataManager
+from utils.toolkit import count_parameters
+import os
+import numpy as np
+from load_model import load_model, get_methods
+
+def train_more(args):
+ seed_list = copy.deepcopy(args["seed"])
+ device = copy.deepcopy(args["device"])
+
+ for seed in seed_list:
+ args["seed"] = seed
+ args["device"] = device
+ _train_more(args)
+
+
+def _train_more(args):
+
+ init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
+ logs_name = "logs/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment'])
+
+ if not os.path.exists(logs_name):
+ os.makedirs(logs_name)
+
+ save_name = "models/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment'])
+
+ if not os.path.exists(save_name):
+ os.makedirs(save_name)
+ logfilename = "logs/{}/{}/{}/{}/{}_{}_{}".format(
+ args["model_name"],
+ args["dataset"],
+ init_cls,
+ args["increment"],
+ args["prefix"],
+ args["seed"],
+ args["convnet_type"],
+ )
+ if not os.path.exists(logs_name):
+ os.makedirs(logs_name)
+ args['logfilename'] = logs_name
+ args['csv_name'] = "{}_{}_{}".format(
+ args["prefix"],
+ args["seed"],
+ args["convnet_type"],
+ )
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(filename)s] => %(message)s",
+ handlers=[
+ logging.FileHandler(filename=logfilename + ".log"),
+ logging.StreamHandler(sys.stdout),
+ ],
+ )
+
+ _set_random()
+ print_args(args)
+ model = load_model(args)
+ data_manager = DataManager(
+ args["dataset"],
+ args["shuffle"],
+ args["seed"],
+ args["init_cls"],
+ args["increment"],
+ resume = True,
+ path = args["data"],
+ class_list = model.class_list
+ )
+ cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []}
+ cnn_matrix, nme_matrix = [], []
+
+ for task in range(data_manager.nb_tasks):
+ print(args["device"])
+ logging.info("All params: {}".format(count_parameters(model._network)))
+ logging.info(
+ "Trainable params: {}".format(count_parameters(model._network, True))
+ )
+ model.incremental_train(data_manager)
+ cnn_accy, nme_accy = model.eval_task(save_conf=True)
+ model.after_task()
+
+ if nme_accy is not None:
+ logging.info("CNN: {}".format(cnn_accy["grouped"]))
+ logging.info("NME: {}".format(nme_accy["grouped"]))
+
+ cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
+ cnn_keys_sorted = sorted(cnn_keys)
+ cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
+ cnn_matrix.append(cnn_values)
+
+ nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key]
+ nme_keys_sorted = sorted(nme_keys)
+ nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted]
+ nme_matrix.append(nme_values)
+
+
+ cnn_curve["top1"].append(cnn_accy["top1"])
+ cnn_curve["top5"].append(cnn_accy["top5"])
+
+ nme_curve["top1"].append(nme_accy["top1"])
+ nme_curve["top5"].append(nme_accy["top5"])
+
+ logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
+ logging.info("CNN top5 curve: {}".format(cnn_curve["top5"]))
+ logging.info("NME top1 curve: {}".format(nme_curve["top1"]))
+ logging.info("NME top5 curve: {}\n".format(nme_curve["top5"]))
+
+ print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
+ print('Average Accuracy (NME):', sum(nme_curve["top1"])/len(nme_curve["top1"]))
+
+ logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))
+ logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"])/len(nme_curve["top1"])))
+ else:
+ logging.info("No NME accuracy.")
+ logging.info("CNN: {}".format(cnn_accy["grouped"]))
+
+ cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
+ cnn_keys_sorted = sorted(cnn_keys)
+ cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
+ cnn_matrix.append(cnn_values)
+
+ cnn_curve["top1"].append(cnn_accy["top1"])
+ cnn_curve["top5"].append(cnn_accy["top5"])
+
+ logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
+ logging.info("CNN top5 curve: {}\n".format(cnn_curve["top5"]))
+
+ print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
+ logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))
+ model.save_checkpoint(save_name)
+ if len(cnn_matrix)>0:
+ np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))])
+ for idxx, line in enumerate(cnn_matrix):
+ idxy = len(line)
+ np_acctable[idxx, :idxy] = np.array(line)
+ np_acctable = np_acctable.T
+ forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1])
+ logging.info('Forgetting (CNN): {}'.format(forgetting))
+ logging.info('Accuracy Matrix (CNN): {}'.format(np_acctable))
+ print('Accuracy Matrix (CNN):')
+ print(np_acctable)
+ print('Forgetting (CNN):', forgetting)
+ if len(nme_matrix)>0:
+ np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))])
+ for idxx, line in enumerate(nme_matrix):
+ idxy = len(line)
+ np_acctable[idxx, :idxy] = np.array(line)
+ np_acctable = np_acctable.T
+ forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1])
+ logging.info('Forgetting (NME): {}'.format(forgetting))
+ logging.info('Accuracy Matrix (NME): {}'.format(np_acctable))
+ print('Accuracy Matrix (NME):')
+ print(np_acctable)
+ print('Forgetting (NME):', forgetting)
+
+def _set_device(args):
+ device_type = args["device"]
+ gpus = []
+
+ for device in device_type:
+ if device == -1:
+ device = torch.device("cpu")
+ else:
+ device = torch.device("cuda:{}".format(device))
+
+ gpus.append(device)
+
+ args["device"] = gpus
+
+
+def _set_random():
+ torch.manual_seed(1)
+ torch.cuda.manual_seed(1)
+ torch.cuda.manual_seed_all(1)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def print_args(args):
+ for key, value in args.items():
+ logging.info("{}: {}".format(key, value))
+
diff --git a/train_more.sh b/train_more.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e3d10e2617d8cf36b85348dbcf2adb5b067fd83b
--- /dev/null
+++ b/train_more.sh
@@ -0,0 +1,5 @@
+#! /bin/sh
+for arg in $@; do
+ python ./main.py --config=$arg --resume
+ # Your commands to process each argument here
+done
diff --git a/trainer.py b/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..09cd4ee86e3e48deea77add4d9b8679c3113dc3c
--- /dev/null
+++ b/trainer.py
@@ -0,0 +1,192 @@
+import sys
+import logging
+import copy
+import torch
+from utils import factory
+from utils.data_manager import DataManager
+from utils.toolkit import count_parameters
+import os
+import numpy as np
+
+
+def train(args):
+ seed_list = copy.deepcopy(args["seed"])
+ device = copy.deepcopy(args["device"])
+
+ for seed in seed_list:
+ args["seed"] = seed
+ args["device"] = device
+ _train(args)
+
+
+def _train(args):
+
+ init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
+ logs_name = "logs/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], init_cls, args['increment'])
+
+ if not os.path.exists(logs_name):
+ os.makedirs(logs_name)
+
+ save_name = "models/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], init_cls, args['increment'])
+
+ if not os.path.exists(save_name):
+ os.makedirs(save_name)
+ if not os.path.exists(logs_name):
+ os.makedirs(logs_name)
+ logfilename = "logs/{}/{}_{}/{}/{}/{}_{}_{}".format(
+ args["model_name"],
+ args["dataset"],
+ args['data'],
+ init_cls,
+ args["increment"],
+ args["prefix"],
+ args["seed"],
+ args["convnet_type"],
+ )
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(filename)s] => %(message)s",
+ handlers=[
+ logging.FileHandler(filename=logfilename + ".log"),
+ logging.StreamHandler(sys.stdout),
+ ],
+ force=True
+ )
+ args['logfilename'] = logs_name
+ args['csv_name'] = "{}_{}_{}".format(
+ args["prefix"],
+ args["seed"],
+ args["convnet_type"],
+ )
+
+
+ _set_random()
+ _set_device(args)
+ print_args(args)
+ model = factory.get_model(args["model_name"], args)
+ data_manager = DataManager(
+ args["dataset"],
+ args["shuffle"],
+ args["seed"],
+ args["init_cls"],
+ args["increment"],
+ path = args["data"],
+ )
+ if data_manager.get_task_size(0) < 5:
+ top_string = "top{}".format(data_manager.get_task_size(0))
+ else:
+ top_string = "top5"
+ cnn_curve, nme_curve = {"top1": [], top_string: []}, {"top1": [], top_string: []}
+ cnn_matrix, nme_matrix = [], []
+
+ for task in range(data_manager.nb_tasks):
+ print(args["device"])
+ logging.info("All params: {}".format(count_parameters(model._network)))
+ logging.info(
+ "Trainable params: {}".format(count_parameters(model._network, True))
+ )
+ model.incremental_train(data_manager)
+ cnn_accy, nme_accy = model.eval_task(save_conf=True)
+ model.after_task()
+
+ if nme_accy is not None:
+ logging.info("CNN: {}".format(cnn_accy["grouped"]))
+ logging.info("NME: {}".format(nme_accy["grouped"]))
+
+ cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
+ cnn_keys_sorted = sorted(cnn_keys)
+ cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
+ cnn_matrix.append(cnn_values)
+
+ nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key]
+ nme_keys_sorted = sorted(nme_keys)
+ nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted]
+ nme_matrix.append(nme_values)
+
+
+ cnn_curve["top1"].append(cnn_accy["top1"])
+ cnn_curve[top_string].append(cnn_accy["top{}".format(model.topk)])
+
+ nme_curve["top1"].append(nme_accy["top1"])
+ nme_curve[top_string].append(nme_accy["top{}".format(model.topk)])
+
+ logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
+ logging.info("CNN top5 curve: {}".format(cnn_curve[top_string]))
+ logging.info("NME top1 curve: {}".format(nme_curve["top1"]))
+ logging.info("NME top5 curve: {}\n".format(nme_curve[top_string]))
+
+ print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
+ print('Average Accuracy (NME):', sum(nme_curve["top1"])/len(nme_curve["top1"]))
+
+ logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))
+ logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"])/len(nme_curve["top1"])))
+ else:
+ logging.info("No NME accuracy.")
+ logging.info("CNN: {}".format(cnn_accy["grouped"]))
+
+ cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
+ cnn_keys_sorted = sorted(cnn_keys)
+ cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
+ cnn_matrix.append(cnn_values)
+
+ cnn_curve["top1"].append(cnn_accy["top1"])
+ cnn_curve[top_string].append(cnn_accy["top{}".format(model.topk)])
+
+ logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
+ logging.info("CNN top5 curve: {}\n".format(cnn_curve[top_string]))
+
+ print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"]))
+ logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))
+ model.save_checkpoint(save_name)
+ if len(cnn_matrix)>0:
+ np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))])
+ for idxx, line in enumerate(cnn_matrix):
+ idxy = len(line)
+ np_acctable[idxx, :idxy] = np.array(line)
+ np_acctable = np_acctable.T
+ forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1])
+ logging.info('Forgetting (CNN): {}'.format(forgetting))
+ logging.info('Accuracy Matrix (CNN): {}'.format(np_acctable))
+ print('Accuracy Matrix (CNN):')
+ print(np_acctable)
+ print('Forgetting (CNN):', forgetting)
+ if len(nme_matrix)>0:
+ np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))])
+ for idxx, line in enumerate(nme_matrix):
+ idxy = len(line)
+ np_acctable[idxx, :idxy] = np.array(line)
+ np_acctable = np_acctable.T
+ forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1])
+ logging.info('Forgetting (NME): {}'.format(forgetting))
+ logging.info('Accuracy Matrix (NME): {}'.format(np_acctable))
+ print('Accuracy Matrix (NME):')
+ print(np_acctable)
+ print('Forgetting (NME):', forgetting)
+
+def _set_device(args):
+ device_type = args["device"]
+ gpus = []
+
+ for device in device_type:
+ if device == -1:
+ device = torch.device("cpu")
+ else:
+ device = torch.device("cuda:{}".format(device))
+
+ gpus.append(device)
+
+ args["device"] = gpus
+
+
+def _set_random():
+ torch.manual_seed(1)
+ torch.cuda.manual_seed(1)
+ torch.cuda.manual_seed_all(1)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def print_args(args):
+ for key, value in args.items():
+ logging.info("{}: {}".format(key, value))
+
diff --git a/upload_s3.sh b/upload_s3.sh
new file mode 100644
index 0000000000000000000000000000000000000000..97d4e956d07481b9e060b6f7353bfd04f111de84
--- /dev/null
+++ b/upload_s3.sh
@@ -0,0 +1,12 @@
+#!/bin/sh
+
+# Ensure the script exits on the first error and prints each command before executing it
+set -e
+set -x
+
+# Check if local directory and s3 path arguments were provided, if not, set default values
+local_dir=${1:-./models}
+s3_path=${2:-s3://pycil.com/"$(date -u +"%Y-%m-%dT%H:%M:%SZ")"}
+
+# Perform the S3 copy operation with the provided or default s3 path
+aws s3 cp "$local_dir" "$s3_path" --recursive
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/autoaugment.py b/utils/autoaugment.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee8c90b1eba9867324212132edbdb5c570c911f5
--- /dev/null
+++ b/utils/autoaugment.py
@@ -0,0 +1,215 @@
+import numpy as np
+from .ops import *
+
+
+class ImageNetPolicy(object):
+ """ Randomly choose one of the best 24 Sub-policies on ImageNet.
+
+ Example:
+ >>> policy = ImageNetPolicy()
+ >>> transformed = policy(image)
+
+ Example as a PyTorch Transform:
+ >>> transform = transforms.Compose([
+ >>> transforms.Resize(256),
+ >>> ImageNetPolicy(),
+ >>> transforms.ToTensor()])
+ """
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.policies = [
+ SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
+ SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
+ SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
+ SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
+ SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
+
+ SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
+ SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
+ SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
+ SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
+ SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
+
+ SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
+ SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
+ SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
+ SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
+ SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
+
+ SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
+ SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
+ SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
+ SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
+ SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
+
+ SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
+ SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
+ SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
+ SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
+ SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
+ ]
+
+ def __call__(self, img):
+ policy_idx = random.randint(0, len(self.policies) - 1)
+ return self.policies[policy_idx](img)
+
+ def __repr__(self):
+ return "AutoAugment ImageNet Policy"
+
+
+class CIFAR10Policy(object):
+ """ Randomly choose one of the best 25 Sub-policies on CIFAR10.
+
+ Example:
+ >>> policy = CIFAR10Policy()
+ >>> transformed = policy(image)
+
+ Example as a PyTorch Transform:
+ >>> transform=transforms.Compose([
+ >>> transforms.Resize(256),
+ >>> CIFAR10Policy(),
+ >>> transforms.ToTensor()])
+ """
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.policies = [
+ SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
+ SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
+ SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
+ SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
+ SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
+
+ SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
+ SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
+ SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
+ SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
+ SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
+
+ SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
+ SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
+ SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
+ SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
+ SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
+
+ SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
+ SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
+ SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
+ SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
+ SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
+
+ SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
+ SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
+ SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
+ SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
+ SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
+ ]
+
+ def __call__(self, img):
+ policy_idx = random.randint(0, len(self.policies) - 1)
+ return self.policies[policy_idx](img)
+
+ def __repr__(self):
+ return "AutoAugment CIFAR10 Policy"
+
+
+class SVHNPolicy(object):
+ """ Randomly choose one of the best 25 Sub-policies on SVHN.
+
+ Example:
+ >>> policy = SVHNPolicy()
+ >>> transformed = policy(image)
+
+ Example as a PyTorch Transform:
+ >>> transform=transforms.Compose([
+ >>> transforms.Resize(256),
+ >>> SVHNPolicy(),
+ >>> transforms.ToTensor()])
+ """
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.policies = [
+ SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
+ SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
+ SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
+ SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
+ SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
+
+ SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
+ SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
+ SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
+ SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
+ SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
+
+ SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
+ SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
+ SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
+ SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
+ SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
+
+ SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
+ SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
+ SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
+ SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
+ SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
+
+ SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
+ SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
+ SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
+ SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
+ SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
+ ]
+
+ def __call__(self, img):
+ policy_idx = random.randint(0, len(self.policies) - 1)
+ return self.policies[policy_idx](img)
+
+ def __repr__(self):
+ return "AutoAugment SVHN Policy"
+
+
+class SubPolicy(object):
+ def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
+ ranges = {
+ "shearX": np.linspace(0, 0.3, 10),
+ "shearY": np.linspace(0, 0.3, 10),
+ "translateX": np.linspace(0, 150 / 331, 10),
+ "translateY": np.linspace(0, 150 / 331, 10),
+ "rotate": np.linspace(0, 30, 10),
+ "color": np.linspace(0.0, 0.9, 10),
+ "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int),
+ "solarize": np.linspace(256, 0, 10),
+ "contrast": np.linspace(0.0, 0.9, 10),
+ "sharpness": np.linspace(0.0, 0.9, 10),
+ "brightness": np.linspace(0.0, 0.9, 10),
+ "autocontrast": [0] * 10,
+ "equalize": [0] * 10,
+ "invert": [0] * 10
+ }
+
+ func = {
+ "shearX": ShearX(fillcolor=fillcolor),
+ "shearY": ShearY(fillcolor=fillcolor),
+ "translateX": TranslateX(fillcolor=fillcolor),
+ "translateY": TranslateY(fillcolor=fillcolor),
+ "rotate": Rotate(),
+ "color": Color(),
+ "posterize": Posterize(),
+ "solarize": Solarize(),
+ "contrast": Contrast(),
+ "sharpness": Sharpness(),
+ "brightness": Brightness(),
+ "autocontrast": AutoContrast(),
+ "equalize": Equalize(),
+ "invert": Invert()
+ }
+
+ self.p1 = p1
+ self.operation1 = func[operation1]
+ self.magnitude1 = ranges[operation1][magnitude_idx1]
+ self.p2 = p2
+ self.operation2 = func[operation2]
+ self.magnitude2 = ranges[operation2][magnitude_idx2]
+
+ def __call__(self, img):
+ if random.random() < self.p1:
+ img = self.operation1(img, self.magnitude1)
+ if random.random() < self.p2:
+ img = self.operation2(img, self.magnitude2)
+ return img
diff --git a/utils/data.py b/utils/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..c161973d316392412f346520c729e19664d04a5a
--- /dev/null
+++ b/utils/data.py
@@ -0,0 +1,199 @@
+import numpy as np
+from torchvision import datasets, transforms
+from utils.toolkit import split_images_labels
+
+import os
+
+class iData(object):
+ train_trsf = []
+ test_trsf = []
+ common_trsf = []
+ class_order = None
+
+class iCIFAR10(iData):
+ use_path = False
+ train_trsf = [
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.ColorJitter(brightness=63 / 255),
+ transforms.ToTensor(),
+ ]
+ test_trsf = [transforms.ToTensor()]
+ common_trsf = [
+ transforms.Normalize(
+ mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)
+ ),
+ ]
+
+ class_order = np.arange(10).tolist()
+
+ def download_data(self):
+ train_dataset = datasets.cifar.CIFAR10("./data", train=True, download=True)
+ test_dataset = datasets.cifar.CIFAR10("./data", train=False, download=True)
+ self.train_data, self.train_targets = train_dataset.data, np.array(
+ train_dataset.targets
+ )
+ self.test_data, self.test_targets = test_dataset.data, np.array(
+ test_dataset.targets
+ )
+
+
+class iCIFAR100(iData):
+ use_path = False
+ train_trsf = [
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ColorJitter(brightness=63 / 255),
+ transforms.ToTensor()
+ ]
+ test_trsf = [transforms.ToTensor()]
+ common_trsf = [
+ transforms.Normalize(
+ mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)
+ ),
+ ]
+
+ class_order = np.arange(100).tolist()
+
+ def download_data(self):
+ train_dataset = datasets.cifar.CIFAR100("./data", train=True, download=True)
+ test_dataset = datasets.cifar.CIFAR100("./data", train=False, download=True)
+ self.train_data, self.train_targets = train_dataset.data, np.array(
+ train_dataset.targets
+ )
+ self.test_data, self.test_targets = test_dataset.data, np.array(
+ test_dataset.targets
+ )
+
+
+class iImageNet1000(iData):
+ use_path = True
+ train_trsf = [
+ transforms.RandomHorizontalFlip(),
+ transforms.RandomAffine(25, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=8),
+ transforms.ColorJitter(),
+ ]
+ test_trsf = [
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ ]
+ common_trsf = [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.470, 0.460, 0.455],
+ std=[0.267, 0.266, 0.270]
+ ),
+ ]
+
+ class_order = np.arange(1000).tolist()
+
+ def download_data(self):
+ assert 0, "You should specify the folder of your dataset"
+ train_dir = "[DATA-PATH]/train/"
+ test_dir = "[DATA-PATH]/val/"
+
+ train_dset = datasets.ImageFolder(train_dir)
+ test_dset = datasets.ImageFolder(test_dir)
+
+ self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
+ self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
+
+
+class StanfordCar(iData):
+ use_path = True
+ train_trsf = [
+ transforms.Resize(320),
+ transforms.CenterCrop(320),
+ transforms.RandomHorizontalFlip(),
+ transforms.RandomAffine(25, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=8),
+ transforms.ColorJitter(),
+ ]
+ test_trsf = [
+ transforms.Resize(320),
+ transforms.CenterCrop(320),
+ ]
+ common_trsf = [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.470, 0.460, 0.455],
+ std=[0.267, 0.266, 0.270]
+ ),
+ ]
+ class_order = np.arange(196).tolist()
+ def download_data(self):
+ path = './car_data/car_data'
+ train_dset = datasets.ImageFolder(os.path.join(path, "train"))
+ test_dset = datasets.ImageFolder(os.path.join(path, "test"))
+ self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
+ self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
+
+class GeneralDataset(iData):
+ def __init__(
+ self,
+ path,
+ init_class_list = [-1],
+ train_transform = None,
+ test_transform = None,
+ common_transform = None):
+ self.use_path = True
+ self.path = path
+ self.train_trsf = train_transform
+ if self.train_trsf == None:
+ self.train_trsf = [
+ transforms.RandomAffine(25, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=8),
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ColorJitter(brightness = 0.3, saturation = 0.2),
+ ]
+ self.test_trsf = test_transform
+ if self.test_trsf == None:
+ self.test_trsf = [
+ transforms.Resize(224),
+ transforms.CenterCrop(224),
+ ]
+ self.common_trsf = common_transform
+ if self.common_trsf == None:
+ self.common_trsf = [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5]
+ ),
+ ]
+ self.init_index = max(init_class_list) + 1
+ self.class_order = np.arange(self.init_index, self.init_index + len(os.listdir(os.path.join(self.path, "train"))))
+
+ def download_data(self):
+ train_dset = datasets.ImageFolder(os.path.join(self.path, "train"))
+ test_dset = datasets.ImageFolder(os.path.join(self.path, "val"))
+ self.train_data, self.train_targets = split_images_labels(train_dset.imgs, start_index = self.init_index)
+ self.test_data, self.test_targets = split_images_labels(test_dset.imgs, start_index = self.init_index)
+ return train_dset.classes
+
+class iImageNet100(iData):
+ use_path = True
+ train_trsf = [
+ transforms.Resize(320),
+ transforms.CenterCrop(320),
+ ]
+ test_trsf = [
+ transforms.Resize(320),
+ transforms.CenterCrop(320),
+ ]
+ common_trsf = [
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ]
+
+ class_order = np.arange(1000).tolist()
+
+ def download_data(self):
+ assert 0, "You should specify the folder of your dataset"
+ train_dir = "[DATA-PATH]/train/"
+ test_dir = "[DATA-PATH]/val/"
+
+ train_dset = datasets.ImageFolder(train_dir)
+ test_dset = datasets.ImageFolder(test_dir)
+
+ self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
+ self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
diff --git a/utils/data_manager.py b/utils/data_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..78b46eb9d9f3c2a3a192e68a472288a12b2f6056
--- /dev/null
+++ b/utils/data_manager.py
@@ -0,0 +1,335 @@
+import logging
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000, StanfordCar, GeneralDataset
+from tqdm import tqdm
+class DataManager(object):
+ def __init__(self, dataset_name, shuffle, seed, init_cls, increment, resume = False, path = None, class_list = [-1]):
+ self.dataset_name = dataset_name
+ self.init_class_list = class_list
+ if not resume:
+ data = {
+ "path": path,
+ "class_list": [-1],
+ }
+ self._setup_data(dataset_name, shuffle, seed, data = data)
+ if len(self._class_order) < init_cls:
+ self._increments = [len(self._class_order)]
+ else:
+ self._increments = [init_cls]
+ while sum(self._increments) + increment < len(self._class_order):
+ self._increments.append(increment)
+ offset = len(self._class_order) - sum(self._increments)
+ if offset > 0:
+ self._increments.append(offset)
+ else:
+ self._increments = [max(class_list)]
+ data = {
+ "path": path,
+ "class_list": class_list,
+ }
+ self._setup_data(dataset_name, shuffle, seed, data = data)
+ while sum(self._increments) + increment < len(self._class_order):
+ self._increments.append(increment)
+ offset = len(self._class_order) - sum(self._increments) - 1
+ if offset > 0:
+ self._increments.append(offset)
+ def get_class_list(self, task):
+ return self._class_order[: sum(self._increments[: task + 1])]
+ def get_label_list(self, task):
+ cls_list = self.get_class_list(task)
+ start_index = max(self.init_class_list) + 1
+ result = {i:self.label_list[i] for i in cls_list}
+ return result
+ @property
+ def nb_tasks(self):
+ return len(self._increments)
+
+ def get_task_size(self, task):
+ return self._increments[task]
+
+ def get_accumulate_tasksize(self,task):
+ return float(sum(self._increments[:task+1]))
+
+ def get_total_classnum(self):
+ return len(self._class_order)
+
+ def get_dataset(
+ self, indices, source, mode, appendent=None, ret_data=False, m_rate=None
+ ):
+ if source == "train":
+ x, y = self._train_data, self._train_targets
+ elif source == "test":
+ x, y = self._test_data, self._test_targets
+ else:
+ raise ValueError("Unknown data source {}.".format(source))
+
+ if mode == "train":
+ trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
+ elif mode == "flip":
+ trsf = transforms.Compose(
+ [
+ *self._test_trsf,
+ transforms.RandomHorizontalFlip(p=1.0),
+ *self._common_trsf,
+ ]
+ )
+ elif mode == "test":
+ trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
+ else:
+ raise ValueError("Unknown mode {}.".format(mode))
+
+ data, targets = [], []
+ for idx in indices:
+ if m_rate is None:
+ class_data, class_targets = self._select(
+ x, y, low_range=idx, high_range=idx + 1
+ )
+ else:
+ class_data, class_targets = self._select_rmm(
+ x, y, low_range=idx, high_range=idx + 1, m_rate=m_rate
+ )
+ data.append(class_data)
+ targets.append(class_targets)
+
+ if appendent is not None and len(appendent) != 0:
+ appendent_data, appendent_targets = appendent
+ data.append(appendent_data)
+ targets.append(appendent_targets)
+
+ data, targets = np.concatenate(data), np.concatenate(targets)
+ if ret_data:
+ return data, targets, DummyDataset(data, targets, trsf, self.use_path)
+ else:
+ return DummyDataset(data, targets, trsf, self.use_path)
+
+
+ def get_finetune_dataset(self,known_classes,total_classes,source,mode,appendent,type="ratio"):
+ if source == 'train':
+ x, y = self._train_data, self._train_targets
+ elif source == 'test':
+ x, y = self._test_data, self._test_targets
+ else:
+ raise ValueError('Unknown data source {}.'.format(source))
+
+ if mode == 'train':
+ trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
+ elif mode == 'test':
+ trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
+ else:
+ raise ValueError('Unknown mode {}.'.format(mode))
+ val_data = []
+ val_targets = []
+
+ old_num_tot = 0
+ appendent_data, appendent_targets = appendent
+
+ for idx in range(0, known_classes):
+ append_data, append_targets = self._select(appendent_data, appendent_targets,
+ low_range=idx, high_range=idx+1)
+ num=len(append_data)
+ if num == 0:
+ continue
+ old_num_tot += num
+ val_data.append(append_data)
+ val_targets.append(append_targets)
+ if type == "ratio":
+ new_num_tot = int(old_num_tot*(total_classes-known_classes)/known_classes)
+ elif type == "same":
+ new_num_tot = old_num_tot
+ else:
+ assert 0, "not implemented yet"
+ new_num_average = int(new_num_tot/(total_classes-known_classes))
+ for idx in range(known_classes,total_classes):
+ class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1)
+ val_indx = np.random.choice(len(class_data),new_num_average, replace=False)
+ val_data.append(class_data[val_indx])
+ val_targets.append(class_targets[val_indx])
+ val_data=np.concatenate(val_data)
+ val_targets = np.concatenate(val_targets)
+ return DummyDataset(val_data, val_targets, trsf, self.use_path)
+
+ def get_dataset_with_split(
+ self, indices, source, mode, appendent=None, val_samples_per_class=0
+ ):
+ if source == "train":
+ x, y = self._train_data, self._train_targets
+ elif source == "test":
+ x, y = self._test_data, self._test_targets
+ else:
+ raise ValueError("Unknown data source {}.".format(source))
+
+ if mode == "train":
+ trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
+ elif mode == "test":
+ trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
+ else:
+ raise ValueError("Unknown mode {}.".format(mode))
+
+ train_data, train_targets = [], []
+ val_data, val_targets = [], []
+ for idx in indices:
+ class_data, class_targets = self._select(
+ x, y, low_range=idx, high_range=idx + 1
+ )
+ val_indx = np.random.choice(
+ len(class_data), val_samples_per_class, replace=False
+ )
+ train_indx = list(set(np.arange(len(class_data))) - set(val_indx))
+ val_data.append(class_data[val_indx])
+ val_targets.append(class_targets[val_indx])
+ train_data.append(class_data[train_indx])
+ train_targets.append(class_targets[train_indx])
+
+ if appendent is not None:
+ appendent_data, appendent_targets = appendent
+ for idx in range(0, int(np.max(appendent_targets)) + 1):
+ append_data, append_targets = self._select(
+ appendent_data, appendent_targets, low_range=idx, high_range=idx + 1
+ )
+ val_indx = np.random.choice(
+ len(append_data), val_samples_per_class, replace=False
+ )
+ train_indx = list(set(np.arange(len(append_data))) - set(val_indx))
+ val_data.append(append_data[val_indx])
+ val_targets.append(append_targets[val_indx])
+ train_data.append(append_data[train_indx])
+ train_targets.append(append_targets[train_indx])
+
+ train_data, train_targets = np.concatenate(train_data), np.concatenate(
+ train_targets
+ )
+ val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets)
+
+ return DummyDataset(
+ train_data, train_targets, trsf, self.use_path
+ ), DummyDataset(val_data, val_targets, trsf, self.use_path)
+
+ def _setup_data(self, dataset_name, shuffle, seed, data = None):
+ idata = _get_idata(dataset_name, data = data)
+ self.label_list = idata.download_data()
+ # Data
+ self._train_data, self._train_targets = idata.train_data, idata.train_targets
+ self._test_data, self._test_targets = idata.test_data, idata.test_targets
+ self.use_path = idata.use_path
+ # Transforms
+ self._train_trsf = idata.train_trsf
+ self._test_trsf = idata.test_trsf
+ self._common_trsf = idata.common_trsf
+
+ # Order
+ order = np.unique(self._train_targets)
+ if shuffle:
+ np.random.seed(seed)
+ order = np.random.permutation(order).tolist()
+ else:
+ order = idata.class_order.tolist()
+ if data['class_list'][0] != -1:
+ self._class_order = np.concatenate((np.array(data['class_list']), order)).tolist()
+ else:
+ self._class_order = order
+ logging.info(self._class_order)
+ # Map indices
+ self._train_targets = _map_new_class_index(
+ self._train_targets, self._class_order,
+ )
+ self._test_targets = _map_new_class_index(self._test_targets, self._class_order)
+
+ def _select(self, x, y, low_range, high_range):
+ idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0]
+ if isinstance(x,np.ndarray):
+ x_return = x[idxes]
+ else:
+ x_return = []
+ for id in idxes:
+ x_return.append(x[id])
+ return x_return, y[idxes]
+
+ def _select_rmm(self, x, y, low_range, high_range, m_rate):
+ assert m_rate is not None
+ if m_rate != 0:
+ idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0]
+ selected_idxes = np.random.randint(
+ 0, len(idxes), size=int((1 - m_rate) * len(idxes))
+ )
+ new_idxes = idxes[selected_idxes]
+ new_idxes = np.sort(new_idxes)
+ else:
+ new_idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0]
+ return x[new_idxes], y[new_idxes]
+
+ def getlen(self, index):
+ y = self._train_targets
+ return np.sum(np.where(y == index))
+
+
+class DummyDataset(Dataset):
+ def __init__(self, images, labels, trsf, use_path=False):
+ assert len(images) == len(labels), "Data size error!"
+ self.images = images
+ self.labels = labels
+ self.trsf = trsf
+ self.use_path = use_path
+
+ def __len__(self):
+ return len(self.images)
+
+ def __getitem__(self, idx):
+ if self.use_path:
+ image = self.trsf(pil_loader(self.images[idx]))
+ else:
+ image = self.trsf(Image.fromarray(self.images[idx]))
+ label = self.labels[idx]
+
+ return idx, image, label
+
+
+def _map_new_class_index(y, order):
+ return np.array(list(map(lambda x: order.index(x), y)))
+
+
+def _get_idata(dataset_name, data = None):
+ name = dataset_name.lower()
+ if name == "cifar10":
+ return iCIFAR10()
+ elif name == "cifar100":
+ return iCIFAR100()
+ elif name == "imagenet1000":
+ return iImageNet1000()
+ elif name == "imagenet100":
+ return iImageNet100()
+ elif name == 'stanfordcar':
+ return StanfordCar()
+ elif name == 'general_dataset':
+ print(data)
+ return GeneralDataset(data["path"], init_class_list = data["class_list"]);
+ else:
+ raise NotImplementedError("Unknown dataset {}.".format(dataset_name))
+
+
+def pil_loader(path):
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+ with open(path, "rb") as f:
+ img = Image.open(f)
+ return img.convert("RGB")
+
+
+def accimage_loader(path):
+ import accimage
+
+ try:
+ return accimage.Image(path)
+ except IOError:
+ # Potentially a decoding problem, fall back to PIL.Image
+ return pil_loader(path)
+
+
+def default_loader(path):
+ from torchvision import get_image_backend
+
+ if get_image_backend() == "accimage":
+ return accimage_loader(path)
+ else:
+ return pil_loader(path)
diff --git a/utils/factory.py b/utils/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..41ed69d14cbf7736b38791d4bb10759dd06b2cc4
--- /dev/null
+++ b/utils/factory.py
@@ -0,0 +1,67 @@
+def get_model(model_name, args):
+ name = model_name.lower()
+ if name == "icarl":
+ from models.icarl import iCaRL
+ return iCaRL(args)
+ elif name == "bic":
+ from models.bic import BiC
+ return BiC(args)
+ elif name == "podnet":
+ from models.podnet import PODNet
+ return PODNet(args)
+ elif name == "lwf":
+ from models.lwf import LwF
+ return LwF(args)
+ elif name == "ewc":
+ from models.ewc import EWC
+ return EWC(args)
+ elif name == "wa":
+ from models.wa import WA
+ return WA(args)
+ elif name == "der":
+ from models.der import DER
+ return DER(args)
+ elif name == "finetune":
+ from models.finetune import Finetune
+ return Finetune(args)
+ elif name == "replay":
+ from models.replay import Replay
+ return Replay(args)
+ elif name == "gem":
+ from models.gem import GEM
+ return GEM(args)
+ elif name == "coil":
+ from models.coil import COIL
+ return COIL(args)
+ elif name == "foster":
+ from models.foster import FOSTER
+ return FOSTER(args)
+ elif name == "rmm-icarl":
+ from models.rmm import RMM_FOSTER, RMM_iCaRL
+ return RMM_iCaRL(args)
+ elif name == "rmm-foster":
+ from models.rmm import RMM_FOSTER, RMM_iCaRL
+ return RMM_FOSTER(args)
+ elif name == "fetril":
+ from models.fetril import FeTrIL
+ return FeTrIL(args)
+ elif name == "pass":
+ from models.pa2s import PASS
+ return PASS(args)
+ elif name == "il2a":
+ from models.il2a import IL2A
+ return IL2A(args)
+ elif name == "ssre":
+ from models.ssre import SSRE
+ return SSRE(args)
+ elif name == "memo":
+ from models.memo import MEMO
+ return MEMO(args)
+ elif name == "beefiso":
+ from models.beef_iso import BEEFISO
+ return BEEFISO(args)
+ elif name == "simplecil":
+ from models.simplecil import SimpleCIL
+ return SimpleCIL(args)
+ else:
+ assert 0
diff --git a/utils/inc_net.py b/utils/inc_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..53fe3721aedc95ccf0169b911369fa6f525f83d5
--- /dev/null
+++ b/utils/inc_net.py
@@ -0,0 +1,799 @@
+import copy
+import logging
+import torch
+from torch import nn
+from convs.cifar_resnet import resnet32
+from convs.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
+from convs.ucir_cifar_resnet import resnet32 as cosine_resnet32
+from convs.ucir_resnet import resnet18 as cosine_resnet18
+from convs.ucir_resnet import resnet34 as cosine_resnet34
+from convs.ucir_resnet import resnet50 as cosine_resnet50
+from convs.linears import SimpleLinear, SplitCosineLinear, CosineLinear
+from convs.modified_represnet import resnet18_rep,resnet34_rep
+from convs.resnet_cbam import resnet18_cbam,resnet34_cbam,resnet50_cbam
+from convs.memo_resnet import get_resnet18_imagenet as get_memo_resnet18 #for MEMO imagenet
+from convs.memo_cifar_resnet import get_resnet32_a2fc as get_memo_resnet32 #for MEMO cifar
+
+def get_convnet(args, pretrained=False):
+ name = args["convnet_type"].lower()
+ if name == "resnet32":
+ return resnet32()
+ elif name == "resnet18":
+ return resnet18(pretrained=pretrained,args=args)
+ elif name == "resnet34":
+ return resnet34(pretrained=pretrained,args=args)
+ elif name == "resnet50":
+ return resnet50(pretrained=pretrained,args=args)
+ elif name == "cosine_resnet18":
+ return cosine_resnet18(pretrained=pretrained,args=args)
+ elif name == "cosine_resnet32":
+ return cosine_resnet32()
+ elif name == "cosine_resnet34":
+ return cosine_resnet34(pretrained=pretrained,args=args)
+ elif name == "cosine_resnet50":
+ return cosine_resnet50(pretrained=pretrained,args=args)
+ elif name == "resnet18_rep":
+ return resnet18_rep(pretrained=pretrained,args=args)
+ elif name == "resnet18_cbam":
+ return resnet18_cbam(pretrained=pretrained,args=args)
+ elif name == "resnet34_cbam":
+ return resnet34_cbam(pretrained=pretrained,args=args)
+ elif name == "resnet50_cbam":
+ return resnet50_cbam(pretrained=pretrained,args=args)
+
+ # MEMO benchmark backbone
+ elif name == 'memo_resnet18':
+ _basenet, _adaptive_net = get_memo_resnet18()
+ return _basenet, _adaptive_net
+ elif name == 'memo_resnet32':
+ _basenet, _adaptive_net = get_memo_resnet32()
+ return _basenet, _adaptive_net
+
+ else:
+ raise NotImplementedError("Unknown type {}".format(name))
+
+
+class BaseNet(nn.Module):
+ def __init__(self, args, pretrained):
+ super(BaseNet, self).__init__()
+
+ self.convnet = get_convnet(args, pretrained)
+ self.fc = None
+
+ @property
+ def feature_dim(self):
+ return self.convnet.out_dim
+
+ def extract_vector(self, x):
+ return self.convnet(x)["features"]
+
+ def forward(self, x):
+ x = self.convnet(x)
+ out = self.fc(x["features"])
+ """
+ {
+ 'fmaps': [x_1, x_2, ..., x_n],
+ 'features': features
+ 'logits': logits
+ }
+ """
+ out.update(x)
+
+ return out
+
+ def update_fc(self, nb_classes):
+ pass
+
+ def generate_fc(self, in_dim, out_dim):
+ pass
+
+ def copy(self):
+ return copy.deepcopy(self)
+
+ def freeze(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self.eval()
+
+ return self
+
+ def load_checkpoint(self, args):
+ if args["init_cls"] == 50:
+ pkl_name = "{}_{}_{}_B{}_Inc{}".format(
+ args["dataset"],
+ args["seed"],
+ args["convnet_type"],
+ 0,
+ args["init_cls"],
+ )
+ checkpoint_name = f"checkpoints/finetune_{pkl_name}_0.pkl"
+ else:
+ checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl"
+ model_infos = torch.load(checkpoint_name)
+ self.convnet.load_state_dict(model_infos['convnet'])
+ self.fc.load_state_dict(model_infos['fc'])
+ test_acc = model_infos['test_acc']
+ return test_acc
+
+class IncrementalNet(BaseNet):
+ def __init__(self, args, pretrained, gradcam=False):
+ super().__init__(args, pretrained)
+ self.gradcam = gradcam
+ if hasattr(self, "gradcam") and self.gradcam:
+ self._gradcam_hooks = [None, None]
+ self.set_gradcam_hook()
+
+ def update_fc(self, nb_classes):
+ fc = self.generate_fc(self.feature_dim, nb_classes)
+ if self.fc is not None:
+ nb_output = self.fc.out_features
+ weight = copy.deepcopy(self.fc.weight.data)
+ bias = copy.deepcopy(self.fc.bias.data)
+ fc.weight.data[:nb_output] = weight
+ fc.bias.data[:nb_output] = bias
+
+ del self.fc
+ self.fc = fc
+
+ def weight_align(self, increment):
+ weights = self.fc.weight.data
+ newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
+ oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
+ meannew = torch.mean(newnorm)
+ meanold = torch.mean(oldnorm)
+ gamma = meanold / meannew
+ print("alignweights,gamma=", gamma)
+ self.fc.weight.data[-increment:, :] *= gamma
+
+ def generate_fc(self, in_dim, out_dim):
+ fc = SimpleLinear(in_dim, out_dim)
+
+ return fc
+
+ def forward(self, x):
+ x = self.convnet(x)
+ out = self.fc(x["features"])
+ out.update(x)
+ if hasattr(self, "gradcam") and self.gradcam:
+ out["gradcam_gradients"] = self._gradcam_gradients
+ out["gradcam_activations"] = self._gradcam_activations
+
+ return out
+
+ def unset_gradcam_hook(self):
+ self._gradcam_hooks[0].remove()
+ self._gradcam_hooks[1].remove()
+ self._gradcam_hooks[0] = None
+ self._gradcam_hooks[1] = None
+ self._gradcam_gradients, self._gradcam_activations = [None], [None]
+
+ def set_gradcam_hook(self):
+ self._gradcam_gradients, self._gradcam_activations = [None], [None]
+
+ def backward_hook(module, grad_input, grad_output):
+ self._gradcam_gradients[0] = grad_output[0]
+ return None
+
+ def forward_hook(module, input, output):
+ self._gradcam_activations[0] = output
+ return None
+
+ self._gradcam_hooks[0] = self.convnet.last_conv.register_backward_hook(
+ backward_hook
+ )
+ self._gradcam_hooks[1] = self.convnet.last_conv.register_forward_hook(
+ forward_hook
+ )
+
+class IL2ANet(IncrementalNet):
+
+ def update_fc(self, num_old, num_total, num_aux):
+ fc = self.generate_fc(self.feature_dim, num_total+num_aux)
+ if self.fc is not None:
+ weight = copy.deepcopy(self.fc.weight.data)
+ bias = copy.deepcopy(self.fc.bias.data)
+ fc.weight.data[:num_old] = weight[:num_old]
+ fc.bias.data[:num_old] = bias[:num_old]
+ del self.fc
+ self.fc = fc
+
+class CosineIncrementalNet(BaseNet):
+ def __init__(self, args, pretrained, nb_proxy=1):
+ super().__init__(args, pretrained)
+ self.nb_proxy = nb_proxy
+
+ def update_fc(self, nb_classes, task_num):
+ fc = self.generate_fc(self.feature_dim, nb_classes)
+ if self.fc is not None:
+ if task_num == 1:
+ fc.fc1.weight.data = self.fc.weight.data
+ fc.sigma.data = self.fc.sigma.data
+ else:
+ prev_out_features1 = self.fc.fc1.out_features
+ fc.fc1.weight.data[:prev_out_features1] = self.fc.fc1.weight.data
+ fc.fc1.weight.data[prev_out_features1:] = self.fc.fc2.weight.data
+ fc.sigma.data = self.fc.sigma.data
+
+ del self.fc
+ self.fc = fc
+ def generate_fc(self, in_dim, out_dim):
+ if self.fc is None:
+ fc = CosineLinear(in_dim, out_dim, self.nb_proxy, to_reduce=True)
+ else:
+ prev_out_features = self.fc.out_features // self.nb_proxy
+ # prev_out_features = self.fc.out_features
+ fc = SplitCosineLinear(
+ in_dim, prev_out_features, out_dim - prev_out_features, self.nb_proxy
+ )
+
+ return fc
+
+
+class BiasLayer_BIC(nn.Module):
+ def __init__(self):
+ super(BiasLayer_BIC, self).__init__()
+ self.alpha = nn.Parameter(torch.ones(1, requires_grad=True))
+ self.beta = nn.Parameter(torch.zeros(1, requires_grad=True))
+
+ def forward(self, x, low_range, high_range):
+ ret_x = x.clone()
+ ret_x[:, low_range:high_range] = (
+ self.alpha * x[:, low_range:high_range] + self.beta
+ )
+ return ret_x
+
+ def get_params(self):
+ return (self.alpha.item(), self.beta.item())
+
+
+class IncrementalNetWithBias(BaseNet):
+ def __init__(self, args, pretrained, bias_correction=False):
+ super().__init__(args, pretrained)
+
+ # Bias layer
+ self.bias_correction = bias_correction
+ self.bias_layers = nn.ModuleList([])
+ self.task_sizes = []
+
+ def forward(self, x):
+ x = self.convnet(x)
+ out = self.fc(x["features"])
+ if self.bias_correction:
+ logits = out["logits"]
+ for i, layer in enumerate(self.bias_layers):
+ logits = layer(
+ logits, sum(self.task_sizes[:i]), sum(self.task_sizes[: i + 1])
+ )
+ out["logits"] = logits
+
+ out.update(x)
+
+ return out
+
+ def update_fc(self, nb_classes):
+ fc = self.generate_fc(self.feature_dim, nb_classes)
+ if self.fc is not None:
+ nb_output = self.fc.out_features
+ weight = copy.deepcopy(self.fc.weight.data)
+ bias = copy.deepcopy(self.fc.bias.data)
+ fc.weight.data[:nb_output] = weight
+ fc.bias.data[:nb_output] = bias
+
+ del self.fc
+ self.fc = fc
+
+ new_task_size = nb_classes - sum(self.task_sizes)
+ self.task_sizes.append(new_task_size)
+ self.bias_layers.append(BiasLayer_BIC())
+
+ def generate_fc(self, in_dim, out_dim):
+ fc = SimpleLinear(in_dim, out_dim)
+
+ return fc
+
+ def get_bias_params(self):
+ params = []
+ for layer in self.bias_layers:
+ params.append(layer.get_params())
+
+ return params
+
+ def unfreeze(self):
+ for param in self.parameters():
+ param.requires_grad = True
+
+
+class DERNet(nn.Module):
+ def __init__(self, args, pretrained):
+ super(DERNet, self).__init__()
+ self.convnet_type = args["convnet_type"]
+ self.convnets = nn.ModuleList()
+ self.pretrained = pretrained
+ self.out_dim = None
+ self.fc = None
+ self.aux_fc = None
+ self.task_sizes = []
+ self.args = args
+
+ @property
+ def feature_dim(self):
+ if self.out_dim is None:
+ return 0
+ return self.out_dim * len(self.convnets)
+
+ def extract_vector(self, x):
+ features = [convnet(x)["features"] for convnet in self.convnets]
+ features = torch.cat(features, 1)
+ return features
+
+ def forward(self, x):
+ features = [convnet(x)["features"] for convnet in self.convnets]
+ features = torch.cat(features, 1)
+
+ out = self.fc(features) # {logics: self.fc(features)}
+
+ aux_logits = self.aux_fc(features[:, -self.out_dim :])["logits"]
+
+ out.update({"aux_logits": aux_logits, "features": features})
+ return out
+ """
+ {
+ 'features': features
+ 'logits': logits
+ 'aux_logits':aux_logits
+ }
+ """
+
+ def update_fc(self, nb_classes):
+ if len(self.convnets) == 0:
+ self.convnets.append(get_convnet(self.args))
+ else:
+ self.convnets.append(get_convnet(self.args))
+ self.convnets[-1].load_state_dict(self.convnets[-2].state_dict())
+
+ if self.out_dim is None:
+ self.out_dim = self.convnets[-1].out_dim
+ fc = self.generate_fc(self.feature_dim, nb_classes)
+ if self.fc is not None:
+ nb_output = self.fc.out_features
+ weight = copy.deepcopy(self.fc.weight.data)
+ bias = copy.deepcopy(self.fc.bias.data)
+ fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight
+ fc.bias.data[:nb_output] = bias
+
+ del self.fc
+ self.fc = fc
+
+ new_task_size = nb_classes - sum(self.task_sizes)
+ self.task_sizes.append(new_task_size)
+
+ self.aux_fc = self.generate_fc(self.out_dim, new_task_size + 1)
+
+ def generate_fc(self, in_dim, out_dim):
+ fc = SimpleLinear(in_dim, out_dim)
+
+ return fc
+
+ def copy(self):
+ return copy.deepcopy(self)
+
+ def freeze(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self.eval()
+
+ return self
+
+ def freeze_conv(self):
+ for param in self.convnets.parameters():
+ param.requires_grad = False
+ self.convnets.eval()
+
+ def weight_align(self, increment):
+ weights = self.fc.weight.data
+ newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
+ oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
+ meannew = torch.mean(newnorm)
+ meanold = torch.mean(oldnorm)
+ gamma = meanold / meannew
+ print("alignweights,gamma=", gamma)
+ self.fc.weight.data[-increment:, :] *= gamma
+
+ def load_checkpoint(self, args):
+ checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl"
+ model_infos = torch.load(checkpoint_name)
+ assert len(self.convnets) == 1
+ self.convnets[0].load_state_dict(model_infos['convnet'])
+ self.fc.load_state_dict(model_infos['fc'])
+ test_acc = model_infos['test_acc']
+ return test_acc
+
+
+class SimpleCosineIncrementalNet(BaseNet):
+ def __init__(self, args, pretrained):
+ super().__init__(args, pretrained)
+
+ def update_fc(self, nb_classes, nextperiod_initialization=None):
+ fc = self.generate_fc(self.feature_dim, nb_classes).cuda()
+ if self.fc is not None:
+ nb_output = self.fc.out_features
+ weight = copy.deepcopy(self.fc.weight.data)
+ fc.sigma.data = self.fc.sigma.data
+ if nextperiod_initialization is not None:
+ weight = torch.cat([weight.cuda(), nextperiod_initialization.cuda()])
+ else:
+ weight = torch.cat([weight.cuda(), torch.zeros(nb_classes - nb_output, self.feature_dim).cuda()])
+ fc.weight = nn.Parameter(weight)
+ del self.fc
+ self.fc = fc
+ def load_checkpoint(self, checkpoint):
+ self.convnet.load_state_dict(checkpoint["convnet"])
+ self.fc.load_state_dict(checkpoint["fc"])
+ def generate_fc(self, in_dim, out_dim):
+ fc = CosineLinear(in_dim, out_dim)
+ return fc
+
+
+class FOSTERNet(nn.Module):
+ def __init__(self, args, pretrained):
+ super(FOSTERNet, self).__init__()
+ self.convnet_type = args["convnet_type"]
+ self.convnets = nn.ModuleList()
+ self.pretrained = pretrained
+ self.out_dim = None
+ self.fc = None
+ self.fe_fc = None
+ self.task_sizes = []
+ self.oldfc = None
+ self.args = args
+
+ @property
+ def feature_dim(self):
+ if self.out_dim is None:
+ return 0
+ return self.out_dim * len(self.convnets)
+
+ def extract_vector(self, x):
+ features = [convnet(x)["features"] for convnet in self.convnets]
+ features = torch.cat(features, 1)
+ return features
+
+ def load_checkpoint(self, checkpoint):
+ if len(self.convnets) == 0:
+ self.convnets.append(get_convnet(self.args))
+ self.convnets[0].load_state_dict(checkpoint["convnet"])
+ self.fc.load_state_dict(checkpoint["fc"])
+
+ def forward(self, x):
+ features = [convnet(x)["features"] for convnet in self.convnets]
+ features = torch.cat(features, 1)
+ out = self.fc(features)
+ fe_logits = self.fe_fc(features[:, -self.out_dim :])["logits"]
+
+ out.update({"fe_logits": fe_logits, "features": features})
+
+ if self.oldfc is not None:
+ old_logits = self.oldfc(features[:, : -self.out_dim])["logits"]
+ out.update({"old_logits": old_logits})
+
+ out.update({"eval_logits": out["logits"]})
+ return out
+
+ def update_fc(self, nb_classes):
+ self.convnets.append(get_convnet(self.args))
+ if self.out_dim is None:
+ self.out_dim = self.convnets[-1].out_dim
+ fc = self.generate_fc(self.feature_dim, nb_classes)
+ if self.fc is not None:
+ nb_output = self.fc.out_features
+ weight = copy.deepcopy(self.fc.weight.data)
+ bias = copy.deepcopy(self.fc.bias.data)
+ fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight
+ fc.bias.data[:nb_output] = bias
+ self.convnets[-1].load_state_dict(self.convnets[-2].state_dict())
+
+ self.oldfc = self.fc
+ self.fc = fc
+ new_task_size = nb_classes - sum(self.task_sizes)
+ self.task_sizes.append(new_task_size)
+ self.fe_fc = self.generate_fc(self.out_dim, nb_classes)
+
+ def generate_fc(self, in_dim, out_dim):
+ fc = SimpleLinear(in_dim, out_dim)
+ return fc
+
+ def copy(self):
+ return copy.deepcopy(self)
+
+ def copy_fc(self, fc):
+ weight = copy.deepcopy(fc.weight.data)
+ bias = copy.deepcopy(fc.bias.data)
+ n, m = weight.shape[0], weight.shape[1]
+ self.fc.weight.data[:n, :m] = weight
+ self.fc.bias.data[:n] = bias
+
+ def freeze(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self.eval()
+ return self
+
+ def freeze_conv(self):
+ for param in self.convnets.parameters():
+ param.requires_grad = False
+ self.convnets.eval()
+
+ def weight_align(self, old, increment, value):
+ weights = self.fc.weight.data
+ newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
+ oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
+ meannew = torch.mean(newnorm)
+ meanold = torch.mean(oldnorm)
+ gamma = meanold / meannew * (value ** (old / increment))
+ logging.info("align weights, gamma = {} ".format(gamma))
+ self.fc.weight.data[-increment:, :] *= gamma
+
+
+class BiasLayer(nn.Module):
+ def __init__(self):
+ super(BiasLayer, self).__init__()
+ self.alpha = nn.Parameter(torch.zeros(1, requires_grad=True))
+ self.beta = nn.Parameter(torch.zeros(1, requires_grad=True))
+
+ def forward(self, x , bias=True):
+ ret_x = x.clone()
+ ret_x = (self.alpha+1) * x # + self.beta
+ if bias:
+ ret_x = ret_x + self.beta
+ return ret_x
+
+ def get_params(self):
+ return (self.alpha.item(), self.beta.item())
+
+
+class BEEFISONet(nn.Module):
+ def __init__(self, args, pretrained):
+ super(BEEFISONet, self).__init__()
+ self.convnet_type = args["convnet_type"]
+ self.convnets = nn.ModuleList()
+ self.pretrained = pretrained
+ self.out_dim = None
+ self.old_fc = None
+ self.new_fc = None
+ self.task_sizes = []
+ self.forward_prototypes = None
+ self.backward_prototypes = None
+ self.args = args
+ self.biases = nn.ModuleList()
+
+ @property
+ def feature_dim(self):
+ if self.out_dim is None:
+ return 0
+ return self.out_dim * len(self.convnets)
+
+ def extract_vector(self, x):
+ features = [convnet(x)["features"] for convnet in self.convnets]
+ features = torch.cat(features, 1)
+ return features
+
+ def forward(self, x):
+ features = [convnet(x)["features"] for convnet in self.convnets]
+ features = torch.cat(features, 1)
+
+ if self.old_fc is None:
+ fc = self.new_fc
+ out = fc(features)
+ else:
+ '''
+ merge the weights
+ '''
+ new_task_size = self.task_sizes[-1]
+ fc_weight = torch.cat([self.old_fc.weight,torch.zeros((new_task_size,self.feature_dim-self.out_dim)).cuda()],dim=0)
+ new_fc_weight = self.new_fc.weight
+ new_fc_bias = self.new_fc.bias
+ for i in range(len(self.task_sizes)-2,-1,-1):
+ new_fc_weight = torch.cat([*[self.biases[i](self.backward_prototypes.weight[i].unsqueeze(0),bias=False) for _ in range(self.task_sizes[i])],new_fc_weight],dim=0)
+ new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias[i].unsqueeze(0),bias=True) for _ in range(self.task_sizes[i])], new_fc_bias])
+ fc_weight = torch.cat([fc_weight,new_fc_weight],dim=1)
+ fc_bias = torch.cat([self.old_fc.bias,torch.zeros(new_task_size).cuda()])
+ fc_bias+=new_fc_bias
+ logits = features@fc_weight.permute(1,0)+fc_bias
+ out = {"logits":logits}
+
+ new_fc_weight = self.new_fc.weight
+ new_fc_bias = self.new_fc.bias
+ for i in range(len(self.task_sizes)-2,-1,-1):
+ new_fc_weight = torch.cat([self.backward_prototypes.weight[i].unsqueeze(0),new_fc_weight],dim=0)
+ new_fc_bias = torch.cat([self.backward_prototypes.bias[i].unsqueeze(0), new_fc_bias])
+ out["train_logits"] = features[:,-self.out_dim:]@new_fc_weight.permute(1,0)+new_fc_bias
+ out.update({"eval_logits": out["logits"],"energy_logits":self.forward_prototypes(features[:,-self.out_dim:])["logits"]})
+ return out
+
+ def update_fc_before(self, nb_classes):
+ new_task_size = nb_classes - sum(self.task_sizes)
+ self.biases = nn.ModuleList([BiasLayer() for i in range(len(self.task_sizes))])
+ self.convnets.append(get_convnet(self.args))
+ if self.out_dim is None:
+ self.out_dim = self.convnets[-1].out_dim
+ if self.new_fc is not None:
+ self.fe_fc = self.generate_fc(self.out_dim, nb_classes)
+ self.backward_prototypes = self.generate_fc(self.out_dim,len(self.task_sizes))
+ self.convnets[-1].load_state_dict(self.convnets[0].state_dict())
+ self.forward_prototypes = self.generate_fc(self.out_dim, nb_classes)
+ self.new_fc = self.generate_fc(self.out_dim,new_task_size)
+ self.task_sizes.append(new_task_size)
+ def generate_fc(self, in_dim, out_dim):
+ fc = SimpleLinear(in_dim, out_dim)
+ return fc
+
+ def update_fc_after(self):
+ if self.old_fc is not None:
+ old_fc = self.generate_fc(self.feature_dim, sum(self.task_sizes))
+ new_task_size = self.task_sizes[-1]
+ old_fc.weight.data = torch.cat([self.old_fc.weight.data,torch.zeros((new_task_size,self.feature_dim-self.out_dim)).cuda()],dim=0)
+ new_fc_weight = self.new_fc.weight.data
+ new_fc_bias = self.new_fc.bias.data
+ for i in range(len(self.task_sizes)-2,-1,-1):
+ new_fc_weight = torch.cat([*[self.biases[i](self.backward_prototypes.weight.data[i].unsqueeze(0),bias=False) for _ in range(self.task_sizes[i])], new_fc_weight],dim=0)
+ new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias.data[i].unsqueeze(0),bias=True) for _ in range(self.task_sizes[i])], new_fc_bias])
+ old_fc.weight.data = torch.cat([old_fc.weight.data,new_fc_weight],dim=1)
+ old_fc.bias.data = torch.cat([self.old_fc.bias.data,torch.zeros(new_task_size).cuda()])
+ old_fc.bias.data+=new_fc_bias
+ self.old_fc = old_fc
+ else:
+ self.old_fc = self.new_fc
+
+ def copy(self):
+ return copy.deepcopy(self)
+
+ def copy_fc(self, fc):
+ weight = copy.deepcopy(fc.weight.data)
+ bias = copy.deepcopy(fc.bias.data)
+ n, m = weight.shape[0], weight.shape[1]
+ self.fc.weight.data[:n, :m] = weight
+ self.fc.bias.data[:n] = bias
+
+ def freeze(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self.eval()
+ return self
+
+ def freeze_conv(self):
+ for param in self.convnets.parameters():
+ param.requires_grad = False
+ self.convnets.eval()
+
+ def weight_align(self, old, increment, value):
+ weights = self.fc.weight.data
+ newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
+ oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
+ meannew = torch.mean(newnorm)
+ meanold = torch.mean(oldnorm)
+ gamma = meanold / meannew * (value ** (old / increment))
+ logging.info("align weights, gamma = {} ".format(gamma))
+ self.fc.weight.data[-increment:, :] *= gamma
+
+
+class AdaptiveNet(nn.Module):
+ def __init__(self, args, pretrained):
+ super(AdaptiveNet, self).__init__()
+ self.convnet_type = args["convnet_type"]
+ self.TaskAgnosticExtractor , _network = get_convnet(args, pretrained) #Generalized blocks
+ self.TaskAgnosticExtractor.train()
+ self.AdaptiveExtractors = nn.ModuleList() #Specialized Blocks
+ self.AdaptiveExtractors.append(_network)
+ self.pretrained=pretrained
+ if args["backbone"] != None and pretrained == True:
+ self.load_checkpoint(args)
+ self.out_dim=None
+ self.fc = None
+ self.aux_fc=None
+ self.task_sizes = []
+ self.args=args
+
+ @property
+ def feature_dim(self):
+ if self.out_dim is None:
+ return 0
+ return self.out_dim*len(self.AdaptiveExtractors)
+
+ def extract_vector(self, x):
+ base_feature_map = self.TaskAgnosticExtractor(x)
+ features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors]
+ features = torch.cat(features, 1)
+ return features
+
+ def forward(self, x):
+ base_feature_map = self.TaskAgnosticExtractor(x)
+ features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors]
+ features = torch.cat(features, 1)
+ out=self.fc(features) #{logits: self.fc(features)}
+
+ aux_logits=self.aux_fc(features[:,-self.out_dim:])["logits"]
+
+ out.update({"aux_logits":aux_logits,"features":features})
+ out.update({"base_features":base_feature_map})
+ return out
+
+ '''
+ {
+ 'features': features
+ 'logits': logits
+ 'aux_logits':aux_logits
+ }
+ '''
+
+ def update_fc(self,nb_classes):
+ _ , _new_extractor = get_convnet(self.args)
+ if len(self.AdaptiveExtractors)==0:
+ self.AdaptiveExtractors.append(_new_extractor)
+ else:
+ self.AdaptiveExtractors.append(_new_extractor)
+ self.AdaptiveExtractors[-1].load_state_dict(self.AdaptiveExtractors[-2].state_dict())
+
+ if self.out_dim is None:
+ logging.info(self.AdaptiveExtractors[-1])
+ self.out_dim=self.AdaptiveExtractors[-1].feature_dim
+ fc = self.generate_fc(self.feature_dim, nb_classes)
+ if self.fc is not None:
+ nb_output = self.fc.out_features
+ weight = copy.deepcopy(self.fc.weight.data)
+ bias = copy.deepcopy(self.fc.bias.data)
+ fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight
+ fc.bias.data[:nb_output] = bias
+
+ del self.fc
+ self.fc = fc
+
+ new_task_size = nb_classes - sum(self.task_sizes)
+ self.task_sizes.append(new_task_size)
+ self.aux_fc=self.generate_fc(self.out_dim,new_task_size+1)
+
+ def generate_fc(self, in_dim, out_dim):
+ fc = SimpleLinear(in_dim, out_dim)
+ return fc
+
+ def copy(self):
+ return copy.deepcopy(self)
+
+ def weight_align(self, increment):
+ weights=self.fc.weight.data
+ newnorm=(torch.norm(weights[-increment:,:],p=2,dim=1))
+ oldnorm=(torch.norm(weights[:-increment,:],p=2,dim=1))
+ meannew=torch.mean(newnorm)
+ meanold=torch.mean(oldnorm)
+ gamma=meanold/meannew
+ print('alignweights,gamma=',gamma)
+ self.fc.weight.data[-increment:,:]*=gamma
+
+ def load_checkpoint(self, args):
+ checkpoint_name = args["backbone"]
+ model_infos = torch.load(checkpoint_name)
+ model_dict = model_infos['convnet']
+ assert len(self.AdaptiveExtractors) == 1
+
+ base_state_dict = self.TaskAgnosticExtractor.state_dict()
+ adap_state_dict = self.AdaptiveExtractors[0].state_dict()
+
+ pretrained_base_dict = {
+ k:v
+ for k, v in model_dict.items()
+ if k in base_state_dict
+ }
+
+ pretrained_adap_dict = {
+ k:v
+ for k, v in model_dict.items()
+ if k in adap_state_dict
+ }
+
+ base_state_dict.update(pretrained_base_dict)
+ adap_state_dict.update(pretrained_adap_dict)
+
+ self.TaskAgnosticExtractor.load_state_dict(base_state_dict)
+ self.AdaptiveExtractors[0].load_state_dict(adap_state_dict)
+ #self.fc.load_state_dict(model_infos['fc'])
+ test_acc = model_infos['test_acc']
+ return test_acc
diff --git a/utils/ops.py b/utils/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..66dcb775caed0ebd9105ac530acd3e70e138d72e
--- /dev/null
+++ b/utils/ops.py
@@ -0,0 +1,121 @@
+from PIL import Image, ImageEnhance, ImageOps
+import random
+import torch
+import numpy as np
+class Cutout(object):
+ def __init__(self, n_holes, length):
+ self.n_holes = n_holes
+ self.length = length
+
+ def __call__(self, img):
+ h = img.size(1)
+ w = img.size(2)
+
+ mask = np.ones((h, w), np.float32)
+
+ for n in range(self.n_holes):
+ y = np.random.randint(h)
+ x = np.random.randint(w)
+
+ y1 = np.clip(y - self.length // 2, 0, h)
+ y2 = np.clip(y + self.length // 2, 0, h)
+ x1 = np.clip(x - self.length // 2, 0, w)
+ x2 = np.clip(x + self.length // 2, 0, w)
+
+ mask[y1: y2, x1: x2] = 0.
+
+ mask = torch.from_numpy(mask)
+ mask = mask.expand_as(img)
+ img = img * mask
+
+ return img
+
+class ShearX(object):
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.fillcolor = fillcolor
+
+ def __call__(self, x, magnitude):
+ return x.transform(
+ x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
+ Image.BICUBIC, fillcolor=self.fillcolor)
+
+
+class ShearY(object):
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.fillcolor = fillcolor
+
+ def __call__(self, x, magnitude):
+ return x.transform(
+ x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
+ Image.BICUBIC, fillcolor=self.fillcolor)
+
+
+class TranslateX(object):
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.fillcolor = fillcolor
+
+ def __call__(self, x, magnitude):
+ return x.transform(
+ x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0),
+ fillcolor=self.fillcolor)
+
+
+class TranslateY(object):
+ def __init__(self, fillcolor=(128, 128, 128)):
+ self.fillcolor = fillcolor
+
+ def __call__(self, x, magnitude):
+ return x.transform(
+ x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])),
+ fillcolor=self.fillcolor)
+
+
+class Rotate(object):
+ def __call__(self, x, magnitude):
+ rot = x.convert("RGBA").rotate(magnitude * random.choice([-1, 1]))
+ return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(x.mode)
+
+
+class Color(object):
+ def __call__(self, x, magnitude):
+ return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1]))
+
+
+class Posterize(object):
+ def __call__(self, x, magnitude):
+ return ImageOps.posterize(x, magnitude)
+
+
+class Solarize(object):
+ def __call__(self, x, magnitude):
+ return ImageOps.solarize(x, magnitude)
+
+
+class Contrast(object):
+ def __call__(self, x, magnitude):
+ return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1]))
+
+
+class Sharpness(object):
+ def __call__(self, x, magnitude):
+ return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1]))
+
+
+class Brightness(object):
+ def __call__(self, x, magnitude):
+ return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1]))
+
+
+class AutoContrast(object):
+ def __call__(self, x, magnitude):
+ return ImageOps.autocontrast(x)
+
+
+class Equalize(object):
+ def __call__(self, x, magnitude):
+ return ImageOps.equalize(x)
+
+
+class Invert(object):
+ def __call__(self, x, magnitude):
+ return ImageOps.invert(x)
diff --git a/utils/rl_utils/ddpg.py b/utils/rl_utils/ddpg.py
new file mode 100644
index 0000000000000000000000000000000000000000..555e46645d91bc08e7e90e6b3d262248a8d2900f
--- /dev/null
+++ b/utils/rl_utils/ddpg.py
@@ -0,0 +1,206 @@
+import logging
+import torch
+from torch import nn
+import torch.nn.functional as F
+import numpy as np
+
+
+class PolicyNet(torch.nn.Module):
+ def __init__(self, state_dim, hidden_dim, action_dim, action_bound):
+ super(PolicyNet, self).__init__()
+ self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
+ self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
+ self.action_bound = action_bound
+
+ def forward(self, x):
+ x = F.relu(self.fc1(x))
+ return torch.tanh(self.fc2(x)) * self.action_bound
+
+
+class RMMPolicyNet(torch.nn.Module):
+ def __init__(self, state_dim, hidden_dim, action_dim):
+ super(RMMPolicyNet, self).__init__()
+ self.fc1 = nn.Sequential(
+ nn.Linear(state_dim, hidden_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(hidden_dim, action_dim),
+ )
+ self.fc2 = nn.Sequential(
+ nn.Linear(state_dim+action_dim, hidden_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(hidden_dim, action_dim),
+ )
+ def forward(self, x):
+ a1 = torch.sigmoid(self.fc1(x))
+ x = torch.cat([x,a1],dim=1)
+ a2 = torch.tanh(self.fc2(x))
+ return torch.cat([a1,a2],dim=1)
+
+class QValueNet(torch.nn.Module):
+ def __init__(self, state_dim, hidden_dim, action_dim):
+ super(QValueNet, self).__init__()
+ self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
+ self.fc2 = torch.nn.Linear(hidden_dim, 1)
+
+ def forward(self, x, a):
+ cat = torch.cat([x, a], dim=1)
+ x = F.relu(self.fc1(cat))
+ return self.fc2(x)
+
+
+class TwoLayerFC(torch.nn.Module):
+ def __init__(
+ self, num_in, num_out, hidden_dim, activation=F.relu, out_fn=lambda x: x
+ ):
+ super().__init__()
+ self.fc1 = nn.Linear(num_in, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
+ self.fc3 = nn.Linear(hidden_dim, num_out)
+
+ self.activation = activation
+ self.out_fn = out_fn
+
+ def forward(self, x):
+ x = self.activation(self.fc1(x))
+ x = self.activation(self.fc2(x))
+ x = self.out_fn(self.fc3(x))
+ return x
+
+
+class DDPG:
+ """DDPG algo"""
+
+ def __init__(
+ self,
+ num_in_actor,
+ num_out_actor,
+ num_in_critic,
+ hidden_dim,
+ discrete,
+ action_bound,
+ sigma,
+ actor_lr,
+ critic_lr,
+ tau,
+ gamma,
+ device,
+ use_rmm=True,
+ ):
+
+ out_fn = (lambda x: x) if discrete else (lambda x: torch.tanh(x) * action_bound)
+
+ if use_rmm:
+ self.actor = RMMPolicyNet(
+ num_in_actor,
+ hidden_dim,
+ num_out_actor,
+ ).to(device)
+ self.target_actor = RMMPolicyNet(
+ num_in_actor,
+ hidden_dim,
+ num_out_actor,
+ ).to(device)
+ else:
+ self.actor = TwoLayerFC(
+ num_in_actor,
+ num_out_actor,
+ hidden_dim,
+ activation=F.relu,
+ out_fn=out_fn,
+ ).to(device)
+ self.target_actor = TwoLayerFC(
+ num_in_actor,
+ num_out_actor,
+ hidden_dim,
+ activation=F.relu,
+ out_fn=out_fn,
+ ).to(device)
+
+ self.critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device)
+ self.target_critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device)
+ self.target_critic.load_state_dict(self.critic.state_dict())
+ self.target_actor.load_state_dict(self.actor.state_dict())
+ self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
+ self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
+ self.gamma = gamma
+ self.sigma = sigma
+ self.action_bound = action_bound
+ self.tau = tau
+ self.action_dim = num_out_actor
+ self.device = device
+
+ def take_action(self, state):
+ state = torch.tensor(np.expand_dims(state,0), dtype=torch.float).to(self.device)
+ action = self.actor(state)[0].detach().cpu().numpy()
+
+ action = action + self.sigma * np.random.randn(self.action_dim)
+ action[0]=np.clip(action[0],0,1)
+ action[1]=np.clip(action[1],-1,1)
+ return action
+ def save_state_dict(self,name):
+ dicts = {
+ "critic":self.critic.state_dict(),
+ "target_critic":self.target_critic.state_dict(),
+ "actor":self.actor.state_dict(),
+ "target_actor":self.target_actor.state_dict()
+ }
+ torch.save(dicts,name)
+ def load_state_dict(self,name):
+ dicts = torch.load(name)
+ self.critic.load_state_dict(dicts["critic"])
+ self.target_critic.load_state_dict(dicts["target_critic"])
+ self.actor.load_state_dict(dicts["actor"])
+ self.target_actor.load_state_dict(dicts["target_actor"])
+ def soft_update(self, net, target_net):
+ for param_target, param in zip(target_net.parameters(), net.parameters()):
+ param_target.data.copy_(
+ param_target.data * (1.0 - self.tau) + param.data * self.tau
+ )
+
+ def update(self, transition_dict):
+ states = torch.tensor(transition_dict["states"], dtype=torch.float).to(
+ self.device
+ )
+ actions = (
+ torch.tensor(transition_dict["actions"], dtype=torch.float)
+ .to(self.device)
+ )
+ rewards = (
+ torch.tensor(transition_dict["rewards"], dtype=torch.float)
+ .view(-1, 1)
+ .to(self.device)
+ )
+ next_states = torch.tensor(
+ transition_dict["next_states"], dtype=torch.float
+ ).to(self.device)
+ dones = (
+ torch.tensor(transition_dict["dones"], dtype=torch.float)
+ .view(-1, 1)
+ .to(self.device)
+ )
+
+ next_q_values = self.target_critic(
+ torch.cat([next_states, self.target_actor(next_states)], dim=1)
+ )
+ q_targets = rewards + self.gamma * next_q_values * (1 - dones)
+ critic_loss = torch.mean(
+ F.mse_loss(
+ self.critic(torch.cat([states, actions], dim=1)),
+ q_targets,
+ )
+ )
+ self.critic_optimizer.zero_grad()
+ critic_loss.backward()
+ self.critic_optimizer.step()
+
+ actor_loss = -torch.mean(
+ self.critic(
+ torch.cat([states, self.actor(states)], dim=1)
+ )
+ )
+ self.actor_optimizer.zero_grad()
+ actor_loss.backward()
+ self.actor_optimizer.step()
+ logging.info(f"update DDPG: actor loss {actor_loss.item():.3f}, critic loss {critic_loss.item():.3f}, ")
+ self.soft_update(self.actor, self.target_actor) # soft-update the target policy net
+ self.soft_update(self.critic, self.target_critic) # soft-update the target Q value net
diff --git a/utils/rl_utils/rl_utils.py b/utils/rl_utils/rl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..33799bbd129890e2a8b14c48ff5086bfd51f7e97
--- /dev/null
+++ b/utils/rl_utils/rl_utils.py
@@ -0,0 +1,20 @@
+from tqdm import tqdm
+import numpy as np
+import torch
+import collections
+import random
+
+class ReplayBuffer:
+ def __init__(self, capacity):
+ self.buffer = collections.deque(maxlen=capacity)
+
+ def add(self, state, action, reward, next_state, done):
+ self.buffer.append((state, action, reward, next_state, done))
+
+ def sample(self, batch_size):
+ transitions = random.sample(self.buffer, batch_size)
+ state, action, reward, next_state, done = zip(*transitions)
+ return np.array(state), np.array(action), reward, np.array(next_state), done
+
+ def size(self):
+ return len(self.buffer)
\ No newline at end of file
diff --git a/utils/toolkit.py b/utils/toolkit.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9792e5180b445bfbf1ce1cf923f213d3d852b4b
--- /dev/null
+++ b/utils/toolkit.py
@@ -0,0 +1,116 @@
+import os
+import numpy as np
+import torch
+import json
+from enum import Enum
+
+class ConfigEncoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o, type):
+ return {'$class': o.__module__ + "." + o.__name__}
+ elif isinstance(o, Enum):
+ return {
+ '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name
+ }
+ elif callable(o):
+ return {
+ '$function': o.__module__ + "." + o.__name__
+ }
+ return json.JSONEncoder.default(self, o)
+
+def count_parameters(model, trainable=False):
+ if trainable:
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return sum(p.numel() for p in model.parameters())
+
+
+def tensor2numpy(x):
+ return x.cpu().data.numpy() if x.is_cuda else x.data.numpy()
+
+
+def target2onehot(targets, n_classes):
+ onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)
+ onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0)
+ return onehot
+
+
+def makedirs(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def accuracy(y_pred, y_true, nb_old, increment=10):
+ assert len(y_pred) == len(y_true), "Data length error."
+ all_acc = {}
+ all_acc["total"] = np.around(
+ (y_pred == y_true).sum() * 100 / len(y_true), decimals=2
+ )
+
+ # Grouped accuracy
+ for class_id in range(0, np.max(y_true), increment):
+ idxes = np.where(
+ np.logical_and(y_true >= class_id, y_true < class_id + increment)
+ )[0]
+ if increment == 1:
+ label = "{}".format(
+ str(class_id).rjust(2, "0")
+ )
+ else:
+ label = "{}-{}".format(
+ str(class_id).rjust(2, "0"), str(class_id + increment - 1).rjust(2, "0")
+ )
+ all_acc[label] = np.around(
+ (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
+ )
+
+ # Old accuracy
+ idxes = np.where(y_true < nb_old)[0]
+ all_acc["old"] = (
+ 0
+ if len(idxes) == 0
+ else np.around(
+ (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
+ )
+ )
+
+ # New accuracy
+ idxes = np.where(y_true >= nb_old)[0]
+ all_acc["new"] = np.around(
+ (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
+ )
+
+ return all_acc
+
+
+def split_images_labels(imgs, start_index = 0):
+ # split trainset.imgs in ImageFolder
+ images = []
+ labels = []
+ for item in imgs:
+ images.append(item[0])
+ labels.append(item[1] + start_index)
+ return np.array(images), np.array(labels)
+
+def save_fc(args, model):
+ _path = os.path.join(args['logfilename'], "fc.pt")
+ if len(args['device']) > 1:
+ fc_weight = model._network.fc.weight.data
+ else:
+ fc_weight = model._network.fc.weight.data.cpu()
+ torch.save(fc_weight, _path)
+
+ _save_dir = os.path.join(f"./results/fc_weights/{args['prefix']}")
+ os.makedirs(_save_dir, exist_ok=True)
+ _save_path = os.path.join(_save_dir, f"{args['csv_name']}.csv")
+ with open(_save_path, "a+") as f:
+ f.write(f"{args['time_str']},{args['model_name']},{_path} \n")
+
+
+def save_model(args, model):
+ #used in PODNet
+ _path = os.path.join(args['logfilename'], "model.pt")
+ if len(args['device']) > 1:
+ weight = model._network
+ else:
+ weight = model._network.cpu()
+ torch.save(weight, _path)
\ No newline at end of file