Coverage for mlair/run_modules/model_setup.py: 37%
136 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1"""Model setup module."""
3__author__ = "Lukas Leufen, Felix Kleinert"
4__date__ = '2019-12-02'
6import logging
7import os
8import re
9import shutil
11from dill.source import getsource
13import tensorflow.keras as keras
14import pandas as pd
15import tensorflow as tf
16import numpy as np
18from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler
19from mlair.run_modules.run_environment import RunEnvironment
20from mlair.configuration import path_config
23class ModelSetup(RunEnvironment):
24 """
25 Set up the model.
27 Schedule of model setup:
28 #. set channels (from variables dimension)
29 #. build imported model
30 #. plot model architecture
31 #. load weights if enabled (e.g. to resume a training)
32 #. set callbacks and checkpoint
33 #. compile model
35 Required objects [scope] from data store:
36 * `experiment_path` [.]
37 * `experiment_name` [.]
38 * `train_model` [.]
39 * `create_new_model` [.]
40 * `generator` [train]
41 * `model_class` [.]
43 Optional objects
44 * `lr_decay` [model]
46 Sets
47 * `channels` [model]
48 * `model` [model]
49 * `hist` [model]
50 * `callbacks` [model]
51 * `model_name` [model]
52 * all settings from model class like `dropout_rate`, `initial_lr`, and `optimizer` [model]
54 Creates
55 * plot of model architecture `<model_name>.pdf`
57 """
59 def __init__(self):
60 """Initialise and run model setup."""
61 super().__init__()
62 self.model = None
63 self.scope = "model"
64 self._train_model = self.data_store.get("train_model")
65 self._create_new_model = self.data_store.get("create_new_model")
66 self.path = self.data_store.get("model_path")
67 self.model_display_name = self.data_store.get_default("model_display_name", default=None)
68 self.model_load_path = None
69 path = self._set_model_path()
70 self.model_path = path % "%s.h5"
71 self.checkpoint_name = path % "model-best.h5"
72 self.callbacks_name = path % "model-best-callbacks-%s.pickle"
73 self._run()
75 def _run(self):
77 # set channels depending on inputs
78 self._set_shapes()
80 # build model graph using settings from my_model_settings()
81 self.build_model()
83 # broadcast custom objects
84 self.broadcast_custom_objects()
86 # plot model structure
87 self.plot_model()
89 # load weights if no training shall be performed
90 if not self._train_model and not self._create_new_model:
91 self.load_model()
93 # create checkpoint
94 self._set_callbacks()
96 # compile model
97 self.compile_model()
99 # report settings
100 self.report_model()
102 def _set_model_path(self):
103 exp_name = self.data_store.get("experiment_name")
104 self.model_load_path = self.data_store.get_default("model_load_path", default=None)
105 if self.model_load_path is not None:
106 if not self.model_load_path.endswith(".h5"):
107 raise FileNotFoundError(f"When providing external models, you need to provide full path including the "
108 f".h5 file. Given path is not valid: {self.model_load_path}")
109 if any([self._train_model, self._create_new_model]):
110 raise ValueError(f"Providing `model_path` along with parameters train_model={self._train_model} and "
111 f"create_new_model={self._create_new_model} is not possible. Either set both "
112 f"parameters to False or remove `model_path` parameter. Given was: model_path = "
113 f"{self.model_load_path}")
114 return os.path.join(self.path, f"{exp_name}_%s")
116 def _set_shapes(self):
117 """Set input and output shapes from train collection."""
118 shape = list(map(lambda x: x.shape[1:], self.data_store.get("data_collection", "train")[0].get_X()))
119 self.data_store.set("input_shape", shape, self.scope)
120 shape = list(map(lambda y: y.shape[1:], self.data_store.get("data_collection", "train")[0].get_Y()))
121 self.data_store.set("output_shape", shape, self.scope)
123 def _set_num_of_training_samples(self):
124 """ Set number of training samples - needed for example for Bayesian NNs"""
125 samples = 0
126 upsampling = self.data_store.create_args_dict(["upsampling"], "train")
127 for data in self.data_store.get("data_collection", "train"):
128 length = data.__len__(**upsampling)
129 samples += length
130 return samples
132 def compile_model(self):
133 """
134 Compiles the keras model. Compile options are mandatory and have to be set by implementing set_compile() method
135 in child class of AbstractModelClass.
136 """
137 compile_options = self.model.compile_options
138 self.model.compile(**compile_options)
139 self.data_store.set("model", self.model, self.scope)
141 def _set_callbacks(self):
142 """
143 Set all callbacks for the training phase.
145 Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added.
146 """
147 # create callback handler
148 callbacks = CallbackHandler()
150 # add callback: learning rate
151 lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
152 if lr is not None:
153 callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
155 # add callback: advanced history
156 hist = HistoryAdvanced()
157 self.data_store.set("hist", hist, scope="model")
158 callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
160 # add callback: epo timing
161 epo_timing = EpoTimingCallback()
162 self.data_store.set("epo_timing", epo_timing, scope="model")
163 callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing")
165 # add callback: early stopping
166 patience = self.data_store.get_default("early_stopping_epochs", default=np.inf)
167 restore_best_weights = self.data_store.get_default("restore_best_model_weights", default=True)
168 assert bool(isinstance(patience, int) or np.isinf(patience)) is True
169 cb = tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=restore_best_weights)
170 callbacks.add_callback(cb, self.callbacks_name % "early_stopping", "early_stopping")
172 # create model checkpoint
173 callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
174 save_best_only=True, mode='auto', restore_best_weights=restore_best_weights)
176 # store callbacks
177 self.data_store.set("callbacks", callbacks, self.scope)
179 def copy_model(self):
180 """Copy external model to internal experiment structure."""
181 if self.model_load_path is not None:
182 logging.info(f"Copy external model file: {self.model_load_path} -> {self.model_path}")
183 shutil.copyfile(self.model_load_path, self.model_path)
185 def load_model(self):
186 """Try to load model from disk or skip if not possible."""
187 self.copy_model()
188 try:
189 self.model.load_model(self.model_path)
190 logging.info(f"reload model {self.model_path} from disk ...")
191 except OSError:
192 logging.info('no local model to load...')
194 def build_model(self):
195 """Build model using input and output shapes from data store."""
196 model = self.data_store.get("model_class")
197 args_list = model.requirements()
198 if "num_of_training_samples" in args_list: 198 ↛ 199line 198 didn't jump to line 199, because the condition on line 198 was never true
199 num_of_training_samples = self._set_num_of_training_samples()
200 self.data_store.set("num_of_training_samples", num_of_training_samples, scope=self.scope)
201 logging.info(f"Store number of training samples ({num_of_training_samples}) in data_store: "
202 f"self.data_store.set('num_of_training_samples', {num_of_training_samples}, scope="
203 f"'{self.scope}')")
204 args = self.data_store.create_args_dict(args_list, self.scope)
205 self.model = model(**args)
206 self.get_model_settings()
208 def broadcast_custom_objects(self):
209 """
210 Broadcast custom objects to keras utils.
212 This method is very important, because it adds the model's custom objects to the keras utils. By doing so, all
213 custom objects can be treated as standard keras modules. Therefore, problems related to model or callback
214 loading are solved.
215 """
216 keras.utils.get_custom_objects().update(self.model.custom_objects)
218 def get_model_settings(self):
219 """Load all model settings and store in data store."""
220 model_settings = self.model.get_settings()
221 self.data_store.set_from_dict(model_settings, self.scope, log=True)
222 generic_model_name = self.data_store.get_default("model_name", self.scope, "my_model")
223 self.model_path = self.model_path % generic_model_name
224 self.data_store.set("model_name", self.model_path, self.scope)
226 def plot_model(self): # pragma: no cover
227 """Plot model architecture as `<model_name>.pdf`."""
228 try:
229 with tf.device("/cpu:0"):
230 file_name = f"{self.model_path.rsplit('.', 1)[0]}.pdf"
231 keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
232 except Exception as e:
233 logging.info(f"Can not plot model due to: {e}")
235 def report_model(self):
236 # report model settings
237 _f = self._clean_name
238 model_settings = self.model.get_settings()
239 model_settings.update(self.model.compile_options)
240 model_settings.update(self.model.optimizer.get_config())
241 df = pd.DataFrame(columns=["model setting"])
242 for k, v in model_settings.items():
243 if v is None:
244 continue
245 if isinstance(v, list):
246 if len(v) > 0:
247 if isinstance(v[0], dict):
248 v = ["{" + vi + "}" for vi in [",".join(f"{_f(str(uk))}:{_f(str(uv))}" for uk, uv in d.items()) for d in v]]
249 else:
250 v = ",".join(_f(str(u)) for u in v)
251 else:
252 v = "[]"
253 if "<" in str(v):
254 v = _f(str(v))
255 df.loc[k] = str(v)
256 df.loc["count params"] = str(self.model.count_params())
257 df.sort_index(inplace=True)
258 column_format = "ll"
259 path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
260 path_config.check_path_and_create(path)
261 for p in [path, self.path]: # log to `latex_report` and `model`
262 df.to_latex(os.path.join(p, "model_settings.tex"), na_rep='---', column_format=column_format)
263 df.to_markdown(open(os.path.join(p, "model_settings.md"), mode="w", encoding='utf-8'), tablefmt="github")
264 # report model summary to file
265 with open(os.path.join(self.path, "model_summary.txt"), "w") as fh:
266 self.model.summary(print_fn=lambda x: fh.write(x + "\n"))
267 # print model code to file
268 with open(os.path.join(self.path, "model_code.txt"), "w") as fh:
269 fh.write(getsource(self.data_store.get("model_class")))
271 @staticmethod
272 def _clean_name(orig_name: str):
273 mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
274 mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name
275 mod_name = mod_name[0].split(".")[-1] if any(
276 map(lambda x: x in mod_name[0], ["tensorflow", "keras"])) else mod_name
277 mod_name = mod_name[:-1] if mod_name[-1] == ">" else "".join(mod_name)
278 return mod_name.split(".")[-1] if any(map(lambda x: x in mod_name, ["tensorflow", "keras"])) else mod_name