JadenFK commited on
Commit
87cd610
2 Parent(s): 7021212 7c71893

Merge branch 'main' of https://huggingface.co./spaces/baulab/Erasing-Concepts-In-Diffusion into main

Browse files
Files changed (2) hide show
  1. app.py +16 -8
  2. train.py +2 -1
app.py CHANGED
@@ -5,10 +5,11 @@ from StableDiffuser import StableDiffuser
5
  from tqdm import tqdm
6
  from train import train
7
 
8
- model_map = {
9
- 'Car' : 'models/car.pt',
10
- 'Van Gogh' : 'models/vangogh.pt',
11
- }
 
12
 
13
 
14
  class Demo:
@@ -50,7 +51,7 @@ class Demo:
50
 
51
  self.model_dropdown = gr.Dropdown(
52
  label="ESD Model",
53
- choices=['Van Gogh', 'Car'],
54
  value='Van Gogh',
55
  interactive=True
56
  )
@@ -151,10 +152,13 @@ class Demo:
151
  )
152
 
153
  def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
154
-
155
  if self.training:
156
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
157
-
 
 
 
158
  if train_method == 'ESD-x':
159
 
160
  modules = ".*attn2$"
@@ -184,11 +188,14 @@ class Demo:
184
 
185
  model_map['Custom'] = save_path
186
 
 
 
 
187
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
188
 
189
 
190
  def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
191
-
192
  self.diffuser._seed = seed or 42
193
 
194
  model_path = model_map[model_name]
@@ -219,6 +226,7 @@ class Demo:
219
 
220
  edited_image = images[0][0]
221
 
 
222
  torch.cuda.empty_cache()
223
 
224
  return edited_image, orig_image
 
5
  from tqdm import tqdm
6
  from train import train
7
 
8
+ model_map = {'Car' : 'models/car.pt',
9
+ 'Van Gogh' : 'models/vangogh.pt',
10
+ 'Kilian Eng' : 'models/kilianeng.pt',
11
+ 'Thomas Kinkade' : 'models/thomaskinkade.pt',
12
+ 'Tyler Edlin' : 'models/tyleredlin.pt'}
13
 
14
 
15
  class Demo:
 
51
 
52
  self.model_dropdown = gr.Dropdown(
53
  label="ESD Model",
54
+ choices= list(model_map.keys()),
55
  value='Van Gogh',
56
  interactive=True
57
  )
 
152
  )
153
 
154
  def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
155
+ # self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda').eval().half()
156
  if self.training:
157
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
158
+ # clear the diffusers
159
+ # del self.diffuser
160
+ # torch.cuda.empty_cache()
161
+
162
  if train_method == 'ESD-x':
163
 
164
  modules = ".*attn2$"
 
188
 
189
  model_map['Custom'] = save_path
190
 
191
+ # del self.diffuser
192
+ torch.cuda.empty_cache()
193
+ # self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda').eval().half()
194
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
195
 
196
 
197
  def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
198
+
199
  self.diffuser._seed = seed or 42
200
 
201
  model_path = model_map[model_name]
 
226
 
227
  edited_image = images[0][0]
228
 
229
+ del self.finetuner
230
  torch.cuda.empty_cache()
231
 
232
  return edited_image, orig_image
train.py CHANGED
@@ -76,7 +76,8 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
76
  optimizer.step()
77
 
78
  torch.save(finetuner.state_dict(), save_path)
79
-
 
80
  if __name__ == '__main__':
81
 
82
  import argparse
 
76
  optimizer.step()
77
 
78
  torch.save(finetuner.state_dict(), save_path)
79
+ del diffuser
80
+ torch.cuda.empty_cache()
81
  if __name__ == '__main__':
82
 
83
  import argparse