Coverage for mlair/model_modules/abstract_model_class.py: 23%
103 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-11-30 10:51 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-11-30 10:51 +0000
1import inspect
2from abc import ABC
3from typing import Any, Dict, Callable
5import tensorflow.keras as keras
6import tensorflow as tf
8from mlair.helpers import remove_items, make_keras_pickable
11class AbstractModelClass(ABC):
12 """
13 The AbstractModelClass provides a unified skeleton for any model provided to the machine learning workflow.
15 The model can always be accessed by calling ModelClass.model or directly by an model method without parsing the
16 model attribute name (e.g. ModelClass.model.compile -> ModelClass.compile). Beside the model, this class provides
17 the corresponding loss function.
18 """
20 _requirements = []
22 def __init__(self, input_shape, output_shape) -> None:
23 """Predefine internal attributes for model and loss."""
24 make_keras_pickable()
25 self.__model = None
26 self.model_name = self.__class__.__name__
27 self.__custom_objects = {}
28 self.__allowed_compile_options = {'optimizer': None,
29 'loss': None,
30 'metrics': None,
31 'loss_weights': None,
32 'sample_weight_mode': None,
33 'weighted_metrics': None,
34 'target_tensors': None
35 }
36 self.__compile_options = self.__allowed_compile_options
37 self.__compile_options_is_set = False
38 self._input_shape = input_shape
39 self._output_shape = self.__extract_from_tuple(output_shape)
41 def load_model(self, name: str, compile: bool = False) -> None:
42 hist = self.model.history
43 self.model.load_weights(name)
44 self.model.history = hist
45 if compile is True:
46 self.model.compile(**self.compile_options)
48 def __getattr__(self, name: str) -> Any:
49 """
50 Is called if __getattribute__ is not able to find requested attribute.
52 Normally, the model class is saved into a variable like `model = ModelClass()`. To bypass a call like
53 `model.model` to access the _model attribute, this method tries to search for the named attribute in the
54 self.model namespace and returns this attribute if available. Therefore, following expression is true:
55 `ModelClass().compile == ModelClass().model.compile` as long the called attribute/method is not part if the
56 ModelClass itself.
58 :param name: name of the attribute or method to call
60 :return: attribute or method from self.model namespace
61 """
62 return self.model.__getattribute__(name)
64 @property
65 def model(self) -> keras.Model:
66 """
67 The model property containing a keras.Model instance.
69 :return: the keras model
70 """
71 return self.__model
73 @model.setter
74 def model(self, value):
75 self.__model = value
77 @property
78 def custom_objects(self) -> Dict:
79 """
80 The custom objects property collects all non-keras utilities that are used in the model class.
82 To load such a customised and already compiled model (e.g. from local disk), this information is required.
84 :return: custom objects in a dictionary
85 """
86 return self.__custom_objects
88 @custom_objects.setter
89 def custom_objects(self, value) -> None:
90 self.__custom_objects = value
92 @property
93 def compile_options(self) -> Dict:
94 """
95 The compile options property allows the user to use all keras.compile() arguments. They can ether be passed as
96 dictionary (1), as attribute, without setting compile_options (2) or as mixture (partly defined as instance
97 attributes and partly parsing a dictionary) of both of them (3).
98 The method will raise an Error when the same parameter is set differently.
100 Example (1) Recommended (includes check for valid keywords which are used as args in keras.compile)
101 .. code-block:: python
102 def set_compile_options(self):
103 self.compile_options = {"optimizer": keras.optimizers.SGD(),
104 "loss": keras.losses.mean_squared_error,
105 "metrics": ["mse", "mae"]}
107 Example (2)
108 .. code-block:: python
109 def set_compile_options(self):
110 self.optimizer = keras.optimizers.SGD()
111 self.loss = keras.losses.mean_squared_error
112 self.metrics = ["mse", "mae"]
114 Example (3)
115 Correct:
116 .. code-block:: python
117 def set_compile_options(self):
118 self.optimizer = keras.optimizers.SGD()
119 self.loss = keras.losses.mean_squared_error
120 self.compile_options = {"metrics": ["mse", "mae"]}
122 Incorrect: (Will raise an error)
123 .. code-block:: python
124 def set_compile_options(self):
125 self.optimizer = keras.optimizers.SGD()
126 self.loss = keras.losses.mean_squared_error
127 self.compile_options = {"optimizer": keras.optimizers.Adam(), "metrics": ["mse", "mae"]}
129 Note:
130 * As long as the attribute and the dict value have exactly the same values, the setter method will not raise
131 an error
132 * For example (2) there is no check implemented, if the attributes are valid compile options
135 :return:
136 """
137 if self.__compile_options_is_set is False:
138 self.compile_options = None
139 return self.__compile_options
141 @compile_options.setter
142 def compile_options(self, value: Dict) -> None:
143 if isinstance(value, dict):
144 if not (set(value.keys()) <= set(self.__allowed_compile_options.keys())):
145 raise ValueError(f"Got invalid key for compile_options. {value.keys()}")
147 for allow_k in self.__allowed_compile_options.keys():
148 if hasattr(self, allow_k):
149 new_v_attr = getattr(self, allow_k)
150 if new_v_attr == list():
151 new_v_attr = None
152 else:
153 new_v_attr = None
154 if isinstance(value, dict):
155 new_v_dic = value.pop(allow_k, None)
156 elif value is None:
157 new_v_dic = None
158 else:
159 raise TypeError(f"`compile_options' must be `dict' or `None', but is {type(value)}.")
160 ## self.__compare_keras_optimizers() foremost disabled, because it does not work as expected
161 #if (new_v_attr == new_v_dic or self.__compare_keras_optimizers(new_v_attr, new_v_dic)) or (
162 # (new_v_attr is None) ^ (new_v_dic is None)):
163 if (new_v_attr == new_v_dic) or ((new_v_attr is None) ^ (new_v_dic is None)):
164 if new_v_attr is not None:
165 self.__compile_options[allow_k] = new_v_attr
166 else:
167 self.__compile_options[allow_k] = new_v_dic
169 else:
170 raise ValueError(
171 f"Got different values or arguments for same argument: self.{allow_k}={new_v_attr.__class__} and '{allow_k}': {new_v_dic.__class__}")
172 self.__compile_options_is_set = True
174 @staticmethod
175 def __extract_from_tuple(tup):
176 """Return element of tuple if it contains only a single element."""
177 return tup[0] if isinstance(tup, tuple) and len(tup) == 1 else tup
179 @staticmethod
180 def __compare_keras_optimizers(first, second):
181 """
182 Compares if optimiser and all settings of the optimisers are exactly equal.
184 :return True if optimisers are interchangeable, or False if optimisers are distinguishable.
185 """
186 if isinstance(list, type(second)):
187 res = False
188 else:
189 if first.__class__ == second.__class__ and '.'.join(
190 first.__module__.split('.')[0:4]) == 'tensorflow.python.keras.optimizer_v2':
191 res = True
192 init = tf.compat.v1.global_variables_initializer()
193 with tf.compat.v1.Session() as sess:
194 sess.run(init)
195 for k, v in first.__dict__.items():
196 try:
197 res *= sess.run(v) == sess.run(second.__dict__[k])
198 except TypeError:
199 res *= v == second.__dict__[k]
200 else:
201 res = False
202 return bool(res)
204 def get_settings(self) -> Dict:
205 """
206 Get all class attributes that are not protected in the AbstractModelClass as dictionary.
208 :return: all class attributes
209 """
210 return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__"))
212 def set_model(self):
213 """Abstract method to set model."""
214 raise NotImplementedError
216 def set_compile_options(self):
217 """
218 This method only has to be defined in child class, when additional compile options should be used ()
219 (other options than optimizer and loss)
220 Has to be set as dictionary: {'optimizer': None,
221 'loss': None,
222 'metrics': None,
223 'loss_weights': None,
224 'sample_weight_mode': None,
225 'weighted_metrics': None,
226 'target_tensors': None
227 }
229 :return:
230 """
231 raise NotImplementedError
233 def set_custom_objects(self, **kwargs) -> None:
234 """
235 Set custom objects that are not part of keras framework.
237 These custom objects are needed if an already compiled model is loaded from disk. There is a special treatment
238 for the Padding2D class, which is a base class for different padding types. For a correct behaviour, all
239 supported subclasses are added as custom objects in addition to the given ones.
241 :param kwargs: all custom objects, that should be saved
242 """
243 if "Padding2D" in kwargs.keys():
244 kwargs.update(kwargs["Padding2D"].allowed_paddings)
245 self.custom_objects = kwargs
247 @classmethod
248 def requirements(cls):
249 """Return requirements and own arguments without duplicates."""
250 return list(set(cls._requirements + cls.own_args()))
252 @classmethod
253 def own_args(cls, *args):
254 """Return all arguments (including kwonlyargs)."""
255 arg_spec = inspect.getfullargspec(cls)
256 list_of_args = arg_spec.args + arg_spec.kwonlyargs + cls.super_args()
257 return list(set(remove_items(list_of_args, ["self"] + list(args))))
259 @classmethod
260 def super_args(cls):
261 args = []
262 for super_cls in cls.__mro__:
263 if super_cls == cls:
264 continue
265 if hasattr(super_cls, "own_args"):
266 # args.extend(super_cls.own_args())
267 args.extend(getattr(super_cls, "own_args")())
268 return list(set(args))