jannisborn's picture
update
69c3e34 unverified
raw
history blame
3.48 kB
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.pgt import (
PGT,
PGTCoherenceChecker,
PGTEditor,
PGTGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
MODEL_FN = {
"PGTGenerator": PGTGenerator,
"PGTEditor": PGTEditor,
"PGTCoherenceChecker": PGTCoherenceChecker,
}
def run_inference(
model_type: str,
generator_task: str,
editor_task: str,
checker_task: str,
prompt: str,
second_prompt: str,
length: int,
k: int,
p: float,
):
kwargs = {"max_length": length, "top_k": k, "top_p": p}
if model_type == "PGTGenerator":
config = PGTGenerator(task=generator_task, input_text=prompt, **kwargs)
elif model_type == "PGTEditor":
config = PGTEditor(input_type=editor_task, input_text=prompt, **kwargs)
elif model_type == "PGTCoherenceChecker":
config = PGTCoherenceChecker(
coherence_type=checker_task, input_a=prompt, input_b=second_prompt, **kwargs
)
model = PGT(config)
text = list(model.sample(1))[0]
return text
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_application"]
for x in list(filter(lambda x: "PGT" in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(
metadata_root.joinpath("examples.csv"), sep="|", header=None
).fillna("")
print("Examples: ", examples.values.tolist())
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
gen_tasks = [
"title-to-abstract",
"abstract-to-title",
"abstract-to-claim",
"claim-to-abstract",
]
demo = gr.Interface(
fn=run_inference,
title="Patent Generative Transformer",
inputs=[
gr.Dropdown(algos, label="Model type", value="PGTGenerator"),
gr.Dropdown(gen_tasks, label="Generator task", value="title-to-abstract"),
gr.Dropdown(["abstract", "claim"], label="Editor task", value="abstract"),
gr.Dropdown(
["title-abstract", "title-claim", "abstract-claim"],
label="Checker task",
value="title-abstract",
),
gr.Textbox(
label="Primary Text prompt",
placeholder="Artificial intelligence and machine learning infrastructure",
lines=5,
),
gr.Textbox(
label="Secondary text prompt (only coherence checker)",
placeholder="",
lines=1,
),
gr.Slider(
minimum=5, maximum=1024, value=512, label="Maximal length", step=1
),
gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1),
gr.Slider(minimum=0.5, maximum=1, value=1.0, label="Top-p", step=1),
],
outputs=gr.Textbox(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)