ga89tiy commited on
Commit
0a8703d
1 Parent(s): 420f27f

module device

Browse files
LLAVA_Biovil/biovil_t/encoder.py CHANGED
@@ -81,6 +81,15 @@ def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Se
81
  new_encoder.load_state_dict(self.encoder.state_dict())
82
  self.encoder = new_encoder
83
 
 
 
 
 
 
 
 
 
 
84
 
85
  class MultiImageEncoder(ImageEncoder):
86
  """Multi-image encoder trunk module for the ``ImageModel`` class.
@@ -96,7 +105,7 @@ def __init__(self, img_encoder_type: str):
96
  output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff]
97
  grid_shape = (14, 14) # Spatial dimensions of patch grid.
98
 
99
- backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=torch.device("cuda"))
100
 
101
  self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim,
102
  kernel_size=1, stride=1, padding=0, bias=False)
 
81
  new_encoder.load_state_dict(self.encoder.state_dict())
82
  self.encoder = new_encoder
83
 
84
+ def get_module_device(module: torch.nn.Module) -> torch.device:
85
+ """
86
+ Returns the device of the module
87
+ """
88
+ device = next(module.parameters()).device # type: ignore
89
+ assert isinstance(device, torch.device)
90
+
91
+ return device
92
+
93
 
94
  class MultiImageEncoder(ImageEncoder):
95
  """Multi-image encoder trunk module for the ``ImageModel`` class.
 
105
  output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff]
106
  grid_shape = (14, 14) # Spatial dimensions of patch grid.
107
 
108
+ backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self))
109
 
110
  self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim,
111
  kernel_size=1, stride=1, padding=0, bias=False)
LLAVA_Biovil/biovil_t/model.py CHANGED
@@ -28,6 +28,14 @@ def forward(self, *args: Any, **kwargs: Any) -> ImageModelOutput:
28
  def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
29
  raise NotImplementedError
30
 
 
 
 
 
 
 
 
 
31
 
32
  class ImageModel(BaseImageModel):
33
  """Image encoder module"""
@@ -42,7 +50,7 @@ def __init__(self,
42
 
43
  # Initiate encoder, projector, and classifier
44
  self.encoder = get_encoder_from_type(img_encoder_type)
45
- self.feature_size = get_encoder_output_dim(self.encoder, device=torch.device("cuda"))
46
  self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
47
  hidden_dim=joint_feature_size, use_1x1_convs=True)
48
  self.downstream_classifier_kwargs = downstream_classifier_kwargs
 
28
  def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
29
  raise NotImplementedError
30
 
31
+ def get_module_device(module: torch.nn.Module) -> torch.device:
32
+ """
33
+ Returns the device of the module
34
+ """
35
+ device = next(module.parameters()).device # type: ignore
36
+ assert isinstance(device, torch.device)
37
+
38
+ return device
39
 
40
  class ImageModel(BaseImageModel):
41
  """Image encoder module"""
 
50
 
51
  # Initiate encoder, projector, and classifier
52
  self.encoder = get_encoder_from_type(img_encoder_type)
53
+ self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder))
54
  self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
55
  hidden_dim=joint_feature_size, use_1x1_convs=True)
56
  self.downstream_classifier_kwargs = downstream_classifier_kwargs
findings_classifier/chexpert_train.py CHANGED
@@ -6,18 +6,12 @@ from collections import defaultdict
6
  import numpy as np
7
  import pytorch_lightning as pl
8
  import torch
9
- import wandb
10
- from pytorch_lightning.callbacks import ModelCheckpoint
11
  from sklearn.metrics import accuracy_score, classification_report, jaccard_score, roc_auc_score
12
  from torch.nn import BCEWithLogitsLoss
13
- from torch.utils.data import DataLoader
14
- from torchinfo import summary
15
- from tqdm import tqdm
16
  from transformers import AdamW
17
 
18
- from findings_classifier.chexpert_dataset import Chexpert_Dataset
19
  from findings_classifier.chexpert_model import ChexpertClassifier
20
- from local_config import WANDB_ENTITY
21
 
22
  class ExpandChannels:
23
  """
@@ -54,203 +48,6 @@ class LitIGClassifier(pl.LightningModule):
54
  def forward(self, x):
55
  return self.model(x)
56
 
57
- def step(self, batch, batch_idx):
58
- x, y = batch['image'].to(self.device), batch['labels'].to(self.device)
59
- logits = self(x)
60
- loss = self.criterion(logits, y)
61
-
62
- # Apply sigmoid to get probabilities
63
- preds_probs = torch.sigmoid(logits)
64
-
65
- # Get predictions as boolean values
66
- preds = preds_probs > 0.5
67
-
68
- # calculate jaccard index
69
- jaccard = jaccard_score(y.cpu().numpy(), preds.detach().cpu().numpy(), average='samples')
70
-
71
- class_report = classification_report(y.cpu().numpy(), preds.detach().cpu().numpy(), output_dict=True)
72
- # scores = class_report['micro avg']
73
- scores = class_report['macro avg']
74
- metrics_per_label = {label: metrics for label, metrics in class_report.items() if label.isdigit()}
75
-
76
- f1 = scores['f1-score']
77
- rec = scores['recall']
78
- prec = scores['precision']
79
- acc = accuracy_score(y.cpu().numpy().flatten(), preds.detach().cpu().numpy().flatten())
80
- try:
81
- auc = roc_auc_score(y.cpu().numpy().flatten(), preds_probs.detach().cpu().numpy().flatten())
82
- except Exception as e:
83
- auc = 0.
84
-
85
- return loss, acc, f1, rec, prec, jaccard, auc, metrics_per_label
86
-
87
- def training_step(self, batch, batch_idx):
88
- loss, acc, f1, rec, prec, jaccard, auc, _ = self.step(batch, batch_idx)
89
- train_stats = {'loss': loss, 'train_acc': acc, 'train_f1': f1, 'train_rec': rec, 'train_prec': prec, 'train_jaccard': jaccard,
90
- 'train_auc': auc}
91
- wandb_run.log(train_stats)
92
- return train_stats
93
-
94
- def training_epoch_end(self, outputs):
95
- avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
96
- avg_acc = np.mean([x['train_acc'] for x in outputs])
97
- avg_f1 = np.mean([x['train_f1'] for x in outputs])
98
- avg_rec = np.mean([x['train_rec'] for x in outputs])
99
- avg_prec = np.mean([x['train_prec'] for x in outputs])
100
- avg_jaccard = np.mean([x['train_jaccard'] for x in outputs])
101
- avg_auc = np.mean([x['train_auc'] for x in outputs])
102
- wandb_run.log({'epoch_train_loss': avg_loss, 'epoch_train_acc': avg_acc, 'epoch_train_f1': avg_f1, 'epoch_train_rec': avg_rec,
103
- 'epoch_train_prec': avg_prec, 'epoch_train_jaccard': avg_jaccard, 'epoch_train_auc': avg_auc})
104
-
105
- def validation_step(self, batch, batch_idx):
106
- loss, acc, f1, rec, prec, jaccard, auc, metrics_per_label = self.step(batch, batch_idx)
107
- # log f1 for checkpoint callback
108
- self.log('val_f1', f1)
109
- return {'val_loss': loss, 'val_acc': acc, 'val_f1': f1, 'val_rec': rec, 'val_prec': prec, 'val_jaccard': jaccard,
110
- 'val_auc': auc}, metrics_per_label
111
-
112
- def validation_epoch_end(self, outputs):
113
- outputs, per_label_metrics_outputs = zip(*outputs)
114
- avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
115
- avg_acc = np.mean([x['val_acc'] for x in outputs])
116
- avg_f1 = np.mean([x['val_f1'] for x in outputs])
117
- avg_rec = np.mean([x['val_rec'] for x in outputs])
118
- avg_prec = np.mean([x['val_prec'] for x in outputs])
119
- avg_jaccard = np.mean([x['val_jaccard'] for x in outputs])
120
- avg_auc = np.mean([x['val_auc'] for x in outputs])
121
-
122
- per_label_metrics = defaultdict(lambda: defaultdict(float))
123
- label_counts = defaultdict(int)
124
- for metrics_per_label in per_label_metrics_outputs:
125
- for label, metrics in metrics_per_label.items():
126
- label_name = self.class_names[int(label)]
127
- per_label_metrics[label_name]['precision'] += metrics['precision']
128
- per_label_metrics[label_name]['recall'] += metrics['recall']
129
- per_label_metrics[label_name]['f1-score'] += metrics['f1-score']
130
- per_label_metrics[label_name]['support'] += metrics['support']
131
- label_counts[label_name] += 1
132
-
133
- # Average the metrics
134
- for label, metrics in per_label_metrics.items():
135
- for metric_name in ['precision', 'recall', 'f1-score']:
136
- if metrics['support'] > 0:
137
- per_label_metrics[label][metric_name] /= label_counts[label]
138
-
139
- val_stats = {'val_loss': avg_loss, 'val_acc': avg_acc, 'val_f1': avg_f1, 'val_rec': avg_rec, 'val_prec': avg_prec, 'val_jaccard': avg_jaccard,
140
- 'val_auc': avg_auc}
141
- wandb_run.log(val_stats)
142
-
143
- def test_step(self, batch, batch_idx):
144
- loss, acc, f1, rec, prec, jaccard, auc, _ = self.step(batch, batch_idx)
145
- return {'test_loss': loss, 'test_acc': acc, 'test_f1': f1, 'test_rec': rec, 'test_prec': prec, 'test_jaccard': jaccard, 'test_auc': auc}
146
-
147
- def test_epoch_end(self, outputs):
148
- avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
149
- avg_acc = np.mean([x['test_acc'] for x in outputs])
150
- avg_f1 = np.mean([x['test_f1'] for x in outputs])
151
- avg_rec = np.mean([x['test_rec'] for x in outputs])
152
- avg_prec = np.mean([x['test_prec'] for x in outputs])
153
- avg_jaccard = np.mean([x['test_jaccard'] for x in outputs])
154
- avg_auc = np.mean([x['test_auc'] for x in outputs])
155
-
156
- test_stats = {'test_loss': avg_loss, 'test_acc': avg_acc, 'test_f1': avg_f1, 'test_rec': avg_rec, 'test_prec': avg_prec,
157
- 'test_jaccard': avg_jaccard, 'test_auc': avg_auc}
158
- wandb_run.log(test_stats)
159
-
160
  def configure_optimizers(self):
161
  optimizer = AdamW(self.parameters(), lr=self.learning_rate)
162
  return optimizer
163
-
164
-
165
- def save_preds(dataloader, split):
166
- # load checkpoint
167
- ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier-epoch=06-val_f1=0.36.ckpt"
168
- model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=num_classes, class_weights=val_dataset.get_class_weights(),
169
- class_names=class_names, learning_rate=args.lr)
170
- model.eval()
171
- model.cuda()
172
- model.half()
173
- class_names_np = np.asarray(class_names)
174
-
175
- # get predictions for all study ids
176
- structured_preds = {}
177
- for batch in tqdm(dataloader):
178
- dicom_ids = batch['dicom_id']
179
- logits = model(batch['image'].half().cuda())
180
- preds_probs = torch.sigmoid(logits)
181
- preds = preds_probs > 0.5
182
-
183
- # iterate over each study id in the batch
184
- for i, (dicom_id, pred) in enumerate(zip(dicom_ids, preds.detach().cpu())):
185
- # get all positive labels
186
- findings = class_names_np[pred].tolist()
187
- structured_preds[dicom_id] = findings
188
-
189
- # save predictions
190
- with open(f"findings_classifier/predictions/structured_preds_chexpert_log_weighting_macro_{split}.json", "w") as f:
191
- json.dump(structured_preds, f, indent=4)
192
-
193
-
194
- if __name__ == '__main__':
195
- parser = argparse.ArgumentParser()
196
- parser.add_argument("--run_name", type=str, default="debug")
197
- parser.add_argument("--lr", type=float, default=5e-5)
198
- parser.add_argument("--epochs", type=int, default=6)
199
- parser.add_argument("--loss_weighting", type=str, default="log", choices=["lin", "log", "none"])
200
- parser.add_argument("--truncate", type=int, default=None)
201
- parser.add_argument("--batch_size", type=int, default=64)
202
- parser.add_argument("--num_workers", type=int, default=12)
203
- parser.add_argument("--use_augs", action="store_true", default=False)
204
- parser.add_argument("--train", action="store_true", default=False)
205
- args = parser.parse_args()
206
-
207
- TRAIN = args.train
208
-
209
- # fix all seeds
210
- pl.seed_everything(42, workers=True)
211
-
212
- # Create DataLoaders
213
- train_dataset = Chexpert_Dataset(split='train', truncate=args.truncate, loss_weighting=args.loss_weighting, use_augs=args.use_augs)
214
- val_dataset = Chexpert_Dataset(split='validate', truncate=args.truncate)
215
- test_dataset = Chexpert_Dataset(split='test')
216
-
217
- train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
218
- val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
219
- test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
220
-
221
- # Number of classes for IGClassifier
222
- num_classes = len(train_dataset.chexpert_cols)
223
- class_names = train_dataset.chexpert_cols
224
-
225
- if TRAIN:
226
- class_weights = torch.tensor(train_dataset.get_class_weights(), dtype=torch.float32)
227
- # Define the model
228
- lit_model = LitIGClassifier(num_classes, class_weights, class_names, learning_rate=args.lr)
229
- print(summary(lit_model))
230
-
231
- # WandB logger
232
- wandb_run = wandb.init(
233
- project="ChexpertClassifier",
234
- entity= WANDB_ENTITY,
235
- name=args.run_name
236
- )
237
-
238
- # checkpoint callback
239
- checkpoint_callback = ModelCheckpoint(
240
- monitor='val_f1',
241
- dirpath=f'findings_classifier/checkpoints/{args.run_name}',
242
- filename='ChexpertClassifier-{epoch:02d}-{val_f1:.2f}',
243
- save_top_k=1,
244
- save_last=True,
245
- mode='max',
246
- )
247
- # Train the model
248
- trainer = pl.Trainer(max_epochs=args.epochs, gpus=1, callbacks=[checkpoint_callback], benchmark=False, deterministic=True, precision=16)
249
- trainer.fit(lit_model, train_dataloader, val_dataloader)
250
-
251
- # Test the model
252
- # trainer.validate(lit_model, val_dataloader, ckpt_path="checkpoints_IGCLassifier/lr_5e-5_to0_log_weighting_patches_augs_imgemb/IGClassifier-epoch=09-val_f1=0.65.ckpt")
253
- else:
254
- save_preds(train_dataloader, "train")
255
- save_preds(val_dataloader, "val")
256
- save_preds(test_dataloader, "test")
 
6
  import numpy as np
7
  import pytorch_lightning as pl
8
  import torch
 
 
9
  from sklearn.metrics import accuracy_score, classification_report, jaccard_score, roc_auc_score
10
  from torch.nn import BCEWithLogitsLoss
11
+
 
 
12
  from transformers import AdamW
13
 
 
14
  from findings_classifier.chexpert_model import ChexpertClassifier
 
15
 
16
  class ExpandChannels:
17
  """
 
48
  def forward(self, x):
49
  return self.model(x)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def configure_optimizers(self):
52
  optimizer = AdamW(self.parameters(), lr=self.learning_rate)
53
  return optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -2,7 +2,7 @@ import numpy as np
2
  import torch
3
  from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms
4
 
5
- from huggingface.findings_classifier.chexpert_train import LitIGClassifier
6
 
7
 
8
  class ExpandChannels:
 
2
  import torch
3
  from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms
4
 
5
+ from findings_classifier.chexpert_train import LitIGClassifier
6
 
7
 
8
  class ExpandChannels: