Coverage for mlair/run_modules/model_setup.py: 41%
120 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1"""Model setup module."""
3__author__ = "Lukas Leufen, Felix Kleinert"
4__date__ = '2019-12-02'
6import logging
7import os
8import re
9from dill.source import getsource
11import tensorflow.keras as keras
12import pandas as pd
13import tensorflow as tf
14import numpy as np
16from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler
17from mlair.run_modules.run_environment import RunEnvironment
18from mlair.configuration import path_config
21class ModelSetup(RunEnvironment):
22 """
23 Set up the model.
25 Schedule of model setup:
26 #. set channels (from variables dimension)
27 #. build imported model
28 #. plot model architecture
29 #. load weights if enabled (e.g. to resume a training)
30 #. set callbacks and checkpoint
31 #. compile model
33 Required objects [scope] from data store:
34 * `experiment_path` [.]
35 * `experiment_name` [.]
36 * `train_model` [.]
37 * `create_new_model` [.]
38 * `generator` [train]
39 * `model_class` [.]
41 Optional objects
42 * `lr_decay` [model]
44 Sets
45 * `channels` [model]
46 * `model` [model]
47 * `hist` [model]
48 * `callbacks` [model]
49 * `model_name` [model]
50 * all settings from model class like `dropout_rate`, `initial_lr`, and `optimizer` [model]
52 Creates
53 * plot of model architecture `<model_name>.pdf`
55 """
57 def __init__(self):
58 """Initialise and run model setup."""
59 super().__init__()
60 self.model = None
61 exp_name = self.data_store.get("experiment_name")
62 self.path = self.data_store.get("model_path")
63 self.scope = "model"
64 path = os.path.join(self.path, f"{exp_name}_%s")
65 self.model_name = path % "%s.h5"
66 self.checkpoint_name = path % "model-best.h5"
67 self.callbacks_name = path % "model-best-callbacks-%s.pickle"
68 self._train_model = self.data_store.get("train_model")
69 self._create_new_model = self.data_store.get("create_new_model")
70 self._run()
72 def _run(self):
74 # set channels depending on inputs
75 self._set_shapes()
77 # build model graph using settings from my_model_settings()
78 self.build_model()
80 # broadcast custom objects
81 self.broadcast_custom_objects()
83 # plot model structure
84 self.plot_model()
86 # load weights if no training shall be performed
87 if not self._train_model and not self._create_new_model:
88 self.load_model()
90 # create checkpoint
91 self._set_callbacks()
93 # compile model
94 self.compile_model()
96 # report settings
97 self.report_model()
99 def _set_shapes(self):
100 """Set input and output shapes from train collection."""
101 shape = list(map(lambda x: x.shape[1:], self.data_store.get("data_collection", "train")[0].get_X()))
102 self.data_store.set("input_shape", shape, self.scope)
103 shape = list(map(lambda y: y.shape[1:], self.data_store.get("data_collection", "train")[0].get_Y()))
104 self.data_store.set("output_shape", shape, self.scope)
106 def _set_num_of_training_samples(self):
107 """ Set number of training samples - needed for example for Bayesian NNs"""
108 samples = 0
109 upsampling = self.data_store.create_args_dict(["upsampling"], "train")
110 for data in self.data_store.get("data_collection", "train"):
111 length = data.__len__(**upsampling)
112 samples += length
113 return samples
115 def compile_model(self):
116 """
117 Compiles the keras model. Compile options are mandatory and have to be set by implementing set_compile() method
118 in child class of AbstractModelClass.
119 """
120 compile_options = self.model.compile_options
121 self.model.compile(**compile_options)
122 self.data_store.set("model", self.model, self.scope)
124 def _set_callbacks(self):
125 """
126 Set all callbacks for the training phase.
128 Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added.
129 """
130 # create callback handler
131 callbacks = CallbackHandler()
133 # add callback: learning rate
134 lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
135 if lr is not None:
136 callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
138 # add callback: advanced history
139 hist = HistoryAdvanced()
140 self.data_store.set("hist", hist, scope="model")
141 callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
143 # add callback: epo timing
144 epo_timing = EpoTimingCallback()
145 self.data_store.set("epo_timing", epo_timing, scope="model")
146 callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing")
148 # add callback: early stopping
149 patience = self.data_store.get_default("early_stopping_epochs", default=np.inf)
150 restore_best_weights = self.data_store.get_default("restore_best_model_weights", default=True)
151 assert bool(isinstance(patience, int) or np.isinf(patience)) is True
152 cb = tf.keras.callbacks.EarlyStopping(patience=patience, restore_best_weights=restore_best_weights)
153 callbacks.add_callback(cb, self.callbacks_name % "early_stopping", "early_stopping")
155 # create model checkpoint
156 callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
157 save_best_only=True, mode='auto', restore_best_weights=restore_best_weights)
159 # store callbacks
160 self.data_store.set("callbacks", callbacks, self.scope)
162 def load_model(self):
163 """Try to load model from disk or skip if not possible."""
164 try:
165 self.model.load_model(self.model_name)
166 logging.info(f"reload model {self.model_name} from disk ...")
167 except OSError:
168 logging.info('no local model to load...')
170 def build_model(self):
171 """Build model using input and output shapes from data store."""
172 model = self.data_store.get("model_class")
173 args_list = model.requirements()
174 if "num_of_training_samples" in args_list: 174 ↛ 175line 174 didn't jump to line 175, because the condition on line 174 was never true
175 num_of_training_samples = self._set_num_of_training_samples()
176 self.data_store.set("num_of_training_samples", num_of_training_samples, scope=self.scope)
177 logging.info(f"Store number of training samples ({num_of_training_samples}) in data_store: "
178 f"self.data_store.set('num_of_training_samples', {num_of_training_samples}, scope="
179 f"'{self.scope}')")
180 args = self.data_store.create_args_dict(args_list, self.scope)
181 self.model = model(**args)
182 self.get_model_settings()
184 def broadcast_custom_objects(self):
185 """
186 Broadcast custom objects to keras utils.
188 This method is very important, because it adds the model's custom objects to the keras utils. By doing so, all
189 custom objects can be treated as standard keras modules. Therefore, problems related to model or callback
190 loading are solved.
191 """
192 keras.utils.get_custom_objects().update(self.model.custom_objects)
194 def get_model_settings(self):
195 """Load all model settings and store in data store."""
196 model_settings = self.model.get_settings()
197 self.data_store.set_from_dict(model_settings, self.scope, log=True)
198 self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model")
199 self.data_store.set("model_name", self.model_name, self.scope)
201 def plot_model(self): # pragma: no cover
202 """Plot model architecture as `<model_name>.pdf`."""
203 try:
204 with tf.device("/cpu:0"):
205 file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
206 keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
207 except Exception as e:
208 logging.info(f"Can not plot model due to: {e}")
210 def report_model(self):
211 # report model settings
212 _f = self._clean_name
213 model_settings = self.model.get_settings()
214 model_settings.update(self.model.compile_options)
215 model_settings.update(self.model.optimizer.get_config())
216 df = pd.DataFrame(columns=["model setting"])
217 for k, v in model_settings.items():
218 if v is None:
219 continue
220 if isinstance(v, list):
221 if len(v) > 0:
222 if isinstance(v[0], dict):
223 v = ["{" + vi + "}" for vi in [",".join(f"{_f(str(uk))}:{_f(str(uv))}" for uk, uv in d.items()) for d in v]]
224 else:
225 v = ",".join(_f(str(u)) for u in v)
226 else:
227 v = "[]"
228 if "<" in str(v):
229 v = _f(str(v))
230 df.loc[k] = str(v)
231 df.loc["count params"] = str(self.model.count_params())
232 df.sort_index(inplace=True)
233 column_format = "ll"
234 path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
235 path_config.check_path_and_create(path)
236 for p in [path, self.path]: # log to `latex_report` and `model`
237 df.to_latex(os.path.join(p, "model_settings.tex"), na_rep='---', column_format=column_format)
238 df.to_markdown(open(os.path.join(p, "model_settings.md"), mode="w", encoding='utf-8'), tablefmt="github")
239 # report model summary to file
240 with open(os.path.join(self.path, "model_summary.txt"), "w") as fh:
241 self.model.summary(print_fn=lambda x: fh.write(x + "\n"))
242 # print model code to file
243 with open(os.path.join(self.path, "model_code.txt"), "w") as fh:
244 fh.write(getsource(self.data_store.get("model_class")))
246 @staticmethod
247 def _clean_name(orig_name: str):
248 mod_name = re.sub(r'^{0}'.format(re.escape("<")), '', orig_name).replace("'", "").split(" ")
249 mod_name = mod_name[1] if any(map(lambda x: x in mod_name[0], ["class", "function", "method"])) else mod_name
250 mod_name = mod_name[0].split(".")[-1] if any(
251 map(lambda x: x in mod_name[0], ["tensorflow", "keras"])) else mod_name
252 mod_name = mod_name[:-1] if mod_name[-1] == ">" else "".join(mod_name)
253 return mod_name.split(".")[-1] if any(map(lambda x: x in mod_name, ["tensorflow", "keras"])) else mod_name