Spaces:
Runtime error
Runtime error
File size: 3,066 Bytes
b442155 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
from typing import Optional, List
from dataclasses import dataclass, field
from omegaconf import OmegaConf
@dataclass
class DataConfig:
dataset: Optional[str] = None
tokenizer_type: str = 'CharBPE'
context_length: int = 64
image_resolution: int = 256
transforms: str = 'dalle-vqvae'
bpe_pdrop: Optional[float] = None
@dataclass
class Stage1Hparams:
double_z: bool = False
z_channels: int = 256
resolution: int = 256
in_channels: int = 3
out_ch: int = 3
ch: int = 128
ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
num_res_blocks: int = 2
attn_resolutions: List[int] = field(default_factory=lambda: [16])
pdrop: float = 0.0
@dataclass
class Stage2Hparams:
embed_dim: int = 1536
n_layers: int = 42
n_heads: int = 24
n_dense_layers: int = 42
ctx_len_img: int = 256
ctx_len_txt: int = 64
embd_pdrop: float = 0.0
resid_pdrop: float = 0.0
attn_pdrop: float = 0.0
mlp_bias: bool = True
attn_bias: bool = True
gelu_use_approx: bool = False
use_head_txt: bool = True
n_classes: Optional[int] = None
@dataclass
class Stage1Config:
type: str = 'vqgan'
embed_dim: int = 256
n_embed: int = 16384
hparams: Stage1Hparams = Stage1Hparams()
@dataclass
class Stage2Config:
type: str = 'transformer1d'
vocab_size_txt: int = 16384
vocab_size_img: int = 16384
use_cls_cond: Optional[bool] = None
hparams: Stage2Hparams = Stage2Hparams()
@dataclass
class WarmupConfig:
epoch: int = 1
multiplier: int = 1
buffer_epoch: int = 0
min_lr: float = 0.0
mode: str = 'fix'
peak_lr: float = 1e-4
start_from_zero: bool = True
@dataclass
class OptConfig:
opt_type: str = 'adamW'
base_lr: float = 1e-4
weight_decay: float = 1e-4
betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
grad_clip_norm: float = 1.0
sched_type: str = 'cosine'
max_steps: int = 0
min_lr: float = 0.0
@dataclass
class ExpConfig:
local_batch_size: int = 4
total_batch_size: int = 512
valid_batch_size: int = 32
epochs: int = 10
save_ckpt_freq: int = 2
test_freq: int = 1
use_amp: bool = True
@dataclass
class DefaultConfig:
dataset: DataConfig = DataConfig()
stage1: Stage1Config = Stage1Config()
stage2: Stage2Config = Stage2Config()
@dataclass
class FineTuningConfig:
dataset: DataConfig = DataConfig()
stage1: Stage1Config = Stage1Config()
stage2: Stage2Config = Stage2Config()
optimizer: OptConfig = OptConfig()
experiment: ExpConfig = ExpConfig()
def get_base_config(use_default=True):
return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)
|