NusaBERT / utils.py
StevenLimcorn's picture
Added NERP tagging and document retrieval wikipedia
2ef1af8
raw
history blame
5.41 kB
import gradio as gr
from functools import partial
from transformers import pipeline, pipelines
from sentence_transformers import SentenceTransformer, util
from scipy.special import softmax
import os
import json
######################
##### INFERENCE ######
######################
class SentenceSimilarity:
def __init__(self, model: str):
self.model = SentenceTransformer(model)
def __call__(self, query: str, corpus: list[str]):
query_embedding = self.model.encode(query)
corpus_embeddings = self.model.encode(corpus)
output = util.semantic_search(query_embedding, corpus_embeddings, top_k=5)
return output[0]
# Sentence Similarity
def sentence_similarity(
query: str,
texts: list[str],
titles: list[str],
urls: list[str],
pipe: SentenceSimilarity,
):
answer = pipe(query=query, corpus=texts)
df = [
[
f"<a href='{urls[ans['corpus_id']]} target='_blank'>{titles[ans['corpus_id']]}</a>"
]
for ans in answer
]
return df
# Text Analysis
def cls_inference(input: list[str], pipe: pipeline) -> dict:
results = pipe(input, top_k=None)
return {x["label"]: x["score"] for x in results}
# POSP
def tagging(text: str, pipe: pipeline):
output = pipe(text)
return {"text": text, "entities": output}
# Text Analysis
def text_analysis(text, pipes: list[pipeline]):
outputs = []
for pipe in pipes:
if isinstance(pipe, pipelines.token_classification.TokenClassificationPipeline):
outputs.append(tagging(text, pipe))
else:
outputs.append(cls_inference(text, pipe))
return outputs
######################
##### INTERFACE ######
######################
def text_interface(
pipe: pipeline, examples: list[str], output_label: str, title: str, desc: str
):
return gr.Interface(
fn=partial(cls_inference, pipe=pipe),
inputs=[
gr.Textbox(lines=5, label="Input Text"),
],
title=title,
description=desc,
outputs=[gr.Label(label=output_label)],
examples=examples,
allow_flagging="never",
)
def search_interface(
pipe: SentenceSimilarity,
examples: list[str],
output_label: str,
title: str,
desc: str,
sample: str,
):
f = open(sample)
data = json.load(f)
with gr.Blocks() as sentence_similarity_interface:
gr.Markdown(title)
gr.Markdown(desc)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(lines=5, label="Query")
df = gr.DataFrame(
[
[id, f"<a href='{url}' target='_blank'>{title}</a>"]
for id, title, url in zip(
data["id"], data["title"], data["url"]
)
],
headers=["ID", "Title"],
wrap=True,
datatype=["markdown", "html"],
interactive=False,
height=300,
)
button = gr.Button("Search...")
output = gr.DataFrame(
headers=["Title"],
wrap=True,
datatype=["html"],
interactive=False,
)
button.click(
fn=partial(
sentence_similarity,
pipe=pipe,
texts=data["text"],
titles=data["title"],
urls=data["url"],
),
inputs=[input_text],
outputs=[output],
)
return sentence_similarity_interface
def token_classification_interface(
pipe: pipeline, examples: list[str], output_label: str, title: str, desc: str
):
return gr.Interface(
fn=partial(tagging, pipe=pipe),
inputs=[
gr.Textbox(placeholder="Masukan kalimat di sini...", label="Input Text"),
],
outputs=[gr.HighlightedText(label=output_label)],
title=title,
examples=examples,
description=desc,
allow_flagging="never",
)
def text_analysis_interface(
pipe: list, examples: list[str], output_label: str, title: str, desc: str
):
with gr.Blocks() as text_analysis_interface:
gr.Markdown(title)
gr.Markdown(desc)
input_text = gr.Textbox(lines=5, label="Input Text")
with gr.Row():
outputs = [
(
gr.HighlightedText(label=label)
if isinstance(
p, pipelines.token_classification.TokenClassificationPipeline
)
else gr.Label(label=label)
)
for label, p in zip(output_label, pipe)
]
btn = gr.Button("Analyze")
btn.click(
fn=partial(text_analysis, pipes=pipe),
inputs=[input_text],
outputs=outputs,
)
gr.Examples(
examples=examples,
inputs=input_text,
outputs=outputs,
)
return text_analysis_interface
# Summary
# summary_interface = gr.Interface.from_pipeline(
# pipes["summarization"],
# title="Summarization",
# examples=details["summarization"]["examples"],
# description=details["summarization"]["description"],
# allow_flagging="never",
# )