:py:mod:`mlair.model_modules.keras_extensions` ============================================== .. py:module:: mlair.model_modules.keras_extensions .. autoapi-nested-parse:: Collection of different extensions to keras framework. Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: mlair.model_modules.keras_extensions.HistoryAdvanced mlair.model_modules.keras_extensions.LearningRateDecay mlair.model_modules.keras_extensions.EpoTimingCallback mlair.model_modules.keras_extensions.ModelCheckpointAdvanced mlair.model_modules.keras_extensions.CallbackHandler Attributes ~~~~~~~~~~ .. autoapisummary:: mlair.model_modules.keras_extensions.__author__ mlair.model_modules.keras_extensions.__date__ mlair.model_modules.keras_extensions.clbk_type .. py:data:: __author__ :annotation: = Lukas Leufen, Felix Kleinert .. py:data:: __date__ :annotation: = 2020-01-31 .. py:class:: HistoryAdvanced Bases: :py:obj:`tensorflow.keras.callbacks.History` This is almost an identical clone of the original History class. The only difference is that attributes epoch and history are instantiated during the init phase and not during on_train_begin. This is required to resume an already started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as additional callback. To get the full history use this object for further steps instead of the default return of training methods like fit_generator(). .. code-block:: python hist = HistoryAdvanced() history = model.fit_generator(generator=.... , callbacks=[hist]) history = hist If training was started from beginning this class is identical to the returned history class object. .. py:method:: on_train_begin(self, logs=None) Overload on_train_begin method to do nothing instead of resetting epoch and history. .. py:class:: LearningRateDecay(base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8) Bases: :py:obj:`tensorflow.keras.callbacks.History` Decay learning rate during model training. Start with a base learning rate and lower this rate after every n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate. :param base_lr: base learning rate to start with :param drop: ratio to drop after epochs_drop :param epochs_drop: number of epochs after that drop takes place .. py:method:: check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1) :staticmethod: Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To use only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the value without any check. :param value: value to check :param name: name of the variable to display in error message :param lower: left (lower) endpoint of interval, opened :param upper: right (upper) endpoint of interval, closed :return: unchanged value or raise ValueError .. py:method:: on_train_begin(self, logs=None) Overload on_train_begin method to do nothing instead of resetting epoch and history. .. py:method:: on_epoch_begin(self, epoch: int, logs=None) Lower learning rate every epochs_drop epochs by factor drop. :param epoch: current epoch :param logs: ? :return: update keras learning rate .. py:class:: EpoTimingCallback Bases: :py:obj:`tensorflow.keras.callbacks.Callback` Abstract base class used to build new callbacks. Callbacks can be passed to keras methods such as `fit`, `evaluate`, and `predict` in order to hook into the various stages of the model training and inference lifecycle. To create a custom callback, subclass `keras.callbacks.Callback` and override the method associated with the stage of interest. See https://www.tensorflow.org/guide/keras/custom_callback for more information. Example: >>> training_finished = False >>> class MyCallback(tf.keras.callbacks.Callback): ... def on_train_end(self, logs=None): ... global training_finished ... training_finished = True >>> model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))]) >>> model.compile(loss='mean_squared_error') >>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]), ... callbacks=[MyCallback()]) >>> assert training_finished == True If you want to use `Callback` objects in a custom training loop: 1. You should pack all your callbacks into a single `callbacks.CallbackList` so they can all be called together. 2. You will need to manually call all the `on_*` methods at the apropriate locations in your loop. Like this: ``` callbacks = tf.keras.callbacks.CallbackList([...]) callbacks.append(...) callbacks.on_train_begin(...) for epoch in range(EPOCHS): callbacks.on_epoch_begin(epoch) for i, data in dataset.enumerate(): callbacks.on_train_batch_begin(i) batch_logs = model.train_step(data) callbacks.on_train_batch_end(i, batch_logs) epoch_logs = ... callbacks.on_epoch_end(epoch, epoch_logs) final_logs=... callbacks.on_train_end(final_logs) ``` .. attribute:: params Dict. Training parameters (eg. verbosity, batch size, number of epochs...). .. attribute:: model Instance of `keras.models.Model`. Reference of the model being trained. The `logs` dictionary that callback methods take as argument will contain keys for quantities relevant to the current batch or epoch (see method-specific docstrings). .. py:method:: on_epoch_begin(self, epoch: int, logs=None) Called at the start of an epoch. Subclasses should override for any actions to run. This function should only be called during TRAIN mode. :param epoch: Integer, index of epoch. :param logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. .. py:method:: on_epoch_end(self, epoch: int, logs=None) Called at the end of an epoch. Subclasses should override for any actions to run. This function should only be called during TRAIN mode. :param epoch: Integer, index of epoch. :param logs: Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with `val_`. For training epoch, the values of the `Model`'s metrics are returned. Example : `{'loss': 0.2, 'accuracy': 0.7}`. .. py:class:: ModelCheckpointAdvanced(*args, **kwargs) Bases: :py:obj:`tensorflow.keras.callbacks.ModelCheckpoint` Enhance the standard ModelCheckpoint class by additional saves of given callbacks. **We recommend to use CallbackHandler instead of ModelCheckpointAdvanced.** CallbackHandler will handler all your callbacks and the ModelCheckpointAdvanced and prevent you from pitfalls like wrong ordering of callbacks. Actually, CallbackHandler makes use of ModelCheckpointAdvanced. However, if you want to use the ModelCheckpointAdvanced explicitly, follow these instructions: .. code-block:: python # load your callbacks lr = CustomLearningRate() hist = CustomHistory() # set your callbacks with a list dictionary structure callbacks_name = "your_custom_path_%s.pickle" callbacks = [{"callback": lr, "path": callbacks_name % "lr"}, {"callback": hist, "path": callbacks_name % "hist"}] # initialise ModelCheckpointAdvanced like the normal ModelCheckpoint (see keras callbacks) ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks) Add ModelCheckpointAdvanced as all other additional callbacks to the callback list. IMPORTANT: Always add ModelCheckpointAdvanced as last callback to properly update all tracked callbacks, e.g. .. code-block:: python # always add ModelCheckpointAdvanced as last element fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks]) .. py:method:: update_best(self, hist) Update internal best on resuming a training process. If no best object is available, best is set to +/- inf depending on the performance metric and the first trained model (first of the resuming training process) will always saved as best model because its performance will be better than infinity. To prevent this behaviour and compare the performance with the best model performance, call this method before resuming the training process. :param hist: The History object from the previous (interrupted) training. .. py:method:: update_callbacks(self, callbacks) Update all stored callback objects. The argument callbacks needs to follow the same convention like described in the class description (list of dictionaries). Must be run before resuming a training process. .. py:method:: on_epoch_end(self, epoch, logs=None) Save model as usual (see ModelCheckpoint class), but also save additional callbacks. .. py:data:: clbk_type .. py:class:: CallbackHandler Use the CallbackHandler for better controlling of custom callbacks. The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position if required. You can add an arbitrary number of callbacks to the handler. First, add all callbacks and finally create the model checkpoint. Callbacks that have been added after checkpoint create wouldn't be part if it. Therefore, the handler blocks adding of new callbacks after creation of model checkpoint. .. code-block:: python # init callbacks handler callbacks = CallbackHandler() # set history object (add further elements like this example) hist = keras.callbacks.History() callbacks.add_callback(hist, "callbacks-hist.pickle", "hist") # create advanced checkpoint (details see ModelCheckpointAdvanced) ckpt_name = "model-best.h5" callbacks.create_model_checkpoint(filepath=ckpt_name, verbose=1, ...) # get checkpoint ckpt = callbacks.get_checkpoint() # fit already compiled model and add callbacks, it is important to call get_callbacks with as_dict=False history = model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False)) If you want to continue a training, you can use the callback handler to load already stored callbacks. First you need to reload all callbacks. Make sure, that all callbacks are available from previous training. If the callback handler was set up like in the former code example, this will work. .. code-block:: python # load callbacks and update checkpoint callbacks.load_callbacks() callbacks.update_checkpoint() # optional: load your model using checkpoint path model = keras.models.load_model(ckpt.filepath) # extract history object and set starting epoch hist = callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 # resume training (including initial_epoch) and use callback handler's history object _ = self.model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch) history = hist Important notes: Do not use the returned history object of model.fit, but use the history object from callback handler. The fit history will only contain the new history, whereas callback handler's history contains the full history including the resumed and new history. For a correct epoch counting, you need to add the initial epoch to the fit method too. .. py:method:: _callbacks(self) :property: .. py:method:: _update_callback(self, pos: int, value: tensorflow.keras.callbacks.Callback) -> None Update callback entry with given value. .. py:method:: add_callback(self, callback: tensorflow.keras.callbacks.Callback, callback_path: str, name: str = 'callback') -> None Add given callback on last position if CallbackHandler is editable. Save callback with given name. Will raise a PermissionError, if editable is False. :param callback: callback object to store :param callback_path: path to callback :param name: name of the callback .. py:method:: get_callbacks(self, as_dict=True) -> Union[List[clbk_type], List[tensorflow.keras.callbacks.Callback]] Get all callbacks including checkpoint on last position. :param as_dict: set return format, either clbk_type with dictionary structure (as_dict=True, default) or list :return: all callbacks either as callback dictionary structure (embedded in a list) or as raw objects in a list .. py:method:: get_callback_by_name(self, obj_name: str) -> Union[tensorflow.keras.callbacks.Callback, tensorflow.keras.callbacks.History] Get single callback by its name. :param obj_name: name of callback to look for :return: requested callback object .. py:method:: _get_callbacks(self) -> List[clbk_type] Return all callbacks and append checkpoint if available on last position. .. py:method:: get_checkpoint(self) -> ModelCheckpointAdvanced Return current checkpoint if available. .. py:method:: create_model_checkpoint(self, **kwargs) Create a model checkpoint and enable edit. .. py:method:: load_callbacks(self) -> None Load callbacks from path and save in callback attribute. .. py:method:: update_checkpoint(self, history_name: str = 'hist') -> None Update callbacks and history's best elements. :param history_name: name of history object