|
import torch |
|
import data_setup, model_builder, engine, utils, plotting |
|
|
|
from torchvision import transforms |
|
import argparse |
|
|
|
|
|
def set_memory_limit(): |
|
if torch.cuda.is_available(): |
|
try: |
|
torch.tensor([1], device='cuda') |
|
print(f"Device is GPU/CUDA.") |
|
device = 'cuda' |
|
return device |
|
except: |
|
print("Device is CPU.") |
|
device = 'cpu' |
|
return device |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train a model for Classification of types of Trash.") |
|
parser.add_argument("--train_dir", type=str, default="data/train", help="Directory containing training images") |
|
parser.add_argument("--test_dir", type=str, default="data/test", help="Directory containing testing images") |
|
parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for training") |
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training") |
|
parser.add_argument("--num_epochs", type=int, default=20, help="Number of epochs to train for") |
|
args = parser.parse_args() |
|
|
|
|
|
train_dir = args.train_dir |
|
test_dir = args.test_dir |
|
LEARNING_RATE = args.learning_rate |
|
BATCH_SIZE = args.batch_size |
|
NUM_EPOCHS = args.num_epochs |
|
HIDDEN_UNITS = 15 |
|
|
|
|
|
data_transform = transforms.Compose([ |
|
transforms.Resize((112, 112)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
train_dataloader, test_dataloader, class_names = data_setup.train_test_dataloader( |
|
train_dir=train_dir, |
|
test_dir=test_dir, |
|
transform=data_transform, |
|
batch_size=BATCH_SIZE |
|
) |
|
|
|
|
|
device = set_memory_limit() |
|
model = model_builder.TrashClassificationCNNModel(input_shape=3, |
|
hidden_units=HIDDEN_UNITS, |
|
output_shape=len(class_names) |
|
).to(device) |
|
|
|
|
|
loss_fn = torch.nn.CrossEntropyLoss() |
|
optimizer = torch.optim.Adam(model.parameters(), |
|
lr=LEARNING_RATE) |
|
|
|
|
|
metrics = engine.train(model=model, |
|
train_dataloader=train_dataloader, |
|
test_dataloader=test_dataloader, |
|
optimizer=optimizer, |
|
loss_fn=loss_fn, |
|
epochs=NUM_EPOCHS, |
|
device=device) |
|
|
|
|
|
utils.save_model(model=model, |
|
target_dir="models", |
|
model_name="Trash_Classification_Model_COLOURED.pth") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
plotting.plot_confusion_Matrix(model_path="models\Trash_Classification_Model_COLOURED.pth", |
|
dataloader=test_dataloader, |
|
class_names=class_names, |
|
device=device) |
|
|
|
|
|
plotting.plot_metrics(metrics) |