Coverage for mlair/model_modules/branched_input_networks.py: 100%

0 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1import logging 

2from functools import partial, reduce 

3import copy 

4from typing import Union 

5 

6from tensorflow import keras as keras 

7 

8from mlair import AbstractModelClass 

9from mlair.helpers import select_from_dict, to_list 

10from mlair.model_modules.loss import var_loss 

11from mlair.model_modules.recurrent_networks import RNN 

12from mlair.model_modules.convolutional_networks import CNNfromConfig 

13from mlair.model_modules.residual_networks import ResNet 

14from mlair.model_modules.u_networks import UNet 

15 

16 

17class BranchedInputCNN(CNNfromConfig): # pragma: no cover 

18 """A convolutional neural network with multiple input branches.""" 

19 

20 def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs): 

21 

22 super().__init__([input_shape], output_shape, layer_configuration, optimizer=optimizer, **kwargs) 

23 

24 def set_model(self): 

25 

26 x_input = [] 

27 x_in = [] 

28 stop_pos = None 

29 

30 for branch in range(len(self._input_shape)): 

31 print(branch) 

32 shape_b = self._input_shape[branch] 

33 x_input_b = keras.layers.Input(shape=shape_b, name=f"input_branch{branch + 1}") 

34 x_input.append(x_input_b) 

35 x_in_b = x_input_b 

36 b_conf = copy.deepcopy(self.conf) 

37 

38 for pos, layer_opts in enumerate(b_conf): 

39 print(layer_opts) 

40 if layer_opts.get("type") == "Concatenate": 

41 if stop_pos is None: 

42 stop_pos = pos 

43 else: 

44 assert pos == stop_pos 

45 break 

46 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) 

47 layer_name = self._get_layer_name(layer, layer_kwargs, pos, branch) 

48 x_in_b = layer(**layer_kwargs, name=layer_name)(x_in_b) 

49 if follow_up_layer is not None: 

50 for follow_up in to_list(follow_up_layer): 

51 layer_name = self._get_layer_name(follow_up, None, pos, branch) 

52 x_in_b = follow_up(name=layer_name)(x_in_b) 

53 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer, 

54 "branch": branch}) 

55 x_in.append(x_in_b) 

56 

57 print("concat") 

58 x_concat = keras.layers.Concatenate()(x_in) 

59 

60 if stop_pos is not None: 

61 for pos, layer_opts in enumerate(self.conf[stop_pos + 1:]): 

62 print(layer_opts) 

63 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) 

64 layer_name = self._get_layer_name(layer, layer_kwargs, pos + stop_pos, None) 

65 x_concat = layer(**layer_kwargs, name=layer_name)(x_concat) 

66 if follow_up_layer is not None: 

67 for follow_up in to_list(follow_up_layer): 

68 layer_name = self._get_layer_name(follow_up, None, pos + stop_pos, None) 

69 x_concat = follow_up(name=layer_name)(x_concat) 

70 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer, 

71 "branch": "concat"}) 

72 

73 x_concat = keras.layers.Dense(self._output_shape)(x_concat) 

74 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat) 

75 self.model = keras.Model(inputs=x_input, outputs=[out]) 

76 print(self.model.summary()) 

77 

78 @staticmethod 

79 def _get_layer_name(layer: keras.layers, layer_kwargs: Union[dict, None], pos: int, branch: int = None): 

80 if isinstance(layer, partial): 

81 name = layer.args[0] if layer.func.__name__ == "Activation" else layer.func.__name__ 

82 else: 

83 name = layer.__name__ 

84 if "Conv" in name and isinstance(layer_kwargs, dict) and "kernel_size" in layer_kwargs: 

85 name = name + "_" + "x".join(map(str, layer_kwargs["kernel_size"])) 

86 if "Pooling" in name and isinstance(layer_kwargs, dict) and "pool_size" in layer_kwargs: 

87 name = name + "_" + "x".join(map(str, layer_kwargs["pool_size"])) 

88 if branch is not None: 

89 name += f"_branch{branch + 1}" 

90 name += f"_{pos + 1}" 

91 return name 

92 

93 

94class BranchedInputRNN(RNN): # pragma: no cover 

95 """A recurrent neural network with multiple input branches.""" 

96 

97 def __init__(self, input_shape, output_shape, *args, **kwargs): 

98 

99 super().__init__([input_shape], output_shape, *args, **kwargs) 

100 

101 def set_model(self): 

102 """ 

103 Build the model. 

104 """ 

105 if isinstance(self.layer_configuration, tuple) is True: 

106 n_layer, n_hidden = self.layer_configuration 

107 conf = [n_hidden for _ in range(n_layer)] 

108 else: 

109 assert isinstance(self.layer_configuration, list) is True 

110 conf = self.layer_configuration 

111 

112 x_input = [] 

113 x_in = [] 

114 

115 for branch in range(len(self._input_shape)): 

116 shape_b = self._input_shape[branch] 

117 x_input_b = keras.layers.Input(shape=shape_b) 

118 x_input.append(x_input_b) 

119 x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])), 

120 name=f"reshape_branch{branch + 1}")(x_input_b) 

121 

122 for layer, n_hidden in enumerate(conf): 

123 return_sequences = (layer < len(conf) - 1) 

124 x_in_b = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn, 

125 name=f"{self.RNN.__name__}_branch{branch + 1}_{layer + 1}", 

126 kernel_regularizer=self.kernel_regularizer)(x_in_b) 

127 if self.bn is True: 

128 x_in_b = keras.layers.BatchNormalization()(x_in_b) 

129 x_in_b = self.activation_rnn(name=f"{self.activation_rnn_name}_branch{branch + 1}_{layer + 1}")(x_in_b) 

130 if self.dropout is not None: 

131 x_in_b = self.dropout(self.dropout_rate)(x_in_b) 

132 x_in.append(x_in_b) 

133 x_concat = keras.layers.Concatenate()(x_in) 

134 

135 if self.add_dense_layer is True: 

136 if len(self.dense_layer_configuration) == 0: 

137 x_concat = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}", 

138 kernel_initializer=self.kernel_initializer, )(x_concat) 

139 x_concat = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_concat) 

140 if self.dropout is not None: 

141 x_concat = self.dropout(self.dropout_rate)(x_concat) 

142 else: 

143 for layer, n_hidden in enumerate(self.dense_layer_configuration): 

144 if n_hidden < self._output_shape: 

145 break 

146 x_concat = keras.layers.Dense(n_hidden, name=f"Dense_{len(conf) + layer + 1}", 

147 kernel_initializer=self.kernel_initializer, )(x_concat) 

148 x_concat = self.activation(name=f"{self.activation_name}_{len(conf) + layer + 1}")(x_concat) 

149 if self.dropout is not None: 

150 x_concat = self.dropout(self.dropout_rate)(x_concat) 

151 

152 x_concat = keras.layers.Dense(self._output_shape)(x_concat) 

153 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat) 

154 self.model = keras.Model(inputs=x_input, outputs=[out]) 

155 print(self.model.summary()) 

156 

157 def set_compile_options(self): 

158 self.compile_options = {"loss": [keras.losses.mean_squared_error], 

159 "metrics": ["mse", "mae", var_loss]} 

160 

161 def _update_model_name(self, rnn_type): 

162 n_input = f"{len(self._input_shape)}x{self._input_shape[0][0]}x" \ 

163 f"{str(reduce(lambda x, y: x * y, self._input_shape[0][1:]))}" 

164 n_output = str(self._output_shape) 

165 self.model_name = rnn_type.upper() 

166 if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2: 

167 n_layer, n_hidden = self.layer_configuration 

168 branch = [f"r{n_hidden}" for _ in range(n_layer)] 

169 else: 

170 branch = [f"r{n}" for n in self.layer_configuration] 

171 

172 concat = [] 

173 if self.add_dense_layer is True: 

174 if len(self.dense_layer_configuration) == 0: 

175 n_hidden = min(self._output_shape ** 2, int(branch[-1])) 

176 concat.append(f"1x{n_hidden}") 

177 else: 

178 for n_hidden in self.dense_layer_configuration: 

179 if n_hidden < self._output_shape: 

180 break 

181 if len(concat) == 0: 

182 concat.append(f"1x{n_hidden}") 

183 else: 

184 concat.append(str(n_hidden)) 

185 self.model_name += "_".join(["", n_input, *branch, *concat, n_output]) 

186 

187 

188class BranchedInputFCN(AbstractModelClass): # pragma: no cover 

189 """ 

190 A fully connected network that uses multiple input branches that are combined by a concatenate layer. 

191 """ 

192 

193 _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"), 

194 "sigmoid": partial(keras.layers.Activation, "sigmoid"), 

195 "linear": partial(keras.layers.Activation, "linear"), 

196 "selu": partial(keras.layers.Activation, "selu"), 

197 "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)), 

198 "leakyrelu": partial(keras.layers.LeakyReLU)} 

199 _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform", 

200 "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(), 

201 "prelu": keras.initializers.he_normal()} 

202 _optimizer = {"adam": keras.optimizers.Adam, "sgd": keras.optimizers.SGD} 

203 _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2} 

204 _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"] 

205 _dropout = {"selu": keras.layers.AlphaDropout} 

206 

207 def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", 

208 optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None, 

209 batch_normalization=False, **kwargs): 

210 """ 

211 Sets model and loss depending on the given arguments. 

212 

213 :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables)) 

214 :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast)) 

215 

216 Customize this FCN model via the following parameters: 

217 

218 :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu, 

219 leakyrelu. (Default relu) 

220 :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default 

221 linear) 

222 :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam) 

223 :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each 

224 layer. (Default 1) 

225 :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10) 

226 :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the 

227 settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the 

228 hidden layer. The number of hidden layers is equal to the total length of this list. 

229 :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the 

230 network at all. (Default None) 

231 :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted 

232 between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer 

233 is added if set to false. (Default false) 

234 """ 

235 

236 super().__init__(input_shape, output_shape[0]) 

237 

238 # settings 

239 self.activation = self._set_activation(activation) 

240 self.activation_name = activation 

241 self.activation_output = self._set_activation(activation_output) 

242 self.activation_output_name = activation_output 

243 self.optimizer = self._set_optimizer(optimizer, **kwargs) 

244 self.bn = batch_normalization 

245 self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration 

246 self._update_model_name() 

247 self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") 

248 self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs) 

249 self.dropout, self.dropout_rate = self._set_dropout(activation, dropout) 

250 

251 # apply to model 

252 self.set_model() 

253 self.set_compile_options() 

254 self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss) 

255 

256 def _set_activation(self, activation): 

257 try: 

258 return self._activation.get(activation.lower()) 

259 except KeyError: 

260 raise AttributeError(f"Given activation {activation} is not supported in this model class.") 

261 

262 def _set_optimizer(self, optimizer, **kwargs): 

263 try: 

264 opt_name = optimizer.lower() 

265 opt = self._optimizer.get(opt_name) 

266 opt_kwargs = {} 

267 if opt_name == "adam": 

268 opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"]) 

269 elif opt_name == "sgd": 

270 opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"]) 

271 return opt(**opt_kwargs) 

272 except KeyError: 

273 raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.") 

274 

275 def _set_regularizer(self, regularizer, **kwargs): 

276 if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): 

277 return None 

278 try: 

279 reg_name = regularizer.lower() 

280 reg = self._regularizer.get(reg_name) 

281 reg_kwargs = {} 

282 if reg_name in ["l1", "l2"]: 

283 reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True) 

284 if reg_name in reg_kwargs: 

285 reg_kwargs["l"] = reg_kwargs.pop(reg_name) 

286 elif reg_name == "l1_l2": 

287 reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True) 

288 return reg(**reg_kwargs) 

289 except KeyError: 

290 raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") 

291 

292 def _set_dropout(self, activation, dropout_rate): 

293 if dropout_rate is None: 

294 return None, None 

295 assert 0 <= dropout_rate < 1 

296 return self._dropout.get(activation, keras.layers.Dropout), dropout_rate 

297 

298 def _update_model_name(self): 

299 n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}" 

300 n_output = str(self._output_shape) 

301 

302 if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2: 

303 n_layer, n_hidden = self.layer_configuration 

304 branch = [f"{n_hidden}" for _ in range(n_layer)] 

305 else: 

306 branch = [f"{n}" for n in self.layer_configuration] 

307 

308 concat = [] 

309 n_neurons_concat = int(branch[-1]) * len(self._input_shape) 

310 for exp in reversed(range(2, len(self._input_shape) + 1)): 

311 n_neurons = self._output_shape ** exp 

312 if n_neurons < n_neurons_concat: 

313 if len(concat) == 0: 

314 concat.append(f"1x{n_neurons}") 

315 else: 

316 concat.append(str(n_neurons)) 

317 self.model_name += "_".join(["", n_input, *branch, *concat, n_output]) 

318 

319 def set_model(self): 

320 """ 

321 Build the model. 

322 """ 

323 

324 if isinstance(self.layer_configuration, tuple) is True: 

325 n_layer, n_hidden = self.layer_configuration 

326 conf = [n_hidden for _ in range(n_layer)] 

327 else: 

328 assert isinstance(self.layer_configuration, list) is True 

329 conf = self.layer_configuration 

330 

331 x_input = [] 

332 x_in = [] 

333 

334 for branch in range(len(self._input_shape)): 

335 x_input_b = keras.layers.Input(shape=self._input_shape[branch]) 

336 x_input.append(x_input_b) 

337 x_in_b = keras.layers.Flatten()(x_input_b) 

338 

339 for layer, n_hidden in enumerate(conf): 

340 x_in_b = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer, 

341 kernel_regularizer=self.kernel_regularizer, 

342 name=f"Dense_branch{branch + 1}_{layer + 1}")(x_in_b) 

343 if self.bn is True: 

344 x_in_b = keras.layers.BatchNormalization()(x_in_b) 

345 x_in_b = self.activation(name=f"{self.activation_name}_branch{branch + 1}_{layer + 1}")(x_in_b) 

346 if self.dropout is not None: 

347 x_in_b = self.dropout(self.dropout_rate)(x_in_b) 

348 x_in.append(x_in_b) 

349 x_concat = keras.layers.Concatenate()(x_in) 

350 

351 n_neurons_concat = int(conf[-1]) * len(self._input_shape) 

352 layer_concat = 0 

353 for exp in reversed(range(2, len(self._input_shape) + 1)): 

354 n_neurons = self._output_shape ** exp 

355 if n_neurons < n_neurons_concat: 

356 layer_concat += 1 

357 x_concat = keras.layers.Dense(n_neurons, name=f"Dense_{layer_concat}")(x_concat) 

358 if self.bn is True: 

359 x_concat = keras.layers.BatchNormalization()(x_concat) 

360 x_concat = self.activation(name=f"{self.activation_name}_{layer_concat}")(x_concat) 

361 if self.dropout is not None: 

362 x_concat = self.dropout(self.dropout_rate)(x_concat) 

363 x_concat = keras.layers.Dense(self._output_shape)(x_concat) 

364 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat) 

365 self.model = keras.Model(inputs=x_input, outputs=[out]) 

366 print(self.model.summary()) 

367 

368 def set_compile_options(self): 

369 self.compile_options = {"loss": [keras.losses.mean_squared_error], 

370 "metrics": ["mse", "mae", var_loss]} 

371 # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss], loss_weights=[2, 1])], 

372 # "metrics": ["mse", "mae", var_loss]} 

373 

374 

375class BranchedInputUNet(UNet, BranchedInputCNN): # pragma: no cover 

376 """ 

377 A U-net neural network with multiple input branches. 

378 

379 ```python 

380 

381 input_shape = [(72,1,9),(72,1,9),] 

382 output_shape = [(4, )] 

383 

384 # model 

385 layer_configuration=[ 

386 

387 # 1st block (down) 

388 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"}, 

389 {"type": "Dropout", "rate": 0.25}, 

390 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"}, 

391 {"type": "blocksave"}, 

392 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)}, 

393 

394 # 2nd block (down) 

395 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"}, 

396 {"type": "Dropout", "rate": 0.25}, 

397 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"}, 

398 {"type": "blocksave"}, 

399 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)}, 

400 

401 # 3rd block (down) 

402 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"}, 

403 {"type": "Dropout", "rate": 0.25}, 

404 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"}, 

405 {"type": "blocksave"}, 

406 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)}, 

407 

408 # 4th block (final down) 

409 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"}, 

410 {"type": "Dropout", "rate": 0.25}, 

411 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"}, 

412 

413 # 5th block (up) 

414 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 64, "strides": (2, 1), 

415 "padding": "same"}, 

416 {"type": "ConcatenateUNet"}, 

417 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"}, 

418 {"type": "Dropout", "rate": 0.25}, 

419 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"}, 

420 

421 # 6th block (up) 

422 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 32, "strides": (2, 1), 

423 "padding": "same"}, 

424 {"type": "ConcatenateUNet"}, 

425 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"}, 

426 {"type": "Dropout", "rate": 0.25}, 

427 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"}, 

428 

429 # 7th block (up) 

430 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 16, "strides": (2, 1), 

431 "padding": "same"}, 

432 {"type": "ConcatenateUNet"}, 

433 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"}, 

434 {"type": "Dropout", "rate": 0.25}, 

435 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"}, 

436 

437 # Tail 

438 {"type": "Concatenate"}, 

439 {"type": "Flatten"}, 

440 {"type": "Dense", "units": 128, "activation": "relu"} 

441 ] 

442 

443 model = BranchedInputUNet(input_shape, output_shape, layer_configuration) 

444 ``` 

445 

446 """ 

447 

448 def __init__(self, input_shape, output_shape, layer_configuration: list, optimizer="adam", **kwargs): 

449 

450 super(BranchedInputUNet, self).__init__(input_shape, output_shape, layer_configuration, optimizer=optimizer, **kwargs) 

451 

452 def set_model(self): 

453 

454 x_input = [] 

455 x_in = [] 

456 stop_pos = None 

457 block_save = [] 

458 

459 for branch in range(len(self._input_shape)): 

460 print(branch) 

461 block_save = [] 

462 shape_b = self._input_shape[branch] 

463 x_input_b = keras.layers.Input(shape=shape_b, name=f"input_branch{branch + 1}") 

464 x_input.append(x_input_b) 

465 x_in_b = x_input_b 

466 b_conf = copy.deepcopy(self.conf) 

467 

468 for pos, layer_opts in enumerate(b_conf): 

469 print(layer_opts) 

470 if layer_opts.get("type") == "Concatenate": 

471 if stop_pos is None: 

472 stop_pos = pos 

473 else: 

474 assert pos == stop_pos 

475 break 

476 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) 

477 if layer == "blocksave": 

478 block_save.append(x_in_b) 

479 continue 

480 layer_name = self._get_layer_name(layer, layer_kwargs, pos, branch) 

481 if "Concatenate" in layer_name: 

482 x_in_b = layer(name=layer_name)([x_in_b, block_save.pop(-1)]) 

483 self._layer_save.append({"layer": layer, "follow_up_layer": follow_up_layer}) 

484 continue 

485 x_in_b = layer(**layer_kwargs, name=layer_name)(x_in_b) 

486 if follow_up_layer is not None: 

487 for follow_up in to_list(follow_up_layer): 

488 layer_name = self._get_layer_name(follow_up, None, pos, branch) 

489 x_in_b = follow_up(name=layer_name)(x_in_b) 

490 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer, 

491 "branch": branch}) 

492 x_in.append(x_in_b) 

493 

494 print("concat") 

495 x_concat = keras.layers.Concatenate()(x_in) 

496 if len(block_save) > 0: 

497 logging.warning(f"Branches of BranchedInputUNet are concatenated before last upsampling block is applied.") 

498 block_save = [] 

499 

500 if stop_pos is not None: 

501 for pos, layer_opts in enumerate(self.conf[stop_pos + 1:]): 

502 print(layer_opts) 

503 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) 

504 if layer == "blocksave": 

505 block_save.append(x_concat) 

506 continue 

507 layer_name = self._get_layer_name(layer, layer_kwargs, pos, None) 

508 if "Concatenate" in layer_name: 

509 x_concat = layer(name=layer_name)([x_concat, block_save.pop(-1)]) 

510 self._layer_save.append({"layer": layer, "follow_up_layer": follow_up_layer}) 

511 continue 

512 x_concat = layer(**layer_kwargs, name=layer_name)(x_concat) 

513 if follow_up_layer is not None: 

514 for follow_up in to_list(follow_up_layer): 

515 layer_name = self._get_layer_name(follow_up, None, pos + stop_pos, None) 

516 x_concat = follow_up(name=layer_name)(x_concat) 

517 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer, 

518 "branch": "concat"}) 

519 

520 x_concat = keras.layers.Dense(self._output_shape)(x_concat) 

521 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat) 

522 self.model = keras.Model(inputs=x_input, outputs=[out]) 

523 print(self.model.summary()) 

524 

525 

526class BranchedInputResNet(ResNet, BranchedInputCNN): # pragma: no cover 

527 """ 

528 A convolutional neural network with multiple input branches and residual blocks (skip connections). 

529 

530 ```python 

531 input_shape = [(65,1,9), (65,1,9)] 

532 output_shape = [(4, )] 

533 

534 # model 

535 layer_configuration=[ 

536 {"type": "Conv2D", "activation": "relu", "kernel_size": (7, 1), "filters": 32, "padding": "same"}, 

537 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)}, 

538 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "strides": (1, 1), "kernel_regularizer": "l2"}, 

539 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "strides": (1, 1), "kernel_regularizer": "l2"}, 

540 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "strides": (1, 1), "kernel_regularizer": "l2", "use_1x1conv": True}, 

541 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "strides": (1, 1), "kernel_regularizer": "l2"}, 

542 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "strides": (1, 1), "kernel_regularizer": "l2", "use_1x1conv": True}, 

543 {"type": "residual_block", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "strides": (1, 1), "kernel_regularizer": "l2"}, 

544 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)}, 

545 {"type": "Dropout", "rate": 0.25}, 

546 {"type": "Flatten"}, 

547 {"type": "Concatenate"}, 

548 {"type": "Dense", "units": 128, "activation": "relu"} 

549 ] 

550 

551 model = BranchedInputResNet(input_shape, output_shape, layer_configuration) 

552 ``` 

553 

554 """ 

555 

556 def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs): 

557 

558 super().__init__(input_shape, output_shape, layer_configuration, optimizer=optimizer, **kwargs)