veb-101 commited on
Commit
2f747ba
1 Parent(s): 9fa4623

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
README.md CHANGED
@@ -8,5 +8,3 @@ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
 
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import gradio as gr
4
+ from glob import glob
5
+ from functools import partial
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ import torchvision
10
+ import torch.nn as nn
11
+ import lightning.pytorch as pl
12
+ import torchvision.transforms as TF
13
+
14
+ from torchmetrics import MeanMetric
15
+ from torchmetrics.classification import MultilabelF1Score
16
+
17
+
18
+ @dataclass
19
+ class DatasetConfig:
20
+ IMAGE_SIZE: tuple = (384, 384) # (W, H)
21
+ CHANNELS: int = 3
22
+ NUM_CLASSES: int = 10
23
+ MEAN: tuple = (0.485, 0.456, 0.406)
24
+ STD: tuple = (0.229, 0.224, 0.225)
25
+
26
+
27
+ @dataclass
28
+ class TrainingConfig:
29
+ METRIC_THRESH: float = 0.4
30
+ MODEL_NAME: str = "efficientnet_v2_s"
31
+ FREEZE_BACKBONE: bool = False
32
+
33
+
34
+ def get_model(model_name: str, num_classes: int, freeze_backbone: bool = True):
35
+ """A helper function to load and prepare any classification model
36
+ available in Torchvision for transfer learning or fine-tuning."""
37
+
38
+ model = getattr(torchvision.models, model_name)(weights="DEFAULT")
39
+
40
+ if freeze_backbone:
41
+ # Set all layer to be non-trainable
42
+ for param in model.parameters():
43
+ param.requires_grad = False
44
+
45
+ model_childrens = [name for name, _ in model.named_children()]
46
+
47
+ try:
48
+ final_layer_in_features = getattr(model, f"{model_childrens[-1]}")[-1].in_features
49
+ except Exception as e:
50
+ final_layer_in_features = getattr(model, f"{model_childrens[-1]}").in_features
51
+
52
+ new_output_layer = nn.Linear(in_features=final_layer_in_features, out_features=num_classes)
53
+
54
+ try:
55
+ getattr(model, f"{model_childrens[-1]}")[-1] = new_output_layer
56
+ except:
57
+ setattr(model, model_childrens[-1], new_output_layer)
58
+
59
+ return model
60
+
61
+
62
+ class ProteinModel(pl.LightningModule):
63
+ def __init__(
64
+ self,
65
+ model_name: str,
66
+ num_classes: int = 10,
67
+ freeze_backbone: bool = False,
68
+ init_lr: float = 0.001,
69
+ optimizer_name: str = "Adam",
70
+ weight_decay: float = 1e-4,
71
+ use_scheduler: bool = False,
72
+ f1_metric_threshold: float = 0.4,
73
+ ):
74
+ super().__init__()
75
+
76
+ # Save the arguments as hyperparameters.
77
+ self.save_hyperparameters()
78
+
79
+ # Loading model using the function defined above.
80
+ self.model = get_model(
81
+ model_name=self.hparams.model_name,
82
+ num_classes=self.hparams.num_classes,
83
+ freeze_backbone=self.hparams.freeze_backbone,
84
+ )
85
+
86
+ # Intialize loss class.
87
+ self.loss_fn = nn.BCEWithLogitsLoss()
88
+
89
+ # Initializing the required metric objects.
90
+ self.mean_train_loss = MeanMetric()
91
+ self.mean_train_f1 = MultilabelF1Score(num_labels=self.hparams.num_classes, average="macro", threshold=self.hparams.f1_metric_threshold)
92
+ self.mean_valid_loss = MeanMetric()
93
+ self.mean_valid_f1 = MultilabelF1Score(num_labels=self.hparams.num_classes, average="macro", threshold=self.hparams.f1_metric_threshold)
94
+
95
+ def forward(self, x):
96
+ return self.model(x)
97
+
98
+ def training_step(self, batch, *args, **kwargs):
99
+ data, target = batch
100
+ logits = self(data)
101
+ loss = self.loss_fn(logits, target)
102
+
103
+ self.mean_train_loss(loss, weight=data.shape[0])
104
+ self.mean_train_f1(logits, target)
105
+
106
+ self.log("train/batch_loss", self.mean_train_loss, prog_bar=True)
107
+ self.log("train/batch_f1", self.mean_train_f1, prog_bar=True)
108
+ return loss
109
+
110
+ def on_train_epoch_end(self):
111
+ # Computing and logging the training mean loss & mean f1.
112
+ self.log("train/loss", self.mean_train_loss, prog_bar=True)
113
+ self.log("train/f1", self.mean_train_f1, prog_bar=True)
114
+ self.log("step", self.current_epoch)
115
+
116
+ def validation_step(self, batch, *args, **kwargs):
117
+ data, target = batch # Unpacking validation dataloader tuple
118
+ logits = self(data)
119
+ loss = self.loss_fn(logits, target)
120
+
121
+ self.mean_valid_loss.update(loss, weight=data.shape[0])
122
+ self.mean_valid_f1.update(logits, target)
123
+
124
+ def on_validation_epoch_end(self):
125
+ # Computing and logging the validation mean loss & mean f1.
126
+ self.log("valid/loss", self.mean_valid_loss, prog_bar=True)
127
+ self.log("valid/f1", self.mean_valid_f1, prog_bar=True)
128
+ self.log("step", self.current_epoch)
129
+
130
+ def configure_optimizers(self):
131
+ optimizer = getattr(torch.optim, self.hparams.optimizer_name)(
132
+ filter(lambda p: p.requires_grad, self.model.parameters()),
133
+ lr=self.hparams.init_lr,
134
+ weight_decay=self.hparams.weight_decay,
135
+ )
136
+
137
+ if self.hparams.use_scheduler:
138
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
139
+ optimizer,
140
+ milestones=[
141
+ self.trainer.max_epochs // 2,
142
+ ],
143
+ gamma=0.1,
144
+ )
145
+
146
+ # The lr_scheduler_config is a dictionary that contains the scheduler
147
+ # and its associated configuration.
148
+ lr_scheduler_config = {
149
+ "scheduler": lr_scheduler,
150
+ "interval": "epoch",
151
+ "name": "multi_step_lr",
152
+ }
153
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
154
+
155
+ else:
156
+ return optimizer
157
+
158
+
159
+ @torch.inference_mode()
160
+ def predict(input_image, threshold=0.4, model=None, preprocess_fn=None, device="cpu", idx2labels=None):
161
+ input_tensor = preprocess_fn(input_image)
162
+ input_tensor = input_tensor.unsqueeze(0).to(device)
163
+
164
+ # Generate predictions
165
+ output = model(input_tensor).cpu()
166
+
167
+ probabilities = torch.sigmoid(output)[0].numpy().tolist()
168
+
169
+ output_probs = dict()
170
+ predicted_classes = []
171
+
172
+ for idx, prob in enumerate(probabilities):
173
+ output_probs[idx2labels[idx]] = prob
174
+ if prob >= threshold:
175
+ predicted_classes.append(idx2labels[idx])
176
+
177
+ predicted_classes = "\n".join(predicted_classes)
178
+ return predicted_classes, output_probs
179
+
180
+
181
+ if __name__ == "__main__":
182
+ labels = {
183
+ 0: "Mitochondria",
184
+ 1: "Nuclear bodies",
185
+ 2: "Nucleoli",
186
+ 3: "Golgi apparatus",
187
+ 4: "Nucleoplasm",
188
+ 5: "Nucleoli fibrillar center",
189
+ 6: "Cytosol",
190
+ 7: "Plasma membrane",
191
+ 8: "Centrosome",
192
+ 9: "Nuclear speckles",
193
+ }
194
+
195
+ DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
196
+ CKPT_PATH = os.path.join(os.getcwd(), r"ckpt_024-vloss_0.1816_vf1_0.7855.ckpt")
197
+ model = ProteinModel.load_from_checkpoint(CKPT_PATH)
198
+ model.to(DEVICE)
199
+ model.eval()
200
+ _ = model(torch.randn(1, DatasetConfig.CHANNELS, *DatasetConfig.IMAGE_SIZE[::-1], device=DEVICE))
201
+
202
+ preprocess = TF.Compose(
203
+ [
204
+ TF.Resize(size=DatasetConfig.IMAGE_SIZE[::-1]),
205
+ TF.ToTensor(),
206
+ TF.Normalize(DatasetConfig.MEAN, DatasetConfig.STD, inplace=True),
207
+ ]
208
+ )
209
+
210
+ images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
211
+ examples = [[i, TrainingConfig.METRIC_THRESH] for i in np.random.choice(images_dir, size=8, replace=False)]
212
+ print(examples)
213
+
214
+ iface = gr.Interface(
215
+ fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE, idx2labels=labels),
216
+ inputs=[
217
+ gr.Image(type="pil", label="Image"),
218
+ gr.Slider(0.0, 1.0, value=0.4, label="Threshold", info="Select the cut-off threshold for a node to be considered as a valid output."),
219
+ ],
220
+ outputs=[
221
+ gr.Textbox(label="Labels Present"),
222
+ gr.Label(label="Probabilities", show_label=False),
223
+ ],
224
+ examples=examples,
225
+ cache_examples=False,
226
+ allow_flagging="never",
227
+ title="Medical Multi-Label Image Classification",
228
+ )
229
+
230
+ iface.launch()
ckpt_024-vloss_0.1816_vf1_0.7855.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeba4764adb310bf3a35ba2479326fdbf38acaed3242a9f020ff2d7eba47b2ca
3
+ size 243578302
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ torch==2.0.0+cpu
3
+ torchvision==0.15.0
4
+ torchmetrics==1.0.0
5
+ lightning==2.0.4
samples/10.png ADDED
samples/10267.png ADDED
samples/10423.png ADDED
samples/116.png ADDED
samples/11603.png ADDED
samples/13698.png ADDED
samples/14311.png ADDED
samples/14546.png ADDED
samples/15528.png ADDED
samples/15561.png ADDED
samples/16150.png ADDED
samples/16312.png ADDED
samples/16411.png ADDED
samples/16621.png ADDED
samples/17289.png ADDED
samples/19682.png ADDED
samples/19884.png ADDED
samples/203.png ADDED
samples/21602.png ADDED
samples/21920.png ADDED
samples/22594.png ADDED
samples/23625.png ADDED
samples/24.png ADDED
samples/24136.png ADDED
samples/24715.png ADDED
samples/24817.png ADDED
samples/25140.png ADDED
samples/2563.png ADDED
samples/25826.png ADDED
samples/26591.png ADDED
samples/2694.png ADDED
samples/27926.png ADDED
samples/28.png ADDED
samples/28661.png ADDED
samples/28983.png ADDED
samples/30258.png ADDED
samples/30809.png ADDED
samples/3282.png ADDED
samples/3665.png ADDED
samples/381.png ADDED
samples/4595.png ADDED
samples/483.png ADDED
samples/4928.png ADDED
samples/497.png ADDED
samples/5378.png ADDED
samples/600.png ADDED