SentimentAnalysis / ai_single_response.py
Hamza1702's picture
Create ai_single_response.py
c818b2a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ai_single_response.py - a script to generate a response to a prompt from a pretrained GPT model
example:
*\gpt2_chatbot> python ai_single_response.py --model "GPT2_conversational_355M_WoW10k" --prompt "hey, what's up?" --time
query_gpt_model is used throughout the code, and is the "fundamental" building block of the bot and how everything works. I would recommend testing this function with a few different models.
"""
import argparse
import pprint as pp
import sys
import time
import warnings
from datetime import datetime
from pathlib import Path
import logging
logging.basicConfig(
filename=f"LOGFILE-{Path(__file__).stem}.log",
filemode="a",
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
from utils import DisableLogger, print_spacer, remove_trailing_punctuation
with DisableLogger():
from cleantext import clean
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
from aitextgen import aitextgen
def extract_response(full_resp: list, plist: list, verbose: bool = False):
"""
extract_response - helper fn for ai_single_response.py. By default aitextgen returns the prompt and the response, we just want the response
Args:
full_resp (list): the full response from aitextgen
plist (list): the prompt list
verbose (bool, optional): Defaults to False.
Returns:
response (str): the response, without the prompt
"""
bot_response = []
for line in full_resp:
if line.lower() in plist and len(bot_response) < len(plist):
first_loc = plist.index(line)
del plist[first_loc]
continue
bot_response.append(line)
full_resp = [clean(ele, lower=False) for ele in bot_response]
if verbose:
print("the isolated responses are:\n")
pp.pprint(full_resp)
print_spacer()
print("the input prompt was:\n")
pp.pprint(plist)
print_spacer()
return full_resp # list of only the model generated responses
def get_bot_response(
name_resp: str, model_resp: list, name_spk: str, verbose: bool = False
):
"""
get_bot_response - gets the bot response to a prompt, checking to ensure that additional statements by the "speaker" are not included in the response.
Args:
name_resp (str): the name of the responder
model_resp (list): the model response
name_spk (str): the name of the speaker
verbose (bool, optional): Defaults to False.
Returns:
bot_response (str): the bot response, isolated down to just text without the "name tokens" or further messages from the speaker.
"""
fn_resp = []
name_counter = 0
break_safe = False
for resline in model_resp:
if name_resp.lower() in resline.lower():
name_counter += 1
break_safe = True
continue
if ":" in resline and name_resp.lower() not in resline.lower():
break
if name_spk.lower() in resline.lower() and not break_safe:
break
else:
fn_resp.append(resline)
if verbose:
print("the full response is:\n")
print("\n".join(fn_resp))
return fn_resp
def query_gpt_model(
folder_path: str or Path,
prompt_msg: str,
conversation_history: list = None,
speaker: str = None,
responder: str = None,
resp_length: int = 48,
kparam: int = 20,
temp: float = 0.4,
top_p: float = 0.9,
aitextgen_obj=None,
verbose: bool = False,
use_gpu: bool = False,
):
"""
query_gpt_model - queries the GPT model and returns the first response by <responder>
Args:
folder_path (str or Path): the path to the model folder
prompt_msg (str): the prompt message
conversation_history (list, optional): the conversation history. Defaults to None.
speaker (str, optional): the name of the speaker. Defaults to None.
responder (str, optional): the name of the responder. Defaults to None.
resp_length (int, optional): the length of the response in tokens. Defaults to 48.
kparam (int, optional): the k parameter for the top_k. Defaults to 40.
temp (float, optional): the temperature for the softmax. Defaults to 0.7.
top_p (float, optional): the top_p parameter for nucleus sampling. Defaults to 0.9.
aitextgen_obj (_type_, optional): a pre-loaded aitextgen object. Defaults to None.
verbose (bool, optional): Defaults to False.
use_gpu (bool, optional): Defaults to False.
Returns:
model_resp (dict): the model response, as a dict with the following keys: out_text (str) the generated text and full_conv (dict) the conversation history
"""
try:
ai = (
aitextgen_obj
if aitextgen_obj
else aitextgen(
model_folder=folder_path,
to_gpu=use_gpu,
)
)
except Exception as e:
print(f"Unable to initialize aitextgen model: {e}")
print(
f"Check model folder: {folder_path}, run the download_models.py script to download the model files"
)
sys.exit(1)
mpath = Path(folder_path)
mpath_base = (
mpath.stem
) # only want the base name of the model folder for check below
# these models used person alpha and person beta in training
mod_ids = ["natqa", "dd", "trivqa", "wow", "conversational"]
if any(substring in str(mpath_base).lower() for substring in mod_ids):
speaker = "person alpha" if speaker is None else speaker
responder = "person beta" if responder is None else responder
else:
if verbose:
print("speaker and responder not set - using default")
speaker = "person" if speaker is None else speaker
responder = "george robot" if responder is None else responder
prompt_list = (
conversation_history if conversation_history is not None else []
) # track conversation
prompt_list.append(speaker.lower() + ":" + "\n")
prompt_list.append(prompt_msg.lower() + "\n")
prompt_list.append("\n")
prompt_list.append(responder.lower() + ":" + "\n")
this_prompt = "".join(prompt_list)
pr_len = len(this_prompt)
if verbose:
print("overall prompt:\n")
pp.pprint(prompt_list)
# call the model
print("\n... generating...")
this_result = ai.generate(
n=1,
top_k=kparam,
batch_size=128,
# the prompt input counts for text length constraints
max_length=resp_length + pr_len,
min_length=16 + pr_len,
prompt=this_prompt,
temperature=temp,
top_p=top_p,
do_sample=True,
return_as_list=True,
use_cache=True,
)
if verbose:
print("\n... generated:\n")
pp.pprint(this_result) # for debugging
# process the full result to get the ~bot response~ piece
this_result = str(this_result[0]).split("\n")
input_prompt = this_prompt.split("\n")
diff_list = extract_response(
this_result, input_prompt, verbose=verbose
) # isolate the responses from the prompts
# extract the bot response from the model generated text
bot_dialogue = get_bot_response(
name_resp=responder, model_resp=diff_list, name_spk=speaker, verbose=verbose
)
bot_resp = ", ".join(bot_dialogue)
bot_resp = remove_trailing_punctuation(
bot_resp.strip()
) # remove trailing punctuation to seem more natural
if verbose:
print("\n... bot response:\n")
pp.pprint(bot_resp)
prompt_list.append(bot_resp + "\n")
prompt_list.append("\n")
conv_history = {}
for i, line in enumerate(prompt_list):
if i not in conv_history.keys():
conv_history[i] = line
if verbose:
print("\n... conversation history:\n")
pp.pprint(conv_history)
print("\nfinished!")
# return the bot response and the full conversation
return {"out_text": bot_resp, "full_conv": conv_history}
# Set up the parsing of command-line arguments
def get_parser():
"""
get_parser [a helper function for the argparse module]
Returns: argparse.ArgumentParser
"""
parser = argparse.ArgumentParser(
description="submit a message and have a pretrained GPT model respond"
)
parser.add_argument(
"-p",
"--prompt",
required=True, # MUST HAVE A PROMPT
type=str,
help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.",
)
parser.add_argument(
"-m",
"--model",
required=False,
type=str,
default="distilgpt2-tiny-conversational",
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
"config.json). You can also pass the huggingface model name (e.g. distilgpt2)",
)
parser.add_argument(
"-s",
"--speaker",
required=False,
default=None,
help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data",
)
parser.add_argument(
"-r",
"--responder",
required=False,
default="person beta",
help="who the responder is. Primarily relevant to bots trained on multi-individual chat data",
)
parser.add_argument(
"--topk",
required=False,
type=int,
default=20,
help="how many responses to sample (positive integer). lower = more random responses",
)
parser.add_argument(
"--temp",
required=False,
type=float,
default=0.4,
help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'",
)
parser.add_argument(
"--topp",
required=False,
type=float,
default=0.9,
help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?",
)
parser.add_argument(
"--resp_length",
required=False,
type=int,
default=50,
help="max length of the response (positive integer)",
)
parser.add_argument(
"-v",
"--verbose",
default=False,
action="store_true",
help="pass this argument if you want all the printouts",
)
parser.add_argument(
"-rt",
"--time",
default=False,
action="store_true",
help="pass this argument if you want to know runtime",
)
parser.add_argument(
"--use_gpu",
required=False,
action="store_true",
help="use gpu if available",
)
return parser
if __name__ == "__main__":
# parse the command line arguments
args = get_parser().parse_args()
query = args.prompt
model_dir = str(args.model)
model_loc = Path.cwd() / model_dir if "/" not in model_dir else model_dir
spkr = args.speaker
rspndr = args.responder
k_results = args.topk
my_temp = args.temp
my_top_p = args.topp
resp_length = args.resp_length
assert resp_length > 0, "response length must be positive"
want_verbose = args.verbose
want_rt = args.time
use_gpu = args.use_gpu
st = time.perf_counter()
resp = query_gpt_model(
folder_path=model_loc,
prompt_msg=query,
speaker=spkr,
responder=rspndr,
kparam=k_results,
temp=my_temp,
top_p=my_top_p,
resp_length=resp_length,
verbose=want_verbose,
use_gpu=use_gpu,
)
output = resp["out_text"]
pp.pprint(output, indent=4)
rt = round(time.perf_counter() - st, 1)
if want_rt:
print("took {runtime} seconds to generate. \n".format(runtime=rt))
if want_verbose:
print("finished - ", datetime.now())
p_list = resp["full_conv"]
print("A transcript of your chat is as follows: \n")
p_list = [item.strip() for item in p_list]
pp.pprint(p_list)