|
import argparse |
|
import json |
|
import os |
|
|
|
from typing import Optional, Tuple |
|
from tqdm.auto import tqdm |
|
|
|
import torch |
|
|
|
from datasets import DatasetDict, load_dataset |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
def check_base_path(path: str) -> Optional[str]: |
|
if path is not None: |
|
base_path = os.path.basename(path) |
|
if os.path.exists(base_path): |
|
return path |
|
else: |
|
raise Exception(f'Path not found {base_path}') |
|
return path |
|
|
|
|
|
def parse_args(): |
|
DEFAULT_MODEL_ID = 'EmbeddingStudio/query-parser-falcon-7b-instruct' |
|
DEFAULT_DATASET = 'EmbeddingStudio/query-parsing-instructions-falcon' |
|
DEFAULT_SPLIT = 'test' |
|
DEFAULT_INSTRUCTION_FIELD = 'text' |
|
DEFAULT_RESPONSE_DELIMITER = '## Response:\n' |
|
DEFAULT_CATEGORY_DELIMITER = '## Category:' |
|
DEFAULT_OUTPUT_PATH = f'{DEFAULT_MODEL_ID.split("/")[-1]}-test.json' |
|
|
|
parser = argparse.ArgumentParser(description='EmbeddingStudio script for testing Zero-Shot Search Query Parsers') |
|
parser.add_argument("--model-id", |
|
help=f"Huggingface model ID (default: {DEFAULT_MODEL_ID})", |
|
default=DEFAULT_MODEL_ID, |
|
type=str, |
|
) |
|
parser.add_argument("--dataset-name", |
|
help=f"Huggingface dataset name which contains instructions (default: {DEFAULT_DATASET})", |
|
default=DEFAULT_DATASET, |
|
type=str, |
|
) |
|
parser.add_argument("--dataset-split", |
|
help=f"Huggingface dataset split name (default: {DEFAULT_SPLIT})", |
|
default=DEFAULT_SPLIT, |
|
type=str, |
|
) |
|
parser.add_argument("--dataset-instructions-field", |
|
help=f"Huggingface dataset field with instructions (default: {DEFAULT_INSTRUCTION_FIELD})", |
|
default=DEFAULT_INSTRUCTION_FIELD, |
|
type=str, |
|
) |
|
parser.add_argument("--instructions-response-delimiter", |
|
help=f"Instruction response delimiter (default: {DEFAULT_RESPONSE_DELIMITER})", |
|
default=DEFAULT_RESPONSE_DELIMITER, |
|
type=str, |
|
) |
|
parser.add_argument("--instructions-category-delimiter", |
|
help=f"Instruction category name delimiter (default: {DEFAULT_CATEGORY_DELIMITER})", |
|
default=DEFAULT_CATEGORY_DELIMITER, |
|
type=str, |
|
) |
|
|
|
parser.add_argument("--output", |
|
help=f"JSON file with test results (default: {DEFAULT_OUTPUT_PATH})", |
|
default=DEFAULT_OUTPUT_PATH, |
|
type=check_base_path, |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def load_model(model_id: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
add_prefix_space=True, |
|
use_fast=False, |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": 0}) |
|
return tokenizer, model |
|
|
|
|
|
@torch.no_grad() |
|
def predict( |
|
tokenizer: AutoTokenizer, |
|
model: AutoModelForCausalLM, |
|
dataset: DatasetDict, |
|
index: int, |
|
field_name: str = 'text', |
|
response_delimiter: str = '## Response:\n', |
|
category_delimiter: str = '## Category: ' |
|
) -> Tuple[dict, dict, str]: |
|
input_text = dataset[index][field_name].split(response_delimiter)[0] + response_delimiter |
|
input_ids = tokenizer.encode(input_text, return_tensors='pt') |
|
real = json.loads(dataset[index][field_name].split(response_delimiter)[-1]) |
|
category = dataset[index][field_name].split(category_delimiter)[-1].split('\n')[0] |
|
|
|
|
|
output = model.generate(input_ids.to('cuda'), |
|
max_new_tokens=1000, |
|
do_sample=True, |
|
temperature=0.05, |
|
pad_token_id=50256 |
|
) |
|
parsed = json.loads(tokenizer.decode(output[0], skip_special_tokens=True).split(response_delimiter)[-1]) |
|
|
|
return [parsed, real, category] |
|
|
|
|
|
@torch.no_grad() |
|
def test_model(model_id: str, |
|
dataset_name: str, |
|
split_name: str, |
|
field_name: str, |
|
response_delimiter: str, |
|
category_delimiter: str, |
|
output_path: str, |
|
|
|
): |
|
dataset = load_dataset(dataset_name, split=split_name) |
|
tokenizer, model = load_model(model_id) |
|
model.eval() |
|
|
|
test_results = [] |
|
for index in tqdm(range(len(dataset[split_name]))): |
|
try: |
|
test_results.append(predict(tokenizer, model, dataset[split_name], index, field_name, response_delimiter, category_delimiter)) |
|
except Exception as e: |
|
continue |
|
|
|
with open(output_path, 'w') as f: |
|
json.dump(test_results) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
test_model( |
|
args.model_id, |
|
args.dataset_name, |
|
args.dataset_split, |
|
args.dataset_instructions_field, |
|
args.instructions_response_delimiter, |
|
args.instructions_category_delimiter, |
|
args.output |
|
) |
|
|