Spaces:

aksell commited on
Commit
5b6d16d
1 Parent(s): 4ad80db

WIP: Start adding ProtT5

Browse files
Files changed (2) hide show
  1. hexviz/app.py +1 -0
  2. hexviz/attention.py +14 -4
hexviz/app.py CHANGED
@@ -10,6 +10,7 @@ st.title("pLM Attention Visualization")
10
  # Define list of model types
11
  models = [
12
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
 
13
  ]
14
 
15
  selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
 
10
  # Define list of model types
11
  models = [
12
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
13
+ # Model(name=ModelType.PROT_T5, layers=24, heads=32),
14
  ]
15
 
16
  selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
hexviz/attention.py CHANGED
@@ -48,6 +48,7 @@ def get_sequences(structure: Structure) -> List[str]:
48
  sequences.append(list(residues_single_letter))
49
  return sequences
50
 
 
51
  def get_protT5() -> Tuple[T5Tokenizer, T5EncoderModel]:
52
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
53
  tokenizer = T5Tokenizer.from_pretrained(
@@ -69,7 +70,7 @@ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
69
 
70
  @st.cache
71
  def get_attention(
72
- sequence: List[str], model_type: ModelType = ModelType.TAPE_BERT
73
  ):
74
  if model_type == ModelType.TAPE_BERT:
75
  tokenizer, model = get_tape_bert()
@@ -81,9 +82,18 @@ def get_attention(
81
  attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
82
  attns = torch.stack([attn.squeeze(0) for attn in attns])
83
  elif model_type == ModelType.PROT_T5:
84
- attns = None
85
- # Space separate sequences
86
- sequences = [" ".join(sequence) for sequence in sequences]
 
 
 
 
 
 
 
 
 
87
  tokenizer, model = get_protT5()
88
  else:
89
  raise ValueError(f"Model {model_type} not supported")
 
48
  sequences.append(list(residues_single_letter))
49
  return sequences
50
 
51
+ @st.cache
52
  def get_protT5() -> Tuple[T5Tokenizer, T5EncoderModel]:
53
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
54
  tokenizer = T5Tokenizer.from_pretrained(
 
70
 
71
  @st.cache
72
  def get_attention(
73
+ sequence: str, model_type: ModelType = ModelType.TAPE_BERT
74
  ):
75
  if model_type == ModelType.TAPE_BERT:
76
  tokenizer, model = get_tape_bert()
 
82
  attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
83
  attns = torch.stack([attn.squeeze(0) for attn in attns])
84
  elif model_type == ModelType.PROT_T5:
85
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
86
+ # Introduce white-space between all amino acids
87
+ sequence = " ".join(sequence)
88
+ # tokenize sequences and pad up to the longest sequence in the batch
89
+ ids = tokenizer.encode_plus(sequence, add_special_tokens=True, padding="longest")
90
+
91
+ input_ids = torch.tensor(ids['input_ids']).to(device)
92
+ attention_mask = torch.tensor(ids['attention_mask']).to(device)
93
+
94
+ with torch.no_grad():
95
+ attns = model(input_ids=input_ids,attention_mask=attention_mask)[-1]
96
+
97
  tokenizer, model = get_protT5()
98
  else:
99
  raise ValueError(f"Model {model_type} not supported")