ga89tiy commited on
Commit
a697138
1 Parent(s): 56f4e99
LLAVA_Biovil/biovil_t/encoder.py CHANGED
@@ -10,7 +10,6 @@
10
 
11
  import torch
12
  import torch.nn as nn
13
- from health_multimodal.common.device import get_module_device
14
  from timm.models.layers import trunc_normal_
15
 
16
  from .resnet import resnet18, resnet50
@@ -97,7 +96,7 @@ def __init__(self, img_encoder_type: str):
97
  output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff]
98
  grid_shape = (14, 14) # Spatial dimensions of patch grid.
99
 
100
- backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self))
101
 
102
  self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim,
103
  kernel_size=1, stride=1, padding=0, bias=False)
 
10
 
11
  import torch
12
  import torch.nn as nn
 
13
  from timm.models.layers import trunc_normal_
14
 
15
  from .resnet import resnet18, resnet50
 
96
  output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff]
97
  grid_shape = (14, 14) # Spatial dimensions of patch grid.
98
 
99
+ backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=torch.device("cuda"))
100
 
101
  self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim,
102
  kernel_size=1, stride=1, padding=0, bias=False)
LLAVA_Biovil/biovil_t/model.py CHANGED
@@ -12,7 +12,6 @@
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
- from health_multimodal.common.device import get_module_device
16
 
17
  from .encoder import get_encoder_from_type, get_encoder_output_dim, MultiImageEncoder
18
  from .modules import MLP, MultiTaskModel
@@ -43,7 +42,7 @@ def __init__(self,
43
 
44
  # Initiate encoder, projector, and classifier
45
  self.encoder = get_encoder_from_type(img_encoder_type)
46
- self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder))
47
  self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
48
  hidden_dim=joint_feature_size, use_1x1_convs=True)
49
  self.downstream_classifier_kwargs = downstream_classifier_kwargs
 
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
 
15
 
16
  from .encoder import get_encoder_from_type, get_encoder_output_dim, MultiImageEncoder
17
  from .modules import MLP, MultiTaskModel
 
42
 
43
  # Initiate encoder, projector, and classifier
44
  self.encoder = get_encoder_from_type(img_encoder_type)
45
+ self.feature_size = get_encoder_output_dim(self.encoder, device=torch.device("cuda"))
46
  self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
47
  hidden_dim=joint_feature_size, use_1x1_convs=True)
48
  self.downstream_classifier_kwargs = downstream_classifier_kwargs
__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.69 kB). View file
 
findings_classifier/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (183 Bytes). View file
 
findings_classifier/__pycache__/chexpert_dataset.cpython-310.pyc ADDED
Binary file (5.95 kB). View file
 
findings_classifier/__pycache__/chexpert_model.cpython-310.pyc ADDED
Binary file (1.09 kB). View file
 
findings_classifier/__pycache__/chexpert_train.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
simple_test.py CHANGED
@@ -1,6 +1,5 @@
1
  from pathlib import Path
2
 
3
- from skimage import io as io_img
4
  import io
5
 
6
  import requests
@@ -14,26 +13,45 @@ from LLAVA_Biovil.llava.model.builder import load_pretrained_model
14
  from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1
15
 
16
  from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
17
- from utils import create_chest_xray_transform_for_inference
18
 
19
- def load_model_from_huggingface(repo_id, model_filename):
 
20
  # Download model files
21
- model_path = snapshot_download(repo_id=repo_id, revision="main")
22
- model_path = Path(model_path) / model_filename
23
 
24
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
25
  model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
26
 
 
27
  return tokenizer, model, image_processor, context_len
28
 
 
 
29
  if __name__ == '__main__':
30
  # config = None
31
  # model_path = "/home/guests/chantal_pellegrini/RaDialog_LLaVA/LLAVA/checkpoints/llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5/checkpoint-21000" #TODO hardcoded in huggingface repo probably
32
  # model_name = get_model_name_from_path(model_path)
33
- tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation", model_filename="model")
 
 
 
 
 
 
 
 
 
34
  model.config.tokenizer_padding_side = "left"
35
 
36
- findings = "edema, pleural effusion" #TODO should these come from chexpert classifier? Or not needed for this demo/test?
 
 
 
 
 
 
37
 
38
  conv = conv_vicuna_v1.copy()
39
  REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
@@ -44,12 +62,6 @@ if __name__ == '__main__':
44
 
45
  # get the image
46
  vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
47
- sample_img_path = "https://openi.nlm.nih.gov/imgs/512/10/10/CXR10_IM-0002-2001.png?keywords=Calcified%20Granuloma" #TODO find good image
48
-
49
- response = requests.get(sample_img_path)
50
- image = Image.open(io.BytesIO(response.content))
51
- image = remap_to_uint8(np.array(image))
52
- image = Image.fromarray(image).convert("L")
53
  image_tensor = vis_transforms_biovil(image).unsqueeze(0)
54
 
55
  image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
 
1
  from pathlib import Path
2
 
 
3
  import io
4
 
5
  import requests
 
13
  from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1
14
 
15
  from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
16
+ from utils import create_chest_xray_transform_for_inference, init_chexpert_predictor
17
 
18
+
19
+ def load_model_from_huggingface(repo_id):
20
  # Download model files
21
+ model_path = snapshot_download(repo_id=repo_id, revision="main", force_download=True)
22
+ model_path = Path(model_path)
23
 
24
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
25
  model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
26
 
27
+
28
  return tokenizer, model, image_processor, context_len
29
 
30
+
31
+
32
  if __name__ == '__main__':
33
  # config = None
34
  # model_path = "/home/guests/chantal_pellegrini/RaDialog_LLaVA/LLAVA/checkpoints/llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5/checkpoint-21000" #TODO hardcoded in huggingface repo probably
35
  # model_name = get_model_name_from_path(model_path)
36
+ sample_img_path = "https://openi.nlm.nih.gov/imgs/512/10/10/CXR10_IM-0002-2001.png?keywords=Calcified%20Granuloma" #TODO find good image
37
+
38
+ response = requests.get(sample_img_path)
39
+ image = Image.open(io.BytesIO(response.content))
40
+ image = remap_to_uint8(np.array(image))
41
+ image = Image.fromarray(image).convert("L")
42
+
43
+ tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation")
44
+ cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()
45
+
46
  model.config.tokenizer_padding_side = "left"
47
 
48
+ cp_image = cp_transforms(image)
49
+ logits = cp_model(cp_image[None].half().cuda())
50
+ preds_probs = torch.sigmoid(logits)
51
+ preds = preds_probs > 0.5
52
+ pred = preds[0].cpu().numpy()
53
+ findings = cp_class_names[pred].tolist()
54
+ findings = ', '.join(findings).lower().strip()
55
 
56
  conv = conv_vicuna_v1.copy()
57
  REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
 
62
 
63
  # get the image
64
  vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
 
 
 
 
 
 
65
  image_tensor = vis_transforms_biovil(image).unsqueeze(0)
66
 
67
  image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
utils.py CHANGED
@@ -2,6 +2,9 @@ import numpy as np
2
  import torch
3
  from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms
4
 
 
 
 
5
  class ExpandChannels:
6
  """
7
  Transforms an image with one channel to an image with three channels by copying
@@ -60,3 +63,20 @@ def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray:
60
  array /= array.max()
61
  array *= 255
62
  return array.astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms
4
 
5
+ from huggingface.findings_classifier.chexpert_train import LitIGClassifier
6
+
7
+
8
  class ExpandChannels:
9
  """
10
  Transforms an image with one channel to an image with three channels by copying
 
63
  array /= array.max()
64
  array *= 255
65
  return array.astype(np.uint8)
66
+
67
+ def init_chexpert_predictor():
68
+ ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier.ckpt"
69
+ chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum",
70
+ "Cardiomegaly", "Lung Opacity",
71
+ "Lung Lesion", "Edema",
72
+ "Consolidation", "Pneumonia",
73
+ "Atelectasis", "Pneumothorax",
74
+ "Pleural Effusion", "Pleural Other",
75
+ "Fracture", "Support Devices"]
76
+ model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=14, class_names=chexpert_cols, strict=False)
77
+ model.eval()
78
+ model.cuda()
79
+ model.half()
80
+ cp_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()])
81
+
82
+ return model, np.asarray(model.class_names), cp_transforms