Coverage for mlair/data_handler/data_handler_with_filter.py: 23%
243 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1"""Data Handler using kz-filtered data."""
3__author__ = 'Lukas Leufen'
4__date__ = '2020-08-26'
6import copy
7import numpy as np
8import pandas as pd
9import xarray as xr
10from typing import List, Union, Tuple, Optional
11from functools import partial
12import logging
13from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
14from mlair.data_handler import DefaultDataHandler
15from mlair.helpers import to_list, TimeTrackingWrapper, statistics
16from mlair.helpers.filter import FIRFilter, ClimateFIRFilter, omega_null_kzf
18# define a more general date type for type hinting
19str_or_list = Union[str, List[str]]
22# cutoff_p = [(None, 14), (8, 6), (2, 0.8), (0.8, None)]
23# cutoff = list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), cutoff_p))
24# fs = 24.
25# # order = int(60 * fs) + 1
26# order = np.array([int(14 * fs) + 1, int(14 * fs) + 1, int(4 * fs) + 1, int(2 * fs) + 1])
27# print("cutoff period", cutoff_p)
28# print("cutoff", cutoff)
29# print("fs", fs)
30# print("order", order)
31# print("delay", 0.5 * (order-1) / fs)
32# window = ("kaiser", 5)
33# # low pass
34# y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low = cutoff[0][0], cutoff_high = cutoff[0][1], window=window)
35# filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape)
38class DataHandlerFilterSingleStation(DataHandlerSingleStation):
39 """General data handler for a single station to be used by a superior data handler."""
41 _hash = DataHandlerSingleStation._hash + ["filter_dim"]
43 DEFAULT_FILTER_DIM = "filter"
45 def __init__(self, *args, filter_dim=DEFAULT_FILTER_DIM, **kwargs):
46 # self.original_data = None # ToDo: implement here something to store unfiltered data
47 self.filter_dim = filter_dim
48 self.filter_dim_order = None
49 super().__init__(*args, **kwargs)
51 def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
52 """
53 Adjust setup of transformation because filtered data will have negative values which is not compatible with
54 the log transformation. Therefore, replace all log transformation methods by a default standardization. This is
55 only applied on input side.
56 """
57 transformation = super(__class__, self).setup_transformation(transformation)
58 if transformation[0] is not None:
59 for k, v in transformation[0].items():
60 if v["method"] == "log":
61 transformation[0][k]["method"] = "standardise"
62 elif v["method"] == "min_max":
63 transformation[0][k]["method"] = "standardise"
64 return transformation
66 def _check_sampling(self, **kwargs):
67 assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution, does it?
69 def make_input_target(self):
70 data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
71 self.store_data_locally, self.data_origin, self.start, self.end)
72 self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
73 limit=self.interpolation_limit)
74 self.set_inputs_and_targets()
75 self.apply_filter()
76 # this is just a code snippet to check the results of the kz filter
77 # import matplotlib
78 # matplotlib.use("TkAgg")
79 # import matplotlib.pyplot as plt
80 # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot()
81 # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
83 def apply_filter(self):
84 raise NotImplementedError
86 def create_filter_index(self) -> pd.Index:
87 """Create name for filter dimension."""
88 raise NotImplementedError
90 def get_transposed_history(self) -> xr.DataArray:
91 """Return history.
93 :return: history with dimensions datetime, window, Stations, variables, filter.
94 """
95 return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim,
96 self.filter_dim).copy()
98 def _create_lazy_data(self):
99 raise NotImplementedError
101 def _extract_lazy(self, lazy_data):
102 _data, self.meta, _input_data, _target_data = lazy_data
103 f_prep = partial(self._slice_prep, start=self.start, end=self.end)
104 self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
107class DataHandlerFilter(DefaultDataHandler):
108 """Data handler using FIR filtered data."""
110 data_handler = DataHandlerFilterSingleStation
111 data_handler_transformation = DataHandlerFilterSingleStation
112 _requirements = data_handler.requirements()
114 def __init__(self, *args, use_filter_branches=False, **kwargs):
115 self.use_filter_branches = use_filter_branches
116 super().__init__(*args, **kwargs)
118 def get_X_original(self):
119 if self.use_filter_branches is True:
120 X = []
121 for data in self._collection:
122 if hasattr(data, "filter_dim"):
123 X_total = data.get_X()
124 filter_dim = data.filter_dim
125 for filter_name in data.filter_dim_order:
126 X.append(X_total.sel({filter_dim: filter_name}, drop=True))
127 else:
128 X.append(data.get_X())
129 return X
130 else:
131 return super().get_X_original()
134class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation):
135 """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered."""
137 _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type"]
139 DEFAULT_WINDOW_TYPE = ("kaiser", 5)
141 def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE,
142 plot_path=None, filter_plot_dates=None, **kwargs):
143 # self.original_data = None # ToDo: implement here something to store unfiltered data
144 self.fs = self._get_fs(**kwargs)
145 if filter_window_type == "kzf":
146 filter_cutoff_period = self._get_kzf_cutoff_period(filter_order, self.fs)
147 self.filter_cutoff_period, removed_index = self._prepare_filter_cutoff_period(filter_cutoff_period, self.fs)
148 self.filter_cutoff_freq = self._period_to_freq(self.filter_cutoff_period)
149 assert len(self.filter_cutoff_period) == (len(filter_order) - len(removed_index))
150 self.filter_order = self._prepare_filter_order(filter_order, removed_index, self.fs)
151 self.filter_window_type = filter_window_type
152 self.unfiltered_name = "unfiltered"
153 self.plot_path = plot_path # use this path to create insight plots
154 self.plot_dates = filter_plot_dates
155 super().__init__(*args, **kwargs)
157 @staticmethod
158 def _prepare_filter_order(filter_order, removed_index, fs):
159 order = []
160 for i, o in enumerate(filter_order):
161 if i not in removed_index:
162 if isinstance(o, tuple):
163 fo = (o[0] * fs, o[1])
164 else:
165 fo = int(o * fs)
166 fo = fo + 1 if fo % 2 == 0 else fo
167 order.append(fo)
168 return order
170 @staticmethod
171 def _prepare_filter_cutoff_period(filter_cutoff_period, fs):
172 """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair."""
173 cutoff = []
174 removed = []
175 for i, period in enumerate(to_list(filter_cutoff_period)):
176 if period > 2. / fs:
177 cutoff.append(period)
178 else:
179 removed.append(i)
180 return cutoff, removed
182 @staticmethod
183 def _get_kzf_cutoff_period(kzf_settings, fs):
184 cutoff = []
185 for (m, k) in kzf_settings:
186 w0 = omega_null_kzf(m * fs, k) * fs
187 cutoff.append(1. / w0)
188 return cutoff
190 @staticmethod
191 def _period_to_freq(cutoff_p):
192 return [1. / x for x in cutoff_p]
194 @staticmethod
195 def _get_fs(**kwargs):
196 """Return frequency in 1/day (not Hz)"""
197 sampling = kwargs.get("sampling")
198 if sampling == "daily":
199 return 1
200 elif sampling == "hourly":
201 return 24
202 else:
203 raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.")
205 @TimeTrackingWrapper
206 def apply_filter(self):
207 """Apply FIR filter only on inputs."""
208 fir = FIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq,
209 self.filter_window_type, self.target_dim, self.time_dim, display_name=self.station[0],
210 minimum_length=self.window_history_size, offset=self.window_history_offset,
211 plot_path=self.plot_path, plot_dates=self.plot_dates)
212 self.fir_coeff = fir.filter_coefficients
213 filter_data = fir.filtered_data
214 input_data = xr.concat(filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim))
215 self.input_data = input_data.sel({self.target_dim: self.variables})
216 # this is just a code snippet to check the results of the kz filter
217 # import matplotlib
218 # matplotlib.use("TkAgg")
219 # import matplotlib.pyplot as plt
220 # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot()
221 # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
223 def create_filter_index(self, add_unfiltered_index=True) -> pd.Index:
224 """
225 Round cut off periods in days and append 'res' for residuum index.
227 Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append
228 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition.
229 """
230 index = np.round(self.filter_cutoff_period, 1)
231 f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1)
232 index = list(map(f, index.tolist()))
233 index = list(map(lambda x: str(x) + "d", index)) + ["res"]
234 self.filter_dim_order = index
235 return pd.Index(index, name=self.filter_dim)
237 def _create_lazy_data(self):
238 return [self._data, self.meta, self.input_data, self.target_data, self.fir_coeff, self.filter_dim_order]
240 def _extract_lazy(self, lazy_data):
241 _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data
242 super()._extract_lazy((_data, _meta, _input_data, _target_data))
244 def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None,
245 transformation_dim=None):
246 """
247 Transform data according to given transformation settings.
249 This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
250 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
251 (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
252 method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
253 internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
255 :param string/int dim: This param is not used for inverse transformation.
256 | for xarray.DataArray as string: name of dimension which should be standardised
257 | for pandas.DataFrame as int: axis of dimension which should be standardised
258 :param inverse: Switch between transformation and inverse transformation.
260 :return: xarray.DataArrays or pandas.DataFrames:
261 #. mean: Mean of data
262 #. std: Standard deviation of data
263 #. data: Standardised data
264 """
266 if transformation_dim is None:
267 transformation_dim = self.DEFAULT_TARGET_DIM
269 def f(data, method="standardise", feature_range=None):
270 if method == "standardise":
271 return statistics.standardise(data, dim)
272 elif method == "centre":
273 return statistics.centre(data, dim)
274 elif method == "min_max":
275 kwargs = {"feature_range": feature_range} if feature_range is not None else {}
276 return statistics.min_max(data, dim, **kwargs)
277 elif method == "log":
278 return statistics.log(data, dim)
279 else:
280 raise NotImplementedError
282 def f_apply(data, method, **kwargs):
283 for k, v in kwargs.items():
284 if not (isinstance(v, xr.DataArray) or v is None):
285 _, opts = statistics.min_max(data, dim)
286 helper = xr.ones_like(opts['min'])
287 kwargs[k] = helper * v
288 mean = kwargs.pop('mean', None)
289 std = kwargs.pop('std', None)
290 min = kwargs.pop('min', None)
291 max = kwargs.pop('max', None)
292 feature_range = kwargs.pop('feature_range', None)
294 if method == "standardise":
295 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
296 elif method == "centre":
297 return statistics.centre_apply(data, mean), {"mean": mean, "method": method}
298 elif method == "min_max":
299 return statistics.min_max_apply(data, min, max), {"min": min, "max": max, "method": method,
300 "feature_range": feature_range}
301 elif method == "log":
302 return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
303 else:
304 raise NotImplementedError
306 opts = opts or {}
307 opts_updated = {}
308 if not inverse:
309 transformed_values = []
310 for var in data_in.variables.values:
311 data_var = data_in.sel(**{transformation_dim: [var]})
312 var_opts = opts.get(var, {})
313 _apply = (var_opts.get("mean", None) is not None) or (var_opts.get("min") is not None)
314 values, new_var_opts = locals()["f_apply" if _apply else "f"](data_var, **var_opts)
315 opts_updated[var] = copy.deepcopy(new_var_opts)
316 transformed_values.append(values)
317 return xr.concat(transformed_values, dim=transformation_dim), opts_updated
318 else:
319 return self.inverse_transform(data_in, opts, transformation_dim)
322class DataHandlerFirFilter(DataHandlerFilter):
323 """Data handler using FIR filtered data."""
325 data_handler = DataHandlerFirFilterSingleStation
326 data_handler_transformation = DataHandlerFirFilterSingleStation
327 _requirements = data_handler.requirements()
330class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation):
331 """
332 Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered. In contrast to
333 the simple DataHandlerFirFilterSingleStation, this data handler is centered around t0 to have no time delay. For
334 values in the future (t > t0), this data handler assumes a climatological value for the low pass data and values of
335 0 for all residuum components.
337 :param apriori: Data to use as apriori information. This should be either a xarray dataarray containing monthly or
338 any other heuristic to support the clim filter, or a list of such arrays containing heuristics for all residua
339 in addition. The 2nd can be used together with apriori_type `residuum_stats` which estimates the error of the
340 residuum when the clim filter should be applied with exogenous parameters. If apriori_type is None/`zeros` data
341 can be provided, but this is not required in this case.
342 :param apriori_type: set type of information that is provided to the clim filter. For the first low pass always a
343 calculated or given statistic is used. For residuum prediction a constant value of zero is assumed if
344 apriori_type is None or `zeros`, and a climatology of the residuum is used for `residuum_stats`.
345 :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by
346 parameter apriori_type. This is only applicable for hourly resolution data.
347 :param apriori_sel_opts: specify some parameters to select a subset of data before calculating the apriori
348 information. Use this parameter for example, if apriori shall only calculated on a shorter time period than
349 available in given data.
350 :param extend_length_opts: use this parameter to use future data in the filter calculation. This parameter does not
351 affect the size of the history samples as this is handled by the window_history_size parameter. Example: set
352 extend_length_opts=7*24 to use the observation of the next 7 days to calculate the filtered components. Which
353 data are finally used for the input samples is not affected by these 7 days. In case the range of history sample
354 exceeds the horizon of extend_length_opts, the history sample will also include data from climatological
355 estimates.
356 """
357 DEFAULT_EXTEND_LENGTH_OPTS = 0
358 _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal",
359 "extend_length_opts"]
360 _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"]
362 def __init__(self, *args, apriori=None, apriori_type=None, apriori_diurnal=False, apriori_sel_opts=None,
363 extend_length_opts=DEFAULT_EXTEND_LENGTH_OPTS, **kwargs):
364 self.apriori_type = apriori_type
365 self.climate_filter_coeff = None # coefficents of the used FIR filter
366 self.apriori = apriori # exogenous apriori information or None to calculate from data (endogenous)
367 self.apriori_diurnal = apriori_diurnal
368 self.all_apriori = None # collection of all apriori information
369 self.apriori_sel_opts = apriori_sel_opts # ensure to separate exogenous and endogenous information
370 self.extend_length_opts = extend_length_opts
371 super().__init__(*args, **kwargs)
373 @TimeTrackingWrapper
374 def apply_filter(self):
375 """Apply FIR filter only on inputs."""
376 self.apriori = self.apriori.get(str(self)) if isinstance(self.apriori, dict) else self.apriori
377 logging.info(f"{self.station}: call ClimateFIRFilter")
378 climate_filter = ClimateFIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order,
379 self.filter_cutoff_freq,
380 self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim,
381 apriori_type=self.apriori_type, apriori=self.apriori,
382 apriori_diurnal=self.apriori_diurnal, sel_opts=self.apriori_sel_opts,
383 plot_path=self.plot_path,
384 minimum_length=self.window_history_size, new_dim=self.window_dim,
385 display_name=self.station[0], extend_length_opts=self.extend_length_opts,
386 offset=self.window_history_end, plot_dates=self.plot_dates)
387 self.climate_filter_coeff = climate_filter.filter_coefficients
389 # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori
390 if self.apriori_type == "residuum_stats":
391 self.apriori = climate_filter.apriori_data
392 else:
393 self.apriori = climate_filter.initial_apriori_data
394 self.all_apriori = climate_filter.apriori_data
396 climate_filter_data = [c.sel({self.window_dim: slice(self.window_history_end-self.window_history_size,
397 self.window_history_end)})
398 for c in climate_filter.filtered_data]
400 # create input data with filter index
401 input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(add_unfiltered_index=False),
402 name=self.filter_dim))
404 self.input_data = input_data.sel({self.target_dim: self.variables})
406 # this is just a code snippet to check the results of the filter
407 # import matplotlib
408 # matplotlib.use("TkAgg")
409 # import matplotlib.pyplot as plt
410 # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot()
411 # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
413 def create_filter_index(self, add_unfiltered_index=True) -> pd.Index:
414 """
415 Round cut off periods in days and append 'res' for residuum index.
417 Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append
418 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition.
419 """
420 index = np.round(self.filter_cutoff_period, 1)
421 f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1)
422 index = list(map(f, index.tolist()))
423 index = list(map(lambda x: str(x) + "d", index)) + ["res"]
424 self.filter_dim_order = index
425 return pd.Index(index, name=self.filter_dim)
427 def _create_lazy_data(self):
428 return [self._data, self.meta, self.input_data, self.target_data, self.climate_filter_coeff,
429 self.apriori, self.all_apriori, self.filter_dim_order]
431 def _extract_lazy(self, lazy_data):
432 _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori, \
433 self.filter_dim_order = lazy_data
434 DataHandlerSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data))
436 @staticmethod
437 def _prepare_filter_cutoff_period(filter_cutoff_period, fs):
438 """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair."""
439 cutoff = []
440 removed = []
441 for i, period in enumerate(to_list(filter_cutoff_period)):
442 if period > 2. / fs:
443 cutoff.append(period)
444 else:
445 removed.append(i)
446 return cutoff, removed
448 @staticmethod
449 def _period_to_freq(cutoff_p):
450 return [1. / x for x in cutoff_p]
452 def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
453 """
454 Create a xr.DataArray containing history data. As 'input_data' already consists of a dimension 'window', this
455 method only shifts the data along 'window' dimension x times where x is given by 'window_history_offset'.
456 Results are stored in history attribute.
458 :param dim_name_of_inputs: Name of dimension which contains the input variables
459 :param window: this parameter is not used in the inherited method
460 :param dim_name_of_shift: Dimension along shift will be applied
461 """
462 data = self.input_data
463 sampling = {"daily": "D", "hourly": "h"}.get(to_list(self.sampling)[0])
464 data.coords[dim_name_of_shift] = data.coords[dim_name_of_shift] - np.timedelta64(self.window_history_offset,
465 sampling)
466 data.coords[self.window_dim] = data.coords[self.window_dim] + self.window_history_offset
467 self.history = data
468 # from matplotlib import pyplot as plt
469 # d = self.load_and_interpolate(0)
470 # data.sel(datetime="2007-07-07 00:00").sum("filter").plot()
471 # plt.plot(data.sel(datetime="2007-07-07 00:00").sum("filter").window.values, d.sel(datetime=slice("2007-07-05 00:00", "2007-07-07 16:00")).values.flatten())
472 # plt.plot(data.sel(datetime="2007-07-07 00:00").sum("filter").window.values, d.sel(datetime=slice("2007-07-05 00:00", "2007-07-11 16:00")).values.flatten())
474 def call_transform(self, inverse=False):
475 opts_input = self._transformation[0]
476 self.input_data, opts_input = self.transform(self.input_data, dim=[self.time_dim, self.window_dim],
477 inverse=inverse, opts=opts_input,
478 transformation_dim=self.target_dim)
479 opts_target = self._transformation[1]
480 self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse,
481 opts=opts_target, transformation_dim=self.target_dim)
482 self._transformation = (opts_input, opts_target)
485class DataHandlerClimateFirFilter(DataHandlerFilter):
486 """Data handler using climatic adjusted FIR filtered data."""
488 data_handler = DataHandlerClimateFirFilterSingleStation
489 data_handler_transformation = DataHandlerClimateFirFilterSingleStation
490 _requirements = data_handler.requirements()
491 _store_attributes = data_handler.store_attributes()