|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
# description |
|
DNA sequence summary model |
|
|
|
pretrained by DNA + ENG text |
|
|
|
finetuned by English summary data |
|
|
|
test the Multilingual transfer ability from ENG to DNA |
|
|
|
# code example |
|
|
|
<!-- DNA sequence summary. --> |
|
|
|
```python |
|
from datasets import load_dataset |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling |
|
from transformers import AutoTokenizer, DataCollatorWithPadding |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import math |
|
from transformers import LogitsProcessorList, LogitsProcessor |
|
import torch |
|
|
|
|
|
# 加载 GPT-2 分词器 |
|
tokenizer = AutoTokenizer.from_pretrained("dnagpt/gene_eng_gpt2_summary") |
|
tokenizer.pad_token = tokenizer.eos_token # 设置填充标记为 EOS 标记 |
|
|
|
# 6. 加载 GPT-2 模型 |
|
model = GPT2LMHeadModel.from_pretrained("dnagpt/gene_eng_gpt2_summary") |
|
model.config.pad_token_id = model.config.eos_token_id |
|
|
|
def classify_sequence(sequence): |
|
# 定义字符集(所有字符都假设为大写) |
|
dna_chars = set('ACGT') |
|
protein_chars = set('ACDEFGHIKLMNPQRSTVWY') |
|
english_chars = set('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789 ,.!?:;-"\'()') |
|
|
|
# 去除空格并检查长度 |
|
sequence = sequence.strip() # |
|
|
|
# 检查是否为DNA序列 |
|
if all(c in dna_chars for c in sequence): |
|
return "DNA" |
|
|
|
# 检查是否为蛋白质序列 |
|
if all(c in protein_chars for c in sequence): |
|
return "Protein" |
|
|
|
# 检查是否为英文文本(允许大小写字母、数字及常见标点符号) |
|
if all(c in english_chars for c in sequence): |
|
return "English" |
|
|
|
# 如果不符合上述任何条件,则无法明确分类 |
|
return "Unknown" |
|
|
|
#获得DNA和英文词表 只要长度2个及以上的词 |
|
word_dict = tokenizer.get_vocab() |
|
|
|
DNA_token_list = [] |
|
|
|
for word in word_dict: |
|
word_type = classify_sequence(word) |
|
if "DNA"==word_type: |
|
DNA_token_list.append(word) |
|
|
|
|
|
class DNAOnlyLogitsProcessor(LogitsProcessor): |
|
def __init__(self, allowed_tokens, tokenizer): |
|
self.allowed_token_ids = tokenizer.convert_tokens_to_ids(allowed_tokens) |
|
|
|
def __call__(self, input_ids, scores): |
|
# 创建掩码,将不允许的 token 的分数设为 -inf |
|
mask = torch.full_like(scores, float("-inf")) |
|
mask[:, self.allowed_token_ids] = 0 |
|
scores += mask |
|
return scores |
|
|
|
def get_summary_with_constraints(text, DNA_token_list): |
|
# 确保输入文本的预处理 |
|
text = text.strip() + " TL;DR:" |
|
|
|
# 对输入文本进行编码 |
|
encoded_input = tokenizer( |
|
text, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=256, # 输入文本的最大长度 |
|
) |
|
|
|
# 创建 DNA 限制的 LogitsProcessor |
|
logits_processor = LogitsProcessorList([ |
|
DNAOnlyLogitsProcessor(DNA_token_list, tokenizer) |
|
]) |
|
|
|
# 使用 max_new_tokens 控制生成长度 |
|
output = model.generate( |
|
input_ids=encoded_input["input_ids"], |
|
attention_mask=encoded_input["attention_mask"], |
|
max_new_tokens=16, # 控制生成的新增文本长度 |
|
num_beams=5, # 控制生成文本的多样性 |
|
logits_processor=logits_processor, |
|
no_repeat_ngram_size=3, # 避免生成重复内容 |
|
early_stopping=True, # 提前终止生成 |
|
) |
|
|
|
# 对生成的输出进行解码 |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
# 提取生成的摘要部分 |
|
summary = generated_text[len(text)+len(encoded_input["input_ids"][0])-1:].strip() #字符长度+多出来的空格-1 |
|
|
|
return summary |
|
|
|
# 示例用法 |
|
#test_text = "The DNA sequence analysis showed remarkable results." |
|
test_text = "GTTATAACCTGTGAGAGTATGTTGGCGGTTTGTTGCACCTACCTTTCAAACCTCTTGTTCTTCCTGTGATTTATTTGAGGCACTCAAGTGGACAGAGACCATGAGAAATTTGAGTGGAGGCCATGTCGAAGAGTTTGTCTTGGTGGGTTTCCCTACCACTCCTCCCTTCCAGCTGCTCCTCTTTGTCCTTTTCTTTGCAATTTACCTTCTGACATTGTTGGAGAATGCACTCATTGTCTTCACAATATGGCTCACTCCAAGCCTTCATCGCCCCATGTACTTTTTCCTTGGCCATCTTTCTTTCCTGGAGCTTTGGTACATCAACGTCACCATTCCTCAGCTCTTGGCAGCCTTTCTTACCCAGGATAGTAGAGTCTCCTATGTAGGTTGCATGACCCAACTCTACTTCTTTATTGCCTTAGCCTGTACTGAATGTGTGCTGTTGGCAGTTATGGCCTATGACCGC" |
|
|
|
print(get_summary_with_constraints(test_text, DNA_token_list)) |
|
|
|
``` |
|
|
|
|