oh-my-words / story_agent.py
linxy's picture
init
bb48ea5
import traceback
from typing import List, Tuple
from loguru import logger
from pydantic import BaseModel, Field
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.output_parsers import (
PydanticOutputParser,
OutputFixingParser,
)
# LLM.py ζ˜―ζˆ‘θ‡ͺε·±ηš„θ―­θ¨€ζ¨‘εž‹οΌŒδ½ ε―δ»₯η›΄ζŽ₯使用 openai ηš„
# from langchain.llms.openai import OpenAIChat
from LLM import OpenAIChat
TEMPLATE = """\
please write a story at least 5 sentences long, using the words [{words}].
{format}
Attention! The words should be highlighted, surrounded by "`". Therefore, the story should be in the following format.
English: ... `word1` ... `word2` ...
Chinese: ... `单词1` ... `单词2` ...
"""
class Story(BaseModel):
story: str = Field(description="the story")
translated_story: str = Field(description="the translated story")
llm = OpenAIChat(model_name="gpt-3.5-turbo", temperature=0.3)
parser = PydanticOutputParser(pydantic_object=Story)
prompt_template = PromptTemplate(
template=TEMPLATE,
input_variables=["words"],
partial_variables={
"format": parser.get_format_instructions(),
}
)
parser = OutputFixingParser.from_llm(parser=parser, llm=llm)
chain = LLMChain(
llm=llm,
prompt=prompt_template,
output_parser=parser,
verbose=False,
)
def tell_story(words: List[str]):
count = 0
while count < 10:
count += 1
try:
resp: Story = chain.run(", ".join(words))
if len(resp.story.strip()) == 0:
continue
if len(resp.translated_story.strip()) == 0:
continue
return resp
except Exception as e:
logger.error(e)
logger.error(traceback.format_exc())
logger.error("retrying...")
continue
return Story(story="", translated_story="")
def generate_story_and_translated_story(words: List[str]) -> Tuple[str, str]:
resp = tell_story(words)
return resp.story, resp.translated_story