import asyncio import os import re from typing import Dict import gradio as gr import httpx from cachetools import TTLCache, cached from cashews import NOT_NONE, cache from dotenv import load_dotenv from httpx import AsyncClient, Limits from huggingface_hub import ( ModelCard, ModelFilter, get_repo_discussions, hf_hub_url, list_models, logging, ) from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError from tqdm.asyncio import tqdm as atqdm from tqdm.auto import tqdm import random cache.setup("mem://") load_dotenv() token = os.environ["HUGGINGFACE_TOKEN"] user_agent = os.environ["USER_AGENT"] assert token assert user_agent headers = {"user-agent": user_agent, "authorization": f"Bearer {token}"} limits = Limits(max_keepalive_connections=10, max_connections=50) def create_client(): return AsyncClient(headers=headers, limits=limits, http2=True) @cached(cache=TTLCache(maxsize=100, ttl=60 * 10)) def get_models(user_or_org): model_filter = ModelFilter(library="transformers", author=user_or_org) return list( tqdm( iter( list_models( filter=model_filter, sort="downloads", direction=-1, cardData=True, full=True, ) ) ) ) def filter_models(models): new_models = [] for model in tqdm(models): try: if card_data := model.cardData: base_model = card_data.get("base_model", None) if not base_model: new_models.append(model) except AttributeError: continue return new_models MODEL_ID_RE_PATTERN = re.compile( "This model is a fine-tuned version of \[(.*?)\]\(.*?\)" ) BASE_MODEL_PATTERN = re.compile("base_model:\s+(.+)") @cached(cache=TTLCache(maxsize=100, ttl=60 * 3)) def has_model_card(model): if siblings := model.siblings: for sibling in siblings: if sibling.rfilename == "README.md": return True return False @cached(cache=TTLCache(maxsize=100, ttl=60)) def check_already_has_base_model(text): return bool(re.search(BASE_MODEL_PATTERN, text)) @cached(cache=TTLCache(maxsize=100, ttl=60)) def extract_model_name(text): return match.group(1) if (match := re.search(MODEL_ID_RE_PATTERN, text)) else None # semaphore = asyncio.Semaphore(10) # Maximum number of concurrent tasks @cache(ttl=120, condition=NOT_NONE) async def check_readme_for_match(model): if not has_model_card(model): return None model_card_url = hf_hub_url(model.modelId, "README.md") client = create_client() try: resp = await client.get(model_card_url) if check_already_has_base_model(resp.text): return None else: return None if resp.status_code != 200 else extract_model_name(resp.text) except httpx.ConnectError: return None except httpx.ReadTimeout: return None except httpx.ConnectTimeout: return None except Exception as e: print(e) return None @cache(ttl=120, condition=NOT_NONE) async def check_model_exists(model, match): client = create_client() url = f"https://huggingface.co./api/models/{match}" try: resp = await client.get(url) if resp.status_code == 200: return {"modelid": model.modelId, "match": match} if resp.status_code == 401: return False except httpx.ConnectError: return None except httpx.ReadTimeout: return None except httpx.ConnectTimeout: return None except Exception as e: print(e) return None @cache(ttl=120, condition=NOT_NONE) async def check_model(model): match = await check_readme_for_match(model) if match: return await check_model_exists(model, match) async def prep_tasks(models): tasks = [] for model in models: task = asyncio.create_task(check_model(model)) tasks.append(task) return [await f for f in atqdm.as_completed(tasks)] def get_data_for_user(user_or_org): models = get_models(user_or_org) models = filter_models(models) results = asyncio.run(prep_tasks(models)) results = [r for r in results if r is not None] return results logger = logging.get_logger() token = os.getenv("HUGGINGFACE_TOKEN") def generate_issue_text(based_model_regex_match, opened_by=None): return f"""This pull request aims to enrich the metadata of your model by adding [`{based_model_regex_match}`](https://huggingface.co./{based_model_regex_match}) as a `base_model` field, situated in the `YAML` block of your model's `README.md`. How did we find this information? We performed a regular expression match on your `README.md` file to determine the connection. **Why add this?** Enhancing your model's metadata in this way: - **Boosts Discoverability** - It becomes straightforward to trace the relationships between various models on the Hugging Face Hub. - **Highlights Impact** - It showcases the contributions and influences different models have within the community. For a hands-on example of how such metadata can play a pivotal role in mapping model connections, take a look at [librarian-bots/base_model_explorer](https://huggingface.co./spaces/librarian-bots/base_model_explorer). This PR comes courtesy of [Librarian Bot](https://huggingface.co./librarian-bot) by request of {opened_by}""" def update_metadata(metadata_payload: Dict[str, str], user_making_request=None): metadata_payload["opened_pr"] = False regex_match = metadata_payload["match"] repo_id = metadata_payload["modelid"] try: model_card = ModelCard.load(repo_id) except RepositoryNotFoundError: return metadata_payload model_card.data["base_model"] = regex_match template = generate_issue_text(regex_match, opened_by=user_making_request) try: if previous_discussions := list(get_repo_discussions(repo_id)): logger.info("found previous discussions") if prs := [ discussion for discussion in previous_discussions if discussion.is_pull_request ]: logger.info("found previous pull requests") for pr in prs: if pr.author == "librarian-bot": logger.info("previously opened PR") if ( pr.title == "Librarian Bot: Add base_model information to model" ): logger.info("previously opened PR to add base_model tag") metadata_payload["opened_pr"] = True return metadata_payload model_card.push_to_hub( repo_id, token=token, repo_type="model", create_pr=True, commit_message="Librarian Bot: Add base_model information to model", commit_description=template, ) metadata_payload["opened_pr"] = True return metadata_payload except HfHubHTTPError: return metadata_payload def open_prs(profile: gr.OAuthProfile | None, user_or_org: str = None): if not profile: return "Please login to open PR requests" username = profile.preferred_username user_to_receive_prs = user_or_org or username data = get_data_for_user(user_to_receive_prs) if user_or_org: random.sample(data, min(10, len(data))) if not data: return "No PRs to open" results = [] for metadata_payload in data: try: results.append( update_metadata(metadata_payload, user_making_request=username) ) except Exception as e: logger.error(e) return f"Opened {len([r for r in results if r['opened_pr']])} PRs" with gr.Blocks() as demo: gr.Markdown("# Librarian Bot") gr.LoginButton(), gr.LogoutButton() user = gr.Textbox(label="user or org to Open PRs for") button = gr.Button() results = gr.Markdown() button.click(open_prs, [user], results) demo.queue(concurrency_count=1).launch()