Coverage for mlair/run_modules/training.py: 92%
109 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1"""Training module."""
3__author__ = "Lukas Leufen, Felix Kleinert"
4__date__ = '2019-12-05'
6import json
7import logging
8import os
9from typing import Union
11import tensorflow.keras as keras
12from tensorflow.keras.callbacks import Callback, History
13import psutil
14import pandas as pd
16from mlair.data_handler import KerasIterator
17from mlair.model_modules import AbstractModelClass
18from mlair.model_modules.keras_extensions import CallbackHandler
19from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
20from mlair.run_modules.run_environment import RunEnvironment
21from mlair.configuration import path_config
22from mlair.helpers import to_list, tables, TimeTrackingWrapper
25class Training(RunEnvironment):
26 """
27 Train your model with this module.
29 This module isn't required to run, if only a fresh post-processing is preformed. Either remove training call from
30 your run script or set create_new_model and train_model both to false.
32 Schedule of training:
33 #. set_generators(): set generators for training, validation and testing and distribute according to batch size
34 #. make_predict_function(): create predict function before distribution on multiple nodes (detailed information
35 in method description)
36 #. train(): start or resume training of model and save callbacks
37 #. save_model(): save best model from training as final model
39 Required objects [scope] from data store:
40 * `model` [model]
41 * `batch_size` [.]
42 * `epochs` [.]
43 * `callbacks` [model]
44 * `model_name` [model]
45 * `experiment_name` [.]
46 * `experiment_path` [.]
47 * `train_model` [.]
48 * `create_new_model` [.]
49 * `generator` [train, val, test]
50 * `plot_path` [.]
52 Optional objects
53 * `permute_data` [train, val, test]
54 * `upsampling` [train, val, test]
56 Sets
57 * `model` [.]
59 Creates
60 * `<exp_name>_model-best.h5`
61 * `<exp_name>_model-best-callbacks-<name>.h5` (all callbacks from CallbackHandler)
62 * `history.json`
63 * `history_lr.json` (optional)
64 * `<exp_name>_history_<name>.pdf` (different monitoring plots depending on loss metrics and callbacks)
66 """
68 def __init__(self):
69 """Set up and run training."""
70 super().__init__()
71 self.model: AbstractModelClass = self.data_store.get("model", "model")
72 self.train_set: Union[KerasIterator, None] = None
73 self.val_set: Union[KerasIterator, None] = None
74 # self.test_set: Union[KerasIterator, None] = None
75 self.batch_size = self.data_store.get("batch_size")
76 self.epochs = self.data_store.get("epochs")
77 self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model")
78 self.experiment_name = self.data_store.get("experiment_name")
79 self._train_model = self.data_store.get("train_model")
80 self._create_new_model = self.data_store.get("create_new_model")
81 self._run()
83 def _run(self) -> None:
84 """Run training. Details in class description."""
85 self.make_predict_function()
86 if self._train_model:
87 self.set_generators()
88 self.train()
89 self.save_model()
90 self.report_training()
91 else:
92 logging.info("No training has started, because train_model parameter was false.")
94 def make_predict_function(self) -> None:
95 """
96 Create predict function.
98 Must be called before distributing. This is necessary, because tf will compile the predict function just in
99 the moment it is used the first time. This can cause problems, if the model is distributed on different
100 workers. To prevent this, the function is pre-compiled. See discussion @
101 https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252
102 """
103 self.model.make_predict_function()
105 @TimeTrackingWrapper
106 def _set_gen(self, mode: str) -> None:
107 """
108 Set and distribute the generators for given mode regarding batch size.
110 :param mode: name of set, should be from ["train", "val", "test"]
111 """
112 collection = self.data_store.get("data_collection", mode)
113 kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path", "use_multiprocessing",
114 "max_number_multiprocessing"], scope=mode)
115 setattr(self, f"{mode}_set", KerasIterator(collection, self.batch_size, model=self.model, name=mode, **kwargs))
117 @TimeTrackingWrapper
118 def set_generators(self) -> None:
119 """
120 Set all generators for training, validation, and testing subsets.
122 The called sub-method will automatically distribute the data according to the batch size. The subsets can be
123 accessed as class variables train_set, val_set, and test_set.
124 """
125 logging.info("set generators for training and validation")
126 # for mode in ["train", "val", "test"]:
127 for mode in ["train", "val"]:
128 self._set_gen(mode)
130 def train(self) -> None:
131 """
132 Perform training using keras fit().
134 Callbacks are stored locally in the experiment directory. Best model from training is saved for class
135 variable model. If the file path of checkpoint is not empty, this method assumes, that this is not a new
136 training starting from the very beginning, but a resumption from a previous started but interrupted training
137 (or a stopped and now continued training). Train will automatically load the locally stored information and the
138 corresponding model and proceed with the already started training.
139 """
140 logging.info(f"Train with {len(self.train_set)} mini batches.")
141 logging.info(f"Train with option upsampling={self.train_set.upsampling}.")
142 logging.info(f"Train with option shuffle={self.train_set.shuffle}.")
144 checkpoint = self.callbacks.get_checkpoint()
145 if not os.path.exists(checkpoint.filepath) or self._create_new_model:
146 history = self.model.fit(self.train_set,
147 steps_per_epoch=len(self.train_set),
148 epochs=self.epochs,
149 verbose=2,
150 validation_data=self.val_set,
151 validation_steps=len(self.val_set),
152 callbacks=self.callbacks.get_callbacks(as_dict=False),
153 workers=psutil.cpu_count(logical=False))
154 else:
155 logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
156 self.callbacks.load_callbacks()
157 self.callbacks.update_checkpoint()
158 self.model.load_model(checkpoint.filepath, compile=True)
159 hist: History = self.callbacks.get_callback_by_name("hist")
160 initial_epoch = max(hist.epoch) + 1
161 _ = self.model.fit(self.train_set,
162 steps_per_epoch=len(self.train_set),
163 epochs=self.epochs,
164 verbose=2,
165 validation_data=self.val_set,
166 validation_steps=len(self.val_set),
167 callbacks=self.callbacks.get_callbacks(as_dict=False),
168 initial_epoch=initial_epoch,
169 workers=psutil.cpu_count(logical=False))
170 history = hist
171 epoch_best = checkpoint.epoch_best
172 if epoch_best is not None: 172 ↛ 174line 172 didn't jump to line 174, because the condition on line 172 was never false
173 logging.info(f"best epoch: {epoch_best + 1}")
174 try:
175 lr = self.callbacks.get_callback_by_name("lr")
176 except IndexError:
177 lr = None
178 try:
179 epo_timing = self.callbacks.get_callback_by_name("epo_timing")
180 except IndexError:
181 epo_timing = None
182 self.save_callbacks_as_json(history, lr, epo_timing)
183 self.create_monitoring_plots(history, lr, epoch_best)
185 def save_model(self) -> None:
186 """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`."""
187 model_name = self.data_store.get("model_name", "model")
188 logging.debug(f"save model to {model_name}")
189 self.model.save(model_name, save_format="h5")
190 self.model.save(model_name, save_format="tf")
191 self.data_store.set("model", self.model)
193 def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None:
194 """
195 Save callbacks (history, learning rate) of training.
197 * history.history -> history.json
198 * lr_sc.lr -> history_lr.json
200 :param history: history object of training
201 :param lr_sc: learning rate object
202 """
203 logging.debug("saving callbacks")
204 path = self.data_store.get("model_path")
205 with open(os.path.join(path, "history.json"), "w") as f:
206 json.dump(history.history, f)
207 if lr_sc: 207 ↛ 210line 207 didn't jump to line 210, because the condition on line 207 was never false
208 with open(os.path.join(path, "history_lr.json"), "w") as f:
209 json.dump(lr_sc.lr, f)
210 if epo_timing is not None:
211 with open(os.path.join(path, "epo_timing.json"), "w") as f:
212 json.dump(epo_timing.epo_timing, f)
214 def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int = None) -> None:
215 """
216 Create plot of history and learning rate in dependence of the number of epochs.
218 The plots are saved in the experiment's plot_path. History plot is named `<exp_name>_history_loss_val_loss.pdf`,
219 the learning rate with `<exp_name>_history_learning_rate.pdf`.
221 :param history: keras history object with losses to plot (must at least include `loss` and `val_loss`)
222 :param lr_sc: learning rate decay object with 'lr' attribute
223 :param epoch_best: number of best epoch (starts counting as 0)
224 """
225 path = self.data_store.get("plot_path")
226 name = self.data_store.get("experiment_name")
228 # plot history of loss and mse (if available)
229 filename = os.path.join(path, f"{name}_history_loss.pdf")
230 PlotModelHistory(filename=filename, history=history, epoch_best=epoch_best)
231 multiple_branches_used = len(history.model.output_names) > 1 # means that there are multiple output branches
232 if multiple_branches_used: 232 ↛ 233line 232 didn't jump to line 233, because the condition on line 232 was never true
233 filename = os.path.join(path, f"{name}_history_main_loss.pdf")
234 PlotModelHistory(filename=filename, history=history, main_branch=True, epoch_best=epoch_best)
235 mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"]))
236 if len(mse_indicator) > 0:
237 filename = os.path.join(path, f"{name}_history_main_mse.pdf")
238 PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0],
239 main_branch=multiple_branches_used, epoch_best=epoch_best)
241 # plot learning rate
242 if lr_sc: 242 ↛ exitline 242 didn't return from function 'create_monitoring_plots', because the condition on line 242 was never false
243 PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
245 def report_training(self):
246 # create training summary
247 data = {"mini batches": len(self.train_set),
248 "upsampling extremes": self.train_set.upsampling,
249 "shuffling": self.train_set.shuffle,
250 "created new model": self._create_new_model,
251 "epochs": self.epochs,
252 "batch size": self.batch_size}
253 df = pd.DataFrame.from_dict(data, orient="index", columns=["training setting"])
254 df.sort_index(inplace=True)
255 path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
256 path_config.check_path_and_create(path)
258 # store as .tex and .md
259 tables.save_to_tex(path, "training_settings.tex", column_format="ll", df=df)
260 tables.save_to_md(path, "training_settings.md", df=df)
262 # calculate val scores
263 val_score = self.model.evaluate(self.val_set, use_multiprocessing=True, verbose=0)
264 path = self.data_store.get("model_path")
265 with open(os.path.join(path, "val_scores.txt"), "a") as f:
266 for index, item in enumerate(to_list(val_score)):
267 logging.info(f"{self.model.metrics_names[index]} (val), {item}")
268 f.write(f"{self.model.metrics_names[index]}, {item}\n")