Spaces:
Running
Running
Fangrui Liu
commited on
Commit
Β·
d5a4cb4
1
Parent(s):
526644e
fix callback
Browse files- callbacks/arxiv_callbacks.py +10 -3
callbacks/arxiv_callbacks.py
CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
|
|
2 |
from typing import Dict, Any
|
3 |
from sql_formatter.core import format_sql
|
4 |
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
|
|
5 |
|
6 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
7 |
def __init__(self) -> None:
|
@@ -62,8 +63,14 @@ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
|
62 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
63 |
pass
|
64 |
|
65 |
-
def
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
st.write('We generated Vector SQL for you:')
|
68 |
st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
|
69 |
print(f"Vector SQL: {text}")
|
@@ -83,4 +90,4 @@ class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
|
83 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
84 |
self.status_bar = st.empty()
|
85 |
self.prog_value = 0
|
86 |
-
self.prog_interval = 0.1
|
|
|
2 |
from typing import Dict, Any
|
3 |
from sql_formatter.core import format_sql
|
4 |
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
5 |
+
from langchain.schema.output import LLMResult
|
6 |
|
7 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
8 |
def __init__(self) -> None:
|
|
|
63 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
64 |
pass
|
65 |
|
66 |
+
def on_llm_end(
|
67 |
+
self,
|
68 |
+
response: LLMResult,
|
69 |
+
*args,
|
70 |
+
**kwargs,
|
71 |
+
):
|
72 |
+
text = response.generations[0][0].text
|
73 |
+
if text.replace(' ', '').upper().startswith('SELECT'):
|
74 |
st.write('We generated Vector SQL for you:')
|
75 |
st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
|
76 |
print(f"Vector SQL: {text}")
|
|
|
90 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
91 |
self.status_bar = st.empty()
|
92 |
self.prog_value = 0
|
93 |
+
self.prog_interval = 0.1
|