NusaBERT / utils.py
StevenLimcorn's picture
Modified Document Search with 10000 samples, outputs accordion
860f760
raw
history blame
6.05 kB
import gradio as gr
from functools import partial
from transformers import pipeline, pipelines
from sentence_transformers import SentenceTransformer, util
import json
######################
##### INFERENCE ######
######################
class SentenceSimilarity:
def __init__(self, model: str, corpus_path: str):
f = open(corpus_path)
data = json.load(f)
self.id, self.url, self.title, self.text = (
data["id"],
data["url"],
data["title"],
data["text"],
)
self.model = SentenceTransformer(model)
self.corpus_embeddings = self.model.encode(self.text)
def __call__(self, query: str, corpus: list[str], top_k: int = 5):
query_embedding = self.model.encode(query)
output = util.semantic_search(
query_embedding, self.corpus_embeddings, top_k=top_k
)
return output[0]
# Sentence Similarity
def sentence_similarity(
query: str,
texts: list[str],
titles: list[str],
urls: list[str],
pipe: SentenceSimilarity,
top_k: int,
) -> list[str]:
answer = pipe(query=query, corpus=texts, top_k=top_k)
output = [
f"""
Cosine Similarity Score: {round(ans['score'], 3)}
## [{titles[ans['corpus_id']]} ๐Ÿ”—]({urls[ans['corpus_id']]})
{texts[ans['corpus_id']]}
"""
for ans in answer
]
return output
# Text Analysis
def cls_inference(input: list[str], pipe: pipeline) -> dict:
results = pipe(input, top_k=None)
return {x["label"]: x["score"] for x in results}
# POSP
def tagging(text: str, pipe: pipeline):
output = pipe(text)
return {"text": text, "entities": output}
# Text Analysis
def text_analysis(text, pipes: list[pipeline]):
outputs = []
for pipe in pipes:
if isinstance(pipe, pipelines.token_classification.TokenClassificationPipeline):
outputs.append(tagging(text, pipe))
else:
outputs.append(cls_inference(text, pipe))
return outputs
######################
##### INTERFACE ######
######################
def text_interface(
pipe: pipeline, examples: list[str], output_label: str, title: str, desc: str
):
return gr.Interface(
fn=partial(cls_inference, pipe=pipe),
inputs=[
gr.Textbox(lines=5, label="Input Text"),
],
title=title,
description=desc,
outputs=[gr.Label(label=output_label)],
examples=examples,
allow_flagging="never",
)
def search_interface(
pipe: SentenceSimilarity,
examples: list[str],
output_label: str,
title: str,
desc: str,
top_k: int,
):
with gr.Blocks() as sentence_similarity_interface:
gr.Markdown(title)
gr.Markdown(desc)
with gr.Row():
# input on the left
with gr.Column():
input_text = gr.Textbox(lines=5, label="Query")
# display documents
df = gr.DataFrame(
[
[id, f"<a href='{url}' target='_blank'>{title} ๐Ÿ”—</a>"]
for id, title, url in zip(pipe.id, pipe.title, pipe.url)
],
headers=["ID", "Title"],
wrap=True,
datatype=["markdown", "html"],
interactive=False,
height=300,
)
button = gr.Button("Search...")
with gr.Column():
# outputs top_k results in accordion format
outputs = []
for i in range(top_k):
# open the first accordion
with gr.Accordion(label=f"Document {i + 1}", open=i == 0) as a:
output = gr.Markdown()
outputs.append(output)
gr.Examples(examples, inputs=[input_text], outputs=outputs)
button.click(
fn=partial(
sentence_similarity,
pipe=pipe,
texts=pipe.text,
titles=pipe.title,
urls=pipe.url,
top_k=top_k,
),
inputs=[input_text],
outputs=outputs,
)
return sentence_similarity_interface
def token_classification_interface(
pipe: pipeline, examples: list[str], output_label: str, title: str, desc: str
):
return gr.Interface(
fn=partial(tagging, pipe=pipe),
inputs=[
gr.Textbox(placeholder="Masukan kalimat di sini...", label="Input Text"),
],
outputs=[gr.HighlightedText(label=output_label)],
title=title,
examples=examples,
description=desc,
allow_flagging="never",
)
def text_analysis_interface(
pipe: list, examples: list[str], output_label: str, title: str, desc: str
):
with gr.Blocks() as text_analysis_interface:
gr.Markdown(title)
gr.Markdown(desc)
input_text = gr.Textbox(lines=5, label="Input Text")
with gr.Row():
outputs = [
(
gr.HighlightedText(label=label)
if isinstance(
p, pipelines.token_classification.TokenClassificationPipeline
)
else gr.Label(label=label)
)
for label, p in zip(output_label, pipe)
]
btn = gr.Button("Analyze")
btn.click(
fn=partial(text_analysis, pipes=pipe),
inputs=[input_text],
outputs=outputs,
)
gr.Examples(
examples=examples,
inputs=input_text,
outputs=outputs,
)
return text_analysis_interface
# Summary
# summary_interface = gr.Interface.from_pipeline(
# pipes["summarization"],
# title="Summarization",
# examples=details["summarization"]["examples"],
# description=details["summarization"]["description"],
# allow_flagging="never",
# )