chenglu commited on
Commit
11067f9
·
1 Parent(s): ca867fb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import requests
4
+ import re
5
+ import emoji
6
+ import nltk
7
+ import lxml
8
+ import os
9
+ from bs4 import BeautifulSoup
10
+ from markdown import markdown
11
+ from nltk.corpus import stopwords
12
+ from datasets import load_dataset
13
+ from sentence_transformers import SentenceTransformer, util
14
+ from retry import retry
15
+
16
+ # 确保已下载 nltk 的停用词
17
+ nltk.download('stopwords')
18
+
19
+ # 从环境变量中获取 hf_token
20
+ hf_token = os.getenv('HF_TOKEN')
21
+
22
+ model_id = "BAAI/bge-large-en-v1.5"
23
+ api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
24
+ headers = {"Authorization": f"Bearer {hf_token}"}
25
+
26
+ @retry(tries=3, delay=10)
27
+ def query(texts):
28
+ response = requests.post(api_url, headers=headers, json={"inputs": texts})
29
+ if response.status_code == 200:
30
+ result = response.json()
31
+ if isinstance(result, list):
32
+ return result
33
+ elif 'error' in result:
34
+ raise RuntimeError("Error from Hugging Face API: " + result['error'])
35
+ else:
36
+ raise RuntimeError("Failed to get response from Hugging Face API, status code: " + str(response.status_code))
37
+
38
+ # 加载嵌入向量数据集
39
+ faqs_embeddings_dataset = load_dataset('chenglu/hf-blogs-baai-embeddings')
40
+ df = faqs_embeddings_dataset["train"].to_pandas()
41
+ embeddings_array = df.T.to_numpy()
42
+ dataset_embeddings = torch.from_numpy(embeddings_array).to(torch.float)
43
+
44
+ # 加载原始数据集
45
+ original_dataset = load_dataset("chenglu/hf-blogs")['train']
46
+
47
+ # 定义英语停用词集
48
+ stop_words = set(stopwords.words('english'))
49
+
50
+ def remove_stopwords(text):
51
+ return ' '.join([word for word in text.split() if word.lower() not in stop_words])
52
+
53
+ def clean_content(content):
54
+ content = re.sub(r"(```.*?```|`.*?`)", "", content, flags=re.DOTALL)
55
+ content = BeautifulSoup(content, "html.parser").get_text()
56
+ content = emoji.replace_emoji(content, replace='')
57
+ content = re.sub(r"[^a-zA-Z\s]", "", content)
58
+ content = re.sub(r"http\S+|www\S+|https\S+", '', content, flags=re.MULTILINE)
59
+ content = markdown(content)
60
+ content = ''.join(BeautifulSoup(content, 'lxml').findAll(text=True))
61
+ content = re.sub(r'\s+', ' ', content)
62
+ return content
63
+
64
+ def get_tags_for_local(dataset, local_value):
65
+ entry = next((item for item in dataset if item['local'] == local_value), None)
66
+ if entry:
67
+ return entry['tags']
68
+ else:
69
+ return None
70
+
71
+ def gradio_query_interface(input_text):
72
+ cleaned_text = clean_content(input_text)
73
+ no_stopwords_text = remove_stopwords(cleaned_text)
74
+ new_embedding = query(no_stopwords_text)
75
+ query_embeddings = torch.FloatTensor(new_embedding)
76
+ hits = util.semantic_search(query_embeddings, dataset_embeddings, top_k=5)
77
+ if all(hit['score'] < 0.6 for hit in hits[0]):
78
+ return "Content Not related"
79
+ else:
80
+ highest_score_result = max(hits[0], key=lambda x: x['score'])
81
+ highest_score_corpus_id = highest_score_result['corpus_id']
82
+ local = df.columns[highest_score_corpus_id]
83
+ recommended_tags = get_tags_for_local(original_dataset, local)
84
+ return f"Recommended category tags: {recommended_tags}"
85
+
86
+ iface = gr.Interface(
87
+ fn=gradio_query_interface,
88
+ inputs="text",
89
+ outputs="label"
90
+ )
91
+
92
+ iface.launch()