Spaces:
Runtime error
Runtime error
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
from typing import Dict, Tuple | |
from collections.abc import Callable | |
import yaml | |
import argparse | |
import asyncio | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from policy_rag.text_utils import DocLoader | |
from policy_rag.text_utils import get_recursive_token_chunks, get_semantic_chunks | |
from policy_rag.sdg_utils import ragas_sdg, upload_dataset_langsmith | |
from policy_rag.chains import get_qa_chain | |
# Config Options | |
CHUNK_METHOD = { | |
'token-overlap': get_recursive_token_chunks, | |
'semantic': get_semantic_chunks | |
} | |
EMBEDDING_MODEL_SOURCE = { | |
'openai': OpenAIEmbeddings, | |
'huggingface': HuggingFaceEmbeddings | |
} | |
# Helpers | |
def get_chunk_func(chunk_method: Dict) -> Tuple[Callable, Dict]: | |
chunk_func = CHUNK_METHOD[chunk_method['method']] | |
if chunk_method['method'] == 'token-overlap': | |
chunk_func_args = chunk_method['args'] | |
if chunk_method['method'] == 'semantic': | |
args = chunk_method['args'] | |
chunk_func_args = { | |
'embedding_model': EMBEDDING_MODEL_SOURCE[args['model_source']](model=args['model_name']), | |
'breakpoint_type': args['breakpoint_type'] | |
} | |
return chunk_func, chunk_func_args | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config', help='YAML config file') | |
args = parser.parse_args() | |
with open(args.config, 'r') as file: | |
config_yml = yaml.safe_load(file) | |
data_dir = config_yml['data_dir'] | |
chunk_method = config_yml['chunk_method'] | |
n_qa_pairs = config_yml['n_qa_pairs'] | |
ls_project = config_yml['ls_project'] | |
ls_dataset_name = config_yml['ls_dataset_name'] | |
ls_dataset_description = config_yml['ls_dataset_description'] | |
# Load Raw Data | |
print('Loading Docs') | |
loader = DocLoader() | |
docs = loader.load_dir(data_dir) | |
# Chunk Docs | |
print('Chunking Docs') | |
chunk_func, chunk_func_args = get_chunk_func(chunk_method) | |
chunks = chunk_func(docs=docs, **chunk_func_args) | |
print(f"len of chunks: {len(chunks)}") | |
# SDG | |
print('RAGAS SDG') | |
test_set = asyncio.run(ragas_sdg( | |
context_docs=chunks, | |
n_qa_pairs=n_qa_pairs, | |
embedding_model=OpenAIEmbeddings(model='text-embedding-3-small') | |
)) | |
# Save as LangSmith Dataset | |
os.environ['LANGCHAIN_PROJECT'] = ls_project | |
print('Uploading to LangSmith') | |
upload_dataset_langsmith( | |
dataset=test_set, | |
dataset_name=ls_dataset_name, | |
description=ls_dataset_description | |
) |