Coverage for mlair/model_modules/flatten.py: 100%
26 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-12-18 17:51 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-12-18 17:51 +0000
1__author__ = "Felix Kleinert, Lukas Leufen"
2__date__ = '2019-12-02'
4from typing import Union, Callable
6import tensorflow.keras as keras
9def get_activation(input_to_activate: keras.layers, activation: Union[Callable, str], **kwargs):
10 """
11 Apply activation on a given input layer.
13 This helper function is able to handle advanced keras activations as well as strings for standard activations.
15 :param input_to_activate: keras layer to apply activation on
16 :param activation: activation to apply on `input_to_activate'. Can be a standard keras strings or activation layers
17 :param kwargs: keyword arguments used inside activation layer
19 :return: activation
21 .. code-block:: python
23 input_x = ... # your input data
24 x_in = keras.layer(<without activation>)(input_x)
26 # get activation via string
27 x_act_string = get_activation(x_in, 'relu')
28 # or get activation via layer callable
29 x_act_layer = get_activation(x_in, keras.layers.advanced_activations.ELU)
31 """
32 if isinstance(activation, str):
33 name = kwargs.pop('name', None)
34 kwargs['name'] = f'{name}_{activation}'
35 act = keras.layers.Activation(activation, **kwargs)(input_to_activate)
36 else:
37 act = activation(**kwargs)(input_to_activate)
38 return act
41def flatten_tail(input_x: keras.layers, inner_neurons: int, activation: Union[Callable, str],
42 output_neurons: int, output_activation: Union[Callable, str],
43 reduction_filter: int = None,
44 name: str = None,
45 bound_weight: bool = False,
46 dropout_rate: float = None,
47 kernel_regularizer: keras.regularizers = None
48 ):
49 """
50 Flatten output of convolutional layers.
52 :param input_x: Multidimensional keras layer (ConvLayer)
53 :param output_neurons: Number of neurons in the last layer (must fit the shape of labels)
54 :param output_activation: final activation function
55 :param name: Name of the flatten tail.
56 :param bound_weight: Use `tanh' as inner activation if set to True, otherwise `activation'
57 :param dropout_rate: Dropout rate to be applied between trainable layers
58 :param activation: activation to after conv and dense layers
59 :param reduction_filter: number of filters used for information compression on `input_x' before flatten()
60 :param inner_neurons: Number of neurons in inner dense layer
61 :param kernel_regularizer: regularizer to apply on conv and dense layers
63 :return: flatten branch with size n=output_neurons
65 .. code-block:: python
67 input_x = ... # your input data
68 conv_out = Conv2D(*args)(input_x) # your convolution stack
69 out = flatten_tail(conv_out, inner_neurons=64, activation=keras.layers.advanced_activations.ELU,
70 output_neurons=4
71 output_activation='linear', reduction_filter=64,
72 name='Main', bound_weight=False, dropout_rate=.3,
73 kernel_regularizer=keras.regularizers.l2()
74 )
75 model = keras.Model(inputs=input_x, outputs=[out])
77 """
78 # compression layer
79 if reduction_filter is None:
80 x_in = input_x
81 else:
82 x_in = keras.layers.Conv2D(reduction_filter, (1, 1), name=f'{name}_Conv_1x1',
83 kernel_regularizer=kernel_regularizer)(input_x)
84 x_in = get_activation(x_in, activation, name=f'{name}_conv_act')
86 x_in = keras.layers.Flatten(name='{}'.format(name))(x_in)
88 if dropout_rate is not None:
89 x_in = keras.layers.Dropout(dropout_rate, name=f'{name}_Dropout_1')(x_in)
90 x_in = keras.layers.Dense(inner_neurons, kernel_regularizer=kernel_regularizer,
91 name=f'{name}_inner_Dense')(x_in)
92 if bound_weight:
93 x_in = keras.layers.Activation('tanh')(x_in)
94 else:
95 x_in = get_activation(x_in, activation, name=f'{name}_act')
97 if dropout_rate is not None:
98 x_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_2'.format(name))(x_in)
99 out = keras.layers.Dense(output_neurons, kernel_regularizer=kernel_regularizer,
100 name=f'{name}_out_Dense')(x_in)
101 out = get_activation(out, output_activation, name=f'{name}_final_act')
102 return out