|
"""This module provides methods to save and load checkpoints and prompts in a JSON file.""" |
|
import sys |
|
import os |
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")) |
|
from scripts.Logger import Logger |
|
|
|
import json |
|
from typing import Dict, List, Tuple |
|
|
|
class Save(): |
|
""" |
|
saves and loads checkpoints and prompts in a JSON |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.file_name = "batchCheckpointPromptValues.json" |
|
self.logger = Logger() |
|
self.logger.debug = False |
|
|
|
def read_file(self) -> Dict[str, Tuple[str, str]]: |
|
"""Read the JSON file and return the data |
|
|
|
Returns: |
|
Dict[str, Tuple[str, str]]: the data from the JSON file. |
|
The key is the name of the save and the value is a tuple of checkpoints and prompts |
|
""" |
|
try: |
|
with open(self.file_name, 'r') as f: |
|
data = json.load(f) |
|
return data |
|
except FileNotFoundError: |
|
return {"None": ("", "")} |
|
|
|
def store_values(self, name: str, checkpoints: str, prompts: str, overwrite_existing_save: bool, append_existing_save: bool) -> str: |
|
"""Store the checkpoints and prompts in a JSON file |
|
|
|
Args: |
|
name (str): the name of the save |
|
checkpoints (str): the checkpoints |
|
prompts (str): the prompts |
|
overwrite_existing_save (bool): if True, overwrite a existing save with the same name |
|
append_existing_save (bool): if True, append a existing save with the same name |
|
|
|
Returns: |
|
str: a message that indicates if the save was successful |
|
""" |
|
data = {} |
|
|
|
|
|
if os.path.exists(self.file_name): |
|
data = self.read_file() |
|
|
|
|
|
|
|
if name in data and not overwrite_existing_save and not append_existing_save: |
|
self.logger.log_info("Name already exists") |
|
return f'Name "{name}" already exists' |
|
|
|
if append_existing_save: |
|
self.logger.debug_log(f"Name: {name}") |
|
read_values = self.read_value(name) |
|
self.logger.pretty_debug_log(read_values) |
|
checkpoints_list = [read_values[0], checkpoints] |
|
prompts_list = [read_values[1], prompts] |
|
checkpoints = ",\n".join(checkpoints_list) |
|
prompts = ";\n".join(prompts_list) |
|
|
|
|
|
data[name] = (checkpoints, prompts) |
|
|
|
|
|
with open(self.file_name, 'w') as f: |
|
json.dump(data, f) |
|
|
|
self.logger.log_info("saved checkpoints and Prompts") |
|
if append_existing_save: |
|
return f'Appended "{name}"' |
|
elif overwrite_existing_save: |
|
return f'Overwrote "{name}"' |
|
else: |
|
return f'Saved "{name}"' |
|
|
|
def read_value(self, name: str) -> Tuple[str, str]: |
|
"""Get the checkpoints and prompts from a save |
|
|
|
Args: |
|
name (str): the name of the save |
|
|
|
Returns: |
|
Tuple[str, str]: the checkpoints and prompts |
|
""" |
|
data = {} |
|
|
|
if os.path.exists(self.file_name): |
|
data = self.read_file() |
|
else: |
|
raise RuntimeError("no save file found") |
|
|
|
x, y = tuple(data[name]) |
|
self.logger.log_info("loaded save") |
|
|
|
return x, y |
|
|
|
def get_keys(self) -> List[str]: |
|
"""Get the keys from the JSON file |
|
|
|
Returns: |
|
List[str]: a list of keys |
|
""" |
|
data = self.read_file() |
|
return list(data.keys()) |
|
|