muellerzr HF staff commited on
Commit
be6343c
·
1 Parent(s): dd566c3

Allow URLs

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -7,6 +7,7 @@ from huggingface_hub import HfApi
7
  from huggingface_hub.utils import RepositoryNotFoundError, GatedRepoError
8
  from accelerate.commands.estimate import create_empty_model, check_has_model
9
  from accelerate.utils import convert_bytes, calculate_maximum_sizes
 
10
 
11
  # We need to store them as globals because gradio doesn't have a way for us to pass them in to the button
12
  HAS_DISCUSSION = True
@@ -54,12 +55,20 @@ When training with `Adam`, you can expect roughly 4x the reported results to be
54
  discussion = api.create_discussion(MODEL_NAME, "[AUTOMATED] Model Memory Requirements", description=post)
55
  webbrowser.open_new_tab(discussion.url)
56
 
57
- def convert_url_to_name(url:str):
58
- "Converts a model URL to its name on the Hub"
59
- results = re.findall(r"huggingface.co\/(.*?)#", url)
60
- if len(results) < 1:
61
- raise ValueError(f"URL {url} is not a valid model URL to the Hugging Face Hub")
62
- return results[0]
 
 
 
 
 
 
 
 
63
 
64
  def calculate_memory(model_name:str, library:str, options:list, access_token:str, raw=False):
65
  "Calculates the memory usage for a model"
@@ -67,11 +76,7 @@ def calculate_memory(model_name:str, library:str, options:list, access_token:str
67
  model_name = translate_llama2(model_name)
68
  if library == "auto":
69
  library = None
70
- if "http" in model_name and "//" in model_name:
71
- try:
72
- model_name = convert_url_to_name(model_name)
73
- except ValueError:
74
- raise gr.Error(f"URL `{model_name}` is not a valid model URL to the Hugging Face Hub")
75
  try:
76
  model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
77
  except GatedRepoError:
 
7
  from huggingface_hub.utils import RepositoryNotFoundError, GatedRepoError
8
  from accelerate.commands.estimate import create_empty_model, check_has_model
9
  from accelerate.utils import convert_bytes, calculate_maximum_sizes
10
+ from urllib.parse import urlparse
11
 
12
  # We need to store them as globals because gradio doesn't have a way for us to pass them in to the button
13
  HAS_DISCUSSION = True
 
55
  discussion = api.create_discussion(MODEL_NAME, "[AUTOMATED] Model Memory Requirements", description=post)
56
  webbrowser.open_new_tab(discussion.url)
57
 
58
+ def extract_from_url(name:str):
59
+ "Checks if `name` is a URL, and if so converts it to a model name"
60
+ is_url = False
61
+ try:
62
+ result = urlparse(name)
63
+ is_url = all([result.scheme, result.netloc])
64
+ except:
65
+ is_url = False
66
+ # Pass through if not a URL
67
+ if not is_url:
68
+ return name
69
+ else:
70
+ path = result.path
71
+ return path[1:]
72
 
73
  def calculate_memory(model_name:str, library:str, options:list, access_token:str, raw=False):
74
  "Calculates the memory usage for a model"
 
76
  model_name = translate_llama2(model_name)
77
  if library == "auto":
78
  library = None
79
+ model_name = extract_from_url(model_name)
 
 
 
 
80
  try:
81
  model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
82
  except GatedRepoError: