Coverage for mlair/run_modules/pre_processing.py: 59%
307 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1"""Pre-processing module."""
3__author__ = "Lukas Leufen, Felix Kleinert"
4__date__ = '2019-11-25'
6import logging
7import os
8import traceback
9from typing import Tuple
10import multiprocessing
11import requests
12import psutil
13import random
14import dill
16import pandas as pd
18from mlair.data_handler import DataCollection, AbstractDataHandler
19from mlair.helpers import TimeTracking, to_list, tables, remove_items
20from mlair.configuration import path_config
21from mlair.helpers.data_sources.data_loader import EmptyQueryResult
22from mlair.helpers.testing import check_nested_equality
23from mlair.run_modules.run_environment import RunEnvironment
26class PreProcessing(RunEnvironment):
27 """
28 Pre-process your data by using this class.
30 Schedule of pre-processing:
31 #. load and check valid stations (either download or load from disk)
32 #. split subsets (train, val, test, train & val)
33 #. create small report on data metrics
35 Required objects [scope] from data store:
36 * all elements from `DEFAULT_ARGS_LIST` in scope preprocessing for general data loading
37 * all elements from `DEFAULT_ARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings
38 * `fraction_of_training` [.]
39 * `experiment_path` [.]
40 * `use_all_stations_on_all_data_sets` [.]
42 Optional objects
43 * all elements from `DEFAULT_KWARGS_LIST` in scope preprocessing for general data loading
44 * all elements from `DEFAULT_KWARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings
46 Sets
47 * `stations` in [., train, val, test, train_val]
48 * `generator` in [train, val, test, train_val]
49 * `transformation` [.]
51 Creates
52 * all input and output data in `data_path`
53 * latex reports in `experiment_path/latex_report`
55 """
57 def __init__(self):
58 """Set up and run pre-processing."""
59 super().__init__()
60 self._run()
62 def _run(self):
63 snapshot_load_path = self.data_store.get_default("snapshot_load_path", default=None)
64 if snapshot_load_path is None: 64 ↛ 75line 64 didn't jump to line 75, because the condition on line 64 was never false
65 stations = self.data_store.get("stations")
66 data_handler = self.data_store.get("data_handler")
67 self._load_apriori()
68 _, valid_stations = self.validate_station(data_handler, stations,
69 "preprocessing") # , store_processed_data=False)
70 if len(valid_stations) == 0: 70 ↛ 71line 70 didn't jump to line 71, because the condition on line 70 was never true
71 raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.")
72 self.data_store.set("stations", valid_stations)
73 self.split_train_val_test()
74 else:
75 self.load_snapshot(snapshot_load_path)
76 self.report_pre_processing()
77 self.prepare_competitors()
78 if self.data_store.get_default("create_snapshot", False) is True: 78 ↛ 79line 78 didn't jump to line 79, because the condition on line 78 was never true
79 self.create_snapshot()
81 def report_pre_processing(self):
82 """Log some metrics on data and create latex report."""
83 logging.debug(20 * '##')
84 n_train = len(self.data_store.get('data_collection', 'train'))
85 n_val = len(self.data_store.get('data_collection', 'val'))
86 n_test = len(self.data_store.get('data_collection', 'test'))
87 n_total = n_train + n_val + n_test
88 logging.debug(f"Number of all stations: {n_total}")
89 logging.debug(f"Number of training stations: {n_train}")
90 logging.debug(f"Number of val stations: {n_val}")
91 logging.debug(f"Number of test stations: {n_test}")
92 self.create_latex_report()
94 def create_latex_report(self):
95 """
96 Create tables with information on the station meta data and a summary on subset sample sizes.
98 * station_sample_size.md: see table below as markdown
99 * station_sample_size.tex: same as table below as latex table
100 * station_sample_size_short.tex: reduced size table without any meta data besides station ID, as latex table
102 All tables are stored inside experiment_path inside the folder latex_report. The table format (e.g. which meta
103 data is highlighted) is currently hardcoded to have a stable table style. If further styles are needed, it is
104 better to add an additional style than modifying the existing table styles.
106 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
107 | stat. ID | station_name | station_lon | station_lat | station_alt | train | val | test |
108 +============+===========================================+===============+===============+===============+=========+=======+========+
109 | DEBW013 | Stuttgart Bad Cannstatt | 9.2297 | 48.8088 | 235 | 1434 | 712 | 1080 |
110 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
111 | DEBW076 | Baden-Baden | 8.2202 | 48.7731 | 148 | 3037 | 722 | 710 |
112 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
113 | DEBW087 | Schwäbische_Alb | 9.2076 | 48.3458 | 798 | 3044 | 714 | 1087 |
114 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
115 | DEBW107 | Tübingen | 9.0512 | 48.5077 | 325 | 1803 | 715 | 1087 |
116 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
117 | DEBY081 | Garmisch-Partenkirchen/Kreuzeckbahnstraße | 11.0631 | 47.4764 | 735 | 2935 | 525 | 714 |
118 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
119 | # Stations | nan | nan | nan | nan | 6 | 6 | 6 |
120 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
121 | # Samples | nan | nan | nan | nan | 12253 | 3388 | 4678 |
122 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+
124 """
125 meta_cols = ["name", "lat", "lon", "alt", "country", "state", "type", "type_of_area", "toar1_category"]
126 meta_round = ["lat", "lon", "alt"]
127 precision = 4
128 path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
129 path_config.check_path_and_create(path)
130 names_of_set = ["train", "val", "test"]
131 df = self.create_info_df(meta_cols, meta_round, names_of_set, precision)
132 column_format = tables.create_column_format_for_tex(df)
133 tables.save_to_tex(path=path, filename="station_sample_size.tex", column_format=column_format, df=df)
134 tables.save_to_md(path=path, filename="station_sample_size.md", df=df)
135 df_nometa = df.drop(meta_cols, axis=1)
136 column_format = tables.create_column_format_for_tex(df)
137 tables.save_to_tex(path=path, filename="station_sample_size_short.tex", column_format=column_format,
138 df=df_nometa)
139 tables.save_to_md(path=path, filename="station_sample_size_short.md", df=df_nometa)
140 df_descr = self.create_describe_df(df_nometa)
141 column_format = tables.create_column_format_for_tex(df_descr)
142 tables.save_to_tex(path=path, filename="station_describe_short.tex", column_format=column_format, df=df_descr)
143 tables.save_to_md(path=path, filename="station_describe_short.md", df=df_descr)
145 @staticmethod
146 def create_describe_df(df, percentiles=None, ignore_last_lines: int = 2):
147 if percentiles is None: 147 ↛ 149line 147 didn't jump to line 149, because the condition on line 147 was never false
148 percentiles = [.05, .1, .25, .5, .75, .9, .95]
149 df_descr = df.iloc[:-ignore_last_lines].astype('float32').describe(
150 percentiles=percentiles).astype("int32", errors="ignore")
151 df_descr = pd.concat([df.loc[['# Samples']], df_descr]).T
152 df_descr.rename(columns={"# Samples": "no. samples", "count": "no. stations"}, inplace=True)
153 df_descr_colnames = list(df_descr.columns)
154 df_descr_colnames = [df_descr_colnames[1]] + [df_descr_colnames[0]] + df_descr_colnames[2:]
155 df_descr = df_descr[df_descr_colnames]
156 return df_descr
158 def create_info_df(self, meta_cols, meta_round, names_of_set, precision):
159 use_multiprocessing = self.data_store.get("use_multiprocessing")
160 max_process = self.data_store.get("max_number_multiprocessing")
161 df = pd.DataFrame(columns=meta_cols + names_of_set)
162 for set_name in names_of_set:
163 data = self.data_store.get("data_collection", set_name)
164 n_process = min([psutil.cpu_count(logical=False), len(data), max_process]) # use only physical cpus
165 if n_process > 1 and use_multiprocessing is True: # parallel solution 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true
166 logging.info(f"use parallel create_info_df ({set_name})")
167 pool = multiprocessing.Pool(n_process)
168 logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
169 output = [pool.apply_async(f_proc_create_info_df, args=(station, meta_cols)) for station in data]
170 for i, p in enumerate(output):
171 res = p.get()
172 station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"]
173 df.loc[station_name, set_name] = shape
174 if df.loc[station_name, meta_cols].isnull().any():
175 df.loc[station_name, meta_cols] = meta
176 logging.info(f"...finished: {station_name} ({int((i + 1.) / len(output) * 100)}%)")
177 pool.close()
178 pool.join()
179 else: # serial solution
180 logging.info(f"use serial create_info_df ({set_name})")
181 for station in data:
182 res = f_proc_create_info_df(station, meta_cols)
183 station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"]
184 df.loc[station_name, set_name] = shape
185 if df.loc[station_name, meta_cols].isnull().any():
186 df.loc[station_name, meta_cols] = meta
187 df.loc["# Samples", set_name] = df.loc[:, set_name].sum()
188 assert len(data) == df.loc[:, set_name].count() - 1
189 df.loc["# Stations", set_name] = len(data)
190 df[meta_round] = df[meta_round].astype(float).round(precision)
191 df.sort_index(inplace=True)
192 df = df.reindex(df.index.drop(["# Stations", "# Samples"]).to_list() + ["# Stations", "# Samples"], )
193 df.index.name = 'stat. ID'
194 return df
196 def split_train_val_test(self) -> None:
197 """
198 Split data into subsets.
200 Currently: train, val, test and train_val (actually this is only the merge of train and val, but as an separate
201 data_collection). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs
202 always to be executed at first, to set a proper transformation.
203 """
204 fraction_of_training = self.data_store.get("fraction_of_training")
205 stations = self.data_store.get("stations")
206 train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations),
207 fraction_of_training)
208 subset_names = ["train", "val", "test", "train_val"]
209 if subset_names[0] != "train": # pragma: no cover
210 raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset"
211 f"order was: {subset_names}.")
212 for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names):
213 self.create_set_split(ind, scope)
215 @staticmethod
216 def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]:
217 """
218 Create the training, validation and test subset slice indices for given total_length.
220 The test data consists on (1-fraction) of total_length (fraction*len:end). Train and validation data therefore
221 are made from fraction of total_length (0:fraction*len). Train and validation data is split by the factor 0.8
222 for train and 0.2 for validation. In addition, split_set_indices returns also the combination of training and
223 validation subset.
225 :param total_length: list with all objects to split
226 :param fraction: ratio between test and union of train/val data
228 :return: slices for each subset in the order: train, val, test, train_val
229 """
230 pos_test_split = int(total_length * fraction)
231 train_index = slice(0, int(pos_test_split * 0.8))
232 val_index = slice(int(pos_test_split * 0.8), pos_test_split)
233 test_index = slice(pos_test_split, total_length)
234 train_val_index = slice(0, pos_test_split)
235 return train_index, val_index, test_index, train_val_index
237 def create_set_split(self, index_list: slice, set_name: str) -> None:
238 # get set stations
239 stations = self.data_store.get("stations", scope=set_name)
240 if self.data_store.get("use_all_stations_on_all_data_sets"):
241 set_stations = stations
242 else:
243 set_stations = stations[index_list]
244 logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
245 # create set data_collection and store
246 data_handler = self.data_store.get("data_handler")
247 collection, valid_stations = self.validate_station(data_handler, set_stations, set_name)
248 self.data_store.set("stations", valid_stations, scope=set_name)
249 self.data_store.set("data_collection", collection, scope=set_name)
251 def validate_station(self, data_handler: AbstractDataHandler, set_stations, set_name=None,
252 store_processed_data=True):
253 """
254 Check if all given stations in `all_stations` are valid.
256 Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the
257 loading time are logged in debug mode.
259 :return: Corrected list containing only valid station IDs.
260 """
261 t_outer = TimeTracking()
262 logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}")
263 # calculate transformation using train data
264 if set_name == "train":
265 logging.info("setup transformation using train data exclusively")
266 self.transformation(data_handler, set_stations)
267 # start station check
268 collection = DataCollection(name=set_name)
269 valid_stations = []
270 kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope=set_name)
271 use_multiprocessing = self.data_store.get("use_multiprocessing")
272 tmp_path = self.data_store.get("tmp_path")
274 max_process = self.data_store.get("max_number_multiprocessing")
275 n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus
276 if n_process > 1 and use_multiprocessing is True: # parallel solution 276 ↛ 277line 276 didn't jump to line 277, because the condition on line 276 was never true
277 logging.info("use parallel validate station approach")
278 pool = multiprocessing.Pool(n_process)
279 logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
280 kwargs.update({"tmp_path": tmp_path, "return_strategy": "reference"})
281 output = [
282 pool.apply_async(f_proc, args=(data_handler, station, set_name, store_processed_data), kwds=kwargs)
283 for station in set_stations]
284 for i, p in enumerate(output):
285 _res_file, s = p.get()
286 logging.info(f"...finished: {s} ({int((i + 1.) / len(output) * 100)}%)")
287 with open(_res_file, "rb") as f:
288 dh = dill.load(f)
289 os.remove(_res_file)
290 if dh is not None:
291 collection.add(dh)
292 valid_stations.append(s)
293 pool.close()
294 pool.join()
295 else: # serial solution
296 logging.info("use serial validate station approach")
297 kwargs.update({"return_strategy": "result"})
298 for station in set_stations:
299 dh, s = f_proc(data_handler, station, set_name, store_processed_data, **kwargs)
300 if dh is not None:
301 collection.add(dh)
302 valid_stations.append(s)
304 logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/"
305 f"{len(set_stations)} valid stations ({set_name}).")
306 if set_name == "train":
307 self.store_data_handler_attributes(data_handler, collection)
308 return collection, valid_stations
310 def store_data_handler_attributes(self, data_handler, collection):
311 store_attributes = data_handler.store_attributes()
312 if len(store_attributes) > 0: 312 ↛ 313line 312 didn't jump to line 313, because the condition on line 312 was never true
313 logging.info(f"store following parameters ({len(store_attributes)}) requested by the data handler: "
314 f"{','.join(store_attributes)}")
315 attrs = {}
316 for dh in collection:
317 station = str(dh)
318 for k, v in dh.get_store_attributes().items():
319 attrs[k] = dict(attrs.get(k, {}), **{station: v})
320 for k, v in attrs.items():
321 self.data_store.set(k, v)
322 self._store_apriori()
324 def _store_apriori(self):
325 apriori = self.data_store.get_default("apriori", default=None)
326 if apriori: 326 ↛ 327line 326 didn't jump to line 327, because the condition on line 326 was never true
327 experiment_path = self.data_store.get("experiment_path")
328 path = os.path.join(experiment_path, "data", "apriori")
329 store_file = os.path.join(path, "apriori.pickle")
330 if not os.path.exists(path):
331 path_config.check_path_and_create(path)
332 with open(store_file, "wb") as f:
333 dill.dump(apriori, f, protocol=4)
334 logging.debug(f"Store apriori options locally for later use at: {store_file}")
336 def _load_apriori(self):
337 if self.data_store.get_default("apriori", default=None) is None: 337 ↛ exitline 337 didn't return from function '_load_apriori', because the condition on line 337 was never false
338 apriori_file = self.data_store.get_default("apriori_file", None)
339 if apriori_file is not None: 339 ↛ 340line 339 didn't jump to line 340, because the condition on line 339 was never true
340 if os.path.exists(apriori_file):
341 logging.info(f"use apriori data from given file: {apriori_file}")
342 with open(apriori_file, "rb") as pickle_file:
343 self.data_store.set("apriori", dill.load(pickle_file))
344 else:
345 logging.info(f"cannot load apriori file: {apriori_file}. Use fresh calculation from data.")
347 def transformation(self, data_handler: AbstractDataHandler, stations):
348 calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True)
349 if hasattr(data_handler, "transformation"):
350 transformation_opts = None if calculate_fresh_transformation is True else self._load_transformation()
351 if transformation_opts is None: 351 ↛ 357line 351 didn't jump to line 357, because the condition on line 351 was never false
352 logging.info(f"start to calculate transformation parameters.")
353 kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope="train")
354 tmp_path = self.data_store.get_default("tmp_path", default=None)
355 transformation_opts = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs)
356 else:
357 logging.info("In case no valid train data could be found due to problems with transformation, please "
358 "check your provided transformation file for compability with your data.")
359 self.data_store.set("transformation", transformation_opts)
360 if transformation_opts is not None:
361 self._store_transformation(transformation_opts)
363 def _load_transformation(self):
364 """Try to load transformation options from file if transformation_file is provided."""
365 transformation_file = self.data_store.get_default("transformation_file", None)
366 if transformation_file is not None:
367 if os.path.exists(transformation_file):
368 logging.info(f"use transformation from given transformation file: {transformation_file}")
369 with open(transformation_file, "rb") as pickle_file:
370 return dill.load(pickle_file)
371 else:
372 logging.info(f"cannot load transformation file: {transformation_file}. Use fresh calculation of "
373 f"transformation from train data.")
375 def _store_transformation(self, transformation_opts):
376 """Store transformation options locally inside experiment_path if not exists already."""
377 experiment_path = self.data_store.get("experiment_path")
378 transformation_path = os.path.join(experiment_path, "data", "transformation")
379 transformation_file = os.path.join(transformation_path, "transformation.pickle")
380 calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True)
381 if not os.path.exists(transformation_file) or calculate_fresh_transformation: 381 ↛ exitline 381 didn't return from function '_store_transformation', because the condition on line 381 was never false
382 path_config.check_path_and_create(transformation_path)
383 with open(transformation_file, "wb") as f:
384 dill.dump(transformation_opts, f, protocol=4)
385 logging.info(f"Store transformation options locally for later use at: {transformation_file}")
387 def prepare_competitors(self):
388 """
389 Prepare competitor models already in the preprocessing stage. This is performed here, because some models might
390 need to have internet access, which is depending on the operating system not possible during postprocessing.
391 This method checks currently only, if the Intelli03-ts-v1 model is requested as competitor and downloads the
392 data if required.
393 """
394 logging.info("Searching for competitors to be prepared for use.")
395 competitors = to_list(self.data_store.get_default("competitors", default=[]))
396 if len(competitors) > 0: 396 ↛ 427line 396 didn't jump to line 427, because the condition on line 396 was never false
397 for competitor_name in competitors:
398 if competitor_name.lower() == "IntelliO3-ts-v1".lower(): 398 ↛ 399line 398 didn't jump to line 399, because the condition on line 398 was never true
399 logging.info("Prepare IntelliO3-ts-v1 model")
400 from mlair.reference_models.reference_model_intellio3_v1 import IntelliO3_ts_v1
401 path = os.path.join(self.data_store.get("competitor_path"), competitor_name)
402 IntelliO3_ts_v1("IntelliO3-ts-v1", ref_store_path=path).make_reference_available_locally(remove_tmp_dir=False)
403 elif competitor_name.lower() == "CAMS".lower(): 403 ↛ 404line 403 didn't jump to line 404, because the condition on line 403 was never true
404 logging.info("Prepare CAMS forecasts")
405 from mlair.reference_models.reference_model_cams import CAMSforecast
406 interp_method = self.data_store.get_default("cams_interp_method", default=None)
407 data_path = self.data_store.get_default("cams_data_path", default=None)
408 path = os.path.join(self.data_store.get("competitor_path"), competitor_name)
409 stations = {}
410 for subset in ["train", "val", "test"]:
411 data_collection = self.data_store.get("data_collection", subset)
412 stations.update({str(s): s.get_coordinates() for s in data_collection if s not in stations})
413 if interp_method is None:
414 CAMSforecast("CAMS", ref_store_path=path, data_path=data_path, interp_method=None
415 ).make_reference_available_locally(stations)
416 else:
417 competitors = remove_items(competitors, "CAMS")
418 for method in to_list(interp_method):
419 CAMSforecast(f"CAMS{method}", ref_store_path=path + method, data_path=data_path,
420 interp_method=method).make_reference_available_locally(stations)
421 competitors.append(f"CAMS{method}")
422 self.data_store.set("competitors", competitors)
423 else:
424 logging.info(f"No preparation required for competitor {competitor_name} as no specific instruction "
425 f"is provided.")
426 else:
427 logging.info("No preparation required because no competitor was provided to the workflow.")
429 def create_snapshot(self):
430 logging.info("create snapshot for preprocessing")
431 from mlair.configuration.snapshot_names import animals
432 for i_try in range(10):
433 snapshot_name = random.choice(animals).lower()
434 snapshot_path = os.path.abspath(self.data_store.get("snapshot_path"))
435 path_config.check_path_and_create(snapshot_path, remove_existing=False)
436 _snapshot_file = os.path.join(snapshot_path, f"snapshot_preprocessing_{snapshot_name}.pickle")
437 if not os.path.exists(_snapshot_file):
438 logging.info(f"store snapshot at: {_snapshot_file}")
439 with open(_snapshot_file, "wb") as f:
440 dill.dump(self.data_store, f, protocol=4)
441 print(_snapshot_file)
442 return
443 logging.info(f"Could not create snapshot at {_snapshot_file} as file is already existing ({i_try + 1}/10)")
444 logging.info(f"Could not create any snapshot after 10/10 tries.")
446 def load_snapshot(self, file):
447 logging.info(f"load snapshot for preprocessing from {file}")
448 with open(file, "rb") as f:
449 snapshot = dill.load(f)
450 excluded_params = ["activation", "activation_output", "add_dense_layer", "batch_normalization", "batch_path",
451 "batch_size", "block_length", "bootstrap_method", "bootstrap_path", "bootstrap_type",
452 "competitor_path", "competitors", "create_new_bootstraps", "create_new_model",
453 "create_snapshot", "data_collection", "debug_mode", "dense_layer_configuration",
454 "do_uncertainty_estimate", "dropout", "dropout_rnn", "early_stopping_epochs", "epochs",
455 "evaluate_competitors", "evaluate_feature_importance", "experiment_name", "experiment_path",
456 "exponent_last_layer", "forecast_path", "fraction_of_training", "hostname", "hpc_hosts",
457 "kernel_regularizer", "kernel_size", "layer_configuration", "log_level_stream",
458 "logging_path", "login_nodes", "loss_type", "loss_weights", "max_number_multiprocessing",
459 "model_class", "model_display_name", "model_path", "n_boots", "n_hidden", "n_layer",
460 "neighbors", "plot_list", "plot_path", "regularizer", "restore_best_model_weights",
461 "snapshot_load_path", "snapshot_path", "stations", "tmp_path", "train_model",
462 "transformation", "use_multiprocessing", "cams_data_path", "cams_interp_method",
463 "do_bias_free_evaluation", "apriori_file", "model_path", "model_load_path"]
464 data_handler = self.data_store.get("data_handler")
465 model_class = self.data_store.get("model_class")
466 excluded_params = list(set(excluded_params + data_handler.store_attributes() + model_class.requirements()))
468 if check_nested_equality(self.data_store._store, snapshot._store, skip_args=excluded_params) is True:
469 self.update_datastore(snapshot, excluded_params=remove_items(excluded_params, ["transformation",
470 "data_collection",
471 "stations"]))
472 else:
473 raise ReferenceError("provided snapshot does not match with the current experiment setup. Abort this run!")
476def f_proc(data_handler, station, name_affix, store, return_strategy="", tmp_path=None, **kwargs):
477 """
478 Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and
479 therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and
480 the station that was used. This function must be implemented globally to work together with multiprocessing.
481 """
482 assert return_strategy in ["result", "reference"]
483 try:
484 res = data_handler.build(station, name_affix=name_affix, store_processed_data=store, **kwargs)
485 except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError, IndexError) as e:
486 formatted_lines = traceback.format_exc().splitlines()
487 logging.info(f"remove station {station} because it raised an error: {e} -> "
488 f"{' | '.join(f_inspect_error(formatted_lines))}")
489 logging.debug(f"detailed information for removal of station {station}: {traceback.format_exc()}")
490 res = None
491 if return_strategy == "result": 491 ↛ 494line 491 didn't jump to line 494, because the condition on line 491 was never false
492 return res, station
493 else:
494 if tmp_path is None:
495 tmp_path = os.getcwd()
496 _tmp_file = os.path.join(tmp_path, f"{station}_{'%032x' % random.getrandbits(128)}.pickle")
497 with open(_tmp_file, "wb") as f:
498 dill.dump(res, f, protocol=4)
499 return _tmp_file, station
502def f_proc_create_info_df(data, meta_cols):
503 station_name = str(data.id_class)
504 meta = data.id_class.meta
505 res = {"station_name": station_name, "Y_shape": data.get_Y()[0].shape[0],
506 "meta": meta.reindex(meta_cols).values.flatten()}
507 return res
510def f_inspect_error(formatted):
511 for i in range(len(formatted) - 1, -1, -1): 511 ↛ 514line 511 didn't jump to line 514, because the loop on line 511 didn't complete
512 if "mlair/mlair" not in formatted[i]: 512 ↛ 511line 512 didn't jump to line 511, because the condition on line 512 was never false
513 return formatted[i - 3:i]
514 return formatted[-3:0]