Coverage for mlair/run_modules/post_processing.py: 6%

574 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-12-02 15:24 +0000

1"""Post-processing module.""" 

2 

3__author__ = "Lukas Leufen, Felix Kleinert" 

4__date__ = '2019-12-11' 

5 

6import inspect 

7import logging 

8import os 

9import sys 

10import traceback 

11import copy 

12from typing import Dict, Tuple, Union, List, Callable 

13 

14import numpy as np 

15import pandas as pd 

16import xarray as xr 

17import datetime as dt 

18 

19from mlair.configuration import path_config 

20from mlair.data_handler import Bootstraps, KerasIterator 

21from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope 

22from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables 

23from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel 

24from mlair.model_modules import AbstractModelClass 

25from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ 

26 PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \ 

27 PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric, PlotSeasonalMSEStack, \ 

28 PlotErrorsOnMap 

29from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ 

30 PlotPeriodogram, PlotDataHistogram 

31from mlair.run_modules.run_environment import RunEnvironment 

32 

33 

34class PostProcessing(RunEnvironment): 

35 """ 

36 Perform post-processing for performance evaluation. 

37 

38 Schedule of post-processing: 

39 #. train an ordinary least squared model (ols) for reference 

40 #. create forecasts for nn, ols, and persistence 

41 #. evaluate feature importance with bootstrapped predictions 

42 #. calculate skill scores 

43 #. create plots 

44 

45 Required objects [scope] from data store: 

46 * `model` [.] or locally saved model plus `model_name` [model] and `model` [model] 

47 * `generator` [train, val, test, train_val] 

48 * `forecast_path` [.] 

49 * `plot_path` [postprocessing] 

50 * `model_path` [.] 

51 * `target_var` [.] 

52 * `sampling` [.] 

53 * `output_shape` [model] 

54 * `evaluate_feature_importance` [postprocessing] and if enabled: 

55 

56 * `create_new_bootstraps` [postprocessing] 

57 * `bootstrap_path` [postprocessing] 

58 * `number_of_bootstraps` [postprocessing] 

59 

60 Optional objects 

61 * `batch_size` [model] 

62 

63 Creates 

64 * forecasts in `forecast_path` if enabled 

65 * bootstraps in `bootstrap_path` if enabled 

66 * plots in `plot_path` 

67 

68 """ 

69 

70 def __init__(self): 

71 """Initialise and run post-processing.""" 

72 super().__init__() 

73 self.model: AbstractModelClass = self._load_model() 

74 self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] 

75 self.ols_model = None 

76 self.persi_model = True 

77 self.batch_size: int = self.data_store.get_default("batch_size", "model", 64) 

78 self.test_data = self.data_store.get("data_collection", "test") 

79 batch_path = self.data_store.get("batch_path", scope="test") 

80 self.test_data_distributed = KerasIterator(self.test_data, self.batch_size, model=self.model, name="test", 

81 batch_path=batch_path) 

82 self.train_data = self.data_store.get("data_collection", "train") 

83 self.val_data = self.data_store.get("data_collection", "val") 

84 self.train_val_data = self.data_store.get("data_collection", "train_val") 

85 self.forecast_path = self.data_store.get("forecast_path") 

86 self.plot_path: str = self.data_store.get("plot_path") 

87 self.target_var = self.data_store.get("target_var") 

88 self._sampling = self.data_store.get("sampling") 

89 self.window_lead_time = extract_value(self.data_store.get("output_shape", "model")) 

90 self.skill_scores = None 

91 self.errors = None 

92 self.feature_importance_skill_scores = None 

93 self.uncertainty_estimate = None 

94 self.uncertainty_estimate_seasons = {} 

95 self.block_mse_per_station = None 

96 self.block_mse = None 

97 self.competitor_path = self.data_store.get("competitor_path") 

98 self.competitors = to_list(self.data_store.get_default("competitors", default=[])) 

99 self.forecast_indicator = "nn" 

100 self.observation_indicator = "obs" 

101 self.ahead_dim = "ahead" 

102 self.boot_var_dim = "boot_var" 

103 self.uncertainty_estimate_boot_dim = "boots" 

104 self.model_type_dim = "type" 

105 self.index_dim = "index" 

106 self.model_display_name = self.data_store.get_default("model_display_name", default=self.model.model_name) 

107 self._run() 

108 

109 def _run(self): 

110 # ols model 

111 self.train_ols_model() 

112 

113 # persi model 

114 self.setup_persistence() 

115 

116 # forecasts on test data 

117 self.make_prediction(self.test_data) 

118 self.make_prediction(self.train_val_data) 

119 

120 # calculate error metrics on test data 

121 self.calculate_test_score() 

122 

123 # calculate monthly block mse 

124 self.block_mse, self.block_mse_per_station = self.calculate_block_mse(evaluate_competitors=True, 

125 separate_ahead=True, block_length="1m") 

126 

127 # sample uncertainty 

128 if self.data_store.get("do_uncertainty_estimate", "postprocessing"): 

129 self.estimate_sample_uncertainty(separate_ahead=True) 

130 

131 # feature importance bootstraps 

132 if self.data_store.get("evaluate_feature_importance", "postprocessing"): 

133 with TimeTracking(name="evaluate_feature_importance", log_on_enter=True): 

134 create_new_bootstraps = self.data_store.get("create_new_bootstraps", "feature_importance") 

135 bootstrap_method = self.data_store.get("bootstrap_method", "feature_importance") 

136 bootstrap_type = self.data_store.get("bootstrap_type", "feature_importance") 

137 self.calculate_feature_importance(create_new_bootstraps, bootstrap_type=bootstrap_type, 

138 bootstrap_method=bootstrap_method) 

139 if self.feature_importance_skill_scores is not None: 

140 self.report_feature_importance_results(self.feature_importance_skill_scores) 

141 

142 # skill scores and error metrics 

143 with TimeTracking(name="calculate_error_metrics", log_on_enter=True): 

144 skill_score_competitive, _, skill_score_climatological, errors = self.calculate_error_metrics() 

145 self.skill_scores = (skill_score_competitive, skill_score_climatological) 

146 with TimeTracking(name="report_error_metrics", log_on_enter=True): 

147 self.report_error_metrics(errors) 

148 self.report_error_metrics({self.forecast_indicator: skill_score_climatological}) 

149 self.report_error_metrics({"skill_score": skill_score_competitive}) 

150 self.store_errors(errors) 

151 

152 # plotting 

153 self.plot() 

154 

155 @TimeTrackingWrapper 

156 def estimate_sample_uncertainty(self, separate_ahead=False): 

157 """ 

158 Estimate sample uncertainty by using a bootstrap approach. Forecasts are split into individual blocks along time 

159 and randomly drawn with replacement. The resulting behaviour of the error indicates the robustness of each 

160 analyzed model to quantify which model might be superior compared to others. 

161 """ 

162 logging.info("start estimate_sample_uncertainty") 

163 n_boots = self.data_store.get_default("n_boots", default=1000, scope="uncertainty_estimate") 

164 block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate") 

165 evaluate_competitors = self.data_store.get_default("evaluate_competitors", default=True, 

166 scope="uncertainty_estimate") 

167 if evaluate_competitors is True and separate_ahead is True and block_length == "1m": 

168 block_mse, block_mse_per_station = self.block_mse, self.block_mse_per_station 

169 else: 

170 block_mse, block_mse_per_station = self.calculate_block_mse(evaluate_competitors=evaluate_competitors, 

171 separate_ahead=separate_ahead, 

172 block_length=block_length) 

173 estimate = statistics.create_n_bootstrap_realizations( 

174 block_mse, dim_name_time=self.index_dim, dim_name_model=self.model_type_dim, 

175 dim_name_boots=self.uncertainty_estimate_boot_dim, n_boots=n_boots, seasons=["DJF", "MAM", "JJA", "SON"]) 

176 self.uncertainty_estimate = estimate.pop("") 

177 self.uncertainty_estimate_seasons = estimate 

178 self.report_sample_uncertainty() 

179 

180 def report_sample_uncertainty(self, percentiles: list = None): 

181 """ 

182 Store raw results of uncertainty estimate and calculate aggregate statistics and store as raw data but also as 

183 markdown and latex. 

184 """ 

185 report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

186 path_config.check_path_and_create(report_path) 

187 

188 # store raw results as nc 

189 file_name = os.path.join(report_path, "uncertainty_estimate_raw_results.nc") 

190 self.uncertainty_estimate.to_netcdf(path=file_name) 

191 for season in self.uncertainty_estimate_seasons.keys(): 

192 file_name = os.path.join(report_path, f"uncertainty_estimate_raw_results_{season}.nc") 

193 self.uncertainty_estimate_seasons[season].to_netcdf(path=file_name) 

194 

195 # store block mse per station 

196 file_name = os.path.join(report_path, f"block_mse_raw_results.nc") 

197 self.block_mse_per_station.to_netcdf(path=file_name) 

198 

199 # store statistics 

200 if percentiles is None: 

201 percentiles = [.05, .1, .25, .5, .75, .9, .95] 

202 

203 for season in [None] + list(self.uncertainty_estimate_seasons.keys()): 

204 estimate = self.uncertainty_estimate if season is None else self.uncertainty_estimate_seasons[season] 

205 affix = "" if season is None else f"_{season}" 

206 for ahead_steps in ["single", "multi"]: 

207 if ahead_steps == "single": 

208 try: 

209 df_descr = estimate.to_pandas().describe(percentiles=percentiles).astype("float32") 

210 except ValueError: 

211 df_descr = estimate.mean(self.ahead_dim).to_pandas().describe(percentiles=percentiles).astype("float32") 

212 else: 

213 if self.ahead_dim not in estimate.dims: 

214 continue 

215 df_descr = estimate.to_dataframe(self.model_type_dim).unstack().groupby(level=self.ahead_dim).describe( 

216 percentiles=percentiles).astype("float32") 

217 df_descr = df_descr.stack(-1) 

218 df_descr = df_descr.reorder_levels(df_descr.index.names[::-1]) 

219 df_sorter = ["count", "mean", "std", "min", *[f"{round(p * 100)}%" for p in percentiles], "max"] 

220 df_descr = df_descr.loc[df_sorter] 

221 column_format = tables.create_column_format_for_tex(df_descr) 

222 file_name = os.path.join(report_path, f"uncertainty_estimate_statistics_{ahead_steps}{affix}.%s") 

223 tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df_descr) 

224 tables.save_to_md(report_path, file_name % "md", df=df_descr) 

225 df_descr.to_csv(file_name % "csv", sep=";") 

226 

227 def calculate_block_mse(self, evaluate_competitors=True, separate_ahead=False, block_length="1m"): 

228 """ 

229 Transform data into blocks along time axis. Block length can be any frequency like '1m' or '7d. Data are only 

230 split along time axis, which means that a single block can have very diverse quantities regarding the number of 

231 station or actual data contained. This is intended to analyze not only the robustness against the time but also 

232 against the number of observations and diversity ot stations. 

233 """ 

234 all_stations = self.data_store.get("stations", "test") 

235 start = self.data_store.get("start", "test") 

236 end = self.data_store.get("end", "test") 

237 index_dim = self.index_dim 

238 coll_dim = "station" 

239 collector = [] 

240 for station in all_stations: 

241 # test data 

242 external_data = self._get_external_data(station, self.forecast_path) 

243 if external_data is None: 

244 logging.info(f"skip calculate_block_mse for {station} as no external_data are available") 

245 continue 

246 # competitors 

247 if evaluate_competitors is True: 

248 competitor = self.load_competitors(station) 

249 combined = self._combine_forecasts(external_data, competitor, dim=self.model_type_dim) 

250 else: 

251 combined = external_data 

252 

253 if combined is None: 

254 continue 

255 else: 

256 combined = self.create_full_time_dim(combined, index_dim, self._sampling, start, end) 

257 # get squared errors 

258 errors = self.create_error_array(combined) 

259 # calc mse for each block (single station) 

260 mse = errors.resample(indexer={index_dim: block_length}).mean(skipna=True) 

261 collector.append(mse.assign_coords({coll_dim: station})) 

262 

263 # combine all mse blocks 

264 mse_blocks_per_station = xr.concat(collector, dim=coll_dim) 

265 # calc mse for each block (average over all stations) 

266 mse_blocks = mse_blocks_per_station.mean(dim=coll_dim, skipna=True) 

267 # average also on ahead steps 

268 if separate_ahead is False: 

269 mse_blocks = mse_blocks.mean(dim=self.ahead_dim, skipna=True) 

270 mse_blocks_per_station = mse_blocks_per_station.mean(dim=self.ahead_dim, skipna=True) 

271 return mse_blocks, mse_blocks_per_station 

272 

273 def create_error_array(self, data): 

274 """Calculate squared error of all given time series in relation to observation.""" 

275 errors = data.drop_sel({self.model_type_dim: self.observation_indicator}) 

276 errors1 = errors - data.sel({self.model_type_dim: self.observation_indicator}) 

277 errors2 = errors1 ** 2 

278 return errors2 

279 

280 @staticmethod 

281 def create_full_time_dim(data, dim, sampling, start, end): 

282 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" 

283 start_data = data.coords[dim].values[0] 

284 freq = {"daily": "1D", "hourly": "1H"}.get(sampling) 

285 _ind = pd.date_range(start, end, freq=freq) # two steps required to include all hours of end interval 

286 datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1), 

287 closed="left", freq=freq)) 

288 t = data.sel({dim: start_data}, drop=True) 

289 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) 

290 res = res.transpose(*data.dims) 

291 if data.shape == res.shape: 

292 res.loc[data.coords] = data 

293 else: 

294 _d = data.sel({dim: slice(start, end)}) 

295 res.loc[_d.coords] = _d 

296 return res 

297 

298 def load_competitors(self, station_name: str) -> xr.DataArray: 

299 """ 

300 Load all requested and available competitors for a given station. Forecasts must be available in the competitor 

301 path like `<competitor_path>/<target_var>/forecasts_<station_name>_test.nc`. The naming style is equal for all 

302 forecasts of MLAir, so that forecasts of a different experiment can easily be copied into the competitor path 

303 without any change. 

304 

305 :param station_name: station indicator to load competitors for 

306 

307 :return: a single xarray with all competing forecasts 

308 """ 

309 competing_predictions = [] 

310 for competitor_name in self.competitors: 

311 try: 

312 prediction = self._create_competitor_forecast(station_name, competitor_name) 

313 competing_predictions.append(prediction) 

314 except (FileNotFoundError, KeyError): 

315 logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.") 

316 continue 

317 return xr.concat(competing_predictions, self.model_type_dim) if len(competing_predictions) > 0 else None 

318 

319 def calculate_feature_importance(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type="singleinput", 

320 bootstrap_method="shuffle") -> None: 

321 """ 

322 Calculate skill scores of bootstrapped data. 

323 

324 Create bootstrapped data if create_new_bootstraps is true or a failure occurred during skill score calculation 

325 (this will happen by default, if no bootstrapped data is available locally). Set class attribute 

326 bootstrap_skill_scores. This method is implemented in a recursive fashion, but is only allowed to call itself 

327 once. 

328 

329 :param create_new_bootstraps: calculate all bootstrap predictions and overwrite already available predictions 

330 :param _iter: internal counter to reduce unnecessary recursive calls (maximum number is 2, otherwise something 

331 went wrong). 

332 """ 

333 if _iter == 0: 

334 self.feature_importance_skill_scores = {} 

335 for boot_type in to_list(bootstrap_type): 

336 if _iter == 0: 

337 self.feature_importance_skill_scores[boot_type] = {} 

338 for boot_method in to_list(bootstrap_method): 

339 try: 

340 if create_new_bootstraps: 

341 self.create_feature_importance_bootstrap_forecast(bootstrap_type=boot_type, 

342 bootstrap_method=boot_method) 

343 boot_skill_score = self.calculate_feature_importance_skill_scores(bootstrap_type=boot_type, 

344 bootstrap_method=boot_method) 

345 self.feature_importance_skill_scores[boot_type][boot_method] = boot_skill_score 

346 except (FileNotFoundError, ValueError, OSError): 

347 if _iter != 0: 

348 raise RuntimeError(f"calculate_feature_importance ({boot_type}, {boot_method}) was called for " 

349 f"the 2nd time. This means, that something internally goes wrong. Please " 

350 f"check for possible errors.") 

351 logging.info(f"Could not load all files for feature importance ({boot_type}, {boot_method}), " 

352 f"restart calculate_feature_importance with create_new_bootstraps=True.") 

353 self.calculate_feature_importance(True, _iter=1, bootstrap_type=boot_type, 

354 bootstrap_method=boot_method) 

355 

356 def create_feature_importance_bootstrap_forecast(self, bootstrap_type, bootstrap_method) -> None: 

357 """ 

358 Create bootstrapped predictions for all stations and variables. 

359 

360 These forecasts are saved in bootstrap_path with the names `bootstraps_{var}_{station}.nc` and 

361 `bootstraps_labels_{station}.nc`. 

362 """ 

363 

364 def _reshape(d, pos): 

365 if isinstance(d, list): 

366 return list(map(lambda x: _reshape(x, pos), d)) 

367 else: 

368 return d[..., pos] 

369 

370 # forecast 

371 with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})", 

372 log_on_enter=True): 

373 # extract all requirements from data store 

374 number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") 

375 dims = [self.uncertainty_estimate_boot_dim, self.index_dim, self.ahead_dim, self.model_type_dim] 

376 for station in self.test_data: 

377 X, Y = None, None 

378 bootstraps = Bootstraps(station, number_of_bootstraps, bootstrap_type=bootstrap_type, 

379 bootstrap_method=bootstrap_method) 

380 number_of_bootstraps = bootstraps.number_of_bootstraps 

381 for boot in bootstraps: 

382 X, Y, (index, dimension) = boot 

383 # make bootstrap predictions 

384 bootstrap_predictions = [self.model.predict(_reshape(X, pos)) for pos in range(number_of_bootstraps)] 

385 if isinstance(bootstrap_predictions[0], list): # if model is branched model 

386 bootstrap_predictions = list(map(lambda x: x[-1], bootstrap_predictions)) 

387 # save bootstrap predictions separately for each station and variable combination 

388 bootstrap_predictions = list(map(lambda x: np.expand_dims(x, axis=-1), bootstrap_predictions)) 

389 shape = bootstrap_predictions[0].shape 

390 coords = (range(number_of_bootstraps), range(shape[0]), range(1, shape[1] + 1)) 

391 var = f"{index}_{dimension}" if index is not None else str(dimension) 

392 tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims) 

393 file_name = os.path.join(self.forecast_path, 

394 f"bootstraps_{station}_{var}_{bootstrap_type}_{bootstrap_method}.nc") 

395 tmp.to_netcdf(file_name) 

396 else: 

397 # store also true labels for each station 

398 labels = np.expand_dims(Y[..., 0], axis=-1) 

399 file_name = os.path.join(self.forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc") 

400 labels = xr.DataArray(labels, coords=(*coords[1:], [self.observation_indicator]), dims=dims[1:]) 

401 labels.to_netcdf(file_name) 

402 

403 def calculate_feature_importance_skill_scores(self, bootstrap_type, bootstrap_method) -> Dict[str, xr.DataArray]: 

404 """ 

405 Calculate skill score of bootstrapped variables. 

406 

407 Use already created bootstrap predictions and the original predictions (the not-bootstrapped ones) and calculate 

408 skill scores for the bootstraps. The result is saved as a xarray DataArray in a dictionary structure separated 

409 for each station (keys of dictionary). 

410 

411 :return: The result dictionary with station-wise skill scores 

412 """ 

413 with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"): 

414 # extract all requirements from data store 

415 number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") 

416 forecast_file = f"forecasts_norm_%s_test.nc" 

417 reference_name = "orig" 

418 branch_names = self.data_store.get_default("branch_names", None) 

419 

420 bootstraps = Bootstraps(self.test_data[0], number_of_bootstraps, bootstrap_type=bootstrap_type, 

421 bootstrap_method=bootstrap_method) 

422 number_of_bootstraps = bootstraps.number_of_bootstraps 

423 bootstrap_iter = bootstraps.bootstraps() 

424 branch_length = self.get_distinct_branches_from_bootstrap_iter(bootstrap_iter) 

425 skill_scores = statistics.SkillScores(None, ahead_dim=self.ahead_dim, type_dim=self.model_type_dim, 

426 index_dim=self.index_dim, observation_name=self.observation_indicator) 

427 score = {} 

428 for station in self.test_data: 

429 # get station labels 

430 file_name = os.path.join(self.forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc") 

431 with xr.open_dataarray(file_name) as da: 

432 labels = da.load() 

433 

434 # get original forecasts 

435 orig = self.get_orig_prediction(self.forecast_path, forecast_file % str(station), 

436 reference_name=reference_name) 

437 orig.coords[self.index_dim] = labels.coords[self.index_dim] 

438 

439 # calculate skill scores for each variable 

440 skill = [] 

441 for boot_set in bootstrap_iter: 

442 boot_var = f"{boot_set[0]}_{boot_set[1]}" if isinstance(boot_set, tuple) else str(boot_set) 

443 file_name = os.path.join(self.forecast_path, 

444 f"bootstraps_{station}_{boot_var}_{bootstrap_type}_{bootstrap_method}.nc") 

445 with xr.open_dataarray(file_name) as da: 

446 boot_data = da.load() 

447 boot_data = boot_data.combine_first(labels).combine_first(orig) 

448 boot_scores = [] 

449 for ahead in range(1, self.window_lead_time + 1): 

450 data = boot_data.sel({self.ahead_dim: ahead}) 

451 boot_scores.append( 

452 skill_scores.general_skill_score(data, forecast_name=boot_var, 

453 reference_name=reference_name, dim=self.index_dim)) 

454 boot_var_renamed = self.rename_boot_var_with_branch(boot_var, bootstrap_type, branch_names, expected_len=branch_length) 

455 tmp = xr.DataArray(np.expand_dims(np.array(boot_scores), axis=-1), 

456 coords={self.ahead_dim: range(1, self.window_lead_time + 1), 

457 self.uncertainty_estimate_boot_dim: range(number_of_bootstraps), 

458 self.boot_var_dim: [boot_var_renamed]}, 

459 dims=[self.ahead_dim, self.uncertainty_estimate_boot_dim, self.boot_var_dim]) 

460 skill.append(tmp) 

461 

462 # collect all results in single dictionary 

463 score[str(station)] = xr.concat(skill, dim=self.boot_var_dim) 

464 return score 

465 

466 @staticmethod 

467 def get_distinct_branches_from_bootstrap_iter(bootstrap_iter): 

468 if isinstance(bootstrap_iter[0], tuple): 

469 return len(set(map(lambda x: x[0], bootstrap_iter))) 

470 else: 

471 return len(bootstrap_iter) 

472 

473 def rename_boot_var_with_branch(self, boot_var, bootstrap_type, branch_names=None, expected_len=0): 

474 if branch_names is None: 

475 return boot_var 

476 if bootstrap_type == "branch": 

477 try: 

478 assert len(branch_names) > int(boot_var) 

479 assert len(branch_names) == expected_len 

480 return branch_names[int(boot_var)] 

481 except (AssertionError, TypeError): 

482 return boot_var 

483 elif bootstrap_type == "singleinput": 

484 if "_" in boot_var: 

485 branch, other = boot_var.split("_", 1) 

486 branch = self.rename_boot_var_with_branch(branch, "branch", branch_names=branch_names, expected_len=expected_len) 

487 boot_var = "_".join([branch, other]) 

488 return boot_var 

489 return boot_var 

490 

491 def get_orig_prediction(self, path, file_name, prediction_name=None, reference_name=None): 

492 if prediction_name is None: 

493 prediction_name = self.forecast_indicator 

494 file = os.path.join(path, file_name) 

495 with xr.open_dataarray(file) as da: 

496 prediction = da.load().sel({self.model_type_dim: [prediction_name]}) 

497 if reference_name is not None: 

498 prediction.coords[self.model_type_dim] = [reference_name] 

499 return prediction.dropna(dim=self.index_dim) 

500 

501 @staticmethod 

502 def repeat_data(data, number_of_repetition): 

503 if isinstance(data, xr.DataArray): 

504 data = data.data 

505 return np.repeat(np.expand_dims(data, axis=-1), number_of_repetition, axis=-1) 

506 

507 def _get_model_name(self): 

508 """Return model name without path information.""" 

509 return self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] 

510 

511 def _load_model(self) -> AbstractModelClass: 

512 """ 

513 Load NN model either from data store or from local path. 

514 

515 :return: the model 

516 """ 

517 try: # is only available if a model was trained in training stage 

518 model = self.data_store.get("model") 

519 except (NameNotFoundInDataStore, NameNotFoundInScope): 

520 logging.info("No model was saved in data store. Try to load model from experiment path.") 

521 model_name = self.data_store.get("model_name", "model") 

522 model: AbstractModelClass = self.data_store.get("model", "model") 

523 model.load_model(model_name) 

524 return model 

525 

526 # noinspection PyBroadException 

527 def plot(self): 

528 """ 

529 Create all plots. 

530 

531 Plots are defined in experiment set up by `plot_list`. As default, all (following) plots are enabled: 

532 

533 * :py:class:`PlotBootstrapSkillScore <src.plotting.postprocessing_plotting.PlotBootstrapSkillScore>` 

534 * :py:class:`PlotConditionalQuantiles <src.plotting.postprocessing_plotting.PlotConditionalQuantiles>` 

535 * :py:class:`PlotStationMap <src.plotting.postprocessing_plotting.PlotStationMap>` 

536 * :py:class:`PlotMonthlySummary <src.plotting.postprocessing_plotting.PlotMonthlySummary>` 

537 * :py:class:`PlotClimatologicalSkillScore <src.plotting.postprocessing_plotting.PlotClimatologicalSkillScore>` 

538 * :py:class:`PlotCompetitiveSkillScore <src.plotting.postprocessing_plotting.PlotCompetitiveSkillScore>` 

539 * :py:class:`PlotTimeSeries <src.plotting.postprocessing_plotting.PlotTimeSeries>` 

540 * :py:class:`PlotAvailability <src.plotting.postprocessing_plotting.PlotAvailability>` 

541 

542 .. note:: Bootstrap plots are only created if bootstraps are evaluated. 

543 

544 """ 

545 logging.info("Run plotting routines...") 

546 use_multiprocessing = self.data_store.get("use_multiprocessing") 

547 

548 plot_list = self.data_store.get("plot_list", "postprocessing") 

549 time_dim = self.data_store.get("time_dim") 

550 window_dim = self.data_store.get("window_dim") 

551 target_dim = self.data_store.get("target_dim") 

552 iter_dim = self.data_store.get("iter_dim") 

553 

554 try: 

555 if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ( 

556 "PlotSeparationOfScales" in plot_list): 

557 filter_dim = self.data_store.get_default("filter_dim", None) 

558 PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path, time_dim=time_dim, 

559 window_dim=window_dim, target_dim=target_dim, **{"filter_dim": filter_dim}) 

560 except Exception as e: 

561 logging.error(f"Could not create plot PlotSeparationOfScales due to the following error:" 

562 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

563 

564 try: 

565 if (self.feature_importance_skill_scores is not None) and ("PlotFeatureImportanceSkillScore" in plot_list): 

566 branch_names = self.data_store.get_default("branch_names", None) 

567 for boot_type, boot_data in self.feature_importance_skill_scores.items(): 

568 for boot_method, boot_skill_score in boot_data.items(): 

569 try: 

570 PlotFeatureImportanceSkillScore( 

571 boot_skill_score, plot_folder=self.plot_path, model_name=self.model_display_name, 

572 sampling=self._sampling, ahead_dim=self.ahead_dim, 

573 separate_vars=to_list(self.target_var), bootstrap_type=boot_type, 

574 bootstrap_method=boot_method, branch_names=branch_names) 

575 except Exception as e: 

576 logging.error(f"Could not create plot PlotFeatureImportanceSkillScore ({boot_type}, " 

577 f"{boot_method}) due to the following error:\n{sys.exc_info()[0]}\n" 

578 f"{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

579 except Exception as e: 

580 logging.error(f"Could not create plot PlotFeatureImportanceSkillScore due to the following error: {e}") 

581 

582 try: 

583 if "PlotConditionalQuantiles" in plot_list: 

584 PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=self.forecast_path, 

585 plot_folder=self.plot_path, forecast_indicator=self.forecast_indicator, 

586 obs_indicator=self.observation_indicator, competitors=self.competitors, 

587 model_type_dim=self.model_type_dim, index_dim=self.index_dim, 

588 ahead_dim=self.ahead_dim, competitor_path=self.competitor_path, 

589 model_name=self.model_display_name) 

590 except Exception as e: 

591 logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error:" 

592 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

593 

594 try: 

595 if "PlotMonthlySummary" in plot_list: 

596 PlotMonthlySummary(self.test_data.keys(), self.forecast_path, r"forecasts_%s_test.nc", self.target_var, 

597 plot_folder=self.plot_path) 

598 except Exception as e: 

599 logging.error(f"Could not create plot PlotMonthlySummary due to the following error:" 

600 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

601 

602 try: 

603 if "PlotClimatologicalSkillScore" in plot_list: 

604 PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, 

605 model_name=self.model_display_name) 

606 PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, 

607 extra_name_tag="all_terms_", model_name=self.model_display_name) 

608 except Exception as e: 

609 logging.error(f"Could not create plot PlotClimatologicalSkillScore due to the following error: {e}" 

610 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

611 

612 try: 

613 if "PlotCompetitiveSkillScore" in plot_list: 

614 PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, 

615 model_setup=self.model_display_name) 

616 except Exception as e: 

617 logging.error(f"Could not create plot PlotCompetitiveSkillScore due to the following error: {e}" 

618 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

619 

620 try: 

621 if "PlotTimeSeries" in plot_list: 

622 PlotTimeSeries(self.test_data.keys(), self.forecast_path, r"forecasts_%s_test.nc", 

623 plot_folder=self.plot_path, sampling=self._sampling, ahead_dim=self.ahead_dim) 

624 except Exception as e: 

625 logging.error(f"Could not create plot PlotTimeSeries due to the following error:\n{sys.exc_info()[0]}\n" 

626 f"{sys.exc_info()[1]}\n{sys.exc_info()[2]}\n{traceback.format_exc()}") 

627 

628 try: 

629 if "PlotSampleUncertaintyFromBootstrap" in plot_list and self.uncertainty_estimate is not None: 

630 block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate") 

631 for season in [None] + list(self.uncertainty_estimate_seasons.keys()): 

632 estimate = self.uncertainty_estimate if season is None else self.uncertainty_estimate_seasons[season] 

633 PlotSampleUncertaintyFromBootstrap( 

634 data=estimate, plot_folder=self.plot_path, model_type_dim=self.model_type_dim, 

635 dim_name_boots=self.uncertainty_estimate_boot_dim, error_measure="mean squared error", 

636 error_unit=r"ppb$^2$", block_length=block_length, model_name=self.model_display_name, 

637 model_indicator=self.forecast_indicator, sampling=self._sampling, season_annotation=season) 

638 except Exception as e: 

639 logging.error(f"Could not create plot PlotSampleUncertaintyFromBootstrap due to the following error: {e}" 

640 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

641 

642 try: 

643 if "PlotErrorMetrics" in plot_list and self.errors is not None: 

644 error_metric_units = statistics.get_error_metrics_units("ppb") 

645 error_metrics_name = statistics.get_error_metrics_long_name() 

646 for error_metric in self.errors.keys(): 

647 try: 

648 PlotSampleUncertaintyFromBootstrap( 

649 data=self.errors[error_metric], plot_folder=self.plot_path, model_type_dim=self.model_type_dim, 

650 dim_name_boots="station", error_measure=error_metrics_name[error_metric], 

651 error_unit=error_metric_units[error_metric], model_name=self.model_display_name, 

652 model_indicator=self.model_display_name, sampling=self._sampling, apply_root=False, 

653 plot_name=f"error_plot_{error_metric}") 

654 except Exception as e: 

655 logging.error(f"Could not create plot PlotErrorMetrics for {error_metric} due to the following " 

656 f"error: {e}\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

657 except Exception as e: 

658 logging.error(f"Could not create plot PlotErrorMetrics due to the following error: {e}" 

659 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

660 

661 try: 

662 if "PlotStationMap" in plot_list: 

663 gens = [(self.train_data, {"marker": 5, "ms": 9}), 

664 (self.val_data, {"marker": 6, "ms": 9}), 

665 (self.test_data, {"marker": 4, "ms": 9})] 

666 PlotStationMap(generators=gens, plot_folder=self.plot_path) 

667 gens = [(self.train_val_data, {"marker": 8, "ms": 9}), 

668 (self.test_data, {"marker": 9, "ms": 9})] 

669 PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var") 

670 except Exception as e: 

671 if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get("hostname")[:6] in self.data_store.get("hpc_hosts"): 

672 logging.info(f"PlotStationMap might have failed as current workflow is running on hpc node {self.data_store.get('hostname')}. To download geographic elements, please run PlotStationMap once on login node.") 

673 logging.error(f"Could not create plot PlotStationMap due to the following error: {e}" 

674 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

675 

676 try: 

677 if "PlotAvailability" in plot_list: 

678 avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} 

679 PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dim, 

680 window_dimension=window_dim) 

681 except Exception as e: 

682 logging.error(f"Could not create plot PlotAvailability due to the following error: {e}" 

683 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

684 

685 try: 

686 if "PlotAvailabilityHistogram" in plot_list: 

687 avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} 

688 PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path, station_dim=iter_dim, 

689 history_dim=window_dim) 

690 except Exception as e: 

691 logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}" 

692 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

693 

694 try: 

695 if "PlotDataHistogram" in plot_list: 

696 upsampling = self.data_store.get_default("upsampling", scope="train", default=False) 

697 gens = {"train": self.train_data, "val": self.val_data, "test": self.test_data} 

698 PlotDataHistogram(gens, plot_folder=self.plot_path, time_dim=time_dim, variables_dim=target_dim, 

699 upsampling=upsampling) 

700 except Exception as e: 

701 logging.error(f"Could not create plot PlotDataHistogram due to the following error: {e}" 

702 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

703 

704 try: 

705 if "PlotPeriodogram" in plot_list: 

706 PlotPeriodogram(self.train_data, plot_folder=self.plot_path, time_dim=time_dim, 

707 variables_dim=target_dim, sampling=self._sampling, 

708 use_multiprocessing=use_multiprocessing) 

709 except Exception as e: 

710 logging.error(f"Could not create plot PlotPeriodogram due to the following error: {e}" 

711 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

712 

713 try: 

714 if "PlotTimeEvolutionMetric" in plot_list: 

715 PlotTimeEvolutionMetric(self.block_mse_per_station, plot_folder=self.plot_path, 

716 model_type_dim=self.model_type_dim, ahead_dim=self.ahead_dim, 

717 error_measure="Mean Squared Error", error_unit=r"ppb$^2$", 

718 model_indicator=self.forecast_indicator, model_name=self.model_display_name, 

719 time_dim=self.index_dim) 

720 except Exception as e: 

721 logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}" 

722 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

723 

724 try: 

725 if "PlotSeasonalMSEStack" in plot_list: 

726 report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

727 PlotSeasonalMSEStack(data=self.block_mse_per_station, data_path=report_path, plot_folder=self.plot_path, 

728 boot_dim=self.uncertainty_estimate_boot_dim, ahead_dim=self.ahead_dim, 

729 sampling=self._sampling, error_measure="Mean Squared Error", error_unit=r"ppb$^2$", 

730 model_indicator=self.forecast_indicator, model_name=self.model_display_name, 

731 model_type_dim=self.model_type_dim) 

732 except Exception as e: 

733 logging.error(f"Could not create plot PlotSeasonalMSEStack due to the following error: {e}" 

734 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

735 

736 try: 

737 if "PlotErrorsOnMap" in plot_list and self.errors is not None: 

738 for error_metric in self.errors.keys(): 

739 try: 

740 PlotErrorsOnMap(self.test_data, self.errors[error_metric], error_metric, 

741 plot_folder=self.plot_path, sampling=self._sampling) 

742 except Exception as e: 

743 logging.error(f"Could not create plot PlotErrorsOnMap for {error_metric} due to the following " 

744 f"error: {e}\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

745 except Exception as e: 

746 logging.error(f"Could not create plot PlotErrorsOnMap due to the following error: {e}" 

747 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

748 

749 

750 

751 

752 

753 @TimeTrackingWrapper 

754 def calculate_test_score(self): 

755 """Evaluate test score of model and save locally.""" 

756 logging.info(f"start to calculate test scores") 

757 

758 # test scores on transformed data 

759 test_score = self.model.evaluate(self.test_data_distributed, 

760 use_multiprocessing=True, verbose=0) 

761 path = self.data_store.get("model_path") 

762 with open(os.path.join(path, "test_scores.txt"), "a") as f: 

763 for index, item in enumerate(to_list(test_score)): 

764 logging.info(f"{self.model.metrics_names[index]} (test), {item}") 

765 f.write(f"{self.model.metrics_names[index]}, {item}\n") 

766 

767 @TimeTrackingWrapper 

768 def train_ols_model(self): 

769 """Train ordinary least squared model on train data.""" 

770 if "ols" in map(lambda x: x.lower(), self.competitors): 

771 logging.info(f"start train_ols_model on train data") 

772 self.ols_model = OrdinaryLeastSquaredModel(self.train_data) 

773 self.competitors = [e for e in self.competitors if e.lower() != "ols"] 

774 else: 

775 logging.info(f"Skip train ols model as it is not present in competitors.") 

776 

777 def setup_persistence(self): 

778 """Check if persistence is requested from competitors and store this information.""" 

779 self.persi_model = any(x in map(str.lower, self.competitors) for x in ["persi", "persistence"]) 

780 if self.persi_model is False: 

781 logging.info(f"Persistence is not calculated as it is not present in competitors.") 

782 

783 @TimeTrackingWrapper 

784 def make_prediction(self, subset): 

785 """ 

786 Create predictions for NN, OLS, and persistence and add true observation as reference. 

787 

788 Predictions are filled in an array with full index range. Therefore, predictions can have missing values. All 

789 predictions for a single station are stored locally under `<forecast/forecast_norm>_<station>_test.nc` and can 

790 be found inside `forecast_path`. 

791 """ 

792 subset_type = subset.name 

793 logging.info(f"start make_prediction for {subset_type}") 

794 time_dimension = self.data_store.get("time_dim") 

795 window_dim = self.data_store.get("window_dim") 

796 

797 for i, data in enumerate(subset): 

798 input_data = data.get_X() 

799 target_data = data.get_Y(as_numpy=False) 

800 observation_data = data.get_observation() 

801 

802 # get scaling parameters 

803 transformation_func = data.apply_transformation 

804 

805 nn_output = self.model.predict(input_data) 

806 

807 for normalised in [True, False]: 

808 # create empty arrays 

809 nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays( 

810 target_data, count=4) 

811 

812 # nn forecast 

813 nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised) 

814 

815 # persistence 

816 if self.persi_model is True: 

817 persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, 

818 transformation_func, normalised) 

819 else: 

820 persistence_prediction = None 

821 

822 # ols 

823 if self.ols_model is not None: 

824 ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_func, 

825 normalised) 

826 else: 

827 ols_prediction = None 

828 

829 # observation 

830 observation = self._create_observation(target_data, observation, transformation_func, normalised) 

831 

832 # merge all predictions 

833 full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency()) 

834 prediction_dict = {self.forecast_indicator: nn_prediction, 

835 "persi": persistence_prediction, 

836 self.observation_indicator: observation, 

837 "ols": ols_prediction} 

838 all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]), 

839 time_dimension, ahead_dim=self.ahead_dim, 

840 index_dim=self.index_dim, type_dim=self.model_type_dim, 

841 **prediction_dict) 

842 

843 # save all forecasts locally 

844 prefix = "forecasts_norm" if normalised is True else "forecasts" 

845 file = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_{subset_type}.nc") 

846 all_predictions.to_netcdf(file) 

847 

848 def _get_frequency(self) -> str: 

849 """Get frequency abbreviation.""" 

850 getter = {"daily": "1D", "hourly": "1H"} 

851 return getter.get(self._sampling, None) 

852 

853 def _create_competitor_forecast(self, station_name: str, competitor_name: str) -> xr.DataArray: 

854 """ 

855 Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station 

856 indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will 

857 raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either 

858 there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file. 

859 Forecast is trimmed on interval start and end of test subset. 

860 

861 :param station_name: name of the station to load data for 

862 :param competitor_name: name of the model 

863 :return: the forecast of the given competitor 

864 """ 

865 path = os.path.join(self.competitor_path, competitor_name) 

866 file = os.path.join(path, f"forecasts_{station_name}_test.nc") 

867 with xr.open_dataarray(file) as da: 

868 data = da.load() 

869 if self.forecast_indicator in data.coords[self.model_type_dim]: 

870 forecast = data.sel({self.model_type_dim: [self.forecast_indicator]}) 

871 forecast.coords[self.model_type_dim] = [competitor_name] 

872 else: 

873 forecast = data.sel({self.model_type_dim: [competitor_name]}) 

874 # limit forecast to time range of test subset 

875 start, end = self.data_store.get("start", "test"), self.data_store.get("end", "test") 

876 return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end) 

877 

878 def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray: 

879 """ 

880 Create observation as ground truth from given data. 

881 

882 Inverse transformation is applied to the ground truth to get the output in the original space. 

883 

884 :param data: observation 

885 :param transformation_func: a callable function to apply inverse transformation 

886 :param normalised: transform ground truth in original space if false, or use normalised predictions if true 

887 

888 :return: filled data array with observation 

889 """ 

890 if not normalised: 

891 data = transformation_func(data, "target", inverse=True) 

892 return data 

893 

894 def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray, 

895 transformation_func: Callable, normalised: bool) -> xr.DataArray: 

896 """ 

897 Create ordinary least square model forecast with given input data. 

898 

899 Inverse transformation is applied to the forecast to get the output in the original space. 

900 

901 :param input_data: transposed history from DataPrep 

902 :param ols_prediction: empty array in right shape to fill with data 

903 :param transformation_func: a callable function to apply inverse transformation 

904 :param normalised: transform prediction in original space if false, or use normalised predictions if true 

905 

906 :return: filled data array with ols predictions 

907 """ 

908 tmp_ols = self.ols_model.predict(input_data) 

909 target_shape = ols_prediction.values.shape 

910 if target_shape != tmp_ols.shape: 

911 if len(target_shape) == 2: 

912 new_values = np.swapaxes(tmp_ols, 1, 0) 

913 else: 

914 new_values = np.swapaxes(tmp_ols, 2, 0) 

915 else: 

916 new_values = tmp_ols 

917 ols_prediction.values = new_values 

918 if not normalised: 

919 ols_prediction = transformation_func(ols_prediction, "target", inverse=True) 

920 return ols_prediction 

921 

922 def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, transformation_func: Callable, 

923 normalised: bool) -> xr.DataArray: 

924 """ 

925 Create persistence forecast with given data. 

926 

927 Persistence is deviated from the value at t=0 and applied to all following time steps (t+1, ..., t+window). 

928 Inverse transformation is applied to the forecast to get the output in the original space. 

929 

930 :param data: observation 

931 :param persistence_prediction: empty array in right shape to fill with data 

932 :param transformation_func: a callable function to apply inverse transformation 

933 :param normalised: transform prediction in original space if false, or use normalised predictions if true 

934 

935 :return: filled data array with persistence predictions 

936 """ 

937 tmp_persi = data.copy() 

938 persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T 

939 if not normalised: 

940 persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True) 

941 return persistence_prediction 

942 

943 def _create_nn_forecast(self, nn_output: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable, 

944 normalised: bool) -> xr.DataArray: 

945 """ 

946 Create NN forecast for given input data. 

947 

948 Inverse transformation is applied to the forecast to get the output in the original space. Furthermore, only the 

949 output of the main branch is returned (not all minor branches, if the network has multiple output branches). The 

950 main branch is defined to be the last entry of all outputs. 

951 

952 :param nn_output: Full NN model output 

953 :param nn_prediction: empty array in right shape to fill with data 

954 :param transformation_func: a callable function to apply inverse transformation 

955 :param normalised: transform prediction in original space if false, or use normalised predictions if true 

956 

957 :return: filled data array with nn predictions 

958 """ 

959 

960 if isinstance(nn_output, list): 

961 nn_prediction.values = nn_output[-1] 

962 elif nn_output.ndim == 3: 

963 nn_prediction.values = nn_output[-1, ...] 

964 elif nn_output.ndim == 2: 

965 nn_prediction.values = nn_output 

966 else: 

967 raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {nn_output.dims}.") 

968 if not normalised: 

969 nn_prediction = transformation_func(nn_prediction, base="target", inverse=True) 

970 return nn_prediction 

971 

972 @staticmethod 

973 def _create_empty_prediction_arrays(target_data, count=1): 

974 """ 

975 Create array to collect all predictions. Expand target data by a station dimension. """ 

976 return [target_data.copy() for _ in range(count)] 

977 

978 @staticmethod 

979 def create_fullindex(df: Union[xr.DataArray, pd.DataFrame, pd.DatetimeIndex], freq: str) -> pd.DataFrame: 

980 """ 

981 Create full index from first and last date inside df and resample with given frequency. 

982 

983 :param df: use time range of this data set 

984 :param freq: frequency of full index 

985 

986 :return: empty data frame with full index. 

987 """ 

988 if isinstance(df, pd.DataFrame): 

989 earliest = df.index[0] 

990 latest = df.index[-1] 

991 elif isinstance(df, xr.DataArray): 

992 earliest = df.index[0].values 

993 latest = df.index[-1].values 

994 elif isinstance(df, pd.DatetimeIndex): 

995 earliest = df[0] 

996 latest = df[-1] 

997 else: 

998 raise AttributeError(f"unknown array type. Only pandas dataframes, xarray dataarrays and pandas datetimes " 

999 f"are supported. Given type is {type(df)}.") 

1000 index = pd.DataFrame(index=pd.date_range(earliest, latest, freq=freq)) 

1001 return index 

1002 

1003 @staticmethod 

1004 def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, 

1005 ahead_dim="ahead", index_dim="index", type_dim="type", **kwargs): 

1006 """ 

1007 Combine different forecast types into single xarray. 

1008 

1009 :param index: index for forecasts (e.g. time) 

1010 :param ahead_names: names of ahead values (e.g. hours or days) 

1011 :param kwargs: as xarrays; data of forecasts 

1012 

1013 :return: xarray of dimension 3: index, ahead_names, # predictions 

1014 

1015 """ 

1016 kwargs = {k: v for k, v in kwargs.items() if v is not None} 

1017 keys = list(kwargs.keys()) 

1018 res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan), 

1019 coords=[index.index, ahead_names, keys], dims=[index_dim, ahead_dim, type_dim]) 

1020 for k, v in kwargs.items(): 

1021 intersection = set(res.index.values) & set(v.indexes[time_dimension].values) 

1022 match_index = np.array(list(intersection)) 

1023 res.loc[match_index, :, k] = v.loc[match_index] 

1024 return res 

1025 

1026 def _get_internal_data(self, station: str, path: str) -> Union[xr.DataArray, None]: 

1027 """ 

1028 Get internal data for given station. 

1029 

1030 Internal data is defined as data that is already known to the model. From an evaluation perspective, this 

1031 refers to data, that is no test data, and therefore to train and val data. 

1032 

1033 :param station: name of station to load internal data. 

1034 """ 

1035 try: 

1036 file = os.path.join(path, f"forecasts_{str(station)}_train_val.nc") 

1037 with xr.open_dataarray(file) as da: 

1038 return da.load() 

1039 except (IndexError, KeyError, FileNotFoundError): 

1040 return None 

1041 

1042 def _get_external_data(self, station: str, path: str) -> Union[xr.DataArray, None]: 

1043 """ 

1044 Get external data for given station. 

1045 

1046 External data is defined as data that is not known to the model. From an evaluation perspective, this refers to 

1047 data, that is not train or val data, and therefore to test data. 

1048 

1049 :param station: name of station to load external data. 

1050 """ 

1051 try: 

1052 file = os.path.join(path, f"forecasts_{str(station)}_test.nc") 

1053 with xr.open_dataarray(file) as da: 

1054 return da.load() 

1055 except (IndexError, KeyError, FileNotFoundError): 

1056 return None 

1057 

1058 def _combine_forecasts(self, forecast, competitor, dim=None): 

1059 """ 

1060 Combine forecast and competitor if both are xarray. If competitor is None, this returns forecasts and vise 

1061 versa. 

1062 """ 

1063 if dim is None: 

1064 dim = self.model_type_dim 

1065 try: 

1066 return xr.concat([forecast, competitor], dim=dim) 

1067 except (TypeError, AttributeError): 

1068 return forecast if competitor is None else competitor 

1069 

1070 def calculate_error_metrics(self) -> Tuple[Dict, Dict, Dict, Dict]: 

1071 """ 

1072 Calculate error metrics and skill scores of NN forecast. 

1073 

1074 The competitive skill score compares the NN prediction with persistence and ordinary least squares forecasts. 

1075 Whereas, the climatological skill scores evaluates the NN prediction in terms of meaningfulness in comparison 

1076 to different climatological references. 

1077 

1078 :return: competitive and climatological skill scores, error metrics 

1079 """ 

1080 all_stations = self.data_store.get("stations") 

1081 skill_score_competitive = {} 

1082 skill_score_competitive_count = {} 

1083 skill_score_climatological = {} 

1084 errors = {} 

1085 for station in all_stations: 

1086 external_data = self._get_external_data(station, self.forecast_path) # test data 

1087 

1088 # test errors 

1089 if external_data is not None: 

1090 external_data.coords[self.model_type_dim] = [{self.forecast_indicator: self.model_display_name}.get(n, n) 

1091 for n in external_data.coords[self.model_type_dim].values] 

1092 model_type_list = external_data.coords[self.model_type_dim].values.tolist() 

1093 for model_type in remove_items(model_type_list, self.observation_indicator): 

1094 if model_type not in errors.keys(): 

1095 errors[model_type] = {} 

1096 errors[model_type][station] = statistics.calculate_error_metrics( 

1097 *map(lambda x: external_data.sel(**{self.model_type_dim: x}), 

1098 [model_type, self.observation_indicator]), dim=self.index_dim) 

1099 

1100 # load competitors 

1101 competitor = self.load_competitors(station) 

1102 combined = self._combine_forecasts(external_data, competitor, dim=self.model_type_dim) 

1103 if combined is not None: 

1104 model_list = remove_items(combined.coords[self.model_type_dim].values.tolist(), 

1105 self.observation_indicator) 

1106 else: 

1107 model_list = None 

1108 

1109 # test errors of competitors 

1110 for model_type in (model_list or []): 

1111 if self.observation_indicator not in combined.coords[self.model_type_dim]: 

1112 continue 

1113 if model_type not in errors.keys(): 

1114 errors[model_type] = {} 

1115 errors[model_type][station] = statistics.calculate_error_metrics( 

1116 *map(lambda x: combined.sel(**{self.model_type_dim: x}), 

1117 [model_type, self.observation_indicator]), dim=self.index_dim) 

1118 

1119 # skill score 

1120 skill_score = statistics.SkillScores(combined, models=model_list, ahead_dim=self.ahead_dim, 

1121 type_dim=self.model_type_dim, index_dim=self.index_dim) 

1122 if external_data is not None: 

1123 skill_score_competitive[station], skill_score_competitive_count[station] = skill_score.skill_scores() 

1124 

1125 internal_data = self._get_internal_data(station, self.forecast_path) 

1126 if internal_data is not None: 

1127 skill_score_climatological[station] = skill_score.climatological_skill_scores( 

1128 internal_data, forecast_name=self.forecast_indicator) 

1129 

1130 for model_type in errors.keys(): 

1131 errors[model_type].update({"total": self.calculate_average_errors(errors[model_type])}) 

1132 skill_score_competitive.update({"total": self.calculate_average_skill_scores(skill_score_competitive, 

1133 skill_score_competitive_count)}) 

1134 return skill_score_competitive, skill_score_competitive_count, skill_score_climatological, errors 

1135 

1136 @staticmethod 

1137 def calculate_average_skill_scores(scores, counts): 

1138 avg_skill_score = 0 

1139 n_total = None 

1140 for vals in counts.values(): 

1141 n_total = vals if n_total is None else n_total.add(vals, fill_value=0) 

1142 for station, station_scores in scores.items(): 

1143 n = counts.get(station) 

1144 avg_skill_score = station_scores.mul(n.div(n_total, fill_value=0), fill_value=0).add(avg_skill_score, 

1145 fill_value=0) 

1146 return avg_skill_score 

1147 

1148 @staticmethod 

1149 def calculate_average_errors(errors): 

1150 avg_error = {} 

1151 n_total = sum([x.get("n", 0) for _, x in errors.items()]) 

1152 for station, station_errors in errors.items(): 

1153 n_station = station_errors.get("n") 

1154 for error_metric, val in station_errors.items(): 

1155 new_val = avg_error.get(error_metric, 0) + val * n_station / n_total 

1156 avg_error[error_metric] = new_val 

1157 return avg_error 

1158 

1159 def report_feature_importance_results(self, results): 

1160 """Create a csv file containing all results from feature importance.""" 

1161 report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

1162 path_config.check_path_and_create(report_path) 

1163 res = [] 

1164 max_cols = 0 

1165 for boot_type, d0 in results.items(): 

1166 for boot_method, d1 in d0.items(): 

1167 for station_name, vals in d1.items(): 

1168 for boot_var in vals.coords[self.boot_var_dim].values.tolist(): 

1169 for ahead in vals.coords[self.ahead_dim].values.tolist(): 

1170 res.append([boot_type, boot_method, station_name, boot_var, ahead, 

1171 *vals.sel({self.boot_var_dim: boot_var, 

1172 self.ahead_dim: ahead}).values.round(5).tolist()]) 

1173 max_cols = max(max_cols, len(res[-1])) 

1174 col_names = [self.model_type_dim, "method", "station", self.boot_var_dim, self.ahead_dim, 

1175 *list(range(max_cols - 5))] 

1176 df = pd.DataFrame(res, columns=col_names) 

1177 file_name = "feature_importance_skill_score_report_raw.csv" 

1178 df.to_csv(os.path.join(report_path, file_name), sep=";") 

1179 

1180 def report_error_metrics(self, errors): 

1181 report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

1182 path_config.check_path_and_create(report_path) 

1183 for model_type in errors.keys(): 

1184 metric_collection = {} 

1185 for station, station_errors in errors[model_type].items(): 

1186 if isinstance(station_errors, xr.DataArray): 

1187 dim = station_errors.dims[0] 

1188 sel_index = [sel for sel in station_errors.coords[dim] if "CASE" in str(sel)] 

1189 station_errors = {str(i.values): station_errors.sel(**{dim: i}) for i in sel_index} 

1190 elif isinstance(station_errors, pd.DataFrame): 

1191 sel_index = station_errors.index.tolist() 

1192 ahead = station_errors.columns.values 

1193 station_errors = {k: xr.DataArray(station_errors[station_errors.index == k].values.flatten(), 

1194 dims=["ahead"], coords={"ahead": ahead}).astype(float) 

1195 for k in sel_index} 

1196 for metric, vals in station_errors.items(): 

1197 if metric == "n": 

1198 metric = "count" 

1199 pd_vals = pd.DataFrame.from_dict({station: vals}).T 

1200 pd_vals.columns = [f"{metric}(t+{x})" for x in vals.coords["ahead"].values] 

1201 mc = metric_collection.get(metric, pd.DataFrame()) 

1202 mc = mc.append(pd_vals) 

1203 metric_collection[metric] = mc 

1204 for metric, error_df in metric_collection.items(): 

1205 df = error_df.sort_index() 

1206 if "total" in df.index: 

1207 df.reindex(df.index.drop(["total"]).to_list() + ["total"], ) 

1208 column_format = tables.create_column_format_for_tex(df) 

1209 if model_type == "skill_score": 

1210 file_name = f"error_report_{model_type}_{metric}.%s".replace(' ', '_').replace('/', '_') 

1211 else: 

1212 file_name = f"error_report_{metric}_{model_type}.%s".replace(' ', '_').replace('/', '_') 

1213 tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df) 

1214 tables.save_to_md(report_path, file_name % "md", df=df) 

1215 

1216 def store_errors(self, errors): 

1217 metric_collection = {} 

1218 error_dim = "error_metric" 

1219 station_dim = "station" 

1220 for model_type in errors.keys(): 

1221 station_collection = {} 

1222 for station, station_errors in errors[model_type].items(): 

1223 if station == "total": 

1224 continue 

1225 station_collection[station] = xr.Dataset(station_errors).to_array(error_dim) 

1226 metric_collection[model_type] = xr.Dataset(station_collection).to_array(station_dim) 

1227 coll = xr.Dataset(metric_collection).to_array(self.model_type_dim) 

1228 coll = coll.transpose(station_dim, self.ahead_dim, self.model_type_dim, error_dim) 

1229 self.errors = {k: coll.sel({error_dim: k}, drop=True) for k in coll.coords[error_dim].values}