File size: 1,617 Bytes
745a6b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import tensorflow as tf

def residual_block(inputs, filters):
    x = tf.keras.layers.Conv2D(filters, (3, 3), padding='same', activation='relu')(inputs)
    x = tf.keras.layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = tf.keras.layers.add([inputs, x])
    x = tf.keras.layers.Activation('relu')(x)
    return x

def get_model():
    inputs = tf.keras.layers.Input(shape=(None, None, 3))
    batch_size = tf.shape(inputs)[0]
    
    conv1 = tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)
    conv1 = tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu')(conv1)
    
    conv2 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(conv1)
    conv2 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(conv2)
    
    conv3 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu')(conv2)
    conv3 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu')(conv2)
    
    res1 = residual_block(conv3, 128)
    res2 = residual_block(res1, 128)
    res3 = residual_block(res2, 128)
    res4 = residual_block(res3, 128)
    res5 = residual_block(res4, 128)
    
    deconv1 = tf.keras.layers.Conv2DTranspose(64, (3, 3), padding='same', activation='relu')(res5)
    deconv2 = tf.keras.layers.Conv2DTranspose(32, (3, 3), padding='same', activation='relu')(deconv1)
    
    outputs = tf.keras.layers.Conv2D(3, (3, 3), padding='same', activation='sigmoid')(deconv2)
    outputs=tf.keras.layers.add([inputs, outputs])

    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    return model