import time import re import json import os from datetime import datetime import gradio as gr import torch import modules.shared as shared from modules import chat, ui as ui_module from modules.utils import gradio from modules.text_generation import generate_reply_HF, generate_reply_custom from .llm_web_search import get_webpage_content, langchain_search_duckduckgo, langchain_search_searxng, Generator from .langchain_websearch import LangchainCompressor params = { "display_name": "LLM Web Search", "is_tab": True, "enable": True, "search results per query": 5, "langchain similarity score threshold": 0.5, "instant answers": True, "regular search results": True, "search command regex": "", "default search command regex": r"Search_web\(\"(.*)\"\)", "open url command regex": "", "default open url command regex": r"Open_url\(\"(.*)\"\)", "display search results in chat": True, "display extracted URL content in chat": True, "searxng url": "", "cpu only": True, "chunk size": 500, "duckduckgo results per query": 10, "append current datetime": False, "default system prompt filename": None, "force search prefix": "Search_web", "ensemble weighting": 0.5, "keyword retriever": "bm25", "splade batch size": 2, "chunking method": "character-based", "chunker breakpoint_threshold_amount": 30 } custom_system_message_filename = None extension_path = os.path.dirname(os.path.abspath(__file__)) langchain_compressor = None update_history = None force_search = False def setup(): """ Is executed when the extension gets imported. :return: """ global params os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["QDRANT__TELEMETRY_DISABLED"] = "true" try: with open(os.path.join(extension_path, "settings.json"), "r") as f: saved_params = json.load(f) params.update(saved_params) save_settings() # add keys of newly added feature to settings.json except FileNotFoundError: save_settings() if not os.path.exists(os.path.join(extension_path, "system_prompts")): os.makedirs(os.path.join(extension_path, "system_prompts")) toggle_extension(params["enable"]) def save_settings(): global params with open(os.path.join(extension_path, "settings.json"), "w") as f: json.dump(params, f, indent=4) current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") return gr.HTML(f' Settings were saved at {current_datetime}', visible=True) def toggle_extension(_enable: bool): global langchain_compressor, custom_system_message_filename if _enable: langchain_compressor = LangchainCompressor(device="cpu" if params["cpu only"] else "cuda", keyword_retriever=params["keyword retriever"], model_cache_dir=os.path.join(extension_path, "hf_models")) compressor_model = langchain_compressor.embeddings.client compressor_model.to(compressor_model._target_device) custom_system_message_filename = params.get("default system prompt filename") else: if not params["cpu only"] and 'langchain_compressor' in globals(): # free some VRAM model_attrs = ["embeddings", "splade_doc_model", "splade_query_model"] for model_attr in model_attrs: if hasattr(langchain_compressor, model_attr): model = getattr(langchain_compressor, model_attr) if hasattr(model, "client"): model.client.to("cpu") del model.client else: if hasattr(model, "to"): model.to("cpu") del model torch.cuda.empty_cache() params.update({"enable": _enable}) return _enable def get_available_system_prompts(): try: return ["None"] + sorted(os.listdir(os.path.join(extension_path, "system_prompts"))) except FileNotFoundError: return ["None"] def load_system_prompt(filename: str or None): global custom_system_message_filename if not filename: return if filename == "None" or filename == "Select custom system message to load...": custom_system_message_filename = None return "" with open(os.path.join(extension_path, "system_prompts", filename), "r") as f: prompt_str = f.read() if params["append current datetime"]: prompt_str += f"\nDate and time of conversation: {datetime.now().strftime('%A %d %B %Y %H:%M')}" shared.settings['custom_system_message'] = prompt_str custom_system_message_filename = filename return prompt_str def save_system_prompt(filename, prompt): if not filename: return with open(os.path.join(extension_path, "system_prompts", filename), "w") as f: f.write(prompt) return gr.HTML(f' Saved successfully', visible=True) def check_file_exists(filename): if filename == "": return gr.HTML("", visible=False) if os.path.exists(os.path.join(extension_path, "system_prompts", filename)): return gr.HTML(f' Warning: Filename already exists', visible=True) return gr.HTML("", visible=False) def timeout_save_message(): time.sleep(2) return gr.HTML("", visible=False) def deactivate_system_prompt(): shared.settings['custom_system_message'] = None return "None" def toggle_forced_search(value): global force_search force_search = value def ui(): """ Creates custom gradio elements when the UI is launched. :return: """ # Inject custom system message into the main textbox if a default one is set shared.gradio['custom_system_message'].value = load_system_prompt(custom_system_message_filename) def update_result_type_setting(choice: str): if choice == "Instant answers": params.update({"instant answers": True}) params.update({"regular search results": False}) elif choice == "Regular results": params.update({"instant answers": False}) params.update({"regular search results": True}) elif choice == "Regular results and instant answers": params.update({"instant answers": True}) params.update({"regular search results": True}) def update_regex_setting(input_str: str, setting_key: str, error_html_element: gr.component): if input_str == "": params.update({setting_key: params[f"default {setting_key}"]}) return {error_html_element: gr.HTML("", visible=False)} try: compiled = re.compile(input_str) if compiled.groups > 1: raise re.error(f"Only 1 capturing group allowed in regex, but there are {compiled.groups}.") params.update({setting_key: input_str}) return {error_html_element: gr.HTML("", visible=False)} except re.error as e: return {error_html_element: gr.HTML(f' Invalid regex. {str(e).capitalize()}', visible=True)} def update_default_custom_system_message(check: bool): if check: params.update({"default system prompt filename": custom_system_message_filename}) else: params.update({"default system prompt filename": None}) with gr.Row(): enable = gr.Checkbox(value=lambda: params['enable'], label='Enable LLM web search') use_cpu_only = gr.Checkbox(value=lambda: params['cpu only'], label='Run extension on CPU only ' '(Save settings and restart for the change to take effect)') with gr.Column(): save_settings_btn = gr.Button("Save settings") saved_success_elem = gr.HTML("", visible=False) with gr.Row(): result_radio = gr.Radio( ["Regular results", "Regular results and instant answers"], label="What kind of search results should be returned?", value=lambda: "Regular results and instant answers" if (params["regular search results"] and params["instant answers"]) else "Regular results" ) with gr.Column(): search_command_regex = gr.Textbox(label="Search command regex string", placeholder=params["default search command regex"], value=lambda: params["search command regex"]) search_command_regex_error_label = gr.HTML("", visible=False) with gr.Column(): open_url_command_regex = gr.Textbox(label="Open URL command regex string", placeholder=params["default open url command regex"], value=lambda: params["open url command regex"]) open_url_command_regex_error_label = gr.HTML("", visible=False) with gr.Column(): show_results = gr.Checkbox(value=lambda: params['display search results in chat'], label='Display search results in chat') show_url_content = gr.Checkbox(value=lambda: params['display extracted URL content in chat'], label='Display extracted URL content in chat') gr.Markdown(value='---') with gr.Row(): with gr.Column(): gr.Markdown(value='#### Load custom system message\n' 'Select a saved custom system message from within the system_prompts folder or "None" ' 'to clear the selection') system_prompt = gr.Dropdown( choices=get_available_system_prompts(), label="Select custom system message", value=lambda: 'Select custom system message to load...' if custom_system_message_filename is None else custom_system_message_filename, elem_classes='slim-dropdown') with gr.Row(): set_system_message_as_default = gr.Checkbox( value=lambda: custom_system_message_filename == params["default system prompt filename"], label='Set this custom system message as the default') refresh_button = ui_module.create_refresh_button(system_prompt, lambda: None, lambda: {'choices': get_available_system_prompts()}, 'refresh-button', interactive=True) refresh_button.elem_id = "custom-sysprompt-refresh" delete_button = gr.Button('🗑️', elem_classes='refresh-button', interactive=True) append_datetime = gr.Checkbox(value=lambda: params['append current datetime'], label='Append current date and time when loading custom system message') with gr.Column(): gr.Markdown(value='#### Create custom system message') system_prompt_text = gr.Textbox(label="Custom system message", lines=3, value=lambda: load_system_prompt(custom_system_message_filename)) sys_prompt_filename = gr.Text(label="Filename") sys_prompt_save_button = gr.Button("Save Custom system message") system_prompt_saved_success_elem = gr.HTML("", visible=False) gr.Markdown(value='---') with gr.Accordion("Advanced settings", open=False): ensemble_weighting = gr.Slider(minimum=0, maximum=1, step=0.05, value=lambda: params["ensemble weighting"], label="Ensemble Weighting", info="Smaller values = More keyword oriented, " "Larger values = More focus on semantic similarity") with gr.Row(): keyword_retriever = gr.Radio([("Okapi BM25", "bm25"),("SPLADE", "splade")], label="Sparse keyword retriever", info="For change to take effect, toggle the extension off and on again", value=lambda: params["keyword retriever"]) splade_batch_size = gr.Slider(minimum=2, maximum=256, step=2, value=lambda: params["splade batch size"], label="SPLADE batch size", info="Smaller values = Slower retrieval (but lower VRAM usage), " "Larger values = Faster retrieval (but higher VRAM usage). " "A good trade-off seems to be setting it = 8", precision=0) with gr.Row(): chunker = gr.Radio([("Character-based", "character-based"), ("Semantic", "semantic")], label="Chunking method", value=lambda: params["chunking method"]) chunker_breakpoint_threshold_amount = gr.Slider(minimum=1, maximum=100, step=1, value=lambda: params["chunker breakpoint_threshold_amount"], label="Semantic chunking: sentence split threshold (%)", info="Defines how different two consecutive sentences have" " to be for them to be split into separate chunks", precision=0) gr.Markdown("**Note: Changing the following might result in DuckDuckGo rate limiting or the LM being overwhelmed**") num_search_results = gr.Number(label="Max. search results to return per query", minimum=1, maximum=100, value=lambda: params["search results per query"], precision=0) num_process_search_results = gr.Number(label="Number of search results to process per query", minimum=1, maximum=100, value=lambda: params["duckduckgo results per query"], precision=0) langchain_similarity_threshold = gr.Number(label="Langchain Similarity Score Threshold", minimum=0., maximum=1., value=lambda: params["langchain similarity score threshold"]) chunk_size = gr.Number(label="Max. chunk size", info="The maximal size of the individual chunks that each webpage will" " be split into, in characters", minimum=2, maximum=10000, value=lambda: params["chunk size"], precision=0) with gr.Row(): searxng_url = gr.Textbox(label="SearXNG URL", value=lambda: params["searxng url"]) # Event functions to update the parameters in the backend enable.input(toggle_extension, enable, enable) use_cpu_only.change(lambda x: params.update({"cpu only": x}), use_cpu_only, None) save_settings_btn.click(save_settings, None, [saved_success_elem]) ensemble_weighting.change(lambda x: params.update({"ensemble weighting": x}), ensemble_weighting, None) keyword_retriever.change(lambda x: params.update({"keyword retriever": x}), keyword_retriever, None) splade_batch_size.change(lambda x: params.update({"splade batch size": x}), splade_batch_size, None) chunker.change(lambda x: params.update({"chunking method": x}), chunker, None) chunker_breakpoint_threshold_amount.change(lambda x: params.update({"chunker breakpoint_threshold_amount": x}), chunker_breakpoint_threshold_amount, None) num_search_results.change(lambda x: params.update({"search results per query": x}), num_search_results, None) num_process_search_results.change(lambda x: params.update({"duckduckgo results per query": x}), num_process_search_results, None) langchain_similarity_threshold.change(lambda x: params.update({"langchain similarity score threshold": x}), langchain_similarity_threshold, None) chunk_size.change(lambda x: params.update({"chunk size": x}), chunk_size, None) result_radio.change(update_result_type_setting, result_radio, None) search_command_regex.change(lambda x: update_regex_setting(x, "search command regex", search_command_regex_error_label), search_command_regex, search_command_regex_error_label, show_progress="hidden") open_url_command_regex.change(lambda x: update_regex_setting(x, "open url command regex", open_url_command_regex_error_label), open_url_command_regex, open_url_command_regex_error_label, show_progress="hidden") show_results.change(lambda x: params.update({"display search results in chat": x}), show_results, None) show_url_content.change(lambda x: params.update({"display extracted URL content in chat": x}), show_url_content, None) searxng_url.change(lambda x: params.update({"searxng url": x}), searxng_url, None) delete_button.click( lambda x: x, system_prompt, gradio('delete_filename')).then( lambda: os.path.join(extension_path, "system_prompts", ""), None, gradio('delete_root')).then( lambda: gr.update(visible=True), None, gradio('file_deleter')) shared.gradio['delete_confirm'].click( lambda: "None", None, system_prompt).then( None, None, None, _js="() => { document.getElementById('custom-sysprompt-refresh').click() }") system_prompt.change(load_system_prompt, system_prompt, shared.gradio['custom_system_message']) system_prompt.change(load_system_prompt, system_prompt, system_prompt_text) # restore checked state if chosen system prompt matches the default system_prompt.change(lambda x: x == params["default system prompt filename"], system_prompt, set_system_message_as_default) sys_prompt_filename.change(check_file_exists, sys_prompt_filename, system_prompt_saved_success_elem) sys_prompt_save_button.click(save_system_prompt, [sys_prompt_filename, system_prompt_text], system_prompt_saved_success_elem, show_progress="hidden").then(timeout_save_message, None, system_prompt_saved_success_elem, _js="() => { document.getElementById('custom-sysprompt-refresh').click() }", show_progress="hidden").then(lambda: "", None, sys_prompt_filename, show_progress="hidden") append_datetime.change(lambda x: params.update({"append current datetime": x}), append_datetime, None) # '.input' = only triggers when user changes the value of the component, not a function set_system_message_as_default.input(update_default_custom_system_message, set_system_message_as_default, None) # A dummy checkbox to enable the actual "Force web search" checkbox to trigger a gradio event force_search_checkbox = gr.Checkbox(value=False, visible=False, elem_id="Force-search-checkbox") force_search_checkbox.change(toggle_forced_search, force_search_checkbox, None) def custom_generate_reply(question, original_question, seed, state, stopping_strings, is_chat): """ Overrides the main text generation function. :return: """ global update_history, langchain_compressor if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF if not params['enable']: for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): yield reply return web_search = False read_webpage = False max_search_results = int(params["search results per query"]) instant_answers = params["instant answers"] # regular_search_results = params["regular search results"] langchain_compressor.num_results = int(params["duckduckgo results per query"]) langchain_compressor.similarity_threshold = params["langchain similarity score threshold"] langchain_compressor.chunk_size = params["chunk size"] langchain_compressor.ensemble_weighting = params["ensemble weighting"] langchain_compressor.splade_batch_size = params["splade batch size"] langchain_compressor.chunking_method = params["chunking method"] langchain_compressor.chunker_breakpoint_threshold_amount = params["chunker breakpoint_threshold_amount"] search_command_regex = params["search command regex"] open_url_command_regex = params["open url command regex"] searxng_url = params["searxng url"] display_search_results = params["display search results in chat"] display_webpage_content = params["display extracted URL content in chat"] if search_command_regex == "": search_command_regex = params["default search command regex"] if open_url_command_regex == "": open_url_command_regex = params["default open url command regex"] compiled_search_command_regex = re.compile(search_command_regex) compiled_open_url_command_regex = re.compile(open_url_command_regex) if force_search: question += f" {params['force search prefix']}" reply = None for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): if force_search: reply = params["force search prefix"] + reply search_re_match = compiled_search_command_regex.search(reply) if search_re_match is not None: yield reply original_model_reply = reply web_search = True search_term = search_re_match.group(1) print(f"LLM_Web_search | Searching for {search_term}...") reply += "\n```plaintext" reply += "\nSearch tool:\n" if searxng_url == "": search_generator = Generator(langchain_search_duckduckgo(search_term, langchain_compressor, max_search_results, instant_answers)) else: search_generator = Generator(langchain_search_searxng(search_term, searxng_url, langchain_compressor, max_search_results)) try: for status_message in search_generator: yield original_model_reply + f"\n*{status_message}*" search_results = search_generator.value except Exception as exc: exception_message = str(exc) reply += f"The search tool encountered an error: {exception_message}" print(f'LLM_Web_search | {search_term} generated an exception: {exception_message}') else: if search_results != "": reply += search_results else: reply += f"\nThe search tool did not return any results." reply += "```" if display_search_results: yield reply break open_url_re_match = compiled_open_url_command_regex.search(reply) if open_url_re_match is not None: yield reply original_model_reply = reply read_webpage = True url = open_url_re_match.group(1) print(f"LLM_Web_search | Reading {url}...") reply += "\n```plaintext" reply += "\nURL opener tool:\n" try: webpage_content = get_webpage_content(url) except Exception as exc: reply += f"Couldn't open {url}. Error message: {str(exc)}" print(f'LLM_Web_search | {url} generated an exception: {str(exc)}') else: reply += f"\nText content of {url}:\n" reply += webpage_content reply += "```\n" if display_webpage_content: yield reply break yield reply if web_search or read_webpage: display_results = web_search and display_search_results or read_webpage and display_webpage_content # Add results to context and continue model output new_question = chat.generate_chat_prompt(f"{question}{reply}", state) new_reply = "" for new_reply in generate_func(new_question, new_question, seed, state, stopping_strings, is_chat=is_chat): if display_results: yield f"{reply}\n{new_reply}" else: yield f"{original_model_reply}\n{new_reply}" if not display_results: update_history = [state["textbox"], f"{reply}\n{new_reply}"] def output_modifier(string, state, is_chat=False): """ Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. :param string: :param state: :param is_chat: :return: """ return string def custom_css(): """ Returns custom CSS as a string. It is applied whenever the web UI is loaded. :return: """ return '' def custom_js(): """ Returns custom javascript as a string. It is applied whenever the web UI is loaded. :return: """ with open(os.path.join(extension_path, "script.js"), "r") as f: return f.read() def chat_input_modifier(text, visible_text, state): """ Modifies both the visible and internal inputs in chat mode. Can be used to hijack the chat input with custom content. :param text: :param visible_text: :param state: :return: """ return text, visible_text def state_modifier(state): """ Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. :param state: :return: """ return state def history_modifier(history): """ Modifies the chat history before the text generation in chat mode begins. :param history: :return: """ global update_history if update_history: history["internal"].append(update_history) update_history = None return history