Commit
·
bc828c5
1
Parent(s):
a2d40e3
Add language parsing functionality and update dependencies
Browse files- main.py +33 -10
- requirements.in +2 -1
- requirements.txt +2 -0
main.py
CHANGED
@@ -15,6 +15,7 @@ from starlette.responses import RedirectResponse
|
|
15 |
from cashews import cache
|
16 |
from datetime import timedelta
|
17 |
import logging
|
|
|
18 |
|
19 |
cache.setup("mem://")
|
20 |
|
@@ -93,6 +94,19 @@ async def get_dataset_info(hub_id: str, config: str | None = None):
|
|
93 |
return resp.json()
|
94 |
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
async def get_random_rows(
|
97 |
hub_id: str,
|
98 |
total_length: int,
|
@@ -110,15 +124,8 @@ async def get_random_rows(
|
|
110 |
offset = random.randint(0, total_length - rows_per_call)
|
111 |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
|
112 |
logger.info(f"Fetching {url}")
|
113 |
-
|
114 |
-
|
115 |
-
if response.status_code == 200:
|
116 |
-
data = response.json()
|
117 |
-
batch_rows = data.get("rows")
|
118 |
-
rows.extend(batch_rows)
|
119 |
-
else:
|
120 |
-
print(f"Failed to fetch data: {response.status_code}")
|
121 |
-
print(url)
|
122 |
if len(rows) >= number_of_rows:
|
123 |
break
|
124 |
return [row.get("row") for row in rows]
|
@@ -181,6 +188,17 @@ def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
|
|
181 |
return {k for k, v in counts_dict.items() if v >= threshold}
|
182 |
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
def predict_rows(
|
185 |
rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False
|
186 |
):
|
@@ -196,8 +214,13 @@ def predict_rows(
|
|
196 |
langues_counts, threshold_percent=language_threshold_percent
|
197 |
)
|
198 |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
|
|
|
|
|
|
|
|
|
199 |
default_data = {
|
200 |
-
"
|
|
|
201 |
"hub_id": "hub_id",
|
202 |
"config": "config",
|
203 |
}
|
|
|
15 |
from cashews import cache
|
16 |
from datetime import timedelta
|
17 |
import logging
|
18 |
+
from iso639 import Lang
|
19 |
|
20 |
cache.setup("mem://")
|
21 |
|
|
|
94 |
return resp.json()
|
95 |
|
96 |
|
97 |
+
@cache(ttl=timedelta(minutes=5))
|
98 |
+
async def fetch_rows(url: str) -> list[dict]:
|
99 |
+
response = await async_client.get(url)
|
100 |
+
if response.status_code == 200:
|
101 |
+
data = response.json()
|
102 |
+
return data.get("rows")
|
103 |
+
else:
|
104 |
+
print(f"Failed to fetch data: {response.status_code}")
|
105 |
+
print(url)
|
106 |
+
return []
|
107 |
+
|
108 |
+
|
109 |
+
# Function to get random rows from the dataset
|
110 |
async def get_random_rows(
|
111 |
hub_id: str,
|
112 |
total_length: int,
|
|
|
124 |
offset = random.randint(0, total_length - rows_per_call)
|
125 |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
|
126 |
logger.info(f"Fetching {url}")
|
127 |
+
batch_rows = await fetch_rows(url)
|
128 |
+
rows.extend(batch_rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
if len(rows) >= number_of_rows:
|
130 |
break
|
131 |
return [row.get("row") for row in rows]
|
|
|
188 |
return {k for k, v in counts_dict.items() if v >= threshold}
|
189 |
|
190 |
|
191 |
+
def try_parse_language(lang: str) -> str | None:
|
192 |
+
try:
|
193 |
+
split = lang.split("_")
|
194 |
+
lang = split[0]
|
195 |
+
lang = Lang(lang)
|
196 |
+
return lang.pt1
|
197 |
+
except Exception as e:
|
198 |
+
logger.error(f"Failed to parse language {lang}: {e}")
|
199 |
+
return None
|
200 |
+
|
201 |
+
|
202 |
def predict_rows(
|
203 |
rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False
|
204 |
):
|
|
|
214 |
langues_counts, threshold_percent=language_threshold_percent
|
215 |
)
|
216 |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
|
217 |
+
raw_model_prediction_summary = dict(valmap(get_mean_score, filtered_dict))
|
218 |
+
parsed_langs = {
|
219 |
+
try_parse_language(k): v for k, v in raw_model_prediction_summary.items()
|
220 |
+
}
|
221 |
default_data = {
|
222 |
+
"language_prediction_summary": parsed_langs,
|
223 |
+
"raw_model_prediction_summary": raw_model_prediction_summary,
|
224 |
"hub_id": "hub_id",
|
225 |
"config": "config",
|
226 |
}
|
requirements.in
CHANGED
@@ -8,4 +8,5 @@ huggingface_hub
|
|
8 |
python-dotenv
|
9 |
rich
|
10 |
toolz
|
11 |
-
uvicorn[standard]
|
|
|
|
8 |
python-dotenv
|
9 |
rich
|
10 |
toolz
|
11 |
+
uvicorn[standard]
|
12 |
+
iso639-lang
|
requirements.txt
CHANGED
@@ -51,6 +51,8 @@ idna==3.6
|
|
51 |
# anyio
|
52 |
# httpx
|
53 |
# requests
|
|
|
|
|
54 |
markdown-it-py==3.0.0
|
55 |
# via rich
|
56 |
mdurl==0.1.2
|
|
|
51 |
# anyio
|
52 |
# httpx
|
53 |
# requests
|
54 |
+
iso639-lang==2.2.2
|
55 |
+
# via -r requirements.in
|
56 |
markdown-it-py==3.0.0
|
57 |
# via rich
|
58 |
mdurl==0.1.2
|