from sqlalchemy.orm import Session from typing import List, Tuple from tqdm import tqdm from database.operation import * from database import schema import random from loguru import logger import math # 记单词 from story_agent import generate_story_and_translated_story from common.util import date_str, multiprocessing_mapping def get_words_for_book(db: Session, user_book: UserBook) -> List[schema.Word]: book = get_book(db, user_book.book_id) if book is None: logger.warning("book not found") return [] q = db.query(schema.Word).join(schema.Unit, schema.Unit.bv_voc_id == schema.Word.vc_id) words = q.filter(schema.Unit.bv_book_id == book.bk_id).order_by(schema.Word.vc_difficulty).all() return words def save_words_as_book(db: Session, user_id: str, words: List[schema.Word], title: str): book = create_book(db, BookCreate(bk_name=f"{title}(待学单词自动保存为单词书)", bk_item_num=len(words), creator=user_id)) for i, word in tqdm(enumerate(words)): unit = UnitCreate(bv_book_id=book.bk_id, bv_voc_id=word.vc_id) db_unit = schema.Unit(**unit.dict()) db.add(db_unit) if i % 500 == 0: db.commit() db.commit() return book def save_batch_words(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word]): batch_words_str_list = [word.vc_vocabulary for word in batch_words] # 我们只在第一个批次生成故事。后面的批次根据用户的记忆情况生成故事,提前 3 个批次生成故事 story, translated_story = generate_story_and_translated_story(batch_words_str_list) return save_batch_words_with_story(db, i, user_book_id, batch_words, story, translated_story) def save_batch_words_with_story(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word], story: str, translated_story: str): batch_words_str_list = [word.vc_vocabulary for word in batch_words] logger.info(f"{i}, {batch_words_str_list}\n{story}") user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate( user_book_id=user_book_id, story=story, translated_story=translated_story )) create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate( batch_id=user_memory_batch.id, story=story, translated_story=translated_story )) for word in batch_words: memory_word = UserMemoryWordCreate( batch_id=user_memory_batch.id, word_id=word.vc_id ) db_memory_word = schema.UserMemoryWord(**memory_word.dict()) db.add(db_memory_word) db.commit() return user_memory_batch async def async_save_batch_words(db: Session, i: int, user_book_id: str, batch_words: List[schema.Word]): save_batch_words(db, i, user_book_id, batch_words) import asyncio async def async_save_batch_words_list(db: Session, user_book_id: str, batch_words_list: List[List[schema.Word]]): for i, batch_words in enumerate(batch_words_list): asyncio.ensure_future(async_save_batch_words(db, i+1, user_book_id, batch_words)) def transform(batch_words: List[str]): story, translated_story = generate_story_and_translated_story(batch_words) return { "story": story, "translated_story": translated_story, "words": batch_words } def save_batch_words_list(db: Session, user_book_id: str, batch_words_list: List[List[schema.Word]]): word_str_list = [] for batch_words in batch_words_list: word_str_list.append([word.vc_vocabulary for word in batch_words]) story_list = multiprocessing_mapping(transform, word_str_list, tmp_filepath=f"./output/logs/save_batch_words_list_{date_str}.xlsx") logger.info(f"story_list: {len(story_list)}") for i, (batch_words, story) in tqdm(enumerate(zip(batch_words_list, story_list))): save_batch_words_with_story(db, i, user_book_id, batch_words, story['story'], story['translated_story']) def track(db: Session, user_book: schema.UserBook, words: List[schema.Word]): batch_size = user_book.batch_size logger.debug(f"{[w.vc_vocabulary for w in words]}") logger.debug(f"batch_size: {batch_size}") logger.debug(f"words count: {len(words)}") if user_book.random: random.shuffle(words) else: words.sort(key=lambda x: x.vc_frequency, reverse=True) # 按照词频排序,词频越高越容易记住 logger.debug(f"saving words as book") save_words_as_book(db, user_book.owner_id, words, user_book.title) logger.debug(f"saved words as book [{user_book.title}]") batch_words_list = [] for i in range(0, len(words), batch_size): batch_words = words[i:i+batch_size] batch_words_list.append(batch_words) logger.debug(f"batch_words_list: {len(batch_words_list)}") if len(batch_words_list) == 0: return first_batch_words = batch_words_list[0] user_memory_batch = save_batch_words(db, 0, user_book.id, first_batch_words) user_book.memorizing_batch = user_memory_batch.id db.commit() save_batch_words_list(db, user_book.id, batch_words_list[1:]) # asyncio.run(async_save_batch_words_list(db, user_book.id, batch_words_list[1:])) def remenber(db: Session, batch_id: str, word_id: str): return create_user_memory_action(db, UserMemoryActionCreate( batch_id=batch_id, word_id=word_id, action="remember" )) def forget(db: Session, batch_id: str, word_id: str): return create_user_memory_action(db, UserMemoryActionCreate( batch_id=batch_id, word_id=word_id, action="forget" )) def save_memorizing_word_action(db: Session, batch_id: str, actions: List[Tuple[str, str]]): """ actions: [(word_id, remember | forget)] """ for word_id, action in actions: memory_action = UserMemoryActionCreate( batch_id=batch_id, word_id=word_id, action=action ) db_memory_action = schema.UserMemoryAction(**memory_action.dict()) db.add(db_memory_action) db.commit() def on_batch_start(db: Session, user_memory_batch_id: str): return create_user_memory_batch_action(db, UserMemoryBatchActionCreate( batch_id=user_memory_batch_id, action="start" )) def on_batch_end(db: Session, user_memory_batch_id: str): return create_user_memory_batch_action(db, UserMemoryBatchActionCreate( batch_id=user_memory_batch_id, action="end" )) # def generate_recall_batch(db: Session, user_book: schema.UserBook): def generate_next_batch(db: Session, user_book: schema.UserBook, minutes: int = 60, k: int = 3): # 生成下一个批次,回忆批或者复习批 # 如果是新词批,则返回 None left_bound, right_bound = 0.3, 0.6 user_book_id = user_book.id batch_size = user_book.batch_size # actions, batch_id_to_batch, batch_id_to_words = get_user_memory_batch_history_in_minutes(db, user_book_id, minutes) # memorizing_words = sum(list(batch_id_to_words.values()), []) memorizing_words = get_user_memory_word_history_in_minutes(db, user_book_id, minutes) if len(memorizing_words) < k * batch_size: # 1. 记忆新词数过少 # 新词批 logger.info("新词批") return None # 计算记忆效率 memory_actions = get_actions_at_each_word(db, [w.vc_id for w in memorizing_words]) remember_count = defaultdict(int) forget_count = defaultdict(int) for a in memory_actions: if a.action == "remember": remember_count[a.word_id] += 1 else: forget_count[a.word_id] += 1 word_id_to_efficiency = {} for word in memorizing_words: efficiency = remember_count[word.vc_id] / (remember_count[word.vc_id] + forget_count[word.vc_id]) word_id_to_efficiency[word.vc_id] = efficiency logger.info([(w.vc_vocabulary, word_id_to_efficiency[w.vc_id]) for w in memorizing_words].sort(key=lambda x: x[1])) if all([efficiency > right_bound for efficiency in word_id_to_efficiency.values()] + [count > 3 for count in remember_count.values()]): # 2. 记忆效率过高 # 新词批 logger.info("新词批") return None forgot_word_ids = [word_id for word_id, efficiency in word_id_to_efficiency.items() if efficiency < left_bound] forgot_word_ids.sort(key=lambda x: word_id_to_efficiency[x]) if len(forgot_word_ids) >= batch_size: # 4. 正常情况 # 复习批 logger.info("复习批") batch_words = [word for word in memorizing_words if word.vc_id in forgot_word_ids][:batch_size] batch_words.sort(key=lambda x: x.vc_difficulty, reverse=True) batch_words_str_list = [word.vc_vocabulary for word in batch_words] story, translated_story = generate_story_and_translated_story(batch_words_str_list) user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate( user_book_id=user_book_id, story=story, translated_story=translated_story, batch_type="复习", )) create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate( batch_id=user_memory_batch.id, story=story, translated_story=translated_story )) for word in batch_words: memory_word = UserMemoryWordCreate( batch_id=user_memory_batch.id, word_id=word.vc_id ) db_memory_word = schema.UserMemoryWord(**memory_word.dict()) db.add(db_memory_word) db.commit() return user_memory_batch unfarmiliar_word_ids = [word_id for word_id, efficiency in word_id_to_efficiency.items() if left_bound <= efficiency < right_bound] unfarmiliar_word_ids.sort(key=lambda x: word_id_to_efficiency[x]) if len(unfarmiliar_word_ids) < batch_size: # 把记住次数少的也加进来 unfarmiliar_word_ids += [word_id for word_id, count in remember_count.items() if count < 3] unfarmiliar_word_ids.sort(key=lambda x: word_id_to_efficiency[x]) if len(unfarmiliar_word_ids) >= batch_size: # 3. 记忆效率过低 # 回忆批 logger.info("回忆批") batch_words = [word for word in memorizing_words if word.vc_id in unfarmiliar_word_ids][:batch_size] batch_words.sort(key=lambda x: x.vc_difficulty, reverse=True) batch_words_str_list = [word.vc_vocabulary for word in batch_words] story, translated_story = generate_story_and_translated_story(batch_words_str_list) user_memory_batch = create_user_memory_batch(db, UserMemoryBatchCreate( user_book_id=user_book_id, story=story, translated_story=translated_story, batch_type="回忆", )) create_user_memory_batch_generation_history(db, UserMemoryBatchGenerationHistoryCreate( batch_id=user_memory_batch.id, story=story, translated_story=translated_story )) for word in batch_words: memory_word = UserMemoryWordCreate( batch_id=user_memory_batch.id, word_id=word.vc_id ) db_memory_word = schema.UserMemoryWord(**memory_word.dict()) db.add(db_memory_word) db.commit() return user_memory_batch # 5. 正常情况 # 新词批 logger.info("新词批") return None