|
import time |
|
import re |
|
import json |
|
import os |
|
from datetime import datetime |
|
|
|
import gradio as gr |
|
import torch |
|
|
|
import modules.shared as shared |
|
from modules import chat, ui as ui_module |
|
from modules.utils import gradio |
|
from modules.text_generation import generate_reply_HF, generate_reply_custom |
|
from .llm_web_search import get_webpage_content, langchain_search_duckduckgo, langchain_search_searxng, Generator |
|
from .langchain_websearch import LangchainCompressor |
|
|
|
|
|
params = { |
|
"display_name": "LLM Web Search", |
|
"is_tab": True, |
|
"enable": True, |
|
"search results per query": 5, |
|
"langchain similarity score threshold": 0.5, |
|
"instant answers": True, |
|
"regular search results": True, |
|
"search command regex": "", |
|
"default search command regex": r"Search_web\(\"(.*)\"\)", |
|
"open url command regex": "", |
|
"default open url command regex": r"Open_url\(\"(.*)\"\)", |
|
"display search results in chat": True, |
|
"display extracted URL content in chat": True, |
|
"searxng url": "", |
|
"cpu only": True, |
|
"chunk size": 500, |
|
"duckduckgo results per query": 10, |
|
"append current datetime": False, |
|
"default system prompt filename": None, |
|
"force search prefix": "Search_web", |
|
"ensemble weighting": 0.5, |
|
"keyword retriever": "bm25", |
|
"splade batch size": 2, |
|
"chunking method": "character-based", |
|
"chunker breakpoint_threshold_amount": 30 |
|
} |
|
custom_system_message_filename = None |
|
extension_path = os.path.dirname(os.path.abspath(__file__)) |
|
langchain_compressor = None |
|
update_history = None |
|
force_search = False |
|
|
|
|
|
def setup(): |
|
""" |
|
Is executed when the extension gets imported. |
|
:return: |
|
""" |
|
global params |
|
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
os.environ["QDRANT__TELEMETRY_DISABLED"] = "true" |
|
|
|
try: |
|
with open(os.path.join(extension_path, "settings.json"), "r") as f: |
|
saved_params = json.load(f) |
|
params.update(saved_params) |
|
save_settings() |
|
except FileNotFoundError: |
|
save_settings() |
|
|
|
if not os.path.exists(os.path.join(extension_path, "system_prompts")): |
|
os.makedirs(os.path.join(extension_path, "system_prompts")) |
|
|
|
toggle_extension(params["enable"]) |
|
|
|
|
|
def save_settings(): |
|
global params |
|
with open(os.path.join(extension_path, "settings.json"), "w") as f: |
|
json.dump(params, f, indent=4) |
|
current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
return gr.HTML(f'<span style="color:lawngreen"> Settings were saved at {current_datetime}</span>', |
|
visible=True) |
|
|
|
|
|
def toggle_extension(_enable: bool): |
|
global langchain_compressor, custom_system_message_filename |
|
if _enable: |
|
langchain_compressor = LangchainCompressor(device="cpu" if params["cpu only"] else "cuda", |
|
keyword_retriever=params["keyword retriever"], |
|
model_cache_dir=os.path.join(extension_path, "hf_models")) |
|
compressor_model = langchain_compressor.embeddings.client |
|
compressor_model.to(compressor_model._target_device) |
|
custom_system_message_filename = params.get("default system prompt filename") |
|
else: |
|
if not params["cpu only"] and 'langchain_compressor' in globals(): |
|
model_attrs = ["embeddings", "splade_doc_model", "splade_query_model"] |
|
for model_attr in model_attrs: |
|
if hasattr(langchain_compressor, model_attr): |
|
model = getattr(langchain_compressor, model_attr) |
|
if hasattr(model, "client"): |
|
model.client.to("cpu") |
|
del model.client |
|
else: |
|
if hasattr(model, "to"): |
|
model.to("cpu") |
|
del model |
|
torch.cuda.empty_cache() |
|
params.update({"enable": _enable}) |
|
return _enable |
|
|
|
|
|
def get_available_system_prompts(): |
|
try: |
|
return ["None"] + sorted(os.listdir(os.path.join(extension_path, "system_prompts"))) |
|
except FileNotFoundError: |
|
return ["None"] |
|
|
|
|
|
def load_system_prompt(filename: str or None): |
|
global custom_system_message_filename |
|
if not filename: |
|
return |
|
if filename == "None" or filename == "Select custom system message to load...": |
|
custom_system_message_filename = None |
|
return "" |
|
with open(os.path.join(extension_path, "system_prompts", filename), "r") as f: |
|
prompt_str = f.read() |
|
|
|
if params["append current datetime"]: |
|
prompt_str += f"\nDate and time of conversation: {datetime.now().strftime('%A %d %B %Y %H:%M')}" |
|
|
|
shared.settings['custom_system_message'] = prompt_str |
|
custom_system_message_filename = filename |
|
return prompt_str |
|
|
|
|
|
def save_system_prompt(filename, prompt): |
|
if not filename: |
|
return |
|
|
|
with open(os.path.join(extension_path, "system_prompts", filename), "w") as f: |
|
f.write(prompt) |
|
|
|
return gr.HTML(f'<span style="color:lawngreen"> Saved successfully</span>', |
|
visible=True) |
|
|
|
|
|
def check_file_exists(filename): |
|
if filename == "": |
|
return gr.HTML("", visible=False) |
|
if os.path.exists(os.path.join(extension_path, "system_prompts", filename)): |
|
return gr.HTML(f'<span style="color:orange"> Warning: Filename already exists</span>', visible=True) |
|
return gr.HTML("", visible=False) |
|
|
|
|
|
def timeout_save_message(): |
|
time.sleep(2) |
|
return gr.HTML("", visible=False) |
|
|
|
|
|
def deactivate_system_prompt(): |
|
shared.settings['custom_system_message'] = None |
|
return "None" |
|
|
|
|
|
def toggle_forced_search(value): |
|
global force_search |
|
force_search = value |
|
|
|
|
|
def ui(): |
|
""" |
|
Creates custom gradio elements when the UI is launched. |
|
:return: |
|
""" |
|
|
|
shared.gradio['custom_system_message'].value = load_system_prompt(custom_system_message_filename) |
|
|
|
def update_result_type_setting(choice: str): |
|
if choice == "Instant answers": |
|
params.update({"instant answers": True}) |
|
params.update({"regular search results": False}) |
|
elif choice == "Regular results": |
|
params.update({"instant answers": False}) |
|
params.update({"regular search results": True}) |
|
elif choice == "Regular results and instant answers": |
|
params.update({"instant answers": True}) |
|
params.update({"regular search results": True}) |
|
|
|
def update_regex_setting(input_str: str, setting_key: str, error_html_element: gr.component): |
|
if input_str == "": |
|
params.update({setting_key: params[f"default {setting_key}"]}) |
|
return {error_html_element: gr.HTML("", visible=False)} |
|
try: |
|
compiled = re.compile(input_str) |
|
if compiled.groups > 1: |
|
raise re.error(f"Only 1 capturing group allowed in regex, but there are {compiled.groups}.") |
|
params.update({setting_key: input_str}) |
|
return {error_html_element: gr.HTML("", visible=False)} |
|
except re.error as e: |
|
return {error_html_element: gr.HTML(f'<span style="color:red"> Invalid regex. {str(e).capitalize()}</span>', |
|
visible=True)} |
|
|
|
def update_default_custom_system_message(check: bool): |
|
if check: |
|
params.update({"default system prompt filename": custom_system_message_filename}) |
|
else: |
|
params.update({"default system prompt filename": None}) |
|
|
|
with gr.Row(): |
|
enable = gr.Checkbox(value=lambda: params['enable'], label='Enable LLM web search') |
|
use_cpu_only = gr.Checkbox(value=lambda: params['cpu only'], |
|
label='Run extension on CPU only ' |
|
'(Save settings and restart for the change to take effect)') |
|
with gr.Column(): |
|
save_settings_btn = gr.Button("Save settings") |
|
saved_success_elem = gr.HTML("", visible=False) |
|
|
|
with gr.Row(): |
|
result_radio = gr.Radio( |
|
["Regular results", "Regular results and instant answers"], |
|
label="What kind of search results should be returned?", |
|
value=lambda: "Regular results and instant answers" if |
|
(params["regular search results"] and params["instant answers"]) else "Regular results" |
|
) |
|
with gr.Column(): |
|
search_command_regex = gr.Textbox(label="Search command regex string", |
|
placeholder=params["default search command regex"], |
|
value=lambda: params["search command regex"]) |
|
search_command_regex_error_label = gr.HTML("", visible=False) |
|
|
|
with gr.Column(): |
|
open_url_command_regex = gr.Textbox(label="Open URL command regex string", |
|
placeholder=params["default open url command regex"], |
|
value=lambda: params["open url command regex"]) |
|
open_url_command_regex_error_label = gr.HTML("", visible=False) |
|
|
|
with gr.Column(): |
|
show_results = gr.Checkbox(value=lambda: params['display search results in chat'], |
|
label='Display search results in chat') |
|
show_url_content = gr.Checkbox(value=lambda: params['display extracted URL content in chat'], |
|
label='Display extracted URL content in chat') |
|
gr.Markdown(value='---') |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(value='#### Load custom system message\n' |
|
'Select a saved custom system message from within the system_prompts folder or "None" ' |
|
'to clear the selection') |
|
system_prompt = gr.Dropdown( |
|
choices=get_available_system_prompts(), label="Select custom system message", |
|
value=lambda: 'Select custom system message to load...' if custom_system_message_filename is None else |
|
custom_system_message_filename, elem_classes='slim-dropdown') |
|
with gr.Row(): |
|
set_system_message_as_default = gr.Checkbox( |
|
value=lambda: custom_system_message_filename == params["default system prompt filename"], |
|
label='Set this custom system message as the default') |
|
refresh_button = ui_module.create_refresh_button(system_prompt, lambda: None, |
|
lambda: {'choices': get_available_system_prompts()}, |
|
'refresh-button', interactive=True) |
|
refresh_button.elem_id = "custom-sysprompt-refresh" |
|
delete_button = gr.Button('🗑️', elem_classes='refresh-button', interactive=True) |
|
append_datetime = gr.Checkbox(value=lambda: params['append current datetime'], |
|
label='Append current date and time when loading custom system message') |
|
with gr.Column(): |
|
gr.Markdown(value='#### Create custom system message') |
|
system_prompt_text = gr.Textbox(label="Custom system message", lines=3, |
|
value=lambda: load_system_prompt(custom_system_message_filename)) |
|
sys_prompt_filename = gr.Text(label="Filename") |
|
sys_prompt_save_button = gr.Button("Save Custom system message") |
|
system_prompt_saved_success_elem = gr.HTML("", visible=False) |
|
|
|
gr.Markdown(value='---') |
|
with gr.Accordion("Advanced settings", open=False): |
|
ensemble_weighting = gr.Slider(minimum=0, maximum=1, step=0.05, value=lambda: params["ensemble weighting"], |
|
label="Ensemble Weighting", info="Smaller values = More keyword oriented, " |
|
"Larger values = More focus on semantic similarity") |
|
with gr.Row(): |
|
keyword_retriever = gr.Radio([("Okapi BM25", "bm25"),("SPLADE", "splade")], label="Sparse keyword retriever", |
|
info="For change to take effect, toggle the extension off and on again", |
|
value=lambda: params["keyword retriever"]) |
|
splade_batch_size = gr.Slider(minimum=2, maximum=256, step=2, value=lambda: params["splade batch size"], |
|
label="SPLADE batch size", |
|
info="Smaller values = Slower retrieval (but lower VRAM usage), " |
|
"Larger values = Faster retrieval (but higher VRAM usage). " |
|
"A good trade-off seems to be setting it = 8", |
|
precision=0) |
|
with gr.Row(): |
|
chunker = gr.Radio([("Character-based", "character-based"), |
|
("Semantic", "semantic")], label="Chunking method", |
|
value=lambda: params["chunking method"]) |
|
chunker_breakpoint_threshold_amount = gr.Slider(minimum=1, maximum=100, step=1, |
|
value=lambda: params["chunker breakpoint_threshold_amount"], |
|
label="Semantic chunking: sentence split threshold (%)", |
|
info="Defines how different two consecutive sentences have" |
|
" to be for them to be split into separate chunks", |
|
precision=0) |
|
gr.Markdown("**Note: Changing the following might result in DuckDuckGo rate limiting or the LM being overwhelmed**") |
|
num_search_results = gr.Number(label="Max. search results to return per query", minimum=1, maximum=100, |
|
value=lambda: params["search results per query"], precision=0) |
|
num_process_search_results = gr.Number(label="Number of search results to process per query", minimum=1, |
|
maximum=100, value=lambda: params["duckduckgo results per query"], |
|
precision=0) |
|
langchain_similarity_threshold = gr.Number(label="Langchain Similarity Score Threshold", minimum=0., maximum=1., |
|
value=lambda: params["langchain similarity score threshold"]) |
|
chunk_size = gr.Number(label="Max. chunk size", info="The maximal size of the individual chunks that each webpage will" |
|
" be split into, in characters", minimum=2, maximum=10000, |
|
value=lambda: params["chunk size"], precision=0) |
|
|
|
with gr.Row(): |
|
searxng_url = gr.Textbox(label="SearXNG URL", |
|
value=lambda: params["searxng url"]) |
|
|
|
|
|
enable.input(toggle_extension, enable, enable) |
|
use_cpu_only.change(lambda x: params.update({"cpu only": x}), use_cpu_only, None) |
|
save_settings_btn.click(save_settings, None, [saved_success_elem]) |
|
ensemble_weighting.change(lambda x: params.update({"ensemble weighting": x}), ensemble_weighting, None) |
|
keyword_retriever.change(lambda x: params.update({"keyword retriever": x}), keyword_retriever, None) |
|
splade_batch_size.change(lambda x: params.update({"splade batch size": x}), splade_batch_size, None) |
|
chunker.change(lambda x: params.update({"chunking method": x}), chunker, None) |
|
chunker_breakpoint_threshold_amount.change(lambda x: params.update({"chunker breakpoint_threshold_amount": x}), |
|
chunker_breakpoint_threshold_amount, None) |
|
num_search_results.change(lambda x: params.update({"search results per query": x}), num_search_results, None) |
|
num_process_search_results.change(lambda x: params.update({"duckduckgo results per query": x}), |
|
num_process_search_results, None) |
|
langchain_similarity_threshold.change(lambda x: params.update({"langchain similarity score threshold": x}), |
|
langchain_similarity_threshold, None) |
|
chunk_size.change(lambda x: params.update({"chunk size": x}), chunk_size, None) |
|
result_radio.change(update_result_type_setting, result_radio, None) |
|
|
|
search_command_regex.change(lambda x: update_regex_setting(x, "search command regex", |
|
search_command_regex_error_label), |
|
search_command_regex, search_command_regex_error_label, show_progress="hidden") |
|
|
|
open_url_command_regex.change(lambda x: update_regex_setting(x, "open url command regex", |
|
open_url_command_regex_error_label), |
|
open_url_command_regex, open_url_command_regex_error_label, show_progress="hidden") |
|
|
|
show_results.change(lambda x: params.update({"display search results in chat": x}), show_results, None) |
|
show_url_content.change(lambda x: params.update({"display extracted URL content in chat": x}), show_url_content, |
|
None) |
|
searxng_url.change(lambda x: params.update({"searxng url": x}), searxng_url, None) |
|
|
|
delete_button.click( |
|
lambda x: x, system_prompt, gradio('delete_filename')).then( |
|
lambda: os.path.join(extension_path, "system_prompts", ""), None, gradio('delete_root')).then( |
|
lambda: gr.update(visible=True), None, gradio('file_deleter')) |
|
shared.gradio['delete_confirm'].click( |
|
lambda: "None", None, system_prompt).then( |
|
None, None, None, _js="() => { document.getElementById('custom-sysprompt-refresh').click() }") |
|
system_prompt.change(load_system_prompt, system_prompt, shared.gradio['custom_system_message']) |
|
system_prompt.change(load_system_prompt, system_prompt, system_prompt_text) |
|
|
|
system_prompt.change(lambda x: x == params["default system prompt filename"], system_prompt, |
|
set_system_message_as_default) |
|
sys_prompt_filename.change(check_file_exists, sys_prompt_filename, system_prompt_saved_success_elem) |
|
sys_prompt_save_button.click(save_system_prompt, [sys_prompt_filename, system_prompt_text], |
|
system_prompt_saved_success_elem, |
|
show_progress="hidden").then(timeout_save_message, |
|
None, |
|
system_prompt_saved_success_elem, |
|
_js="() => { document.getElementById('custom-sysprompt-refresh').click() }", |
|
show_progress="hidden").then(lambda: "", None, |
|
sys_prompt_filename, |
|
show_progress="hidden") |
|
append_datetime.change(lambda x: params.update({"append current datetime": x}), append_datetime, None) |
|
|
|
set_system_message_as_default.input(update_default_custom_system_message, set_system_message_as_default, None) |
|
|
|
|
|
force_search_checkbox = gr.Checkbox(value=False, visible=False, elem_id="Force-search-checkbox") |
|
force_search_checkbox.change(toggle_forced_search, force_search_checkbox, None) |
|
|
|
|
|
def custom_generate_reply(question, original_question, seed, state, stopping_strings, is_chat): |
|
""" |
|
Overrides the main text generation function. |
|
:return: |
|
""" |
|
global update_history, langchain_compressor |
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', |
|
'CtransformersModel']: |
|
generate_func = generate_reply_custom |
|
else: |
|
generate_func = generate_reply_HF |
|
|
|
if not params['enable']: |
|
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): |
|
yield reply |
|
return |
|
|
|
web_search = False |
|
read_webpage = False |
|
max_search_results = int(params["search results per query"]) |
|
instant_answers = params["instant answers"] |
|
|
|
|
|
langchain_compressor.num_results = int(params["duckduckgo results per query"]) |
|
langchain_compressor.similarity_threshold = params["langchain similarity score threshold"] |
|
langchain_compressor.chunk_size = params["chunk size"] |
|
langchain_compressor.ensemble_weighting = params["ensemble weighting"] |
|
langchain_compressor.splade_batch_size = params["splade batch size"] |
|
langchain_compressor.chunking_method = params["chunking method"] |
|
langchain_compressor.chunker_breakpoint_threshold_amount = params["chunker breakpoint_threshold_amount"] |
|
|
|
search_command_regex = params["search command regex"] |
|
open_url_command_regex = params["open url command regex"] |
|
searxng_url = params["searxng url"] |
|
display_search_results = params["display search results in chat"] |
|
display_webpage_content = params["display extracted URL content in chat"] |
|
|
|
if search_command_regex == "": |
|
search_command_regex = params["default search command regex"] |
|
if open_url_command_regex == "": |
|
open_url_command_regex = params["default open url command regex"] |
|
|
|
compiled_search_command_regex = re.compile(search_command_regex) |
|
compiled_open_url_command_regex = re.compile(open_url_command_regex) |
|
|
|
if force_search: |
|
question += f" {params['force search prefix']}" |
|
|
|
reply = None |
|
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): |
|
|
|
if force_search: |
|
reply = params["force search prefix"] + reply |
|
|
|
search_re_match = compiled_search_command_regex.search(reply) |
|
if search_re_match is not None: |
|
yield reply |
|
original_model_reply = reply |
|
web_search = True |
|
search_term = search_re_match.group(1) |
|
print(f"LLM_Web_search | Searching for {search_term}...") |
|
reply += "\n```plaintext" |
|
reply += "\nSearch tool:\n" |
|
if searxng_url == "": |
|
search_generator = Generator(langchain_search_duckduckgo(search_term, |
|
langchain_compressor, |
|
max_search_results, |
|
instant_answers)) |
|
else: |
|
search_generator = Generator(langchain_search_searxng(search_term, |
|
searxng_url, |
|
langchain_compressor, |
|
max_search_results)) |
|
try: |
|
for status_message in search_generator: |
|
yield original_model_reply + f"\n*{status_message}*" |
|
search_results = search_generator.value |
|
except Exception as exc: |
|
exception_message = str(exc) |
|
reply += f"The search tool encountered an error: {exception_message}" |
|
print(f'LLM_Web_search | {search_term} generated an exception: {exception_message}') |
|
else: |
|
if search_results != "": |
|
reply += search_results |
|
else: |
|
reply += f"\nThe search tool did not return any results." |
|
reply += "```" |
|
if display_search_results: |
|
yield reply |
|
break |
|
|
|
open_url_re_match = compiled_open_url_command_regex.search(reply) |
|
if open_url_re_match is not None: |
|
yield reply |
|
original_model_reply = reply |
|
read_webpage = True |
|
url = open_url_re_match.group(1) |
|
print(f"LLM_Web_search | Reading {url}...") |
|
reply += "\n```plaintext" |
|
reply += "\nURL opener tool:\n" |
|
try: |
|
webpage_content = get_webpage_content(url) |
|
except Exception as exc: |
|
reply += f"Couldn't open {url}. Error message: {str(exc)}" |
|
print(f'LLM_Web_search | {url} generated an exception: {str(exc)}') |
|
else: |
|
reply += f"\nText content of {url}:\n" |
|
reply += webpage_content |
|
reply += "```\n" |
|
if display_webpage_content: |
|
yield reply |
|
break |
|
yield reply |
|
|
|
if web_search or read_webpage: |
|
display_results = web_search and display_search_results or read_webpage and display_webpage_content |
|
|
|
new_question = chat.generate_chat_prompt(f"{question}{reply}", state) |
|
new_reply = "" |
|
for new_reply in generate_func(new_question, new_question, seed, state, |
|
stopping_strings, is_chat=is_chat): |
|
if display_results: |
|
yield f"{reply}\n{new_reply}" |
|
else: |
|
yield f"{original_model_reply}\n{new_reply}" |
|
|
|
if not display_results: |
|
update_history = [state["textbox"], f"{reply}\n{new_reply}"] |
|
|
|
|
|
def output_modifier(string, state, is_chat=False): |
|
""" |
|
Modifies the output string before it is presented in the UI. In chat mode, |
|
it is applied to the bot's reply. Otherwise, it is applied to the entire |
|
output. |
|
:param string: |
|
:param state: |
|
:param is_chat: |
|
:return: |
|
""" |
|
return string |
|
|
|
|
|
def custom_css(): |
|
""" |
|
Returns custom CSS as a string. It is applied whenever the web UI is loaded. |
|
:return: |
|
""" |
|
return '' |
|
|
|
|
|
def custom_js(): |
|
""" |
|
Returns custom javascript as a string. It is applied whenever the web UI is |
|
loaded. |
|
:return: |
|
""" |
|
with open(os.path.join(extension_path, "script.js"), "r") as f: |
|
return f.read() |
|
|
|
|
|
def chat_input_modifier(text, visible_text, state): |
|
""" |
|
Modifies both the visible and internal inputs in chat mode. Can be used to |
|
hijack the chat input with custom content. |
|
:param text: |
|
:param visible_text: |
|
:param state: |
|
:return: |
|
""" |
|
return text, visible_text |
|
|
|
|
|
def state_modifier(state): |
|
""" |
|
Modifies the dictionary containing the UI input parameters before it is |
|
used by the text generation functions. |
|
:param state: |
|
:return: |
|
""" |
|
return state |
|
|
|
|
|
def history_modifier(history): |
|
""" |
|
Modifies the chat history before the text generation in chat mode begins. |
|
:param history: |
|
:return: |
|
""" |
|
global update_history |
|
if update_history: |
|
history["internal"].append(update_history) |
|
update_history = None |
|
return history |