Spaces:
Sleeping
Sleeping
import pandas as pd | |
import streamlit as st | |
from infer import USPPPMModel, USPPPMDataset | |
import torch | |
def load_model(): | |
model = USPPPMModel('microsoft/deberta-v3-small') | |
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
ds = USPPPMDataset(model.tokenizer, 133) | |
return model, ds | |
def infer(anchor, target, title): | |
model, ds = load_model() | |
d = { | |
'anchor': anchor, | |
'target': target, | |
'title': title, | |
'label': 0 | |
} | |
x = ds[d][0] | |
with torch.no_grad(): | |
y = model(x) | |
return y.cpu().numpy()[0][0] | |
def get_context(): | |
df = pd.read_csv('./fold-0-train.csv') | |
l = list(set(list(df['title'].values))) | |
return l | |
st.set_page_config( | |
page_title="PatentMatch", | |
page_icon="π§", | |
layout="centered", | |
initial_sidebar_state="expanded", | |
) | |
# fix sidebar | |
st.markdown(""" | |
<style> | |
.css-vk3wp9 { | |
background-color: rgb(255 255 255); | |
} | |
.css-18l0hbk { | |
padding: 0.34rem 1.2rem !important; | |
margin: 0.125rem 2rem; | |
} | |
.css-nziaof { | |
padding: 0.34rem 1.2rem !important; | |
margin: 0.125rem 2rem; | |
background-color: rgb(181 197 227 / 18%) !important; | |
} | |
</style> | |
""", unsafe_allow_html=True | |
) | |
hide_st_style = """ | |
<style> | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
header {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(hide_st_style, unsafe_allow_html=True) | |
def app(): | |
st.title("PatentMatch: Patent Semantic Similarity Matcher") | |
#st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/<username>/<project_name>?workspace=user-<username>)") | |
st.markdown( | |
"""This project is focused on developing a Transformer based NLP model to match phrases | |
in U.S. patents based on their semantic similarity within a specific | |
technical domain context. The trained model achieved Pearson correlation coefficient score of 0.745. | |
[[Source Code]](https://github.com/dataraptor/PatentMatch) | |
""" | |
) | |
st.markdown('---') | |
# st.selectbox("Select from example", | |
# [ | |
# "Example 1", | |
# "Example 2", | |
# ]) | |
row1_col1, row1_col2, row1_col3 = st.columns( | |
[0.5, 0.4, 0.4] | |
) | |
# with row1_col1: | |
# frequency = st.selectbox("Section", | |
# [ | |
# "A: Human Necessities", | |
# "B: Operations and Transport", | |
# "C: Chemistry and Metallurgy", | |
# "D: Textiles", | |
# "E: Fixed Constructions", | |
# "F: Mechanical Engineering", | |
# "G: Physics", | |
# "H: Electricity", | |
# "Y: Emerging Cross-Sectional Technologies", | |
# ]) | |
# with row1_col2: | |
# class_box = st.selectbox("Class", | |
# [ | |
# "21", | |
# "14", | |
# "23", | |
# ]) | |
with row1_col1: | |
l = get_context() | |
context = st.selectbox("Context", l, l.index('basic electric elements')) | |
with row1_col2: | |
anchor = st.text_input("Anchor", "deflect light") | |
with row1_col3: | |
target = st.text_input("Target", "bending moment") | |
if st.button("Predict Scores", type="primary"): | |
with st.spinner("Predicting scores..."): | |
score = infer(anchor, target, context) | |
ss = st.success("Scores predicted successfully!") | |
score += 2.0 | |
fmt = "{:<.3f}".format(score) | |
st.subheader(f"Similarity Score: {fmt}") | |
app() | |
# Display a footer with links and credits | |
st.markdown("---") | |
st.markdown("Built by [Shamim Ahamed](https://www.shamimahamed.com/). Data provided by [Kaggle](https://www.kaggle.com/competitions/us-patent-phrase-to-phrase-matching)") | |
#st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)") | |