Spaces:
Sleeping
Sleeping
import torch | |
import os | |
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer | |
#from interface import GemmaLLMInterface | |
from llama_index.embeddings.instructor import InstructorEmbedding | |
import gradio as gr | |
from llama_index.core import Settings, ServiceContext, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, PromptTemplate, load_index_from_storage, StorageContext | |
from llama_index.core.node_parser import SentenceSplitter | |
import spaces | |
from huggingface_hub import login | |
from llama_index.core.memory import ChatMemoryBuffer | |
from typing import Iterator, List, Any | |
from llama_index.core.chat_engine import CondensePlusContextChatEngine | |
from llama_index.core.llms import ChatMessage, MessageRole , CompletionResponse | |
from IPython.display import Markdown, display | |
from langchain_huggingface import HuggingFaceEmbeddings | |
#from llama_index import LangchainEmbedding, ServiceContext | |
#from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.llms.huggingface import HuggingFaceInferenceAPI, HuggingFaceLLM | |
from dotenv import load_dotenv | |
import logging | |
import sys | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
login(huggingface_token) | |
"""huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
login(huggingface_token) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_id = "google/gemma-2-2b-it" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
token=True) | |
tokenizer= AutoTokenizer.from_pretrained("google/gemma-2b-it") | |
model.tokenizer = tokenizer | |
model.eval()""" | |
system_prompt=""" | |
You are a Q&A assistant. Your goal is to answer questions as | |
accurately as possible based on the instructions and context provided. | |
""" | |
load_dotenv() | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
llm = HuggingFaceLLM( | |
context_window=4096, | |
max_new_tokens=256, | |
generate_kwargs={"temperature": 0.1, "do_sample": True}, | |
system_prompt=system_prompt, | |
tokenizer_name="meta-llama/Llama-2-7b-chat-hf", | |
model_name="meta-llama/Llama-2-7b-chat-hf", | |
device_map="auto", | |
# loading model in 8bit for reducing memory | |
model_kwargs={"torch_dtype": torch.float16 } | |
) | |
embed_model= HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
Settings.llm = llm | |
Settings.embed_model = embed_model | |
#Settings.node_parser = SentenceSplitter(chunk_size=512, chunk_overlap=20, paragraph_separator="\n\n") | |
Settings.num_output = 512 | |
Settings.context_window = 3900 | |
documents = SimpleDirectoryReader('./data').load_data() | |
nodes = SentenceSplitter(chunk_size=512, chunk_overlap=20, paragraph_separator="\n\n").get_nodes_from_documents(documents) | |
# Build the vector store index from the nodes | |
# what models will be used by LlamaIndex: | |
#Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") | |
#Settings.embed_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') | |
#Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") | |
#Settings.llm = GemmaLLMInterface() | |
documents_paths = { | |
'blockchain': 'data/blockchainprova.txt', | |
'metaverse': 'data/metaverseprova.txt', | |
'payment': 'data/paymentprova.txt' | |
} | |
global session_state | |
session_state = {"index": False, | |
"documents_loaded": False, | |
"document_db": None, | |
"original_message": None, | |
"clarification": False} | |
PERSIST_DIR = "./db" | |
os.makedirs(PERSIST_DIR, exist_ok=True) | |
ISTR = "In italiano, chiedi molto brevemente se la domanda si riferisce agli 'Osservatori Blockchain', 'Osservatori Payment' oppure 'Osservatori Metaverse'." | |
############################--------------------------------- | |
# Get the parser | |
"""parser = SentenceSplitter.from_defaults( | |
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n" | |
) | |
def build_index(path: str): | |
# Load documents from a file | |
documents = SimpleDirectoryReader(input_files=[path]).load_data() | |
# Parse the documents into nodes | |
nodes = parser.get_nodes_from_documents(documents) | |
# Build the vector store index from the nodes | |
index = VectorStoreIndex(nodes) | |
#storage_context = StorageContext.from_defaults() | |
#index.storage_context.persist(persist_dir=PERSIST_DIR) | |
return index""" | |
def handle_query(query_str: str, | |
chat_history: list[tuple[str, str]]) -> Iterator[str]: | |
#index= build_index("data/blockchainprova.txt") | |
index = VectorStoreIndex(nodes, show_progress = True) | |
conversation: List[ChatMessage] = [] | |
for user, assistant in chat_history: | |
conversation.extend([ | |
ChatMessage(role=MessageRole.USER, content=user), | |
ChatMessage(role=MessageRole.ASSISTANT, content=assistant), | |
] | |
) | |
"""if not session_state["index"]: | |
matched_path = None | |
words = query_str.lower() | |
for key, path in documents_paths.items(): | |
if key in words: | |
matched_path = path | |
break | |
if matched_path: | |
index = build_index(matched_path) | |
gr.Info("index costruito con la path sulla base della query") | |
session_state["index"] = True | |
else: ## CHIEDI CHIARIMENTO | |
conversation.append(ChatMessage(role=MessageRole.SYSTEM, content=ISTR)) | |
index = build_index("data/blockchainprova.txt") | |
gr.Info("index costruito con richiesta di chiarimento") | |
else: | |
index = build_index(matched_path) | |
#storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) | |
#index = load_index_from_storage(storage_context) | |
gr.Info("index is true")""" | |
try: | |
memory = ChatMemoryBuffer.from_defaults(token_limit=1500) | |
"""chat_engine = index.as_chat_engine( | |
chat_mode="condense_plus_context", | |
memory=memory, | |
similarity_top_k=3, | |
response_mode= "tree_summarize", #Good for summarization purposes | |
context_prompt = ( | |
"Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso." | |
" Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto." | |
" Ecco i documenti rilevanti per il contesto:\n" | |
"{context_str}" | |
"\nIstruzione: Usa la cronologia della chat, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda." | |
), | |
verbose=False, | |
)""" | |
print("chat engine..") | |
gr.Info("chat engine..") | |
chat_engine = index.as_chat_engine( | |
chat_mode="context", | |
similarity_top_k=3, | |
memory=memory, | |
context_prompt=( | |
"Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso." | |
" Usa la cronologia della chat, o il contesto fornito, per interagire e aiutare l'utente a rispondere alla sua domanda." | |
), | |
) | |
"""retriever = index.as_retriever(similarity_top_k=3) | |
# Let's test it out | |
relevant_chunks = relevant_chunks = retriever.retrieve(query_str) | |
print(f"Found: {len(relevant_chunks)} relevant chunks") | |
for idx, chunk in enumerate(relevant_chunks): | |
info_message += f"{idx + 1}) {chunk.text[:64]}...\n" | |
print(info_message) | |
gr.Info(info_message)""" | |
#prompts_dict = chat_engine.get_prompts() | |
#display_prompt_dict(prompts_dict) | |
#chat_engine.reset() | |
outputs = [] | |
#response = query_engine.query(query_str) | |
response = chat_engine.stream_chat(query_str, chat_history=conversation) | |
sources = [] # Use a list to collect multiple sources if present | |
#response = chat_engine.chat(query_str) | |
for token in response.response_gen: | |
if token.startswith("assistant:"): | |
# Remove the "assistant:" prefix | |
outputs.append(token[len("assistant:"):]) | |
print(f"Generated token: {token}") | |
yield "".join(outputs) | |
#yield CompletionResponse(text=''.join(outputs), delta=token) | |
"""if sources: | |
sources_str = ", ".join(sources) | |
outputs.append(f"Fonti utilizzate: {sources_str}") | |
else: | |
outputs.append("Nessuna fonte specifica utilizzata.") | |
yield "".join(outputs)""" | |
except Exception as e: | |
yield f"Error processing query: {str(e)}" | |