Coverage for mlair/data_handler/data_handler_single_station.py: 65%
402 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1"""Data Preparation class to handle data processing for machine learning."""
3__author__ = 'Lukas Leufen, Felix Kleinert'
4__date__ = '2020-07-20'
6import copy
7import datetime as dt
8import gc
10import dill
11import hashlib
12import logging
13import os
14from functools import reduce, partial
15from typing import Union, List, Iterable, Tuple, Dict, Optional
17import numpy as np
18import pandas as pd
19import xarray as xr
21from mlair.configuration import check_path_and_create
22from mlair import helpers
23from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict
24from mlair.data_handler.abstract_data_handler import AbstractDataHandler
25from mlair.helpers import data_sources
27# define a more general date type for type hinting
28date = Union[dt.date, dt.datetime]
29str_or_list = Union[str, List[str]]
30number = Union[float, int]
31num_or_list = Union[number, List[number]]
32data_or_none = Union[xr.DataArray, None]
35class DataHandlerSingleStation(AbstractDataHandler):
36 """
37 :param window_history_offset: used to shift t0 according to the specified value.
38 :param window_history_end: used to set the last time step that is used to create a sample. A negative value
39 indicates that not all values up to t0 are used, a positive values indicates usage of values at t>t0. Default
40 is 0.
41 """
42 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
43 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
44 'pblheight': 'maximum'}
45 DEFAULT_WINDOW_LEAD_TIME = 3
46 DEFAULT_WINDOW_HISTORY_SIZE = 13
47 DEFAULT_WINDOW_HISTORY_OFFSET = 0
48 DEFAULT_WINDOW_HISTORY_END = 0
49 DEFAULT_TIME_DIM = "datetime"
50 DEFAULT_TARGET_VAR = "o3"
51 DEFAULT_TARGET_DIM = "variables"
52 DEFAULT_ITER_DIM = "Stations"
53 DEFAULT_WINDOW_DIM = "window"
54 DEFAULT_SAMPLING = "daily"
55 DEFAULT_INTERPOLATION_LIMIT = 0
56 DEFAULT_INTERPOLATION_METHOD = "linear"
57 chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane",
58 "so2", "toluene"]
60 _hash = ["station", "statistics_per_var", "data_origin", "sampling", "target_dim", "target_var", "time_dim",
61 "iter_dim", "window_dim", "window_history_size", "window_history_offset", "window_lead_time",
62 "interpolation_limit", "interpolation_method", "variables", "window_history_end"]
64 def __init__(self, station, data_path, statistics_per_var=None, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING,
65 target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM,
66 iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM,
67 window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET,
68 window_history_end=DEFAULT_WINDOW_HISTORY_END, window_lead_time=DEFAULT_WINDOW_LEAD_TIME,
69 interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT,
70 interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD,
71 overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
72 min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None,
73 lazy_preprocessing: bool = False, overwrite_lazy_data=False, **kwargs):
74 super().__init__()
75 self.station = helpers.to_list(station)
76 self.path = self.setup_data_path(data_path, sampling)
77 self.lazy = lazy_preprocessing
78 self.lazy_path = None
79 if self.lazy is True: 79 ↛ 80line 79 didn't jump to line 80, because the condition on line 79 was never true
80 self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__)
81 check_path_and_create(self.lazy_path)
82 self.statistics_per_var = statistics_per_var or self.DEFAULT_VAR_ALL_DICT
83 self.data_origin = data_origin
84 self.do_transformation = transformation is not None
85 self.input_data, self.target_data = None, None
86 self._transformation = self.setup_transformation(transformation)
88 self.sampling = sampling
89 self.target_dim = target_dim
90 self.target_var = target_var
91 self.time_dim = time_dim
92 self.iter_dim = iter_dim
93 self.window_dim = window_dim
94 self.window_history_size = window_history_size
95 self.window_history_offset = window_history_offset
96 self.window_history_end = window_history_end
97 self.window_lead_time = window_lead_time
99 self.interpolation_limit = interpolation_limit
100 self.interpolation_method = interpolation_method
102 self.overwrite_local_data = overwrite_local_data
103 self.overwrite_lazy_data = True if self.overwrite_local_data is True else overwrite_lazy_data
104 self.store_data_locally = store_data_locally
105 self.min_length = min_length
106 self.start = start
107 self.end = end
109 # internal
110 self._data: xr.DataArray = None # loaded raw data
111 self.meta = None
112 self.variables = sorted(list(statistics_per_var.keys())) if variables is None else variables
113 self.history = None
114 self.label = None
115 self.observation = None
117 # create samples
118 self.setup_samples()
119 self.clean_up()
121 def clean_up(self):
122 self._data = None
123 self.input_data = None
124 self.target_data = None
125 gc.collect()
127 def __str__(self):
128 return self.station[0]
130 def __len__(self):
131 assert len(self.get_X()) == len(self.get_Y())
132 return len(self.get_X())
134 @property
135 def shape(self):
136 return self._data.shape, self.get_X().shape, self.get_Y().shape
138 def __repr__(self):
139 return f"StationPrep(station={self.station}, data_path='{self.path}', data_origin={self.data_origin}, " \
140 f"statistics_per_var={self.statistics_per_var}, " \
141 f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \
142 f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \
143 f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \
144 f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data})"
146 def get_transposed_history(self) -> xr.DataArray:
147 """Return history.
149 :return: history with dimensions datetime, window, Stations, variables.
150 """
151 return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim).copy()
153 def get_transposed_label(self) -> xr.DataArray:
154 """Return label.
156 :return: label with dimensions datetime*, window*, Stations, variables.
157 """
158 return self.label.squeeze([self.iter_dim, self.target_dim]).transpose(self.time_dim, self.window_dim).copy()
160 def get_X(self, **kwargs):
161 return self.get_transposed_history().sel({self.target_dim: self.variables})
163 def get_Y(self, **kwargs):
164 return self.get_transposed_label()
166 def get_coordinates(self):
167 try:
168 coords = self.meta.loc[["station_lon", "station_lat"]].astype(float)
169 coords = coords.rename(index={"station_lon": "lon", "station_lat": "lat"})
170 except KeyError:
171 coords = self.meta.loc[["lon", "lat"]].astype(float)
172 return coords.to_dict()[str(self)]
174 def call_transform(self, inverse=False):
175 opts_input = self._transformation[0]
176 self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse,
177 opts=opts_input, transformation_dim=self.target_dim)
178 opts_target = self._transformation[1]
179 self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse,
180 opts=opts_target, transformation_dim=self.target_dim)
181 self._transformation = (opts_input, opts_target)
183 def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None,
184 transformation_dim=DEFAULT_TARGET_DIM):
185 """
186 Transform data according to given transformation settings.
188 This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
189 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
190 (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
191 method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
192 internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
194 :param string/int dim: This param is not used for inverse transformation.
195 | for xarray.DataArray as string: name of dimension which should be standardised
196 | for pandas.DataFrame as int: axis of dimension which should be standardised
197 :param inverse: Switch between transformation and inverse transformation.
199 :return: xarray.DataArrays or pandas.DataFrames:
200 #. mean: Mean of data
201 #. std: Standard deviation of data
202 #. data: Standardised data
203 """
205 def f(data, method="standardise", feature_range=None):
206 if method == "standardise": 206 ↛ 208line 206 didn't jump to line 208, because the condition on line 206 was never false
207 return statistics.standardise(data, dim)
208 elif method == "centre":
209 return statistics.centre(data, dim)
210 elif method == "min_max":
211 kwargs = {"feature_range": feature_range} if feature_range is not None else {}
212 return statistics.min_max(data, dim, **kwargs)
213 elif method == "log":
214 return statistics.log(data, dim)
215 else:
216 raise NotImplementedError
218 def f_apply(data, method, **kwargs):
219 for k, v in kwargs.items():
220 if not (isinstance(v, xr.DataArray) or v is None): 220 ↛ 221line 220 didn't jump to line 221, because the condition on line 220 was never true
221 _, opts = statistics.min_max(data, dim)
222 helper = xr.ones_like(opts['min'])
223 kwargs[k] = helper * v
224 mean = kwargs.pop('mean', None)
225 std = kwargs.pop('std', None)
226 min = kwargs.pop('min', None)
227 max = kwargs.pop('max', None)
228 feature_range = kwargs.pop('feature_range', None)
230 if method == "standardise": 230 ↛ 232line 230 didn't jump to line 232, because the condition on line 230 was never false
231 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
232 elif method == "centre":
233 return statistics.centre_apply(data, mean), {"mean": mean, "method": method}
234 elif method == "min_max":
235 kws = {"feature_range": feature_range} if feature_range is not None else {}
236 return statistics.min_max_apply(data, min, max, **kws), {"min": min, "max": max, "method": method,
237 "feature_range": feature_range}
238 elif method == "log":
239 return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
240 else:
241 raise NotImplementedError
243 opts = opts or {}
244 opts_updated = {}
245 if not inverse: 245 ↛ 256line 245 didn't jump to line 256, because the condition on line 245 was never false
246 transformed_values = []
247 for var in data_in.variables.values:
248 data_var = data_in.sel(**{transformation_dim: [var]})
249 var_opts = opts.get(var, {})
250 _apply = (var_opts.get("mean", None) is not None) or (var_opts.get("min") is not None)
251 values, new_var_opts = locals()["f_apply" if _apply else "f"](data_var, **var_opts)
252 opts_updated[var] = copy.deepcopy(new_var_opts)
253 transformed_values.append(values)
254 return xr.concat(transformed_values, dim=transformation_dim), opts_updated
255 else:
256 return self.inverse_transform(data_in, opts, transformation_dim)
258 @TimeTrackingWrapper
259 def setup_samples(self):
260 """
261 Setup samples. This method prepares and creates samples X, and labels Y.
262 """
263 if self.lazy is False: 263 ↛ 266line 263 didn't jump to line 266, because the condition on line 263 was never false
264 self.make_input_target()
265 else:
266 self.load_lazy()
267 self.store_lazy()
268 if self.do_transformation is True:
269 self.call_transform()
270 self.make_samples()
272 def store_lazy(self):
273 hash = self._get_hash()
274 filename = os.path.join(self.lazy_path, hash + ".pickle")
275 if not os.path.exists(filename):
276 dill.dump(self._create_lazy_data(), file=open(filename, "wb"), protocol=4)
278 def _create_lazy_data(self):
279 return [self._data, self.meta, self.input_data, self.target_data]
281 def load_lazy(self):
282 hash = self._get_hash()
283 filename = os.path.join(self.lazy_path, hash + ".pickle")
284 try:
285 if self.overwrite_lazy_data is True:
286 os.remove(filename)
287 raise FileNotFoundError
288 with open(filename, "rb") as pickle_file:
289 lazy_data = dill.load(pickle_file)
290 self._extract_lazy(lazy_data)
291 logging.debug(f"{self.station[0]}: used lazy data")
292 except FileNotFoundError:
293 logging.debug(f"{self.station[0]}: could not use lazy data")
294 self.make_input_target()
296 def _extract_lazy(self, lazy_data):
297 _data, self.meta, _input_data, _target_data = lazy_data
298 f_prep = partial(self._slice_prep, start=self.start, end=self.end)
299 self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
301 def make_input_target(self):
302 data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
303 self.store_data_locally, self.data_origin, self.start, self.end)
304 self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
305 limit=self.interpolation_limit, sampling=self.sampling)
306 self.set_inputs_and_targets()
308 def set_inputs_and_targets(self):
309 inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)})
310 targets = self._data.sel(
311 {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim??
312 self.input_data = inputs
313 self.target_data = targets
315 def make_samples(self):
316 self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) #todo stopped here
317 self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time)
318 self.make_observation(self.target_dim, self.target_var, self.time_dim)
319 self.remove_nan(self.time_dim)
321 def load_data(self, path, station, statistics_per_var, sampling, store_data_locally=False,
322 data_origin: Dict = None, start=None, end=None):
323 """
324 Load data and meta data either from local disk (preferred) or download new data by using a custom download method.
326 Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both
327 cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not
328 set, it is assumed, that data should be saved locally.
329 """
330 check_path_and_create(path)
331 file_name = self._set_file_name(path, station, statistics_per_var)
332 meta_file = self._set_meta_file_name(path, station, statistics_per_var)
333 if self.overwrite_local_data is True: 333 ↛ 334line 333 didn't jump to line 334, because the condition on line 333 was never true
334 logging.debug(f"overwrite_local_data is true, therefore reload {file_name}")
335 if os.path.exists(file_name):
336 os.remove(file_name)
337 if os.path.exists(meta_file):
338 os.remove(meta_file)
339 data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling,
340 store_data_locally=store_data_locally, data_origin=data_origin,
341 time_dim=self.time_dim, target_dim=self.target_dim, iter_dim=self.iter_dim)
342 logging.debug(f"loaded new data")
343 else:
344 try:
345 logging.debug(f"try to load local data from: {file_name}")
346 data = xr.open_dataarray(file_name)
347 meta = pd.read_csv(meta_file, index_col=0)
348 self.check_station_meta(meta, station, data_origin, statistics_per_var)
349 logging.debug("loading finished")
350 except FileNotFoundError as e:
351 logging.debug(e)
352 logging.debug(f"load new data")
353 data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling,
354 store_data_locally=store_data_locally, data_origin=data_origin,
355 time_dim=self.time_dim, target_dim=self.target_dim,
356 iter_dim=self.iter_dim)
357 logging.debug("loading finished")
358 # create slices and check for negative concentration.
359 data = self._slice_prep(data, start=start, end=end)
360 data = self.check_for_negative_concentrations(data)
361 return data, meta
363 def download_data(self, file_name: str, meta_file: str, station, statistics_per_var, sampling,
364 store_data_locally=True, data_origin: Dict = None, time_dim=DEFAULT_TIME_DIM,
365 target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM) -> [xr.DataArray, pd.DataFrame]:
366 """
367 Download data from TOAR database using the JOIN interface or load local era5 data.
369 Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally
370 stored locally using given names for file and meta file.
372 :param file_name: name of file to save data to (containing full path)
373 :param meta_file: name of the meta data file (also containing full path)
375 :return: downloaded data and its meta data
376 """
377 df_all = {}
378 df_era5, df_toar = None, None
379 meta_era5, meta_toar = None, None
380 if data_origin is not None: 380 ↛ 388line 380 didn't jump to line 388, because the condition on line 380 was never false
381 era5_origin = filter_dict_by_value(data_origin, "era5", True)
382 era5_stats = select_from_dict(statistics_per_var, era5_origin.keys())
383 toar_origin = filter_dict_by_value(data_origin, "era5", False)
384 toar_stats = select_from_dict(statistics_per_var, era5_origin.keys(), filter_cond=False)
385 assert len(era5_origin) + len(toar_origin) == len(data_origin)
386 assert len(era5_stats) + len(toar_stats) == len(statistics_per_var)
387 else:
388 era5_origin, toar_origin = None, None
389 era5_stats, toar_stats = statistics_per_var, statistics_per_var
391 # load data
392 if era5_origin is not None and len(era5_stats) > 0: 392 ↛ 394line 392 didn't jump to line 394, because the condition on line 392 was never true
393 # load era5 data
394 df_era5, meta_era5 = data_sources.era5.load_era5(station_name=station, stat_var=era5_stats,
395 sampling=sampling, data_origin=era5_origin)
396 if toar_origin is None or len(toar_stats) > 0: 396 ↛ 401line 396 didn't jump to line 401, because the condition on line 396 was never false
397 # load combined data from toar-data (v2 & v1)
398 df_toar, meta_toar = data_sources.toar_data.download_toar(station=station, toar_stats=toar_stats,
399 sampling=sampling, data_origin=toar_origin)
401 if df_era5 is None and df_toar is None:
402 raise data_sources.toar_data.EmptyQueryResult(f"No data available for era5 and toar-data")
404 df = pd.concat([df_era5, df_toar], axis=1, sort=True)
405 if meta_era5 is not None and meta_toar is not None: 405 ↛ 406line 405 didn't jump to line 406, because the condition on line 405 was never true
406 meta = meta_era5.combine_first(meta_toar)
407 else:
408 meta = meta_era5 if meta_era5 is not None else meta_toar
409 meta.loc["data_origin"] = str(data_origin)
410 meta.loc["statistics_per_var"] = str(statistics_per_var)
412 df_all[station[0]] = df
413 # convert df_all to xarray
414 xarr = {k: xr.DataArray(v, dims=[time_dim, target_dim]) for k, v in df_all.items()}
415 xarr = xr.Dataset(xarr).to_array(dim=iter_dim)
416 if store_data_locally is True: 416 ↛ 420line 416 didn't jump to line 420, because the condition on line 416 was never false
417 # save locally as nc/csv file
418 xarr.to_netcdf(path=file_name)
419 meta.to_csv(meta_file)
420 return xarr, meta
422 @staticmethod
423 def check_station_meta(meta, station, data_origin, statistics_per_var):
424 """
425 Search for the entries in meta data and compare the value with the requested values.
427 Will raise a FileNotFoundError if the values mismatch.
428 """
429 check_dict = {"data_origin": str(data_origin), "statistics_per_var": str(statistics_per_var)}
430 for (k, v) in check_dict.items():
431 if v is None or k not in meta.index: 431 ↛ 432line 431 didn't jump to line 432, because the condition on line 431 was never true
432 continue
433 if meta.at[k, station[0]] != v: 433 ↛ 434line 433 didn't jump to line 434, because the condition on line 433 was never true
434 logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
435 f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new "
436 f"grapping from web.")
437 raise FileNotFoundError
439 def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
440 """
441 Set all negative concentrations to zero.
443 Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/
444 #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox",
445 "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene".
447 :param data: data array containing variables to check
448 :param minimum: minimum value, by default this should be 0
450 :return: corrected data
451 """
452 used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values))
453 if len(used_chem_vars) > 0: 453 ↛ 455line 453 didn't jump to line 455, because the condition on line 453 was never false
454 data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum)
455 return data
457 def setup_data_path(self, data_path: str, sampling: str):
458 return os.path.join(os.path.abspath(data_path), sampling)
460 def shift(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray:
461 """
462 Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0).
464 :param data: data set to shift
465 :param dim: dimension along shift is applied
466 :param window: number of steps to shift (corresponds to the window length)
467 :param offset: use offset to move the window by as many time steps as given in offset. This can be used, if the
468 index time of a history element is not the last timestamp. E.g. you could use offset=23 when dealing with
469 hourly data in combination with daily data (values from 00 to 23 are aggregated on 00 the same day).
471 :return: shifted data
472 """
473 start = 1
474 end = 1
475 if window <= 0:
476 start = window
477 else:
478 end = window + 1
479 res = []
480 _range = list(map(lambda x: x + offset, range(start, end)))
481 for w in _range:
482 res.append(data.shift({dim: -w}))
483 window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim)
484 res = xr.concat(res, dim=window_array)
485 return res
487 @staticmethod
488 def create_index_array(index_name: str, index_value: Iterable[int], squeeze_dim: str) -> xr.DataArray:
489 """
490 Create an 1D xr.DataArray with given index name and value.
492 :param index_name: name of dimension
493 :param index_value: values of this dimension
495 :return: this array
496 """
497 ind = pd.DataFrame({'val': index_value}, index=index_value)
498 res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze(
499 dim=squeeze_dim, drop=True)
500 res.name = index_name
501 return res
503 @staticmethod
504 def _set_file_name(path, station, statistics_per_var):
505 all_vars = sorted(statistics_per_var.keys())
506 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc")
508 @staticmethod
509 def _set_meta_file_name(path, station, statistics_per_var):
510 all_vars = sorted(statistics_per_var.keys())
511 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv")
513 def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None,
514 use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs):
515 """
516 Interpolate values according to different methods.
518 (Copy paste from dataarray.interpolate_na)
520 :param dim:
521 Specifies the dimension along which to interpolate.
522 :param method:
523 {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
524 'polynomial', 'barycentric', 'krog', 'pchip',
525 'spline', 'akima'}, optional
526 String indicating which method to use for interpolation:
528 - 'linear': linear interpolation (Default). Additional keyword
529 arguments are passed to ``numpy.interp``
530 - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
531 'polynomial': are passed to ``scipy.interpolate.interp1d``. If
532 method=='polynomial', the ``order`` keyword argument must also be
533 provided.
534 - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their
535 respective``scipy.interpolate`` classes.
536 :param limit:
537 default None
538 Maximum number of consecutive NaNs to fill. Must be greater than 0
539 or None for no limit.
540 :param use_coordinate:
541 default True
542 Specifies which index to use as the x values in the interpolation
543 formulated as `y = f(x)`. If False, values are treated as if
544 eqaully-spaced along `dim`. If True, the IndexVariable `dim` is
545 used. If use_coordinate is a string, it specifies the name of a
546 coordinate variariable to use as the index.
547 :param kwargs:
549 :return: xarray.DataArray
550 """
551 data = self.create_full_time_dim(data, dim, sampling)
552 return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs)
554 @staticmethod
555 def create_full_time_dim(data, dim, sampling):
556 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
557 start = data.coords[dim].values[0]
558 end = data.coords[dim].values[-1]
559 freq = {"daily": "1D", "hourly": "1H"}.get(sampling)
560 datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq))
561 t = data.sel({dim: start}, drop=True)
562 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords])
563 res = res.transpose(*data.dims)
564 res.loc[data.coords] = data
565 return res
567 def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
568 """
569 Create a xr.DataArray containing history data.
571 Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted
572 data. This is used to represent history in the data. Results are stored in history attribute.
574 :param dim_name_of_inputs: Name of dimension which contains the input variables
575 :param window: number of time steps to look back in history
576 Note: window will be treated as negative value. This should be in agreement with looking back on
577 a time line. Nonetheless positive values are allowed but they are converted to its negative
578 expression
579 :param dim_name_of_shift: Dimension along shift will be applied
580 """
581 window = -abs(window)
582 data = self.input_data
583 offset = self.window_history_offset + self.window_history_end
584 self.history = self.shift(data, dim_name_of_shift, window, offset=offset)
586 def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
587 window: int) -> None:
588 """
589 Create a xr.DataArray containing labels.
591 Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label
592 attribute.
594 :param dim_name_of_target: Name of dimension which contains the target variable
595 :param target_var: Name of target variable in 'dimension'
596 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
597 :param window: lead time of label
598 """
599 window = abs(window)
600 data = self.target_data
601 self.label = self.shift(data, dim_name_of_shift, window)
603 def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
604 """
605 Create a xr.DataArray containing observations.
607 Observations are defined as value of the current time step t. Set observation attribute.
609 :param dim_name_of_target: Name of dimension which contains the observation variable
610 :param target_var: Name of observation variable(s) in 'dimension'
611 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
612 """
613 data = self.target_data
614 self.observation = self.shift(data, dim_name_of_shift, 0)
616 def remove_nan(self, dim: str) -> None:
617 """
618 Remove all NAs slices along dim which contain nans in history, label and observation.
620 This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute.
622 :param dim: dimension along the remove is performed.
623 """
624 intersect = []
625 if (self.history is not None) and (self.label is not None): 625 ↛ 638line 625 didn't jump to line 638, because the condition on line 625 was never false
626 non_nan_history = self.history.dropna(dim=dim)
627 non_nan_label = self.label.dropna(dim=dim)
628 non_nan_observation = self.observation.dropna(dim=dim)
629 if non_nan_label.coords[dim].shape[0] == 0: 629 ↛ 630line 629 didn't jump to line 630, because the condition on line 629 was never true
630 raise ValueError(f'self.label consist of NaNs only - station {self.station} is therefore dropped')
631 if non_nan_history.coords[dim].shape[0] == 0:
632 raise ValueError(f'self.history consist of NaNs only - station {self.station} is therefore dropped')
633 if non_nan_observation.coords[dim].shape[0] == 0: 633 ↛ 634line 633 didn't jump to line 634, because the condition on line 633 was never true
634 raise ValueError(f'self.observation consist of NaNs only - station {self.station} is therefore dropped')
635 intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values,
636 non_nan_observation.coords[dim].values))
638 if len(intersect) < max(self.min_length, 1): 638 ↛ 639line 638 didn't jump to line 639, because the condition on line 638 was never true
639 self.history = None
640 self.label = None
641 self.observation = None
642 else:
643 self.history = self.history.sel({dim: intersect})
644 self.label = self.label.sel({dim: intersect})
645 self.observation = self.observation.sel({dim: intersect})
647 def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray:
648 """
649 Set start and end date for slicing and execute self._slice().
651 :param data: data to slice
652 :param coord: name of axis to slice
654 :return: sliced data
655 """
656 start = start if start is not None else data.coords[self.time_dim][0].values
657 end = end if end is not None else data.coords[self.time_dim][-1].values
658 return self._slice(data, start, end, self.time_dim)
660 @staticmethod
661 def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray:
662 """
663 Slice through a given data_item (for example select only values of 2011).
665 :param data: data to slice
666 :param start: start date of slice
667 :param end: end date of slice
668 :param coord: name of axis to slice
670 :return: sliced data
671 """
672 return data.loc[{coord: slice(str(start), str(end))}]
674 def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
675 """
676 Set up transformation by extracting all relevant information.
678 * Either return new empty DataClass instances if given transformation arg is None,
679 * or return given object twice if transformation is a DataClass instance,
680 * or return the inputs and targets attributes if transformation is a TransformationClass instance (default
681 design behaviour)
682 """
683 if transformation is None:
684 return None, None
685 elif isinstance(transformation, dict):
686 return copy.deepcopy(transformation), copy.deepcopy(transformation)
687 elif isinstance(transformation, tuple) and len(transformation) == 2:
688 return copy.deepcopy(transformation)
689 else:
690 raise NotImplementedError("Cannot handle this.")
692 @staticmethod
693 def check_inverse_transform_params(method: str, mean=None, std=None, min=None, max=None) -> None:
694 """
695 Support inverse_transformation method.
697 Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas
698 normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements.
700 :param mean: data with all mean values
701 :param std: data with all standard deviation values
702 :param method: name of transformation method
703 """
704 msg = ""
705 if method in ['standardise', 'centre'] and mean is None:
706 msg += "mean, "
707 if method == 'standardise' and std is None:
708 msg += "std, "
709 if method == "min_max" and min is None:
710 msg += "min, "
711 if method == "min_max" and max is None:
712 msg += "max, "
713 if len(msg) > 0:
714 raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}")
716 def inverse_transform(self, data_in, opts, transformation_dim) -> xr.DataArray:
717 """
718 Perform inverse transformation.
720 Will raise an AssertionError, if no transformation was performed before. Checks first, if all required
721 statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by
722 new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the
723 current data is not transformed.
724 """
726 def f_inverse(data, method, mean=None, std=None, min=None, max=None, feature_range=None):
727 if method == "standardise":
728 return statistics.standardise_inverse(data, mean, std)
729 elif method == "centre":
730 return statistics.centre_inverse(data, mean)
731 elif method == "min_max":
732 return statistics.min_max_inverse(data, min, max, feature_range)
733 elif method == "log":
734 return statistics.log_inverse(data, mean, std)
735 else:
736 raise NotImplementedError
738 transformed_values = []
739 squeeze = False
740 if transformation_dim in data_in.coords:
741 if transformation_dim not in data_in.dims:
742 data_in = data_in.expand_dims(transformation_dim)
743 squeeze = True
744 else:
745 raise IndexError(f"Could not find given dimension: {transformation_dim}. Available is: {data_in.coords}")
746 for var in data_in.variables.values:
747 data_var = data_in.sel(**{transformation_dim: [var]})
748 var_opts = opts.get(var, {})
749 _method = var_opts.get("method", None)
750 if _method is None:
751 raise AssertionError(f"Inverse transformation method is not set for {var}.")
752 self.check_inverse_transform_params(**var_opts)
753 values = f_inverse(data_var, **var_opts)
754 transformed_values.append(values)
755 res = xr.concat(transformed_values, dim=transformation_dim)
756 return res.squeeze(transformation_dim) if squeeze else res
758 def apply_transformation(self, data, base=None, dim=0, inverse=False):
759 """
760 Apply transformation on external data. Specify if transformation should be based on parameters related to input
761 or target data using `base`. This method can also apply inverse transformation.
763 :param data:
764 :param base:
765 :param dim:
766 :param inverse:
767 :return:
768 """
769 if base in ["target", 1]:
770 pos = 1
771 elif base in ["input", 0]:
772 pos = 0
773 else:
774 raise ValueError("apply transformation requires a reference for transformation options. Please specify if"
775 "you want to use input or target transformation using the parameter 'base'. Given was: " +
776 base)
777 return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse,
778 transformation_dim=self.target_dim)
780 def _hash_list(self):
781 return sorted(list(set(self._hash)))
783 def _get_hash(self):
784 hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode()
785 return hashlib.md5(hash).hexdigest()