terapyon commited on
Commit
49e6454
·
unverified ·
2 Parent(s): 0c03b81 b9061bc

Merge pull request #2 from terapyon/terada/mt-239-duckdb-structure

Browse files
Files changed (4) hide show
  1. src/config.py +6 -0
  2. src/embedding.py +20 -0
  3. src/episode.py +1 -1
  4. src/store.py +108 -0
src/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ # import logging
3
+
4
+
5
+ HERE = Path(__file__).resolve().parent
6
+ DUCKDB_FILE = HERE.parent / "db" / "terapyon-podcast.duckdb"
src/embedding.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sentence_transformers import SentenceTransformer
3
+
4
+ MODEL_NAME = "cl-nagoya/ruri-large"
5
+ PREFIX_QUERY = "クエリ: " # "query: "
6
+ PASSAGE_QUERY = "文章: " # "passage: "
7
+
8
+ model = SentenceTransformer(MODEL_NAME)
9
+
10
+
11
+ def get_embeddings(texts: list[str], query=False, passage=False) -> np.ndarray:
12
+ if query:
13
+ texts = [PREFIX_QUERY + text for text in texts]
14
+ if passage:
15
+ texts = [PASSAGE_QUERY + text for text in texts]
16
+ # texts = [text[i : i + CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
17
+ embeddings = model.encode(texts)
18
+ # print(embeddings.shape)
19
+ # print(type(embeddings))
20
+ return embeddings
src/episode.py CHANGED
@@ -87,7 +87,7 @@ def make_df(episode: Episode) -> pd.DataFrame:
87
  data = []
88
  for text in episode.texts:
89
  data.append([episode.id_, text.part, text.start, text.end, text.text])
90
- df = pd.DataFrame(data, columns=["id", "part", "start", "end", "text"])
91
  return df
92
 
93
 
 
87
  data = []
88
  for text in episode.texts:
89
  data.append([episode.id_, text.part, text.start, text.end, text.text])
90
+ df = pd.DataFrame(data, columns=["id", "part", "start", "end_", "text"])
91
  return df
92
 
93
 
src/store.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import duckdb
3
+ from embedding import get_embeddings
4
+ from config import DUCKDB_FILE
5
+
6
+
7
+ HERE = Path(__file__).parent
8
+ STORE_DIR = HERE.parent / "store"
9
+
10
+
11
+ def create_table():
12
+ conn = duckdb.connect(DUCKDB_FILE)
13
+ podcasts_create = """CREATE TABLE podcasts (
14
+ id BIGINT PRIMARY KEY,
15
+ title TEXT, date DATE, guests TEXT[], length BIGINT, audio TEXT
16
+ );
17
+ """
18
+ episodes_create = """CREATE TABLE episodes (
19
+ id BIGINT, part BIGINT, start BIGINT, end_ BIGINT, text TEXT,
20
+ PRIMARY KEY (id, part)
21
+ );
22
+ """
23
+ embeddings_create = """CREATE TABLE embeddings (
24
+ id BIGINT, part BIGINT, embedding FLOAT[1024],
25
+ PRIMARY KEY (id, part)
26
+ );
27
+ """
28
+ conn.execute(podcasts_create)
29
+ conn.execute(episodes_create)
30
+ conn.execute(embeddings_create)
31
+ conn.commit()
32
+ conn.close()
33
+ print("Tables created.")
34
+
35
+
36
+ def insert_podcast():
37
+ conn = duckdb.connect(DUCKDB_FILE)
38
+ sql = """INSERT INTO podcasts
39
+ SELECT id, title, date, [], length, audio
40
+ FROM read_parquet(?);
41
+ """
42
+ conn.execute(sql, [str(STORE_DIR / 'title-list-202301-202501.parquet')])
43
+ conn.commit()
44
+ conn.close()
45
+
46
+
47
+ def insert_episodes():
48
+ conn = duckdb.connect(DUCKDB_FILE)
49
+ sql = """INSERT INTO episodes
50
+ SELECT id, part, start, end_, text
51
+ FROM read_parquet(?);
52
+ """
53
+ conn.execute(sql, [str(STORE_DIR / 'podcast-*.parquet')])
54
+ conn.commit()
55
+ conn.close()
56
+
57
+
58
+ def embed_store():
59
+ conn = duckdb.connect(DUCKDB_FILE)
60
+ sql_select = """SELECT id, part, text FROM episodes;"""
61
+ data = conn.execute(sql_select).df()
62
+ targets = data["text"].tolist()
63
+ enbeddings = get_embeddings(targets)
64
+ for id_, part, emb in zip(data["id"], data["part"], enbeddings):
65
+ # print(id_, title)
66
+ conn.execute(
67
+ "INSERT INTO embeddings VALUES (?, ?, ?)", (id_, part, emb.tolist())
68
+ )
69
+ conn.commit()
70
+ conn.close()
71
+
72
+
73
+ def create_index():
74
+ conn = duckdb.connect(DUCKDB_FILE)
75
+ conn.execute("LOAD vss;")
76
+ conn.execute("SET hnsw_enable_experimental_persistence=true;")
77
+ conn.execute("""CREATE INDEX embeddings_index
78
+ ON embeddings USING HNSW (embedding);""")
79
+ conn.commit()
80
+ conn.close()
81
+
82
+
83
+ if __name__ == "__main__":
84
+ import sys
85
+ args = sys.argv
86
+ if len(args) == 2:
87
+ if args[1] == "create":
88
+ create_table()
89
+ elif args[1] == "podcastinsert":
90
+ insert_podcast()
91
+ elif args[1] == "episodeinsert":
92
+ insert_episodes()
93
+ elif args[1] == "embed":
94
+ embed_store()
95
+ elif args[1] == "index":
96
+ create_index()
97
+ elif args[1] == "all":
98
+ create_table()
99
+ insert_podcast()
100
+ insert_episodes()
101
+ embed_store()
102
+ create_index()
103
+ else:
104
+ print("Usage: python store.py all")
105
+ sys.exit(1)
106
+ else:
107
+ print("Usage: python store.py create")
108
+ sys.exit(1)