Coverage for mlair/plotting/postprocessing_plotting.py: 13%

357 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1"""Collection of plots to evaluate a model, create overviews on data or forecasts.""" 

2__author__ = "Lukas Leufen, Felix Kleinert" 

3__date__ = '2020-11-23' 

4 

5import logging 

6import math 

7import os 

8import warnings 

9from typing import Dict, List, Tuple, Union 

10import itertools 

11 

12import matplotlib 

13import matplotlib.pyplot as plt 

14import matplotlib.colors as colors 

15import numpy as np 

16import pandas as pd 

17import seaborn as sns 

18import xarray as xr 

19from matplotlib.backends.backend_pdf import PdfPages 

20from matplotlib.offsetbox import AnchoredText 

21import matplotlib.dates as mdates 

22from scipy.stats import mannwhitneyu 

23import datetime as dt 

24 

25from mlair import helpers 

26from mlair.data_handler.iterator import DataCollection 

27from mlair.helpers import TimeTrackingWrapper 

28from mlair.helpers.helpers import relative_round 

29from mlair.plotting.abstract_plot_class import AbstractPlotClass 

30from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks 

31 

32 

33logging.getLogger('matplotlib').setLevel(logging.WARNING) 

34 

35 

36# import matplotlib 

37# matplotlib.use("TkAgg") 

38# import matplotlib.pyplot as plt 

39 

40 

41@TimeTrackingWrapper 

42class PlotMonthlySummary(AbstractPlotClass): # pragma: no cover 

43 """ 

44 Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. 

45 

46 The plot is saved in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution. 

47 

48 .. image:: ../../../../../_source/_plots/monthly_summary_box_plot.png 

49 :width: 400 

50 

51 :param stations: all stations to plot 

52 :param data_path: path, where the data is located 

53 :param name: full name of the local files with a % as placeholder for the station name 

54 :param target_var: display name of the target variable on plot's axis 

55 :param window_lead_time: lead time to plot, if window_lead_time is higher than the available lead time or not given 

56 the maximum lead time from data is used. (default None -> use maximum lead time from data). 

57 :param plot_folder: path to save the plot (default: current directory) 

58 :param target_var_unit: unit of target var for plot legend (default= ppb) 

59 

60 """ 

61 

62 def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None, 

63 plot_folder: str = ".", target_var_unit: str = 'ppb', model_name="nn"): 

64 """Set attributes and create plot.""" 

65 super().__init__(plot_folder, "monthly_summary_box_plot") 

66 self._data_path = data_path 

67 self._data_name = name 

68 self._model_name = model_name 

69 self._data = self._prepare_data(stations) 

70 self._window_lead_time = self._get_window_lead_time(window_lead_time) 

71 self._plot(target_var, target_var_unit) 

72 self._save() 

73 

74 def _prepare_data(self, stations: List) -> xr.DataArray: 

75 """ 

76 Pre.process data required to plot. 

77 

78 For each station, load locally saved predictions, extract the CNN prediction and the observation and group them 

79 into monthly bins (no aggregation, only sorting them). 

80 

81 :param stations: all stations to plot 

82 :return: The entire data set, flagged with the corresponding month. 

83 """ 

84 forecasts = None 

85 for station in stations: 

86 logging.debug(f"... preprocess station {station}") 

87 file_name = os.path.join(self._data_path, self._data_name % station) 

88 data = xr.open_dataarray(file_name) 

89 

90 data_nn = data.sel(type=self._model_name).squeeze() 

91 if len(data_nn.shape) > 1: 

92 data_nn = data_nn.assign_coords(ahead=[f"{days}d" for days in data_nn.coords["ahead"].values]) 

93 else: 

94 data_nn.coords["ahead"].values = str(data_nn.coords["ahead"].values) + "d" 

95 

96 data_obs = data.sel(type="obs", ahead=1).squeeze() 

97 data_obs.coords["ahead"] = "obs" 

98 

99 data_concat = xr.concat([data_obs, data_nn], dim="ahead") 

100 data_concat = data_concat.drop_vars("type") 

101 

102 new_index = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1 

103 data_concat = data_concat.assign_coords(index=new_index) 

104 data_concat = data_concat.clip(min=0) 

105 

106 forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat 

107 return forecasts 

108 

109 def _get_window_lead_time(self, window_lead_time: int): 

110 """ 

111 Extract the lead time from data and arguments. 

112 

113 If window_lead_time is not given, extract this information from data itself by the number of ahead dimensions. 

114 If given, check if data supports the give length. If the number of ahead dimensions in data is lower than the 

115 given lead time, data's lead time is used. 

116 

117 :param window_lead_time: lead time from arguments to validate 

118 :return: validated lead time, comes either from given argument or from data itself 

119 """ 

120 ahead_steps = len(self._data.ahead) 

121 if window_lead_time is None: 

122 window_lead_time = ahead_steps 

123 return min(ahead_steps, window_lead_time) 

124 

125 @staticmethod 

126 def _spell_out_chemical_concentrations(short_name: str): 

127 short2long = {'o3': 'ozone', 'no': 'nitrogen oxide', 'no2': 'nitrogen dioxide', 'nox': 'nitrogen dioxides'} 

128 return f"{short2long[short_name]} concentration" 

129 

130 def _plot(self, target_var: str, target_var_unit: str): 

131 """ 

132 Create a monthly grouped box plot over all stations but with separate boxes for each lead time step. 

133 

134 :param target_var: display name of the target variable on plot's axis 

135 """ 

136 data = self._data.to_dataset(name='values').to_dask_dataframe() 

137 logging.debug("... start plotting") 

138 color_palette = [colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() 

139 ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1.5, palette=color_palette, 

140 flierprops={'marker': '.', 'markersize': 1}, showmeans=True, 

141 meanprops={'markersize': 1, 'markeredgecolor': 'k'}) 

142 ylabel = self._spell_out_chemical_concentrations(target_var) 

143 ax.set(xlabel='month', ylabel=f'{ylabel} (in {target_var_unit})') 

144 plt.tight_layout() 

145 

146 

147@TimeTrackingWrapper 

148class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover 

149 """ 

150 Create cond.quantile plots as originally proposed by Murphy, Brown and Chen (1989) [But in log scale]. 

151 

152 Link to paper: https://journals.ametsoc.org/doi/pdf/10.1175/1520-0434%281989%29004%3C0485%3ADVOTF%3E2.0.CO%3B2 

153 

154 .. image:: ../../../../../_source/_plots/conditional_quantiles_cali-ref_plot.png 

155 :width: 400 

156 

157 .. image:: ../../../../../_source/_plots/conditional_quantiles_like-bas_plot.png 

158 :width: 400 

159 

160 For each time step ahead a separate plot is created. If parameter plot_per_season is true, data is split by season 

161 and conditional quantiles are plotted for each season in addition. 

162 

163 :param stations: all stations to plot 

164 :param data_pred_path: path to dir which contains the forecasts as .nc files 

165 :param plot_folder: path where the plots are stored 

166 :param plot_per_seasons: if `True' create cond. quantile plots for _seasons (DJF, MAM, JJA, SON) individually 

167 :param rolling_window: smoothing of quantiles (3 is used by Murphy et al.) 

168 :param model_name: name of the model prediction as stored in netCDF file (for example "nn") 

169 :param obs_name: name of observation as stored in netCDF file (for example "obs") 

170 :param kwargs: Some further arguments which are listed in self._opts 

171 """ 

172 

173 # ignore warnings if nans appear in quantile grouping 

174 warnings.filterwarnings("ignore", message="All-NaN slice encountered") 

175 # ignore warnings if mean is calculated on nans 

176 warnings.filterwarnings("ignore", message="Mean of empty slice") 

177 # ignore warnings for y tick = 0 on log scale (instead of 0.00001 or similar) 

178 warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.") 

179 

180 def __init__(self, stations: List, data_pred_path: str, plot_folder: str = ".", plot_per_seasons=True, 

181 rolling_window: int = 3, forecast_indicator: str = "nn", obs_indicator: str = "obs", 

182 competitors=None, model_type_dim: str = "type", index_dim: str = "index", ahead_dim: str = "ahead", 

183 competitor_path: str = None, sampling: str = "daily", model_name: str = "nn", **kwargs): 

184 """Initialise.""" 

185 super().__init__(plot_folder, "conditional_quantiles") 

186 self._data_pred_path = data_pred_path 

187 self._stations = stations 

188 self._rolling_window = rolling_window 

189 self._forecast_indicator = forecast_indicator 

190 self.model_type_dim = model_type_dim 

191 self.index_dim = index_dim 

192 self.ahead_dim = ahead_dim 

193 self.iter_dim = "station" 

194 self.model_name = model_name 

195 self._obs_name = obs_indicator 

196 self._sampling = sampling 

197 self._opts = self._get_opts(kwargs) 

198 self._seasons = ['DJF', 'MAM', 'JJA', 'SON'] if plot_per_seasons is True else "" 

199 self.competitors = self._correct_persi_name(competitors or []) 

200 self.competitor_path = competitor_path or data_pred_path 

201 self._data = self._load_data() 

202 self._bins = self._get_bins_from_rage_of_data() 

203 self._plot() 

204 

205 @staticmethod 

206 def _get_opts(kwargs): 

207 """Extract options from kwargs.""" 

208 return {"q": kwargs.get("q", [.1, .25, .5, .75, .9]), 

209 "linetype": kwargs.get("linetype", [':', '-.', '--', '-.', ':']), 

210 "legend": kwargs.get("legend", ['.10th and .90th quantile', '.25th and .75th quantile', 

211 '.50th quantile', 'reference 1:1']), 

212 "data_unit": kwargs.get("data_unit", "ppb"), } 

213 

214 def _load_data(self) -> xr.DataArray: 

215 """ 

216 Load plot data. 

217 

218 :return: plot data 

219 """ 

220 logging.debug("... load data") 

221 data_collector = [] 

222 for station in self._stations: 

223 file = os.path.join(self._data_pred_path, f"forecasts_{station}_test.nc") 

224 data_tmp = xr.open_dataarray(file) 

225 start = data_tmp.coords[self.index_dim].min().values 

226 end = data_tmp.coords[self.index_dim].max().values 

227 competitor = self.load_competitors(station, start, end) 

228 combined = self._combine_forecasts(data_tmp, competitor, dim=self.model_type_dim) 

229 sel = combined.sel({self.model_type_dim: [self._forecast_indicator, self._obs_name, *self.competitors]}) 

230 data_collector.append(sel.assign_coords({self.iter_dim: station})) 

231 res = xr.concat(data_collector, dim=self.iter_dim).transpose(self.index_dim, self.model_type_dim, 

232 self.ahead_dim, self.iter_dim) 

233 return res 

234 

235 def _combine_forecasts(self, forecast, competitor, dim=None): 

236 """ 

237 Combine forecast and competitor if both are xarray. If competitor is None, this returns forecasts and vise 

238 versa. 

239 """ 

240 if dim is None: 

241 dim = self.model_type_dim 

242 try: 

243 return xr.concat([forecast, competitor], dim=dim) 

244 except (TypeError, AttributeError): 

245 return forecast if competitor is None else competitor 

246 

247 def load_competitors(self, station_name: str, start, end) -> xr.DataArray: 

248 """ 

249 Load all requested and available competitors for a given station. Forecasts must be available in the competitor 

250 path like `<competitor_path>/<target_var>/forecasts_<station_name>_test.nc`. The naming style is equal for all 

251 forecasts of MLAir, so that forecasts of a different experiment can easily be copied into the competitor path 

252 without any change. 

253 

254 :param station_name: station indicator to load competitors for 

255 

256 :return: a single xarray with all competing forecasts 

257 """ 

258 competing_predictions = [] 

259 for competitor_name in self.competitors: 

260 try: 

261 prediction = self._create_competitor_forecast(station_name, competitor_name, start, end) 

262 competing_predictions.append(prediction) 

263 except (FileNotFoundError, KeyError): 

264 logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.") 

265 continue 

266 return xr.concat(competing_predictions, self.model_type_dim) if len(competing_predictions) > 0 else None 

267 

268 @staticmethod 

269 def create_full_time_dim(data, dim, sampling, start, end): 

270 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" 

271 start_data = data.coords[dim].values[0] 

272 freq = {"daily": "1D", "hourly": "1H"}.get(sampling) 

273 _ind = pd.date_range(start, end, freq=freq) # two steps required to include all hours of end interval 

274 datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1), 

275 closed="left", freq=freq)) 

276 t = data.sel({dim: start_data}, drop=True) 

277 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) 

278 res = res.transpose(*data.dims) 

279 if data.shape == res.shape: 

280 res.loc[data.coords] = data 

281 else: 

282 _d = data.sel({dim: slice(start, end)}) 

283 res.loc[_d.coords] = _d 

284 return res 

285 

286 def _create_competitor_forecast(self, station_name: str, competitor_name: str, start, end) -> xr.DataArray: 

287 """ 

288 Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station 

289 indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will 

290 raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either 

291 there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file. 

292 Forecast is trimmed on interval start and end of test subset. 

293 

294 :param station_name: name of the station to load data for 

295 :param competitor_name: name of the model 

296 :return: the forecast of the given competitor 

297 """ 

298 path = os.path.join(self.competitor_path, competitor_name) 

299 file = os.path.join(path, f"forecasts_{station_name}_test.nc") 

300 with xr.open_dataarray(file) as da: 

301 data = da.load() 

302 if self._forecast_indicator in data.coords[self.model_type_dim]: 

303 forecast = data.sel({self.model_type_dim: [self._forecast_indicator]}) 

304 forecast.coords[self.model_type_dim] = [competitor_name] 

305 else: 

306 forecast = data.sel({self.model_type_dim: [competitor_name]}) 

307 # limit forecast to time range of test subset 

308 return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end) 

309 

310 @staticmethod 

311 def _correct_persi_name(competitors): 

312 return ["persi" if x == "Persistence" else x for x in competitors] 

313 

314 def _segment_data(self, data: xr.DataArray, x_model: str) -> xr.DataArray: 

315 """ 

316 Segment data into bins. 

317 

318 :param data: data to segment 

319 :param x_model: name of x dimension 

320 

321 :return: segmented data 

322 """ 

323 logging.debug("... segment data") 

324 # combine index and station to multi index 

325 data = data.stack(z=[self.index_dim, self.iter_dim]) 

326 # replace multi index by simple position index (order is not relevant anymore) 

327 data.coords['z'] = range(len(data.coords['z'])) 

328 # segment data of x_model into bins 

329 data_sel = data.sel({self.model_type_dim: x_model}) 

330 data.loc[{self.model_type_dim: x_model}] = data_sel.to_pandas().T.apply(pd.cut, bins=self._bins, 

331 labels=self._bins[1:]).T.values 

332 return data 

333 

334 @staticmethod 

335 def _labels(plot_type: str, data_unit: str = "ppb") -> Tuple[str, str]: 

336 """ 

337 Assign (x,y) labels to plots correctly, depending on like-base or cali-ref factorization. 

338 

339 :param plot_type: type of plot, either `obs` or a model name 

340 :param data_unit: unit of data to add to labels (default ppb) 

341 

342 :return: tuple with y and x labels 

343 """ 

344 names = (f"forecasted concentration (in {data_unit})", f"observed concentration (in {data_unit})") 

345 if plot_type == "obs": 

346 return names 

347 else: 

348 return names[::-1] 

349 

350 def _get_bins_from_rage_of_data(self) -> np.ndarray: 

351 """ 

352 Get array of bins to use for quantiles. 

353 

354 :return: range from 0 to data's maximum + 1 (rounded down) 

355 """ 

356 return np.arange(0, math.ceil(self._data.max().max()) + 1, 1).astype(int) 

357 

358 def _create_quantile_panel(self, data: xr.DataArray, x_model: str, y_model: str) -> xr.DataArray: 

359 """ 

360 Calculate quantiles. 

361 

362 :param data: data to calculate quantiles 

363 :param x_model: name of x dimension 

364 :param y_model: name of y dimension 

365 

366 :return: quantile panel with binned data 

367 """ 

368 logging.debug("... create quantile panel") 

369 # create empty xarray with dims: time steps ahead, quantiles, bin index (numbers create in previous step) 

370 quantile_panel = xr.DataArray( 

371 np.full([data.ahead.shape[0], len(self._opts["q"]), self._bins[1:].shape[0]], np.nan), 

372 coords=[data.ahead, self._opts["q"], self._bins[1:]], dims=[self.ahead_dim, 'quantiles', 'categories']) 

373 # ensure that the coordinates are in the right order 

374 quantile_panel = quantile_panel.transpose(self.ahead_dim, 'quantiles', 'categories') 

375 # calculate for each bin of the pred_name data the quantiles of the ref_name data 

376 for bin in self._bins[1:]: 

377 mask = (data.loc[x_model, ...] == bin) 

378 quantile_panel.loc[..., bin] = data.loc[y_model, ...].where(mask).quantile(self._opts["q"], dim=['z']).T 

379 return quantile_panel 

380 

381 @staticmethod 

382 def add_affix(affix: str) -> str: 

383 """ 

384 Add additional information to plot name with leading underscore or add empty string if affix is empty. 

385 

386 :param affix: string to add 

387 

388 :return: affix with leading underscore or empty string. 

389 """ 

390 return f"_{affix}" if len(affix) > 0 else "" 

391 

392 def _prepare_plots(self, data: xr.DataArray, x_model: str, y_model: str) -> Tuple[xr.DataArray, xr.DataArray]: 

393 """ 

394 Get segmented data and quantile panel. 

395 

396 :param data: plot data 

397 :param x_model: name of x dimension 

398 :param y_model: name of y dimension 

399 

400 :return: segmented data and quantile panel 

401 """ 

402 segmented_data = self._segment_data(data, x_model) 

403 quantile_panel = self._create_quantile_panel(segmented_data, x_model, y_model) 

404 return segmented_data, quantile_panel 

405 

406 def _plot(self): 

407 """Start plotting routines: overall plot and seasonal (if enabled).""" 

408 logging.info(f"start plotting {self.__class__.__name__}, scheduled number of plots: " 

409 f"{(len(self._seasons) + 1) * (len(self.competitors) + 1) * 2}") 

410 if len(self._seasons) > 0: 

411 self._plot_seasons() 

412 self._plot_all() 

413 

414 def _plot_seasons(self): 

415 """Create seasonal plots.""" 

416 for model in [self._forecast_indicator, *self.competitors]: 

417 for season in self._seasons: 

418 self._plot_base(data=self._data.where(self._data[f"{self.index_dim}.season"] == season), 

419 x_model=model, y_model=self._obs_name, plot_name_affix="cali-ref", 

420 season=season, model_name=model) 

421 self._plot_base(data=self._data.where(self._data[f"{self.index_dim}.season"] == season), 

422 x_model=self._obs_name, y_model=model, plot_name_affix="like-base", 

423 season=season, model_name=model) 

424 

425 def _plot_all(self): 

426 """Plot overall conditional quantiles on full data.""" 

427 for model in [self._forecast_indicator, *self.competitors]: 

428 self._plot_base(data=self._data, x_model=model, y_model=self._obs_name, 

429 plot_name_affix="cali-ref", model_name=model) 

430 self._plot_base(data=self._data, x_model=self._obs_name, y_model=model, 

431 plot_name_affix="like-base", model_name=model) 

432 

433 @TimeTrackingWrapper 

434 def _plot_base(self, data: xr.DataArray, x_model: str, y_model: str, plot_name_affix: str, season: str = "", 

435 model_name: str = ""): 

436 """ 

437 Create conditional quantile plots. 

438 

439 :param data: data which is used to create cond. quantile plot 

440 :param x_model: name of model on x axis (can also be obs) 

441 :param y_model: name of model on y axis (can also be obs) 

442 :param plot_name_affix: should be `cali-ref' or `like-base' 

443 :param season: List of _seasons to use 

444 """ 

445 segmented_data, quantile_panel = self._prepare_plots(data, x_model, y_model) 

446 ylabel, xlabel = self._labels(x_model, self._opts["data_unit"]) 

447 plot_name = f"{self.plot_name}{self.add_affix(season)}{self.add_affix(plot_name_affix)}" \ 

448 f"{self.add_affix(model_name)}.pdf" 

449 plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) 

450 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

451 logging.debug(f"... plot path is {plot_path}") 

452 

453 # create plot for each time step ahead 

454 y2_max = 0 

455 for iteration, d in enumerate(segmented_data.ahead): 

456 logging.debug(f"... plotting {d.values} time step(s) ahead") 

457 # plot smoothed lines with rolling mean 

458 smooth_data = quantile_panel.loc[d, ...].rolling(categories=self._rolling_window, 

459 center=True).mean().to_pandas().T 

460 ax = smooth_data.plot(style=self._opts["linetype"], color='black', legend=False) 

461 ax2 = ax.twinx() 

462 # add reference line 

463 ax.plot([0, self._bins.max()], [0, self._bins.max()], color='k', label='reference 1:1', linewidth=.8) 

464 # add histogram of the segmented data (pred_name) 

465 handles, labels = ax.get_legend_handles_labels() 

466 segmented_data.loc[x_model, d, :].to_pandas().hist(bins=self._bins, ax=ax2, color='k', alpha=.3, grid=False, 

467 rwidth=1) 

468 # add legend 

469 plt.legend(handles[:3] + [handles[-1]], self._opts["legend"], loc='upper left', fontsize='large') 

470 # adjust limits and set labels 

471 ax.set(xlim=(0, self._bins.max()), ylim=(0, self._bins.max())) 

472 ax.set_xlabel(xlabel, fontsize='x-large') 

473 ax.tick_params(axis='x', which='major', labelsize=15) 

474 ax.set_ylabel(ylabel, fontsize='x-large') 

475 ax.tick_params(axis='y', which='major', labelsize=15) 

476 ax2.yaxis.label.set_color('gray') 

477 ax2.tick_params(axis='y', colors='gray') 

478 ax2.yaxis.labelpad = -15 

479 ax2.set_yscale('log') 

480 if iteration == 0: 

481 y2_max = ax2.get_ylim()[1] + 100 

482 ax2.set(ylim=(0, y2_max * 10 ** 8), yticks=np.logspace(0, 4, 5)) 

483 ax2.set_ylabel(' sample size', fontsize='x-large') 

484 ax2.tick_params(axis='y', which='major', labelsize=15) 

485 # set title and save current figure 

486 sampling_letter = {"daily": "D", "hourly": "H"}.get(self._sampling) 

487 model_name = self.model_name if model_name == self._forecast_indicator else model_name 

488 title = f"{model_name} ({sampling_letter}{d.values}{f', {season}' if len(season) > 0 else ''})" 

489 plt.title(title) 

490 pdf_pages.savefig() 

491 # close all open figures / plots 

492 pdf_pages.close() 

493 plt.close('all') 

494 

495 

496@TimeTrackingWrapper 

497class PlotClimatologicalSkillScore(AbstractPlotClass): # pragma: no cover 

498 """ 

499 Create plot of climatological skill score after Murphy (1988) as box plot over all stations. 

500 

501 A forecast time step (called "ahead") is separately shown to highlight the differences for each prediction time 

502 step. Either each single term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed 

503 (score_only=True, default). Y-axis is adjusted following the data and not hard coded. The plot is saved under 

504 plot_folder path with name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. 

505 

506 .. image:: ../../../../../_source/_plots/skill_score_clim_all_terms_CNN.png 

507 :width: 400 

508 

509 .. image:: ../../../../../_source/_plots/skill_score_clim_CNN.png 

510 :width: 400 

511 

512 :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. 

513 :param plot_folder: path to save the plot (default: current directory) 

514 :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True) 

515 :param extra_name_tag: additional tag that can be included in the plot name (default "") 

516 :param model_name: architecture type to specify plot name (default "") 

517 

518 """ 

519 

520 def __init__(self, data: Dict, plot_folder: str = ".", score_only: bool = True, extra_name_tag: str = "", 

521 model_name: str = ""): 

522 """Initialise.""" 

523 super().__init__(plot_folder, f"skill_score_clim_{extra_name_tag}{model_name}") 

524 self._labels = None 

525 self._data = self._prepare_data(data, score_only) 

526 self._plot(score_only) 

527 self._save() 

528 

529 def _prepare_data(self, data: Dict, score_only: bool) -> pd.DataFrame: 

530 """ 

531 Shrink given data, if only scores are relevant. 

532 

533 In any case, transform data to a plot friendly format. Also set plot labels depending on the lead time 

534 dimensions. 

535 

536 :param data: dictionary with station names as keys and 2D xarrays as values 

537 :param score_only: if true only scores of CASE I to IV are relevant 

538 :return: pre-processed data set 

539 """ 

540 data = helpers.dict_to_xarray(data, "station") 

541 self._labels = [str(i) + "d" for i in data.coords["ahead"].values] 

542 if score_only: 

543 data = data.loc[:, ["CASE I", "CASE II", "CASE III", "CASE IV"], :] 

544 return data.to_dataframe("data").reset_index(level=[0, 1, 2]) 

545 

546 def _label_add(self, score_only: bool): 

547 """ 

548 Add the phrase "terms and " if score_only is disabled or empty string (if score_only=True). 

549 

550 :param score_only: if false all terms are relevant, otherwise only CASE I to IV 

551 :return: additional label 

552 """ 

553 return "" if score_only else "terms and " 

554 

555 def _plot(self, score_only, xlim=5): 

556 """ 

557 Plot climatological skill score. 

558 

559 :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms 

560 """ 

561 fig, ax = plt.subplots() 

562 if not score_only: 

563 fig.set_size_inches(11.7, 8.27) 

564 sns.boxplot(x="terms", y="data", hue="ahead", data=self._data, ax=ax, whis=1.5, palette="Blues_d", 

565 showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."}) 

566 ax.axhline(y=0, color="grey", linewidth=.5) 

567 ax.set(ylabel=f"{self._label_add(score_only)}skill score", xlabel="", title="summary of all stations", 

568 ylim=self._lim()) 

569 handles, _ = ax.get_legend_handles_labels() 

570 ax.legend(handles, self._labels) 

571 plt.tight_layout() 

572 

573 def _lim(self) -> Tuple[float, float]: 

574 """ 

575 Calculate axis limits from data (Can be used to set axis extend). 

576 

577 Lower limit is the minimum of 0 and data's minimum (reduced by small subtrahend) and upper limit is data's 

578 maximum (increased by a small addend). 

579 

580 :return: 

581 """ 

582 limit = 5 

583 lower = np.max([-limit, np.min([0, helpers.float_round(self._data["data"].min() - 0.1, 2)])]) 

584 upper = np.min([limit, helpers.float_round(self._data["data"].max() + 0.1, 2)]) 

585 return lower, upper 

586 

587 

588@TimeTrackingWrapper 

589class PlotCompetitiveSkillScore(AbstractPlotClass): # pragma: no cover 

590 """ 

591 Create competitive skill score plot. 

592 

593 Create this plot for the given model setup and the reference models ordinary least squared ("ols") and the 

594 persistence forecast ("persi") for all lead times ("ahead"). The plot is saved under plot_folder with the name 

595 skill_score_competitive_{model_setup}.pdf and resolution of 500dpi. 

596 

597 .. image:: ../../../../../_source/_plots/skill_score_competitive.png 

598 :width: 400 

599 

600 :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre- 

601 calculated comparisons for cnn, persistence and ols. 

602 :param plot_folder: path to save the plot (default: current directory) 

603 :param model_setup: architecture type (default "CNN") 

604 

605 """ 

606 

607 def __init__(self, data: Dict[str, pd.DataFrame], plot_folder=".", model_setup="NN"): 

608 """Initialise.""" 

609 super().__init__(plot_folder, f"skill_score_competitive_{model_setup}") 

610 self._model_setup = model_setup 

611 self._labels = None 

612 self._data = self._prepare_data(helpers.remove_items(data, "total")) 

613 default_plot_name = self.plot_name 

614 # draw full detail plot 

615 self.plot_name = default_plot_name + "_full_detail" 

616 self._plot() 

617 self._save() 

618 # draw also a vertical full detail version 

619 self.plot_name = default_plot_name + "_full_detail_vertical" 

620 self._plot_vertical() 

621 self._save() 

622 # draw default plot with only model comparison 

623 self.plot_name = default_plot_name 

624 self._plot(single_model_comparison=True) 

625 self._save() 

626 # draw also a vertical full detail version 

627 self.plot_name = default_plot_name + "_vertical" 

628 self._plot_vertical(single_model_comparison=True) 

629 self._save() 

630 

631 def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame: 

632 """ 

633 Reformat given data and create plot labels and introduce the dimensions stations and comparison. 

634 

635 :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre- 

636 calculated comparisons for cnn, persistence and ols. 

637 :return: processed data 

638 """ 

639 data = pd.concat(data, axis=0) 

640 data = xr.DataArray(data, dims=["stations", "ahead"]).unstack("stations") 

641 data = data.rename({"stations_level_0": "stations", "stations_level_1": "comparison"}) 

642 data = data.to_dataframe("data").unstack(level=1).swaplevel() 

643 data.columns = data.columns.levels[1] 

644 self._labels = [str(i) + "d" for i in data.index.levels[1].values] 

645 data = data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data") 

646 return data.astype({"comparison": str, "ahead": int, "data": float}) 

647 

648 def _plot(self, single_model_comparison=False): 

649 """Plot skill scores of the comparisons.""" 

650 data = self._filter_comparisons(self._data) if single_model_comparison is True else self._data 

651 max_label_size = len(max(np.unique(data.comparison).tolist(), key=len)) 

652 size = max([len(np.unique(data.comparison)), 6]) 

653 fig, ax = plt.subplots(figsize=(size, 5 * max(0.8, max_label_size/20))) 

654 order = self._create_pseudo_order(data) 

655 sns.boxplot(x="comparison", y="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d", 

656 showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, 

657 order=order) 

658 ax.axhline(y=0, color="grey", linewidth=.5) 

659 ax.set(ylabel="skill score", xlabel="competing models", title="summary of all stations", ylim=self._lim(data)) 

660 handles, _ = ax.get_legend_handles_labels() 

661 plt.xticks(rotation=90) 

662 ax.legend(handles, self._labels) 

663 plt.tight_layout() 

664 

665 def _plot_vertical(self, single_model_comparison=False): 

666 """Plot skill scores of the comparisons, but vertically aligned.""" 

667 data = self._filter_comparisons(self._data) if single_model_comparison is True else self._data 

668 max_label_size = len(max(np.unique(data.comparison).tolist(), key=len)) 

669 size = max([len(np.unique(data.comparison)), 6]) 

670 fig, ax = plt.subplots(figsize=(5 * max(0.8, max_label_size/20), size)) 

671 order = self._create_pseudo_order(data) 

672 sns.boxplot(y="comparison", x="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d", 

673 showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, 

674 order=order) 

675 ax.axvline(x=0, color="grey", linewidth=.5) 

676 ax.set(xlabel="skill score", ylabel="competing models", title="summary of all stations", xlim=self._lim(data)) 

677 handles, _ = ax.get_legend_handles_labels() 

678 ax.legend(handles, self._labels) 

679 plt.tight_layout() 

680 

681 def _create_pseudo_order(self, data): 

682 """Provide first predefined elements and append all remaining.""" 

683 first_elements = [f"{self._model_setup} - persi", "ols - persi", f"{self._model_setup} - ols"] 

684 first_elements = list(filter(lambda x: x in data.comparison.tolist(), first_elements)) 

685 uniq, index = np.unique(first_elements + data.comparison.unique().tolist(), return_index=True) 

686 return uniq[index.argsort()] 

687 

688 def _filter_comparisons(self, data): 

689 filtered_headers = list(filter(lambda x: f"{self._model_setup} - " in x, data.comparison.unique())) 

690 return data[data.comparison.isin(filtered_headers)] 

691 

692 @staticmethod 

693 def _lim(data) -> Tuple[float, float]: 

694 """ 

695 Calculate axis limits from data (Can be used to set axis extend). 

696 

697 Lower limit is the minimum of 0 and data's minimum (reduced by small subtrahend) and upper limit is data's 

698 maximum (increased by a small addend). 

699 

700 :return: 

701 """ 

702 limit = 5 

703 lower = np.max([-limit, np.min([0, helpers.float_round(data.min()[2], 2) - 0.1])]) 

704 upper = np.min([limit, helpers.float_round(data.max()[2], 2) + 0.1]) 

705 return lower, upper 

706 

707 

708@TimeTrackingWrapper 

709class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover 

710 """ 

711 Create plot of feature importance analysis. 

712 

713 By passing a list `separate_vars` containing variable names, a second plot is created showing the `separate_vars` 

714 and the remaining variables side by side with different scaling. 

715 

716 .. image:: ../../../../../_source/_plots/skill_score_bootstrap.png 

717 :width: 400 

718 

719 .. image:: ../../../../../_source/_plots/skill_score_bootstrap_separated.png 

720 :width: 400 

721 

722 """ 

723 

724 def __init__(self, data: Dict, plot_folder: str = ".", separate_vars: List = None, sampling: str = "daily", 

725 ahead_dim: str = "ahead", bootstrap_type: str = None, bootstrap_method: str = None, 

726 boot_dim: str = "boots", model_name: str = "NN", branch_names: list = None, ylim: tuple = None): 

727 """ 

728 Set attributes and create plot. 

729 

730 :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. 

731 :param plot_folder: path to save the plot (default: current directory) 

732 :param separate_vars: variables to plot separated (default: ['o3']) 

733 :param sampling: type of sampling rate, should be either hourly or daily (default: "daily") 

734 :param ahead_dim: name of the ahead dimensions (default: "ahead") 

735 :param bootstrap_annotation: additional information to use in the file name (default: None) 

736 :param model_name: architecture type to specify plot name (default "NN") 

737 """ 

738 annotation = ["_".join([s for s in ["", bootstrap_type, bootstrap_method] if s is not None])][0] 

739 super().__init__(plot_folder, f"feature_importance_{model_name}{annotation}") 

740 if separate_vars is None: 

741 separate_vars = ['o3'] 

742 self._labels = None 

743 self._x_name = "boot_var" 

744 self._ahead_dim = ahead_dim 

745 self._boot_dim = boot_dim 

746 self._boot_type = self._set_bootstrap_type(bootstrap_type) 

747 self._boot_method = self._set_bootstrap_method(bootstrap_method) 

748 self._number_of_bootstraps = 0 

749 self._branches_names = branch_names 

750 self._ylim = ylim 

751 

752 self._data = self._prepare_data(data, sampling) 

753 self._set_title(model_name) 

754 if "branch" in self._data.columns: 

755 plot_name = self.plot_name 

756 for branch in self._data["branch"].unique(): 

757 self._set_title(model_name, branch, len(self._data["branch"].unique())) 

758 self.plot_name = f"{plot_name}_{branch}" 

759 try: 

760 self._plot(branch=branch) 

761 self._save() 

762 except ValueError as e: 

763 logging.info(f"Did not plot {self.plot_name} because of {e}") 

764 if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0: 

765 self.plot_name += '_separated' 

766 try: 

767 self._plot(branch=branch, separate_vars=separate_vars) 

768 self._save(bbox_inches='tight') 

769 except ValueError as e: 

770 logging.info(f"Did not plot {self.plot_name} because of {e}") 

771 else: 

772 try: 

773 self._plot() 

774 self._save() 

775 except ValueError as e: 

776 logging.info(f"Did not plot {self.plot_name} because of {e}") 

777 if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0: 

778 self.plot_name += '_separated' 

779 try: 

780 self._plot(separate_vars=separate_vars) 

781 self._save(bbox_inches='tight') 

782 except ValueError as e: 

783 logging.info(f"Did not plot {self.plot_name} because of {e}") 

784 

785 @staticmethod 

786 def _set_bootstrap_type(boot_type): 

787 return {"singleinput": "single input"}.get(boot_type, boot_type) 

788 

789 def _set_title(self, model_name, branch=None, n_branches=None): 

790 title_d = {"single input": "Single Inputs", "branch": "Input Branches", "variable": "Variables"} 

791 base_title = f"{model_name}\nImportance of {title_d[self._boot_type]}" 

792 

793 additional = [] 

794 if branch is not None: 

795 try: 

796 assert n_branches == len(self._branches_names) 

797 branch_name = self._branches_names[int(branch)] 

798 except (IndexError, TypeError, ValueError, AssertionError): 

799 branch_name = branch 

800 additional.append(branch_name) 

801 if self._number_of_bootstraps > 1: 

802 additional.append(f"n={self._number_of_bootstraps}") 

803 additional_title = ", ".join(additional) 

804 if len(additional_title) > 0: 

805 additional_title = f" ({additional_title})" 

806 self._title = base_title + additional_title 

807 

808 @staticmethod 

809 def _set_bootstrap_method(boot_method): 

810 return {"zero_mean": "zero mean", "shuffle": "shuffled"}.get(boot_method, boot_method) 

811 

812 def _prepare_data(self, data: Dict, sampling: str) -> pd.DataFrame: 

813 """ 

814 Shrink given data, if only scores are relevant. 

815 

816 In any case, transform data to a plot friendly format. Also set plot labels depending on the lead time 

817 dimensions. 

818 

819 :param data: dictionary with station names as keys and 2D xarrays as values 

820 :return: pre-processed data set 

821 """ 

822 station_dim = "station" 

823 data = helpers.dict_to_xarray(data, station_dim).sortby(self._x_name) 

824 data = data.transpose(station_dim, self._ahead_dim, self._boot_dim, self._x_name) 

825 if self._boot_type == "single input": 

826 number_tags = self._get_number_tag(data.coords[self._x_name].values, split_by='_') 

827 new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', 

828 keep=1, as_unique=True) 

829 try: 

830 values = data.values.reshape((*data.shape[:3], len(number_tags), len(new_boot_coords))) 

831 data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, 

832 "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim], 

833 self._boot_dim: data.coords[self._boot_dim]}, 

834 dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) 

835 except ValueError: 

836 data_coll = [] 

837 for nr in number_tags: 

838 filtered_coords = list(filter(lambda x: nr in x.split("_")[0], data.coords[self._x_name].values)) 

839 new_boot_coords = self._return_vars_without_number_tag(filtered_coords, split_by='_', keep=1, 

840 as_unique=True) 

841 sel_data = data.sel({self._x_name: filtered_coords}) 

842 values = sel_data.values.reshape((*data.shape[:3], 1, len(new_boot_coords))) 

843 sel_data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, 

844 "branch": [nr], self._ahead_dim: data.coords[self._ahead_dim], 

845 self._boot_dim: data.coords[self._boot_dim]}, 

846 dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) 

847 data_coll.append(sel_data) 

848 data = xr.concat(data_coll, "branch") 

849 else: 

850 try: 

851 new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', 

852 keep=1) 

853 data = data.assign_coords({self._x_name: new_boot_coords}) 

854 except NotImplementedError: 

855 pass 

856 _, sampling_letter = self._get_sampling(sampling, 1) 

857 sampling_letter = {"d": "D", "h": "H"}.get(sampling_letter, sampling_letter) 

858 self._labels = [sampling_letter + str(i) for i in data.coords[self._ahead_dim].values] 

859 if station_dim not in data.dims: 

860 data = data.expand_dims(station_dim) 

861 self._number_of_bootstraps = np.unique(data.coords[self._boot_dim].values).shape[0] 

862 return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()).dropna() 

863 

864 def _return_vars_without_number_tag(self, values, split_by, keep, as_unique=False): 

865 arr = np.array([v.split(split_by) for v in values]) 

866 num = arr[:, 0] 

867 if arr.shape[keep] == 1: # keep dim has only length 1, no number tags required 

868 return num 

869 new_val = arr[:, keep] 

870 if self._all_values_are_equal(num, axis=0): 

871 return new_val 

872 elif as_unique is True: 

873 return np.unique(new_val) 

874 else: 

875 raise NotImplementedError 

876 

877 @staticmethod 

878 def _get_number_tag(values, split_by): 

879 arr = np.array([v.split(split_by) for v in values]) 

880 num = arr[:, 0] 

881 return np.unique(num).tolist() 

882 

883 @staticmethod 

884 def _all_values_are_equal(arr, axis=0): 

885 if np.all(arr == arr[0], axis=axis): 

886 return True 

887 else: 

888 return False 

889 

890 def _label_add(self, score_only: bool): 

891 """ 

892 Add the phrase "terms and " if score_only is disabled or empty string (if score_only=True). 

893 

894 :param score_only: if false all terms are relevant, otherwise only CASE I to IV 

895 :return: additional label 

896 """ 

897 return "" if score_only else "terms and " 

898 

899 def _plot(self, branch=None, separate_vars=None): 

900 """Plot climatological skill score.""" 

901 if separate_vars is None: 

902 self._plot_all_variables(branch) 

903 else: 

904 self._plot_selected_variables(separate_vars, branch) 

905 

906 def _plot_selected_variables(self, separate_vars: List, branch=None): 

907 data = self._data if branch is None else self._data[self._data["branch"] == str(branch)] 

908 self.raise_error_if_vars_do_not_exist(data, separate_vars, self._x_name, name="separate_vars") 

909 all_variables = self._get_unique_values_from_column_of_df(data, self._x_name) 

910 remaining_vars = helpers.remove_items(all_variables, separate_vars) 

911 self.raise_error_if_vars_do_not_exist(data, remaining_vars, self._x_name, name="remaining_vars") 

912 data_first = self._select_data(df=data, variables=separate_vars, column_name=self._x_name) 

913 data_second = self._select_data(df=data, variables=remaining_vars, column_name=self._x_name) 

914 

915 fig, ax = plt.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [len(separate_vars), 

916 len(remaining_vars)]}, 

917 figsize=(len(remaining_vars),len(remaining_vars)/2.)) 

918 if len(separate_vars) > 1: 

919 first_box_width = .8 

920 else: 

921 first_box_width = .8 

922 

923 sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_first, ax=ax[0], whis=1.5, 

924 palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, 

925 showfliers=False, width=first_box_width) 

926 ax[0].set(ylabel=f"skill score", xlabel="") 

927 if self._ylim is not None: 

928 _ylim = self._ylim if isinstance(self._ylim, tuple) else self._ylim[0] 

929 ax[0].set(ylim=_ylim) 

930 

931 sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_second, ax=ax[1], whis=1.5, 

932 palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, 

933 showfliers=False) 

934 ax[1].set(ylabel="", xlabel="") 

935 ax[1].yaxis.tick_right() 

936 if self._ylim is not None and isinstance(self._ylim, list): 

937 _ylim = self._ylim[1] 

938 ax[1].set(ylim=_ylim) 

939 

940 handles, _ = ax[1].get_legend_handles_labels() 

941 for sax in ax: 

942 matplotlib.pyplot.sca(sax) 

943 sax.axhline(y=0, color="grey", linewidth=.5) 

944 plt.xticks(rotation=45, ha='right') 

945 sax.legend_.remove() 

946 

947 # fig.legend(handles, self._labels, loc='upper center', ncol=len(handles) + 1, ) 

948 ax[1].legend(handles, self._labels, loc='lower center', ncol=len(handles) + 1, fontsize="medium") 

949 

950 def align_yaxis(ax1, ax2): 

951 """ 

952 Align zeros of the two axes, zooming them out by same ratio 

953 

954 This function is copy pasted from https://stackoverflow.com/a/41259922 

955 """ 

956 axes = (ax1, ax2) 

957 extrema = [ax.get_ylim() for ax in axes] 

958 tops = [extr[1] / (extr[1] - extr[0]) for extr in extrema] 

959 # Ensure that plots (intervals) are ordered bottom to top: 

960 if tops[0] > tops[1]: 

961 axes, extrema, tops = [list(reversed(l)) for l in (axes, extrema, tops)] 

962 

963 # How much would the plot overflow if we kept current zoom levels? 

964 tot_span = tops[1] + 1 - tops[0] 

965 

966 b_new_t = extrema[0][0] + tot_span * (extrema[0][1] - extrema[0][0]) 

967 t_new_b = extrema[1][1] - tot_span * (extrema[1][1] - extrema[1][0]) 

968 axes[0].set_ylim(extrema[0][0], b_new_t) 

969 axes[1].set_ylim(t_new_b, extrema[1][1]) 

970 

971 align_yaxis(ax[0], ax[1]) 

972 align_yaxis(ax[0], ax[1]) 

973 plt.subplots_adjust(right=0.8) 

974 plt.title(self._title) 

975 

976 @staticmethod 

977 def _select_data(df: pd.DataFrame, variables: List[str], column_name: str) -> pd.DataFrame: 

978 selected_data = None 

979 for i, variable in enumerate(variables): 

980 if i == 0: 

981 selected_data = df.loc[df[column_name] == variable] 

982 else: 

983 tmp_var = df.loc[df[column_name] == variable] 

984 selected_data = pd.concat([selected_data, tmp_var], axis=0) 

985 return selected_data 

986 

987 def raise_error_if_vars_do_not_exist(self, data, vars, column_name, name="separate_vars"): 

988 if len(vars) == 0: 

989 msg = f"No variables are given for `{name}' to check in `self.data' " 

990 raise ValueError(msg) 

991 if not self._variables_exist_in_df(df=data, variables=vars, column_name=column_name): 

992 msg = f"At least one entry of `{name}' does not exist in `self.data' " 

993 raise ValueError(msg) 

994 

995 @staticmethod 

996 def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List: 

997 return list(df[column_name].unique()) 

998 

999 def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str): 

1000 vars_in_df = set(self._get_unique_values_from_column_of_df(df, column_name)) 

1001 return set(variables).issubset(vars_in_df) 

1002 

1003 def _plot_all_variables(self, branch=None): 

1004 """ 

1005 

1006 """ 

1007 plot_data = self._data if branch is None else self._data[self._data["branch"] == str(branch)] 

1008 if self._boot_type == "branch": 

1009 fig, ax = plt.subplots(figsize=(0.5 + 2 / len(plot_data[self._x_name].unique()) + len(plot_data[self._x_name].unique()),4)) 

1010 sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1., 

1011 palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, 

1012 showfliers=False, width=0.8) 

1013 else: 

1014 fig, ax = plt.subplots() 

1015 sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1.5, palette="Blues_d", 

1016 showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, showfliers=False) 

1017 ax.axhline(y=0, color="grey", linewidth=.5) 

1018 

1019 if self._ylim is not None: 

1020 if isinstance(self._ylim, tuple): 

1021 _ylim = self._ylim 

1022 else: 

1023 _ylim = (min(self._ylim[0][0], self._ylim[1][0]), max(self._ylim[0][1], self._ylim[1][1])) 

1024 ax.set(ylim=_ylim) 

1025 

1026 if self._boot_type == "branch": 

1027 plt.xticks() 

1028 else: 

1029 plt.xticks(rotation=45) 

1030 ax.set(ylabel=f"skill score", xlabel="", title=self._title) 

1031 handles, _ = ax.get_legend_handles_labels() 

1032 ax.legend(handles, self._labels) 

1033 plt.tight_layout() 

1034 

1035 

1036@TimeTrackingWrapper 

1037class PlotTimeSeries: # pragma: no cover 

1038 """ 

1039 Create time series plot. 

1040 

1041 Currently, plots are under development and not well designed for any use in public. 

1042 """ 

1043 

1044 def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = ".", 

1045 sampling="daily", model_name="nn", obs_name="obs", ahead_dim="ahead"): 

1046 """Initialise.""" 

1047 self._data_path = data_path 

1048 self._data_name = name 

1049 self._stations = stations 

1050 self._model_name = model_name 

1051 self._obs_name = obs_name 

1052 self._ahead_dim = ahead_dim 

1053 self._window_lead_time = self._get_window_lead_time(window_lead_time) 

1054 self._sampling = self._get_sampling(sampling) 

1055 self._plot(plot_folder) 

1056 

1057 @staticmethod 

1058 def _get_sampling(sampling): 

1059 if sampling == "daily": 

1060 return "D" 

1061 elif sampling == "hourly": 

1062 return "h" 

1063 

1064 def _get_window_lead_time(self, window_lead_time: int): 

1065 """ 

1066 Extract the lead time from data and arguments. 

1067 

1068 If window_lead_time is not given, extract this information from data itself by the number of ahead dimensions. 

1069 If given, check if data supports the give length. If the number of ahead dimensions in data is lower than the 

1070 given lead time, data's lead time is used. 

1071 

1072 :param window_lead_time: lead time from arguments to validate 

1073 :return: validated lead time, comes either from given argument or from data itself 

1074 """ 

1075 ahead_steps = len(self._load_data(self._stations[0]).coords[self._ahead_dim]) 

1076 if window_lead_time is None: 

1077 window_lead_time = ahead_steps 

1078 return min(ahead_steps, window_lead_time) 

1079 

1080 def _load_data(self, station): 

1081 logging.debug(f"... preprocess station {station}") 

1082 file_name = os.path.join(self._data_path, self._data_name % station) 

1083 data = xr.open_dataarray(file_name) 

1084 return data.sel(type=[self._model_name, self._obs_name]) 

1085 

1086 def _plot(self, plot_folder): 

1087 pdf_pages = self._create_pdf_pages(plot_folder) 

1088 for pos, station in enumerate(self._stations): 

1089 data = self._load_data(station) 

1090 start, end = self._get_time_range(data) 

1091 fig, axes, factor = self._create_subplots(start, end) 

1092 nan_list = [] 

1093 for i_year in range(end - start + 1): 

1094 data_year = data.sel(index=f"{start + i_year}") 

1095 for i_half_of_year in range(factor): 

1096 pos = factor * i_year + i_half_of_year 

1097 try: 

1098 plot_data = self._create_plot_data(data_year, factor, i_half_of_year) 

1099 self._plot_obs(axes[pos], plot_data) 

1100 self._plot_ahead(axes[pos], plot_data) 

1101 if np.isnan(plot_data.values).all(): 

1102 nan_list.append(pos) 

1103 except Exception: 

1104 nan_list.append(pos) 

1105 self._clean_up_axes(nan_list, axes, fig) 

1106 self._save_page(station, pdf_pages) 

1107 pdf_pages.close() 

1108 plt.close('all') 

1109 

1110 @staticmethod 

1111 def _clean_up_axes(nan_list, axes, fig): 

1112 for i in reversed(nan_list): 

1113 fig.delaxes(axes[i]) 

1114 

1115 @staticmethod 

1116 def _save_page(station, pdf_pages): 

1117 plt.suptitle(station) 

1118 plt.legend() 

1119 plt.tight_layout() 

1120 pdf_pages.savefig(dpi=500) 

1121 

1122 @staticmethod 

1123 def _create_plot_data(data, factor, running_index): 

1124 if factor > 1: 

1125 if running_index == 0: 

1126 data = data.where(data["index.month"] < 7) 

1127 else: 

1128 data = data.where(data["index.month"] >= 7) 

1129 return data 

1130 

1131 def _create_subplots(self, start, end): 

1132 factor = 1 

1133 if self._sampling == "h": 

1134 factor = 2 

1135 f, ax = plt.subplots((end - start + 1) * factor, sharey=True, figsize=(50, 30), squeeze=False) 

1136 return f, ax[:, 0], factor 

1137 

1138 def _plot_ahead(self, ax, data): 

1139 color = sns.color_palette("Blues_d", self._window_lead_time).as_hex() 

1140 for ahead in data.coords[self._ahead_dim].values: 

1141 plot_data = data.sel({"type": self._model_name, self._ahead_dim: ahead}).drop(["type", self._ahead_dim]).squeeze().shift(index=ahead) 

1142 sampling_letter = {"d": "D", "h": "H"}.get(self._sampling, self._sampling) 

1143 label = f"{sampling_letter}{ahead}" 

1144 ax.plot(plot_data, color=color[ahead - 1], label=label) 

1145 

1146 def _plot_obs(self, ax, data): 

1147 ahead = 1 

1148 obs_data = data.sel(type="obs", ahead=ahead).shift(index=ahead) 

1149 ax.plot(obs_data, color=colors.cnames["green"], label="obs") 

1150 

1151 @staticmethod 

1152 def _get_time_range(data): 

1153 def f(x, f_x): 

1154 return pd.to_datetime(f_x(x.index.values)).year 

1155 

1156 return f(data, min), f(data, max) 

1157 

1158 @staticmethod 

1159 def _create_pdf_pages(plot_folder: str): 

1160 """ 

1161 Store plot locally. 

1162 

1163 :param plot_folder: path to save the plot 

1164 """ 

1165 plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf') 

1166 logging.debug(f"... save plot to {plot_name}") 

1167 return matplotlib.backends.backend_pdf.PdfPages(plot_name) 

1168 

1169 

1170@TimeTrackingWrapper 

1171class PlotSeparationOfScales(AbstractPlotClass): # pragma: no cover 

1172 

1173 def __init__(self, collection: DataCollection, plot_folder: str = ".", time_dim="datetime", window_dim="window", 

1174 filter_dim="filter", target_dim="variables"): 

1175 """Initialise.""" 

1176 # create standard Gantt plot for all stations (currently in single pdf file with single page) 

1177 plot_folder = os.path.join(plot_folder, "separation_of_scales") 

1178 super().__init__(plot_folder, "separation_of_scales") 

1179 self.time_dim = time_dim 

1180 self.window_dim = window_dim 

1181 self.filter_dim = filter_dim 

1182 self.target_dim = target_dim 

1183 self._plot(collection) 

1184 

1185 def _plot(self, collection: DataCollection): 

1186 orig_plot_name = self.plot_name 

1187 for dh in collection: 

1188 data = dh.get_X(as_numpy=False)[0] 

1189 station = dh.id_class.station[0] 

1190 data = data.sel(Stations=station) 

1191 data.plot(x=self.time_dim, y=self.window_dim, col=self.filter_dim, row=self.target_dim, robust=True) 

1192 self.plot_name = f"{orig_plot_name}_{station}" 

1193 self._save() 

1194 

1195 

1196@TimeTrackingWrapper 

1197class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover 

1198 

1199 def __init__(self, data: xr.DataArray, plot_folder: str = ".", model_type_dim: str = "type", 

1200 error_measure: str = "mse", error_unit: str = None, dim_name_boots: str = 'boots', 

1201 block_length: str = None, model_name: str = "NN", model_indicator: str = "nn", 

1202 ahead_dim: str = "ahead", sampling: Union[str, Tuple[str]] = "", season_annotation: str = None, 

1203 apply_root: bool = True, plot_name="sample_uncertainty_from_bootstrap"): 

1204 super().__init__(plot_folder, plot_name) 

1205 self.default_plot_name = self.plot_name 

1206 self.model_type_dim = model_type_dim 

1207 self.ahead_dim = ahead_dim 

1208 self.error_measure = error_measure 

1209 self.dim_name_boots = dim_name_boots 

1210 self.error_unit = error_unit 

1211 self.block_length = block_length 

1212 self.model_name = model_name 

1213 _season = season_annotation or "" 

1214 self.sampling = {"daily": "d", "hourly": "H"}.get(sampling[1] if isinstance(sampling, tuple) else sampling, "") 

1215 data = self.rename_model_indicator(data, model_name, model_indicator) 

1216 self.prepare_data(data) 

1217 

1218 # create all combinations to plot (h/v, utest/notest, single/multi) 

1219 variants = list(itertools.product(*[["v", "h"], [True, False], ["single", "multi", "panel"]])) 

1220 

1221 # plot raw metric (mse) 

1222 for orientation, utest, agg_type in variants: 

1223 self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, season=_season) 

1224 self._plot_kde(agg_type=agg_type, season=_season) 

1225 

1226 if apply_root is True: 

1227 # plot root of metric (rmse) 

1228 self._apply_root() 

1229 for orientation, utest, agg_type in variants: 

1230 self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, tag="_sqrt", season=_season) 

1231 self._plot_kde(agg_type=agg_type, tag="_sqrt", season=_season) 

1232 

1233 self._data_table = None 

1234 self._n_boots = None 

1235 self._factor = None 

1236 

1237 @property 

1238 def get_asteriks_from_mann_whitney_u_result(self): 

1239 return represent_p_values_as_asteriks(mann_whitney_u_test(data=self._data_table, 

1240 reference_col_name=self.model_name, 

1241 axis=0, alternative="two-sided").iloc[-1]) 

1242 

1243 def rename_model_indicator(self, data, model_name, model_indicator): 

1244 data.coords[self.model_type_dim] = [{model_indicator: model_name}.get(n, n) 

1245 for n in data.coords[self.model_type_dim].values] 

1246 return data 

1247 

1248 def prepare_data(self, data: xr.DataArray): 

1249 data_table = data.to_dataframe(self.model_type_dim).unstack() 

1250 factor = len(data.coords[self.ahead_dim]) if self.ahead_dim in data.dims else 1 

1251 self._data_table = data_table[data_table.mean().sort_values().index].droplevel(0, axis=1) 

1252 self._n_boots = int(self._data_table.shape[0] / factor) 

1253 self._factor = factor 

1254 

1255 def _apply_root(self): 

1256 self._data_table = np.sqrt(self._data_table) 

1257 self.error_measure = f"root {self.error_measure}" 

1258 self.error_unit = self.error_unit.replace("$^2$", "") 

1259 

1260 def _plot_kde(self, agg_type="single", tag="", season=""): 

1261 self.plot_name = self.default_plot_name + "_kde" + "_" + agg_type + tag + {"": ""}.get(season, f"_{season}") 

1262 data_table = self._data_table 

1263 if agg_type == "multi": 

1264 return # nothing to do 

1265 if self.ahead_dim not in data_table.index.names and agg_type == "panel": 

1266 return # nothing to do 

1267 n_boots = self._n_boots 

1268 error_label = self.error_measure if self.error_unit is None else f"{self.error_measure} (in {self.error_unit})" 

1269 sampling_letter = {"d": "D", "h": "H"}.get(self.sampling, self.sampling) 

1270 if agg_type == "single": 

1271 fig, ax = plt.subplots() 

1272 if self.ahead_dim in data_table.index.names: 

1273 data_table = data_table.groupby(level=0).mean() 

1274 sns.kdeplot(data=data_table, ax=ax) 

1275 ylims = list(ax.get_ylim()) 

1276 ax.set_ylim([ylims[0], ylims[1]*1.025]) 

1277 ax.set_xlabel(error_label) 

1278 

1279 text = f"n={n_boots}" 

1280 if self.block_length is not None: 

1281 text = f"{self.block_length}, {text}" 

1282 if len(season) > 0: 

1283 text = f"{season}, {text}" 

1284 loc = "upper left" 

1285 text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5, bbox_to_anchor=(0., 1.0), 

1286 bbox_transform=ax.transAxes) 

1287 plt.setp(text_box.patch, edgecolor='k', facecolor='w') 

1288 ax.add_artist(text_box) 

1289 ax.get_legend().set_title(None) 

1290 plt.tight_layout() 

1291 else: 

1292 g = sns.FacetGrid(data_table.stack(self.model_type_dim).reset_index(), **{"col": self.ahead_dim}, 

1293 hue=self.model_type_dim, legend_out=True) 

1294 g.map(sns.kdeplot, 0) 

1295 g.add_legend(title="") 

1296 fig = plt.gcf() 

1297 _labels = [sampling_letter + str(i) for i in data_table.index.levels[1].values] 

1298 for axi, title in zip(g.axes.flatten(), _labels): 

1299 axi.set_title(title) 

1300 for axi in g.axes.flatten(): 

1301 axi.set_xlabel(None) 

1302 fig.supxlabel(error_label) 

1303 text = f"n={n_boots}" 

1304 if self.block_length is not None: 

1305 text = f"{self.block_length}, {text}" 

1306 if len(season) > 0: 

1307 text = f"{season}, {text}" 

1308 loc = "upper right" 

1309 text_box = AnchoredText(text, frameon=True, loc=loc) 

1310 plt.setp(text_box.patch, edgecolor='k', facecolor='w') 

1311 g.axes.flatten()[0].add_artist(text_box) 

1312 self._save() 

1313 plt.close("all") 

1314 

1315 def _plot(self, orientation: str = "v", apply_u_test: bool = False, agg_type="single", tag="", season=""): 

1316 self.plot_name = self.default_plot_name + {"v": "_vertical", "h": "_horizontal"}[orientation] + \ 

1317 {True: "_u_test", False: ""}[apply_u_test] + "_" + agg_type + tag + \ 

1318 {"": ""}.get(season, f"_{season}") 

1319 if apply_u_test is True and agg_type == "multi": 

1320 return # not implemented 

1321 data_table = self._data_table 

1322 if self.ahead_dim not in data_table.index.names and agg_type in ["multi", "panel"]: 

1323 return # nothing to do 

1324 if apply_u_test is True and agg_type == "panel": 

1325 return # nothing to do 

1326 n_boots = self._n_boots 

1327 size = len(np.unique(data_table.columns)) 

1328 asteriks = self.get_asteriks_from_mann_whitney_u_result if apply_u_test is True else None 

1329 color_palette = sns.color_palette("Blues_d", self._factor).as_hex() 

1330 sampling_letter = {"d": "D", "h": "H"}.get(self.sampling, self.sampling) 

1331 if orientation == "v": 

1332 figsize, width = (size, 5), 0.4 

1333 elif orientation == "h": 

1334 if agg_type == "multi": 

1335 size *= np.sqrt(len(data_table.index.unique(self.ahead_dim))) 

1336 size = max(size, 8) 

1337 figsize, width = (7, (1+.5*size)), 0.65 

1338 else: 

1339 raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") 

1340 fig, ax = plt.subplots(figsize=figsize) 

1341 if agg_type == "single": 

1342 if self.ahead_dim in data_table.index.names: 

1343 data_table = data_table.groupby(level=0).mean() 

1344 sns.boxplot(data=data_table, ax=ax, whis=1.5, color="white", 

1345 showmeans=True, meanprops={"markersize": 6, "markeredgecolor": "k"}, 

1346 flierprops={"marker": "o", "markerfacecolor": "black", "markeredgecolor": "none", "markersize": 3}, 

1347 boxprops={'facecolor': 'none', 'edgecolor': 'k'}, width=width, orient=orientation) 

1348 elif agg_type == "multi": 

1349 xy = {"x": self.model_type_dim, "y": 0} if orientation == "v" else {"x": 0, "y": self.model_type_dim} 

1350 sns.boxplot(data=data_table.stack(self.model_type_dim).reset_index(), ax=ax, whis=1.5, palette=color_palette, 

1351 showmeans=True, meanprops={"markersize": 6, "markeredgecolor": "k", "markerfacecolor": "white"}, 

1352 flierprops={"marker": "o", "markerfacecolor": "black", "markeredgecolor": "none", "markersize": 3}, 

1353 boxprops={'edgecolor': 'k'}, width=.8, orient=orientation, **xy, hue=self.ahead_dim) 

1354 

1355 _labels = [sampling_letter + str(i) for i in data_table.index.levels[1].values] 

1356 handles, _ = ax.get_legend_handles_labels() 

1357 ax.legend(handles, _labels) 

1358 else: 

1359 xy = (self.model_type_dim, 0) if orientation == "v" else (0, self.model_type_dim) 

1360 col_or_row = {"col": self.ahead_dim} if orientation == "v" else {"row": self.ahead_dim} 

1361 aspect = figsize[0] / figsize[1] 

1362 height = figsize[1] * 0.8 

1363 ax = sns.FacetGrid(data_table.stack(self.model_type_dim).reset_index(), **col_or_row, aspect=aspect, height=height) 

1364 ax.map(sns.boxplot, *xy, whis=1.5, color="white", showmeans=True, order=data_table.mean().index.to_list(), 

1365 meanprops={"markersize": 6, "markeredgecolor": "k"}, 

1366 flierprops={"marker": "o", "markerfacecolor": "black", "markeredgecolor": "none", "markersize": 3}, 

1367 boxprops={'facecolor': 'none', 'edgecolor': 'k'}, width=width, orient=orientation) 

1368 

1369 _labels = [sampling_letter + str(i) for i in data_table.index.levels[1].values] 

1370 for axi, title in zip(ax.axes.flatten(), _labels): 

1371 axi.set_title(title) 

1372 plt.setp(axi.lines, color='k') 

1373 

1374 error_label = self.error_measure if self.error_unit is None else f"{self.error_measure} (in {self.error_unit})" 

1375 if agg_type == "panel": 

1376 if orientation == "v": 

1377 for axi in ax.axes.flatten(): 

1378 axi.set_xlabel(None) 

1379 axi.set_xticklabels(axi.get_xticklabels(), rotation=45) 

1380 ax.set_ylabels(error_label) 

1381 loc = "upper left" 

1382 else: 

1383 for axi in ax.axes.flatten(): 

1384 axi.set_ylabel(None) 

1385 ax.set_xlabels(error_label) 

1386 loc = "upper right" 

1387 text = f"n={n_boots}" 

1388 if self.block_length is not None: 

1389 text = f"{self.block_length}, {text}" 

1390 if len(season) > 0: 

1391 text = f"{season}, {text}" 

1392 text_box = AnchoredText(text, frameon=True, loc=loc) 

1393 plt.setp(text_box.patch, edgecolor='k', facecolor='w') 

1394 ax.axes.flatten()[0].add_artist(text_box) 

1395 else: 

1396 if orientation == "v": 

1397 if apply_u_test: 

1398 ax = self.set_significance_bars(asteriks, ax, data_table, orientation) 

1399 ylims = list(ax.get_ylim()) 

1400 ax.set_ylim([ylims[0], ylims[1]*1.025]) 

1401 ax.set_ylabel(error_label) 

1402 ax.set_xticklabels(ax.get_xticklabels(), rotation=45) 

1403 ax.set_xlabel(None) 

1404 elif orientation == "h": 

1405 if apply_u_test: 

1406 ax = self.set_significance_bars(asteriks, ax, data_table, orientation) 

1407 ax.set_xlabel(error_label) 

1408 xlims = list(ax.get_xlim()) 

1409 ax.set_xlim([xlims[0], xlims[1] * 1.015]) 

1410 ax.set_ylabel(None) 

1411 else: 

1412 raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") 

1413 text = f"n={n_boots}" 

1414 if self.block_length is not None: 

1415 text = f"{self.block_length}, {text}" 

1416 if len(season) > 0: 

1417 text = f"{season}, {text}" 

1418 loc = "lower left" 

1419 text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5, bbox_to_anchor=(0., 1.0), 

1420 bbox_transform=ax.transAxes) 

1421 plt.setp(text_box.patch, edgecolor='k', facecolor='w') 

1422 ax.add_artist(text_box) 

1423 plt.setp(ax.lines, color='k') 

1424 plt.tight_layout() 

1425 self._save() 

1426 plt.close("all") 

1427 

1428 def set_significance_bars(self, asteriks, ax, data_table, orientation): 

1429 p1 = list(asteriks.index).index(self.model_name) 

1430 q_prev = 0. 

1431 factor = 0.025 

1432 for i, ast in enumerate(asteriks): 

1433 if not i == list(asteriks.index).index(self.model_name): 

1434 p2 = i 

1435 q = data_table[[self.model_name, data_table.columns[i]]].max().max() 

1436 q = max(q, q_prev) * (1 + factor) 

1437 if abs(q - q_prev) < q * factor: 

1438 q = q * (1 + factor) 

1439 h = 0.01 * data_table.max().max() 

1440 if orientation == "h": 

1441 ax.plot([q, q + h, q + h, q], [p1, p1, p2, p2], c="k") 

1442 ax.text(q + h, (p1 + p2) * 0.5, ast, ha="left", va="center", color="k", rotation=-90) 

1443 elif orientation == "v": 

1444 ax.plot([p1, p1, p2, p2], [q, q + h, q + h, q], c="k") 

1445 ax.text((p1 + p2) * 0.5, q + h, ast, ha="center", va="bottom", color="k") 

1446 q_prev = q 

1447 return ax 

1448 

1449 

1450@TimeTrackingWrapper 

1451class PlotTimeEvolutionMetric(AbstractPlotClass): 

1452 

1453 def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type", plot_folder=".", 

1454 error_measure: str = "mse", error_unit: str = None, model_name: str = "NN", 

1455 model_indicator: str = "nn", time_dim="index"): 

1456 super().__init__(plot_folder, "time_evolution_mse") 

1457 self.title = error_measure + f" (in {error_unit})" if error_unit is not None else "" 

1458 plot_name = self.plot_name 

1459 vmin = int(data.quantile(0.05)) 

1460 vmax = int(data.quantile(0.95)) 

1461 data = self._prepare_data(data, time_dim, model_type_dim, model_indicator, model_name) 

1462 

1463 # detailed plot for each model type 

1464 for t in data[model_type_dim]: 

1465 # note: could be expanded to create plot per ahead step 

1466 plot_data = data.sel({model_type_dim: t}).mean(ahead_dim).to_pandas() 

1467 years = plot_data.columns.strftime("%Y").to_list() 

1468 months = plot_data.columns.strftime("%b").to_list() 

1469 plot_data.columns = plot_data.columns.strftime("%b %Y") 

1470 self.plot_name = f"{plot_name}_{t.values}" 

1471 self._plot(plot_data, years, months, vmin, vmax, str(t.values)) 

1472 

1473 # aggregated version with all model types 

1474 remaining_dim = set(data.dims).difference((model_type_dim, time_dim)) 

1475 _data = data.mean(remaining_dim, skipna=True).transpose(model_type_dim, time_dim) 

1476 vmin = int(_data.quantile(0.05)) 

1477 vmax = int(_data.quantile(0.95)) 

1478 plot_data = _data.to_pandas() 

1479 years = plot_data.columns.strftime("%Y").to_list() 

1480 months = plot_data.columns.strftime("%b").to_list() 

1481 plot_data.columns = plot_data.columns.strftime("%b %Y") 

1482 self.plot_name = f"{plot_name}_summary" 

1483 self._plot(plot_data, years, months, vmin, vmax, None) 

1484 

1485 # line plot version 

1486 y_dim = "error" 

1487 plot_data = data.to_dataset(name=y_dim).to_dataframe().reset_index() 

1488 self.plot_name = f"{plot_name}_line_plot" 

1489 self._plot_summary_line(plot_data, x_dim=time_dim, y_dim=y_dim, hue_dim=model_type_dim) 

1490 

1491 @staticmethod 

1492 def _find_nan_edge(data, time_dim): 

1493 coll = [] 

1494 for i in data: 

1495 if bool(i) is False: 

1496 break 

1497 else: 

1498 coll.append(i[time_dim].values) 

1499 return coll 

1500 

1501 def _prepare_data(self, data, time_dim, model_type_dim, model_indicator, model_name): 

1502 # remove nans at begin and end 

1503 nan_locs = data.isnull().all(helpers.remove_items(data.dims, time_dim)) 

1504 nans_at_end = self._find_nan_edge(reversed(nan_locs), time_dim) 

1505 nans_at_begin = self._find_nan_edge(nan_locs, time_dim) 

1506 data = data.drop(nans_at_begin + nans_at_end, time_dim) 

1507 # rename nn model 

1508 data[model_type_dim] = [v if v != model_indicator else model_name for v in data[model_type_dim].data.tolist()] 

1509 return data 

1510 

1511 @staticmethod 

1512 def _set_ticks(ax, years, months): 

1513 from matplotlib.ticker import IndexLocator 

1514 ax.xaxis.set_major_locator(IndexLocator(1, 0.5)) 

1515 locs = ax.get_xticks(minor=False).tolist()[:len(months)] 

1516 ax.set_xticks(locs, minor=True) 

1517 ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0) 

1518 locs_major = [] 

1519 labels_major = [] 

1520 for l, major, minor in zip(locs, years, months): 

1521 if minor == "Jan": 

1522 locs_major.append(l + 0.001) 

1523 labels_major.append(major) 

1524 if len(locs_major) == 0: # in case there is less than a year and no Jan included 

1525 locs_major = locs[0] + 0.001 

1526 labels_major = years[0] 

1527 ax.set_xticks(locs_major) 

1528 ax.set_xticklabels(labels_major, minor=False, rotation=0) 

1529 ax.tick_params(axis="x", which="major", pad=15) 

1530 

1531 @staticmethod 

1532 def _aspect_cbar(val): 

1533 return min(max(1.25 * val + 7.5, 5), 30) 

1534 

1535 def _plot(self, data, years, months, vmin=None, vmax=None, subtitle=None): 

1536 fig, ax = plt.subplots(figsize=(max(data.shape[1] / 6, 12), max(data.shape[0] / 3.5, 2))) 

1537 data.sort_index(inplace=True) 

1538 sns.heatmap(data, linewidths=1, cmap="coolwarm", ax=ax, vmin=vmin, vmax=vmax, 

1539 cbar_kws={"aspect": self._aspect_cbar(data.shape[0])}) 

1540 # or cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm", 

1541 # square=True 

1542 self._set_ticks(ax, years, months) 

1543 ax.set(xlabel=None, ylabel=None, title=self.title if subtitle is None else f"{subtitle}\n{self.title}") 

1544 plt.tight_layout() 

1545 self._save() 

1546 

1547 def _plot_summary_line(self, data, x_dim, y_dim, hue_dim): 

1548 data[x_dim] = pd.to_datetime(data[x_dim].dt.strftime('%Y-%m')) #??? 

1549 n = len(data[hue_dim].unique()) 

1550 ax = sns.lineplot(data=data, x=x_dim, y=y_dim, hue=hue_dim, errorbar=("pi", 50), 

1551 palette=sns.color_palette()[:n], style=hue_dim, dashes=False, markers=["X"]*n) 

1552 ax.set(xlabel=None, ylabel=self.title) 

1553 ax.get_legend().set_title(None) 

1554 ax.xaxis.set_major_locator(mdates.YearLocator()) 

1555 ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y')) 

1556 ax.xaxis.set_minor_locator(mdates.MonthLocator(bymonth=range(1, 13, 3))) 

1557 ax.xaxis.set_minor_formatter(mdates.DateFormatter('%b')) 

1558 ax.margins(x=0) 

1559 plt.tight_layout() 

1560 self._save() 

1561 

1562 

1563@TimeTrackingWrapper 

1564class PlotSeasonalMSEStack(AbstractPlotClass): 

1565 

1566 def __init__(self, data, data_path: str, plot_folder: str = ".", boot_dim="boots", ahead_dim="ahead", 

1567 sampling: str = "daily", error_measure: str = "MSE", error_unit: str = "ppb$^2$", time_dim="index", 

1568 model_type_dim: str = "type", model_name: str = "NN", model_indicator: str = "nn",): 

1569 """Set attributes and create plot.""" 

1570 super().__init__(plot_folder, "seasonal_mse_stack_plot") 

1571 self._data_path = data_path 

1572 self.season_dim = "season" 

1573 self.time_dim = time_dim 

1574 self.ahead_dim = ahead_dim 

1575 self.error_unit = error_unit 

1576 self.error_measure = error_measure 

1577 self.dim_order = [self.season_dim, ahead_dim, model_type_dim] 

1578 

1579 # mse from monthly blocks 

1580 self.plot_name_orig = "seasonal_mse_stack_plot" 

1581 self._data = self._prepare_data(data) 

1582 for orientation in ["horizontal", "vertical"]: 

1583 for split_ahead in [True, False]: 

1584 self._plot(ahead_dim, split_ahead, sampling, orientation) 

1585 self._save(bbox_inches="tight") 

1586 

1587 # mes from resampling 

1588 self.plot_name_orig = "seasonal_mse_from_uncertainty_stack_plot" 

1589 self._data = self._prepare_data_from_uncertainty(boot_dim, data_path, model_type_dim, model_indicator, 

1590 model_name) 

1591 for orientation in ["horizontal", "vertical"]: 

1592 for split_ahead in [True, False]: 

1593 self._plot(ahead_dim, split_ahead, sampling, orientation) 

1594 self._save(bbox_inches="tight") 

1595 

1596 def _prepare_data(self, data): 

1597 season_mean = data.groupby(f"{self.time_dim}.{self.season_dim}").mean() 

1598 total_mean = data.mean(self.time_dim) 

1599 factor = season_mean / season_mean.sum(self.season_dim) 

1600 season_share = (total_mean * factor).reindex({self.season_dim: ["DJF", "MAM", "JJA", "SON"]}) 

1601 season_share = season_share.mean(set(season_share.dims).difference(self.dim_order)) 

1602 return season_share.sortby(season_share.sum([self.season_dim, self.ahead_dim])).transpose(*self.dim_order) 

1603 

1604 def _prepare_data_from_uncertainty(self, boot_dim, data_path, model_type_dim, model_indicator, model_name): 

1605 season_dim = self.season_dim 

1606 data = {} 

1607 for season in ["total", "DJF", "MAM", "JJA", "SON"]: 

1608 if season == "total": 

1609 file_name = "uncertainty_estimate_raw_results.nc" 

1610 else: 

1611 file_name = f"uncertainty_estimate_raw_results_{season}.nc" 

1612 with xr.open_dataarray(os.path.join(data_path, file_name)) as d: 

1613 data[season] = d 

1614 mean = {} 

1615 for season in data.keys(): 

1616 mean[season] = data[season].mean(boot_dim) 

1617 xr_data = xr.Dataset(mean).to_array(season_dim) 

1618 xr_data[model_type_dim] = [v if v != model_indicator else model_name for v in xr_data[model_type_dim].values] 

1619 xr_season = xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}) 

1620 factor = xr_season / xr_season.sum(season_dim) 

1621 season_share = xr_data.sel({season_dim: "total"}) * factor 

1622 return season_share.sortby(season_share.sum([self.season_dim, self.ahead_dim])).transpose(*self.dim_order) 

1623 

1624 @staticmethod 

1625 def _set_bar_label(ax): 

1626 opts = {} 

1627 sum = {} 

1628 for c in ax.containers: 

1629 labels = [v for v in c.datavalues] 

1630 opts[c] = labels 

1631 sum = {i: sum.get(i, 0) + l for (i, l) in enumerate(labels)} 

1632 for c, labels in opts.items(): 

1633 _l = [f"{round(100 * labels[i] / sum[i])}%" for i in range(len(labels))] 

1634 ax.bar_label(c, labels=_l, label_type='center') 

1635 

1636 def _plot(self, dim, split_ahead=True, sampling="daily", orientation="vertical"): 

1637 _, sampling_letter = self._get_sampling(sampling, 1) 

1638 sampling_letter = {"d": "D", "h": "H"}.get(sampling_letter, sampling_letter) 

1639 if split_ahead is False: 

1640 self.plot_name = self.plot_name_orig + "_total_" + orientation 

1641 data = self._data.mean(dim) 

1642 if orientation == "vertical": 

1643 fig, ax = plt.subplots(1, 1) 

1644 data.to_pandas().T.plot.bar(ax=ax, stacked=True, cmap="Dark2", legend=False) 

1645 ax.xaxis.label.set_visible(False) 

1646 ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})") 

1647 self._set_bar_label(ax) 

1648 else: 

1649 m = data.to_pandas().T.shape[0] 

1650 fig, ax = plt.subplots(1, 1, figsize=(6, m)) 

1651 data.to_pandas().T.plot.barh(ax=ax, stacked=True, cmap="Dark2", legend=False) 

1652 ax.yaxis.label.set_visible(False) 

1653 ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})") 

1654 self._set_bar_label(ax) 

1655 fig.legend(*ax.get_legend_handles_labels(), loc="upper center", ncol=4) 

1656 fig.tight_layout(rect=[0, 0, 1, 0.9]) 

1657 else: 

1658 self.plot_name = self.plot_name_orig + "_" + orientation 

1659 data = self._data 

1660 n = len(data.coords[dim]) 

1661 m = data.max(self.season_dim).shape 

1662 if orientation == "vertical": 

1663 fig, ax = plt.subplots(1, n, sharey=True, figsize=(np.prod(m) / 0.8, 5)) 

1664 for i, sel in enumerate(data.coords[dim].values): 

1665 data.sel({dim: sel}).to_pandas().T.plot.bar(ax=ax[i], stacked=True, cmap="Dark2", legend=False) 

1666 label = sampling_letter + str(sel) 

1667 ax[i].set_title(label) 

1668 ax[i].xaxis.label.set_visible(False) 

1669 self._set_bar_label(ax[i]) 

1670 ax[0].set_ylabel(f"{self.error_measure} (in {self.error_unit})") 

1671 fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4) 

1672 fig.tight_layout(rect=[0, 0, 1, 0.9]) 

1673 else: 

1674 fig, ax = plt.subplots(n, 1, sharex=True, figsize=(6, np.prod(m) * 0.6)) 

1675 for i, sel in enumerate(data.coords[dim].values): 

1676 data.sel({dim: sel}).to_pandas().T.plot.barh(ax=ax[i], stacked=True, cmap="Dark2", legend=False) 

1677 label = sampling_letter + str(sel) 

1678 ax[i].set_title(label) 

1679 ax[i].yaxis.label.set_visible(False) 

1680 self._set_bar_label(ax[i]) 

1681 ax[-1].set_xlabel(f"{self.error_measure} (in {self.error_unit})") 

1682 fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4) 

1683 fig.tight_layout(rect=[0, 0, 1, 0.95]) 

1684 

1685 

1686@TimeTrackingWrapper 

1687class PlotErrorsOnMap(AbstractPlotClass): 

1688 from mlair.plotting.data_insight_plotting import PlotStationMap 

1689 

1690 def __init__(self, data_gen, errors, error_metric, plot_folder: str = ".", iter_dim: str = "station", 

1691 model_type_dim: str = "type", ahead_dim: str = "ahead", sampling: str = "daily"): 

1692 

1693 super().__init__(plot_folder, f"map_plot_{error_metric}") 

1694 plot_path = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") 

1695 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

1696 error_metric_units = helpers.statistics.get_error_metrics_units("ppb")[error_metric] 

1697 error_metric_name = helpers.statistics.get_error_metrics_long_name()[error_metric] 

1698 self.sampling = self._get_sampling(sampling, 1)[1] 

1699 

1700 coords = self._extract_coords(data_gen) 

1701 for split_ahead in [False, True]: 

1702 error_data = {} 

1703 for model_type in errors.coords[model_type_dim].values: 

1704 error_data[model_type] = self._prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric, 

1705 split_ahead=split_ahead) 

1706 limits = self._calculate_limits(error_data) 

1707 for model_type, error in error_data.items(): 

1708 if split_ahead is True: 

1709 for ahead in error.index.unique(1).to_list(): 

1710 error_ahead = error.query(f"{ahead_dim} == {ahead}").droplevel(1) 

1711 plot_data = pd.concat([coords, error_ahead], axis=1) 

1712 self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits, 

1713 ahead=ahead) 

1714 pdf_pages.savefig() 

1715 else: 

1716 plot_data = pd.concat([coords, error], axis=1) 

1717 self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits) 

1718 pdf_pages.savefig() 

1719 pdf_pages.close() 

1720 plt.close('all') 

1721 

1722 @staticmethod 

1723 def _calculate_limits(data): 

1724 vmin, vmax = np.inf, -np.inf 

1725 for v in data.values(): 

1726 vmin = min(vmin, v.min().values) 

1727 vmax = max(vmax, v.max().values) 

1728 return relative_round(float(vmin), 2, floor=True), relative_round(float(vmax), 2, ceil=True) 

1729 

1730 @staticmethod 

1731 def _set_bounds(limits, ncolors, error_metric): 

1732 bound_lims = {"ioa": [0, 1], "mnmb": [-2, 2]}.get(error_metric, limits) 

1733 vmin = relative_round(bound_lims[0], 2, floor=True) 

1734 vmax = relative_round(bound_lims[1], 2, ceil=True) 

1735 interval = relative_round((vmax - vmin) / ncolors, 1, ceil=True) 

1736 bounds = np.sort(np.arange(vmax, vmin, -interval)) 

1737 return bounds 

1738 

1739 @staticmethod 

1740 def _get_colorpalette(error_metric): 

1741 # cmap = matplotlib.cm.coolwarm 

1742 # cmap = sns.color_palette("magma_r", as_cmap=True) 

1743 # cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm", 

1744 # cmap = sns.cubehelix_palette(8, start=2, rot=0, dark=0, light=.95, as_cmap=True) 

1745 if error_metric == "mnmb": 

1746 cmap = sns.mpl_palette("coolwarm", as_cmap=True) 

1747 elif error_metric == "ioa": 

1748 cmap = sns.mpl_palette("coolwarm_r", as_cmap=True) 

1749 else: 

1750 cmap = sns.color_palette("magma_r", as_cmap=True) 

1751 return cmap 

1752 

1753 def plot(self, plot_data, error_metric, error_long_name, error_units, model_type, limits, ahead=None): 

1754 import cartopy.crs as ccrs 

1755 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER 

1756 fig = plt.figure(figsize=(10, 5)) 

1757 ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) 

1758 _gl = ax.gridlines(xlocs=range(-180, 180, 5), ylocs=range(-90, 90, 2), draw_labels=True) 

1759 _gl.xformatter = LONGITUDE_FORMATTER 

1760 _gl.yformatter = LATITUDE_FORMATTER 

1761 self._draw_background(ax) 

1762 cmap = self._get_colorpalette(error_metric) 

1763 ncolors = 20 

1764 bounds = self._set_bounds(limits, ncolors, error_metric) 

1765 norm = colors.BoundaryNorm(bounds, cmap.N, extend='both') 

1766 cb = ax.scatter(plot_data["lon"], plot_data["lat"], c=plot_data[error_metric], marker='o', s=50, 

1767 transform=ccrs.PlateCarree(), zorder=2, cmap=cmap, norm=norm) 

1768 cbar_label = f"{error_long_name} (in {error_units})" if error_units is not None else error_long_name 

1769 plt.colorbar(cb, label=cbar_label) 

1770 self._adjust_extent(ax) 

1771 sampling_letter = {"d": "D", "h": "H"}.get(self.sampling, self.sampling) 

1772 title = model_type if ahead is None else f"{model_type} ({sampling_letter}{ahead})" 

1773 plt.title(title) 

1774 plt.tight_layout() 

1775 

1776 @staticmethod 

1777 def _adjust_extent(ax): 

1778 import cartopy.crs as ccrs 

1779 

1780 def diff(arr): 

1781 return arr[1] - arr[0], arr[3] - arr[2] 

1782 

1783 def find_ratio(delta, reference=5): 

1784 return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5) 

1785 

1786 extent = ax.get_extent(crs=ccrs.PlateCarree()) 

1787 ratio = find_ratio(diff(extent)) 

1788 new_extent = extent + np.array([-1, 1, -1, 1]) * ratio 

1789 ax.set_extent(new_extent, crs=ccrs.PlateCarree()) 

1790 

1791 @staticmethod 

1792 def _extract_coords(gen): 

1793 coll = [] 

1794 for station in gen: 

1795 coords = station.get_coordinates() 

1796 coll.append((str(station), coords["lon"], coords["lat"])) 

1797 return pd.DataFrame(coll, columns=["station", "lon", "lat"]).set_index("station") 

1798 

1799 @staticmethod 

1800 def _prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric, split_ahead=False): 

1801 e = errors.sel({model_type_dim: model_type}, drop=True) 

1802 if split_ahead is False: 

1803 e = e.mean(ahead_dim) 

1804 return e.to_dataframe(error_metric) 

1805 

1806 @staticmethod 

1807 def _draw_background(ax): 

1808 """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" 

1809 

1810 import cartopy.feature as cfeature 

1811 

1812 ax.add_feature(cfeature.LAND.with_scale("50m")) 

1813 ax.natural_earth_shp(resolution='50m') 

1814 ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') 

1815 ax.add_feature(cfeature.LAKES.with_scale("50m")) 

1816 ax.add_feature(cfeature.OCEAN.with_scale("50m")) 

1817 ax.add_feature(cfeature.RIVERS.with_scale("50m")) 

1818 ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') 

1819 

1820 

1821 

1822 

1823 

1824 

1825 def _plot_individual(self): 

1826 import cartopy.feature as cfeature 

1827 import cartopy.crs as ccrs 

1828 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER 

1829 from mpl_toolkits.axes_grid1 import make_axes_locatable 

1830 

1831 for competitor in self.reference_models: 

1832 file_name = os.path.join(self.skill_score_report_path, 

1833 f"error_report_skill_score_{self.model_name}_-_{competitor}.csv" 

1834 ) 

1835 

1836 plot_path = os.path.join(os.path.abspath(self.plot_folder), 

1837 f"{self.plot_name}_{self.model_name}_-_{competitor}.pdf") 

1838 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

1839 

1840 for i, lead_name in enumerate(df.columns[:-2]): # last two are lat lon 

1841 fig = plt.figure() 

1842 self._ax.scatter(df.lon.values, df.lat.values, c=df[lead_name], 

1843 transform=ccrs.PlateCarree(), 

1844 norm=norm, cmap=cmap) 

1845 self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) 

1846 self._gl.xformatter = LONGITUDE_FORMATTER 

1847 self._gl.yformatter = LATITUDE_FORMATTER 

1848 label = f"Skill Score: {lead_name.replace('-', 'vs.').replace('(t+', ' (').replace(')', 'd)')}" 

1849 self._cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), 

1850 orientation='horizontal', ticks=ticks, 

1851 label=label, 

1852 # cax=cax 

1853 ) 

1854 

1855 # close all open figures / plots 

1856 pdf_pages.savefig() 

1857 pdf_pages.close() 

1858 plt.close('all') 

1859 

1860 def _plot(self, ncol: int = 2): 

1861 import cartopy.feature as cfeature 

1862 import cartopy.crs as ccrs 

1863 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER 

1864 import string 

1865 base_plot_name = self.plot_name 

1866 for competitor in self.reference_models: 

1867 file_name = os.path.join(self.skill_score_report_path, 

1868 f"error_report_skill_score_{self.model_name}_-_{competitor}.csv" 

1869 ) 

1870 

1871 self.plot_name = f"{base_plot_name}_{self.model_name}_-_{competitor}" 

1872 df = self.open_data(file_name) 

1873 

1874 nrow = int(np.ceil(len(df.columns[:-2])/ncol)) 

1875 bounds = np.linspace(-1, 1, 100) 

1876 cmap = mpl.cm.coolwarm 

1877 norm = colors.BoundaryNorm(bounds, cmap.N, extend='both') 

1878 ticks = np.arange(norm.vmin, norm.vmax + .2, .2) 

1879 fig, self._axes = plt.subplots(nrows=nrow, ncols=ncol, subplot_kw={'projection': ccrs.PlateCarree()}) 

1880 for i, ax in enumerate(self._axes.reshape(-1)): # last two are lat lon 

1881 

1882 sub_name = f"({string.ascii_lowercase[i]})" 

1883 lead_name = df.columns[i] 

1884 ax.add_feature(cfeature.LAND.with_scale("50m")) 

1885 ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') 

1886 ax.add_feature(cfeature.OCEAN.with_scale("50m")) 

1887 ax.add_feature(cfeature.RIVERS.with_scale("50m")) 

1888 ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') 

1889 ax.scatter(df.lon.values, df.lat.values, c=df[lead_name], 

1890 marker='.', 

1891 transform=ccrs.PlateCarree(), 

1892 norm=norm, cmap=cmap) 

1893 gl = ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) 

1894 gl.xformatter = LONGITUDE_FORMATTER 

1895 gl.yformatter = LATITUDE_FORMATTER 

1896 gl.top_labels = [] 

1897 gl.right_labels = [] 

1898 ax.text(0.01, 1.09, f'{sub_name} {lead_name.split("+")[1][:-1]}d', 

1899 verticalalignment='top', horizontalalignment='left', 

1900 transform=ax.transAxes, 

1901 color='black', 

1902 ) 

1903 label = f"Skill Score: {lead_name.replace('-', 'vs.').split('(')[0]}" 

1904 

1905 fig.subplots_adjust(bottom=0.18) 

1906 cax = fig.add_axes([0.15, 0.1, 0.7, 0.02]) 

1907 self._cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), 

1908 orientation='horizontal', 

1909 ticks=ticks, 

1910 label=label, 

1911 cax=cax 

1912 ) 

1913 

1914 fig.subplots_adjust(wspace=.001, hspace=.2) 

1915 self._save(bbox_inches="tight") 

1916 plt.close('all') 

1917 

1918 

1919 @staticmethod 

1920 def get_coords_from_index(name_string: str) -> List[float]: 

1921 """ 

1922 

1923 :param name_string: 

1924 :type name_string: 

1925 :return: List of coords [lat, lon] 

1926 :rtype: List 

1927 """ 

1928 res = [float(frac.replace("_", ".")) for frac in name_string.split(sep="__")[1:]] 

1929 return res