## DOWNLOADING DATASETS

In [None]:
!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/01_images.zip
!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/01_labels.zip
!unzip /kaggle/working/01_images.zip
!unzip /kaggle/working/01_labels.zip
!rm /kaggle/working/01_images.zip
!rm /kaggle/working/01_labels.zip

In [None]:
!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/02_images.zip
!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/02_labels.zip
!unzip /kaggle/working/02_images.zip
!unzip /kaggle/working/02_labels.zip
!rm /kaggle/working/02_images.zip
!rm /kaggle/working/02_labels.zip

## IN-MEMORY DOWNLOAD AND ZIPPING: REDUCING SECONDARY STORAGE FOOTPRINT & OVERHEAD

In [None]:
import requests
import zipfile
import io

# URL of the zip file
url = 'https://download.visinf.tu-darmstadt.de/data/from_games/data/03_labels.zip'

# Create a streaming GET request
response = requests.get(url, stream=True)

# Create an in-memory bytes buffer for the zip file
zip_buffer = io.BytesIO()

# Download the file in chunks and write to the in-memory buffer
for chunk in response.iter_content(chunk_size=1024):
    if chunk:
        zip_buffer.write(chunk)

# Seek to the beginning of the buffer
zip_buffer.seek(0)

# Open the zip file in memory
with zipfile.ZipFile(zip_buffer, 'r') as zip_ref:
    # Extract all files directly to your desired directory
    zip_ref.extractall('/kaggle/working/')

# No need to explicitly delete the zip file, it's only in memory

In [None]:
import os

_, _, files = next(os.walk("/kaggle/working/labels/"))
print(len(files))

In [None]:
import torch
import sys
print('__Python VERSION:', sys.version)
print('__pyTorch VERSION:', torch.__version__)
print('__CUDA VERSION')
from subprocess import call
# call(["nvcc", "--version"]) does not work
! nvcc --version
print('__CUDNN VERSION:', torch.backends.cudnn.version())
print('__Number CUDA Devices:', torch.cuda.device_count())
print('__Devices')
call(["nvidia-smi", "--format=csv", "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"])
print('Active CUDA Device: GPU', torch.cuda.current_device())
print ('Available devices ', torch.cuda.device_count())
print ('Current cuda device ', torch.cuda.current_device())

## DEFINING STYLE TRANSFER VISION TRANSFORMER ARCHITECTURE

In [143]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=8, emb_dim=768, img_size=256):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_encoding = nn.Parameter(torch.zeros(1, num_patches, emb_dim))

    def forward(self, x):
        # Shape of x: (batch_size, channels, height, width)
        batch_size = x.shape[0]
        x = self.proj(x)  # (batch_size, emb_dim, H/P, W/P)
        x = x.flatten(2)  # Flatten spatial dimensions
        x = x.transpose(1, 2)  # (batch_size, num_patches, emb_dim)
        x += self.positional_encoding  # Add positional encoding
        return x

In [144]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, emb_dim=768, num_heads=8, hidden_dim=2048, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.attn = LocationBasedMultiheadAttention(emb_dim, num_heads, dropout=dropout)
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, emb_dim),
        )
        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.adain = AdaIN(emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, style):
        attn_output = self.attn(x, x, x)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        x = self.adain(x, style)
        intermediate = x  # This is the intermediate representation for the skip connection
        ff_output = self.ff(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)
        x = self.adain(x, style)
        return x, intermediate
    
class TransformerDecoderBlock(nn.Module):
    def __init__(self, emb_dim=768, num_heads=8, hidden_dim=2048, dropout=0.1):
        super(TransformerDecoderBlock, self).__init__()
        self.attn1 = LocationBasedMultiheadAttention(emb_dim, num_heads, dropout=dropout)
        self.attn2 = LocationBasedMultiheadAttention(emb_dim, num_heads, dropout=dropout)
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, emb_dim),
        )
        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.norm3 = nn.LayerNorm(emb_dim)
        self.norm4 = nn.LayerNorm(emb_dim)
        self.adain1 = AdaIN(emb_dim)
        self.adain2 = AdaIN(emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, skip_connection, style):
        attn_output1 = self.attn1(x, x, x)
        x = x + self.dropout(attn_output1)
        x = self.norm1(x)
        x = self.adain1(x, style)
        attn_output2 = self.attn2(x, enc_output, enc_output)
        x = x + self.dropout(attn_output2)
        x = self.norm2(x)
        x = self.adain2(x, style)
        ff_output = self.ff(x)
        x = x + self.dropout(ff_output)
        x = self.norm3(x)
        x = x + self.norm4(skip_connection)
        return x

In [145]:
class LocationBasedMultiheadAttention(nn.Module):
    def __init__(self, emb_dim, num_heads, dropout=0.1):
        super(LocationBasedMultiheadAttention, self).__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads
        
        self.q_proj = nn.Linear(emb_dim, emb_dim)
        self.k_proj = nn.Linear(emb_dim, emb_dim)
        self.v_proj = nn.Linear(emb_dim, emb_dim)
        self.out_proj = nn.Linear(emb_dim, emb_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        # Learnable position encodings
        self.pos_enc = nn.Parameter(torch.randn(1, 1, emb_dim))

    def forward(self, q, k, v):
        batch_size, seq_len, _ = q.shape
        
        # Add positional encodings
        q = q + self.pos_enc
        k = k + self.pos_enc
        
        # Separate heads
        q = self.q_proj(q).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(k).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(v).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = F.softmax(scores, dim=-1)
        attn_probs = self.dropout(attn_probs)
        
        # Compute output
        out = torch.matmul(attn_probs, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.emb_dim)
        out = self.out_proj(out)
        
        return out

In [146]:
class AdaIN(nn.Module):
    def __init__(self, emb_dim):
        super(AdaIN, self).__init__()
        self.norm = nn.InstanceNorm1d(emb_dim, affine=False)
        self.fc = nn.Linear(emb_dim, emb_dim * 2)

    def forward(self, x, style):
        style = self.fc(style).unsqueeze(1)
        gamma, beta = style.chunk(2, dim=-1)
        out = self.norm(x.transpose(1, 2)).transpose(1, 2)
        out = gamma * out + beta
        return out

In [147]:
class RefinementBlock(nn.Module):
    def __init__(self, in_channels=768, out_channels=3, kernel_size=3, stride=1, padding=1):
        super(RefinementBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

In [148]:
# !pip install einops

In [149]:
import torch.nn.functional as F
from einops import rearrange

# Swin Transformer Block
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=7, shift_size=2):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size

        # Layer normalization
        self.norm1 = nn.LayerNorm(dim)
        
        # Window-based multi-head self-attention (W-MSA) or shifted W-MSA
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)

        # Feed-forward network
        self.mlp = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.GELU(),
            nn.Linear(4 * dim, dim)
        )

        # Another layer normalization
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x, mask=None):
        # Input size
        B, H, W, C = x.shape

        # Partition windows
        x = rearrange(x, 'b (h ws1) (w ws2) c -> (b h w) (ws1 ws2) c', ws1=self.window_size, ws2=self.window_size)

        # Add cyclic shift if shift_size > 0
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

        # Layer norm
        shortcut = x
        x = self.norm1(x)
        
        # Multi-head self-attention
        x, _ = self.attn(x, x, x, attn_mask=mask)
        
        # Residual connection
        x = shortcut + x

        # MLP with another residual connection
        shortcut = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = shortcut + x

        # Reverse window partitioning
        x = rearrange(x, '(b h w) (ws1 ws2) c -> b (h ws1) (w ws2) c', h=H // self.window_size, w=W // self.window_size, ws1=self.window_size, ws2=self.window_size)

        return x

In [150]:
class ViTImage2Image(nn.Module):
    def __init__(self, img_size=512, patch_size=8, emb_dim=768, num_heads=12, num_layers=12, hidden_dim=3072, window_size=8):
        super(ViTImage2Image, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.window_size = window_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.patch_embed = nn.Conv2d(3, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, emb_dim))
        
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderBlock(emb_dim, num_heads, hidden_dim, dropout=0.1)
            for _ in range(num_layers)
        ])
        
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderBlock(emb_dim, num_heads, hidden_dim, dropout=0.1)
            for _ in range(num_layers)
        ])
        
        self.swin_layers = nn.ModuleList([
            SwinTransformerBlock(emb_dim, num_heads, window_size=window_size)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(emb_dim)
        self.mlp_head = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, emb_dim)
        )
        self.refinement = RefinementBlock(in_channels=emb_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
        
        self.style_encoder = nn.Sequential(
            nn.Conv2d(3, emb_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(emb_dim, emb_dim)
        )

    def forward(self, content, style):
        x = self.patch_embed(content)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = x + self.pos_embed
        
        style_features = self.style_encoder(style)
        
        skip_connections = []
        # Transformer encoder
        for encoder in self.encoder_layers:
            x, skip = encoder(x, style_features)
            skip_connections.append(skip)
        
        # Transformer decoder
        for decoder, skip in zip(self.decoder_layers, reversed(skip_connections)):
            x = decoder(x, x, skip, style_features)
        
        # Swin Transformer encoding
        x = x.view(B, H, W, C)
        for layer in self.swin_layers:
            x = layer(x)
        
        x = x.view(B, -1, C)
        x = self.norm(x)
        x = self.mlp_head(x)
        x = x.transpose(1, 2).view(B, self.emb_dim, H, W)
        x = self.refinement(x)
        x = nn.functional.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
        return x

## DEFINE LOSS METRICS

In [151]:
def total_variation_loss(x):
    return torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))

def combined_loss(output, target):
    l1_loss = nn.L1Loss()(output, target)
    tv_loss = total_variation_loss(output)
    return l1_loss + 0.0001 * tv_loss

## DEFINE TRAINING METRICS

In [152]:
def psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

## DATALOADERS

In [181]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Custom Dataset class
class ImageLabelDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.image_filenames = sorted(os.listdir(images_dir))
        self.label_filenames = sorted(os.listdir(labels_dir))
        self.transform = transform

        assert len(self.image_filenames) == len(self.label_filenames), "Mismatch in number of images and labels"

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.images_dir, self.image_filenames[idx])
        label_path = os.path.join(self.labels_dir, self.label_filenames[idx])
        
        # Open image and label using PIL
        image = Image.open(image_path).convert('RGB')  # Ensure image is RGB
        label = Image.open(label_path).convert('RGB')  # Ensure label is RGB (or use 'L' for grayscale)

        if self.transform:
            # Apply the same transformation to both the image and the label
            image = self.transform(image)
            label = self.transform(label)

        return image, label

resize_dim = 512 # resize to dimensions
batch_size = 4  # Adjust batch size according to your GPU memory capacity
    
# Define transformations (resize, convert to tensor, normalize)
transform = transforms.Compose([
    transforms.Resize((resize_dim, resize_dim)),  # Resize images
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1] for each channel
])

# Create dataset instances for training
images_dir = '/kaggle/working/images/'  # Replace with your image directory path
labels_dir = '/kaggle/working/labels/'  # Replace with your label directory path
dataset = ImageLabelDataset(images_dir, labels_dir, transform=transform)

# Create DataLoader with batch size control
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

## TRAINING SCRIPT

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# Instantiate the model
model = ViTImage2Image(img_size=resize_dim, patch_size=16, emb_dim=768, num_heads=16, num_layers=8, hidden_dim=3072)

# Move model to the appropriate device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# If using 2 x T4 GPU only
model = nn.DataParallel(model, device_ids = [0,1])

# Move model to device
model = model.to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Number of epochs
num_epochs = 100

best_loss = float('inf')

# Path to save the best model
save_dir = 'saved_models'
os.makedirs(save_dir, exist_ok=True)  # Create directory if it doesn't exist
best_model_path = os.path.join(save_dir, 'best_model.pth')

# Initialize best loss as a large value
best_loss = float('inf')

# Training loop with tqdm
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_psnr = 0.0
    
    # Wrap dataloader with tqdm for progress bar
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (input, target) in pbar:
        # Move data to the same device as the model
        input, target = input.to(device), target.to(device)
        
        optimizer.zero_grad()  # Clear the gradients from the last step
        output = model(input, target)  # Forward pass
#         print(f"Input shape: {input.shape}, Output shape: {output.shape}, Target shape: {target.shape}")
        loss = combined_loss(output, target)  # Compute the loss
        
        loss.backward()  # Backward pass (compute gradients)
        optimizer.step()  # Update the weights
        
        running_loss += loss.item()  # Accumulate the loss for this batch
        
        # Calculate metrics
        current_psnr = psnr(output, target).item()
        running_psnr += current_psnr
        
        # Update progress bar
        pbar.set_postfix({
            'loss': loss.item(),
            'psnr': current_psnr,
        })
    
    # Calculate the average loss and metrics for the epoch
    epoch_loss = running_loss / len(dataloader)
    avg_psnr = running_psnr / len(dataloader)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, PSNR: {avg_psnr:.4f}")
    
    # Save the best model based on the lowest loss
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved at epoch {epoch+1} with loss {best_loss:.4f}")

print("Training complete")

## CLEAR CUDA CACHE

In [193]:
import gc
torch.cuda.empty_cache()
gc.collect()
#RuntimeError: The size of tensor a (768) must match the size of tensor b (512) at non-singleton dimension 3

0

In [180]:
total_params = sum(p.numel() for p in model.parameters())
print(total_params)
print(model)

223543305
DataParallel(
  (module): ViTImage2Image(
    (patch_embed): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder_layers): ModuleList(
      (0-7): 8 x TransformerEncoderBlock(
        (attn): LocationBasedMultiheadAttention(
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): ReLU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (adain): AdaIN(
          (norm): InstanceNorm1d(

## INFERENCE

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

def load_image(image_path, img_size=1024):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    return transform(image).unsqueeze(0)

def perform_inference(model, input_image):
    model.eval()
    with torch.no_grad():
        output = model(input_image)
    return output

def visualize_tensor(tensor, title):
    img = tensor.squeeze().cpu().permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0, 1]
    plt.imshow(img)
    plt.title(title)
    plt.axis('off')
    plt.show()

def main():
    # Load the model (adjust parameters as needed)
    model = ViTImage2Image(img_size=1024, patch_size=8, emb_dim=768, num_heads=12, num_layers=12, hidden_dim=3072)
    model.load_state_dict(torch.load('/kaggle/working/saved_models/best_model.pth'))
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Load and preprocess input image
    input_image = load_image('/kaggle/working/labels/00143.png')
    input_image = input_image.to(device)
    
    # Perform inference
    output = perform_inference(model, input_image)
    
    # Debug: Print shape and statistics of the output
    print("Output shape:", output.shape)
    print("Output min:", output.min().item())
    print("Output max:", output.max().item())
    print("Output mean:", output.mean().item())
    print("Output std:", output.std().item())
    
    # Visualize input and output
    visualize_tensor(input_image, "Input Image")
    visualize_tensor(output, "Output Image (Before Processing)")
    
    # Try different post-processing steps
    # 1. Sigmoid activation
    output_sigmoid = torch.sigmoid(output)
    visualize_tensor(output_sigmoid, "Output Image (After Sigmoid)")
    
    # 2. Tanh activation
    output_tanh = torch.tanh(output)
    visualize_tensor(output_tanh, "Output Image (After Tanh)")
    
    # 3. Min-Max Normalization
    output_normalized = (output - output.min()) / (output.max() - output.min())
    visualize_tensor(output_normalized, "Output Image (After Min-Max Normalization)")

if __name__ == "__main__":
    main()