Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -11,7 +11,7 @@ if torch.cuda.is_available():
|
|
11 |
model_id,
|
12 |
torch_dtype=torch.float16,
|
13 |
device_map='cuda',
|
14 |
-
)
|
15 |
else:
|
16 |
model = AutoModelForCausalLM.from_pretrained(
|
17 |
model_id,
|
@@ -44,7 +44,7 @@ def run(message: str,
|
|
44 |
chat_history: list[tuple[str, str]],
|
45 |
system_prompt: str,
|
46 |
max_new_tokens: int = 1024,
|
47 |
-
temperature: float = 0.
|
48 |
top_p: float = 0.95,
|
49 |
top_k: int = 50) -> Iterator[str]:
|
50 |
prompt = get_prompt(message, chat_history, system_prompt)
|
@@ -62,7 +62,6 @@ def run(message: str,
|
|
62 |
top_p=top_p,
|
63 |
top_k=top_k,
|
64 |
temperature=temperature,
|
65 |
-
num_beams=1,
|
66 |
eos_token_id=tokenizer.eos_token_id,
|
67 |
pad_token_id=tokenizer.pad_token_id,
|
68 |
)
|
|
|
11 |
model_id,
|
12 |
torch_dtype=torch.float16,
|
13 |
device_map='cuda',
|
14 |
+
).to("cuda")
|
15 |
else:
|
16 |
model = AutoModelForCausalLM.from_pretrained(
|
17 |
model_id,
|
|
|
44 |
chat_history: list[tuple[str, str]],
|
45 |
system_prompt: str,
|
46 |
max_new_tokens: int = 1024,
|
47 |
+
temperature: float = 0.2,
|
48 |
top_p: float = 0.95,
|
49 |
top_k: int = 50) -> Iterator[str]:
|
50 |
prompt = get_prompt(message, chat_history, system_prompt)
|
|
|
62 |
top_p=top_p,
|
63 |
top_k=top_k,
|
64 |
temperature=temperature,
|
|
|
65 |
eos_token_id=tokenizer.eos_token_id,
|
66 |
pad_token_id=tokenizer.pad_token_id,
|
67 |
)
|