Coverage for mlair/data_handler/data_handler_with_filter.py: 23%
239 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-12-18 17:51 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-12-18 17:51 +0000
1"""Data Handler using kz-filtered data."""
3__author__ = 'Lukas Leufen'
4__date__ = '2020-08-26'
6import copy
7import numpy as np
8import pandas as pd
9import xarray as xr
10from typing import List, Union, Tuple, Optional
11from functools import partial
12import logging
13from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
14from mlair.data_handler import DefaultDataHandler
15from mlair.helpers import to_list, TimeTrackingWrapper, statistics
16from mlair.helpers.filter import FIRFilter, ClimateFIRFilter, omega_null_kzf
18# define a more general date type for type hinting
19str_or_list = Union[str, List[str]]
22# cutoff_p = [(None, 14), (8, 6), (2, 0.8), (0.8, None)]
23# cutoff = list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), cutoff_p))
24# fs = 24.
25# # order = int(60 * fs) + 1
26# order = np.array([int(14 * fs) + 1, int(14 * fs) + 1, int(4 * fs) + 1, int(2 * fs) + 1])
27# print("cutoff period", cutoff_p)
28# print("cutoff", cutoff)
29# print("fs", fs)
30# print("order", order)
31# print("delay", 0.5 * (order-1) / fs)
32# window = ("kaiser", 5)
33# # low pass
34# y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low = cutoff[0][0], cutoff_high = cutoff[0][1], window=window)
35# filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape)
38class DataHandlerFilterSingleStation(DataHandlerSingleStation):
39 """General data handler for a single station to be used by a superior data handler."""
41 _hash = DataHandlerSingleStation._hash + ["filter_dim"]
43 DEFAULT_FILTER_DIM = "filter"
45 def __init__(self, *args, filter_dim=DEFAULT_FILTER_DIM, **kwargs):
46 # self.original_data = None # ToDo: implement here something to store unfiltered data
47 self.filter_dim = filter_dim
48 self.filter_dim_order = None
49 super().__init__(*args, **kwargs)
51 def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
52 """
53 Adjust setup of transformation because filtered data will have negative values which is not compatible with
54 the log transformation. Therefore, replace all log transformation methods by a default standardization. This is
55 only applied on input side.
56 """
57 transformation = super(__class__, self).setup_transformation(transformation)
58 if transformation[0] is not None:
59 for k, v in transformation[0].items():
60 if v["method"] == "log":
61 transformation[0][k]["method"] = "standardise"
62 elif v["method"] == "min_max":
63 transformation[0][k]["method"] = "standardise"
64 return transformation
66 def _check_sampling(self, **kwargs):
67 assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution, does it?
69 def make_input_target(self):
70 data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
71 self.store_data_locally, self.data_origin, self.start, self.end)
72 self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
73 limit=self.interpolation_limit)
74 self.set_inputs_and_targets()
75 self.apply_filter()
76 # this is just a code snippet to check the results of the kz filter
77 # import matplotlib
78 # matplotlib.use("TkAgg")
79 # import matplotlib.pyplot as plt
80 # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot()
81 # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
83 def apply_filter(self):
84 raise NotImplementedError
86 def create_filter_index(self) -> pd.Index:
87 """Create name for filter dimension."""
88 raise NotImplementedError
90 def get_transposed_history(self) -> xr.DataArray:
91 """Return history.
93 :return: history with dimensions datetime, window, Stations, variables, filter.
94 """
95 return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim,
96 self.filter_dim).copy()
98 def _create_lazy_data(self):
99 raise NotImplementedError
101 def _extract_lazy(self, lazy_data):
102 _data, self.meta, _input_data, _target_data = lazy_data
103 f_prep = partial(self._slice_prep, start=self.start, end=self.end)
104 self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
107class DataHandlerFilter(DefaultDataHandler):
108 """Data handler using FIR filtered data."""
110 data_handler = DataHandlerFilterSingleStation
111 data_handler_transformation = DataHandlerFilterSingleStation
112 _requirements = data_handler.requirements()
114 def __init__(self, *args, use_filter_branches=False, **kwargs):
115 self.use_filter_branches = use_filter_branches
116 super().__init__(*args, **kwargs)
118 def get_X_original(self):
119 if self.use_filter_branches is True:
120 X = []
121 for data in self._collection:
122 if hasattr(data, "filter_dim"):
123 X_total = data.get_X()
124 filter_dim = data.filter_dim
125 for filter_name in data.filter_dim_order:
126 X.append(X_total.sel({filter_dim: filter_name}, drop=True))
127 else:
128 X.append(data.get_X())
129 return X
130 else:
131 return super().get_X_original()
134class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation):
135 """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered."""
137 _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type"]
139 DEFAULT_WINDOW_TYPE = ("kaiser", 5)
141 def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE,
142 plot_path=None, filter_plot_dates=None, **kwargs):
143 # self.original_data = None # ToDo: implement here something to store unfiltered data
144 self.fs = self._get_fs(**kwargs)
145 if filter_window_type == "kzf":
146 filter_cutoff_period = self._get_kzf_cutoff_period(filter_order, self.fs)
147 self.filter_cutoff_period, removed_index = self._prepare_filter_cutoff_period(filter_cutoff_period, self.fs)
148 self.filter_cutoff_freq = self._period_to_freq(self.filter_cutoff_period)
149 assert len(self.filter_cutoff_period) == (len(filter_order) - len(removed_index))
150 self.filter_order = self._prepare_filter_order(filter_order, removed_index, self.fs)
151 self.filter_window_type = filter_window_type
152 self.unfiltered_name = "unfiltered"
153 self.plot_path = plot_path # use this path to create insight plots
154 self.plot_dates = filter_plot_dates
155 super().__init__(*args, **kwargs)
157 @staticmethod
158 def _prepare_filter_order(filter_order, removed_index, fs):
159 order = []
160 for i, o in enumerate(filter_order):
161 if i not in removed_index:
162 if isinstance(o, tuple):
163 fo = (o[0] * fs, o[1])
164 else:
165 fo = int(o * fs)
166 fo = fo + 1 if fo % 2 == 0 else fo
167 order.append(fo)
168 return order
170 @staticmethod
171 def _prepare_filter_cutoff_period(filter_cutoff_period, fs):
172 """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair."""
173 cutoff = []
174 removed = []
175 for i, period in enumerate(to_list(filter_cutoff_period)):
176 if period > 2. / fs:
177 cutoff.append(period)
178 else:
179 removed.append(i)
180 return cutoff, removed
182 @staticmethod
183 def _get_kzf_cutoff_period(kzf_settings, fs):
184 cutoff = []
185 for (m, k) in kzf_settings:
186 w0 = omega_null_kzf(m * fs, k) * fs
187 cutoff.append(1. / w0)
188 return cutoff
190 @staticmethod
191 def _period_to_freq(cutoff_p):
192 return [1. / x for x in cutoff_p]
194 @staticmethod
195 def _get_fs(**kwargs):
196 """Return frequency in 1/day (not Hz)"""
197 sampling = kwargs.get("sampling")
198 if sampling == "daily":
199 return 1
200 elif sampling == "hourly":
201 return 24
202 else:
203 raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.")
205 @TimeTrackingWrapper
206 def apply_filter(self):
207 """Apply FIR filter only on inputs."""
208 fir = FIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq,
209 self.filter_window_type, self.target_dim, self.time_dim, display_name=self.station[0],
210 minimum_length=self.window_history_size, offset=self.window_history_offset,
211 plot_path=self.plot_path, plot_dates=self.plot_dates)
212 self.fir_coeff = fir.filter_coefficients
213 filter_data = fir.filtered_data
214 input_data = xr.concat(filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim))
215 self.input_data = input_data.sel({self.target_dim: self.variables})
216 # this is just a code snippet to check the results of the kz filter
217 # import matplotlib
218 # matplotlib.use("TkAgg")
219 # import matplotlib.pyplot as plt
220 # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot()
221 # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
223 def create_filter_index(self, add_unfiltered_index=True) -> pd.Index:
224 """
225 Round cut off periods in days and append 'res' for residuum index.
227 Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append
228 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition.
229 """
230 index = np.round(self.filter_cutoff_period, 1)
231 f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1)
232 index = list(map(f, index.tolist()))
233 index = list(map(lambda x: str(x) + "d", index)) + ["res"]
234 self.filter_dim_order = index
235 return pd.Index(index, name=self.filter_dim)
237 def _create_lazy_data(self):
238 return [self._data, self.meta, self.input_data, self.target_data, self.fir_coeff, self.filter_dim_order]
240 def _extract_lazy(self, lazy_data):
241 _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data
242 super()._extract_lazy((_data, _meta, _input_data, _target_data))
244 def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None,
245 transformation_dim=None):
246 """
247 Transform data according to given transformation settings.
249 This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
250 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
251 (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
252 method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
253 internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
255 :param string/int dim: This param is not used for inverse transformation.
256 | for xarray.DataArray as string: name of dimension which should be standardised
257 | for pandas.DataFrame as int: axis of dimension which should be standardised
258 :param inverse: Switch between transformation and inverse transformation.
260 :return: xarray.DataArrays or pandas.DataFrames:
261 #. mean: Mean of data
262 #. std: Standard deviation of data
263 #. data: Standardised data
264 """
266 if transformation_dim is None:
267 transformation_dim = self.DEFAULT_TARGET_DIM
269 def f(data, method="standardise", feature_range=None):
270 if method == "standardise":
271 return statistics.standardise(data, dim)
272 elif method == "centre":
273 return statistics.centre(data, dim)
274 elif method == "min_max":
275 kwargs = {"feature_range": feature_range} if feature_range is not None else {}
276 return statistics.min_max(data, dim, **kwargs)
277 elif method == "log":
278 return statistics.log(data, dim)
279 else:
280 raise NotImplementedError
282 def f_apply(data, method, **kwargs):
283 for k, v in kwargs.items():
284 if not (isinstance(v, xr.DataArray) or v is None):
285 _, opts = statistics.min_max(data, dim)
286 helper = xr.ones_like(opts['min'])
287 kwargs[k] = helper * v
288 mean = kwargs.pop('mean', None)
289 std = kwargs.pop('std', None)
290 min = kwargs.pop('min', None)
291 max = kwargs.pop('max', None)
292 feature_range = kwargs.pop('feature_range', None)
294 if method == "standardise":
295 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
296 elif method == "centre":
297 return statistics.centre_apply(data, mean), {"mean": mean, "method": method}
298 elif method == "min_max":
299 return statistics.min_max_apply(data, min, max), {"min": min, "max": max, "method": method,
300 "feature_range": feature_range}
301 elif method == "log":
302 return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
303 else:
304 raise NotImplementedError
306 opts = opts or {}
307 opts_updated = {}
308 if not inverse:
309 transformed_values = []
310 for var in data_in.variables.values:
311 data_var = data_in.sel(**{transformation_dim: [var]})
312 var_opts = opts.get(var, {})
313 _apply = (var_opts.get("mean", None) is not None) or (var_opts.get("min") is not None)
314 values, new_var_opts = locals()["f_apply" if _apply else "f"](data_var, **var_opts)
315 opts_updated[var] = copy.deepcopy(new_var_opts)
316 transformed_values.append(values)
317 return xr.concat(transformed_values, dim=transformation_dim), opts_updated
318 else:
319 return self.inverse_transform(data_in, opts, transformation_dim)
322class DataHandlerFirFilter(DataHandlerFilter):
323 """Data handler using FIR filtered data."""
325 data_handler = DataHandlerFirFilterSingleStation
326 data_handler_transformation = DataHandlerFirFilterSingleStation
327 _requirements = data_handler.requirements()
330class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation):
331 """
332 Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered. In contrast to
333 the simple DataHandlerFirFilterSingleStation, this data handler is centered around t0 to have no time delay. For
334 values in the future (t > t0), this data handler assumes a climatological value for the low pass data and values of
335 0 for all residuum components.
337 :param apriori: Data to use as apriori information. This should be either a xarray dataarray containing monthly or
338 any other heuristic to support the clim filter, or a list of such arrays containing heuristics for all residua
339 in addition. The 2nd can be used together with apriori_type `residuum_stats` which estimates the error of the
340 residuum when the clim filter should be applied with exogenous parameters. If apriori_type is None/`zeros` data
341 can be provided, but this is not required in this case.
342 :param apriori_type: set type of information that is provided to the clim filter. For the first low pass always a
343 calculated or given statistic is used. For residuum prediction a constant value of zero is assumed if
344 apriori_type is None or `zeros`, and a climatology of the residuum is used for `residuum_stats`.
345 :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by
346 parameter apriori_type. This is only applicable for hourly resolution data.
347 :param apriori_sel_opts: specify some parameters to select a subset of data before calculating the apriori
348 information. Use this parameter for example, if apriori shall only calculated on a shorter time period than
349 available in given data.
350 :param extend_length_opts: use this parameter to use future data in the filter calculation. This parameter does not
351 affect the size of the history samples as this is handled by the window_history_size parameter. Example: set
352 extend_length_opts=7*24 to use the observation of the next 7 days to calculate the filtered components. Which
353 data are finally used for the input samples is not affected by these 7 days. In case the range of history sample
354 exceeds the horizon of extend_length_opts, the history sample will also include data from climatological
355 estimates.
356 """
357 DEFAULT_EXTEND_LENGTH_OPTS = 0
358 _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal",
359 "extend_length_opts"]
360 _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"]
362 def __init__(self, *args, apriori=None, apriori_type=None, apriori_diurnal=False, apriori_sel_opts=None,
363 extend_length_opts=DEFAULT_EXTEND_LENGTH_OPTS, **kwargs):
364 self.apriori_type = apriori_type
365 self.climate_filter_coeff = None # coefficents of the used FIR filter
366 self.apriori = apriori # exogenous apriori information or None to calculate from data (endogenous)
367 self.apriori_diurnal = apriori_diurnal
368 self.all_apriori = None # collection of all apriori information
369 self.apriori_sel_opts = apriori_sel_opts # ensure to separate exogenous and endogenous information
370 self.extend_length_opts = extend_length_opts
371 super().__init__(*args, **kwargs)
373 @TimeTrackingWrapper
374 def apply_filter(self):
375 """Apply FIR filter only on inputs."""
376 self.apriori = self.apriori.get(str(self)) if isinstance(self.apriori, dict) else self.apriori
377 logging.info(f"{self.station[0]}: call ClimateFIRFilter")
378 climate_filter = ClimateFIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order,
379 self.filter_cutoff_freq,
380 self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim,
381 apriori_type=self.apriori_type, apriori=self.apriori,
382 apriori_diurnal=self.apriori_diurnal, sel_opts=self.apriori_sel_opts,
383 plot_path=self.plot_path,
384 minimum_length=self.window_history_size, new_dim=self.window_dim,
385 display_name=self.station[0], extend_length_opts=self.extend_length_opts,
386 extend_end=self.window_history_end, plot_dates=self.plot_dates,
387 offset=self.window_history_offset)
388 self.climate_filter_coeff = climate_filter.filter_coefficients
390 # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori
391 if self.apriori_type == "residuum_stats":
392 self.apriori = climate_filter.apriori_data
393 else:
394 self.apriori = climate_filter.initial_apriori_data
395 self.all_apriori = climate_filter.apriori_data
397 climate_filter_data = [c.sel({self.window_dim: slice(self.window_history_end-self.window_history_size,
398 self.window_history_end)})
399 for c in climate_filter.filtered_data]
401 # create input data with filter index
402 input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(add_unfiltered_index=False),
403 name=self.filter_dim))
405 self.input_data = input_data.sel({self.target_dim: self.variables})
407 # this is just a code snippet to check the results of the filter
408 # import matplotlib
409 # matplotlib.use("TkAgg")
410 # import matplotlib.pyplot as plt
411 # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot()
412 # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
414 def create_filter_index(self, add_unfiltered_index=True) -> pd.Index:
415 """
416 Round cut off periods in days and append 'res' for residuum index.
418 Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append
419 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition.
420 """
421 index = np.round(self.filter_cutoff_period, 1)
422 f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1)
423 index = list(map(f, index.tolist()))
424 index = list(map(lambda x: str(x) + "d", index)) + ["res"]
425 self.filter_dim_order = index
426 return pd.Index(index, name=self.filter_dim)
428 def _create_lazy_data(self):
429 return [self._data, self.meta, self.input_data, self.target_data, self.climate_filter_coeff,
430 self.apriori, self.all_apriori, self.filter_dim_order]
432 def _extract_lazy(self, lazy_data):
433 _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori, \
434 self.filter_dim_order = lazy_data
435 DataHandlerSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data))
437 @staticmethod
438 def _prepare_filter_cutoff_period(filter_cutoff_period, fs):
439 """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair."""
440 cutoff = []
441 removed = []
442 for i, period in enumerate(to_list(filter_cutoff_period)):
443 if period > 2. / fs:
444 cutoff.append(period)
445 else:
446 removed.append(i)
447 return cutoff, removed
449 @staticmethod
450 def _period_to_freq(cutoff_p):
451 return [1. / x for x in cutoff_p]
453 def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
454 """
455 Create a xr.DataArray containing history data. As 'input_data' already consists of a dimension 'window', this
456 method only shifts the data along 'window' dimension x times where x is given by 'window_history_offset'.
457 Results are stored in history attribute.
459 :param dim_name_of_inputs: Name of dimension which contains the input variables
460 :param window: this parameter is not used in the inherited method
461 :param dim_name_of_shift: Dimension along shift will be applied
462 """
463 self.history = self.input_data
464 # from matplotlib import pyplot as plt
465 # d = self.load_and_interpolate(0)
466 # data.sel(datetime="2007-07-07 00:00").sum("filter").plot()
467 # plt.plot(data.sel(datetime="2007-07-07 00:00").sum("filter").window.values, d.sel(datetime=slice("2007-07-05 00:00", "2007-07-07 16:00")).values.flatten())
468 # plt.plot(data.sel(datetime="2007-07-07 00:00").sum("filter").window.values, d.sel(datetime=slice("2007-07-05 00:00", "2007-07-11 16:00")).values.flatten())
470 def call_transform(self, inverse=False):
471 opts_input = self._transformation[0]
472 self.input_data, opts_input = self.transform(self.input_data, dim=[self.time_dim, self.window_dim],
473 inverse=inverse, opts=opts_input,
474 transformation_dim=self.target_dim)
475 opts_target = self._transformation[1]
476 self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse,
477 opts=opts_target, transformation_dim=self.target_dim)
478 self._transformation = (opts_input, opts_target)
481class DataHandlerClimateFirFilter(DataHandlerFilter):
482 """Data handler using climatic adjusted FIR filtered data."""
484 data_handler = DataHandlerClimateFirFilterSingleStation
485 data_handler_transformation = DataHandlerClimateFirFilterSingleStation
486 _requirements = data_handler.requirements()
487 _store_attributes = data_handler.store_attributes()