Coverage for mlair/run_modules/post_processing.py: 6%

626 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-12-18 17:51 +0000

1"""Post-processing module.""" 

2 

3__author__ = "Lukas Leufen, Felix Kleinert" 

4__date__ = '2019-12-11' 

5 

6import inspect 

7import logging 

8import os 

9import sys 

10import traceback 

11import copy 

12from typing import Dict, Tuple, Union, List, Callable 

13 

14import numpy as np 

15import pandas as pd 

16import xarray as xr 

17import datetime as dt 

18 

19from mlair.configuration import path_config 

20from mlair.data_handler import Bootstraps, KerasIterator 

21from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope 

22from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables 

23from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel 

24from mlair.model_modules import AbstractModelClass 

25from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ 

26 PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \ 

27 PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric, PlotSeasonalMSEStack, \ 

28 PlotErrorsOnMap 

29from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ 

30 PlotPeriodogram, PlotDataHistogram, PlotDataMonthlyDistribution 

31from mlair.run_modules.run_environment import RunEnvironment 

32 

33 

34class PostProcessing(RunEnvironment): 

35 """ 

36 Perform post-processing for performance evaluation. 

37 

38 Schedule of post-processing: 

39 #. train an ordinary least squared model (ols) for reference 

40 #. create forecasts for nn, ols, and persistence 

41 #. evaluate feature importance with bootstrapped predictions 

42 #. calculate skill scores 

43 #. create plots 

44 

45 Required objects [scope] from data store: 

46 * `model` [.] or locally saved model plus `model_name` [model] and `model` [model] 

47 * `generator` [train, val, test, train_val] 

48 * `forecast_path` [.] 

49 * `plot_path` [postprocessing] 

50 * `model_path` [.] 

51 * `target_var` [.] 

52 * `sampling` [.] 

53 * `output_shape` [model] 

54 * `evaluate_feature_importance` [postprocessing] and if enabled: 

55 

56 * `create_new_bootstraps` [postprocessing] 

57 * `bootstrap_path` [postprocessing] 

58 * `number_of_bootstraps` [postprocessing] 

59 

60 Optional objects 

61 * `batch_size` [model] 

62 

63 Creates 

64 * forecasts in `forecast_path` if enabled 

65 * bootstraps in `bootstrap_path` if enabled 

66 * plots in `plot_path` 

67 

68 """ 

69 

70 def __init__(self): 

71 """Initialise and run post-processing.""" 

72 super().__init__() 

73 self.model: AbstractModelClass = self._load_model() 

74 self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] 

75 self.ols_model = None 

76 self.persi_model = True 

77 self.batch_size: int = self.data_store.get_default("batch_size", "model", 64) 

78 self.test_data = self.data_store.get("data_collection", "test") 

79 batch_path = self.data_store.get("batch_path", scope="test") 

80 self.test_data_distributed = KerasIterator(self.test_data, self.batch_size, model=self.model, name="test", 

81 batch_path=batch_path) 

82 self.train_data = self.data_store.get("data_collection", "train") 

83 self.val_data = self.data_store.get("data_collection", "val") 

84 self.train_val_data = self.data_store.get("data_collection", "train_val") 

85 self.forecast_path = self.data_store.get("forecast_path") 

86 self.plot_path: str = self.data_store.get("plot_path") 

87 self.target_var = self.data_store.get("target_var") 

88 self._sampling = self.data_store.get("sampling") 

89 self.window_lead_time = extract_value(self.data_store.get("output_shape", "model")) 

90 self.skill_scores = None 

91 self.errors = None 

92 self.bias_free_errors = None 

93 self.feature_importance_skill_scores = None 

94 self.uncertainty_estimate = None 

95 self.uncertainty_estimate_seasons = {} 

96 self.block_mse_per_station = None 

97 self.block_mse = None 

98 self.competitor_path = self.data_store.get("competitor_path") 

99 self.competitors = to_list(self.data_store.get_default("competitors", default=[])) 

100 self.forecast_indicator = "nn" 

101 self.observation_indicator = "obs" 

102 self.ahead_dim = "ahead" 

103 self.boot_var_dim = "boot_var" 

104 self.uncertainty_estimate_boot_dim = "boots" 

105 self.model_type_dim = "type" 

106 self.index_dim = "index" 

107 self.model_display_name = self.data_store.get_default("model_display_name", default=self.model.model_name) 

108 self._run() 

109 

110 def _run(self): 

111 # ols model 

112 self.train_ols_model() 

113 

114 # persi model 

115 self.setup_persistence() 

116 

117 # forecasts on test data 

118 self.make_prediction(self.test_data) 

119 self.make_prediction(self.train_val_data) 

120 

121 # calculate error metrics on test data 

122 self.calculate_test_score() 

123 

124 # calculate monthly block mse 

125 self.block_mse, self.block_mse_per_station = self.calculate_block_mse(evaluate_competitors=True, 

126 separate_ahead=True, block_length="1m") 

127 

128 # sample uncertainty 

129 if self.data_store.get("do_uncertainty_estimate", "postprocessing"): 

130 self.estimate_sample_uncertainty(separate_ahead=True) 

131 

132 # feature importance bootstraps 

133 if self.data_store.get("evaluate_feature_importance", "postprocessing"): 

134 with TimeTracking(name="evaluate_feature_importance", log_on_enter=True): 

135 create_new_bootstraps = self.data_store.get("create_new_bootstraps", "feature_importance") 

136 bootstrap_method = self.data_store.get("bootstrap_method", "feature_importance") 

137 bootstrap_type = self.data_store.get("bootstrap_type", "feature_importance") 

138 self.calculate_feature_importance(create_new_bootstraps, bootstrap_type=bootstrap_type, 

139 bootstrap_method=bootstrap_method) 

140 if self.feature_importance_skill_scores is not None: 

141 self.report_feature_importance_results(self.feature_importance_skill_scores) 

142 

143 # skill scores and error metrics 

144 with TimeTracking(name="calculate_error_metrics", log_on_enter=True): 

145 skill_score_competitive, _, skill_score_climatological, errors = self.calculate_error_metrics() 

146 self.skill_scores = (skill_score_competitive, skill_score_climatological) 

147 with TimeTracking(name="report_error_metrics", log_on_enter=True): 

148 self.report_error_metrics(errors) 

149 self.report_error_metrics({self.forecast_indicator: skill_score_climatological}) 

150 self.report_error_metrics({"skill_score": skill_score_competitive}) 

151 self.errors = self.store_errors(errors) 

152 

153 # bias free evaluation 

154 if self.data_store.get("do_bias_free_evaluation", "postprocessing") is True: 

155 bias_free_errors = self.calculate_bias_free_error_metrics() 

156 self.report_error_metrics(bias_free_errors[0], tag="bias_free") 

157 self.report_error_metrics(bias_free_errors[1], tag="seasonal_bias_free") 

158 self.bias_free_errors = [self.store_errors(bias_free_errors[0]), self.store_errors(bias_free_errors[1])] 

159 

160 # plotting 

161 self.plot() 

162 

163 @TimeTrackingWrapper 

164 def estimate_sample_uncertainty(self, separate_ahead=False): 

165 """ 

166 Estimate sample uncertainty by using a bootstrap approach. Forecasts are split into individual blocks along time 

167 and randomly drawn with replacement. The resulting behaviour of the error indicates the robustness of each 

168 analyzed model to quantify which model might be superior compared to others. 

169 """ 

170 logging.info("start estimate_sample_uncertainty") 

171 n_boots = self.data_store.get_default("n_boots", default=1000, scope="uncertainty_estimate") 

172 block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate") 

173 evaluate_competitors = self.data_store.get_default("evaluate_competitors", default=True, 

174 scope="uncertainty_estimate") 

175 if evaluate_competitors is True and separate_ahead is True and block_length == "1m": 

176 block_mse, block_mse_per_station = self.block_mse, self.block_mse_per_station 

177 else: 

178 block_mse, block_mse_per_station = self.calculate_block_mse(evaluate_competitors=evaluate_competitors, 

179 separate_ahead=separate_ahead, 

180 block_length=block_length) 

181 estimate = statistics.create_n_bootstrap_realizations( 

182 block_mse, dim_name_time=self.index_dim, dim_name_model=self.model_type_dim, 

183 dim_name_boots=self.uncertainty_estimate_boot_dim, n_boots=n_boots, seasons=["DJF", "MAM", "JJA", "SON"]) 

184 self.uncertainty_estimate = estimate.pop("") 

185 self.uncertainty_estimate_seasons = estimate 

186 self.report_sample_uncertainty() 

187 

188 def report_sample_uncertainty(self, percentiles: list = None): 

189 """ 

190 Store raw results of uncertainty estimate and calculate aggregate statistics and store as raw data but also as 

191 markdown and latex. 

192 """ 

193 report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

194 path_config.check_path_and_create(report_path) 

195 

196 # store raw results as nc 

197 file_name = os.path.join(report_path, "uncertainty_estimate_raw_results.nc") 

198 self.uncertainty_estimate.to_netcdf(path=file_name) 

199 for season in self.uncertainty_estimate_seasons.keys(): 

200 file_name = os.path.join(report_path, f"uncertainty_estimate_raw_results_{season}.nc") 

201 self.uncertainty_estimate_seasons[season].to_netcdf(path=file_name) 

202 

203 # store block mse per station 

204 file_name = os.path.join(report_path, f"block_mse_raw_results.nc") 

205 self.block_mse_per_station.to_netcdf(path=file_name) 

206 

207 # store statistics 

208 if percentiles is None: 

209 percentiles = [.05, .1, .25, .5, .75, .9, .95] 

210 

211 for season in [None] + list(self.uncertainty_estimate_seasons.keys()): 

212 estimate = self.uncertainty_estimate if season is None else self.uncertainty_estimate_seasons[season] 

213 affix = "" if season is None else f"_{season}" 

214 for ahead_steps in ["single", "multi"]: 

215 if ahead_steps == "single": 

216 try: 

217 df_descr = estimate.to_pandas().describe(percentiles=percentiles).astype("float32") 

218 except ValueError: 

219 df_descr = estimate.mean(self.ahead_dim).to_pandas().describe(percentiles=percentiles).astype("float32") 

220 else: 

221 if self.ahead_dim not in estimate.dims: 

222 continue 

223 df_descr = estimate.to_dataframe(self.model_type_dim).unstack().groupby(level=self.ahead_dim).describe( 

224 percentiles=percentiles).astype("float32") 

225 df_descr = df_descr.stack(-1) 

226 df_descr = df_descr.reorder_levels(df_descr.index.names[::-1]) 

227 df_sorter = ["count", "mean", "std", "min", *[f"{round(p * 100)}%" for p in percentiles], "max"] 

228 df_descr = df_descr.loc[df_sorter] 

229 column_format = tables.create_column_format_for_tex(df_descr) 

230 file_name = os.path.join(report_path, f"uncertainty_estimate_statistics_{ahead_steps}{affix}.%s") 

231 tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df_descr) 

232 tables.save_to_md(report_path, file_name % "md", df=df_descr) 

233 df_descr.to_csv(file_name % "csv", sep=";") 

234 

235 def calculate_block_mse(self, evaluate_competitors=True, separate_ahead=False, block_length="1m"): 

236 """ 

237 Transform data into blocks along time axis. Block length can be any frequency like '1m' or '7d. Data are only 

238 split along time axis, which means that a single block can have very diverse quantities regarding the number of 

239 station or actual data contained. This is intended to analyze not only the robustness against the time but also 

240 against the number of observations and diversity ot stations. 

241 """ 

242 all_stations = self.data_store.get("stations", "test") 

243 start = self.data_store.get("start", "test") 

244 end = self.data_store.get("end", "test") 

245 index_dim = self.index_dim 

246 coll_dim = "station" 

247 collector = [] 

248 for station in all_stations: 

249 # test data 

250 external_data = self._get_external_data(station, self.forecast_path) 

251 if external_data is None: 

252 logging.info(f"skip calculate_block_mse for {station} as no external_data are available") 

253 continue 

254 # competitors 

255 if evaluate_competitors is True: 

256 competitor = self.load_competitors(station) 

257 combined = self._combine_forecasts(external_data, competitor, dim=self.model_type_dim) 

258 else: 

259 combined = external_data 

260 

261 if combined is None: 

262 continue 

263 else: 

264 combined = self.create_full_time_dim(combined, index_dim, self._sampling, start, end) 

265 # get squared errors 

266 errors = self.create_error_array(combined) 

267 # calc mse for each block (single station) 

268 mse = errors.resample(indexer={index_dim: block_length}).mean(skipna=True) 

269 collector.append(mse.assign_coords({coll_dim: station})) 

270 

271 # combine all mse blocks 

272 mse_blocks_per_station = xr.concat(collector, dim=coll_dim) 

273 # calc mse for each block (average over all stations) 

274 mse_blocks = mse_blocks_per_station.mean(dim=coll_dim, skipna=True) 

275 # average also on ahead steps 

276 if separate_ahead is False: 

277 mse_blocks = mse_blocks.mean(dim=self.ahead_dim, skipna=True) 

278 mse_blocks_per_station = mse_blocks_per_station.mean(dim=self.ahead_dim, skipna=True) 

279 return mse_blocks, mse_blocks_per_station 

280 

281 def create_error_array(self, data): 

282 """Calculate squared error of all given time series in relation to observation.""" 

283 errors = data.drop_sel({self.model_type_dim: self.observation_indicator}) 

284 errors1 = errors - data.sel({self.model_type_dim: self.observation_indicator}) 

285 errors2 = errors1 ** 2 

286 return errors2 

287 

288 @staticmethod 

289 def create_full_time_dim(data, dim, sampling, start, end): 

290 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" 

291 start_data = data.coords[dim].values[0] 

292 freq = {"daily": "1D", "hourly": "1H"}.get(sampling) 

293 _ind = pd.date_range(start, end, freq=freq) # two steps required to include all hours of end interval 

294 datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1), 

295 closed="left", freq=freq)) 

296 t = data.sel({dim: start_data}, drop=True) 

297 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) 

298 res = res.transpose(*data.dims) 

299 if data.shape == res.shape: 

300 res.loc[data.coords] = data 

301 else: 

302 _d = data.sel({dim: slice(start, end)}) 

303 res.loc[_d.coords] = _d 

304 return res 

305 

306 def load_competitors(self, station_name: str) -> xr.DataArray: 

307 """ 

308 Load all requested and available competitors for a given station. Forecasts must be available in the competitor 

309 path like `<competitor_path>/<target_var>/forecasts_<station_name>_test.nc`. The naming style is equal for all 

310 forecasts of MLAir, so that forecasts of a different experiment can easily be copied into the competitor path 

311 without any change. 

312 

313 :param station_name: station indicator to load competitors for 

314 

315 :return: a single xarray with all competing forecasts 

316 """ 

317 competing_predictions = [] 

318 for competitor_name in self.competitors: 

319 try: 

320 prediction = self._create_competitor_forecast(station_name, competitor_name) 

321 competing_predictions.append(prediction) 

322 except (FileNotFoundError, KeyError): 

323 logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.") 

324 continue 

325 return xr.concat(competing_predictions, self.model_type_dim) if len(competing_predictions) > 0 else None 

326 

327 def calculate_feature_importance(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type="singleinput", 

328 bootstrap_method="shuffle") -> None: 

329 """ 

330 Calculate skill scores of bootstrapped data. 

331 

332 Create bootstrapped data if create_new_bootstraps is true or a failure occurred during skill score calculation 

333 (this will happen by default, if no bootstrapped data is available locally). Set class attribute 

334 bootstrap_skill_scores. This method is implemented in a recursive fashion, but is only allowed to call itself 

335 once. 

336 

337 :param create_new_bootstraps: calculate all bootstrap predictions and overwrite already available predictions 

338 :param _iter: internal counter to reduce unnecessary recursive calls (maximum number is 2, otherwise something 

339 went wrong). 

340 """ 

341 if _iter == 0: 

342 self.feature_importance_skill_scores = {} 

343 for boot_type in to_list(bootstrap_type): 

344 if _iter == 0: 

345 self.feature_importance_skill_scores[boot_type] = {} 

346 for boot_method in to_list(bootstrap_method): 

347 try: 

348 if create_new_bootstraps: 

349 self.create_feature_importance_bootstrap_forecast(bootstrap_type=boot_type, 

350 bootstrap_method=boot_method) 

351 boot_skill_score = self.calculate_feature_importance_skill_scores(bootstrap_type=boot_type, 

352 bootstrap_method=boot_method) 

353 self.feature_importance_skill_scores[boot_type][boot_method] = boot_skill_score 

354 except (FileNotFoundError, ValueError, OSError): 

355 if _iter != 0: 

356 raise RuntimeError(f"calculate_feature_importance ({boot_type}, {boot_method}) was called for " 

357 f"the 2nd time. This means, that something internally goes wrong. Please " 

358 f"check for possible errors.") 

359 logging.info(f"Could not load all files for feature importance ({boot_type}, {boot_method}), " 

360 f"restart calculate_feature_importance with create_new_bootstraps=True.") 

361 self.calculate_feature_importance(True, _iter=1, bootstrap_type=boot_type, 

362 bootstrap_method=boot_method) 

363 

364 def create_feature_importance_bootstrap_forecast(self, bootstrap_type, bootstrap_method) -> None: 

365 """ 

366 Create bootstrapped predictions for all stations and variables. 

367 

368 These forecasts are saved in bootstrap_path with the names `bootstraps_{var}_{station}.nc` and 

369 `bootstraps_labels_{station}.nc`. 

370 """ 

371 

372 def _reshape(d, pos): 

373 if isinstance(d, list): 

374 return list(map(lambda x: _reshape(x, pos), d)) 

375 else: 

376 return d[..., pos] 

377 

378 # forecast 

379 with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})", 

380 log_on_enter=True): 

381 # extract all requirements from data store 

382 number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") 

383 dims = [self.uncertainty_estimate_boot_dim, self.index_dim, self.ahead_dim, self.model_type_dim] 

384 for station in self.test_data: 

385 X, Y = None, None 

386 bootstraps = Bootstraps(station, number_of_bootstraps, bootstrap_type=bootstrap_type, 

387 bootstrap_method=bootstrap_method) 

388 number_of_bootstraps = bootstraps.number_of_bootstraps 

389 for boot in bootstraps: 

390 X, Y, (index, dimension) = boot 

391 # make bootstrap predictions 

392 bootstrap_predictions = [self.model.predict(_reshape(X, pos)) for pos in range(number_of_bootstraps)] 

393 if isinstance(bootstrap_predictions[0], list): # if model is branched model 

394 bootstrap_predictions = list(map(lambda x: x[-1], bootstrap_predictions)) 

395 # save bootstrap predictions separately for each station and variable combination 

396 bootstrap_predictions = list(map(lambda x: np.expand_dims(x, axis=-1), bootstrap_predictions)) 

397 shape = bootstrap_predictions[0].shape 

398 coords = (range(number_of_bootstraps), range(shape[0]), range(1, shape[1] + 1)) 

399 var = f"{index}_{dimension}" if index is not None else str(dimension) 

400 tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims) 

401 file_name = os.path.join(self.forecast_path, 

402 f"bootstraps_{station}_{var}_{bootstrap_type}_{bootstrap_method}.nc") 

403 tmp.to_netcdf(file_name) 

404 else: 

405 # store also true labels for each station 

406 labels = np.expand_dims(Y[..., 0], axis=-1) 

407 file_name = os.path.join(self.forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc") 

408 labels = xr.DataArray(labels, coords=(*coords[1:], [self.observation_indicator]), dims=dims[1:]) 

409 labels.to_netcdf(file_name) 

410 

411 def calculate_feature_importance_skill_scores(self, bootstrap_type, bootstrap_method) -> Dict[str, xr.DataArray]: 

412 """ 

413 Calculate skill score of bootstrapped variables. 

414 

415 Use already created bootstrap predictions and the original predictions (the not-bootstrapped ones) and calculate 

416 skill scores for the bootstraps. The result is saved as a xarray DataArray in a dictionary structure separated 

417 for each station (keys of dictionary). 

418 

419 :return: The result dictionary with station-wise skill scores 

420 """ 

421 with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"): 

422 # extract all requirements from data store 

423 number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") 

424 forecast_file = f"forecasts_norm_%s_test.nc" 

425 reference_name = "orig" 

426 branch_names = self.data_store.get_default("branch_names", None) 

427 

428 bootstraps = Bootstraps(self.test_data[0], number_of_bootstraps, bootstrap_type=bootstrap_type, 

429 bootstrap_method=bootstrap_method) 

430 number_of_bootstraps = bootstraps.number_of_bootstraps 

431 bootstrap_iter = bootstraps.bootstraps() 

432 branch_length = self.get_distinct_branches_from_bootstrap_iter(bootstrap_iter) 

433 skill_scores = statistics.SkillScores(None, ahead_dim=self.ahead_dim, type_dim=self.model_type_dim, 

434 index_dim=self.index_dim, observation_name=self.observation_indicator) 

435 score = {} 

436 for station in self.test_data: 

437 # get station labels 

438 file_name = os.path.join(self.forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc") 

439 with xr.open_dataarray(file_name) as da: 

440 labels = da.load() 

441 

442 # get original forecasts 

443 orig = self.get_orig_prediction(self.forecast_path, forecast_file % str(station), 

444 reference_name=reference_name) 

445 orig.coords[self.index_dim] = labels.coords[self.index_dim] 

446 

447 # calculate skill scores for each variable 

448 skill = [] 

449 for boot_set in bootstrap_iter: 

450 boot_var = f"{boot_set[0]}_{boot_set[1]}" if isinstance(boot_set, tuple) else str(boot_set) 

451 file_name = os.path.join(self.forecast_path, 

452 f"bootstraps_{station}_{boot_var}_{bootstrap_type}_{bootstrap_method}.nc") 

453 with xr.open_dataarray(file_name) as da: 

454 boot_data = da.load() 

455 boot_data = boot_data.combine_first(labels).combine_first(orig) 

456 boot_scores = [] 

457 for ahead in range(1, self.window_lead_time + 1): 

458 data = boot_data.sel({self.ahead_dim: ahead}) 

459 boot_scores.append( 

460 skill_scores.general_skill_score(data, forecast_name=boot_var, 

461 reference_name=reference_name, dim=self.index_dim)) 

462 boot_var_renamed = self.rename_boot_var_with_branch(boot_var, bootstrap_type, branch_names, expected_len=branch_length) 

463 tmp = xr.DataArray(np.expand_dims(np.array(boot_scores), axis=-1), 

464 coords={self.ahead_dim: range(1, self.window_lead_time + 1), 

465 self.uncertainty_estimate_boot_dim: range(number_of_bootstraps), 

466 self.boot_var_dim: [boot_var_renamed]}, 

467 dims=[self.ahead_dim, self.uncertainty_estimate_boot_dim, self.boot_var_dim]) 

468 skill.append(tmp) 

469 

470 # collect all results in single dictionary 

471 score[str(station)] = xr.concat(skill, dim=self.boot_var_dim) 

472 return score 

473 

474 @staticmethod 

475 def get_distinct_branches_from_bootstrap_iter(bootstrap_iter): 

476 if isinstance(bootstrap_iter[0], tuple): 

477 return len(set(map(lambda x: x[0], bootstrap_iter))) 

478 else: 

479 return len(bootstrap_iter) 

480 

481 def rename_boot_var_with_branch(self, boot_var, bootstrap_type, branch_names=None, expected_len=0): 

482 if branch_names is None: 

483 return boot_var 

484 if bootstrap_type == "branch": 

485 try: 

486 assert len(branch_names) > int(boot_var) 

487 assert len(branch_names) == expected_len 

488 return branch_names[int(boot_var)] 

489 except (AssertionError, TypeError): 

490 return boot_var 

491 elif bootstrap_type == "singleinput": 

492 if "_" in boot_var: 

493 branch, other = boot_var.split("_", 1) 

494 branch = self.rename_boot_var_with_branch(branch, "branch", branch_names=branch_names, expected_len=expected_len) 

495 boot_var = "_".join([branch, other]) 

496 return boot_var 

497 return boot_var 

498 

499 def get_orig_prediction(self, path, file_name, prediction_name=None, reference_name=None): 

500 if prediction_name is None: 

501 prediction_name = self.forecast_indicator 

502 file = os.path.join(path, file_name) 

503 with xr.open_dataarray(file) as da: 

504 prediction = da.load().sel({self.model_type_dim: [prediction_name]}) 

505 if reference_name is not None: 

506 prediction.coords[self.model_type_dim] = [reference_name] 

507 return prediction.dropna(dim=self.index_dim) 

508 

509 @staticmethod 

510 def repeat_data(data, number_of_repetition): 

511 if isinstance(data, xr.DataArray): 

512 data = data.data 

513 return np.repeat(np.expand_dims(data, axis=-1), number_of_repetition, axis=-1) 

514 

515 def _get_model_name(self): 

516 """Return model name without path information.""" 

517 return self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] 

518 

519 def _load_model(self) -> AbstractModelClass: 

520 """ 

521 Load NN model either from data store or from local path. 

522 

523 :return: the model 

524 """ 

525 try: # is only available if a model was trained in training stage 

526 model = self.data_store.get("model") 

527 except (NameNotFoundInDataStore, NameNotFoundInScope): 

528 logging.info("No model was saved in data store. Try to load model from experiment path.") 

529 model_name = self.data_store.get("model_name", "model") 

530 model: AbstractModelClass = self.data_store.get("model", "model") 

531 model.load_model(model_name) 

532 return model 

533 

534 # noinspection PyBroadException 

535 def plot(self): 

536 """ 

537 Create all plots. 

538 

539 Plots are defined in experiment set up by `plot_list`. As default, all (following) plots are enabled: 

540 

541 * :py:class:`PlotBootstrapSkillScore <src.plotting.postprocessing_plotting.PlotBootstrapSkillScore>` 

542 * :py:class:`PlotConditionalQuantiles <src.plotting.postprocessing_plotting.PlotConditionalQuantiles>` 

543 * :py:class:`PlotStationMap <src.plotting.postprocessing_plotting.PlotStationMap>` 

544 * :py:class:`PlotMonthlySummary <src.plotting.postprocessing_plotting.PlotMonthlySummary>` 

545 * :py:class:`PlotClimatologicalSkillScore <src.plotting.postprocessing_plotting.PlotClimatologicalSkillScore>` 

546 * :py:class:`PlotCompetitiveSkillScore <src.plotting.postprocessing_plotting.PlotCompetitiveSkillScore>` 

547 * :py:class:`PlotTimeSeries <src.plotting.postprocessing_plotting.PlotTimeSeries>` 

548 * :py:class:`PlotAvailability <src.plotting.postprocessing_plotting.PlotAvailability>` 

549 

550 .. note:: Bootstrap plots are only created if bootstraps are evaluated. 

551 

552 """ 

553 logging.info("Run plotting routines...") 

554 use_multiprocessing = self.data_store.get("use_multiprocessing") 

555 

556 plot_list = self.data_store.get("plot_list", "postprocessing") 

557 time_dim = self.data_store.get("time_dim") 

558 window_dim = self.data_store.get("window_dim") 

559 target_dim = self.data_store.get("target_dim") 

560 iter_dim = self.data_store.get("iter_dim") 

561 

562 try: 

563 if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ( 

564 "PlotSeparationOfScales" in plot_list): 

565 filter_dim = self.data_store.get_default("filter_dim", None) 

566 PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path, time_dim=time_dim, 

567 window_dim=window_dim, target_dim=target_dim, **{"filter_dim": filter_dim}) 

568 except Exception as e: 

569 logging.error(f"Could not create plot PlotSeparationOfScales due to the following error:" 

570 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

571 

572 try: 

573 if (self.feature_importance_skill_scores is not None) and ("PlotFeatureImportanceSkillScore" in plot_list): 

574 branch_names = self.data_store.get_default("branch_names", None) 

575 for boot_type, boot_data in self.feature_importance_skill_scores.items(): 

576 for boot_method, boot_skill_score in boot_data.items(): 

577 try: 

578 PlotFeatureImportanceSkillScore( 

579 boot_skill_score, plot_folder=self.plot_path, model_name=self.model_display_name, 

580 sampling=self._sampling, ahead_dim=self.ahead_dim, 

581 separate_vars=to_list(self.target_var), bootstrap_type=boot_type, 

582 bootstrap_method=boot_method, branch_names=branch_names) 

583 except Exception as e: 

584 logging.error(f"Could not create plot PlotFeatureImportanceSkillScore ({boot_type}, " 

585 f"{boot_method}) due to the following error:\n{sys.exc_info()[0]}\n" 

586 f"{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

587 except Exception as e: 

588 logging.error(f"Could not create plot PlotFeatureImportanceSkillScore due to the following error: {e}") 

589 

590 try: 

591 if "PlotConditionalQuantiles" in plot_list: 

592 PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=self.forecast_path, 

593 plot_folder=self.plot_path, forecast_indicator=self.forecast_indicator, 

594 obs_indicator=self.observation_indicator, competitors=self.competitors, 

595 model_type_dim=self.model_type_dim, index_dim=self.index_dim, 

596 ahead_dim=self.ahead_dim, competitor_path=self.competitor_path, 

597 model_name=self.model_display_name) 

598 except Exception as e: 

599 logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error:" 

600 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

601 

602 try: 

603 if "PlotMonthlySummary" in plot_list: 

604 PlotMonthlySummary(self.test_data.keys(), self.forecast_path, r"forecasts_%s_test.nc", self.target_var, 

605 plot_folder=self.plot_path) 

606 except Exception as e: 

607 logging.error(f"Could not create plot PlotMonthlySummary due to the following error:" 

608 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

609 

610 try: 

611 if "PlotClimatologicalSkillScore" in plot_list: 

612 PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, 

613 model_name=self.model_display_name) 

614 PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, 

615 extra_name_tag="all_terms_", model_name=self.model_display_name) 

616 except Exception as e: 

617 logging.error(f"Could not create plot PlotClimatologicalSkillScore due to the following error: {e}" 

618 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

619 

620 try: 

621 if "PlotCompetitiveSkillScore" in plot_list: 

622 PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, 

623 model_setup=self.model_display_name) 

624 except Exception as e: 

625 logging.error(f"Could not create plot PlotCompetitiveSkillScore due to the following error: {e}" 

626 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

627 

628 try: 

629 if "PlotTimeSeries" in plot_list: 

630 PlotTimeSeries(self.test_data.keys(), self.forecast_path, r"forecasts_%s_test.nc", 

631 plot_folder=self.plot_path, sampling=self._sampling, ahead_dim=self.ahead_dim) 

632 except Exception as e: 

633 logging.error(f"Could not create plot PlotTimeSeries due to the following error:\n{sys.exc_info()[0]}\n" 

634 f"{sys.exc_info()[1]}\n{sys.exc_info()[2]}\n{traceback.format_exc()}") 

635 

636 try: 

637 if "PlotSampleUncertaintyFromBootstrap" in plot_list and self.uncertainty_estimate is not None: 

638 block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate") 

639 for season in [None] + list(self.uncertainty_estimate_seasons.keys()): 

640 estimate = self.uncertainty_estimate if season is None else self.uncertainty_estimate_seasons[season] 

641 PlotSampleUncertaintyFromBootstrap( 

642 data=estimate, plot_folder=self.plot_path, model_type_dim=self.model_type_dim, 

643 dim_name_boots=self.uncertainty_estimate_boot_dim, error_measure="mean squared error", 

644 error_unit=r"ppb$^2$", block_length=block_length, model_name=self.model_display_name, 

645 model_indicator=self.forecast_indicator, sampling=self._sampling, season_annotation=season) 

646 except Exception as e: 

647 logging.error(f"Could not create plot PlotSampleUncertaintyFromBootstrap due to the following error: {e}" 

648 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

649 

650 try: 

651 if "PlotErrorMetrics" in plot_list and self.errors is not None: 

652 error_metric_units = statistics.get_error_metrics_units("ppb") 

653 error_metrics_name = statistics.get_error_metrics_long_name() 

654 for error_metric in self.errors.keys(): 

655 try: 

656 PlotSampleUncertaintyFromBootstrap( 

657 data=self.errors[error_metric], plot_folder=self.plot_path, model_type_dim=self.model_type_dim, 

658 dim_name_boots="station", error_measure=error_metrics_name[error_metric], 

659 error_unit=error_metric_units[error_metric], model_name=self.model_display_name, 

660 model_indicator=self.model_display_name, sampling=self._sampling, apply_root=False, 

661 plot_name=f"error_plot_{error_metric}") 

662 except Exception as e: 

663 logging.error(f"Could not create plot PlotErrorMetrics for {error_metric} due to the following " 

664 f"error: {e}\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

665 except Exception as e: 

666 logging.error(f"Could not create plot PlotErrorMetrics due to the following error: {e}" 

667 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

668 

669 try: 

670 if "PlotErrorMetrics" in plot_list and self.bias_free_errors is not None: 

671 error_metric_units = statistics.get_error_metrics_units("ppb") 

672 error_metrics_name = statistics.get_error_metrics_long_name() 

673 tag = {0: "", 1: "seasonal_"} 

674 for i, errors in enumerate(self.bias_free_errors): 

675 for error_metric in errors.keys(): 

676 try: 

677 PlotSampleUncertaintyFromBootstrap( 

678 data=errors[error_metric], plot_folder=self.plot_path, 

679 model_type_dim=self.model_type_dim, 

680 dim_name_boots="station", error_measure=error_metrics_name[error_metric], 

681 error_unit=error_metric_units[error_metric], model_name=self.model_display_name, 

682 model_indicator=self.model_display_name, sampling=self._sampling, apply_root=False, 

683 plot_name=f"{tag[i]}bias_free_error_plot_{error_metric}") 

684 except Exception as e: 

685 logging.error(f"Could not create plot PlotErrorMetrics for {error_metric} (bias free " 

686 f"{tag[i]}) due to the following error: " 

687 f"{e}\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

688 except Exception as e: 

689 logging.error(f"Could not create plot PlotErrorMetrics (bias free) due to the following error: {e}" 

690 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

691 

692 try: 

693 if "PlotStationMap" in plot_list: 

694 gens = [(self.train_data, {"marker": 5, "ms": 9}), 

695 (self.val_data, {"marker": 6, "ms": 9}), 

696 (self.test_data, {"marker": 4, "ms": 9})] 

697 PlotStationMap(generators=gens, plot_folder=self.plot_path) 

698 gens = [(self.train_val_data, {"marker": 8, "ms": 9}), 

699 (self.test_data, {"marker": 9, "ms": 9})] 

700 PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var") 

701 except Exception as e: 

702 if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get("hostname")[:6] in self.data_store.get("hpc_hosts"): 

703 logging.info(f"PlotStationMap might have failed as current workflow is running on hpc node {self.data_store.get('hostname')}. To download geographic elements, please run PlotStationMap once on login node.") 

704 logging.error(f"Could not create plot PlotStationMap due to the following error: {e}" 

705 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

706 

707 try: 

708 if "PlotAvailability" in plot_list: 

709 avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} 

710 PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dim, 

711 window_dimension=window_dim) 

712 except Exception as e: 

713 logging.error(f"Could not create plot PlotAvailability due to the following error: {e}" 

714 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

715 

716 try: 

717 if "PlotAvailabilityHistogram" in plot_list: 

718 avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} 

719 PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path, station_dim=iter_dim, 

720 history_dim=window_dim) 

721 except Exception as e: 

722 logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}" 

723 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

724 

725 try: 

726 if "PlotDataMonthlyDistribution" in plot_list: 

727 avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} 

728 PlotDataMonthlyDistribution(avail_data, plot_folder=self.plot_path, time_dim=time_dim, 

729 variables_dim=target_dim, window_dim=window_dim, target_var=self.target_var, 

730 ) 

731 except Exception as e: 

732 logging.error(f"Could not create plot PlotDataMonthlyDistribution due to the following error:" 

733 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

734 

735 try: 

736 if "PlotDataHistogram" in plot_list: 

737 upsampling = self.data_store.get_default("upsampling", scope="train", default=False) 

738 gens = {"train": self.train_data, "val": self.val_data, "test": self.test_data} 

739 PlotDataHistogram(gens, plot_folder=self.plot_path, time_dim=time_dim, variables_dim=target_dim, 

740 upsampling=upsampling) 

741 except Exception as e: 

742 logging.error(f"Could not create plot PlotDataHistogram due to the following error: {e}" 

743 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

744 

745 try: 

746 if "PlotPeriodogram" in plot_list: 

747 PlotPeriodogram(self.train_data, plot_folder=self.plot_path, time_dim=time_dim, 

748 variables_dim=target_dim, sampling=self._sampling, 

749 use_multiprocessing=use_multiprocessing) 

750 except Exception as e: 

751 logging.error(f"Could not create plot PlotPeriodogram due to the following error: {e}" 

752 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

753 

754 try: 

755 if "PlotTimeEvolutionMetric" in plot_list: 

756 PlotTimeEvolutionMetric(self.block_mse_per_station, plot_folder=self.plot_path, 

757 model_type_dim=self.model_type_dim, ahead_dim=self.ahead_dim, 

758 error_measure="Mean Squared Error", error_unit=r"ppb$^2$", 

759 model_indicator=self.forecast_indicator, model_name=self.model_display_name, 

760 time_dim=self.index_dim) 

761 except Exception as e: 

762 logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}" 

763 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

764 

765 try: 

766 if "PlotSeasonalMSEStack" in plot_list: 

767 report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

768 PlotSeasonalMSEStack(data=self.block_mse_per_station, data_path=report_path, plot_folder=self.plot_path, 

769 boot_dim=self.uncertainty_estimate_boot_dim, ahead_dim=self.ahead_dim, 

770 sampling=self._sampling, error_measure="Mean Squared Error", error_unit=r"ppb$^2$", 

771 model_indicator=self.forecast_indicator, model_name=self.model_display_name, 

772 model_type_dim=self.model_type_dim) 

773 except Exception as e: 

774 logging.error(f"Could not create plot PlotSeasonalMSEStack due to the following error: {e}" 

775 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

776 

777 try: 

778 if "PlotErrorsOnMap" in plot_list and self.errors is not None: 

779 for error_metric in self.errors.keys(): 

780 try: 

781 PlotErrorsOnMap(self.test_data, self.errors[error_metric], error_metric, 

782 plot_folder=self.plot_path, sampling=self._sampling) 

783 except Exception as e: 

784 logging.error(f"Could not create plot PlotErrorsOnMap for {error_metric} due to the following " 

785 f"error: {e}\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

786 except Exception as e: 

787 logging.error(f"Could not create plot PlotErrorsOnMap due to the following error: {e}" 

788 f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

789 

790 @TimeTrackingWrapper 

791 def calculate_test_score(self): 

792 """Evaluate test score of model and save locally.""" 

793 logging.info(f"start to calculate test scores") 

794 

795 # test scores on transformed data 

796 test_score = self.model.evaluate(self.test_data_distributed, 

797 use_multiprocessing=True, verbose=0) 

798 path = self.data_store.get("model_path") 

799 with open(os.path.join(path, "test_scores.txt"), "a") as f: 

800 for index, item in enumerate(to_list(test_score)): 

801 logging.info(f"{self.model.metrics_names[index]} (test), {item}") 

802 f.write(f"{self.model.metrics_names[index]}, {item}\n") 

803 

804 @TimeTrackingWrapper 

805 def train_ols_model(self): 

806 """Train ordinary least squared model on train data.""" 

807 if "ols" in map(lambda x: x.lower(), self.competitors): 

808 logging.info(f"start train_ols_model on train data") 

809 self.ols_model = OrdinaryLeastSquaredModel(self.train_data) 

810 self.competitors = [e for e in self.competitors if e.lower() != "ols"] 

811 else: 

812 logging.info(f"Skip train ols model as it is not present in competitors.") 

813 

814 def setup_persistence(self): 

815 """Check if persistence is requested from competitors and store this information.""" 

816 self.persi_model = any(x in map(str.lower, self.competitors) for x in ["persi", "persistence"]) 

817 if self.persi_model is False: 

818 logging.info(f"Persistence is not calculated as it is not present in competitors.") 

819 

820 @TimeTrackingWrapper 

821 def make_prediction(self, subset): 

822 """ 

823 Create predictions for NN, OLS, and persistence and add true observation as reference. 

824 

825 Predictions are filled in an array with full index range. Therefore, predictions can have missing values. All 

826 predictions for a single station are stored locally under `<forecast/forecast_norm>_<station>_test.nc` and can 

827 be found inside `forecast_path`. 

828 """ 

829 subset_type = subset.name 

830 logging.info(f"start make_prediction for {subset_type}") 

831 time_dimension = self.data_store.get("time_dim") 

832 window_dim = self.data_store.get("window_dim") 

833 

834 for i, data in enumerate(subset): 

835 input_data = data.get_X() 

836 target_data = data.get_Y(as_numpy=False) 

837 observation_data = data.get_observation() 

838 

839 # get scaling parameters 

840 transformation_func = data.apply_transformation 

841 

842 nn_output = self.model.predict(input_data) 

843 

844 for normalised in [True, False]: 

845 # create empty arrays 

846 nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays( 

847 target_data, count=4) 

848 

849 # nn forecast 

850 nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised) 

851 

852 # persistence 

853 if self.persi_model is True: 

854 persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, 

855 transformation_func, normalised) 

856 else: 

857 persistence_prediction = None 

858 

859 # ols 

860 if self.ols_model is not None: 

861 ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_func, 

862 normalised) 

863 else: 

864 ols_prediction = None 

865 

866 # observation 

867 observation = self._create_observation(target_data, observation, transformation_func, normalised) 

868 

869 # merge all predictions 

870 full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency()) 

871 prediction_dict = {self.forecast_indicator: nn_prediction, 

872 "persi": persistence_prediction, 

873 self.observation_indicator: observation, 

874 "ols": ols_prediction} 

875 all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]), 

876 time_dimension, ahead_dim=self.ahead_dim, 

877 index_dim=self.index_dim, type_dim=self.model_type_dim, 

878 **prediction_dict) 

879 

880 # save all forecasts locally 

881 prefix = "forecasts_norm" if normalised is True else "forecasts" 

882 file = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_{subset_type}.nc") 

883 all_predictions.to_netcdf(file) 

884 

885 def _get_frequency(self) -> str: 

886 """Get frequency abbreviation.""" 

887 getter = {"daily": "1D", "hourly": "1H"} 

888 return getter.get(self._sampling, None) 

889 

890 def _create_competitor_forecast(self, station_name: str, competitor_name: str) -> xr.DataArray: 

891 """ 

892 Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station 

893 indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will 

894 raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either 

895 there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file. 

896 Forecast is trimmed on interval start and end of test subset. 

897 

898 :param station_name: name of the station to load data for 

899 :param competitor_name: name of the model 

900 :return: the forecast of the given competitor 

901 """ 

902 path = os.path.join(self.competitor_path, competitor_name) 

903 file = os.path.join(path, f"forecasts_{station_name}_test.nc") 

904 with xr.open_dataarray(file) as da: 

905 data = da.load() 

906 if self.forecast_indicator in data.coords[self.model_type_dim]: 

907 forecast = data.sel({self.model_type_dim: [self.forecast_indicator]}) 

908 forecast.coords[self.model_type_dim] = [competitor_name] 

909 else: 

910 forecast = data.sel({self.model_type_dim: [competitor_name]}) 

911 # limit forecast to time range of test subset 

912 start, end = self.data_store.get("start", "test"), self.data_store.get("end", "test") 

913 return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end) 

914 

915 def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray: 

916 """ 

917 Create observation as ground truth from given data. 

918 

919 Inverse transformation is applied to the ground truth to get the output in the original space. 

920 

921 :param data: observation 

922 :param transformation_func: a callable function to apply inverse transformation 

923 :param normalised: transform ground truth in original space if false, or use normalised predictions if true 

924 

925 :return: filled data array with observation 

926 """ 

927 if not normalised: 

928 data = transformation_func(data, "target", inverse=True) 

929 return data 

930 

931 def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray, 

932 transformation_func: Callable, normalised: bool) -> xr.DataArray: 

933 """ 

934 Create ordinary least square model forecast with given input data. 

935 

936 Inverse transformation is applied to the forecast to get the output in the original space. 

937 

938 :param input_data: transposed history from DataPrep 

939 :param ols_prediction: empty array in right shape to fill with data 

940 :param transformation_func: a callable function to apply inverse transformation 

941 :param normalised: transform prediction in original space if false, or use normalised predictions if true 

942 

943 :return: filled data array with ols predictions 

944 """ 

945 tmp_ols = self.ols_model.predict(input_data) 

946 target_shape = ols_prediction.values.shape 

947 if target_shape != tmp_ols.shape: 

948 if len(target_shape) == 2: 

949 new_values = np.swapaxes(tmp_ols, 1, 0) 

950 else: 

951 new_values = np.swapaxes(tmp_ols, 2, 0) 

952 else: 

953 new_values = tmp_ols 

954 ols_prediction.values = new_values 

955 if not normalised: 

956 ols_prediction = transformation_func(ols_prediction, "target", inverse=True) 

957 return ols_prediction 

958 

959 def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, transformation_func: Callable, 

960 normalised: bool) -> xr.DataArray: 

961 """ 

962 Create persistence forecast with given data. 

963 

964 Persistence is deviated from the value at t=0 and applied to all following time steps (t+1, ..., t+window). 

965 Inverse transformation is applied to the forecast to get the output in the original space. 

966 

967 :param data: observation 

968 :param persistence_prediction: empty array in right shape to fill with data 

969 :param transformation_func: a callable function to apply inverse transformation 

970 :param normalised: transform prediction in original space if false, or use normalised predictions if true 

971 

972 :return: filled data array with persistence predictions 

973 """ 

974 tmp_persi = data.copy() 

975 persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T 

976 if not normalised: 

977 persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True) 

978 return persistence_prediction 

979 

980 def _create_nn_forecast(self, nn_output: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable, 

981 normalised: bool) -> xr.DataArray: 

982 """ 

983 Create NN forecast for given input data. 

984 

985 Inverse transformation is applied to the forecast to get the output in the original space. Furthermore, only the 

986 output of the main branch is returned (not all minor branches, if the network has multiple output branches). The 

987 main branch is defined to be the last entry of all outputs. 

988 

989 :param nn_output: Full NN model output 

990 :param nn_prediction: empty array in right shape to fill with data 

991 :param transformation_func: a callable function to apply inverse transformation 

992 :param normalised: transform prediction in original space if false, or use normalised predictions if true 

993 

994 :return: filled data array with nn predictions 

995 """ 

996 

997 if isinstance(nn_output, list): 

998 nn_prediction.values = nn_output[-1] 

999 elif nn_output.ndim == 3: 

1000 nn_prediction.values = nn_output[-1, ...] 

1001 elif nn_output.ndim == 2: 

1002 nn_prediction.values = nn_output 

1003 else: 

1004 raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {nn_output.dims}.") 

1005 if not normalised: 

1006 nn_prediction = transformation_func(nn_prediction, base="target", inverse=True) 

1007 return nn_prediction 

1008 

1009 @staticmethod 

1010 def _create_empty_prediction_arrays(target_data, count=1): 

1011 """ 

1012 Create array to collect all predictions. Expand target data by a station dimension. """ 

1013 return [target_data.copy() for _ in range(count)] 

1014 

1015 @staticmethod 

1016 def create_fullindex(df: Union[xr.DataArray, pd.DataFrame, pd.DatetimeIndex], freq: str) -> pd.DataFrame: 

1017 """ 

1018 Create full index from first and last date inside df and resample with given frequency. 

1019 

1020 :param df: use time range of this data set 

1021 :param freq: frequency of full index 

1022 

1023 :return: empty data frame with full index. 

1024 """ 

1025 if isinstance(df, pd.DataFrame): 

1026 earliest = df.index[0] 

1027 latest = df.index[-1] 

1028 elif isinstance(df, xr.DataArray): 

1029 earliest = df.index[0].values 

1030 latest = df.index[-1].values 

1031 elif isinstance(df, pd.DatetimeIndex): 

1032 earliest = df[0] 

1033 latest = df[-1] 

1034 else: 

1035 raise AttributeError(f"unknown array type. Only pandas dataframes, xarray dataarrays and pandas datetimes " 

1036 f"are supported. Given type is {type(df)}.") 

1037 index = pd.DataFrame(index=pd.date_range(earliest, latest, freq=freq)) 

1038 return index 

1039 

1040 @staticmethod 

1041 def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, 

1042 ahead_dim="ahead", index_dim="index", type_dim="type", **kwargs): 

1043 """ 

1044 Combine different forecast types into single xarray. 

1045 

1046 :param index: index for forecasts (e.g. time) 

1047 :param ahead_names: names of ahead values (e.g. hours or days) 

1048 :param kwargs: as xarrays; data of forecasts 

1049 

1050 :return: xarray of dimension 3: index, ahead_names, # predictions 

1051 

1052 """ 

1053 kwargs = {k: v for k, v in kwargs.items() if v is not None} 

1054 keys = list(kwargs.keys()) 

1055 res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan), 

1056 coords=[index.index, ahead_names, keys], dims=[index_dim, ahead_dim, type_dim]) 

1057 for k, v in kwargs.items(): 

1058 intersection = set(res.index.values) & set(v.indexes[time_dimension].values) 

1059 match_index = np.array(list(intersection)) 

1060 res.loc[match_index, :, k] = v.loc[match_index] 

1061 return res 

1062 

1063 def _get_internal_data(self, station: str, path: str) -> Union[xr.DataArray, None]: 

1064 """ 

1065 Get internal data for given station. 

1066 

1067 Internal data is defined as data that is already known to the model. From an evaluation perspective, this 

1068 refers to data, that is no test data, and therefore to train and val data. 

1069 

1070 :param station: name of station to load internal data. 

1071 """ 

1072 try: 

1073 file = os.path.join(path, f"forecasts_{str(station)}_train_val.nc") 

1074 with xr.open_dataarray(file) as da: 

1075 return da.load() 

1076 except (IndexError, KeyError, FileNotFoundError): 

1077 return None 

1078 

1079 def _get_external_data(self, station: str, path: str) -> Union[xr.DataArray, None]: 

1080 """ 

1081 Get external data for given station. 

1082 

1083 External data is defined as data that is not known to the model. From an evaluation perspective, this refers to 

1084 data, that is not train or val data, and therefore to test data. 

1085 

1086 :param station: name of station to load external data. 

1087 """ 

1088 try: 

1089 file = os.path.join(path, f"forecasts_{str(station)}_test.nc") 

1090 with xr.open_dataarray(file) as da: 

1091 return da.load() 

1092 except (IndexError, KeyError, FileNotFoundError): 

1093 return None 

1094 

1095 def _combine_forecasts(self, forecast, competitor, dim=None): 

1096 """ 

1097 Combine forecast and competitor if both are xarray. If competitor is None, this returns forecasts and vise 

1098 versa. 

1099 """ 

1100 if dim is None: 

1101 dim = self.model_type_dim 

1102 try: 

1103 return xr.concat([forecast, competitor], dim=dim) 

1104 except (TypeError, AttributeError): 

1105 return forecast if competitor is None else competitor 

1106 

1107 def calculate_bias_free_error_metrics(self): 

1108 all_stations = self.data_store.get("stations") 

1109 errors = [{}, {}] # errors_total_bias_free, errors_seasonal_bias_free 

1110 for station in all_stations: 

1111 external_data = self._get_external_data(station, self.forecast_path) # test data 

1112 

1113 # test errors 

1114 if external_data is not None: 

1115 external_data.coords[self.model_type_dim] = [{self.forecast_indicator: self.model_display_name}.get(n, n) 

1116 for n in external_data.coords[self.model_type_dim].values] 

1117 model_type_list = external_data.coords[self.model_type_dim].values.tolist() 

1118 

1119 # load competitors 

1120 competitor = self.load_competitors(station) 

1121 combined = self._combine_forecasts(external_data, competitor, dim=self.model_type_dim) 

1122 if combined is not None: 

1123 model_list = remove_items(combined.coords[self.model_type_dim].values.tolist(), 

1124 self.observation_indicator) 

1125 else: 

1126 model_list = None 

1127 continue 

1128 

1129 # data_total_bias_free, data_seasonal_bias_free 

1130 bias_free_data = statistics.calculate_bias_free_data(combined, time_dim=self.index_dim, window_size=30) 

1131 

1132 # test errors of competitors 

1133 for model_type in (model_list or []): 

1134 for data, e in zip(bias_free_data, errors): 

1135 if self.observation_indicator not in data.coords[self.model_type_dim]: 

1136 continue 

1137 if model_type not in e.keys(): 

1138 e[model_type] = {} 

1139 e[model_type][station] = statistics.calculate_error_metrics( 

1140 *map(lambda x: data.sel(**{self.model_type_dim: x}), 

1141 [model_type, self.observation_indicator]), dim=self.index_dim) 

1142 for e in errors: 

1143 for model_type in e.keys(): 

1144 e[model_type].update({"total": self.calculate_average_errors(e[model_type])}) 

1145 return errors 

1146 

1147 def calculate_error_metrics(self) -> Tuple[Dict, Dict, Dict, Dict]: 

1148 """ 

1149 Calculate error metrics and skill scores of NN forecast. 

1150 

1151 The competitive skill score compares the NN prediction with persistence and ordinary least squares forecasts. 

1152 Whereas, the climatological skill scores evaluates the NN prediction in terms of meaningfulness in comparison 

1153 to different climatological references. 

1154 

1155 :return: competitive and climatological skill scores, error metrics 

1156 """ 

1157 all_stations = self.data_store.get("stations") 

1158 skill_score_competitive = {} 

1159 skill_score_competitive_count = {} 

1160 skill_score_climatological = {} 

1161 errors = {} 

1162 for station in all_stations: 

1163 external_data = self._get_external_data(station, self.forecast_path) # test data 

1164 

1165 # test errors 

1166 if external_data is not None: 

1167 external_data.coords[self.model_type_dim] = [{self.forecast_indicator: self.model_display_name}.get(n, n) 

1168 for n in external_data.coords[self.model_type_dim].values] 

1169 model_type_list = external_data.coords[self.model_type_dim].values.tolist() 

1170 for model_type in remove_items(model_type_list, self.observation_indicator): 

1171 if model_type not in errors.keys(): 

1172 errors[model_type] = {} 

1173 errors[model_type][station] = statistics.calculate_error_metrics( 

1174 *map(lambda x: external_data.sel(**{self.model_type_dim: x}), 

1175 [model_type, self.observation_indicator]), dim=self.index_dim) 

1176 

1177 # load competitors 

1178 competitor = self.load_competitors(station) 

1179 combined = self._combine_forecasts(external_data, competitor, dim=self.model_type_dim) 

1180 if combined is not None: 

1181 model_list = remove_items(combined.coords[self.model_type_dim].values.tolist(), 

1182 self.observation_indicator) 

1183 else: 

1184 model_list = None 

1185 

1186 # test errors of competitors 

1187 for model_type in (model_list or []): 

1188 if self.observation_indicator not in combined.coords[self.model_type_dim]: 

1189 continue 

1190 if model_type not in errors.keys(): 

1191 errors[model_type] = {} 

1192 errors[model_type][station] = statistics.calculate_error_metrics( 

1193 *map(lambda x: combined.sel(**{self.model_type_dim: x}), 

1194 [model_type, self.observation_indicator]), dim=self.index_dim) 

1195 

1196 # skill score 

1197 skill_score = statistics.SkillScores(combined, models=model_list, ahead_dim=self.ahead_dim, 

1198 type_dim=self.model_type_dim, index_dim=self.index_dim) 

1199 if external_data is not None: 

1200 skill_score_competitive[station], skill_score_competitive_count[station] = skill_score.skill_scores() 

1201 

1202 internal_data = self._get_internal_data(station, self.forecast_path) 

1203 if internal_data is not None: 

1204 skill_score_climatological[station] = skill_score.climatological_skill_scores( 

1205 internal_data, forecast_name=self.forecast_indicator) 

1206 

1207 for model_type in errors.keys(): 

1208 errors[model_type].update({"total": self.calculate_average_errors(errors[model_type])}) 

1209 skill_score_competitive.update({"total": self.calculate_average_skill_scores(skill_score_competitive, 

1210 skill_score_competitive_count)}) 

1211 return skill_score_competitive, skill_score_competitive_count, skill_score_climatological, errors 

1212 

1213 @staticmethod 

1214 def calculate_average_skill_scores(scores, counts): 

1215 avg_skill_score = 0 

1216 n_total = None 

1217 for vals in counts.values(): 

1218 n_total = vals if n_total is None else n_total.add(vals, fill_value=0) 

1219 for station, station_scores in scores.items(): 

1220 n = counts.get(station) 

1221 avg_skill_score = station_scores.mul(n.div(n_total, fill_value=0), fill_value=0).add(avg_skill_score, 

1222 fill_value=0) 

1223 return avg_skill_score 

1224 

1225 @staticmethod 

1226 def calculate_average_errors(errors): 

1227 avg_error = {} 

1228 n_total = sum([x.get("n", 0) for _, x in errors.items()]) 

1229 for station, station_errors in errors.items(): 

1230 n_station = station_errors.get("n") 

1231 for error_metric, val in station_errors.items(): 

1232 new_val = avg_error.get(error_metric, 0) + val * n_station / n_total 

1233 avg_error[error_metric] = new_val 

1234 return avg_error 

1235 

1236 def report_feature_importance_results(self, results): 

1237 """Create a csv file containing all results from feature importance.""" 

1238 report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

1239 path_config.check_path_and_create(report_path) 

1240 res = [] 

1241 max_cols = 0 

1242 for boot_type, d0 in results.items(): 

1243 for boot_method, d1 in d0.items(): 

1244 for station_name, vals in d1.items(): 

1245 for boot_var in vals.coords[self.boot_var_dim].values.tolist(): 

1246 for ahead in vals.coords[self.ahead_dim].values.tolist(): 

1247 res.append([boot_type, boot_method, station_name, boot_var, ahead, 

1248 *vals.sel({self.boot_var_dim: boot_var, 

1249 self.ahead_dim: ahead}).values.round(5).tolist()]) 

1250 max_cols = max(max_cols, len(res[-1])) 

1251 col_names = [self.model_type_dim, "method", "station", self.boot_var_dim, self.ahead_dim, 

1252 *list(range(max_cols - 5))] 

1253 df = pd.DataFrame(res, columns=col_names) 

1254 file_name = "feature_importance_skill_score_report_raw.csv" 

1255 df.to_csv(os.path.join(report_path, file_name), sep=";") 

1256 

1257 def report_error_metrics(self, errors, tag=None): 

1258 report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

1259 path_config.check_path_and_create(report_path) 

1260 base_file_name = "error_report" if tag is None else f"error_report_{tag}" 

1261 for model_type in errors.keys(): 

1262 metric_collection = {} 

1263 for station, station_errors in errors[model_type].items(): 

1264 if isinstance(station_errors, xr.DataArray): 

1265 dim = station_errors.dims[0] 

1266 sel_index = [sel for sel in station_errors.coords[dim] if "CASE" in str(sel)] 

1267 station_errors = {str(i.values): station_errors.sel(**{dim: i}) for i in sel_index} 

1268 elif isinstance(station_errors, pd.DataFrame): 

1269 sel_index = station_errors.index.tolist() 

1270 ahead = station_errors.columns.values 

1271 station_errors = {k: xr.DataArray(station_errors[station_errors.index == k].values.flatten(), 

1272 dims=["ahead"], coords={"ahead": ahead}).astype(float) 

1273 for k in sel_index} 

1274 for metric, vals in station_errors.items(): 

1275 if metric == "n": 

1276 metric = "count" 

1277 pd_vals = pd.DataFrame.from_dict({station: vals}).T 

1278 pd_vals.columns = [f"{metric}(t+{x})" for x in vals.coords["ahead"].values] 

1279 mc = metric_collection.get(metric, pd.DataFrame()) 

1280 mc = mc.append(pd_vals) 

1281 metric_collection[metric] = mc 

1282 for metric, error_df in metric_collection.items(): 

1283 df = error_df.sort_index() 

1284 if "total" in df.index: 

1285 df.reindex(df.index.drop(["total"]).to_list() + ["total"], ) 

1286 column_format = tables.create_column_format_for_tex(df) 

1287 if model_type == "skill_score": 

1288 file_name = f"{base_file_name}_{model_type}_{metric}.%s".replace(' ', '_').replace('/', '_') 

1289 else: 

1290 file_name = f"{base_file_name}_{metric}_{model_type}.%s".replace(' ', '_').replace('/', '_') 

1291 tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df) 

1292 tables.save_to_md(report_path, file_name % "md", df=df) 

1293 

1294 def store_errors(self, errors): 

1295 metric_collection = {} 

1296 error_dim = "error_metric" 

1297 station_dim = "station" 

1298 for model_type in errors.keys(): 

1299 station_collection = {} 

1300 for station, station_errors in errors[model_type].items(): 

1301 if station == "total": 

1302 continue 

1303 station_collection[station] = xr.Dataset(station_errors).to_array(error_dim) 

1304 metric_collection[model_type] = xr.Dataset(station_collection).to_array(station_dim) 

1305 coll = xr.Dataset(metric_collection).to_array(self.model_type_dim) 

1306 coll = coll.transpose(station_dim, self.ahead_dim, self.model_type_dim, error_dim) 

1307 return {k: coll.sel({error_dim: k}, drop=True) for k in coll.coords[error_dim].values}