davanstrien HF staff commited on
Commit
eb444c6
·
verified ·
1 Parent(s): 9a546d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -23
app.py CHANGED
@@ -38,30 +38,52 @@ def load_model():
38
  logger.error(f"Failed to load model: {e}")
39
  return False
40
 
41
- def get_card_info(hub_id: str) -> Tuple[str, str]:
42
  """Get card information from a Hugging Face hub_id."""
43
  model_exists = False
44
  dataset_exists = False
45
  model_text = None
46
  dataset_text = None
47
 
48
- # Try getting model card
49
- try:
50
- info = model_info(hub_id)
51
- card = ModelCard.load(hub_id)
52
- model_exists = True
53
- model_text = card.text
54
- except Exception as e:
55
- logger.debug(f"No model card found for {hub_id}: {e}")
56
-
57
- # Try getting dataset card
58
- try:
59
- info = dataset_info(hub_id)
60
- card = DatasetCard.load(hub_id)
61
- dataset_exists = True
62
- dataset_text = card.text
63
- except Exception as e:
64
- logger.debug(f"No dataset card found for {hub_id}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Handle different cases
67
  if model_exists and dataset_exists:
@@ -115,12 +137,12 @@ def generate_summary(card_text: str, card_type: str) -> str:
115
  """Cached wrapper for generate_summary with TTL."""
116
  return _generate_summary_gpu(card_text, card_type)
117
 
118
- def summarize(hub_id: str = "") -> str:
119
  """Interface function for Gradio. Returns JSON format."""
120
  try:
121
  if hub_id:
122
- # Fetch and infer card type automatically
123
- card_type, card_text = get_card_info(hub_id)
124
 
125
  if card_type == "both":
126
  model_text, dataset_text = card_text
@@ -148,7 +170,15 @@ def summarize(hub_id: str = "") -> str:
148
  def create_interface():
149
  interface = gr.Interface(
150
  fn=summarize,
151
- inputs=gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"),
 
 
 
 
 
 
 
 
152
  outputs=gr.JSON(label="Output"),
153
  title="Hugging Face Hub TLDR Generator",
154
  description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.",
@@ -160,4 +190,4 @@ if __name__ == "__main__":
160
  interface = create_interface()
161
  interface.launch()
162
  else:
163
- print("Failed to load model. Please check the logs for details.")
 
38
  logger.error(f"Failed to load model: {e}")
39
  return False
40
 
41
+ def get_card_info(hub_id: str, repo_type: str = "auto") -> Tuple[str, str]:
42
  """Get card information from a Hugging Face hub_id."""
43
  model_exists = False
44
  dataset_exists = False
45
  model_text = None
46
  dataset_text = None
47
 
48
+ # Handle based on repo type
49
+ if repo_type == "auto":
50
+ # Try getting model card
51
+ try:
52
+ info = model_info(hub_id)
53
+ card = ModelCard.load(hub_id)
54
+ model_exists = True
55
+ model_text = card.text
56
+ except Exception as e:
57
+ logger.debug(f"No model card found for {hub_id}: {e}")
58
+
59
+ # Try getting dataset card
60
+ try:
61
+ info = dataset_info(hub_id)
62
+ card = DatasetCard.load(hub_id)
63
+ dataset_exists = True
64
+ dataset_text = card.text
65
+ except Exception as e:
66
+ logger.debug(f"No dataset card found for {hub_id}: {e}")
67
+ elif repo_type == "model":
68
+ try:
69
+ info = model_info(hub_id)
70
+ card = ModelCard.load(hub_id)
71
+ model_exists = True
72
+ model_text = card.text
73
+ except Exception as e:
74
+ logger.error(f"Failed to get model card for {hub_id}: {e}")
75
+ raise ValueError(f"Could not find model with id {hub_id}")
76
+ elif repo_type == "dataset":
77
+ try:
78
+ info = dataset_info(hub_id)
79
+ card = DatasetCard.load(hub_id)
80
+ dataset_exists = True
81
+ dataset_text = card.text
82
+ except Exception as e:
83
+ logger.error(f"Failed to get dataset card for {hub_id}: {e}")
84
+ raise ValueError(f"Could not find dataset with id {hub_id}")
85
+ else:
86
+ raise ValueError(f"Invalid repo_type: {repo_type}. Must be 'auto', 'model', or 'dataset'")
87
 
88
  # Handle different cases
89
  if model_exists and dataset_exists:
 
137
  """Cached wrapper for generate_summary with TTL."""
138
  return _generate_summary_gpu(card_text, card_type)
139
 
140
+ def summarize(hub_id: str = "", repo_type: str = "auto") -> str:
141
  """Interface function for Gradio. Returns JSON format."""
142
  try:
143
  if hub_id:
144
+ # Fetch card information with specified repo_type
145
+ card_type, card_text = get_card_info(hub_id, repo_type)
146
 
147
  if card_type == "both":
148
  model_text, dataset_text = card_text
 
170
  def create_interface():
171
  interface = gr.Interface(
172
  fn=summarize,
173
+ inputs=[
174
+ gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"),
175
+ gr.Radio(
176
+ choices=["auto", "model", "dataset"],
177
+ value="auto",
178
+ label="Repository Type",
179
+ info="Choose 'auto' to detect automatically, or specify the repository type"
180
+ )
181
+ ],
182
  outputs=gr.JSON(label="Output"),
183
  title="Hugging Face Hub TLDR Generator",
184
  description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.",
 
190
  interface = create_interface()
191
  interface.launch()
192
  else:
193
+ print("Failed to load model. Please check the logs for details.")