Coverage for mlair/data_handler/default_data_handler.py: 69%
254 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
2__author__ = 'Lukas Leufen'
3__date__ = '2020-09-21'
5import copy
6import inspect
7import gc
8import logging
9import os
10import pickle
11import random
12import dill
13import shutil
14from functools import reduce
15from typing import Tuple, Union, List
16import multiprocessing
17import psutil
18import dask
20import numpy as np
21import xarray as xr
23from mlair.data_handler.abstract_data_handler import AbstractDataHandler
24from mlair.helpers import remove_items, to_list, TimeTrackingWrapper
25from mlair.helpers.data_sources.toar_data import EmptyQueryResult
28number = Union[float, int]
29num_or_list = Union[number, List[number]]
32class DefaultDataHandler(AbstractDataHandler):
33 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler
34 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation
36 _requirements = data_handler.requirements()
37 _store_attributes = data_handler.store_attributes()
38 _skip_args = AbstractDataHandler._skip_args + ["id_class"]
40 DEFAULT_ITER_DIM = "Stations"
41 DEFAULT_TIME_DIM = "datetime"
42 MAX_NUMBER_MULTIPROCESSING = 16
44 def __init__(self, id_class: data_handler, experiment_path: str, min_length: int = 0,
45 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None,
46 store_processed_data=True, iter_dim=DEFAULT_ITER_DIM, time_dim=DEFAULT_TIME_DIM,
47 use_multiprocessing=True, max_number_multiprocessing=MAX_NUMBER_MULTIPROCESSING):
48 super().__init__()
49 self.id_class = id_class
50 self.time_dim = time_dim
51 self.iter_dim = iter_dim
52 self.min_length = min_length
53 self._X = None
54 self._Y = None
55 self._X_extreme = None
56 self._Y_extreme = None
57 self._data_intersection = None
58 self._len = None
59 self._len_upsampling = None
60 self._use_multiprocessing = use_multiprocessing
61 self._max_number_multiprocessing = max_number_multiprocessing
62 _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
63 self._save_file = os.path.join(experiment_path, "data", f"{_name_affix}.pickle")
64 self._collection = self._create_collection()
65 self.harmonise_X()
66 self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.time_dim)
67 self._store(fresh_store=True, store_processed_data=store_processed_data)
69 @classmethod
70 def build(cls, station: str, **kwargs):
71 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
72 sp = cls.data_handler(station, **sp_keys)
73 dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
74 return cls(sp, **dp_args)
76 def _create_collection(self):
77 return [self.id_class]
79 def _reset_data(self):
80 self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
81 gc.collect()
83 def _cleanup(self):
84 directory = os.path.dirname(self._save_file)
85 if os.path.exists(directory) is False:
86 os.makedirs(directory, exist_ok=True)
87 if os.path.exists(self._save_file):
88 shutil.rmtree(self._save_file, ignore_errors=True)
90 def _store(self, fresh_store=False, store_processed_data=True):
91 if store_processed_data is True: 91 ↛ exitline 91 didn't return from function '_store', because the condition on line 91 was never false
92 self._cleanup() if fresh_store is True else None
93 data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
94 data = self._force_dask_computation(data)
95 with open(self._save_file, "wb") as f:
96 dill.dump(data, f, protocol=4)
97 logging.debug(f"save pickle data to {self._save_file}")
98 self._reset_data()
100 def get_store_attributes(self):
101 attr_dict = {}
102 for attr in self.store_attributes():
103 try:
104 val = self.__getattribute__(attr)
105 except AttributeError:
106 val = self.id_class.__getattribute__(attr)
107 attr_dict[attr] = val
108 return attr_dict
110 @staticmethod
111 def _force_dask_computation(data):
112 try:
113 data = dask.compute(data)[0]
114 except:
115 pass
116 return data
118 def _load(self):
119 try:
120 with open(self._save_file, "rb") as f:
121 data = dill.load(f)
122 logging.debug(f"load pickle data from {self._save_file}")
123 self._X, self._Y = data["X"], data["Y"]
124 self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"]
125 except FileNotFoundError:
126 pass
128 def get_data(self, upsampling=False, as_numpy=True):
129 self._load()
130 as_numpy_X, as_numpy_Y = as_numpy if isinstance(as_numpy, tuple) else (as_numpy, as_numpy)
131 X = self.get_X(upsampling, as_numpy_X)
132 Y = self.get_Y(upsampling, as_numpy_Y)
133 self._reset_data()
134 return X, Y
136 def __repr__(self):
137 return str(self._collection[0])
139 def __len__(self, upsampling=False):
140 if upsampling is False: 140 ↛ 143line 140 didn't jump to line 143, because the condition on line 140 was never false
141 return self._len
142 else:
143 return self._len_upsampling
145 def get_X_original(self):
146 X = []
147 for data in self._collection:
148 X.append(data.get_X())
149 return X
151 def get_Y_original(self):
152 Y = self._collection[0].get_Y()
153 return Y
155 @staticmethod
156 def _to_numpy(d):
157 return list(map(lambda x: np.copy(x), d))
159 def get_X(self, upsampling=False, as_numpy=True):
160 no_data = (self._X is None)
161 self._load() if no_data is True else None
162 X = self._X if upsampling is False else self._X_extreme
163 self._reset_data() if no_data is True else None
164 return self._to_numpy(X) if as_numpy is True else X
166 def get_Y(self, upsampling=False, as_numpy=True):
167 no_data = (self._Y is None)
168 self._load() if no_data is True else None
169 Y = self._Y if upsampling is False else self._Y_extreme
170 self._reset_data() if no_data is True else None
171 return self._to_numpy([Y]) if as_numpy is True else Y
173 @TimeTrackingWrapper
174 def harmonise_X(self):
175 X_original, Y_original = self.get_X_original(), self.get_Y_original()
176 dim = self.time_dim
177 intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original))
178 if len(intersect) < max(self.min_length, 1): 178 ↛ 179line 178 didn't jump to line 179, because the condition on line 178 was never true
179 raise ValueError(f"There is no intersection of X.")
180 else:
181 X = list(map(lambda x: x.sel({dim: intersect}), X_original))
182 Y = Y_original.sel({dim: intersect})
183 self._data_intersection = intersect
184 self._X, self._Y = X, Y
185 self._len = len(self._data_intersection)
187 def get_observation(self):
188 dim = self.time_dim
189 if self._data_intersection is not None:
190 return self.id_class.observation.sel({dim: self._data_intersection}).copy().squeeze()
191 else:
192 return self.id_class.observation.copy().squeeze()
194 def apply_transformation(self, data, base="target", dim=0, inverse=False):
195 return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse)
197 @TimeTrackingWrapper
198 def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
199 timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM):
200 """
201 Multiply extremes.
203 This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can
204 also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of
205 floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised
206 space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be
207 extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is
208 used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can
209 identify those "artificial" data points later easily. Extreme inputs and labels are stored in
210 self.extremes_history and self.extreme_labels, respectively.
212 :param extreme_values: user definition of extreme
213 :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values,
214 if True only extract values larger than extreme_values
215 :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime
216 """
217 if extreme_values is None: 217 ↛ 224line 217 didn't jump to line 224, because the condition on line 217 was never false
218 logging.debug(f"No extreme values given, skip multiply extremes")
219 self._X_extreme, self._Y_extreme = self._X, self._Y
220 self._len_upsampling = self._len
221 return
223 # check type if inputs
224 extreme_values = to_list(extreme_values)
225 for i in extreme_values:
226 if not isinstance(i, number.__args__):
227 raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
228 f"{i} is type {type(i)}")
230 extremes_X, extremes_Y = None, None
231 for extr_val in sorted(extreme_values):
232 # check if some extreme values are already extracted
233 if (extremes_X is None) or (extremes_Y is None):
234 X, Y = self._X, self._Y
235 extremes_X, extremes_Y = X, Y
236 else: # one extr value iteration is done already: self.extremes_label is NOT None...
237 X, Y = self._X_extreme, self._Y_extreme
239 # extract extremes based on occurrence in labels
240 other_dims = remove_items(list(Y.dims), dim)
241 if extremes_on_right_tail_only:
242 extreme_idx = (extremes_Y > extr_val).any(dim=other_dims)
243 else:
244 extreme_idx = xr.concat([(extremes_Y < -extr_val).any(dim=other_dims[0]),
245 (extremes_Y > extr_val).any(dim=other_dims[0])],
246 dim=other_dims[0]).any(dim=other_dims[0])
248 sel = extreme_idx[extreme_idx].coords[dim].values
249 extremes_X = list(map(lambda x: x.sel(**{dim: sel}), extremes_X))
250 self._add_timedelta(extremes_X, dim, timedelta)
251 extremes_Y = extremes_Y.sel(**{dim: extreme_idx})
252 self._add_timedelta([extremes_Y], dim, timedelta)
254 self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
255 self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
256 self._len_upsampling = len(self._X_extreme[0].coords[dim])
258 @staticmethod
259 def _add_timedelta(data, dim, timedelta):
260 for d in data:
261 d.coords[dim] = d.coords[dim].values + np.timedelta64(*timedelta)
263 @classmethod
264 def transformation(cls, set_stations, tmp_path=None, dh_transformation=None, **kwargs):
265 """
266 ### supported transformation methods
268 Currently supported methods are:
270 * standardise (default, if method is not given)
271 * centre
272 * min_max
273 * log
275 ### mean and std estimation
277 Mean and std (depending on method) are estimated. For each station, mean and std are calculated and afterwards
278 aggregated using the mean value over all station-wise metrics. This method is not exactly accurate, especially
279 regarding the std calculation but therefore much faster. Furthermore, it is a weighted mean weighted by the
280 time series length / number of data itself - a longer time series has more influence on the transformation
281 settings than a short time series. The estimation of the std in less accurate, because the unweighted mean of
282 all stds in not equal to the true std, but still the mean of all station-wise std is a decent estimate. Finally,
283 the real accuracy of mean and std is less important, because it is "just" a transformation / scaling.
285 ### mean and std given
287 If mean and std are not None, the default data handler expects this parameters to match the data and applies
288 this values to the data. Make sure that all dimensions and/or coordinates are in agreement.
290 ### min and max given
291 If min and max are not None, the default data handler expects this parameters to match the data and applies
292 this values to the data. Make sure that all dimensions and/or coordinates are in agreement.
293 """
294 if dh_transformation is None: 294 ↛ 297line 294 didn't jump to line 297, because the condition on line 294 was never false
295 dh_transformation = cls.data_handler_transformation
297 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in dh_transformation.requirements() if k in kwargs}
298 if "transformation" not in sp_keys.keys(): 298 ↛ 299line 298 didn't jump to line 299, because the condition on line 298 was never true
299 return
300 transformation_dict = ({}, {})
302 max_process = kwargs.get("max_number_multiprocessing", 16)
303 set_stations = to_list(set_stations)
304 n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus
305 if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution 305 ↛ 306line 305 didn't jump to line 306, because the condition on line 305 was never true
306 logging.info("use parallel transformation approach")
307 pool = multiprocessing.Pool(n_process) # use only physical cpus
308 logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
309 sp_keys.update({"tmp_path": tmp_path, "return_strategy": "reference"})
310 output = [
311 pool.apply_async(f_proc, args=(dh_transformation, station), kwds=sp_keys)
312 for station in set_stations]
313 for p in output:
314 _res_file, s = p.get()
315 with open(_res_file, "rb") as f:
316 dh = dill.load(f)
317 os.remove(_res_file)
318 transformation_dict = cls.update_transformation_dict(dh, transformation_dict)
319 pool.close()
320 pool.join()
321 else: # serial solution
322 logging.info("use serial transformation approach")
323 sp_keys.update({"return_strategy": "result"})
324 for station in set_stations:
325 dh, s = f_proc(dh_transformation, station, **sp_keys)
326 transformation_dict = cls.update_transformation_dict(dh, transformation_dict)
328 # aggregate all information
329 iter_dim = sp_keys.get("iter_dim", cls.DEFAULT_ITER_DIM)
330 transformation_dict = cls.aggregate_transformation(transformation_dict, iter_dim)
331 return transformation_dict
333 @classmethod
334 def aggregate_transformation(cls, transformation_dict, iter_dim):
335 pop_list = []
336 for i, transformation in enumerate(transformation_dict):
337 for k in transformation.keys():
338 try:
339 if transformation[k]["mean"] is not None: 339 ↛ 341line 339 didn't jump to line 341, because the condition on line 339 was never false
340 transformation_dict[i][k]["mean"] = transformation[k]["mean"].mean(iter_dim)
341 if transformation[k]["std"] is not None: 341 ↛ 343line 341 didn't jump to line 343, because the condition on line 341 was never false
342 transformation_dict[i][k]["std"] = transformation[k]["std"].mean(iter_dim)
343 if transformation[k]["min"] is not None: 343 ↛ 344line 343 didn't jump to line 344, because the condition on line 343 was never true
344 transformation_dict[i][k]["min"] = transformation[k]["min"].min(iter_dim)
345 if transformation[k]["max"] is not None: 345 ↛ 346line 345 didn't jump to line 346, because the condition on line 345 was never true
346 transformation_dict[i][k]["max"] = transformation[k]["max"].max(iter_dim)
347 if "feature_range" in transformation[k].keys(): 347 ↛ 348line 347 didn't jump to line 348, because the condition on line 347 was never true
348 transformation_dict[i][k]["feature_range"] = transformation[k]["feature_range"]
349 except KeyError:
350 pop_list.append((i, k))
351 for (i, k) in pop_list: 351 ↛ 352line 351 didn't jump to line 352, because the loop on line 351 never started
352 transformation_dict[i].pop(k)
353 return transformation_dict
355 @classmethod
356 def update_transformation_dict(cls, dh, transformation_dict):
357 """Inner method that is performed in both serial and parallel approach."""
358 if dh is not None:
359 for i, transformation in enumerate(dh._transformation):
360 for var in transformation.keys():
361 if var not in transformation_dict[i].keys():
362 transformation_dict[i][var] = {}
363 opts = transformation[var]
364 if not transformation_dict[i][var].get("method", opts["method"]) == opts["method"]: 364 ↛ 366line 364 didn't jump to line 366, because the condition on line 364 was never true
365 # data handlers with filters are allowed to change transformation method to standardise
366 assert hasattr(dh, "filter_dim") and opts["method"] == "standardise"
367 transformation_dict[i][var]["method"] = opts["method"]
368 for k in ["mean", "std", "min", "max"]:
369 old = transformation_dict[i][var].get(k, None)
370 new = opts.get(k)
371 transformation_dict[i][var][k] = new if old is None else old.combine_first(new)
372 if "feature_range" in opts.keys(): 372 ↛ 373line 372 didn't jump to line 373, because the condition on line 372 was never true
373 transformation_dict[i][var]["feature_range"] = opts.get("feature_range", None)
374 return transformation_dict
376 def get_coordinates(self):
377 return self.id_class.get_coordinates()
380def f_proc(data_handler, station, return_strategy="", tmp_path=None, **sp_keys):
381 """
382 Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and
383 therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and
384 the station that was used. This function must be implemented globally to work together with multiprocessing.
385 """
386 assert return_strategy in ["result", "reference"]
387 try:
388 res = data_handler(station, **sp_keys)
389 except (AttributeError, EmptyQueryResult, KeyError, ValueError, IndexError) as e:
390 logging.info(f"remove station {station} because it raised an error: {e}")
391 res = None
392 if return_strategy == "result": 392 ↛ 395line 392 didn't jump to line 395, because the condition on line 392 was never false
393 return res, station
394 else:
395 _tmp_file = os.path.join(tmp_path, f"{station}_{'%032x' % random.getrandbits(128)}.pickle")
396 with open(_tmp_file, "wb") as f:
397 dill.dump(res, f, protocol=4)
398 return _tmp_file, station