|
import pandas as pd |
|
import numpy as np |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.schema import Document |
|
from rank_bm25 import BM25Okapi |
|
from kiwipiepy import Kiwi |
|
from typing import List |
|
import gradio as gr |
|
|
|
class ProductSearchSystem: |
|
def __init__(self, |
|
model_name: str = "snunlp/KR-SBERT-V40K-klueNLI-augSTS", |
|
bm25_weight: float = 0.3, |
|
vector_weight: float = 0.7): |
|
"""κ²μ μμ€ν
μ΄κΈ°ν""" |
|
self.embeddings = HuggingFaceEmbeddings( |
|
model_name=model_name, |
|
model_kwargs={'device': 'cpu'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
self.bm25_weight = bm25_weight |
|
self.vector_weight = vector_weight |
|
self.vector_store = None |
|
self.bm25 = None |
|
self.documents = [] |
|
self.df = None |
|
|
|
self.kiwi = Kiwi() |
|
|
|
def _tokenize_text(self, text: str) -> List[str]: |
|
"""Kiwiλ₯Ό μ¬μ©ν ν
μ€νΈ ν ν¬λμ΄μ§""" |
|
|
|
tokens = self.kiwi.tokenize(text) |
|
|
|
pos_tags = ['NNG', 'NNP', 'VV', 'VA', 'SL'] |
|
return [token.form for token in tokens if token.tag in pos_tags] |
|
|
|
def load_sample_data(self): |
|
"""μν λ°μ΄ν° λ‘λ""" |
|
self.df = pd.read_csv("sample_data.csv") |
|
self._preprocess_data() |
|
self._create_search_index() |
|
return True |
|
|
|
def _preprocess_data(self): |
|
"""λ°μ΄ν° μ μ²λ¦¬""" |
|
|
|
self.df['category'] = self.df['category'].fillna('λ―ΈλΆλ₯') |
|
|
|
|
|
self.df['company_info'] = self.df['company_info'].fillna('') |
|
self.df['company_info'] = self.df['company_info'].str.replace('_x000D_', '\n') |
|
self.df['description'] = self.df['description'].fillna('') |
|
self.df['description'] = self.df['description'].str.replace('_x000D_', '\n') |
|
|
|
|
|
for col in self.df.columns: |
|
if self.df[col].dtype == 'object': |
|
self.df[col] = self.df[col].str.strip() |
|
|
|
def _create_search_index(self): |
|
"""κ²μ μΈλ±μ€ μμ±""" |
|
self.documents = [] |
|
tokenized_documents = [] |
|
|
|
for _, row in self.df.iterrows(): |
|
content = f"{row['company_name']} {row['category']} {row['company_info']} {row['product_name']} {row['description']}" |
|
|
|
tokenized_doc = self._tokenize_text(content) |
|
tokenized_documents.append(tokenized_doc) |
|
|
|
self.documents.append( |
|
Document( |
|
page_content=content, |
|
metadata={ |
|
'company_name': row['company_name'], |
|
'category': row['category'], |
|
'company_info': row['company_info'], |
|
'product_name': row['product_name'], |
|
'description': row['description'] |
|
} |
|
) |
|
) |
|
|
|
|
|
self.bm25 = BM25Okapi(tokenized_documents) |
|
|
|
|
|
self.vector_store = FAISS.from_documents(self.documents, self.embeddings) |
|
|
|
def search(self, query: str, top_k: int = 3) -> List[dict]: |
|
"""κ²μ μ€ν""" |
|
if not query.strip(): |
|
return [] |
|
|
|
|
|
tokenized_query = self._tokenize_text(query) |
|
bm25_scores = self.bm25.get_scores(tokenized_query) |
|
|
|
|
|
query_embedding = self.embeddings.embed_query(query) |
|
vector_docs_and_scores = self.vector_store.similarity_search_with_score(query, k=len(self.documents)) |
|
|
|
|
|
results = [] |
|
seen_products = set() |
|
|
|
|
|
max_bm25 = max(bm25_scores) if len(bm25_scores) > 0 else 1 |
|
max_vector = max(score for _, score in vector_docs_and_scores) if vector_docs_and_scores else 1 |
|
|
|
for i, doc in enumerate(self.documents): |
|
|
|
bm25_score = bm25_scores[i] / max_bm25 if max_bm25 > 0 else 0 |
|
vector_score = None |
|
|
|
|
|
for vec_doc, vec_score in vector_docs_and_scores: |
|
if vec_doc.page_content == doc.page_content: |
|
vector_score = (1 - (vec_score / max_vector)) if max_vector > 0 else 0 |
|
break |
|
|
|
if vector_score is not None: |
|
|
|
final_score = (self.bm25_weight * bm25_score) + (self.vector_weight * vector_score) |
|
|
|
product_key = f"{doc.metadata['company_name']}-{doc.metadata['product_name']}" |
|
if product_key not in seen_products: |
|
results.append({ |
|
'company_name': doc.metadata['company_name'], |
|
'category': doc.metadata['category'], |
|
'company_info': doc.metadata['company_info'], |
|
'product_name': doc.metadata['product_name'], |
|
'description': doc.metadata['description'], |
|
'bm25_score': round(bm25_score, 3), |
|
'vector_score': round(vector_score, 3), |
|
'final_score': round(final_score, 3) |
|
}) |
|
seen_products.add(product_key) |
|
|
|
|
|
results.sort(key=lambda x: x['final_score'], reverse=True) |
|
return results[:top_k] |
|
|
|
def create_gradio_interface(): |
|
"""Gradio μΈν°νμ΄μ€ μμ±""" |
|
|
|
search_system = ProductSearchSystem() |
|
search_system.load_sample_data() |
|
|
|
def search_products(query: str, |
|
top_k: int, |
|
bm25_weight: float) -> tuple: |
|
"""κ²μ μ€ν λ° κ²°κ³Ό ν¬λ§€ν
""" |
|
|
|
search_system.bm25_weight = bm25_weight |
|
search_system.vector_weight = 1 - bm25_weight |
|
|
|
|
|
results = search_system.search(query, top_k=top_k) |
|
|
|
|
|
if results: |
|
|
|
columns_order = ['company_name', 'category', 'company_info', 'product_name', 'bm25_score', 'vector_score', 'final_score', 'description'] |
|
df_results = pd.DataFrame(results)[columns_order] |
|
|
|
|
|
df_results.columns = ['νμ¬λͺ
', 'μΉ΄ν
κ³ λ¦¬', 'νμ¬ μ€λͺ
', 'μ νλͺ
', 'ν€μλ μ μ', 'λ²‘ν° μ μ', 'μ΅μ’
μ μ', 'μ€λͺ
'] |
|
|
|
html_table = df_results.to_html( |
|
classes=['table', 'table-striped'], |
|
escape=False, |
|
index=False, |
|
float_format=lambda x: '{:.3f}'.format(x) |
|
) |
|
else: |
|
html_table = "<p>κ²μ κ²°κ³Όκ° μμ΅λλ€.</p>" |
|
|
|
|
|
detailed_results = [] |
|
for i, result in enumerate(results, 1): |
|
detailed_results.append(f""" |
|
=== κ²μκ²°κ³Ό #{i} === |
|
νμ¬λͺ
: {result['company_name']} |
|
μΉ΄ν
κ³ λ¦¬: {result['category']} |
|
νμ¬ μ€λͺ
: {result['company_info']} |
|
μ νλͺ
: {result['product_name']} |
|
ν€μλ μ μ: {result['bm25_score']:.3f} |
|
λ²‘ν° μ μ: {result['vector_score']:.3f} |
|
μ΅μ’
μ μ: {result['final_score']:.3f} |
|
μ€λͺ
: {result['description']} |
|
""") |
|
|
|
detailed_text = "\n".join(detailed_results) if detailed_results else "κ²μ κ²°κ³Όκ° μμ΅λλ€." |
|
|
|
return html_table, detailed_text |
|
|
|
|
|
with gr.Blocks(css="footer {visibility: hidden}") as demo: |
|
gr.Markdown(""" |
|
# π μ½μμ€ λΆμ€ μΆμ² μμ€ν
|
|
νμ΄λΈλ¦¬λ λ°©μμ μ΄μ©ν κΈ°μ
λ° μ ν κ²μ/μΆμ² μμ€ν
μ
λλ€. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
query_input = gr.Textbox( |
|
label="κ²μμ΄λ₯Ό μ
λ ₯νμΈμ", |
|
placeholder="μ: AI κΈ°μ νμ¬, μΌμ, μλν λ±", |
|
) |
|
with gr.Column(scale=1): |
|
top_k = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=3, |
|
step=1, |
|
label="κ²μ κ²°κ³Ό μ", |
|
) |
|
|
|
with gr.Row(): |
|
bm25_weight = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.3, |
|
step=0.1, |
|
label="ν€μλ κ²μ κ°μ€μΉ", |
|
) |
|
|
|
with gr.Row(): |
|
search_button = gr.Button("κ²μ", variant="primary") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
results_table = gr.HTML(label="κ²μ κ²°κ³Ό ν
μ΄λΈ") |
|
with gr.Column(): |
|
results_text = gr.Textbox( |
|
label="μμΈ κ²°κ³Ό", |
|
show_label=True, |
|
interactive=False, |
|
lines=10 |
|
) |
|
|
|
|
|
search_button.click( |
|
fn=search_products, |
|
inputs=[query_input, top_k, bm25_weight], |
|
outputs=[results_table, results_text], |
|
) |
|
|
|
gr.Markdown(""" |
|
### μ¬μ© λ°©λ² |
|
1. κ²μμ΄ μ
λ ₯: μ°Ύκ³ μ νλ κΈ°μ
, μ ν, κΈ°μ λ±μ ν€μλλ₯Ό μ
λ ₯νμΈμ |
|
2. κ²μ κ²°κ³Ό μ μ‘°μ : μνλ κ²°κ³Ό μλ₯Ό μ ννμΈμ |
|
3. κ°μ€μΉ μ‘°μ : ν€μλ 맀μΉκ³Ό μλ―Έμ μ μ¬λ κ°μ κ°μ€μΉλ₯Ό μ‘°μ νμΈμ |
|
|
|
### μ μ μ€λͺ
|
|
- ν€μλ μ μ: Kiwi ν ν¬λμ΄μ λ₯Ό μ¬μ©ν ν€μλ κΈ°λ° λ§€μΉ μ μ (0~1) |
|
- λ²‘ν° μ μ: μλ―Έμ μ μ¬λ μ μ (0~1) |
|
- μ΅μ’
μ μ: ν€μλ μ μμ λ²‘ν° μ μμ κ°μ€ νκ· |
|
""") |
|
|
|
return demo |
|
|
|
def main(): |
|
demo = create_gradio_interface() |
|
demo.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|