|
import logging |
|
import os |
|
import json |
|
from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL |
|
from typing import Optional |
|
|
|
|
|
|
|
db = SqliteDatabase(None) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class _TranslationCache(Model): |
|
id = AutoField() |
|
translate_engine = CharField(max_length=20) |
|
translate_engine_params = TextField() |
|
original_text = TextField() |
|
translation = TextField() |
|
|
|
class Meta: |
|
database = db |
|
constraints = [ |
|
SQL( |
|
""" |
|
UNIQUE ( |
|
translate_engine, |
|
translate_engine_params, |
|
original_text |
|
) |
|
ON CONFLICT REPLACE |
|
""" |
|
) |
|
] |
|
|
|
|
|
class TranslationCache: |
|
@staticmethod |
|
def _sort_dict_recursively(obj): |
|
if isinstance(obj, dict): |
|
return { |
|
k: TranslationCache._sort_dict_recursively(v) |
|
for k in sorted(obj.keys()) |
|
for v in [obj[k]] |
|
} |
|
elif isinstance(obj, list): |
|
return [TranslationCache._sort_dict_recursively(item) for item in obj] |
|
return obj |
|
|
|
def __init__(self, translate_engine: str, translate_engine_params: dict = None): |
|
assert ( |
|
len(translate_engine) < 20 |
|
), "current cache require translate engine name less than 20 characters" |
|
self.translate_engine = translate_engine |
|
self.replace_params(translate_engine_params) |
|
|
|
|
|
|
|
|
|
def replace_params(self, params: dict = None): |
|
if params is None: |
|
params = {} |
|
self.params = params |
|
params = self._sort_dict_recursively(params) |
|
self.translate_engine_params = json.dumps(params) |
|
|
|
def update_params(self, params: dict = None): |
|
if params is None: |
|
params = {} |
|
self.params.update(params) |
|
self.replace_params(self.params) |
|
|
|
def add_params(self, k: str, v): |
|
self.params[k] = v |
|
self.replace_params(self.params) |
|
|
|
|
|
|
|
def get(self, original_text: str) -> Optional[str]: |
|
result = _TranslationCache.get_or_none( |
|
translate_engine=self.translate_engine, |
|
translate_engine_params=self.translate_engine_params, |
|
original_text=original_text, |
|
) |
|
return result.translation if result else None |
|
|
|
def set(self, original_text: str, translation: str): |
|
try: |
|
_TranslationCache.create( |
|
translate_engine=self.translate_engine, |
|
translate_engine_params=self.translate_engine_params, |
|
original_text=original_text, |
|
translation=translation, |
|
) |
|
except Exception as e: |
|
logger.debug(f"Error setting cache: {e}") |
|
|
|
|
|
def init_db(remove_exists=False): |
|
cache_folder = os.path.join(os.path.expanduser("~"), ".cache", "pdf2zh") |
|
os.makedirs(cache_folder, exist_ok=True) |
|
|
|
cache_db_path = os.path.join(cache_folder, "cache.v1.db") |
|
if remove_exists and os.path.exists(cache_db_path): |
|
os.remove(cache_db_path) |
|
db.init( |
|
cache_db_path, |
|
pragmas={ |
|
"journal_mode": "wal", |
|
"busy_timeout": 1000, |
|
}, |
|
) |
|
db.create_tables([_TranslationCache], safe=True) |
|
|
|
|
|
def init_test_db(): |
|
import tempfile |
|
|
|
cache_db_path = tempfile.mktemp(suffix=".db") |
|
test_db = SqliteDatabase( |
|
cache_db_path, |
|
pragmas={ |
|
"journal_mode": "wal", |
|
"busy_timeout": 1000, |
|
}, |
|
) |
|
test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False) |
|
test_db.connect() |
|
test_db.create_tables([_TranslationCache], safe=True) |
|
return test_db |
|
|
|
|
|
def clean_test_db(test_db): |
|
test_db.drop_tables([_TranslationCache]) |
|
test_db.close() |
|
db_path = test_db.database |
|
if os.path.exists(db_path): |
|
os.remove(test_db.database) |
|
wal_path = db_path + "-wal" |
|
if os.path.exists(wal_path): |
|
os.remove(wal_path) |
|
shm_path = db_path + "-shm" |
|
if os.path.exists(shm_path): |
|
os.remove(shm_path) |
|
|
|
|
|
init_db() |
|
|