|
from pathlib import Path |
|
|
|
import yaml |
|
|
|
from modules import utils |
|
from modules.text_generation import get_encoded_length |
|
|
|
|
|
def load_prompt(fname): |
|
if fname in ['None', '']: |
|
return '' |
|
else: |
|
file_path = Path(f'prompts/{fname}.txt') |
|
if not file_path.exists(): |
|
return '' |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
text = f.read() |
|
if text[-1] == '\n': |
|
text = text[:-1] |
|
|
|
return text |
|
|
|
|
|
def load_instruction_prompt_simple(fname): |
|
file_path = Path(f'instruction-templates/{fname}.yaml') |
|
if not file_path.exists(): |
|
return '' |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
data = yaml.safe_load(f) |
|
output = '' |
|
if 'context' in data: |
|
output += data['context'] |
|
|
|
replacements = { |
|
'<|user|>': data['user'], |
|
'<|bot|>': data['bot'], |
|
'<|user-message|>': 'Input', |
|
} |
|
|
|
output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements) |
|
return output.rstrip(' ') |
|
|
|
|
|
def count_tokens(text): |
|
try: |
|
tokens = get_encoded_length(text) |
|
return str(tokens) |
|
except: |
|
return '0' |
|
|