Fix
Browse files- app_inference.py +23 -6
app_inference.py
CHANGED
@@ -61,6 +61,13 @@ class InferenceUtil:
|
|
61 |
instance_prompt = getattr(card.data, 'instance_prompt', '')
|
62 |
return base_model, instance_prompt
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def create_inference_demo(pipe: InferencePipeline,
|
66 |
hf_token: str | None = None) -> gr.Blocks:
|
@@ -117,12 +124,22 @@ def create_inference_demo(pipe: InferencePipeline,
|
|
117 |
with gr.Column():
|
118 |
result = gr.Image(label='Result')
|
119 |
|
120 |
-
model_source.change(
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
lora_model_id.change(fn=app.load_model_info,
|
127 |
inputs=lora_model_id,
|
128 |
outputs=[
|
|
|
61 |
instance_prompt = getattr(card.data, 'instance_prompt', '')
|
62 |
return base_model, instance_prompt
|
63 |
|
64 |
+
def reload_lora_model_list_and_update_model_info(
|
65 |
+
self, model_source: str) -> tuple[dict, str, str]:
|
66 |
+
model_list_update = self.reload_lora_model_list(model_source)
|
67 |
+
model_list = model_list_update['choices']
|
68 |
+
model_info = self.load_model_info(model_list[0] if model_list else '')
|
69 |
+
return model_list_update, *model_info
|
70 |
+
|
71 |
|
72 |
def create_inference_demo(pipe: InferencePipeline,
|
73 |
hf_token: str | None = None) -> gr.Blocks:
|
|
|
124 |
with gr.Column():
|
125 |
result = gr.Image(label='Result')
|
126 |
|
127 |
+
model_source.change(
|
128 |
+
fn=app.reload_lora_model_list_and_update_model_info,
|
129 |
+
inputs=model_source,
|
130 |
+
outputs=[
|
131 |
+
lora_model_id,
|
132 |
+
base_model_used_for_training,
|
133 |
+
instance_prompt_used_for_training,
|
134 |
+
])
|
135 |
+
reload_button.click(
|
136 |
+
fn=app.reload_lora_model_list_and_update_model_info,
|
137 |
+
inputs=model_source,
|
138 |
+
outputs=[
|
139 |
+
lora_model_id,
|
140 |
+
base_model_used_for_training,
|
141 |
+
instance_prompt_used_for_training,
|
142 |
+
])
|
143 |
lora_model_id.change(fn=app.load_model_info,
|
144 |
inputs=lora_model_id,
|
145 |
outputs=[
|