Coverage for mlair/model_modules/abstract_model_class.py: 80%

103 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-12-18 17:51 +0000

1import inspect 

2from abc import ABC 

3from typing import Any, Dict, Callable 

4 

5import tensorflow.keras as keras 

6import tensorflow as tf 

7 

8from mlair.helpers import remove_items, make_keras_pickable 

9 

10 

11class AbstractModelClass(ABC): 

12 """ 

13 The AbstractModelClass provides a unified skeleton for any model provided to the machine learning workflow. 

14 

15 The model can always be accessed by calling ModelClass.model or directly by an model method without parsing the 

16 model attribute name (e.g. ModelClass.model.compile -> ModelClass.compile). Beside the model, this class provides 

17 the corresponding loss function. 

18 """ 

19 

20 _requirements = [] 

21 

22 def __init__(self, input_shape, output_shape) -> None: 

23 """Predefine internal attributes for model and loss.""" 

24 make_keras_pickable() 

25 self.__model = None 

26 self.model_name = self.__class__.__name__ 

27 self.__custom_objects = {} 

28 self.__allowed_compile_options = {'optimizer': None, 

29 'loss': None, 

30 'metrics': None, 

31 'loss_weights': None, 

32 'sample_weight_mode': None, 

33 'weighted_metrics': None, 

34 'target_tensors': None 

35 } 

36 self.__compile_options = self.__allowed_compile_options 

37 self.__compile_options_is_set = False 

38 self._input_shape = input_shape 

39 self._output_shape = self.__extract_from_tuple(output_shape) 

40 

41 def load_model(self, name: str, compile: bool = False) -> None: 

42 hist = self.model.history 

43 self.model.load_weights(name) 

44 self.model.history = hist 

45 if compile is True: 

46 self.model.compile(**self.compile_options) 

47 

48 def __getattr__(self, name: str) -> Any: 

49 """ 

50 Is called if __getattribute__ is not able to find requested attribute. 

51 

52 Normally, the model class is saved into a variable like `model = ModelClass()`. To bypass a call like 

53 `model.model` to access the _model attribute, this method tries to search for the named attribute in the 

54 self.model namespace and returns this attribute if available. Therefore, following expression is true: 

55 `ModelClass().compile == ModelClass().model.compile` as long the called attribute/method is not part if the 

56 ModelClass itself. 

57 

58 :param name: name of the attribute or method to call 

59 

60 :return: attribute or method from self.model namespace 

61 """ 

62 return self.model.__getattribute__(name) 

63 

64 @property 

65 def model(self) -> keras.Model: 

66 """ 

67 The model property containing a keras.Model instance. 

68 

69 :return: the keras model 

70 """ 

71 return self.__model 

72 

73 @model.setter 

74 def model(self, value): 

75 self.__model = value 

76 

77 @property 

78 def custom_objects(self) -> Dict: 

79 """ 

80 The custom objects property collects all non-keras utilities that are used in the model class. 

81 

82 To load such a customised and already compiled model (e.g. from local disk), this information is required. 

83 

84 :return: custom objects in a dictionary 

85 """ 

86 return self.__custom_objects 

87 

88 @custom_objects.setter 

89 def custom_objects(self, value) -> None: 

90 self.__custom_objects = value 

91 

92 @property 

93 def compile_options(self) -> Dict: 

94 """ 

95 The compile options property allows the user to use all keras.compile() arguments. They can ether be passed as 

96 dictionary (1), as attribute, without setting compile_options (2) or as mixture (partly defined as instance 

97 attributes and partly parsing a dictionary) of both of them (3). 

98 The method will raise an Error when the same parameter is set differently. 

99 

100 Example (1) Recommended (includes check for valid keywords which are used as args in keras.compile) 

101 .. code-block:: python 

102 def set_compile_options(self): 

103 self.compile_options = {"optimizer": keras.optimizers.SGD(), 

104 "loss": keras.losses.mean_squared_error, 

105 "metrics": ["mse", "mae"]} 

106 

107 Example (2) 

108 .. code-block:: python 

109 def set_compile_options(self): 

110 self.optimizer = keras.optimizers.SGD() 

111 self.loss = keras.losses.mean_squared_error 

112 self.metrics = ["mse", "mae"] 

113 

114 Example (3) 

115 Correct: 

116 .. code-block:: python 

117 def set_compile_options(self): 

118 self.optimizer = keras.optimizers.SGD() 

119 self.loss = keras.losses.mean_squared_error 

120 self.compile_options = {"metrics": ["mse", "mae"]} 

121 

122 Incorrect: (Will raise an error) 

123 .. code-block:: python 

124 def set_compile_options(self): 

125 self.optimizer = keras.optimizers.SGD() 

126 self.loss = keras.losses.mean_squared_error 

127 self.compile_options = {"optimizer": keras.optimizers.Adam(), "metrics": ["mse", "mae"]} 

128 

129 Note: 

130 * As long as the attribute and the dict value have exactly the same values, the setter method will not raise 

131 an error 

132 * For example (2) there is no check implemented, if the attributes are valid compile options 

133 

134 

135 :return: 

136 """ 

137 if self.__compile_options_is_set is False: 137 ↛ 138line 137 didn't jump to line 138, because the condition on line 137 was never true

138 self.compile_options = None 

139 return self.__compile_options 

140 

141 @compile_options.setter 

142 def compile_options(self, value: Dict) -> None: 

143 if isinstance(value, dict): 

144 if not (set(value.keys()) <= set(self.__allowed_compile_options.keys())): 

145 raise ValueError(f"Got invalid key for compile_options. {value.keys()}") 

146 

147 for allow_k in self.__allowed_compile_options.keys(): 

148 if hasattr(self, allow_k): 

149 new_v_attr = getattr(self, allow_k) 

150 if new_v_attr == list(): 

151 new_v_attr = None 

152 else: 

153 new_v_attr = None 

154 if isinstance(value, dict): 

155 new_v_dic = value.pop(allow_k, None) 

156 elif value is None: 

157 new_v_dic = None 

158 else: 

159 raise TypeError(f"`compile_options' must be `dict' or `None', but is {type(value)}.") 

160 ## self.__compare_keras_optimizers() foremost disabled, because it does not work as expected 

161 #if (new_v_attr == new_v_dic or self.__compare_keras_optimizers(new_v_attr, new_v_dic)) or ( 

162 # (new_v_attr is None) ^ (new_v_dic is None)): 

163 if (new_v_attr == new_v_dic) or ((new_v_attr is None) ^ (new_v_dic is None)): 

164 if new_v_attr is not None: 

165 self.__compile_options[allow_k] = new_v_attr 

166 else: 

167 self.__compile_options[allow_k] = new_v_dic 

168 

169 else: 

170 raise ValueError( 

171 f"Got different values or arguments for same argument: self.{allow_k}={new_v_attr.__class__} and '{allow_k}': {new_v_dic.__class__}") 

172 self.__compile_options_is_set = True 

173 

174 @staticmethod 

175 def __extract_from_tuple(tup): 

176 """Return element of tuple if it contains only a single element.""" 

177 return tup[0] if isinstance(tup, tuple) and len(tup) == 1 else tup 

178 

179 @staticmethod 

180 def __compare_keras_optimizers(first, second): 

181 """ 

182 Compares if optimiser and all settings of the optimisers are exactly equal. 

183 

184 :return True if optimisers are interchangeable, or False if optimisers are distinguishable. 

185 """ 

186 if isinstance(list, type(second)): 

187 res = False 

188 else: 

189 if first.__class__ == second.__class__ and '.'.join( 

190 first.__module__.split('.')[0:4]) == 'tensorflow.python.keras.optimizer_v2': 

191 res = True 

192 init = tf.compat.v1.global_variables_initializer() 

193 with tf.compat.v1.Session() as sess: 

194 sess.run(init) 

195 for k, v in first.__dict__.items(): 

196 try: 

197 res *= sess.run(v) == sess.run(second.__dict__[k]) 

198 except TypeError: 

199 res *= v == second.__dict__[k] 

200 else: 

201 res = False 

202 return bool(res) 

203 

204 def get_settings(self) -> Dict: 

205 """ 

206 Get all class attributes that are not protected in the AbstractModelClass as dictionary. 

207 

208 :return: all class attributes 

209 """ 

210 return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__")) 

211 

212 def set_model(self): 

213 """Abstract method to set model.""" 

214 raise NotImplementedError 

215 

216 def set_compile_options(self): 

217 """ 

218 This method only has to be defined in child class, when additional compile options should be used () 

219 (other options than optimizer and loss) 

220 Has to be set as dictionary: {'optimizer': None, 

221 'loss': None, 

222 'metrics': None, 

223 'loss_weights': None, 

224 'sample_weight_mode': None, 

225 'weighted_metrics': None, 

226 'target_tensors': None 

227 } 

228 

229 :return: 

230 """ 

231 raise NotImplementedError 

232 

233 def set_custom_objects(self, **kwargs) -> None: 

234 """ 

235 Set custom objects that are not part of keras framework. 

236 

237 These custom objects are needed if an already compiled model is loaded from disk. There is a special treatment 

238 for the Padding2D class, which is a base class for different padding types. For a correct behaviour, all 

239 supported subclasses are added as custom objects in addition to the given ones. 

240 

241 :param kwargs: all custom objects, that should be saved 

242 """ 

243 if "Padding2D" in kwargs.keys(): 

244 kwargs.update(kwargs["Padding2D"].allowed_paddings) 

245 self.custom_objects = kwargs 

246 

247 @classmethod 

248 def requirements(cls): 

249 """Return requirements and own arguments without duplicates.""" 

250 return list(set(cls._requirements + cls.own_args())) 

251 

252 @classmethod 

253 def own_args(cls, *args): 

254 """Return all arguments (including kwonlyargs).""" 

255 arg_spec = inspect.getfullargspec(cls) 

256 list_of_args = arg_spec.args + arg_spec.kwonlyargs + cls.super_args() 

257 return list(set(remove_items(list_of_args, ["self"] + list(args)))) 

258 

259 @classmethod 

260 def super_args(cls): 

261 args = [] 

262 for super_cls in cls.__mro__: 

263 if super_cls == cls: 

264 continue 

265 if hasattr(super_cls, "own_args"): 

266 # args.extend(super_cls.own_args()) 

267 args.extend(getattr(super_cls, "own_args")()) 

268 return list(set(args)) 

269