import json import chromadb from datetime import datetime import math from utils.general_utils import timeit from utils.embedding_utils import MyEmbeddingFunction from youtube_transcript_api import YouTubeTranscriptApi @timeit def run_etl(json_path="data/videos.json", db=None, batch_size=None, overlap=None): with open(json_path) as f: video_info = json.load(f) videos = [] for video in video_info: video_id = video["id"] video_title = video["title"] transcript = get_video_transcript(video_id) print(f"Transcript for video {video_id} fetched.") if transcript: formatted_transcript = format_transcript(transcript, video_id, video_title, batch_size=batch_size, overlap=overlap) videos.extend(formatted_transcript) if db: initialize_db(db) load_data_to_db(db, videos) log_data_load(json_path, db, batch_size, overlap) else: print("No database specified. Skipping database load.") print(videos) @timeit def get_video_transcript(video_id): try: transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en', 'en-US']) return transcript except Exception as e: print(f"Error fetching transcript for video {video_id}: {str(e)}") return None def format_transcript(transcript, video_id, video_title, batch_size=None, overlap=None): formatted_data = [] base_url = f"https://www.youtube.com/watch?v={video_id}" query_params = "&t={start}s" if not batch_size: batch_size = 1 overlap = 0 for i in range(0, len(transcript), batch_size - overlap): batch = list(transcript[i:i+batch_size]) start_time = batch[0]["start"] text = " ".join(entry["text"] for entry in batch) url = base_url + query_params.format(start=start_time) metadata = { "video_id": video_id, "segment_id": video_id + "__" + str(i), "title": video_title, "source": url } segment = {"text": text, "metadata": metadata} formatted_data.append(segment) return formatted_data embed_text = MyEmbeddingFunction() def initialize_db(db_path, distance_metric="cosine"): client = chromadb.PersistentClient(path=db_path) # Clear existing data # client.reset() client.create_collection( name="huberman_videos", embedding_function=embed_text, metadata={"hnsw:space": distance_metric} ) print(f"Database created at {db_path}") def load_data_to_db(db_path, data): client = chromadb.PersistentClient(path=db_path) collection = client.get_collection("huberman_videos") num_rows = len(data) batch_size = 5461 num_batches = math.ceil(num_rows / batch_size) for i in range(num_batches): batch_data = data[i * batch_size : (i + 1) * batch_size] documents = [segment['text'] for segment in batch_data] metadata = [segment['metadata'] for segment in batch_data] ids = [segment['metadata']['segment_id'] for segment in batch_data] collection.add( documents=documents, metadatas=metadata, ids=ids ) print(f"Batch {i+1} of {num_batches} loaded to database.") print(f"Data loaded to database at {db_path}.") def log_data_load(json_path, db_path, batch_size, overlap): log_json = json.dumps({ "videos_info_path": json_path, "db_path": db_path, "batch_size": batch_size, "overlap": overlap, "load_time": str(datetime.now()) }) db_file = db_path.split("/")[-1] db_name = db_file.split(".")[0] log_path = f"data/logs/{db_name}_load_log.json" with open(log_path, "w") as f: f.write(log_json)