Coverage for mlair/data_handler/data_handler_single_station.py: 65%

402 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-12-02 15:24 +0000

1"""Data Preparation class to handle data processing for machine learning.""" 

2 

3__author__ = 'Lukas Leufen, Felix Kleinert' 

4__date__ = '2020-07-20' 

5 

6import copy 

7import datetime as dt 

8import gc 

9 

10import dill 

11import hashlib 

12import logging 

13import os 

14from functools import reduce, partial 

15from typing import Union, List, Iterable, Tuple, Dict, Optional 

16 

17import numpy as np 

18import pandas as pd 

19import xarray as xr 

20 

21from mlair.configuration import check_path_and_create 

22from mlair import helpers 

23from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict 

24from mlair.data_handler.abstract_data_handler import AbstractDataHandler 

25from mlair.helpers import data_sources 

26 

27# define a more general date type for type hinting 

28date = Union[dt.date, dt.datetime] 

29str_or_list = Union[str, List[str]] 

30number = Union[float, int] 

31num_or_list = Union[number, List[number]] 

32data_or_none = Union[xr.DataArray, None] 

33 

34 

35class DataHandlerSingleStation(AbstractDataHandler): 

36 """ 

37 :param window_history_offset: used to shift t0 according to the specified value. 

38 :param window_history_end: used to set the last time step that is used to create a sample. A negative value 

39 indicates that not all values up to t0 are used, a positive values indicates usage of values at t>t0. Default 

40 is 0. 

41 """ 

42 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 

43 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 

44 'pblheight': 'maximum'} 

45 DEFAULT_WINDOW_LEAD_TIME = 3 

46 DEFAULT_WINDOW_HISTORY_SIZE = 13 

47 DEFAULT_WINDOW_HISTORY_OFFSET = 0 

48 DEFAULT_WINDOW_HISTORY_END = 0 

49 DEFAULT_TIME_DIM = "datetime" 

50 DEFAULT_TARGET_VAR = "o3" 

51 DEFAULT_TARGET_DIM = "variables" 

52 DEFAULT_ITER_DIM = "Stations" 

53 DEFAULT_WINDOW_DIM = "window" 

54 DEFAULT_SAMPLING = "daily" 

55 DEFAULT_INTERPOLATION_LIMIT = 0 

56 DEFAULT_INTERPOLATION_METHOD = "linear" 

57 chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", 

58 "so2", "toluene"] 

59 

60 _hash = ["station", "statistics_per_var", "data_origin", "sampling", "target_dim", "target_var", "time_dim", 

61 "iter_dim", "window_dim", "window_history_size", "window_history_offset", "window_lead_time", 

62 "interpolation_limit", "interpolation_method", "variables", "window_history_end"] 

63 

64 def __init__(self, station, data_path, statistics_per_var=None, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, 

65 target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, 

66 iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, 

67 window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET, 

68 window_history_end=DEFAULT_WINDOW_HISTORY_END, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, 

69 interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT, 

70 interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD, 

71 overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True, 

72 min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, 

73 lazy_preprocessing: bool = False, overwrite_lazy_data=False, **kwargs): 

74 super().__init__() 

75 self.station = helpers.to_list(station) 

76 self.path = self.setup_data_path(data_path, sampling) 

77 self.lazy = lazy_preprocessing 

78 self.lazy_path = None 

79 if self.lazy is True: 79 ↛ 80line 79 didn't jump to line 80, because the condition on line 79 was never true

80 self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__) 

81 check_path_and_create(self.lazy_path) 

82 self.statistics_per_var = statistics_per_var or self.DEFAULT_VAR_ALL_DICT 

83 self.data_origin = data_origin 

84 self.do_transformation = transformation is not None 

85 self.input_data, self.target_data = None, None 

86 self._transformation = self.setup_transformation(transformation) 

87 

88 self.sampling = sampling 

89 self.target_dim = target_dim 

90 self.target_var = target_var 

91 self.time_dim = time_dim 

92 self.iter_dim = iter_dim 

93 self.window_dim = window_dim 

94 self.window_history_size = window_history_size 

95 self.window_history_offset = window_history_offset 

96 self.window_history_end = window_history_end 

97 self.window_lead_time = window_lead_time 

98 

99 self.interpolation_limit = interpolation_limit 

100 self.interpolation_method = interpolation_method 

101 

102 self.overwrite_local_data = overwrite_local_data 

103 self.overwrite_lazy_data = True if self.overwrite_local_data is True else overwrite_lazy_data 

104 self.store_data_locally = store_data_locally 

105 self.min_length = min_length 

106 self.start = start 

107 self.end = end 

108 

109 # internal 

110 self._data: xr.DataArray = None # loaded raw data 

111 self.meta = None 

112 self.variables = sorted(list(statistics_per_var.keys())) if variables is None else variables 

113 self.history = None 

114 self.label = None 

115 self.observation = None 

116 

117 # create samples 

118 self.setup_samples() 

119 self.clean_up() 

120 

121 def clean_up(self): 

122 self._data = None 

123 self.input_data = None 

124 self.target_data = None 

125 gc.collect() 

126 

127 def __str__(self): 

128 return self.station[0] 

129 

130 def __len__(self): 

131 assert len(self.get_X()) == len(self.get_Y()) 

132 return len(self.get_X()) 

133 

134 @property 

135 def shape(self): 

136 return self._data.shape, self.get_X().shape, self.get_Y().shape 

137 

138 def __repr__(self): 

139 return f"StationPrep(station={self.station}, data_path='{self.path}', data_origin={self.data_origin}, " \ 

140 f"statistics_per_var={self.statistics_per_var}, " \ 

141 f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \ 

142 f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \ 

143 f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \ 

144 f"interpolation_method='{self.interpolation_method}', overwrite_local_data={self.overwrite_local_data})" 

145 

146 def get_transposed_history(self) -> xr.DataArray: 

147 """Return history. 

148 

149 :return: history with dimensions datetime, window, Stations, variables. 

150 """ 

151 return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim).copy() 

152 

153 def get_transposed_label(self) -> xr.DataArray: 

154 """Return label. 

155 

156 :return: label with dimensions datetime*, window*, Stations, variables. 

157 """ 

158 return self.label.squeeze([self.iter_dim, self.target_dim]).transpose(self.time_dim, self.window_dim).copy() 

159 

160 def get_X(self, **kwargs): 

161 return self.get_transposed_history().sel({self.target_dim: self.variables}) 

162 

163 def get_Y(self, **kwargs): 

164 return self.get_transposed_label() 

165 

166 def get_coordinates(self): 

167 try: 

168 coords = self.meta.loc[["station_lon", "station_lat"]].astype(float) 

169 coords = coords.rename(index={"station_lon": "lon", "station_lat": "lat"}) 

170 except KeyError: 

171 coords = self.meta.loc[["lon", "lat"]].astype(float) 

172 return coords.to_dict()[str(self)] 

173 

174 def call_transform(self, inverse=False): 

175 opts_input = self._transformation[0] 

176 self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse, 

177 opts=opts_input, transformation_dim=self.target_dim) 

178 opts_target = self._transformation[1] 

179 self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse, 

180 opts=opts_target, transformation_dim=self.target_dim) 

181 self._transformation = (opts_input, opts_target) 

182 

183 def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None, 

184 transformation_dim=DEFAULT_TARGET_DIM): 

185 """ 

186 Transform data according to given transformation settings. 

187 

188 This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0 

189 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale 

190 (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This 

191 method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the 

192 internal transform method, internal mean and internal standard deviation weren't set ('inverse=True'). 

193 

194 :param string/int dim: This param is not used for inverse transformation. 

195 | for xarray.DataArray as string: name of dimension which should be standardised 

196 | for pandas.DataFrame as int: axis of dimension which should be standardised 

197 :param inverse: Switch between transformation and inverse transformation. 

198 

199 :return: xarray.DataArrays or pandas.DataFrames: 

200 #. mean: Mean of data 

201 #. std: Standard deviation of data 

202 #. data: Standardised data 

203 """ 

204 

205 def f(data, method="standardise", feature_range=None): 

206 if method == "standardise": 206 ↛ 208line 206 didn't jump to line 208, because the condition on line 206 was never false

207 return statistics.standardise(data, dim) 

208 elif method == "centre": 

209 return statistics.centre(data, dim) 

210 elif method == "min_max": 

211 kwargs = {"feature_range": feature_range} if feature_range is not None else {} 

212 return statistics.min_max(data, dim, **kwargs) 

213 elif method == "log": 

214 return statistics.log(data, dim) 

215 else: 

216 raise NotImplementedError 

217 

218 def f_apply(data, method, **kwargs): 

219 for k, v in kwargs.items(): 

220 if not (isinstance(v, xr.DataArray) or v is None): 220 ↛ 221line 220 didn't jump to line 221, because the condition on line 220 was never true

221 _, opts = statistics.min_max(data, dim) 

222 helper = xr.ones_like(opts['min']) 

223 kwargs[k] = helper * v 

224 mean = kwargs.pop('mean', None) 

225 std = kwargs.pop('std', None) 

226 min = kwargs.pop('min', None) 

227 max = kwargs.pop('max', None) 

228 feature_range = kwargs.pop('feature_range', None) 

229 

230 if method == "standardise": 230 ↛ 232line 230 didn't jump to line 232, because the condition on line 230 was never false

231 return statistics.standardise_apply(data, mean, std), {"mean": mean, "std": std, "method": method} 

232 elif method == "centre": 

233 return statistics.centre_apply(data, mean), {"mean": mean, "method": method} 

234 elif method == "min_max": 

235 kws = {"feature_range": feature_range} if feature_range is not None else {} 

236 return statistics.min_max_apply(data, min, max, **kws), {"min": min, "max": max, "method": method, 

237 "feature_range": feature_range} 

238 elif method == "log": 

239 return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method} 

240 else: 

241 raise NotImplementedError 

242 

243 opts = opts or {} 

244 opts_updated = {} 

245 if not inverse: 245 ↛ 256line 245 didn't jump to line 256, because the condition on line 245 was never false

246 transformed_values = [] 

247 for var in data_in.variables.values: 

248 data_var = data_in.sel(**{transformation_dim: [var]}) 

249 var_opts = opts.get(var, {}) 

250 _apply = (var_opts.get("mean", None) is not None) or (var_opts.get("min") is not None) 

251 values, new_var_opts = locals()["f_apply" if _apply else "f"](data_var, **var_opts) 

252 opts_updated[var] = copy.deepcopy(new_var_opts) 

253 transformed_values.append(values) 

254 return xr.concat(transformed_values, dim=transformation_dim), opts_updated 

255 else: 

256 return self.inverse_transform(data_in, opts, transformation_dim) 

257 

258 @TimeTrackingWrapper 

259 def setup_samples(self): 

260 """ 

261 Setup samples. This method prepares and creates samples X, and labels Y. 

262 """ 

263 if self.lazy is False: 263 ↛ 266line 263 didn't jump to line 266, because the condition on line 263 was never false

264 self.make_input_target() 

265 else: 

266 self.load_lazy() 

267 self.store_lazy() 

268 if self.do_transformation is True: 

269 self.call_transform() 

270 self.make_samples() 

271 

272 def store_lazy(self): 

273 hash = self._get_hash() 

274 filename = os.path.join(self.lazy_path, hash + ".pickle") 

275 if not os.path.exists(filename): 

276 dill.dump(self._create_lazy_data(), file=open(filename, "wb"), protocol=4) 

277 

278 def _create_lazy_data(self): 

279 return [self._data, self.meta, self.input_data, self.target_data] 

280 

281 def load_lazy(self): 

282 hash = self._get_hash() 

283 filename = os.path.join(self.lazy_path, hash + ".pickle") 

284 try: 

285 if self.overwrite_lazy_data is True: 

286 os.remove(filename) 

287 raise FileNotFoundError 

288 with open(filename, "rb") as pickle_file: 

289 lazy_data = dill.load(pickle_file) 

290 self._extract_lazy(lazy_data) 

291 logging.debug(f"{self.station[0]}: used lazy data") 

292 except FileNotFoundError: 

293 logging.debug(f"{self.station[0]}: could not use lazy data") 

294 self.make_input_target() 

295 

296 def _extract_lazy(self, lazy_data): 

297 _data, self.meta, _input_data, _target_data = lazy_data 

298 f_prep = partial(self._slice_prep, start=self.start, end=self.end) 

299 self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) 

300 

301 def make_input_target(self): 

302 data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, 

303 self.store_data_locally, self.data_origin, self.start, self.end) 

304 self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, 

305 limit=self.interpolation_limit, sampling=self.sampling) 

306 self.set_inputs_and_targets() 

307 

308 def set_inputs_and_targets(self): 

309 inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)}) 

310 targets = self._data.sel( 

311 {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim?? 

312 self.input_data = inputs 

313 self.target_data = targets 

314 

315 def make_samples(self): 

316 self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) #todo stopped here 

317 self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) 

318 self.make_observation(self.target_dim, self.target_var, self.time_dim) 

319 self.remove_nan(self.time_dim) 

320 

321 def load_data(self, path, station, statistics_per_var, sampling, store_data_locally=False, 

322 data_origin: Dict = None, start=None, end=None): 

323 """ 

324 Load data and meta data either from local disk (preferred) or download new data by using a custom download method. 

325 

326 Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both 

327 cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not 

328 set, it is assumed, that data should be saved locally. 

329 """ 

330 check_path_and_create(path) 

331 file_name = self._set_file_name(path, station, statistics_per_var) 

332 meta_file = self._set_meta_file_name(path, station, statistics_per_var) 

333 if self.overwrite_local_data is True: 333 ↛ 334line 333 didn't jump to line 334, because the condition on line 333 was never true

334 logging.debug(f"overwrite_local_data is true, therefore reload {file_name}") 

335 if os.path.exists(file_name): 

336 os.remove(file_name) 

337 if os.path.exists(meta_file): 

338 os.remove(meta_file) 

339 data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, 

340 store_data_locally=store_data_locally, data_origin=data_origin, 

341 time_dim=self.time_dim, target_dim=self.target_dim, iter_dim=self.iter_dim) 

342 logging.debug(f"loaded new data") 

343 else: 

344 try: 

345 logging.debug(f"try to load local data from: {file_name}") 

346 data = xr.open_dataarray(file_name) 

347 meta = pd.read_csv(meta_file, index_col=0) 

348 self.check_station_meta(meta, station, data_origin, statistics_per_var) 

349 logging.debug("loading finished") 

350 except FileNotFoundError as e: 

351 logging.debug(e) 

352 logging.debug(f"load new data") 

353 data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, 

354 store_data_locally=store_data_locally, data_origin=data_origin, 

355 time_dim=self.time_dim, target_dim=self.target_dim, 

356 iter_dim=self.iter_dim) 

357 logging.debug("loading finished") 

358 # create slices and check for negative concentration. 

359 data = self._slice_prep(data, start=start, end=end) 

360 data = self.check_for_negative_concentrations(data) 

361 return data, meta 

362 

363 def download_data(self, file_name: str, meta_file: str, station, statistics_per_var, sampling, 

364 store_data_locally=True, data_origin: Dict = None, time_dim=DEFAULT_TIME_DIM, 

365 target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM) -> [xr.DataArray, pd.DataFrame]: 

366 """ 

367 Download data from TOAR database using the JOIN interface or load local era5 data. 

368 

369 Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally 

370 stored locally using given names for file and meta file. 

371 

372 :param file_name: name of file to save data to (containing full path) 

373 :param meta_file: name of the meta data file (also containing full path) 

374 

375 :return: downloaded data and its meta data 

376 """ 

377 df_all = {} 

378 df_era5, df_toar = None, None 

379 meta_era5, meta_toar = None, None 

380 if data_origin is not None: 380 ↛ 388line 380 didn't jump to line 388, because the condition on line 380 was never false

381 era5_origin = filter_dict_by_value(data_origin, "era5", True) 

382 era5_stats = select_from_dict(statistics_per_var, era5_origin.keys()) 

383 toar_origin = filter_dict_by_value(data_origin, "era5", False) 

384 toar_stats = select_from_dict(statistics_per_var, era5_origin.keys(), filter_cond=False) 

385 assert len(era5_origin) + len(toar_origin) == len(data_origin) 

386 assert len(era5_stats) + len(toar_stats) == len(statistics_per_var) 

387 else: 

388 era5_origin, toar_origin = None, None 

389 era5_stats, toar_stats = statistics_per_var, statistics_per_var 

390 

391 # load data 

392 if era5_origin is not None and len(era5_stats) > 0: 392 ↛ 394line 392 didn't jump to line 394, because the condition on line 392 was never true

393 # load era5 data 

394 df_era5, meta_era5 = data_sources.era5.load_era5(station_name=station, stat_var=era5_stats, 

395 sampling=sampling, data_origin=era5_origin) 

396 if toar_origin is None or len(toar_stats) > 0: 396 ↛ 401line 396 didn't jump to line 401, because the condition on line 396 was never false

397 # load combined data from toar-data (v2 & v1) 

398 df_toar, meta_toar = data_sources.toar_data.download_toar(station=station, toar_stats=toar_stats, 

399 sampling=sampling, data_origin=toar_origin) 

400 

401 if df_era5 is None and df_toar is None: 

402 raise data_sources.toar_data.EmptyQueryResult(f"No data available for era5 and toar-data") 

403 

404 df = pd.concat([df_era5, df_toar], axis=1, sort=True) 

405 if meta_era5 is not None and meta_toar is not None: 405 ↛ 406line 405 didn't jump to line 406, because the condition on line 405 was never true

406 meta = meta_era5.combine_first(meta_toar) 

407 else: 

408 meta = meta_era5 if meta_era5 is not None else meta_toar 

409 meta.loc["data_origin"] = str(data_origin) 

410 meta.loc["statistics_per_var"] = str(statistics_per_var) 

411 

412 df_all[station[0]] = df 

413 # convert df_all to xarray 

414 xarr = {k: xr.DataArray(v, dims=[time_dim, target_dim]) for k, v in df_all.items()} 

415 xarr = xr.Dataset(xarr).to_array(dim=iter_dim) 

416 if store_data_locally is True: 416 ↛ 420line 416 didn't jump to line 420, because the condition on line 416 was never false

417 # save locally as nc/csv file 

418 xarr.to_netcdf(path=file_name) 

419 meta.to_csv(meta_file) 

420 return xarr, meta 

421 

422 @staticmethod 

423 def check_station_meta(meta, station, data_origin, statistics_per_var): 

424 """ 

425 Search for the entries in meta data and compare the value with the requested values. 

426 

427 Will raise a FileNotFoundError if the values mismatch. 

428 """ 

429 check_dict = {"data_origin": str(data_origin), "statistics_per_var": str(statistics_per_var)} 

430 for (k, v) in check_dict.items(): 

431 if v is None or k not in meta.index: 431 ↛ 432line 431 didn't jump to line 432, because the condition on line 431 was never true

432 continue 

433 if meta.at[k, station[0]] != v: 433 ↛ 434line 433 didn't jump to line 434, because the condition on line 433 was never true

434 logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " 

435 f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new " 

436 f"grapping from web.") 

437 raise FileNotFoundError 

438 

439 def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: 

440 """ 

441 Set all negative concentrations to zero. 

442 

443 Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/ 

444 #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox", 

445 "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene". 

446 

447 :param data: data array containing variables to check 

448 :param minimum: minimum value, by default this should be 0 

449 

450 :return: corrected data 

451 """ 

452 used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values)) 

453 if len(used_chem_vars) > 0: 453 ↛ 455line 453 didn't jump to line 455, because the condition on line 453 was never false

454 data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) 

455 return data 

456 

457 def setup_data_path(self, data_path: str, sampling: str): 

458 return os.path.join(os.path.abspath(data_path), sampling) 

459 

460 def shift(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray: 

461 """ 

462 Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0). 

463 

464 :param data: data set to shift 

465 :param dim: dimension along shift is applied 

466 :param window: number of steps to shift (corresponds to the window length) 

467 :param offset: use offset to move the window by as many time steps as given in offset. This can be used, if the 

468 index time of a history element is not the last timestamp. E.g. you could use offset=23 when dealing with 

469 hourly data in combination with daily data (values from 00 to 23 are aggregated on 00 the same day). 

470 

471 :return: shifted data 

472 """ 

473 start = 1 

474 end = 1 

475 if window <= 0: 

476 start = window 

477 else: 

478 end = window + 1 

479 res = [] 

480 _range = list(map(lambda x: x + offset, range(start, end))) 

481 for w in _range: 

482 res.append(data.shift({dim: -w})) 

483 window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim) 

484 res = xr.concat(res, dim=window_array) 

485 return res 

486 

487 @staticmethod 

488 def create_index_array(index_name: str, index_value: Iterable[int], squeeze_dim: str) -> xr.DataArray: 

489 """ 

490 Create an 1D xr.DataArray with given index name and value. 

491 

492 :param index_name: name of dimension 

493 :param index_value: values of this dimension 

494 

495 :return: this array 

496 """ 

497 ind = pd.DataFrame({'val': index_value}, index=index_value) 

498 res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze( 

499 dim=squeeze_dim, drop=True) 

500 res.name = index_name 

501 return res 

502 

503 @staticmethod 

504 def _set_file_name(path, station, statistics_per_var): 

505 all_vars = sorted(statistics_per_var.keys()) 

506 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc") 

507 

508 @staticmethod 

509 def _set_meta_file_name(path, station, statistics_per_var): 

510 all_vars = sorted(statistics_per_var.keys()) 

511 return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv") 

512 

513 def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None, 

514 use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs): 

515 """ 

516 Interpolate values according to different methods. 

517 

518 (Copy paste from dataarray.interpolate_na) 

519 

520 :param dim: 

521 Specifies the dimension along which to interpolate. 

522 :param method: 

523 {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 

524 'polynomial', 'barycentric', 'krog', 'pchip', 

525 'spline', 'akima'}, optional 

526 String indicating which method to use for interpolation: 

527 

528 - 'linear': linear interpolation (Default). Additional keyword 

529 arguments are passed to ``numpy.interp`` 

530 - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 

531 'polynomial': are passed to ``scipy.interpolate.interp1d``. If 

532 method=='polynomial', the ``order`` keyword argument must also be 

533 provided. 

534 - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their 

535 respective``scipy.interpolate`` classes. 

536 :param limit: 

537 default None 

538 Maximum number of consecutive NaNs to fill. Must be greater than 0 

539 or None for no limit. 

540 :param use_coordinate: 

541 default True 

542 Specifies which index to use as the x values in the interpolation 

543 formulated as `y = f(x)`. If False, values are treated as if 

544 eqaully-spaced along `dim`. If True, the IndexVariable `dim` is 

545 used. If use_coordinate is a string, it specifies the name of a 

546 coordinate variariable to use as the index. 

547 :param kwargs: 

548 

549 :return: xarray.DataArray 

550 """ 

551 data = self.create_full_time_dim(data, dim, sampling) 

552 return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs) 

553 

554 @staticmethod 

555 def create_full_time_dim(data, dim, sampling): 

556 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" 

557 start = data.coords[dim].values[0] 

558 end = data.coords[dim].values[-1] 

559 freq = {"daily": "1D", "hourly": "1H"}.get(sampling) 

560 datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq)) 

561 t = data.sel({dim: start}, drop=True) 

562 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) 

563 res = res.transpose(*data.dims) 

564 res.loc[data.coords] = data 

565 return res 

566 

567 def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: 

568 """ 

569 Create a xr.DataArray containing history data. 

570 

571 Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted 

572 data. This is used to represent history in the data. Results are stored in history attribute. 

573 

574 :param dim_name_of_inputs: Name of dimension which contains the input variables 

575 :param window: number of time steps to look back in history 

576 Note: window will be treated as negative value. This should be in agreement with looking back on 

577 a time line. Nonetheless positive values are allowed but they are converted to its negative 

578 expression 

579 :param dim_name_of_shift: Dimension along shift will be applied 

580 """ 

581 window = -abs(window) 

582 data = self.input_data 

583 offset = self.window_history_offset + self.window_history_end 

584 self.history = self.shift(data, dim_name_of_shift, window, offset=offset) 

585 

586 def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, 

587 window: int) -> None: 

588 """ 

589 Create a xr.DataArray containing labels. 

590 

591 Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label 

592 attribute. 

593 

594 :param dim_name_of_target: Name of dimension which contains the target variable 

595 :param target_var: Name of target variable in 'dimension' 

596 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied 

597 :param window: lead time of label 

598 """ 

599 window = abs(window) 

600 data = self.target_data 

601 self.label = self.shift(data, dim_name_of_shift, window) 

602 

603 def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: 

604 """ 

605 Create a xr.DataArray containing observations. 

606 

607 Observations are defined as value of the current time step t. Set observation attribute. 

608 

609 :param dim_name_of_target: Name of dimension which contains the observation variable 

610 :param target_var: Name of observation variable(s) in 'dimension' 

611 :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied 

612 """ 

613 data = self.target_data 

614 self.observation = self.shift(data, dim_name_of_shift, 0) 

615 

616 def remove_nan(self, dim: str) -> None: 

617 """ 

618 Remove all NAs slices along dim which contain nans in history, label and observation. 

619 

620 This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute. 

621 

622 :param dim: dimension along the remove is performed. 

623 """ 

624 intersect = [] 

625 if (self.history is not None) and (self.label is not None): 625 ↛ 638line 625 didn't jump to line 638, because the condition on line 625 was never false

626 non_nan_history = self.history.dropna(dim=dim) 

627 non_nan_label = self.label.dropna(dim=dim) 

628 non_nan_observation = self.observation.dropna(dim=dim) 

629 if non_nan_label.coords[dim].shape[0] == 0: 629 ↛ 630line 629 didn't jump to line 630, because the condition on line 629 was never true

630 raise ValueError(f'self.label consist of NaNs only - station {self.station} is therefore dropped') 

631 if non_nan_history.coords[dim].shape[0] == 0: 

632 raise ValueError(f'self.history consist of NaNs only - station {self.station} is therefore dropped') 

633 if non_nan_observation.coords[dim].shape[0] == 0: 633 ↛ 634line 633 didn't jump to line 634, because the condition on line 633 was never true

634 raise ValueError(f'self.observation consist of NaNs only - station {self.station} is therefore dropped') 

635 intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, 

636 non_nan_observation.coords[dim].values)) 

637 

638 if len(intersect) < max(self.min_length, 1): 638 ↛ 639line 638 didn't jump to line 639, because the condition on line 638 was never true

639 self.history = None 

640 self.label = None 

641 self.observation = None 

642 else: 

643 self.history = self.history.sel({dim: intersect}) 

644 self.label = self.label.sel({dim: intersect}) 

645 self.observation = self.observation.sel({dim: intersect}) 

646 

647 def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray: 

648 """ 

649 Set start and end date for slicing and execute self._slice(). 

650 

651 :param data: data to slice 

652 :param coord: name of axis to slice 

653 

654 :return: sliced data 

655 """ 

656 start = start if start is not None else data.coords[self.time_dim][0].values 

657 end = end if end is not None else data.coords[self.time_dim][-1].values 

658 return self._slice(data, start, end, self.time_dim) 

659 

660 @staticmethod 

661 def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray: 

662 """ 

663 Slice through a given data_item (for example select only values of 2011). 

664 

665 :param data: data to slice 

666 :param start: start date of slice 

667 :param end: end date of slice 

668 :param coord: name of axis to slice 

669 

670 :return: sliced data 

671 """ 

672 return data.loc[{coord: slice(str(start), str(end))}] 

673 

674 def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: 

675 """ 

676 Set up transformation by extracting all relevant information. 

677 

678 * Either return new empty DataClass instances if given transformation arg is None, 

679 * or return given object twice if transformation is a DataClass instance, 

680 * or return the inputs and targets attributes if transformation is a TransformationClass instance (default 

681 design behaviour) 

682 """ 

683 if transformation is None: 

684 return None, None 

685 elif isinstance(transformation, dict): 

686 return copy.deepcopy(transformation), copy.deepcopy(transformation) 

687 elif isinstance(transformation, tuple) and len(transformation) == 2: 

688 return copy.deepcopy(transformation) 

689 else: 

690 raise NotImplementedError("Cannot handle this.") 

691 

692 @staticmethod 

693 def check_inverse_transform_params(method: str, mean=None, std=None, min=None, max=None) -> None: 

694 """ 

695 Support inverse_transformation method. 

696 

697 Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas 

698 normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements. 

699 

700 :param mean: data with all mean values 

701 :param std: data with all standard deviation values 

702 :param method: name of transformation method 

703 """ 

704 msg = "" 

705 if method in ['standardise', 'centre'] and mean is None: 

706 msg += "mean, " 

707 if method == 'standardise' and std is None: 

708 msg += "std, " 

709 if method == "min_max" and min is None: 

710 msg += "min, " 

711 if method == "min_max" and max is None: 

712 msg += "max, " 

713 if len(msg) > 0: 

714 raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}") 

715 

716 def inverse_transform(self, data_in, opts, transformation_dim) -> xr.DataArray: 

717 """ 

718 Perform inverse transformation. 

719 

720 Will raise an AssertionError, if no transformation was performed before. Checks first, if all required 

721 statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by 

722 new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the 

723 current data is not transformed. 

724 """ 

725 

726 def f_inverse(data, method, mean=None, std=None, min=None, max=None, feature_range=None): 

727 if method == "standardise": 

728 return statistics.standardise_inverse(data, mean, std) 

729 elif method == "centre": 

730 return statistics.centre_inverse(data, mean) 

731 elif method == "min_max": 

732 return statistics.min_max_inverse(data, min, max, feature_range) 

733 elif method == "log": 

734 return statistics.log_inverse(data, mean, std) 

735 else: 

736 raise NotImplementedError 

737 

738 transformed_values = [] 

739 squeeze = False 

740 if transformation_dim in data_in.coords: 

741 if transformation_dim not in data_in.dims: 

742 data_in = data_in.expand_dims(transformation_dim) 

743 squeeze = True 

744 else: 

745 raise IndexError(f"Could not find given dimension: {transformation_dim}. Available is: {data_in.coords}") 

746 for var in data_in.variables.values: 

747 data_var = data_in.sel(**{transformation_dim: [var]}) 

748 var_opts = opts.get(var, {}) 

749 _method = var_opts.get("method", None) 

750 if _method is None: 

751 raise AssertionError(f"Inverse transformation method is not set for {var}.") 

752 self.check_inverse_transform_params(**var_opts) 

753 values = f_inverse(data_var, **var_opts) 

754 transformed_values.append(values) 

755 res = xr.concat(transformed_values, dim=transformation_dim) 

756 return res.squeeze(transformation_dim) if squeeze else res 

757 

758 def apply_transformation(self, data, base=None, dim=0, inverse=False): 

759 """ 

760 Apply transformation on external data. Specify if transformation should be based on parameters related to input 

761 or target data using `base`. This method can also apply inverse transformation. 

762 

763 :param data: 

764 :param base: 

765 :param dim: 

766 :param inverse: 

767 :return: 

768 """ 

769 if base in ["target", 1]: 

770 pos = 1 

771 elif base in ["input", 0]: 

772 pos = 0 

773 else: 

774 raise ValueError("apply transformation requires a reference for transformation options. Please specify if" 

775 "you want to use input or target transformation using the parameter 'base'. Given was: " + 

776 base) 

777 return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse, 

778 transformation_dim=self.target_dim) 

779 

780 def _hash_list(self): 

781 return sorted(list(set(self._hash))) 

782 

783 def _get_hash(self): 

784 hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode() 

785 return hashlib.md5(hash).hexdigest()