hf-blog-tags / app.py
chenglu's picture
Update app.py
aff403e
import gradio as gr
import torch
import requests
import re
import emoji
import nltk
import lxml
import os
from bs4 import BeautifulSoup
from markdown import markdown
from nltk.corpus import stopwords
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from retry import retry
from transformers import pipeline
pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es")
# 确保已下载 nltk 的停用词
nltk.download('stopwords')
# 从环境变量中获取 hf_token
hf_token = os.getenv('HF_TOKEN')
model_id = "BAAI/bge-large-en-v1.5"
feature_extraction_pipeline = pipeline("feature-extraction", model=model_id)
# model_id = "BAAI/bge-large-en-v1.5"
# api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
# headers = {"Authorization": f"Bearer {hf_token}"}
@retry(tries=3, delay=10)
def query(texts):
# 使用特征提取管道获取特征
features = feature_extraction_pipeline(texts)
# 将特征降维成二维张量(如果它们不是)
# 假设 features 是一个列表,每个元素是一个句子的特征
embeddings = [torch.tensor(f).mean(dim=0) for f in features]
embeddings = torch.stack(embeddings)
return embeddings
# def query(texts):
# response = requests.post(api_url, headers=headers, json={"inputs": texts})
# if response.status_code == 200:
# result = response.json()
# if isinstance(result, list):
# return result
# elif 'error' in result:
# raise RuntimeError("Error from Hugging Face API: " + result['error'])
# else:
# raise RuntimeError("Failed to get response from Hugging Face API, status code: " + str(response.status_code))
# 加载嵌入向量数据集
faqs_embeddings_dataset = load_dataset('chenglu/hf-blogs-baai-embeddings')
df = faqs_embeddings_dataset["train"].to_pandas()
embeddings_array = df.T.to_numpy()
dataset_embeddings = torch.from_numpy(embeddings_array).to(torch.float)
# 加载原始数据集
original_dataset = load_dataset("chenglu/hf-blogs")['train']
# 定义英语停用词集
stop_words = set(stopwords.words('english'))
def remove_stopwords(text):
return ' '.join([word for word in text.split() if word.lower() not in stop_words])
def clean_content(content):
content = re.sub(r"(```.*?```|`.*?`)", "", content, flags=re.DOTALL)
content = BeautifulSoup(content, "html.parser").get_text()
content = emoji.replace_emoji(content, replace='')
content = re.sub(r"[^a-zA-Z\s]", "", content)
content = re.sub(r"http\S+|www\S+|https\S+", '', content, flags=re.MULTILINE)
content = markdown(content)
content = ''.join(BeautifulSoup(content, 'lxml').findAll(text=True))
content = re.sub(r'\s+', ' ', content)
return content
def get_tags_for_local(dataset, local_value):
entry = next((item for item in dataset if item['local'] == local_value), None)
if entry:
return entry['tags']
else:
return None
def gradio_query_interface(input_text):
cleaned_text = clean_content(input_text)
no_stopwords_text = remove_stopwords(cleaned_text)
new_embedding = query(no_stopwords_text)
# new_embedding = feature_extraction_pipeline(input_text)
query_embeddings = torch.FloatTensor(new_embedding)
hits = util.semantic_search(query_embeddings, dataset_embeddings, top_k=5)
if all(hit['score'] < 0.6 for hit in hits[0]):
return "Content Not related"
else:
highest_score_result = max(hits[0], key=lambda x: x['score'])
highest_score_corpus_id = highest_score_result['corpus_id']
local = df.columns[highest_score_corpus_id]
recommended_tags = get_tags_for_local(original_dataset, local)
return f"Recommended category tags: {recommended_tags}"
iface = gr.Interface(
fn=gradio_query_interface,
inputs="text",
outputs="label"
)
iface.launch()