Spaces:
Paused
Paused
larger datasets
Browse files- src/calibration_datasets.py +7 -259
- src/train_workflow.py +1 -1
src/calibration_datasets.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
"""Prepares the datasets for calibration. Original code gently shared by TheBloke"""
|
2 |
|
|
|
3 |
from abc import ABC
|
4 |
import time
|
5 |
from typing import Dict, List, Optional
|
@@ -123,10 +124,10 @@ class CalibrationDataset(ABC):
|
|
123 |
"""Load the Hugging Face dataset at `path`, using the provided kwargs."""
|
124 |
|
125 |
print(f"Loading HF dataset {path} with params: {kwargs}")
|
126 |
-
data: Dataset = load_dataset(path=path, **kwargs)
|
127 |
|
128 |
limit = limit and min(limit, len(data)) or len(data)
|
129 |
-
return data.
|
130 |
|
131 |
@staticmethod
|
132 |
def list_with_nls(samples: List[str]) -> List[str]:
|
@@ -217,10 +218,10 @@ class WikitextDataset(CalibrationDataset):
|
|
217 |
dataset = "wikitext"
|
218 |
dataset_config = {
|
219 |
"path": "wikitext",
|
220 |
-
"name": "wikitext-
|
221 |
"split": "train"
|
222 |
}
|
223 |
-
dataset_name = "
|
224 |
|
225 |
def process_samples(self) -> List[str]:
|
226 |
return [
|
@@ -234,272 +235,19 @@ class C4Dataset(CalibrationDataset):
|
|
234 |
dataset_field = "text"
|
235 |
dataset_config = {
|
236 |
"path": "allenai/c4",
|
237 |
-
"data_files": {
|
238 |
-
"train": "en/c4-train.00000-of-01024.json.gz"
|
239 |
-
},
|
240 |
"split": "train"
|
241 |
}
|
242 |
dataset_name = "C4"
|
243 |
|
244 |
|
245 |
-
class ThaiDataset(CalibrationDataset):
|
246 |
-
dataset = "thai"
|
247 |
-
dataset_field = "text"
|
248 |
-
dataset_config = {
|
249 |
-
"path": "pbwt/all-thai",
|
250 |
-
"data_files": {
|
251 |
-
"train": "data/train-00000-of-00047-985fbaed08d034cf.parquet"
|
252 |
-
},
|
253 |
-
"split": "train"
|
254 |
-
}
|
255 |
-
dataset_name = "All Thai"
|
256 |
-
|
257 |
-
|
258 |
-
class MovieScriptDataset(CalibrationDataset):
|
259 |
-
dataset = "movie-scripts"
|
260 |
-
dataset_field = "full_script"
|
261 |
-
dataset_config = {
|
262 |
-
"path": "jondurbin/cinematika-v0.1",
|
263 |
-
"data_files": { "train": "full_script.parquet" },
|
264 |
-
"split": "train"
|
265 |
-
}
|
266 |
-
dataset_name = "Cinematika Full Scripts"
|
267 |
-
|
268 |
-
|
269 |
-
class JapaneseEnglishDataset(CalibrationDataset):
|
270 |
-
dataset = "japanese-english"
|
271 |
-
dataset_config = {
|
272 |
-
"path": "augmxnt/shisa-en-ja-dpo-v1",
|
273 |
-
"split": "train"
|
274 |
-
}
|
275 |
-
dataset_name = "Shisa English Japanese DPO"
|
276 |
-
randomize = True
|
277 |
-
|
278 |
-
def process_samples(self) -> List[str]:
|
279 |
-
def transform_samples(sample):
|
280 |
-
prompt = sample["prompt"]
|
281 |
-
chosen = sample["chosen"]
|
282 |
-
# prompt example: "[INST] <<SYS>>\nYou are a helpful, unbiased, uncensored assistant.\n<</SYS>>\n\nWhat are cardigans made of? Leather or wood? [/INST]"
|
283 |
-
|
284 |
-
try:
|
285 |
-
part1 = prompt.split('\n<</SYS>>\n\n')[1]
|
286 |
-
extracted_text = part1.split(' [/INST]')[0]
|
287 |
-
except Exception as e:
|
288 |
-
print(f"Error extracting text from prompt '{prompt}': {e}")
|
289 |
-
raise
|
290 |
-
|
291 |
-
prompt = extracted_text
|
292 |
-
|
293 |
-
return {"output": f"{prompt} {chosen}"}
|
294 |
-
|
295 |
-
return self.data.map(transform_samples)["output"]
|
296 |
-
|
297 |
-
|
298 |
-
class PortugueseDataset(CalibrationDataset):
|
299 |
-
dataset = "portuguese"
|
300 |
-
dataset_config = {
|
301 |
-
"path": "adalbertojunior/portuguese_orca",
|
302 |
-
"split": "train"
|
303 |
-
}
|
304 |
-
dataset_name = "Portuguese Orca"
|
305 |
-
transform_fields = [ "question", "response" ]
|
306 |
-
|
307 |
-
|
308 |
-
class MathsDataset(CalibrationDataset):
|
309 |
-
dataset = "maths"
|
310 |
-
dataset_config = {
|
311 |
-
"path": "andersonbcdefg/math",
|
312 |
-
"split": "train"
|
313 |
-
}
|
314 |
-
dataset_name = "CamelAI Math"
|
315 |
-
transform_fields = [ "message_1", "message_2" ]
|
316 |
-
|
317 |
-
|
318 |
-
class MedicalDataset(CalibrationDataset):
|
319 |
-
dataset = "medical"
|
320 |
-
dataset_config = {
|
321 |
-
"path": "medalpaca/medical_meadow_wikidoc",
|
322 |
-
"split": "train"
|
323 |
-
}
|
324 |
-
dataset_name = "Medical Medaow WikiDoc"
|
325 |
-
transform_fields = [ "input", "output" ]
|
326 |
-
|
327 |
-
|
328 |
-
class OpenInstructDataset(CalibrationDataset):
|
329 |
-
dataset = "open-instruct"
|
330 |
-
dataset_config = {
|
331 |
-
"path": "VMware/open-instruct",
|
332 |
-
"split": "train"
|
333 |
-
}
|
334 |
-
dataset_name = "VMware Open Instruct"
|
335 |
-
transform_fields = [ "instruction", "response" ]
|
336 |
-
|
337 |
-
|
338 |
-
class KoreanDataset(CalibrationDataset):
|
339 |
-
dataset = "korean"
|
340 |
-
dataset_config = {
|
341 |
-
"path": "beomi/KoAlpaca-v1.1a",
|
342 |
-
"split": "train"
|
343 |
-
}
|
344 |
-
dataset_name = "Korean Alpaca"
|
345 |
-
transform_fields = [ "instruction", "output" ]
|
346 |
-
|
347 |
-
|
348 |
class CodeDataset(CalibrationDataset):
|
349 |
dataset = "code"
|
350 |
-
dataset_field = "output"
|
351 |
-
dataset_config = {
|
352 |
-
"path": "nickrosh/Evol-Instruct-Code-80k-v1",
|
353 |
-
"split": "train"
|
354 |
-
}
|
355 |
-
dataset_name = "Evol Instruct Code"
|
356 |
-
|
357 |
-
|
358 |
-
class MultiLanguageDataset(CalibrationDataset):
|
359 |
-
dataset = "multi-language"
|
360 |
-
dataset_field = "text"
|
361 |
-
dataset_config = {
|
362 |
-
"path": "papluca/language-identification",
|
363 |
-
"split": "train"
|
364 |
-
}
|
365 |
-
dataset_name = "Language Identification"
|
366 |
-
|
367 |
-
|
368 |
-
class RussianDataset(CalibrationDataset):
|
369 |
-
dataset = "russian"
|
370 |
-
dataset_config = {
|
371 |
-
"path": "Den4ikAI/russian_instructions_2",
|
372 |
-
"split": "train"
|
373 |
-
}
|
374 |
-
dataset_name = "Russian Instructions 2"
|
375 |
-
transform_fields = [ "question", "answer" ]
|
376 |
-
|
377 |
-
|
378 |
-
class DutchDataset(CalibrationDataset):
|
379 |
-
dataset = "dutch"
|
380 |
-
dataset_config = {
|
381 |
-
"path": "BramVanroy/dolly-15k-dutch",
|
382 |
-
"split": "train"
|
383 |
-
}
|
384 |
-
dataset_name = "Dolly 15K Dutch"
|
385 |
-
transform_fields = [ "instruction", "context", "response" ]
|
386 |
-
transform_join = "{field1} {field2} {field3}"
|
387 |
-
|
388 |
-
|
389 |
-
class VietnameseChineseDataset(CalibrationDataset):
|
390 |
-
dataset = "vietnamesechinese"
|
391 |
-
dataset_config = {
|
392 |
-
"path": "nRuaif/Vietnamese_x_Alpaca",
|
393 |
-
"split": "train"
|
394 |
-
}
|
395 |
-
dataset_name = "Vietnamese and Chinese"
|
396 |
-
|
397 |
-
def get_dataset_url(self) -> None:
|
398 |
-
return None
|
399 |
-
|
400 |
-
def process_samples(self) -> List[str]:
|
401 |
-
samples = self.data["output"]
|
402 |
-
chinese_samples = CalibrationDataset.get_dataset("chinese").get_samples()
|
403 |
-
|
404 |
-
joined_list = samples + chinese_samples
|
405 |
-
|
406 |
-
import random
|
407 |
-
random.shuffle(joined_list)
|
408 |
-
|
409 |
-
return joined_list[:self.dataset_limit]
|
410 |
-
|
411 |
-
|
412 |
-
class VietnameseDataset(CalibrationDataset):
|
413 |
-
dataset = "vietnamese"
|
414 |
-
dataset_field = "output"
|
415 |
-
dataset_config = {
|
416 |
-
"path": "nRuaif/Vietnamese_x_Alpaca",
|
417 |
-
"split": "train"
|
418 |
-
}
|
419 |
-
dataset_name = "Alpaca Vietnamese"
|
420 |
-
|
421 |
-
|
422 |
-
class ChineseDataset(CalibrationDataset):
|
423 |
-
dataset = "chinese"
|
424 |
-
dataset_config = {
|
425 |
-
"path": "TigerResearch/tigerbot-alpaca-zh-0.5m",
|
426 |
-
"split": "train"
|
427 |
-
}
|
428 |
-
dataset_name = "Tiger Alpaca ZH"
|
429 |
-
transform_fields = [ "instruction", "input", "output" ]
|
430 |
-
transform_join = "{field1} {field2} {field3}"
|
431 |
-
|
432 |
-
|
433 |
-
class LatinEnglishDataset(CalibrationDataset):
|
434 |
-
dataset = "latin-english"
|
435 |
-
dataset_config = {
|
436 |
-
"path": "grosenthal/latin_english_parallel",
|
437 |
-
"split": "train"
|
438 |
-
}
|
439 |
-
dataset_name = "Latin English Parallel"
|
440 |
-
transform_fields = [ "la", "en" ]
|
441 |
-
transform_join = "{field1}\n{field2}"
|
442 |
-
|
443 |
-
|
444 |
-
class PolishDataset(CalibrationDataset):
|
445 |
-
dataset = "polish"
|
446 |
dataset_field = "content"
|
447 |
dataset_config = {
|
448 |
-
"path": "
|
449 |
-
"split": "train"
|
450 |
-
}
|
451 |
-
dataset_name = "Polish News"
|
452 |
-
|
453 |
-
|
454 |
-
class JapaneseDataset(CalibrationDataset):
|
455 |
-
dataset = "japanese"
|
456 |
-
dataset_field = "output"
|
457 |
-
dataset_config = {
|
458 |
-
"path": "fujiki/japanese_alpaca_data",
|
459 |
-
"split": "train"
|
460 |
-
}
|
461 |
-
dataset_name = "Alpaca Japanese"
|
462 |
-
|
463 |
-
|
464 |
-
class SpanishDataset(CalibrationDataset):
|
465 |
-
dataset = "spanish"
|
466 |
-
dataset_field = "output"
|
467 |
-
dataset_config = {
|
468 |
-
"path": "bertin-project/alpaca-spanish",
|
469 |
-
"split": "train"
|
470 |
-
}
|
471 |
-
dataset_name = "Alpaca Spanish"
|
472 |
-
|
473 |
-
|
474 |
-
class GermanDataset(CalibrationDataset):
|
475 |
-
dataset = "german"
|
476 |
-
dataset_config = {
|
477 |
-
"path": "deepset/germanquad",
|
478 |
-
"split": "train"
|
479 |
-
}
|
480 |
-
dataset_name = "German Quad"
|
481 |
-
|
482 |
-
def process_samples(self) -> List[str]:
|
483 |
-
def transform_samples(sample):
|
484 |
-
split_context = sample["context"].split("===")
|
485 |
-
if len(split_context) >= 3:
|
486 |
-
trans_context = split_context[2]
|
487 |
-
else:
|
488 |
-
trans_context = sample["context"]
|
489 |
-
return {"output": trans_context.strip()}
|
490 |
-
|
491 |
-
return self.data.map(transform_samples)["output"]
|
492 |
-
|
493 |
-
|
494 |
-
class FrenchDataset(CalibrationDataset):
|
495 |
-
dataset = "french"
|
496 |
-
dataset_field = "text"
|
497 |
-
dataset_config = {
|
498 |
-
"path": "Kant1/French_Wikipedia_articles",
|
499 |
-
"data_files": { "wiki_00.txt" },
|
500 |
"split": "train"
|
501 |
}
|
502 |
-
dataset_name = "
|
503 |
|
504 |
|
505 |
def validate_dataset(dataset_name: str, **kwargs):
|
|
|
1 |
"""Prepares the datasets for calibration. Original code gently shared by TheBloke"""
|
2 |
|
3 |
+
import os
|
4 |
from abc import ABC
|
5 |
import time
|
6 |
from typing import Dict, List, Optional
|
|
|
124 |
"""Load the Hugging Face dataset at `path`, using the provided kwargs."""
|
125 |
|
126 |
print(f"Loading HF dataset {path} with params: {kwargs}")
|
127 |
+
data: Dataset = load_dataset(path=path, streaming=True, num_proc=len(os.sched_getaffinity(0)), **kwargs)
|
128 |
|
129 |
limit = limit and min(limit, len(data)) or len(data)
|
130 |
+
return data.shuffle(seed=42).take(range(limit))
|
131 |
|
132 |
@staticmethod
|
133 |
def list_with_nls(samples: List[str]) -> List[str]:
|
|
|
218 |
dataset = "wikitext"
|
219 |
dataset_config = {
|
220 |
"path": "wikitext",
|
221 |
+
"name": "wikitext-103-raw-v1",
|
222 |
"split": "train"
|
223 |
}
|
224 |
+
dataset_name = "Wikitext103 Full"
|
225 |
|
226 |
def process_samples(self) -> List[str]:
|
227 |
return [
|
|
|
235 |
dataset_field = "text"
|
236 |
dataset_config = {
|
237 |
"path": "allenai/c4",
|
|
|
|
|
|
|
238 |
"split": "train"
|
239 |
}
|
240 |
dataset_name = "C4"
|
241 |
|
242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
class CodeDataset(CalibrationDataset):
|
244 |
dataset = "code"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
dataset_field = "content"
|
246 |
dataset_config = {
|
247 |
+
"path": "bigcode/the-stack",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
"split": "train"
|
249 |
}
|
250 |
+
dataset_name = "The Stack"
|
251 |
|
252 |
|
253 |
def validate_dataset(dataset_name: str, **kwargs):
|
src/train_workflow.py
CHANGED
@@ -31,7 +31,7 @@ DEFAULT_TRAINING_ARGS = \
|
|
31 |
--num_train_epochs 1
|
32 |
--per_device_train_batch_size 64
|
33 |
--per_device_eval_batch_size 64
|
34 |
-
--gradient_accumulation_steps
|
35 |
--evaluation_strategy no
|
36 |
--save_strategy no
|
37 |
--weight_decay 0.0
|
|
|
31 |
--num_train_epochs 1
|
32 |
--per_device_train_batch_size 64
|
33 |
--per_device_eval_batch_size 64
|
34 |
+
--gradient_accumulation_steps 8
|
35 |
--evaluation_strategy no
|
36 |
--save_strategy no
|
37 |
--weight_decay 0.0
|