Coverage for mlair/data_handler/data_handler_single_station.py: 64%
378 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-30 10:40 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-30 10:40 +0000
1"""Data Preparation class to handle data processing for machine learning."""
3__author__ = 'Lukas Leufen, Felix Kleinert'
4__date__ = '2020-07-20'
6import copy
7import datetime as dt
8import gc
10import dill
11import hashlib
12import logging
13import os
14import ast
15from functools import reduce, partial
16from typing import Union, List, Iterable, Tuple, Dict, Optional
18import numpy as np
19import pandas as pd
20import xarray as xr
22from mlair.configuration import check_path_and_create
23from mlair import helpers
24from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict
25from mlair.data_handler.abstract_data_handler import AbstractDataHandler
26from mlair.helpers import data_sources, check_nested_equality
28# define a more general date type for type hinting
29date = Union[dt.date, dt.datetime]
30str_or_list = Union[str, List[str]]
31number = Union[float, int]
32num_or_list = Union[number, List[number]]
33data_or_none = Union[xr.DataArray, None]
36class DataHandlerSingleStation(AbstractDataHandler):
37 """
38 :param window_history_offset: used to shift t0 according to the specified value.
39 :param window_history_end: used to set the last time step that is used to create a sample. A negative value
40 indicates that not all values up to t0 are used, a positive values indicates usage of values at t>t0. Default
41 is 0.
42 """
43 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
44 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
45 'pblheight': 'maximum'}
46 DEFAULT_WINDOW_LEAD_TIME = 3
47 DEFAULT_WINDOW_HISTORY_SIZE = 13
48 DEFAULT_WINDOW_HISTORY_OFFSET = 0
49 DEFAULT_WINDOW_HISTORY_END = 0
50 DEFAULT_TIME_DIM = "datetime"
51 DEFAULT_TARGET_VAR = "o3"
52 DEFAULT_TARGET_DIM = "variables"
53 DEFAULT_ITER_DIM = "Stations"
54 DEFAULT_WINDOW_DIM = "window"
55 DEFAULT_SAMPLING = "daily"
56 DEFAULT_INTERPOLATION_LIMIT = 0
57 DEFAULT_INTERPOLATION_METHOD = "linear"
58 chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane",
59 "so2", "toluene"]
61 _hash = ["station", "statistics_per_var", "data_origin", "sampling", "target_dim", "target_var", "time_dim",
62 "iter_dim", "window_dim", "window_history_size", "window_history_offset", "window_lead_time",
63 "interpolation_limit", "interpolation_method", "variables", "window_history_end"]
65 def __init__(self, station, data_path, statistics_per_var=None, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING,
66 target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM,
67 iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM,
68 window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET,
69 window_history_end=DEFAULT_WINDOW_HISTORY_END, window_lead_time=DEFAULT_WINDOW_LEAD_TIME,
70 interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT,
71 interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD,
72 overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
73 min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None,
74 lazy_preprocessing: bool = False, overwrite_lazy_data=False, era5_data_path=None, era5_file_names=None,
75 ifs_data_path=None, ifs_file_names=None, **kwargs):
76 super().__init__()
77 self.station = helpers.to_list(station)
78 self.path = self.setup_data_path(data_path, sampling)
79 self.lazy = lazy_preprocessing
80 self.lazy_path = None
81 if self.lazy is True: 81 ↛ 82line 81 didn't jump to line 82, because the condition on line 81 was never true
82 self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__)
83 check_path_and_create(self.lazy_path)
84 self.statistics_per_var = statistics_per_var or self.DEFAULT_VAR_ALL_DICT
85 self.data_origin = data_origin
86 self.do_transformation = transformation is not None
87 self.input_data, self.target_data = None, None
88 self._transformation = self.setup_transformation(transformation)
90 self.sampling = sampling
91 self.target_dim = target_dim
92 self.target_var = target_var
93 self.time_dim = time_dim
94 self.iter_dim = iter_dim
95 self.window_dim = window_dim
96 self.window_history_size = window_history_size
97 self.window_history_offset = window_history_offset
98 self.window_history_end = window_history_end
99 self.window_lead_time = window_lead_time
101 self.interpolation_limit = interpolation_limit
102 self.interpolation_method = interpolation_method
104 self.overwrite_local_data = overwrite_local_data
105 self.overwrite_lazy_data = True if self.overwrite_local_data is True else overwrite_lazy_data
106 self.store_data_locally = store_data_locally
107 self.min_length = min_length
108 self.start = start
109 self.end = end
111 # internal
112 self._data: xr.DataArray = None # loaded raw data
113 self.meta = None
114 self.variables = sorted(list(statistics_per_var.keys())) if variables is None else variables
115 self.history = None
116 self.label = None
117 self.observation = None
119 self._era5_data_path = era5_data_path
120 self._era5_file_names = era5_file_names
121 self._ifs_data_path = ifs_data_path
122 self._ifs_file_names = ifs_file_names
124 # create samples
125 self.setup_samples()
126 self.clean_up()
128 def clean_up(self):
129 self._data = None
130 self.input_data = None
131 self.target_data = None
132 gc.collect()
134 def __str__(self):
135 return self.station[0]
137 def __len__(self):
138 assert len(self.get_X()) == len(self.get_Y())
139 return len(self.get_X())
141 @property
142 def shape(self):
143 return self._data.shape, self.get_X().shape, self.get_Y().shape
145 def __repr__(self):
146 return f"StationPrep(station={self.station}, data_path='{self.path}', data_origin={self.data_origin}, " \
147 f"statistics_per_var={self.statistics_per_var}, " \
148 f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \
149 f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \
150 f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \
151 f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data})"
153 def get_transposed_history(self) -> xr.DataArray:
154 """Return history.
156 :return: history with dimensions datetime, window, Stations, variables.
157 """
158 return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim).copy()
160 def get_transposed_label(self) -> xr.DataArray:
161 """Return label.
163 :return: label with dimensions datetime*, window*, Stations, variables.
164 """
165 return self.label.squeeze([self.iter_dim, self.target_dim]).transpose(self.time_dim, self.window_dim).copy()
167 def get_X(self, **kwargs):
168 return self.get_transposed_history().sel({self.target_dim: self.variables})
170 def get_Y(self, **kwargs):
171 return self.get_transposed_label()
173 def get_coordinates(self):
174 try:
175 coords = self.meta.loc[["station_lon", "station_lat"]].astype(float)
176 coords = coords.rename(index={"station_lon": "lon", "station_lat": "lat"})
177 except KeyError:
178 coords = self.meta.loc[["lon", "lat"]].astype(float)
179 return coords.to_dict()[str(self)]
181 def call_transform(self, inverse=False):
182 opts_input = self._transformation[0]
183 self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse,
184 opts=opts_input, transformation_dim=self.target_dim)
185 opts_target = self._transformation[1]
186 self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse,
187 opts=opts_target, transformation_dim=self.target_dim)
188 self._transformation = (opts_input, opts_target)
190 def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None,
191 transformation_dim=DEFAULT_TARGET_DIM):
192 """
193 Transform data according to given transformation settings.
195 This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
196 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
197 (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
198 method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
199 internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
201 :param string/int dim: This param is not used for inverse transformation.
202 | for xarray.DataArray as string: name of dimension which should be standardised
203 | for pandas.DataFrame as int: axis of dimension which should be standardised
204 :param inverse: Switch between transformation and inverse transformation.
206 :return: xarray.DataArrays or pandas.DataFrames:
207 #. mean: Mean of data
208 #. std: Standard deviation of data
209 #. data: Standardised data
210 """
212 def f(data, method="standardise", feature_range=None):
213 if method == "standardise": 213 ↛ 215line 213 didn't jump to line 215, because the condition on line 213 was never false
214 return statistics.standardise(data, dim)
215 elif method == "centre":
216 return statistics.centre(data, dim)
217 elif method == "min_max":
218 kwargs = {"feature_range": feature_range} if feature_range is not None else {}
219 return statistics.min_max(data, dim, **kwargs)
220 elif method == "log":
221 return statistics.log(data, dim)
222 else:
223 raise NotImplementedError
225 def f_apply(data, method, **kwargs):
226 for k, v in kwargs.items():
227 if not (isinstance(v, xr.DataArray) or v is None): 227 ↛ 228line 227 didn't jump to line 228, because the condition on line 227 was never true
228 _, opts = statistics.min_max(data, dim)
229 helper = xr.ones_like(opts['min'])
230 kwargs[k] = helper * v
231 mean = kwargs.pop('mean', None)
232 std = kwargs.pop('std', None)
233 min = kwargs.pop('min', None)
234 max = kwargs.pop('max', None)
235 feature_range = kwargs.pop('feature_range', None)
237 if method == "standardise": 237 ↛ 239line 237 didn't jump to line 239, because the condition on line 237 was never false
238 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
239 elif method == "centre":
240 return statistics.centre_apply(data, mean), {"mean": mean, "method": method}
241 elif method == "min_max":
242 kws = {"feature_range": feature_range} if feature_range is not None else {}
243 return statistics.min_max_apply(data, min, max, **kws), {"min": min, "max": max, "method": method,
244 "feature_range": feature_range}
245 elif method == "log":
246 return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
247 else:
248 raise NotImplementedError
250 opts = opts or {}
251 opts_updated = {}
252 if not inverse: 252 ↛ 263line 252 didn't jump to line 263, because the condition on line 252 was never false
253 transformed_values = []
254 for var in data_in.variables.values:
255 data_var = data_in.sel(**{transformation_dim: [var]})
256 var_opts = opts.get(var, {})
257 _apply = (var_opts.get("mean", None) is not None) or (var_opts.get("min") is not None)
258 values, new_var_opts = locals()["f_apply" if _apply else "f"](data_var, **var_opts)
259 opts_updated[var] = copy.deepcopy(new_var_opts)
260 transformed_values.append(values)
261 return xr.concat(transformed_values, dim=transformation_dim), opts_updated
262 else:
263 return self.inverse_transform(data_in, opts, transformation_dim)
265 @TimeTrackingWrapper
266 def setup_samples(self):
267 """
268 Setup samples. This method prepares and creates samples X, and labels Y.
269 """
270 if self.lazy is False: 270 ↛ 273line 270 didn't jump to line 273, because the condition on line 270 was never false
271 self.make_input_target()
272 else:
273 self.load_lazy()
274 self.store_lazy()
275 if self.do_transformation is True:
276 self.call_transform()
277 self.make_samples()
279 def store_lazy(self):
280 hash = self._get_hash()
281 filename = os.path.join(self.lazy_path, hash + ".pickle")
282 if not os.path.exists(filename):
283 dill.dump(self._create_lazy_data(), file=open(filename, "wb"), protocol=4)
285 def _create_lazy_data(self):
286 return [self._data, self.meta, self.input_data, self.target_data]
288 def load_lazy(self):
289 hash = self._get_hash()
290 filename = os.path.join(self.lazy_path, hash + ".pickle")
291 try:
292 if self.overwrite_lazy_data is True:
293 os.remove(filename)
294 raise FileNotFoundError
295 with open(filename, "rb") as pickle_file:
296 lazy_data = dill.load(pickle_file)
297 self._extract_lazy(lazy_data)
298 logging.debug(f"{self.station[0]}: used lazy data")
299 except FileNotFoundError:
300 logging.debug(f"{self.station[0]}: could not use lazy data")
301 self.make_input_target()
303 def _extract_lazy(self, lazy_data):
304 _data, self.meta, _input_data, _target_data = lazy_data
305 f_prep = partial(self._slice_prep, start=self.start, end=self.end)
306 self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
308 def make_input_target(self):
309 vars = [*self.variables, self.target_var]
310 stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars)
311 data_origin = helpers.select_from_dict(self.data_origin, vars)
312 data, self.meta = self.load_data(self.path, self.station, stats_per_var, self.sampling,
313 self.store_data_locally, data_origin, self.start, self.end)
314 self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
315 limit=self.interpolation_limit, sampling=self.sampling)
316 self.set_inputs_and_targets()
318 def set_inputs_and_targets(self):
319 inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)})
320 targets = self._data.sel(
321 {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim??
322 self.input_data = inputs
323 self.target_data = targets
325 def make_samples(self):
326 self.make_history_window(self.target_dim, self.window_history_size, self.time_dim)
327 self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time)
328 self.make_observation(self.target_dim, self.target_var, self.time_dim)
329 self.remove_nan(self.time_dim)
331 def load_data(self, path, station, statistics_per_var, sampling, store_data_locally=False,
332 data_origin: Dict = None, start=None, end=None):
333 """
334 Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
336 Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both
337 cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
338 set, it is assumed, that data should be saved locally.
339 """
340 check_path_and_create(path)
341 file_name = self._set_file_name(path, station, statistics_per_var)
342 meta_file = self._set_meta_file_name(path, station, statistics_per_var)
343 if self.overwrite_local_data is True: 343 ↛ 344line 343 didn't jump to line 344, because the condition on line 343 was never true
344 logging.debug(f"{self.station[0]}: overwrite_local_data is true, therefore reload {file_name}")
345 if os.path.exists(file_name):
346 os.remove(file_name)
347 if os.path.exists(meta_file):
348 os.remove(meta_file)
349 data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling,
350 store_data_locally=store_data_locally, data_origin=data_origin,
351 time_dim=self.time_dim, target_dim=self.target_dim,
352 iter_dim=self.iter_dim, window_dim=self.window_dim,
353 era5_data_path=self._era5_data_path,
354 era5_file_names=self._era5_file_names,
355 ifs_data_path=self._ifs_data_path,
356 ifs_file_names=self._ifs_file_names)
357 logging.debug(f"{self.station[0]}: loaded new data")
358 else:
359 try:
360 logging.debug(f"{self.station[0]}: try to load local data from: {file_name}")
361 data = xr.open_dataarray(file_name)
362 meta = pd.read_csv(meta_file, index_col=0)
363 self.check_station_meta(meta, station, data_origin, statistics_per_var)
364 logging.debug(f"{self.station[0]}: loading finished")
365 except FileNotFoundError as e:
366 logging.debug(f"{self.station[0]}: {e}")
367 logging.debug(f"{self.station[0]}: load new data")
368 data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling,
369 store_data_locally=store_data_locally, data_origin=data_origin,
370 time_dim=self.time_dim, target_dim=self.target_dim,
371 iter_dim=self.iter_dim, era5_data_path=self._era5_data_path,
372 era5_file_names=self._era5_file_names,
373 ifs_data_path=self._ifs_data_path,
374 ifs_file_names=self._ifs_file_names)
375 logging.debug(f"{self.station[0]}: loading finished")
376 # create slices and check for negative concentration.
377 data = self._slice_prep(data, start=start, end=end)
378 data = self.check_for_negative_concentrations(data)
379 return data, meta
381 @staticmethod
382 def check_station_meta(meta, station, data_origin, statistics_per_var):
383 """
384 Search for the entries in meta data and compare the value with the requested values.
386 Will raise a FileNotFoundError if the values mismatch.
387 """
388 check_dict = {"data_origin": data_origin, "statistics_per_var": statistics_per_var}
389 for (k, v) in check_dict.items():
390 if v is None or k not in meta.index: 390 ↛ 391line 390 didn't jump to line 391, because the condition on line 390 was never true
391 continue
392 m = ast.literal_eval(meta.at[k, station[0]])
393 if not check_nested_equality(select_from_dict(m, v.keys()), v): 393 ↛ 394line 393 didn't jump to line 394, because the condition on line 393 was never true
394 logging.debug(f"{station[0]}: meta data does not agree with given request for {k}: {v} (requested) != "
395 f"{m} (local). Raise FileNotFoundError to trigger new grapping from web.")
396 raise FileNotFoundError
398 def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
399 """
400 Set all negative concentrations to zero.
402 Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
403 #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
404 "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
406 :param data: data array containing variables to check
407 :param minimum: minimum value, by default this should be 0
409 :return: corrected data
410 """
411 used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values))
412 if len(used_chem_vars) > 0: 412 ↛ 414line 412 didn't jump to line 414, because the condition on line 412 was never false
413 data = data.sel({self.target_dim: used_chem_vars}).clip(min=minimum).combine_first(data)
414 return data
416 def setup_data_path(self, data_path: str, sampling: str):
417 return os.path.join(os.path.abspath(data_path), sampling)
419 def shift(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray:
420 """
421 Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0).
423 :param data: data set to shift
424 :param dim: dimension along shift is applied
425 :param window: number of steps to shift (corresponds to the window length)
426 :param offset: use offset to move the window by as many time steps as given in offset. This can be used, if the
427 index time of a history element is not the last timestamp. E.g. you could use offset=23 when dealing with
428 hourly data in combination with daily data (values from 00 to 23 are aggregated on 00 the same day).
430 :return: shifted data
431 """
432 start = 1
433 end = 1
434 if window <= 0:
435 start = window
436 else:
437 end = window + 1
438 res = []
439 _range = list(map(lambda x: x + offset, range(start, end)))
440 for w in _range:
441 res.append(data.shift({dim: -w}))
442 window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim)
443 res = xr.concat(res, dim=window_array)
444 return res
446 @staticmethod
447 def create_index_array(index_name: str, index_value: Iterable[int], squeeze_dim: str) -> xr.DataArray:
448 """
449 Create an 1D xr.DataArray with given index name and value.
451 :param index_name: name of dimension
452 :param index_value: values of this dimension
454 :return: this array
455 """
456 ind = pd.DataFrame({'val': index_value}, index=index_value)
457 res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze(
458 dim=squeeze_dim, drop=True)
459 res.name = index_name
460 return res
462 @staticmethod
463 def _set_file_name(path, station, statistics_per_var):
464 all_vars = sorted(statistics_per_var.keys())
465 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc")
467 @staticmethod
468 def _set_meta_file_name(path, station, statistics_per_var):
469 all_vars = sorted(statistics_per_var.keys())
470 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv")
472 def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None,
473 use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs):
474 """
475 Interpolate values according to different methods.
477 (Copy paste from dataarray.interpolate_na)
479 :param dim:
480 Specifies the dimension along which to interpolate.
481 :param method:
482 {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
483 'polynomial', 'barycentric', 'krog', 'pchip',
484 'spline', 'akima'}, optional
485 String indicating which method to use for interpolation:
487 - 'linear': linear interpolation (Default). Additional keyword
488 arguments are passed to ``numpy.interp``
489 - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
490 'polynomial': are passed to ``scipy.interpolate.interp1d``. If
491 method=='polynomial', the ``order`` keyword argument must also be
492 provided.
493 - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their
494 respective``scipy.interpolate`` classes.
495 :param limit:
496 default None
497 Maximum number of consecutive NaNs to fill. Must be greater than 0
498 or None for no limit.
499 :param use_coordinate:
500 default True
501 Specifies which index to use as the x values in the interpolation
502 formulated as `y = f(x)`. If False, values are treated as if
503 eqaully-spaced along `dim`. If True, the IndexVariable `dim` is
504 used. If use_coordinate is a string, it specifies the name of a
505 coordinate variariable to use as the index.
506 :param kwargs:
508 :return: xarray.DataArray
509 """
510 data = self.create_full_time_dim(data, dim, sampling)
511 return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs)
513 @staticmethod
514 def create_full_time_dim(data, dim, sampling):
515 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
516 start = data.coords[dim].values[0]
517 end = data.coords[dim].values[-1]
518 freq = {"daily": "1D", "hourly": "1H"}.get(sampling)
519 datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq))
520 t = data.sel({dim: start}, drop=True)
521 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords])
522 res = res.transpose(*data.dims)
523 res.loc[data.coords] = data
524 return res
526 def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
527 """
528 Create a xr.DataArray containing history data.
530 Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted
531 data. This is used to represent history in the data. Results are stored in history attribute.
533 :param dim_name_of_inputs: Name of dimension which contains the input variables
534 :param window: number of time steps to look back in history
535 Note: window will be treated as negative value. This should be in agreement with looking back on
536 a time line. Nonetheless positive values are allowed but they are converted to its negative
537 expression
538 :param dim_name_of_shift: Dimension along shift will be applied
539 """
540 window = -abs(window)
541 data = self.input_data
542 offset = self.window_history_offset + self.window_history_end
543 self.history = self.shift(data, dim_name_of_shift, window, offset=offset)
545 def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
546 window: int) -> None:
547 """
548 Create a xr.DataArray containing labels.
550 Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label
551 attribute.
553 :param dim_name_of_target: Name of dimension which contains the target variable
554 :param target_var: Name of target variable in 'dimension'
555 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
556 :param window: lead time of label
557 """
558 window = abs(window)
559 data = self.target_data
560 self.label = self.shift(data, dim_name_of_shift, window)
562 def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
563 """
564 Create a xr.DataArray containing observations.
566 Observations are defined as value of the current time step t. Set observation attribute.
568 :param dim_name_of_target: Name of dimension which contains the observation variable
569 :param target_var: Name of observation variable(s) in 'dimension'
570 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
571 """
572 data = self.target_data
573 self.observation = self.shift(data, dim_name_of_shift, 0)
575 def remove_nan(self, dim: str) -> None:
576 """
577 Remove all NAs slices along dim which contain nans in history, label and observation.
579 This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute.
581 :param dim: dimension along the remove is performed.
582 """
583 intersect = []
584 if (self.history is not None) and (self.label is not None): 584 ↛ 597line 584 didn't jump to line 597, because the condition on line 584 was never false
585 non_nan_history = self.history.dropna(dim=dim)
586 non_nan_label = self.label.dropna(dim=dim)
587 non_nan_observation = self.observation.dropna(dim=dim)
588 if non_nan_label.coords[dim].shape[0] == 0: 588 ↛ 589line 588 didn't jump to line 589, because the condition on line 588 was never true
589 raise ValueError(f'self.label consist of NaNs only - station {self.station} is therefore dropped')
590 if non_nan_history.coords[dim].shape[0] == 0:
591 raise ValueError(f'self.history consist of NaNs only - station {self.station} is therefore dropped')
592 if non_nan_observation.coords[dim].shape[0] == 0: 592 ↛ 593line 592 didn't jump to line 593, because the condition on line 592 was never true
593 raise ValueError(f'self.observation consist of NaNs only - station {self.station} is therefore dropped')
594 intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values,
595 non_nan_observation.coords[dim].values))
597 if len(intersect) < max(self.min_length, 1): 597 ↛ 598line 597 didn't jump to line 598, because the condition on line 597 was never true
598 self.history = None
599 self.label = None
600 self.observation = None
601 else:
602 self.history = self.history.sel({dim: intersect})
603 self.label = self.label.sel({dim: intersect})
604 self.observation = self.observation.sel({dim: intersect})
606 def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray:
607 """
608 Set start and end date for slicing and execute self._slice().
610 :param data: data to slice
611 :param coord: name of axis to slice
613 :return: sliced data
614 """
615 start = start if start is not None else data.coords[self.time_dim][0].values
616 end = end if end is not None else data.coords[self.time_dim][-1].values
617 return self._slice(data, start, end, self.time_dim)
619 @staticmethod
620 def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray:
621 """
622 Slice through a given data_item (for example select only values of 2011).
624 :param data: data to slice
625 :param start: start date of slice
626 :param end: end date of slice
627 :param coord: name of axis to slice
629 :return: sliced data
630 """
631 return data.loc[{coord: slice(str(start), str(end))}]
633 def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
634 """
635 Set up transformation by extracting all relevant information.
637 * Either return new empty DataClass instances if given transformation arg is None,
638 * or return given object twice if transformation is a DataClass instance,
639 * or return the inputs and targets attributes if transformation is a TransformationClass instance (default
640 design behaviour)
641 """
642 if transformation is None:
643 return None, None
644 elif isinstance(transformation, dict):
645 return copy.deepcopy(transformation), copy.deepcopy(transformation)
646 elif isinstance(transformation, tuple) and len(transformation) == 2:
647 return copy.deepcopy(transformation)
648 else:
649 raise NotImplementedError("Cannot handle this.")
651 @staticmethod
652 def check_inverse_transform_params(method: str, mean=None, std=None, min=None, max=None) -> None:
653 """
654 Support inverse_transformation method.
656 Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas
657 normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements.
659 :param mean: data with all mean values
660 :param std: data with all standard deviation values
661 :param method: name of transformation method
662 """
663 msg = ""
664 if method in ['standardise', 'centre'] and mean is None:
665 msg += "mean, "
666 if method == 'standardise' and std is None:
667 msg += "std, "
668 if method == "min_max" and min is None:
669 msg += "min, "
670 if method == "min_max" and max is None:
671 msg += "max, "
672 if len(msg) > 0:
673 raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}")
675 def inverse_transform(self, data_in, opts, transformation_dim) -> xr.DataArray:
676 """
677 Perform inverse transformation.
679 Will raise an AssertionError, if no transformation was performed before. Checks first, if all required
680 statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by
681 new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the
682 current data is not transformed.
683 """
685 def f_inverse(data, method, mean=None, std=None, min=None, max=None, feature_range=None):
686 if method == "standardise":
687 return statistics.standardise_inverse(data, mean, std)
688 elif method == "centre":
689 return statistics.centre_inverse(data, mean)
690 elif method == "min_max":
691 return statistics.min_max_inverse(data, min, max, feature_range)
692 elif method == "log":
693 return statistics.log_inverse(data, mean, std)
694 else:
695 raise NotImplementedError
697 transformed_values = []
698 squeeze = False
699 if transformation_dim in data_in.coords:
700 if transformation_dim not in data_in.dims:
701 data_in = data_in.expand_dims(transformation_dim)
702 squeeze = True
703 else:
704 raise IndexError(f"Could not find given dimension: {transformation_dim}. Available is: {data_in.coords}")
705 for var in data_in.variables.values:
706 data_var = data_in.sel(**{transformation_dim: [var]})
707 var_opts = opts.get(var, {})
708 _method = var_opts.get("method", None)
709 if _method is None:
710 raise AssertionError(f"Inverse transformation method is not set for {var}.")
711 self.check_inverse_transform_params(**var_opts)
712 values = f_inverse(data_var, **var_opts)
713 transformed_values.append(values)
714 res = xr.concat(transformed_values, dim=transformation_dim)
715 return res.squeeze(transformation_dim) if squeeze else res
717 def apply_transformation(self, data, base=None, dim=0, inverse=False):
718 """
719 Apply transformation on external data. Specify if transformation should be based on parameters related to input
720 or target data using `base`. This method can also apply inverse transformation.
722 :param data:
723 :param base:
724 :param dim:
725 :param inverse:
726 :return:
727 """
728 if base in ["target", 1]:
729 pos = 1
730 elif base in ["input", 0]:
731 pos = 0
732 else:
733 raise ValueError("apply transformation requires a reference for transformation options. Please specify if"
734 "you want to use input or target transformation using the parameter 'base'. Given was: " +
735 base)
736 return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse,
737 transformation_dim=self.target_dim)
739 def _hash_list(self):
740 return sorted(list(set(self._hash)))
742 def _get_hash(self):
743 hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode()
744 return hashlib.md5(hash).hexdigest()