File size: 3,727 Bytes
f4a41d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
"""This module provides methods to save and load checkpoints and prompts in a JSON file."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts"))
from scripts.Logger import Logger
import json
from typing import Dict, List, Tuple
class Save():
"""
saves and loads checkpoints and prompts in a JSON
"""
def __init__(self) -> None:
self.file_name = "batchCheckpointPromptValues.json"
self.logger = Logger()
self.logger.debug = False
def read_file(self) -> Dict[str, Tuple[str, str]]:
"""Read the JSON file and return the data
Returns:
Dict[str, Tuple[str, str]]: the data from the JSON file.
The key is the name of the save and the value is a tuple of checkpoints and prompts
"""
try:
with open(self.file_name, 'r') as f:
data = json.load(f)
return data
except FileNotFoundError:
return {"None": ("", "")}
def store_values(self, name: str, checkpoints: str, prompts: str, overwrite_existing_save: bool, append_existing_save: bool) -> str:
"""Store the checkpoints and prompts in a JSON file
Args:
name (str): the name of the save
checkpoints (str): the checkpoints
prompts (str): the prompts
overwrite_existing_save (bool): if True, overwrite a existing save with the same name
append_existing_save (bool): if True, append a existing save with the same name
Returns:
str: a message that indicates if the save was successful
"""
data = {}
# If the JSON file already exists, load the data into the dictionary
if os.path.exists(self.file_name):
data = self.read_file()
# Check if the name already exists in the data dictionary
if name in data and not overwrite_existing_save and not append_existing_save:
self.logger.log_info("Name already exists")
return f'Name "{name}" already exists'
if append_existing_save:
self.logger.debug_log(f"Name: {name}")
read_values = self.read_value(name)
self.logger.pretty_debug_log(read_values)
checkpoints_list = [read_values[0], checkpoints]
prompts_list = [read_values[1], prompts]
checkpoints = ",\n".join(checkpoints_list)
prompts = ";\n".join(prompts_list)
# Add the data to the dictionary
data[name] = (checkpoints, prompts)
# Append the new data to the JSON file
with open(self.file_name, 'w') as f:
json.dump(data, f)
self.logger.log_info("saved checkpoints and Prompts")
if append_existing_save:
return f'Appended "{name}"'
elif overwrite_existing_save:
return f'Overwrote "{name}"'
else:
return f'Saved "{name}"'
def read_value(self, name: str) -> Tuple[str, str]:
"""Get the checkpoints and prompts from a save
Args:
name (str): the name of the save
Returns:
Tuple[str, str]: the checkpoints and prompts
"""
data = {}
if os.path.exists(self.file_name):
data = self.read_file()
else:
raise RuntimeError("no save file found")
x, y = tuple(data[name])
self.logger.log_info("loaded save")
return x, y
def get_keys(self) -> List[str]:
"""Get the keys from the JSON file
Returns:
List[str]: a list of keys
"""
data = self.read_file()
return list(data.keys())
|