import torch import config from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM def load_esm2_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) masked_model = AutoModelForMaskedLM.from_pretrained(model_name) embedding_model = AutoModel.from_pretrained(model_name) return tokenizer, masked_model, embedding_model def get_latents(model, tokenizer, sequence, device): inputs = tokenizer(sequence, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs).last_hidden_state.squeeze(0) return outputs