:py:mod:`mlair.run_modules.training` ==================================== .. py:module:: mlair.run_modules.training .. autoapi-nested-parse:: Training module. Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: mlair.run_modules.training.Training Attributes ~~~~~~~~~~ .. autoapisummary:: mlair.run_modules.training.__author__ mlair.run_modules.training.__date__ .. py:data:: __author__ :annotation: = Lukas Leufen, Felix Kleinert .. py:data:: __date__ :annotation: = 2019-12-05 .. py:class:: Training Bases: :py:obj:`mlair.run_modules.run_environment.RunEnvironment` Train your model with this module. This module isn't required to run, if only a fresh post-processing is preformed. Either remove training call from your run script or set create_new_model and train_model both to false. Schedule of training: #. set_generators(): set generators for training, validation and testing and distribute according to batch size #. make_predict_function(): create predict function before distribution on multiple nodes (detailed information in method description) #. train(): start or resume training of model and save callbacks #. save_model(): save best model from training as final model Required objects [scope] from data store: * `model` [model] * `batch_size` [.] * `epochs` [.] * `callbacks` [model] * `model_name` [model] * `experiment_name` [.] * `experiment_path` [.] * `train_model` [.] * `create_new_model` [.] * `generator` [train, val, test] * `plot_path` [.] Optional objects * `permute_data` [train, val, test] * `upsampling` [train, val, test] Sets * `model` [.] Creates * `_model-best.h5` * `_model-best-callbacks-.h5` (all callbacks from CallbackHandler) * `history.json` * `history_lr.json` (optional) * `_history_.pdf` (different monitoring plots depending on loss metrics and callbacks) .. py:method:: _run(self) -> None Run training. Details in class description. .. py:method:: make_predict_function(self) -> None Create predict function. Must be called before distributing. This is necessary, because tf will compile the predict function just in the moment it is used the first time. This can cause problems, if the model is distributed on different workers. To prevent this, the function is pre-compiled. See discussion @ https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252 .. py:method:: _set_gen(self, mode: str) -> None Set and distribute the generators for given mode regarding batch size. :param mode: name of set, should be from ["train", "val", "test"] .. py:method:: set_generators(self) -> None Set all generators for training, validation, and testing subsets. The called sub-method will automatically distribute the data according to the batch size. The subsets can be accessed as class variables train_set, val_set, and test_set. .. py:method:: train(self) -> None Perform training using keras fit(). Callbacks are stored locally in the experiment directory. Best model from training is saved for class variable model. If the file path of checkpoint is not empty, this method assumes, that this is not a new training starting from the very beginning, but a resumption from a previous started but interrupted training (or a stopped and now continued training). Train will automatically load the locally stored information and the corresponding model and proceed with the already started training. .. py:method:: save_model(self) -> None Save model in local experiment directory. Model is named as `_.h5`. .. py:method:: save_callbacks_as_json(self, history: tensorflow.keras.callbacks.Callback, lr_sc: tensorflow.keras.callbacks.Callback, epo_timing: tensorflow.keras.callbacks.Callback) -> None Save callbacks (history, learning rate) of training. * history.history -> history.json * lr_sc.lr -> history_lr.json :param history: history object of training :param lr_sc: learning rate object .. py:method:: create_monitoring_plots(self, history: tensorflow.keras.callbacks.Callback, lr_sc: tensorflow.keras.callbacks.Callback, epoch_best: int = None) -> None Create plot of history and learning rate in dependence of the number of epochs. The plots are saved in the experiment's plot_path. History plot is named `_history_loss_val_loss.pdf`, the learning rate with `_history_learning_rate.pdf`. :param history: keras history object with losses to plot (must at least include `loss` and `val_loss`) :param lr_sc: learning rate decay object with 'lr' attribute :param epoch_best: number of best epoch (starts counting as 0) .. py:method:: report_training(self)