Coverage for mlair/run_modules/experiment_setup.py: 83%
159 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-12-18 17:51 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-12-18 17:51 +0000
1__author__ = "Lukas Leufen, Felix Kleinert"
2__date__ = '2019-11-15'
4import argparse
5import logging
6import os
7import sys
8from typing import Union, Dict, Any, List, Callable
9from dill.source import getsource
11from mlair.configuration import path_config
12from mlair import helpers
13from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_STATION_TYPE, \
14 DEFAULT_START, DEFAULT_END, DEFAULT_WINDOW_HISTORY_SIZE, DEFAULT_OVERWRITE_LOCAL_DATA, \
15 DEFAULT_HPC_LOGIN_LIST, DEFAULT_HPC_HOST_LIST, DEFAULT_CREATE_NEW_MODEL, DEFAULT_TRAIN_MODEL, \
16 DEFAULT_FRACTION_OF_TRAINING, DEFAULT_EXTREME_VALUES, DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY, DEFAULT_PERMUTE_DATA, \
17 DEFAULT_BATCH_SIZE, DEFAULT_EPOCHS, DEFAULT_TARGET_VAR, DEFAULT_TARGET_DIM, DEFAULT_WINDOW_LEAD_TIME, \
18 DEFAULT_WINDOW_DIM, DEFAULT_DIMENSIONS, DEFAULT_TIME_DIM, DEFAULT_INTERPOLATION_METHOD, DEFAULT_INTERPOLATION_LIMIT, \
19 DEFAULT_TRAIN_START, DEFAULT_TRAIN_END, DEFAULT_TRAIN_MIN_LENGTH, DEFAULT_VAL_START, DEFAULT_VAL_END, \
20 DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
21 DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_FEATURE_IMPORTANCE, DEFAULT_FEATURE_IMPORTANCE_CREATE_NEW_BOOTSTRAPS, \
22 DEFAULT_FEATURE_IMPORTANCE_N_BOOTS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \
23 DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING, \
24 DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \
25 DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \
26 DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE, DEFAULT_CREATE_SNAPSHOT, \
27 DEFAULT_EARLY_STOPPING_EPOCHS, DEFAULT_RESTORE_BEST_MODEL_WEIGHTS, DEFAULT_COMPETITORS, \
28 DEFAULT_DO_BIAS_FREE_EVALUATION
29from mlair.data_handler import DefaultDataHandler
30from mlair.run_modules.run_environment import RunEnvironment
31from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
34class ExperimentSetup(RunEnvironment):
35 """
36 Set up the model.
38 Schedule of experiment setup:
39 * set up experiment path
40 * set up data path (according to host system)
41 * set up forecast, bootstrap and plot path (inside experiment path)
42 * set all parameters given in args (or use default values)
43 * check target variable
44 * check `variables` and `statistics_per_var` parameter for consistency
46 Sets
47 * `data_path` [.]
48 * `create_new_model` [.]
49 * `bootstrap_path` [.]
50 * `train_model` [.]
51 * `fraction_of_training` [.]
52 * `extreme_values` [train]
53 * `extremes_on_right_tail_only` [train]
54 * `upsampling` [train]
55 * `permute_data` [train]
56 * `experiment_name` [.]
57 * `experiment_path` [.]
58 * `plot_path` [.]
59 * `forecast_path` [.]
60 * `stations` [.]
61 * `statistics_per_var` [.]
62 * `variables` [.]
63 * `start` [.]
64 * `end` [.]
65 * `window_history_size` [.]
66 * `overwrite_local_data` [preprocessing]
67 * `sampling` [.]
68 * `transformation` [., preprocessing]
69 * `target_var` [.]
70 * `target_dim` [.]
71 * `window_lead_time` [.]
73 Creates
74 * plot of model architecture in `<model_name>.pdf`
76 :param parser_args: argument parser, currently only accepting ``experiment_date argument`` to be used for
77 experiment's name and path creation. Final experiment's name is derived from given name and the time series
78 sampling as `<name>_network_<sampling>/` . All interim and final results, logging, plots, ... of this run are
79 stored in this directory if not explicitly provided in kwargs. Only the data itself and data for bootstrap
80 investigations are stored outside this structure.
81 :param stations: list of stations or single station to use in experiment. If not provided, stations are set to
82 :py:const:`default stations <DEFAULT_STATIONS>`.
83 :param variables: list of all variables to use. Valid names can be found in
84 `Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_. If not provided, this
85 parameter is filled with keys from ``statistics_per_var``.
86 :param statistics_per_var: dictionary with statistics to use for variables (if data is daily and loaded from JOIN).
87 If not provided, :py:const:`default statistics <DEFAULT_VAR_ALL_DICT>` is applied. ``statistics_per_var`` is
88 compared with given ``variables`` and unused variables are removed. Therefore, statistics at least need to
89 provide all variables from ``variables``. For more details on available statistics, we refer to
90 `Section 3.3 List of statistics/metrics for stats service <https://join.fz-juelich.de/services/rest/surfacedata/>`_
91 in the JOIN documentation. Valid parameter names can be found in
92 `Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_.
93 :param start: start date of overall data (default `"1997-01-01"`)
94 :param end: end date of overall data (default `"2017-12-31"`)
95 :param window_history_size: number of time steps to use for input data (default 13). Time steps `t_0 - w` to `t_0`
96 are used as input data (therefore actual data size is `w+1`).
97 :param target_var: target variable to predict by model, currently only a single target variable is supported.
98 Because this framework was originally designed to predict ozone, default is `"o3"`.
99 :param target_dim: dimension of target variable (default `"variables"`).
100 :param window_lead_time: number of time steps to predict by model (default 3). Time steps `t_0+1` to `t_0+w` are
101 predicted.
102 :param dimensions:
103 :param time_dim:
104 :param interpolation_method: The method to use for interpolation.
105 :param interpolation_limit: The maximum number of subsequent time steps in a gap to fill by interpolation. If the
106 gap exceeds this number, the gap is not filled by interpolation at all. The value of time steps is an arbitrary
107 number that is applied depending on the `sampling` frequency. A limit of 2 means that either 2 hours or 2 days
108 are allowed to be interpolated in dependency of the set sampling rate.
109 :param train_start:
110 :param train_end:
111 :param val_start:
112 :param val_end:
113 :param test_start:
114 :param test_end:
115 :param use_all_stations_on_all_data_sets:
116 :param train_model: train a new model from scratch or resume training with existing model if `True` (default) or
117 freeze loaded model and do not perform any modification on it. ``train_model`` is set to `True` if
118 ``create_new_model`` is `True`.
119 :param fraction_of_train: given value is used to split between test data and train data (including validation data).
120 The value of ``fraction_of_train`` must be in `(0, 1)` but is recommended to be in the interval `[0.6, 0.9]`.
121 Default value is `0.8`. Split between train and validation is fixed to 80% - 20% and currently not changeable.
122 :param experiment_path:
123 :param plot_path: path to save all plots. If left blank, this will be included in the experiment path (recommended).
124 Otherwise customise the location to save all plots.
125 :param forecast_path: path to save all forecasts in files. It is recommended to leave this parameter blank, all
126 forecasts will be the directory `forecasts` inside the experiment path (default). For customisation, add your
127 path here.
128 :param overwrite_local_data: Reload input and target data from web and replace local data if `True` (default
129 `False`).
130 :param sampling: set temporal sampling rate of data. You can choose from daily (default), monthly, seasonal,
131 vegseason, summer and annual for aggregated values and hourly for the actual values. Note, that hourly values on
132 JOIN are currently not accessible from outside. To access this data, you need to add your personal token in
133 :py:mod:`join settings <src.configuration.join_settings>` and make sure to untrack this file!
134 :param create_new_model: determine whether a new model will be created (`True`, default) or not (`False`). If this
135 parameter is set to `False`, make sure, that a suitable model already exists in the experiment path. This model
136 must fit in terms of input and output dimensions as well as ``window_history_size`` and ``window_lead_time`` and
137 must be implemented as a :py:mod:`model class <src.model_modules.model_class>` and imported in
138 :py:mod:`model setup <src.run_modules.model_setup>`. If ``create_new_model`` is `True`, parameter ``train_model``
139 is automatically set to `True` too.
140 :param bootstrap_path:
141 :param permute_data_on_training: shuffle train data individually for each station if `True`. This is performed each
142 iteration for new, so that each sample very likely differs from epoch to epoch. Train data permutation is
143 disabled (`False`) per default. If the case of extreme value manifolding, data permutation is enabled anyway.
144 :param transformation: set transformation options in dictionary style. All information about transformation options
145 can be found in :py:meth:`setup transformation <src.data_handling.data_generator.DataGenerator.setup_transformation>`.
146 If no transformation is provided, all options are set to :py:const:`default transformation <DEFAULT_TRANSFORMATION>`.
147 :param train_min_length:
148 :param val_min_length:
149 :param test_min_length:
150 :param extreme_values: augment target samples with values of lower occurrences indicated by its normalised
151 deviation from mean by manifolding. These extreme values need to be indicated by a list of thresholds. For
152 each entry in this list, all values outside an +/- interval will be added in the training (and only the
153 training) set for a second time to the sample. If multiple valus are given, a sample is added for each
154 exceedence once. E.g. a sample with `value=2.5` occurs twice in the training set for given
155 `extreme_values=[2, 3]`, whereas a sample with `value=5` occurs three times in the training set. For default,
156 upsampling of extreme values is disabled (`None`). Upsampling can be modified to manifold only values that are
157 actually larger than given values from ``extreme_values`` (apply only on right side of distribution) by using
158 ``extremes_on_right_tail_only``. This can be useful for positive skew variables.
159 :param extremes_on_right_tail_only: applies only if ``extreme_values`` are given. If ``extremes_on_right_tail_only``
160 is `True`, only manifold values that are larger than given extremes (apply upsampling only on right side of
161 distribution). In default mode, this is set to `False` to manifold extremes on both sides.
162 :param evaluate_bootstraps:
163 :param plot_list:
164 :param number_of_bootstraps:
165 :param create_new_bootstraps:
166 :param data_path: path to find and store meteorological and environmental / air quality data. Leave this parameter
167 empty, if your host system is known and a suitable path was already hardcoded in the program (see
168 :py:func:`prepare host <src.configuration.path_config.prepare_host>`).
169 :param experiment_date:
170 :param window_dim: "Temporal" dimension of the input and target data, that is provided for each sample. The number
171 of samples provided in this dimension can be set using `window_history_size` for inputs and `window_lead_time`
172 on target site.
173 :param iter_dim:
174 :param batch_path:
175 :param login_nodes:
176 :param hpc_hosts:
177 :param model:
178 :param batch_size:
179 :param epochs: Number of epochs used in training. If a training is resumed and the number of epochs of the already
180 (partly) trained model is lower than this parameter, training is continue. In case this number is higher than
181 the given epochs parameter, no training is resumed. Epochs is set to 20 per default, but this value is just a
182 placeholder that should be adjusted for a meaningful training.
183 :param early_stopping_epochs: number of consecutive epochs with no improvement on val loss to stop training. When
184 set to `np.inf` or not providing at all, training is not stopped before reaching `epochs`.
185 :param restore_best_model_weights: indicates whether to use model state with best val loss (if True) or model state
186 on ending of training (if False). The later depends on the parameters `epochs` and `early_stopping_epochs` which
187 trigger stopping of training.
188 :param data_handler:
189 :param data_origin:
190 :param competitors: Provide names of reference models trained by MLAir that can be found in the `competitor_path`.
191 These models will be used in the postprocessing for comparison.
192 :param competitor_path: The path where MLAir can find competing models. If not provided, this path is assumed to be
193 in the ´data_path´ directory as a subdirectory called `competitors` (default).
194 :param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this
195 parameter to `True` (default). If set to `False` the computation is performed in an serial approach.
196 Multiprocessing is disabled when running in debug mode and cannot be switched on.
197 :param transformation_file: Use transformation options from this file for transformation
198 :param calculate_fresh_transformation: can either be True or False, indicates if new transformation options should
199 be calculated in any case (transformation_file is not used in this case!).
200 :param snapshot_path: path to store snapshot of current run (default inside experiment path)
201 :param create_snapshot: indicate if a snapshot is taken from current run or not (default False)
202 :param snapshot_load_path: path to load a snapshot from (default None). In contrast to `snapshot_path`, which is
203 only for storing a snapshot, `snapshot_load_path` indicates where to load the snapshot from. If this parameter
204 is not provided at all, no snapshot is loaded. Note, the workflow will apply the default preprocessing without
205 loading a snapshot only if this parameter is None!
206 """
208 def __init__(self,
209 experiment_date=None,
210 stations: Union[str, List[str]] = None,
211 variables: Union[str, List[str]] = None,
212 statistics_per_var: Dict = None,
213 start: str = None,
214 end: str = None,
215 window_history_size: int = None,
216 target_var="o3",
217 target_dim=None,
218 window_lead_time: int = None,
219 window_dim=None,
220 dimensions=None,
221 time_dim=None,
222 iter_dim=None,
223 interpolation_method=None,
224 interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None,
225 test_start=None,
226 test_end=None, use_all_stations_on_all_data_sets=None, train_model: bool = None,
227 fraction_of_train: float = None,
228 experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data=None,
229 sampling: str = None,
230 create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
231 train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
232 extremes_on_right_tail_only: bool = None, evaluate_feature_importance: bool = None, plot_list=None,
233 feature_importance_n_boots: int = None, feature_importance_create_new_bootstraps: bool = None,
234 feature_importance_bootstrap_method=None, feature_importance_bootstrap_type=None,
235 data_path: str = None, batch_path: str = None, login_nodes=None,
236 hpc_hosts=None, model=None, batch_size=None, epochs=None,
237 early_stopping_epochs: int = None, restore_best_model_weights: bool = None,
238 data_handler=None,
239 data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
240 use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
241 max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None,
242 overwrite_lazy_data: bool = None, uncertainty_estimate_block_length: str = None,
243 uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None,
244 do_uncertainty_estimate: bool = None, do_bias_free_evaluation: bool = None,
245 model_display_name: str = None, transformation_file: str = None,
246 calculate_fresh_transformation: bool = None, snapshot_load_path: str = None,
247 create_snapshot: bool = None, snapshot_path: str = None, model_path: str = None, **kwargs):
249 # create run framework
250 super().__init__()
252 # experiment setup, hyperparameters
253 self._set_param("data_path", path_config.prepare_host(data_path=data_path))
254 self._set_param("hostname", path_config.get_host())
255 self._set_param("hpc_hosts", hpc_hosts, default=DEFAULT_HPC_HOST_LIST + DEFAULT_HPC_LOGIN_LIST)
256 self._set_param("login_nodes", login_nodes, default=DEFAULT_HPC_LOGIN_LIST)
257 self._set_param("create_new_model", create_new_model, default=DEFAULT_CREATE_NEW_MODEL)
258 if self.data_store.get("create_new_model"):
259 train_model = True
260 data_path = self.data_store.get("data_path")
261 bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path)
262 self._set_param("bootstrap_path", bootstrap_path)
263 self._set_param("train_model", train_model, default=DEFAULT_TRAIN_MODEL)
264 self._set_param("fraction_of_training", fraction_of_train, default=DEFAULT_FRACTION_OF_TRAINING)
265 self._set_param("extreme_values", extreme_values, default=DEFAULT_EXTREME_VALUES, scope="train")
266 self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only,
267 default=DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY, scope="train")
268 self._set_param("upsampling", extreme_values is not None, scope="train")
269 upsampling = self.data_store.get("upsampling", "train")
270 permute_data = DEFAULT_PERMUTE_DATA if permute_data_on_training is None else permute_data_on_training
271 self._set_param("permute_data", permute_data or upsampling, scope="train")
272 self._set_param("batch_size", batch_size, default=DEFAULT_BATCH_SIZE)
273 self._set_param("epochs", epochs, default=DEFAULT_EPOCHS)
274 self._set_param("early_stopping_epochs", early_stopping_epochs, default=DEFAULT_EARLY_STOPPING_EPOCHS)
275 self._set_param("restore_best_model_weights", restore_best_model_weights,
276 default=DEFAULT_RESTORE_BEST_MODEL_WEIGHTS)
278 # set experiment name
279 sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) # always related to output sampling
280 experiment_name = path_config.set_experiment_name(name=experiment_date, sampling=sampling)
281 experiment_path = path_config.set_experiment_path(name=experiment_name, path=experiment_path)
282 self._set_param("experiment_name", experiment_name)
283 self._set_param("experiment_path", experiment_path)
284 logging.info(f"Experiment path is: {experiment_path}")
285 path_config.check_path_and_create(self.data_store.get("experiment_path"))
287 # host system setup
288 debug_mode = sys.gettrace() is not None
289 self._set_param("debug_mode", debug_mode)
290 if debug_mode is True: 290 ↛ 294line 290 didn't jump to line 294, because the condition on line 290 was never false
291 self._set_param("use_multiprocessing", use_multiprocessing_on_debug,
292 default=DEFAULT_USE_MULTIPROCESSING_ON_DEBUG)
293 else:
294 self._set_param("use_multiprocessing", use_multiprocessing, default=DEFAULT_USE_MULTIPROCESSING)
295 self._set_param("max_number_multiprocessing", max_number_multiprocessing,
296 default=DEFAULT_MAX_NUMBER_MULTIPROCESSING)
298 # batch path (temporary)
299 self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data"))
301 # set model path
302 self._set_param("model_load_path", model_path)
303 self._set_param("model_path", None, os.path.join(experiment_path, "model"))
304 path_config.check_path_and_create(self.data_store.get("model_path"))
306 # set plot path
307 default_plot_path = os.path.join(experiment_path, "plots")
308 self._set_param("plot_path", plot_path, default=default_plot_path)
309 path_config.check_path_and_create(self.data_store.get("plot_path"))
311 # set results path
312 default_forecast_path = os.path.join(experiment_path, "forecasts")
313 self._set_param("forecast_path", forecast_path, default_forecast_path)
314 path_config.check_path_and_create(self.data_store.get("forecast_path"))
316 # set logging path
317 self._set_param("logging_path", None, os.path.join(experiment_path, "logging"))
318 path_config.check_path_and_create(self.data_store.get("logging_path"))
320 # set tmp path
321 self._set_param("tmp_path", None, os.path.join(experiment_path, "tmp"))
322 path_config.check_path_and_create(self.data_store.get("tmp_path"), remove_existing=True)
324 # snapshot settings
325 self._set_param("snapshot_path", snapshot_path, default=os.path.join(experiment_path, "snapshot"))
326 path_config.check_path_and_create(self.data_store.get("snapshot_path"), remove_existing=False)
327 self._set_param("create_snapshot", create_snapshot, default=DEFAULT_CREATE_SNAPSHOT)
328 if snapshot_load_path is not None: 328 ↛ 329line 328 didn't jump to line 329, because the condition on line 328 was never true
329 self._set_param("snapshot_load_path", snapshot_load_path)
331 # setup for data
332 self._set_param("stations", stations, default=DEFAULT_STATIONS, apply=helpers.to_list)
333 self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
334 self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys()))
335 self._set_param("data_origin", data_origin, default=DEFAULT_DATA_ORIGIN)
336 self._set_param("start", start, default=DEFAULT_START)
337 self._set_param("end", end, default=DEFAULT_END)
338 self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE)
339 self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA,
340 scope="preprocessing")
341 self._set_param("overwrite_lazy_data", overwrite_lazy_data, default=DEFAULT_OVERWRITE_LAZY_DATA,
342 scope="preprocessing")
343 self._set_param("transformation", transformation, default={})
344 self._set_param("transformation", None, scope="preprocessing")
345 self._set_param("transformation_file", transformation_file, default=None)
346 if calculate_fresh_transformation is not None: 346 ↛ 347line 346 didn't jump to line 347, because the condition on line 346 was never true
347 self._set_param("calculate_fresh_transformation", calculate_fresh_transformation)
348 self._set_param("data_handler", data_handler, default=DefaultDataHandler)
350 # iter and window dimension
351 self._set_param("iter_dim", iter_dim, default=DEFAULT_ITER_DIM)
352 self._set_param("window_dim", window_dim, default=DEFAULT_WINDOW_DIM)
354 # target
355 self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
356 self._set_param("target_dim", target_dim, default=DEFAULT_TARGET_DIM)
357 self._set_param("window_lead_time", window_lead_time, default=DEFAULT_WINDOW_LEAD_TIME)
359 # interpolation
360 self._set_param("dimensions", dimensions, default=DEFAULT_DIMENSIONS)
361 self._set_param("time_dim", time_dim, default=DEFAULT_TIME_DIM)
362 self._set_param("interpolation_method", interpolation_method, default=DEFAULT_INTERPOLATION_METHOD)
363 self._set_param("interpolation_limit", interpolation_limit, default=DEFAULT_INTERPOLATION_LIMIT)
365 # train set parameters
366 self._set_param("start", train_start, default=DEFAULT_TRAIN_START, scope="train")
367 self._set_param("end", train_end, default=DEFAULT_TRAIN_END, scope="train")
368 self._set_param("min_length", train_min_length, default=DEFAULT_TRAIN_MIN_LENGTH, scope="train")
370 # validation set parameters
371 self._set_param("start", val_start, default=DEFAULT_VAL_START, scope="val")
372 self._set_param("end", val_end, default=DEFAULT_VAL_END, scope="val")
373 self._set_param("min_length", val_min_length, default=DEFAULT_VAL_MIN_LENGTH, scope="val")
375 # test set parameters
376 self._set_param("start", test_start, default=DEFAULT_TEST_START, scope="test")
377 self._set_param("end", test_end, default=DEFAULT_TEST_END, scope="test")
378 self._set_param("min_length", test_min_length, default=DEFAULT_TEST_MIN_LENGTH, scope="test")
380 # train_val set parameters
381 self._set_param("start", self.data_store.get("start", "train"), scope="train_val")
382 self._set_param("end", self.data_store.get("end", "val"), scope="train_val")
383 train_val_min_length = sum([self.data_store.get("min_length", s) for s in ["train", "val"]])
384 self._set_param("min_length", train_val_min_length, default=DEFAULT_TRAIN_VAL_MIN_LENGTH, scope="train_val")
386 # use all stations on all data sets (train, val, test)
387 self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets,
388 default=DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS)
390 # set post-processing instructions
391 self._set_param("do_uncertainty_estimate", do_uncertainty_estimate,
392 default=DEFAULT_DO_UNCERTAINTY_ESTIMATE, scope="general.postprocessing")
393 self._set_param("block_length", uncertainty_estimate_block_length,
394 default=DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, scope="uncertainty_estimate")
395 self._set_param("evaluate_competitors", uncertainty_estimate_evaluate_competitors,
396 default=DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, scope="uncertainty_estimate")
397 self._set_param("n_boots", uncertainty_estimate_n_boots,
398 default=DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, scope="uncertainty_estimate")
399 self._set_param("do_bias_free_evaluation", do_bias_free_evaluation,
400 default=DEFAULT_DO_BIAS_FREE_EVALUATION, scope="general.postprocessing")
402 self._set_param("evaluate_feature_importance", evaluate_feature_importance,
403 default=DEFAULT_EVALUATE_FEATURE_IMPORTANCE, scope="general.postprocessing")
404 feature_importance_create_new_bootstraps = max([self.data_store.get("train_model", "general"),
405 feature_importance_create_new_bootstraps or
406 DEFAULT_FEATURE_IMPORTANCE_CREATE_NEW_BOOTSTRAPS])
407 self._set_param("create_new_bootstraps", feature_importance_create_new_bootstraps, scope="feature_importance")
408 self._set_param("n_boots", feature_importance_n_boots, default=DEFAULT_FEATURE_IMPORTANCE_N_BOOTS,
409 scope="feature_importance")
410 self._set_param("bootstrap_method", feature_importance_bootstrap_method,
411 default=DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, scope="feature_importance")
412 self._set_param("bootstrap_type", feature_importance_bootstrap_type,
413 default=DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, scope="feature_importance")
415 self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
416 if model_display_name is not None: 416 ↛ 417line 416 didn't jump to line 417, because the condition on line 416 was never true
417 self._set_param("model_display_name", model_display_name)
418 self._set_param("neighbors", ["DEBW030"]) # TODO: just for testing
420 # set competitors
421 if model_display_name is not None and competitors is not None and model_display_name in competitors: 421 ↛ 422line 421 didn't jump to line 422, because the condition on line 421 was never true
422 raise IndexError(f"Given model_display_name {model_display_name} is also present in the competitors "
423 f"variable {competitors}. To assure a proper workflow it is required to have unique names "
424 f"for each model and competitor. Please use a different model display name or competitor.")
425 self._set_param("competitors", competitors, default=DEFAULT_COMPETITORS)
426 competitor_path_default = os.path.join(self.data_store.get("data_path"), "competitors",
427 "_".join(self.data_store.get("target_var")))
428 self._set_param("competitor_path", competitor_path, default=competitor_path_default)
430 # check variables, statistics and target variable
431 self._check_target_var()
432 self._compare_variables_and_statistics()
434 # set model architecture class
435 self._set_param("model_class", model, VanillaModel)
437 # store starting script if provided
438 if start_script is not None: 438 ↛ 439line 438 didn't jump to line 439, because the condition on line 438 was never true
439 self._store_start_script(start_script, experiment_path)
441 # set remaining kwargs
442 if len(kwargs) > 0:
443 for k, v in kwargs.items():
444 if len(self.data_store.search_name(k)) == 0:
445 self._set_param(k, v)
446 else:
447 s = ", ".join([f"{k}({s})={self.data_store.get(k, scope=s)}"
448 for s in self.data_store.search_name(k)])
449 raise KeyError(f"Given argument {k} with value {v} cannot be set for this experiment due to a "
450 f"conflict with an existing entry with same naming: {s}")
452 def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general",
453 apply: Callable = None) -> Any:
454 """Set given parameter and log in debug. Use apply parameter to adjust the stored value (e.g. to transform value
455 to a list use apply=helpers.to_list)."""
456 if value is None and default is not None:
457 value = default
458 if apply is not None:
459 value = apply(value)
460 self.data_store.set(param, value, scope)
461 logging.debug(f"set experiment attribute: {param}({scope})={value}")
462 return value
464 @staticmethod
465 def _store_start_script(start_script, store_path):
466 out_file = os.path.join(store_path, "start_script.txt")
467 if isinstance(start_script, Callable):
468 with open(out_file, "w") as fh:
469 fh.write(getsource(start_script))
470 if isinstance(start_script, str):
471 with open(start_script, 'r') as f:
472 with open(out_file, "w") as out:
473 for line in (f.readlines()):
474 print(line, end='', file=out)
476 def _compare_variables_and_statistics(self):
477 """
478 Compare variables and statistics.
480 * raise error, if a variable is missing.
481 * remove unused variables from statistics.
482 """
483 logging.debug("check if all variables are included in statistics_per_var")
484 stat = self.data_store.get("statistics_per_var")
485 var = self.data_store.get("variables")
486 # too less entries, raise error
487 if not set(var).issubset(stat.keys()):
488 missing = set(var).difference(stat.keys())
489 raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested "
490 f"variables are part of statistics_per_var. Please add also information on the missing "
491 f"statistics for the variables: {missing}")
492 # too much entries, remove unused
493 target_var = helpers.to_list(self.data_store.get("target_var"))
494 unused_vars = set(stat.keys()).difference(set(var).union(target_var))
495 if len(unused_vars) > 0: 495 ↛ 496line 495 didn't jump to line 496, because the condition on line 495 was never true
496 logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}")
497 stat_new = helpers.remove_items(stat, list(unused_vars))
498 self._set_param("statistics_per_var", stat_new)
500 def _check_target_var(self):
501 """Check if target variable is in statistics_per_var dictionary."""
502 target_var = helpers.to_list(self.data_store.get("target_var"))
503 stat = self.data_store.get("statistics_per_var")
504 var = self.data_store.get("variables")
505 if not set(target_var).issubset(stat.keys()): 505 ↛ 506line 505 didn't jump to line 506, because the condition on line 505 was never true
506 raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.")
509if __name__ == "__main__":
510 formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
511 logging.basicConfig(format=formatter, level=logging.DEBUG)
513 parser = argparse.ArgumentParser()
514 parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
515 help="set experiment date as string")
516 parser_args = parser.parse_args()
517 with RunEnvironment():
518 setup = ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])