Ashmi Banerjee
updates with gemini
420fa8a
raw
history blame
594 Bytes
import os
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
load_dotenv()
def gemma_predict(combined_information, model_name):
HF_token = os.environ["HF_TOKEN"]
client = InferenceClient(model_name, token=HF_token)
stream = client.text_generation(prompt=combined_information, details=True, stream=True, max_new_tokens=2048,
return_full_text=False)
output = ""
for response in stream:
output += response.token.text
if "<eos>" in output:
output = output.split("<eos>")[0]
return output