Spaces:
Runtime error
Runtime error
File size: 3,972 Bytes
0685af6 61e6bcc 0685af6 61e6bcc 0685af6 61e6bcc 0685af6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import json
import chromadb
from datetime import datetime
import math
from utils.general_utils import timeit
from utils.embedding_utils import MyEmbeddingFunction
from youtube_transcript_api import YouTubeTranscriptApi
@timeit
def run_etl(json_path="data/videos.json", db=None, batch_size=None, overlap=None):
with open(json_path) as f:
video_info = json.load(f)
videos = []
for video in video_info:
video_id = video["id"]
video_title = video["title"]
transcript = get_video_transcript(video_id)
print(f"Transcript for video {video_id} fetched.")
if transcript:
formatted_transcript = format_transcript(transcript, video_id, video_title, batch_size=batch_size, overlap=overlap)
videos.extend(formatted_transcript)
if db:
initialize_db(db)
load_data_to_db(db, videos)
log_data_load(json_path, db, batch_size, overlap)
else:
print("No database specified. Skipping database load.")
print(videos)
@timeit
def get_video_transcript(video_id):
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en', 'en-US'])
return transcript
except Exception as e:
print(f"Error fetching transcript for video {video_id}: {str(e)}")
return None
def format_transcript(transcript, video_id, video_title, batch_size=None, overlap=None):
formatted_data = []
base_url = f"https://www.youtube.com/watch?v={video_id}"
query_params = "&t={start}s"
if not batch_size:
batch_size = 1
overlap = 0
for i in range(0, len(transcript), batch_size - overlap):
batch = list(transcript[i:i+batch_size])
start_time = batch[0]["start"]
text = " ".join(entry["text"] for entry in batch)
url = base_url + query_params.format(start=start_time)
metadata = {
"video_id": video_id,
"segment_id": video_id + "__" + str(i),
"title": video_title,
"source": url
}
segment = {"text": text, "metadata": metadata}
formatted_data.append(segment)
return formatted_data
embed_text = MyEmbeddingFunction()
def initialize_db(db_path, distance_metric="cosine"):
client = chromadb.PersistentClient(path=db_path)
# Clear existing data
# client.reset()
client.create_collection(
name="huberman_videos",
embedding_function=embed_text,
metadata={"hnsw:space": distance_metric}
)
print(f"Database created at {db_path}")
def load_data_to_db(db_path, data):
client = chromadb.PersistentClient(path=db_path)
collection = client.get_collection("huberman_videos")
num_rows = len(data)
batch_size = 5461
num_batches = math.ceil(num_rows / batch_size)
for i in range(num_batches):
batch_data = data[i * batch_size : (i + 1) * batch_size]
documents = [segment['text'] for segment in batch_data]
metadata = [segment['metadata'] for segment in batch_data]
ids = [segment['metadata']['segment_id'] for segment in batch_data]
collection.add(
documents=documents,
metadatas=metadata,
ids=ids
)
print(f"Batch {i+1} of {num_batches} loaded to database.")
print(f"Data loaded to database at {db_path}.")
def log_data_load(json_path, db_path, batch_size, overlap):
log_json = json.dumps({
"videos_info_path": json_path,
"db_path": db_path,
"batch_size": batch_size,
"overlap": overlap,
"load_time": str(datetime.now())
})
db_file = db_path.split("/")[-1]
db_name = db_file.split(".")[0]
log_path = f"data/logs/{db_name}_load_log.json"
with open(log_path, "w") as f:
f.write(log_json) |