|
from huggingface_hub import InferenceClient |
|
from dotenv import load_dotenv |
|
import configparser |
|
import os |
|
|
|
|
|
class LLMManager: |
|
def __init__(self, settings): |
|
|
|
|
|
try: |
|
load_dotenv() |
|
except: |
|
print("No .env file") |
|
|
|
|
|
HF_TOKEN = os.environ.get('HF_TOKEN') |
|
self.client = InferenceClient(token=HF_TOKEN) |
|
|
|
|
|
self.config=configparser.ConfigParser() |
|
self.config.read("config.ini") |
|
|
|
|
|
self.set=settings |
|
|
|
|
|
self.defaultLLM=self.set.defaultLLM |
|
|
|
|
|
self.listLLM=self.get_llm() |
|
self.listLLMMap=self.get_llm_map() |
|
|
|
|
|
self.currentLLM=self.listLLM[self.defaultLLM] |
|
|
|
|
|
def selectLLM(self, llm): |
|
print("Selected {llm} LLM") |
|
llmIndex=self.listLLMMap.index(llm) |
|
self.currentLLM=self.listLLM[llmIndex] |
|
|
|
|
|
def get_llm(self): |
|
llm_section = 'LLM' |
|
if llm_section in self.config: |
|
return [self.config.get(llm_section, llm) for llm in self.config[llm_section]] |
|
else: |
|
return [] |
|
|
|
|
|
def get_llm_prompts(self): |
|
prompt_section = 'Prompt_map' |
|
if prompt_section in self.config: |
|
return [self.config.get(prompt_section, llm) for llm in self.config[prompt_section]] |
|
else: |
|
return [] |
|
|
|
|
|
def get_llm_map(self): |
|
llm_map_section = 'LLM_Map' |
|
if llm_map_section in self.config: |
|
return [self.config.get(llm_map_section, llm) for llm in self.config[llm_map_section]] |
|
else: |
|
return [] |
|
|
|
|
|
def get_text(self, question): |
|
|
|
print("temp={temp}".format(temp=self.set.temperature)) |
|
print("Repetition={rep}".format(rep=self.set.repetition_penalty)) |
|
generate_kwargs = dict( |
|
temperature=self.set.temperature, |
|
max_new_tokens=self.set.max_new_token, |
|
top_p=self.set.top_p, |
|
repetition_penalty=self.set.repetition_penalty, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
|
|
stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False) |
|
|
|
return stream |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_query_terms(self, question): |
|
generate_kwargs = dict( |
|
temperature=self.set.RAG_temperature, |
|
max_new_tokens=self.set.RAG_max_new_token, |
|
top_p=self.set.RAG_top_p, |
|
repetition_penalty=self.set.RAG_repetition_penalty, |
|
do_sample=True, |
|
) |
|
stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False) |
|
return stream |
|
|
|
|
|
|
|
def get_prompt(self,user_input,rag_contex,chat_history, system_prompt=None): |
|
"""Returns the formatted prompt for a specific LLM""" |
|
|
|
prompts=self.get_llm_prompts() |
|
prompt="" |
|
|
|
if system_prompt==None: |
|
system_prompt=self.set.system_prompt |
|
else: |
|
print("System prompt set to : \n {sys_prompt}".format(sys_prompt=system_prompt)) |
|
|
|
try: |
|
prompt= prompts[self.listLLM.index(self.currentLLM)].format(sys_prompt=system_prompt) |
|
except Exception: |
|
print("Warning prompt map for {llm} has not been defined".format(llm=self.currentLLM)) |
|
prompt="{sys_prompt}".format(sys_prompt=system_prompt) |
|
|
|
print("Prompt={pro}".format(pro=prompt)) |
|
return prompt.format(context=rag_contex,history=chat_history,question=user_input) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|