aoxo
/

Image-to-Image
English
art
File size: 24,580 Bytes
1a381bd
1
2
{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T14:52:05.628410Z","iopub.status.busy":"2024-10-06T14:52:05.627990Z","iopub.status.idle":"2024-10-06T14:52:17.944735Z","shell.execute_reply":"2024-10-06T14:52:17.943542Z","shell.execute_reply.started":"2024-10-06T14:52:05.628372Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import math\n","from einops import rearrange\n","\n","# Patch Embedding with Dynamic Positional Encoding\n","class DynamicPatchEmbedding(nn.Module):\n","    def __init__(self, in_channels=3, patch_size=8, emb_dim=768, img_size=256):\n","        super(DynamicPatchEmbedding, self).__init__()\n","        self.patch_size = patch_size\n","        self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)\n","        self.num_patches = (img_size // patch_size) ** 2\n","\n","    def forward(self, x):\n","        # Reshape the input to 4D if necessary\n","        if len(x.shape) == 2:\n","            batch_size = x.shape[0]\n","            channels = 3  # Assuming 3 feature channels\n","            h = w = int(math.sqrt(x.shape[1] // channels))  # Infer height and width\n","            x = x.view(batch_size, channels, h, w)  # Reshape to [batch_size, channels, height, width]\n","        \n","        x = self.proj(x)  # Apply Conv2d\n","        x = x.flatten(2).transpose(1, 2)  # (batch_size, num_patches, emb_dim)\n","        return x\n","    \n","# Style Adaptive Layer Normalization (SALN)\n","class StyleAdaptiveLayerNorm(nn.Module):\n","    def __init__(self, emb_dim):\n","        super(StyleAdaptiveLayerNorm, self).__init__()\n","        self.norm = nn.LayerNorm(emb_dim)\n","        self.fc = nn.Linear(emb_dim, emb_dim * 2)\n","\n","    def forward(self, x, style):\n","        style = self.fc(style).unsqueeze(1)\n","        gamma, beta = style.chunk(2, dim=-1)\n","        normalized_x = self.norm(x)\n","        return gamma * normalized_x + beta\n","\n","# LayerNorm-based Attention Conditioning using Pre-learned Attention Weights\n","class AttentionConditioning(nn.Module):\n","    def __init__(self, emb_dim, learned_attn_weights):\n","        super(AttentionConditioning, self).__init__()\n","        self.learned_attn_weights = learned_attn_weights\n","        self.norm = nn.LayerNorm(emb_dim)\n","\n","    def forward(self, untrained_attn_weights):\n","        # Condition untrained weights using pre-learned attention weights\n","        conditioned_weights = self.learned_attn_weights + untrained_attn_weights\n","        return self.norm(conditioned_weights)\n","\n","# Cross-Attention Layer\n","class CrossAttentionLayer(nn.Module):\n","    def __init__(self, emb_dim, num_heads, dropout=0.1):\n","        super(CrossAttentionLayer, self).__init__()\n","        self.attn = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, batch_first=True)\n","        self.dropout = nn.Dropout(dropout)\n","\n","    def forward(self, x, context):\n","        attn_output, _ = self.attn(x, context, context)\n","        return self.dropout(attn_output)\n","\n","# Transformer Encoder Block with Pre-learned Attention Conditioning\n","class TransformerEncoderBlock(nn.Module):\n","    def __init__(self, emb_dim=768, num_heads=8, hidden_dim=3, learned_attn_weights=None, dropout=0.1):\n","        super(TransformerEncoderBlock, self).__init__()\n","        self.attn = CrossAttentionLayer(emb_dim, num_heads, dropout)\n","        self.ff = nn.Sequential(\n","            nn.Linear(emb_dim, hidden_dim),\n","            nn.ReLU(),\n","            nn.Linear(hidden_dim, emb_dim),\n","        )\n","        self.norm1 = AttentionConditioning(emb_dim, learned_attn_weights) if learned_attn_weights is not None else nn.LayerNorm(emb_dim)\n","        self.norm2 = StyleAdaptiveLayerNorm(emb_dim)\n","        self.dropout = nn.Dropout(dropout)\n","\n","    def forward(self, x, style):\n","        attn_output = self.attn(x, x)\n","        x = x + self.dropout(attn_output)\n","        x = self.norm1(x)\n","\n","        ff_output = self.ff(x)\n","        x = x + self.dropout(ff_output)\n","        x = self.norm2(x)\n","        return x\n","\n","# Transformer Decoder Block\n","class TransformerDecoderBlock(nn.Module):\n","    def __init__(self, emb_dim=768, num_heads=8, hidden_dim=3, dropout=0.1):\n","        super(TransformerDecoderBlock, self).__init__()\n","        self.attn1 = CrossAttentionLayer(emb_dim, num_heads, dropout)\n","        self.attn2 = CrossAttentionLayer(emb_dim, num_heads, dropout)\n","        self.ff = nn.Sequential(\n","            nn.Linear(emb_dim, hidden_dim),\n","            nn.ReLU(),\n","            nn.Linear(hidden_dim, emb_dim),\n","        )\n","        self.norm1 = StyleAdaptiveLayerNorm(emb_dim)\n","        self.norm2 = StyleAdaptiveLayerNorm(emb_dim)\n","        self.norm3 = StyleAdaptiveLayerNorm(emb_dim)\n","\n","    def forward(self, x, enc_output, style):\n","        attn_output1 = self.attn1(x, x)\n","        x = x + attn_output1\n","        x = self.norm1(x)\n","\n","        attn_output2 = self.attn2(x, enc_output)\n","        x = x + attn_output2\n","        x = self.norm2(x)\n","\n","        ff_output = self.ff(x)\n","        x = x + ff_output\n","        x = self.norm3(x)\n","\n","        return x\n","\n","# Swin Transformer Block\n","class SwinTransformerBlock(nn.Module):\n","    def __init__(self, dim, num_heads, window_size=7, shift_size=2):\n","        super(SwinTransformerBlock, self).__init__()\n","        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)\n","        self.mlp = nn.Sequential(\n","            nn.Linear(dim, 4 * dim),\n","            nn.GELU(),\n","            nn.Linear(4 * dim, dim)\n","        )\n","        self.norm1 = nn.LayerNorm(dim)\n","        self.norm2 = nn.LayerNorm(dim)\n","\n","    def forward(self, x):\n","        shortcut = x\n","        x = self.norm1(x)\n","        x, _ = self.attn(x, x, x)\n","        x = shortcut + x\n","\n","        shortcut = x\n","        x = self.norm2(x)\n","        x = self.mlp(x)\n","        x = shortcut + x\n","\n","        return x\n","\n","# Refinement Block\n","class RefinementBlock(nn.Module):\n","    def __init__(self, in_channels=768, out_channels=3, kernel_size=3, stride=1, padding=1):\n","        super(RefinementBlock, self).__init__()\n","        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)\n","        self.bn = nn.BatchNorm2d(out_channels)\n","        self.relu = nn.ReLU(inplace=True)\n","\n","    def forward(self, x):\n","        x = self.conv(x)\n","        x = self.bn(x)\n","        x = self.relu(x)\n","        return x\n","\n","class MultiScaleStyleEncoder(nn.Module):\n","    def __init__(self, emb_dim):\n","        super().__init__()\n","        self.conv = nn.Conv2d(3, emb_dim, kernel_size=3, stride=2, padding=1)\n","        self.pool1 = nn.AdaptiveAvgPool2d(16)\n","        self.pool2 = nn.AdaptiveAvgPool2d(8)\n","        self.pool3 = nn.AdaptiveAvgPool2d(4)\n","        self.fc = nn.Linear(emb_dim * (16*16 + 8*8 + 4*4), emb_dim)\n","\n","    def forward(self, x):\n","        x = F.relu(self.conv(x))\n","        x1 = self.pool1(x).flatten(1)\n","        x2 = self.pool2(x).flatten(1)\n","        x3 = self.pool3(x).flatten(1)\n","        x = torch.cat([x1, x2, x3], dim=1)\n","        return self.fc(x)\n","\n","# Main RealFormer v3 with AGA Attention Conditioning\n","class RealFormerAGA(nn.Module):\n","    def __init__(self, img_size=512, patch_size=8, emb_dim=768, num_heads=12, num_layers=12, hidden_dim=3072, window_size=8, learned_attn_weights=None):\n","        super(RealFormerAGA, self).__init__()\n","        self.patch_embed = DynamicPatchEmbedding(in_channels=3, patch_size=patch_size, emb_dim=emb_dim, img_size=img_size)\n","\n","        # Encoder with pre-learned attention conditioning\n","        self.encoder_layers = nn.ModuleList([TransformerEncoderBlock(emb_dim, num_heads, hidden_dim, learned_attn_weights) for _ in range(num_layers)])\n","        self.decoder_layers = nn.ModuleList([TransformerDecoderBlock(emb_dim, num_heads, hidden_dim) for _ in range(num_layers)])\n","        self.swin_layers = nn.ModuleList([SwinTransformerBlock(emb_dim, num_heads, window_size) for _ in range(num_layers)])\n","\n","        self.refinement = RefinementBlock(in_channels=emb_dim, out_channels=3)\n","        self.final_layer = nn.Conv2d(3, 3, kernel_size=1)  # Adjust the input channels to 3\n","\n","        # Style encoder\n","        self.style_encoder = MultiScaleStyleEncoder(emb_dim)\n","\n","    def forward(self, frame_t, frame_t1, learned_attn_weights):\n","        # Patch embedding for consecutive frames\n","        x_t = self.patch_embed(frame_t)\n","        x_t1 = self.patch_embed(frame_t1)\n","\n","        # Style encoding from previous frames\n","        style_features = self.style_encoder(frame_t1)\n","\n","        # Conditioning learned attention weights prior to untrained attention\n","        for encoder in self.encoder_layers:\n","            x_t = encoder(x_t, style_features)\n","\n","        # Transformer decoder to reconstruct the next frame\n","        for decoder in self.decoder_layers:\n","            x_t1 = decoder(x_t1, x_t, style_features)\n","\n","        # Swin Transformer processing for temporal coherence\n","        for swin in self.swin_layers:\n","            x_t1 = swin(x_t1)\n","\n","        # Final refinement and output\n","        batch_size, num_patches, emb_dim = x_t1.shape\n","        h = w = int(math.sqrt(num_patches))  # Assuming square patches\n","        x_t1 = x_t1.transpose(1, 2).view(batch_size, emb_dim, h, w)\n","        \n","        x_t1 = self.refinement(x_t1)\n","        x_t1 = self.final_layer(x_t1)\n","        return x_t1\n","\n","# Loss functions remain the same\n","def total_variation_loss(x):\n","    return torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))\n","\n","def combined_loss(output, target):\n","    l1_loss = nn.L1Loss()(output, target)\n","    tv_loss = total_variation_loss(output)\n","    return l1_loss + 0.0001 * tv_loss\n","\n","def psnr(img1, img2):\n","    mse = torch.mean((img1 - img2) ** 2)\n","    if mse == 0:\n","        return float('inf')\n","    return 20 * torch.log10(1.0 / torch.sqrt(mse))"]},{"cell_type":"code","execution_count":154,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T14:46:11.045039Z","iopub.status.busy":"2024-10-06T14:46:11.044572Z","iopub.status.idle":"2024-10-06T14:46:11.868049Z","shell.execute_reply":"2024-10-06T14:46:11.867070Z","shell.execute_reply.started":"2024-10-06T14:46:11.044981Z"},"trusted":true},"outputs":[],"source":["from torch.utils.data import Dataset, DataLoader\n","import json\n","\n","class FeatureMapDataset(Dataset):\n","    def __init__(self, frames_dir, real_dir, json_file):\n","        self.frames_dir = frames_dir\n","        self.real_dir = real_dir\n","\n","        with open(json_file, 'r') as f:\n","            self.mappings = json.load(f)\n","\n","        self.frame_files = list(self.mappings.keys())  # List of frame filenames\n","\n","    def __len__(self):\n","        return len(self.frame_files)\n","\n","    def __getitem__(self, idx):\n","        frame_file = self.frame_files[idx]\n","        real_images = self.mappings[frame_file]\n","\n","        # Load frame feature map\n","        frame_feature = torch.load(os.path.join(self.frames_dir, frame_file))\n","\n","        # Load top real world image feature maps\n","        real_features = [torch.load(os.path.join(self.real_dir, img[0])) for img in real_images]\n","\n","        # Extract the top real image and its similarity score\n","        top_real_feature = real_features[0]\n","        top_similarity = real_images[0][1]\n","\n","        return frame_feature, top_real_feature, top_similarity, real_features[1:]\n","\n","# Define data loaders\n","frames_dir = '/kaggle/working/frames_features/'\n","real_dir = '/kaggle/working/real_features/'\n","json_file = '/kaggle/working/*.json'  # or gtav_mapillary.json\n","\n","dataset = FeatureMapDataset(frames_dir, real_dir, json_file)\n","dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)"]},{"cell_type":"code","execution_count":151,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T14:45:42.810459Z","iopub.status.busy":"2024-10-06T14:45:42.810043Z","iopub.status.idle":"2024-10-06T14:45:42.815339Z","shell.execute_reply":"2024-10-06T14:45:42.814119Z","shell.execute_reply.started":"2024-10-06T14:45:42.810415Z"},"trusted":true},"outputs":[],"source":["import torch\n","torch.cuda.empty_cache()"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T15:51:43.264775Z","iopub.status.busy":"2024-10-06T15:51:43.263890Z","iopub.status.idle":"2024-10-06T15:51:43.422577Z","shell.execute_reply":"2024-10-06T15:51:43.421024Z","shell.execute_reply.started":"2024-10-06T15:51:43.264719Z"},"trusted":true},"outputs":[],"source":["import os\n","import torchvision\n","\n","def setup_distributed():\n","    dist.init_process_group(backend='nccl')\n","    torch.cuda.set_device(args.local_rank)\n","\n","def contrastive_loss(anchor, positive, negatives, margin=0.2):\n","    # Cosine similarity between anchor and positive (\n","    pos_sim = F.cosine_similarity(anchor, positive, dim=-1)\n","\n","    # Cosine similarity between anchor and all negative examples\n","    neg_sims = [F.cosine_similarity(anchor, neg, dim=-1) for neg in negatives]\n","\n","    # Calculate loss\n","    loss = 0.0\n","    for neg_sim in neg_sims:\n","        loss += torch.clamp(margin + neg_sim - pos_sim, min=0.0)  # Margin-based contrastive loss\n","\n","    return loss.mean()\n","\n","# Training script\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class MultiClassNPairLoss(nn.Module):\n","    def __init__(self, temperature=0.07):\n","        super().__init__()\n","        self.temperature = temperature\n","\n","    def forward(self, anchor, positives, negatives):\n","        # Normalize embeddings\n","        anchor = F.normalize(anchor, dim=-1)\n","        positives = F.normalize(positives, dim=-1)\n","        negatives = F.normalize(negatives, dim=-1)\n","\n","        # Compute similarities\n","        pos_sim = torch.sum(anchor * positives, dim=-1) / self.temperature\n","        neg_sims = torch.sum(anchor.unsqueeze(1) * negatives, dim=-1) / self.temperature\n","\n","        # Compute loss\n","        logits = torch.cat([pos_sim.unsqueeze(1), neg_sims], dim=1)\n","        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)\n","        return F.cross_entropy(logits, labels)\n","\n","def train_contrastive(model, dataloader, feature_extractor, optimizer, num_epochs=10, alpha=1.0, beta=0.5):\n","    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","    model = nn.DataParallel(model).to(device)\n","    feature_extractor = feature_extractor.to(device)\n","    contrastive_loss_fn = MultiClassNPairLoss()\n","\n","    for epoch in range(num_epochs):\n","        model.train()\n","        running_loss = 0.0\n","\n","        for batch_idx, (synthetic_frame, real_images, similarity_scores) in enumerate(dataloader):\n","            synthetic_frame = synthetic_frame.to(device)\n","            real_images = [img.to(device) for img in real_images]\n","            similarity_scores = similarity_scores.to(device)\n","\n","            optimizer.zero_grad()\n","\n","            # Forward pass\n","            reconstructed_frame = model(synthetic_frame)\n","\n","            # Extract features from real images\n","            real_features = [feature_extractor(img) for img in real_images]\n","\n","            # Identify positive and negative pairs\n","            positive_idx = torch.argmax(similarity_scores, dim=1)\n","            positive_features = torch.stack([real_features[i][idx] for i, idx in enumerate(positive_idx)])\n","            negative_features = torch.stack([feat for i, feat in enumerate(real_features) if i != positive_idx])\n","\n","            # Compute losses\n","            reconstruction_loss = F.mse_loss(reconstructed_frame, synthetic_frame)\n","            contrastive_loss = contrastive_loss(reconstructed_frame, positive_features, negative_features)\n","\n","            # Combine losses\n","            total_loss = alpha * reconstruction_loss + beta * contrastive_loss\n","\n","            # Backpropagation and optimization\n","            total_loss.backward()\n","            optimizer.step()\n","\n","            running_loss += total_loss.item()\n","\n","            if batch_idx % 10 == 0:\n","                print(f\"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], \"\n","                      f\"Loss: {total_loss.item():.4f}, Recon Loss: {reconstruction_loss.item():.4f}, \"\n","                      f\"Contrastive Loss: {contrastive_loss.item():.4f}\")\n","\n","        epoch_loss = running_loss / len(dataloader)\n","        print(f\"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {epoch_loss:.4f}\")\n","\n","    torch.save(model.state_dict(), 'realformer_contrastivev4.pth')\n","\n","model = RealFormerAGA(img_size=256, patch_size=1, emb_dim=768, num_heads=32, num_layers=16, hidden_dim=3072)\n","feature_extractor = torchvision.models.resnet50(pretrained=True).features\n","optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n","\n","train_contrastive(model, dataloader, feature_extractor, optimizer, num_epochs=50, alpha=1.0, beta=0.5)"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T15:15:47.858959Z","iopub.status.busy":"2024-10-06T15:15:47.858119Z","iopub.status.idle":"2024-10-06T15:16:25.296512Z","shell.execute_reply":"2024-10-06T15:16:25.295405Z","shell.execute_reply.started":"2024-10-06T15:15:47.858916Z"},"trusted":true},"outputs":[],"source":["import os\n","from huggingface_hub import login, HfApi\n","\n","# Login to Hugging Face Hub\n","login(token=\"\")\n","\n","# Initialize the Hugging Face API\n","api = HfApi()\n","\n","# Specify the directory containing the models\n","model_directory = \"/kaggle/working/\"\n","repo_id = \"aoxo/RealFormer\"\n","repo_type = \"model\"\n","\n","# Loop through all files in the model directory\n","for filename in os.listdir(model_directory):\n","    # Only upload files that end with .pth\n","    if filename.endswith(\".pth\"):\n","        file_path = os.path.join(model_directory, filename)\n","        path_in_repo = filename  # Use the same filename in the repo\n","        \n","        # Upload the model file to the repository\n","        api.upload_file(\n","            path_or_fileobj=file_path,\n","            path_in_repo=path_in_repo,\n","            repo_id=repo_id,\n","            repo_type=repo_type,\n","        )\n","        print(f\"Uploaded {filename} to {repo_id} repository.\")"]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T15:51:48.356977Z","iopub.status.busy":"2024-10-06T15:51:48.355877Z","iopub.status.idle":"2024-10-06T15:51:59.708112Z","shell.execute_reply":"2024-10-06T15:51:59.706952Z","shell.execute_reply.started":"2024-10-06T15:51:48.356933Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["651863061\n","RealFormerAGA(\n","  (patch_embed): DynamicPatchEmbedding(\n","    (proj): Conv2d(3, 768, kernel_size=(1, 1), stride=(1, 1))\n","  )\n","  (encoder_layers): ModuleList(\n","    (0-15): 16 x TransformerEncoderBlock(\n","      (attn): CrossAttentionLayer(\n","        (attn): MultiheadAttention(\n","          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n","        )\n","        (dropout): Dropout(p=0.1, inplace=False)\n","      )\n","      (ff): Sequential(\n","        (0): Linear(in_features=768, out_features=3072, bias=True)\n","        (1): ReLU()\n","        (2): Linear(in_features=3072, out_features=768, bias=True)\n","      )\n","      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","      (norm2): StyleAdaptiveLayerNorm(\n","        (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (fc): Linear(in_features=768, out_features=1536, bias=True)\n","      )\n","      (dropout): Dropout(p=0.1, inplace=False)\n","    )\n","  )\n","  (decoder_layers): ModuleList(\n","    (0-15): 16 x TransformerDecoderBlock(\n","      (attn1): CrossAttentionLayer(\n","        (attn): MultiheadAttention(\n","          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n","        )\n","        (dropout): Dropout(p=0.1, inplace=False)\n","      )\n","      (attn2): CrossAttentionLayer(\n","        (attn): MultiheadAttention(\n","          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n","        )\n","        (dropout): Dropout(p=0.1, inplace=False)\n","      )\n","      (ff): Sequential(\n","        (0): Linear(in_features=768, out_features=3072, bias=True)\n","        (1): ReLU()\n","        (2): Linear(in_features=3072, out_features=768, bias=True)\n","      )\n","      (norm1): StyleAdaptiveLayerNorm(\n","        (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (fc): Linear(in_features=768, out_features=1536, bias=True)\n","      )\n","      (norm2): StyleAdaptiveLayerNorm(\n","        (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (fc): Linear(in_features=768, out_features=1536, bias=True)\n","      )\n","      (norm3): StyleAdaptiveLayerNorm(\n","        (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","        (fc): Linear(in_features=768, out_features=1536, bias=True)\n","      )\n","    )\n","  )\n","  (swin_layers): ModuleList(\n","    (0-15): 16 x SwinTransformerBlock(\n","      (attn): MultiheadAttention(\n","        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n","      )\n","      (mlp): Sequential(\n","        (0): Linear(in_features=768, out_features=3072, bias=True)\n","        (1): GELU(approximate='none')\n","        (2): Linear(in_features=3072, out_features=768, bias=True)\n","      )\n","      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n","    )\n","  )\n","  (refinement): RefinementBlock(\n","    (conv): Conv2d(768, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n","    (relu): ReLU(inplace=True)\n","  )\n","  (final_layer): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))\n","  (style_encoder): MultiScaleStyleEncoder(\n","    (conv): Conv2d(3, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n","    (pool1): AdaptiveAvgPool2d(output_size=16)\n","    (pool2): AdaptiveAvgPool2d(output_size=8)\n","    (pool3): AdaptiveAvgPool2d(output_size=4)\n","    (fc): Linear(in_features=258048, out_features=768, bias=True)\n","  )\n",")\n"]}],"source":["total_params = sum(p.numel() for p in model.parameters())\n","print(total_params)\n","print(model)\n","# torch.save(model.state_dict(), 'realformerv4.pth')\n","# Convert model to FP16 and save\n","model.half()\n","torch.save(model.state_dict(), 'realformerv4_fp16.pth')\n","# Convert model to BF16 and save\n","model.to(torch.bfloat16)\n","torch.save(model.state_dict(), 'realformerv4_bf16.pth')\n","import torch.quantization as quantization\n","\n","# Apply static quantization to the model\n","model_int8 = quantization.quantize_dynamic(\n","    model, {torch.nn.Linear}, dtype=torch.qint8\n",")\n","\n","# Save the INT8 quantized model\n","torch.save(model_int8.state_dict(), 'realformerv4_int8.pth')"]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"datasetId":5825636,"sourceId":9559983,"sourceType":"datasetVersion"}],"dockerImageVersionId":30786,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"tf","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"}},"nbformat":4,"nbformat_minor":4}