import nltk import torch from summarizer import Summarizer from sumy.nlp.tokenizers import Tokenizer from sumy.summarizers.lsa import LsaSummarizer from sumy.parsers.plaintext import PlaintextParser from sumy.summarizers.lex_rank import LexRankSummarizer from sumy.summarizers.sum_basic import SumBasicSummarizer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM nltk.download('punkt') def extractive(method, file): sumarizer = method sentences_ = [] doc_ = PlaintextParser(file, Tokenizer("en")).document for sentence in sumarizer(doc_, 5): sentences_.append(str(sentence)) summm_ = " ".join(sentences_) return summm_ def summarize(file, model): with open(file.name) as f: doc = f.read() if model == "Pegasus": checkpoint = "google/pegasus-billsum" tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) inputs = tokenizer(doc, max_length=1024, truncation=True, return_tensors="pt") summary_ids = model.generate(inputs["input_ids"]) summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) summary = summary[0] elif model == "LEDBill": tokenizer = AutoTokenizer.from_pretrained("d0r1h/LEDBill") model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/LEDBill", return_dict_in_generate=True) input_ids = tokenizer(doc, return_tensors="pt").input_ids global_attention_mask = torch.zeros_like(input_ids) global_attention_mask[:, 0] = 1 sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences summary = tokenizer.batch_decode(sequences, skip_special_tokens=True) summary = summary[0] elif model == "ILC": tokenizer = AutoTokenizer.from_pretrained("d0r1h/led-base-ilc") model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/led-base-ilc", return_dict_in_generate=True) input_ids = tokenizer(doc, return_tensors="pt").input_ids global_attention_mask = torch.zeros_like(input_ids) global_attention_mask[:, 0] = 1 sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences summary = tokenizer.batch_decode(sequences, skip_special_tokens=True) summary = summary[0] elif model == "TextRank": summary = extractive(LexRankSummarizer(), doc) elif model == "SumBasic": summary = extractive(SumBasicSummarizer(), doc) elif model == "Lsa": summary = extractive(LsaSummarizer(), doc) elif model == "BERT": modelbert = Summarizer('distilbert-base-uncased', hidden=[-1,-2], hidden_concat=True) result = modelbert(doc) summary = ''.join(result) return summary