Spaces:
Runtime error
Runtime error
from sqlalchemy.orm import Session | |
from typing import List, Tuple | |
from tqdm import tqdm | |
from database.operation import * | |
from database import schema | |
import random | |
from loguru import logger | |
import math | |
# ่ฎฐๅ่ฏ | |
from story_agent import generate_story_and_translated_story | |
from common.util import date_str, multiprocessing_mapping | |
def get_words_for_book(db: Session, user_book: UserBook) -> List[schema.Word]: | |
book = get_book(db, user_book.book_id) | |
if book is None: | |
logger.warning("book not found") | |
return [] | |
q = db.query(schema.Word).join(schema.Unit, schema.Unit.bv_voc_id == schema.Word.vc_id) | |
words = q.filter(schema.Unit.bv_book_id == book.bk_id).order_by(schema.Word.vc_difficulty).all() | |
return words | |
def save_words_as_book(db: Session, user_id: str, words: List[schema.Word], title: str): | |
book = create_book(db, BookCreate(bk_name=f"{title}๏ผๅพ ๅญฆๅ่ฏ่ชๅจไฟๅญไธบๅ่ฏไนฆ๏ผ", bk_item_num=len(words), creator=user_id)) | |
for i, word in tqdm(enumerate(words)): | |
unit = UnitCreate(bv_book_id=book.bk_id, bv_voc_id=word.vc_id) | |
db_unit = schema.Unit(**unit.dict()) | |
db.add(db_unit) | |
if i % 500 == 0: | |
db.commit() | |
db.commit() | |
return book | |
def save_batch_words(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word]): | |
batch_words_str_list = [word.vc_vocabulary for word in batch_words] | |
# ๆไปฌๅชๅจ็ฌฌไธไธชๆนๆฌก็ๆๆ ไบใๅ้ข็ๆนๆฌกๆ นๆฎ็จๆท็่ฎฐๅฟๆ ๅต็ๆๆ ไบ๏ผๆๅ 3 ไธชๆนๆฌก็ๆๆ ไบ | |
story, translated_story = generate_story_and_translated_story(batch_words_str_list) | |
return save_batch_words_with_story(db, i, user_book_id, batch_words, story, translated_story) | |
def save_batch_words_with_story(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word], story: str, translated_story: str): | |
batch_words_str_list = [word.vc_vocabulary for word in batch_words] | |
logger.info(f"{i}, {batch_words_str_list}\n{story}") | |
user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate( | |
user_book_id=user_book_id, | |
story=story, | |
translated_story=translated_story | |
)) | |
create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate( | |
batch_id=user_memory_batch.id, | |
story=story, | |
translated_story=translated_story | |
)) | |
for word in batch_words: | |
memory_word = UserMemoryWordCreate( | |
batch_id=user_memory_batch.id, | |
word_id=word.vc_id | |
) | |
db_memory_word = schema.UserMemoryWord(**memory_word.dict()) | |
db.add(db_memory_word) | |
db.commit() | |
return user_memory_batch | |
async def async_save_batch_words(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word]): | |
save_batch_words(db, i, user_book_id, batch_words) | |
import asyncio | |
async def async_save_batch_words_list(db: Session, user_book_id: str, batch_words_list: List[List[schema.Word]]): | |
for i, batch_words in enumerate(batch_words_list): | |
asyncio.ensure_future(async_save_batch_words(db, i+1, user_book_id, batch_words)) | |
def transform(batch_words: List[str]): | |
story, translated_story = generate_story_and_translated_story(batch_words) | |
return { | |
"story": story, | |
"translated_story": translated_story, | |
"words": batch_words | |
} | |
def save_batch_words_list(db: Session, user_book_id: str, batch_words_list: List[List[schema.Word]]): | |
word_str_list = [] | |
for batch_words in batch_words_list: | |
word_str_list.append([word.vc_vocabulary for word in batch_words]) | |
story_list = multiprocessing_mapping(transform, word_str_list, tmp_filepath=f"./output/logs/save_batch_words_list_{date_str}.xlsx") | |
logger.info(f"story_list: {len(story_list)}") | |
for i, (batch_words, story) in tqdm(enumerate(zip(batch_words_list, story_list))): | |
save_batch_words_with_story(db, i, user_book_id, batch_words, story['story'], story['translated_story']) | |
def track(db: Session, user_book: schema.UserBook, words: List[schema.Word]): | |
batch_size = user_book.batch_size | |
logger.debug(f"{[w.vc_vocabulary for w in words]}") | |
logger.debug(f"batch_size: {batch_size}") | |
logger.debug(f"words count: {len(words)}") | |
if user_book.random: | |
random.shuffle(words) | |
else: | |
words.sort(key=lambda x: x.vc_frequency, reverse=True) # ๆ็ ง่ฏ้ขๆๅบ๏ผ่ฏ้ข่ถ้ซ่ถๅฎนๆ่ฎฐไฝ | |
logger.debug(f"saving words as book") | |
save_words_as_book(db, user_book.owner_id, words, user_book.title) | |
logger.debug(f"saved words as book [{user_book.title}]") | |
batch_words_list = [] | |
for i in range(0, len(words), batch_size): | |
batch_words = words[i:i+batch_size] | |
batch_words_list.append(batch_words) | |
logger.debug(f"batch_words_list: {len(batch_words_list)}") | |
if len(batch_words_list) == 0: | |
return | |
first_batch_words = batch_words_list[0] | |
user_memory_batch = save_batch_words(db, 0, user_book.id, first_batch_words) | |
user_book.memorizing_batch = user_memory_batch.id | |
db.commit() | |
save_batch_words_list(db, user_book.id, batch_words_list[1:]) | |
# asyncio.run(async_save_batch_words_list(db, user_book.id, batch_words_list[1:])) | |
def remenber(db: Session, batch_id: str, word_id: str): | |
return create_user_memory_action(db, UserMemoryActionCreate( | |
batch_id=batch_id, | |
word_id=word_id, | |
action="remember" | |
)) | |
def forget(db: Session, batch_id: str, word_id: str): | |
return create_user_memory_action(db, UserMemoryActionCreate( | |
batch_id=batch_id, | |
word_id=word_id, | |
action="forget" | |
)) | |
def save_memorizing_word_action(db: Session, batch_id: str, actions: List[Tuple[str, str]]): | |
""" | |
actions: [(word_id, remember | forget)] | |
""" | |
for word_id, action in actions: | |
memory_action = UserMemoryActionCreate( | |
batch_id=batch_id, | |
word_id=word_id, | |
action=action | |
) | |
db_memory_action = schema.UserMemoryAction(**memory_action.dict()) | |
db.add(db_memory_action) | |
db.commit() | |
def on_batch_start(db: Session, user_memory_batch_id: str): | |
return create_user_memory_batch_action(db, UserMemoryBatchActionCreate( | |
batch_id=user_memory_batch_id, | |
action="start" | |
)) | |
def on_batch_end(db: Session, user_memory_batch_id: str): | |
return create_user_memory_batch_action(db, UserMemoryBatchActionCreate( | |
batch_id=user_memory_batch_id, | |
action="end" | |
)) | |
# def generate_recall_batch(db: Session, user_book: schema.UserBook): | |
def generate_next_batch(db: Session, user_book: schema.UserBook, | |
minutes: int = 60, k: int = 3): | |
# ็ๆไธไธไธชๆนๆฌก๏ผๅๅฟๆนๆ่ ๅคไน ๆน | |
# ๅฆๆๆฏๆฐ่ฏๆน๏ผๅ่ฟๅ None | |
left_bound, right_bound = 0.3, 0.6 | |
user_book_id = user_book.id | |
batch_size = user_book.batch_size | |
# actions, batch_id_to_batch, batch_id_to_words = get_user_memory_batch_history_in_minutes(db, user_book_id, minutes) | |
# memorizing_words = sum(list(batch_id_to_words.values()), []) | |
memorizing_words = get_user_memory_word_history_in_minutes(db, user_book_id, minutes) | |
if len(memorizing_words) < k * batch_size: | |
# 1. ่ฎฐๅฟๆฐ่ฏๆฐ่ฟๅฐ | |
# ๆฐ่ฏๆน | |
logger.info("ๆฐ่ฏๆน") | |
return None | |
# ่ฎก็ฎ่ฎฐๅฟๆ็ | |
memory_actions = get_actions_at_each_word(db, [w.vc_id for w in memorizing_words]) | |
remember_count = defaultdict(int) | |
forget_count = defaultdict(int) | |
for a in memory_actions: | |
if a.action == "remember": | |
remember_count[a.word_id] += 1 | |
else: | |
forget_count[a.word_id] += 1 | |
word_id_to_efficiency = {} | |
for word in memorizing_words: | |
efficiency = remember_count[word.vc_id] / (remember_count[word.vc_id] + forget_count[word.vc_id]) | |
word_id_to_efficiency[word.vc_id] = efficiency | |
logger.info([(w.vc_vocabulary, word_id_to_efficiency[w.vc_id]) for w in memorizing_words].sort(key=lambda x: x[1])) | |
if all([efficiency > right_bound for efficiency in word_id_to_efficiency.values()] + [count > 3 for count in remember_count.values()]): | |
# 2. ่ฎฐๅฟๆ็่ฟ้ซ | |
# ๆฐ่ฏๆน | |
logger.info("ๆฐ่ฏๆน") | |
return None | |
forgot_word_ids = [word_id for word_id, efficiency in word_id_to_efficiency.items() if efficiency < left_bound] | |
forgot_word_ids.sort(key=lambda x: word_id_to_efficiency[x]) | |
if len(forgot_word_ids) >= batch_size: | |
# 4. ๆญฃๅธธๆ ๅต | |
# ๅคไน ๆน | |
logger.info("ๅคไน ๆน") | |
batch_words = [word for word in memorizing_words if word.vc_id in forgot_word_ids][:batch_size] | |
batch_words.sort(key=lambda x: x.vc_difficulty, reverse=True) | |
batch_words_str_list = [word.vc_vocabulary for word in batch_words] | |
story, translated_story = generate_story_and_translated_story(batch_words_str_list) | |
user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate( | |
user_book_id=user_book_id, | |
story=story, | |
translated_story=translated_story, | |
batch_type="ๅคไน ", | |
)) | |
create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate( | |
batch_id=user_memory_batch.id, | |
story=story, | |
translated_story=translated_story | |
)) | |
for word in batch_words: | |
memory_word = UserMemoryWordCreate( | |
batch_id=user_memory_batch.id, | |
word_id=word.vc_id | |
) | |
db_memory_word = schema.UserMemoryWord(**memory_word.dict()) | |
db.add(db_memory_word) | |
db.commit() | |
return user_memory_batch | |
unfarmiliar_word_ids = [word_id for word_id, efficiency in word_id_to_efficiency.items() if left_bound <= efficiency < right_bound] | |
unfarmiliar_word_ids.sort(key=lambda x: word_id_to_efficiency[x]) | |
if len(unfarmiliar_word_ids) < batch_size: | |
# ๆ่ฎฐไฝๆฌกๆฐๅฐ็ไนๅ ่ฟๆฅ | |
unfarmiliar_word_ids += [word_id for word_id, count in remember_count.items() if count < 3] | |
unfarmiliar_word_ids.sort(key=lambda x: word_id_to_efficiency[x]) | |
if len(unfarmiliar_word_ids) >= batch_size: | |
# 3. ่ฎฐๅฟๆ็่ฟไฝ | |
# ๅๅฟๆน | |
logger.info("ๅๅฟๆน") | |
batch_words = [word for word in memorizing_words if word.vc_id in unfarmiliar_word_ids][:batch_size] | |
batch_words.sort(key=lambda x: x.vc_difficulty, reverse=True) | |
batch_words_str_list = [word.vc_vocabulary for word in batch_words] | |
story, translated_story = generate_story_and_translated_story(batch_words_str_list) | |
user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate( | |
user_book_id=user_book_id, | |
story=story, | |
translated_story=translated_story, | |
batch_type="ๅๅฟ", | |
)) | |
create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate( | |
batch_id=user_memory_batch.id, | |
story=story, | |
translated_story=translated_story | |
)) | |
for word in batch_words: | |
memory_word = UserMemoryWordCreate( | |
batch_id=user_memory_batch.id, | |
word_id=word.vc_id | |
) | |
db_memory_word = schema.UserMemoryWord(**memory_word.dict()) | |
db.add(db_memory_word) | |
db.commit() | |
return user_memory_batch | |
# 5. ๆญฃๅธธๆ ๅต | |
# ๆฐ่ฏๆน | |
logger.info("ๆฐ่ฏๆน") | |
return None | |