Coverage for mlair/model_modules/keras_extensions.py: 96%
136 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1"""Collection of different extensions to keras framework."""
3__author__ = 'Lukas Leufen, Felix Kleinert'
4__date__ = '2020-01-31'
6import copy
7import logging
8import math
9import pickle
10from typing import Union, List
11from typing_extensions import TypedDict
12from time import time
14import numpy as np
15from tensorflow.keras import backend as K
16from tensorflow.keras.callbacks import History, ModelCheckpoint, Callback
18from mlair import helpers
21class HistoryAdvanced(History):
22 """
23 This is almost an identical clone of the original History class.
25 The only difference is that attributes epoch and history are instantiated during the init phase and not during
26 on_train_begin. This is required to resume an already started but disrupted training from an saved state. This
27 HistoryAdvanced callback needs to be added separately as additional callback. To get the full history use this
28 object for further steps instead of the default return of training methods like fit_generator().
30 .. code-block:: python
32 hist = HistoryAdvanced()
33 history = model.fit_generator(generator=.... , callbacks=[hist])
34 history = hist
36 If training was started from beginning this class is identical to the returned history class object.
37 """
39 def __init__(self):
40 """Set up HistoryAdvanced."""
41 self.epoch = []
42 self.history = {}
43 super().__init__()
45 def on_train_begin(self, logs=None):
46 """Overload on_train_begin method to do nothing instead of resetting epoch and history."""
47 pass
50class LearningRateDecay(History):
51 """
52 Decay learning rate during model training.
54 Start with a base learning rate and lower this rate after every n(=epochs_drop) epochs by drop value (0, 1], drop
55 value = 1 means no decay in learning rate.
57 :param base_lr: base learning rate to start with
58 :param drop: ratio to drop after epochs_drop
59 :param epochs_drop: number of epochs after that drop takes place
60 """
62 def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
63 """Set up LearningRateDecay."""
64 super().__init__()
65 self.lr = {'lr': []}
66 self.base_lr = self.check_param(base_lr, 'base_lr')
67 self.drop = self.check_param(drop, 'drop')
68 self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
69 self.epoch = []
70 self.history = {}
72 @staticmethod
73 def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
74 """
75 Check if given value is in interval.
77 The left (lower) endpoint is open, right (upper) endpoint is closed. To use only one side of the interval, set
78 the other endpoint to None. If both ends are set to None, just return the value without any check.
80 :param value: value to check
81 :param name: name of the variable to display in error message
82 :param lower: left (lower) endpoint of interval, opened
83 :param upper: right (upper) endpoint of interval, closed
85 :return: unchanged value or raise ValueError
86 """
87 if lower is None:
88 lower = -np.inf
89 if upper is None:
90 upper = np.inf
91 if lower < value <= upper:
92 return value
93 else:
94 raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
95 f"{name}={value}")
97 def on_train_begin(self, logs=None):
98 """Overload on_train_begin method to do nothing instead of resetting epoch and history."""
99 pass
101 def on_epoch_begin(self, epoch: int, logs=None):
102 """
103 Lower learning rate every epochs_drop epochs by factor drop.
105 :param epoch: current epoch
106 :param logs: ?
107 :return: update keras learning rate
108 """
109 current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop))
110 K.set_value(self.model.optimizer.lr, current_lr)
111 self.lr['lr'].append(current_lr)
112 logging.info(f"Set learning rate to {current_lr}")
113 return K.get_value(self.model.optimizer.lr)
116class EpoTimingCallback(Callback):
117 def __init__(self):
118 self.epo_timing = {'epo_timing': []}
119 self.logs = []
120 self.starttime = None
121 super().__init__()
123 def on_epoch_begin(self, epoch: int, logs=None):
124 self.starttime = time()
126 def on_epoch_end(self, epoch: int, logs=None):
127 self.epo_timing["epo_timing"].append(time()-self.starttime)
130class ModelCheckpointAdvanced(ModelCheckpoint):
131 """
132 Enhance the standard ModelCheckpoint class by additional saves of given callbacks.
134 **We recommend to use CallbackHandler instead of ModelCheckpointAdvanced.** CallbackHandler will handler all your
135 callbacks and the ModelCheckpointAdvanced and prevent you from pitfalls like wrong ordering of callbacks. Actually,
136 CallbackHandler makes use of ModelCheckpointAdvanced.
138 However, if you want to use the ModelCheckpointAdvanced explicitly, follow these instructions:
140 .. code-block:: python
142 # load your callbacks
143 lr = CustomLearningRate()
144 hist = CustomHistory()
146 # set your callbacks with a list dictionary structure
147 callbacks_name = "your_custom_path_%s.pickle"
148 callbacks = [{"callback": lr, "path": callbacks_name % "lr"},
149 {"callback": hist, "path": callbacks_name % "hist"}]
150 # initialise ModelCheckpointAdvanced like the normal ModelCheckpoint (see keras callbacks)
151 ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks)
153 Add ModelCheckpointAdvanced as all other additional callbacks to the callback list. IMPORTANT: Always add
154 ModelCheckpointAdvanced as last callback to properly update all tracked callbacks, e.g.
156 .. code-block:: python
158 # always add ModelCheckpointAdvanced as last element
159 fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks])
161 """
163 def __init__(self, *args, **kwargs):
164 """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
165 self.callbacks = kwargs.pop("callbacks")
166 self.epoch_best = None
167 self.restore_best_weights = kwargs.pop("restore_best_weights", True)
168 super().__init__(*args, **kwargs)
170 def update_best(self, hist):
171 """
172 Update internal best on resuming a training process.
174 If no best object is available, best is set to +/- inf depending on the performance metric and the first trained
175 model (first of the resuming training process) will always saved as best model because its performance will be
176 better than infinity. To prevent this behaviour and compare the performance with the best model performance,
177 call this method before resuming the training process.
179 :param hist: The History object from the previous (interrupted) training.
180 """
181 if self.restore_best_weights: 181 ↛ 191line 181 didn't jump to line 191, because the condition on line 181 was never false
182 f = np.min if self.monitor_op.__name__ == "less" else np.max
183 f_loc = lambda x: np.where(x == f(x))[0][-1]
184 _d = hist.history.get(self.monitor)
185 loc = f_loc(_d)
186 assert f(_d) == _d[loc]
187 self.epoch_best = loc
188 self.best = _d[loc]
189 logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}")
190 else:
191 _d = hist.history.get(self.monitor)[-1]
192 self.best = _d
193 logging.info(f"Set only best result ({self.monitor}={self.best}) without best epoch")
195 def update_callbacks(self, callbacks):
196 """
197 Update all stored callback objects.
199 The argument callbacks needs to follow the same convention like described in the class description (list of
200 dictionaries). Must be run before resuming a training process.
201 """
202 self.callbacks = helpers.to_list(callbacks)
204 def on_epoch_end(self, epoch, logs=None):
205 """Save model as usual (see ModelCheckpoint class), but also save additional callbacks."""
206 super().on_epoch_end(epoch, logs)
208 for callback in self.callbacks:
209 file_path = callback["path"]
210 if self.epochs_since_last_save == 0 and epoch != 0:
211 if self.save_best_only:
212 current = logs.get(self.monitor)
213 if current == self.best:
214 if self.restore_best_weights: 214 ↛ 216line 214 didn't jump to line 216, because the condition on line 214 was never false
215 self.epoch_best = epoch
216 if self.verbose > 0: # pragma: no branch
217 print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
218 with open(file_path, "wb") as f:
219 c = copy.copy(callback["callback"])
220 if hasattr(c, "model"): 220 ↛ 222line 220 didn't jump to line 222, because the condition on line 220 was never false
221 c.model = None
222 pickle.dump(c, f)
223 else:
224 with open(file_path, "wb") as f:
225 if self.verbose > 0: # pragma: no branch
226 print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
227 c = copy.copy(callback["callback"])
228 if hasattr(c, "model"): 228 ↛ 230line 228 didn't jump to line 230, because the condition on line 228 was never false
229 c.model = None
230 pickle.dump(c, f)
233clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str})
236class CallbackHandler:
237 r"""Use the CallbackHandler for better controlling of custom callbacks.
239 The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position
240 if required. You can add an arbitrary number of callbacks to the handler. First, add all callbacks and finally
241 create the model checkpoint. Callbacks that have been added after checkpoint create wouldn't be part if it.
242 Therefore, the handler blocks adding of new callbacks after creation of model checkpoint.
244 .. code-block:: python
246 # init callbacks handler
247 callbacks = CallbackHandler()
249 # set history object (add further elements like this example)
250 hist = keras.callbacks.History()
251 callbacks.add_callback(hist, "callbacks-hist.pickle", "hist")
253 # create advanced checkpoint (details see ModelCheckpointAdvanced)
254 ckpt_name = "model-best.h5"
255 callbacks.create_model_checkpoint(filepath=ckpt_name, verbose=1, ...)
257 # get checkpoint
258 ckpt = callbacks.get_checkpoint()
260 # fit already compiled model and add callbacks, it is important to call get_callbacks with as_dict=False
261 history = model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False))
263 If you want to continue a training, you can use the callback handler to load already stored callbacks. First you
264 need to reload all callbacks. Make sure, that all callbacks are available from previous training. If the callback
265 handler was set up like in the former code example, this will work.
267 .. code-block:: python
269 # load callbacks and update checkpoint
270 callbacks.load_callbacks()
271 callbacks.update_checkpoint()
273 # optional: load your model using checkpoint path
274 model = keras.models.load_model(ckpt.filepath)
276 # extract history object and set starting epoch
277 hist = callbacks.get_callback_by_name("hist")
278 initial_epoch = max(hist.epoch) + 1
280 # resume training (including initial_epoch) and use callback handler's history object
281 _ = self.model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch)
282 history = hist
284 Important notes: Do not use the returned history object of model.fit, but use the history object from callback
285 handler. The fit history will only contain the new history, whereas callback handler's history contains the full
286 history including the resumed and new history. For a correct epoch counting, you need to add the initial epoch to
287 the fit method too.
289 """
291 def __init__(self):
292 """Initialise CallbackHandler."""
293 self.__callbacks: List[clbk_type] = []
294 self._checkpoint = None
295 self.editable = True
297 @property
298 def _callbacks(self):
299 return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks]
301 @_callbacks.setter
302 def _callbacks(self, value):
303 name, callback, callback_path = value
304 self.__callbacks.append({"name": name, name: callback, "path": callback_path})
306 def _update_callback(self, pos: int, value: Callback) -> None:
307 """Update callback entry with given value."""
308 name = self.__callbacks[pos]["name"]
309 self.__callbacks[pos][name] = value
311 def add_callback(self, callback: Callback, callback_path: str, name: str = "callback") -> None:
312 """
313 Add given callback on last position if CallbackHandler is editable.
315 Save callback with given name. Will raise a PermissionError, if editable is False.
317 :param callback: callback object to store
318 :param callback_path: path to callback
319 :param name: name of the callback
320 """
321 if self.editable:
322 self._callbacks = (name, callback, callback_path)
323 else:
324 raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.")
326 def get_callbacks(self, as_dict=True) -> Union[List[clbk_type], List[Callback]]:
327 """
328 Get all callbacks including checkpoint on last position.
330 :param as_dict: set return format, either clbk_type with dictionary structure (as_dict=True, default) or list
332 :return: all callbacks either as callback dictionary structure (embedded in a list) or as raw objects in a list
333 """
334 if as_dict:
335 return self._get_callbacks()
336 else:
337 return [clb["callback"] for clb in self._get_callbacks()]
339 def get_callback_by_name(self, obj_name: str) -> Union[Callback, History]:
340 """
341 Get single callback by its name.
343 :param obj_name: name of callback to look for
345 :return: requested callback object
346 """
347 if obj_name != "callback":
348 return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0]
350 def _get_callbacks(self) -> List[clbk_type]:
351 """Return all callbacks and append checkpoint if available on last position."""
352 clbks = self._callbacks
353 if self._checkpoint is not None:
354 clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
355 return clbks
357 def get_checkpoint(self) -> ModelCheckpointAdvanced:
358 """Return current checkpoint if available."""
359 if self._checkpoint is not None:
360 return self._checkpoint
362 def create_model_checkpoint(self, **kwargs):
363 """Create a model checkpoint and enable edit."""
364 self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
365 self.editable = False
367 def load_callbacks(self) -> None:
368 """Load callbacks from path and save in callback attribute."""
369 for pos, callback in enumerate(self.__callbacks):
370 path = callback["path"]
371 clb = pickle.load(open(path, "rb"))
372 if clb.model is None and hasattr(self._checkpoint, "model"):
373 clb.model = self._checkpoint.model
374 self._update_callback(pos, clb)
376 def update_checkpoint(self, history_name: str = "hist") -> None:
377 """
378 Update callbacks and history's best elements.
380 :param history_name: name of history object
381 """
382 self._checkpoint.update_callbacks(self._callbacks)
383 self._checkpoint.update_best(self.get_callback_by_name(history_name))