Coverage for mlair/model_modules/u_networks.py: 0%
2 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1__author__ = "Lukas Leufen"
2__date__ = "2022-08-29"
5from functools import partial
7from mlair.helpers import select_from_dict, to_list
8from mlair.model_modules.convolutional_networks import CNNfromConfig
9import tensorflow.keras as keras
12class UNet(CNNfromConfig): # pragma: no cover
13 """
14 A U-net neural network.
16 ```python
17 input_shape = [(65,1,9)]
18 output_shape = [(4, )]
20 # model
21 layer_configuration=[
23 # 1st block (down)
24 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"},
25 {"type": "Dropout", "rate": 0.25},
26 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"},
27 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)},
28 {"type": "blocksave"},
30 # 2nd block (down)
31 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"},
32 {"type": "Dropout", "rate": 0.25},
33 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"},
34 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)},
35 {"type": "blocksave"},
37 # 3rd block (down)
38 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"},
39 {"type": "Dropout", "rate": 0.25},
40 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"},
41 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)},
42 {"type": "blocksave"},
44 # 4th block (down)
45 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"},
46 {"type": "Dropout", "rate": 0.25},
47 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"},
48 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)},
49 {"type": "blocksave"},
51 # 5th block (final down)
52 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 256, "padding": "same"},
53 {"type": "Dropout", "rate": 0.25},
54 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 256, "padding": "same"},
56 # 6th block (up)
57 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 128, "strides": (2, 1),
58 "padding": "same"},
59 {"type": "ConcatenateUNet"},
60 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"},
61 {"type": "Dropout", "rate": 0.25},
62 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"},
64 # 7th block (up)
65 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 64, "strides": (2, 1),
66 "padding": "same"},
67 {"type": "ConcatenateUNet"},
68 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"},
69 {"type": "Dropout", "rate": 0.25},
70 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"},
72 # 8th block (up)
73 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 32, "strides": (2, 1),
74 "padding": "same"},
75 {"type": "ConcatenateUNet"},
76 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"},
77 {"type": "Dropout", "rate": 0.25},
78 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"},
80 # 8th block (up)
81 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 16, "strides": (2, 1),
82 "padding": "same"},
83 {"type": "ConcatenateUNet"},
84 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"},
85 {"type": "Dropout", "rate": 0.25},
86 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"},
88 # Tail
89 {"type": "Flatten"},
90 {"type": "Dense", "units": 128, "activation": "relu"}
91 ]
93 model = UNet(input_shape, output_shape, layer_configuration)
94 ```
97 """
99 def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs):
101 super().__init__(input_shape, output_shape, layer_configuration, optimizer=optimizer, **kwargs)
103 def _extract_layer_conf(self, layer_opts):
104 if layer_opts["type"] == "ConcatenateUNet":
105 layer = getattr(keras.layers, "Concatenate", None)
106 return layer, None, None
107 elif layer_opts["type"] == "blocksave":
108 return "blocksave", None, None
109 else:
110 return super()._extract_layer_conf(layer_opts)
112 def set_model(self):
113 x_input = keras.layers.Input(shape=self._input_shape)
114 x_in = x_input
115 block_save = []
117 for pos, layer_opts in enumerate(self.conf):
118 print(layer_opts)
119 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
120 if layer == "blocksave":
121 block_save.append(x_in)
122 continue
123 layer_name = self._get_layer_name(layer, layer_kwargs, pos)
124 if "Concatenate" in layer_name:
125 x_in = layer(name=layer_name)([x_in, block_save.pop(-1)])
126 self._layer_save.append({"layer": layer, "follow_up_layer": follow_up_layer})
127 continue
128 x_in = layer(**layer_kwargs, name=layer_name)(x_in)
129 if follow_up_layer is not None:
130 for follow_up in to_list(follow_up_layer):
131 layer_name = self._get_layer_name(follow_up, None, pos)
132 x_in = follow_up(name=layer_name)(x_in)
133 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer})
135 x_in = keras.layers.Dense(self._output_shape)(x_in)
136 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
137 self.model = keras.Model(inputs=x_input, outputs=[out])
138 print(self.model.summary())