hidehisa-arai commited on
Commit
391228d
1 Parent(s): c66a69b
Files changed (2) hide show
  1. README.md +12 -7
  2. modeling_japanese_clip.py +12 -3
README.md CHANGED
@@ -7,7 +7,6 @@ tags:
7
  - clip
8
  - japanese-clip
9
  ---
10
-
11
  # recruit-jp/japanese-clip-vit-b-32-roberta-base
12
 
13
  ## Overview
@@ -41,17 +40,19 @@ pip install pillow requests transformers torch torchvision sentencepiece
41
  ```python
42
  import io
43
  import requests
44
- from PIL import Image
45
 
46
  import torch
47
  import torchvision
 
48
  from transformers import AutoTokenizer, AutoModel
49
 
 
50
  model_name = "recruit-jp/japanese-clip-vit-b-32-roberta-base"
51
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
52
  tokenizer = AutoTokenizer.from_pretrained(model_name)
53
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
54
 
 
55
  def _convert_to_rgb(image):
56
  return image.convert('RGB')
57
 
@@ -68,25 +69,29 @@ preprocess = torchvision.transforms.Compose([
68
  def tokenize(tokenizer, texts):
69
  texts = ["[CLS]" + text for text in texts]
70
  encodings = [
 
71
  tokenizer(text, max_length=77, padding="max_length", truncation=True, add_special_tokens=False)["input_ids"]
72
  for text in texts
73
  ]
74
  return torch.LongTensor(encodings)
75
 
 
76
  # Run!
77
- image = Image.open(io.BytesIO(requests.get('https://images.pexels.com/photos/2253275/pexels-photo-2253275.jpeg?auto=compress&cs=tinysrgb&dpr=3&h=750&w=1260').content))
 
 
 
 
 
 
78
  image = preprocess(image).unsqueeze(0).to(device)
79
  text = tokenize(tokenizer, texts=["犬", "猫", "象"]).to(device)
80
-
81
  with torch.inference_mode():
82
  image_features = model.get_image_features(image)
83
  image_features /= image_features.norm(dim=-1, keepdim=True)
84
-
85
  text_features = model.get_text_features(input_ids=text)
86
  text_features /= text_features.norm(dim=-1, keepdim=True)
87
-
88
  probs = image_features @ text_features.T
89
-
90
  print("Label probs:", probs.cpu().numpy()[0])
91
  ```
92
 
 
7
  - clip
8
  - japanese-clip
9
  ---
 
10
  # recruit-jp/japanese-clip-vit-b-32-roberta-base
11
 
12
  ## Overview
 
40
  ```python
41
  import io
42
  import requests
 
43
 
44
  import torch
45
  import torchvision
46
+ from PIL import Image
47
  from transformers import AutoTokenizer, AutoModel
48
 
49
+
50
  model_name = "recruit-jp/japanese-clip-vit-b-32-roberta-base"
51
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
52
  tokenizer = AutoTokenizer.from_pretrained(model_name)
53
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
54
 
55
+
56
  def _convert_to_rgb(image):
57
  return image.convert('RGB')
58
 
 
69
  def tokenize(tokenizer, texts):
70
  texts = ["[CLS]" + text for text in texts]
71
  encodings = [
72
+ # NOTE: the maximum token length that can be fed into this model is 77
73
  tokenizer(text, max_length=77, padding="max_length", truncation=True, add_special_tokens=False)["input_ids"]
74
  for text in texts
75
  ]
76
  return torch.LongTensor(encodings)
77
 
78
+
79
  # Run!
80
+ image = Image.open(
81
+ io.BytesIO(
82
+ requests.get(
83
+ 'https://images.pexels.com/photos/2253275/pexels-photo-2253275.jpeg?auto=compress&cs=tinysrgb&dpr=3&h=750&w=1260'
84
+ ).content
85
+ )
86
+ )
87
  image = preprocess(image).unsqueeze(0).to(device)
88
  text = tokenize(tokenizer, texts=["犬", "猫", "象"]).to(device)
 
89
  with torch.inference_mode():
90
  image_features = model.get_image_features(image)
91
  image_features /= image_features.norm(dim=-1, keepdim=True)
 
92
  text_features = model.get_text_features(input_ids=text)
93
  text_features /= text_features.norm(dim=-1, keepdim=True)
 
94
  probs = image_features @ text_features.T
 
95
  print("Label probs:", probs.cpu().numpy()[0])
96
  ```
97
 
modeling_japanese_clip.py CHANGED
@@ -84,7 +84,9 @@ class AttentionalPooler(nn.Module):
84
  ):
85
  super().__init__()
86
  self.query = nn.Parameter(torch.randn(n_queries, d_model))
87
- self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
 
 
88
  self.ln_q = norm_layer(d_model)
89
  self.ln_k = norm_layer(context_dim)
90
 
@@ -92,7 +94,9 @@ class AttentionalPooler(nn.Module):
92
  x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
93
  N = x.shape[1]
94
  q = self.ln_q(self.query)
95
- out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0]
 
 
96
  return out.permute(1, 0, 2) # LND -> NLD
97
 
98
 
@@ -187,7 +191,12 @@ class Transformer(nn.Module):
187
 
188
  self.resblocks = nn.ModuleList([
189
  ResidualAttentionBlock(
190
- width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
 
 
 
 
 
191
  for _ in range(layers)
192
  ])
193
 
 
84
  ):
85
  super().__init__()
86
  self.query = nn.Parameter(torch.randn(n_queries, d_model))
87
+ self.attn = nn.MultiheadAttention(
88
+ d_model, n_head, kdim=context_dim, vdim=context_dim
89
+ )
90
  self.ln_q = norm_layer(d_model)
91
  self.ln_k = norm_layer(context_dim)
92
 
 
94
  x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
95
  N = x.shape[1]
96
  q = self.ln_q(self.query)
97
+ out = self.attn(
98
+ q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False
99
+ )[0]
100
  return out.permute(1, 0, 2) # LND -> NLD
101
 
102
 
 
191
 
192
  self.resblocks = nn.ModuleList([
193
  ResidualAttentionBlock(
194
+ width,
195
+ heads,
196
+ mlp_ratio,
197
+ ls_init_value=ls_init_value,
198
+ act_layer=act_layer,
199
+ norm_layer=norm_layer)
200
  for _ in range(layers)
201
  ])
202