DrishtiSharma commited on
Commit
7e42dc2
·
verified ·
1 Parent(s): 8555124

Update lab/graph3.py

Browse files
Files changed (1) hide show
  1. lab/graph3.py +195 -0
lab/graph3.py CHANGED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __import__('pysqlite3') # Workaround for sqlite3 error on live Streamlit.
2
+ import sys
3
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') # Workaround for sqlite3 error on live Streamlit.
4
+ import graphviz
5
+ import traceback
6
+ from langgraph.graph import StateGraph, END
7
+ from langchain_openai import ChatOpenAI
8
+ from pydantic import BaseModel, Field
9
+ from typing import TypedDict, List, Literal, Dict, Any
10
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain.memory import ConversationBufferMemory
13
+ from pdf_writer import generate_pdf
14
+
15
+ from crew import CrewClass, Essay
16
+
17
+
18
+ class GraphState(TypedDict):
19
+ topic: str
20
+ response: str
21
+ documents: List[str]
22
+ essay: Dict[str, Any]
23
+ pdf_name: str
24
+
25
+
26
+ class RouteQuery(BaseModel):
27
+ """Route a user query to direct answer or research."""
28
+
29
+ way: Literal["edit_essay", "write_essay", "answer"] = Field(
30
+ ...,
31
+ description="Given a user question, choose to route it to write_essay, edit_essay, or answer",
32
+ )
33
+
34
+
35
+ class EssayWriter:
36
+ def __init__(self):
37
+ self.model = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0)
38
+ self.crew = CrewClass(llm=ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0.5))
39
+
40
+ self.memory = ConversationBufferMemory()
41
+ self.essay = {}
42
+ self.router_prompt = """
43
+ You are a router, and your duty is to route the user to the correct expert.
44
+ Always check conversation history and consider your move based on it.
45
+ If the topic is something about memory or daily talk, route the user to the answer expert.
46
+ If the topic starts with something like "Can you write" or the user requests an article or essay, route the user to the write_essay expert.
47
+ If the topic is about editing an essay, route the user to the edit_essay expert.
48
+
49
+ \nConversation History: {memory}
50
+ \nTopic: {topic}
51
+ """
52
+
53
+ self.simple_answer_prompt = """
54
+ You are an expert, and you are providing a simple answer to the user's question.
55
+
56
+ \nConversation History: {memory}
57
+ \nTopic: {topic}
58
+ """
59
+
60
+ builder = StateGraph(GraphState)
61
+
62
+ builder.add_node("answer", self.answer)
63
+ builder.add_node("write_essay", self.write_essay)
64
+ builder.add_node("edit_essay", self.edit_essay)
65
+
66
+ builder.set_conditional_entry_point(self.router_query, {
67
+ "write_essay": "write_essay",
68
+ "answer": "answer",
69
+ "edit_essay": "edit_essay",
70
+ })
71
+
72
+ builder.add_edge("write_essay", END)
73
+ builder.add_edge("edit_essay", END)
74
+ builder.add_edge("answer", END)
75
+
76
+ self.graph = builder.compile()
77
+ self.save_workflow_graph()
78
+
79
+
80
+ def router_query(self, state: GraphState):
81
+ print("**ROUTER**")
82
+ prompt = PromptTemplate.from_template(self.router_prompt)
83
+ memory = self.memory.load_memory_variables({})
84
+
85
+ router_query = self.model.with_structured_output(RouteQuery)
86
+ chain = prompt | router_query
87
+ result: RouteQuery = chain.invoke({"topic": state["topic"], "memory": memory})
88
+
89
+ print("Router Result: ", result.way)
90
+ return result.way
91
+
92
+ def answer(self, state: GraphState):
93
+ print("**ANSWER**")
94
+ prompt = PromptTemplate.from_template(self.simple_answer_prompt)
95
+ memory = self.memory.load_memory_variables({})
96
+ chain = prompt | self.model | StrOutputParser()
97
+ result = chain.invoke({"topic": state["topic"], "memory": memory})
98
+
99
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": result})
100
+ return {"response": result}
101
+
102
+ def write_essay(self, state: GraphState):
103
+ print("**ESSAY COMPLETION**")
104
+ # Generate the essay using the crew
105
+ self.essay = self.crew.kickoff({"topic": state["topic"]})
106
+ # Save the conversation context
107
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})
108
+ # Generate the PDF and return essay content for preview
109
+ pdf_name = generate_pdf(self.essay)
110
+ return {
111
+ "response": "Here is your essay! You can review it below before downloading.",
112
+ "essay": self.essay,
113
+ "pdf_name": pdf_name,
114
+ }
115
+
116
+ def edit_essay(self, state: GraphState):
117
+ print("**ESSAY EDIT**")
118
+ memory = self.memory.load_memory_variables({})
119
+
120
+ user_request = state["topic"]
121
+ parser = JsonOutputParser(pydantic_object=Essay)
122
+ prompt = PromptTemplate(
123
+ template=(
124
+ "Edit the JSON file as the user requested, and return the new JSON file."
125
+ "\n Request: {user_request} "
126
+ "\n Conversation History: {memory}"
127
+ "\n JSON File: {essay}"
128
+ " \n{format_instructions}"
129
+ ),
130
+ input_variables=["memory", "user_request", "essay"],
131
+ partial_variables={"format_instructions": parser.get_format_instructions()},
132
+ )
133
+
134
+ chain = prompt | self.model | parser
135
+
136
+ # Update the essay with the edits
137
+ self.essay = chain.invoke({"user_request": user_request, "memory": memory, "essay": self.essay})
138
+
139
+ # Save the conversation context
140
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})
141
+
142
+ # Generate the PDF and return essay content for preview
143
+ pdf_name = generate_pdf(self.essay)
144
+ return {
145
+ "response": "Here is your edited essay! You can review it below before downloading.",
146
+ "essay": self.essay,
147
+ "pdf_name": pdf_name,
148
+ }
149
+
150
+ import os
151
+ import graphviz
152
+
153
+ def save_workflow_graph(self):
154
+ """Generate and save a Graphviz workflow visualization with logging."""
155
+ log_file = "/tmp/graph_debug.log"
156
+ try:
157
+ output_path = "/tmp/graph"
158
+ dot = graphviz.Digraph(format="png")
159
+ dot.attr(dpi='300')
160
+
161
+ # Define Nodes
162
+ dot.node("Router", "🔀 Router")
163
+ dot.node("Write Essay", "📝 Write Essay")
164
+ dot.node("Edit Essay", "✏️ Edit Essay")
165
+ dot.node("Answer", "💬 Answer")
166
+
167
+ # Define Edges
168
+ dot.edge("Router", "Write Essay")
169
+ dot.edge("Router", "Edit Essay")
170
+ dot.edge("Router", "Answer")
171
+ dot.edge("Write Essay", "✅ Done")
172
+ dot.edge("Edit Essay", "✅ Done")
173
+ dot.edge("Answer", "✅ Done")
174
+
175
+ # Generate and save the graph in /tmp/
176
+ dot.render(output_path, format="png", cleanup=False)
177
+
178
+ # Check if file exists
179
+ graph_path = "/tmp/graph.png"
180
+ if os.path.exists(graph_path):
181
+ with open(log_file, "w") as f:
182
+ f.write("✅ Graphviz successfully generated /tmp/graph.png\n")
183
+ print("✅ Graphviz successfully generated /tmp/graph.png")
184
+ else:
185
+ raise FileNotFoundError("❌ Graphviz failed to generate /tmp/graph.png")
186
+
187
+ except Exception as e:
188
+ # Capture full traceback
189
+ error_message = f"❌ Error generating workflow visualization:\n{traceback.format_exc()}\n"
190
+ with open(log_file, "w") as f:
191
+ f.write(error_message)
192
+ print(error_message)
193
+ return error_message
194
+
195
+