diff --git a/README.md b/README.md index fc9dadebc0952b26058118c8c5428366fe8a83c2..9ada45d5bacde3e1d4b4afffe4d74d52ce6885fc 100644 --- a/README.md +++ b/README.md @@ -8,5 +8,3 @@ sdk_version: 3.36.1 app_file: app.py pinned: false --- - -Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfd0d236498928249ce7f9ff5642ad5c9843ec6 --- /dev/null +++ b/app.py @@ -0,0 +1,230 @@ +import os +import numpy as np +import gradio as gr +from glob import glob +from functools import partial +from dataclasses import dataclass + +import torch +import torchvision +import torch.nn as nn +import lightning.pytorch as pl +import torchvision.transforms as TF + +from torchmetrics import MeanMetric +from torchmetrics.classification import MultilabelF1Score + + +@dataclass +class DatasetConfig: + IMAGE_SIZE: tuple = (384, 384) # (W, H) + CHANNELS: int = 3 + NUM_CLASSES: int = 10 + MEAN: tuple = (0.485, 0.456, 0.406) + STD: tuple = (0.229, 0.224, 0.225) + + +@dataclass +class TrainingConfig: + METRIC_THRESH: float = 0.4 + MODEL_NAME: str = "efficientnet_v2_s" + FREEZE_BACKBONE: bool = False + + +def get_model(model_name: str, num_classes: int, freeze_backbone: bool = True): + """A helper function to load and prepare any classification model + available in Torchvision for transfer learning or fine-tuning.""" + + model = getattr(torchvision.models, model_name)(weights="DEFAULT") + + if freeze_backbone: + # Set all layer to be non-trainable + for param in model.parameters(): + param.requires_grad = False + + model_childrens = [name for name, _ in model.named_children()] + + try: + final_layer_in_features = getattr(model, f"{model_childrens[-1]}")[-1].in_features + except Exception as e: + final_layer_in_features = getattr(model, f"{model_childrens[-1]}").in_features + + new_output_layer = nn.Linear(in_features=final_layer_in_features, out_features=num_classes) + + try: + getattr(model, f"{model_childrens[-1]}")[-1] = new_output_layer + except: + setattr(model, model_childrens[-1], new_output_layer) + + return model + + +class ProteinModel(pl.LightningModule): + def __init__( + self, + model_name: str, + num_classes: int = 10, + freeze_backbone: bool = False, + init_lr: float = 0.001, + optimizer_name: str = "Adam", + weight_decay: float = 1e-4, + use_scheduler: bool = False, + f1_metric_threshold: float = 0.4, + ): + super().__init__() + + # Save the arguments as hyperparameters. + self.save_hyperparameters() + + # Loading model using the function defined above. + self.model = get_model( + model_name=self.hparams.model_name, + num_classes=self.hparams.num_classes, + freeze_backbone=self.hparams.freeze_backbone, + ) + + # Intialize loss class. + self.loss_fn = nn.BCEWithLogitsLoss() + + # Initializing the required metric objects. + self.mean_train_loss = MeanMetric() + self.mean_train_f1 = MultilabelF1Score(num_labels=self.hparams.num_classes, average="macro", threshold=self.hparams.f1_metric_threshold) + self.mean_valid_loss = MeanMetric() + self.mean_valid_f1 = MultilabelF1Score(num_labels=self.hparams.num_classes, average="macro", threshold=self.hparams.f1_metric_threshold) + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, *args, **kwargs): + data, target = batch + logits = self(data) + loss = self.loss_fn(logits, target) + + self.mean_train_loss(loss, weight=data.shape[0]) + self.mean_train_f1(logits, target) + + self.log("train/batch_loss", self.mean_train_loss, prog_bar=True) + self.log("train/batch_f1", self.mean_train_f1, prog_bar=True) + return loss + + def on_train_epoch_end(self): + # Computing and logging the training mean loss & mean f1. + self.log("train/loss", self.mean_train_loss, prog_bar=True) + self.log("train/f1", self.mean_train_f1, prog_bar=True) + self.log("step", self.current_epoch) + + def validation_step(self, batch, *args, **kwargs): + data, target = batch # Unpacking validation dataloader tuple + logits = self(data) + loss = self.loss_fn(logits, target) + + self.mean_valid_loss.update(loss, weight=data.shape[0]) + self.mean_valid_f1.update(logits, target) + + def on_validation_epoch_end(self): + # Computing and logging the validation mean loss & mean f1. + self.log("valid/loss", self.mean_valid_loss, prog_bar=True) + self.log("valid/f1", self.mean_valid_f1, prog_bar=True) + self.log("step", self.current_epoch) + + def configure_optimizers(self): + optimizer = getattr(torch.optim, self.hparams.optimizer_name)( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.hparams.init_lr, + weight_decay=self.hparams.weight_decay, + ) + + if self.hparams.use_scheduler: + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[ + self.trainer.max_epochs // 2, + ], + gamma=0.1, + ) + + # The lr_scheduler_config is a dictionary that contains the scheduler + # and its associated configuration. + lr_scheduler_config = { + "scheduler": lr_scheduler, + "interval": "epoch", + "name": "multi_step_lr", + } + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} + + else: + return optimizer + + +@torch.inference_mode() +def predict(input_image, threshold=0.4, model=None, preprocess_fn=None, device="cpu", idx2labels=None): + input_tensor = preprocess_fn(input_image) + input_tensor = input_tensor.unsqueeze(0).to(device) + + # Generate predictions + output = model(input_tensor).cpu() + + probabilities = torch.sigmoid(output)[0].numpy().tolist() + + output_probs = dict() + predicted_classes = [] + + for idx, prob in enumerate(probabilities): + output_probs[idx2labels[idx]] = prob + if prob >= threshold: + predicted_classes.append(idx2labels[idx]) + + predicted_classes = "\n".join(predicted_classes) + return predicted_classes, output_probs + + +if __name__ == "__main__": + labels = { + 0: "Mitochondria", + 1: "Nuclear bodies", + 2: "Nucleoli", + 3: "Golgi apparatus", + 4: "Nucleoplasm", + 5: "Nucleoli fibrillar center", + 6: "Cytosol", + 7: "Plasma membrane", + 8: "Centrosome", + 9: "Nuclear speckles", + } + + DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + CKPT_PATH = os.path.join(os.getcwd(), r"ckpt_024-vloss_0.1816_vf1_0.7855.ckpt") + model = ProteinModel.load_from_checkpoint(CKPT_PATH) + model.to(DEVICE) + model.eval() + _ = model(torch.randn(1, DatasetConfig.CHANNELS, *DatasetConfig.IMAGE_SIZE[::-1], device=DEVICE)) + + preprocess = TF.Compose( + [ + TF.Resize(size=DatasetConfig.IMAGE_SIZE[::-1]), + TF.ToTensor(), + TF.Normalize(DatasetConfig.MEAN, DatasetConfig.STD, inplace=True), + ] + ) + + images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png") + examples = [[i, TrainingConfig.METRIC_THRESH] for i in np.random.choice(images_dir, size=8, replace=False)] + print(examples) + + iface = gr.Interface( + fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE, idx2labels=labels), + inputs=[ + gr.Image(type="pil", label="Image"), + gr.Slider(0.0, 1.0, value=0.4, label="Threshold", info="Select the cut-off threshold for a node to be considered as a valid output."), + ], + outputs=[ + gr.Textbox(label="Labels Present"), + gr.Label(label="Probabilities", show_label=False), + ], + examples=examples, + cache_examples=False, + allow_flagging="never", + title="Medical Multi-Label Image Classification", + ) + + iface.launch() diff --git a/ckpt_024-vloss_0.1816_vf1_0.7855.ckpt b/ckpt_024-vloss_0.1816_vf1_0.7855.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..d01410cf9b112aa65c9b179e243369415bfc5a7e --- /dev/null +++ b/ckpt_024-vloss_0.1816_vf1_0.7855.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eeba4764adb310bf3a35ba2479326fdbf38acaed3242a9f020ff2d7eba47b2ca +size 243578302 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..dbb5b355c08932e9a26ff53dac66dfc23174ca40 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +--find-links https://download.pytorch.org/whl/torch_stable.html +torch==2.0.0+cpu +torchvision==0.15.0 +torchmetrics==1.0.0 +lightning==2.0.4 \ No newline at end of file diff --git a/samples/10.png b/samples/10.png new file mode 100644 index 0000000000000000000000000000000000000000..349b14fc07ac30114136b4d2fd26f0a36ba63f31 Binary files /dev/null and b/samples/10.png differ diff --git a/samples/10267.png b/samples/10267.png new file mode 100644 index 0000000000000000000000000000000000000000..dd78cc3b3d744e38248595716fe644b71d8297fd Binary files /dev/null and b/samples/10267.png differ diff --git a/samples/10423.png b/samples/10423.png new file mode 100644 index 0000000000000000000000000000000000000000..174b77193cf8e47c9f109ce50f9baa7e33ee63b9 Binary files /dev/null and b/samples/10423.png differ diff --git a/samples/116.png b/samples/116.png new file mode 100644 index 0000000000000000000000000000000000000000..a4e208c34d8cc1477e94ad558e565d8b6b766fa0 Binary files /dev/null and b/samples/116.png differ diff --git a/samples/11603.png b/samples/11603.png new file mode 100644 index 0000000000000000000000000000000000000000..d5410f29ae6481924f7d2e7e1d67c20846c0d69f Binary files /dev/null and b/samples/11603.png differ diff --git a/samples/13698.png b/samples/13698.png new file mode 100644 index 0000000000000000000000000000000000000000..ee0bc39dd4714541fdda1b4106917b65c54cf07b Binary files /dev/null and b/samples/13698.png differ diff --git a/samples/14311.png b/samples/14311.png new file mode 100644 index 0000000000000000000000000000000000000000..6cb3f8e3074a3a225c1fd4a1a2fe21970b75583c Binary files /dev/null and b/samples/14311.png differ diff --git a/samples/14546.png b/samples/14546.png new file mode 100644 index 0000000000000000000000000000000000000000..30a2fbf271ed4fb2c4aacc91341f56237b3b6289 Binary files /dev/null and b/samples/14546.png differ diff --git a/samples/15528.png b/samples/15528.png new file mode 100644 index 0000000000000000000000000000000000000000..19d6cf00f110f71badac5be576a1a20b7bf644ea Binary files /dev/null and b/samples/15528.png differ diff --git a/samples/15561.png b/samples/15561.png new file mode 100644 index 0000000000000000000000000000000000000000..a68029d7afd72f4573bf6dc058fd1457984440c4 Binary files /dev/null and b/samples/15561.png differ diff --git a/samples/16150.png b/samples/16150.png new file mode 100644 index 0000000000000000000000000000000000000000..689571359cbd3537c67e46267eb8b654b658b77c Binary files /dev/null and b/samples/16150.png differ diff --git a/samples/16312.png b/samples/16312.png new file mode 100644 index 0000000000000000000000000000000000000000..4f6bfce7fc2c0d54950e04e3c2a9d4416f57d455 Binary files /dev/null and b/samples/16312.png differ diff --git a/samples/16411.png b/samples/16411.png new file mode 100644 index 0000000000000000000000000000000000000000..be67f1928076a862cf3f1ff270158a7170e267eb Binary files /dev/null and b/samples/16411.png differ diff --git a/samples/16621.png b/samples/16621.png new file mode 100644 index 0000000000000000000000000000000000000000..fcfeccb4374815ae090fe2420235dfcd590fc0c0 Binary files /dev/null and b/samples/16621.png differ diff --git a/samples/17289.png b/samples/17289.png new file mode 100644 index 0000000000000000000000000000000000000000..f967c26c92e9af4aeb39ef4661bb3cce5491e315 Binary files /dev/null and b/samples/17289.png differ diff --git a/samples/19682.png b/samples/19682.png new file mode 100644 index 0000000000000000000000000000000000000000..c08a394b9fe0523f4262e9842a83a292db5739d3 Binary files /dev/null and b/samples/19682.png differ diff --git a/samples/19884.png b/samples/19884.png new file mode 100644 index 0000000000000000000000000000000000000000..a2ac0f7ed95a4ffcc5a5dcc9fed3996e3a1b64a2 Binary files /dev/null and b/samples/19884.png differ diff --git a/samples/203.png b/samples/203.png new file mode 100644 index 0000000000000000000000000000000000000000..c428099a4e7f99614d07645509eaab9c0b762538 Binary files /dev/null and b/samples/203.png differ diff --git a/samples/21602.png b/samples/21602.png new file mode 100644 index 0000000000000000000000000000000000000000..2e6e3cba7c3085721bd8552d4890ff997ff188eb Binary files /dev/null and b/samples/21602.png differ diff --git a/samples/21920.png b/samples/21920.png new file mode 100644 index 0000000000000000000000000000000000000000..8e52fd63342b5f62198e49ceb76f1290a1401902 Binary files /dev/null and b/samples/21920.png differ diff --git a/samples/22594.png b/samples/22594.png new file mode 100644 index 0000000000000000000000000000000000000000..c1aec781baec9810729bb655c60e2aaf64af9a5c Binary files /dev/null and b/samples/22594.png differ diff --git a/samples/23625.png b/samples/23625.png new file mode 100644 index 0000000000000000000000000000000000000000..bf3af0403fd57b7d239b247137797b4057faefb1 Binary files /dev/null and b/samples/23625.png differ diff --git a/samples/24.png b/samples/24.png new file mode 100644 index 0000000000000000000000000000000000000000..e977717bcb6b6cae90e6aa0c758f91daff586c6a Binary files /dev/null and b/samples/24.png differ diff --git a/samples/24136.png b/samples/24136.png new file mode 100644 index 0000000000000000000000000000000000000000..b78397b0f70b243ab46a50f7368fe4c6f5737c53 Binary files /dev/null and b/samples/24136.png differ diff --git a/samples/24715.png b/samples/24715.png new file mode 100644 index 0000000000000000000000000000000000000000..88338e548948170088387fee0aeb808d6ae99312 Binary files /dev/null and b/samples/24715.png differ diff --git a/samples/24817.png b/samples/24817.png new file mode 100644 index 0000000000000000000000000000000000000000..c91040aab5a6155eb05aede8c662fbd9f459fa8d Binary files /dev/null and b/samples/24817.png differ diff --git a/samples/25140.png b/samples/25140.png new file mode 100644 index 0000000000000000000000000000000000000000..940146291b169f2457e8c74303c864e7065cadb2 Binary files /dev/null and b/samples/25140.png differ diff --git a/samples/2563.png b/samples/2563.png new file mode 100644 index 0000000000000000000000000000000000000000..159f133325fd13f554529904c6054f7194772e30 Binary files /dev/null and b/samples/2563.png differ diff --git a/samples/25826.png b/samples/25826.png new file mode 100644 index 0000000000000000000000000000000000000000..3746cf9d5802ffc1dd900dba0bd659886891d77d Binary files /dev/null and b/samples/25826.png differ diff --git a/samples/26591.png b/samples/26591.png new file mode 100644 index 0000000000000000000000000000000000000000..ffe80c61cbe21f9ba2188c7cc22cb4042705438d Binary files /dev/null and b/samples/26591.png differ diff --git a/samples/2694.png b/samples/2694.png new file mode 100644 index 0000000000000000000000000000000000000000..906b319daace30518d2288367aa8bd5f8d5c238d Binary files /dev/null and b/samples/2694.png differ diff --git a/samples/27926.png b/samples/27926.png new file mode 100644 index 0000000000000000000000000000000000000000..7e0c4b90f1ff23d39246ba17cd075b4aef15d5bc Binary files /dev/null and b/samples/27926.png differ diff --git a/samples/28.png b/samples/28.png new file mode 100644 index 0000000000000000000000000000000000000000..ff1a08fe78a7bbf3d24581773055c5ff9c416a81 Binary files /dev/null and b/samples/28.png differ diff --git a/samples/28661.png b/samples/28661.png new file mode 100644 index 0000000000000000000000000000000000000000..82469e3d7b0921fe79504a0d37908a2f64a0d189 Binary files /dev/null and b/samples/28661.png differ diff --git a/samples/28983.png b/samples/28983.png new file mode 100644 index 0000000000000000000000000000000000000000..3c90ac077c9b8e24e463b2baa639cb132c682314 Binary files /dev/null and b/samples/28983.png differ diff --git a/samples/30258.png b/samples/30258.png new file mode 100644 index 0000000000000000000000000000000000000000..c279ec390263ee46cdf1123940f2481d2206e0f4 Binary files /dev/null and b/samples/30258.png differ diff --git a/samples/30809.png b/samples/30809.png new file mode 100644 index 0000000000000000000000000000000000000000..d7640e96a3276eca01623f5f17105804d7c33dc1 Binary files /dev/null and b/samples/30809.png differ diff --git a/samples/3282.png b/samples/3282.png new file mode 100644 index 0000000000000000000000000000000000000000..8bac55f220d01453e568d4a3c0390fcdb1ebc97a Binary files /dev/null and b/samples/3282.png differ diff --git a/samples/3665.png b/samples/3665.png new file mode 100644 index 0000000000000000000000000000000000000000..145dcdc47b88b18cd9b8c743efe4a66449913fda Binary files /dev/null and b/samples/3665.png differ diff --git a/samples/381.png b/samples/381.png new file mode 100644 index 0000000000000000000000000000000000000000..e81ec9209ed62ea248aaa085aab5119b1f3eba43 Binary files /dev/null and b/samples/381.png differ diff --git a/samples/4595.png b/samples/4595.png new file mode 100644 index 0000000000000000000000000000000000000000..d7b51942e0a4656c383bb6db61eb08fbc2e43a18 Binary files /dev/null and b/samples/4595.png differ diff --git a/samples/483.png b/samples/483.png new file mode 100644 index 0000000000000000000000000000000000000000..c1242b41d82ed6e546a34e0ffebf133068b23cf1 Binary files /dev/null and b/samples/483.png differ diff --git a/samples/4928.png b/samples/4928.png new file mode 100644 index 0000000000000000000000000000000000000000..8286447e130937041dbc28b835d5eae9b23260b1 Binary files /dev/null and b/samples/4928.png differ diff --git a/samples/497.png b/samples/497.png new file mode 100644 index 0000000000000000000000000000000000000000..2ce787a7e88a7417fad3e4de10654a04c311331a Binary files /dev/null and b/samples/497.png differ diff --git a/samples/5378.png b/samples/5378.png new file mode 100644 index 0000000000000000000000000000000000000000..c8e8432ff0976cc55034b0fdf39b85c079091331 Binary files /dev/null and b/samples/5378.png differ diff --git a/samples/600.png b/samples/600.png new file mode 100644 index 0000000000000000000000000000000000000000..6a00b1cd689706a2786f5a27ddb8e23d6c98bebf Binary files /dev/null and b/samples/600.png differ diff --git a/samples/602.png b/samples/602.png new file mode 100644 index 0000000000000000000000000000000000000000..29c4b8c174024ba7d59987e5b061de039ecb1127 Binary files /dev/null and b/samples/602.png differ diff --git a/samples/604.png b/samples/604.png new file mode 100644 index 0000000000000000000000000000000000000000..c55448f76245473199141db56b986e839982c1e3 Binary files /dev/null and b/samples/604.png differ diff --git a/samples/61.png b/samples/61.png new file mode 100644 index 0000000000000000000000000000000000000000..81ad06fe1d6e9532b0b1df813b5b8e2553299d50 Binary files /dev/null and b/samples/61.png differ diff --git a/samples/6937.png b/samples/6937.png new file mode 100644 index 0000000000000000000000000000000000000000..a0e758da4b017702269cf51140a69c9c7f02172c Binary files /dev/null and b/samples/6937.png differ diff --git a/samples/9440.png b/samples/9440.png new file mode 100644 index 0000000000000000000000000000000000000000..bcce53c727ac8e0c28a43c2f16c6d68d7f296067 Binary files /dev/null and b/samples/9440.png differ