hysts HF staff commited on
Commit
c11895e
1 Parent(s): f2dba90
Files changed (1) hide show
  1. app_inference.py +23 -6
app_inference.py CHANGED
@@ -61,6 +61,13 @@ class InferenceUtil:
61
  instance_prompt = getattr(card.data, 'instance_prompt', '')
62
  return base_model, instance_prompt
63
 
 
 
 
 
 
 
 
64
 
65
  def create_inference_demo(pipe: InferencePipeline,
66
  hf_token: str | None = None) -> gr.Blocks:
@@ -117,12 +124,22 @@ def create_inference_demo(pipe: InferencePipeline,
117
  with gr.Column():
118
  result = gr.Image(label='Result')
119
 
120
- model_source.change(fn=app.reload_lora_model_list,
121
- inputs=model_source,
122
- outputs=lora_model_id)
123
- reload_button.click(fn=app.reload_lora_model_list,
124
- inputs=model_source,
125
- outputs=lora_model_id)
 
 
 
 
 
 
 
 
 
 
126
  lora_model_id.change(fn=app.load_model_info,
127
  inputs=lora_model_id,
128
  outputs=[
 
61
  instance_prompt = getattr(card.data, 'instance_prompt', '')
62
  return base_model, instance_prompt
63
 
64
+ def reload_lora_model_list_and_update_model_info(
65
+ self, model_source: str) -> tuple[dict, str, str]:
66
+ model_list_update = self.reload_lora_model_list(model_source)
67
+ model_list = model_list_update['choices']
68
+ model_info = self.load_model_info(model_list[0] if model_list else '')
69
+ return model_list_update, *model_info
70
+
71
 
72
  def create_inference_demo(pipe: InferencePipeline,
73
  hf_token: str | None = None) -> gr.Blocks:
 
124
  with gr.Column():
125
  result = gr.Image(label='Result')
126
 
127
+ model_source.change(
128
+ fn=app.reload_lora_model_list_and_update_model_info,
129
+ inputs=model_source,
130
+ outputs=[
131
+ lora_model_id,
132
+ base_model_used_for_training,
133
+ instance_prompt_used_for_training,
134
+ ])
135
+ reload_button.click(
136
+ fn=app.reload_lora_model_list_and_update_model_info,
137
+ inputs=model_source,
138
+ outputs=[
139
+ lora_model_id,
140
+ base_model_used_for_training,
141
+ instance_prompt_used_for_training,
142
+ ])
143
  lora_model_id.change(fn=app.load_model_info,
144
  inputs=lora_model_id,
145
  outputs=[