Coverage for mlair/run_modules/training.py: 93%

109 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1"""Training module.""" 

2 

3__author__ = "Lukas Leufen, Felix Kleinert" 

4__date__ = '2019-12-05' 

5 

6import json 

7import logging 

8import os 

9from typing import Union 

10 

11import tensorflow.keras as keras 

12from tensorflow.keras.callbacks import Callback, History 

13import psutil 

14import pandas as pd 

15 

16from mlair.data_handler import KerasIterator 

17from mlair.model_modules import AbstractModelClass 

18from mlair.model_modules.keras_extensions import CallbackHandler 

19from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate 

20from mlair.run_modules.run_environment import RunEnvironment 

21from mlair.configuration import path_config 

22from mlair.helpers import to_list, tables, TimeTrackingWrapper 

23 

24 

25class Training(RunEnvironment): 

26 """ 

27 Train your model with this module. 

28 

29 This module isn't required to run, if only a fresh post-processing is preformed. Either remove training call from 

30 your run script or set create_new_model and train_model both to false. 

31 

32 Schedule of training: 

33 #. set_generators(): set generators for training, validation and testing and distribute according to batch size 

34 #. make_predict_function(): create predict function before distribution on multiple nodes (detailed information 

35 in method description) 

36 #. train(): start or resume training of model and save callbacks 

37 #. save_model(): save best model from training as final model 

38 

39 Required objects [scope] from data store: 

40 * `model` [model] 

41 * `batch_size` [.] 

42 * `epochs` [.] 

43 * `callbacks` [model] 

44 * `model_name` [model] 

45 * `experiment_name` [.] 

46 * `experiment_path` [.] 

47 * `train_model` [.] 

48 * `create_new_model` [.] 

49 * `generator` [train, val, test] 

50 * `plot_path` [.] 

51 

52 Optional objects 

53 * `permute_data` [train, val, test] 

54 * `upsampling` [train, val, test] 

55 

56 Sets 

57 * `model` [.] 

58 

59 Creates 

60 * `<exp_name>_model-best.h5` 

61 * `<exp_name>_model-best-callbacks-<name>.h5` (all callbacks from CallbackHandler) 

62 * `history.json` 

63 * `history_lr.json` (optional) 

64 * `<exp_name>_history_<name>.pdf` (different monitoring plots depending on loss metrics and callbacks) 

65 

66 """ 

67 

68 def __init__(self): 

69 """Set up and run training.""" 

70 super().__init__() 

71 self.model: AbstractModelClass = self.data_store.get("model", "model") 

72 self.train_set: Union[KerasIterator, None] = None 

73 self.val_set: Union[KerasIterator, None] = None 

74 # self.test_set: Union[KerasIterator, None] = None 

75 self.batch_size = self.data_store.get("batch_size") 

76 self.epochs = self.data_store.get("epochs") 

77 self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model") 

78 self.experiment_name = self.data_store.get("experiment_name") 

79 self._train_model = self.data_store.get("train_model") 

80 self._create_new_model = self.data_store.get("create_new_model") 

81 self._run() 

82 

83 def _run(self) -> None: 

84 """Run training. Details in class description.""" 

85 self.make_predict_function() 

86 if self._train_model: 

87 self.set_generators() 

88 self.train() 

89 self.save_model() 

90 self.report_training() 

91 else: 

92 logging.info("No training has started, because train_model parameter was false.") 

93 

94 def make_predict_function(self) -> None: 

95 """ 

96 Create predict function. 

97 

98 Must be called before distributing. This is necessary, because tf will compile the predict function just in 

99 the moment it is used the first time. This can cause problems, if the model is distributed on different 

100 workers. To prevent this, the function is pre-compiled. See discussion @ 

101 https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252 

102 """ 

103 self.model.make_predict_function() 

104 

105 @TimeTrackingWrapper 

106 def _set_gen(self, mode: str) -> None: 

107 """ 

108 Set and distribute the generators for given mode regarding batch size. 

109 

110 :param mode: name of set, should be from ["train", "val", "test"] 

111 """ 

112 collection = self.data_store.get("data_collection", mode) 

113 kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path", "use_multiprocessing", 

114 "max_number_multiprocessing"], scope=mode) 

115 setattr(self, f"{mode}_set", KerasIterator(collection, self.batch_size, model=self.model, name=mode, **kwargs)) 

116 

117 @TimeTrackingWrapper 

118 def set_generators(self) -> None: 

119 """ 

120 Set all generators for training, validation, and testing subsets. 

121 

122 The called sub-method will automatically distribute the data according to the batch size. The subsets can be 

123 accessed as class variables train_set, val_set, and test_set. 

124 """ 

125 logging.info("set generators for training and validation") 

126 # for mode in ["train", "val", "test"]: 

127 for mode in ["train", "val"]: 

128 self._set_gen(mode) 

129 

130 def train(self) -> None: 

131 """ 

132 Perform training using keras fit(). 

133 

134 Callbacks are stored locally in the experiment directory. Best model from training is saved for class 

135 variable model. If the file path of checkpoint is not empty, this method assumes, that this is not a new 

136 training starting from the very beginning, but a resumption from a previous started but interrupted training 

137 (or a stopped and now continued training). Train will automatically load the locally stored information and the 

138 corresponding model and proceed with the already started training. 

139 """ 

140 logging.info(f"Train with {len(self.train_set)} mini batches.") 

141 logging.info(f"Train with option upsampling={self.train_set.upsampling}.") 

142 logging.info(f"Train with option shuffle={self.train_set.shuffle}.") 

143 

144 checkpoint = self.callbacks.get_checkpoint() 

145 if not os.path.exists(checkpoint.filepath) or self._create_new_model: 

146 history = self.model.fit(self.train_set, 

147 steps_per_epoch=len(self.train_set), 

148 epochs=self.epochs, 

149 verbose=2, 

150 validation_data=self.val_set, 

151 validation_steps=len(self.val_set), 

152 callbacks=self.callbacks.get_callbacks(as_dict=False), 

153 workers=psutil.cpu_count(logical=False)) 

154 else: 

155 logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") 

156 self.callbacks.load_callbacks() 

157 self.callbacks.update_checkpoint() 

158 self.model.load_model(checkpoint.filepath, compile=True) 

159 hist: History = self.callbacks.get_callback_by_name("hist") 

160 initial_epoch = max(hist.epoch) + 1 

161 _ = self.model.fit(self.train_set, 

162 steps_per_epoch=len(self.train_set), 

163 epochs=self.epochs, 

164 verbose=2, 

165 validation_data=self.val_set, 

166 validation_steps=len(self.val_set), 

167 callbacks=self.callbacks.get_callbacks(as_dict=False), 

168 initial_epoch=initial_epoch, 

169 workers=psutil.cpu_count(logical=False)) 

170 history = hist 

171 epoch_best = checkpoint.epoch_best 

172 if epoch_best is not None: 

173 logging.info(f"best epoch: {epoch_best + 1}") 

174 try: 

175 lr = self.callbacks.get_callback_by_name("lr") 

176 except IndexError: 

177 lr = None 

178 try: 

179 epo_timing = self.callbacks.get_callback_by_name("epo_timing") 

180 except IndexError: 

181 epo_timing = None 

182 self.save_callbacks_as_json(history, lr, epo_timing) 

183 self.create_monitoring_plots(history, lr, epoch_best) 

184 

185 def save_model(self) -> None: 

186 """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`.""" 

187 model_name = self.data_store.get("model_name", "model") 

188 logging.debug(f"save model to {model_name}") 

189 self.model.save(model_name, save_format="h5") 

190 self.model.save(model_name, save_format="tf") 

191 self.data_store.set("model", self.model) 

192 

193 def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None: 

194 """ 

195 Save callbacks (history, learning rate) of training. 

196 

197 * history.history -> history.json 

198 * lr_sc.lr -> history_lr.json 

199 

200 :param history: history object of training 

201 :param lr_sc: learning rate object 

202 """ 

203 logging.debug("saving callbacks") 

204 path = self.data_store.get("model_path") 

205 with open(os.path.join(path, "history.json"), "w") as f: 

206 json.dump(history.history, f) 

207 if lr_sc: 207 ↛ 210line 207 didn't jump to line 210, because the condition on line 207 was never false

208 with open(os.path.join(path, "history_lr.json"), "w") as f: 

209 json.dump(lr_sc.lr, f) 

210 if epo_timing is not None: 

211 with open(os.path.join(path, "epo_timing.json"), "w") as f: 

212 json.dump(epo_timing.epo_timing, f) 

213 

214 def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int = None) -> None: 

215 """ 

216 Create plot of history and learning rate in dependence of the number of epochs. 

217 

218 The plots are saved in the experiment's plot_path. History plot is named `<exp_name>_history_loss_val_loss.pdf`, 

219 the learning rate with `<exp_name>_history_learning_rate.pdf`. 

220 

221 :param history: keras history object with losses to plot (must at least include `loss` and `val_loss`) 

222 :param lr_sc: learning rate decay object with 'lr' attribute 

223 :param epoch_best: number of best epoch (starts counting as 0) 

224 """ 

225 path = self.data_store.get("plot_path") 

226 name = self.data_store.get("experiment_name") 

227 

228 # plot history of loss and mse (if available) 

229 filename = os.path.join(path, f"{name}_history_loss.pdf") 

230 PlotModelHistory(filename=filename, history=history, epoch_best=epoch_best) 

231 multiple_branches_used = len(history.model.output_names) > 1 # means that there are multiple output branches 

232 if multiple_branches_used: 232 ↛ 233line 232 didn't jump to line 233, because the condition on line 232 was never true

233 filename = os.path.join(path, f"{name}_history_main_loss.pdf") 

234 PlotModelHistory(filename=filename, history=history, main_branch=True, epoch_best=epoch_best) 

235 mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"])) 

236 if len(mse_indicator) > 0: 

237 filename = os.path.join(path, f"{name}_history_main_mse.pdf") 

238 PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0], 

239 main_branch=multiple_branches_used, epoch_best=epoch_best) 

240 

241 # plot learning rate 

242 if lr_sc: 242 ↛ exitline 242 didn't return from function 'create_monitoring_plots', because the condition on line 242 was never false

243 PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) 

244 

245 def report_training(self): 

246 # create training summary 

247 data = {"mini batches": len(self.train_set), 

248 "upsampling extremes": self.train_set.upsampling, 

249 "shuffling": self.train_set.shuffle, 

250 "created new model": self._create_new_model, 

251 "epochs": self.epochs, 

252 "batch size": self.batch_size} 

253 df = pd.DataFrame.from_dict(data, orient="index", columns=["training setting"]) 

254 df.sort_index(inplace=True) 

255 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

256 path_config.check_path_and_create(path) 

257 

258 # store as .tex and .md 

259 tables.save_to_tex(path, "training_settings.tex", column_format="ll", df=df) 

260 tables.save_to_md(path, "training_settings.md", df=df) 

261 

262 # calculate val scores 

263 val_score = self.model.evaluate(self.val_set, use_multiprocessing=True, verbose=0) 

264 path = self.data_store.get("model_path") 

265 with open(os.path.join(path, "val_scores.txt"), "a") as f: 

266 for index, item in enumerate(to_list(val_score)): 

267 logging.info(f"{self.model.metrics_names[index]} (val), {item}") 

268 f.write(f"{self.model.metrics_names[index]}, {item}\n")