{"cells":[{"cell_type":"markdown","metadata":{},"source":["## DOWNLOADING DATASETS"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/01_images.zip\n","!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/01_labels.zip\n","!unzip /kaggle/working/01_images.zip\n","!unzip /kaggle/working/01_labels.zip\n","!rm /kaggle/working/01_images.zip\n","!rm /kaggle/working/01_labels.zip"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/02_images.zip\n","!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/02_labels.zip\n","!unzip /kaggle/working/02_images.zip\n","!unzip /kaggle/working/02_labels.zip\n","!rm /kaggle/working/02_images.zip\n","!rm /kaggle/working/02_labels.zip"]},{"cell_type":"markdown","metadata":{},"source":["## IN-MEMORY DOWNLOAD AND ZIPPING: REDUCING SECONDARY STORAGE FOOTPRINT & OVERHEAD"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["import requests\n","import zipfile\n","import io\n","\n","# URL of the zip file\n","url = 'https://download.visinf.tu-darmstadt.de/data/from_games/data/03_labels.zip'\n","\n","# Create a streaming GET request\n","response = requests.get(url, stream=True)\n","\n","# Create an in-memory bytes buffer for the zip file\n","zip_buffer = io.BytesIO()\n","\n","# Download the file in chunks and write to the in-memory buffer\n","for chunk in response.iter_content(chunk_size=1024):\n"," if chunk:\n"," zip_buffer.write(chunk)\n","\n","# Seek to the beginning of the buffer\n","zip_buffer.seek(0)\n","\n","# Open the zip file in memory\n","with zipfile.ZipFile(zip_buffer, 'r') as zip_ref:\n"," # Extract all files directly to your desired directory\n"," zip_ref.extractall('/kaggle/working/')\n","\n","# No need to explicitly delete the zip file, it's only in memory"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["import os\n","\n","_, _, files = next(os.walk(\"/kaggle/working/labels/\"))\n","print(len(files))"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["import torch\n","import sys\n","print('__Python VERSION:', sys.version)\n","print('__pyTorch VERSION:', torch.__version__)\n","print('__CUDA VERSION')\n","from subprocess import call\n","# call([\"nvcc\", \"--version\"]) does not work\n","! nvcc --version\n","print('__CUDNN VERSION:', torch.backends.cudnn.version())\n","print('__Number CUDA Devices:', torch.cuda.device_count())\n","print('__Devices')\n","call([\"nvidia-smi\", \"--format=csv\", \"--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free\"])\n","print('Active CUDA Device: GPU', torch.cuda.current_device())\n","print ('Available devices ', torch.cuda.device_count())\n","print ('Current cuda device ', torch.cuda.current_device())"]},{"cell_type":"markdown","metadata":{},"source":["## DEFINING STYLE TRANSFER VISION TRANSFORMER ARCHITECTURE"]},{"cell_type":"code","execution_count":143,"metadata":{"_kg_hide-output":true,"execution":{"iopub.execute_input":"2024-09-21T23:16:41.009559Z","iopub.status.busy":"2024-09-21T23:16:41.008700Z","iopub.status.idle":"2024-09-21T23:16:41.019269Z","shell.execute_reply":"2024-09-21T23:16:41.017876Z","shell.execute_reply.started":"2024-09-21T23:16:41.009514Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import math\n","\n","class PatchEmbedding(nn.Module):\n"," def __init__(self, in_channels=3, patch_size=8, emb_dim=768, img_size=256):\n"," super(PatchEmbedding, self).__init__()\n"," self.patch_size = patch_size\n"," num_patches = (img_size // patch_size) ** 2\n"," self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)\n"," self.positional_encoding = nn.Parameter(torch.zeros(1, num_patches, emb_dim))\n","\n"," def forward(self, x):\n"," # Shape of x: (batch_size, channels, height, width)\n"," batch_size = x.shape[0]\n"," x = self.proj(x) # (batch_size, emb_dim, H/P, W/P)\n"," x = x.flatten(2) # Flatten spatial dimensions\n"," x = x.transpose(1, 2) # (batch_size, num_patches, emb_dim)\n"," x += self.positional_encoding # Add positional encoding\n"," return x"]},{"cell_type":"code","execution_count":144,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.036391Z","iopub.status.busy":"2024-09-21T23:16:41.035657Z","iopub.status.idle":"2024-09-21T23:16:41.055122Z","shell.execute_reply":"2024-09-21T23:16:41.053886Z","shell.execute_reply.started":"2024-09-21T23:16:41.036334Z"},"trusted":true},"outputs":[],"source":["class TransformerEncoderBlock(nn.Module):\n"," def __init__(self, emb_dim=768, num_heads=8, hidden_dim=2048, dropout=0.1):\n"," super(TransformerEncoderBlock, self).__init__()\n"," self.attn = LocationBasedMultiheadAttention(emb_dim, num_heads, dropout=dropout)\n"," self.ff = nn.Sequential(\n"," nn.Linear(emb_dim, hidden_dim),\n"," nn.ReLU(),\n"," nn.Linear(hidden_dim, emb_dim),\n"," )\n"," self.norm1 = nn.LayerNorm(emb_dim)\n"," self.norm2 = nn.LayerNorm(emb_dim)\n"," self.adain = AdaIN(emb_dim)\n"," self.dropout = nn.Dropout(dropout)\n","\n"," def forward(self, x, style):\n"," attn_output = self.attn(x, x, x)\n"," x = x + self.dropout(attn_output)\n"," x = self.norm1(x)\n"," x = self.adain(x, style)\n"," intermediate = x # This is the intermediate representation for the skip connection\n"," ff_output = self.ff(x)\n"," x = x + self.dropout(ff_output)\n"," x = self.norm2(x)\n"," x = self.adain(x, style)\n"," return x, intermediate\n"," \n","class TransformerDecoderBlock(nn.Module):\n"," def __init__(self, emb_dim=768, num_heads=8, hidden_dim=2048, dropout=0.1):\n"," super(TransformerDecoderBlock, self).__init__()\n"," self.attn1 = LocationBasedMultiheadAttention(emb_dim, num_heads, dropout=dropout)\n"," self.attn2 = LocationBasedMultiheadAttention(emb_dim, num_heads, dropout=dropout)\n"," self.ff = nn.Sequential(\n"," nn.Linear(emb_dim, hidden_dim),\n"," nn.ReLU(),\n"," nn.Linear(hidden_dim, emb_dim),\n"," )\n"," self.norm1 = nn.LayerNorm(emb_dim)\n"," self.norm2 = nn.LayerNorm(emb_dim)\n"," self.norm3 = nn.LayerNorm(emb_dim)\n"," self.norm4 = nn.LayerNorm(emb_dim)\n"," self.adain1 = AdaIN(emb_dim)\n"," self.adain2 = AdaIN(emb_dim)\n"," self.dropout = nn.Dropout(dropout)\n","\n"," def forward(self, x, enc_output, skip_connection, style):\n"," attn_output1 = self.attn1(x, x, x)\n"," x = x + self.dropout(attn_output1)\n"," x = self.norm1(x)\n"," x = self.adain1(x, style)\n"," attn_output2 = self.attn2(x, enc_output, enc_output)\n"," x = x + self.dropout(attn_output2)\n"," x = self.norm2(x)\n"," x = self.adain2(x, style)\n"," ff_output = self.ff(x)\n"," x = x + self.dropout(ff_output)\n"," x = self.norm3(x)\n"," x = x + self.norm4(skip_connection)\n"," return x"]},{"cell_type":"code","execution_count":145,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.057990Z","iopub.status.busy":"2024-09-21T23:16:41.057280Z","iopub.status.idle":"2024-09-21T23:16:41.072000Z","shell.execute_reply":"2024-09-21T23:16:41.070842Z","shell.execute_reply.started":"2024-09-21T23:16:41.057950Z"},"trusted":true},"outputs":[],"source":["class LocationBasedMultiheadAttention(nn.Module):\n"," def __init__(self, emb_dim, num_heads, dropout=0.1):\n"," super(LocationBasedMultiheadAttention, self).__init__()\n"," self.emb_dim = emb_dim\n"," self.num_heads = num_heads\n"," self.head_dim = emb_dim // num_heads\n"," \n"," self.q_proj = nn.Linear(emb_dim, emb_dim)\n"," self.k_proj = nn.Linear(emb_dim, emb_dim)\n"," self.v_proj = nn.Linear(emb_dim, emb_dim)\n"," self.out_proj = nn.Linear(emb_dim, emb_dim)\n"," \n"," self.dropout = nn.Dropout(dropout)\n"," \n"," # Learnable position encodings\n"," self.pos_enc = nn.Parameter(torch.randn(1, 1, emb_dim))\n","\n"," def forward(self, q, k, v):\n"," batch_size, seq_len, _ = q.shape\n"," \n"," # Add positional encodings\n"," q = q + self.pos_enc\n"," k = k + self.pos_enc\n"," \n"," # Separate heads\n"," q = self.q_proj(q).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n"," k = self.k_proj(k).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n"," v = self.v_proj(v).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n"," \n"," # Compute attention scores\n"," scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)\n"," attn_probs = F.softmax(scores, dim=-1)\n"," attn_probs = self.dropout(attn_probs)\n"," \n"," # Compute output\n"," out = torch.matmul(attn_probs, v)\n"," out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.emb_dim)\n"," out = self.out_proj(out)\n"," \n"," return out"]},{"cell_type":"code","execution_count":146,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.073857Z","iopub.status.busy":"2024-09-21T23:16:41.073403Z","iopub.status.idle":"2024-09-21T23:16:41.085968Z","shell.execute_reply":"2024-09-21T23:16:41.084912Z","shell.execute_reply.started":"2024-09-21T23:16:41.073817Z"},"trusted":true},"outputs":[],"source":["class AdaIN(nn.Module):\n"," def __init__(self, emb_dim):\n"," super(AdaIN, self).__init__()\n"," self.norm = nn.InstanceNorm1d(emb_dim, affine=False)\n"," self.fc = nn.Linear(emb_dim, emb_dim * 2)\n","\n"," def forward(self, x, style):\n"," style = self.fc(style).unsqueeze(1)\n"," gamma, beta = style.chunk(2, dim=-1)\n"," out = self.norm(x.transpose(1, 2)).transpose(1, 2)\n"," out = gamma * out + beta\n"," return out"]},{"cell_type":"code","execution_count":147,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.088410Z","iopub.status.busy":"2024-09-21T23:16:41.088010Z","iopub.status.idle":"2024-09-21T23:16:41.099022Z","shell.execute_reply":"2024-09-21T23:16:41.097846Z","shell.execute_reply.started":"2024-09-21T23:16:41.088350Z"},"trusted":true},"outputs":[],"source":["class RefinementBlock(nn.Module):\n"," def __init__(self, in_channels=768, out_channels=3, kernel_size=3, stride=1, padding=1):\n"," super(RefinementBlock, self).__init__()\n"," self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)\n"," self.bn = nn.BatchNorm2d(out_channels)\n"," self.relu = nn.ReLU(inplace=True)\n"," \n"," def forward(self, x):\n"," x = self.conv(x)\n"," x = self.bn(x)\n"," x = self.relu(x)\n"," return x"]},{"cell_type":"code","execution_count":148,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.101869Z","iopub.status.busy":"2024-09-21T23:16:41.101083Z","iopub.status.idle":"2024-09-21T23:16:41.107762Z","shell.execute_reply":"2024-09-21T23:16:41.106685Z","shell.execute_reply.started":"2024-09-21T23:16:41.101815Z"},"trusted":true},"outputs":[],"source":["# !pip install einops"]},{"cell_type":"code","execution_count":149,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.117428Z","iopub.status.busy":"2024-09-21T23:16:41.116861Z","iopub.status.idle":"2024-09-21T23:16:41.132062Z","shell.execute_reply":"2024-09-21T23:16:41.131002Z","shell.execute_reply.started":"2024-09-21T23:16:41.117388Z"},"trusted":true},"outputs":[],"source":["import torch.nn.functional as F\n","from einops import rearrange\n","\n","# Swin Transformer Block\n","class SwinTransformerBlock(nn.Module):\n"," def __init__(self, dim, num_heads, window_size=7, shift_size=2):\n"," super(SwinTransformerBlock, self).__init__()\n"," self.dim = dim\n"," self.num_heads = num_heads\n"," self.window_size = window_size\n"," self.shift_size = shift_size\n","\n"," # Layer normalization\n"," self.norm1 = nn.LayerNorm(dim)\n"," \n"," # Window-based multi-head self-attention (W-MSA) or shifted W-MSA\n"," self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)\n","\n"," # Feed-forward network\n"," self.mlp = nn.Sequential(\n"," nn.Linear(dim, 4 * dim),\n"," nn.GELU(),\n"," nn.Linear(4 * dim, dim)\n"," )\n","\n"," # Another layer normalization\n"," self.norm2 = nn.LayerNorm(dim)\n","\n"," def forward(self, x, mask=None):\n"," # Input size\n"," B, H, W, C = x.shape\n","\n"," # Partition windows\n"," x = rearrange(x, 'b (h ws1) (w ws2) c -> (b h w) (ws1 ws2) c', ws1=self.window_size, ws2=self.window_size)\n","\n"," # Add cyclic shift if shift_size > 0\n"," if self.shift_size > 0:\n"," x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n","\n"," # Layer norm\n"," shortcut = x\n"," x = self.norm1(x)\n"," \n"," # Multi-head self-attention\n"," x, _ = self.attn(x, x, x, attn_mask=mask)\n"," \n"," # Residual connection\n"," x = shortcut + x\n","\n"," # MLP with another residual connection\n"," shortcut = x\n"," x = self.norm2(x)\n"," x = self.mlp(x)\n"," x = shortcut + x\n","\n"," # Reverse window partitioning\n"," x = rearrange(x, '(b h w) (ws1 ws2) c -> b (h ws1) (w ws2) c', h=H // self.window_size, w=W // self.window_size, ws1=self.window_size, ws2=self.window_size)\n","\n"," return x"]},{"cell_type":"code","execution_count":150,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.153207Z","iopub.status.busy":"2024-09-21T23:16:41.152730Z","iopub.status.idle":"2024-09-21T23:16:41.174097Z","shell.execute_reply":"2024-09-21T23:16:41.172886Z","shell.execute_reply.started":"2024-09-21T23:16:41.153164Z"},"trusted":true},"outputs":[],"source":["class ViTImage2Image(nn.Module):\n"," def __init__(self, img_size=512, patch_size=8, emb_dim=768, num_heads=12, num_layers=12, hidden_dim=3072, window_size=8):\n"," super(ViTImage2Image, self).__init__()\n"," self.img_size = img_size\n"," self.patch_size = patch_size\n"," self.emb_dim = emb_dim\n"," self.num_heads = num_heads\n"," self.num_layers = num_layers\n"," self.hidden_dim = hidden_dim\n"," self.window_size = window_size\n"," self.num_patches = (img_size // patch_size) ** 2\n"," \n"," self.patch_embed = nn.Conv2d(3, emb_dim, kernel_size=patch_size, stride=patch_size)\n"," self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, emb_dim))\n"," \n"," self.encoder_layers = nn.ModuleList([\n"," TransformerEncoderBlock(emb_dim, num_heads, hidden_dim, dropout=0.1)\n"," for _ in range(num_layers)\n"," ])\n"," \n"," self.decoder_layers = nn.ModuleList([\n"," TransformerDecoderBlock(emb_dim, num_heads, hidden_dim, dropout=0.1)\n"," for _ in range(num_layers)\n"," ])\n"," \n"," self.swin_layers = nn.ModuleList([\n"," SwinTransformerBlock(emb_dim, num_heads, window_size=window_size)\n"," for _ in range(num_layers)\n"," ])\n"," \n"," self.norm = nn.LayerNorm(emb_dim)\n"," self.mlp_head = nn.Sequential(\n"," nn.Linear(emb_dim, hidden_dim),\n"," nn.GELU(),\n"," nn.Linear(hidden_dim, emb_dim)\n"," )\n"," self.refinement = RefinementBlock(in_channels=emb_dim, out_channels=3, kernel_size=3, stride=1, padding=1)\n"," \n"," self.style_encoder = nn.Sequential(\n"," nn.Conv2d(3, emb_dim, kernel_size=3, stride=2, padding=1),\n"," nn.ReLU(),\n"," nn.AdaptiveAvgPool2d(1),\n"," nn.Flatten(),\n"," nn.Linear(emb_dim, emb_dim)\n"," )\n","\n"," def forward(self, content, style):\n"," x = self.patch_embed(content)\n"," B, C, H, W = x.shape\n"," x = x.flatten(2).transpose(1, 2)\n"," x = x + self.pos_embed\n"," \n"," style_features = self.style_encoder(style)\n"," \n"," skip_connections = []\n"," # Transformer encoder\n"," for encoder in self.encoder_layers:\n"," x, skip = encoder(x, style_features)\n"," skip_connections.append(skip)\n"," \n"," # Transformer decoder\n"," for decoder, skip in zip(self.decoder_layers, reversed(skip_connections)):\n"," x = decoder(x, x, skip, style_features)\n"," \n"," # Swin Transformer encoding\n"," x = x.view(B, H, W, C)\n"," for layer in self.swin_layers:\n"," x = layer(x)\n"," \n"," x = x.view(B, -1, C)\n"," x = self.norm(x)\n"," x = self.mlp_head(x)\n"," x = x.transpose(1, 2).view(B, self.emb_dim, H, W)\n"," x = self.refinement(x)\n"," x = nn.functional.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)\n"," return x"]},{"cell_type":"markdown","metadata":{},"source":["## DEFINE LOSS METRICS"]},{"cell_type":"code","execution_count":151,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.183842Z","iopub.status.busy":"2024-09-21T23:16:41.183339Z","iopub.status.idle":"2024-09-21T23:16:41.191559Z","shell.execute_reply":"2024-09-21T23:16:41.190562Z","shell.execute_reply.started":"2024-09-21T23:16:41.183800Z"},"trusted":true},"outputs":[],"source":["def total_variation_loss(x):\n"," return torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))\n","\n","def combined_loss(output, target):\n"," l1_loss = nn.L1Loss()(output, target)\n"," tv_loss = total_variation_loss(output)\n"," return l1_loss + 0.0001 * tv_loss"]},{"cell_type":"markdown","metadata":{},"source":["## DEFINE TRAINING METRICS"]},{"cell_type":"code","execution_count":152,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.213172Z","iopub.status.busy":"2024-09-21T23:16:41.212746Z","iopub.status.idle":"2024-09-21T23:16:41.219196Z","shell.execute_reply":"2024-09-21T23:16:41.218074Z","shell.execute_reply.started":"2024-09-21T23:16:41.213131Z"},"trusted":true},"outputs":[],"source":["def psnr(img1, img2):\n"," mse = torch.mean((img1 - img2) ** 2)\n"," if mse == 0:\n"," return float('inf')\n"," return 20 * torch.log10(1.0 / torch.sqrt(mse))"]},{"cell_type":"markdown","metadata":{},"source":["## DATALOADERS"]},{"cell_type":"code","execution_count":181,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:24:17.558825Z","iopub.status.busy":"2024-09-21T23:24:17.557763Z","iopub.status.idle":"2024-09-21T23:24:17.588295Z","shell.execute_reply":"2024-09-21T23:24:17.587164Z","shell.execute_reply.started":"2024-09-21T23:24:17.558778Z"},"trusted":true},"outputs":[],"source":["import os\n","from PIL import Image\n","from torch.utils.data import Dataset, DataLoader\n","from torchvision import transforms\n","\n","# Custom Dataset class\n","class ImageLabelDataset(Dataset):\n"," def __init__(self, images_dir, labels_dir, transform=None):\n"," self.images_dir = images_dir\n"," self.labels_dir = labels_dir\n"," self.image_filenames = sorted(os.listdir(images_dir))\n"," self.label_filenames = sorted(os.listdir(labels_dir))\n"," self.transform = transform\n","\n"," assert len(self.image_filenames) == len(self.label_filenames), \"Mismatch in number of images and labels\"\n","\n"," def __len__(self):\n"," return len(self.image_filenames)\n","\n"," def __getitem__(self, idx):\n"," image_path = os.path.join(self.images_dir, self.image_filenames[idx])\n"," label_path = os.path.join(self.labels_dir, self.label_filenames[idx])\n"," \n"," # Open image and label using PIL\n"," image = Image.open(image_path).convert('RGB') # Ensure image is RGB\n"," label = Image.open(label_path).convert('RGB') # Ensure label is RGB (or use 'L' for grayscale)\n","\n"," if self.transform:\n"," # Apply the same transformation to both the image and the label\n"," image = self.transform(image)\n"," label = self.transform(label)\n","\n"," return image, label\n","\n","resize_dim = 512 # resize to dimensions\n","batch_size = 4 # Adjust batch size according to your GPU memory capacity\n"," \n","# Define transformations (resize, convert to tensor, normalize)\n","transform = transforms.Compose([\n"," transforms.Resize((resize_dim, resize_dim)), # Resize images\n"," transforms.ToTensor(), # Convert images to PyTorch tensors\n"," transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1] for each channel\n","])\n","\n","# Create dataset instances for training\n","images_dir = '/kaggle/working/images/' # Replace with your image directory path\n","labels_dir = '/kaggle/working/labels/' # Replace with your label directory path\n","dataset = ImageLabelDataset(images_dir, labels_dir, transform=transform)\n","\n","# Create DataLoader with batch size control\n","dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)"]},{"cell_type":"markdown","metadata":{},"source":["## TRAINING SCRIPT"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:25:08.871609Z","iopub.status.busy":"2024-09-21T23:25:08.871150Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","import os\n","\n","# Instantiate the model\n","model = ViTImage2Image(img_size=resize_dim, patch_size=16, emb_dim=768, num_heads=16, num_layers=8, hidden_dim=3072)\n","\n","# Move model to the appropriate device (GPU if available)\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","# If using 2 x T4 GPU only\n","model = nn.DataParallel(model, device_ids = [0,1])\n","\n","# Move model to device\n","model = model.to(device)\n","\n","# Optimizer\n","optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n","\n","# Number of epochs\n","num_epochs = 100\n","\n","best_loss = float('inf')\n","\n","# Path to save the best model\n","save_dir = 'saved_models'\n","os.makedirs(save_dir, exist_ok=True) # Create directory if it doesn't exist\n","best_model_path = os.path.join(save_dir, 'best_model.pth')\n","\n","# Initialize best loss as a large value\n","best_loss = float('inf')\n","\n","# Training loop with tqdm\n","for epoch in range(num_epochs):\n"," model.train()\n"," running_loss = 0.0\n"," running_psnr = 0.0\n"," \n"," # Wrap dataloader with tqdm for progress bar\n"," pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f\"Epoch {epoch+1}/{num_epochs}\")\n"," \n"," for batch_idx, (input, target) in pbar:\n"," # Move data to the same device as the model\n"," input, target = input.to(device), target.to(device)\n"," \n"," optimizer.zero_grad() # Clear the gradients from the last step\n"," output = model(input, target) # Forward pass\n","# print(f\"Input shape: {input.shape}, Output shape: {output.shape}, Target shape: {target.shape}\")\n"," loss = combined_loss(output, target) # Compute the loss\n"," \n"," loss.backward() # Backward pass (compute gradients)\n"," optimizer.step() # Update the weights\n"," \n"," running_loss += loss.item() # Accumulate the loss for this batch\n"," \n"," # Calculate metrics\n"," current_psnr = psnr(output, target).item()\n"," running_psnr += current_psnr\n"," \n"," # Update progress bar\n"," pbar.set_postfix({\n"," 'loss': loss.item(),\n"," 'psnr': current_psnr,\n"," })\n"," \n"," # Calculate the average loss and metrics for the epoch\n"," epoch_loss = running_loss / len(dataloader)\n"," avg_psnr = running_psnr / len(dataloader)\n"," \n"," print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, PSNR: {avg_psnr:.4f}\")\n"," \n"," # Save the best model based on the lowest loss\n"," if epoch_loss < best_loss:\n"," best_loss = epoch_loss\n"," torch.save(model.state_dict(), best_model_path)\n"," print(f\"New best model saved at epoch {epoch+1} with loss {best_loss:.4f}\")\n","\n","print(\"Training complete\")"]},{"cell_type":"markdown","metadata":{},"source":["## CLEAR CUDA CACHE"]},{"cell_type":"code","execution_count":193,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:25:05.712036Z","iopub.status.busy":"2024-09-21T23:25:05.711264Z","iopub.status.idle":"2024-09-21T23:25:05.899856Z","shell.execute_reply":"2024-09-21T23:25:05.898677Z","shell.execute_reply.started":"2024-09-21T23:25:05.711993Z"},"trusted":true},"outputs":[{"data":{"text/plain":["0"]},"execution_count":193,"metadata":{},"output_type":"execute_result"}],"source":["import gc\n","torch.cuda.empty_cache()\n","gc.collect()\n","#RuntimeError: The size of tensor a (768) must match the size of tensor b (512) at non-singleton dimension 3"]},{"cell_type":"code","execution_count":180,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:24:04.448611Z","iopub.status.busy":"2024-09-21T23:24:04.448163Z","iopub.status.idle":"2024-09-21T23:24:04.462532Z","shell.execute_reply":"2024-09-21T23:24:04.461381Z","shell.execute_reply.started":"2024-09-21T23:24:04.448567Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["223543305\n","DataParallel(\n"," (module): ViTImage2Image(\n"," (patch_embed): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n"," (encoder_layers): ModuleList(\n"," (0-7): 8 x TransformerEncoderBlock(\n"," (attn): LocationBasedMultiheadAttention(\n"," (q_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (k_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (v_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (out_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (ff): Sequential(\n"," (0): Linear(in_features=768, out_features=3072, bias=True)\n"," (1): ReLU()\n"," (2): Linear(in_features=3072, out_features=768, bias=True)\n"," )\n"," (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (adain): AdaIN(\n"," (norm): InstanceNorm1d(768, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n"," (fc): Linear(in_features=768, out_features=1536, bias=True)\n"," )\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," )\n"," (decoder_layers): ModuleList(\n"," (0-7): 8 x TransformerDecoderBlock(\n"," (attn1): LocationBasedMultiheadAttention(\n"," (q_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (k_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (v_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (out_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (attn2): LocationBasedMultiheadAttention(\n"," (q_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (k_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (v_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (out_proj): Linear(in_features=768, out_features=768, bias=True)\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (ff): Sequential(\n"," (0): Linear(in_features=768, out_features=3072, bias=True)\n"," (1): ReLU()\n"," (2): Linear(in_features=3072, out_features=768, bias=True)\n"," )\n"," (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (norm4): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (adain1): AdaIN(\n"," (norm): InstanceNorm1d(768, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n"," (fc): Linear(in_features=768, out_features=1536, bias=True)\n"," )\n"," (adain2): AdaIN(\n"," (norm): InstanceNorm1d(768, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n"," (fc): Linear(in_features=768, out_features=1536, bias=True)\n"," )\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," )\n"," (swin_layers): ModuleList(\n"," (0-7): 8 x SwinTransformerBlock(\n"," (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (attn): MultiheadAttention(\n"," (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n"," )\n"," (mlp): Sequential(\n"," (0): Linear(in_features=768, out_features=3072, bias=True)\n"," (1): GELU(approximate='none')\n"," (2): Linear(in_features=3072, out_features=768, bias=True)\n"," )\n"," (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," )\n"," )\n"," (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (mlp_head): Sequential(\n"," (0): Linear(in_features=768, out_features=3072, bias=True)\n"," (1): GELU(approximate='none')\n"," (2): Linear(in_features=3072, out_features=768, bias=True)\n"," )\n"," (refinement): RefinementBlock(\n"," (conv): Conv2d(768, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (relu): ReLU(inplace=True)\n"," )\n"," (style_encoder): Sequential(\n"," (0): Conv2d(3, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n"," (1): ReLU()\n"," (2): AdaptiveAvgPool2d(output_size=1)\n"," (3): Flatten(start_dim=1, end_dim=-1)\n"," (4): Linear(in_features=768, out_features=768, bias=True)\n"," )\n"," )\n",")\n"]}],"source":["total_params = sum(p.numel() for p in model.parameters())\n","print(total_params)\n","print(model)"]},{"cell_type":"markdown","metadata":{},"source":["## INFERENCE"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-21T23:16:44.175025Z","iopub.status.idle":"2024-09-21T23:16:44.184902Z","shell.execute_reply":"2024-09-21T23:16:44.184625Z","shell.execute_reply.started":"2024-09-21T23:16:44.184592Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torchvision.transforms as transforms\n","from PIL import Image\n","import matplotlib.pyplot as plt\n","\n","def load_image(image_path, img_size=1024):\n"," image = Image.open(image_path).convert('RGB')\n"," transform = transforms.Compose([\n"," transforms.Resize((img_size, img_size)),\n"," transforms.ToTensor(),\n"," transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n"," ])\n"," return transform(image).unsqueeze(0)\n","\n","def perform_inference(model, input_image):\n"," model.eval()\n"," with torch.no_grad():\n"," output = model(input_image)\n"," return output\n","\n","def visualize_tensor(tensor, title):\n"," img = tensor.squeeze().cpu().permute(1, 2, 0).numpy()\n"," img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0, 1]\n"," plt.imshow(img)\n"," plt.title(title)\n"," plt.axis('off')\n"," plt.show()\n","\n","def main():\n"," # Load the model (adjust parameters as needed)\n"," model = ViTImage2Image(img_size=1024, patch_size=8, emb_dim=768, num_heads=12, num_layers=12, hidden_dim=3072)\n"," model.load_state_dict(torch.load('/kaggle/working/saved_models/best_model.pth'))\n"," \n"," device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"," model = model.to(device)\n"," \n"," # Load and preprocess input image\n"," input_image = load_image('/kaggle/working/labels/00143.png')\n"," input_image = input_image.to(device)\n"," \n"," # Perform inference\n"," output = perform_inference(model, input_image)\n"," \n"," # Debug: Print shape and statistics of the output\n"," print(\"Output shape:\", output.shape)\n"," print(\"Output min:\", output.min().item())\n"," print(\"Output max:\", output.max().item())\n"," print(\"Output mean:\", output.mean().item())\n"," print(\"Output std:\", output.std().item())\n"," \n"," # Visualize input and output\n"," visualize_tensor(input_image, \"Input Image\")\n"," visualize_tensor(output, \"Output Image (Before Processing)\")\n"," \n"," # Try different post-processing steps\n"," # 1. Sigmoid activation\n"," output_sigmoid = torch.sigmoid(output)\n"," visualize_tensor(output_sigmoid, \"Output Image (After Sigmoid)\")\n"," \n"," # 2. Tanh activation\n"," output_tanh = torch.tanh(output)\n"," visualize_tensor(output_tanh, \"Output Image (After Tanh)\")\n"," \n"," # 3. Min-Max Normalization\n"," output_normalized = (output - output.min()) / (output.max() - output.min())\n"," visualize_tensor(output_normalized, \"Output Image (After Min-Max Normalization)\")\n","\n","if __name__ == \"__main__\":\n"," main()"]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30761,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"}},"nbformat":4,"nbformat_minor":4}