Coverage for mlair/run_modules/experiment_setup.py: 81%
157 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1__author__ = "Lukas Leufen, Felix Kleinert"
2__date__ = '2019-11-15'
4import argparse
5import logging
6import os
7import sys
8from typing import Union, Dict, Any, List, Callable
9from dill.source import getsource
11from mlair.configuration import path_config
12from mlair import helpers
13from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_STATION_TYPE, \
14 DEFAULT_START, DEFAULT_END, DEFAULT_WINDOW_HISTORY_SIZE, DEFAULT_OVERWRITE_LOCAL_DATA, \
15 DEFAULT_HPC_LOGIN_LIST, DEFAULT_HPC_HOST_LIST, DEFAULT_CREATE_NEW_MODEL, DEFAULT_TRAIN_MODEL, \
16 DEFAULT_FRACTION_OF_TRAINING, DEFAULT_EXTREME_VALUES, DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY, DEFAULT_PERMUTE_DATA, \
17 DEFAULT_BATCH_SIZE, DEFAULT_EPOCHS, DEFAULT_TARGET_VAR, DEFAULT_TARGET_DIM, DEFAULT_WINDOW_LEAD_TIME, \
18 DEFAULT_WINDOW_DIM, DEFAULT_DIMENSIONS, DEFAULT_TIME_DIM, DEFAULT_INTERPOLATION_METHOD, DEFAULT_INTERPOLATION_LIMIT, \
19 DEFAULT_TRAIN_START, DEFAULT_TRAIN_END, DEFAULT_TRAIN_MIN_LENGTH, DEFAULT_VAL_START, DEFAULT_VAL_END, \
20 DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
21 DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_FEATURE_IMPORTANCE, DEFAULT_FEATURE_IMPORTANCE_CREATE_NEW_BOOTSTRAPS, \
22 DEFAULT_FEATURE_IMPORTANCE_N_BOOTS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \
23 DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING, \
24 DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \
25 DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \
26 DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE, DEFAULT_CREATE_SNAPSHOT, \
27 DEFAULT_EARLY_STOPPING_EPOCHS, DEFAULT_RESTORE_BEST_MODEL_WEIGHTS, DEFAULT_COMPETITORS
28from mlair.data_handler import DefaultDataHandler
29from mlair.run_modules.run_environment import RunEnvironment
30from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
33class ExperimentSetup(RunEnvironment):
34 """
35 Set up the model.
37 Schedule of experiment setup:
38 * set up experiment path
39 * set up data path (according to host system)
40 * set up forecast, bootstrap and plot path (inside experiment path)
41 * set all parameters given in args (or use default values)
42 * check target variable
43 * check `variables` and `statistics_per_var` parameter for consistency
45 Sets
46 * `data_path` [.]
47 * `create_new_model` [.]
48 * `bootstrap_path` [.]
49 * `train_model` [.]
50 * `fraction_of_training` [.]
51 * `extreme_values` [train]
52 * `extremes_on_right_tail_only` [train]
53 * `upsampling` [train]
54 * `permute_data` [train]
55 * `experiment_name` [.]
56 * `experiment_path` [.]
57 * `plot_path` [.]
58 * `forecast_path` [.]
59 * `stations` [.]
60 * `statistics_per_var` [.]
61 * `variables` [.]
62 * `start` [.]
63 * `end` [.]
64 * `window_history_size` [.]
65 * `overwrite_local_data` [preprocessing]
66 * `sampling` [.]
67 * `transformation` [., preprocessing]
68 * `target_var` [.]
69 * `target_dim` [.]
70 * `window_lead_time` [.]
72 Creates
73 * plot of model architecture in `<model_name>.pdf`
75 :param parser_args: argument parser, currently only accepting ``experiment_date argument`` to be used for
76 experiment's name and path creation. Final experiment's name is derived from given name and the time series
77 sampling as `<name>_network_<sampling>/` . All interim and final results, logging, plots, ... of this run are
78 stored in this directory if not explicitly provided in kwargs. Only the data itself and data for bootstrap
79 investigations are stored outside this structure.
80 :param stations: list of stations or single station to use in experiment. If not provided, stations are set to
81 :py:const:`default stations <DEFAULT_STATIONS>`.
82 :param variables: list of all variables to use. Valid names can be found in
83 `Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_. If not provided, this
84 parameter is filled with keys from ``statistics_per_var``.
85 :param statistics_per_var: dictionary with statistics to use for variables (if data is daily and loaded from JOIN).
86 If not provided, :py:const:`default statistics <DEFAULT_VAR_ALL_DICT>` is applied. ``statistics_per_var`` is
87 compared with given ``variables`` and unused variables are removed. Therefore, statistics at least need to
88 provide all variables from ``variables``. For more details on available statistics, we refer to
89 `Section 3.3 List of statistics/metrics for stats service <https://join.fz-juelich.de/services/rest/surfacedata/>`_
90 in the JOIN documentation. Valid parameter names can be found in
91 `Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_.
92 :param start: start date of overall data (default `"1997-01-01"`)
93 :param end: end date of overall data (default `"2017-12-31"`)
94 :param window_history_size: number of time steps to use for input data (default 13). Time steps `t_0 - w` to `t_0`
95 are used as input data (therefore actual data size is `w+1`).
96 :param target_var: target variable to predict by model, currently only a single target variable is supported.
97 Because this framework was originally designed to predict ozone, default is `"o3"`.
98 :param target_dim: dimension of target variable (default `"variables"`).
99 :param window_lead_time: number of time steps to predict by model (default 3). Time steps `t_0+1` to `t_0+w` are
100 predicted.
101 :param dimensions:
102 :param time_dim:
103 :param interpolation_method: The method to use for interpolation.
104 :param interpolation_limit: The maximum number of subsequent time steps in a gap to fill by interpolation. If the
105 gap exceeds this number, the gap is not filled by interpolation at all. The value of time steps is an arbitrary
106 number that is applied depending on the `sampling` frequency. A limit of 2 means that either 2 hours or 2 days
107 are allowed to be interpolated in dependency of the set sampling rate.
108 :param train_start:
109 :param train_end:
110 :param val_start:
111 :param val_end:
112 :param test_start:
113 :param test_end:
114 :param use_all_stations_on_all_data_sets:
115 :param train_model: train a new model from scratch or resume training with existing model if `True` (default) or
116 freeze loaded model and do not perform any modification on it. ``train_model`` is set to `True` if
117 ``create_new_model`` is `True`.
118 :param fraction_of_train: given value is used to split between test data and train data (including validation data).
119 The value of ``fraction_of_train`` must be in `(0, 1)` but is recommended to be in the interval `[0.6, 0.9]`.
120 Default value is `0.8`. Split between train and validation is fixed to 80% - 20% and currently not changeable.
121 :param experiment_path:
122 :param plot_path: path to save all plots. If left blank, this will be included in the experiment path (recommended).
123 Otherwise customise the location to save all plots.
124 :param forecast_path: path to save all forecasts in files. It is recommended to leave this parameter blank, all
125 forecasts will be the directory `forecasts` inside the experiment path (default). For customisation, add your
126 path here.
127 :param overwrite_local_data: Reload input and target data from web and replace local data if `True` (default
128 `False`).
129 :param sampling: set temporal sampling rate of data. You can choose from daily (default), monthly, seasonal,
130 vegseason, summer and annual for aggregated values and hourly for the actual values. Note, that hourly values on
131 JOIN are currently not accessible from outside. To access this data, you need to add your personal token in
132 :py:mod:`join settings <src.configuration.join_settings>` and make sure to untrack this file!
133 :param create_new_model: determine whether a new model will be created (`True`, default) or not (`False`). If this
134 parameter is set to `False`, make sure, that a suitable model already exists in the experiment path. This model
135 must fit in terms of input and output dimensions as well as ``window_history_size`` and ``window_lead_time`` and
136 must be implemented as a :py:mod:`model class <src.model_modules.model_class>` and imported in
137 :py:mod:`model setup <src.run_modules.model_setup>`. If ``create_new_model`` is `True`, parameter ``train_model``
138 is automatically set to `True` too.
139 :param bootstrap_path:
140 :param permute_data_on_training: shuffle train data individually for each station if `True`. This is performed each
141 iteration for new, so that each sample very likely differs from epoch to epoch. Train data permutation is
142 disabled (`False`) per default. If the case of extreme value manifolding, data permutation is enabled anyway.
143 :param transformation: set transformation options in dictionary style. All information about transformation options
144 can be found in :py:meth:`setup transformation <src.data_handling.data_generator.DataGenerator.setup_transformation>`.
145 If no transformation is provided, all options are set to :py:const:`default transformation <DEFAULT_TRANSFORMATION>`.
146 :param train_min_length:
147 :param val_min_length:
148 :param test_min_length:
149 :param extreme_values: augment target samples with values of lower occurrences indicated by its normalised
150 deviation from mean by manifolding. These extreme values need to be indicated by a list of thresholds. For
151 each entry in this list, all values outside an +/- interval will be added in the training (and only the
152 training) set for a second time to the sample. If multiple valus are given, a sample is added for each
153 exceedence once. E.g. a sample with `value=2.5` occurs twice in the training set for given
154 `extreme_values=[2, 3]`, whereas a sample with `value=5` occurs three times in the training set. For default,
155 upsampling of extreme values is disabled (`None`). Upsampling can be modified to manifold only values that are
156 actually larger than given values from ``extreme_values`` (apply only on right side of distribution) by using
157 ``extremes_on_right_tail_only``. This can be useful for positive skew variables.
158 :param extremes_on_right_tail_only: applies only if ``extreme_values`` are given. If ``extremes_on_right_tail_only``
159 is `True`, only manifold values that are larger than given extremes (apply upsampling only on right side of
160 distribution). In default mode, this is set to `False` to manifold extremes on both sides.
161 :param evaluate_bootstraps:
162 :param plot_list:
163 :param number_of_bootstraps:
164 :param create_new_bootstraps:
165 :param data_path: path to find and store meteorological and environmental / air quality data. Leave this parameter
166 empty, if your host system is known and a suitable path was already hardcoded in the program (see
167 :py:func:`prepare host <src.configuration.path_config.prepare_host>`).
168 :param experiment_date:
169 :param window_dim: "Temporal" dimension of the input and target data, that is provided for each sample. The number
170 of samples provided in this dimension can be set using `window_history_size` for inputs and `window_lead_time`
171 on target site.
172 :param iter_dim:
173 :param batch_path:
174 :param login_nodes:
175 :param hpc_hosts:
176 :param model:
177 :param batch_size:
178 :param epochs: Number of epochs used in training. If a training is resumed and the number of epochs of the already
179 (partly) trained model is lower than this parameter, training is continue. In case this number is higher than
180 the given epochs parameter, no training is resumed. Epochs is set to 20 per default, but this value is just a
181 placeholder that should be adjusted for a meaningful training.
182 :param early_stopping_epochs: number of consecutive epochs with no improvement on val loss to stop training. When
183 set to `np.inf` or not providing at all, training is not stopped before reaching `epochs`.
184 :param restore_best_model_weights: indicates whether to use model state with best val loss (if True) or model state
185 on ending of training (if False). The later depends on the parameters `epochs` and `early_stopping_epochs` which
186 trigger stopping of training.
187 :param data_handler:
188 :param data_origin:
189 :param competitors: Provide names of reference models trained by MLAir that can be found in the `competitor_path`.
190 These models will be used in the postprocessing for comparison.
191 :param competitor_path: The path where MLAir can find competing models. If not provided, this path is assumed to be
192 in the ´data_path´ directory as a subdirectory called `competitors` (default).
193 :param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this
194 parameter to `True` (default). If set to `False` the computation is performed in an serial approach.
195 Multiprocessing is disabled when running in debug mode and cannot be switched on.
196 :param transformation_file: Use transformation options from this file for transformation
197 :param calculate_fresh_transformation: can either be True or False, indicates if new transformation options should
198 be calculated in any case (transformation_file is not used in this case!).
199 :param snapshot_path: path to store snapshot of current run (default inside experiment path)
200 :param create_snapshot: indicate if a snapshot is taken from current run or not (default False)
201 :param snapshot_load_path: path to load a snapshot from (default None). In contrast to `snapshot_path`, which is
202 only for storing a snapshot, `snapshot_load_path` indicates where to load the snapshot from. If this parameter
203 is not provided at all, no snapshot is loaded. Note, the workflow will apply the default preprocessing without
204 loading a snapshot only if this parameter is None!
205 """
207 def __init__(self,
208 experiment_date=None,
209 stations: Union[str, List[str]] = None,
210 variables: Union[str, List[str]] = None,
211 statistics_per_var: Dict = None,
212 start: str = None,
213 end: str = None,
214 window_history_size: int = None,
215 target_var="o3",
216 target_dim=None,
217 window_lead_time: int = None,
218 window_dim=None,
219 dimensions=None,
220 time_dim=None,
221 iter_dim=None,
222 interpolation_method=None,
223 interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None,
224 test_start=None,
225 test_end=None, use_all_stations_on_all_data_sets=None, train_model: bool = None,
226 fraction_of_train: float = None,
227 experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data=None,
228 sampling: str = None,
229 create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
230 train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
231 extremes_on_right_tail_only: bool = None, evaluate_feature_importance: bool = None, plot_list=None,
232 feature_importance_n_boots: int = None, feature_importance_create_new_bootstraps: bool = None,
233 feature_importance_bootstrap_method=None, feature_importance_bootstrap_type=None,
234 data_path: str = None, batch_path: str = None, login_nodes=None,
235 hpc_hosts=None, model=None, batch_size=None, epochs=None,
236 early_stopping_epochs: int = None, restore_best_model_weights: bool = None,
237 data_handler=None,
238 data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
239 use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
240 max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None,
241 overwrite_lazy_data: bool = None, uncertainty_estimate_block_length: str = None,
242 uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None,
243 do_uncertainty_estimate: bool = None, model_display_name: str = None, transformation_file: str = None,
244 calculate_fresh_transformation: bool = None, snapshot_load_path: str = None,
245 create_snapshot: bool = None, snapshot_path: str = None, **kwargs):
247 # create run framework
248 super().__init__()
250 # experiment setup, hyperparameters
251 self._set_param("data_path", path_config.prepare_host(data_path=data_path))
252 self._set_param("hostname", path_config.get_host())
253 self._set_param("hpc_hosts", hpc_hosts, default=DEFAULT_HPC_HOST_LIST + DEFAULT_HPC_LOGIN_LIST)
254 self._set_param("login_nodes", login_nodes, default=DEFAULT_HPC_LOGIN_LIST)
255 self._set_param("create_new_model", create_new_model, default=DEFAULT_CREATE_NEW_MODEL)
256 if self.data_store.get("create_new_model"):
257 train_model = True
258 data_path = self.data_store.get("data_path")
259 bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path)
260 self._set_param("bootstrap_path", bootstrap_path)
261 self._set_param("train_model", train_model, default=DEFAULT_TRAIN_MODEL)
262 self._set_param("fraction_of_training", fraction_of_train, default=DEFAULT_FRACTION_OF_TRAINING)
263 self._set_param("extreme_values", extreme_values, default=DEFAULT_EXTREME_VALUES, scope="train")
264 self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only,
265 default=DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY, scope="train")
266 self._set_param("upsampling", extreme_values is not None, scope="train")
267 upsampling = self.data_store.get("upsampling", "train")
268 permute_data = DEFAULT_PERMUTE_DATA if permute_data_on_training is None else permute_data_on_training
269 self._set_param("permute_data", permute_data or upsampling, scope="train")
270 self._set_param("batch_size", batch_size, default=DEFAULT_BATCH_SIZE)
271 self._set_param("epochs", epochs, default=DEFAULT_EPOCHS)
272 self._set_param("early_stopping_epochs", early_stopping_epochs, default=DEFAULT_EARLY_STOPPING_EPOCHS)
273 self._set_param("restore_best_model_weights", restore_best_model_weights,
274 default=DEFAULT_RESTORE_BEST_MODEL_WEIGHTS)
276 # set experiment name
277 sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) # always related to output sampling
278 experiment_name = path_config.set_experiment_name(name=experiment_date, sampling=sampling)
279 experiment_path = path_config.set_experiment_path(name=experiment_name, path=experiment_path)
280 self._set_param("experiment_name", experiment_name)
281 self._set_param("experiment_path", experiment_path)
282 logging.info(f"Experiment path is: {experiment_path}")
283 path_config.check_path_and_create(self.data_store.get("experiment_path"))
285 # host system setup
286 debug_mode = sys.gettrace() is not None
287 self._set_param("debug_mode", debug_mode)
288 if debug_mode is True: 288 ↛ 292line 288 didn't jump to line 292, because the condition on line 288 was never false
289 self._set_param("use_multiprocessing", use_multiprocessing_on_debug,
290 default=DEFAULT_USE_MULTIPROCESSING_ON_DEBUG)
291 else:
292 self._set_param("use_multiprocessing", use_multiprocessing, default=DEFAULT_USE_MULTIPROCESSING)
293 self._set_param("max_number_multiprocessing", max_number_multiprocessing,
294 default=DEFAULT_MAX_NUMBER_MULTIPROCESSING)
296 # batch path (temporary)
297 self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data"))
299 # set model path
300 self._set_param("model_path", None, os.path.join(experiment_path, "model"))
301 path_config.check_path_and_create(self.data_store.get("model_path"))
303 # set plot path
304 default_plot_path = os.path.join(experiment_path, "plots")
305 self._set_param("plot_path", plot_path, default=default_plot_path)
306 path_config.check_path_and_create(self.data_store.get("plot_path"))
308 # set results path
309 default_forecast_path = os.path.join(experiment_path, "forecasts")
310 self._set_param("forecast_path", forecast_path, default_forecast_path)
311 path_config.check_path_and_create(self.data_store.get("forecast_path"))
313 # set logging path
314 self._set_param("logging_path", None, os.path.join(experiment_path, "logging"))
315 path_config.check_path_and_create(self.data_store.get("logging_path"))
317 # set tmp path
318 self._set_param("tmp_path", None, os.path.join(experiment_path, "tmp"))
319 path_config.check_path_and_create(self.data_store.get("tmp_path"), remove_existing=True)
321 # snapshot settings
322 self._set_param("snapshot_path", snapshot_path, default=os.path.join(experiment_path, "snapshot"))
323 path_config.check_path_and_create(self.data_store.get("snapshot_path"), remove_existing=False)
324 self._set_param("create_snapshot", create_snapshot, default=DEFAULT_CREATE_SNAPSHOT)
325 if snapshot_load_path is not None: 325 ↛ 326line 325 didn't jump to line 326, because the condition on line 325 was never true
326 self._set_param("snapshot_load_path", snapshot_load_path)
328 # setup for data
329 self._set_param("stations", stations, default=DEFAULT_STATIONS, apply=helpers.to_list)
330 self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
331 self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys()))
332 self._set_param("data_origin", data_origin, default=DEFAULT_DATA_ORIGIN)
333 self._set_param("start", start, default=DEFAULT_START)
334 self._set_param("end", end, default=DEFAULT_END)
335 self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE)
336 self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA,
337 scope="preprocessing")
338 self._set_param("overwrite_lazy_data", overwrite_lazy_data, default=DEFAULT_OVERWRITE_LAZY_DATA,
339 scope="preprocessing")
340 self._set_param("transformation", transformation, default={})
341 self._set_param("transformation", None, scope="preprocessing")
342 self._set_param("transformation_file", transformation_file, default=None)
343 if calculate_fresh_transformation is not None: 343 ↛ 344line 343 didn't jump to line 344, because the condition on line 343 was never true
344 self._set_param("calculate_fresh_transformation", calculate_fresh_transformation)
345 self._set_param("data_handler", data_handler, default=DefaultDataHandler)
347 # iter and window dimension
348 self._set_param("iter_dim", iter_dim, default=DEFAULT_ITER_DIM)
349 self._set_param("window_dim", window_dim, default=DEFAULT_WINDOW_DIM)
351 # target
352 self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
353 self._set_param("target_dim", target_dim, default=DEFAULT_TARGET_DIM)
354 self._set_param("window_lead_time", window_lead_time, default=DEFAULT_WINDOW_LEAD_TIME)
356 # interpolation
357 self._set_param("dimensions", dimensions, default=DEFAULT_DIMENSIONS)
358 self._set_param("time_dim", time_dim, default=DEFAULT_TIME_DIM)
359 self._set_param("interpolation_method", interpolation_method, default=DEFAULT_INTERPOLATION_METHOD)
360 self._set_param("interpolation_limit", interpolation_limit, default=DEFAULT_INTERPOLATION_LIMIT)
362 # train set parameters
363 self._set_param("start", train_start, default=DEFAULT_TRAIN_START, scope="train")
364 self._set_param("end", train_end, default=DEFAULT_TRAIN_END, scope="train")
365 self._set_param("min_length", train_min_length, default=DEFAULT_TRAIN_MIN_LENGTH, scope="train")
367 # validation set parameters
368 self._set_param("start", val_start, default=DEFAULT_VAL_START, scope="val")
369 self._set_param("end", val_end, default=DEFAULT_VAL_END, scope="val")
370 self._set_param("min_length", val_min_length, default=DEFAULT_VAL_MIN_LENGTH, scope="val")
372 # test set parameters
373 self._set_param("start", test_start, default=DEFAULT_TEST_START, scope="test")
374 self._set_param("end", test_end, default=DEFAULT_TEST_END, scope="test")
375 self._set_param("min_length", test_min_length, default=DEFAULT_TEST_MIN_LENGTH, scope="test")
377 # train_val set parameters
378 self._set_param("start", self.data_store.get("start", "train"), scope="train_val")
379 self._set_param("end", self.data_store.get("end", "val"), scope="train_val")
380 train_val_min_length = sum([self.data_store.get("min_length", s) for s in ["train", "val"]])
381 self._set_param("min_length", train_val_min_length, default=DEFAULT_TRAIN_VAL_MIN_LENGTH, scope="train_val")
383 # use all stations on all data sets (train, val, test)
384 self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets,
385 default=DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS)
387 # set post-processing instructions
388 self._set_param("do_uncertainty_estimate", do_uncertainty_estimate,
389 default=DEFAULT_DO_UNCERTAINTY_ESTIMATE, scope="general.postprocessing")
390 self._set_param("block_length", uncertainty_estimate_block_length,
391 default=DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, scope="uncertainty_estimate")
392 self._set_param("evaluate_competitors", uncertainty_estimate_evaluate_competitors,
393 default=DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, scope="uncertainty_estimate")
394 self._set_param("n_boots", uncertainty_estimate_n_boots,
395 default=DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, scope="uncertainty_estimate")
397 self._set_param("evaluate_feature_importance", evaluate_feature_importance,
398 default=DEFAULT_EVALUATE_FEATURE_IMPORTANCE, scope="general.postprocessing")
399 feature_importance_create_new_bootstraps = max([self.data_store.get("train_model", "general"),
400 feature_importance_create_new_bootstraps or
401 DEFAULT_FEATURE_IMPORTANCE_CREATE_NEW_BOOTSTRAPS])
402 self._set_param("create_new_bootstraps", feature_importance_create_new_bootstraps, scope="feature_importance")
403 self._set_param("n_boots", feature_importance_n_boots, default=DEFAULT_FEATURE_IMPORTANCE_N_BOOTS,
404 scope="feature_importance")
405 self._set_param("bootstrap_method", feature_importance_bootstrap_method,
406 default=DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, scope="feature_importance")
407 self._set_param("bootstrap_type", feature_importance_bootstrap_type,
408 default=DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, scope="feature_importance")
410 self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
411 if model_display_name is not None: 411 ↛ 412line 411 didn't jump to line 412, because the condition on line 411 was never true
412 self._set_param("model_display_name", model_display_name)
413 self._set_param("neighbors", ["DEBW030"]) # TODO: just for testing
415 # set competitors
416 if model_display_name is not None and competitors is not None and model_display_name in competitors: 416 ↛ 417line 416 didn't jump to line 417, because the condition on line 416 was never true
417 raise IndexError(f"Given model_display_name {model_display_name} is also present in the competitors "
418 f"variable {competitors}. To assure a proper workflow it is required to have unique names "
419 f"for each model and competitor. Please use a different model display name or competitor.")
420 self._set_param("competitors", competitors, default=DEFAULT_COMPETITORS)
421 competitor_path_default = os.path.join(self.data_store.get("data_path"), "competitors",
422 "_".join(self.data_store.get("target_var")))
423 self._set_param("competitor_path", competitor_path, default=competitor_path_default)
425 # check variables, statistics and target variable
426 self._check_target_var()
427 self._compare_variables_and_statistics()
429 # set model architecture class
430 self._set_param("model_class", model, VanillaModel)
432 # store starting script if provided
433 if start_script is not None: 433 ↛ 434line 433 didn't jump to line 434, because the condition on line 433 was never true
434 self._store_start_script(start_script, experiment_path)
436 # set remaining kwargs
437 if len(kwargs) > 0:
438 for k, v in kwargs.items():
439 if len(self.data_store.search_name(k)) == 0: 439 ↛ 442line 439 didn't jump to line 442, because the condition on line 439 was never false
440 self._set_param(k, v)
441 else:
442 s = ", ".join([f"{k}({s})={self.data_store.get(k, scope=s)}"
443 for s in self.data_store.search_name(k)])
444 raise KeyError(f"Given argument {k} with value {v} cannot be set for this experiment due to a "
445 f"conflict with an existing entry with same naming: {s}")
447 def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general",
448 apply: Callable = None) -> Any:
449 """Set given parameter and log in debug. Use apply parameter to adjust the stored value (e.g. to transform value
450 to a list use apply=helpers.to_list)."""
451 if value is None and default is not None:
452 value = default
453 if apply is not None:
454 value = apply(value)
455 self.data_store.set(param, value, scope)
456 logging.debug(f"set experiment attribute: {param}({scope})={value}")
457 return value
459 @staticmethod
460 def _store_start_script(start_script, store_path):
461 out_file = os.path.join(store_path, "start_script.txt")
462 if isinstance(start_script, Callable):
463 with open(out_file, "w") as fh:
464 fh.write(getsource(start_script))
465 if isinstance(start_script, str):
466 with open(start_script, 'r') as f:
467 with open(out_file, "w") as out:
468 for line in (f.readlines()):
469 print(line, end='', file=out)
471 def _compare_variables_and_statistics(self):
472 """
473 Compare variables and statistics.
475 * raise error, if a variable is missing.
476 * remove unused variables from statistics.
477 """
478 logging.debug("check if all variables are included in statistics_per_var")
479 stat = self.data_store.get("statistics_per_var")
480 var = self.data_store.get("variables")
481 # too less entries, raise error
482 if not set(var).issubset(stat.keys()):
483 missing = set(var).difference(stat.keys())
484 raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested "
485 f"variables are part of statistics_per_var. Please add also information on the missing "
486 f"statistics for the variables: {missing}")
487 # too much entries, remove unused
488 target_var = helpers.to_list(self.data_store.get("target_var"))
489 unused_vars = set(stat.keys()).difference(set(var).union(target_var))
490 if len(unused_vars) > 0: 490 ↛ 491line 490 didn't jump to line 491, because the condition on line 490 was never true
491 logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}")
492 stat_new = helpers.remove_items(stat, list(unused_vars))
493 self._set_param("statistics_per_var", stat_new)
495 def _check_target_var(self):
496 """Check if target variable is in statistics_per_var dictionary."""
497 target_var = helpers.to_list(self.data_store.get("target_var"))
498 stat = self.data_store.get("statistics_per_var")
499 var = self.data_store.get("variables")
500 if not set(target_var).issubset(stat.keys()): 500 ↛ 501line 500 didn't jump to line 501, because the condition on line 500 was never true
501 raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.")
504if __name__ == "__main__":
505 formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
506 logging.basicConfig(format=formatter, level=logging.DEBUG)
508 parser = argparse.ArgumentParser()
509 parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
510 help="set experiment date as string")
511 parser_args = parser.parse_args()
512 with RunEnvironment():
513 setup = ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])