LTEnjoy commited on
Commit
616a082
1 Parent(s): 2b951c3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -2
README.md CHANGED
@@ -4,7 +4,7 @@ license: mit
4
  We provide two ways to use SaProt, including through huggingface class and
5
  through the same way as in [esm github](https://github.com/facebookresearch/esm). Users can choose either one to use.
6
 
7
- ### Huggingface model
8
  The following code shows how to load the model.
9
  ```
10
  from transformers import EsmTokenizer, EsmForMaskedLM
@@ -33,11 +33,88 @@ torch.Size([1, 11, 446])
33
  """
34
  ```
35
 
36
- ### esm model
37
  The esm version is also stored in the same folder, named `SaProt_650M_AF2.pt`. We provide a function to load the model.
38
  ```
39
  from utils.esm_loader import load_esm_saprot
40
 
41
  model_path = "/your/path/to/SaProt_650M_AF2.pt"
42
  model, alphabet = load_esm_saprot(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ```
 
4
  We provide two ways to use SaProt, including through huggingface class and
5
  through the same way as in [esm github](https://github.com/facebookresearch/esm). Users can choose either one to use.
6
 
7
+ ## Huggingface model
8
  The following code shows how to load the model.
9
  ```
10
  from transformers import EsmTokenizer, EsmForMaskedLM
 
33
  """
34
  ```
35
 
36
+ ## esm model
37
  The esm version is also stored in the same folder, named `SaProt_650M_AF2.pt`. We provide a function to load the model.
38
  ```
39
  from utils.esm_loader import load_esm_saprot
40
 
41
  model_path = "/your/path/to/SaProt_650M_AF2.pt"
42
  model, alphabet = load_esm_saprot(model_path)
43
+ ```
44
+
45
+ ## Predict mutational effect
46
+ We provide a function to predict the mutational effect of a protein sequence. The example below shows how to predict
47
+ the mutational effect at a specific position. If using the AF2 structure, we strongly recommend that you add pLDDT mask (see below).
48
+ ```python
49
+ from model.saprot.saprot_foldseek_mutation_model import SaprotFoldseekMutationModel
50
+
51
+
52
+ config = {
53
+ "foldseek_path": None,
54
+ "config_path": "/your/path/to/SaProt_650M_AF2", # Note this is the directory path of SaProt, not the ".pt" file
55
+ "load_pretrained": True,
56
+ }
57
+ model = SaprotFoldseekMutationModel(**config)
58
+ tokenizer = model.tokenizer
59
+
60
+ device = "cuda"
61
+ model.eval()
62
+ model.to(device)
63
+
64
+ seq = "M#EvVpQpL#VyQdYaKv" # Here "#" represents lower plDDT regions (plddt < 70)
65
+
66
+ # Predict the effect of mutating the 3rd amino acid to A
67
+ mut_info = "V3A"
68
+ mut_value = model.predict_mut(seq, mut_info)
69
+ print(mut_value)
70
+
71
+ # Predict all effects of mutations at 3rd position
72
+ mut_pos = 3
73
+ mut_dict = model.predict_pos_mut(seq, mut_pos)
74
+ print(mut_dict)
75
+
76
+ # Predict probabilities of all amino acids at 3rd position
77
+ mut_pos = 3
78
+ mut_dict = model.predict_pos_prob(seq, mut_pos)
79
+ print(mut_dict)
80
+
81
+ """
82
+ 0.7908501625061035
83
+
84
+ {'V3A': 0.7908501625061035, 'V3C': -0.9117952585220337, 'V3D': 2.7700226306915283, 'V3E': 2.3255627155303955, 'V3F': 0.2094242423772812, 'V3G': 2.699633836746216, 'V3H': 1.240191102027893, 'V3I': 0.10231903940439224, 'V3K': 1.804598093032837,
85
+ 'V3L': 1.3324960470199585, 'V3M': -0.18938277661800385, 'V3N': 2.8249857425689697, 'V3P': 0.40185314416885376, 'V3Q': 1.8361762762069702, 'V3R': 1.1899691820144653, 'V3S': 2.2159857749938965, 'V3T': 0.8813426494598389, 'V3V': 0.0, 'V3W': 0.5853186249732971, 'V3Y': 0.17449656128883362}
86
+
87
+ {'A': 0.021275954321026802, 'C': 0.0038764977362006903, 'D': 0.15396881103515625, 'E': 0.0987202599644661, 'F': 0.011895398609340191, 'G': 0.14350374042987823, 'H': 0.03334535285830498, 'I': 0.010687196627259254, 'K': 0.058634623885154724, 'L': 0.03656982257962227, 'M': 0.00798324216157198, 'N': 0.16266827285289764, 'P': 0.014419485814869404, 'Q': 0.06051575019955635, 'R': 0.03171204403042793, 'S': 0.08847439289093018, 'T': 0.023291070014238358, 'V': 0.009647775441408157, 'W': 0.017323188483715057, 'Y': 0.011487090960144997}
88
+ """
89
+ ```
90
+
91
+ ## Get protein embeddings
92
+ If you want to generate protein embeddings, you could refer to the following code. The embeddings are the average of
93
+ the hidden states of the last layer.
94
+ ```python
95
+ from model.saprot.base import SaprotBaseModel
96
+ from transformers import EsmTokenizer
97
+
98
+
99
+ config = {
100
+ "task": "base",
101
+ "config_path": "/your/path/to/SaProt_650M_AF2", # Note this is the directory path of SaProt, not the ".pt" file
102
+ "load_pretrained": True,
103
+ }
104
+
105
+ model = SaprotBaseModel(**config)
106
+ tokenizer = EsmTokenizer.from_pretrained(config["config_path"])
107
+
108
+ device = "cuda"
109
+ model.to(device)
110
+
111
+ seq = "M#EvVpQpL#VyQdYaKv" # Here "#" represents lower plDDT regions (plddt < 70)
112
+ tokens = tokenizer.tokenize(seq)
113
+ print(tokens)
114
+
115
+ inputs = tokenizer(seq, return_tensors="pt")
116
+ inputs = {k: v.to(device) for k, v in inputs.items()}
117
+
118
+ embeddings = model.get_hidden_states(inputs, reduction="mean")
119
+ print(embeddings[0].shape)
120
  ```