Donghao Huang
commited on
Commit
·
6b469d2
1
Parent(s):
571afe2
fixed bug on llama-2
Browse files- app_modules/llm_inference.py +9 -8
- app_modules/llm_loader.py +1 -1
- test.py +5 -1
app_modules/llm_inference.py
CHANGED
@@ -35,7 +35,12 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
35 |
return self.chain
|
36 |
|
37 |
def call_chain(
|
38 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
39 |
):
|
40 |
print(inputs)
|
41 |
if self.llm_loader.streamer.for_huggingface:
|
@@ -46,11 +51,7 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
46 |
|
47 |
chain = self.get_chain(tracing)
|
48 |
result = (
|
49 |
-
self._run_chain(
|
50 |
-
chain,
|
51 |
-
inputs,
|
52 |
-
streaming_handler,
|
53 |
-
)
|
54 |
if streaming_handler is not None
|
55 |
else chain(inputs)
|
56 |
)
|
@@ -74,7 +75,7 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
74 |
def _execute_chain(self, chain, inputs, q, sh):
|
75 |
q.put(chain(inputs, callbacks=[sh]))
|
76 |
|
77 |
-
def _run_chain(self, chain, inputs, streaming_handler):
|
78 |
que = Queue()
|
79 |
|
80 |
t = Thread(
|
@@ -83,7 +84,7 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
83 |
)
|
84 |
t.start()
|
85 |
|
86 |
-
if self.llm_loader.streamer.for_huggingface:
|
87 |
count = (
|
88 |
2
|
89 |
if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
|
|
|
35 |
return self.chain
|
36 |
|
37 |
def call_chain(
|
38 |
+
self,
|
39 |
+
inputs,
|
40 |
+
streaming_handler,
|
41 |
+
q: Queue = None,
|
42 |
+
tracing: bool = False,
|
43 |
+
testing: bool = False,
|
44 |
):
|
45 |
print(inputs)
|
46 |
if self.llm_loader.streamer.for_huggingface:
|
|
|
51 |
|
52 |
chain = self.get_chain(tracing)
|
53 |
result = (
|
54 |
+
self._run_chain(chain, inputs, streaming_handler, testing)
|
|
|
|
|
|
|
|
|
55 |
if streaming_handler is not None
|
56 |
else chain(inputs)
|
57 |
)
|
|
|
75 |
def _execute_chain(self, chain, inputs, q, sh):
|
76 |
q.put(chain(inputs, callbacks=[sh]))
|
77 |
|
78 |
+
def _run_chain(self, chain, inputs, streaming_handler, testing):
|
79 |
que = Queue()
|
80 |
|
81 |
t = Thread(
|
|
|
84 |
)
|
85 |
t.start()
|
86 |
|
87 |
+
if self.llm_loader.streamer.for_huggingface and not testing:
|
88 |
count = (
|
89 |
2
|
90 |
if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
|
app_modules/llm_loader.py
CHANGED
@@ -227,6 +227,7 @@ class LLMLoader:
|
|
227 |
if "gpt4all-j" in MODEL_NAME_OR_PATH
|
228 |
or "dolly" in MODEL_NAME_OR_PATH
|
229 |
or "Qwen" in MODEL_NAME_OR_PATH
|
|
|
230 |
else 0
|
231 |
)
|
232 |
use_fast = (
|
@@ -452,7 +453,6 @@ class LLMLoader:
|
|
452 |
top_p=0.95,
|
453 |
top_k=0, # select from top 0 tokens (because zero, relies on top_p)
|
454 |
repetition_penalty=1.115,
|
455 |
-
use_auth_token=token,
|
456 |
token=token,
|
457 |
)
|
458 |
)
|
|
|
227 |
if "gpt4all-j" in MODEL_NAME_OR_PATH
|
228 |
or "dolly" in MODEL_NAME_OR_PATH
|
229 |
or "Qwen" in MODEL_NAME_OR_PATH
|
230 |
+
or "Llama-2" in MODEL_NAME_OR_PATH
|
231 |
else 0
|
232 |
)
|
233 |
use_fast = (
|
|
|
453 |
top_p=0.95,
|
454 |
top_k=0, # select from top 0 tokens (because zero, relies on top_p)
|
455 |
repetition_penalty=1.115,
|
|
|
456 |
token=token,
|
457 |
)
|
458 |
)
|
test.py
CHANGED
@@ -69,7 +69,11 @@ while True:
|
|
69 |
|
70 |
start = timer()
|
71 |
result = qa_chain.call_chain(
|
72 |
-
{"question": query, "chat_history": chat_history},
|
|
|
|
|
|
|
|
|
73 |
)
|
74 |
end = timer()
|
75 |
print(f"Completed in {end - start:.3f}s")
|
|
|
69 |
|
70 |
start = timer()
|
71 |
result = qa_chain.call_chain(
|
72 |
+
{"question": query, "chat_history": chat_history},
|
73 |
+
custom_handler,
|
74 |
+
None,
|
75 |
+
False,
|
76 |
+
True,
|
77 |
)
|
78 |
end = timer()
|
79 |
print(f"Completed in {end - start:.3f}s")
|