echarlaix HF staff commited on
Commit
c2e6a29
·
1 Parent(s): 1f756ba

simplify model loading

Browse files
Files changed (1) hide show
  1. app.py +10 -20
app.py CHANGED
@@ -24,11 +24,9 @@ from optimum.intel import (
24
  OVModelForSeq2SeqLM,
25
  OVModelForSequenceClassification,
26
  OVModelForTokenClassification,
27
- OVStableDiffusionPipeline,
28
- OVStableDiffusionXLPipeline,
29
- OVLatentConsistencyModelPipeline,
30
  OVModelForPix2Struct,
31
  OVWeightQuantizationConfig,
 
32
  )
33
  from diffusers import ConfigMixin
34
 
@@ -41,9 +39,6 @@ _HEAD_TO_AUTOMODELS = {
41
  "question-answering": "OVModelForQuestionAnswering",
42
  "image-classification": "OVModelForImageClassification",
43
  "audio-classification": "OVModelForAudioClassification",
44
- "stable-diffusion": "OVStableDiffusionPipeline",
45
- "stable-diffusion-xl": "OVStableDiffusionXLPipeline",
46
- "latent-consistency": "OVLatentConsistencyModelPipeline",
47
  }
48
 
49
 
@@ -61,27 +56,22 @@ def export(model_id: str, private_repo: bool, overwritte: bool, oauth_token: gr.
61
  library_name = TasksManager.infer_library_from_model(model_id, token=oauth_token.token)
62
 
63
  if library_name == "diffusers":
64
- ConfigMixin.config_name = "model_index.json"
65
- class_name = ConfigMixin.load_config(model_id, token=oauth_token.token)["_class_name"].lower()
66
- if "xl" in class_name:
67
- task = "stable-diffusion-xl"
68
- elif "consistency" in class_name:
69
- task = "latent-consistency"
70
- else:
71
- task = "stable-diffusion"
72
  else:
73
  task = TasksManager.infer_task_from_model(model_id, token=oauth_token.token)
74
 
75
- if task == "text2text-generation":
76
- return "Export of Seq2Seq models is currently disabled"
 
 
 
 
 
77
 
78
- if task not in _HEAD_TO_AUTOMODELS:
79
- return f"The task '{task}' is not supported, only {_HEAD_TO_AUTOMODELS.keys()} tasks are supported"
80
 
81
- auto_model_class = _HEAD_TO_AUTOMODELS[task]
82
  ov_files = _find_files_matching_pattern(
83
  model_id,
84
- pattern=r"(.*)?openvino(.*)?\_model.xml",
85
  use_auth_token=oauth_token.token,
86
  )
87
 
 
24
  OVModelForSeq2SeqLM,
25
  OVModelForSequenceClassification,
26
  OVModelForTokenClassification,
 
 
 
27
  OVModelForPix2Struct,
28
  OVWeightQuantizationConfig,
29
+ OVDiffusionPipeline,
30
  )
31
  from diffusers import ConfigMixin
32
 
 
39
  "question-answering": "OVModelForQuestionAnswering",
40
  "image-classification": "OVModelForImageClassification",
41
  "audio-classification": "OVModelForAudioClassification",
 
 
 
42
  }
43
 
44
 
 
56
  library_name = TasksManager.infer_library_from_model(model_id, token=oauth_token.token)
57
 
58
  if library_name == "diffusers":
59
+ auto_model_class = "OVDiffusionPipeline"
 
 
 
 
 
 
 
60
  else:
61
  task = TasksManager.infer_task_from_model(model_id, token=oauth_token.token)
62
 
63
+ if task == "text2text-generation":
64
+ return "Export of Seq2Seq models is currently disabled"
65
+
66
+ if task not in _HEAD_TO_AUTOMODELS:
67
+ return f"The task '{task}' is not supported, only {_HEAD_TO_AUTOMODELS.keys()} tasks are supported"
68
+
69
+ auto_model_class = _HEAD_TO_AUTOMODELS[task]
70
 
 
 
71
 
 
72
  ov_files = _find_files_matching_pattern(
73
  model_id,
74
+ pattern=r"(.*)?openvino(.*)?\_model(.*)?.xml$",
75
  use_auth_token=oauth_token.token,
76
  )
77