Coverage for mlair/model_modules/keras_extensions.py: 96%

136 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:40 +0000

1"""Collection of different extensions to keras framework.""" 

2 

3__author__ = 'Lukas Leufen, Felix Kleinert' 

4__date__ = '2020-01-31' 

5 

6import copy 

7import logging 

8import math 

9import pickle 

10from typing import Union, List 

11from typing_extensions import TypedDict 

12from time import time 

13 

14import numpy as np 

15from tensorflow.keras import backend as K 

16from tensorflow.keras.callbacks import History, ModelCheckpoint, Callback 

17 

18from mlair import helpers 

19 

20 

21class HistoryAdvanced(History): 

22 """ 

23 This is almost an identical clone of the original History class. 

24 

25 The only difference is that attributes epoch and history are instantiated during the init phase and not during 

26 on_train_begin. This is required to resume an already started but disrupted training from an saved state. This 

27 HistoryAdvanced callback needs to be added separately as additional callback. To get the full history use this 

28 object for further steps instead of the default return of training methods like fit_generator(). 

29 

30 .. code-block:: python 

31 

32 hist = HistoryAdvanced() 

33 history = model.fit_generator(generator=.... , callbacks=[hist]) 

34 history = hist 

35 

36 If training was started from beginning this class is identical to the returned history class object. 

37 """ 

38 

39 def __init__(self): 

40 """Set up HistoryAdvanced.""" 

41 self.epoch = [] 

42 self.history = {} 

43 super().__init__() 

44 

45 def on_train_begin(self, logs=None): 

46 """Overload on_train_begin method to do nothing instead of resetting epoch and history.""" 

47 pass 

48 

49 

50class LearningRateDecay(History): 

51 """ 

52 Decay learning rate during model training. 

53 

54 Start with a base learning rate and lower this rate after every n(=epochs_drop) epochs by drop value (0, 1], drop 

55 value = 1 means no decay in learning rate. 

56 

57 :param base_lr: base learning rate to start with 

58 :param drop: ratio to drop after epochs_drop 

59 :param epochs_drop: number of epochs after that drop takes place 

60 """ 

61 

62 def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8): 

63 """Set up LearningRateDecay.""" 

64 super().__init__() 

65 self.lr = {'lr': []} 

66 self.base_lr = self.check_param(base_lr, 'base_lr') 

67 self.drop = self.check_param(drop, 'drop') 

68 self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None) 

69 self.epoch = [] 

70 self.history = {} 

71 

72 @staticmethod 

73 def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1): 

74 """ 

75 Check if given value is in interval. 

76 

77 The left (lower) endpoint is open, right (upper) endpoint is closed. To use only one side of the interval, set 

78 the other endpoint to None. If both ends are set to None, just return the value without any check. 

79 

80 :param value: value to check 

81 :param name: name of the variable to display in error message 

82 :param lower: left (lower) endpoint of interval, opened 

83 :param upper: right (upper) endpoint of interval, closed 

84 

85 :return: unchanged value or raise ValueError 

86 """ 

87 if lower is None: 

88 lower = -np.inf 

89 if upper is None: 

90 upper = np.inf 

91 if lower < value <= upper: 

92 return value 

93 else: 

94 raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: " 

95 f"{name}={value}") 

96 

97 def on_train_begin(self, logs=None): 

98 """Overload on_train_begin method to do nothing instead of resetting epoch and history.""" 

99 pass 

100 

101 def on_epoch_begin(self, epoch: int, logs=None): 

102 """ 

103 Lower learning rate every epochs_drop epochs by factor drop. 

104 

105 :param epoch: current epoch 

106 :param logs: ? 

107 :return: update keras learning rate 

108 """ 

109 current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop)) 

110 K.set_value(self.model.optimizer.lr, current_lr) 

111 self.lr['lr'].append(current_lr) 

112 logging.info(f"Set learning rate to {current_lr}") 

113 return K.get_value(self.model.optimizer.lr) 

114 

115 

116class EpoTimingCallback(Callback): 

117 def __init__(self): 

118 self.epo_timing = {'epo_timing': []} 

119 self.logs = [] 

120 self.starttime = None 

121 super().__init__() 

122 

123 def on_epoch_begin(self, epoch: int, logs=None): 

124 self.starttime = time() 

125 

126 def on_epoch_end(self, epoch: int, logs=None): 

127 self.epo_timing["epo_timing"].append(time()-self.starttime) 

128 

129 

130class ModelCheckpointAdvanced(ModelCheckpoint): 

131 """ 

132 Enhance the standard ModelCheckpoint class by additional saves of given callbacks. 

133 

134 **We recommend to use CallbackHandler instead of ModelCheckpointAdvanced.** CallbackHandler will handler all your 

135 callbacks and the ModelCheckpointAdvanced and prevent you from pitfalls like wrong ordering of callbacks. Actually, 

136 CallbackHandler makes use of ModelCheckpointAdvanced. 

137 

138 However, if you want to use the ModelCheckpointAdvanced explicitly, follow these instructions: 

139 

140 .. code-block:: python 

141 

142 # load your callbacks 

143 lr = CustomLearningRate() 

144 hist = CustomHistory() 

145 

146 # set your callbacks with a list dictionary structure 

147 callbacks_name = "your_custom_path_%s.pickle" 

148 callbacks = [{"callback": lr, "path": callbacks_name % "lr"}, 

149 {"callback": hist, "path": callbacks_name % "hist"}] 

150 # initialise ModelCheckpointAdvanced like the normal ModelCheckpoint (see keras callbacks) 

151 ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks) 

152 

153 Add ModelCheckpointAdvanced as all other additional callbacks to the callback list. IMPORTANT: Always add 

154 ModelCheckpointAdvanced as last callback to properly update all tracked callbacks, e.g. 

155 

156 .. code-block:: python 

157 

158 # always add ModelCheckpointAdvanced as last element 

159 fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks]) 

160 

161 """ 

162 

163 def __init__(self, *args, **kwargs): 

164 """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" 

165 self.callbacks = kwargs.pop("callbacks") 

166 self.epoch_best = None 

167 self.restore_best_weights = kwargs.pop("restore_best_weights", True) 

168 super().__init__(*args, **kwargs) 

169 

170 def update_best(self, hist): 

171 """ 

172 Update internal best on resuming a training process. 

173 

174 If no best object is available, best is set to +/- inf depending on the performance metric and the first trained 

175 model (first of the resuming training process) will always saved as best model because its performance will be 

176 better than infinity. To prevent this behaviour and compare the performance with the best model performance, 

177 call this method before resuming the training process. 

178 

179 :param hist: The History object from the previous (interrupted) training. 

180 """ 

181 if self.restore_best_weights: 181 ↛ 191line 181 didn't jump to line 191, because the condition on line 181 was never false

182 f = np.min if self.monitor_op.__name__ == "less" else np.max 

183 f_loc = lambda x: np.where(x == f(x))[0][-1] 

184 _d = hist.history.get(self.monitor) 

185 loc = f_loc(_d) 

186 assert f(_d) == _d[loc] 

187 self.epoch_best = loc 

188 self.best = _d[loc] 

189 logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}") 

190 else: 

191 _d = hist.history.get(self.monitor)[-1] 

192 self.best = _d 

193 logging.info(f"Set only best result ({self.monitor}={self.best}) without best epoch") 

194 

195 def update_callbacks(self, callbacks): 

196 """ 

197 Update all stored callback objects. 

198 

199 The argument callbacks needs to follow the same convention like described in the class description (list of 

200 dictionaries). Must be run before resuming a training process. 

201 """ 

202 self.callbacks = helpers.to_list(callbacks) 

203 

204 def on_epoch_end(self, epoch, logs=None): 

205 """Save model as usual (see ModelCheckpoint class), but also save additional callbacks.""" 

206 super().on_epoch_end(epoch, logs) 

207 

208 for callback in self.callbacks: 

209 file_path = callback["path"] 

210 if self.epochs_since_last_save == 0 and epoch != 0: 

211 if self.save_best_only: 

212 current = logs.get(self.monitor) 

213 if current == self.best: 

214 if self.restore_best_weights: 214 ↛ 216line 214 didn't jump to line 216, because the condition on line 214 was never false

215 self.epoch_best = epoch 

216 if self.verbose > 0: # pragma: no branch 

217 print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) 

218 with open(file_path, "wb") as f: 

219 c = copy.copy(callback["callback"]) 

220 if hasattr(c, "model"): 220 ↛ 222line 220 didn't jump to line 222, because the condition on line 220 was never false

221 c.model = None 

222 pickle.dump(c, f) 

223 else: 

224 with open(file_path, "wb") as f: 

225 if self.verbose > 0: # pragma: no branch 

226 print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) 

227 c = copy.copy(callback["callback"]) 

228 if hasattr(c, "model"): 228 ↛ 230line 228 didn't jump to line 230, because the condition on line 228 was never false

229 c.model = None 

230 pickle.dump(c, f) 

231 

232 

233clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str}) 

234 

235 

236class CallbackHandler: 

237 r"""Use the CallbackHandler for better controlling of custom callbacks. 

238 

239 The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position 

240 if required. You can add an arbitrary number of callbacks to the handler. First, add all callbacks and finally 

241 create the model checkpoint. Callbacks that have been added after checkpoint create wouldn't be part if it. 

242 Therefore, the handler blocks adding of new callbacks after creation of model checkpoint. 

243 

244 .. code-block:: python 

245 

246 # init callbacks handler 

247 callbacks = CallbackHandler() 

248 

249 # set history object (add further elements like this example) 

250 hist = keras.callbacks.History() 

251 callbacks.add_callback(hist, "callbacks-hist.pickle", "hist") 

252 

253 # create advanced checkpoint (details see ModelCheckpointAdvanced) 

254 ckpt_name = "model-best.h5" 

255 callbacks.create_model_checkpoint(filepath=ckpt_name, verbose=1, ...) 

256 

257 # get checkpoint 

258 ckpt = callbacks.get_checkpoint() 

259 

260 # fit already compiled model and add callbacks, it is important to call get_callbacks with as_dict=False 

261 history = model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False)) 

262 

263 If you want to continue a training, you can use the callback handler to load already stored callbacks. First you 

264 need to reload all callbacks. Make sure, that all callbacks are available from previous training. If the callback 

265 handler was set up like in the former code example, this will work. 

266 

267 .. code-block:: python 

268 

269 # load callbacks and update checkpoint 

270 callbacks.load_callbacks() 

271 callbacks.update_checkpoint() 

272 

273 # optional: load your model using checkpoint path 

274 model = keras.models.load_model(ckpt.filepath) 

275 

276 # extract history object and set starting epoch 

277 hist = callbacks.get_callback_by_name("hist") 

278 initial_epoch = max(hist.epoch) + 1 

279 

280 # resume training (including initial_epoch) and use callback handler's history object 

281 _ = self.model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch) 

282 history = hist 

283 

284 Important notes: Do not use the returned history object of model.fit, but use the history object from callback 

285 handler. The fit history will only contain the new history, whereas callback handler's history contains the full 

286 history including the resumed and new history. For a correct epoch counting, you need to add the initial epoch to 

287 the fit method too. 

288 

289 """ 

290 

291 def __init__(self): 

292 """Initialise CallbackHandler.""" 

293 self.__callbacks: List[clbk_type] = [] 

294 self._checkpoint = None 

295 self.editable = True 

296 

297 @property 

298 def _callbacks(self): 

299 return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks] 

300 

301 @_callbacks.setter 

302 def _callbacks(self, value): 

303 name, callback, callback_path = value 

304 self.__callbacks.append({"name": name, name: callback, "path": callback_path}) 

305 

306 def _update_callback(self, pos: int, value: Callback) -> None: 

307 """Update callback entry with given value.""" 

308 name = self.__callbacks[pos]["name"] 

309 self.__callbacks[pos][name] = value 

310 

311 def add_callback(self, callback: Callback, callback_path: str, name: str = "callback") -> None: 

312 """ 

313 Add given callback on last position if CallbackHandler is editable. 

314 

315 Save callback with given name. Will raise a PermissionError, if editable is False. 

316 

317 :param callback: callback object to store 

318 :param callback_path: path to callback 

319 :param name: name of the callback 

320 """ 

321 if self.editable: 

322 self._callbacks = (name, callback, callback_path) 

323 else: 

324 raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.") 

325 

326 def get_callbacks(self, as_dict=True) -> Union[List[clbk_type], List[Callback]]: 

327 """ 

328 Get all callbacks including checkpoint on last position. 

329 

330 :param as_dict: set return format, either clbk_type with dictionary structure (as_dict=True, default) or list 

331 

332 :return: all callbacks either as callback dictionary structure (embedded in a list) or as raw objects in a list 

333 """ 

334 if as_dict: 

335 return self._get_callbacks() 

336 else: 

337 return [clb["callback"] for clb in self._get_callbacks()] 

338 

339 def get_callback_by_name(self, obj_name: str) -> Union[Callback, History]: 

340 """ 

341 Get single callback by its name. 

342 

343 :param obj_name: name of callback to look for 

344 

345 :return: requested callback object 

346 """ 

347 if obj_name != "callback": 

348 return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0] 

349 

350 def _get_callbacks(self) -> List[clbk_type]: 

351 """Return all callbacks and append checkpoint if available on last position.""" 

352 clbks = self._callbacks 

353 if self._checkpoint is not None: 

354 clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}] 

355 return clbks 

356 

357 def get_checkpoint(self) -> ModelCheckpointAdvanced: 

358 """Return current checkpoint if available.""" 

359 if self._checkpoint is not None: 

360 return self._checkpoint 

361 

362 def create_model_checkpoint(self, **kwargs): 

363 """Create a model checkpoint and enable edit.""" 

364 self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs) 

365 self.editable = False 

366 

367 def load_callbacks(self) -> None: 

368 """Load callbacks from path and save in callback attribute.""" 

369 for pos, callback in enumerate(self.__callbacks): 

370 path = callback["path"] 

371 clb = pickle.load(open(path, "rb")) 

372 if clb.model is None and hasattr(self._checkpoint, "model"): 

373 clb.model = self._checkpoint.model 

374 self._update_callback(pos, clb) 

375 

376 def update_checkpoint(self, history_name: str = "hist") -> None: 

377 """ 

378 Update callbacks and history's best elements. 

379 

380 :param history_name: name of history object 

381 """ 

382 self._checkpoint.update_callbacks(self._callbacks) 

383 self._checkpoint.update_best(self.get_callback_by_name(history_name))