wangjin2000 commited on
Commit
cebaaeb
·
verified ·
1 Parent(s): 73969e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -97
app.py CHANGED
@@ -150,55 +150,18 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
150
  "weight_decay": 0.2,
151
  # Add other hyperparameters as needed
152
  }
153
- # The base model you will train a LoRA on top of
154
- #base_model_path = "facebook/esm2_t12_35M_UR50D"
155
-
156
- # Define labels and model
157
- #id2label = {0: "No binding site", 1: "Binding site"}
158
- #label2id = {v: k for k, v in id2label.items()}
159
 
160
-
161
  base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
162
-
163
- '''
164
- # Load the data from pickle files (replace with your local paths)
165
- with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
166
- train_sequences = pickle.load(f)
167
-
168
- with open("./datasets/test_sequences_chunked_by_family.pkl", "rb") as f:
169
- test_sequences = pickle.load(f)
170
-
171
- with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f:
172
- train_labels = pickle.load(f)
173
-
174
- with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
175
- test_labels = pickle.load(f)
176
- '''
177
 
178
  # Tokenization
179
  tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
180
- #max_sequence_length = 1000
181
 
182
  train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
183
  test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
184
 
185
- # Directly truncate the entire list of labels
186
- #train_labels = truncate_labels(train_labels, max_sequence_length)
187
- #test_labels = truncate_labels(test_labels, max_sequence_length)
188
-
189
  train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
190
  test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
191
 
192
- '''
193
- # Compute Class Weights
194
- classes = [0, 1]
195
- flat_train_labels = [label for sublist in train_labels for label in sublist]
196
- class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
197
- accelerator = Accelerator()
198
- class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
199
- print(" class_weights:", class_weights)
200
- '''
201
-
202
  # Convert the model into a PeftModel
203
  peft_config = LoraConfig(
204
  task_type=TaskType.TOKEN_CLS,
@@ -217,7 +180,7 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
217
  test_dataset = accelerator.prepare(test_dataset)
218
 
219
  model_name_base = base_model_path.split("/")[1]
220
- timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
221
 
222
  # Training setup
223
  training_args = TrainingArguments(
@@ -262,9 +225,6 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
262
 
263
  # Train and Save Model
264
  trainer.train()
265
- #save_path = os.path.join("lora_binding_sites", f"best_model_esm2_t12_35M_lora_{timestamp}")
266
- #trainer.save_model(save_path)
267
- #tokenizer.save_pretrained(save_path)
268
 
269
  return save_path
270
 
@@ -279,8 +239,8 @@ MODEL_OPTIONS = [
279
  ] # models users can choose from
280
 
281
  PEFT_MODEL_OPTIONS = [
282
- "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3",
283
  "wangjin2000/esm2_t6_8M-lora-binding-sites_2024-07-02_09-26-54",
 
284
  ] # finetuned models
285
 
286
 
@@ -297,21 +257,12 @@ with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f:
297
  with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
298
  test_labels = pickle.load(f)
299
 
300
- ## Tokenization
301
- #tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
302
  max_sequence_length = 1000
303
 
304
- #train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
305
- #test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
306
-
307
  # Directly truncate the entire list of labels
308
  train_labels = truncate_labels(train_labels, max_sequence_length)
309
  test_labels = truncate_labels(test_labels, max_sequence_length)
310
 
311
- #train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
312
- #test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
313
-
314
-
315
  # Compute Class Weights
316
  classes = [0, 1]
317
  flat_train_labels = [label for sublist in train_labels for label in sublist]
@@ -324,48 +275,6 @@ id2label = {0: "No binding site", 1: "Binding site"}
324
  label2id = {v: k for k, v in id2label.items()}
325
 
326
  '''
327
- # inference
328
- # Path to the saved LoRA model
329
- model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
330
- # ESM2 base model
331
- base_model_path = "facebook/esm2_t12_35M_UR50D"
332
-
333
- # Load the model
334
- base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
335
- loaded_model = PeftModel.from_pretrained(base_model, model_path)
336
-
337
- # Ensure the model is in evaluation mode
338
- loaded_model.eval()
339
-
340
- # Protein sequence for inference
341
- protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
342
-
343
- # Tokenize the sequence
344
- inputs = tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
345
-
346
- # Run the model
347
- with torch.no_grad():
348
- logits = loaded_model(**inputs).logits
349
-
350
- # Get predictions
351
- tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
352
- predictions = torch.argmax(logits, dim=2)
353
-
354
-
355
- # Define labels
356
- id2label = {
357
- 0: "No binding site",
358
- 1: "Binding site"
359
- }
360
-
361
- # Print the predicted labels for each token
362
- for token, prediction in zip(tokens, predictions[0].numpy()):
363
- if token not in ['<pad>', '<cls>', '<eos>']:
364
- print((token, id2label[prediction]))
365
-
366
- # train
367
- saved_path = train_function_no_sweeps(base_model_path,train_dataset, test_dataset)
368
-
369
  # debug result
370
  dubug_result = saved_path #predictions #class_weights
371
  '''
@@ -376,12 +285,9 @@ with demo:
376
  gr.Markdown("# DEMO FOR ESM2Bind")
377
  #gr.Textbox(dubug_result)
378
 
379
- #gr.Markdown("## Finetune Pre-trained Model")
380
  with gr.Column():
381
  gr.Markdown("## Select a base model and a corresponding PEFT finetune model")
382
- #gr.Markdown(
383
- # """ Pick a base model and press **Finetune Pre-trained Model!"""
384
- #)
385
  with gr.Row():
386
  with gr.Column(scale=5, variant="compact"):
387
  base_model_name = gr.Dropdown(
@@ -462,6 +368,7 @@ with demo:
462
  inputs=[base_model_name,PEFT_model_name,input_seq],
463
  outputs = [output_text],
464
  )
 
465
  # "Finetune Pre-trained Model" actions
466
  finetune_button.click(
467
  fn = train_function_no_sweeps,
 
150
  "weight_decay": 0.2,
151
  # Add other hyperparameters as needed
152
  }
 
 
 
 
 
 
153
 
 
154
  base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # Tokenization
157
  tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
 
158
 
159
  train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
160
  test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
161
 
 
 
 
 
162
  train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
163
  test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
164
 
 
 
 
 
 
 
 
 
 
 
165
  # Convert the model into a PeftModel
166
  peft_config = LoraConfig(
167
  task_type=TaskType.TOKEN_CLS,
 
180
  test_dataset = accelerator.prepare(test_dataset)
181
 
182
  model_name_base = base_model_path.split("/")[1]
183
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H')
184
 
185
  # Training setup
186
  training_args = TrainingArguments(
 
225
 
226
  # Train and Save Model
227
  trainer.train()
 
 
 
228
 
229
  return save_path
230
 
 
239
  ] # models users can choose from
240
 
241
  PEFT_MODEL_OPTIONS = [
 
242
  "wangjin2000/esm2_t6_8M-lora-binding-sites_2024-07-02_09-26-54",
243
+ "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3",
244
  ] # finetuned models
245
 
246
 
 
257
  with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
258
  test_labels = pickle.load(f)
259
 
 
 
260
  max_sequence_length = 1000
261
 
 
 
 
262
  # Directly truncate the entire list of labels
263
  train_labels = truncate_labels(train_labels, max_sequence_length)
264
  test_labels = truncate_labels(test_labels, max_sequence_length)
265
 
 
 
 
 
266
  # Compute Class Weights
267
  classes = [0, 1]
268
  flat_train_labels = [label for sublist in train_labels for label in sublist]
 
275
  label2id = {v: k for k, v in id2label.items()}
276
 
277
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  # debug result
279
  dubug_result = saved_path #predictions #class_weights
280
  '''
 
285
  gr.Markdown("# DEMO FOR ESM2Bind")
286
  #gr.Textbox(dubug_result)
287
 
 
288
  with gr.Column():
289
  gr.Markdown("## Select a base model and a corresponding PEFT finetune model")
290
+
 
 
291
  with gr.Row():
292
  with gr.Column(scale=5, variant="compact"):
293
  base_model_name = gr.Dropdown(
 
368
  inputs=[base_model_name,PEFT_model_name,input_seq],
369
  outputs = [output_text],
370
  )
371
+
372
  # "Finetune Pre-trained Model" actions
373
  finetune_button.click(
374
  fn = train_function_no_sweeps,