import torch import numpy as np import gradio as gr from nltk import sent_tokenize from transformers import RobertaTokenizer, RobertaForMaskedLM cuda = torch.cuda.is_available() tokenizer = RobertaTokenizer.from_pretrained("roberta-large") model = RobertaForMaskedLM.from_pretrained("roberta-large") if cuda: model = model.cuda() max_len = 20 top_k = 100 temperature = 1 burnin = 250 max_iter = 500 # adapted from https://github.com/nyu-dl/bert-gen def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True): """ Generate a word from from out[gen_idx] args: - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size - gen_idx (int): location for which to generate for - top_k (int): if >0, only sample from the top k most probable words - sample (Bool): if True, sample from full distribution. Overridden by top_k """ logits = out.logits[:, gen_idx] if temperature is not None: logits = logits / temperature if top_k > 0: kth_vals, kth_idx = logits.topk(top_k, dim=-1) dist = torch.distributions.categorical.Categorical(logits=kth_vals) idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1) elif sample: dist = torch.distributions.categorical.Categorical(logits=logits) idx = dist.sample() # removed superfluous squeeze(-1) else: idx = torch.argmax(logits, dim=-1) return idx.tolist() if return_list else idx # adapted from https://github.com/nyu-dl/bert-gen def parallel_sequential_generation(seed_text, seed_end_text, max_len=max_len, top_k=top_k, temperature=temperature, max_iter=max_iter, burnin=burnin): """ Generate for one random position at a timestep args: - burnin: during burn-in period, sample from full distribution; afterwards take argmax """ inp = tokenizer(seed_text + tokenizer.mask_token * max_len + seed_end_text, return_tensors='pt') masked_tokens = np.where( inp['input_ids'][0].numpy() == tokenizer.mask_token_id)[0] seed_len = masked_tokens[0] if cuda: inp = inp.to('cuda') for ii in range(max_iter): kk = np.random.randint(0, max_len) out = model(**inp) topk = top_k if (ii >= burnin) else 0 idxs = generate_step(out, gen_idx=seed_len + kk, top_k=topk, temperature=temperature, sample=(ii < burnin)) inp['input_ids'][0][seed_len + kk] = idxs[0] tokens = inp['input_ids'].cpu().numpy()[0][masked_tokens] tokens = tokens[(np.where((tokens != tokenizer.eos_token_id) & (tokens != tokenizer.bos_token_id)))] return tokenizer.decode(tokens) def inbertolate(doc, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200): new_doc = '' paras = doc.split('\n') for para in paras: para = sent_tokenize(para) if para == '': new_doc += '\n' continue para += [''] for sentence in range(len(para) - 1): new_doc += para[sentence] + ' ' new_doc += parallel_sequential_generation(para[sentence], para[sentence + 1], max_len=max_len, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter) + ' ' new_doc += '\n' return new_doc if __name__ == '__main__': block = gr.Blocks(css='.container') with block: gr.Markdown("