GPTfree api
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -2,30 +2,30 @@ import os
|
|
2 |
import tempfile
|
3 |
from vllm import LLM
|
4 |
from vllm.sampling_params import SamplingParams
|
5 |
-
from huggingface_hub import hf_hub_download
|
6 |
from datetime import datetime, timedelta
|
7 |
|
8 |
# モデル名
|
9 |
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
10 |
|
11 |
-
# SYSTEM_PROMPTのロード関数
|
12 |
-
def load_system_prompt(
|
13 |
-
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
14 |
with open(file_path, 'r') as file:
|
15 |
system_prompt = file.read()
|
16 |
today = datetime.today().strftime('%Y-%m-%d')
|
17 |
yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d')
|
18 |
-
model_name = repo_id.split("/")[-1]
|
19 |
return system_prompt.format(name=model_name, today=today, yesterday=yesterday)
|
20 |
|
|
|
|
|
|
|
|
|
21 |
# 一時ディレクトリをキャッシュ用に設定
|
22 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
23 |
os.environ["TRANSFORMERS_CACHE"] = tmpdirname
|
24 |
os.environ["HF_HOME"] = tmpdirname
|
25 |
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") # 環境変数からトークンを取得
|
26 |
|
27 |
-
|
28 |
-
|
29 |
messages = [
|
30 |
{"role": "system", "content": SYSTEM_PROMPT},
|
31 |
{
|
@@ -40,4 +40,5 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|
40 |
llm = LLM(model=model_name, trust_remote_code=True, tensor_parallel_size=1, device="cpu")
|
41 |
outputs = llm.chat(messages, sampling_params=sampling_params)
|
42 |
|
|
|
43 |
print(outputs[0].outputs[0].text)
|
|
|
2 |
import tempfile
|
3 |
from vllm import LLM
|
4 |
from vllm.sampling_params import SamplingParams
|
|
|
5 |
from datetime import datetime, timedelta
|
6 |
|
7 |
# モデル名
|
8 |
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
9 |
|
10 |
+
# SYSTEM_PROMPTのロード関数 (ローカルファイルを読み込む)
|
11 |
+
def load_system_prompt(file_path: str) -> str:
|
|
|
12 |
with open(file_path, 'r') as file:
|
13 |
system_prompt = file.read()
|
14 |
today = datetime.today().strftime('%Y-%m-%d')
|
15 |
yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d')
|
|
|
16 |
return system_prompt.format(name=model_name, today=today, yesterday=yesterday)
|
17 |
|
18 |
+
# SYSTEM_PROMPT.txtは現在のディレクトリ内にあることを想定
|
19 |
+
system_prompt_path = "./SYSTEM_PROMPT.txt" # 現在のディレクトリ内のファイルパス
|
20 |
+
SYSTEM_PROMPT = load_system_prompt(system_prompt_path)
|
21 |
+
|
22 |
# 一時ディレクトリをキャッシュ用に設定
|
23 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
24 |
os.environ["TRANSFORMERS_CACHE"] = tmpdirname
|
25 |
os.environ["HF_HOME"] = tmpdirname
|
26 |
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") # 環境変数からトークンを取得
|
27 |
|
28 |
+
# メッセージリストの作成
|
|
|
29 |
messages = [
|
30 |
{"role": "system", "content": SYSTEM_PROMPT},
|
31 |
{
|
|
|
40 |
llm = LLM(model=model_name, trust_remote_code=True, tensor_parallel_size=1, device="cpu")
|
41 |
outputs = llm.chat(messages, sampling_params=sampling_params)
|
42 |
|
43 |
+
# 出力を表示
|
44 |
print(outputs[0].outputs[0].text)
|