chenglu commited on
Commit
5f36c2e
·
1 Parent(s): 3f9de00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -19,21 +19,26 @@ nltk.download('stopwords')
19
  # 从环境变量中获取 hf_token
20
  hf_token = os.getenv('HF_TOKEN')
21
 
 
22
  model_id = "BAAI/bge-large-en-v1.5"
23
- api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
24
- headers = {"Authorization": f"Bearer {hf_token}"}
25
-
26
- @retry(tries=3, delay=10)
27
- def query(texts):
28
- response = requests.post(api_url, headers=headers, json={"inputs": texts})
29
- if response.status_code == 200:
30
- result = response.json()
31
- if isinstance(result, list):
32
- return result
33
- elif 'error' in result:
34
- raise RuntimeError("Error from Hugging Face API: " + result['error'])
35
- else:
36
- raise RuntimeError("Failed to get response from Hugging Face API, status code: " + str(response.status_code))
 
 
 
 
37
 
38
  # 加载嵌入向量数据集
39
  faqs_embeddings_dataset = load_dataset('chenglu/hf-blogs-baai-embeddings')
@@ -71,7 +76,8 @@ def get_tags_for_local(dataset, local_value):
71
  def gradio_query_interface(input_text):
72
  cleaned_text = clean_content(input_text)
73
  no_stopwords_text = remove_stopwords(cleaned_text)
74
- new_embedding = query(no_stopwords_text)
 
75
  query_embeddings = torch.FloatTensor(new_embedding)
76
  hits = util.semantic_search(query_embeddings, dataset_embeddings, top_k=5)
77
  if all(hit['score'] < 0.6 for hit in hits[0]):
 
19
  # 从环境变量中获取 hf_token
20
  hf_token = os.getenv('HF_TOKEN')
21
 
22
+
23
  model_id = "BAAI/bge-large-en-v1.5"
24
+ feature_extraction_pipeline = pipeline("feature-extraction", model=model_id)
25
+
26
+
27
+ # model_id = "BAAI/bge-large-en-v1.5"
28
+ # api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
29
+ # headers = {"Authorization": f"Bearer {hf_token}"}
30
+
31
+ # @retry(tries=3, delay=10)
32
+ # def query(texts):
33
+ # response = requests.post(api_url, headers=headers, json={"inputs": texts})
34
+ # if response.status_code == 200:
35
+ # result = response.json()
36
+ # if isinstance(result, list):
37
+ # return result
38
+ # elif 'error' in result:
39
+ # raise RuntimeError("Error from Hugging Face API: " + result['error'])
40
+ # else:
41
+ # raise RuntimeError("Failed to get response from Hugging Face API, status code: " + str(response.status_code))
42
 
43
  # 加载嵌入向量数据集
44
  faqs_embeddings_dataset = load_dataset('chenglu/hf-blogs-baai-embeddings')
 
76
  def gradio_query_interface(input_text):
77
  cleaned_text = clean_content(input_text)
78
  no_stopwords_text = remove_stopwords(cleaned_text)
79
+ # new_embedding = query(no_stopwords_text)
80
+ new_embedding = feature_extraction_pipeline(input_text)
81
  query_embeddings = torch.FloatTensor(new_embedding)
82
  hits = util.semantic_search(query_embeddings, dataset_embeddings, top_k=5)
83
  if all(hit['score'] < 0.6 for hit in hits[0]):