aoxo
/

Image-to-Image
English
art
File size: 35,028 Bytes
c59452d
1
2
{"cells":[{"cell_type":"markdown","metadata":{},"source":["## DOWNLOADING DATASETS"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/01_images.zip\n","!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/01_labels.zip\n","!unzip /kaggle/working/01_images.zip\n","!unzip /kaggle/working/01_labels.zip\n","!rm /kaggle/working/01_images.zip\n","!rm /kaggle/working/01_labels.zip"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/02_images.zip\n","!wget https://download.visinf.tu-darmstadt.de/data/from_games/data/02_labels.zip\n","!unzip /kaggle/working/02_images.zip\n","!unzip /kaggle/working/02_labels.zip\n","!rm /kaggle/working/02_images.zip\n","!rm /kaggle/working/02_labels.zip"]},{"cell_type":"markdown","metadata":{},"source":["## IN-MEMORY DOWNLOAD AND ZIPPING: REDUCING SECONDARY STORAGE FOOTPRINT & OVERHEAD"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["import requests\n","import zipfile\n","import io\n","\n","# URL of the zip file\n","url = 'https://download.visinf.tu-darmstadt.de/data/from_games/data/03_labels.zip'\n","\n","# Create a streaming GET request\n","response = requests.get(url, stream=True)\n","\n","# Create an in-memory bytes buffer for the zip file\n","zip_buffer = io.BytesIO()\n","\n","# Download the file in chunks and write to the in-memory buffer\n","for chunk in response.iter_content(chunk_size=1024):\n","    if chunk:\n","        zip_buffer.write(chunk)\n","\n","# Seek to the beginning of the buffer\n","zip_buffer.seek(0)\n","\n","# Open the zip file in memory\n","with zipfile.ZipFile(zip_buffer, 'r') as zip_ref:\n","    # Extract all files directly to your desired directory\n","    zip_ref.extractall('/kaggle/working/')\n","\n","# No need to explicitly delete the zip file, it's only in memory"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["import os\n","\n","_, _, files = next(os.walk(\"/kaggle/working/labels/\"))\n","print(len(files))"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["import torch\n","import sys\n","print('__Python VERSION:', sys.version)\n","print('__pyTorch VERSION:', torch.__version__)\n","print('__CUDA VERSION')\n","from subprocess import call\n","# call([\"nvcc\", \"--version\"]) does not work\n","! nvcc --version\n","print('__CUDNN VERSION:', torch.backends.cudnn.version())\n","print('__Number CUDA Devices:', torch.cuda.device_count())\n","print('__Devices')\n","call([\"nvidia-smi\", \"--format=csv\", \"--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free\"])\n","print('Active CUDA Device: GPU', torch.cuda.current_device())\n","print ('Available devices ', torch.cuda.device_count())\n","print ('Current cuda device ', torch.cuda.current_device())"]},{"cell_type":"markdown","metadata":{},"source":["## DEFINING STYLE TRANSFER VISION TRANSFORMER ARCHITECTURE"]},{"cell_type":"code","execution_count":143,"metadata":{"_kg_hide-output":true,"execution":{"iopub.execute_input":"2024-09-21T23:16:41.009559Z","iopub.status.busy":"2024-09-21T23:16:41.008700Z","iopub.status.idle":"2024-09-21T23:16:41.019269Z","shell.execute_reply":"2024-09-21T23:16:41.017876Z","shell.execute_reply.started":"2024-09-21T23:16:41.009514Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import math\n","\n","class PatchEmbedding(nn.Module):\n","    def __init__(self, in_channels=3, patch_size=8, emb_dim=768, img_size=256):\n","        super(PatchEmbedding, self).__init__()\n","        self.patch_size = patch_size\n","        num_patches = (img_size // patch_size) ** 2\n","        self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)\n","        self.positional_encoding = nn.Parameter(torch.zeros(1, num_patches, emb_dim))\n","\n","    def forward(self, x):\n","        # Shape of x: (batch_size, channels, height, width)\n","        batch_size = x.shape[0]\n","        x = self.proj(x)  # (batch_size, emb_dim, H/P, W/P)\n","        x = x.flatten(2)  # Flatten spatial dimensions\n","        x = x.transpose(1, 2)  # (batch_size, num_patches, emb_dim)\n","        x += self.positional_encoding  # Add positional encoding\n","        return x"]},{"cell_type":"code","execution_count":144,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.036391Z","iopub.status.busy":"2024-09-21T23:16:41.035657Z","iopub.status.idle":"2024-09-21T23:16:41.055122Z","shell.execute_reply":"2024-09-21T23:16:41.053886Z","shell.execute_reply.started":"2024-09-21T23:16:41.036334Z"},"trusted":true},"outputs":[],"source":["class TransformerEncoderBlock(nn.Module):\n","    def __init__(self, emb_dim=768, num_heads=8, hidden_dim=2048, dropout=0.1):\n","        super(TransformerEncoderBlock, self).__init__()\n","        self.attn = LocationBasedMultiheadAttention(emb_dim, num_heads, dropout=dropout)\n","        self.ff = nn.Sequential(\n","            nn.Linear(emb_dim, hidden_dim),\n","            nn.ReLU(),\n","            nn.Linear(hidden_dim, emb_dim),\n","        )\n","        self.norm1 = nn.LayerNorm(emb_dim)\n","        self.norm2 = nn.LayerNorm(emb_dim)\n","        self.adain = AdaIN(emb_dim)\n","        self.dropout = nn.Dropout(dropout)\n","\n","    def forward(self, x, style):\n","        attn_output = self.attn(x, x, x)\n","        x = x + self.dropout(attn_output)\n","        x = self.norm1(x)\n","        x = self.adain(x, style)\n","        intermediate = x  # This is the intermediate representation for the skip connection\n","        ff_output = self.ff(x)\n","        x = x + self.dropout(ff_output)\n","        x = self.norm2(x)\n","        x = self.adain(x, style)\n","        return x, intermediate\n","    \n","class TransformerDecoderBlock(nn.Module):\n","    def __init__(self, emb_dim=768, num_heads=8, hidden_dim=2048, dropout=0.1):\n","        super(TransformerDecoderBlock, self).__init__()\n","        self.attn1 = LocationBasedMultiheadAttention(emb_dim, num_heads, dropout=dropout)\n","        self.attn2 = LocationBasedMultiheadAttention(emb_dim, num_heads, dropout=dropout)\n","        self.ff = nn.Sequential(\n","            nn.Linear(emb_dim, hidden_dim),\n","            nn.ReLU(),\n","            nn.Linear(hidden_dim, emb_dim),\n","        )\n","        self.norm1 = nn.LayerNorm(emb_dim)\n","        self.norm2 = nn.LayerNorm(emb_dim)\n","        self.norm3 = nn.LayerNorm(emb_dim)\n","        self.norm4 = nn.LayerNorm(emb_dim)\n","        self.adain1 = AdaIN(emb_dim)\n","        self.adain2 = AdaIN(emb_dim)\n","        self.dropout = nn.Dropout(dropout)\n","\n","    def forward(self, x, enc_output, skip_connection, style):\n","        attn_output1 = self.attn1(x, x, x)\n","        x = x + self.dropout(attn_output1)\n","        x = self.norm1(x)\n","        x = self.adain1(x, style)\n","        attn_output2 = self.attn2(x, enc_output, enc_output)\n","        x = x + self.dropout(attn_output2)\n","        x = self.norm2(x)\n","        x = self.adain2(x, style)\n","        ff_output = self.ff(x)\n","        x = x + self.dropout(ff_output)\n","        x = self.norm3(x)\n","        x = x + self.norm4(skip_connection)\n","        return x"]},{"cell_type":"code","execution_count":145,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.057990Z","iopub.status.busy":"2024-09-21T23:16:41.057280Z","iopub.status.idle":"2024-09-21T23:16:41.072000Z","shell.execute_reply":"2024-09-21T23:16:41.070842Z","shell.execute_reply.started":"2024-09-21T23:16:41.057950Z"},"trusted":true},"outputs":[],"source":["class LocationBasedMultiheadAttention(nn.Module):\n","    def __init__(self, emb_dim, num_heads, dropout=0.1):\n","        super(LocationBasedMultiheadAttention, self).__init__()\n","        self.emb_dim = emb_dim\n","        self.num_heads = num_heads\n","        self.head_dim = emb_dim // num_heads\n","        \n","        self.q_proj = nn.Linear(emb_dim, emb_dim)\n","        self.k_proj = nn.Linear(emb_dim, emb_dim)\n","        self.v_proj = nn.Linear(emb_dim, emb_dim)\n","        self.out_proj = nn.Linear(emb_dim, emb_dim)\n","        \n","        self.dropout = nn.Dropout(dropout)\n","        \n","        # Learnable position encodings\n","        self.pos_enc = nn.Parameter(torch.randn(1, 1, emb_dim))\n","\n","    def forward(self, q, k, v):\n","        batch_size, seq_len, _ = q.shape\n","        \n","        # Add positional encodings\n","        q = q + self.pos_enc\n","        k = k + self.pos_enc\n","        \n","        # Separate heads\n","        q = self.q_proj(q).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n","        k = self.k_proj(k).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n","        v = self.v_proj(v).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n","        \n","        # Compute attention scores\n","        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)\n","        attn_probs = F.softmax(scores, dim=-1)\n","        attn_probs = self.dropout(attn_probs)\n","        \n","        # Compute output\n","        out = torch.matmul(attn_probs, v)\n","        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.emb_dim)\n","        out = self.out_proj(out)\n","        \n","        return out"]},{"cell_type":"code","execution_count":146,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.073857Z","iopub.status.busy":"2024-09-21T23:16:41.073403Z","iopub.status.idle":"2024-09-21T23:16:41.085968Z","shell.execute_reply":"2024-09-21T23:16:41.084912Z","shell.execute_reply.started":"2024-09-21T23:16:41.073817Z"},"trusted":true},"outputs":[],"source":["class AdaIN(nn.Module):\n","    def __init__(self, emb_dim):\n","        super(AdaIN, self).__init__()\n","        self.norm = nn.InstanceNorm1d(emb_dim, affine=False)\n","        self.fc = nn.Linear(emb_dim, emb_dim * 2)\n","\n","    def forward(self, x, style):\n","        style = self.fc(style).unsqueeze(1)\n","        gamma, beta = style.chunk(2, dim=-1)\n","        out = self.norm(x.transpose(1, 2)).transpose(1, 2)\n","        out = gamma * out + beta\n","        return out"]},{"cell_type":"code","execution_count":147,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.088410Z","iopub.status.busy":"2024-09-21T23:16:41.088010Z","iopub.status.idle":"2024-09-21T23:16:41.099022Z","shell.execute_reply":"2024-09-21T23:16:41.097846Z","shell.execute_reply.started":"2024-09-21T23:16:41.088350Z"},"trusted":true},"outputs":[],"source":["class RefinementBlock(nn.Module):\n","    def __init__(self, in_channels=768, out_channels=3, kernel_size=3, stride=1, padding=1):\n","        super(RefinementBlock, self).__init__()\n","        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)\n","        self.bn = nn.BatchNorm2d(out_channels)\n","        self.relu = nn.ReLU(inplace=True)\n","        \n","    def forward(self, x):\n","        x = self.conv(x)\n","        x = self.bn(x)\n","        x = self.relu(x)\n","        return x"]},{"cell_type":"code","execution_count":148,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.101869Z","iopub.status.busy":"2024-09-21T23:16:41.101083Z","iopub.status.idle":"2024-09-21T23:16:41.107762Z","shell.execute_reply":"2024-09-21T23:16:41.106685Z","shell.execute_reply.started":"2024-09-21T23:16:41.101815Z"},"trusted":true},"outputs":[],"source":["# !pip install einops"]},{"cell_type":"code","execution_count":149,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.117428Z","iopub.status.busy":"2024-09-21T23:16:41.116861Z","iopub.status.idle":"2024-09-21T23:16:41.132062Z","shell.execute_reply":"2024-09-21T23:16:41.131002Z","shell.execute_reply.started":"2024-09-21T23:16:41.117388Z"},"trusted":true},"outputs":[],"source":["import torch.nn.functional as F\n","from einops import rearrange\n","\n","# Swin Transformer Block\n","class SwinTransformerBlock(nn.Module):\n","    def __init__(self, dim, num_heads, window_size=7, shift_size=2):\n","        super(SwinTransformerBlock, self).__init__()\n","        self.dim = dim\n","        self.num_heads = num_heads\n","        self.window_size = window_size\n","        self.shift_size = shift_size\n","\n","        # Layer normalization\n","        self.norm1 = nn.LayerNorm(dim)\n","        \n","        # Window-based multi-head self-attention (W-MSA) or shifted W-MSA\n","        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)\n","\n","        # Feed-forward network\n","        self.mlp = nn.Sequential(\n","            nn.Linear(dim, 4 * dim),\n","            nn.GELU(),\n","            nn.Linear(4 * dim, dim)\n","        )\n","\n","        # Another layer normalization\n","        self.norm2 = nn.LayerNorm(dim)\n","\n","    def forward(self, x, mask=None):\n","        # Input size\n","        B, H, W, C = x.shape\n","\n","        # Partition windows\n","        x = rearrange(x, 'b (h ws1) (w ws2) c -> (b h w) (ws1 ws2) c', ws1=self.window_size, ws2=self.window_size)\n","\n","        # Add cyclic shift if shift_size > 0\n","        if self.shift_size > 0:\n","            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n","\n","        # Layer norm\n","        shortcut = x\n","        x = self.norm1(x)\n","        \n","        # Multi-head self-attention\n","        x, _ = self.attn(x, x, x, attn_mask=mask)\n","        \n","        # Residual connection\n","        x = shortcut + x\n","\n","        # MLP with another residual connection\n","        shortcut = x\n","        x = self.norm2(x)\n","        x = self.mlp(x)\n","        x = shortcut + x\n","\n","        # Reverse window partitioning\n","        x = rearrange(x, '(b h w) (ws1 ws2) c -> b (h ws1) (w ws2) c', h=H // self.window_size, w=W // self.window_size, ws1=self.window_size, ws2=self.window_size)\n","\n","        return x"]},{"cell_type":"code","execution_count":150,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.153207Z","iopub.status.busy":"2024-09-21T23:16:41.152730Z","iopub.status.idle":"2024-09-21T23:16:41.174097Z","shell.execute_reply":"2024-09-21T23:16:41.172886Z","shell.execute_reply.started":"2024-09-21T23:16:41.153164Z"},"trusted":true},"outputs":[],"source":["class ViTImage2Image(nn.Module):\n","    def __init__(self, img_size=512, patch_size=8, emb_dim=768, num_heads=12, num_layers=12, hidden_dim=3072, window_size=8):\n","        super(ViTImage2Image, self).__init__()\n","        self.img_size = img_size\n","        self.patch_size = patch_size\n","        self.emb_dim = emb_dim\n","        self.num_heads = num_heads\n","        self.num_layers = num_layers\n","        self.hidden_dim = hidden_dim\n","        self.window_size = window_size\n","        self.num_patches = (img_size // patch_size) ** 2\n","        \n","        self.patch_embed = nn.Conv2d(3, emb_dim, kernel_size=patch_size, stride=patch_size)\n","        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, emb_dim))\n","        \n","        self.encoder_layers = nn.ModuleList([\n","            TransformerEncoderBlock(emb_dim, num_heads, hidden_dim, dropout=0.1)\n","            for _ in range(num_layers)\n","        ])\n","        \n","        self.decoder_layers = nn.ModuleList([\n","            TransformerDecoderBlock(emb_dim, num_heads, hidden_dim, dropout=0.1)\n","            for _ in range(num_layers)\n","        ])\n","        \n","        self.swin_layers = nn.ModuleList([\n","            SwinTransformerBlock(emb_dim, num_heads, window_size=window_size)\n","            for _ in range(num_layers)\n","        ])\n","        \n","        self.norm = nn.LayerNorm(emb_dim)\n","        self.mlp_head = nn.Sequential(\n","            nn.Linear(emb_dim, hidden_dim),\n","            nn.GELU(),\n","            nn.Linear(hidden_dim, emb_dim)\n","        )\n","        self.refinement = RefinementBlock(in_channels=emb_dim, out_channels=3, kernel_size=3, stride=1, padding=1)\n","        \n","        self.style_encoder = nn.Sequential(\n","            nn.Conv2d(3, emb_dim, kernel_size=3, stride=2, padding=1),\n","            nn.ReLU(),\n","            nn.AdaptiveAvgPool2d(1),\n","            nn.Flatten(),\n","            nn.Linear(emb_dim, emb_dim)\n","        )\n","\n","    def forward(self, content, style):\n","        x = self.patch_embed(content)\n","        B, C, H, W = x.shape\n","        x = x.flatten(2).transpose(1, 2)\n","        x = x + self.pos_embed\n","        \n","        style_features = self.style_encoder(style)\n","        \n","        skip_connections = []\n","        # Transformer encoder\n","        for encoder in self.encoder_layers:\n","            x, skip = encoder(x, style_features)\n","            skip_connections.append(skip)\n","        \n","        # Transformer decoder\n","        for decoder, skip in zip(self.decoder_layers, reversed(skip_connections)):\n","            x = decoder(x, x, skip, style_features)\n","        \n","        # Swin Transformer encoding\n","        x = x.view(B, H, W, C)\n","        for layer in self.swin_layers:\n","            x = layer(x)\n","        \n","        x = x.view(B, -1, C)\n","        x = self.norm(x)\n","        x = self.mlp_head(x)\n","        x = x.transpose(1, 2).view(B, self.emb_dim, H, W)\n","        x = self.refinement(x)\n","        x = nn.functional.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)\n","        return x"]},{"cell_type":"markdown","metadata":{},"source":["## DEFINE LOSS METRICS"]},{"cell_type":"code","execution_count":151,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.183842Z","iopub.status.busy":"2024-09-21T23:16:41.183339Z","iopub.status.idle":"2024-09-21T23:16:41.191559Z","shell.execute_reply":"2024-09-21T23:16:41.190562Z","shell.execute_reply.started":"2024-09-21T23:16:41.183800Z"},"trusted":true},"outputs":[],"source":["def total_variation_loss(x):\n","    return torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))\n","\n","def combined_loss(output, target):\n","    l1_loss = nn.L1Loss()(output, target)\n","    tv_loss = total_variation_loss(output)\n","    return l1_loss + 0.0001 * tv_loss"]},{"cell_type":"markdown","metadata":{},"source":["## DEFINE TRAINING METRICS"]},{"cell_type":"code","execution_count":152,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:16:41.213172Z","iopub.status.busy":"2024-09-21T23:16:41.212746Z","iopub.status.idle":"2024-09-21T23:16:41.219196Z","shell.execute_reply":"2024-09-21T23:16:41.218074Z","shell.execute_reply.started":"2024-09-21T23:16:41.213131Z"},"trusted":true},"outputs":[],"source":["def psnr(img1, img2):\n","    mse = torch.mean((img1 - img2) ** 2)\n","    if mse == 0:\n","        return float('inf')\n","    return 20 * torch.log10(1.0 / torch.sqrt(mse))"]},{"cell_type":"markdown","metadata":{},"source":["## DATALOADERS"]},{"cell_type":"code","execution_count":181,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:24:17.558825Z","iopub.status.busy":"2024-09-21T23:24:17.557763Z","iopub.status.idle":"2024-09-21T23:24:17.588295Z","shell.execute_reply":"2024-09-21T23:24:17.587164Z","shell.execute_reply.started":"2024-09-21T23:24:17.558778Z"},"trusted":true},"outputs":[],"source":["import os\n","from PIL import Image\n","from torch.utils.data import Dataset, DataLoader\n","from torchvision import transforms\n","\n","# Custom Dataset class\n","class ImageLabelDataset(Dataset):\n","    def __init__(self, images_dir, labels_dir, transform=None):\n","        self.images_dir = images_dir\n","        self.labels_dir = labels_dir\n","        self.image_filenames = sorted(os.listdir(images_dir))\n","        self.label_filenames = sorted(os.listdir(labels_dir))\n","        self.transform = transform\n","\n","        assert len(self.image_filenames) == len(self.label_filenames), \"Mismatch in number of images and labels\"\n","\n","    def __len__(self):\n","        return len(self.image_filenames)\n","\n","    def __getitem__(self, idx):\n","        image_path = os.path.join(self.images_dir, self.image_filenames[idx])\n","        label_path = os.path.join(self.labels_dir, self.label_filenames[idx])\n","        \n","        # Open image and label using PIL\n","        image = Image.open(image_path).convert('RGB')  # Ensure image is RGB\n","        label = Image.open(label_path).convert('RGB')  # Ensure label is RGB (or use 'L' for grayscale)\n","\n","        if self.transform:\n","            # Apply the same transformation to both the image and the label\n","            image = self.transform(image)\n","            label = self.transform(label)\n","\n","        return image, label\n","\n","resize_dim = 512 # resize to dimensions\n","batch_size = 4  # Adjust batch size according to your GPU memory capacity\n","    \n","# Define transformations (resize, convert to tensor, normalize)\n","transform = transforms.Compose([\n","    transforms.Resize((resize_dim, resize_dim)),  # Resize images\n","    transforms.ToTensor(),          # Convert images to PyTorch tensors\n","    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1] for each channel\n","])\n","\n","# Create dataset instances for training\n","images_dir = '/kaggle/working/images/'  # Replace with your image directory path\n","labels_dir = '/kaggle/working/labels/'  # Replace with your label directory path\n","dataset = ImageLabelDataset(images_dir, labels_dir, transform=transform)\n","\n","# Create DataLoader with batch size control\n","dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)"]},{"cell_type":"markdown","metadata":{},"source":["## TRAINING SCRIPT"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:25:08.871609Z","iopub.status.busy":"2024-09-21T23:25:08.871150Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","import os\n","\n","# Instantiate the model\n","model = ViTImage2Image(img_size=resize_dim, patch_size=16, emb_dim=768, num_heads=16, num_layers=8, hidden_dim=3072)\n","\n","# Move model to the appropriate device (GPU if available)\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","# If using 2 x T4 GPU only\n","model = nn.DataParallel(model, device_ids = [0,1])\n","\n","# Move model to device\n","model = model.to(device)\n","\n","# Optimizer\n","optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n","\n","# Number of epochs\n","num_epochs = 100\n","\n","best_loss = float('inf')\n","\n","# Path to save the best model\n","save_dir = 'saved_models'\n","os.makedirs(save_dir, exist_ok=True)  # Create directory if it doesn't exist\n","best_model_path = os.path.join(save_dir, 'best_model.pth')\n","\n","# Initialize best loss as a large value\n","best_loss = float('inf')\n","\n","# Training loop with tqdm\n","for epoch in range(num_epochs):\n","    model.train()\n","    running_loss = 0.0\n","    running_psnr = 0.0\n","    \n","    # Wrap dataloader with tqdm for progress bar\n","    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f\"Epoch {epoch+1}/{num_epochs}\")\n","    \n","    for batch_idx, (input, target) in pbar:\n","        # Move data to the same device as the model\n","        input, target = input.to(device), target.to(device)\n","        \n","        optimizer.zero_grad()  # Clear the gradients from the last step\n","        output = model(input, target)  # Forward pass\n","#         print(f\"Input shape: {input.shape}, Output shape: {output.shape}, Target shape: {target.shape}\")\n","        loss = combined_loss(output, target)  # Compute the loss\n","        \n","        loss.backward()  # Backward pass (compute gradients)\n","        optimizer.step()  # Update the weights\n","        \n","        running_loss += loss.item()  # Accumulate the loss for this batch\n","        \n","        # Calculate metrics\n","        current_psnr = psnr(output, target).item()\n","        running_psnr += current_psnr\n","        \n","        # Update progress bar\n","        pbar.set_postfix({\n","            'loss': loss.item(),\n","            'psnr': current_psnr,\n","        })\n","    \n","    # Calculate the average loss and metrics for the epoch\n","    epoch_loss = running_loss / len(dataloader)\n","    avg_psnr = running_psnr / len(dataloader)\n","    \n","    print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, PSNR: {avg_psnr:.4f}\")\n","    \n","    # Save the best model based on the lowest loss\n","    if epoch_loss < best_loss:\n","        best_loss = epoch_loss\n","        torch.save(model.state_dict(), best_model_path)\n","        print(f\"New best model saved at epoch {epoch+1} with loss {best_loss:.4f}\")\n","\n","print(\"Training complete\")"]},{"cell_type":"markdown","metadata":{},"source":["## CLEAR CUDA CACHE"]},{"cell_type":"code","execution_count":193,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:25:05.712036Z","iopub.status.busy":"2024-09-21T23:25:05.711264Z","iopub.status.idle":"2024-09-21T23:25:05.899856Z","shell.execute_reply":"2024-09-21T23:25:05.898677Z","shell.execute_reply.started":"2024-09-21T23:25:05.711993Z"},"trusted":true},"outputs":[{"data":{"text/plain":["0"]},"execution_count":193,"metadata":{},"output_type":"execute_result"}],"source":["import gc\n","torch.cuda.empty_cache()\n","gc.collect()\n","#RuntimeError: The size of tensor a (768) must match the size of tensor b (512) at non-singleton dimension 3"]},{"cell_type":"code","execution_count":180,"metadata":{"execution":{"iopub.execute_input":"2024-09-21T23:24:04.448611Z","iopub.status.busy":"2024-09-21T23:24:04.448163Z","iopub.status.idle":"2024-09-21T23:24:04.462532Z","shell.execute_reply":"2024-09-21T23:24:04.461381Z","shell.execute_reply.started":"2024-09-21T23:24:04.448567Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["223543305\n","DataParallel(\n","  (module): ViTImage2Image(\n","    (patch_embed): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n","    (encoder_layers): ModuleList(\n","      (0-7): 8 x TransformerEncoderBlock(\n","        (attn): LocationBasedMultiheadAttention(\n","          (q_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (k_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (v_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (out_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (dropout): Dropout(p=0.1, inplace=False)\n","        )\n","        (ff): Sequential(\n","          (0): Linear(in_features=768, out_features=3072, bias=True)\n","          (1): ReLU()\n","          (2): Linear(in_features=3072, out_features=768, bias=True)\n","        )\n","        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (adain): AdaIN(\n","          (norm): InstanceNorm1d(768, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n","          (fc): Linear(in_features=768, out_features=1536, bias=True)\n","        )\n","        (dropout): Dropout(p=0.1, inplace=False)\n","      )\n","    )\n","    (decoder_layers): ModuleList(\n","      (0-7): 8 x TransformerDecoderBlock(\n","        (attn1): LocationBasedMultiheadAttention(\n","          (q_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (k_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (v_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (out_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (dropout): Dropout(p=0.1, inplace=False)\n","        )\n","        (attn2): LocationBasedMultiheadAttention(\n","          (q_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (k_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (v_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (out_proj): Linear(in_features=768, out_features=768, bias=True)\n","          (dropout): Dropout(p=0.1, inplace=False)\n","        )\n","        (ff): Sequential(\n","          (0): Linear(in_features=768, out_features=3072, bias=True)\n","          (1): ReLU()\n","          (2): Linear(in_features=3072, out_features=768, bias=True)\n","        )\n","        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (norm4): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (adain1): AdaIN(\n","          (norm): InstanceNorm1d(768, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n","          (fc): Linear(in_features=768, out_features=1536, bias=True)\n","        )\n","        (adain2): AdaIN(\n","          (norm): InstanceNorm1d(768, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n","          (fc): Linear(in_features=768, out_features=1536, bias=True)\n","        )\n","        (dropout): Dropout(p=0.1, inplace=False)\n","      )\n","    )\n","    (swin_layers): ModuleList(\n","      (0-7): 8 x SwinTransformerBlock(\n","        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (attn): MultiheadAttention(\n","          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n","        )\n","        (mlp): Sequential(\n","          (0): Linear(in_features=768, out_features=3072, bias=True)\n","          (1): GELU(approximate='none')\n","          (2): Linear(in_features=3072, out_features=768, bias=True)\n","        )\n","        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","      )\n","    )\n","    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","    (mlp_head): Sequential(\n","      (0): Linear(in_features=768, out_features=3072, bias=True)\n","      (1): GELU(approximate='none')\n","      (2): Linear(in_features=3072, out_features=768, bias=True)\n","    )\n","    (refinement): RefinementBlock(\n","      (conv): Conv2d(768, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","      (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n","      (relu): ReLU(inplace=True)\n","    )\n","    (style_encoder): Sequential(\n","      (0): Conv2d(3, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n","      (1): ReLU()\n","      (2): AdaptiveAvgPool2d(output_size=1)\n","      (3): Flatten(start_dim=1, end_dim=-1)\n","      (4): Linear(in_features=768, out_features=768, bias=True)\n","    )\n","  )\n",")\n"]}],"source":["total_params = sum(p.numel() for p in model.parameters())\n","print(total_params)\n","print(model)"]},{"cell_type":"markdown","metadata":{},"source":["## INFERENCE"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-21T23:16:44.175025Z","iopub.status.idle":"2024-09-21T23:16:44.184902Z","shell.execute_reply":"2024-09-21T23:16:44.184625Z","shell.execute_reply.started":"2024-09-21T23:16:44.184592Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torchvision.transforms as transforms\n","from PIL import Image\n","import matplotlib.pyplot as plt\n","\n","def load_image(image_path, img_size=1024):\n","    image = Image.open(image_path).convert('RGB')\n","    transform = transforms.Compose([\n","        transforms.Resize((img_size, img_size)),\n","        transforms.ToTensor(),\n","        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n","    ])\n","    return transform(image).unsqueeze(0)\n","\n","def perform_inference(model, input_image):\n","    model.eval()\n","    with torch.no_grad():\n","        output = model(input_image)\n","    return output\n","\n","def visualize_tensor(tensor, title):\n","    img = tensor.squeeze().cpu().permute(1, 2, 0).numpy()\n","    img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0, 1]\n","    plt.imshow(img)\n","    plt.title(title)\n","    plt.axis('off')\n","    plt.show()\n","\n","def main():\n","    # Load the model (adjust parameters as needed)\n","    model = ViTImage2Image(img_size=1024, patch_size=8, emb_dim=768, num_heads=12, num_layers=12, hidden_dim=3072)\n","    model.load_state_dict(torch.load('/kaggle/working/saved_models/best_model.pth'))\n","    \n","    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","    model = model.to(device)\n","    \n","    # Load and preprocess input image\n","    input_image = load_image('/kaggle/working/labels/00143.png')\n","    input_image = input_image.to(device)\n","    \n","    # Perform inference\n","    output = perform_inference(model, input_image)\n","    \n","    # Debug: Print shape and statistics of the output\n","    print(\"Output shape:\", output.shape)\n","    print(\"Output min:\", output.min().item())\n","    print(\"Output max:\", output.max().item())\n","    print(\"Output mean:\", output.mean().item())\n","    print(\"Output std:\", output.std().item())\n","    \n","    # Visualize input and output\n","    visualize_tensor(input_image, \"Input Image\")\n","    visualize_tensor(output, \"Output Image (Before Processing)\")\n","    \n","    # Try different post-processing steps\n","    # 1. Sigmoid activation\n","    output_sigmoid = torch.sigmoid(output)\n","    visualize_tensor(output_sigmoid, \"Output Image (After Sigmoid)\")\n","    \n","    # 2. Tanh activation\n","    output_tanh = torch.tanh(output)\n","    visualize_tensor(output_tanh, \"Output Image (After Tanh)\")\n","    \n","    # 3. Min-Max Normalization\n","    output_normalized = (output - output.min()) / (output.max() - output.min())\n","    visualize_tensor(output_normalized, \"Output Image (After Min-Max Normalization)\")\n","\n","if __name__ == \"__main__\":\n","    main()"]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30761,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"}},"nbformat":4,"nbformat_minor":4}