coex-prj / run.py
harheem's picture
Upload folder using huggingface_hub
d2941e6 verified
import pandas as pd
import numpy as np
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from rank_bm25 import BM25Okapi
from kiwipiepy import Kiwi
from typing import List
import gradio as gr
class ProductSearchSystem:
def __init__(self,
model_name: str = "snunlp/KR-SBERT-V40K-klueNLI-augSTS",
bm25_weight: float = 0.3,
vector_weight: float = 0.7):
"""검색 μ‹œμŠ€ν…œ μ΄ˆκΈ°ν™”"""
self.embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
self.bm25_weight = bm25_weight
self.vector_weight = vector_weight
self.vector_store = None
self.bm25 = None
self.documents = []
self.df = None
# Kiwi ν† ν¬λ‚˜μ΄μ € μ΄ˆκΈ°ν™”
self.kiwi = Kiwi()
def _tokenize_text(self, text: str) -> List[str]:
"""Kiwiλ₯Ό μ‚¬μš©ν•œ ν…μŠ€νŠΈ ν† ν¬λ‚˜μ΄μ§•"""
# ν˜•νƒœμ†Œ 뢄석 μˆ˜ν–‰
tokens = self.kiwi.tokenize(text)
# λͺ…사, 동사, ν˜•μš©μ‚¬λ§Œ μΆ”μΆœ
pos_tags = ['NNG', 'NNP', 'VV', 'VA', 'SL'] # 일반λͺ…사, 고유λͺ…사, 동사, ν˜•μš©μ‚¬
return [token.form for token in tokens if token.tag in pos_tags] # posλ₯Ό tag둜 λ³€κ²½
def load_sample_data(self):
"""μƒ˜ν”Œ 데이터 λ‘œλ“œ"""
self.df = pd.read_csv("sample_data.csv")
self._preprocess_data()
self._create_search_index()
return True
def _preprocess_data(self):
"""데이터 μ „μ²˜λ¦¬"""
# 빈 κ°’ 처리
self.df['category'] = self.df['category'].fillna('λ―ΈλΆ„λ₯˜')
# 특수 문자 처리
self.df['company_info'] = self.df['company_info'].fillna('')
self.df['company_info'] = self.df['company_info'].str.replace('_x000D_', '\n')
self.df['description'] = self.df['description'].fillna('')
self.df['description'] = self.df['description'].str.replace('_x000D_', '\n')
# λΆˆν•„μš”ν•œ 곡백 제거
for col in self.df.columns:
if self.df[col].dtype == 'object':
self.df[col] = self.df[col].str.strip()
def _create_search_index(self):
"""검색 인덱슀 생성"""
self.documents = []
tokenized_documents = [] # BM25용 ν† ν°ν™”λœ λ¬Έμ„œ
for _, row in self.df.iterrows():
content = f"{row['company_name']} {row['category']} {row['company_info']} {row['product_name']} {row['description']}"
# Kiwi ν† ν¬λ‚˜μ΄μ €λ₯Ό μ‚¬μš©ν•œ 토큰화
tokenized_doc = self._tokenize_text(content)
tokenized_documents.append(tokenized_doc)
self.documents.append(
Document(
page_content=content,
metadata={
'company_name': row['company_name'],
'category': row['category'],
'company_info': row['company_info'],
'product_name': row['product_name'],
'description': row['description']
}
)
)
# BM25 인덱슀 생성
self.bm25 = BM25Okapi(tokenized_documents)
# 벑터 μŠ€ν† μ–΄ 생성
self.vector_store = FAISS.from_documents(self.documents, self.embeddings)
def search(self, query: str, top_k: int = 3) -> List[dict]:
"""검색 μ‹€ν–‰"""
if not query.strip():
return []
# BM25 검색 - Kiwi ν† ν¬λ‚˜μ΄μ € μ‚¬μš©
tokenized_query = self._tokenize_text(query)
bm25_scores = self.bm25.get_scores(tokenized_query)
# 벑터 검색
query_embedding = self.embeddings.embed_query(query)
vector_docs_and_scores = self.vector_store.similarity_search_with_score(query, k=len(self.documents))
# κ²°κ³Ό 톡합 및 점수 계산
results = []
seen_products = set()
# 점수 μ •κ·œν™”λ₯Ό μœ„ν•œ μ΅œλŒ€κ°’
max_bm25 = max(bm25_scores) if len(bm25_scores) > 0 else 1
max_vector = max(score for _, score in vector_docs_and_scores) if vector_docs_and_scores else 1
for i, doc in enumerate(self.documents):
# μ •κ·œν™”λœ 점수 계산
bm25_score = bm25_scores[i] / max_bm25 if max_bm25 > 0 else 0
vector_score = None
# ν•΄λ‹Ή λ¬Έμ„œμ˜ 벑터 점수 μ°ΎκΈ°
for vec_doc, vec_score in vector_docs_and_scores:
if vec_doc.page_content == doc.page_content:
vector_score = (1 - (vec_score / max_vector)) if max_vector > 0 else 0
break
if vector_score is not None:
# μ΅œμ’… 점수 계산
final_score = (self.bm25_weight * bm25_score) + (self.vector_weight * vector_score)
product_key = f"{doc.metadata['company_name']}-{doc.metadata['product_name']}"
if product_key not in seen_products:
results.append({
'company_name': doc.metadata['company_name'],
'category': doc.metadata['category'],
'company_info': doc.metadata['company_info'],
'product_name': doc.metadata['product_name'],
'description': doc.metadata['description'],
'bm25_score': round(bm25_score, 3),
'vector_score': round(vector_score, 3),
'final_score': round(final_score, 3)
})
seen_products.add(product_key)
# μ΅œμ’… 점수둜 μ •λ ¬
results.sort(key=lambda x: x['final_score'], reverse=True)
return results[:top_k]
def create_gradio_interface():
"""Gradio μΈν„°νŽ˜μ΄μŠ€ 생성"""
# 검색 μ‹œμŠ€ν…œ μ΄ˆκΈ°ν™” 및 μƒ˜ν”Œ 데이터 λ‘œλ“œ
search_system = ProductSearchSystem()
search_system.load_sample_data()
def search_products(query: str,
top_k: int,
bm25_weight: float) -> tuple:
"""검색 μ‹€ν–‰ 및 κ²°κ³Ό ν¬λ§€νŒ…"""
# κ°€μ€‘μΉ˜ μ—…λ°μ΄νŠΈ
search_system.bm25_weight = bm25_weight
search_system.vector_weight = 1 - bm25_weight
# 검색 μ‹€ν–‰
results = search_system.search(query, top_k=top_k)
# κ²°κ³Όλ₯Ό ν‘œ ν˜•μ‹μœΌλ‘œ λ³€ν™˜
if results:
# ν‘œμ‹œν•  μ—΄ μˆœμ„œ 지정
columns_order = ['company_name', 'category', 'company_info', 'product_name', 'bm25_score', 'vector_score', 'final_score', 'description']
df_results = pd.DataFrame(results)[columns_order]
# μ—΄ 이름 ν•œκΈ€ν™”
df_results.columns = ['νšŒμ‚¬λͺ…', 'μΉ΄ν…Œκ³ λ¦¬', 'νšŒμ‚¬ μ„€λͺ…', 'μ œν’ˆλͺ…', 'ν‚€μ›Œλ“œ 점수', '벑터 점수', 'μ΅œμ’… 점수', 'μ„€λͺ…']
html_table = df_results.to_html(
classes=['table', 'table-striped'],
escape=False,
index=False,
float_format=lambda x: '{:.3f}'.format(x) # μ†Œμˆ˜μ  3μžλ¦¬κΉŒμ§€ ν‘œμ‹œ
)
else:
html_table = "<p>검색 κ²°κ³Όκ°€ μ—†μŠ΅λ‹ˆλ‹€.</p>"
# 상세 κ²°κ³Ό ν…μŠ€νŠΈ 생성
detailed_results = []
for i, result in enumerate(results, 1):
detailed_results.append(f"""
=== 검색결과 #{i} ===
νšŒμ‚¬λͺ…: {result['company_name']}
μΉ΄ν…Œκ³ λ¦¬: {result['category']}
νšŒμ‚¬ μ„€λͺ…: {result['company_info']}
μ œν’ˆλͺ…: {result['product_name']}
ν‚€μ›Œλ“œ 점수: {result['bm25_score']:.3f}
벑터 점수: {result['vector_score']:.3f}
μ΅œμ’… 점수: {result['final_score']:.3f}
μ„€λͺ…: {result['description']}
""")
detailed_text = "\n".join(detailed_results) if detailed_results else "검색 κ²°κ³Όκ°€ μ—†μŠ΅λ‹ˆλ‹€."
return html_table, detailed_text
# Gradio μΈν„°νŽ˜μ΄μŠ€ μ •μ˜
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown("""
# πŸ” μ½”μ—‘μŠ€ λΆ€μŠ€ μΆ”μ²œ μ‹œμŠ€ν…œ
ν•˜μ΄λΈŒλ¦¬λ“œ 방식을 μ΄μš©ν•œ κΈ°μ—… 및 μ œν’ˆ 검색/μΆ”μ²œ μ‹œμŠ€ν…œμž…λ‹ˆλ‹€.
""")
with gr.Row():
with gr.Column(scale=4):
query_input = gr.Textbox(
label="검색어λ₯Ό μž…λ ₯ν•˜μ„Έμš”",
placeholder="예: AI 기술 νšŒμ‚¬, μ„Όμ„œ, μžλ™ν™” λ“±",
)
with gr.Column(scale=1):
top_k = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="검색 κ²°κ³Ό 수",
)
with gr.Row():
bm25_weight = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.1,
label="ν‚€μ›Œλ“œ 검색 κ°€μ€‘μΉ˜",
)
with gr.Row():
search_button = gr.Button("검색", variant="primary")
with gr.Row():
with gr.Column():
results_table = gr.HTML(label="검색 κ²°κ³Ό ν…Œμ΄λΈ”")
with gr.Column():
results_text = gr.Textbox(
label="상세 κ²°κ³Ό",
show_label=True,
interactive=False,
lines=10
)
# 이벀트 ν•Έλ“€λŸ¬ μ—°κ²°
search_button.click(
fn=search_products,
inputs=[query_input, top_k, bm25_weight],
outputs=[results_table, results_text],
)
gr.Markdown("""
### μ‚¬μš© 방법
1. 검색어 μž…λ ₯: 찾고자 ν•˜λŠ” κΈ°μ—…, μ œν’ˆ, 기술 λ“±μ˜ ν‚€μ›Œλ“œλ₯Ό μž…λ ₯ν•˜μ„Έμš”
2. 검색 κ²°κ³Ό 수 μ‘°μ •: μ›ν•˜λŠ” κ²°κ³Ό 수λ₯Ό μ„ νƒν•˜μ„Έμš”
3. κ°€μ€‘μΉ˜ μ‘°μ •: ν‚€μ›Œλ“œ 맀칭과 의미적 μœ μ‚¬λ„ κ°„μ˜ κ°€μ€‘μΉ˜λ₯Ό μ‘°μ ˆν•˜μ„Έμš”
### 점수 μ„€λͺ…
- ν‚€μ›Œλ“œ 점수: Kiwi ν† ν¬λ‚˜μ΄μ €λ₯Ό μ‚¬μš©ν•œ ν‚€μ›Œλ“œ 기반 맀칭 점수 (0~1)
- 벑터 점수: 의미적 μœ μ‚¬λ„ 점수 (0~1)
- μ΅œμ’… 점수: ν‚€μ›Œλ“œ μ μˆ˜μ™€ 벑터 점수의 가쀑 평균
""")
return demo
def main():
demo = create_gradio_interface()
demo.launch(share=True)
if __name__ == "__main__":
main()
# TODO
# OCR λ”₯λŸ¬λ‹ vs OCR 처리
# ν† ν¬λ‚˜μ΄μ € 처리 κ²°κ³Ό ν…ŒμŠ€νŠΈ
# ν’ˆμ‚¬ νƒœκΉ… κ²°κ³Ό 확인