Update services/llama_generator.py
Browse files- services/llama_generator.py +1 -11
services/llama_generator.py
CHANGED
@@ -45,19 +45,9 @@ class LlamaGenerator(BaseGenerator):
|
|
45 |
):
|
46 |
print(llama_model_name)
|
47 |
print(prm_model_path)
|
48 |
-
@observe()
|
49 |
-
def load_model(self, model_name: str):
|
50 |
-
# Code to load your model, e.g., Hugging Face's transformers library
|
51 |
-
from transformers import AutoModelForCausalLM
|
52 |
-
return AutoModelForCausalLM.from_pretrained(model_name)
|
53 |
|
54 |
-
@observe()
|
55 |
-
def load_tokenizer(self, model_name: str):
|
56 |
-
# Load the tokenizer associated with the model
|
57 |
-
from transformers import AutoTokenizer
|
58 |
-
return AutoTokenizer.from_pretrained(model_name)
|
59 |
|
60 |
-
self.tokenizer = load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
|
61 |
|
62 |
super().__init__(
|
63 |
llama_model_name,
|
|
|
45 |
):
|
46 |
print(llama_model_name)
|
47 |
print(prm_model_path)
|
|
|
|
|
|
|
|
|
|
|
48 |
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
self.tokenizer = model_manager.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
|
51 |
|
52 |
super().__init__(
|
53 |
llama_model_name,
|