Output of ZymCTRL
Hi Noelia,
it's me, Kuiroscity from the discussion of ProtGTP2. I tried to use ZymCTRL after downloading the models manually using this wrapper script:
#!/usr/bin/env python
#@title Import libraries and initialize model
from transformers import GPT2LMHeadModel, AutoTokenizer
import torch
import os
print("Libraries imported successfully.")
enzyme_class = "1.1.3.4"
device = torch.device('cpu')
tokenizer = AutoTokenizer.from_pretrained('/mnt/e/ZymCTRL/model')
model = GPT2LMHeadModel.from_pretrained('/mnt/e/ZymCTRL/model').to(device)
input_ids = tokenizer.encode(enzyme_class,return_tensors='pt').to(device)
output = model.generate(input_ids, top_k=8, repetition_penalty=1.2, max_length=1024,
eos_token_id=1,pad_token_id=0,do_sample=True, num_return_sequences=100)
print(output)
f = open("output.test", "w")
f.write(output)
f.close()
I have not figured out how to enable GPU on our local desktop (we're running WSL2 on a Windows machine that does have a GPU but have not been able to figure out on how to get the GPU to work in the Ubuntu OS of WLS2); therefore I changed the device
parameter to cpu
instead. The output I got was not the list of sequences but some kind of a matrix.
I guess there was something wrong from my side with the script but I would very much appreciate if you could have a look at it.
Thanks a lot,
Best regards,
Kurioscity
Hi Kurioscity!
Great to hear from you again :)
I'd love to help you with the GPU issue, what is the error you are getting (I might have no idea though)
The output you get is indeed a matrix of tokens. The model encodes each amino acid as a token, so you get a list of tokens per sequence. Then it outputs all the sequences at once, so you get a matrix for tokens. These tokens can be decoded with the tokenizer, here you have an example script:
# 1. Generate sequences
enzyme_class = "1.1.3.4"
input_ids = tokenizer.encode(enzyme_class,return_tensors='pt').to(device)
outputs = model.generate(
input_ids,
top_k=9,
repetition_penalty=1.2,
max_length=1024,
eos_token_id=1,
pad_token_id=0,
do_sample=True,
num_return_sequences=100)
# This step makes sure that the sequences weren't truncated during generation. The last token should be a padding token.
new_outputs = [ output for output in outputs if output[-1] == 0]
if not new_outputs:
print("not enough sequences with short lengths!!")
# To decode the sequences, you have to use the tokenizer:
tokenizer.decode(new_outputs[0]) # for example the first sequence in the list
# You can also decode and compute perplexity for all sequences at once:
def calculatePerplexity(input_ids,model,tokenizer):
'''
Function to compute perplexity
'''
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
return math.exp(loss)
ppls = [(tokenizer.decode(output), calculatePerplexity(output, model, tokenizer)) for output in new_outputs ]
# After this, one possibility is to sort the sequences by perplexity, the lower the better
ppls.sort(key=lambda i:i[1])
Let me know if something throws an error or is unclear.
Best
Noelia
Hi Noelia,
thanks a lot for your help with the decoding of the tokens. Your suggestion did help and I wrote a small script to get the sequences as followed (as you can already see from my code, I'm a seasonal coder with limited knowledge in Python ):
#!/usr/bin/env python
#@title Import libraries and initialize model
from transformers import GPT2LMHeadModel, AutoTokenizer
import torch
import os
from datetime import datetime
print("Libraries imported successfully.")
start_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print(start_time)
enzyme_class = "1.1.3.4"
device = torch.device('cpu')
tokenizer = AutoTokenizer.from_pretrained('/mnt/e/ZymCTRL/model')
model = GPT2LMHeadModel.from_pretrained('/mnt/e/ZymCTRL/model').to(device)
input_ids = tokenizer.encode(enzyme_class,return_tensors='pt').to(device)
outputs = model.generate(input_ids, top_k=8, repetition_penalty=1.2, max_length=1024,
eos_token_id=1,pad_token_id=0,do_sample=True, num_return_sequences=50)
print(outputs)
new_outputs = [output for output in outputs if output[-1] == 0]
if not new_outputs:
print("Not enough sequences with short lengths!!")
fastaname = enzyme_class + '_' + str(start_time)
f = open(fastaname + ".fasta", "w")
fasta_records = []
for count, seq in enumerate (new_outputs, start = 0):
seq = tokenizer.decode(new_outputs[count])
print(seq)
write_seq = seq.replace(' ', '').replace('<pad>', '').replace('<sep>', '').replace('<start>', '').replace(enzyme_class, '').replace('<|endoftext|>', '').replace('<end>', '')
print(write_seq)
fasta_record = ">" + enzyme_class + "_" + start_time + "_" + str(count) + "\n" + write_seq + "\n"
print(fasta_record)
fasta_records.append(fasta_record)
print(fasta_records)
fasta_list = list(map(str, fasta_records))
fasta_file = " ".join(fasta_list).lstrip()
print(fasta_file)
f.write(fasta_file)
f.close()
I will try to include the perplexity calculation as you suggested and I will open a new thread to ask for your help with the GPU thing (or is there a better way/channel to communicate with you the GPU issue as I don't think the issue is very relevant to ZymCTRL?)
Great to hear it helped! Let me know if you need some help with the perplexity. In my experience, it makes a huge different to only select values with high perplexity, e.g, the top 10%.
For the GPU error (which I might have no idea of) I guess you can try here in case it helps other users :) or send me an email at noelia [dot] ferruz [at] udg [dot] edu