AmelieSchreiber commited on
Commit
b8c42ac
1 Parent(s): d1bcf9c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md CHANGED
@@ -1,3 +1,58 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ # ESM-2 for Post Translational Modification
6
+
7
+ This is a LoRA finetuned version of `esm2_t12_35M_UR50D` for predicting post translational modification sites.
8
+
9
+
10
+ ## Using the Model
11
+
12
+ To use this model, run the following:
13
+
14
+ ```python
15
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
16
+ from peft import PeftModel
17
+ import torch
18
+
19
+ # Path to the saved LoRA model
20
+ model_path = "AmelieSchreiber/esm2_t12_35M_ptm_lora_2100K"
21
+ # ESM2 base model
22
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
23
+
24
+ # Load the model
25
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
26
+ loaded_model = PeftModel.from_pretrained(base_model, model_path)
27
+
28
+ # Ensure the model is in evaluation mode
29
+ loaded_model.eval()
30
+
31
+ # Load the tokenizer
32
+ loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
33
+
34
+ # Protein sequence for inference
35
+ protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
36
+
37
+ # Tokenize the sequence
38
+ inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
39
+
40
+ # Run the model
41
+ with torch.no_grad():
42
+ logits = loaded_model(**inputs).logits
43
+
44
+ # Get predictions
45
+ tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
46
+ predictions = torch.argmax(logits, dim=2)
47
+
48
+ # Define labels
49
+ id2label = {
50
+ 0: "No ptm site",
51
+ 1: "ptm site"
52
+ }
53
+
54
+ # Print the predicted labels for each token
55
+ for token, prediction in zip(tokens, predictions[0].numpy()):
56
+ if token not in ['<pad>', '<cls>', '<eos>']:
57
+ print((token, id2label[prediction]))
58
+ ```