lihuigu commited on
Commit
4117eaa
·
1 Parent(s): cc402d9

update embedding model load

Browse files
Files changed (3) hide show
  1. app.py +2 -1
  2. configs/datasets.yaml +2 -6
  3. src/utils/hash.py +14 -0
app.py CHANGED
@@ -3,10 +3,11 @@ import os
3
  sys.path.append("./src")
4
  import streamlit as st
5
  from pages import button_interface, step_by_step_generation, one_click_generation
6
- from utils.hash import check_env
7
 
8
  if __name__ == "__main__":
9
  check_env()
 
10
  backend = button_interface.Backend()
11
  st.set_page_config(layout="wide")
12
  def fn1():
 
3
  sys.path.append("./src")
4
  import streamlit as st
5
  from pages import button_interface, step_by_step_generation, one_click_generation
6
+ from utils.hash import check_env, check_embedding
7
 
8
  if __name__ == "__main__":
9
  check_env()
10
+ check_embedding()
11
  backend = button_interface.Backend()
12
  st.set_page_config(layout="wide")
13
  def fn1():
configs/datasets.yaml CHANGED
@@ -3,7 +3,7 @@ DEFAULT:
3
  ignore_paper_id_list: ./assets/data/ignore_paper_id_list.json
4
  log_level: "DEBUG"
5
  log_dir: ./log
6
- embedding: "sentence-transformers/all-MiniLM-L6-v2"
7
  device: "cpu" # "cpu"
8
 
9
  ARTICLE:
@@ -31,8 +31,4 @@ RETRIEVE:
31
  s_bg: 0
32
  s_contribution: 0.5
33
  s_summary: 0.5
34
- similarity_threshold: 0.55
35
-
36
- used_llms_apis:
37
- summarization: ZhipuAI
38
- generation: OpenAI
 
3
  ignore_paper_id_list: ./assets/data/ignore_paper_id_list.json
4
  log_level: "DEBUG"
5
  log_dir: ./log
6
+ embedding: ./assets/model/sentence-transformers/all-MiniLM-L6-v2
7
  device: "cpu" # "cpu"
8
 
9
  ARTICLE:
 
31
  s_bg: 0
32
  s_contribution: 0.5
33
  s_summary: 0.5
34
+ similarity_threshold: 0.55
 
 
 
 
src/utils/hash.py CHANGED
@@ -3,7 +3,21 @@ import os
3
  import hashlib
4
  import struct
5
  from collections import Counter
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def check_env():
9
  env_name_list = [
 
3
  import hashlib
4
  import struct
5
  from collections import Counter
6
+ from huggingface_hub import hf_hub_download
7
 
8
+ def check_embedding():
9
+ # Define the repository and files to download
10
+ repo_id = "sentence-transformers/all-MiniLM-L6-v2" # "BAAI/bge-small-en-v1.5"
11
+ files_to_download = [
12
+ "config.json",
13
+ "pytorch_model.bin",
14
+ "tokenizer_config.json",
15
+ "vocab.txt",
16
+ ]
17
+ # Download each file and save it to the /model/bge directory
18
+ for file_name in files_to_download:
19
+ print("Checking for file: ", file_name)
20
+ hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=f"./assets/model/{repo_id}")
21
 
22
  def check_env():
23
  env_name_list = [