File size: 1,092 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
from typing import Union
from stark_qa.retrieval import STaRKDataset

REGISTERED_DATASETS = ['amazon', 'prime', 'mag']


def load_qa(name: str,
            root: Union[str, None] = None,
            human_generated_eval: bool = False) -> STaRKDataset:
    """
    Load the QA dataset.

    Args:
        name (str): Name of the dataset. One of 'amazon', 'prime', or 'mag'.
        root (Union[str, None]): Root directory to store the dataset. If not provided, the default Hugging Face cache path is used.
        human_generated_eval (bool): Whether to use human-generated evaluation data. Default is False.

    Returns:
        STaRKDataset: The loaded STaRK dataset.

    Raises:
        ValueError: If the dataset name is not registered.
    """
    
    if root is not None:
        if not os.path.isabs(root):
            root = os.path.abspath(root)

    if name in REGISTERED_DATASETS:
        return STaRKDataset(name, root, 
                            human_generated_eval=human_generated_eval)
    else:
        raise ValueError(f"Unknown dataset {name}")