File size: 1,384 Bytes
fb358cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1c18f
fb358cd
2f1c18f
 
fb358cd
 
 
 
 
 
 
64a9264
fb358cd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import torch
import pandas as pd
import numpy as np
import re
import gradio as gr

model_repo = "napatswift/mt5-fixpdftext"

tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModelForSeq2SeqLM.from_pretrained(model_repo)

embedding = list(model.modules())[1]
del model

def get_embedding(text):
    return embedding(tokenizer(text, return_tensors='pt').input_ids[0]).mean(axis=0)

df = pd.read_csv('67_all_ministry.csv')

def get_name(row):
    for col, val in row.items():
        if col.startswith('name_') and val and isinstance(val, str):
            return val
    return

budget_items = df.apply(get_name, axis=1).unique().tolist()

budget_item_embeddings = torch.stack(list(map(get_embedding, budget_items)))

def get_closest_budget_item(text, num_results=5):
    text_embedding = get_embedding(text)
    scores = torch.norm(budget_item_embeddings - text_embedding, dim=1)
    top_idx = scores.argsort()[:num_results]
    return pd.DataFrame({
        'budget_item': np.array(budget_items)[top_idx],
        'score': scores[top_idx].tolist()
    })

demo = gr.Interface(
    fn=get_closest_budget_item,
    inputs=['textbox', gr.Slider(minimum=1, maximum=50, step=5, value=5, label="Number of results")],
    outputs='dataframe',
)

if __name__ == "__main__":
    demo.launch()