damerajee commited on
Commit
f2de4bf
·
verified ·
1 Parent(s): d9d60c5

Update vision_encoder.py

Browse files
Files changed (1) hide show
  1. vision_encoder.py +6 -6
vision_encoder.py CHANGED
@@ -1,10 +1,10 @@
1
- import torch.nn as nn
2
- from transformers import ViTModel
3
  from torchvision import transforms
4
- import torch
5
 
6
  import transformers
7
 
 
8
  transformers.logging.set_verbosity_error()
9
 
10
  class VisionEncoder(nn.Module):
@@ -17,9 +17,9 @@ class VisionEncoder(nn.Module):
17
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
  ])
19
 
20
- def forward(self, image, device):
21
- processed_images = torch.stack([self.image_transform(image) for image in images]).to(device)
22
  with torch.no_grad():
23
- pixel_values = self.vision_model(processed_image)
24
  image_features = pixel_values.last_hidden_state
25
  return image_features
 
1
+ from transformers import ViTModel
 
2
  from torchvision import transforms
3
+ import torch
4
 
5
  import transformers
6
 
7
+
8
  transformers.logging.set_verbosity_error()
9
 
10
  class VisionEncoder(nn.Module):
 
17
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
  ])
19
 
20
+ def forward(self, images,device):
21
+ processed_images = torch.stack([self.image_transform(image) for image in images]).to(device)
22
  with torch.no_grad():
23
+ pixel_values = self.vision_model(processed_images)
24
  image_features = pixel_values.last_hidden_state
25
  return image_features