Update train_script.py
Browse files- train_script.py +2 -1
train_script.py
CHANGED
@@ -28,7 +28,7 @@ model = SentenceTransformer(
|
|
28 |
# 3. Load a dataset to finetune on
|
29 |
dataset = load_dataset("sentence-transformers/gooaq", split="train")
|
30 |
dataset = dataset.add_column("id", range(len(dataset)))
|
31 |
-
dataset_dict = dataset.train_test_split(test_size=10_000)
|
32 |
train_dataset: Dataset = dataset_dict["train"]
|
33 |
eval_dataset: Dataset = dataset_dict["test"]
|
34 |
|
@@ -62,6 +62,7 @@ args = SentenceTransformerTrainingArguments(
|
|
62 |
# 6. (Optional) Create an evaluator & evaluate the base model
|
63 |
# The full corpus, but only the evaluation queries
|
64 |
# corpus = dict(zip(dataset["id"], dataset["answer"]))
|
|
|
65 |
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
|
66 |
corpus = (
|
67 |
{qid: dataset[qid]["answer"] for qid in queries} |
|
|
|
28 |
# 3. Load a dataset to finetune on
|
29 |
dataset = load_dataset("sentence-transformers/gooaq", split="train")
|
30 |
dataset = dataset.add_column("id", range(len(dataset)))
|
31 |
+
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
|
32 |
train_dataset: Dataset = dataset_dict["train"]
|
33 |
eval_dataset: Dataset = dataset_dict["test"]
|
34 |
|
|
|
62 |
# 6. (Optional) Create an evaluator & evaluate the base model
|
63 |
# The full corpus, but only the evaluation queries
|
64 |
# corpus = dict(zip(dataset["id"], dataset["answer"]))
|
65 |
+
random.seed(12)
|
66 |
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
|
67 |
corpus = (
|
68 |
{qid: dataset[qid]["answer"] for qid in queries} |
|