Spaces:
Build error
Build error
File size: 3,281 Bytes
a1ca2de 3332aa4 a1ca2de 3332aa4 936d161 a1ca2de ceefdf5 936d161 ceefdf5 3332aa4 936d161 3332aa4 936d161 a1ca2de 3332aa4 a1ca2de 3332aa4 a1ca2de 3332aa4 a1ca2de 936d161 3332aa4 936d161 3332aa4 936d161 3332aa4 936d161 a1ca2de 3332aa4 a1ca2de 936d161 a1ca2de 936d161 3332aa4 936d161 a1ca2de 3332aa4 a1ca2de 3332aa4 a1ca2de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import copy
import json
import string
import random
import asyncio
from modules.llms import get_llm_factory
from pingpong.context import CtxLastWindowStrategy
def add_side_character_to_export(
characters, enable, img,
name, age, personality, job
):
if enable:
characters.append(
{
'img': img,
'name': name
}
)
return characters
def add_side_character(enable, name, age, personality, job, llm_type="PaLM"):
prompts = get_llm_factory(llm_type).create_prompt_manager().prompts
cur_side_chars = 1
prompt = ""
for idx in range(len(enable)):
if enable[idx]:
prompt += prompts['story_gen']['add_side_character'].format(
cur_side_chars=cur_side_chars,
name=name[idx],
job=job[idx],
age=age[idx],
personality=personality[idx]
)
cur_side_chars += 1
return "\n" + prompt if prompt else ""
def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def parse_first_json_code_snippet(code_snippet):
json_parsed_string = None
try:
json_parsed_string = json.loads(code_snippet, strict=False)
except:
json_start_index = code_snippet.find('```json')
json_end_index = code_snippet.find('```', json_start_index + 6)
if json_start_index < 0 or json_end_index < 0:
raise ValueError('No JSON code snippet found in string.')
json_code_snippet = code_snippet[json_start_index + 7:json_end_index]
json_parsed_string = json.loads(json_code_snippet, strict=False)
finally:
if json_parsed_string is None:
raise ValueError('No JSON code snippet found in string.')
return json_parsed_string
async def retry_until_valid_json(prompt, parameters=None, llm_type="PaLM"):
response_json = None
factory = get_llm_factory(llm_type)
llm_service = factory.create_llm_service()
for _ in range(3):
try:
response, response_txt = await asyncio.wait_for(llm_service.gen_text(prompt, mode="text", parameters=parameters),
timeout=10)
print(response_txt)
except asyncio.TimeoutError:
raise TimeoutError(f"The response time for {llm_type} API exceeded the limit.")
except Exception as e:
print(f"{llm_type} API has encountered an error. Retrying...")
continue
try:
response_json = parse_first_json_code_snippet(response_txt)
if not response_json:
print("Parsing JSON failed. Retrying...")
continue
except:
print("Parsing JSON failed. Retrying...")
pass
if len(response.filters) > 0:
raise ValueError(f"{llm_type} API has withheld a response due to content safety concerns.")
elif response_json is None:
print("=== Failed to generate valid JSON response. ===")
print(response_txt)
raise ValueError("Failed to generate valid JSON response.")
return response_json
def build_prompts(ppm, win_size=3):
dummy_ppm = copy.deepcopy(ppm)
lws = CtxLastWindowStrategy(win_size)
return lws(dummy_ppm)
async def get_chat_response(prompt, ctx=None, llm_type="PaLM"):
factory = get_llm_factory(llm_type)
llm_service = factory.create_llm_service()
parameters = llm_service.make_params(mode="chat", temperature=1.0, top_k=50, top_p=0.9)
_, response_txt = await llm_service.gen_text(
prompt,
parameters=parameters
)
return response_txt
|