Commit
·
fd0708d
1
Parent(s):
ca6376e
Upload refactored_data_preprocessing_notebook (1).ipynb
Browse files
refactored_data_preprocessing_notebook (1).ipynb
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "c743a143-9d3e-4199-a984-5ad51014c168",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# UniProt Data Preprocessing\n",
|
9 |
+
"\n",
|
10 |
+
"This notebook is for preprocessing a UniProt TSV file with columns (Protein families, Binding site, Active site, Sequence). If the family annotation is missing, the code will filter out this sequence. Missing binding sites are not acceptable for this notebook, so make sure all of your suequences have binding site annotations. If the Active site annotation is missing, the sequence will be included without issue. Missing sequences are not handled by this notebook. "
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 1,
|
16 |
+
"id": "b595fc6e-ef53-47fa-a517-ea3bd1066a1c",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import pandas as pd\n",
|
21 |
+
"import numpy as np\n",
|
22 |
+
"\n",
|
23 |
+
"# Load the dataset\n",
|
24 |
+
"file_path = 'uniprotkb_family_AND_ft_binding_AND_pro_2023_09_19.tsv'\n",
|
25 |
+
"data = pd.read_csv(file_path, sep='\\t')\n",
|
26 |
+
"\n",
|
27 |
+
"# Display the first few rows of the dataframe\n",
|
28 |
+
"data.head()"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 2,
|
34 |
+
"id": "45666b3a-45e7-41e5-834b-bf7a0ca8b3de",
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"data.shape[0]"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": 4,
|
44 |
+
"id": "7bff9929-9423-4d58-9c5f-6c1758e50da7",
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"import pandas as pd\n",
|
49 |
+
"\n",
|
50 |
+
"# Load the dataset\n",
|
51 |
+
"file_path = 'uniprotkb_family_AND_ft_binding_AND_pro_2023_09_19.tsv'\n",
|
52 |
+
"data = pd.read_csv(file_path, sep='\\t')\n",
|
53 |
+
"\n",
|
54 |
+
"# Filter out rows with NaN values in the 'Protein families' column\n",
|
55 |
+
"data = data[pd.notna(data['Protein families'])]\n",
|
56 |
+
"\n",
|
57 |
+
"# Display the first few rows of the modified dataframe\n",
|
58 |
+
"data.head()"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"execution_count": 5,
|
64 |
+
"id": "f87e25da-12b2-4002-959b-f3be7c5b4928",
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [],
|
67 |
+
"source": [
|
68 |
+
"data.shape[0]"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": 6,
|
74 |
+
"id": "062d44bd-2aa4-40e2-9662-9d7cfacabc80",
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"# Group the data by 'Protein families' and get the size of each group\n",
|
79 |
+
"family_sizes = data.groupby('Protein families').size()\n",
|
80 |
+
"\n",
|
81 |
+
"# Create a new column with the size of each family\n",
|
82 |
+
"data['Family size'] = data['Protein families'].map(family_sizes)\n",
|
83 |
+
"\n",
|
84 |
+
"# Sort the data by 'Family size' in descending order and then by 'Protein families'\n",
|
85 |
+
"data_sorted = data.sort_values(by=['Family size', 'Protein families'], ascending=[False, True])\n",
|
86 |
+
"\n",
|
87 |
+
"# Drop the 'Family size' column as it is no longer needed\n",
|
88 |
+
"data_sorted.drop(columns='Family size', inplace=True)\n",
|
89 |
+
"\n",
|
90 |
+
"# Define a function to extract the location from the binding and active site columns\n",
|
91 |
+
"def extract_location(site_info):\n",
|
92 |
+
" if pd.isnull(site_info):\n",
|
93 |
+
" return None\n",
|
94 |
+
" locations = []\n",
|
95 |
+
" for info in site_info.split(';'):\n",
|
96 |
+
" if 'BINDING' in info or 'ACT_SITE' in info:\n",
|
97 |
+
" locations.append(info.split()[1])\n",
|
98 |
+
" return '; '.join(locations)\n",
|
99 |
+
"\n",
|
100 |
+
"# Apply the function to the 'Binding site' and 'Active site' columns to extract the locations\n",
|
101 |
+
"data_sorted['Binding site'] = data_sorted['Binding site'].apply(extract_location)\n",
|
102 |
+
"data_sorted['Active site'] = data_sorted['Active site'].apply(extract_location)\n",
|
103 |
+
"\n",
|
104 |
+
"# Display the first few rows of the modified dataframe\n",
|
105 |
+
"data_sorted.head()"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": 7,
|
111 |
+
"id": "70e04892-19a8-4e55-8b0d-bd5d9108b8d2",
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [],
|
114 |
+
"source": [
|
115 |
+
"# Create a new column that combines the 'Binding site' and 'Active site' columns\n",
|
116 |
+
"data_sorted['Binding-Active site'] = data_sorted['Binding site'].astype(str) + '; ' + data_sorted['Active site'].astype(str)\n",
|
117 |
+
"\n",
|
118 |
+
"# Replace 'nan' values with None\n",
|
119 |
+
"data_sorted['Binding-Active site'] = data_sorted['Binding-Active site'].replace('nan; nan', None)\n",
|
120 |
+
"\n",
|
121 |
+
"# Display the first few rows of the updated dataframe\n",
|
122 |
+
"data_sorted.head()"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": 8,
|
128 |
+
"id": "c7022fff-b445-47df-afc8-f5a2e3659be7",
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": [
|
132 |
+
"# Find entries in the \"Binding-Active site\" column containing '<' or '>'\n",
|
133 |
+
"entries_with_angle_brackets = data_sorted['Binding-Active site'].str.contains('<|>', na=False)\n",
|
134 |
+
"\n",
|
135 |
+
"# Get the number of such entries\n",
|
136 |
+
"num_entries_with_angle_brackets = entries_with_angle_brackets.sum()\n",
|
137 |
+
"\n",
|
138 |
+
"# Display the number of entries containing '<' or '>'\n",
|
139 |
+
"print(f\"Number of entries with angle brackets: {num_entries_with_angle_brackets}\")\n",
|
140 |
+
"\n",
|
141 |
+
"# Remove all rows where the \"Binding-Active site\" column contains '<' or '>'\n",
|
142 |
+
"data_filtered = data_sorted[~entries_with_angle_brackets]\n",
|
143 |
+
"\n",
|
144 |
+
"# Get the number of remaining rows\n",
|
145 |
+
"num_remaining_rows = data_filtered.shape[0]\n",
|
146 |
+
"\n",
|
147 |
+
"# Display the number of remaining rows\n",
|
148 |
+
"print(f\"Number of remaining rows: {num_remaining_rows}\")\n",
|
149 |
+
"\n",
|
150 |
+
"# Get the number of distinct protein families\n",
|
151 |
+
"num_distinct_families = data_filtered['Protein families'].nunique()\n",
|
152 |
+
"\n",
|
153 |
+
"# Display the number of distinct protein families\n",
|
154 |
+
"# Display the number of distinct protein families\n",
|
155 |
+
"print(f\"Number of distinct protein families: {num_distinct_families}\")\n",
|
156 |
+
"\n",
|
157 |
+
"# Define the target number of rows for the test set (approximately 20% of the data)\n",
|
158 |
+
"target_test_rows = int(0.20 * num_remaining_rows)\n",
|
159 |
+
"\n",
|
160 |
+
"# Get unique protein families\n",
|
161 |
+
"unique_families = data_filtered['Protein families'].unique()\n",
|
162 |
+
"\n",
|
163 |
+
"# Shuffle the unique families to randomize the selection\n",
|
164 |
+
"np.random.shuffle(unique_families)\n",
|
165 |
+
"\n",
|
166 |
+
"# Initialize variables to keep track of the selected rows for the test and train sets\n",
|
167 |
+
"test_rows = []\n",
|
168 |
+
"current_test_rows = 0\n",
|
169 |
+
"\n",
|
170 |
+
"# Loop through the shuffled families and add rows to the test set until we reach the target number of rows\n",
|
171 |
+
"for family in unique_families:\n",
|
172 |
+
" family_rows = data_filtered[data_filtered['Protein families'] == family].index.tolist()\n",
|
173 |
+
" if current_test_rows + len(family_rows) < target_test_rows:\n",
|
174 |
+
" test_rows.extend(family_rows)\n",
|
175 |
+
" current_test_rows += len(family_rows)\n",
|
176 |
+
" else:\n",
|
177 |
+
" # If adding the current family exceeds the target, we add it anyway and break the loop\n",
|
178 |
+
" test_rows.extend(family_rows)\n",
|
179 |
+
" break\n",
|
180 |
+
"\n",
|
181 |
+
"# Get the indices of the rows for the train set (all rows not in the test set)\n",
|
182 |
+
"train_rows = [i for i in data_filtered.index if i not in test_rows]\n",
|
183 |
+
"\n",
|
184 |
+
"# Create the test and train datasets\n",
|
185 |
+
"test_df = data_filtered.loc[test_rows]\n",
|
186 |
+
"train_df = data_filtered.loc[train_rows]\n",
|
187 |
+
"\n",
|
188 |
+
"test_df.shape[0], train_df.shape[0]"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"execution_count": 9,
|
194 |
+
"id": "6a5e747a-5af3-4eec-ba83-7d520d753e1f",
|
195 |
+
"metadata": {},
|
196 |
+
"outputs": [],
|
197 |
+
"source": [
|
198 |
+
"# Print the first few rows of each dataset to understand their structure\n",
|
199 |
+
"test_df.head()"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"cell_type": "code",
|
204 |
+
"execution_count": 10,
|
205 |
+
"id": "796c592c-4e4f-4403-b43d-843b9972170f",
|
206 |
+
"metadata": {},
|
207 |
+
"outputs": [],
|
208 |
+
"source": [
|
209 |
+
"train_df.head()"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": 11,
|
215 |
+
"id": "ae5884c1-376c-4f53-b638-f63d7e4ea5ad",
|
216 |
+
"metadata": {},
|
217 |
+
"outputs": [],
|
218 |
+
"source": [
|
219 |
+
"# Find rows where the \"Binding-Active site\" column contains the character \"?\", treating \"?\" as a literal character\n",
|
220 |
+
"test_rows_with_question_mark = test_df[test_df['Binding-Active site'].str.contains('\\?', na=False, regex=True)]\n",
|
221 |
+
"train_rows_with_question_mark = train_df[train_df['Binding-Active site'].str.contains('\\?', na=False, regex=True)]\n",
|
222 |
+
"\n",
|
223 |
+
"# Get the number of such rows in both datasets\n",
|
224 |
+
"num_test_rows_with_question_mark = len(test_rows_with_question_mark)\n",
|
225 |
+
"num_train_rows_with_question_mark = len(train_rows_with_question_mark)\n",
|
226 |
+
"\n",
|
227 |
+
"print(f\"Number of test rows with question mark: {num_test_rows_with_question_mark}\")\n",
|
228 |
+
"print(f\"Number of train rows with question mark: {num_train_rows_with_question_mark}\")\n",
|
229 |
+
"\n",
|
230 |
+
"# Delete the rows containing '?' in the \"Binding-Active site\" column\n",
|
231 |
+
"test_df = test_df.drop(test_rows_with_question_mark.index)\n",
|
232 |
+
"train_df = train_df.drop(train_rows_with_question_mark.index)\n",
|
233 |
+
"\n",
|
234 |
+
"# Check the number of remaining rows in both datasets\n",
|
235 |
+
"remaining_test_rows = test_df.shape[0]\n",
|
236 |
+
"remaining_train_rows = train_df.shape[0]\n",
|
237 |
+
"\n",
|
238 |
+
"print(f\"Number of remaining test rows: {remaining_test_rows}\")\n",
|
239 |
+
"print(f\"Number of remaining train rows: {remaining_train_rows}\")\n",
|
240 |
+
"\n",
|
241 |
+
"import re\n",
|
242 |
+
"\n",
|
243 |
+
"def expand_ranges(s):\n",
|
244 |
+
" \"\"\"Expand ranges in a string.\"\"\"\n",
|
245 |
+
" return re.sub(r'(\\d+)\\.\\.(\\d+)', lambda m: ', '.join(map(str, range(int(m.group(1)), int(m.group(2))+1))), str(s))\n",
|
246 |
+
"\n",
|
247 |
+
"# Apply the function to expand ranges in the \"Binding-Active site\" column in both datasets\n",
|
248 |
+
"test_df['Binding-Active site'] = test_df['Binding-Active site'].apply(expand_ranges)\n",
|
249 |
+
"train_df['Binding-Active site'] = train_df['Binding-Active site'].apply(expand_ranges)\n",
|
250 |
+
"\n",
|
251 |
+
"# Display the first few rows of each dataset to verify the changes\n",
|
252 |
+
"# print(test_df.head())\n",
|
253 |
+
"# print(train_df.head())"
|
254 |
+
]
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"cell_type": "code",
|
258 |
+
"execution_count": 12,
|
259 |
+
"id": "2d76022f-b8d6-4a9c-81a4-7eae34af4732",
|
260 |
+
"metadata": {},
|
261 |
+
"outputs": [],
|
262 |
+
"source": [
|
263 |
+
"def convert_to_binary_list(binding_active_str, sequence_len):\n",
|
264 |
+
" \"\"\"Convert a Binding-Active site string to a binary list based on the sequence length.\"\"\"\n",
|
265 |
+
" # Step 2: Create a list of 0s with length equal to the sequence length\n",
|
266 |
+
" binary_list = [0] * sequence_len\n",
|
267 |
+
" \n",
|
268 |
+
" # Step 3: Retrieve the indices and set the corresponding positions to 1\n",
|
269 |
+
" if pd.notna(binding_active_str):\n",
|
270 |
+
" # Get the indices from the binding-active site string\n",
|
271 |
+
" indices = [int(x) - 1 for segment in binding_active_str.split(';') for x in segment.split(',') if x.strip().isdigit()]\n",
|
272 |
+
" for idx in indices:\n",
|
273 |
+
" # Ensure the index is within the valid range\n",
|
274 |
+
" if 0 <= idx < sequence_len:\n",
|
275 |
+
" binary_list[idx] = 1\n",
|
276 |
+
" \n",
|
277 |
+
" # Step 4: Return the binary list\n",
|
278 |
+
" return binary_list\n",
|
279 |
+
"\n",
|
280 |
+
"# Apply the function to both datasets\n",
|
281 |
+
"test_df['Binding-Active site'] = test_df.apply(lambda row: convert_to_binary_list(row['Binding-Active site'], len(row['Sequence'])), axis=1)\n",
|
282 |
+
"train_df['Binding-Active site'] = train_df.apply(lambda row: convert_to_binary_list(row['Binding-Active site'], len(row['Sequence'])), axis=1)\n"
|
283 |
+
]
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "code",
|
287 |
+
"execution_count": 13,
|
288 |
+
"id": "0c01fca9-a558-4919-bd97-73b41acd4fc9",
|
289 |
+
"metadata": {},
|
290 |
+
"outputs": [],
|
291 |
+
"source": [
|
292 |
+
"test_df.head()"
|
293 |
+
]
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"cell_type": "code",
|
297 |
+
"execution_count": 14,
|
298 |
+
"id": "f686b656-313e-4ac3-a4ac-59fa87c6cbc9",
|
299 |
+
"metadata": {},
|
300 |
+
"outputs": [],
|
301 |
+
"source": [
|
302 |
+
"train_df.head()"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "code",
|
307 |
+
"execution_count": 15,
|
308 |
+
"id": "b4d27236-49ef-4244-81d2-4c5f120a97d5",
|
309 |
+
"metadata": {},
|
310 |
+
"outputs": [],
|
311 |
+
"source": [
|
312 |
+
"import pickle\n",
|
313 |
+
"import random\n",
|
314 |
+
"\n",
|
315 |
+
"def split_into_chunks(sequences, labels):\n",
|
316 |
+
" \"\"\"Split sequences and labels into chunks of size 1000 or less.\"\"\"\n",
|
317 |
+
" chunk_size = 1000\n",
|
318 |
+
" new_sequences = []\n",
|
319 |
+
" new_labels = []\n",
|
320 |
+
" \n",
|
321 |
+
" for seq, lbl in zip(sequences, labels):\n",
|
322 |
+
" if len(seq) > chunk_size:\n",
|
323 |
+
" # Split the sequence and labels into chunks of size 1000 or less\n",
|
324 |
+
" for i in range(0, len(seq), chunk_size):\n",
|
325 |
+
" new_sequences.append(seq[i:i+chunk_size])\n",
|
326 |
+
" new_labels.append(lbl[i:i+chunk_size])\n",
|
327 |
+
" else:\n",
|
328 |
+
" new_sequences.append(seq)\n",
|
329 |
+
" new_labels.append(lbl)\n",
|
330 |
+
" \n",
|
331 |
+
" return new_sequences, new_labels\n",
|
332 |
+
"\n",
|
333 |
+
"# Extract the necessary columns to create lists of sequences and labels\n",
|
334 |
+
"test_sequences_by_family = test_df['Sequence'].tolist()\n",
|
335 |
+
"test_labels_by_family = test_df['Binding-Active site'].tolist()\n",
|
336 |
+
"train_sequences_by_family = train_df['Sequence'].tolist()\n",
|
337 |
+
"train_labels_by_family = train_df['Binding-Active site'].tolist()\n",
|
338 |
+
"\n",
|
339 |
+
"# Get the number of samples in each dataset\n",
|
340 |
+
"num_test_samples = len(test_sequences_by_family)\n",
|
341 |
+
"num_train_samples = len(train_sequences_by_family)\n",
|
342 |
+
"\n",
|
343 |
+
"# Generate random indices representing 50% of each dataset\n",
|
344 |
+
"random_test_indices = random.sample(range(num_test_samples), num_test_samples // 26.66)\n",
|
345 |
+
"random_train_indices = random.sample(range(num_train_samples), num_train_samples // 26.66)\n",
|
346 |
+
"\n",
|
347 |
+
"# Create smaller datasets using the random indices\n",
|
348 |
+
"test_sequences_small = [test_sequences_by_family[i] for i in random_test_indices]\n",
|
349 |
+
"test_labels_small = [test_labels_by_family[i] for i in random_test_indices]\n",
|
350 |
+
"train_sequences_small = [train_sequences_by_family[i] for i in random_train_indices]\n",
|
351 |
+
"train_labels_small = [train_labels_by_family[i] for i in random_train_indices]\n",
|
352 |
+
"\n",
|
353 |
+
"# Apply the function to create new datasets with chunks of size 1000 or less\n",
|
354 |
+
"test_sequences_chunked, test_labels_chunked = split_into_chunks(test_sequences_small, test_labels_small)\n",
|
355 |
+
"train_sequences_chunked, train_labels_chunked = split_into_chunks(train_sequences_small, train_labels_small)\n",
|
356 |
+
"\n",
|
357 |
+
"# Paths to save the new chunked pickle files\n",
|
358 |
+
"test_labels_chunked_path = '600K_data/test_labels_chunked_by_family.pkl'\n",
|
359 |
+
"test_sequences_chunked_path = '600K_data/test_sequences_chunked_by_family.pkl'\n",
|
360 |
+
"train_labels_chunked_path = '600K_data/train_labels_chunked_by_family.pkl'\n",
|
361 |
+
"train_sequences_chunked_path = '600K_data/train_sequences_chunked_by_family.pkl'\n",
|
362 |
+
"\n",
|
363 |
+
"# Save the chunked datasets as new pickle files\n",
|
364 |
+
"with open(test_labels_chunked_path, 'wb') as file:\n",
|
365 |
+
" pickle.dump(test_labels_chunked, file)\n",
|
366 |
+
"with open(test_sequences_chunked_path, 'wb') as file:\n",
|
367 |
+
" pickle.dump(test_sequences_chunked, file)\n",
|
368 |
+
"with open(train_labels_chunked_path, 'wb') as file:\n",
|
369 |
+
" pickle.dump(train_labels_chunked, file)\n",
|
370 |
+
"with open(train_sequences_chunked_path, 'wb') as file:\n",
|
371 |
+
" pickle.dump(train_sequences_chunked, file)\n",
|
372 |
+
"\n",
|
373 |
+
"test_labels_chunked_path, test_sequences_chunked_path, train_labels_chunked_path, train_sequences_chunked_path\n"
|
374 |
+
]
|
375 |
+
},
|
376 |
+
{
|
377 |
+
"cell_type": "code",
|
378 |
+
"execution_count": 16,
|
379 |
+
"id": "bea3056d-c72c-420f-9036-c9f5069312d6",
|
380 |
+
"metadata": {},
|
381 |
+
"outputs": [],
|
382 |
+
"source": [
|
383 |
+
"# Load each pickle file and get the number of entries in each\n",
|
384 |
+
"with open(test_labels_chunked_path, 'rb') as file:\n",
|
385 |
+
" test_labels_chunked = pickle.load(file)\n",
|
386 |
+
" num_test_labels_chunked = len(test_labels_chunked)\n",
|
387 |
+
"\n",
|
388 |
+
"with open(test_sequences_chunked_path, 'rb') as file:\n",
|
389 |
+
" test_sequences_chunked = pickle.load(file)\n",
|
390 |
+
" num_test_sequences_chunked = len(test_sequences_chunked)\n",
|
391 |
+
"\n",
|
392 |
+
"with open(train_labels_chunked_path, 'rb') as file:\n",
|
393 |
+
" train_labels_chunked = pickle.load(file)\n",
|
394 |
+
" num_train_labels_chunked = len(train_labels_chunked)\n",
|
395 |
+
"\n",
|
396 |
+
"with open(train_sequences_chunked_path, 'rb') as file:\n",
|
397 |
+
" train_sequences_chunked = pickle.load(file)\n",
|
398 |
+
" num_train_sequences_chunked = len(train_sequences_chunked)\n",
|
399 |
+
"\n",
|
400 |
+
"num_test_labels_chunked, num_test_sequences_chunked, num_train_labels_chunked, num_train_sequences_chunked\n"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "code",
|
405 |
+
"execution_count": null,
|
406 |
+
"id": "eb10699a-0441-48be-bafd-c3a1a4d113af",
|
407 |
+
"metadata": {},
|
408 |
+
"outputs": [],
|
409 |
+
"source": []
|
410 |
+
}
|
411 |
+
],
|
412 |
+
"metadata": {
|
413 |
+
"kernelspec": {
|
414 |
+
"display_name": "esm2_binding_py38b",
|
415 |
+
"language": "python",
|
416 |
+
"name": "esm2_binding_py38b"
|
417 |
+
},
|
418 |
+
"language_info": {
|
419 |
+
"codemirror_mode": {
|
420 |
+
"name": "ipython",
|
421 |
+
"version": 3
|
422 |
+
},
|
423 |
+
"file_extension": ".py",
|
424 |
+
"mimetype": "text/x-python",
|
425 |
+
"name": "python",
|
426 |
+
"nbconvert_exporter": "python",
|
427 |
+
"pygments_lexer": "ipython3",
|
428 |
+
"version": "3.8.17"
|
429 |
+
}
|
430 |
+
},
|
431 |
+
"nbformat": 4,
|
432 |
+
"nbformat_minor": 5
|
433 |
+
}
|