pz / pdf2zh /cache.py
github-actions[bot]
GitHub deploy: ec524f8885a8220ac96f04d1c0f21b3e1e69500e
1bf3e6e
import logging
import os
import json
from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL
from typing import Optional
# we don't init the database here
db = SqliteDatabase(None)
logger = logging.getLogger(__name__)
class _TranslationCache(Model):
id = AutoField()
translate_engine = CharField(max_length=20)
translate_engine_params = TextField()
original_text = TextField()
translation = TextField()
class Meta:
database = db
constraints = [
SQL(
"""
UNIQUE (
translate_engine,
translate_engine_params,
original_text
)
ON CONFLICT REPLACE
"""
)
]
class TranslationCache:
@staticmethod
def _sort_dict_recursively(obj):
if isinstance(obj, dict):
return {
k: TranslationCache._sort_dict_recursively(v)
for k in sorted(obj.keys())
for v in [obj[k]]
}
elif isinstance(obj, list):
return [TranslationCache._sort_dict_recursively(item) for item in obj]
return obj
def __init__(self, translate_engine: str, translate_engine_params: dict = None):
assert (
len(translate_engine) < 20
), "current cache require translate engine name less than 20 characters"
self.translate_engine = translate_engine
self.replace_params(translate_engine_params)
# The program typically starts multi-threaded translation
# only after cache parameters are fully configured,
# so thread safety doesn't need to be considered here.
def replace_params(self, params: dict = None):
if params is None:
params = {}
self.params = params
params = self._sort_dict_recursively(params)
self.translate_engine_params = json.dumps(params)
def update_params(self, params: dict = None):
if params is None:
params = {}
self.params.update(params)
self.replace_params(self.params)
def add_params(self, k: str, v):
self.params[k] = v
self.replace_params(self.params)
# Since peewee and the underlying sqlite are thread-safe,
# get and set operations don't need locks.
def get(self, original_text: str) -> Optional[str]:
result = _TranslationCache.get_or_none(
translate_engine=self.translate_engine,
translate_engine_params=self.translate_engine_params,
original_text=original_text,
)
return result.translation if result else None
def set(self, original_text: str, translation: str):
try:
_TranslationCache.create(
translate_engine=self.translate_engine,
translate_engine_params=self.translate_engine_params,
original_text=original_text,
translation=translation,
)
except Exception as e:
logger.debug(f"Error setting cache: {e}")
def init_db(remove_exists=False):
cache_folder = os.path.join(os.path.expanduser("~"), ".cache", "pdf2zh")
os.makedirs(cache_folder, exist_ok=True)
# The current version does not support database migration, so add the version number to the file name.
cache_db_path = os.path.join(cache_folder, "cache.v1.db")
if remove_exists and os.path.exists(cache_db_path):
os.remove(cache_db_path)
db.init(
cache_db_path,
pragmas={
"journal_mode": "wal",
"busy_timeout": 1000,
},
)
db.create_tables([_TranslationCache], safe=True)
def init_test_db():
import tempfile
cache_db_path = tempfile.mktemp(suffix=".db")
test_db = SqliteDatabase(
cache_db_path,
pragmas={
"journal_mode": "wal",
"busy_timeout": 1000,
},
)
test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False)
test_db.connect()
test_db.create_tables([_TranslationCache], safe=True)
return test_db
def clean_test_db(test_db):
test_db.drop_tables([_TranslationCache])
test_db.close()
db_path = test_db.database
if os.path.exists(db_path):
os.remove(test_db.database)
wal_path = db_path + "-wal"
if os.path.exists(wal_path):
os.remove(wal_path)
shm_path = db_path + "-shm"
if os.path.exists(shm_path):
os.remove(shm_path)
init_db()