"""This file should be imported only and only if you want to run the UI locally.""" import itertools import logging from pathlib import Path import subprocess from typing import Any import os from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr from fastapi import FastAPI from gradio.themes.utils.colors import slate from llama_index.llms import MessageRole, ChatMessage from app._config import settings from app.components.embedding.component import EmbeddingComponent from app.components.llm.component import LLMComponent from app.components.node_store.component import NodeStoreComponent from app.components.vector_store.component import VectorStoreComponent from app.enums import PROJECT_ROOT_PATH from app.server.chat.service import ChatService from app.server.ingest.service import IngestService from app.ui.schemas import Source from app.paths import local_data_path logger = logging.getLogger(__name__) THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH) AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "dodge_ava.jpg" UI_TAB_TITLE = "Agriculture Chatbot" SOURCES_SEPARATOR = "\n\n Sources: \n" model_name = "VietAI/envit5-translation" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) class PrivateGptUi: def __init__( self, ingest_service: IngestService, chat_service: ChatService, ) -> None: self._ingest_service = ingest_service self._chat_service = chat_service # Cache the UI blocks self._ui_block = None # Initialize system prompt self._system_prompt = self._get_default_system_prompt() def _chat( self, message: str, history: list[list[str]], upload_button: Any, system_prompt_input: Any, # show_image: bool, ) -> Any: # logger.info(f"Show image = {show_image}") if "#ảnh" in message: message = message.replace("#ảnh","") vi_message = "vi: " + message outputs = model.generate(tokenizer([vi_message], return_tensors="pt", padding=True).input_ids, max_length=512) en_message = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('en:','') command = f""" cd {local_data_path} clip-retrieval filter --query "{en_message}" --output_folder "retrieved_folder" --indice_folder "index_folder" --num_results 1 """ logger.info(command) subprocess.run(command, shell=True, check=True) folder_path = f"{local_data_path}/retrieved_folder" files = os.listdir(folder_path) # sort images by most lately retrieved. Keep the old images to show them in chat history files.sort( key=lambda x: os.path.getctime(os.path.join(folder_path, x)), reverse=True, ) newest_image = files[0] logger.info(f"Retrieve image {newest_image}") return (os.path.relpath(f"{folder_path}/{newest_image}", PROJECT_ROOT_PATH),) def build_history() -> list[ChatMessage]: history_messages: list[ChatMessage] = list( itertools.chain( *[ [ ChatMessage(content=interaction[0], role=MessageRole.USER), ChatMessage( # Remove from history content the Sources information content=( "[Image Output]" if isinstance(interaction[1], tuple) else (interaction[1]).split(SOURCES_SEPARATOR)[0] ), role=MessageRole.ASSISTANT, ), ] for interaction in history ] ) ) # max 20 messages to try to avoid context overflow return history_messages[:20] new_message = ChatMessage(content=message, role=MessageRole.USER) all_messages = [*build_history(), new_message] # If a system prompt is set, add it as a system message if self._system_prompt: all_messages.insert( 0, ChatMessage( content=self._system_prompt, role=MessageRole.SYSTEM, ), ) completion = self._chat_service.chat(messages=all_messages) full_response = completion.response if completion.sources: full_response += SOURCES_SEPARATOR curated_sources = Source.curate_sources(completion.sources) sources_text = "\n\n\n".join( f"{index}. {source.file} (page {source.page})" for index, source in enumerate(curated_sources, start=1) ) full_response += sources_text return full_response # On initialization this function set the system prompt # to the default prompt based on settings. @staticmethod def _get_default_system_prompt() -> str: return settings.DEFAULT_QUERY_SYSTEM_PROMPT def _set_system_prompt(self, system_prompt_input: str) -> None: logger.info(f"Setting system prompt to: {system_prompt_input}") self._system_prompt = system_prompt_input def _list_ingested_files(self) -> list[list[str]]: files = set() for ingested_document in self._ingest_service.list_ingested(): if ingested_document.doc_metadata is None: # Skipping documents without metadata continue file_name = ingested_document.doc_metadata.get( "file_name", "[FILE NAME MISSING]" ) files.add(file_name) return [[row] for row in files] def _upload_file(self, files: list[str]) -> None: logger.debug("Loading count=%s files", len(files)) paths = [Path(file) for file in files] self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths]) def _build_ui_blocks(self) -> gr.Blocks: logger.debug("Creating the UI blocks") with gr.Blocks( title=UI_TAB_TITLE, theme=gr.themes.Soft(primary_hue=slate), css=".logo { " "display:flex;" "height: 80px;" "border-radius: 8px;" "align-content: center;" "justify-content: center;" "align-items: center;" "}" ".logo img { height: 25% }" ".contain { display: flex !important; flex-direction: column !important; }" "#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }" "#chatbot { flex-grow: 1 !important; overflow: auto !important;}" "#col { height: calc(100vh - 112px - 16px) !important; }", ) as blocks: with gr.Row(): gr.HTML(f"