Spaces:
Build error
Build error
import copy | |
import json | |
import string | |
import random | |
import asyncio | |
from modules.llms import get_llm_factory | |
from pingpong.context import CtxLastWindowStrategy | |
def add_side_character_to_export( | |
characters, enable, img, | |
name, age, personality, job | |
): | |
if enable: | |
characters.append( | |
{ | |
'img': img, | |
'name': name | |
} | |
) | |
return characters | |
def add_side_character(enable, name, age, personality, job, llm_type="PaLM"): | |
prompts = get_llm_factory(llm_type).create_prompt_manager().prompts | |
cur_side_chars = 1 | |
prompt = "" | |
for idx in range(len(enable)): | |
if enable[idx]: | |
prompt += prompts['story_gen']['add_side_character'].format( | |
cur_side_chars=cur_side_chars, | |
name=name[idx], | |
job=job[idx], | |
age=age[idx], | |
personality=personality[idx] | |
) | |
cur_side_chars += 1 | |
return "\n" + prompt if prompt else "" | |
def id_generator(size=6, chars=string.ascii_uppercase + string.digits): | |
return ''.join(random.choice(chars) for _ in range(size)) | |
def parse_first_json_code_snippet(code_snippet): | |
json_parsed_string = None | |
try: | |
json_parsed_string = json.loads(code_snippet, strict=False) | |
except: | |
json_start_index = code_snippet.find('```json') | |
json_end_index = code_snippet.find('```', json_start_index + 6) | |
if json_start_index < 0 or json_end_index < 0: | |
raise ValueError('No JSON code snippet found in string.') | |
json_code_snippet = code_snippet[json_start_index + 7:json_end_index] | |
json_parsed_string = json.loads(json_code_snippet, strict=False) | |
finally: | |
if json_parsed_string is None: | |
raise ValueError('No JSON code snippet found in string.') | |
return json_parsed_string | |
async def retry_until_valid_json(prompt, parameters=None, llm_type="PaLM"): | |
response_json = None | |
factory = get_llm_factory(llm_type) | |
llm_service = factory.create_llm_service() | |
for _ in range(3): | |
try: | |
response, response_txt = await asyncio.wait_for(llm_service.gen_text(prompt, mode="text", parameters=parameters), | |
timeout=10) | |
print(response_txt) | |
except asyncio.TimeoutError: | |
raise TimeoutError(f"The response time for {llm_type} API exceeded the limit.") | |
except Exception as e: | |
print(f"{llm_type} API has encountered an error. Retrying...") | |
continue | |
try: | |
response_json = parse_first_json_code_snippet(response_txt) | |
if not response_json: | |
print("Parsing JSON failed. Retrying...") | |
continue | |
except: | |
print("Parsing JSON failed. Retrying...") | |
pass | |
if len(response.filters) > 0: | |
raise ValueError(f"{llm_type} API has withheld a response due to content safety concerns.") | |
elif response_json is None: | |
print("=== Failed to generate valid JSON response. ===") | |
print(response_txt) | |
raise ValueError("Failed to generate valid JSON response.") | |
return response_json | |
def build_prompts(ppm, win_size=3): | |
dummy_ppm = copy.deepcopy(ppm) | |
lws = CtxLastWindowStrategy(win_size) | |
return lws(dummy_ppm) | |
async def get_chat_response(prompt, ctx=None, llm_type="PaLM"): | |
factory = get_llm_factory(llm_type) | |
llm_service = factory.create_llm_service() | |
parameters = llm_service.make_params(mode="chat", temperature=1.0, top_k=50, top_p=0.9) | |
_, response_txt = await llm_service.gen_text( | |
prompt, | |
parameters=parameters | |
) | |
return response_txt | |