{ "cells": [ { "cell_type": "markdown", "id": "22da4371", "metadata": {}, "source": [ "# Summarizer\n", "This script is used for summarizing Czech news texts as well as for generating news headlines or abstracts. It can be considered as a demonstration for the application of our summarization models.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "99dc0be0", "metadata": {}, "outputs": [], "source": [ "#dependencies\n", "import torch as pt\n", "import numpy as np\n", "\n", "from collections import OrderedDict\n", "\n", "from transformers import AutoModelForSeq2SeqLM\n", "from transformers import AutoTokenizer\n", "\n", "from sentence_splitter import SentenceSplitter, split_text_into_sentences\n", "\n", "#init Summarizer\n", "#comment cuda and delete .to(cuda) if using cpuUse\n", "class Summarizer:\n", " def __init__(self,model, tokenizer, inference_cfg):\n", " self.model = model\n", " self.model.cuda()\n", " self.tokenizer = tokenizer\n", " self.inference_cfg = inference_cfg\n", " self.enc_max_len = 512\n", " \n", " #tokenize & summarize input texts\n", " def __call__(self, texts, inference_cfg=None):\n", " if type(texts) == str:\n", " texts = [texts]\n", " assert type(texts) == list and type(texts[0]) == str, \"Expected string or list of strings\"\n", " summaries = []\n", " self.inference_cfg = inference_cfg if inference_cfg is not None else self.inference_cfg\n", " for text in texts:\n", " text = self.tokenizer.eos_token.join(SentenceSplitter(language='cs').split(text))\n", " ttext = self.tokenizer(text,max_length = self.enc_max_len, truncation=True, padding=\"max_length\",return_tensors=\"pt\")\n", " summaries.append(self._summarize(ttext,**self.inference_cfg)[0])\n", " return summaries\n", " \n", " #summarize batch of data\n", " def _summarize(self, data, num_beams=1, do_sample=False, \n", " top_k=50, \n", " top_p=1.0,\n", " temperature=1.0,\n", " repetition_penalty=1.0,\n", " no_repeat_ngram_size = None,\n", " max_length=1024,\n", " min_length=10,\n", " decode_decoder_ids = False,\n", " early_stopping = False,**kwargs):\n", " summary = model.generate(input_ids=data[\"input_ids\"].to(\"cuda\"),attention_mask=data[\"attention_mask\"].to(\"cuda\"),\n", " num_beams= num_beams,\n", " do_sample= do_sample,\n", " top_k=top_k,\n", " top_p=top_p,\n", " temperature=temperature,\n", " repetition_penalty=repetition_penalty,\n", " max_length=max_length,\n", " min_length=min_length,\n", " early_stopping=early_stopping,\n", " forced_bos_token_id=tokenizer.lang_code_to_id['cs_CZ'])\n", " return self.tokenizer.batch_decode(summary,skip_special_tokens=True)\n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "24195d3c", "metadata": {}, "source": [ "# Use\n", "- Load Czech summarization model from https://huggingface.co./krotima1\n", "- Summarize Czech news texts\n", "- Play with summarization parameters" ] }, { "cell_type": "code", "execution_count": 2, "id": "343ddf97", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "915f91449a8a43458945118791b2654a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/480 [00:00