whoami02 commited on
Commit
8460174
·
verified ·
1 Parent(s): 01cfee6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -73
app.py CHANGED
@@ -1,82 +1,105 @@
1
  import os
2
- import urllib.request
3
  import gradio as gr
4
- from llama_cpp import Llama
5
- from langchain.llms import llamacpp
6
- from huggingface_hub import login, hf_hub_download
7
- from dotenv import load_dotenv
 
 
 
 
 
 
 
8
 
9
- MODEL_ID = "TheBloke/Llama-2-7b-Chat-GGUF"
10
- MODEL_BASENAME = "llama-2-7b-chat.Q4_K_M.gguf"
11
- # MODEL_ID = "TheBloke/Wizard-Vicuna-7B-Uncensored-GGUF"
12
- # MODEL_BASENAME = "Wizard-Vicuna-7B-Uncensored.Q4_K_M.gguf"
13
- CONTEXT_WINDOW_SIZE = 8000
14
- MAX_NEW_TOKENS = 2000
15
- N_BATCH = 128
16
- # load_dotenv()
17
- os.getenv('hf_token')
18
- def load_quantized_model(model_id, model_basename):
19
- try:
20
- model_path = hf_hub_download(
21
- repo_id=model_id,
22
- filename=model_basename,
23
- resume_download=True,
24
- cache_dir="./models"
25
- )
26
- kwargs = {
27
- 'model_path': model_path,
28
- 'c_ctx': CONTEXT_WINDOW_SIZE,
29
- 'max_tokens': MAX_NEW_TOKENS,
30
- 'n_batch': N_BATCH
31
- }
32
- return llamacpp.LlamaCpp(**kwargs)
33
- except TypeError:
34
- return None
35
 
36
- def load_model(model_id, model_basename=None):
37
- if ".gguf" in model_basename.lower():
38
- llm = load_quantized_model(model_id, model_basename)
39
- return llm
 
 
 
 
 
40
  else:
41
- print("currently only .gguf models supported")
42
-
43
 
44
-
45
- def generate_text(prompt="Who is the CEO of Apple?"):
46
- llm = load_model(MODEL_ID, MODEL_BASENAME)
47
- output = llm(
48
- prompt,
49
- max_tokens=256,
50
- temperature=0.1,
51
- top_p=0.5,
52
- echo=False,
53
- stop=["#"],
 
 
 
 
 
 
 
54
  )
55
- print(output)
56
- return output
57
- # output_text = output["choices"][0]["text"].strip()
58
-
59
- # # Remove Prompt Echo from Generated Text
60
- # cleaned_output_text = output_text.replace(prompt, "")
61
- # return cleaned_output_text
62
-
63
-
64
- description = "Zephyr-beta"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- examples = [
67
- ["What is the capital of France?", "The capital of France is Paris."],
68
- [
69
- "Who wrote the novel 'Pride and Prejudice'?",
70
- "The novel 'Pride and Prejudice' was written by Jane Austen.",
71
- ],
72
- ["What is the square root of 64?", "The square root of 64 is 8."],
73
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- gradio_interface = gr.Interface(
76
- fn=generate_text,
77
- inputs="text",
78
- outputs="text",
79
- examples=examples,
80
- title="Zephyr-B",
81
- )
82
- gradio_interface.launch(share=True)
 
1
  import os
 
2
  import gradio as gr
3
+ from dotenv import load_dotenv, find_dotenv
4
+ from langchain.utilities.sql_database import SQLDatabase
5
+ from langchain_google_genai import ChatGoogleGenerativeAI
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_core.runnables import RunnablePassthrough
9
+ from langchain_core.tracers import ConsoleCallbackHandler
10
+ from langchain_community.llms.llamacpp import LlamaCpp
11
+ from huggingface_hub import login
12
+ from langchain.globals import set_verbose
13
+ set_verbose(True)
14
 
15
+ # load_dotenv(find_dotenv(r"E:\AW\LLMs\.env"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def load_model(model_id):
18
+ if model_id == "gemini":
19
+ return ChatGoogleGenerativeAI(
20
+ model='gemini-pro',
21
+ google_api_key=login(os.environ["GOOGLE_API_KEY"]),
22
+ convert_system_message_to_human=True,
23
+ temperature=0.05,
24
+ verbose=True,
25
+ )
26
  else:
27
+ print("only gemini supported aofn")
 
28
 
29
+ def chain(db, llm):
30
+
31
+ def get_schema(_):
32
+ return db.get_table_info()
33
+ def run_query(query):
34
+ return db.run(query)
35
+
36
+ template = """Based on the table schema below, write a MS SQL query that would answer the user's question:
37
+ {schema}
38
+ Question: {question}
39
+ Query:"""
40
+
41
+ prompt = ChatPromptTemplate.from_messages(
42
+ [
43
+ ("system", "Given an input question, convert it to a MSSQL query. No pre-amble."),
44
+ ("human", template),
45
+ ]
46
  )
47
+
48
+ template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response:
49
+ {schema}
50
+
51
+ Question: {question}
52
+ MS-SQL Query: {query}
53
+ MS-SQL Response: {response}"""
54
+
55
+ prompt_response = ChatPromptTemplate.from_messages(
56
+ [
57
+ ("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."),
58
+ ("human", template),
59
+ ]
60
+ )
61
+
62
+ sql_response = (
63
+ RunnablePassthrough.assign(schema=get_schema)
64
+ | prompt
65
+ | llm.bind(stop=["\nSQLResult:"])
66
+ | StrOutputParser()
67
+ )
68
+ full_chain = (
69
+ RunnablePassthrough.assign(query=sql_response)
70
+ | RunnablePassthrough.assign(
71
+ schema=get_schema,
72
+ response=lambda x: db.run(x["query"]),
73
+ )
74
+ | prompt_response
75
+ | llm
76
+ )
77
+
78
+ return full_chain
79
 
80
+ def main():
81
+ gemini = load_model("gemini")
82
+
83
+ path = r"OPPI_shift.db" # \OPPI_down.db"
84
+ db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
85
+ db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
86
+
87
+ down_chain = chain(db=db1, llm=gemini)
88
+ prod_chain = chain(db=db2, llm=gemini)
89
+
90
+ def echo1(message, history):
91
+ ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
92
+ return str(ans)
93
+
94
+ def echo2(message, history):
95
+ ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
96
+ return str(ans)
97
+
98
+ downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
99
+ production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
100
+
101
+ demo = gr.TabbedInterface([downtime, production], ['ShiftDownTimeDetails', 'ShiftProductionDetails'])
102
+ demo.launch(debug=True)
103
 
104
+ if __name__ == "__main__":
105
+ main()