TuringsSolutions commited on
Commit
1f52ded
·
verified ·
1 Parent(s): 12a587e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -27
app.py CHANGED
@@ -11,6 +11,9 @@ from skimage.transform import resize
11
  from PIL import Image, ImageEnhance, ImageFilter
12
  from tqdm import tqdm
13
 
 
 
 
14
  class SwarmAgent:
15
  def __init__(self, position, velocity):
16
  self.position = position
@@ -21,7 +24,7 @@ class SwarmAgent:
21
  class SwarmNeuralNetwork:
22
  def __init__(self, num_agents, image_shape, target_image):
23
  self.image_shape = image_shape
24
- self.resized_shape = (256, 256, 3) # Increased resolution
25
  self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
26
  self.target_image = self.load_target_image(target_image)
27
  self.generated_image = np.random.randn(*image_shape) # Start with noise
@@ -83,7 +86,7 @@ class SwarmNeuralNetwork:
83
 
84
  return np.array(losses)
85
 
86
- @spaces.GPU(duration=120)
87
  def update_agents(self, timestep):
88
  noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)]
89
 
@@ -100,7 +103,7 @@ class SwarmNeuralNetwork:
100
  # Clip values
101
  agent.position = np.clip(agent.position, -1, 1)
102
 
103
- @spaces.GPU(duration=120)
104
  def generate_image(self):
105
  self.generated_image = np.mean([agent.position for agent in self.agents], axis=0)
106
  # Normalize to [0, 1] range for display
@@ -110,7 +113,6 @@ class SwarmNeuralNetwork:
110
  # Apply sharpening filter
111
  image_pil = Image.fromarray((self.generated_image * 255).astype(np.uint8))
112
  image_pil = image_pil.filter(ImageFilter.SHARPEN)
113
- image_pil = image_pil.filter(ImageFilter.DETAIL)
114
  self.generated_image = np.array(image_pil) / 255.0
115
 
116
  def train(self, epochs):
@@ -123,7 +125,7 @@ class SwarmNeuralNetwork:
123
  mse = np.mean(((self.generated_image * 2 - 1) - self.target_image)**2)
124
  logging.info(f"Epoch {epoch}, MSE: {mse}")
125
 
126
- if epoch % 10 == 0:
127
  print(f"Epoch {epoch}, MSE: {mse}")
128
  self.display_image(self.generated_image, title=f'Epoch {epoch}')
129
  self.current_epoch += 1
@@ -156,8 +158,8 @@ class SwarmNeuralNetwork:
156
  self.generated_image = model_state['generated_image']
157
  self.current_epoch = model_state['current_epoch']
158
 
159
- @spaces.GPU(duration=120)
160
- def generate_new_image(self, num_steps=1000):
161
  for agent in self.agents:
162
  agent.position = np.random.randn(*self.image_shape)
163
 
@@ -167,21 +169,10 @@ class SwarmNeuralNetwork:
167
  self.generate_image()
168
  return self.generated_image
169
 
170
- @spaces.GPU(duration=60)
171
- def apply_super_resolution(self, image):
172
- import cv2
173
- sr = cv2.dnn_superres.DnnSuperResImpl_create()
174
- path = "EDSR_x3.pb" # Path to a pre-trained super-resolution model file (download it from OpenCV's repository)
175
- sr.readModel(path)
176
- sr.setModel("edsr", 3) # Use EDSR model with a scale factor of 3
177
- image = (image * 255).astype(np.uint8)
178
- upscaled = sr.upsample(image)
179
- return upscaled / 255.0
180
-
181
  # Gradio Interface
182
  @spaces.GPU(duration=120)
183
  def train_snn(image, num_agents, epochs, brightness, contrast, color):
184
- snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(256, 256, 3), target_image=image)
185
 
186
  # Apply user-specified adjustments to the target image
187
  image = ImageEnhance.Brightness(image).enhance(brightness)
@@ -192,23 +183,21 @@ def train_snn(image, num_agents, epochs, brightness, contrast, color):
192
  snn.train(epochs=epochs)
193
  snn.save_model('snn_model.npy')
194
  generated_image = snn.generated_image
195
- upscaled_image = snn.apply_super_resolution(generated_image)
196
- return upscaled_image
197
 
198
  @spaces.GPU(duration=120)
199
  def generate_new_image():
200
- snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(256, 256, 3), target_image=None)
201
  snn.load_model('snn_model.npy')
202
  new_image = snn.generate_new_image()
203
- upscaled_image = snn.apply_super_resolution(new_image)
204
- return upscaled_image
205
 
206
  interface = gr.Interface(
207
  fn=train_snn,
208
  inputs=[
209
  gr.Image(type="pil", label="Upload Target Image"),
210
- gr.Slider(minimum=500, maximum=3000, value=2000, label="Number of Agents"),
211
- gr.Slider(minimum=10, maximum=200, value=100, label="Number of Epochs"),
212
  gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Brightness"),
213
  gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Contrast"),
214
  gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Color Balance")
@@ -218,4 +207,4 @@ interface = gr.Interface(
218
  description="Upload an image and set the number of agents and epochs to train the Swarm Neural Network to generate a new image. Adjust brightness, contrast, and color balance for personalization."
219
  )
220
 
221
- interface.launch()
 
11
  from PIL import Image, ImageEnhance, ImageFilter
12
  from tqdm import tqdm
13
 
14
+ # Disable GPU usage by default
15
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
16
+
17
  class SwarmAgent:
18
  def __init__(self, position, velocity):
19
  self.position = position
 
24
  class SwarmNeuralNetwork:
25
  def __init__(self, num_agents, image_shape, target_image):
26
  self.image_shape = image_shape
27
+ self.resized_shape = (256, 256, 3) # High resolution
28
  self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
29
  self.target_image = self.load_target_image(target_image)
30
  self.generated_image = np.random.randn(*image_shape) # Start with noise
 
86
 
87
  return np.array(losses)
88
 
89
+ @spaces.GPU(duration=90)
90
  def update_agents(self, timestep):
91
  noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)]
92
 
 
103
  # Clip values
104
  agent.position = np.clip(agent.position, -1, 1)
105
 
106
+ @spaces.GPU(duration=90)
107
  def generate_image(self):
108
  self.generated_image = np.mean([agent.position for agent in self.agents], axis=0)
109
  # Normalize to [0, 1] range for display
 
113
  # Apply sharpening filter
114
  image_pil = Image.fromarray((self.generated_image * 255).astype(np.uint8))
115
  image_pil = image_pil.filter(ImageFilter.SHARPEN)
 
116
  self.generated_image = np.array(image_pil) / 255.0
117
 
118
  def train(self, epochs):
 
125
  mse = np.mean(((self.generated_image * 2 - 1) - self.target_image)**2)
126
  logging.info(f"Epoch {epoch}, MSE: {mse}")
127
 
128
+ if epoch % 5 == 0:
129
  print(f"Epoch {epoch}, MSE: {mse}")
130
  self.display_image(self.generated_image, title=f'Epoch {epoch}')
131
  self.current_epoch += 1
 
158
  self.generated_image = model_state['generated_image']
159
  self.current_epoch = model_state['current_epoch']
160
 
161
+ @spaces.GPU(duration=90)
162
+ def generate_new_image(self, num_steps=500): # Optimized number of steps
163
  for agent in self.agents:
164
  agent.position = np.random.randn(*self.image_shape)
165
 
 
169
  self.generate_image()
170
  return self.generated_image
171
 
 
 
 
 
 
 
 
 
 
 
 
172
  # Gradio Interface
173
  @spaces.GPU(duration=120)
174
  def train_snn(image, num_agents, epochs, brightness, contrast, color):
175
+ snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(256, 256, 3), target_image=image) # High resolution
176
 
177
  # Apply user-specified adjustments to the target image
178
  image = ImageEnhance.Brightness(image).enhance(brightness)
 
183
  snn.train(epochs=epochs)
184
  snn.save_model('snn_model.npy')
185
  generated_image = snn.generated_image
186
+ return generated_image
 
187
 
188
  @spaces.GPU(duration=120)
189
  def generate_new_image():
190
+ snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(256, 256, 3), target_image=None) # High resolution and optimized number of agents
191
  snn.load_model('snn_model.npy')
192
  new_image = snn.generate_new_image()
193
+ return new_image
 
194
 
195
  interface = gr.Interface(
196
  fn=train_snn,
197
  inputs=[
198
  gr.Image(type="pil", label="Upload Target Image"),
199
+ gr.Slider(minimum=500, maximum=2000, value=1000, label="Number of Agents"), # Adjusted range for number of agents
200
+ gr.Slider(minimum=10, maximum=100, value=50, label="Number of Epochs"), # Adjusted range for number of epochs
201
  gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Brightness"),
202
  gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Contrast"),
203
  gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Color Balance")
 
207
  description="Upload an image and set the number of agents and epochs to train the Swarm Neural Network to generate a new image. Adjust brightness, contrast, and color balance for personalization."
208
  )
209
 
210
+ interface.launch()