Coverage for mlair/data_handler/data_handler_single_station.py: 63%
370 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1"""Data Preparation class to handle data processing for machine learning."""
3__author__ = 'Lukas Leufen, Felix Kleinert'
4__date__ = '2020-07-20'
6import copy
7import datetime as dt
8import gc
10import dill
11import hashlib
12import logging
13import os
14from functools import reduce, partial
15from typing import Union, List, Iterable, Tuple, Dict, Optional
17import numpy as np
18import pandas as pd
19import xarray as xr
21from mlair.configuration import check_path_and_create
22from mlair import helpers
23from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict
24from mlair.data_handler.abstract_data_handler import AbstractDataHandler
25from mlair.helpers import data_sources
27# define a more general date type for type hinting
28date = Union[dt.date, dt.datetime]
29str_or_list = Union[str, List[str]]
30number = Union[float, int]
31num_or_list = Union[number, List[number]]
32data_or_none = Union[xr.DataArray, None]
35class DataHandlerSingleStation(AbstractDataHandler):
36 """
37 :param window_history_offset: used to shift t0 according to the specified value.
38 :param window_history_end: used to set the last time step that is used to create a sample. A negative value
39 indicates that not all values up to t0 are used, a positive values indicates usage of values at t>t0. Default
40 is 0.
41 """
42 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
43 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
44 'pblheight': 'maximum'}
45 DEFAULT_WINDOW_LEAD_TIME = 3
46 DEFAULT_WINDOW_HISTORY_SIZE = 13
47 DEFAULT_WINDOW_HISTORY_OFFSET = 0
48 DEFAULT_WINDOW_HISTORY_END = 0
49 DEFAULT_TIME_DIM = "datetime"
50 DEFAULT_TARGET_VAR = "o3"
51 DEFAULT_TARGET_DIM = "variables"
52 DEFAULT_ITER_DIM = "Stations"
53 DEFAULT_WINDOW_DIM = "window"
54 DEFAULT_SAMPLING = "daily"
55 DEFAULT_INTERPOLATION_LIMIT = 0
56 DEFAULT_INTERPOLATION_METHOD = "linear"
57 chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane",
58 "so2", "toluene"]
60 _hash = ["station", "statistics_per_var", "data_origin", "sampling", "target_dim", "target_var", "time_dim",
61 "iter_dim", "window_dim", "window_history_size", "window_history_offset", "window_lead_time",
62 "interpolation_limit", "interpolation_method", "variables", "window_history_end"]
64 def __init__(self, station, data_path, statistics_per_var=None, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING,
65 target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM,
66 iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM,
67 window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET,
68 window_history_end=DEFAULT_WINDOW_HISTORY_END, window_lead_time=DEFAULT_WINDOW_LEAD_TIME,
69 interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT,
70 interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD,
71 overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
72 min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None,
73 lazy_preprocessing: bool = False, overwrite_lazy_data=False, **kwargs):
74 super().__init__()
75 self.station = helpers.to_list(station)
76 self.path = self.setup_data_path(data_path, sampling)
77 self.lazy = lazy_preprocessing
78 self.lazy_path = None
79 if self.lazy is True: 79 ↛ 80line 79 didn't jump to line 80, because the condition on line 79 was never true
80 self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__)
81 check_path_and_create(self.lazy_path)
82 self.statistics_per_var = statistics_per_var or self.DEFAULT_VAR_ALL_DICT
83 self.data_origin = data_origin
84 self.do_transformation = transformation is not None
85 self.input_data, self.target_data = None, None
86 self._transformation = self.setup_transformation(transformation)
88 self.sampling = sampling
89 self.target_dim = target_dim
90 self.target_var = target_var
91 self.time_dim = time_dim
92 self.iter_dim = iter_dim
93 self.window_dim = window_dim
94 self.window_history_size = window_history_size
95 self.window_history_offset = window_history_offset
96 self.window_history_end = window_history_end
97 self.window_lead_time = window_lead_time
99 self.interpolation_limit = interpolation_limit
100 self.interpolation_method = interpolation_method
102 self.overwrite_local_data = overwrite_local_data
103 self.overwrite_lazy_data = True if self.overwrite_local_data is True else overwrite_lazy_data
104 self.store_data_locally = store_data_locally
105 self.min_length = min_length
106 self.start = start
107 self.end = end
109 # internal
110 self._data: xr.DataArray = None # loaded raw data
111 self.meta = None
112 self.variables = sorted(list(statistics_per_var.keys())) if variables is None else variables
113 self.history = None
114 self.label = None
115 self.observation = None
117 # create samples
118 self.setup_samples()
119 self.clean_up()
121 def clean_up(self):
122 self._data = None
123 self.input_data = None
124 self.target_data = None
125 gc.collect()
127 def __str__(self):
128 return self.station[0]
130 def __len__(self):
131 assert len(self.get_X()) == len(self.get_Y())
132 return len(self.get_X())
134 @property
135 def shape(self):
136 return self._data.shape, self.get_X().shape, self.get_Y().shape
138 def __repr__(self):
139 return f"StationPrep(station={self.station}, data_path='{self.path}', data_origin={self.data_origin}, " \
140 f"statistics_per_var={self.statistics_per_var}, " \
141 f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \
142 f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \
143 f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \
144 f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data})"
146 def get_transposed_history(self) -> xr.DataArray:
147 """Return history.
149 :return: history with dimensions datetime, window, Stations, variables.
150 """
151 return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim).copy()
153 def get_transposed_label(self) -> xr.DataArray:
154 """Return label.
156 :return: label with dimensions datetime*, window*, Stations, variables.
157 """
158 return self.label.squeeze([self.iter_dim, self.target_dim]).transpose(self.time_dim, self.window_dim).copy()
160 def get_X(self, **kwargs):
161 return self.get_transposed_history().sel({self.target_dim: self.variables})
163 def get_Y(self, **kwargs):
164 return self.get_transposed_label()
166 def get_coordinates(self):
167 try:
168 coords = self.meta.loc[["station_lon", "station_lat"]].astype(float)
169 coords = coords.rename(index={"station_lon": "lon", "station_lat": "lat"})
170 except KeyError:
171 coords = self.meta.loc[["lon", "lat"]].astype(float)
172 return coords.to_dict()[str(self)]
174 def call_transform(self, inverse=False):
175 opts_input = self._transformation[0]
176 self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse,
177 opts=opts_input, transformation_dim=self.target_dim)
178 opts_target = self._transformation[1]
179 self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse,
180 opts=opts_target, transformation_dim=self.target_dim)
181 self._transformation = (opts_input, opts_target)
183 def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None,
184 transformation_dim=DEFAULT_TARGET_DIM):
185 """
186 Transform data according to given transformation settings.
188 This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
189 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
190 (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
191 method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
192 internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
194 :param string/int dim: This param is not used for inverse transformation.
195 | for xarray.DataArray as string: name of dimension which should be standardised
196 | for pandas.DataFrame as int: axis of dimension which should be standardised
197 :param inverse: Switch between transformation and inverse transformation.
199 :return: xarray.DataArrays or pandas.DataFrames:
200 #. mean: Mean of data
201 #. std: Standard deviation of data
202 #. data: Standardised data
203 """
205 def f(data, method="standardise", feature_range=None):
206 if method == "standardise": 206 ↛ 208line 206 didn't jump to line 208, because the condition on line 206 was never false
207 return statistics.standardise(data, dim)
208 elif method == "centre":
209 return statistics.centre(data, dim)
210 elif method == "min_max":
211 kwargs = {"feature_range": feature_range} if feature_range is not None else {}
212 return statistics.min_max(data, dim, **kwargs)
213 elif method == "log":
214 return statistics.log(data, dim)
215 else:
216 raise NotImplementedError
218 def f_apply(data, method, **kwargs):
219 for k, v in kwargs.items():
220 if not (isinstance(v, xr.DataArray) or v is None): 220 ↛ 221line 220 didn't jump to line 221, because the condition on line 220 was never true
221 _, opts = statistics.min_max(data, dim)
222 helper = xr.ones_like(opts['min'])
223 kwargs[k] = helper * v
224 mean = kwargs.pop('mean', None)
225 std = kwargs.pop('std', None)
226 min = kwargs.pop('min', None)
227 max = kwargs.pop('max', None)
228 feature_range = kwargs.pop('feature_range', None)
230 if method == "standardise": 230 ↛ 232line 230 didn't jump to line 232, because the condition on line 230 was never false
231 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
232 elif method == "centre":
233 return statistics.centre_apply(data, mean), {"mean": mean, "method": method}
234 elif method == "min_max":
235 kws = {"feature_range": feature_range} if feature_range is not None else {}
236 return statistics.min_max_apply(data, min, max, **kws), {"min": min, "max": max, "method": method,
237 "feature_range": feature_range}
238 elif method == "log":
239 return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
240 else:
241 raise NotImplementedError
243 opts = opts or {}
244 opts_updated = {}
245 if not inverse: 245 ↛ 256line 245 didn't jump to line 256, because the condition on line 245 was never false
246 transformed_values = []
247 for var in data_in.variables.values:
248 data_var = data_in.sel(**{transformation_dim: [var]})
249 var_opts = opts.get(var, {})
250 _apply = (var_opts.get("mean", None) is not None) or (var_opts.get("min") is not None)
251 values, new_var_opts = locals()["f_apply" if _apply else "f"](data_var, **var_opts)
252 opts_updated[var] = copy.deepcopy(new_var_opts)
253 transformed_values.append(values)
254 return xr.concat(transformed_values, dim=transformation_dim), opts_updated
255 else:
256 return self.inverse_transform(data_in, opts, transformation_dim)
258 @TimeTrackingWrapper
259 def setup_samples(self):
260 """
261 Setup samples. This method prepares and creates samples X, and labels Y.
262 """
263 if self.lazy is False: 263 ↛ 266line 263 didn't jump to line 266, because the condition on line 263 was never false
264 self.make_input_target()
265 else:
266 self.load_lazy()
267 self.store_lazy()
268 if self.do_transformation is True:
269 self.call_transform()
270 self.make_samples()
272 def store_lazy(self):
273 hash = self._get_hash()
274 filename = os.path.join(self.lazy_path, hash + ".pickle")
275 if not os.path.exists(filename):
276 dill.dump(self._create_lazy_data(), file=open(filename, "wb"), protocol=4)
278 def _create_lazy_data(self):
279 return [self._data, self.meta, self.input_data, self.target_data]
281 def load_lazy(self):
282 hash = self._get_hash()
283 filename = os.path.join(self.lazy_path, hash + ".pickle")
284 try:
285 if self.overwrite_lazy_data is True:
286 os.remove(filename)
287 raise FileNotFoundError
288 with open(filename, "rb") as pickle_file:
289 lazy_data = dill.load(pickle_file)
290 self._extract_lazy(lazy_data)
291 logging.debug(f"{self.station[0]}: used lazy data")
292 except FileNotFoundError:
293 logging.debug(f"{self.station[0]}: could not use lazy data")
294 self.make_input_target()
296 def _extract_lazy(self, lazy_data):
297 _data, self.meta, _input_data, _target_data = lazy_data
298 f_prep = partial(self._slice_prep, start=self.start, end=self.end)
299 self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
301 def make_input_target(self):
302 data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
303 self.store_data_locally, self.data_origin, self.start, self.end)
304 self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
305 limit=self.interpolation_limit, sampling=self.sampling)
306 self.set_inputs_and_targets()
308 def set_inputs_and_targets(self):
309 inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)})
310 targets = self._data.sel(
311 {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim??
312 self.input_data = inputs
313 self.target_data = targets
315 def make_samples(self):
316 self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) #todo stopped here
317 self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time)
318 self.make_observation(self.target_dim, self.target_var, self.time_dim)
319 self.remove_nan(self.time_dim)
321 def load_data(self, path, station, statistics_per_var, sampling, store_data_locally=False,
322 data_origin: Dict = None, start=None, end=None):
323 """
324 Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
326 Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both
327 cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
328 set, it is assumed, that data should be saved locally.
329 """
330 check_path_and_create(path)
331 file_name = self._set_file_name(path, station, statistics_per_var)
332 meta_file = self._set_meta_file_name(path, station, statistics_per_var)
333 if self.overwrite_local_data is True: 333 ↛ 334line 333 didn't jump to line 334, because the condition on line 333 was never true
334 logging.debug(f"overwrite_local_data is true, therefore reload {file_name}")
335 if os.path.exists(file_name):
336 os.remove(file_name)
337 if os.path.exists(meta_file):
338 os.remove(meta_file)
339 data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling,
340 store_data_locally=store_data_locally, data_origin=data_origin,
341 time_dim=self.time_dim, target_dim=self.target_dim, iter_dim=self.iter_dim)
342 logging.debug(f"loaded new data")
343 else:
344 try:
345 logging.debug(f"try to load local data from: {file_name}")
346 data = xr.open_dataarray(file_name)
347 meta = pd.read_csv(meta_file, index_col=0)
348 self.check_station_meta(meta, station, data_origin, statistics_per_var)
349 logging.debug("loading finished")
350 except FileNotFoundError as e:
351 logging.debug(e)
352 logging.debug(f"load new data")
353 data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling,
354 store_data_locally=store_data_locally, data_origin=data_origin,
355 time_dim=self.time_dim, target_dim=self.target_dim,
356 iter_dim=self.iter_dim)
357 logging.debug("loading finished")
358 # create slices and check for negative concentration.
359 data = self._slice_prep(data, start=start, end=end)
360 data = self.check_for_negative_concentrations(data)
361 return data, meta
363 @staticmethod
364 def check_station_meta(meta, station, data_origin, statistics_per_var):
365 """
366 Search for the entries in meta data and compare the value with the requested values.
368 Will raise a FileNotFoundError if the values mismatch.
369 """
370 check_dict = {"data_origin": str(data_origin), "statistics_per_var": str(statistics_per_var)}
371 for (k, v) in check_dict.items():
372 if v is None or k not in meta.index: 372 ↛ 373line 372 didn't jump to line 373, because the condition on line 372 was never true
373 continue
374 if meta.at[k, station[0]] != v: 374 ↛ 375line 374 didn't jump to line 375, because the condition on line 374 was never true
375 logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
376 f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new "
377 f"grapping from web.")
378 raise FileNotFoundError
380 def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
381 """
382 Set all negative concentrations to zero.
384 Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
385 #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
386 "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
388 :param data: data array containing variables to check
389 :param minimum: minimum value, by default this should be 0
391 :return: corrected data
392 """
393 used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values))
394 if len(used_chem_vars) > 0: 394 ↛ 396line 394 didn't jump to line 396, because the condition on line 394 was never false
395 data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
396 return data
398 def setup_data_path(self, data_path: str, sampling: str):
399 return os.path.join(os.path.abspath(data_path), sampling)
401 def shift(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray:
402 """
403 Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0).
405 :param data: data set to shift
406 :param dim: dimension along shift is applied
407 :param window: number of steps to shift (corresponds to the window length)
408 :param offset: use offset to move the window by as many time steps as given in offset. This can be used, if the
409 index time of a history element is not the last timestamp. E.g. you could use offset=23 when dealing with
410 hourly data in combination with daily data (values from 00 to 23 are aggregated on 00 the same day).
412 :return: shifted data
413 """
414 start = 1
415 end = 1
416 if window <= 0:
417 start = window
418 else:
419 end = window + 1
420 res = []
421 _range = list(map(lambda x: x + offset, range(start, end)))
422 for w in _range:
423 res.append(data.shift({dim: -w}))
424 window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim)
425 res = xr.concat(res, dim=window_array)
426 return res
428 @staticmethod
429 def create_index_array(index_name: str, index_value: Iterable[int], squeeze_dim: str) -> xr.DataArray:
430 """
431 Create an 1D xr.DataArray with given index name and value.
433 :param index_name: name of dimension
434 :param index_value: values of this dimension
436 :return: this array
437 """
438 ind = pd.DataFrame({'val': index_value}, index=index_value)
439 res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze(
440 dim=squeeze_dim, drop=True)
441 res.name = index_name
442 return res
444 @staticmethod
445 def _set_file_name(path, station, statistics_per_var):
446 all_vars = sorted(statistics_per_var.keys())
447 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc")
449 @staticmethod
450 def _set_meta_file_name(path, station, statistics_per_var):
451 all_vars = sorted(statistics_per_var.keys())
452 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv")
454 def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None,
455 use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs):
456 """
457 Interpolate values according to different methods.
459 (Copy paste from dataarray.interpolate_na)
461 :param dim:
462 Specifies the dimension along which to interpolate.
463 :param method:
464 {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
465 'polynomial', 'barycentric', 'krog', 'pchip',
466 'spline', 'akima'}, optional
467 String indicating which method to use for interpolation:
469 - 'linear': linear interpolation (Default). Additional keyword
470 arguments are passed to ``numpy.interp``
471 - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
472 'polynomial': are passed to ``scipy.interpolate.interp1d``. If
473 method=='polynomial', the ``order`` keyword argument must also be
474 provided.
475 - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their
476 respective``scipy.interpolate`` classes.
477 :param limit:
478 default None
479 Maximum number of consecutive NaNs to fill. Must be greater than 0
480 or None for no limit.
481 :param use_coordinate:
482 default True
483 Specifies which index to use as the x values in the interpolation
484 formulated as `y = f(x)`. If False, values are treated as if
485 eqaully-spaced along `dim`. If True, the IndexVariable `dim` is
486 used. If use_coordinate is a string, it specifies the name of a
487 coordinate variariable to use as the index.
488 :param kwargs:
490 :return: xarray.DataArray
491 """
492 data = self.create_full_time_dim(data, dim, sampling)
493 return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs)
495 @staticmethod
496 def create_full_time_dim(data, dim, sampling):
497 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
498 start = data.coords[dim].values[0]
499 end = data.coords[dim].values[-1]
500 freq = {"daily": "1D", "hourly": "1H"}.get(sampling)
501 datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq))
502 t = data.sel({dim: start}, drop=True)
503 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords])
504 res = res.transpose(*data.dims)
505 res.loc[data.coords] = data
506 return res
508 def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
509 """
510 Create a xr.DataArray containing history data.
512 Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted
513 data. This is used to represent history in the data. Results are stored in history attribute.
515 :param dim_name_of_inputs: Name of dimension which contains the input variables
516 :param window: number of time steps to look back in history
517 Note: window will be treated as negative value. This should be in agreement with looking back on
518 a time line. Nonetheless positive values are allowed but they are converted to its negative
519 expression
520 :param dim_name_of_shift: Dimension along shift will be applied
521 """
522 window = -abs(window)
523 data = self.input_data
524 offset = self.window_history_offset + self.window_history_end
525 self.history = self.shift(data, dim_name_of_shift, window, offset=offset)
527 def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
528 window: int) -> None:
529 """
530 Create a xr.DataArray containing labels.
532 Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label
533 attribute.
535 :param dim_name_of_target: Name of dimension which contains the target variable
536 :param target_var: Name of target variable in 'dimension'
537 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
538 :param window: lead time of label
539 """
540 window = abs(window)
541 data = self.target_data
542 self.label = self.shift(data, dim_name_of_shift, window)
544 def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
545 """
546 Create a xr.DataArray containing observations.
548 Observations are defined as value of the current time step t. Set observation attribute.
550 :param dim_name_of_target: Name of dimension which contains the observation variable
551 :param target_var: Name of observation variable(s) in 'dimension'
552 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
553 """
554 data = self.target_data
555 self.observation = self.shift(data, dim_name_of_shift, 0)
557 def remove_nan(self, dim: str) -> None:
558 """
559 Remove all NAs slices along dim which contain nans in history, label and observation.
561 This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute.
563 :param dim: dimension along the remove is performed.
564 """
565 intersect = []
566 if (self.history is not None) and (self.label is not None): 566 ↛ 579line 566 didn't jump to line 579, because the condition on line 566 was never false
567 non_nan_history = self.history.dropna(dim=dim)
568 non_nan_label = self.label.dropna(dim=dim)
569 non_nan_observation = self.observation.dropna(dim=dim)
570 if non_nan_label.coords[dim].shape[0] == 0: 570 ↛ 571line 570 didn't jump to line 571, because the condition on line 570 was never true
571 raise ValueError(f'self.label consist of NaNs only - station {self.station} is therefore dropped')
572 if non_nan_history.coords[dim].shape[0] == 0:
573 raise ValueError(f'self.history consist of NaNs only - station {self.station} is therefore dropped')
574 if non_nan_observation.coords[dim].shape[0] == 0: 574 ↛ 575line 574 didn't jump to line 575, because the condition on line 574 was never true
575 raise ValueError(f'self.observation consist of NaNs only - station {self.station} is therefore dropped')
576 intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values,
577 non_nan_observation.coords[dim].values))
579 if len(intersect) < max(self.min_length, 1): 579 ↛ 580line 579 didn't jump to line 580, because the condition on line 579 was never true
580 self.history = None
581 self.label = None
582 self.observation = None
583 else:
584 self.history = self.history.sel({dim: intersect})
585 self.label = self.label.sel({dim: intersect})
586 self.observation = self.observation.sel({dim: intersect})
588 def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray:
589 """
590 Set start and end date for slicing and execute self._slice().
592 :param data: data to slice
593 :param coord: name of axis to slice
595 :return: sliced data
596 """
597 start = start if start is not None else data.coords[self.time_dim][0].values
598 end = end if end is not None else data.coords[self.time_dim][-1].values
599 return self._slice(data, start, end, self.time_dim)
601 @staticmethod
602 def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray:
603 """
604 Slice through a given data_item (for example select only values of 2011).
606 :param data: data to slice
607 :param start: start date of slice
608 :param end: end date of slice
609 :param coord: name of axis to slice
611 :return: sliced data
612 """
613 return data.loc[{coord: slice(str(start), str(end))}]
615 def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
616 """
617 Set up transformation by extracting all relevant information.
619 * Either return new empty DataClass instances if given transformation arg is None,
620 * or return given object twice if transformation is a DataClass instance,
621 * or return the inputs and targets attributes if transformation is a TransformationClass instance (default
622 design behaviour)
623 """
624 if transformation is None:
625 return None, None
626 elif isinstance(transformation, dict):
627 return copy.deepcopy(transformation), copy.deepcopy(transformation)
628 elif isinstance(transformation, tuple) and len(transformation) == 2:
629 return copy.deepcopy(transformation)
630 else:
631 raise NotImplementedError("Cannot handle this.")
633 @staticmethod
634 def check_inverse_transform_params(method: str, mean=None, std=None, min=None, max=None) -> None:
635 """
636 Support inverse_transformation method.
638 Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas
639 normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements.
641 :param mean: data with all mean values
642 :param std: data with all standard deviation values
643 :param method: name of transformation method
644 """
645 msg = ""
646 if method in ['standardise', 'centre'] and mean is None:
647 msg += "mean, "
648 if method == 'standardise' and std is None:
649 msg += "std, "
650 if method == "min_max" and min is None:
651 msg += "min, "
652 if method == "min_max" and max is None:
653 msg += "max, "
654 if len(msg) > 0:
655 raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}")
657 def inverse_transform(self, data_in, opts, transformation_dim) -> xr.DataArray:
658 """
659 Perform inverse transformation.
661 Will raise an AssertionError, if no transformation was performed before. Checks first, if all required
662 statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by
663 new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the
664 current data is not transformed.
665 """
667 def f_inverse(data, method, mean=None, std=None, min=None, max=None, feature_range=None):
668 if method == "standardise":
669 return statistics.standardise_inverse(data, mean, std)
670 elif method == "centre":
671 return statistics.centre_inverse(data, mean)
672 elif method == "min_max":
673 return statistics.min_max_inverse(data, min, max, feature_range)
674 elif method == "log":
675 return statistics.log_inverse(data, mean, std)
676 else:
677 raise NotImplementedError
679 transformed_values = []
680 squeeze = False
681 if transformation_dim in data_in.coords:
682 if transformation_dim not in data_in.dims:
683 data_in = data_in.expand_dims(transformation_dim)
684 squeeze = True
685 else:
686 raise IndexError(f"Could not find given dimension: {transformation_dim}. Available is: {data_in.coords}")
687 for var in data_in.variables.values:
688 data_var = data_in.sel(**{transformation_dim: [var]})
689 var_opts = opts.get(var, {})
690 _method = var_opts.get("method", None)
691 if _method is None:
692 raise AssertionError(f"Inverse transformation method is not set for {var}.")
693 self.check_inverse_transform_params(**var_opts)
694 values = f_inverse(data_var, **var_opts)
695 transformed_values.append(values)
696 res = xr.concat(transformed_values, dim=transformation_dim)
697 return res.squeeze(transformation_dim) if squeeze else res
699 def apply_transformation(self, data, base=None, dim=0, inverse=False):
700 """
701 Apply transformation on external data. Specify if transformation should be based on parameters related to input
702 or target data using `base`. This method can also apply inverse transformation.
704 :param data:
705 :param base:
706 :param dim:
707 :param inverse:
708 :return:
709 """
710 if base in ["target", 1]:
711 pos = 1
712 elif base in ["input", 0]:
713 pos = 0
714 else:
715 raise ValueError("apply transformation requires a reference for transformation options. Please specify if"
716 "you want to use input or target transformation using the parameter 'base'. Given was: " +
717 base)
718 return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse,
719 transformation_dim=self.target_dim)
721 def _hash_list(self):
722 return sorted(list(set(self._hash)))
724 def _get_hash(self):
725 hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode()
726 return hashlib.md5(hash).hexdigest()