Coverage for mlair/run_modules/model_setup.py: 10%

136 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-11-30 10:51 +0000

1"""Model setup module.""" 

2 

3__author__ = "Lukas Leufen, Felix Kleinert" 

4__date__ = '2019-12-02' 

5 

6import logging 

7import os 

8import re 

9import shutil 

10 

11from dill.source import getsource 

12 

13import tensorflow.keras as keras 

14import pandas as pd 

15import tensorflow as tf 

16import numpy as np 

17 

18from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler 

19from mlair.run_modules.run_environment import RunEnvironment 

20from mlair.configuration import path_config 

21 

22 

23class ModelSetup(RunEnvironment): 

24 """ 

25 Set up the model. 

26 

27 Schedule of model setup: 

28 #. set channels (from variables dimension) 

29 #. build imported model 

30 #. plot model architecture 

31 #. load weights if enabled (e.g. to resume a training) 

32 #. set callbacks and checkpoint 

33 #. compile model 

34 

35 Required objects [scope] from data store: 

36 * `experiment_path` [.] 

37 * `experiment_name` [.] 

38 * `train_model` [.] 

39 * `create_new_model` [.] 

40 * `generator` [train] 

41 * `model_class` [.] 

42 

43 Optional objects 

44 * `lr_decay` [model] 

45 

46 Sets 

47 * `channels` [model] 

48 * `model` [model] 

49 * `hist` [model] 

50 * `callbacks` [model] 

51 * `model_name` [model] 

52 * all settings from model class like `dropout_rate`, `initial_lr`, and `optimizer` [model] 

53 

54 Creates 

55 * plot of model architecture `<model_name>.pdf` 

56 

57 """ 

58 

59 def __init__(self): 

60 """Initialise and run model setup.""" 

61 super().__init__() 

62 self.model = None 

63 self.scope = "model" 

64 self._train_model = self.data_store.get("train_model") 

65 self._create_new_model = self.data_store.get("create_new_model") 

66 self.path = self.data_store.get("model_path") 

67 self.model_display_name = self.data_store.get_default("model_display_name", default=None) 

68 self.model_load_path = None 

69 path = self._set_model_path() 

70 self.model_path = path % "%s.h5" 

71 self.checkpoint_name = path % "model-best.h5" 

72 self.callbacks_name = path % "model-best-callbacks-%s.pickle" 

73 self._run() 

74 

75 def _run(self): 

76 

77 # set channels depending on inputs 

78 self._set_shapes() 

79 

80 # build model graph using settings from my_model_settings() 

81 self.build_model() 

82 

83 # broadcast custom objects 

84 self.broadcast_custom_objects() 

85 

86 # plot model structure 

87 self.plot_model() 

88 

89 # load weights if no training shall be performed 

90 if not self._train_model and not self._create_new_model: 

91 self.load_model() 

92 

93 # create checkpoint 

94 self._set_callbacks() 

95 

96 # compile model 

97 self.compile_model() 

98 

99 # report settings 

100 self.report_model() 

101 

102 def _set_model_path(self): 

103 exp_name = self.data_store.get("experiment_name") 

104 self.model_load_path = self.data_store.get_default("model_load_path", default=None) 

105 if self.model_load_path is not None: 

106 if not self.model_load_path.endswith(".h5"): 

107 raise FileNotFoundError(f"When providing external models, you need to provide full path including the " 

108 f".h5 file. Given path is not valid: {self.model_load_path}") 

109 if any([self._train_model, self._create_new_model]): 

110 raise ValueError(f"Providing `model_path` along with parameters train_model={self._train_model} and " 

111 f"create_new_model={self._create_new_model} is not possible. Either set both " 

112 f"parameters to False or remove `model_path` parameter. Given was: model_path = " 

113 f"{self.model_load_path}") 

114 return os.path.join(self.path, f"{exp_name}_%s") 

115 

116 def _set_shapes(self): 

117 """Set input and output shapes from train collection.""" 

118 shape = list(map(lambda x: x.shape[1:], self.data_store.get("data_collection", "train")[0].get_X())) 

119 self.data_store.set("input_shape", shape, self.scope) 

120 shape = list(map(lambda y: y.shape[1:], self.data_store.get("data_collection", "train")[0].get_Y())) 

121 self.data_store.set("output_shape", shape, self.scope) 

122 

123 def _set_num_of_training_samples(self): 

124 """ Set number of training samples - needed for example for Bayesian NNs""" 

125 samples = 0 

126 upsampling = self.data_store.create_args_dict(["upsampling"], "train") 

127 for data in self.data_store.get("data_collection", "train"): 

128 length = data.__len__(**upsampling) 

129 samples += length 

130 return samples 

131 

132 def compile_model(self): 

133 """ 

134 Compiles the keras model. Compile options are mandatory and have to be set by implementing set_compile() method 

135 in child class of AbstractModelClass. 

136 """ 

137 compile_options = self.model.compile_options 

138 self.model.compile(**compile_options) 

139 self.data_store.set("model", self.model, self.scope) 

140 

141 def _set_callbacks(self): 

142 """ 

143 Set all callbacks for the training phase. 

144 

145 Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added. 

146 """ 

147 # create callback handler 

148 callbacks = CallbackHandler() 

149 

150 # add callback: learning rate 

151 lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) 

152 if lr is not None: 

153 callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") 

154 

155 # add callback: advanced history 

156 hist = HistoryAdvanced() 

157 self.data_store.set("hist", hist, scope="model") 

158 callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") 

159 

160 # add callback: epo timing 

161 epo_timing = EpoTimingCallback() 

162 self.data_store.set("epo_timing", epo_timing, scope="model") 

163 callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing") 

164 

165 # add callback: early stopping 

166 patience = self.data_store.get_default("early_stopping_epochs", default=np.inf) 

167 restore_best_weights = self.data_store.get_default("restore_best_model_weights", default=True) 

168 assert bool(isinstance(patience, int) or np.isinf(patience)) is True 

169 cb = tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=restore_best_weights) 

170 callbacks.add_callback(cb, self.callbacks_name % "early_stopping", "early_stopping") 

171 

172 # create model checkpoint 

173 callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', 

174 save_best_only=True, mode='auto', restore_best_weights=restore_best_weights) 

175 

176 # store callbacks 

177 self.data_store.set("callbacks", callbacks, self.scope) 

178 

179 def copy_model(self): 

180 """Copy external model to internal experiment structure.""" 

181 if self.model_load_path is not None: 

182 logging.info(f"Copy external model file: {self.model_load_path} -> {self.model_path}") 

183 shutil.copyfile(self.model_load_path, self.model_path) 

184 

185 def load_model(self): 

186 """Try to load model from disk or skip if not possible.""" 

187 self.copy_model() 

188 try: 

189 self.model.load_model(self.model_path) 

190 logging.info(f"reload model {self.model_path} from disk ...") 

191 except OSError: 

192 logging.info('no local model to load...') 

193 

194 def build_model(self): 

195 """Build model using input and output shapes from data store.""" 

196 model = self.data_store.get("model_class") 

197 args_list = model.requirements() 

198 if "num_of_training_samples" in args_list: 

199 num_of_training_samples = self._set_num_of_training_samples() 

200 self.data_store.set("num_of_training_samples", num_of_training_samples, scope=self.scope) 

201 logging.info(f"Store number of training samples ({num_of_training_samples}) in data_store: " 

202 f"self.data_store.set('num_of_training_samples', {num_of_training_samples}, scope=" 

203 f"'{self.scope}')") 

204 args = self.data_store.create_args_dict(args_list, self.scope) 

205 self.model = model(**args) 

206 self.get_model_settings() 

207 

208 def broadcast_custom_objects(self): 

209 """ 

210 Broadcast custom objects to keras utils. 

211 

212 This method is very important, because it adds the model's custom objects to the keras utils. By doing so, all 

213 custom objects can be treated as standard keras modules. Therefore, problems related to model or callback 

214 loading are solved. 

215 """ 

216 keras.utils.get_custom_objects().update(self.model.custom_objects) 

217 

218 def get_model_settings(self): 

219 """Load all model settings and store in data store.""" 

220 model_settings = self.model.get_settings() 

221 self.data_store.set_from_dict(model_settings, self.scope, log=True) 

222 generic_model_name = self.data_store.get_default("model_name", self.scope, "my_model") 

223 self.model_path = self.model_path % generic_model_name 

224 self.data_store.set("model_name", self.model_path, self.scope) 

225 

226 def plot_model(self): # pragma: no cover 

227 """Plot model architecture as `<model_name>.pdf`.""" 

228 try: 

229 with tf.device("/cpu:0"): 

230 file_name = f"{self.model_path.rsplit('.', 1)[0]}.pdf" 

231 keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) 

232 except Exception as e: 

233 logging.info(f"Can not plot model due to: {e}") 

234 

235 def report_model(self): 

236 # report model settings 

237 _f = self._clean_name 

238 model_settings = self.model.get_settings() 

239 model_settings.update(self.model.compile_options) 

240 model_settings.update(self.model.optimizer.get_config()) 

241 df = pd.DataFrame(columns=["model setting"]) 

242 for k, v in model_settings.items(): 

243 if v is None: 

244 continue 

245 if isinstance(v, list): 

246 if len(v) > 0: 

247 if isinstance(v[0], dict): 

248 v = ["{" + vi + "}" for vi in [",".join(f"{_f(str(uk))}:{_f(str(uv))}" for uk, uv in d.items()) for d in v]] 

249 else: 

250 v = ",".join(_f(str(u)) for u in v) 

251 else: 

252 v = "[]" 

253 if "<" in str(v): 

254 v = _f(str(v)) 

255 df.loc[k] = str(v) 

256 df.loc["count params"] = str(self.model.count_params()) 

257 df.sort_index(inplace=True) 

258 column_format = "ll" 

259 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

260 path_config.check_path_and_create(path) 

261 for p in [path, self.path]: # log to `latex_report` and `model` 

262 df.to_latex(os.path.join(p, "model_settings.tex"), na_rep='---', column_format=column_format) 

263 df.to_markdown(open(os.path.join(p, "model_settings.md"), mode="w", encoding='utf-8'), tablefmt="github") 

264 # report model summary to file 

265 with open(os.path.join(self.path, "model_summary.txt"), "w") as fh: 

266 self.model.summary(print_fn=lambda x: fh.write(x + "\n")) 

267 # print model code to file 

268 with open(os.path.join(self.path, "model_code.txt"), "w") as fh: 

269 fh.write(getsource(self.data_store.get("model_class"))) 

270 

271 @staticmethod 

272 def _clean_name(orig_name: str): 

273 mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ") 

274 mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name 

275 mod_name = mod_name[0].split(".")[-1] if any( 

276 map(lambda x: x in mod_name[0], ["tensorflow", "keras"])) else mod_name 

277 mod_name = mod_name[:-1] if mod_name[-1] == ">" else "".join(mod_name) 

278 return mod_name.split(".")[-1] if any(map(lambda x: x in mod_name, ["tensorflow", "keras"])) else mod_name