aoxo
/

Image-to-Image
English
art
aoxo commited on
Commit
6fb541c
1 Parent(s): 62d76b0

Upload realformerv4.ipynb

Browse files
Files changed (1) hide show
  1. realformerv4.ipynb +1 -0
realformerv4.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells":[{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T14:52:05.628410Z","iopub.status.busy":"2024-10-06T14:52:05.627990Z","iopub.status.idle":"2024-10-06T14:52:17.944735Z","shell.execute_reply":"2024-10-06T14:52:17.943542Z","shell.execute_reply.started":"2024-10-06T14:52:05.628372Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import math\n","from einops import rearrange\n","\n","# Patch Embedding with Dynamic Positional Encoding\n","class DynamicPatchEmbedding(nn.Module):\n"," def __init__(self, in_channels=3, patch_size=8, emb_dim=768, img_size=256):\n"," super(DynamicPatchEmbedding, self).__init__()\n"," self.patch_size = patch_size\n"," self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)\n"," self.num_patches = (img_size // patch_size) ** 2\n","\n"," def forward(self, x):\n"," # Reshape the input to 4D if necessary\n"," if len(x.shape) == 2:\n"," batch_size = x.shape[0]\n"," channels = 3 # Assuming 3 feature channels\n"," h = w = int(math.sqrt(x.shape[1] // channels)) # Infer height and width\n"," x = x.view(batch_size, channels, h, w) # Reshape to [batch_size, channels, height, width]\n"," \n"," x = self.proj(x) # Apply Conv2d\n"," x = x.flatten(2).transpose(1, 2) # (batch_size, num_patches, emb_dim)\n"," return x\n"," \n","# Style Adaptive Layer Normalization (SALN)\n","class StyleAdaptiveLayerNorm(nn.Module):\n"," def __init__(self, emb_dim):\n"," super(StyleAdaptiveLayerNorm, self).__init__()\n"," self.norm = nn.LayerNorm(emb_dim)\n"," self.fc = nn.Linear(emb_dim, emb_dim * 2)\n","\n"," def forward(self, x, style):\n"," style = self.fc(style).unsqueeze(1)\n"," gamma, beta = style.chunk(2, dim=-1)\n"," normalized_x = self.norm(x)\n"," return gamma * normalized_x + beta\n","\n","# LayerNorm-based Attention Conditioning using Pre-learned Attention Weights\n","class AttentionConditioning(nn.Module):\n"," def __init__(self, emb_dim, learned_attn_weights):\n"," super(AttentionConditioning, self).__init__()\n"," self.learned_attn_weights = learned_attn_weights\n"," self.norm = nn.LayerNorm(emb_dim)\n","\n"," def forward(self, untrained_attn_weights):\n"," # Condition untrained weights using pre-learned attention weights\n"," conditioned_weights = self.learned_attn_weights + untrained_attn_weights\n"," return self.norm(conditioned_weights)\n","\n","# Cross-Attention Layer\n","class CrossAttentionLayer(nn.Module):\n"," def __init__(self, emb_dim, num_heads, dropout=0.1):\n"," super(CrossAttentionLayer, self).__init__()\n"," self.attn = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, batch_first=True)\n"," self.dropout = nn.Dropout(dropout)\n","\n"," def forward(self, x, context):\n"," attn_output, _ = self.attn(x, context, context)\n"," return self.dropout(attn_output)\n","\n","# Transformer Encoder Block with Pre-learned Attention Conditioning\n","class TransformerEncoderBlock(nn.Module):\n"," def __init__(self, emb_dim=768, num_heads=8, hidden_dim=3, learned_attn_weights=None, dropout=0.1):\n"," super(TransformerEncoderBlock, self).__init__()\n"," self.attn = CrossAttentionLayer(emb_dim, num_heads, dropout)\n"," self.ff = nn.Sequential(\n"," nn.Linear(emb_dim, hidden_dim),\n"," nn.ReLU(),\n"," nn.Linear(hidden_dim, emb_dim),\n"," )\n"," self.norm1 = AttentionConditioning(emb_dim, learned_attn_weights) if learned_attn_weights is not None else nn.LayerNorm(emb_dim)\n"," self.norm2 = StyleAdaptiveLayerNorm(emb_dim)\n"," self.dropout = nn.Dropout(dropout)\n","\n"," def forward(self, x, style):\n"," attn_output = self.attn(x, x)\n"," x = x + self.dropout(attn_output)\n"," x = self.norm1(x)\n","\n"," ff_output = self.ff(x)\n"," x = x + self.dropout(ff_output)\n"," x = self.norm2(x)\n"," return x\n","\n","# Transformer Decoder Block\n","class TransformerDecoderBlock(nn.Module):\n"," def __init__(self, emb_dim=768, num_heads=8, hidden_dim=3, dropout=0.1):\n"," super(TransformerDecoderBlock, self).__init__()\n"," self.attn1 = CrossAttentionLayer(emb_dim, num_heads, dropout)\n"," self.attn2 = CrossAttentionLayer(emb_dim, num_heads, dropout)\n"," self.ff = nn.Sequential(\n"," nn.Linear(emb_dim, hidden_dim),\n"," nn.ReLU(),\n"," nn.Linear(hidden_dim, emb_dim),\n"," )\n"," self.norm1 = StyleAdaptiveLayerNorm(emb_dim)\n"," self.norm2 = StyleAdaptiveLayerNorm(emb_dim)\n"," self.norm3 = StyleAdaptiveLayerNorm(emb_dim)\n","\n"," def forward(self, x, enc_output, style):\n"," attn_output1 = self.attn1(x, x)\n"," x = x + attn_output1\n"," x = self.norm1(x)\n","\n"," attn_output2 = self.attn2(x, enc_output)\n"," x = x + attn_output2\n"," x = self.norm2(x)\n","\n"," ff_output = self.ff(x)\n"," x = x + ff_output\n"," x = self.norm3(x)\n","\n"," return x\n","\n","# Swin Transformer Block\n","class SwinTransformerBlock(nn.Module):\n"," def __init__(self, dim, num_heads, window_size=7, shift_size=2):\n"," super(SwinTransformerBlock, self).__init__()\n"," self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)\n"," self.mlp = nn.Sequential(\n"," nn.Linear(dim, 4 * dim),\n"," nn.GELU(),\n"," nn.Linear(4 * dim, dim)\n"," )\n"," self.norm1 = nn.LayerNorm(dim)\n"," self.norm2 = nn.LayerNorm(dim)\n","\n"," def forward(self, x):\n"," shortcut = x\n"," x = self.norm1(x)\n"," x, _ = self.attn(x, x, x)\n"," x = shortcut + x\n","\n"," shortcut = x\n"," x = self.norm2(x)\n"," x = self.mlp(x)\n"," x = shortcut + x\n","\n"," return x\n","\n","# Refinement Block\n","class RefinementBlock(nn.Module):\n"," def __init__(self, in_channels=768, out_channels=3, kernel_size=3, stride=1, padding=1):\n"," super(RefinementBlock, self).__init__()\n"," self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)\n"," self.bn = nn.BatchNorm2d(out_channels)\n"," self.relu = nn.ReLU(inplace=True)\n","\n"," def forward(self, x):\n"," x = self.conv(x)\n"," x = self.bn(x)\n"," x = self.relu(x)\n"," return x\n","\n","class MultiScaleStyleEncoder(nn.Module):\n"," def __init__(self, emb_dim):\n"," super().__init__()\n"," self.conv = nn.Conv2d(3, emb_dim, kernel_size=3, stride=2, padding=1)\n"," self.pool1 = nn.AdaptiveAvgPool2d(16)\n"," self.pool2 = nn.AdaptiveAvgPool2d(8)\n"," self.pool3 = nn.AdaptiveAvgPool2d(4)\n"," self.fc = nn.Linear(emb_dim * (16*16 + 8*8 + 4*4), emb_dim)\n","\n"," def forward(self, x):\n"," x = F.relu(self.conv(x))\n"," x1 = self.pool1(x).flatten(1)\n"," x2 = self.pool2(x).flatten(1)\n"," x3 = self.pool3(x).flatten(1)\n"," x = torch.cat([x1, x2, x3], dim=1)\n"," return self.fc(x)\n","\n","# Main RealFormer v3 with AGA Attention Conditioning\n","class RealFormerAGA(nn.Module):\n"," def __init__(self, img_size=512, patch_size=8, emb_dim=768, num_heads=12, num_layers=12, hidden_dim=3072, window_size=8, learned_attn_weights=None):\n"," super(RealFormerAGA, self).__init__()\n"," self.patch_embed = DynamicPatchEmbedding(in_channels=3, patch_size=patch_size, emb_dim=emb_dim, img_size=img_size)\n","\n"," # Encoder with pre-learned attention conditioning\n"," self.encoder_layers = nn.ModuleList([TransformerEncoderBlock(emb_dim, num_heads, hidden_dim, learned_attn_weights) for _ in range(num_layers)])\n"," self.decoder_layers = nn.ModuleList([TransformerDecoderBlock(emb_dim, num_heads, hidden_dim) for _ in range(num_layers)])\n"," self.swin_layers = nn.ModuleList([SwinTransformerBlock(emb_dim, num_heads, window_size) for _ in range(num_layers)])\n","\n"," self.refinement = RefinementBlock(in_channels=emb_dim, out_channels=3)\n"," self.final_layer = nn.Conv2d(3, 3, kernel_size=1) # Adjust the input channels to 3\n","\n"," # Style encoder\n"," self.style_encoder = MultiScaleStyleEncoder(emb_dim)\n","\n"," def forward(self, frame_t, frame_t1, learned_attn_weights):\n"," # Patch embedding for consecutive frames\n"," x_t = self.patch_embed(frame_t)\n"," x_t1 = self.patch_embed(frame_t1)\n","\n"," # Style encoding from previous frames\n"," style_features = self.style_encoder(frame_t1)\n","\n"," # Conditioning learned attention weights prior to untrained attention\n"," for encoder in self.encoder_layers:\n"," x_t = encoder(x_t, style_features)\n","\n"," # Transformer decoder to reconstruct the next frame\n"," for decoder in self.decoder_layers:\n"," x_t1 = decoder(x_t1, x_t, style_features)\n","\n"," # Swin Transformer processing for temporal coherence\n"," for swin in self.swin_layers:\n"," x_t1 = swin(x_t1)\n","\n"," # Final refinement and output\n"," batch_size, num_patches, emb_dim = x_t1.shape\n"," h = w = int(math.sqrt(num_patches)) # Assuming square patches\n"," x_t1 = x_t1.transpose(1, 2).view(batch_size, emb_dim, h, w)\n"," \n"," x_t1 = self.refinement(x_t1)\n"," x_t1 = self.final_layer(x_t1)\n"," return x_t1\n","\n","# Loss functions remain the same\n","def total_variation_loss(x):\n"," return torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))\n","\n","def combined_loss(output, target):\n"," l1_loss = nn.L1Loss()(output, target)\n"," tv_loss = total_variation_loss(output)\n"," return l1_loss + 0.0001 * tv_loss\n","\n","def psnr(img1, img2):\n"," mse = torch.mean((img1 - img2) ** 2)\n"," if mse == 0:\n"," return float('inf')\n"," return 20 * torch.log10(1.0 / torch.sqrt(mse))"]},{"cell_type":"code","execution_count":154,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T14:46:11.045039Z","iopub.status.busy":"2024-10-06T14:46:11.044572Z","iopub.status.idle":"2024-10-06T14:46:11.868049Z","shell.execute_reply":"2024-10-06T14:46:11.867070Z","shell.execute_reply.started":"2024-10-06T14:46:11.044981Z"},"trusted":true},"outputs":[],"source":["from torch.utils.data import Dataset, DataLoader\n","import json\n","\n","class FeatureMapDataset(Dataset):\n"," def __init__(self, frames_dir, real_dir, json_file):\n"," self.frames_dir = frames_dir\n"," self.real_dir = real_dir\n","\n"," with open(json_file, 'r') as f:\n"," self.mappings = json.load(f)\n","\n"," self.frame_files = list(self.mappings.keys()) # List of frame filenames\n","\n"," def __len__(self):\n"," return len(self.frame_files)\n","\n"," def __getitem__(self, idx):\n"," frame_file = self.frame_files[idx]\n"," real_images = self.mappings[frame_file]\n","\n"," # Load frame feature map\n"," frame_feature = torch.load(os.path.join(self.frames_dir, frame_file))\n","\n"," # Load top real world image feature maps\n"," real_features = [torch.load(os.path.join(self.real_dir, img[0])) for img in real_images]\n","\n"," # Extract the top real image and its similarity score\n"," top_real_feature = real_features[0]\n"," top_similarity = real_images[0][1]\n","\n"," return frame_feature, top_real_feature, top_similarity, real_features[1:]\n","\n","# Define data loaders\n","frames_dir = '/kaggle/working/frames_features/'\n","real_dir = '/kaggle/working/real_features/'\n","json_file = '/kaggle/working/*.json' # or gtav_mapillary.json\n","\n","dataset = FeatureMapDataset(frames_dir, real_dir, json_file)\n","dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)"]},{"cell_type":"code","execution_count":151,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T14:45:42.810459Z","iopub.status.busy":"2024-10-06T14:45:42.810043Z","iopub.status.idle":"2024-10-06T14:45:42.815339Z","shell.execute_reply":"2024-10-06T14:45:42.814119Z","shell.execute_reply.started":"2024-10-06T14:45:42.810415Z"},"trusted":true},"outputs":[],"source":["import torch\n","torch.cuda.empty_cache()"]},{"cell_type":"code","execution_count":2,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T15:51:43.264775Z","iopub.status.busy":"2024-10-06T15:51:43.263890Z","iopub.status.idle":"2024-10-06T15:51:43.422577Z","shell.execute_reply":"2024-10-06T15:51:43.421024Z","shell.execute_reply.started":"2024-10-06T15:51:43.264719Z"},"trusted":true},"outputs":[{"ename":"NameError","evalue":"name 'dataloader' is not defined","output_type":"error","traceback":["\u001b[1;31m---------------------------------------------------------------------------\u001b[0m","\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)","Cell \u001b[1;32mIn[2], line 78\u001b[0m\n\u001b[0;32m 75\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39moptim\u001b[38;5;241m.\u001b[39mAdam(model\u001b[38;5;241m.\u001b[39mparameters(), lr\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.001\u001b[39m)\n\u001b[0;32m 77\u001b[0m \u001b[38;5;66;03m# Start training\u001b[39;00m\n\u001b[1;32m---> 78\u001b[0m train_contrastive(model, \u001b[43mdataloader\u001b[49m, optimizer, num_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m50\u001b[39m, margin\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.2\u001b[39m)\n","\u001b[1;31mNameError\u001b[0m: name 'dataloader' is not defined"]}],"source":["import os\n","\n","def setup_distributed():\n"," dist.init_process_group(backend='nccl')\n"," torch.cuda.set_device(args.local_rank)\n","\n","def contrastive_loss(anchor, positive, negatives, margin=0.2):\n"," # Cosine similarity between anchor and positive (\n"," pos_sim = F.cosine_similarity(anchor, positive, dim=-1)\n","\n"," # Cosine similarity between anchor and all negative examples\n"," neg_sims = [F.cosine_similarity(anchor, neg, dim=-1) for neg in negatives]\n","\n"," # Calculate loss\n"," loss = 0.0\n"," for neg_sim in neg_sims:\n"," loss += torch.clamp(margin + neg_sim - pos_sim, min=0.0) # Margin-based contrastive loss\n","\n"," return loss.mean()\n","\n","# Training script\n","def train_contrastive(model, dataloader, optimizer, num_epochs=10, margin=0.2):\n"," device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"," model = nn.DataParallel(model, device_ids = [0,1])\n"," model.to(device)\n","\n"," best_loss = float('inf')\n","\n"," for epoch in range(num_epochs):\n"," model.train()\n"," running_loss = 0.0\n"," running_psnr = 0.0\n","\n"," for batch_idx, (frame_feature, top_real_feature, top_similarity, other_real_features) in enumerate(dataloader):\n"," frame_feature = frame_feature.to(device)\n"," top_real_feature = top_real_feature.to(device)\n"," other_real_features = [neg.to(device) for neg in other_real_features]\n","\n"," optimizer.zero_grad()\n","\n"," # Forward pass\n"," output = model(frame_feature, top_real_feature)\n","\n"," # Compute contrastive loss\n"," loss = contrastive_loss(output, top_real_feature, other_real_features, margin=margin)\n","\n"," # Backpropagation and optimization\n"," loss.backward()\n"," optimizer.step()\n","\n"," running_loss += loss.item()\n","\n"," # PSNR metric computation\n"," psnr_value = psnr(output, top_real_feature)\n"," running_psnr += psnr_value\n","\n"," # Print training status\n"," if batch_idx % 10 == 0:\n"," print(f\"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item()}, PSNR: {psnr_value:.4f}\")\n","\n"," # Epoch-level metrics\n"," epoch_loss = running_loss / len(dataloader)\n"," avg_psnr = running_psnr / len(dataloader)\n","\n"," print(f\"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {epoch_loss:.4f}, Avg PSNR: {avg_psnr:.4f}\")\n","\n"," # Save the best model\n"," if epoch_loss < best_loss:\n"," best_loss = epoch_loss\n"," torch.save(model.state_dict(), 'realformerv4.pth')\n"," print(f\"Model saved at epoch {epoch+1} with loss {best_loss:.4f}\")\n","\n","# Optimizer setup\n","model = RealFormerAGA(img_size=256, patch_size=1, emb_dim=768, num_heads=32, num_layers=16, hidden_dim=3072)\n","optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n","\n","# Start training\n","train_contrastive(model, dataloader, optimizer, num_epochs=50, margin=0.2)"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T15:15:47.858959Z","iopub.status.busy":"2024-10-06T15:15:47.858119Z","iopub.status.idle":"2024-10-06T15:16:25.296512Z","shell.execute_reply":"2024-10-06T15:16:25.295405Z","shell.execute_reply.started":"2024-10-06T15:15:47.858916Z"},"trusted":true},"outputs":[],"source":["import os\n","from huggingface_hub import login, HfApi\n","\n","# Login to Hugging Face Hub\n","login(token=\"\")\n","\n","# Initialize the Hugging Face API\n","api = HfApi()\n","\n","# Specify the directory containing the models\n","model_directory = \"/kaggle/working/\"\n","repo_id = \"aoxo/RealFormer\"\n","repo_type = \"model\"\n","\n","# Loop through all files in the model directory\n","for filename in os.listdir(model_directory):\n"," # Only upload files that end with .pth\n"," if filename.endswith(\".pth\"):\n"," file_path = os.path.join(model_directory, filename)\n"," path_in_repo = filename # Use the same filename in the repo\n"," \n"," # Upload the model file to the repository\n"," api.upload_file(\n"," path_or_fileobj=file_path,\n"," path_in_repo=path_in_repo,\n"," repo_id=repo_id,\n"," repo_type=repo_type,\n"," )\n"," print(f\"Uploaded {filename} to {repo_id} repository.\")"]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2024-10-06T15:51:48.356977Z","iopub.status.busy":"2024-10-06T15:51:48.355877Z","iopub.status.idle":"2024-10-06T15:51:59.708112Z","shell.execute_reply":"2024-10-06T15:51:59.706952Z","shell.execute_reply.started":"2024-10-06T15:51:48.356933Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["651863061\n","RealFormerAGA(\n"," (patch_embed): DynamicPatchEmbedding(\n"," (proj): Conv2d(3, 768, kernel_size=(1, 1), stride=(1, 1))\n"," )\n"," (encoder_layers): ModuleList(\n"," (0-15): 16 x TransformerEncoderBlock(\n"," (attn): CrossAttentionLayer(\n"," (attn): MultiheadAttention(\n"," (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n"," )\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (ff): Sequential(\n"," (0): Linear(in_features=768, out_features=3072, bias=True)\n"," (1): ReLU()\n"," (2): Linear(in_features=3072, out_features=768, bias=True)\n"," )\n"," (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (norm2): StyleAdaptiveLayerNorm(\n"," (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (fc): Linear(in_features=768, out_features=1536, bias=True)\n"," )\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," )\n"," (decoder_layers): ModuleList(\n"," (0-15): 16 x TransformerDecoderBlock(\n"," (attn1): CrossAttentionLayer(\n"," (attn): MultiheadAttention(\n"," (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n"," )\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (attn2): CrossAttentionLayer(\n"," (attn): MultiheadAttention(\n"," (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n"," )\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (ff): Sequential(\n"," (0): Linear(in_features=768, out_features=3072, bias=True)\n"," (1): ReLU()\n"," (2): Linear(in_features=3072, out_features=768, bias=True)\n"," )\n"," (norm1): StyleAdaptiveLayerNorm(\n"," (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (fc): Linear(in_features=768, out_features=1536, bias=True)\n"," )\n"," (norm2): StyleAdaptiveLayerNorm(\n"," (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (fc): Linear(in_features=768, out_features=1536, bias=True)\n"," )\n"," (norm3): StyleAdaptiveLayerNorm(\n"," (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (fc): Linear(in_features=768, out_features=1536, bias=True)\n"," )\n"," )\n"," )\n"," (swin_layers): ModuleList(\n"," (0-15): 16 x SwinTransformerBlock(\n"," (attn): MultiheadAttention(\n"," (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n"," )\n"," (mlp): Sequential(\n"," (0): Linear(in_features=768, out_features=3072, bias=True)\n"," (1): GELU(approximate='none')\n"," (2): Linear(in_features=3072, out_features=768, bias=True)\n"," )\n"," (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," )\n"," )\n"," (refinement): RefinementBlock(\n"," (conv): Conv2d(768, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (relu): ReLU(inplace=True)\n"," )\n"," (final_layer): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))\n"," (style_encoder): MultiScaleStyleEncoder(\n"," (conv): Conv2d(3, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n"," (pool1): AdaptiveAvgPool2d(output_size=16)\n"," (pool2): AdaptiveAvgPool2d(output_size=8)\n"," (pool3): AdaptiveAvgPool2d(output_size=4)\n"," (fc): Linear(in_features=258048, out_features=768, bias=True)\n"," )\n",")\n"]}],"source":["total_params = sum(p.numel() for p in model.parameters())\n","print(total_params)\n","print(model)\n","torch.save(model.state_dict(), 'realformerv4.pth')\n","# Convert model to FP16 and save\n","model.half()\n","torch.save(model.state_dict(), 'realformerv4_fp16.pth')\n","# Convert model to BF16 and save\n","model.to(torch.bfloat16)\n","torch.save(model.state_dict(), 'realformerv4_bf16.pth')\n","import torch.quantization as quantization\n","\n","# Apply static quantization to the model\n","model_int8 = quantization.quantize_dynamic(\n"," model, {torch.nn.Linear}, dtype=torch.qint8\n",")\n","\n","# Save the INT8 quantized model\n","torch.save(model_int8.state_dict(), 'realformerv4_int8.pth')"]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"datasetId":5825636,"sourceId":9559983,"sourceType":"datasetVersion"}],"dockerImageVersionId":30786,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"tf","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"}},"nbformat":4,"nbformat_minor":4}