more robust colum names

#1
by davanstrien HF staff - opened
Files changed (1) hide show
  1. main.py +53 -21
main.py CHANGED
@@ -75,16 +75,16 @@ async def get_first_config_and_split_name(hub_id: str):
75
  return data["splits"][0]["config"], data["splits"][0]["split"]
76
  except Exception as e:
77
  logger.error(f"Failed to get splits for {hub_id}: {e}")
78
- return None
79
 
80
 
81
  async def get_dataset_info(hub_id: str, config: str | None = None):
82
  if config is None:
83
- config = get_first_config_and_split_name(hub_id)
84
- if config is None:
85
  return None
86
  else:
87
- config = config[0]
88
  resp = await async_client.get(
89
  f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
90
  )
@@ -229,18 +229,10 @@ def predict_rows(
229
  return default_data
230
 
231
 
232
- # @app.get("/", response_class=HTMLResponse)
233
- # async def read_index():
234
- # html_content = Path("index.html").read_text()
235
- # return HTMLResponse(content=html_content)
236
-
237
-
238
  @app.get("/", include_in_schema=False)
239
  def root():
240
  return RedirectResponse(url="/docs")
241
 
242
- # item_id: Annotated[int, Path(title="The ID of the item to get", ge=1)], q: str
243
-
244
 
245
  @app.get("/predict_dataset_language/{hub_id:path}")
246
  @cache(ttl=timedelta(minutes=10))
@@ -257,31 +249,66 @@ async def predict_language(
257
  is_valid = datasets_server_valid_rows(hub_id)
258
  if not is_valid:
259
  logger.error(f"Dataset {hub_id} is not accessible via the datasets server.")
 
 
260
  if not config and not split:
261
- config, split = await get_first_config_and_split_name(hub_id)
262
- if not config:
263
- config, _ = await get_first_config_and_split_name(hub_id)
264
- if not split:
265
- _, split = await get_first_config_and_split_name(hub_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  info = await get_dataset_info(hub_id, config)
267
  if info is None:
268
  logger.error(f"Dataset {hub_id} is not accessible via the datasets server.")
269
  return None
 
270
  if dataset_info := info.get("dataset_info"):
271
  total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
272
  features = dataset_info.get("features")
 
 
273
  column_names = set(features.keys())
274
  logger.info(f"Column names: {column_names}")
275
- if not set(column_names).intersection(TARGET_COLUMN_NAMES):
 
 
 
 
 
 
 
 
276
  logger.error(
277
- f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}"
278
  )
279
  return None
 
 
 
280
  for column in TARGET_COLUMN_NAMES:
281
- if column in column_names:
282
- target_column = column
 
283
  logger.info(f"Using column {target_column} for language detection")
284
  break
 
 
 
 
 
285
  random_rows = await get_random_rows(
286
  hub_id,
287
  total_rows_for_split,
@@ -290,6 +317,7 @@ async def predict_language(
290
  config,
291
  split,
292
  )
 
293
  logger.info(f"Predicting language for {len(random_rows)} rows")
294
  predictions = predict_rows(
295
  random_rows,
@@ -300,3 +328,7 @@ async def predict_language(
300
  predictions["config"] = config
301
  predictions["split"] = split
302
  return predictions
 
 
 
 
 
75
  return data["splits"][0]["config"], data["splits"][0]["split"]
76
  except Exception as e:
77
  logger.error(f"Failed to get splits for {hub_id}: {e}")
78
+ return (None, None) # Return a tuple of None values
79
 
80
 
81
  async def get_dataset_info(hub_id: str, config: str | None = None):
82
  if config is None:
83
+ config_tuple, _ = await get_first_config_and_split_name(hub_id)
84
+ if config_tuple is None:
85
  return None
86
  else:
87
+ config = config_tuple
88
  resp = await async_client.get(
89
  f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
90
  )
 
229
  return default_data
230
 
231
 
 
 
 
 
 
 
232
  @app.get("/", include_in_schema=False)
233
  def root():
234
  return RedirectResponse(url="/docs")
235
 
 
 
236
 
237
  @app.get("/predict_dataset_language/{hub_id:path}")
238
  @cache(ttl=timedelta(minutes=10))
 
249
  is_valid = datasets_server_valid_rows(hub_id)
250
  if not is_valid:
251
  logger.error(f"Dataset {hub_id} is not accessible via the datasets server.")
252
+ return None # Return early if dataset is not valid
253
+
254
  if not config and not split:
255
+ config_tuple, split_tuple = await get_first_config_and_split_name(hub_id)
256
+ if config_tuple is None:
257
+ logger.error(f"Could not retrieve configuration for dataset {hub_id}")
258
+ return None
259
+ config, split = config_tuple, split_tuple
260
+ elif not config:
261
+ config_tuple, _ = await get_first_config_and_split_name(hub_id)
262
+ if config_tuple is None:
263
+ logger.error(f"Could not retrieve configuration for dataset {hub_id}")
264
+ return None
265
+ config = config_tuple
266
+ elif not split:
267
+ _, split_tuple = await get_first_config_and_split_name(hub_id)
268
+ if split_tuple is None:
269
+ logger.error(f"Could not retrieve split for dataset {hub_id}")
270
+ return None
271
+ split = split_tuple
272
+
273
  info = await get_dataset_info(hub_id, config)
274
  if info is None:
275
  logger.error(f"Dataset {hub_id} is not accessible via the datasets server.")
276
  return None
277
+
278
  if dataset_info := info.get("dataset_info"):
279
  total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
280
  features = dataset_info.get("features")
281
+
282
+ # Get original column names
283
  column_names = set(features.keys())
284
  logger.info(f"Column names: {column_names}")
285
+
286
+ # Create a mapping of lowercase column names to their original casing
287
+ lowercase_to_original = {col.lower(): col for col in column_names}
288
+
289
+ # Check intersection with lowercase versions
290
+ lowercase_column_names = set(lowercase_to_original.keys())
291
+ lowercase_target_columns = {col.lower() for col in TARGET_COLUMN_NAMES}
292
+
293
+ if not lowercase_column_names.intersection(lowercase_target_columns):
294
  logger.error(
295
+ f"Dataset {hub_id} {column_names} does not contain any of the target columns {TARGET_COLUMN_NAMES}"
296
  )
297
  return None
298
+
299
+ # Find target column with case-insensitive matching
300
+ target_column = None
301
  for column in TARGET_COLUMN_NAMES:
302
+ if column.lower() in lowercase_column_names:
303
+ # Use the original casing from the dataset
304
+ target_column = lowercase_to_original[column.lower()]
305
  logger.info(f"Using column {target_column} for language detection")
306
  break
307
+
308
+ if target_column is None:
309
+ logger.error(f"Could not find a suitable column for language detection")
310
+ return None
311
+
312
  random_rows = await get_random_rows(
313
  hub_id,
314
  total_rows_for_split,
 
317
  config,
318
  split,
319
  )
320
+
321
  logger.info(f"Predicting language for {len(random_rows)} rows")
322
  predictions = predict_rows(
323
  random_rows,
 
328
  predictions["config"] = config
329
  predictions["split"] = split
330
  return predictions
331
+
332
+ else:
333
+ logger.error(f"No dataset_info available for {hub_id}")
334
+ return None