ItsMeBell commited on
Commit
a1d9188
·
verified ·
1 Parent(s): dbf59aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -102,6 +102,7 @@ class Predictor:
102
  def __init__(self):
103
  self.model_target_size = None
104
  self.last_loaded_repo = None
 
105
 
106
  def download_model(self, model_repo):
107
  csv_path = huggingface_hub.hf_hub_download(
@@ -117,11 +118,11 @@ class Predictor:
117
  return csv_path, model_path
118
 
119
  def load_model(self, model_repo):
120
- if model_repo == self.last_loaded_repo:
 
121
  return
122
-
123
  csv_path, model_path = self.download_model(model_repo)
124
-
125
  tags_df = pd.read_csv(csv_path)
126
  sep_tags = load_labels(tags_df)
127
 
@@ -130,12 +131,23 @@ class Predictor:
130
  self.general_indexes = sep_tags[2]
131
  self.character_indexes = sep_tags[3]
132
 
133
- model = rt.InferenceSession(model_path)
134
- _, height, width, _ = model.get_inputs()[0].shape
135
- self.model_target_size = height
136
 
137
- self.last_loaded_repo = model_repo
138
- self.model = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  def prepare_image(self, image):
141
  target_size = self.model_target_size
@@ -179,6 +191,9 @@ class Predictor:
179
  ):
180
  self.load_model(model_repo)
181
 
 
 
 
182
  image = self.prepare_image(image)
183
 
184
  input_name = self.model.get_inputs()[0].name
@@ -347,4 +362,4 @@ def main():
347
 
348
 
349
  if __name__ == "__main__":
350
- main()
 
102
  def __init__(self):
103
  self.model_target_size = None
104
  self.last_loaded_repo = None
105
+ self.model = None # Inisialisasi model di sini
106
 
107
  def download_model(self, model_repo):
108
  csv_path = huggingface_hub.hf_hub_download(
 
118
  return csv_path, model_path
119
 
120
  def load_model(self, model_repo):
121
+ # Cek apakah model sudah dimuat
122
+ if model_repo == self.last_loaded_repo and self.model is not None:
123
  return
124
+
125
  csv_path, model_path = self.download_model(model_repo)
 
126
  tags_df = pd.read_csv(csv_path)
127
  sep_tags = load_labels(tags_df)
128
 
 
131
  self.general_indexes = sep_tags[2]
132
  self.character_indexes = sep_tags[3]
133
 
 
 
 
134
 
135
+ # Gunakan CPU execution provider jika GPU tidak tersedia
136
+ providers = ["CPUExecutionProvider"]
137
+ if rt.get_device() == "GPU":
138
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
139
+ try:
140
+ model = rt.InferenceSession(model_path, providers=providers)
141
+ _, height, width, _ = model.get_inputs()[0].shape
142
+ self.model_target_size = height
143
+ self.last_loaded_repo = model_repo
144
+ self.model = model
145
+
146
+ except Exception as e:
147
+ print(f"Error loading model with given providers: {e}")
148
+ self.model = None
149
+ self.last_loaded_repo = None
150
+
151
 
152
  def prepare_image(self, image):
153
  target_size = self.model_target_size
 
191
  ):
192
  self.load_model(model_repo)
193
 
194
+ if self.model is None:
195
+ return "", {}, {}, {}
196
+
197
  image = self.prepare_image(image)
198
 
199
  input_name = self.model.get_inputs()[0].name
 
362
 
363
 
364
  if __name__ == "__main__":
365
+ main()