menghanxia commited on
Commit
2a10b61
·
1 Parent(s): 565e612

fixed the randomness issue

Browse files
Files changed (3) hide show
  1. app.py +18 -13
  2. inference.py +3 -7
  3. models/clusterkit.py +1 -0
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import gradio as gr
2
  import os, requests
 
3
  from inference import setup_model, colorize_grayscale, predict_anchors
4
 
5
- os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/disco-beta.pth.rar")
6
- os.rename("disco-beta.pth.rar", "./checkpoints/disco-beta.pth.rar")
 
 
7
 
8
  ## step 1: set up model
9
  device = "cpu"
@@ -30,7 +33,7 @@ def switch_states(is_checked):
30
  demo = gr.Blocks(title="DISCO")
31
  with demo:
32
  gr.Markdown(value="""
33
- **Gradio demo for DISCO: Disentangled Image Colorization via Global Anchors. [Project Page](https://menghanxia.github.io/projects/disco.html)**.
34
  """)
35
  with gr.Row():
36
  with gr.Column(scale=1):
@@ -39,13 +42,13 @@ with demo:
39
  with gr.Row():
40
  Num_anchor = gr.Number(type="int", value=8, label="Num. of anchors (3~14)")
41
  Radio_resolution = gr.Radio(type="index", choices=["Low (256x256)", "High (512x512)"], \
42
- label="Colorization resolution", value="Low (256x256)")
43
  Ckeckbox_editable = gr.Checkbox(default=False, label='Show editable anchors')
44
  with gr.Row():
45
  Button_show_anchor = gr.Button(value="Predict anchors", visible=False)
46
  Button_run = gr.Button(value="Colorize")
47
  with gr.Column(scale=1):
48
- Image_output = gr.Image(type="numpy", label="Output", shape=[100,100])
49
 
50
  Ckeckbox_editable.change(fn=switch_states, inputs=Ckeckbox_editable, outputs=[Image_anchor, Button_show_anchor])
51
  Button_show_anchor.click(fn=click_predanchors, inputs=[Image_input, Num_anchor, Radio_resolution, Ckeckbox_editable], outputs=Image_anchor)
@@ -55,14 +58,16 @@ with demo:
55
  gr.Markdown(value="""
56
  **Guideline**
57
  1. upload your image;
58
- 2. Set up the arguments: "Num. of anchors" and "Colorization resolution";
59
- 3. Run the colorization (two modes supported):
60
- - **Automatic**: click "Colorize" to get the automatically colorized output.
61
- - **Editable**: check ""Show editable anchors" and click "Predict anchors". Then, modify the colors of the predicted anchors (only anchor region will be used). Finally, click "Colorize" to get the result.
62
  """)
63
  gr.HTML(value="""
64
- <p style='text-align: center'><a href='https://menghanxia.github.io/projects/disco.html' target='_blank'>DISCO Project Page</a> | <a href='https://github.com/MenghanXia/DisentangledColorization' target='_blank'>Github Repo</a></p>
65
  """)
66
-
67
- #demo.launch(server_name='9.134.253.83',server_port=7788)
68
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import os, requests
3
+ import numpy as np
4
  from inference import setup_model, colorize_grayscale, predict_anchors
5
 
6
+ RUN_MODE = "remote"
7
+ if RUN_MODE != "local":
8
+ os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/disco-beta.pth.rar")
9
+ os.rename("disco-beta.pth.rar", "./checkpoints/disco-beta.pth.rar")
10
 
11
  ## step 1: set up model
12
  device = "cpu"
 
33
  demo = gr.Blocks(title="DISCO")
34
  with demo:
35
  gr.Markdown(value="""
36
+ **Gradio demo for DISCO: Disentangled Image Colorization via Global Anchors**. Check our project page [*Here*](https://menghanxia.github.io/projects/disco.html).
37
  """)
38
  with gr.Row():
39
  with gr.Column(scale=1):
 
42
  with gr.Row():
43
  Num_anchor = gr.Number(type="int", value=8, label="Num. of anchors (3~14)")
44
  Radio_resolution = gr.Radio(type="index", choices=["Low (256x256)", "High (512x512)"], \
45
+ label="Colorization resolution (Low is more stable)", value="Low (256x256)")
46
  Ckeckbox_editable = gr.Checkbox(default=False, label='Show editable anchors')
47
  with gr.Row():
48
  Button_show_anchor = gr.Button(value="Predict anchors", visible=False)
49
  Button_run = gr.Button(value="Colorize")
50
  with gr.Column(scale=1):
51
+ Image_output = gr.Image(type="numpy", label="Output").style(height=480)
52
 
53
  Ckeckbox_editable.change(fn=switch_states, inputs=Ckeckbox_editable, outputs=[Image_anchor, Button_show_anchor])
54
  Button_show_anchor.click(fn=click_predanchors, inputs=[Image_input, Num_anchor, Radio_resolution, Ckeckbox_editable], outputs=Image_anchor)
 
58
  gr.Markdown(value="""
59
  **Guideline**
60
  1. upload your image;
61
+ 2. set up the arguments: "Num. of anchors" and "Colorization resolution";
62
+ 3. run the colorization (two modes supported):
63
+ - *Automatic mode*: click "Colorize" to get the automatically colorized output.
64
+ - *Editable mode*: check ""Show editable anchors" and click "Predict anchors". Then, modify the colors of the predicted anchors (only anchor region will be used). Finally, click "Colorize" to get the result.
65
  """)
66
  gr.HTML(value="""
67
+ <p style="text-align:center; color:orange"><a href='https://menghanxia.github.io/projects/disco.html' target='_blank'>DISCO Project Page</a> | <a href='https://github.com/MenghanXia/DisentangledColorization' target='_blank'>Github Repo</a></p>
68
  """)
69
+
70
+ if RUN_MODE == "local":
71
+ demo.launch(server_name='9.134.253.83',server_port=7788)
72
+ else:
73
+ demo.launch()
inference.py CHANGED
@@ -11,17 +11,13 @@ from utils import util
11
 
12
 
13
  def setup_model(checkpt_path, device="cuda"):
14
- seed = 130
15
- np.random.seed(seed)
16
- torch.manual_seed(seed)
17
- torch.cuda.manual_seed(seed)
18
  #print('--------------', torch.cuda.is_available())
19
  """Load the model into memory to make running multiple predictions efficient"""
20
  colorLabeler = basic.ColorLabel(device=device)
21
  colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True, colorLabeler=colorLabeler)
22
  colorizer = colorizer.to(device)
23
  #checkpt_path = "./checkpoints/disco-beta.pth.rar"
24
- assert os.path.exists(checkpt_path)
25
  data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
26
  colorizer.load_state_dict(data_dict['state_dict'])
27
  colorizer.eval()
@@ -89,8 +85,8 @@ def predict_anchors(colorizer, color_class, rgb_img, n_anchors, is_high_res, is_
89
  n_anchors = min(n_anchors, 14)
90
  target_res = (512,512) if is_high_res else (256,256)
91
  input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
92
- input_grays = input_grays.cuda(non_blocking=True)
93
- input_colors = input_colors.cuda(non_blocking=True)
94
 
95
  sampled_T, sp_size = 0, 16
96
  pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
 
11
 
12
 
13
  def setup_model(checkpt_path, device="cuda"):
 
 
 
 
14
  #print('--------------', torch.cuda.is_available())
15
  """Load the model into memory to make running multiple predictions efficient"""
16
  colorLabeler = basic.ColorLabel(device=device)
17
  colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True, colorLabeler=colorLabeler)
18
  colorizer = colorizer.to(device)
19
  #checkpt_path = "./checkpoints/disco-beta.pth.rar"
20
+ assert os.path.exists(checkpt_path), "No checkpoint found!"
21
  data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
22
  colorizer.load_state_dict(data_dict['state_dict'])
23
  colorizer.eval()
 
85
  n_anchors = min(n_anchors, 14)
86
  target_res = (512,512) if is_high_res else (256,256)
87
  input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
88
+ input_grays = input_grays.to(device)
89
+ input_colors = input_colors.to(device)
90
 
91
  sampled_T, sp_size = 0, 16
92
  pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
models/clusterkit.py CHANGED
@@ -103,6 +103,7 @@ def initialize(X, num_clusters):
103
  :param num_clusters: (int) number of clusters
104
  :return: (np.array) initial state
105
  """
 
106
  num_samples = len(X)
107
  indices = np.random.choice(num_samples, num_clusters, replace=False)
108
  initial_state = X[indices]
 
103
  :param num_clusters: (int) number of clusters
104
  :return: (np.array) initial state
105
  """
106
+ np.random.seed(1)
107
  num_samples = len(X)
108
  indices = np.random.choice(num_samples, num_clusters, replace=False)
109
  initial_state = X[indices]