Coverage for mlair/model_modules/branched_input_networks.py: 100%
0 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1import logging
2from functools import partial, reduce
3import copy
4from typing import Union
6from tensorflow import keras as keras
8from mlair import AbstractModelClass
9from mlair.helpers import select_from_dict, to_list
10from mlair.model_modules.loss import var_loss
11from mlair.model_modules.recurrent_networks import RNN
12from mlair.model_modules.convolutional_networks import CNNfromConfig
13from mlair.model_modules.residual_networks import ResNet
14from mlair.model_modules.u_networks import UNet
17class BranchedInputCNN(CNNfromConfig): # pragma: no cover
18 """A convolutional neural network with multiple input branches."""
20 def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs):
22 super().__init__([input_shape], output_shape, layer_configuration, optimizer=optimizer, **kwargs)
24 def set_model(self):
26 x_input = []
27 x_in = []
28 stop_pos = None
30 for branch in range(len(self._input_shape)):
31 print(branch)
32 shape_b = self._input_shape[branch]
33 x_input_b = keras.layers.Input(shape=shape_b, name=f"input_branch{branch + 1}")
34 x_input.append(x_input_b)
35 x_in_b = x_input_b
36 b_conf = copy.deepcopy(self.conf)
38 for pos, layer_opts in enumerate(b_conf):
39 print(layer_opts)
40 if layer_opts.get("type") == "Concatenate":
41 if stop_pos is None:
42 stop_pos = pos
43 else:
44 assert pos == stop_pos
45 break
46 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
47 layer_name = self._get_layer_name(layer, layer_kwargs, pos, branch)
48 x_in_b = layer(**layer_kwargs, name=layer_name)(x_in_b)
49 if follow_up_layer is not None:
50 for follow_up in to_list(follow_up_layer):
51 layer_name = self._get_layer_name(follow_up, None, pos, branch)
52 x_in_b = follow_up(name=layer_name)(x_in_b)
53 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
54 "branch": branch})
55 x_in.append(x_in_b)
57 print("concat")
58 x_concat = keras.layers.Concatenate()(x_in)
60 if stop_pos is not None:
61 for pos, layer_opts in enumerate(self.conf[stop_pos + 1:]):
62 print(layer_opts)
63 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
64 layer_name = self._get_layer_name(layer, layer_kwargs, pos + stop_pos, None)
65 x_concat = layer(**layer_kwargs, name=layer_name)(x_concat)
66 if follow_up_layer is not None:
67 for follow_up in to_list(follow_up_layer):
68 layer_name = self._get_layer_name(follow_up, None, pos + stop_pos, None)
69 x_concat = follow_up(name=layer_name)(x_concat)
70 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
71 "branch": "concat"})
73 x_concat = keras.layers.Dense(self._output_shape)(x_concat)
74 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
75 self.model = keras.Model(inputs=x_input, outputs=[out])
76 print(self.model.summary())
78 @staticmethod
79 def _get_layer_name(layer: keras.layers, layer_kwargs: Union[dict, None], pos: int, branch: int = None):
80 if isinstance(layer, partial):
81 name = layer.args[0] if layer.func.__name__ == "Activation" else layer.func.__name__
82 else:
83 name = layer.__name__
84 if "Conv" in name and isinstance(layer_kwargs, dict) and "kernel_size" in layer_kwargs:
85 name = name + "_" + "x".join(map(str, layer_kwargs["kernel_size"]))
86 if "Pooling" in name and isinstance(layer_kwargs, dict) and "pool_size" in layer_kwargs:
87 name = name + "_" + "x".join(map(str, layer_kwargs["pool_size"]))
88 if branch is not None:
89 name += f"_branch{branch + 1}"
90 name += f"_{pos + 1}"
91 return name
94class BranchedInputRNN(RNN): # pragma: no cover
95 """A recurrent neural network with multiple input branches."""
97 def __init__(self, input_shape, output_shape, *args, **kwargs):
99 super().__init__([input_shape], output_shape, *args, **kwargs)
101 def set_model(self):
102 """
103 Build the model.
104 """
105 if isinstance(self.layer_configuration, tuple) is True:
106 n_layer, n_hidden = self.layer_configuration
107 conf = [n_hidden for _ in range(n_layer)]
108 else:
109 assert isinstance(self.layer_configuration, list) is True
110 conf = self.layer_configuration
112 x_input = []
113 x_in = []
115 for branch in range(len(self._input_shape)):
116 shape_b = self._input_shape[branch]
117 x_input_b = keras.layers.Input(shape=shape_b)
118 x_input.append(x_input_b)
119 x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])),
120 name=f"reshape_branch{branch + 1}")(x_input_b)
122 for layer, n_hidden in enumerate(conf):
123 return_sequences = (layer < len(conf) - 1)
124 x_in_b = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn,
125 name=f"{self.RNN.__name__}_branch{branch + 1}_{layer + 1}",
126 kernel_regularizer=self.kernel_regularizer)(x_in_b)
127 if self.bn is True:
128 x_in_b = keras.layers.BatchNormalization()(x_in_b)
129 x_in_b = self.activation_rnn(name=f"{self.activation_rnn_name}_branch{branch + 1}_{layer + 1}")(x_in_b)
130 if self.dropout is not None:
131 x_in_b = self.dropout(self.dropout_rate)(x_in_b)
132 x_in.append(x_in_b)
133 x_concat = keras.layers.Concatenate()(x_in)
135 if self.add_dense_layer is True:
136 if len(self.dense_layer_configuration) == 0:
137 x_concat = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}",
138 kernel_initializer=self.kernel_initializer, )(x_concat)
139 x_concat = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_concat)
140 if self.dropout is not None:
141 x_concat = self.dropout(self.dropout_rate)(x_concat)
142 else:
143 for layer, n_hidden in enumerate(self.dense_layer_configuration):
144 if n_hidden < self._output_shape:
145 break
146 x_concat = keras.layers.Dense(n_hidden, name=f"Dense_{len(conf) + layer + 1}",
147 kernel_initializer=self.kernel_initializer, )(x_concat)
148 x_concat = self.activation(name=f"{self.activation_name}_{len(conf) + layer + 1}")(x_concat)
149 if self.dropout is not None:
150 x_concat = self.dropout(self.dropout_rate)(x_concat)
152 x_concat = keras.layers.Dense(self._output_shape)(x_concat)
153 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
154 self.model = keras.Model(inputs=x_input, outputs=[out])
155 print(self.model.summary())
157 def set_compile_options(self):
158 self.compile_options = {"loss": [keras.losses.mean_squared_error],
159 "metrics": ["mse", "mae", var_loss]}
161 def _update_model_name(self, rnn_type):
162 n_input = f"{len(self._input_shape)}x{self._input_shape[0][0]}x" \
163 f"{str(reduce(lambda x, y: x * y, self._input_shape[0][1:]))}"
164 n_output = str(self._output_shape)
165 self.model_name = rnn_type.upper()
166 if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
167 n_layer, n_hidden = self.layer_configuration
168 branch = [f"r{n_hidden}" for _ in range(n_layer)]
169 else:
170 branch = [f"r{n}" for n in self.layer_configuration]
172 concat = []
173 if self.add_dense_layer is True:
174 if len(self.dense_layer_configuration) == 0:
175 n_hidden = min(self._output_shape ** 2, int(branch[-1]))
176 concat.append(f"1x{n_hidden}")
177 else:
178 for n_hidden in self.dense_layer_configuration:
179 if n_hidden < self._output_shape:
180 break
181 if len(concat) == 0:
182 concat.append(f"1x{n_hidden}")
183 else:
184 concat.append(str(n_hidden))
185 self.model_name += "_".join(["", n_input, *branch, *concat, n_output])
188class BranchedInputFCN(AbstractModelClass): # pragma: no cover
189 """
190 A fully connected network that uses multiple input branches that are combined by a concatenate layer.
191 """
193 _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"),
194 "sigmoid": partial(keras.layers.Activation, "sigmoid"),
195 "linear": partial(keras.layers.Activation, "linear"),
196 "selu": partial(keras.layers.Activation, "selu"),
197 "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)),
198 "leakyrelu": partial(keras.layers.LeakyReLU)}
199 _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
200 "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
201 "prelu": keras.initializers.he_normal()}
202 _optimizer = {"adam": keras.optimizers.Adam, "sgd": keras.optimizers.SGD}
203 _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2}
204 _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"]
205 _dropout = {"selu": keras.layers.AlphaDropout}
207 def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
208 optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
209 batch_normalization=False, **kwargs):
210 """
211 Sets model and loss depending on the given arguments.
213 :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
214 :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
216 Customize this FCN model via the following parameters:
218 :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu,
219 leakyrelu. (Default relu)
220 :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default
221 linear)
222 :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam)
223 :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each
224 layer. (Default 1)
225 :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10)
226 :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the
227 settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the
228 hidden layer. The number of hidden layers is equal to the total length of this list.
229 :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the
230 network at all. (Default None)
231 :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted
232 between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer
233 is added if set to false. (Default false)
234 """
236 super().__init__(input_shape, output_shape[0])
238 # settings
239 self.activation = self._set_activation(activation)
240 self.activation_name = activation
241 self.activation_output = self._set_activation(activation_output)
242 self.activation_output_name = activation_output
243 self.optimizer = self._set_optimizer(optimizer, **kwargs)
244 self.bn = batch_normalization
245 self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
246 self._update_model_name()
247 self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
248 self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
249 self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
251 # apply to model
252 self.set_model()
253 self.set_compile_options()
254 self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
256 def _set_activation(self, activation):
257 try:
258 return self._activation.get(activation.lower())
259 except KeyError:
260 raise AttributeError(f"Given activation {activation} is not supported in this model class.")
262 def _set_optimizer(self, optimizer, **kwargs):
263 try:
264 opt_name = optimizer.lower()
265 opt = self._optimizer.get(opt_name)
266 opt_kwargs = {}
267 if opt_name == "adam":
268 opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
269 elif opt_name == "sgd":
270 opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
271 return opt(**opt_kwargs)
272 except KeyError:
273 raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
275 def _set_regularizer(self, regularizer, **kwargs):
276 if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
277 return None
278 try:
279 reg_name = regularizer.lower()
280 reg = self._regularizer.get(reg_name)
281 reg_kwargs = {}
282 if reg_name in ["l1", "l2"]:
283 reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True)
284 if reg_name in reg_kwargs:
285 reg_kwargs["l"] = reg_kwargs.pop(reg_name)
286 elif reg_name == "l1_l2":
287 reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True)
288 return reg(**reg_kwargs)
289 except KeyError:
290 raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
292 def _set_dropout(self, activation, dropout_rate):
293 if dropout_rate is None:
294 return None, None
295 assert 0 <= dropout_rate < 1
296 return self._dropout.get(activation, keras.layers.Dropout), dropout_rate
298 def _update_model_name(self):
299 n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}"
300 n_output = str(self._output_shape)
302 if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
303 n_layer, n_hidden = self.layer_configuration
304 branch = [f"{n_hidden}" for _ in range(n_layer)]
305 else:
306 branch = [f"{n}" for n in self.layer_configuration]
308 concat = []
309 n_neurons_concat = int(branch[-1]) * len(self._input_shape)
310 for exp in reversed(range(2, len(self._input_shape) + 1)):
311 n_neurons = self._output_shape ** exp
312 if n_neurons < n_neurons_concat:
313 if len(concat) == 0:
314 concat.append(f"1x{n_neurons}")
315 else:
316 concat.append(str(n_neurons))
317 self.model_name += "_".join(["", n_input, *branch, *concat, n_output])
319 def set_model(self):
320 """
321 Build the model.
322 """
324 if isinstance(self.layer_configuration, tuple) is True:
325 n_layer, n_hidden = self.layer_configuration
326 conf = [n_hidden for _ in range(n_layer)]
327 else:
328 assert isinstance(self.layer_configuration, list) is True
329 conf = self.layer_configuration
331 x_input = []
332 x_in = []
334 for branch in range(len(self._input_shape)):
335 x_input_b = keras.layers.Input(shape=self._input_shape[branch])
336 x_input.append(x_input_b)
337 x_in_b = keras.layers.Flatten()(x_input_b)
339 for layer, n_hidden in enumerate(conf):
340 x_in_b = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
341 kernel_regularizer=self.kernel_regularizer,
342 name=f"Dense_branch{branch + 1}_{layer + 1}")(x_in_b)
343 if self.bn is True:
344 x_in_b = keras.layers.BatchNormalization()(x_in_b)
345 x_in_b = self.activation(name=f"{self.activation_name}_branch{branch + 1}_{layer + 1}")(x_in_b)
346 if self.dropout is not None:
347 x_in_b = self.dropout(self.dropout_rate)(x_in_b)
348 x_in.append(x_in_b)
349 x_concat = keras.layers.Concatenate()(x_in)
351 n_neurons_concat = int(conf[-1]) * len(self._input_shape)
352 layer_concat = 0
353 for exp in reversed(range(2, len(self._input_shape) + 1)):
354 n_neurons = self._output_shape ** exp
355 if n_neurons < n_neurons_concat:
356 layer_concat += 1
357 x_concat = keras.layers.Dense(n_neurons, name=f"Dense_{layer_concat}")(x_concat)
358 if self.bn is True:
359 x_concat = keras.layers.BatchNormalization()(x_concat)
360 x_concat = self.activation(name=f"{self.activation_name}_{layer_concat}")(x_concat)
361 if self.dropout is not None:
362 x_concat = self.dropout(self.dropout_rate)(x_concat)
363 x_concat = keras.layers.Dense(self._output_shape)(x_concat)
364 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
365 self.model = keras.Model(inputs=x_input, outputs=[out])
366 print(self.model.summary())
368 def set_compile_options(self):
369 self.compile_options = {"loss": [keras.losses.mean_squared_error],
370 "metrics": ["mse", "mae", var_loss]}
371 # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss], loss_weights=[2, 1])],
372 # "metrics": ["mse", "mae", var_loss]}
375class BranchedInputUNet(UNet, BranchedInputCNN): # pragma: no cover
376 """
377 A U-net neural network with multiple input branches.
379 ```python
381 input_shape = [(72,1,9),(72,1,9),]
382 output_shape = [(4, )]
384 # model
385 layer_configuration=[
387 # 1st block (down)
388 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"},
389 {"type": "Dropout", "rate": 0.25},
390 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"},
391 {"type": "blocksave"},
392 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)},
394 # 2nd block (down)
395 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"},
396 {"type": "Dropout", "rate": 0.25},
397 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"},
398 {"type": "blocksave"},
399 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)},
401 # 3rd block (down)
402 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"},
403 {"type": "Dropout", "rate": 0.25},
404 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"},
405 {"type": "blocksave"},
406 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)},
408 # 4th block (final down)
409 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"},
410 {"type": "Dropout", "rate": 0.25},
411 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"},
413 # 5th block (up)
414 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 64, "strides": (2, 1),
415 "padding": "same"},
416 {"type": "ConcatenateUNet"},
417 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"},
418 {"type": "Dropout", "rate": 0.25},
419 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"},
421 # 6th block (up)
422 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 32, "strides": (2, 1),
423 "padding": "same"},
424 {"type": "ConcatenateUNet"},
425 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"},
426 {"type": "Dropout", "rate": 0.25},
427 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"},
429 # 7th block (up)
430 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 16, "strides": (2, 1),
431 "padding": "same"},
432 {"type": "ConcatenateUNet"},
433 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"},
434 {"type": "Dropout", "rate": 0.25},
435 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"},
437 # Tail
438 {"type": "Concatenate"},
439 {"type": "Flatten"},
440 {"type": "Dense", "units": 128, "activation": "relu"}
441 ]
443 model = BranchedInputUNet(input_shape, output_shape, layer_configuration)
444 ```
446 """
448 def __init__(self, input_shape, output_shape, layer_configuration: list, optimizer="adam", **kwargs):
450 super(BranchedInputUNet, self).__init__(input_shape, output_shape, layer_configuration, optimizer=optimizer, **kwargs)
452 def set_model(self):
454 x_input = []
455 x_in = []
456 stop_pos = None
457 block_save = []
459 for branch in range(len(self._input_shape)):
460 print(branch)
461 block_save = []
462 shape_b = self._input_shape[branch]
463 x_input_b = keras.layers.Input(shape=shape_b, name=f"input_branch{branch + 1}")
464 x_input.append(x_input_b)
465 x_in_b = x_input_b
466 b_conf = copy.deepcopy(self.conf)
468 for pos, layer_opts in enumerate(b_conf):
469 print(layer_opts)
470 if layer_opts.get("type") == "Concatenate":
471 if stop_pos is None:
472 stop_pos = pos
473 else:
474 assert pos == stop_pos
475 break
476 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
477 if layer == "blocksave":
478 block_save.append(x_in_b)
479 continue
480 layer_name = self._get_layer_name(layer, layer_kwargs, pos, branch)
481 if "Concatenate" in layer_name:
482 x_in_b = layer(name=layer_name)([x_in_b, block_save.pop(-1)])
483 self._layer_save.append({"layer": layer, "follow_up_layer": follow_up_layer})
484 continue
485 x_in_b = layer(**layer_kwargs, name=layer_name)(x_in_b)
486 if follow_up_layer is not None:
487 for follow_up in to_list(follow_up_layer):
488 layer_name = self._get_layer_name(follow_up, None, pos, branch)
489 x_in_b = follow_up(name=layer_name)(x_in_b)
490 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
491 "branch": branch})
492 x_in.append(x_in_b)
494 print("concat")
495 x_concat = keras.layers.Concatenate()(x_in)
496 if len(block_save) > 0:
497 logging.warning(f"Branches of BranchedInputUNet are concatenated before last upsampling block is applied.")
498 block_save = []
500 if stop_pos is not None:
501 for pos, layer_opts in enumerate(self.conf[stop_pos + 1:]):
502 print(layer_opts)
503 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
504 if layer == "blocksave":
505 block_save.append(x_concat)
506 continue
507 layer_name = self._get_layer_name(layer, layer_kwargs, pos, None)
508 if "Concatenate" in layer_name:
509 x_concat = layer(name=layer_name)([x_concat, block_save.pop(-1)])
510 self._layer_save.append({"layer": layer, "follow_up_layer": follow_up_layer})
511 continue
512 x_concat = layer(**layer_kwargs, name=layer_name)(x_concat)
513 if follow_up_layer is not None:
514 for follow_up in to_list(follow_up_layer):
515 layer_name = self._get_layer_name(follow_up, None, pos + stop_pos, None)
516 x_concat = follow_up(name=layer_name)(x_concat)
517 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
518 "branch": "concat"})
520 x_concat = keras.layers.Dense(self._output_shape)(x_concat)
521 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
522 self.model = keras.Model(inputs=x_input, outputs=[out])
523 print(self.model.summary())
526class BranchedInputResNet(ResNet, BranchedInputCNN): # pragma: no cover
527 """
528 A convolutional neural network with multiple input branches and residual blocks (skip connections).
530 ```python
531 input_shape = [(65,1,9), (65,1,9)]
532 output_shape = [(4, )]
534 # model
535 layer_configuration=[
536 {"type": "Conv2D", "activation": "relu", "kernel_size": (7, 1), "filters": 32, "padding": "same"},
537 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)},
538 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "strides": (1, 1), "kernel_regularizer": "l2"},
539 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "strides": (1, 1), "kernel_regularizer": "l2"},
540 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "strides": (1, 1), "kernel_regularizer": "l2", "use_1x1conv": True},
541 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "strides": (1, 1), "kernel_regularizer": "l2"},
542 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "strides": (1, 1), "kernel_regularizer": "l2", "use_1x1conv": True},
543 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "strides": (1, 1), "kernel_regularizer": "l2"},
544 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)},
545 {"type": "Dropout", "rate": 0.25},
546 {"type": "Flatten"},
547 {"type": "Concatenate"},
548 {"type": "Dense", "units": 128, "activation": "relu"}
549 ]
551 model = BranchedInputResNet(input_shape, output_shape, layer_configuration)
552 ```
554 """
556 def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs):
558 super().__init__(input_shape, output_shape, layer_configuration, optimizer=optimizer, **kwargs)