SKB-Explorer / src /tools /api_lib /huggingface.py
shirwu's picture
rephrsse
a00d62c
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
loaded_hf_models = {}
def complete_text_hf(message,
model="huggingface/codellama/CodeLlama-7b-hf",
max_tokens=2000,
temperature=0.5,
json_object=False,
max_retry=1,
sleep_time=0,
stop_sequences=[],
**kwargs):
if json_object:
message = "You are a helpful assistant designed to output in JSON format." + message
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.split("/", 1)[1]
if model in loaded_hf_models:
hf_model, tokenizer = loaded_hf_models[model]
else:
hf_model = AutoModelForCausalLM.from_pretrained(model).to(device)
tokenizer = AutoTokenizer.from_pretrained(model)
loaded_hf_models[model] = (hf_model, tokenizer)
encoded_input = tokenizer(message,
return_tensors="pt",
return_token_type_ids=False
).to(device)
for cnt in range(max_retry):
try:
output = hf_model.generate(
**encoded_input,
temperature=temperature,
max_new_tokens=max_tokens,
do_sample=True,
return_dict_in_generate=True,
output_scores=True,
**kwargs,
)
sequences = output.sequences
sequences = [sequence[len(encoded_input.input_ids[0]) :] for sequence in sequences]
all_decoded_text = tokenizer.batch_decode(sequences)
completion = all_decoded_text[0]
return completion
except Exception as e:
print(cnt, "=>", e)
time.sleep(sleep_time)
raise e