dejanseo commited on
Commit
fffb0cd
1 Parent(s): 4937b51

Upload 12 files

Browse files
training/dataset_dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"splits": ["train", "test", "validation"]}
training/test/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9afc5d689789eb1388168b261564e63112bb8f0a9d3f9c96a1cf590f73a9449
3
+ size 190736
training/test/dataset_info.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "builder_name": "google_wellformed_query",
3
+ "citation": "@misc{faruqui2018identifying,\n title={Identifying Well-formed Natural Language Questions},\n author={Manaal Faruqui and Dipanjan Das},\n year={2018},\n eprint={1808.09419},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n",
4
+ "config_name": "default",
5
+ "dataset_name": "google_wellformed_query",
6
+ "dataset_size": 1230988,
7
+ "description": "Google's query wellformedness dataset was created by crowdsourcing well-formedness annotations for 25,100 queries from the Paralex corpus. Every query was annotated by five raters each with 1/0 rating of whether or not the query is well-formed.\n",
8
+ "download_checksums": {
9
+ "https://raw.githubusercontent.com/google-research-datasets/query-wellformedness/master/train.tsv": {
10
+ "num_bytes": 805818,
11
+ "checksum": null
12
+ },
13
+ "https://raw.githubusercontent.com/google-research-datasets/query-wellformedness/master/test.tsv": {
14
+ "num_bytes": 178070,
15
+ "checksum": null
16
+ },
17
+ "https://raw.githubusercontent.com/google-research-datasets/query-wellformedness/master/dev.tsv": {
18
+ "num_bytes": 173131,
19
+ "checksum": null
20
+ }
21
+ },
22
+ "download_size": 1157019,
23
+ "features": {
24
+ "rating": {
25
+ "dtype": "float32",
26
+ "_type": "Value"
27
+ },
28
+ "content": {
29
+ "dtype": "string",
30
+ "_type": "Value"
31
+ }
32
+ },
33
+ "homepage": "https://github.com/google-research-datasets/query-wellformedness",
34
+ "license": "",
35
+ "size_in_bytes": 2388007,
36
+ "splits": {
37
+ "train": {
38
+ "name": "train",
39
+ "num_bytes": 857383,
40
+ "num_examples": 17500,
41
+ "dataset_name": "google_wellformed_query"
42
+ },
43
+ "test": {
44
+ "name": "test",
45
+ "num_bytes": 189499,
46
+ "num_examples": 3850,
47
+ "dataset_name": "google_wellformed_query"
48
+ },
49
+ "validation": {
50
+ "name": "validation",
51
+ "num_bytes": 184106,
52
+ "num_examples": 3750,
53
+ "dataset_name": "google_wellformed_query"
54
+ }
55
+ },
56
+ "version": {
57
+ "version_str": "0.0.0",
58
+ "major": 0,
59
+ "minor": 0,
60
+ "patch": 0
61
+ }
62
+ }
training/test/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "007669a06fb24065",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": "test"
13
+ }
training/train.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from transformers import AlbertTokenizer, AlbertForSequenceClassification, AdamW, AlbertConfig
5
+ from datasets import Dataset as HFDataset
6
+ import pandas as pd
7
+ import os
8
+
9
+ # Ensure the /model/ directory exists
10
+ model_dir = 'model'
11
+ os.makedirs(model_dir, exist_ok=True)
12
+
13
+ # Load datasets from the Arrow files
14
+ train_dataset = HFDataset.from_file('train/data-00000-of-00001.arrow')
15
+ val_dataset = HFDataset.from_file('validation/data-00000-of-00001.arrow')
16
+ test_dataset = HFDataset.from_file('test/data-00000-of-00001.arrow')
17
+
18
+ # Convert datasets to pandas DataFrame
19
+ train_df = train_dataset.to_pandas()
20
+ val_df = val_dataset.to_pandas()
21
+ test_df = test_dataset.to_pandas()
22
+
23
+ # Remove question marks at the end of each query
24
+ train_df['content'] = train_df['content'].str.rstrip('?')
25
+ val_df['content'] = val_df['content'].str.rstrip('?')
26
+ test_df['content'] = test_df['content'].str.rstrip('?')
27
+
28
+ # Convert labels to integers (0 or 1)
29
+ train_df['rating'] = train_df['rating'].apply(lambda x: int(x >= 0.5))
30
+ val_df['rating'] = val_df['rating'].apply(lambda x: int(x >= 0.5))
31
+ test_df['rating'] = test_df['rating'].apply(lambda x: int(x >= 0.5))
32
+
33
+ # Initialize ALBERT tokenizer
34
+ tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
35
+
36
+ # Custom Dataset class for PyTorch
37
+ class QueryDataset(Dataset):
38
+ def __init__(self, texts, labels, tokenizer, max_length=32):
39
+ self.texts = texts
40
+ self.labels = labels
41
+ self.tokenizer = tokenizer
42
+ self.max_length = max_length
43
+
44
+ def __len__(self):
45
+ return len(self.texts)
46
+
47
+ def __getitem__(self, idx):
48
+ text = str(self.texts[idx])
49
+ label = int(self.labels[idx]) # Ensure label is an integer
50
+ encoding = self.tokenizer.encode_plus(
51
+ text,
52
+ add_special_tokens=True,
53
+ max_length=self.max_length,
54
+ padding='max_length', # Ensure consistent length
55
+ truncation=True, # Truncate longer sequences
56
+ return_attention_mask=True,
57
+ return_tensors='pt'
58
+ )
59
+
60
+ return {
61
+ 'input_ids': encoding['input_ids'].flatten(),
62
+ 'attention_mask': encoding['attention_mask'].flatten(),
63
+ 'label': torch.tensor(label, dtype=torch.long)
64
+ }
65
+
66
+ # Prepare datasets
67
+ train_dataset = QueryDataset(train_df['content'].values, train_df['rating'].values, tokenizer)
68
+ val_dataset = QueryDataset(val_df['content'].values, val_df['rating'].values, tokenizer)
69
+ test_dataset = QueryDataset(test_df['content'].values, test_df['rating'].values, tokenizer)
70
+
71
+ # DataLoaders
72
+ batch_size = 128
73
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
74
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
75
+ test_loader = DataLoader(test_dataset, batch_size=batch_size)
76
+
77
+ # Load ALBERT model
78
+ model = AlbertForSequenceClassification.from_pretrained('albert-base-v2', num_labels=2)
79
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
80
+ model.to(device)
81
+
82
+ # Optimizer and loss function
83
+ optimizer = AdamW(model.parameters(), lr=1e-5)
84
+ criterion = nn.CrossEntropyLoss()
85
+
86
+ # Training loop
87
+ epochs = 4
88
+ for epoch in range(epochs):
89
+ model.train()
90
+ total_loss = 0
91
+ for batch in train_loader:
92
+ input_ids = batch['input_ids'].to(device)
93
+ attention_mask = batch['attention_mask'].to(device)
94
+ labels = batch['label'].to(device)
95
+
96
+ optimizer.zero_grad()
97
+ outputs = model(input_ids, attention_mask=attention_mask)
98
+ loss = criterion(outputs.logits, labels)
99
+ loss.backward()
100
+ optimizer.step()
101
+
102
+ total_loss += loss.item()
103
+
104
+ avg_loss = total_loss / len(train_loader)
105
+ print(f'Epoch {epoch + 1}, Loss: {avg_loss:.4f}')
106
+
107
+ # Validation step at the end of each epoch
108
+ model.eval()
109
+ correct_predictions = 0
110
+ total_predictions = 0
111
+ with torch.no_grad():
112
+ for batch in val_loader:
113
+ input_ids = batch['input_ids'].to(device)
114
+ attention_mask = batch['attention_mask'].to(device)
115
+ labels = batch['label'].to(device)
116
+
117
+ outputs = model(input_ids, attention_mask=attention_mask)
118
+ preds = torch.argmax(outputs.logits, dim=1)
119
+ correct_predictions += (preds == labels).sum().item()
120
+ total_predictions += labels.size(0)
121
+
122
+ accuracy = correct_predictions / total_predictions
123
+ print(f'Validation Accuracy after Epoch {epoch + 1}: {accuracy:.4f}')
124
+
125
+ # Save the model, tokenizer, and config to /model/ directory
126
+ model.save_pretrained(model_dir, safe_serialization=True) # Save model weights in safetensors format
127
+ tokenizer.save_pretrained(model_dir)
128
+
129
+ # Update config with correct classifier details
130
+ config = AlbertConfig.from_pretrained('albert-base-v2')
131
+ config.num_labels = 2 # Set the number of labels for classification
132
+ config.save_pretrained(model_dir)
133
+
134
+ print(f"Model and all required files saved to {model_dir}")
training/train/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f13a5a3621b3b4b062d3b6f1958162b1c4f6c9235cf3f22b3841f5e4a23704d2
3
+ size 861704
training/train/dataset_info.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "builder_name": "google_wellformed_query",
3
+ "citation": "@misc{faruqui2018identifying,\n title={Identifying Well-formed Natural Language Questions},\n author={Manaal Faruqui and Dipanjan Das},\n year={2018},\n eprint={1808.09419},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n",
4
+ "config_name": "default",
5
+ "dataset_name": "google_wellformed_query",
6
+ "dataset_size": 1230988,
7
+ "description": "Google's query wellformedness dataset was created by crowdsourcing well-formedness annotations for 25,100 queries from the Paralex corpus. Every query was annotated by five raters each with 1/0 rating of whether or not the query is well-formed.\n",
8
+ "download_checksums": {
9
+ "https://raw.githubusercontent.com/google-research-datasets/query-wellformedness/master/train.tsv": {
10
+ "num_bytes": 805818,
11
+ "checksum": null
12
+ },
13
+ "https://raw.githubusercontent.com/google-research-datasets/query-wellformedness/master/test.tsv": {
14
+ "num_bytes": 178070,
15
+ "checksum": null
16
+ },
17
+ "https://raw.githubusercontent.com/google-research-datasets/query-wellformedness/master/dev.tsv": {
18
+ "num_bytes": 173131,
19
+ "checksum": null
20
+ }
21
+ },
22
+ "download_size": 1157019,
23
+ "features": {
24
+ "rating": {
25
+ "dtype": "float32",
26
+ "_type": "Value"
27
+ },
28
+ "content": {
29
+ "dtype": "string",
30
+ "_type": "Value"
31
+ }
32
+ },
33
+ "homepage": "https://github.com/google-research-datasets/query-wellformedness",
34
+ "license": "",
35
+ "size_in_bytes": 2388007,
36
+ "splits": {
37
+ "train": {
38
+ "name": "train",
39
+ "num_bytes": 857383,
40
+ "num_examples": 17500,
41
+ "dataset_name": "google_wellformed_query"
42
+ },
43
+ "test": {
44
+ "name": "test",
45
+ "num_bytes": 189499,
46
+ "num_examples": 3850,
47
+ "dataset_name": "google_wellformed_query"
48
+ },
49
+ "validation": {
50
+ "name": "validation",
51
+ "num_bytes": 184106,
52
+ "num_examples": 3750,
53
+ "dataset_name": "google_wellformed_query"
54
+ }
55
+ },
56
+ "version": {
57
+ "version_str": "0.0.0",
58
+ "major": 0,
59
+ "minor": 0,
60
+ "patch": 0
61
+ }
62
+ }
training/train/google_wellformed_query_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
training/train/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "5aec13d80b0bb552",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": "train"
13
+ }
training/validation/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e07235be93ddd52fa59494a2ec702b7fbdd405e96c6069ba10d8588a45071d6
3
+ size 185352
training/validation/dataset_info.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "builder_name": "google_wellformed_query",
3
+ "citation": "@misc{faruqui2018identifying,\n title={Identifying Well-formed Natural Language Questions},\n author={Manaal Faruqui and Dipanjan Das},\n year={2018},\n eprint={1808.09419},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n",
4
+ "config_name": "default",
5
+ "dataset_name": "google_wellformed_query",
6
+ "dataset_size": 1230988,
7
+ "description": "Google's query wellformedness dataset was created by crowdsourcing well-formedness annotations for 25,100 queries from the Paralex corpus. Every query was annotated by five raters each with 1/0 rating of whether or not the query is well-formed.\n",
8
+ "download_checksums": {
9
+ "https://raw.githubusercontent.com/google-research-datasets/query-wellformedness/master/train.tsv": {
10
+ "num_bytes": 805818,
11
+ "checksum": null
12
+ },
13
+ "https://raw.githubusercontent.com/google-research-datasets/query-wellformedness/master/test.tsv": {
14
+ "num_bytes": 178070,
15
+ "checksum": null
16
+ },
17
+ "https://raw.githubusercontent.com/google-research-datasets/query-wellformedness/master/dev.tsv": {
18
+ "num_bytes": 173131,
19
+ "checksum": null
20
+ }
21
+ },
22
+ "download_size": 1157019,
23
+ "features": {
24
+ "rating": {
25
+ "dtype": "float32",
26
+ "_type": "Value"
27
+ },
28
+ "content": {
29
+ "dtype": "string",
30
+ "_type": "Value"
31
+ }
32
+ },
33
+ "homepage": "https://github.com/google-research-datasets/query-wellformedness",
34
+ "license": "",
35
+ "size_in_bytes": 2388007,
36
+ "splits": {
37
+ "train": {
38
+ "name": "train",
39
+ "num_bytes": 857383,
40
+ "num_examples": 17500,
41
+ "dataset_name": "google_wellformed_query"
42
+ },
43
+ "test": {
44
+ "name": "test",
45
+ "num_bytes": 189499,
46
+ "num_examples": 3850,
47
+ "dataset_name": "google_wellformed_query"
48
+ },
49
+ "validation": {
50
+ "name": "validation",
51
+ "num_bytes": 184106,
52
+ "num_examples": 3750,
53
+ "dataset_name": "google_wellformed_query"
54
+ }
55
+ },
56
+ "version": {
57
+ "version_str": "0.0.0",
58
+ "major": 0,
59
+ "minor": 0,
60
+ "patch": 0
61
+ }
62
+ }
training/validation/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "cc2d3fe0964202f3",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": "validation"
13
+ }