TuringsSolutions commited on
Commit
3cb8cd6
·
verified ·
1 Parent(s): 2bfc99d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
6
+ from keras.models import Model
7
+ import matplotlib.pyplot as plt
8
+ import logging
9
+ from skimage.transform import resize
10
+ from PIL import Image, ImageEnhance, ImageFilter
11
+ from tqdm import tqdm
12
+
13
+ # Disable GPU usage by default
14
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
15
+
16
+ class SwarmAgent:
17
+ def __init__(self, position, velocity):
18
+ self.position = position
19
+ self.velocity = velocity
20
+ self.m = np.zeros_like(position)
21
+ self.v = np.zeros_like(position)
22
+
23
+ class SwarmNeuralNetwork:
24
+ def __init__(self, num_agents, image_shape, target_image):
25
+ self.image_shape = image_shape
26
+ self.resized_shape = (256, 256, 3) # High resolution
27
+ self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
28
+ self.target_image = self.load_target_image(target_image)
29
+ self.generated_image = np.random.randn(*image_shape) # Start with noise
30
+ self.mobilenet = self.load_mobilenet_model()
31
+ self.current_epoch = 0
32
+ self.noise_schedule = np.linspace(0.1, 0.002, 1000) # Noise schedule
33
+
34
+ def random_position(self):
35
+ return np.random.randn(*self.image_shape) # Use Gaussian noise
36
+
37
+ def random_velocity(self):
38
+ return np.random.randn(*self.image_shape) * 0.01
39
+
40
+ def load_target_image(self, img_path):
41
+ img = Image.open(img_path)
42
+ img = img.resize((self.image_shape[1], self.image_shape[0]))
43
+ img_array = np.array(img) / 127.5 - 1 # Normalize to [-1, 1]
44
+ plt.imshow((img_array + 1) / 2) # Convert back to [0, 1] for display
45
+ plt.title('Target Image')
46
+ plt.show()
47
+ return img_array
48
+
49
+ def resize_image(self, image):
50
+ return resize(image, self.resized_shape, anti_aliasing=True)
51
+
52
+ def load_mobilenet_model(self):
53
+ mobilenet = MobileNetV2(weights='imagenet', include_top=False, input_shape=self.resized_shape)
54
+ return Model(inputs=mobilenet.input, outputs=mobilenet.get_layer('block_13_expand_relu').output)
55
+
56
+ def add_positional_encoding(self, image):
57
+ h, w, c = image.shape
58
+ pos_enc = np.zeros_like(image)
59
+ for i in range(h):
60
+ for j in range(w):
61
+ pos_enc[i, j, :] = [i/h, j/w, 0]
62
+ return image + pos_enc
63
+
64
+ def multi_head_attention(self, agent, num_heads=4):
65
+ attention_scores = []
66
+ for _ in range(num_heads):
67
+ similarity = np.exp(-np.sum((agent.position - self.target_image)**2, axis=-1))
68
+ attention_score = similarity / np.sum(similarity)
69
+ attention_scores.append(attention_score)
70
+ attention = np.mean(attention_scores, axis=0)
71
+ return np.expand_dims(attention, axis=-1)
72
+
73
+ def multi_scale_perceptual_loss(self, agent_positions):
74
+ target_image_resized = self.resize_image((self.target_image + 1) / 2) # Convert to [0, 1] for MobileNet
75
+ target_image_preprocessed = preprocess_input(target_image_resized[np.newaxis, ...] * 255) # MobileNet expects [0, 255]
76
+ target_features = self.mobilenet.predict(target_image_preprocessed)
77
+
78
+ losses = []
79
+ for agent_position in agent_positions:
80
+ agent_image_resized = self.resize_image((agent_position + 1) / 2)
81
+ agent_image_preprocessed = preprocess_input(agent_image_resized[np.newaxis, ...] * 255)
82
+ agent_features = self.mobilenet.predict(agent_image_preprocessed)
83
+
84
+ loss = np.mean((target_features - agent_features)**2)
85
+ losses.append(1 / (1 + loss))
86
+
87
+ return np.array(losses)
88
+
89
+ def update_agents(self, timestep):
90
+ noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)]
91
+
92
+ for agent in self.agents:
93
+ # Predict noise
94
+ predicted_noise = agent.position - self.target_image
95
+
96
+ # Denoise
97
+ denoised = (agent.position - noise_level * predicted_noise) / (1 - noise_level)
98
+
99
+ # Add scaled noise for next step
100
+ agent.position = denoised + np.random.randn(*self.image_shape) * np.sqrt(noise_level)
101
+
102
+ # Clip values
103
+ agent.position = np.clip(agent.position, -1, 1)
104
+
105
+ def generate_image(self):
106
+ self.generated_image = np.mean([agent.position for agent in self.agents], axis=0)
107
+ # Normalize to [0, 1] range for display
108
+ self.generated_image = (self.generated_image + 1) / 2
109
+ self.generated_image = np.clip(self.generated_image, 0, 1)
110
+
111
+ # Apply sharpening filter
112
+ image_pil = Image.fromarray((self.generated_image * 255).astype(np.uint8))
113
+ image_pil = image_pil.filter(ImageFilter.SHARPEN)
114
+ self.generated_image = np.array(image_pil) / 255.0
115
+
116
+ def train(self, epochs):
117
+ logging.basicConfig(filename='training.log', level=logging.INFO)
118
+
119
+ for epoch in tqdm(range(epochs), desc="Training Epochs"):
120
+ self.update_agents(epoch)
121
+ self.generate_image()
122
+
123
+ mse = np.mean(((self.generated_image * 2 - 1) - self.target_image)**2)
124
+ logging.info(f"Epoch {epoch}, MSE: {mse}")
125
+
126
+ if epoch % 5 == 0:
127
+ print(f"Epoch {epoch}, MSE: {mse}")
128
+ self.display_image(self.generated_image, title=f'Epoch {epoch}')
129
+ self.current_epoch += 1
130
+
131
+ def display_image(self, image, title=''):
132
+ plt.imshow(image)
133
+ plt.title(title)
134
+ plt.axis('off')
135
+ plt.show()
136
+
137
+ def display_agent_positions(self, epoch):
138
+ fig, ax = plt.subplots()
139
+ positions = np.array([agent.position for agent in self.agents])
140
+ ax.imshow(self.generated_image, extent=[0, self.image_shape[1], 0, self.image_shape[0]])
141
+ ax.scatter(positions[:, :, 0].flatten(), positions[:, :, 1].flatten(), s=1, c='red')
142
+ plt.title(f'Agent Positions at Epoch {epoch}')
143
+ plt.show()
144
+
145
+ def save_model(self, filename):
146
+ model_state = {
147
+ 'agents': self.agents,
148
+ 'generated_image': self.generated_image,
149
+ 'current_epoch': self.current_epoch
150
+ }
151
+ np.save(filename, model_state)
152
+
153
+ def load_model(self, filename):
154
+ model_state = np.load(filename, allow_pickle=True).item()
155
+ self.agents = model_state['agents']
156
+ self.generated_image = model_state['generated_image']
157
+ self.current_epoch = model_state['current_epoch']
158
+
159
+ def generate_new_image(self, num_steps=500): # Optimized number of steps
160
+ for agent in self.agents:
161
+ agent.position = np.random.randn(*self.image_shape)
162
+
163
+ for step in tqdm(range(num_steps), desc="Generating Image"):
164
+ self.update_agents(num_steps - step - 1) # Reverse order
165
+
166
+ self.generate_image()
167
+ return self.generated_image
168
+
169
+ # Gradio Interface
170
+ def train_snn(image_path, num_agents, epochs, arm_position, leg_position, brightness, contrast, color):
171
+ snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(256, 256, 3), target_image=image_path) # High resolution
172
+
173
+ # Apply user-specified adjustments to the target image
174
+ image = Image.open(image_path)
175
+ image = ImageEnhance.Brightness(image).enhance(brightness)
176
+ image = ImageEnhance.Contrast(image).enhance(contrast)
177
+ image = ImageEnhance.Color(image).enhance(color)
178
+
179
+ # Mock adjustment for arm and leg positions (to be implemented with actual logic)
180
+ # For now, we just log the values
181
+ print(f"Adjusting arm position: {arm_position}, leg position: {leg_position}")
182
+
183
+ snn.target_image = snn.load_target_image(image)
184
+ snn.train(epochs=epochs)
185
+ snn.save_model('snn_model.npy')
186
+ generated_image = snn.generated_image
187
+ return generated_image
188
+
189
+ def generate_new_image():
190
+ snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(256, 256, 3), target_image=None) # High resolution and optimized number of agents
191
+ snn.load_model('snn_model.npy')
192
+ new_image = snn.generate_new_image()
193
+ return new_image
194
+
195
+ interface = gr.Interface(
196
+ fn=train_snn,
197
+ inputs=[
198
+ gr.Image(type="filepath", label