chilly-magician's picture
[add]: test parser script
6a7f508
import argparse
import json
import os
from typing import Optional, Tuple
from tqdm.auto import tqdm
import torch
from datasets import DatasetDict, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
def check_base_path(path: str) -> Optional[str]:
if path is not None:
base_path = os.path.basename(path)
if os.path.exists(base_path):
return path
else:
raise Exception(f'Path not found {base_path}')
return path
def parse_args():
DEFAULT_MODEL_ID = 'EmbeddingStudio/query-parser-falcon-7b-instruct'
DEFAULT_DATASET = 'EmbeddingStudio/query-parsing-instructions-falcon'
DEFAULT_SPLIT = 'test'
DEFAULT_INSTRUCTION_FIELD = 'text'
DEFAULT_RESPONSE_DELIMITER = '## Response:\n'
DEFAULT_CATEGORY_DELIMITER = '## Category:'
DEFAULT_OUTPUT_PATH = f'{DEFAULT_MODEL_ID.split("/")[-1]}-test.json'
parser = argparse.ArgumentParser(description='EmbeddingStudio script for testing Zero-Shot Search Query Parsers')
parser.add_argument("--model-id",
help=f"Huggingface model ID (default: {DEFAULT_MODEL_ID})",
default=DEFAULT_MODEL_ID,
type=str,
)
parser.add_argument("--dataset-name",
help=f"Huggingface dataset name which contains instructions (default: {DEFAULT_DATASET})",
default=DEFAULT_DATASET,
type=str,
)
parser.add_argument("--dataset-split",
help=f"Huggingface dataset split name (default: {DEFAULT_SPLIT})",
default=DEFAULT_SPLIT,
type=str,
)
parser.add_argument("--dataset-instructions-field",
help=f"Huggingface dataset field with instructions (default: {DEFAULT_INSTRUCTION_FIELD})",
default=DEFAULT_INSTRUCTION_FIELD,
type=str,
)
parser.add_argument("--instructions-response-delimiter",
help=f"Instruction response delimiter (default: {DEFAULT_RESPONSE_DELIMITER})",
default=DEFAULT_RESPONSE_DELIMITER,
type=str,
)
parser.add_argument("--instructions-category-delimiter",
help=f"Instruction category name delimiter (default: {DEFAULT_CATEGORY_DELIMITER})",
default=DEFAULT_CATEGORY_DELIMITER,
type=str,
)
parser.add_argument("--output",
help=f"JSON file with test results (default: {DEFAULT_OUTPUT_PATH})",
default=DEFAULT_OUTPUT_PATH,
type=check_base_path,
)
args = parser.parse_args()
return args
def load_model(model_id: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
add_prefix_space=True,
use_fast=False,
)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": 0})
return tokenizer, model
@torch.no_grad()
def predict(
tokenizer: AutoTokenizer,
model: AutoModelForCausalLM,
dataset: DatasetDict,
index: int,
field_name: str = 'text',
response_delimiter: str = '## Response:\n',
category_delimiter: str = '## Category: '
) -> Tuple[dict, dict, str]:
input_text = dataset[index][field_name].split(response_delimiter)[0] + response_delimiter
input_ids = tokenizer.encode(input_text, return_tensors='pt')
real = json.loads(dataset[index][field_name].split(response_delimiter)[-1])
category = dataset[index][field_name].split(category_delimiter)[-1].split('\n')[0]
# Generating text
output = model.generate(input_ids.to('cuda'),
max_new_tokens=1000,
do_sample=True,
temperature=0.05,
pad_token_id=50256
)
parsed = json.loads(tokenizer.decode(output[0], skip_special_tokens=True).split(response_delimiter)[-1])
return [parsed, real, category]
@torch.no_grad()
def test_model(model_id: str,
dataset_name: str,
split_name: str,
field_name: str,
response_delimiter: str,
category_delimiter: str,
output_path: str,
):
dataset = load_dataset(dataset_name, split=split_name)
tokenizer, model = load_model(model_id)
model.eval()
test_results = []
for index in tqdm(range(len(dataset[split_name]))):
try:
test_results.append(predict(tokenizer, model, dataset[split_name], index, field_name, response_delimiter, category_delimiter))
except Exception as e:
continue
with open(output_path, 'w') as f:
json.dump(test_results)
if __name__ == '__main__':
args = parse_args()
test_model(
args.model_id,
args.dataset_name,
args.dataset_split,
args.dataset_instructions_field,
args.instructions_response_delimiter,
args.instructions_category_delimiter,
args.output
)