Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input | |
from keras.models import Model | |
import matplotlib.pyplot as plt | |
import logging | |
from skimage.transform import resize | |
from PIL import Image, ImageEnhance, ImageFilter | |
from tqdm import tqdm | |
class SwarmAgent: | |
def __init__(self, position, velocity): | |
self.position = position | |
self.velocity = velocity | |
self.m = np.zeros_like(position) | |
self.v = np.zeros_like(position) | |
class SwarmNeuralNetwork: | |
def __init__(self, num_agents, image_shape, target_image): | |
self.image_shape = image_shape | |
self.resized_shape = (128, 128, 3) # Increased resolution | |
self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)] | |
self.target_image = self.load_target_image(target_image) | |
self.generated_image = np.random.randn(*image_shape) # Start with noise | |
self.mobilenet = self.load_mobilenet_model() | |
self.current_epoch = 0 | |
self.noise_schedule = np.linspace(0.1, 0.002, 1000) # Noise schedule | |
def random_position(self): | |
return np.random.randn(*self.image_shape) # Use Gaussian noise | |
def random_velocity(self): | |
return np.random.randn(*self.image_shape) * 0.01 | |
def load_target_image(self, img): | |
img = img.resize((self.image_shape[1], self.image_shape[0])) | |
img_array = np.array(img) / 127.5 - 1 # Normalize to [-1, 1] | |
plt.imshow((img_array + 1) / 2) # Convert back to [0, 1] for display | |
plt.title('Target Image') | |
plt.show() | |
return img_array | |
def resize_image(self, image): | |
return resize(image, self.resized_shape, anti_aliasing=True) | |
def load_mobilenet_model(self): | |
mobilenet = MobileNetV2(weights='imagenet', include_top=False, input_shape=self.resized_shape) | |
return Model(inputs=mobilenet.input, outputs=mobilenet.get_layer('block_13_expand_relu').output) | |
def add_positional_encoding(self, image): | |
h, w, c = image.shape | |
pos_enc = np.zeros_like(image) | |
for i in range(h): | |
for j in range(w): | |
pos_enc[i, j, :] = [i/h, j/w, 0] | |
return image + pos_enc | |
def multi_head_attention(self, agent, num_heads=4): | |
attention_scores = [] | |
for _ in range(num_heads): | |
similarity = np.exp(-np.sum((agent.position - self.target_image)**2, axis=-1)) | |
attention_score = similarity / np.sum(similarity) | |
attention_scores.append(attention_score) | |
attention = np.mean(attention_scores, axis=0) | |
return np.expand_dims(attention, axis=-1) | |
def multi_scale_perceptual_loss(self, agent_positions): | |
target_image_resized = self.resize_image((self.target_image + 1) / 2) # Convert to [0, 1] for MobileNet | |
target_image_preprocessed = preprocess_input(target_image_resized[np.newaxis, ...] * 255) # MobileNet expects [0, 255] | |
target_features = self.mobilenet.predict(target_image_preprocessed) | |
losses = [] | |
for agent_position in agent_positions: | |
agent_image_resized = self.resize_image((agent_position + 1) / 2) | |
agent_image_preprocessed = preprocess_input(agent_image_resized[np.newaxis, ...] * 255) | |
agent_features = self.mobilenet.predict(agent_image_preprocessed) | |
loss = np.mean((target_features - agent_features)**2) | |
losses.append(1 / (1 + loss)) | |
return np.array(losses) | |
def update_agents(self, timestep): | |
noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)] | |
for agent in self.agents: | |
# Predict noise | |
predicted_noise = agent.position - self.target_image | |
# Denoise | |
denoised = (agent.position - noise_level * predicted_noise) / (1 - noise_level) | |
# Add scaled noise for next step | |
agent.position = denoised + np.random.randn(*self.image_shape) * np.sqrt(noise_level) | |
# Clip values | |
agent.position = np.clip(agent.position, -1, 1) | |
def generate_image(self): | |
self.generated_image = np.mean([agent.position for agent in self.agents], axis=0) | |
# Normalize to [0, 1] range for display | |
self.generated_image = (self.generated_image + 1) / 2 | |
self.generated_image = np.clip(self.generated_image, 0, 1) | |
# Apply sharpening filter | |
image_pil = Image.fromarray((self.generated_image * 255).astype(np.uint8)) | |
image_pil = image_pil.filter(ImageFilter.SHARPEN) | |
image_pil = image_pil.filter(ImageFilter.DETAIL) | |
self.generated_image = np.array(image_pil) / 255.0 | |
def train(self, epochs): | |
logging.basicConfig(filename='training.log', level=logging.INFO) | |
for epoch in tqdm(range(epochs), desc="Training Epochs"): | |
self.update_agents(epoch) | |
self.generate_image() | |
mse = np.mean(((self.generated_image * 2 - 1) - self.target_image)**2) | |
logging.info(f"Epoch {epoch}, MSE: {mse}") | |
if epoch % 10 == 0: | |
print(f"Epoch {epoch}, MSE: {mse}") | |
self.display_image(self.generated_image, title=f'Epoch {epoch}') | |
self.current_epoch += 1 | |
def display_image(self, image, title=''): | |
plt.imshow(image) | |
plt.title(title) | |
plt.axis('off') | |
plt.show() | |
def display_agent_positions(self, epoch): | |
fig, ax = plt.subplots() | |
positions = np.array([agent.position for agent in self.agents]) | |
ax.imshow(self.generated_image, extent=[0, self.image_shape[1], 0, self.image_shape[0]]) | |
ax.scatter(positions[:, :, 0].flatten(), positions[:, :, 1].flatten(), s=1, c='red') | |
plt.title(f'Agent Positions at Epoch {epoch}') | |
plt.show() | |
def save_model(self, filename): | |
model_state = { | |
'agents': self.agents, | |
'generated_image': self.generated_image, | |
'current_epoch': self.current_epoch | |
} | |
np.save(filename, model_state) | |
def load_model(self, filename): | |
model_state = np.load(filename, allow_pickle=True).item() | |
self.agents = model_state['agents'] | |
self.generated_image = model_state['generated_image'] | |
self.current_epoch = model_state['current_epoch'] | |
def generate_new_image(self, num_steps=1000): | |
for agent in self.agents: | |
agent.position = np.random.randn(*self.image_shape) | |
for step in tqdm(range(num_steps), desc="Generating Image"): | |
self.update_agents(num_steps - step - 1) # Reverse order | |
self.generate_image() | |
return self.generated_image | |
# Gradio Interface | |
def train_snn(image, num_agents, epochs, brightness, contrast, color): | |
snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(128, 128, 3), target_image=image) | |
# Apply user-specified adjustments to the target image | |
image = ImageEnhance.Brightness(image).enhance(brightness) | |
image = ImageEnhance.Contrast(image).enhance(contrast) | |
image = ImageEnhance.Color(image).enhance(color) | |
snn.target_image = snn.load_target_image(image) | |
snn.train(epochs=epochs) | |
snn.save_model('snn_model.npy') | |
return snn.generated_image | |
def generate_new_image(): | |
snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(128, 128, 3), target_image=None) | |
snn.load_model('snn_model.npy') | |
new_image = snn.generate_new_image() | |
return new_image | |
interface = gr.Interface( | |
fn=train_snn, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Target Image"), | |
gr.Slider(minimum=500, maximum=3000, value=2000, label="Number of Agents"), | |
gr.Slider(minimum=10, maximum=200, value=100, label="Number of Epochs"), | |
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Brightness"), | |
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Contrast"), | |
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Color Balance") | |
], | |
outputs=gr.Image(type="numpy", label="Generated Image"), | |
title="Swarm Neural Network Image Generation", | |
description="Upload an image and set the number of agents and epochs to train the Swarm Neural Network to generate a new image. Adjust brightness, contrast, and color balance for personalization." | |
) | |
interface.launch() | |