|
import torch |
|
import config |
|
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM |
|
|
|
def load_esm2_model(model_name): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
masked_model = AutoModelForMaskedLM.from_pretrained(model_name) |
|
embedding_model = AutoModel.from_pretrained(model_name) |
|
return tokenizer, masked_model, embedding_model |
|
|
|
def get_latents(model, tokenizer, sequence, device): |
|
inputs = tokenizer(sequence, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs).last_hidden_state.squeeze(0) |
|
return outputs |