File size: 1,984 Bytes
74b7d77 f641aea 74b7d77 f641aea 74b7d77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import json
from typing import Dict
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
class JsonPromptSegment():
def __init__(self, segment_index: int, prompt: str, format_prompt: bool = False):
self.prompt = prompt
self.segment_index = segment_index
self.format_prompt = format_prompt
class JsonPromptStrategy(AbstractPromptStrategy):
def __init__(self, initial_json_prompt: str):
"""
Parameters
----------
initial_json_prompt: str
The initial prompts for each segment in JSON form.
Format:
[
{"segment_index": 0, "prompt": "Hello, how are you?"},
{"segment_index": 1, "prompt": "I'm doing well, how are you?"},
{"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
]
"""
parsed_json = json.loads(initial_json_prompt)
self.segment_lookup: Dict[str, JsonPromptSegment] = dict()
for prompt_entry in parsed_json:
segment_index = prompt_entry["segment_index"]
prompt = prompt_entry["prompt"]
format_prompt = prompt_entry.get("format_prompt", False)
self.segment_lookup[str(segment_index)] = JsonPromptSegment(segment_index, prompt, format_prompt)
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
# Lookup prompt
prompt = self.segment_lookup.get(str(segment_index), None)
if (prompt is None):
# No prompt found, return whisper prompt
print(f"Could not find prompt for segment {segment_index}, returning whisper prompt")
return whisper_prompt
if (prompt.format_prompt):
return prompt.prompt.format(whisper_prompt)
else:
return self._concat_prompt(prompt.prompt, whisper_prompt)
|