NSTiwari commited on
Commit
68b64fb
·
verified ·
1 Parent(s): fb9e7b8

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +201 -0
  2. requirements.txt +12 -0
  3. run.sh +3 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional, Tuple
8
+ from urllib.request import urlopen, urlretrieve
9
+
10
+ import streamlit as st
11
+ from huggingface_hub import HfApi, whoami
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class Config:
19
+ """Application configuration."""
20
+
21
+ hf_token: str
22
+ hf_username: str
23
+ transformers_version: str = "3.0.0"
24
+ hf_base_url: str = "https://huggingface.co"
25
+ transformers_base_url: str = (
26
+ "https://github.com/xenova/transformers.js/archive/refs"
27
+ )
28
+ repo_path: Path = Path("./transformers.js")
29
+
30
+ @classmethod
31
+ def from_env(cls) -> "Config":
32
+ """Create config from environment variables and secrets."""
33
+ system_token = st.secrets.get("HF_TOKEN")
34
+ user_token = st.session_state.get("user_hf_token")
35
+ if user_token:
36
+ hf_username = whoami(token=user_token)["name"]
37
+ else:
38
+ hf_username = (
39
+ os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
40
+ )
41
+ hf_token = user_token or system_token
42
+
43
+ if not hf_token:
44
+ raise ValueError("HF_TOKEN must be set")
45
+
46
+ return cls(hf_token=hf_token, hf_username=hf_username)
47
+
48
+
49
+ class ModelConverter:
50
+ """Handles model conversion and upload operations."""
51
+
52
+ def __init__(self, config: Config):
53
+ self.config = config
54
+ self.api = HfApi(token=config.hf_token)
55
+
56
+ def _get_ref_type(self) -> str:
57
+ """Determine the reference type for the transformers repository."""
58
+ url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz"
59
+ try:
60
+ return "tags" if urlopen(url).getcode() == 200 else "heads"
61
+ except Exception as e:
62
+ logger.warning(f"Failed to check tags, defaulting to heads: {e}")
63
+ return "heads"
64
+
65
+ def setup_repository(self) -> None:
66
+ """Download and setup transformers repository if needed."""
67
+ if self.config.repo_path.exists():
68
+ return
69
+
70
+ ref_type = self._get_ref_type()
71
+ archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz"
72
+ archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz")
73
+
74
+ try:
75
+ urlretrieve(archive_url, archive_path)
76
+ self._extract_archive(archive_path)
77
+ logger.info("Repository downloaded and extracted successfully")
78
+ except Exception as e:
79
+ raise RuntimeError(f"Failed to setup repository: {e}")
80
+ finally:
81
+ archive_path.unlink(missing_ok=True)
82
+
83
+ def _extract_archive(self, archive_path: Path) -> None:
84
+ """Extract the downloaded archive."""
85
+ import tarfile
86
+ import tempfile
87
+
88
+ with tempfile.TemporaryDirectory() as tmp_dir:
89
+ with tarfile.open(archive_path, "r:gz") as tar:
90
+ tar.extractall(tmp_dir)
91
+
92
+ extracted_folder = next(Path(tmp_dir).iterdir())
93
+ extracted_folder.rename(self.config.repo_path)
94
+
95
+ def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
96
+ """Convert the model to ONNX format."""
97
+ try:
98
+ result = subprocess.run(
99
+ [
100
+ sys.executable,
101
+ "-m",
102
+ "scripts.convert",
103
+ "--quantize",
104
+ "--model_id",
105
+ input_model_id,
106
+ ],
107
+ cwd=self.config.repo_path,
108
+ capture_output=True,
109
+ text=True,
110
+ env={},
111
+ )
112
+
113
+ if result.returncode != 0:
114
+ return False, result.stderr
115
+
116
+ return True, result.stderr
117
+
118
+ except Exception as e:
119
+ return False, str(e)
120
+
121
+ def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
122
+ """Upload the converted model to Hugging Face."""
123
+ try:
124
+ self.api.create_repo(output_model_id, exist_ok=True, private=False)
125
+ model_folder_path = self.config.repo_path / "models" / input_model_id
126
+
127
+ self.api.upload_folder(
128
+ folder_path=str(model_folder_path), repo_id=output_model_id
129
+ )
130
+ return None
131
+ except Exception as e:
132
+ return str(e)
133
+ finally:
134
+ import shutil
135
+
136
+ shutil.rmtree(model_folder_path, ignore_errors=True)
137
+
138
+
139
+ def main():
140
+ """Main application entry point."""
141
+ st.write("## Convert a Hugging Face model to ONNX")
142
+
143
+ try:
144
+ config = Config.from_env()
145
+ converter = ModelConverter(config)
146
+ converter.setup_repository()
147
+
148
+ input_model_id = st.text_input(
149
+ "Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"
150
+ )
151
+
152
+ if not input_model_id:
153
+ return
154
+
155
+ st.text_input(
156
+ f"Optional: Your Hugging Face write token. Fill it if you want to upload the model under your account.",
157
+ type="password",
158
+ key="user_hf_token",
159
+ )
160
+
161
+ model_name = input_model_id.split("/")[-1]
162
+ output_model_id = f"{config.hf_username}/{model_name}-ONNX"
163
+ output_model_url = f"{config.hf_base_url}/{output_model_id}"
164
+
165
+ if converter.api.repo_exists(output_model_id):
166
+ st.write("This model has already been converted! 🎉")
167
+ st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
168
+ return
169
+
170
+ st.write(f"URL where the model will be converted and uploaded to:")
171
+ st.code(output_model_url, language="plaintext")
172
+
173
+ if not st.button(label="Proceed", type="primary"):
174
+ return
175
+
176
+ with st.spinner("Converting model..."):
177
+ success, stderr = converter.convert_model(input_model_id)
178
+ if not success:
179
+ st.error(f"Conversion failed: {stderr}")
180
+ return
181
+
182
+ st.success("Conversion successful!")
183
+ st.code(stderr)
184
+
185
+ with st.spinner("Uploading model..."):
186
+ error = converter.upload_model(input_model_id, output_model_id)
187
+ if error:
188
+ st.error(f"Upload failed: {error}")
189
+ return
190
+
191
+ st.success("Upload successful!")
192
+ st.write("You can now go and view the model on Hugging Face!")
193
+ st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
194
+
195
+ except Exception as e:
196
+ logger.exception("Application error")
197
+ st.error(f"An error occurred: {str(e)}")
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ streamlit
3
+ transformers[torch]==4.43.4
4
+ onnxruntime==1.19.2
5
+ optimum==1.21.3
6
+ onnx==1.16.2
7
+ onnxconverter-common==1.14.0
8
+ tqdm==4.66.5
9
+ onnxslim==0.1.31
10
+ --extra-index-url https://pypi.ngc.nvidia.com
11
+ onnx_graphsurgeon==0.3.27
12
+ timm
run.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/sh
2
+ pip install -r requirements.txt
3
+ streamlit run app.py