TuringsSolutions commited on
Commit
12a587e
·
verified ·
1 Parent(s): 7612874

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import gradio as gr
3
  import numpy as np
4
  import tensorflow as tf
@@ -10,9 +11,6 @@ from skimage.transform import resize
10
  from PIL import Image, ImageEnhance, ImageFilter
11
  from tqdm import tqdm
12
 
13
- # Ensure TensorFlow runs on CPU
14
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
15
-
16
  class SwarmAgent:
17
  def __init__(self, position, velocity):
18
  self.position = position
@@ -85,6 +83,7 @@ class SwarmNeuralNetwork:
85
 
86
  return np.array(losses)
87
 
 
88
  def update_agents(self, timestep):
89
  noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)]
90
 
@@ -101,6 +100,7 @@ class SwarmNeuralNetwork:
101
  # Clip values
102
  agent.position = np.clip(agent.position, -1, 1)
103
 
 
104
  def generate_image(self):
105
  self.generated_image = np.mean([agent.position for agent in self.agents], axis=0)
106
  # Normalize to [0, 1] range for display
@@ -156,6 +156,7 @@ class SwarmNeuralNetwork:
156
  self.generated_image = model_state['generated_image']
157
  self.current_epoch = model_state['current_epoch']
158
 
 
159
  def generate_new_image(self, num_steps=1000):
160
  for agent in self.agents:
161
  agent.position = np.random.randn(*self.image_shape)
@@ -166,6 +167,7 @@ class SwarmNeuralNetwork:
166
  self.generate_image()
167
  return self.generated_image
168
 
 
169
  def apply_super_resolution(self, image):
170
  import cv2
171
  sr = cv2.dnn_superres.DnnSuperResImpl_create()
@@ -177,6 +179,7 @@ class SwarmNeuralNetwork:
177
  return upscaled / 255.0
178
 
179
  # Gradio Interface
 
180
  def train_snn(image, num_agents, epochs, brightness, contrast, color):
181
  snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(256, 256, 3), target_image=image)
182
 
@@ -192,6 +195,7 @@ def train_snn(image, num_agents, epochs, brightness, contrast, color):
192
  upscaled_image = snn.apply_super_resolution(generated_image)
193
  return upscaled_image
194
 
 
195
  def generate_new_image():
196
  snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(256, 256, 3), target_image=None)
197
  snn.load_model('snn_model.npy')
 
1
  import os
2
+ import spaces
3
  import gradio as gr
4
  import numpy as np
5
  import tensorflow as tf
 
11
  from PIL import Image, ImageEnhance, ImageFilter
12
  from tqdm import tqdm
13
 
 
 
 
14
  class SwarmAgent:
15
  def __init__(self, position, velocity):
16
  self.position = position
 
83
 
84
  return np.array(losses)
85
 
86
+ @spaces.GPU(duration=120)
87
  def update_agents(self, timestep):
88
  noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)]
89
 
 
100
  # Clip values
101
  agent.position = np.clip(agent.position, -1, 1)
102
 
103
+ @spaces.GPU(duration=120)
104
  def generate_image(self):
105
  self.generated_image = np.mean([agent.position for agent in self.agents], axis=0)
106
  # Normalize to [0, 1] range for display
 
156
  self.generated_image = model_state['generated_image']
157
  self.current_epoch = model_state['current_epoch']
158
 
159
+ @spaces.GPU(duration=120)
160
  def generate_new_image(self, num_steps=1000):
161
  for agent in self.agents:
162
  agent.position = np.random.randn(*self.image_shape)
 
167
  self.generate_image()
168
  return self.generated_image
169
 
170
+ @spaces.GPU(duration=60)
171
  def apply_super_resolution(self, image):
172
  import cv2
173
  sr = cv2.dnn_superres.DnnSuperResImpl_create()
 
179
  return upscaled / 255.0
180
 
181
  # Gradio Interface
182
+ @spaces.GPU(duration=120)
183
  def train_snn(image, num_agents, epochs, brightness, contrast, color):
184
  snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(256, 256, 3), target_image=image)
185
 
 
195
  upscaled_image = snn.apply_super_resolution(generated_image)
196
  return upscaled_image
197
 
198
+ @spaces.GPU(duration=120)
199
  def generate_new_image():
200
  snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(256, 256, 3), target_image=None)
201
  snn.load_model('snn_model.npy')