import gradio as gr import torch import numpy as np from transformers import AutoTokenizer, AutoModelForSeq2SeqLM def getScores(ids, scores, pad_token_id): """get sequence scores from model.generate output""" scores = torch.stack(scores, dim=1) log_probs = torch.log_softmax(scores, dim=2) # remove start token ids = ids[:,1:] # gather needed probs x = ids.unsqueeze(-1).expand(log_probs.shape) needed_logits = torch.gather(log_probs, 2, x) final_logits = needed_logits[:, :, 0] padded_mask = (ids == pad_token_id) final_logits[padded_mask] = 0 final_scores = final_logits.sum(dim=-1) return final_scores.cpu().detach().numpy() def topkSample(input, model, tokenizer, num_samples=5, num_beams=1, max_output_length=30): tokenized = tokenizer(input, return_tensors="pt") out = model.generate(**tokenized, do_sample=True, num_return_sequences = num_samples, num_beams = num_beams, eos_token_id = tokenizer.eos_token_id, pad_token_id = tokenizer.pad_token_id, output_scores = True, return_dict_in_generate=True, max_length=max_output_length,) out_tokens = out.sequences out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True) out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id) pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)] sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True) return sorted_pair_list def greedyPredict(input, model, tokenizer): input_ids = tokenizer([input], return_tensors="pt").input_ids out_tokens = model.generate(input_ids) out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True) return out_str[0] def predict_tail(entity, relation): global model, tokenizer input = entity + "| " + relation out = topkSample(input, model, tokenizer, num_samples=25) out_dict = {} for k, v in out: out_dict[k] = np.exp(v).item() return out_dict tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2") model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2") ent_input = gr.inputs.Textbox(lines=1, default="Apoorv Umang Saxena") rel_input = gr.inputs.Textbox(lines=1, default="country") output = gr.outputs.Label() examples = [ ['Adrian Kochsiek', 'gender'], ['Apoorv Umang Saxena', 'family name'], ['World War II', 'followed by'], ['Apoorv Umang Saxena', 'country'] ] title = "Interactive demo: KGT5" description = """Demo for Sequence-to-Sequence Knowledge Graph Completion and Question Answering (KGT5). This particular model is a T5-base model trained on the task of tail prediction on WikiKG90Mv2 dataset and obtains 0.239 validation MRR on this task (leaderboard, see paper for details). To use it, simply give an entity name and relation and click 'submit'. Upto 25 model predictions will show up in a few seconds. The model works best when the exact entity/relation names that it has been trained on are used. It is sometimes able to generalize to unseen entities as well (see examples). """ #article = """ #
Sequence-to-Sequence Knowledge Graph Completion and Question Answering | Github Repo
#""" article = """ Under the hood, this demo concatenates the entity and relation, feeds it to the model and then samples 25 sequences, which are then ranked according to their sequence probabilities. For more details see the Github repo or the hf model page. """ iface = gr.Interface(fn=predict_tail, inputs=[ent_input, rel_input], outputs=output, title=title, description=description, article=article, examples=examples,) iface.launch()