harpreetsahota commited on
Commit
267e635
·
verified ·
1 Parent(s): 41f63c2

Update prompt_template.py

Browse files
Files changed (1) hide show
  1. prompt_template.py +24 -113
prompt_template.py CHANGED
@@ -1,118 +1,29 @@
1
- import re
2
  import yaml
3
- from dataclasses import dataclass, field
4
- from typing import Dict, List, Set, Optional
 
5
 
6
  @dataclass
7
  class PromptTemplate:
8
- """
9
- A template class for managing and validating LLM prompts.
10
-
11
- This class handles:
12
- - Storing system and user prompts
13
- - Validating required template variables
14
- - Formatting prompts with provided variables
15
-
16
- Attributes:
17
- system_prompt (str): The system-level instructions for the LLM
18
- user_template (str): Template string with variables in {variable} format
19
- """
20
- system_prompt: str
21
- user_template: str
22
-
23
- def __post_init__(self):
24
- """Initialize the set of required variables from the template."""
25
- self.required_variables: Set[str] = self._get_required_variables()
26
-
27
- def _get_required_variables(self) -> set:
28
- """
29
- Extract required variables from the template using regex.
30
-
31
- Returns:
32
- set: Set of variable names found in the template
33
-
34
- Example:
35
- Template "Write about {topic} in {style}" returns {'topic', 'style'}
36
- """
37
- return set(re.findall(r'\{(\w+)\}', self.user_template))
38
-
39
- def _validate_variables(self, provided_vars: Dict):
40
- """
41
- Ensure all required template variables are provided.
42
-
43
- Args:
44
- provided_vars: Dictionary of variable names and values
45
-
46
- Raises:
47
- ValueError: If any required variables are missing
48
- """
49
- provided_keys = set(provided_vars.keys())
50
- missing_vars = self.required_variables - provided_keys
51
- if missing_vars:
52
- error_msg = (
53
- f"\nPrompt Template Error:\n"
54
- f"Missing required variables: {', '.join(missing_vars)}\n"
55
- f"Template requires: {', '.join(self.required_variables)}\n"
56
- f"You provided: {', '.join(provided_keys)}\n"
57
- f"Template string: '{self.user_template}'"
58
  )
59
- raise ValueError(error_msg)
60
-
61
- def format(self, **kwargs) -> List[Dict[str, str]]:
62
- """
63
- Format the prompt template with provided variables.
64
-
65
- Args:
66
- **kwargs: Key-value pairs for template variables
67
-
68
- Returns:
69
- List[Dict[str, str]]: Formatted messages ready for LLM API
70
-
71
- Example:
72
- template.format(topic="AI", style="academic")
73
- """
74
- self._validate_variables(kwargs)
75
-
76
- try:
77
- formatted_user_message = self.user_template.format(**kwargs)
78
- except Exception as e:
79
- raise ValueError(f"Error formatting template: {str(e)}")
80
-
81
- return [
82
- {"role": "system", "content": self.system_prompt},
83
- {"role": "user", "content": formatted_user_message}
84
- ]
85
-
86
-
87
- def load_prompt(yaml_path: str, version: str = None) -> tuple[PromptTemplate, dict]:
88
- """
89
- Load prompt configuration from YAML file.
90
-
91
- Args:
92
- yaml_path: Path to YAML configuration file
93
- version: Specific version to load (defaults to 'current_version')
94
-
95
- Returns:
96
- tuple: (PromptTemplate instance, generation parameters dictionary)
97
-
98
- Example:
99
- prompt, params = load_prompt('prompts.yaml', version='v2')
100
- """
101
- with open(yaml_path, 'r') as f:
102
- data = yaml.safe_load(f)
103
-
104
- # Use specified version or fall back to current_version
105
- version_to_use = version or data.get('current_version')
106
- if version_to_use not in data:
107
- raise KeyError(f"Version '{version_to_use}' not found in {yaml_path}")
108
-
109
- version_data = data[version_to_use]
110
-
111
- prompt = PromptTemplate(
112
- system_prompt=version_data['system_prompt'],
113
- user_template=version_data['user_template']
114
- )
115
-
116
- generation_params = version_data.get('generation_params', {})
117
-
118
- return prompt, generation_params
 
 
1
  import yaml
2
+ from typing import Dict, Any, Optional
3
+ from openai import OpenAI
4
+ from dataclasses import dataclass
5
 
6
  @dataclass
7
  class PromptTemplate:
8
+ """A simple prompt template that handles parameter substitution and generation settings."""
9
+ template: str
10
+ parameters: Dict[str, Any]
11
+
12
+ def format(self, **kwargs) -> str:
13
+ return self.template.format(**kwargs)
14
+
15
+ class PromptLoader:
16
+ """Loads prompts and their parameters from a YAML file."""
17
+
18
+ @staticmethod
19
+ def load_prompts(yaml_path: str) -> Dict[str, PromptTemplate]:
20
+ with open(yaml_path, 'r') as f:
21
+ data = yaml.safe_load(f)
22
+
23
+ prompts = {}
24
+ for name, content in data.items():
25
+ prompts[name] = PromptTemplate(
26
+ template=content['template'],
27
+ parameters=content.get('parameters', {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
29
+ return prompts