File size: 2,426 Bytes
dfd4a53
662ed4b
611a3ed
 
d0f55c6
523fad9
2f4d877
d0f55c6
523fad9
 
d0f55c6
719c272
 
 
 
dfd4a53
 
 
 
 
719c272
 
dfd4a53
 
 
719c272
 
 
 
dfd4a53
 
 
 
 
 
 
 
 
 
 
719c272
 
7e32ac7
 
 
 
 
 
d0f55c6
 
fae0e19
d0f55c6
 
719c272
 
d0f55c6
 
 
fae0e19
2f4d877
 
719c272
 
662ed4b
 
2f4d877
 
d0f55c6
523fad9
 
 
 
 
 
 
 
719c272
 
523fad9
 
 
 
 
 
 
 
719c272
 
523fad9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import asyncio
import io
import json

import httpx
from huggingface_hub import HfFileSystem, ModelCard, hf_hub_url
from huggingface_hub.utils import build_hf_headers

import src.constants as constants


class Client:
    def __init__(self):
        self.client = httpx.AsyncClient(follow_redirects=True)

    async def _get(self, url, headers=None, params=None):
        r = await self.client.get(url, headers=headers, params=params)
        r.raise_for_status()
        return r

    async def get(self, url, headers=None, params=None):
        try:
            r = await self._get(url, headers=headers, params=params)
        except httpx.ReadTimeout:
            return await self.retry(self._get, url, headers=headers, params=params)
        except httpx.HTTPError:
            return
        return r

    async def retry(self, func, url, max_retries=4, max_wait_time=8, wait_time=1, **kwargs):
        for _ in range(max_retries):
            try:
                await asyncio.sleep(wait_time)
                return await func(url, **kwargs)
            except httpx.ReadTimeout:
                wait_time = wait_time * 2
                if wait_time > max_wait_time:
                    print("HTTP Timeout: max retries exceeded with url:", url)
                    return


client = Client()
fs = HfFileSystem()


def glob(path):
    paths = fs.glob(path)
    return paths


async def load_json_file(path):
    url = to_url(path)
    r = await client.get(url)
    if r is None:
        return
    return r.json()


async def load_jsonlines_file(path):
    url = to_url(path)
    r = await client.get(url, headers=build_hf_headers())
    if r is None:
        return
    f = io.StringIO(r.text)
    return [json.loads(line) for line in f]


def to_url(path):
    *repo_type, org_name, ds_name, filename = path.split("/", 3)
    repo_type = repo_type[0][:-1] if repo_type else None
    return hf_hub_url(repo_id=f"{org_name}/{ds_name}", filename=filename, repo_type=repo_type)


async def load_model_card(model_id):
    url = to_url(f"{model_id}/README.md")
    r = await client.get(url)
    if r is None:
        return
    return ModelCard(r.text, ignore_metadata_errors=True)


async def list_models(filtering=None):
    params = {}
    if filtering:
        params["filter"] = filtering
    r = await client.get(f"{constants.HF_API_URL}/models", params=params)
    if r is None:
        return
    return r.json()