omer.danziger
commited on
Commit
•
e766b50
1
Parent(s):
29d1cfb
set model and context from outside the model
Browse files
LLM.py
CHANGED
@@ -33,18 +33,20 @@ class LLM:
|
|
33 |
MODEL = "mosaicml/mpt-7b-chat"
|
34 |
CONSTEXT = "You are an helpfull assistante in a school. You are helping a student with his homework."
|
35 |
|
36 |
-
def __init__(self):
|
37 |
-
|
|
|
|
|
38 |
|
39 |
-
def load_model(self):
|
40 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
41 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
42 |
|
43 |
if device == "cuda:0":
|
44 |
-
model = AutoModelForCausalLM.from_pretrained(
|
45 |
torch_dtype=torch.float16, device_map="auto", load_in_8bit=True)
|
46 |
else:
|
47 |
-
model = AutoModelForCausalLM.from_pretrained(
|
48 |
pipe = pipeline(
|
49 |
"text-generation",
|
50 |
model=model,
|
|
|
33 |
MODEL = "mosaicml/mpt-7b-chat"
|
34 |
CONSTEXT = "You are an helpfull assistante in a school. You are helping a student with his homework."
|
35 |
|
36 |
+
def __init__(self, model_name=None):
|
37 |
+
if model_name is None:
|
38 |
+
model_name = LLM.MODEL
|
39 |
+
self.load_model(model_name)
|
40 |
|
41 |
+
def load_model(self, model_name):
|
42 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
44 |
|
45 |
if device == "cuda:0":
|
46 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,
|
47 |
torch_dtype=torch.float16, device_map="auto", load_in_8bit=True)
|
48 |
else:
|
49 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
50 |
pipe = pipeline(
|
51 |
"text-generation",
|
52 |
model=model,
|
app.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
from LLM import LLM
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
5 |
while True:
|
6 |
qn = input("Question: ")
|
7 |
if qn == "exit":
|
|
|
1 |
from LLM import LLM
|
2 |
|
3 |
+
model = "mosaicml/mpt-7b-chat"
|
4 |
+
context = "You are an helpfully assistant in a school. You are helping a student with his homework."
|
5 |
+
|
6 |
+
llm = LLM(model)
|
7 |
+
chat = llm.get_chat(context=context)
|
8 |
while True:
|
9 |
qn = input("Question: ")
|
10 |
if qn == "exit":
|