|
|
|
from configuration import DatasetName, D300wConf, InputDataSize, CofwConf, WflwConf |
|
from cnn_model import CNNModel |
|
from pca_utility import PCAUtility |
|
from image_utility import ImageUtility |
|
from student_train import StudentTrainer |
|
from test import Test |
|
from teacher_trainer import TeacherTrainer |
|
|
|
if __name__ == '__main__': |
|
|
|
'''test models''' |
|
|
|
'''train Teacher Networks''' |
|
trainer = TeacherTrainer(dataset_name=DatasetName.w300) |
|
trainer.train(arch='efficientNet',weight_path=None) |
|
|
|
|
|
'''Training Student Network''' |
|
'''300W''' |
|
st_trainer = StudentTrainer(dataset_name=DatasetName.w300, use_augmneted=True) |
|
st_trainer.train(arch_student='mobileNetV2', weight_path_student=None, |
|
loss_weight_student=2.0, |
|
arch_tough_teacher='efficientNet', weight_path_tough_teacher='./models/teachers/ds_300w_ef_tou.h5', |
|
loss_weight_tough_teacher=1, |
|
arch_tol_teacher='efficientNet', weight_path_tol_teacher='./models/teachers/ds_300w_ef_tol.h5', |
|
loss_weight_tol_teacher=1) |
|
|
|
'''COFW''' |
|
st_trainer = StudentTrainer(dataset_name=DatasetName.cofw, use_augmneted=True) |
|
st_trainer.train(arch_student='mobileNetV2', weight_path_student=None, |
|
loss_weight_student=2.0, |
|
arch_tough_teacher='efficientNet', weight_path_tough_teacher='./models/teachers/ds_cofw_ef_tou.h5', |
|
loss_weight_tough_teacher=1, |
|
arch_tol_teacher='efficientNet', weight_path_tol_teacher='./models/teachers/ds_cofw_ef_tol.h5', |
|
loss_weight_tol_teacher=1) |
|
|
|
'''WFLW''' |
|
st_trainer = StudentTrainer(dataset_name=DatasetName.wflw, use_augmneted=True) |
|
st_trainer.train(arch_student='mobileNetV2', weight_path_student=None, |
|
loss_weight_student=2.0, |
|
arch_tough_teacher='efficientNet', weight_path_tough_teacher='./models/teachers/ds_wflw_ef_tou.h5', |
|
loss_weight_tough_teacher=1, |
|
arch_tol_teacher='efficientNet', weight_path_tol_teacher='./models/teachers/ds_wflw_ef_tol.h5', |
|
loss_weight_tol_teacher=1) |