omer.danziger commited on
Commit
e766b50
1 Parent(s): 29d1cfb

set model and context from outside the model

Browse files
Files changed (2) hide show
  1. LLM.py +8 -6
  2. app.py +5 -2
LLM.py CHANGED
@@ -33,18 +33,20 @@ class LLM:
33
  MODEL = "mosaicml/mpt-7b-chat"
34
  CONSTEXT = "You are an helpfull assistante in a school. You are helping a student with his homework."
35
 
36
- def __init__(self):
37
- self.load_model()
 
 
38
 
39
- def load_model(self):
40
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
41
- tokenizer = AutoTokenizer.from_pretrained(LLM.MODEL)
42
 
43
  if device == "cuda:0":
44
- model = AutoModelForCausalLM.from_pretrained(LLM.MODEL, trust_remote_code=True,
45
  torch_dtype=torch.float16, device_map="auto", load_in_8bit=True)
46
  else:
47
- model = AutoModelForCausalLM.from_pretrained(LLM.MODEL, trust_remote_code=True)
48
  pipe = pipeline(
49
  "text-generation",
50
  model=model,
 
33
  MODEL = "mosaicml/mpt-7b-chat"
34
  CONSTEXT = "You are an helpfull assistante in a school. You are helping a student with his homework."
35
 
36
+ def __init__(self, model_name=None):
37
+ if model_name is None:
38
+ model_name = LLM.MODEL
39
+ self.load_model(model_name)
40
 
41
+ def load_model(self, model_name):
42
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
44
 
45
  if device == "cuda:0":
46
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,
47
  torch_dtype=torch.float16, device_map="auto", load_in_8bit=True)
48
  else:
49
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
50
  pipe = pipeline(
51
  "text-generation",
52
  model=model,
app.py CHANGED
@@ -1,7 +1,10 @@
1
  from LLM import LLM
2
 
3
- llm = LLM()
4
- chat = llm.get_chat(context=LLM.CONSTEXT)
 
 
 
5
  while True:
6
  qn = input("Question: ")
7
  if qn == "exit":
 
1
  from LLM import LLM
2
 
3
+ model = "mosaicml/mpt-7b-chat"
4
+ context = "You are an helpfully assistant in a school. You are helping a student with his homework."
5
+
6
+ llm = LLM(model)
7
+ chat = llm.get_chat(context=context)
8
  while True:
9
  qn = input("Question: ")
10
  if qn == "exit":