aoxo
/

Image-to-Image
English
art
File size: 15,769 Bytes
c59452d
1
2
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T14:41:17.988745Z","iopub.status.busy":"2024-09-22T14:41:17.988306Z","iopub.status.idle":"2024-09-22T14:41:17.994222Z","shell.execute_reply":"2024-09-22T14:41:17.993225Z","shell.execute_reply.started":"2024-09-22T14:41:17.988710Z"},"trusted":true},"outputs":[],"source":["import os\n","import requests\n","import zipfile\n","import io\n","\n","def lesgooo(name_it):\n","    response = requests.get(name_it, stream=True)\n","\n","    # Create an in-memory bytes buffer for the zip file\n","    zip_buffer = io.BytesIO()\n","\n","    # Download the file in chunks and write to the in-memory buffer\n","    for chunk in response.iter_content(chunk_size=1024):\n","        if chunk:\n","            zip_buffer.write(chunk)\n","            \n","    # Seek to the beginning of the buffer\n","    zip_buffer.seek(0)\n","\n","    # Open the zip file in memory\n","    with zipfile.ZipFile(zip_buffer, 'r') as zip_ref:\n","        # Extract all files directly to your desired directory\n","        zip_ref.extractall('/kaggle/working/')\n","\n","# Function to download and unzip files\n","def download_and_unzip(index):\n","    images_zip = f\"https://download.visinf.tu-darmstadt.de/data/from_games/data/{str(index).zfill(2)}_images.zip\"\n","    labels_zip = f\"https://download.visinf.tu-darmstadt.de/data/from_games/data/{str(index).zfill(2)}_labels.zip\"\n","    \n","    lesgooo(images_zip)\n","    lesgooo(labels_zip)\n","\n","# Loop through indices 1 to 10\n","for i in range(1, 11):\n","    download_and_unzip(i)\n","    print(f\"Part {i} done\")"]},{"cell_type":"code","execution_count":11,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T16:21:09.902762Z","iopub.status.busy":"2024-09-22T16:21:09.902339Z","iopub.status.idle":"2024-09-22T16:21:09.929962Z","shell.execute_reply":"2024-09-22T16:21:09.929117Z","shell.execute_reply.started":"2024-09-22T16:21:09.902721Z"},"trusted":true},"outputs":[],"source":["import os\n","from PIL import Image\n","from torch.utils.data import Dataset, DataLoader\n","from torchvision import transforms\n","\n","# Custom Dataset class\n","class ImageLabelDataset(Dataset):\n","    def __init__(self, images_dir, labels_dir, transform=None):\n","        self.images_dir = images_dir\n","        self.labels_dir = labels_dir\n","        self.image_filenames = sorted(os.listdir(images_dir))\n","        self.label_filenames = sorted(os.listdir(labels_dir))\n","        self.transform = transform\n","\n","        assert len(self.image_filenames) == len(self.label_filenames), \"Mismatch in number of images and labels\"\n","\n","    def __len__(self):\n","        return len(self.image_filenames)\n","\n","    def __getitem__(self, idx):\n","        image_path = os.path.join(self.images_dir, self.image_filenames[idx])\n","        label_path = os.path.join(self.labels_dir, self.label_filenames[idx])\n","        \n","        # Open image and label using PIL\n","        image = Image.open(image_path).convert('RGB')  # Ensure image is RGB\n","        label = Image.open(label_path).convert('RGB')  # Ensure label is RGB (or use 'L' for grayscale)\n","\n","        if self.transform:\n","            # Apply the same transformation to both the image and the label\n","            image = self.transform(image)\n","            label = self.transform(label)\n","\n","        return image, label\n","\n","resize_dim = 512 # resize to dimensions\n","batch_size = 12 # Adjust batch size according to your GPU memory capacity\n","    \n","# Define transformations (resize, convert to tensor, normalize)\n","transform = transforms.Compose([\n","    transforms.Resize((resize_dim, resize_dim)),  # Resize images\n","    transforms.ToTensor(),          # Convert images to PyTorch tensors\n","    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1] for each channel\n","])\n","\n","# Create dataset instances for training\n","images_dir = '/kaggle/working/images/'  # Replace with your image directory path\n","labels_dir = '/kaggle/working/labels/'  # Replace with your label directory path\n","dataset = ImageLabelDataset(images_dir, labels_dir, transform=transform)\n","\n","# Create DataLoader with batch size control\n","dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)"]},{"cell_type":"code","execution_count":2,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2024-09-22T16:16:46.786170Z","iopub.status.busy":"2024-09-22T16:16:46.785554Z","iopub.status.idle":"2024-09-22T16:16:46.810539Z","shell.execute_reply":"2024-09-22T16:16:46.809524Z","shell.execute_reply.started":"2024-09-22T16:16:46.786122Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class DepthwiseSeparableConv(nn.Module):\n","    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):\n","        super(DepthwiseSeparableConv, self).__init__()\n","        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)\n","        self.pointwise = nn.Conv2d(in_channels, out_channels, 1)\n","\n","    def forward(self, x):\n","        x = self.depthwise(x)\n","        x = self.pointwise(x)\n","        return x\n","\n","class LinearAttention(nn.Module):\n","    def __init__(self, dim, heads=4, dim_head=64):\n","        super(LinearAttention, self).__init__()\n","        self.heads = heads\n","        self.scale = dim_head ** -0.5\n","        inner_dim = dim_head * heads\n","        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False)\n","        self.to_out = nn.Conv2d(inner_dim, dim, 1)\n","\n","    def forward(self, x):\n","        b, c, h, w = x.shape\n","        qkv = self.to_qkv(x).chunk(3, dim=1)\n","        q, k, v = map(lambda t: t.reshape(b, self.heads, -1, h * w), qkv)\n","        q = q * self.scale\n","        k = k.softmax(dim=-1)\n","        context = torch.einsum('bhdn,bhen->bhde', k, v)\n","        out = torch.einsum('bhde,bhdn->bhen', context, q)\n","        out = out.reshape(b, -1, h, w)\n","        return self.to_out(out)\n","\n","class ResidualBlock(nn.Module):\n","    def __init__(self, channels):\n","        super(ResidualBlock, self).__init__()\n","        self.conv1 = DepthwiseSeparableConv(channels, channels, 3, padding=1)\n","        self.in1 = nn.InstanceNorm2d(channels)\n","        self.conv2 = DepthwiseSeparableConv(channels, channels, 3, padding=1)\n","        self.in2 = nn.InstanceNorm2d(channels)\n","\n","    def forward(self, x):\n","        residual = x\n","        out = F.relu(self.in1(self.conv1(x)))\n","        out = self.in2(self.conv2(out))\n","        out += residual\n","        return F.relu(out)\n","\n","class EfficientRealformer(nn.Module):\n","    def __init__(self, input_channels=3, output_channels=3, base_channels=64, num_residuals=6):\n","        super(EfficientRealformer, self).__init__()\n","        \n","        # Encoder\n","        self.encoder = nn.Sequential(\n","            DepthwiseSeparableConv(input_channels, base_channels, 7, padding=3),\n","            nn.InstanceNorm2d(base_channels),\n","            nn.ReLU(inplace=True),\n","            DepthwiseSeparableConv(base_channels, base_channels * 2, 3, stride=2, padding=1),\n","            nn.InstanceNorm2d(base_channels * 2),\n","            nn.ReLU(inplace=True),\n","            DepthwiseSeparableConv(base_channels * 2, base_channels * 4, 3, stride=2, padding=1),\n","            nn.InstanceNorm2d(base_channels * 4),\n","            nn.ReLU(inplace=True)\n","        )\n","        \n","        # Transformer blocks\n","        self.transformer_blocks = nn.ModuleList([\n","            nn.Sequential(\n","                LinearAttention(base_channels * 4),\n","                ResidualBlock(base_channels * 4)\n","            ) for _ in range(num_residuals)\n","        ])\n","        \n","        # Decoder\n","        self.decoder = nn.Sequential(\n","            nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 3, stride=2, padding=1, output_padding=1),\n","            nn.InstanceNorm2d(base_channels * 2),\n","            nn.ReLU(inplace=True),\n","            nn.ConvTranspose2d(base_channels * 2, base_channels, 3, stride=2, padding=1, output_padding=1),\n","            nn.InstanceNorm2d(base_channels),\n","            nn.ReLU(inplace=True),\n","            DepthwiseSeparableConv(base_channels, output_channels, 7, padding=3)\n","        )\n","        \n","        # Style encoder\n","        self.style_encoder = nn.Sequential(\n","            DepthwiseSeparableConv(input_channels, base_channels, 3, stride=2, padding=1),\n","            nn.InstanceNorm2d(base_channels),\n","            nn.ReLU(inplace=True),\n","            DepthwiseSeparableConv(base_channels, base_channels * 2, 3, stride=2, padding=1),\n","            nn.InstanceNorm2d(base_channels * 2),\n","            nn.ReLU(inplace=True),\n","            DepthwiseSeparableConv(base_channels * 2, base_channels * 4, 3, stride=2, padding=1),\n","            nn.AdaptiveAvgPool2d(1),\n","            nn.Flatten(),\n","            nn.Linear(base_channels * 4, base_channels * 4)\n","        )\n","\n","    def forward(self, content, style):\n","        # Encode content\n","        x = self.encoder(content)\n","        \n","        # Extract style features\n","        style_features = self.style_encoder(style)\n","        \n","        # Apply transformer blocks with style injection\n","        for block in self.transformer_blocks:\n","            x = block(x)\n","            x = x + style_features.view(*style_features.shape, 1, 1)\n","        \n","        # Decode\n","        output = self.decoder(x)\n","        return torch.tanh(output)  # Ensure output is in [-1, 1] range"]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T16:16:46.811827Z","iopub.status.busy":"2024-09-22T16:16:46.811543Z","iopub.status.idle":"2024-09-22T16:16:46.822285Z","shell.execute_reply":"2024-09-22T16:16:46.821363Z","shell.execute_reply.started":"2024-09-22T16:16:46.811794Z"},"trusted":true},"outputs":[],"source":["def total_variation_loss(x):\n","    return torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))\n","\n","def combined_loss(output, target):\n","    l1_loss = nn.L1Loss()(output, target)\n","    tv_loss = total_variation_loss(output)\n","    return l1_loss + 0.0001 * tv_loss\n","\n","def psnr(img1, img2):\n","    mse = torch.mean((img1 - img2) ** 2)\n","    if mse == 0:\n","        return float('inf')\n","    return 20 * torch.log10(1.0 / torch.sqrt(mse))"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T16:21:12.853117Z","iopub.status.busy":"2024-09-22T16:21:12.852217Z","iopub.status.idle":"2024-09-22T17:16:46.272031Z","shell.execute_reply":"2024-09-22T17:16:46.268249Z","shell.execute_reply.started":"2024-09-22T16:21:12.853074Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","import os\n","\n","# Instantiate the model\n","model = EfficientRealformer(input_channels=3, output_channels=3, base_channels=64, num_residuals=6)\n","\n","# Move model to the appropriate device (GPU if available)\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","# If using 2 x T4 GPU only\n","model = nn.DataParallel(model, device_ids = [0,1])\n","\n","# Move model to device\n","model = model.to(device)\n","\n","# Optimizer\n","optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n","\n","# Number of epochs\n","num_epochs = 20\n","\n","best_loss = float('inf')\n","\n","# Path to save the best model\n","save_dir = 'saved_models'\n","os.makedirs(save_dir, exist_ok=True)  # Create directory if it doesn't exist\n","best_model_path = os.path.join(save_dir, 'best_model.pth')\n","\n","# Initialize best loss as a large value\n","best_loss = float('inf')\n","\n","# Training loop with tqdm\n","for epoch in range(num_epochs):\n","    model.train()\n","    running_loss = 0.0\n","    running_psnr = 0.0\n","    \n","    # Wrap dataloader with tqdm for progress bar\n","    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f\"Epoch {epoch+1}/{num_epochs}\")\n","    \n","    for batch_idx, (input, target) in pbar:\n","        # Move data to the same device as the model\n","        input, target = input.to(device), target.to(device)\n","        \n","        optimizer.zero_grad()  # Clear the gradients from the last step\n","        output = model(input, target)  # Forward pass\n","#         print(f\"Input shape: {input.shape}, Output shape: {output.shape}, Target shape: {target.shape}\")\n","        loss = combined_loss(output, target)  # Compute the loss\n","        \n","        loss.backward()  # Backward pass (compute gradients)\n","        optimizer.step()  # Update the weights\n","        \n","        running_loss += loss.item()  # Accumulate the loss for this batch\n","        \n","        # Calculate metrics\n","        current_psnr = psnr(output, target).item()\n","        running_psnr += current_psnr\n","        \n","        # Update progress bar\n","        pbar.set_postfix({\n","            'loss': loss.item(),\n","            'psnr': current_psnr,\n","        })\n","    \n","    # Calculate the average loss and metrics for the epoch\n","    epoch_loss = running_loss / len(dataloader)\n","    avg_psnr = running_psnr / len(dataloader)\n","    \n","    print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, PSNR: {avg_psnr:.4f}\")\n","    \n","    # Save the best model based on the lowest loss\n","    if epoch_loss < best_loss:\n","        best_loss = epoch_loss\n","        torch.save(model.state_dict(), best_model_path)\n","        print(f\"New best model saved at epoch {epoch+1} with loss {best_loss:.4f}\")\n","\n","print(\"Training complete\")"]},{"cell_type":"code","execution_count":10,"metadata":{"execution":{"iopub.execute_input":"2024-09-22T16:20:38.886138Z","iopub.status.busy":"2024-09-22T16:20:38.885742Z","iopub.status.idle":"2024-09-22T16:20:39.010224Z","shell.execute_reply":"2024-09-22T16:20:39.009275Z","shell.execute_reply.started":"2024-09-22T16:20:38.886101Z"},"trusted":true},"outputs":[{"data":{"text/plain":["0"]},"execution_count":10,"metadata":{},"output_type":"execute_result"}],"source":["import gc\n","torch.cuda.empty_cache()\n","gc.collect()\n","#RuntimeError: The size of tensor a (768) must match the size of tensor b (512) at non-singleton dimension 3"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-22T16:17:14.709235Z","iopub.status.idle":"2024-09-22T16:17:14.709692Z","shell.execute_reply":"2024-09-22T16:17:14.709509Z","shell.execute_reply.started":"2024-09-22T16:17:14.709481Z"},"trusted":true},"outputs":[],"source":["total_params = sum(p.numel() for p in model.parameters())\n","print(total_params)\n","print(model)"]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30762,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"}},"nbformat":4,"nbformat_minor":4}