Coverage for mlair/model_modules/u_networks.py: 0%

2 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1__author__ = "Lukas Leufen" 

2__date__ = "2022-08-29" 

3 

4 

5from functools import partial 

6 

7from mlair.helpers import select_from_dict, to_list 

8from mlair.model_modules.convolutional_networks import CNNfromConfig 

9import tensorflow.keras as keras 

10 

11 

12class UNet(CNNfromConfig): # pragma: no cover 

13 """ 

14 A U-net neural network. 

15 

16 ```python 

17 input_shape = [(65,1,9)] 

18 output_shape = [(4, )] 

19 

20 # model 

21 layer_configuration=[ 

22 

23 # 1st block (down) 

24 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"}, 

25 {"type": "Dropout", "rate": 0.25}, 

26 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"}, 

27 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)}, 

28 {"type": "blocksave"}, 

29 

30 # 2nd block (down) 

31 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"}, 

32 {"type": "Dropout", "rate": 0.25}, 

33 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"}, 

34 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)}, 

35 {"type": "blocksave"}, 

36 

37 # 3rd block (down) 

38 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"}, 

39 {"type": "Dropout", "rate": 0.25}, 

40 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"}, 

41 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)}, 

42 {"type": "blocksave"}, 

43 

44 # 4th block (down) 

45 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"}, 

46 {"type": "Dropout", "rate": 0.25}, 

47 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"}, 

48 {"type": "MaxPooling2D", "pool_size": (2, 1), "strides": (2, 1)}, 

49 {"type": "blocksave"}, 

50 

51 # 5th block (final down) 

52 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 256, "padding": "same"}, 

53 {"type": "Dropout", "rate": 0.25}, 

54 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 256, "padding": "same"}, 

55 

56 # 6th block (up) 

57 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 128, "strides": (2, 1), 

58 "padding": "same"}, 

59 {"type": "ConcatenateUNet"}, 

60 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"}, 

61 {"type": "Dropout", "rate": 0.25}, 

62 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 128, "padding": "same"}, 

63 

64 # 7th block (up) 

65 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 64, "strides": (2, 1), 

66 "padding": "same"}, 

67 {"type": "ConcatenateUNet"}, 

68 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"}, 

69 {"type": "Dropout", "rate": 0.25}, 

70 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 64, "padding": "same"}, 

71 

72 # 8th block (up) 

73 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 32, "strides": (2, 1), 

74 "padding": "same"}, 

75 {"type": "ConcatenateUNet"}, 

76 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"}, 

77 {"type": "Dropout", "rate": 0.25}, 

78 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 32, "padding": "same"}, 

79 

80 # 8th block (up) 

81 {"type": "Conv2DTranspose", "activation": "relu", "kernel_size": (2, 1), "filters": 16, "strides": (2, 1), 

82 "padding": "same"}, 

83 {"type": "ConcatenateUNet"}, 

84 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"}, 

85 {"type": "Dropout", "rate": 0.25}, 

86 {"type": "Conv2D", "activation": "relu", "kernel_size": (3, 1), "filters": 16, "padding": "same"}, 

87 

88 # Tail 

89 {"type": "Flatten"}, 

90 {"type": "Dense", "units": 128, "activation": "relu"} 

91 ] 

92 

93 model = UNet(input_shape, output_shape, layer_configuration) 

94 ``` 

95 

96 

97 """ 

98 

99 def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs): 

100 

101 super().__init__(input_shape, output_shape, layer_configuration, optimizer=optimizer, **kwargs) 

102 

103 def _extract_layer_conf(self, layer_opts): 

104 if layer_opts["type"] == "ConcatenateUNet": 

105 layer = getattr(keras.layers, "Concatenate", None) 

106 return layer, None, None 

107 elif layer_opts["type"] == "blocksave": 

108 return "blocksave", None, None 

109 else: 

110 return super()._extract_layer_conf(layer_opts) 

111 

112 def set_model(self): 

113 x_input = keras.layers.Input(shape=self._input_shape) 

114 x_in = x_input 

115 block_save = [] 

116 

117 for pos, layer_opts in enumerate(self.conf): 

118 print(layer_opts) 

119 layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) 

120 if layer == "blocksave": 

121 block_save.append(x_in) 

122 continue 

123 layer_name = self._get_layer_name(layer, layer_kwargs, pos) 

124 if "Concatenate" in layer_name: 

125 x_in = layer(name=layer_name)([x_in, block_save.pop(-1)]) 

126 self._layer_save.append({"layer": layer, "follow_up_layer": follow_up_layer}) 

127 continue 

128 x_in = layer(**layer_kwargs, name=layer_name)(x_in) 

129 if follow_up_layer is not None: 

130 for follow_up in to_list(follow_up_layer): 

131 layer_name = self._get_layer_name(follow_up, None, pos) 

132 x_in = follow_up(name=layer_name)(x_in) 

133 self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer}) 

134 

135 x_in = keras.layers.Dense(self._output_shape)(x_in) 

136 out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) 

137 self.model = keras.Model(inputs=x_input, outputs=[out]) 

138 print(self.model.summary())