ProHelper / llm.py
DarForm's picture
Upload folder using huggingface_hub
097caae verified
raw
history blame
4.56 kB
from huggingface_hub import InferenceClient
from dotenv import load_dotenv
import configparser
import os
class LLMManager:
def __init__(self, settings):
#Loading HF Token
try:
load_dotenv()
except:
print("No .env file")
#Initing HuggingFace Inference client
HF_TOKEN = os.environ.get('HF_TOKEN')
self.client = InferenceClient(token=HF_TOKEN)
#Creating and loading config file
self.config=configparser.ConfigParser()
self.config.read("config.ini")
#getting setting
self.set=settings
#Loading default index for LLM
self.defaultLLM=self.set.defaultLLM
#Loading available LLM
self.listLLM=self.get_llm()
self.listLLMMap=self.get_llm_map()
#Setting the model
self.currentLLM=self.listLLM[self.defaultLLM]
#Function used to select the LLM
def selectLLM(self, llm):
print("Selected {llm} LLM")
llmIndex=self.listLLMMap.index(llm)
self.currentLLM=self.listLLM[llmIndex]
#Function used to get a list of available LLM
def get_llm(self):
llm_section = 'LLM'
if llm_section in self.config:
return [self.config.get(llm_section, llm) for llm in self.config[llm_section]]
else:
return []
#Function used to get a list of available LLM
def get_llm_prompts(self):
prompt_section = 'Prompt_map'
if prompt_section in self.config:
return [self.config.get(prompt_section, llm) for llm in self.config[prompt_section]]
else:
return []
#Function used to get the list of llm Map
def get_llm_map(self):
llm_map_section = 'LLM_Map'
if llm_map_section in self.config:
return [self.config.get(llm_map_section, llm) for llm in self.config[llm_map_section]]
else:
return []
#This function is used to retrive the reply to a question
def get_text(self, question):
print("temp={temp}".format(temp=self.set.temperature))
print("Repetition={rep}".format(rep=self.set.repetition_penalty))
generate_kwargs = dict(
temperature=self.set.temperature,
max_new_tokens=self.set.max_new_token,
top_p=self.set.top_p,
repetition_penalty=self.set.repetition_penalty,
do_sample=True,
seed=42,
)
stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False)
#output = ""
return stream
#for response in stream:
# output += response.token.text
# yield output
#return output
#this function is used to retrive the best search terms
def get_query_terms(self, question):
generate_kwargs = dict(
temperature=self.set.RAG_temperature,
max_new_tokens=self.set.RAG_max_new_token,
top_p=self.set.RAG_top_p,
repetition_penalty=self.set.RAG_repetition_penalty,
do_sample=True,
)
stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False)
return stream
#This function is used to generate the prompt for the LLM
def get_prompt(self,user_input,rag_contex,chat_history, system_prompt=None):
"""Returns the formatted prompt for a specific LLM"""
prompts=self.get_llm_prompts()
prompt=""
if system_prompt==None:
system_prompt=self.set.system_prompt
else:
print("System prompt set to : \n {sys_prompt}".format(sys_prompt=system_prompt))
try:
prompt= prompts[self.listLLM.index(self.currentLLM)].format(sys_prompt=system_prompt)
except Exception:
print("Warning prompt map for {llm} has not been defined".format(llm=self.currentLLM))
prompt="{sys_prompt}".format(sys_prompt=system_prompt)
print("Prompt={pro}".format(pro=prompt))
return prompt.format(context=rag_contex,history=chat_history,question=user_input)
# Example Usage:
#if __name__ == "__main__":
# llm_manager = LLMManager()
# print(llm_manager.config.get('Prompt_map', 'prompt1').format(
# system_prompt="Sei una brava IA",
# history="",
# context="",
# question=""))
#llm_manager.selectLLM("Mixtral 7B")