Coverage for mlair/data_handler/default_data_handler.py: 69%

254 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1 

2__author__ = 'Lukas Leufen' 

3__date__ = '2020-09-21' 

4 

5import copy 

6import inspect 

7import gc 

8import logging 

9import os 

10import pickle 

11import random 

12import dill 

13import shutil 

14from functools import reduce 

15from typing import Tuple, Union, List 

16import multiprocessing 

17import psutil 

18import dask 

19 

20import numpy as np 

21import xarray as xr 

22 

23from mlair.data_handler.abstract_data_handler import AbstractDataHandler 

24from mlair.helpers import remove_items, to_list, TimeTrackingWrapper 

25from mlair.helpers.data_sources.data_loader import EmptyQueryResult 

26 

27 

28number = Union[float, int] 

29num_or_list = Union[number, List[number]] 

30 

31 

32class DefaultDataHandler(AbstractDataHandler): 

33 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler 

34 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation 

35 

36 _requirements = data_handler.requirements() 

37 _store_attributes = data_handler.store_attributes() 

38 _skip_args = AbstractDataHandler._skip_args + ["id_class"] 

39 

40 DEFAULT_ITER_DIM = "Stations" 

41 DEFAULT_TIME_DIM = "datetime" 

42 MAX_NUMBER_MULTIPROCESSING = 16 

43 

44 def __init__(self, id_class: data_handler, experiment_path: str, min_length: int = 0, 

45 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None, 

46 store_processed_data=True, iter_dim=DEFAULT_ITER_DIM, time_dim=DEFAULT_TIME_DIM, 

47 use_multiprocessing=True, max_number_multiprocessing=MAX_NUMBER_MULTIPROCESSING): 

48 super().__init__() 

49 self.id_class = id_class 

50 self.time_dim = time_dim 

51 self.iter_dim = iter_dim 

52 self.min_length = min_length 

53 self._X = None 

54 self._Y = None 

55 self._X_extreme = None 

56 self._Y_extreme = None 

57 self._data_intersection = None 

58 self._len = None 

59 self._len_upsampling = None 

60 self._use_multiprocessing = use_multiprocessing 

61 self._max_number_multiprocessing = max_number_multiprocessing 

62 _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) 

63 self._save_file = os.path.join(experiment_path, "data", f"{_name_affix}.pickle") 

64 self._collection = self._create_collection() 

65 self.harmonise_X() 

66 self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.time_dim) 

67 self._store(fresh_store=True, store_processed_data=store_processed_data) 

68 

69 @classmethod 

70 def build(cls, station: str, **kwargs): 

71 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} 

72 sp = cls.data_handler(station, **sp_keys) 

73 dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} 

74 return cls(sp, **dp_args) 

75 

76 def _create_collection(self): 

77 return [self.id_class] 

78 

79 def _reset_data(self): 

80 self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None 

81 gc.collect() 

82 

83 def _cleanup(self): 

84 directory = os.path.dirname(self._save_file) 

85 if os.path.exists(directory) is False: 

86 os.makedirs(directory, exist_ok=True) 

87 if os.path.exists(self._save_file): 

88 shutil.rmtree(self._save_file, ignore_errors=True) 

89 

90 def _store(self, fresh_store=False, store_processed_data=True): 

91 if store_processed_data is True: 91 ↛ exitline 91 didn't return from function '_store', because the condition on line 91 was never false

92 self._cleanup() if fresh_store is True else None 

93 data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme} 

94 data = self._force_dask_computation(data) 

95 with open(self._save_file, "wb") as f: 

96 dill.dump(data, f, protocol=4) 

97 logging.debug(f"save pickle data to {self._save_file}") 

98 self._reset_data() 

99 

100 def get_store_attributes(self): 

101 attr_dict = {} 

102 for attr in self.store_attributes(): 

103 try: 

104 val = self.__getattribute__(attr) 

105 except AttributeError: 

106 val = self.id_class.__getattribute__(attr) 

107 attr_dict[attr] = val 

108 return attr_dict 

109 

110 @staticmethod 

111 def _force_dask_computation(data): 

112 try: 

113 data = dask.compute(data)[0] 

114 except: 

115 pass 

116 return data 

117 

118 def _load(self): 

119 try: 

120 with open(self._save_file, "rb") as f: 

121 data = dill.load(f) 

122 logging.debug(f"load pickle data from {self._save_file}") 

123 self._X, self._Y = data["X"], data["Y"] 

124 self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"] 

125 except FileNotFoundError: 

126 pass 

127 

128 def get_data(self, upsampling=False, as_numpy=True): 

129 self._load() 

130 as_numpy_X, as_numpy_Y = as_numpy if isinstance(as_numpy, tuple) else (as_numpy, as_numpy) 

131 X = self.get_X(upsampling, as_numpy_X) 

132 Y = self.get_Y(upsampling, as_numpy_Y) 

133 self._reset_data() 

134 return X, Y 

135 

136 def __repr__(self): 

137 return str(self._collection[0]) 

138 

139 def __len__(self, upsampling=False): 

140 if upsampling is False: 140 ↛ 143line 140 didn't jump to line 143, because the condition on line 140 was never false

141 return self._len 

142 else: 

143 return self._len_upsampling 

144 

145 def get_X_original(self): 

146 X = [] 

147 for data in self._collection: 

148 X.append(data.get_X()) 

149 return X 

150 

151 def get_Y_original(self): 

152 Y = self._collection[0].get_Y() 

153 return Y 

154 

155 @staticmethod 

156 def _to_numpy(d): 

157 return list(map(lambda x: np.copy(x), d)) 

158 

159 def get_X(self, upsampling=False, as_numpy=True): 

160 no_data = (self._X is None) 

161 self._load() if no_data is True else None 

162 X = self._X if upsampling is False else self._X_extreme 

163 self._reset_data() if no_data is True else None 

164 return self._to_numpy(X) if as_numpy is True else X 

165 

166 def get_Y(self, upsampling=False, as_numpy=True): 

167 no_data = (self._Y is None) 

168 self._load() if no_data is True else None 

169 Y = self._Y if upsampling is False else self._Y_extreme 

170 self._reset_data() if no_data is True else None 

171 return self._to_numpy([Y]) if as_numpy is True else Y 

172 

173 @TimeTrackingWrapper 

174 def harmonise_X(self): 

175 X_original, Y_original = self.get_X_original(), self.get_Y_original() 

176 dim = self.time_dim 

177 intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original)) 

178 if len(intersect) < max(self.min_length, 1): 178 ↛ 179line 178 didn't jump to line 179, because the condition on line 178 was never true

179 raise ValueError(f"There is no intersection of X.") 

180 else: 

181 X = list(map(lambda x: x.sel({dim: intersect}), X_original)) 

182 Y = Y_original.sel({dim: intersect}) 

183 self._data_intersection = intersect 

184 self._X, self._Y = X, Y 

185 self._len = len(self._data_intersection) 

186 

187 def get_observation(self): 

188 dim = self.time_dim 

189 if self._data_intersection is not None: 

190 return self.id_class.observation.sel({dim: self._data_intersection}).copy().squeeze() 

191 else: 

192 return self.id_class.observation.copy().squeeze() 

193 

194 def apply_transformation(self, data, base="target", dim=0, inverse=False): 

195 return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse) 

196 

197 @TimeTrackingWrapper 

198 def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, 

199 timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM): 

200 """ 

201 Multiply extremes. 

202 

203 This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can 

204 also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of 

205 floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised 

206 space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be 

207 extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is 

208 used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can 

209 identify those "artificial" data points later easily. Extreme inputs and labels are stored in 

210 self.extremes_history and self.extreme_labels, respectively. 

211 

212 :param extreme_values: user definition of extreme 

213 :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values, 

214 if True only extract values larger than extreme_values 

215 :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime 

216 """ 

217 if extreme_values is None: 217 ↛ 224line 217 didn't jump to line 224, because the condition on line 217 was never false

218 logging.debug(f"No extreme values given, skip multiply extremes") 

219 self._X_extreme, self._Y_extreme = self._X, self._Y 

220 self._len_upsampling = self._len 

221 return 

222 

223 # check type if inputs 

224 extreme_values = to_list(extreme_values) 

225 for i in extreme_values: 

226 if not isinstance(i, number.__args__): 

227 raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element " 

228 f"{i} is type {type(i)}") 

229 

230 extremes_X, extremes_Y = None, None 

231 for extr_val in sorted(extreme_values): 

232 # check if some extreme values are already extracted 

233 if (extremes_X is None) or (extremes_Y is None): 

234 X, Y = self._X, self._Y 

235 extremes_X, extremes_Y = X, Y 

236 else: # one extr value iteration is done already: self.extremes_label is NOT None... 

237 X, Y = self._X_extreme, self._Y_extreme 

238 

239 # extract extremes based on occurrence in labels 

240 other_dims = remove_items(list(Y.dims), dim) 

241 if extremes_on_right_tail_only: 

242 extreme_idx = (extremes_Y > extr_val).any(dim=other_dims) 

243 else: 

244 extreme_idx = xr.concat([(extremes_Y < -extr_val).any(dim=other_dims[0]), 

245 (extremes_Y > extr_val).any(dim=other_dims[0])], 

246 dim=other_dims[0]).any(dim=other_dims[0]) 

247 

248 sel = extreme_idx[extreme_idx].coords[dim].values 

249 extremes_X = list(map(lambda x: x.sel(**{dim: sel}), extremes_X)) 

250 self._add_timedelta(extremes_X, dim, timedelta) 

251 extremes_Y = extremes_Y.sel(**{dim: extreme_idx}) 

252 self._add_timedelta([extremes_Y], dim, timedelta) 

253 

254 self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim) 

255 self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X)) 

256 self._len_upsampling = len(self._X_extreme[0].coords[dim]) 

257 

258 @staticmethod 

259 def _add_timedelta(data, dim, timedelta): 

260 for d in data: 

261 d.coords[dim] = d.coords[dim].values + np.timedelta64(*timedelta) 

262 

263 @classmethod 

264 def transformation(cls, set_stations, tmp_path=None, dh_transformation=None, **kwargs): 

265 """ 

266 ### supported transformation methods 

267 

268 Currently supported methods are: 

269 

270 * standardise (default, if method is not given) 

271 * centre 

272 * min_max 

273 * log 

274 

275 ### mean and std estimation 

276 

277 Mean and std (depending on method) are estimated. For each station, mean and std are calculated and afterwards 

278 aggregated using the mean value over all station-wise metrics. This method is not exactly accurate, especially 

279 regarding the std calculation but therefore much faster. Furthermore, it is a weighted mean weighted by the 

280 time series length / number of data itself - a longer time series has more influence on the transformation 

281 settings than a short time series. The estimation of the std in less accurate, because the unweighted mean of 

282 all stds in not equal to the true std, but still the mean of all station-wise std is a decent estimate. Finally, 

283 the real accuracy of mean and std is less important, because it is "just" a transformation / scaling. 

284 

285 ### mean and std given 

286 

287 If mean and std are not None, the default data handler expects this parameters to match the data and applies 

288 this values to the data. Make sure that all dimensions and/or coordinates are in agreement. 

289 

290 ### min and max given 

291 If min and max are not None, the default data handler expects this parameters to match the data and applies 

292 this values to the data. Make sure that all dimensions and/or coordinates are in agreement. 

293 """ 

294 if dh_transformation is None: 294 ↛ 297line 294 didn't jump to line 297, because the condition on line 294 was never false

295 dh_transformation = cls.data_handler_transformation 

296 

297 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in dh_transformation.requirements() if k in kwargs} 

298 if "transformation" not in sp_keys.keys(): 298 ↛ 299line 298 didn't jump to line 299, because the condition on line 298 was never true

299 return 

300 transformation_dict = ({}, {}) 

301 

302 max_process = kwargs.get("max_number_multiprocessing", 16) 

303 set_stations = to_list(set_stations) 

304 n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus 

305 if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution 305 ↛ 306line 305 didn't jump to line 306, because the condition on line 305 was never true

306 logging.info("use parallel transformation approach") 

307 pool = multiprocessing.Pool(n_process) # use only physical cpus 

308 logging.info(f"running {getattr(pool, '_processes')} processes in parallel") 

309 sp_keys.update({"tmp_path": tmp_path, "return_strategy": "reference"}) 

310 output = [ 

311 pool.apply_async(f_proc, args=(dh_transformation, station), kwds=sp_keys) 

312 for station in set_stations] 

313 for p in output: 

314 _res_file, s = p.get() 

315 with open(_res_file, "rb") as f: 

316 dh = dill.load(f) 

317 os.remove(_res_file) 

318 transformation_dict = cls.update_transformation_dict(dh, transformation_dict) 

319 pool.close() 

320 pool.join() 

321 else: # serial solution 

322 logging.info("use serial transformation approach") 

323 sp_keys.update({"return_strategy": "result"}) 

324 for station in set_stations: 

325 dh, s = f_proc(dh_transformation, station, **sp_keys) 

326 transformation_dict = cls.update_transformation_dict(dh, transformation_dict) 

327 

328 # aggregate all information 

329 iter_dim = sp_keys.get("iter_dim", cls.DEFAULT_ITER_DIM) 

330 transformation_dict = cls.aggregate_transformation(transformation_dict, iter_dim) 

331 return transformation_dict 

332 

333 @classmethod 

334 def aggregate_transformation(cls, transformation_dict, iter_dim): 

335 pop_list = [] 

336 for i, transformation in enumerate(transformation_dict): 

337 for k in transformation.keys(): 

338 try: 

339 if transformation[k]["mean"] is not None: 339 ↛ 341line 339 didn't jump to line 341, because the condition on line 339 was never false

340 transformation_dict[i][k]["mean"] = transformation[k]["mean"].mean(iter_dim) 

341 if transformation[k]["std"] is not None: 341 ↛ 343line 341 didn't jump to line 343, because the condition on line 341 was never false

342 transformation_dict[i][k]["std"] = transformation[k]["std"].mean(iter_dim) 

343 if transformation[k]["min"] is not None: 343 ↛ 344line 343 didn't jump to line 344, because the condition on line 343 was never true

344 transformation_dict[i][k]["min"] = transformation[k]["min"].min(iter_dim) 

345 if transformation[k]["max"] is not None: 345 ↛ 346line 345 didn't jump to line 346, because the condition on line 345 was never true

346 transformation_dict[i][k]["max"] = transformation[k]["max"].max(iter_dim) 

347 if "feature_range" in transformation[k].keys(): 347 ↛ 348line 347 didn't jump to line 348, because the condition on line 347 was never true

348 transformation_dict[i][k]["feature_range"] = transformation[k]["feature_range"] 

349 except KeyError: 

350 pop_list.append((i, k)) 

351 for (i, k) in pop_list: 351 ↛ 352line 351 didn't jump to line 352, because the loop on line 351 never started

352 transformation_dict[i].pop(k) 

353 return transformation_dict 

354 

355 @classmethod 

356 def update_transformation_dict(cls, dh, transformation_dict): 

357 """Inner method that is performed in both serial and parallel approach.""" 

358 if dh is not None: 

359 for i, transformation in enumerate(dh._transformation): 

360 for var in transformation.keys(): 

361 if var not in transformation_dict[i].keys(): 

362 transformation_dict[i][var] = {} 

363 opts = transformation[var] 

364 if not transformation_dict[i][var].get("method", opts["method"]) == opts["method"]: 364 ↛ 366line 364 didn't jump to line 366, because the condition on line 364 was never true

365 # data handlers with filters are allowed to change transformation method to standardise 

366 assert hasattr(dh, "filter_dim") and opts["method"] == "standardise" 

367 transformation_dict[i][var]["method"] = opts["method"] 

368 for k in ["mean", "std", "min", "max"]: 

369 old = transformation_dict[i][var].get(k, None) 

370 new = opts.get(k) 

371 transformation_dict[i][var][k] = new if old is None else old.combine_first(new) 

372 if "feature_range" in opts.keys(): 372 ↛ 373line 372 didn't jump to line 373, because the condition on line 372 was never true

373 transformation_dict[i][var]["feature_range"] = opts.get("feature_range", None) 

374 return transformation_dict 

375 

376 def get_coordinates(self): 

377 return self.id_class.get_coordinates() 

378 

379 

380def f_proc(data_handler, station, return_strategy="", tmp_path=None, **sp_keys): 

381 """ 

382 Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and 

383 therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and 

384 the station that was used. This function must be implemented globally to work together with multiprocessing. 

385 """ 

386 assert return_strategy in ["result", "reference"] 

387 try: 

388 res = data_handler(station, **sp_keys) 

389 except (AttributeError, EmptyQueryResult, KeyError, ValueError, IndexError) as e: 

390 logging.info(f"remove station {station} because it raised an error: {e}") 

391 res = None 

392 if return_strategy == "result": 392 ↛ 395line 392 didn't jump to line 395, because the condition on line 392 was never false

393 return res, station 

394 else: 

395 _tmp_file = os.path.join(tmp_path, f"{station}_{'%032x' % random.getrandbits(128)}.pickle") 

396 with open(_tmp_file, "wb") as f: 

397 dill.dump(res, f, protocol=4) 

398 return _tmp_file, station