Trent commited on
Commit
98e7562
1 Parent(s): 587ab22

Global lock to avoid concurrent caching

Browse files
Files changed (2) hide show
  1. global_session.py +17 -0
  2. utils.py +15 -2
global_session.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hacky way to provide global session.
2
+ import sys
3
+
4
+ GLOBAL_CONTAINER = sys
5
+
6
+
7
+ class GlobalState(object):
8
+
9
+ def __new__(cls, key="default"):
10
+ if not hasattr(GLOBAL_CONTAINER, '_global_states'):
11
+ GLOBAL_CONTAINER._global_states = {}
12
+ print("Global state container created")
13
+
14
+ if not GLOBAL_CONTAINER._global_states.get(key):
15
+ GLOBAL_CONTAINER._global_states[key] = super(GlobalState, cls).__new__(cls)
16
+
17
+ return GLOBAL_CONTAINER._global_states[key]
utils.py CHANGED
@@ -5,7 +5,8 @@ from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
5
 
6
  from config import MODEL_LIST
7
  from koclip import FlaxHybridCLIP
8
-
 
9
 
10
  @st.cache(allow_output_mutation=True)
11
  def load_index(img_file):
@@ -24,8 +25,20 @@ def load_index(img_file):
24
  return filenames, index
25
 
26
 
27
- @st.cache(allow_output_mutation=True)
28
  def load_model(model_name="koclip/koclip-base"):
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  assert model_name in {f"koclip/{model}" for model in MODEL_LIST}
30
  model = FlaxHybridCLIP.from_pretrained(model_name)
31
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
5
 
6
  from config import MODEL_LIST
7
  from koclip import FlaxHybridCLIP
8
+ from global_session import GlobalState
9
+ from threading import Lock
10
 
11
  @st.cache(allow_output_mutation=True)
12
  def load_index(img_file):
 
25
  return filenames, index
26
 
27
 
 
28
  def load_model(model_name="koclip/koclip-base"):
29
+ state = GlobalState(model_name)
30
+ if not hasattr(state, '_lock'):
31
+ state._lock = Lock()
32
+ print(f"Locking loading of model : {model_name} to avoid concurrent caching.")
33
+
34
+ with state._lock:
35
+ cached_model = load_model_cached(model_name)
36
+
37
+ print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.")
38
+ return cached_model
39
+
40
+ @st.cache(allow_output_mutation=True)
41
+ def load_model_cached(model_name):
42
  assert model_name in {f"koclip/{model}" for model in MODEL_LIST}
43
  model = FlaxHybridCLIP.from_pretrained(model_name)
44
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")