Coverage for mlair/data_handler/data_handler_single_station.py: 63%

370 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-01 13:03 +0000

1"""Data Preparation class to handle data processing for machine learning.""" 

2 

3__author__ = 'Lukas Leufen, Felix Kleinert' 

4__date__ = '2020-07-20' 

5 

6import copy 

7import datetime as dt 

8import gc 

9 

10import dill 

11import hashlib 

12import logging 

13import os 

14from functools import reduce, partial 

15from typing import Union, List, Iterable, Tuple, Dict, Optional 

16 

17import numpy as np 

18import pandas as pd 

19import xarray as xr 

20 

21from mlair.configuration import check_path_and_create 

22from mlair import helpers 

23from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict 

24from mlair.data_handler.abstract_data_handler import AbstractDataHandler 

25from mlair.helpers import data_sources 

26 

27# define a more general date type for type hinting 

28date = Union[dt.date, dt.datetime] 

29str_or_list = Union[str, List[str]] 

30number = Union[float, int] 

31num_or_list = Union[number, List[number]] 

32data_or_none = Union[xr.DataArray, None] 

33 

34 

35class DataHandlerSingleStation(AbstractDataHandler): 

36 """ 

37 :param window_history_offset: used to shift t0 according to the specified value. 

38 :param window_history_end: used to set the last time step that is used to create a sample. A negative value 

39 indicates that not all values up to t0 are used, a positive values indicates usage of values at t>t0. Default 

40 is 0. 

41 """ 

42 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 

43 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 

44 'pblheight': 'maximum'} 

45 DEFAULT_WINDOW_LEAD_TIME = 3 

46 DEFAULT_WINDOW_HISTORY_SIZE = 13 

47 DEFAULT_WINDOW_HISTORY_OFFSET = 0 

48 DEFAULT_WINDOW_HISTORY_END = 0 

49 DEFAULT_TIME_DIM = "datetime" 

50 DEFAULT_TARGET_VAR = "o3" 

51 DEFAULT_TARGET_DIM = "variables" 

52 DEFAULT_ITER_DIM = "Stations" 

53 DEFAULT_WINDOW_DIM = "window" 

54 DEFAULT_SAMPLING = "daily" 

55 DEFAULT_INTERPOLATION_LIMIT = 0 

56 DEFAULT_INTERPOLATION_METHOD = "linear" 

57 chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", 

58 "so2", "toluene"] 

59 

60 _hash = ["station", "statistics_per_var", "data_origin", "sampling", "target_dim", "target_var", "time_dim", 

61 "iter_dim", "window_dim", "window_history_size", "window_history_offset", "window_lead_time", 

62 "interpolation_limit", "interpolation_method", "variables", "window_history_end"] 

63 

64 def __init__(self, station, data_path, statistics_per_var=None, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, 

65 target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, 

66 iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, 

67 window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET, 

68 window_history_end=DEFAULT_WINDOW_HISTORY_END, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, 

69 interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT, 

70 interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD, 

71 overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True, 

72 min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, 

73 lazy_preprocessing: bool = False, overwrite_lazy_data=False, **kwargs): 

74 super().__init__() 

75 self.station = helpers.to_list(station) 

76 self.path = self.setup_data_path(data_path, sampling) 

77 self.lazy = lazy_preprocessing 

78 self.lazy_path = None 

79 if self.lazy is True: 79 ↛ 80line 79 didn't jump to line 80, because the condition on line 79 was never true

80 self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__) 

81 check_path_and_create(self.lazy_path) 

82 self.statistics_per_var = statistics_per_var or self.DEFAULT_VAR_ALL_DICT 

83 self.data_origin = data_origin 

84 self.do_transformation = transformation is not None 

85 self.input_data, self.target_data = None, None 

86 self._transformation = self.setup_transformation(transformation) 

87 

88 self.sampling = sampling 

89 self.target_dim = target_dim 

90 self.target_var = target_var 

91 self.time_dim = time_dim 

92 self.iter_dim = iter_dim 

93 self.window_dim = window_dim 

94 self.window_history_size = window_history_size 

95 self.window_history_offset = window_history_offset 

96 self.window_history_end = window_history_end 

97 self.window_lead_time = window_lead_time 

98 

99 self.interpolation_limit = interpolation_limit 

100 self.interpolation_method = interpolation_method 

101 

102 self.overwrite_local_data = overwrite_local_data 

103 self.overwrite_lazy_data = True if self.overwrite_local_data is True else overwrite_lazy_data 

104 self.store_data_locally = store_data_locally 

105 self.min_length = min_length 

106 self.start = start 

107 self.end = end 

108 

109 # internal 

110 self._data: xr.DataArray = None # loaded raw data 

111 self.meta = None 

112 self.variables = sorted(list(statistics_per_var.keys())) if variables is None else variables 

113 self.history = None 

114 self.label = None 

115 self.observation = None 

116 

117 # create samples 

118 self.setup_samples() 

119 self.clean_up() 

120 

121 def clean_up(self): 

122 self._data = None 

123 self.input_data = None 

124 self.target_data = None 

125 gc.collect() 

126 

127 def __str__(self): 

128 return self.station[0] 

129 

130 def __len__(self): 

131 assert len(self.get_X()) == len(self.get_Y()) 

132 return len(self.get_X()) 

133 

134 @property 

135 def shape(self): 

136 return self._data.shape, self.get_X().shape, self.get_Y().shape 

137 

138 def __repr__(self): 

139 return f"StationPrep(station={self.station}, data_path='{self.path}', data_origin={self.data_origin}, " \ 

140 f"statistics_per_var={self.statistics_per_var}, " \ 

141 f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \ 

142 f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \ 

143 f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \ 

144 f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data})" 

145 

146 def get_transposed_history(self) -> xr.DataArray: 

147 """Return history. 

148 

149 :return: history with dimensions datetime, window, Stations, variables. 

150 """ 

151 return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim).copy() 

152 

153 def get_transposed_label(self) -> xr.DataArray: 

154 """Return label. 

155 

156 :return: label with dimensions datetime*, window*, Stations, variables. 

157 """ 

158 return self.label.squeeze([self.iter_dim, self.target_dim]).transpose(self.time_dim, self.window_dim).copy() 

159 

160 def get_X(self, **kwargs): 

161 return self.get_transposed_history().sel({self.target_dim: self.variables}) 

162 

163 def get_Y(self, **kwargs): 

164 return self.get_transposed_label() 

165 

166 def get_coordinates(self): 

167 try: 

168 coords = self.meta.loc[["station_lon", "station_lat"]].astype(float) 

169 coords = coords.rename(index={"station_lon": "lon", "station_lat": "lat"}) 

170 except KeyError: 

171 coords = self.meta.loc[["lon", "lat"]].astype(float) 

172 return coords.to_dict()[str(self)] 

173 

174 def call_transform(self, inverse=False): 

175 opts_input = self._transformation[0] 

176 self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse, 

177 opts=opts_input, transformation_dim=self.target_dim) 

178 opts_target = self._transformation[1] 

179 self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse, 

180 opts=opts_target, transformation_dim=self.target_dim) 

181 self._transformation = (opts_input, opts_target) 

182 

183 def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None, 

184 transformation_dim=DEFAULT_TARGET_DIM): 

185 """ 

186 Transform data according to given transformation settings. 

187 

188 This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0 

189 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale 

190 (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This 

191 method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the 

192 internal transform method, internal mean and internal standard deviation weren't set ('inverse=True'). 

193 

194 :param string/int dim: This param is not used for inverse transformation. 

195 | for xarray.DataArray as string: name of dimension which should be standardised 

196 | for pandas.DataFrame as int: axis of dimension which should be standardised 

197 :param inverse: Switch between transformation and inverse transformation. 

198 

199 :return: xarray.DataArrays or pandas.DataFrames: 

200 #. mean: Mean of data 

201 #. std: Standard deviation of data 

202 #. data: Standardised data 

203 """ 

204 

205 def f(data, method="standardise", feature_range=None): 

206 if method == "standardise": 206 ↛ 208line 206 didn't jump to line 208, because the condition on line 206 was never false

207 return statistics.standardise(data, dim) 

208 elif method == "centre": 

209 return statistics.centre(data, dim) 

210 elif method == "min_max": 

211 kwargs = {"feature_range": feature_range} if feature_range is not None else {} 

212 return statistics.min_max(data, dim, **kwargs) 

213 elif method == "log": 

214 return statistics.log(data, dim) 

215 else: 

216 raise NotImplementedError 

217 

218 def f_apply(data, method, **kwargs): 

219 for k, v in kwargs.items(): 

220 if not (isinstance(v, xr.DataArray) or v is None): 220 ↛ 221line 220 didn't jump to line 221, because the condition on line 220 was never true

221 _, opts = statistics.min_max(data, dim) 

222 helper = xr.ones_like(opts['min']) 

223 kwargs[k] = helper * v 

224 mean = kwargs.pop('mean', None) 

225 std = kwargs.pop('std', None) 

226 min = kwargs.pop('min', None) 

227 max = kwargs.pop('max', None) 

228 feature_range = kwargs.pop('feature_range', None) 

229 

230 if method == "standardise": 230 ↛ 232line 230 didn't jump to line 232, because the condition on line 230 was never false

231 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method} 

232 elif method == "centre": 

233 return statistics.centre_apply(data, mean), {"mean": mean, "method": method} 

234 elif method == "min_max": 

235 kws = {"feature_range": feature_range} if feature_range is not None else {} 

236 return statistics.min_max_apply(data, min, max, **kws), {"min": min, "max": max, "method": method, 

237 "feature_range": feature_range} 

238 elif method == "log": 

239 return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method} 

240 else: 

241 raise NotImplementedError 

242 

243 opts = opts or {} 

244 opts_updated = {} 

245 if not inverse: 245 ↛ 256line 245 didn't jump to line 256, because the condition on line 245 was never false

246 transformed_values = [] 

247 for var in data_in.variables.values: 

248 data_var = data_in.sel(**{transformation_dim: [var]}) 

249 var_opts = opts.get(var, {}) 

250 _apply = (var_opts.get("mean", None) is not None) or (var_opts.get("min") is not None) 

251 values, new_var_opts = locals()["f_apply" if _apply else "f"](data_var, **var_opts) 

252 opts_updated[var] = copy.deepcopy(new_var_opts) 

253 transformed_values.append(values) 

254 return xr.concat(transformed_values, dim=transformation_dim), opts_updated 

255 else: 

256 return self.inverse_transform(data_in, opts, transformation_dim) 

257 

258 @TimeTrackingWrapper 

259 def setup_samples(self): 

260 """ 

261 Setup samples. This method prepares and creates samples X, and labels Y. 

262 """ 

263 if self.lazy is False: 263 ↛ 266line 263 didn't jump to line 266, because the condition on line 263 was never false

264 self.make_input_target() 

265 else: 

266 self.load_lazy() 

267 self.store_lazy() 

268 if self.do_transformation is True: 

269 self.call_transform() 

270 self.make_samples() 

271 

272 def store_lazy(self): 

273 hash = self._get_hash() 

274 filename = os.path.join(self.lazy_path, hash + ".pickle") 

275 if not os.path.exists(filename): 

276 dill.dump(self._create_lazy_data(), file=open(filename, "wb"), protocol=4) 

277 

278 def _create_lazy_data(self): 

279 return [self._data, self.meta, self.input_data, self.target_data] 

280 

281 def load_lazy(self): 

282 hash = self._get_hash() 

283 filename = os.path.join(self.lazy_path, hash + ".pickle") 

284 try: 

285 if self.overwrite_lazy_data is True: 

286 os.remove(filename) 

287 raise FileNotFoundError 

288 with open(filename, "rb") as pickle_file: 

289 lazy_data = dill.load(pickle_file) 

290 self._extract_lazy(lazy_data) 

291 logging.debug(f"{self.station[0]}: used lazy data") 

292 except FileNotFoundError: 

293 logging.debug(f"{self.station[0]}: could not use lazy data") 

294 self.make_input_target() 

295 

296 def _extract_lazy(self, lazy_data): 

297 _data, self.meta, _input_data, _target_data = lazy_data 

298 f_prep = partial(self._slice_prep, start=self.start, end=self.end) 

299 self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) 

300 

301 def make_input_target(self): 

302 data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, 

303 self.store_data_locally, self.data_origin, self.start, self.end) 

304 self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, 

305 limit=self.interpolation_limit, sampling=self.sampling) 

306 self.set_inputs_and_targets() 

307 

308 def set_inputs_and_targets(self): 

309 inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)}) 

310 targets = self._data.sel( 

311 {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim?? 

312 self.input_data = inputs 

313 self.target_data = targets 

314 

315 def make_samples(self): 

316 self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) #todo stopped here 

317 self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) 

318 self.make_observation(self.target_dim, self.target_var, self.time_dim) 

319 self.remove_nan(self.time_dim) 

320 

321 def load_data(self, path, station, statistics_per_var, sampling, store_data_locally=False, 

322 data_origin: Dict = None, start=None, end=None): 

323 """ 

324 Load data and meta data either from local disk (preferred) or download new data by using a custom download method. 

325 

326 Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both 

327 cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not 

328 set, it is assumed, that data should be saved locally. 

329 """ 

330 check_path_and_create(path) 

331 file_name = self._set_file_name(path, station, statistics_per_var) 

332 meta_file = self._set_meta_file_name(path, station, statistics_per_var) 

333 if self.overwrite_local_data is True: 333 ↛ 334line 333 didn't jump to line 334, because the condition on line 333 was never true

334 logging.debug(f"overwrite_local_data is true, therefore reload {file_name}") 

335 if os.path.exists(file_name): 

336 os.remove(file_name) 

337 if os.path.exists(meta_file): 

338 os.remove(meta_file) 

339 data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling, 

340 store_data_locally=store_data_locally, data_origin=data_origin, 

341 time_dim=self.time_dim, target_dim=self.target_dim, iter_dim=self.iter_dim) 

342 logging.debug(f"loaded new data") 

343 else: 

344 try: 

345 logging.debug(f"try to load local data from: {file_name}") 

346 data = xr.open_dataarray(file_name) 

347 meta = pd.read_csv(meta_file, index_col=0) 

348 self.check_station_meta(meta, station, data_origin, statistics_per_var) 

349 logging.debug("loading finished") 

350 except FileNotFoundError as e: 

351 logging.debug(e) 

352 logging.debug(f"load new data") 

353 data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling, 

354 store_data_locally=store_data_locally, data_origin=data_origin, 

355 time_dim=self.time_dim, target_dim=self.target_dim, 

356 iter_dim=self.iter_dim) 

357 logging.debug("loading finished") 

358 # create slices and check for negative concentration. 

359 data = self._slice_prep(data, start=start, end=end) 

360 data = self.check_for_negative_concentrations(data) 

361 return data, meta 

362 

363 @staticmethod 

364 def check_station_meta(meta, station, data_origin, statistics_per_var): 

365 """ 

366 Search for the entries in meta data and compare the value with the requested values. 

367 

368 Will raise a FileNotFoundError if the values mismatch. 

369 """ 

370 check_dict = {"data_origin": str(data_origin), "statistics_per_var": str(statistics_per_var)} 

371 for (k, v) in check_dict.items(): 

372 if v is None or k not in meta.index: 372 ↛ 373line 372 didn't jump to line 373, because the condition on line 372 was never true

373 continue 

374 if meta.at[k, station[0]] != v: 374 ↛ 375line 374 didn't jump to line 375, because the condition on line 374 was never true

375 logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " 

376 f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new " 

377 f"grapping from web.") 

378 raise FileNotFoundError 

379 

380 def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: 

381 """ 

382 Set all negative concentrations to zero. 

383 

384 Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/ 

385 #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox", 

386 "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene". 

387 

388 :param data: data array containing variables to check 

389 :param minimum: minimum value, by default this should be 0 

390 

391 :return: corrected data 

392 """ 

393 used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values)) 

394 if len(used_chem_vars) > 0: 394 ↛ 396line 394 didn't jump to line 396, because the condition on line 394 was never false

395 data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) 

396 return data 

397 

398 def setup_data_path(self, data_path: str, sampling: str): 

399 return os.path.join(os.path.abspath(data_path), sampling) 

400 

401 def shift(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray: 

402 """ 

403 Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0). 

404 

405 :param data: data set to shift 

406 :param dim: dimension along shift is applied 

407 :param window: number of steps to shift (corresponds to the window length) 

408 :param offset: use offset to move the window by as many time steps as given in offset. This can be used, if the 

409 index time of a history element is not the last timestamp. E.g. you could use offset=23 when dealing with 

410 hourly data in combination with daily data (values from 00 to 23 are aggregated on 00 the same day). 

411 

412 :return: shifted data 

413 """ 

414 start = 1 

415 end = 1 

416 if window <= 0: 

417 start = window 

418 else: 

419 end = window + 1 

420 res = [] 

421 _range = list(map(lambda x: x + offset, range(start, end))) 

422 for w in _range: 

423 res.append(data.shift({dim: -w})) 

424 window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim) 

425 res = xr.concat(res, dim=window_array) 

426 return res 

427 

428 @staticmethod 

429 def create_index_array(index_name: str, index_value: Iterable[int], squeeze_dim: str) -> xr.DataArray: 

430 """ 

431 Create an 1D xr.DataArray with given index name and value. 

432 

433 :param index_name: name of dimension 

434 :param index_value: values of this dimension 

435 

436 :return: this array 

437 """ 

438 ind = pd.DataFrame({'val': index_value}, index=index_value) 

439 res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze( 

440 dim=squeeze_dim, drop=True) 

441 res.name = index_name 

442 return res 

443 

444 @staticmethod 

445 def _set_file_name(path, station, statistics_per_var): 

446 all_vars = sorted(statistics_per_var.keys()) 

447 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc") 

448 

449 @staticmethod 

450 def _set_meta_file_name(path, station, statistics_per_var): 

451 all_vars = sorted(statistics_per_var.keys()) 

452 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv") 

453 

454 def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None, 

455 use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs): 

456 """ 

457 Interpolate values according to different methods. 

458 

459 (Copy paste from dataarray.interpolate_na) 

460 

461 :param dim: 

462 Specifies the dimension along which to interpolate. 

463 :param method: 

464 {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 

465 'polynomial', 'barycentric', 'krog', 'pchip', 

466 'spline', 'akima'}, optional 

467 String indicating which method to use for interpolation: 

468 

469 - 'linear': linear interpolation (Default). Additional keyword 

470 arguments are passed to ``numpy.interp`` 

471 - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 

472 'polynomial': are passed to ``scipy.interpolate.interp1d``. If 

473 method=='polynomial', the ``order`` keyword argument must also be 

474 provided. 

475 - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their 

476 respective``scipy.interpolate`` classes. 

477 :param limit: 

478 default None 

479 Maximum number of consecutive NaNs to fill. Must be greater than 0 

480 or None for no limit. 

481 :param use_coordinate: 

482 default True 

483 Specifies which index to use as the x values in the interpolation 

484 formulated as `y = f(x)`. If False, values are treated as if 

485 eqaully-spaced along `dim`. If True, the IndexVariable `dim` is 

486 used. If use_coordinate is a string, it specifies the name of a 

487 coordinate variariable to use as the index. 

488 :param kwargs: 

489 

490 :return: xarray.DataArray 

491 """ 

492 data = self.create_full_time_dim(data, dim, sampling) 

493 return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs) 

494 

495 @staticmethod 

496 def create_full_time_dim(data, dim, sampling): 

497 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" 

498 start = data.coords[dim].values[0] 

499 end = data.coords[dim].values[-1] 

500 freq = {"daily": "1D", "hourly": "1H"}.get(sampling) 

501 datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq)) 

502 t = data.sel({dim: start}, drop=True) 

503 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) 

504 res = res.transpose(*data.dims) 

505 res.loc[data.coords] = data 

506 return res 

507 

508 def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: 

509 """ 

510 Create a xr.DataArray containing history data. 

511 

512 Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted 

513 data. This is used to represent history in the data. Results are stored in history attribute. 

514 

515 :param dim_name_of_inputs: Name of dimension which contains the input variables 

516 :param window: number of time steps to look back in history 

517 Note: window will be treated as negative value. This should be in agreement with looking back on 

518 a time line. Nonetheless positive values are allowed but they are converted to its negative 

519 expression 

520 :param dim_name_of_shift: Dimension along shift will be applied 

521 """ 

522 window = -abs(window) 

523 data = self.input_data 

524 offset = self.window_history_offset + self.window_history_end 

525 self.history = self.shift(data, dim_name_of_shift, window, offset=offset) 

526 

527 def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, 

528 window: int) -> None: 

529 """ 

530 Create a xr.DataArray containing labels. 

531 

532 Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label 

533 attribute. 

534 

535 :param dim_name_of_target: Name of dimension which contains the target variable 

536 :param target_var: Name of target variable in 'dimension' 

537 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied 

538 :param window: lead time of label 

539 """ 

540 window = abs(window) 

541 data = self.target_data 

542 self.label = self.shift(data, dim_name_of_shift, window) 

543 

544 def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: 

545 """ 

546 Create a xr.DataArray containing observations. 

547 

548 Observations are defined as value of the current time step t. Set observation attribute. 

549 

550 :param dim_name_of_target: Name of dimension which contains the observation variable 

551 :param target_var: Name of observation variable(s) in 'dimension' 

552 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied 

553 """ 

554 data = self.target_data 

555 self.observation = self.shift(data, dim_name_of_shift, 0) 

556 

557 def remove_nan(self, dim: str) -> None: 

558 """ 

559 Remove all NAs slices along dim which contain nans in history, label and observation. 

560 

561 This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute. 

562 

563 :param dim: dimension along the remove is performed. 

564 """ 

565 intersect = [] 

566 if (self.history is not None) and (self.label is not None): 566 ↛ 579line 566 didn't jump to line 579, because the condition on line 566 was never false

567 non_nan_history = self.history.dropna(dim=dim) 

568 non_nan_label = self.label.dropna(dim=dim) 

569 non_nan_observation = self.observation.dropna(dim=dim) 

570 if non_nan_label.coords[dim].shape[0] == 0: 570 ↛ 571line 570 didn't jump to line 571, because the condition on line 570 was never true

571 raise ValueError(f'self.label consist of NaNs only - station {self.station} is therefore dropped') 

572 if non_nan_history.coords[dim].shape[0] == 0: 

573 raise ValueError(f'self.history consist of NaNs only - station {self.station} is therefore dropped') 

574 if non_nan_observation.coords[dim].shape[0] == 0: 574 ↛ 575line 574 didn't jump to line 575, because the condition on line 574 was never true

575 raise ValueError(f'self.observation consist of NaNs only - station {self.station} is therefore dropped') 

576 intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, 

577 non_nan_observation.coords[dim].values)) 

578 

579 if len(intersect) < max(self.min_length, 1): 579 ↛ 580line 579 didn't jump to line 580, because the condition on line 579 was never true

580 self.history = None 

581 self.label = None 

582 self.observation = None 

583 else: 

584 self.history = self.history.sel({dim: intersect}) 

585 self.label = self.label.sel({dim: intersect}) 

586 self.observation = self.observation.sel({dim: intersect}) 

587 

588 def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray: 

589 """ 

590 Set start and end date for slicing and execute self._slice(). 

591 

592 :param data: data to slice 

593 :param coord: name of axis to slice 

594 

595 :return: sliced data 

596 """ 

597 start = start if start is not None else data.coords[self.time_dim][0].values 

598 end = end if end is not None else data.coords[self.time_dim][-1].values 

599 return self._slice(data, start, end, self.time_dim) 

600 

601 @staticmethod 

602 def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray: 

603 """ 

604 Slice through a given data_item (for example select only values of 2011). 

605 

606 :param data: data to slice 

607 :param start: start date of slice 

608 :param end: end date of slice 

609 :param coord: name of axis to slice 

610 

611 :return: sliced data 

612 """ 

613 return data.loc[{coord: slice(str(start), str(end))}] 

614 

615 def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: 

616 """ 

617 Set up transformation by extracting all relevant information. 

618 

619 * Either return new empty DataClass instances if given transformation arg is None, 

620 * or return given object twice if transformation is a DataClass instance, 

621 * or return the inputs and targets attributes if transformation is a TransformationClass instance (default 

622 design behaviour) 

623 """ 

624 if transformation is None: 

625 return None, None 

626 elif isinstance(transformation, dict): 

627 return copy.deepcopy(transformation), copy.deepcopy(transformation) 

628 elif isinstance(transformation, tuple) and len(transformation) == 2: 

629 return copy.deepcopy(transformation) 

630 else: 

631 raise NotImplementedError("Cannot handle this.") 

632 

633 @staticmethod 

634 def check_inverse_transform_params(method: str, mean=None, std=None, min=None, max=None) -> None: 

635 """ 

636 Support inverse_transformation method. 

637 

638 Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas 

639 normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements. 

640 

641 :param mean: data with all mean values 

642 :param std: data with all standard deviation values 

643 :param method: name of transformation method 

644 """ 

645 msg = "" 

646 if method in ['standardise', 'centre'] and mean is None: 

647 msg += "mean, " 

648 if method == 'standardise' and std is None: 

649 msg += "std, " 

650 if method == "min_max" and min is None: 

651 msg += "min, " 

652 if method == "min_max" and max is None: 

653 msg += "max, " 

654 if len(msg) > 0: 

655 raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}") 

656 

657 def inverse_transform(self, data_in, opts, transformation_dim) -> xr.DataArray: 

658 """ 

659 Perform inverse transformation. 

660 

661 Will raise an AssertionError, if no transformation was performed before. Checks first, if all required 

662 statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by 

663 new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the 

664 current data is not transformed. 

665 """ 

666 

667 def f_inverse(data, method, mean=None, std=None, min=None, max=None, feature_range=None): 

668 if method == "standardise": 

669 return statistics.standardise_inverse(data, mean, std) 

670 elif method == "centre": 

671 return statistics.centre_inverse(data, mean) 

672 elif method == "min_max": 

673 return statistics.min_max_inverse(data, min, max, feature_range) 

674 elif method == "log": 

675 return statistics.log_inverse(data, mean, std) 

676 else: 

677 raise NotImplementedError 

678 

679 transformed_values = [] 

680 squeeze = False 

681 if transformation_dim in data_in.coords: 

682 if transformation_dim not in data_in.dims: 

683 data_in = data_in.expand_dims(transformation_dim) 

684 squeeze = True 

685 else: 

686 raise IndexError(f"Could not find given dimension: {transformation_dim}. Available is: {data_in.coords}") 

687 for var in data_in.variables.values: 

688 data_var = data_in.sel(**{transformation_dim: [var]}) 

689 var_opts = opts.get(var, {}) 

690 _method = var_opts.get("method", None) 

691 if _method is None: 

692 raise AssertionError(f"Inverse transformation method is not set for {var}.") 

693 self.check_inverse_transform_params(**var_opts) 

694 values = f_inverse(data_var, **var_opts) 

695 transformed_values.append(values) 

696 res = xr.concat(transformed_values, dim=transformation_dim) 

697 return res.squeeze(transformation_dim) if squeeze else res 

698 

699 def apply_transformation(self, data, base=None, dim=0, inverse=False): 

700 """ 

701 Apply transformation on external data. Specify if transformation should be based on parameters related to input 

702 or target data using `base`. This method can also apply inverse transformation. 

703 

704 :param data: 

705 :param base: 

706 :param dim: 

707 :param inverse: 

708 :return: 

709 """ 

710 if base in ["target", 1]: 

711 pos = 1 

712 elif base in ["input", 0]: 

713 pos = 0 

714 else: 

715 raise ValueError("apply transformation requires a reference for transformation options. Please specify if" 

716 "you want to use input or target transformation using the parameter 'base'. Given was: " + 

717 base) 

718 return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse, 

719 transformation_dim=self.target_dim) 

720 

721 def _hash_list(self): 

722 return sorted(list(set(self._hash))) 

723 

724 def _get_hash(self): 

725 hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode() 

726 return hashlib.md5(hash).hexdigest()