Coverage for mlair/data_handler/data_handler_with_filter.py: 23%

239 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1"""Data Handler using kz-filtered data.""" 

2 

3__author__ = 'Lukas Leufen' 

4__date__ = '2020-08-26' 

5 

6import copy 

7import numpy as np 

8import pandas as pd 

9import xarray as xr 

10from typing import List, Union, Tuple, Optional 

11from functools import partial 

12import logging 

13from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation 

14from mlair.data_handler import DefaultDataHandler 

15from mlair.helpers import to_list, TimeTrackingWrapper, statistics 

16from mlair.helpers.filter import FIRFilter, ClimateFIRFilter, omega_null_kzf 

17 

18# define a more general date type for type hinting 

19str_or_list = Union[str, List[str]] 

20 

21 

22# cutoff_p = [(None, 14), (8, 6), (2, 0.8), (0.8, None)] 

23# cutoff = list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), cutoff_p)) 

24# fs = 24. 

25# # order = int(60 * fs) + 1 

26# order = np.array([int(14 * fs) + 1, int(14 * fs) + 1, int(4 * fs) + 1, int(2 * fs) + 1]) 

27# print("cutoff period", cutoff_p) 

28# print("cutoff", cutoff) 

29# print("fs", fs) 

30# print("order", order) 

31# print("delay", 0.5 * (order-1) / fs) 

32# window = ("kaiser", 5) 

33# # low pass 

34# y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low = cutoff[0][0], cutoff_high = cutoff[0][1], window=window) 

35# filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape) 

36 

37 

38class DataHandlerFilterSingleStation(DataHandlerSingleStation): 

39 """General data handler for a single station to be used by a superior data handler.""" 

40 

41 _hash = DataHandlerSingleStation._hash + ["filter_dim"] 

42 

43 DEFAULT_FILTER_DIM = "filter" 

44 

45 def __init__(self, *args, filter_dim=DEFAULT_FILTER_DIM, **kwargs): 

46 # self.original_data = None # ToDo: implement here something to store unfiltered data 

47 self.filter_dim = filter_dim 

48 self.filter_dim_order = None 

49 super().__init__(*args, **kwargs) 

50 

51 def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: 

52 """ 

53 Adjust setup of transformation because filtered data will have negative values which is not compatible with 

54 the log transformation. Therefore, replace all log transformation methods by a default standardization. This is 

55 only applied on input side. 

56 """ 

57 transformation = super(__class__, self).setup_transformation(transformation) 

58 if transformation[0] is not None: 

59 for k, v in transformation[0].items(): 

60 if v["method"] == "log": 

61 transformation[0][k]["method"] = "standardise" 

62 elif v["method"] == "min_max": 

63 transformation[0][k]["method"] = "standardise" 

64 return transformation 

65 

66 def _check_sampling(self, **kwargs): 

67 assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution, does it? 

68 

69 def make_input_target(self): 

70 data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, 

71 self.store_data_locally, self.data_origin, self.start, self.end) 

72 self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, 

73 limit=self.interpolation_limit) 

74 self.set_inputs_and_targets() 

75 self.apply_filter() 

76 # this is just a code snippet to check the results of the kz filter 

77 # import matplotlib 

78 # matplotlib.use("TkAgg") 

79 # import matplotlib.pyplot as plt 

80 # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() 

81 # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") 

82 

83 def apply_filter(self): 

84 raise NotImplementedError 

85 

86 def create_filter_index(self) -> pd.Index: 

87 """Create name for filter dimension.""" 

88 raise NotImplementedError 

89 

90 def get_transposed_history(self) -> xr.DataArray: 

91 """Return history. 

92 

93 :return: history with dimensions datetime, window, Stations, variables, filter. 

94 """ 

95 return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim, 

96 self.filter_dim).copy() 

97 

98 def _create_lazy_data(self): 

99 raise NotImplementedError 

100 

101 def _extract_lazy(self, lazy_data): 

102 _data, self.meta, _input_data, _target_data = lazy_data 

103 f_prep = partial(self._slice_prep, start=self.start, end=self.end) 

104 self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) 

105 

106 

107class DataHandlerFilter(DefaultDataHandler): 

108 """Data handler using FIR filtered data.""" 

109 

110 data_handler = DataHandlerFilterSingleStation 

111 data_handler_transformation = DataHandlerFilterSingleStation 

112 _requirements = data_handler.requirements() 

113 

114 def __init__(self, *args, use_filter_branches=False, **kwargs): 

115 self.use_filter_branches = use_filter_branches 

116 super().__init__(*args, **kwargs) 

117 

118 def get_X_original(self): 

119 if self.use_filter_branches is True: 

120 X = [] 

121 for data in self._collection: 

122 if hasattr(data, "filter_dim"): 

123 X_total = data.get_X() 

124 filter_dim = data.filter_dim 

125 for filter_name in data.filter_dim_order: 

126 X.append(X_total.sel({filter_dim: filter_name}, drop=True)) 

127 else: 

128 X.append(data.get_X()) 

129 return X 

130 else: 

131 return super().get_X_original() 

132 

133 

134class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): 

135 """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered.""" 

136 

137 _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type"] 

138 

139 DEFAULT_WINDOW_TYPE = ("kaiser", 5) 

140 

141 def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, 

142 plot_path=None, filter_plot_dates=None, **kwargs): 

143 # self.original_data = None # ToDo: implement here something to store unfiltered data 

144 self.fs = self._get_fs(**kwargs) 

145 if filter_window_type == "kzf": 

146 filter_cutoff_period = self._get_kzf_cutoff_period(filter_order, self.fs) 

147 self.filter_cutoff_period, removed_index = self._prepare_filter_cutoff_period(filter_cutoff_period, self.fs) 

148 self.filter_cutoff_freq = self._period_to_freq(self.filter_cutoff_period) 

149 assert len(self.filter_cutoff_period) == (len(filter_order) - len(removed_index)) 

150 self.filter_order = self._prepare_filter_order(filter_order, removed_index, self.fs) 

151 self.filter_window_type = filter_window_type 

152 self.unfiltered_name = "unfiltered" 

153 self.plot_path = plot_path # use this path to create insight plots 

154 self.plot_dates = filter_plot_dates 

155 super().__init__(*args, **kwargs) 

156 

157 @staticmethod 

158 def _prepare_filter_order(filter_order, removed_index, fs): 

159 order = [] 

160 for i, o in enumerate(filter_order): 

161 if i not in removed_index: 

162 if isinstance(o, tuple): 

163 fo = (o[0] * fs, o[1]) 

164 else: 

165 fo = int(o * fs) 

166 fo = fo + 1 if fo % 2 == 0 else fo 

167 order.append(fo) 

168 return order 

169 

170 @staticmethod 

171 def _prepare_filter_cutoff_period(filter_cutoff_period, fs): 

172 """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair.""" 

173 cutoff = [] 

174 removed = [] 

175 for i, period in enumerate(to_list(filter_cutoff_period)): 

176 if period > 2. / fs: 

177 cutoff.append(period) 

178 else: 

179 removed.append(i) 

180 return cutoff, removed 

181 

182 @staticmethod 

183 def _get_kzf_cutoff_period(kzf_settings, fs): 

184 cutoff = [] 

185 for (m, k) in kzf_settings: 

186 w0 = omega_null_kzf(m * fs, k) * fs 

187 cutoff.append(1. / w0) 

188 return cutoff 

189 

190 @staticmethod 

191 def _period_to_freq(cutoff_p): 

192 return [1. / x for x in cutoff_p] 

193 

194 @staticmethod 

195 def _get_fs(**kwargs): 

196 """Return frequency in 1/day (not Hz)""" 

197 sampling = kwargs.get("sampling") 

198 if sampling == "daily": 

199 return 1 

200 elif sampling == "hourly": 

201 return 24 

202 else: 

203 raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.") 

204 

205 @TimeTrackingWrapper 

206 def apply_filter(self): 

207 """Apply FIR filter only on inputs.""" 

208 fir = FIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq, 

209 self.filter_window_type, self.target_dim, self.time_dim, display_name=self.station[0], 

210 minimum_length=self.window_history_size, offset=self.window_history_offset, 

211 plot_path=self.plot_path, plot_dates=self.plot_dates) 

212 self.fir_coeff = fir.filter_coefficients 

213 filter_data = fir.filtered_data 

214 input_data = xr.concat(filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) 

215 self.input_data = input_data.sel({self.target_dim: self.variables}) 

216 # this is just a code snippet to check the results of the kz filter 

217 # import matplotlib 

218 # matplotlib.use("TkAgg") 

219 # import matplotlib.pyplot as plt 

220 # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() 

221 # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") 

222 

223 def create_filter_index(self, add_unfiltered_index=True) -> pd.Index: 

224 """ 

225 Round cut off periods in days and append 'res' for residuum index. 

226 

227 Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append 

228 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition. 

229 """ 

230 index = np.round(self.filter_cutoff_period, 1) 

231 f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) 

232 index = list(map(f, index.tolist())) 

233 index = list(map(lambda x: str(x) + "d", index)) + ["res"] 

234 self.filter_dim_order = index 

235 return pd.Index(index, name=self.filter_dim) 

236 

237 def _create_lazy_data(self): 

238 return [self._data, self.meta, self.input_data, self.target_data, self.fir_coeff, self.filter_dim_order] 

239 

240 def _extract_lazy(self, lazy_data): 

241 _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data 

242 super()._extract_lazy((_data, _meta, _input_data, _target_data)) 

243 

244 def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None, 

245 transformation_dim=None): 

246 """ 

247 Transform data according to given transformation settings. 

248 

249 This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0 

250 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale 

251 (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This 

252 method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the 

253 internal transform method, internal mean and internal standard deviation weren't set ('inverse=True'). 

254 

255 :param string/int dim: This param is not used for inverse transformation. 

256 | for xarray.DataArray as string: name of dimension which should be standardised 

257 | for pandas.DataFrame as int: axis of dimension which should be standardised 

258 :param inverse: Switch between transformation and inverse transformation. 

259 

260 :return: xarray.DataArrays or pandas.DataFrames: 

261 #. mean: Mean of data 

262 #. std: Standard deviation of data 

263 #. data: Standardised data 

264 """ 

265 

266 if transformation_dim is None: 

267 transformation_dim = self.DEFAULT_TARGET_DIM 

268 

269 def f(data, method="standardise", feature_range=None): 

270 if method == "standardise": 

271 return statistics.standardise(data, dim) 

272 elif method == "centre": 

273 return statistics.centre(data, dim) 

274 elif method == "min_max": 

275 kwargs = {"feature_range": feature_range} if feature_range is not None else {} 

276 return statistics.min_max(data, dim, **kwargs) 

277 elif method == "log": 

278 return statistics.log(data, dim) 

279 else: 

280 raise NotImplementedError 

281 

282 def f_apply(data, method, **kwargs): 

283 for k, v in kwargs.items(): 

284 if not (isinstance(v, xr.DataArray) or v is None): 

285 _, opts = statistics.min_max(data, dim) 

286 helper = xr.ones_like(opts['min']) 

287 kwargs[k] = helper * v 

288 mean = kwargs.pop('mean', None) 

289 std = kwargs.pop('std', None) 

290 min = kwargs.pop('min', None) 

291 max = kwargs.pop('max', None) 

292 feature_range = kwargs.pop('feature_range', None) 

293 

294 if method == "standardise": 

295 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method} 

296 elif method == "centre": 

297 return statistics.centre_apply(data, mean), {"mean": mean, "method": method} 

298 elif method == "min_max": 

299 return statistics.min_max_apply(data, min, max), {"min": min, "max": max, "method": method, 

300 "feature_range": feature_range} 

301 elif method == "log": 

302 return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method} 

303 else: 

304 raise NotImplementedError 

305 

306 opts = opts or {} 

307 opts_updated = {} 

308 if not inverse: 

309 transformed_values = [] 

310 for var in data_in.variables.values: 

311 data_var = data_in.sel(**{transformation_dim: [var]}) 

312 var_opts = opts.get(var, {}) 

313 _apply = (var_opts.get("mean", None) is not None) or (var_opts.get("min") is not None) 

314 values, new_var_opts = locals()["f_apply" if _apply else "f"](data_var, **var_opts) 

315 opts_updated[var] = copy.deepcopy(new_var_opts) 

316 transformed_values.append(values) 

317 return xr.concat(transformed_values, dim=transformation_dim), opts_updated 

318 else: 

319 return self.inverse_transform(data_in, opts, transformation_dim) 

320 

321 

322class DataHandlerFirFilter(DataHandlerFilter): 

323 """Data handler using FIR filtered data.""" 

324 

325 data_handler = DataHandlerFirFilterSingleStation 

326 data_handler_transformation = DataHandlerFirFilterSingleStation 

327 _requirements = data_handler.requirements() 

328 

329 

330class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation): 

331 """ 

332 Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered. In contrast to 

333 the simple DataHandlerFirFilterSingleStation, this data handler is centered around t0 to have no time delay. For 

334 values in the future (t > t0), this data handler assumes a climatological value for the low pass data and values of 

335 0 for all residuum components. 

336 

337 :param apriori: Data to use as apriori information. This should be either a xarray dataarray containing monthly or 

338 any other heuristic to support the clim filter, or a list of such arrays containing heuristics for all residua 

339 in addition. The 2nd can be used together with apriori_type `residuum_stats` which estimates the error of the 

340 residuum when the clim filter should be applied with exogenous parameters. If apriori_type is None/`zeros` data 

341 can be provided, but this is not required in this case. 

342 :param apriori_type: set type of information that is provided to the clim filter. For the first low pass always a 

343 calculated or given statistic is used. For residuum prediction a constant value of zero is assumed if 

344 apriori_type is None or `zeros`, and a climatology of the residuum is used for `residuum_stats`. 

345 :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by 

346 parameter apriori_type. This is only applicable for hourly resolution data. 

347 :param apriori_sel_opts: specify some parameters to select a subset of data before calculating the apriori 

348 information. Use this parameter for example, if apriori shall only calculated on a shorter time period than 

349 available in given data. 

350 :param extend_length_opts: use this parameter to use future data in the filter calculation. This parameter does not 

351 affect the size of the history samples as this is handled by the window_history_size parameter. Example: set 

352 extend_length_opts=7*24 to use the observation of the next 7 days to calculate the filtered components. Which 

353 data are finally used for the input samples is not affected by these 7 days. In case the range of history sample 

354 exceeds the horizon of extend_length_opts, the history sample will also include data from climatological 

355 estimates. 

356 """ 

357 DEFAULT_EXTEND_LENGTH_OPTS = 0 

358 _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal", 

359 "extend_length_opts"] 

360 _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"] 

361 

362 def __init__(self, *args, apriori=None, apriori_type=None, apriori_diurnal=False, apriori_sel_opts=None, 

363 extend_length_opts=DEFAULT_EXTEND_LENGTH_OPTS, **kwargs): 

364 self.apriori_type = apriori_type 

365 self.climate_filter_coeff = None # coefficents of the used FIR filter 

366 self.apriori = apriori # exogenous apriori information or None to calculate from data (endogenous) 

367 self.apriori_diurnal = apriori_diurnal 

368 self.all_apriori = None # collection of all apriori information 

369 self.apriori_sel_opts = apriori_sel_opts # ensure to separate exogenous and endogenous information 

370 self.extend_length_opts = extend_length_opts 

371 super().__init__(*args, **kwargs) 

372 

373 @TimeTrackingWrapper 

374 def apply_filter(self): 

375 """Apply FIR filter only on inputs.""" 

376 self.apriori = self.apriori.get(str(self)) if isinstance(self.apriori, dict) else self.apriori 

377 logging.info(f"{self.station[0]}: call ClimateFIRFilter") 

378 climate_filter = ClimateFIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, 

379 self.filter_cutoff_freq, 

380 self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim, 

381 apriori_type=self.apriori_type, apriori=self.apriori, 

382 apriori_diurnal=self.apriori_diurnal, sel_opts=self.apriori_sel_opts, 

383 plot_path=self.plot_path, 

384 minimum_length=self.window_history_size, new_dim=self.window_dim, 

385 display_name=self.station[0], extend_length_opts=self.extend_length_opts, 

386 extend_end=self.window_history_end, plot_dates=self.plot_dates, 

387 offset=self.window_history_offset) 

388 self.climate_filter_coeff = climate_filter.filter_coefficients 

389 

390 # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori 

391 if self.apriori_type == "residuum_stats": 

392 self.apriori = climate_filter.apriori_data 

393 else: 

394 self.apriori = climate_filter.initial_apriori_data 

395 self.all_apriori = climate_filter.apriori_data 

396 

397 climate_filter_data = [c.sel({self.window_dim: slice(self.window_history_end-self.window_history_size, 

398 self.window_history_end)}) 

399 for c in climate_filter.filtered_data] 

400 

401 # create input data with filter index 

402 input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(add_unfiltered_index=False), 

403 name=self.filter_dim)) 

404 

405 self.input_data = input_data.sel({self.target_dim: self.variables}) 

406 

407 # this is just a code snippet to check the results of the filter 

408 # import matplotlib 

409 # matplotlib.use("TkAgg") 

410 # import matplotlib.pyplot as plt 

411 # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() 

412 # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") 

413 

414 def create_filter_index(self, add_unfiltered_index=True) -> pd.Index: 

415 """ 

416 Round cut off periods in days and append 'res' for residuum index. 

417 

418 Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append 

419 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition. 

420 """ 

421 index = np.round(self.filter_cutoff_period, 1) 

422 f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) 

423 index = list(map(f, index.tolist())) 

424 index = list(map(lambda x: str(x) + "d", index)) + ["res"] 

425 self.filter_dim_order = index 

426 return pd.Index(index, name=self.filter_dim) 

427 

428 def _create_lazy_data(self): 

429 return [self._data, self.meta, self.input_data, self.target_data, self.climate_filter_coeff, 

430 self.apriori, self.all_apriori, self.filter_dim_order] 

431 

432 def _extract_lazy(self, lazy_data): 

433 _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori, \ 

434 self.filter_dim_order = lazy_data 

435 DataHandlerSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data)) 

436 

437 @staticmethod 

438 def _prepare_filter_cutoff_period(filter_cutoff_period, fs): 

439 """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair.""" 

440 cutoff = [] 

441 removed = [] 

442 for i, period in enumerate(to_list(filter_cutoff_period)): 

443 if period > 2. / fs: 

444 cutoff.append(period) 

445 else: 

446 removed.append(i) 

447 return cutoff, removed 

448 

449 @staticmethod 

450 def _period_to_freq(cutoff_p): 

451 return [1. / x for x in cutoff_p] 

452 

453 def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: 

454 """ 

455 Create a xr.DataArray containing history data. As 'input_data' already consists of a dimension 'window', this 

456 method only shifts the data along 'window' dimension x times where x is given by 'window_history_offset'. 

457 Results are stored in history attribute. 

458 

459 :param dim_name_of_inputs: Name of dimension which contains the input variables 

460 :param window: this parameter is not used in the inherited method 

461 :param dim_name_of_shift: Dimension along shift will be applied 

462 """ 

463 self.history = self.input_data 

464 # from matplotlib import pyplot as plt 

465 # d = self.load_and_interpolate(0) 

466 # data.sel(datetime="2007-07-07 00:00").sum("filter").plot() 

467 # plt.plot(data.sel(datetime="2007-07-07 00:00").sum("filter").window.values, d.sel(datetime=slice("2007-07-05 00:00", "2007-07-07 16:00")).values.flatten()) 

468 # plt.plot(data.sel(datetime="2007-07-07 00:00").sum("filter").window.values, d.sel(datetime=slice("2007-07-05 00:00", "2007-07-11 16:00")).values.flatten()) 

469 

470 def call_transform(self, inverse=False): 

471 opts_input = self._transformation[0] 

472 self.input_data, opts_input = self.transform(self.input_data, dim=[self.time_dim, self.window_dim], 

473 inverse=inverse, opts=opts_input, 

474 transformation_dim=self.target_dim) 

475 opts_target = self._transformation[1] 

476 self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse, 

477 opts=opts_target, transformation_dim=self.target_dim) 

478 self._transformation = (opts_input, opts_target) 

479 

480 

481class DataHandlerClimateFirFilter(DataHandlerFilter): 

482 """Data handler using climatic adjusted FIR filtered data.""" 

483 

484 data_handler = DataHandlerClimateFirFilterSingleStation 

485 data_handler_transformation = DataHandlerClimateFirFilterSingleStation 

486 _requirements = data_handler.requirements() 

487 _store_attributes = data_handler.store_attributes() 

488