|
|
|
from mistral_inference.model import Transformer |
|
from mistral_inference.generate import generate |
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
|
from mistral_common.protocol.instruct.messages import UserMessage |
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest |
|
|
|
def main(): |
|
tokenizer = MistralTokenizer.from_file("model/tokenizer.model.v3") |
|
model = Transformer.from_folder("model") |
|
model.load_lora("lora/lora.safetensors") |
|
|
|
completion_request = ChatCompletionRequest(messages=[UserMessage(content="Explain Machine Learning to me in a nutshell.")]) |
|
tokens = tokenizer.encode_chat_completion(completion_request).tokens |
|
out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) |
|
result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0]) |
|
print(result) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|