Coverage for mlair/run_modules/model_setup.py: 41%

120 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-12-02 15:24 +0000

1"""Model setup module.""" 

2 

3__author__ = "Lukas Leufen, Felix Kleinert" 

4__date__ = '2019-12-02' 

5 

6import logging 

7import os 

8import re 

9from dill.source import getsource 

10 

11import tensorflow.keras as keras 

12import pandas as pd 

13import tensorflow as tf 

14import numpy as np 

15 

16from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler 

17from mlair.run_modules.run_environment import RunEnvironment 

18from mlair.configuration import path_config 

19 

20 

21class ModelSetup(RunEnvironment): 

22 """ 

23 Set up the model. 

24 

25 Schedule of model setup: 

26 #. set channels (from variables dimension) 

27 #. build imported model 

28 #. plot model architecture 

29 #. load weights if enabled (e.g. to resume a training) 

30 #. set callbacks and checkpoint 

31 #. compile model 

32 

33 Required objects [scope] from data store: 

34 * `experiment_path` [.] 

35 * `experiment_name` [.] 

36 * `train_model` [.] 

37 * `create_new_model` [.] 

38 * `generator` [train] 

39 * `model_class` [.] 

40 

41 Optional objects 

42 * `lr_decay` [model] 

43 

44 Sets 

45 * `channels` [model] 

46 * `model` [model] 

47 * `hist` [model] 

48 * `callbacks` [model] 

49 * `model_name` [model] 

50 * all settings from model class like `dropout_rate`, `initial_lr`, and `optimizer` [model] 

51 

52 Creates 

53 * plot of model architecture `<model_name>.pdf` 

54 

55 """ 

56 

57 def __init__(self): 

58 """Initialise and run model setup.""" 

59 super().__init__() 

60 self.model = None 

61 exp_name = self.data_store.get("experiment_name") 

62 self.path = self.data_store.get("model_path") 

63 self.scope = "model" 

64 path = os.path.join(self.path, f"{exp_name}_%s") 

65 self.model_name = path % "%s.h5" 

66 self.checkpoint_name = path % "model-best.h5" 

67 self.callbacks_name = path % "model-best-callbacks-%s.pickle" 

68 self._train_model = self.data_store.get("train_model") 

69 self._create_new_model = self.data_store.get("create_new_model") 

70 self._run() 

71 

72 def _run(self): 

73 

74 # set channels depending on inputs 

75 self._set_shapes() 

76 

77 # build model graph using settings from my_model_settings() 

78 self.build_model() 

79 

80 # broadcast custom objects 

81 self.broadcast_custom_objects() 

82 

83 # plot model structure 

84 self.plot_model() 

85 

86 # load weights if no training shall be performed 

87 if not self._train_model and not self._create_new_model: 

88 self.load_model() 

89 

90 # create checkpoint 

91 self._set_callbacks() 

92 

93 # compile model 

94 self.compile_model() 

95 

96 # report settings 

97 self.report_model() 

98 

99 def _set_shapes(self): 

100 """Set input and output shapes from train collection.""" 

101 shape = list(map(lambda x: x.shape[1:], self.data_store.get("data_collection", "train")[0].get_X())) 

102 self.data_store.set("input_shape", shape, self.scope) 

103 shape = list(map(lambda y: y.shape[1:], self.data_store.get("data_collection", "train")[0].get_Y())) 

104 self.data_store.set("output_shape", shape, self.scope) 

105 

106 def _set_num_of_training_samples(self): 

107 """ Set number of training samples - needed for example for Bayesian NNs""" 

108 samples = 0 

109 upsampling = self.data_store.create_args_dict(["upsampling"], "train") 

110 for data in self.data_store.get("data_collection", "train"): 

111 length = data.__len__(**upsampling) 

112 samples += length 

113 return samples 

114 

115 def compile_model(self): 

116 """ 

117 Compiles the keras model. Compile options are mandatory and have to be set by implementing set_compile() method 

118 in child class of AbstractModelClass. 

119 """ 

120 compile_options = self.model.compile_options 

121 self.model.compile(**compile_options) 

122 self.data_store.set("model", self.model, self.scope) 

123 

124 def _set_callbacks(self): 

125 """ 

126 Set all callbacks for the training phase. 

127 

128 Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added. 

129 """ 

130 # create callback handler 

131 callbacks = CallbackHandler() 

132 

133 # add callback: learning rate 

134 lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) 

135 if lr is not None: 

136 callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") 

137 

138 # add callback: advanced history 

139 hist = HistoryAdvanced() 

140 self.data_store.set("hist", hist, scope="model") 

141 callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") 

142 

143 # add callback: epo timing 

144 epo_timing = EpoTimingCallback() 

145 self.data_store.set("epo_timing", epo_timing, scope="model") 

146 callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing") 

147 

148 # add callback: early stopping 

149 patience = self.data_store.get_default("early_stopping_epochs", default=np.inf) 

150 restore_best_weights = self.data_store.get_default("restore_best_model_weights", default=True) 

151 assert bool(isinstance(patience, int) or np.isinf(patience)) is True 

152 cb = tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=restore_best_weights) 

153 callbacks.add_callback(cb, self.callbacks_name % "early_stopping", "early_stopping") 

154 

155 # create model checkpoint 

156 callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', 

157 save_best_only=True, mode='auto', restore_best_weights=restore_best_weights) 

158 

159 # store callbacks 

160 self.data_store.set("callbacks", callbacks, self.scope) 

161 

162 def load_model(self): 

163 """Try to load model from disk or skip if not possible.""" 

164 try: 

165 self.model.load_model(self.model_name) 

166 logging.info(f"reload model {self.model_name} from disk ...") 

167 except OSError: 

168 logging.info('no local model to load...') 

169 

170 def build_model(self): 

171 """Build model using input and output shapes from data store.""" 

172 model = self.data_store.get("model_class") 

173 args_list = model.requirements() 

174 if "num_of_training_samples" in args_list: 174 ↛ 175line 174 didn't jump to line 175, because the condition on line 174 was never true

175 num_of_training_samples = self._set_num_of_training_samples() 

176 self.data_store.set("num_of_training_samples", num_of_training_samples, scope=self.scope) 

177 logging.info(f"Store number of training samples ({num_of_training_samples}) in data_store: " 

178 f"self.data_store.set('num_of_training_samples', {num_of_training_samples}, scope=" 

179 f"'{self.scope}')") 

180 args = self.data_store.create_args_dict(args_list, self.scope) 

181 self.model = model(**args) 

182 self.get_model_settings() 

183 

184 def broadcast_custom_objects(self): 

185 """ 

186 Broadcast custom objects to keras utils. 

187 

188 This method is very important, because it adds the model's custom objects to the keras utils. By doing so, all 

189 custom objects can be treated as standard keras modules. Therefore, problems related to model or callback 

190 loading are solved. 

191 """ 

192 keras.utils.get_custom_objects().update(self.model.custom_objects) 

193 

194 def get_model_settings(self): 

195 """Load all model settings and store in data store.""" 

196 model_settings = self.model.get_settings() 

197 self.data_store.set_from_dict(model_settings, self.scope, log=True) 

198 self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") 

199 self.data_store.set("model_name", self.model_name, self.scope) 

200 

201 def plot_model(self): # pragma: no cover 

202 """Plot model architecture as `<model_name>.pdf`.""" 

203 try: 

204 with tf.device("/cpu:0"): 

205 file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" 

206 keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) 

207 except Exception as e: 

208 logging.info(f"Can not plot model due to: {e}") 

209 

210 def report_model(self): 

211 # report model settings 

212 _f = self._clean_name 

213 model_settings = self.model.get_settings() 

214 model_settings.update(self.model.compile_options) 

215 model_settings.update(self.model.optimizer.get_config()) 

216 df = pd.DataFrame(columns=["model setting"]) 

217 for k, v in model_settings.items(): 

218 if v is None: 

219 continue 

220 if isinstance(v, list): 

221 if len(v) > 0: 

222 if isinstance(v[0], dict): 

223 v = ["{" + vi + "}" for vi in [",".join(f"{_f(str(uk))}:{_f(str(uv))}" for uk, uv in d.items()) for d in v]] 

224 else: 

225 v = ",".join(_f(str(u)) for u in v) 

226 else: 

227 v = "[]" 

228 if "<" in str(v): 

229 v = _f(str(v)) 

230 df.loc[k] = str(v) 

231 df.loc["count params"] = str(self.model.count_params()) 

232 df.sort_index(inplace=True) 

233 column_format = "ll" 

234 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

235 path_config.check_path_and_create(path) 

236 for p in [path, self.path]: # log to `latex_report` and `model` 

237 df.to_latex(os.path.join(p, "model_settings.tex"), na_rep='---', column_format=column_format) 

238 df.to_markdown(open(os.path.join(p, "model_settings.md"), mode="w", encoding='utf-8'), tablefmt="github") 

239 # report model summary to file 

240 with open(os.path.join(self.path, "model_summary.txt"), "w") as fh: 

241 self.model.summary(print_fn=lambda x: fh.write(x + "\n")) 

242 # print model code to file 

243 with open(os.path.join(self.path, "model_code.txt"), "w") as fh: 

244 fh.write(getsource(self.data_store.get("model_class"))) 

245 

246 @staticmethod 

247 def _clean_name(orig_name: str): 

248 mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ") 

249 mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name 

250 mod_name = mod_name[0].split(".")[-1] if any( 

251 map(lambda x: x in mod_name[0], ["tensorflow", "keras"])) else mod_name 

252 mod_name = mod_name[:-1] if mod_name[-1] == ">" else "".join(mod_name) 

253 return mod_name.split(".")[-1] if any(map(lambda x: x in mod_name, ["tensorflow", "keras"])) else mod_name