Coverage for mlair/data_handler/data_handler_single_station.py: 64%

378 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:40 +0000

1"""Data Preparation class to handle data processing for machine learning.""" 

2 

3__author__ = 'Lukas Leufen, Felix Kleinert' 

4__date__ = '2020-07-20' 

5 

6import copy 

7import datetime as dt 

8import gc 

9 

10import dill 

11import hashlib 

12import logging 

13import os 

14import ast 

15from functools import reduce, partial 

16from typing import Union, List, Iterable, Tuple, Dict, Optional 

17 

18import numpy as np 

19import pandas as pd 

20import xarray as xr 

21 

22from mlair.configuration import check_path_and_create 

23from mlair import helpers 

24from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict 

25from mlair.data_handler.abstract_data_handler import AbstractDataHandler 

26from mlair.helpers import data_sources, check_nested_equality 

27 

28# define a more general date type for type hinting 

29date = Union[dt.date, dt.datetime] 

30str_or_list = Union[str, List[str]] 

31number = Union[float, int] 

32num_or_list = Union[number, List[number]] 

33data_or_none = Union[xr.DataArray, None] 

34 

35 

36class DataHandlerSingleStation(AbstractDataHandler): 

37 """ 

38 :param window_history_offset: used to shift t0 according to the specified value. 

39 :param window_history_end: used to set the last time step that is used to create a sample. A negative value 

40 indicates that not all values up to t0 are used, a positive values indicates usage of values at t>t0. Default 

41 is 0. 

42 """ 

43 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 

44 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 

45 'pblheight': 'maximum'} 

46 DEFAULT_WINDOW_LEAD_TIME = 3 

47 DEFAULT_WINDOW_HISTORY_SIZE = 13 

48 DEFAULT_WINDOW_HISTORY_OFFSET = 0 

49 DEFAULT_WINDOW_HISTORY_END = 0 

50 DEFAULT_TIME_DIM = "datetime" 

51 DEFAULT_TARGET_VAR = "o3" 

52 DEFAULT_TARGET_DIM = "variables" 

53 DEFAULT_ITER_DIM = "Stations" 

54 DEFAULT_WINDOW_DIM = "window" 

55 DEFAULT_SAMPLING = "daily" 

56 DEFAULT_INTERPOLATION_LIMIT = 0 

57 DEFAULT_INTERPOLATION_METHOD = "linear" 

58 chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", 

59 "so2", "toluene"] 

60 

61 _hash = ["station", "statistics_per_var", "data_origin", "sampling", "target_dim", "target_var", "time_dim", 

62 "iter_dim", "window_dim", "window_history_size", "window_history_offset", "window_lead_time", 

63 "interpolation_limit", "interpolation_method", "variables", "window_history_end"] 

64 

65 def __init__(self, station, data_path, statistics_per_var=None, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, 

66 target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, 

67 iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, 

68 window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET, 

69 window_history_end=DEFAULT_WINDOW_HISTORY_END, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, 

70 interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT, 

71 interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD, 

72 overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True, 

73 min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, 

74 lazy_preprocessing: bool = False, overwrite_lazy_data=False, era5_data_path=None, era5_file_names=None, 

75 ifs_data_path=None, ifs_file_names=None, **kwargs): 

76 super().__init__() 

77 self.station = helpers.to_list(station) 

78 self.path = self.setup_data_path(data_path, sampling) 

79 self.lazy = lazy_preprocessing 

80 self.lazy_path = None 

81 if self.lazy is True: 81 ↛ 82line 81 didn't jump to line 82, because the condition on line 81 was never true

82 self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__) 

83 check_path_and_create(self.lazy_path) 

84 self.statistics_per_var = statistics_per_var or self.DEFAULT_VAR_ALL_DICT 

85 self.data_origin = data_origin 

86 self.do_transformation = transformation is not None 

87 self.input_data, self.target_data = None, None 

88 self._transformation = self.setup_transformation(transformation) 

89 

90 self.sampling = sampling 

91 self.target_dim = target_dim 

92 self.target_var = target_var 

93 self.time_dim = time_dim 

94 self.iter_dim = iter_dim 

95 self.window_dim = window_dim 

96 self.window_history_size = window_history_size 

97 self.window_history_offset = window_history_offset 

98 self.window_history_end = window_history_end 

99 self.window_lead_time = window_lead_time 

100 

101 self.interpolation_limit = interpolation_limit 

102 self.interpolation_method = interpolation_method 

103 

104 self.overwrite_local_data = overwrite_local_data 

105 self.overwrite_lazy_data = True if self.overwrite_local_data is True else overwrite_lazy_data 

106 self.store_data_locally = store_data_locally 

107 self.min_length = min_length 

108 self.start = start 

109 self.end = end 

110 

111 # internal 

112 self._data: xr.DataArray = None # loaded raw data 

113 self.meta = None 

114 self.variables = sorted(list(statistics_per_var.keys())) if variables is None else variables 

115 self.history = None 

116 self.label = None 

117 self.observation = None 

118 

119 self._era5_data_path = era5_data_path 

120 self._era5_file_names = era5_file_names 

121 self._ifs_data_path = ifs_data_path 

122 self._ifs_file_names = ifs_file_names 

123 

124 # create samples 

125 self.setup_samples() 

126 self.clean_up() 

127 

128 def clean_up(self): 

129 self._data = None 

130 self.input_data = None 

131 self.target_data = None 

132 gc.collect() 

133 

134 def __str__(self): 

135 return self.station[0] 

136 

137 def __len__(self): 

138 assert len(self.get_X()) == len(self.get_Y()) 

139 return len(self.get_X()) 

140 

141 @property 

142 def shape(self): 

143 return self._data.shape, self.get_X().shape, self.get_Y().shape 

144 

145 def __repr__(self): 

146 return f"StationPrep(station={self.station}, data_path='{self.path}', data_origin={self.data_origin}, " \ 

147 f"statistics_per_var={self.statistics_per_var}, " \ 

148 f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \ 

149 f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \ 

150 f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \ 

151 f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data})" 

152 

153 def get_transposed_history(self) -> xr.DataArray: 

154 """Return history. 

155 

156 :return: history with dimensions datetime, window, Stations, variables. 

157 """ 

158 return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim).copy() 

159 

160 def get_transposed_label(self) -> xr.DataArray: 

161 """Return label. 

162 

163 :return: label with dimensions datetime*, window*, Stations, variables. 

164 """ 

165 return self.label.squeeze([self.iter_dim, self.target_dim]).transpose(self.time_dim, self.window_dim).copy() 

166 

167 def get_X(self, **kwargs): 

168 return self.get_transposed_history().sel({self.target_dim: self.variables}) 

169 

170 def get_Y(self, **kwargs): 

171 return self.get_transposed_label() 

172 

173 def get_coordinates(self): 

174 try: 

175 coords = self.meta.loc[["station_lon", "station_lat"]].astype(float) 

176 coords = coords.rename(index={"station_lon": "lon", "station_lat": "lat"}) 

177 except KeyError: 

178 coords = self.meta.loc[["lon", "lat"]].astype(float) 

179 return coords.to_dict()[str(self)] 

180 

181 def call_transform(self, inverse=False): 

182 opts_input = self._transformation[0] 

183 self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse, 

184 opts=opts_input, transformation_dim=self.target_dim) 

185 opts_target = self._transformation[1] 

186 self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse, 

187 opts=opts_target, transformation_dim=self.target_dim) 

188 self._transformation = (opts_input, opts_target) 

189 

190 def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None, 

191 transformation_dim=DEFAULT_TARGET_DIM): 

192 """ 

193 Transform data according to given transformation settings. 

194 

195 This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0 

196 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale 

197 (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This 

198 method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the 

199 internal transform method, internal mean and internal standard deviation weren't set ('inverse=True'). 

200 

201 :param string/int dim: This param is not used for inverse transformation. 

202 | for xarray.DataArray as string: name of dimension which should be standardised 

203 | for pandas.DataFrame as int: axis of dimension which should be standardised 

204 :param inverse: Switch between transformation and inverse transformation. 

205 

206 :return: xarray.DataArrays or pandas.DataFrames: 

207 #. mean: Mean of data 

208 #. std: Standard deviation of data 

209 #. data: Standardised data 

210 """ 

211 

212 def f(data, method="standardise", feature_range=None): 

213 if method == "standardise": 213 ↛ 215line 213 didn't jump to line 215, because the condition on line 213 was never false

214 return statistics.standardise(data, dim) 

215 elif method == "centre": 

216 return statistics.centre(data, dim) 

217 elif method == "min_max": 

218 kwargs = {"feature_range": feature_range} if feature_range is not None else {} 

219 return statistics.min_max(data, dim, **kwargs) 

220 elif method == "log": 

221 return statistics.log(data, dim) 

222 else: 

223 raise NotImplementedError 

224 

225 def f_apply(data, method, **kwargs): 

226 for k, v in kwargs.items(): 

227 if not (isinstance(v, xr.DataArray) or v is None): 227 ↛ 228line 227 didn't jump to line 228, because the condition on line 227 was never true

228 _, opts = statistics.min_max(data, dim) 

229 helper = xr.ones_like(opts['min']) 

230 kwargs[k] = helper * v 

231 mean = kwargs.pop('mean', None) 

232 std = kwargs.pop('std', None) 

233 min = kwargs.pop('min', None) 

234 max = kwargs.pop('max', None) 

235 feature_range = kwargs.pop('feature_range', None) 

236 

237 if method == "standardise": 237 ↛ 239line 237 didn't jump to line 239, because the condition on line 237 was never false

238 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method} 

239 elif method == "centre": 

240 return statistics.centre_apply(data, mean), {"mean": mean, "method": method} 

241 elif method == "min_max": 

242 kws = {"feature_range": feature_range} if feature_range is not None else {} 

243 return statistics.min_max_apply(data, min, max, **kws), {"min": min, "max": max, "method": method, 

244 "feature_range": feature_range} 

245 elif method == "log": 

246 return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method} 

247 else: 

248 raise NotImplementedError 

249 

250 opts = opts or {} 

251 opts_updated = {} 

252 if not inverse: 252 ↛ 263line 252 didn't jump to line 263, because the condition on line 252 was never false

253 transformed_values = [] 

254 for var in data_in.variables.values: 

255 data_var = data_in.sel(**{transformation_dim: [var]}) 

256 var_opts = opts.get(var, {}) 

257 _apply = (var_opts.get("mean", None) is not None) or (var_opts.get("min") is not None) 

258 values, new_var_opts = locals()["f_apply" if _apply else "f"](data_var, **var_opts) 

259 opts_updated[var] = copy.deepcopy(new_var_opts) 

260 transformed_values.append(values) 

261 return xr.concat(transformed_values, dim=transformation_dim), opts_updated 

262 else: 

263 return self.inverse_transform(data_in, opts, transformation_dim) 

264 

265 @TimeTrackingWrapper 

266 def setup_samples(self): 

267 """ 

268 Setup samples. This method prepares and creates samples X, and labels Y. 

269 """ 

270 if self.lazy is False: 270 ↛ 273line 270 didn't jump to line 273, because the condition on line 270 was never false

271 self.make_input_target() 

272 else: 

273 self.load_lazy() 

274 self.store_lazy() 

275 if self.do_transformation is True: 

276 self.call_transform() 

277 self.make_samples() 

278 

279 def store_lazy(self): 

280 hash = self._get_hash() 

281 filename = os.path.join(self.lazy_path, hash + ".pickle") 

282 if not os.path.exists(filename): 

283 dill.dump(self._create_lazy_data(), file=open(filename, "wb"), protocol=4) 

284 

285 def _create_lazy_data(self): 

286 return [self._data, self.meta, self.input_data, self.target_data] 

287 

288 def load_lazy(self): 

289 hash = self._get_hash() 

290 filename = os.path.join(self.lazy_path, hash + ".pickle") 

291 try: 

292 if self.overwrite_lazy_data is True: 

293 os.remove(filename) 

294 raise FileNotFoundError 

295 with open(filename, "rb") as pickle_file: 

296 lazy_data = dill.load(pickle_file) 

297 self._extract_lazy(lazy_data) 

298 logging.debug(f"{self.station[0]}: used lazy data") 

299 except FileNotFoundError: 

300 logging.debug(f"{self.station[0]}: could not use lazy data") 

301 self.make_input_target() 

302 

303 def _extract_lazy(self, lazy_data): 

304 _data, self.meta, _input_data, _target_data = lazy_data 

305 f_prep = partial(self._slice_prep, start=self.start, end=self.end) 

306 self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) 

307 

308 def make_input_target(self): 

309 vars = [*self.variables, self.target_var] 

310 stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars) 

311 data_origin = helpers.select_from_dict(self.data_origin, vars) 

312 data, self.meta = self.load_data(self.path, self.station, stats_per_var, self.sampling, 

313 self.store_data_locally, data_origin, self.start, self.end) 

314 self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, 

315 limit=self.interpolation_limit, sampling=self.sampling) 

316 self.set_inputs_and_targets() 

317 

318 def set_inputs_and_targets(self): 

319 inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)}) 

320 targets = self._data.sel( 

321 {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim?? 

322 self.input_data = inputs 

323 self.target_data = targets 

324 

325 def make_samples(self): 

326 self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) 

327 self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) 

328 self.make_observation(self.target_dim, self.target_var, self.time_dim) 

329 self.remove_nan(self.time_dim) 

330 

331 def load_data(self, path, station, statistics_per_var, sampling, store_data_locally=False, 

332 data_origin: Dict = None, start=None, end=None): 

333 """ 

334 Load data and meta data either from local disk (preferred) or download new data by using a custom download method. 

335 

336 Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both 

337 cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not 

338 set, it is assumed, that data should be saved locally. 

339 """ 

340 check_path_and_create(path) 

341 file_name = self._set_file_name(path, station, statistics_per_var) 

342 meta_file = self._set_meta_file_name(path, station, statistics_per_var) 

343 if self.overwrite_local_data is True: 343 ↛ 344line 343 didn't jump to line 344, because the condition on line 343 was never true

344 logging.debug(f"{self.station[0]}: overwrite_local_data is true, therefore reload {file_name}") 

345 if os.path.exists(file_name): 

346 os.remove(file_name) 

347 if os.path.exists(meta_file): 

348 os.remove(meta_file) 

349 data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling, 

350 store_data_locally=store_data_locally, data_origin=data_origin, 

351 time_dim=self.time_dim, target_dim=self.target_dim, 

352 iter_dim=self.iter_dim, window_dim=self.window_dim, 

353 era5_data_path=self._era5_data_path, 

354 era5_file_names=self._era5_file_names, 

355 ifs_data_path=self._ifs_data_path, 

356 ifs_file_names=self._ifs_file_names) 

357 logging.debug(f"{self.station[0]}: loaded new data") 

358 else: 

359 try: 

360 logging.debug(f"{self.station[0]}: try to load local data from: {file_name}") 

361 data = xr.open_dataarray(file_name) 

362 meta = pd.read_csv(meta_file, index_col=0) 

363 self.check_station_meta(meta, station, data_origin, statistics_per_var) 

364 logging.debug(f"{self.station[0]}: loading finished") 

365 except FileNotFoundError as e: 

366 logging.debug(f"{self.station[0]}: {e}") 

367 logging.debug(f"{self.station[0]}: load new data") 

368 data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling, 

369 store_data_locally=store_data_locally, data_origin=data_origin, 

370 time_dim=self.time_dim, target_dim=self.target_dim, 

371 iter_dim=self.iter_dim, era5_data_path=self._era5_data_path, 

372 era5_file_names=self._era5_file_names, 

373 ifs_data_path=self._ifs_data_path, 

374 ifs_file_names=self._ifs_file_names) 

375 logging.debug(f"{self.station[0]}: loading finished") 

376 # create slices and check for negative concentration. 

377 data = self._slice_prep(data, start=start, end=end) 

378 data = self.check_for_negative_concentrations(data) 

379 return data, meta 

380 

381 @staticmethod 

382 def check_station_meta(meta, station, data_origin, statistics_per_var): 

383 """ 

384 Search for the entries in meta data and compare the value with the requested values. 

385 

386 Will raise a FileNotFoundError if the values mismatch. 

387 """ 

388 check_dict = {"data_origin": data_origin, "statistics_per_var": statistics_per_var} 

389 for (k, v) in check_dict.items(): 

390 if v is None or k not in meta.index: 390 ↛ 391line 390 didn't jump to line 391, because the condition on line 390 was never true

391 continue 

392 m = ast.literal_eval(meta.at[k, station[0]]) 

393 if not check_nested_equality(select_from_dict(m, v.keys()), v): 393 ↛ 394line 393 didn't jump to line 394, because the condition on line 393 was never true

394 logging.debug(f"{station[0]}: meta data does not agree with given request for {k}: {v} (requested) != " 

395 f"{m} (local). Raise FileNotFoundError to trigger new grapping from web.") 

396 raise FileNotFoundError 

397 

398 def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: 

399 """ 

400 Set all negative concentrations to zero. 

401 

402 Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/ 

403 #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox", 

404 "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene". 

405 

406 :param data: data array containing variables to check 

407 :param minimum: minimum value, by default this should be 0 

408 

409 :return: corrected data 

410 """ 

411 used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values)) 

412 if len(used_chem_vars) > 0: 412 ↛ 414line 412 didn't jump to line 414, because the condition on line 412 was never false

413 data = data.sel({self.target_dim: used_chem_vars}).clip(min=minimum).combine_first(data) 

414 return data 

415 

416 def setup_data_path(self, data_path: str, sampling: str): 

417 return os.path.join(os.path.abspath(data_path), sampling) 

418 

419 def shift(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray: 

420 """ 

421 Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0). 

422 

423 :param data: data set to shift 

424 :param dim: dimension along shift is applied 

425 :param window: number of steps to shift (corresponds to the window length) 

426 :param offset: use offset to move the window by as many time steps as given in offset. This can be used, if the 

427 index time of a history element is not the last timestamp. E.g. you could use offset=23 when dealing with 

428 hourly data in combination with daily data (values from 00 to 23 are aggregated on 00 the same day). 

429 

430 :return: shifted data 

431 """ 

432 start = 1 

433 end = 1 

434 if window <= 0: 

435 start = window 

436 else: 

437 end = window + 1 

438 res = [] 

439 _range = list(map(lambda x: x + offset, range(start, end))) 

440 for w in _range: 

441 res.append(data.shift({dim: -w})) 

442 window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim) 

443 res = xr.concat(res, dim=window_array) 

444 return res 

445 

446 @staticmethod 

447 def create_index_array(index_name: str, index_value: Iterable[int], squeeze_dim: str) -> xr.DataArray: 

448 """ 

449 Create an 1D xr.DataArray with given index name and value. 

450 

451 :param index_name: name of dimension 

452 :param index_value: values of this dimension 

453 

454 :return: this array 

455 """ 

456 ind = pd.DataFrame({'val': index_value}, index=index_value) 

457 res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze( 

458 dim=squeeze_dim, drop=True) 

459 res.name = index_name 

460 return res 

461 

462 @staticmethod 

463 def _set_file_name(path, station, statistics_per_var): 

464 all_vars = sorted(statistics_per_var.keys()) 

465 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc") 

466 

467 @staticmethod 

468 def _set_meta_file_name(path, station, statistics_per_var): 

469 all_vars = sorted(statistics_per_var.keys()) 

470 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv") 

471 

472 def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None, 

473 use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs): 

474 """ 

475 Interpolate values according to different methods. 

476 

477 (Copy paste from dataarray.interpolate_na) 

478 

479 :param dim: 

480 Specifies the dimension along which to interpolate. 

481 :param method: 

482 {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 

483 'polynomial', 'barycentric', 'krog', 'pchip', 

484 'spline', 'akima'}, optional 

485 String indicating which method to use for interpolation: 

486 

487 - 'linear': linear interpolation (Default). Additional keyword 

488 arguments are passed to ``numpy.interp`` 

489 - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 

490 'polynomial': are passed to ``scipy.interpolate.interp1d``. If 

491 method=='polynomial', the ``order`` keyword argument must also be 

492 provided. 

493 - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their 

494 respective``scipy.interpolate`` classes. 

495 :param limit: 

496 default None 

497 Maximum number of consecutive NaNs to fill. Must be greater than 0 

498 or None for no limit. 

499 :param use_coordinate: 

500 default True 

501 Specifies which index to use as the x values in the interpolation 

502 formulated as `y = f(x)`. If False, values are treated as if 

503 eqaully-spaced along `dim`. If True, the IndexVariable `dim` is 

504 used. If use_coordinate is a string, it specifies the name of a 

505 coordinate variariable to use as the index. 

506 :param kwargs: 

507 

508 :return: xarray.DataArray 

509 """ 

510 data = self.create_full_time_dim(data, dim, sampling) 

511 return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs) 

512 

513 @staticmethod 

514 def create_full_time_dim(data, dim, sampling): 

515 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" 

516 start = data.coords[dim].values[0] 

517 end = data.coords[dim].values[-1] 

518 freq = {"daily": "1D", "hourly": "1H"}.get(sampling) 

519 datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq)) 

520 t = data.sel({dim: start}, drop=True) 

521 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) 

522 res = res.transpose(*data.dims) 

523 res.loc[data.coords] = data 

524 return res 

525 

526 def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: 

527 """ 

528 Create a xr.DataArray containing history data. 

529 

530 Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted 

531 data. This is used to represent history in the data. Results are stored in history attribute. 

532 

533 :param dim_name_of_inputs: Name of dimension which contains the input variables 

534 :param window: number of time steps to look back in history 

535 Note: window will be treated as negative value. This should be in agreement with looking back on 

536 a time line. Nonetheless positive values are allowed but they are converted to its negative 

537 expression 

538 :param dim_name_of_shift: Dimension along shift will be applied 

539 """ 

540 window = -abs(window) 

541 data = self.input_data 

542 offset = self.window_history_offset + self.window_history_end 

543 self.history = self.shift(data, dim_name_of_shift, window, offset=offset) 

544 

545 def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, 

546 window: int) -> None: 

547 """ 

548 Create a xr.DataArray containing labels. 

549 

550 Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label 

551 attribute. 

552 

553 :param dim_name_of_target: Name of dimension which contains the target variable 

554 :param target_var: Name of target variable in 'dimension' 

555 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied 

556 :param window: lead time of label 

557 """ 

558 window = abs(window) 

559 data = self.target_data 

560 self.label = self.shift(data, dim_name_of_shift, window) 

561 

562 def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: 

563 """ 

564 Create a xr.DataArray containing observations. 

565 

566 Observations are defined as value of the current time step t. Set observation attribute. 

567 

568 :param dim_name_of_target: Name of dimension which contains the observation variable 

569 :param target_var: Name of observation variable(s) in 'dimension' 

570 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied 

571 """ 

572 data = self.target_data 

573 self.observation = self.shift(data, dim_name_of_shift, 0) 

574 

575 def remove_nan(self, dim: str) -> None: 

576 """ 

577 Remove all NAs slices along dim which contain nans in history, label and observation. 

578 

579 This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute. 

580 

581 :param dim: dimension along the remove is performed. 

582 """ 

583 intersect = [] 

584 if (self.history is not None) and (self.label is not None): 584 ↛ 597line 584 didn't jump to line 597, because the condition on line 584 was never false

585 non_nan_history = self.history.dropna(dim=dim) 

586 non_nan_label = self.label.dropna(dim=dim) 

587 non_nan_observation = self.observation.dropna(dim=dim) 

588 if non_nan_label.coords[dim].shape[0] == 0: 588 ↛ 589line 588 didn't jump to line 589, because the condition on line 588 was never true

589 raise ValueError(f'self.label consist of NaNs only - station {self.station} is therefore dropped') 

590 if non_nan_history.coords[dim].shape[0] == 0: 

591 raise ValueError(f'self.history consist of NaNs only - station {self.station} is therefore dropped') 

592 if non_nan_observation.coords[dim].shape[0] == 0: 592 ↛ 593line 592 didn't jump to line 593, because the condition on line 592 was never true

593 raise ValueError(f'self.observation consist of NaNs only - station {self.station} is therefore dropped') 

594 intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, 

595 non_nan_observation.coords[dim].values)) 

596 

597 if len(intersect) < max(self.min_length, 1): 597 ↛ 598line 597 didn't jump to line 598, because the condition on line 597 was never true

598 self.history = None 

599 self.label = None 

600 self.observation = None 

601 else: 

602 self.history = self.history.sel({dim: intersect}) 

603 self.label = self.label.sel({dim: intersect}) 

604 self.observation = self.observation.sel({dim: intersect}) 

605 

606 def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray: 

607 """ 

608 Set start and end date for slicing and execute self._slice(). 

609 

610 :param data: data to slice 

611 :param coord: name of axis to slice 

612 

613 :return: sliced data 

614 """ 

615 start = start if start is not None else data.coords[self.time_dim][0].values 

616 end = end if end is not None else data.coords[self.time_dim][-1].values 

617 return self._slice(data, start, end, self.time_dim) 

618 

619 @staticmethod 

620 def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray: 

621 """ 

622 Slice through a given data_item (for example select only values of 2011). 

623 

624 :param data: data to slice 

625 :param start: start date of slice 

626 :param end: end date of slice 

627 :param coord: name of axis to slice 

628 

629 :return: sliced data 

630 """ 

631 return data.loc[{coord: slice(str(start), str(end))}] 

632 

633 def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: 

634 """ 

635 Set up transformation by extracting all relevant information. 

636 

637 * Either return new empty DataClass instances if given transformation arg is None, 

638 * or return given object twice if transformation is a DataClass instance, 

639 * or return the inputs and targets attributes if transformation is a TransformationClass instance (default 

640 design behaviour) 

641 """ 

642 if transformation is None: 

643 return None, None 

644 elif isinstance(transformation, dict): 

645 return copy.deepcopy(transformation), copy.deepcopy(transformation) 

646 elif isinstance(transformation, tuple) and len(transformation) == 2: 

647 return copy.deepcopy(transformation) 

648 else: 

649 raise NotImplementedError("Cannot handle this.") 

650 

651 @staticmethod 

652 def check_inverse_transform_params(method: str, mean=None, std=None, min=None, max=None) -> None: 

653 """ 

654 Support inverse_transformation method. 

655 

656 Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas 

657 normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements. 

658 

659 :param mean: data with all mean values 

660 :param std: data with all standard deviation values 

661 :param method: name of transformation method 

662 """ 

663 msg = "" 

664 if method in ['standardise', 'centre'] and mean is None: 

665 msg += "mean, " 

666 if method == 'standardise' and std is None: 

667 msg += "std, " 

668 if method == "min_max" and min is None: 

669 msg += "min, " 

670 if method == "min_max" and max is None: 

671 msg += "max, " 

672 if len(msg) > 0: 

673 raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}") 

674 

675 def inverse_transform(self, data_in, opts, transformation_dim) -> xr.DataArray: 

676 """ 

677 Perform inverse transformation. 

678 

679 Will raise an AssertionError, if no transformation was performed before. Checks first, if all required 

680 statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by 

681 new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the 

682 current data is not transformed. 

683 """ 

684 

685 def f_inverse(data, method, mean=None, std=None, min=None, max=None, feature_range=None): 

686 if method == "standardise": 

687 return statistics.standardise_inverse(data, mean, std) 

688 elif method == "centre": 

689 return statistics.centre_inverse(data, mean) 

690 elif method == "min_max": 

691 return statistics.min_max_inverse(data, min, max, feature_range) 

692 elif method == "log": 

693 return statistics.log_inverse(data, mean, std) 

694 else: 

695 raise NotImplementedError 

696 

697 transformed_values = [] 

698 squeeze = False 

699 if transformation_dim in data_in.coords: 

700 if transformation_dim not in data_in.dims: 

701 data_in = data_in.expand_dims(transformation_dim) 

702 squeeze = True 

703 else: 

704 raise IndexError(f"Could not find given dimension: {transformation_dim}. Available is: {data_in.coords}") 

705 for var in data_in.variables.values: 

706 data_var = data_in.sel(**{transformation_dim: [var]}) 

707 var_opts = opts.get(var, {}) 

708 _method = var_opts.get("method", None) 

709 if _method is None: 

710 raise AssertionError(f"Inverse transformation method is not set for {var}.") 

711 self.check_inverse_transform_params(**var_opts) 

712 values = f_inverse(data_var, **var_opts) 

713 transformed_values.append(values) 

714 res = xr.concat(transformed_values, dim=transformation_dim) 

715 return res.squeeze(transformation_dim) if squeeze else res 

716 

717 def apply_transformation(self, data, base=None, dim=0, inverse=False): 

718 """ 

719 Apply transformation on external data. Specify if transformation should be based on parameters related to input 

720 or target data using `base`. This method can also apply inverse transformation. 

721 

722 :param data: 

723 :param base: 

724 :param dim: 

725 :param inverse: 

726 :return: 

727 """ 

728 if base in ["target", 1]: 

729 pos = 1 

730 elif base in ["input", 0]: 

731 pos = 0 

732 else: 

733 raise ValueError("apply transformation requires a reference for transformation options. Please specify if" 

734 "you want to use input or target transformation using the parameter 'base'. Given was: " + 

735 base) 

736 return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse, 

737 transformation_dim=self.target_dim) 

738 

739 def _hash_list(self): 

740 return sorted(list(set(self._hash))) 

741 

742 def _get_hash(self): 

743 hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode() 

744 return hashlib.md5(hash).hexdigest()