Coverage for mlair/data_handler/data_handler_mixed_sampling.py: 28%
282 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1__author__ = 'Lukas Leufen'
2__date__ = '2020-11-05'
4from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
5from mlair.data_handler.data_handler_with_filter import DataHandlerFirFilterSingleStation, \
6 DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation
7from mlair.data_handler.data_handler_with_filter import DataHandlerClimateFirFilter, DataHandlerFirFilter
8from mlair.data_handler import DefaultDataHandler
9from mlair import helpers
10from mlair.helpers import to_list, sort_like
11from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD
12from mlair.helpers.filter import filter_width_kzf
14import copy
15import datetime as dt
16from typing import Any
17from functools import partial
19import pandas as pd
20import xarray as xr
23class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
25 def __init__(self, *args, **kwargs):
26 """
27 This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple
28 for input and target data. If one of these kwargs is only a single argument, it will be applied to inputs and
29 targets with this value. If one of these kwargs is a 2-dim tuple, the first element is applied to inputs and the
30 second to targets respectively. If one of these kwargs is not provided, it is filled up with the same default
31 value for inputs and targets.
32 """
33 self.update_kwargs("sampling", DEFAULT_SAMPLING, kwargs)
34 self.update_kwargs("interpolation_limit", DEFAULT_INTERPOLATION_LIMIT, kwargs)
35 self.update_kwargs("interpolation_method", DEFAULT_INTERPOLATION_METHOD, kwargs)
36 super().__init__(*args, **kwargs)
38 @staticmethod
39 def update_kwargs(parameter_name: str, default: Any, kwargs: dict):
40 """
41 Update a single element of kwargs inplace to be usable for inputs and targets.
43 The updated value in the kwargs dictionary is a tuple consisting on the value applicable to the inputs as first
44 element and the target's value as second element: (<value_input>, <value_target>). If the value for the given
45 parameter_name is already a tuple, it is checked to have exact two entries. If the paramter_name is not
46 included in kwargs, the given default value is used and applied to both elements of the update tuple.
48 :param parameter_name: name of the parameter that should be transformed to 2-dim
49 :param default: the default value to fill if parameter is not in kwargs
50 :param kwargs: the kwargs dictionary containing parameters
51 """
52 parameter = kwargs.get(parameter_name, default)
53 if not isinstance(parameter, tuple):
54 parameter = (parameter, parameter)
55 assert len(parameter) == 2 # (inputs, targets)
56 kwargs.update({parameter_name: parameter})
58 def make_input_target(self):
59 self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
60 self.set_inputs_and_targets()
62 def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
63 vars = [self.variables, self.target_var]
64 stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind])
65 data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind],
66 self.store_data_locally, self.data_origin, self.start, self.end)
67 data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
68 limit=self.interpolation_limit[ind], sampling=self.sampling[ind])
70 return data
72 def set_inputs_and_targets(self):
73 self.input_data = self._data[0].sel({self.target_dim: helpers.to_list(self.variables)})
74 self.target_data = self._data[1].sel({self.target_dim: helpers.to_list(self.target_var)})
76 def setup_data_path(self, data_path, sampling):
77 """Sets two paths instead of single path. Expects sampling arg to be a list with two entries"""
78 assert len(sampling) == 2
79 return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling))
81 def _extract_lazy(self, lazy_data):
82 _data, self.meta, _input_data, _target_data = lazy_data
83 f_prep = partial(self._slice_prep, start=self.start, end=self.end)
84 self._data = f_prep(_data[0]), f_prep(_data[1])
85 self.input_data, self.target_data = list(map(f_prep, [_input_data, _target_data]))
88class DataHandlerMixedSampling(DefaultDataHandler):
89 """Data handler using mixed sampling for input and target."""
91 data_handler = DataHandlerMixedSamplingSingleStation
92 data_handler_transformation = DataHandlerMixedSamplingSingleStation
93 _requirements = data_handler.requirements()
96class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation,
97 DataHandlerFilterSingleStation):
99 def __init__(self, *args, **kwargs):
100 super().__init__(*args, **kwargs)
102 def _check_sampling(self, **kwargs):
103 assert kwargs.get("sampling") == ("hourly", "daily")
105 def apply_filter(self):
106 raise NotImplementedError
108 def create_filter_index(self) -> pd.Index:
109 """Create name for filter dimension."""
110 raise NotImplementedError
112 def _create_lazy_data(self):
113 raise NotImplementedError
115 def make_input_target(self):
116 """
117 A FIR filter is applied on the input data that has hourly resolution. Labels Y are provided as aggregated values
118 with daily resolution.
119 """
120 self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
121 self.set_inputs_and_targets()
122 self.apply_filter()
124 def estimate_filter_width(self):
125 """Return maximum filter width."""
126 raise NotImplementedError
128 @staticmethod
129 def _add_time_delta(date, delta):
130 new_date = dt.datetime.strptime(date, "%Y-%m-%d") + dt.timedelta(hours=delta)
131 return new_date.strftime("%Y-%m-%d")
133 def update_start_end(self, ind):
134 if ind == 0: # for inputs
135 estimated_filter_width = self.estimate_filter_width()
136 start = self._add_time_delta(self.start, -estimated_filter_width)
137 end = self._add_time_delta(self.end, estimated_filter_width)
138 else: # target
139 start, end = self.start, self.end
140 return start, end
142 def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
144 start, end = self.update_start_end(ind)
145 vars = [self.variables, self.target_var]
146 stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind])
148 data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind],
149 self.store_data_locally, self.data_origin, start, end)
150 data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
151 limit=self.interpolation_limit[ind], sampling=self.sampling[ind])
152 return data
154 def _extract_lazy(self, lazy_data):
155 _data, self.meta, _input_data, _target_data = lazy_data
156 start_inp, end_inp = self.update_start_end(0)
157 self._data = tuple(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1]))
158 self.input_data = self._slice_prep(_input_data, start_inp, end_inp)
159 self.target_data = self._slice_prep(_target_data, self.start, self.end)
162class DataHandlerMixedSamplingWithFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation,
163 DataHandlerFirFilterSingleStation):
165 def __init__(self, *args, **kwargs):
166 super().__init__(*args, **kwargs)
168 def estimate_filter_width(self):
169 """Filter width is determined by the filter with the highest order."""
170 if isinstance(self.filter_order[0], tuple):
171 return max([filter_width_kzf(*e) for e in self.filter_order])
172 else:
173 return max(self.filter_order)
175 def apply_filter(self):
176 DataHandlerFirFilterSingleStation.apply_filter(self)
178 def create_filter_index(self, add_unfiltered_index=True) -> pd.Index:
179 return DataHandlerFirFilterSingleStation.create_filter_index(self, add_unfiltered_index=add_unfiltered_index)
181 def _extract_lazy(self, lazy_data):
182 _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data
183 DataHandlerMixedSamplingWithFilterSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data))
185 def _create_lazy_data(self):
186 return DataHandlerFirFilterSingleStation._create_lazy_data(self)
188 @staticmethod
189 def _get_fs(**kwargs):
190 """Return frequency in 1/day (not Hz)"""
191 sampling = kwargs.get("sampling")[0]
192 if sampling == "daily":
193 return 1
194 elif sampling == "hourly":
195 return 24
196 else:
197 raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.")
200class DataHandlerMixedSamplingWithFirFilter(DataHandlerFirFilter):
201 """Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
203 data_handler = DataHandlerMixedSamplingWithFirFilterSingleStation
204 data_handler_transformation = DataHandlerMixedSamplingWithFirFilterSingleStation
205 _requirements = data_handler.requirements()
208class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerClimateFirFilterSingleStation,
209 DataHandlerMixedSamplingWithFirFilterSingleStation):
211 def __init__(self, *args, **kwargs):
212 super().__init__(*args, **kwargs)
214 def _extract_lazy(self, lazy_data):
215 _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori, \
216 self.filter_dim_order = lazy_data
217 DataHandlerMixedSamplingWithFilterSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data))
220class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter):
221 """Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
223 data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
224 data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
225 data_handler_unfiltered = DataHandlerMixedSamplingSingleStation
226 _requirements = list(set(data_handler.requirements() + data_handler_unfiltered.requirements()))
227 DEFAULT_FILTER_ADD_UNFILTERED = False
229 def __init__(self, *args, data_handler_class_unfiltered: data_handler_unfiltered = None,
230 filter_add_unfiltered: bool = DEFAULT_FILTER_ADD_UNFILTERED, **kwargs):
231 self.dh_unfiltered = data_handler_class_unfiltered
232 self.filter_add_unfiltered = filter_add_unfiltered
233 super().__init__(*args, **kwargs)
235 def _create_collection(self):
236 collection = super()._create_collection()
237 if self.filter_add_unfiltered is True and self.dh_unfiltered is not None:
238 collection.append(self.dh_unfiltered)
239 return collection
241 @classmethod
242 def build(cls, station: str, **kwargs):
243 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler.requirements() if k in kwargs}
244 filter_add_unfiltered = kwargs.get("filter_add_unfiltered", False)
245 sp_keys = cls.build_update_transformation(sp_keys, dh_type="filtered")
246 sp = cls.data_handler(station, **sp_keys)
247 if filter_add_unfiltered is True:
248 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs}
249 sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered")
250 sp_unfiltered = cls.data_handler_unfiltered(station, **sp_keys)
251 else:
252 sp_unfiltered = None
253 dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
254 return cls(sp, data_handler_class_unfiltered=sp_unfiltered, **dp_args)
256 @classmethod
257 def build_update_transformation(cls, kwargs_dict, dh_type="filtered"):
258 if "transformation" in kwargs_dict:
259 trafo_opts = kwargs_dict.get("transformation")
260 if isinstance(trafo_opts, dict):
261 kwargs_dict["transformation"] = trafo_opts.get(dh_type)
262 return kwargs_dict
264 @classmethod
265 def transformation(cls, set_stations, tmp_path=None, dh_transformation=None, **kwargs):
267 # sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
268 if "transformation" not in kwargs.keys():
269 return
271 if dh_transformation is None:
272 dh_transformation = (cls.data_handler_transformation, cls.data_handler_unfiltered)
273 elif not isinstance(dh_transformation, tuple):
274 dh_transformation = (dh_transformation, dh_transformation)
275 transformation_filtered = super().transformation(set_stations, tmp_path=tmp_path,
276 dh_transformation=dh_transformation[0], **kwargs)
277 if kwargs.get("filter_add_unfiltered", False) is False:
278 return transformation_filtered
279 else:
280 transformation_unfiltered = super().transformation(set_stations, tmp_path=tmp_path,
281 dh_transformation=dh_transformation[1], **kwargs)
282 return {"filtered": transformation_filtered, "unfiltered": transformation_unfiltered}
285class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWithClimateFirFilter):
286 # data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
287 # data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
288 # data_handler_unfiltered = DataHandlerMixedSamplingSingleStation
289 # _requirements = list(set(data_handler.requirements() + data_handler_unfiltered.requirements()))
290 # DEFAULT_FILTER_ADD_UNFILTERED = False
291 data_handler_climate_fir = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
292 data_handler_fir = (DataHandlerMixedSamplingWithFirFilterSingleStation,
293 DataHandlerMixedSamplingWithClimateFirFilterSingleStation)
294 data_handler_fir_pos = None
295 data_handler = None
296 data_handler_unfiltered = DataHandlerMixedSamplingSingleStation
297 _requirements = list(set(data_handler_climate_fir.requirements() + data_handler_fir[0].requirements() +
298 data_handler_fir[1].requirements() + data_handler_unfiltered.requirements()))
299 chem_indicator = "chem"
300 meteo_indicator = "meteo"
302 def __init__(self, data_handler_class_chem, data_handler_class_meteo, data_handler_class_chem_unfiltered,
303 data_handler_class_meteo_unfiltered, chem_vars, meteo_vars, *args, **kwargs):
305 if len(chem_vars) > 0:
306 id_class, id_class_unfiltered = data_handler_class_chem, data_handler_class_chem_unfiltered
307 self.id_class_other = data_handler_class_meteo
308 self.id_class_other_unfiltered = data_handler_class_meteo_unfiltered
309 else:
310 id_class, id_class_unfiltered = data_handler_class_meteo, data_handler_class_meteo_unfiltered
311 self.id_class_other = data_handler_class_chem
312 self.id_class_other_unfiltered = data_handler_class_chem_unfiltered
313 super().__init__(id_class, *args, data_handler_class_unfiltered=id_class_unfiltered, **kwargs)
315 @classmethod
316 def _split_chem_and_meteo_variables(cls, **kwargs):
317 """
318 Select all used variables and split them into categories chem and other.
320 Chemical variables are indicated by `cls.data_handler_climate_fir.chem_vars`. To indicate used variables, this
321 method uses 1) parameter `variables`, 2) keys from `statistics_per_var`, 3) keys from
322 `cls.data_handler_climate_fir.DEFAULT_VAR_ALL_DICT`. Option 3) is also applied if 1) or 2) are given but None.
323 """
324 if "variables" in kwargs:
325 variables = kwargs.get("variables")
326 elif "statistics_per_var" in kwargs:
327 variables = kwargs.get("statistics_per_var").keys()
328 else:
329 variables = None
330 if variables is None:
331 variables = cls.data_handler_climate_fir.DEFAULT_VAR_ALL_DICT.keys()
332 chem_vars = cls.data_handler_climate_fir.chem_vars
333 chem = set(variables).intersection(chem_vars)
334 meteo = set(variables).difference(chem_vars)
335 return sort_like(to_list(chem), variables), sort_like(to_list(meteo), variables)
337 @classmethod
338 def build(cls, station: str, **kwargs):
339 chem_vars, meteo_vars = cls._split_chem_and_meteo_variables(**kwargs)
340 filter_add_unfiltered = kwargs.get("filter_add_unfiltered", False)
341 sp_chem, sp_chem_unfiltered = None, None
342 sp_meteo, sp_meteo_unfiltered = None, None
344 if len(chem_vars) > 0:
345 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_climate_fir.requirements() if k in kwargs}
346 sp_keys = cls.build_update_transformation(sp_keys, dh_type="filtered_chem")
348 cls.prepare_build(sp_keys, chem_vars, cls.chem_indicator)
349 sp_chem = cls.data_handler_climate_fir(station, **sp_keys)
350 if filter_add_unfiltered is True:
351 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs}
352 sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered_chem")
353 cls.prepare_build(sp_keys, chem_vars, cls.chem_indicator)
354 cls.correct_overwrite_option(sp_keys)
355 sp_chem_unfiltered = cls.data_handler_unfiltered(station, **sp_keys)
356 if len(meteo_vars) > 0:
357 cls.set_data_handler_fir_pos(**kwargs)
358 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_fir[cls.data_handler_fir_pos].requirements() if k in kwargs}
359 sp_keys = cls.build_update_transformation(sp_keys, dh_type="filtered_meteo")
360 cls.prepare_build(sp_keys, meteo_vars, cls.meteo_indicator)
361 sp_meteo = cls.data_handler_fir[cls.data_handler_fir_pos](station, **sp_keys)
362 if filter_add_unfiltered is True:
363 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs}
364 sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered_meteo")
365 cls.prepare_build(sp_keys, meteo_vars, cls.meteo_indicator)
366 cls.correct_overwrite_option(sp_keys)
367 sp_meteo_unfiltered = cls.data_handler_unfiltered(station, **sp_keys)
369 dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
370 return cls(sp_chem, sp_meteo, sp_chem_unfiltered, sp_meteo_unfiltered, chem_vars, meteo_vars, **dp_args)
372 @classmethod
373 def correct_overwrite_option(cls, kwargs):
374 """Set `overwrite_local_data=False`."""
375 if "overwrite_local_data" in kwargs:
376 kwargs["overwrite_local_data"] = False
378 @classmethod
379 def set_data_handler_fir_pos(cls, **kwargs):
380 """
381 Set position of fir data handler to use either faster FIR version or slower climate FIR.
383 This method will set data handler indicator to 0 if either no parameter "extend_length_opts" is given or the
384 parameter is of type dict but has no entry for the meteo_indicator. In all other cases, indicator is set to 1.
385 """
386 p_name = "extend_length_opts"
387 if cls.data_handler_fir_pos is None:
388 if p_name in kwargs:
389 if isinstance(kwargs[p_name], dict) and cls.meteo_indicator not in kwargs[p_name].keys():
390 cls.data_handler_fir_pos = 0 # use faster fir version without climate estimate
391 else:
392 cls.data_handler_fir_pos = 1 # use slower fir version with climate estimate
393 else:
394 cls.data_handler_fir_pos = 0 # use faster fir version without climate estimate
396 @classmethod
397 def prepare_build(cls, kwargs, var_list, var_type):
398 """
399 Prepares for build of class.
401 `variables` parameter is updated by `var_list`, which should only include variables of a specific type (e.g.
402 only chemical variables) indicated by `var_type`. Furthermore, this method cleans the `kwargs` dictionary as
403 follows: For all parameters provided as dict to separate between chem and meteo options (dict must have keys
404 from `cls.chem_indicator` and/or `cls.meteo_indicator`), this parameter is removed from kwargs and its value
405 related to `var_type` added again. In case there is no value for given `var_type`, the parameter is not added
406 at all (as this parameter is assumed to affect only other types of variables).
407 """
408 kwargs.update({"variables": var_list})
409 for k in list(kwargs.keys()):
410 v = kwargs[k]
411 if isinstance(v, dict):
412 if len(set(v.keys()).intersection({cls.chem_indicator, cls.meteo_indicator})) > 0:
413 try:
414 new_v = kwargs.pop(k)
415 kwargs[k] = new_v[var_type]
416 except KeyError:
417 pass
419 def _create_collection(self):
420 collection = super()._create_collection()
421 if self.id_class_other is not None:
422 collection.append(self.id_class_other)
423 if self.filter_add_unfiltered is True and self.id_class_other_unfiltered is not None:
424 collection.append(self.id_class_other_unfiltered)
425 return collection
427 @classmethod
428 def transformation(cls, set_stations, tmp_path=None, **kwargs):
430 if "transformation" not in kwargs.keys():
431 return
433 chem_vars, meteo_vars = cls._split_chem_and_meteo_variables(**kwargs)
434 transformation_chem, transformation_meteo = None, None
435 # chem transformation
436 if len(chem_vars) > 0:
437 kwargs_chem = copy.deepcopy(kwargs)
438 cls.prepare_build(kwargs_chem, chem_vars, cls.chem_indicator)
439 dh_transformation = (cls.data_handler_climate_fir, cls.data_handler_unfiltered)
440 transformation_chem = super().transformation(set_stations, tmp_path=tmp_path,
441 dh_transformation=dh_transformation, **kwargs_chem)
443 # meteo transformation
444 if len(meteo_vars) > 0:
445 cls.set_data_handler_fir_pos(**kwargs)
446 kwargs_meteo = copy.deepcopy(kwargs)
447 cls.prepare_build(kwargs_meteo, meteo_vars, cls.meteo_indicator)
448 dh_transformation = (cls.data_handler_fir[cls.data_handler_fir_pos], cls.data_handler_unfiltered)
449 transformation_meteo = super().transformation(set_stations, tmp_path=tmp_path,
450 dh_transformation=dh_transformation, **kwargs_meteo)
452 # combine all transformations
453 transformation_res = {}
454 if transformation_chem is not None:
455 if isinstance(transformation_chem, dict):
456 if len(transformation_chem) > 0:
457 transformation_res["filtered_chem"] = transformation_chem.pop("filtered")
458 transformation_res["unfiltered_chem"] = transformation_chem.pop("unfiltered")
459 else: # if no unfiltered chem branch
460 transformation_res["filtered_chem"] = transformation_chem
461 if transformation_meteo is not None:
462 if isinstance(transformation_meteo, dict):
463 if len(transformation_meteo) > 0:
464 transformation_res["filtered_meteo"] = transformation_meteo.pop("filtered")
465 transformation_res["unfiltered_meteo"] = transformation_meteo.pop("unfiltered")
466 else: # if no unfiltered meteo branch
467 transformation_res["filtered_meteo"] = transformation_meteo
468 return transformation_res if len(transformation_res) > 0 else None