Coverage for mlair/plotting/postprocessing_plotting.py: 13%
357 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1"""Collection of plots to evaluate a model, create overviews on data or forecasts."""
2__author__ = "Lukas Leufen, Felix Kleinert"
3__date__ = '2020-11-23'
5import logging
6import math
7import os
8import warnings
9from typing import Dict, List, Tuple, Union
10import itertools
12import matplotlib
13import matplotlib.pyplot as plt
14import matplotlib.colors as colors
15import numpy as np
16import pandas as pd
17import seaborn as sns
18import xarray as xr
19from matplotlib.backends.backend_pdf import PdfPages
20from matplotlib.offsetbox import AnchoredText
21import matplotlib.dates as mdates
22from scipy.stats import mannwhitneyu
23import datetime as dt
25from mlair import helpers
26from mlair.data_handler.iterator import DataCollection
27from mlair.helpers import TimeTrackingWrapper
28from mlair.helpers.helpers import relative_round
29from mlair.plotting.abstract_plot_class import AbstractPlotClass
30from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks
33logging.getLogger('matplotlib').setLevel(logging.WARNING)
36# import matplotlib
37# matplotlib.use("TkAgg")
38# import matplotlib.pyplot as plt
41@TimeTrackingWrapper
42class PlotMonthlySummary(AbstractPlotClass): # pragma: no cover
43 """
44 Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot.
46 The plot is saved in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution.
48 .. image:: ../../../../../_source/_plots/monthly_summary_box_plot.png
49 :width: 400
51 :param stations: all stations to plot
52 :param data_path: path, where the data is located
53 :param name: full name of the local files with a % as placeholder for the station name
54 :param target_var: display name of the target variable on plot's axis
55 :param window_lead_time: lead time to plot, if window_lead_time is higher than the available lead time or not given
56 the maximum lead time from data is used. (default None -> use maximum lead time from data).
57 :param plot_folder: path to save the plot (default: current directory)
58 :param target_var_unit: unit of target var for plot legend (default= ppb)
60 """
62 def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None,
63 plot_folder: str = ".", target_var_unit: str = 'ppb', model_name="nn"):
64 """Set attributes and create plot."""
65 super().__init__(plot_folder, "monthly_summary_box_plot")
66 self._data_path = data_path
67 self._data_name = name
68 self._model_name = model_name
69 self._data = self._prepare_data(stations)
70 self._window_lead_time = self._get_window_lead_time(window_lead_time)
71 self._plot(target_var, target_var_unit)
72 self._save()
74 def _prepare_data(self, stations: List) -> xr.DataArray:
75 """
76 Pre.process data required to plot.
78 For each station, load locally saved predictions, extract the CNN prediction and the observation and group them
79 into monthly bins (no aggregation, only sorting them).
81 :param stations: all stations to plot
82 :return: The entire data set, flagged with the corresponding month.
83 """
84 forecasts = None
85 for station in stations:
86 logging.debug(f"... preprocess station {station}")
87 file_name = os.path.join(self._data_path, self._data_name % station)
88 data = xr.open_dataarray(file_name)
90 data_nn = data.sel(type=self._model_name).squeeze()
91 if len(data_nn.shape) > 1:
92 data_nn = data_nn.assign_coords(ahead=[f"{days}d" for days in data_nn.coords["ahead"].values])
93 else:
94 data_nn.coords["ahead"].values = str(data_nn.coords["ahead"].values) + "d"
96 data_obs = data.sel(type="obs", ahead=1).squeeze()
97 data_obs.coords["ahead"] = "obs"
99 data_concat = xr.concat([data_obs, data_nn], dim="ahead")
100 data_concat = data_concat.drop_vars("type")
102 new_index = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1
103 data_concat = data_concat.assign_coords(index=new_index)
104 data_concat = data_concat.clip(min=0)
106 forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat
107 return forecasts
109 def _get_window_lead_time(self, window_lead_time: int):
110 """
111 Extract the lead time from data and arguments.
113 If window_lead_time is not given, extract this information from data itself by the number of ahead dimensions.
114 If given, check if data supports the give length. If the number of ahead dimensions in data is lower than the
115 given lead time, data's lead time is used.
117 :param window_lead_time: lead time from arguments to validate
118 :return: validated lead time, comes either from given argument or from data itself
119 """
120 ahead_steps = len(self._data.ahead)
121 if window_lead_time is None:
122 window_lead_time = ahead_steps
123 return min(ahead_steps, window_lead_time)
125 @staticmethod
126 def _spell_out_chemical_concentrations(short_name: str):
127 short2long = {'o3': 'ozone', 'no': 'nitrogen oxide', 'no2': 'nitrogen dioxide', 'nox': 'nitrogen dioxides'}
128 return f"{short2long[short_name]} concentration"
130 def _plot(self, target_var: str, target_var_unit: str):
131 """
132 Create a monthly grouped box plot over all stations but with separate boxes for each lead time step.
134 :param target_var: display name of the target variable on plot's axis
135 """
136 data = self._data.to_dataset(name='values').to_dask_dataframe()
137 logging.debug("... start plotting")
138 color_palette = [colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex()
139 ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1.5, palette=color_palette,
140 flierprops={'marker': '.', 'markersize': 1}, showmeans=True,
141 meanprops={'markersize': 1, 'markeredgecolor': 'k'})
142 ylabel = self._spell_out_chemical_concentrations(target_var)
143 ax.set(xlabel='month', ylabel=f'{ylabel} (in {target_var_unit})')
144 plt.tight_layout()
147@TimeTrackingWrapper
148class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover
149 """
150 Create cond.quantile plots as originally proposed by Murphy, Brown and Chen (1989) [But in log scale].
152 Link to paper: https://journals.ametsoc.org/doi/pdf/10.1175/1520-0434%281989%29004%3C0485%3ADVOTF%3E2.0.CO%3B2
154 .. image:: ../../../../../_source/_plots/conditional_quantiles_cali-ref_plot.png
155 :width: 400
157 .. image:: ../../../../../_source/_plots/conditional_quantiles_like-bas_plot.png
158 :width: 400
160 For each time step ahead a separate plot is created. If parameter plot_per_season is true, data is split by season
161 and conditional quantiles are plotted for each season in addition.
163 :param stations: all stations to plot
164 :param data_pred_path: path to dir which contains the forecasts as .nc files
165 :param plot_folder: path where the plots are stored
166 :param plot_per_seasons: if `True' create cond. quantile plots for _seasons (DJF, MAM, JJA, SON) individually
167 :param rolling_window: smoothing of quantiles (3 is used by Murphy et al.)
168 :param model_name: name of the model prediction as stored in netCDF file (for example "nn")
169 :param obs_name: name of observation as stored in netCDF file (for example "obs")
170 :param kwargs: Some further arguments which are listed in self._opts
171 """
173 # ignore warnings if nans appear in quantile grouping
174 warnings.filterwarnings("ignore", message="All-NaN slice encountered")
175 # ignore warnings if mean is calculated on nans
176 warnings.filterwarnings("ignore", message="Mean of empty slice")
177 # ignore warnings for y tick = 0 on log scale (instead of 0.00001 or similar)
178 warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.")
180 def __init__(self, stations: List, data_pred_path: str, plot_folder: str = ".", plot_per_seasons=True,
181 rolling_window: int = 3, forecast_indicator: str = "nn", obs_indicator: str = "obs",
182 competitors=None, model_type_dim: str = "type", index_dim: str = "index", ahead_dim: str = "ahead",
183 competitor_path: str = None, sampling: str = "daily", model_name: str = "nn", **kwargs):
184 """Initialise."""
185 super().__init__(plot_folder, "conditional_quantiles")
186 self._data_pred_path = data_pred_path
187 self._stations = stations
188 self._rolling_window = rolling_window
189 self._forecast_indicator = forecast_indicator
190 self.model_type_dim = model_type_dim
191 self.index_dim = index_dim
192 self.ahead_dim = ahead_dim
193 self.iter_dim = "station"
194 self.model_name = model_name
195 self._obs_name = obs_indicator
196 self._sampling = sampling
197 self._opts = self._get_opts(kwargs)
198 self._seasons = ['DJF', 'MAM', 'JJA', 'SON'] if plot_per_seasons is True else ""
199 self.competitors = self._correct_persi_name(competitors or [])
200 self.competitor_path = competitor_path or data_pred_path
201 self._data = self._load_data()
202 self._bins = self._get_bins_from_rage_of_data()
203 self._plot()
205 @staticmethod
206 def _get_opts(kwargs):
207 """Extract options from kwargs."""
208 return {"q": kwargs.get("q", [.1, .25, .5, .75, .9]),
209 "linetype": kwargs.get("linetype", [':', '-.', '--', '-.', ':']),
210 "legend": kwargs.get("legend", ['.10th and .90th quantile', '.25th and .75th quantile',
211 '.50th quantile', 'reference 1:1']),
212 "data_unit": kwargs.get("data_unit", "ppb"), }
214 def _load_data(self) -> xr.DataArray:
215 """
216 Load plot data.
218 :return: plot data
219 """
220 logging.debug("... load data")
221 data_collector = []
222 for station in self._stations:
223 file = os.path.join(self._data_pred_path, f"forecasts_{station}_test.nc")
224 data_tmp = xr.open_dataarray(file)
225 start = data_tmp.coords[self.index_dim].min().values
226 end = data_tmp.coords[self.index_dim].max().values
227 competitor = self.load_competitors(station, start, end)
228 combined = self._combine_forecasts(data_tmp, competitor, dim=self.model_type_dim)
229 sel = combined.sel({self.model_type_dim: [self._forecast_indicator, self._obs_name, *self.competitors]})
230 data_collector.append(sel.assign_coords({self.iter_dim: station}))
231 res = xr.concat(data_collector, dim=self.iter_dim).transpose(self.index_dim, self.model_type_dim,
232 self.ahead_dim, self.iter_dim)
233 return res
235 def _combine_forecasts(self, forecast, competitor, dim=None):
236 """
237 Combine forecast and competitor if both are xarray. If competitor is None, this returns forecasts and vise
238 versa.
239 """
240 if dim is None:
241 dim = self.model_type_dim
242 try:
243 return xr.concat([forecast, competitor], dim=dim)
244 except (TypeError, AttributeError):
245 return forecast if competitor is None else competitor
247 def load_competitors(self, station_name: str, start, end) -> xr.DataArray:
248 """
249 Load all requested and available competitors for a given station. Forecasts must be available in the competitor
250 path like `<competitor_path>/<target_var>/forecasts_<station_name>_test.nc`. The naming style is equal for all
251 forecasts of MLAir, so that forecasts of a different experiment can easily be copied into the competitor path
252 without any change.
254 :param station_name: station indicator to load competitors for
256 :return: a single xarray with all competing forecasts
257 """
258 competing_predictions = []
259 for competitor_name in self.competitors:
260 try:
261 prediction = self._create_competitor_forecast(station_name, competitor_name, start, end)
262 competing_predictions.append(prediction)
263 except (FileNotFoundError, KeyError):
264 logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.")
265 continue
266 return xr.concat(competing_predictions, self.model_type_dim) if len(competing_predictions) > 0 else None
268 @staticmethod
269 def create_full_time_dim(data, dim, sampling, start, end):
270 """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
271 start_data = data.coords[dim].values[0]
272 freq = {"daily": "1D", "hourly": "1H"}.get(sampling)
273 _ind = pd.date_range(start, end, freq=freq) # two steps required to include all hours of end interval
274 datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1),
275 closed="left", freq=freq))
276 t = data.sel({dim: start_data}, drop=True)
277 res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords])
278 res = res.transpose(*data.dims)
279 if data.shape == res.shape:
280 res.loc[data.coords] = data
281 else:
282 _d = data.sel({dim: slice(start, end)})
283 res.loc[_d.coords] = _d
284 return res
286 def _create_competitor_forecast(self, station_name: str, competitor_name: str, start, end) -> xr.DataArray:
287 """
288 Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station
289 indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will
290 raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either
291 there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file.
292 Forecast is trimmed on interval start and end of test subset.
294 :param station_name: name of the station to load data for
295 :param competitor_name: name of the model
296 :return: the forecast of the given competitor
297 """
298 path = os.path.join(self.competitor_path, competitor_name)
299 file = os.path.join(path, f"forecasts_{station_name}_test.nc")
300 with xr.open_dataarray(file) as da:
301 data = da.load()
302 if self._forecast_indicator in data.coords[self.model_type_dim]:
303 forecast = data.sel({self.model_type_dim: [self._forecast_indicator]})
304 forecast.coords[self.model_type_dim] = [competitor_name]
305 else:
306 forecast = data.sel({self.model_type_dim: [competitor_name]})
307 # limit forecast to time range of test subset
308 return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end)
310 @staticmethod
311 def _correct_persi_name(competitors):
312 return ["persi" if x == "Persistence" else x for x in competitors]
314 def _segment_data(self, data: xr.DataArray, x_model: str) -> xr.DataArray:
315 """
316 Segment data into bins.
318 :param data: data to segment
319 :param x_model: name of x dimension
321 :return: segmented data
322 """
323 logging.debug("... segment data")
324 # combine index and station to multi index
325 data = data.stack(z=[self.index_dim, self.iter_dim])
326 # replace multi index by simple position index (order is not relevant anymore)
327 data.coords['z'] = range(len(data.coords['z']))
328 # segment data of x_model into bins
329 data_sel = data.sel({self.model_type_dim: x_model})
330 data.loc[{self.model_type_dim: x_model}] = data_sel.to_pandas().T.apply(pd.cut, bins=self._bins,
331 labels=self._bins[1:]).T.values
332 return data
334 @staticmethod
335 def _labels(plot_type: str, data_unit: str = "ppb") -> Tuple[str, str]:
336 """
337 Assign (x,y) labels to plots correctly, depending on like-base or cali-ref factorization.
339 :param plot_type: type of plot, either `obs` or a model name
340 :param data_unit: unit of data to add to labels (default ppb)
342 :return: tuple with y and x labels
343 """
344 names = (f"forecasted concentration (in {data_unit})", f"observed concentration (in {data_unit})")
345 if plot_type == "obs":
346 return names
347 else:
348 return names[::-1]
350 def _get_bins_from_rage_of_data(self) -> np.ndarray:
351 """
352 Get array of bins to use for quantiles.
354 :return: range from 0 to data's maximum + 1 (rounded down)
355 """
356 return np.arange(0, math.ceil(self._data.max().max()) + 1, 1).astype(int)
358 def _create_quantile_panel(self, data: xr.DataArray, x_model: str, y_model: str) -> xr.DataArray:
359 """
360 Calculate quantiles.
362 :param data: data to calculate quantiles
363 :param x_model: name of x dimension
364 :param y_model: name of y dimension
366 :return: quantile panel with binned data
367 """
368 logging.debug("... create quantile panel")
369 # create empty xarray with dims: time steps ahead, quantiles, bin index (numbers create in previous step)
370 quantile_panel = xr.DataArray(
371 np.full([data.ahead.shape[0], len(self._opts["q"]), self._bins[1:].shape[0]], np.nan),
372 coords=[data.ahead, self._opts["q"], self._bins[1:]], dims=[self.ahead_dim, 'quantiles', 'categories'])
373 # ensure that the coordinates are in the right order
374 quantile_panel = quantile_panel.transpose(self.ahead_dim, 'quantiles', 'categories')
375 # calculate for each bin of the pred_name data the quantiles of the ref_name data
376 for bin in self._bins[1:]:
377 mask = (data.loc[x_model, ...] == bin)
378 quantile_panel.loc[..., bin] = data.loc[y_model, ...].where(mask).quantile(self._opts["q"], dim=['z']).T
379 return quantile_panel
381 @staticmethod
382 def add_affix(affix: str) -> str:
383 """
384 Add additional information to plot name with leading underscore or add empty string if affix is empty.
386 :param affix: string to add
388 :return: affix with leading underscore or empty string.
389 """
390 return f"_{affix}" if len(affix) > 0 else ""
392 def _prepare_plots(self, data: xr.DataArray, x_model: str, y_model: str) -> Tuple[xr.DataArray, xr.DataArray]:
393 """
394 Get segmented data and quantile panel.
396 :param data: plot data
397 :param x_model: name of x dimension
398 :param y_model: name of y dimension
400 :return: segmented data and quantile panel
401 """
402 segmented_data = self._segment_data(data, x_model)
403 quantile_panel = self._create_quantile_panel(segmented_data, x_model, y_model)
404 return segmented_data, quantile_panel
406 def _plot(self):
407 """Start plotting routines: overall plot and seasonal (if enabled)."""
408 logging.info(f"start plotting {self.__class__.__name__}, scheduled number of plots: "
409 f"{(len(self._seasons) + 1) * (len(self.competitors) + 1) * 2}")
410 if len(self._seasons) > 0:
411 self._plot_seasons()
412 self._plot_all()
414 def _plot_seasons(self):
415 """Create seasonal plots."""
416 for model in [self._forecast_indicator, *self.competitors]:
417 for season in self._seasons:
418 self._plot_base(data=self._data.where(self._data[f"{self.index_dim}.season"] == season),
419 x_model=model, y_model=self._obs_name, plot_name_affix="cali-ref",
420 season=season, model_name=model)
421 self._plot_base(data=self._data.where(self._data[f"{self.index_dim}.season"] == season),
422 x_model=self._obs_name, y_model=model, plot_name_affix="like-base",
423 season=season, model_name=model)
425 def _plot_all(self):
426 """Plot overall conditional quantiles on full data."""
427 for model in [self._forecast_indicator, *self.competitors]:
428 self._plot_base(data=self._data, x_model=model, y_model=self._obs_name,
429 plot_name_affix="cali-ref", model_name=model)
430 self._plot_base(data=self._data, x_model=self._obs_name, y_model=model,
431 plot_name_affix="like-base", model_name=model)
433 @TimeTrackingWrapper
434 def _plot_base(self, data: xr.DataArray, x_model: str, y_model: str, plot_name_affix: str, season: str = "",
435 model_name: str = ""):
436 """
437 Create conditional quantile plots.
439 :param data: data which is used to create cond. quantile plot
440 :param x_model: name of model on x axis (can also be obs)
441 :param y_model: name of model on y axis (can also be obs)
442 :param plot_name_affix: should be `cali-ref' or `like-base'
443 :param season: List of _seasons to use
444 """
445 segmented_data, quantile_panel = self._prepare_plots(data, x_model, y_model)
446 ylabel, xlabel = self._labels(x_model, self._opts["data_unit"])
447 plot_name = f"{self.plot_name}{self.add_affix(season)}{self.add_affix(plot_name_affix)}" \
448 f"{self.add_affix(model_name)}.pdf"
449 plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name)
450 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
451 logging.debug(f"... plot path is {plot_path}")
453 # create plot for each time step ahead
454 y2_max = 0
455 for iteration, d in enumerate(segmented_data.ahead):
456 logging.debug(f"... plotting {d.values} time step(s) ahead")
457 # plot smoothed lines with rolling mean
458 smooth_data = quantile_panel.loc[d, ...].rolling(categories=self._rolling_window,
459 center=True).mean().to_pandas().T
460 ax = smooth_data.plot(style=self._opts["linetype"], color='black', legend=False)
461 ax2 = ax.twinx()
462 # add reference line
463 ax.plot([0, self._bins.max()], [0, self._bins.max()], color='k', label='reference 1:1', linewidth=.8)
464 # add histogram of the segmented data (pred_name)
465 handles, labels = ax.get_legend_handles_labels()
466 segmented_data.loc[x_model, d, :].to_pandas().hist(bins=self._bins, ax=ax2, color='k', alpha=.3, grid=False,
467 rwidth=1)
468 # add legend
469 plt.legend(handles[:3] + [handles[-1]], self._opts["legend"], loc='upper left', fontsize='large')
470 # adjust limits and set labels
471 ax.set(xlim=(0, self._bins.max()), ylim=(0, self._bins.max()))
472 ax.set_xlabel(xlabel, fontsize='x-large')
473 ax.tick_params(axis='x', which='major', labelsize=15)
474 ax.set_ylabel(ylabel, fontsize='x-large')
475 ax.tick_params(axis='y', which='major', labelsize=15)
476 ax2.yaxis.label.set_color('gray')
477 ax2.tick_params(axis='y', colors='gray')
478 ax2.yaxis.labelpad = -15
479 ax2.set_yscale('log')
480 if iteration == 0:
481 y2_max = ax2.get_ylim()[1] + 100
482 ax2.set(ylim=(0, y2_max * 10 ** 8), yticks=np.logspace(0, 4, 5))
483 ax2.set_ylabel(' sample size', fontsize='x-large')
484 ax2.tick_params(axis='y', which='major', labelsize=15)
485 # set title and save current figure
486 sampling_letter = {"daily": "D", "hourly": "H"}.get(self._sampling)
487 model_name = self.model_name if model_name == self._forecast_indicator else model_name
488 title = f"{model_name} ({sampling_letter}{d.values}{f', {season}' if len(season) > 0 else ''})"
489 plt.title(title)
490 pdf_pages.savefig()
491 # close all open figures / plots
492 pdf_pages.close()
493 plt.close('all')
496@TimeTrackingWrapper
497class PlotClimatologicalSkillScore(AbstractPlotClass): # pragma: no cover
498 """
499 Create plot of climatological skill score after Murphy (1988) as box plot over all stations.
501 A forecast time step (called "ahead") is separately shown to highlight the differences for each prediction time
502 step. Either each single term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed
503 (score_only=True, default). Y-axis is adjusted following the data and not hard coded. The plot is saved under
504 plot_folder path with name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi.
506 .. image:: ../../../../../_source/_plots/skill_score_clim_all_terms_CNN.png
507 :width: 400
509 .. image:: ../../../../../_source/_plots/skill_score_clim_CNN.png
510 :width: 400
512 :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms.
513 :param plot_folder: path to save the plot (default: current directory)
514 :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True)
515 :param extra_name_tag: additional tag that can be included in the plot name (default "")
516 :param model_name: architecture type to specify plot name (default "")
518 """
520 def __init__(self, data: Dict, plot_folder: str = ".", score_only: bool = True, extra_name_tag: str = "",
521 model_name: str = ""):
522 """Initialise."""
523 super().__init__(plot_folder, f"skill_score_clim_{extra_name_tag}{model_name}")
524 self._labels = None
525 self._data = self._prepare_data(data, score_only)
526 self._plot(score_only)
527 self._save()
529 def _prepare_data(self, data: Dict, score_only: bool) -> pd.DataFrame:
530 """
531 Shrink given data, if only scores are relevant.
533 In any case, transform data to a plot friendly format. Also set plot labels depending on the lead time
534 dimensions.
536 :param data: dictionary with station names as keys and 2D xarrays as values
537 :param score_only: if true only scores of CASE I to IV are relevant
538 :return: pre-processed data set
539 """
540 data = helpers.dict_to_xarray(data, "station")
541 self._labels = [str(i) + "d" for i in data.coords["ahead"].values]
542 if score_only:
543 data = data.loc[:, ["CASE I", "CASE II", "CASE III", "CASE IV"], :]
544 return data.to_dataframe("data").reset_index(level=[0, 1, 2])
546 def _label_add(self, score_only: bool):
547 """
548 Add the phrase "terms and " if score_only is disabled or empty string (if score_only=True).
550 :param score_only: if false all terms are relevant, otherwise only CASE I to IV
551 :return: additional label
552 """
553 return "" if score_only else "terms and "
555 def _plot(self, score_only, xlim=5):
556 """
557 Plot climatological skill score.
559 :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms
560 """
561 fig, ax = plt.subplots()
562 if not score_only:
563 fig.set_size_inches(11.7, 8.27)
564 sns.boxplot(x="terms", y="data", hue="ahead", data=self._data, ax=ax, whis=1.5, palette="Blues_d",
565 showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
566 ax.axhline(y=0, color="grey", linewidth=.5)
567 ax.set(ylabel=f"{self._label_add(score_only)}skill score", xlabel="", title="summary of all stations",
568 ylim=self._lim())
569 handles, _ = ax.get_legend_handles_labels()
570 ax.legend(handles, self._labels)
571 plt.tight_layout()
573 def _lim(self) -> Tuple[float, float]:
574 """
575 Calculate axis limits from data (Can be used to set axis extend).
577 Lower limit is the minimum of 0 and data's minimum (reduced by small subtrahend) and upper limit is data's
578 maximum (increased by a small addend).
580 :return:
581 """
582 limit = 5
583 lower = np.max([-limit, np.min([0, helpers.float_round(self._data["data"].min() - 0.1, 2)])])
584 upper = np.min([limit, helpers.float_round(self._data["data"].max() + 0.1, 2)])
585 return lower, upper
588@TimeTrackingWrapper
589class PlotCompetitiveSkillScore(AbstractPlotClass): # pragma: no cover
590 """
591 Create competitive skill score plot.
593 Create this plot for the given model setup and the reference models ordinary least squared ("ols") and the
594 persistence forecast ("persi") for all lead times ("ahead"). The plot is saved under plot_folder with the name
595 skill_score_competitive_{model_setup}.pdf and resolution of 500dpi.
597 .. image:: ../../../../../_source/_plots/skill_score_competitive.png
598 :width: 400
600 :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre-
601 calculated comparisons for cnn, persistence and ols.
602 :param plot_folder: path to save the plot (default: current directory)
603 :param model_setup: architecture type (default "CNN")
605 """
607 def __init__(self, data: Dict[str, pd.DataFrame], plot_folder=".", model_setup="NN"):
608 """Initialise."""
609 super().__init__(plot_folder, f"skill_score_competitive_{model_setup}")
610 self._model_setup = model_setup
611 self._labels = None
612 self._data = self._prepare_data(helpers.remove_items(data, "total"))
613 default_plot_name = self.plot_name
614 # draw full detail plot
615 self.plot_name = default_plot_name + "_full_detail"
616 self._plot()
617 self._save()
618 # draw also a vertical full detail version
619 self.plot_name = default_plot_name + "_full_detail_vertical"
620 self._plot_vertical()
621 self._save()
622 # draw default plot with only model comparison
623 self.plot_name = default_plot_name
624 self._plot(single_model_comparison=True)
625 self._save()
626 # draw also a vertical full detail version
627 self.plot_name = default_plot_name + "_vertical"
628 self._plot_vertical(single_model_comparison=True)
629 self._save()
631 def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
632 """
633 Reformat given data and create plot labels and introduce the dimensions stations and comparison.
635 :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre-
636 calculated comparisons for cnn, persistence and ols.
637 :return: processed data
638 """
639 data = pd.concat(data, axis=0)
640 data = xr.DataArray(data, dims=["stations", "ahead"]).unstack("stations")
641 data = data.rename({"stations_level_0": "stations", "stations_level_1": "comparison"})
642 data = data.to_dataframe("data").unstack(level=1).swaplevel()
643 data.columns = data.columns.levels[1]
644 self._labels = [str(i) + "d" for i in data.index.levels[1].values]
645 data = data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data")
646 return data.astype({"comparison": str, "ahead": int, "data": float})
648 def _plot(self, single_model_comparison=False):
649 """Plot skill scores of the comparisons."""
650 data = self._filter_comparisons(self._data) if single_model_comparison is True else self._data
651 max_label_size = len(max(np.unique(data.comparison).tolist(), key=len))
652 size = max([len(np.unique(data.comparison)), 6])
653 fig, ax = plt.subplots(figsize=(size, 5 * max(0.8, max_label_size/20)))
654 order = self._create_pseudo_order(data)
655 sns.boxplot(x="comparison", y="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d",
656 showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
657 order=order)
658 ax.axhline(y=0, color="grey", linewidth=.5)
659 ax.set(ylabel="skill score", xlabel="competing models", title="summary of all stations", ylim=self._lim(data))
660 handles, _ = ax.get_legend_handles_labels()
661 plt.xticks(rotation=90)
662 ax.legend(handles, self._labels)
663 plt.tight_layout()
665 def _plot_vertical(self, single_model_comparison=False):
666 """Plot skill scores of the comparisons, but vertically aligned."""
667 data = self._filter_comparisons(self._data) if single_model_comparison is True else self._data
668 max_label_size = len(max(np.unique(data.comparison).tolist(), key=len))
669 size = max([len(np.unique(data.comparison)), 6])
670 fig, ax = plt.subplots(figsize=(5 * max(0.8, max_label_size/20), size))
671 order = self._create_pseudo_order(data)
672 sns.boxplot(y="comparison", x="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d",
673 showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
674 order=order)
675 ax.axvline(x=0, color="grey", linewidth=.5)
676 ax.set(xlabel="skill score", ylabel="competing models", title="summary of all stations", xlim=self._lim(data))
677 handles, _ = ax.get_legend_handles_labels()
678 ax.legend(handles, self._labels)
679 plt.tight_layout()
681 def _create_pseudo_order(self, data):
682 """Provide first predefined elements and append all remaining."""
683 first_elements = [f"{self._model_setup} - persi", "ols - persi", f"{self._model_setup} - ols"]
684 first_elements = list(filter(lambda x: x in data.comparison.tolist(), first_elements))
685 uniq, index = np.unique(first_elements + data.comparison.unique().tolist(), return_index=True)
686 return uniq[index.argsort()]
688 def _filter_comparisons(self, data):
689 filtered_headers = list(filter(lambda x: f"{self._model_setup} - " in x, data.comparison.unique()))
690 return data[data.comparison.isin(filtered_headers)]
692 @staticmethod
693 def _lim(data) -> Tuple[float, float]:
694 """
695 Calculate axis limits from data (Can be used to set axis extend).
697 Lower limit is the minimum of 0 and data's minimum (reduced by small subtrahend) and upper limit is data's
698 maximum (increased by a small addend).
700 :return:
701 """
702 limit = 5
703 lower = np.max([-limit, np.min([0, helpers.float_round(data.min()[2], 2) - 0.1])])
704 upper = np.min([limit, helpers.float_round(data.max()[2], 2) + 0.1])
705 return lower, upper
708@TimeTrackingWrapper
709class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover
710 """
711 Create plot of feature importance analysis.
713 By passing a list `separate_vars` containing variable names, a second plot is created showing the `separate_vars`
714 and the remaining variables side by side with different scaling.
716 .. image:: ../../../../../_source/_plots/skill_score_bootstrap.png
717 :width: 400
719 .. image:: ../../../../../_source/_plots/skill_score_bootstrap_separated.png
720 :width: 400
722 """
724 def __init__(self, data: Dict, plot_folder: str = ".", separate_vars: List = None, sampling: str = "daily",
725 ahead_dim: str = "ahead", bootstrap_type: str = None, bootstrap_method: str = None,
726 boot_dim: str = "boots", model_name: str = "NN", branch_names: list = None, ylim: tuple = None):
727 """
728 Set attributes and create plot.
730 :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms.
731 :param plot_folder: path to save the plot (default: current directory)
732 :param separate_vars: variables to plot separated (default: ['o3'])
733 :param sampling: type of sampling rate, should be either hourly or daily (default: "daily")
734 :param ahead_dim: name of the ahead dimensions (default: "ahead")
735 :param bootstrap_annotation: additional information to use in the file name (default: None)
736 :param model_name: architecture type to specify plot name (default "NN")
737 """
738 annotation = ["_".join([s for s in ["", bootstrap_type, bootstrap_method] if s is not None])][0]
739 super().__init__(plot_folder, f"feature_importance_{model_name}{annotation}")
740 if separate_vars is None:
741 separate_vars = ['o3']
742 self._labels = None
743 self._x_name = "boot_var"
744 self._ahead_dim = ahead_dim
745 self._boot_dim = boot_dim
746 self._boot_type = self._set_bootstrap_type(bootstrap_type)
747 self._boot_method = self._set_bootstrap_method(bootstrap_method)
748 self._number_of_bootstraps = 0
749 self._branches_names = branch_names
750 self._ylim = ylim
752 self._data = self._prepare_data(data, sampling)
753 self._set_title(model_name)
754 if "branch" in self._data.columns:
755 plot_name = self.plot_name
756 for branch in self._data["branch"].unique():
757 self._set_title(model_name, branch, len(self._data["branch"].unique()))
758 self.plot_name = f"{plot_name}_{branch}"
759 try:
760 self._plot(branch=branch)
761 self._save()
762 except ValueError as e:
763 logging.info(f"Did not plot {self.plot_name} because of {e}")
764 if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0:
765 self.plot_name += '_separated'
766 try:
767 self._plot(branch=branch, separate_vars=separate_vars)
768 self._save(bbox_inches='tight')
769 except ValueError as e:
770 logging.info(f"Did not plot {self.plot_name} because of {e}")
771 else:
772 try:
773 self._plot()
774 self._save()
775 except ValueError as e:
776 logging.info(f"Did not plot {self.plot_name} because of {e}")
777 if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0:
778 self.plot_name += '_separated'
779 try:
780 self._plot(separate_vars=separate_vars)
781 self._save(bbox_inches='tight')
782 except ValueError as e:
783 logging.info(f"Did not plot {self.plot_name} because of {e}")
785 @staticmethod
786 def _set_bootstrap_type(boot_type):
787 return {"singleinput": "single input"}.get(boot_type, boot_type)
789 def _set_title(self, model_name, branch=None, n_branches=None):
790 title_d = {"single input": "Single Inputs", "branch": "Input Branches", "variable": "Variables"}
791 base_title = f"{model_name}\nImportance of {title_d[self._boot_type]}"
793 additional = []
794 if branch is not None:
795 try:
796 assert n_branches == len(self._branches_names)
797 branch_name = self._branches_names[int(branch)]
798 except (IndexError, TypeError, ValueError, AssertionError):
799 branch_name = branch
800 additional.append(branch_name)
801 if self._number_of_bootstraps > 1:
802 additional.append(f"n={self._number_of_bootstraps}")
803 additional_title = ", ".join(additional)
804 if len(additional_title) > 0:
805 additional_title = f" ({additional_title})"
806 self._title = base_title + additional_title
808 @staticmethod
809 def _set_bootstrap_method(boot_method):
810 return {"zero_mean": "zero mean", "shuffle": "shuffled"}.get(boot_method, boot_method)
812 def _prepare_data(self, data: Dict, sampling: str) -> pd.DataFrame:
813 """
814 Shrink given data, if only scores are relevant.
816 In any case, transform data to a plot friendly format. Also set plot labels depending on the lead time
817 dimensions.
819 :param data: dictionary with station names as keys and 2D xarrays as values
820 :return: pre-processed data set
821 """
822 station_dim = "station"
823 data = helpers.dict_to_xarray(data, station_dim).sortby(self._x_name)
824 data = data.transpose(station_dim, self._ahead_dim, self._boot_dim, self._x_name)
825 if self._boot_type == "single input":
826 number_tags = self._get_number_tag(data.coords[self._x_name].values, split_by='_')
827 new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_',
828 keep=1, as_unique=True)
829 try:
830 values = data.values.reshape((*data.shape[:3], len(number_tags), len(new_boot_coords)))
831 data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords,
832 "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim],
833 self._boot_dim: data.coords[self._boot_dim]},
834 dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name])
835 except ValueError:
836 data_coll = []
837 for nr in number_tags:
838 filtered_coords = list(filter(lambda x: nr in x.split("_")[0], data.coords[self._x_name].values))
839 new_boot_coords = self._return_vars_without_number_tag(filtered_coords, split_by='_', keep=1,
840 as_unique=True)
841 sel_data = data.sel({self._x_name: filtered_coords})
842 values = sel_data.values.reshape((*data.shape[:3], 1, len(new_boot_coords)))
843 sel_data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords,
844 "branch": [nr], self._ahead_dim: data.coords[self._ahead_dim],
845 self._boot_dim: data.coords[self._boot_dim]},
846 dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name])
847 data_coll.append(sel_data)
848 data = xr.concat(data_coll, "branch")
849 else:
850 try:
851 new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_',
852 keep=1)
853 data = data.assign_coords({self._x_name: new_boot_coords})
854 except NotImplementedError:
855 pass
856 _, sampling_letter = self._get_sampling(sampling, 1)
857 sampling_letter = {"d": "D", "h": "H"}.get(sampling_letter, sampling_letter)
858 self._labels = [sampling_letter + str(i) for i in data.coords[self._ahead_dim].values]
859 if station_dim not in data.dims:
860 data = data.expand_dims(station_dim)
861 self._number_of_bootstraps = np.unique(data.coords[self._boot_dim].values).shape[0]
862 return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()).dropna()
864 def _return_vars_without_number_tag(self, values, split_by, keep, as_unique=False):
865 arr = np.array([v.split(split_by) for v in values])
866 num = arr[:, 0]
867 if arr.shape[keep] == 1: # keep dim has only length 1, no number tags required
868 return num
869 new_val = arr[:, keep]
870 if self._all_values_are_equal(num, axis=0):
871 return new_val
872 elif as_unique is True:
873 return np.unique(new_val)
874 else:
875 raise NotImplementedError
877 @staticmethod
878 def _get_number_tag(values, split_by):
879 arr = np.array([v.split(split_by) for v in values])
880 num = arr[:, 0]
881 return np.unique(num).tolist()
883 @staticmethod
884 def _all_values_are_equal(arr, axis=0):
885 if np.all(arr == arr[0], axis=axis):
886 return True
887 else:
888 return False
890 def _label_add(self, score_only: bool):
891 """
892 Add the phrase "terms and " if score_only is disabled or empty string (if score_only=True).
894 :param score_only: if false all terms are relevant, otherwise only CASE I to IV
895 :return: additional label
896 """
897 return "" if score_only else "terms and "
899 def _plot(self, branch=None, separate_vars=None):
900 """Plot climatological skill score."""
901 if separate_vars is None:
902 self._plot_all_variables(branch)
903 else:
904 self._plot_selected_variables(separate_vars, branch)
906 def _plot_selected_variables(self, separate_vars: List, branch=None):
907 data = self._data if branch is None else self._data[self._data["branch"] == str(branch)]
908 self.raise_error_if_vars_do_not_exist(data, separate_vars, self._x_name, name="separate_vars")
909 all_variables = self._get_unique_values_from_column_of_df(data, self._x_name)
910 remaining_vars = helpers.remove_items(all_variables, separate_vars)
911 self.raise_error_if_vars_do_not_exist(data, remaining_vars, self._x_name, name="remaining_vars")
912 data_first = self._select_data(df=data, variables=separate_vars, column_name=self._x_name)
913 data_second = self._select_data(df=data, variables=remaining_vars, column_name=self._x_name)
915 fig, ax = plt.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [len(separate_vars),
916 len(remaining_vars)]},
917 figsize=(len(remaining_vars),len(remaining_vars)/2.))
918 if len(separate_vars) > 1:
919 first_box_width = .8
920 else:
921 first_box_width = .8
923 sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_first, ax=ax[0], whis=1.5,
924 palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
925 showfliers=False, width=first_box_width)
926 ax[0].set(ylabel=f"skill score", xlabel="")
927 if self._ylim is not None:
928 _ylim = self._ylim if isinstance(self._ylim, tuple) else self._ylim[0]
929 ax[0].set(ylim=_ylim)
931 sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_second, ax=ax[1], whis=1.5,
932 palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
933 showfliers=False)
934 ax[1].set(ylabel="", xlabel="")
935 ax[1].yaxis.tick_right()
936 if self._ylim is not None and isinstance(self._ylim, list):
937 _ylim = self._ylim[1]
938 ax[1].set(ylim=_ylim)
940 handles, _ = ax[1].get_legend_handles_labels()
941 for sax in ax:
942 matplotlib.pyplot.sca(sax)
943 sax.axhline(y=0, color="grey", linewidth=.5)
944 plt.xticks(rotation=45, ha='right')
945 sax.legend_.remove()
947 # fig.legend(handles, self._labels, loc='upper center', ncol=len(handles) + 1, )
948 ax[1].legend(handles, self._labels, loc='lower center', ncol=len(handles) + 1, fontsize="medium")
950 def align_yaxis(ax1, ax2):
951 """
952 Align zeros of the two axes, zooming them out by same ratio
954 This function is copy pasted from https://stackoverflow.com/a/41259922
955 """
956 axes = (ax1, ax2)
957 extrema = [ax.get_ylim() for ax in axes]
958 tops = [extr[1] / (extr[1] - extr[0]) for extr in extrema]
959 # Ensure that plots (intervals) are ordered bottom to top:
960 if tops[0] > tops[1]:
961 axes, extrema, tops = [list(reversed(l)) for l in (axes, extrema, tops)]
963 # How much would the plot overflow if we kept current zoom levels?
964 tot_span = tops[1] + 1 - tops[0]
966 b_new_t = extrema[0][0] + tot_span * (extrema[0][1] - extrema[0][0])
967 t_new_b = extrema[1][1] - tot_span * (extrema[1][1] - extrema[1][0])
968 axes[0].set_ylim(extrema[0][0], b_new_t)
969 axes[1].set_ylim(t_new_b, extrema[1][1])
971 align_yaxis(ax[0], ax[1])
972 align_yaxis(ax[0], ax[1])
973 plt.subplots_adjust(right=0.8)
974 plt.title(self._title)
976 @staticmethod
977 def _select_data(df: pd.DataFrame, variables: List[str], column_name: str) -> pd.DataFrame:
978 selected_data = None
979 for i, variable in enumerate(variables):
980 if i == 0:
981 selected_data = df.loc[df[column_name] == variable]
982 else:
983 tmp_var = df.loc[df[column_name] == variable]
984 selected_data = pd.concat([selected_data, tmp_var], axis=0)
985 return selected_data
987 def raise_error_if_vars_do_not_exist(self, data, vars, column_name, name="separate_vars"):
988 if len(vars) == 0:
989 msg = f"No variables are given for `{name}' to check in `self.data' "
990 raise ValueError(msg)
991 if not self._variables_exist_in_df(df=data, variables=vars, column_name=column_name):
992 msg = f"At least one entry of `{name}' does not exist in `self.data' "
993 raise ValueError(msg)
995 @staticmethod
996 def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List:
997 return list(df[column_name].unique())
999 def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str):
1000 vars_in_df = set(self._get_unique_values_from_column_of_df(df, column_name))
1001 return set(variables).issubset(vars_in_df)
1003 def _plot_all_variables(self, branch=None):
1004 """
1006 """
1007 plot_data = self._data if branch is None else self._data[self._data["branch"] == str(branch)]
1008 if self._boot_type == "branch":
1009 fig, ax = plt.subplots(figsize=(0.5 + 2 / len(plot_data[self._x_name].unique()) + len(plot_data[self._x_name].unique()),4))
1010 sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1.,
1011 palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
1012 showfliers=False, width=0.8)
1013 else:
1014 fig, ax = plt.subplots()
1015 sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1.5, palette="Blues_d",
1016 showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, showfliers=False)
1017 ax.axhline(y=0, color="grey", linewidth=.5)
1019 if self._ylim is not None:
1020 if isinstance(self._ylim, tuple):
1021 _ylim = self._ylim
1022 else:
1023 _ylim = (min(self._ylim[0][0], self._ylim[1][0]), max(self._ylim[0][1], self._ylim[1][1]))
1024 ax.set(ylim=_ylim)
1026 if self._boot_type == "branch":
1027 plt.xticks()
1028 else:
1029 plt.xticks(rotation=45)
1030 ax.set(ylabel=f"skill score", xlabel="", title=self._title)
1031 handles, _ = ax.get_legend_handles_labels()
1032 ax.legend(handles, self._labels)
1033 plt.tight_layout()
1036@TimeTrackingWrapper
1037class PlotTimeSeries: # pragma: no cover
1038 """
1039 Create time series plot.
1041 Currently, plots are under development and not well designed for any use in public.
1042 """
1044 def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = ".",
1045 sampling="daily", model_name="nn", obs_name="obs", ahead_dim="ahead"):
1046 """Initialise."""
1047 self._data_path = data_path
1048 self._data_name = name
1049 self._stations = stations
1050 self._model_name = model_name
1051 self._obs_name = obs_name
1052 self._ahead_dim = ahead_dim
1053 self._window_lead_time = self._get_window_lead_time(window_lead_time)
1054 self._sampling = self._get_sampling(sampling)
1055 self._plot(plot_folder)
1057 @staticmethod
1058 def _get_sampling(sampling):
1059 if sampling == "daily":
1060 return "D"
1061 elif sampling == "hourly":
1062 return "h"
1064 def _get_window_lead_time(self, window_lead_time: int):
1065 """
1066 Extract the lead time from data and arguments.
1068 If window_lead_time is not given, extract this information from data itself by the number of ahead dimensions.
1069 If given, check if data supports the give length. If the number of ahead dimensions in data is lower than the
1070 given lead time, data's lead time is used.
1072 :param window_lead_time: lead time from arguments to validate
1073 :return: validated lead time, comes either from given argument or from data itself
1074 """
1075 ahead_steps = len(self._load_data(self._stations[0]).coords[self._ahead_dim])
1076 if window_lead_time is None:
1077 window_lead_time = ahead_steps
1078 return min(ahead_steps, window_lead_time)
1080 def _load_data(self, station):
1081 logging.debug(f"... preprocess station {station}")
1082 file_name = os.path.join(self._data_path, self._data_name % station)
1083 data = xr.open_dataarray(file_name)
1084 return data.sel(type=[self._model_name, self._obs_name])
1086 def _plot(self, plot_folder):
1087 pdf_pages = self._create_pdf_pages(plot_folder)
1088 for pos, station in enumerate(self._stations):
1089 data = self._load_data(station)
1090 start, end = self._get_time_range(data)
1091 fig, axes, factor = self._create_subplots(start, end)
1092 nan_list = []
1093 for i_year in range(end - start + 1):
1094 data_year = data.sel(index=f"{start + i_year}")
1095 for i_half_of_year in range(factor):
1096 pos = factor * i_year + i_half_of_year
1097 try:
1098 plot_data = self._create_plot_data(data_year, factor, i_half_of_year)
1099 self._plot_obs(axes[pos], plot_data)
1100 self._plot_ahead(axes[pos], plot_data)
1101 if np.isnan(plot_data.values).all():
1102 nan_list.append(pos)
1103 except Exception:
1104 nan_list.append(pos)
1105 self._clean_up_axes(nan_list, axes, fig)
1106 self._save_page(station, pdf_pages)
1107 pdf_pages.close()
1108 plt.close('all')
1110 @staticmethod
1111 def _clean_up_axes(nan_list, axes, fig):
1112 for i in reversed(nan_list):
1113 fig.delaxes(axes[i])
1115 @staticmethod
1116 def _save_page(station, pdf_pages):
1117 plt.suptitle(station)
1118 plt.legend()
1119 plt.tight_layout()
1120 pdf_pages.savefig(dpi=500)
1122 @staticmethod
1123 def _create_plot_data(data, factor, running_index):
1124 if factor > 1:
1125 if running_index == 0:
1126 data = data.where(data["index.month"] < 7)
1127 else:
1128 data = data.where(data["index.month"] >= 7)
1129 return data
1131 def _create_subplots(self, start, end):
1132 factor = 1
1133 if self._sampling == "h":
1134 factor = 2
1135 f, ax = plt.subplots((end - start + 1) * factor, sharey=True, figsize=(50, 30), squeeze=False)
1136 return f, ax[:, 0], factor
1138 def _plot_ahead(self, ax, data):
1139 color = sns.color_palette("Blues_d", self._window_lead_time).as_hex()
1140 for ahead in data.coords[self._ahead_dim].values:
1141 plot_data = data.sel({"type": self._model_name, self._ahead_dim: ahead}).drop(["type", self._ahead_dim]).squeeze().shift(index=ahead)
1142 sampling_letter = {"d": "D", "h": "H"}.get(self._sampling, self._sampling)
1143 label = f"{sampling_letter}{ahead}"
1144 ax.plot(plot_data, color=color[ahead - 1], label=label)
1146 def _plot_obs(self, ax, data):
1147 ahead = 1
1148 obs_data = data.sel(type="obs", ahead=ahead).shift(index=ahead)
1149 ax.plot(obs_data, color=colors.cnames["green"], label="obs")
1151 @staticmethod
1152 def _get_time_range(data):
1153 def f(x, f_x):
1154 return pd.to_datetime(f_x(x.index.values)).year
1156 return f(data, min), f(data, max)
1158 @staticmethod
1159 def _create_pdf_pages(plot_folder: str):
1160 """
1161 Store plot locally.
1163 :param plot_folder: path to save the plot
1164 """
1165 plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf')
1166 logging.debug(f"... save plot to {plot_name}")
1167 return matplotlib.backends.backend_pdf.PdfPages(plot_name)
1170@TimeTrackingWrapper
1171class PlotSeparationOfScales(AbstractPlotClass): # pragma: no cover
1173 def __init__(self, collection: DataCollection, plot_folder: str = ".", time_dim="datetime", window_dim="window",
1174 filter_dim="filter", target_dim="variables"):
1175 """Initialise."""
1176 # create standard Gantt plot for all stations (currently in single pdf file with single page)
1177 plot_folder = os.path.join(plot_folder, "separation_of_scales")
1178 super().__init__(plot_folder, "separation_of_scales")
1179 self.time_dim = time_dim
1180 self.window_dim = window_dim
1181 self.filter_dim = filter_dim
1182 self.target_dim = target_dim
1183 self._plot(collection)
1185 def _plot(self, collection: DataCollection):
1186 orig_plot_name = self.plot_name
1187 for dh in collection:
1188 data = dh.get_X(as_numpy=False)[0]
1189 station = dh.id_class.station[0]
1190 data = data.sel(Stations=station)
1191 data.plot(x=self.time_dim, y=self.window_dim, col=self.filter_dim, row=self.target_dim, robust=True)
1192 self.plot_name = f"{orig_plot_name}_{station}"
1193 self._save()
1196@TimeTrackingWrapper
1197class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
1199 def __init__(self, data: xr.DataArray, plot_folder: str = ".", model_type_dim: str = "type",
1200 error_measure: str = "mse", error_unit: str = None, dim_name_boots: str = 'boots',
1201 block_length: str = None, model_name: str = "NN", model_indicator: str = "nn",
1202 ahead_dim: str = "ahead", sampling: Union[str, Tuple[str]] = "", season_annotation: str = None,
1203 apply_root: bool = True, plot_name="sample_uncertainty_from_bootstrap"):
1204 super().__init__(plot_folder, plot_name)
1205 self.default_plot_name = self.plot_name
1206 self.model_type_dim = model_type_dim
1207 self.ahead_dim = ahead_dim
1208 self.error_measure = error_measure
1209 self.dim_name_boots = dim_name_boots
1210 self.error_unit = error_unit
1211 self.block_length = block_length
1212 self.model_name = model_name
1213 _season = season_annotation or ""
1214 self.sampling = {"daily": "d", "hourly": "H"}.get(sampling[1] if isinstance(sampling, tuple) else sampling, "")
1215 data = self.rename_model_indicator(data, model_name, model_indicator)
1216 self.prepare_data(data)
1218 # create all combinations to plot (h/v, utest/notest, single/multi)
1219 variants = list(itertools.product(*[["v", "h"], [True, False], ["single", "multi", "panel"]]))
1221 # plot raw metric (mse)
1222 for orientation, utest, agg_type in variants:
1223 self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, season=_season)
1224 self._plot_kde(agg_type=agg_type, season=_season)
1226 if apply_root is True:
1227 # plot root of metric (rmse)
1228 self._apply_root()
1229 for orientation, utest, agg_type in variants:
1230 self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, tag="_sqrt", season=_season)
1231 self._plot_kde(agg_type=agg_type, tag="_sqrt", season=_season)
1233 self._data_table = None
1234 self._n_boots = None
1235 self._factor = None
1237 @property
1238 def get_asteriks_from_mann_whitney_u_result(self):
1239 return represent_p_values_as_asteriks(mann_whitney_u_test(data=self._data_table,
1240 reference_col_name=self.model_name,
1241 axis=0, alternative="two-sided").iloc[-1])
1243 def rename_model_indicator(self, data, model_name, model_indicator):
1244 data.coords[self.model_type_dim] = [{model_indicator: model_name}.get(n, n)
1245 for n in data.coords[self.model_type_dim].values]
1246 return data
1248 def prepare_data(self, data: xr.DataArray):
1249 data_table = data.to_dataframe(self.model_type_dim).unstack()
1250 factor = len(data.coords[self.ahead_dim]) if self.ahead_dim in data.dims else 1
1251 self._data_table = data_table[data_table.mean().sort_values().index].droplevel(0, axis=1)
1252 self._n_boots = int(self._data_table.shape[0] / factor)
1253 self._factor = factor
1255 def _apply_root(self):
1256 self._data_table = np.sqrt(self._data_table)
1257 self.error_measure = f"root {self.error_measure}"
1258 self.error_unit = self.error_unit.replace("$^2$", "")
1260 def _plot_kde(self, agg_type="single", tag="", season=""):
1261 self.plot_name = self.default_plot_name + "_kde" + "_" + agg_type + tag + {"": ""}.get(season, f"_{season}")
1262 data_table = self._data_table
1263 if agg_type == "multi":
1264 return # nothing to do
1265 if self.ahead_dim not in data_table.index.names and agg_type == "panel":
1266 return # nothing to do
1267 n_boots = self._n_boots
1268 error_label = self.error_measure if self.error_unit is None else f"{self.error_measure} (in {self.error_unit})"
1269 sampling_letter = {"d": "D", "h": "H"}.get(self.sampling, self.sampling)
1270 if agg_type == "single":
1271 fig, ax = plt.subplots()
1272 if self.ahead_dim in data_table.index.names:
1273 data_table = data_table.groupby(level=0).mean()
1274 sns.kdeplot(data=data_table, ax=ax)
1275 ylims = list(ax.get_ylim())
1276 ax.set_ylim([ylims[0], ylims[1]*1.025])
1277 ax.set_xlabel(error_label)
1279 text = f"n={n_boots}"
1280 if self.block_length is not None:
1281 text = f"{self.block_length}, {text}"
1282 if len(season) > 0:
1283 text = f"{season}, {text}"
1284 loc = "upper left"
1285 text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5, bbox_to_anchor=(0., 1.0),
1286 bbox_transform=ax.transAxes)
1287 plt.setp(text_box.patch, edgecolor='k', facecolor='w')
1288 ax.add_artist(text_box)
1289 ax.get_legend().set_title(None)
1290 plt.tight_layout()
1291 else:
1292 g = sns.FacetGrid(data_table.stack(self.model_type_dim).reset_index(), **{"col": self.ahead_dim},
1293 hue=self.model_type_dim, legend_out=True)
1294 g.map(sns.kdeplot, 0)
1295 g.add_legend(title="")
1296 fig = plt.gcf()
1297 _labels = [sampling_letter + str(i) for i in data_table.index.levels[1].values]
1298 for axi, title in zip(g.axes.flatten(), _labels):
1299 axi.set_title(title)
1300 for axi in g.axes.flatten():
1301 axi.set_xlabel(None)
1302 fig.supxlabel(error_label)
1303 text = f"n={n_boots}"
1304 if self.block_length is not None:
1305 text = f"{self.block_length}, {text}"
1306 if len(season) > 0:
1307 text = f"{season}, {text}"
1308 loc = "upper right"
1309 text_box = AnchoredText(text, frameon=True, loc=loc)
1310 plt.setp(text_box.patch, edgecolor='k', facecolor='w')
1311 g.axes.flatten()[0].add_artist(text_box)
1312 self._save()
1313 plt.close("all")
1315 def _plot(self, orientation: str = "v", apply_u_test: bool = False, agg_type="single", tag="", season=""):
1316 self.plot_name = self.default_plot_name + {"v": "_vertical", "h": "_horizontal"}[orientation] + \
1317 {True: "_u_test", False: ""}[apply_u_test] + "_" + agg_type + tag + \
1318 {"": ""}.get(season, f"_{season}")
1319 if apply_u_test is True and agg_type == "multi":
1320 return # not implemented
1321 data_table = self._data_table
1322 if self.ahead_dim not in data_table.index.names and agg_type in ["multi", "panel"]:
1323 return # nothing to do
1324 if apply_u_test is True and agg_type == "panel":
1325 return # nothing to do
1326 n_boots = self._n_boots
1327 size = len(np.unique(data_table.columns))
1328 asteriks = self.get_asteriks_from_mann_whitney_u_result if apply_u_test is True else None
1329 color_palette = sns.color_palette("Blues_d", self._factor).as_hex()
1330 sampling_letter = {"d": "D", "h": "H"}.get(self.sampling, self.sampling)
1331 if orientation == "v":
1332 figsize, width = (size, 5), 0.4
1333 elif orientation == "h":
1334 if agg_type == "multi":
1335 size *= np.sqrt(len(data_table.index.unique(self.ahead_dim)))
1336 size = max(size, 8)
1337 figsize, width = (7, (1+.5*size)), 0.65
1338 else:
1339 raise ValueError(f"orientation must be `v' or `h' but is: {orientation}")
1340 fig, ax = plt.subplots(figsize=figsize)
1341 if agg_type == "single":
1342 if self.ahead_dim in data_table.index.names:
1343 data_table = data_table.groupby(level=0).mean()
1344 sns.boxplot(data=data_table, ax=ax, whis=1.5, color="white",
1345 showmeans=True, meanprops={"markersize": 6, "markeredgecolor": "k"},
1346 flierprops={"marker": "o", "markerfacecolor": "black", "markeredgecolor": "none", "markersize": 3},
1347 boxprops={'facecolor': 'none', 'edgecolor': 'k'}, width=width, orient=orientation)
1348 elif agg_type == "multi":
1349 xy = {"x": self.model_type_dim, "y": 0} if orientation == "v" else {"x": 0, "y": self.model_type_dim}
1350 sns.boxplot(data=data_table.stack(self.model_type_dim).reset_index(), ax=ax, whis=1.5, palette=color_palette,
1351 showmeans=True, meanprops={"markersize": 6, "markeredgecolor": "k", "markerfacecolor": "white"},
1352 flierprops={"marker": "o", "markerfacecolor": "black", "markeredgecolor": "none", "markersize": 3},
1353 boxprops={'edgecolor': 'k'}, width=.8, orient=orientation, **xy, hue=self.ahead_dim)
1355 _labels = [sampling_letter + str(i) for i in data_table.index.levels[1].values]
1356 handles, _ = ax.get_legend_handles_labels()
1357 ax.legend(handles, _labels)
1358 else:
1359 xy = (self.model_type_dim, 0) if orientation == "v" else (0, self.model_type_dim)
1360 col_or_row = {"col": self.ahead_dim} if orientation == "v" else {"row": self.ahead_dim}
1361 aspect = figsize[0] / figsize[1]
1362 height = figsize[1] * 0.8
1363 ax = sns.FacetGrid(data_table.stack(self.model_type_dim).reset_index(), **col_or_row, aspect=aspect, height=height)
1364 ax.map(sns.boxplot, *xy, whis=1.5, color="white", showmeans=True, order=data_table.mean().index.to_list(),
1365 meanprops={"markersize": 6, "markeredgecolor": "k"},
1366 flierprops={"marker": "o", "markerfacecolor": "black", "markeredgecolor": "none", "markersize": 3},
1367 boxprops={'facecolor': 'none', 'edgecolor': 'k'}, width=width, orient=orientation)
1369 _labels = [sampling_letter + str(i) for i in data_table.index.levels[1].values]
1370 for axi, title in zip(ax.axes.flatten(), _labels):
1371 axi.set_title(title)
1372 plt.setp(axi.lines, color='k')
1374 error_label = self.error_measure if self.error_unit is None else f"{self.error_measure} (in {self.error_unit})"
1375 if agg_type == "panel":
1376 if orientation == "v":
1377 for axi in ax.axes.flatten():
1378 axi.set_xlabel(None)
1379 axi.set_xticklabels(axi.get_xticklabels(), rotation=45)
1380 ax.set_ylabels(error_label)
1381 loc = "upper left"
1382 else:
1383 for axi in ax.axes.flatten():
1384 axi.set_ylabel(None)
1385 ax.set_xlabels(error_label)
1386 loc = "upper right"
1387 text = f"n={n_boots}"
1388 if self.block_length is not None:
1389 text = f"{self.block_length}, {text}"
1390 if len(season) > 0:
1391 text = f"{season}, {text}"
1392 text_box = AnchoredText(text, frameon=True, loc=loc)
1393 plt.setp(text_box.patch, edgecolor='k', facecolor='w')
1394 ax.axes.flatten()[0].add_artist(text_box)
1395 else:
1396 if orientation == "v":
1397 if apply_u_test:
1398 ax = self.set_significance_bars(asteriks, ax, data_table, orientation)
1399 ylims = list(ax.get_ylim())
1400 ax.set_ylim([ylims[0], ylims[1]*1.025])
1401 ax.set_ylabel(error_label)
1402 ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
1403 ax.set_xlabel(None)
1404 elif orientation == "h":
1405 if apply_u_test:
1406 ax = self.set_significance_bars(asteriks, ax, data_table, orientation)
1407 ax.set_xlabel(error_label)
1408 xlims = list(ax.get_xlim())
1409 ax.set_xlim([xlims[0], xlims[1] * 1.015])
1410 ax.set_ylabel(None)
1411 else:
1412 raise ValueError(f"orientation must be `v' or `h' but is: {orientation}")
1413 text = f"n={n_boots}"
1414 if self.block_length is not None:
1415 text = f"{self.block_length}, {text}"
1416 if len(season) > 0:
1417 text = f"{season}, {text}"
1418 loc = "lower left"
1419 text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5, bbox_to_anchor=(0., 1.0),
1420 bbox_transform=ax.transAxes)
1421 plt.setp(text_box.patch, edgecolor='k', facecolor='w')
1422 ax.add_artist(text_box)
1423 plt.setp(ax.lines, color='k')
1424 plt.tight_layout()
1425 self._save()
1426 plt.close("all")
1428 def set_significance_bars(self, asteriks, ax, data_table, orientation):
1429 p1 = list(asteriks.index).index(self.model_name)
1430 q_prev = 0.
1431 factor = 0.025
1432 for i, ast in enumerate(asteriks):
1433 if not i == list(asteriks.index).index(self.model_name):
1434 p2 = i
1435 q = data_table[[self.model_name, data_table.columns[i]]].max().max()
1436 q = max(q, q_prev) * (1 + factor)
1437 if abs(q - q_prev) < q * factor:
1438 q = q * (1 + factor)
1439 h = 0.01 * data_table.max().max()
1440 if orientation == "h":
1441 ax.plot([q, q + h, q + h, q], [p1, p1, p2, p2], c="k")
1442 ax.text(q + h, (p1 + p2) * 0.5, ast, ha="left", va="center", color="k", rotation=-90)
1443 elif orientation == "v":
1444 ax.plot([p1, p1, p2, p2], [q, q + h, q + h, q], c="k")
1445 ax.text((p1 + p2) * 0.5, q + h, ast, ha="center", va="bottom", color="k")
1446 q_prev = q
1447 return ax
1450@TimeTrackingWrapper
1451class PlotTimeEvolutionMetric(AbstractPlotClass):
1453 def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type", plot_folder=".",
1454 error_measure: str = "mse", error_unit: str = None, model_name: str = "NN",
1455 model_indicator: str = "nn", time_dim="index"):
1456 super().__init__(plot_folder, "time_evolution_mse")
1457 self.title = error_measure + f" (in {error_unit})" if error_unit is not None else ""
1458 plot_name = self.plot_name
1459 vmin = int(data.quantile(0.05))
1460 vmax = int(data.quantile(0.95))
1461 data = self._prepare_data(data, time_dim, model_type_dim, model_indicator, model_name)
1463 # detailed plot for each model type
1464 for t in data[model_type_dim]:
1465 # note: could be expanded to create plot per ahead step
1466 plot_data = data.sel({model_type_dim: t}).mean(ahead_dim).to_pandas()
1467 years = plot_data.columns.strftime("%Y").to_list()
1468 months = plot_data.columns.strftime("%b").to_list()
1469 plot_data.columns = plot_data.columns.strftime("%b %Y")
1470 self.plot_name = f"{plot_name}_{t.values}"
1471 self._plot(plot_data, years, months, vmin, vmax, str(t.values))
1473 # aggregated version with all model types
1474 remaining_dim = set(data.dims).difference((model_type_dim, time_dim))
1475 _data = data.mean(remaining_dim, skipna=True).transpose(model_type_dim, time_dim)
1476 vmin = int(_data.quantile(0.05))
1477 vmax = int(_data.quantile(0.95))
1478 plot_data = _data.to_pandas()
1479 years = plot_data.columns.strftime("%Y").to_list()
1480 months = plot_data.columns.strftime("%b").to_list()
1481 plot_data.columns = plot_data.columns.strftime("%b %Y")
1482 self.plot_name = f"{plot_name}_summary"
1483 self._plot(plot_data, years, months, vmin, vmax, None)
1485 # line plot version
1486 y_dim = "error"
1487 plot_data = data.to_dataset(name=y_dim).to_dataframe().reset_index()
1488 self.plot_name = f"{plot_name}_line_plot"
1489 self._plot_summary_line(plot_data, x_dim=time_dim, y_dim=y_dim, hue_dim=model_type_dim)
1491 @staticmethod
1492 def _find_nan_edge(data, time_dim):
1493 coll = []
1494 for i in data:
1495 if bool(i) is False:
1496 break
1497 else:
1498 coll.append(i[time_dim].values)
1499 return coll
1501 def _prepare_data(self, data, time_dim, model_type_dim, model_indicator, model_name):
1502 # remove nans at begin and end
1503 nan_locs = data.isnull().all(helpers.remove_items(data.dims, time_dim))
1504 nans_at_end = self._find_nan_edge(reversed(nan_locs), time_dim)
1505 nans_at_begin = self._find_nan_edge(nan_locs, time_dim)
1506 data = data.drop(nans_at_begin + nans_at_end, time_dim)
1507 # rename nn model
1508 data[model_type_dim] = [v if v != model_indicator else model_name for v in data[model_type_dim].data.tolist()]
1509 return data
1511 @staticmethod
1512 def _set_ticks(ax, years, months):
1513 from matplotlib.ticker import IndexLocator
1514 ax.xaxis.set_major_locator(IndexLocator(1, 0.5))
1515 locs = ax.get_xticks(minor=False).tolist()[:len(months)]
1516 ax.set_xticks(locs, minor=True)
1517 ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0)
1518 locs_major = []
1519 labels_major = []
1520 for l, major, minor in zip(locs, years, months):
1521 if minor == "Jan":
1522 locs_major.append(l + 0.001)
1523 labels_major.append(major)
1524 if len(locs_major) == 0: # in case there is less than a year and no Jan included
1525 locs_major = locs[0] + 0.001
1526 labels_major = years[0]
1527 ax.set_xticks(locs_major)
1528 ax.set_xticklabels(labels_major, minor=False, rotation=0)
1529 ax.tick_params(axis="x", which="major", pad=15)
1531 @staticmethod
1532 def _aspect_cbar(val):
1533 return min(max(1.25 * val + 7.5, 5), 30)
1535 def _plot(self, data, years, months, vmin=None, vmax=None, subtitle=None):
1536 fig, ax = plt.subplots(figsize=(max(data.shape[1] / 6, 12), max(data.shape[0] / 3.5, 2)))
1537 data.sort_index(inplace=True)
1538 sns.heatmap(data, linewidths=1, cmap="coolwarm", ax=ax, vmin=vmin, vmax=vmax,
1539 cbar_kws={"aspect": self._aspect_cbar(data.shape[0])})
1540 # or cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm",
1541 # square=True
1542 self._set_ticks(ax, years, months)
1543 ax.set(xlabel=None, ylabel=None, title=self.title if subtitle is None else f"{subtitle}\n{self.title}")
1544 plt.tight_layout()
1545 self._save()
1547 def _plot_summary_line(self, data, x_dim, y_dim, hue_dim):
1548 data[x_dim] = pd.to_datetime(data[x_dim].dt.strftime('%Y-%m')) #???
1549 n = len(data[hue_dim].unique())
1550 ax = sns.lineplot(data=data, x=x_dim, y=y_dim, hue=hue_dim, errorbar=("pi", 50),
1551 palette=sns.color_palette()[:n], style=hue_dim, dashes=False, markers=["X"]*n)
1552 ax.set(xlabel=None, ylabel=self.title)
1553 ax.get_legend().set_title(None)
1554 ax.xaxis.set_major_locator(mdates.YearLocator())
1555 ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
1556 ax.xaxis.set_minor_locator(mdates.MonthLocator(bymonth=range(1, 13, 3)))
1557 ax.xaxis.set_minor_formatter(mdates.DateFormatter('%b'))
1558 ax.margins(x=0)
1559 plt.tight_layout()
1560 self._save()
1563@TimeTrackingWrapper
1564class PlotSeasonalMSEStack(AbstractPlotClass):
1566 def __init__(self, data, data_path: str, plot_folder: str = ".", boot_dim="boots", ahead_dim="ahead",
1567 sampling: str = "daily", error_measure: str = "MSE", error_unit: str = "ppb$^2$", time_dim="index",
1568 model_type_dim: str = "type", model_name: str = "NN", model_indicator: str = "nn",):
1569 """Set attributes and create plot."""
1570 super().__init__(plot_folder, "seasonal_mse_stack_plot")
1571 self._data_path = data_path
1572 self.season_dim = "season"
1573 self.time_dim = time_dim
1574 self.ahead_dim = ahead_dim
1575 self.error_unit = error_unit
1576 self.error_measure = error_measure
1577 self.dim_order = [self.season_dim, ahead_dim, model_type_dim]
1579 # mse from monthly blocks
1580 self.plot_name_orig = "seasonal_mse_stack_plot"
1581 self._data = self._prepare_data(data)
1582 for orientation in ["horizontal", "vertical"]:
1583 for split_ahead in [True, False]:
1584 self._plot(ahead_dim, split_ahead, sampling, orientation)
1585 self._save(bbox_inches="tight")
1587 # mes from resampling
1588 self.plot_name_orig = "seasonal_mse_from_uncertainty_stack_plot"
1589 self._data = self._prepare_data_from_uncertainty(boot_dim, data_path, model_type_dim, model_indicator,
1590 model_name)
1591 for orientation in ["horizontal", "vertical"]:
1592 for split_ahead in [True, False]:
1593 self._plot(ahead_dim, split_ahead, sampling, orientation)
1594 self._save(bbox_inches="tight")
1596 def _prepare_data(self, data):
1597 season_mean = data.groupby(f"{self.time_dim}.{self.season_dim}").mean()
1598 total_mean = data.mean(self.time_dim)
1599 factor = season_mean / season_mean.sum(self.season_dim)
1600 season_share = (total_mean * factor).reindex({self.season_dim: ["DJF", "MAM", "JJA", "SON"]})
1601 season_share = season_share.mean(set(season_share.dims).difference(self.dim_order))
1602 return season_share.sortby(season_share.sum([self.season_dim, self.ahead_dim])).transpose(*self.dim_order)
1604 def _prepare_data_from_uncertainty(self, boot_dim, data_path, model_type_dim, model_indicator, model_name):
1605 season_dim = self.season_dim
1606 data = {}
1607 for season in ["total", "DJF", "MAM", "JJA", "SON"]:
1608 if season == "total":
1609 file_name = "uncertainty_estimate_raw_results.nc"
1610 else:
1611 file_name = f"uncertainty_estimate_raw_results_{season}.nc"
1612 with xr.open_dataarray(os.path.join(data_path, file_name)) as d:
1613 data[season] = d
1614 mean = {}
1615 for season in data.keys():
1616 mean[season] = data[season].mean(boot_dim)
1617 xr_data = xr.Dataset(mean).to_array(season_dim)
1618 xr_data[model_type_dim] = [v if v != model_indicator else model_name for v in xr_data[model_type_dim].values]
1619 xr_season = xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]})
1620 factor = xr_season / xr_season.sum(season_dim)
1621 season_share = xr_data.sel({season_dim: "total"}) * factor
1622 return season_share.sortby(season_share.sum([self.season_dim, self.ahead_dim])).transpose(*self.dim_order)
1624 @staticmethod
1625 def _set_bar_label(ax):
1626 opts = {}
1627 sum = {}
1628 for c in ax.containers:
1629 labels = [v for v in c.datavalues]
1630 opts[c] = labels
1631 sum = {i: sum.get(i, 0) + l for (i, l) in enumerate(labels)}
1632 for c, labels in opts.items():
1633 _l = [f"{round(100 * labels[i] / sum[i])}%" for i in range(len(labels))]
1634 ax.bar_label(c, labels=_l, label_type='center')
1636 def _plot(self, dim, split_ahead=True, sampling="daily", orientation="vertical"):
1637 _, sampling_letter = self._get_sampling(sampling, 1)
1638 sampling_letter = {"d": "D", "h": "H"}.get(sampling_letter, sampling_letter)
1639 if split_ahead is False:
1640 self.plot_name = self.plot_name_orig + "_total_" + orientation
1641 data = self._data.mean(dim)
1642 if orientation == "vertical":
1643 fig, ax = plt.subplots(1, 1)
1644 data.to_pandas().T.plot.bar(ax=ax, stacked=True, cmap="Dark2", legend=False)
1645 ax.xaxis.label.set_visible(False)
1646 ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})")
1647 self._set_bar_label(ax)
1648 else:
1649 m = data.to_pandas().T.shape[0]
1650 fig, ax = plt.subplots(1, 1, figsize=(6, m))
1651 data.to_pandas().T.plot.barh(ax=ax, stacked=True, cmap="Dark2", legend=False)
1652 ax.yaxis.label.set_visible(False)
1653 ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})")
1654 self._set_bar_label(ax)
1655 fig.legend(*ax.get_legend_handles_labels(), loc="upper center", ncol=4)
1656 fig.tight_layout(rect=[0, 0, 1, 0.9])
1657 else:
1658 self.plot_name = self.plot_name_orig + "_" + orientation
1659 data = self._data
1660 n = len(data.coords[dim])
1661 m = data.max(self.season_dim).shape
1662 if orientation == "vertical":
1663 fig, ax = plt.subplots(1, n, sharey=True, figsize=(np.prod(m) / 0.8, 5))
1664 for i, sel in enumerate(data.coords[dim].values):
1665 data.sel({dim: sel}).to_pandas().T.plot.bar(ax=ax[i], stacked=True, cmap="Dark2", legend=False)
1666 label = sampling_letter + str(sel)
1667 ax[i].set_title(label)
1668 ax[i].xaxis.label.set_visible(False)
1669 self._set_bar_label(ax[i])
1670 ax[0].set_ylabel(f"{self.error_measure} (in {self.error_unit})")
1671 fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4)
1672 fig.tight_layout(rect=[0, 0, 1, 0.9])
1673 else:
1674 fig, ax = plt.subplots(n, 1, sharex=True, figsize=(6, np.prod(m) * 0.6))
1675 for i, sel in enumerate(data.coords[dim].values):
1676 data.sel({dim: sel}).to_pandas().T.plot.barh(ax=ax[i], stacked=True, cmap="Dark2", legend=False)
1677 label = sampling_letter + str(sel)
1678 ax[i].set_title(label)
1679 ax[i].yaxis.label.set_visible(False)
1680 self._set_bar_label(ax[i])
1681 ax[-1].set_xlabel(f"{self.error_measure} (in {self.error_unit})")
1682 fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4)
1683 fig.tight_layout(rect=[0, 0, 1, 0.95])
1686@TimeTrackingWrapper
1687class PlotErrorsOnMap(AbstractPlotClass):
1688 from mlair.plotting.data_insight_plotting import PlotStationMap
1690 def __init__(self, data_gen, errors, error_metric, plot_folder: str = ".", iter_dim: str = "station",
1691 model_type_dim: str = "type", ahead_dim: str = "ahead", sampling: str = "daily"):
1693 super().__init__(plot_folder, f"map_plot_{error_metric}")
1694 plot_path = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
1695 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
1696 error_metric_units = helpers.statistics.get_error_metrics_units("ppb")[error_metric]
1697 error_metric_name = helpers.statistics.get_error_metrics_long_name()[error_metric]
1698 self.sampling = self._get_sampling(sampling, 1)[1]
1700 coords = self._extract_coords(data_gen)
1701 for split_ahead in [False, True]:
1702 error_data = {}
1703 for model_type in errors.coords[model_type_dim].values:
1704 error_data[model_type] = self._prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric,
1705 split_ahead=split_ahead)
1706 limits = self._calculate_limits(error_data)
1707 for model_type, error in error_data.items():
1708 if split_ahead is True:
1709 for ahead in error.index.unique(1).to_list():
1710 error_ahead = error.query(f"{ahead_dim} == {ahead}").droplevel(1)
1711 plot_data = pd.concat([coords, error_ahead], axis=1)
1712 self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits,
1713 ahead=ahead)
1714 pdf_pages.savefig()
1715 else:
1716 plot_data = pd.concat([coords, error], axis=1)
1717 self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits)
1718 pdf_pages.savefig()
1719 pdf_pages.close()
1720 plt.close('all')
1722 @staticmethod
1723 def _calculate_limits(data):
1724 vmin, vmax = np.inf, -np.inf
1725 for v in data.values():
1726 vmin = min(vmin, v.min().values)
1727 vmax = max(vmax, v.max().values)
1728 return relative_round(float(vmin), 2, floor=True), relative_round(float(vmax), 2, ceil=True)
1730 @staticmethod
1731 def _set_bounds(limits, ncolors, error_metric):
1732 bound_lims = {"ioa": [0, 1], "mnmb": [-2, 2]}.get(error_metric, limits)
1733 vmin = relative_round(bound_lims[0], 2, floor=True)
1734 vmax = relative_round(bound_lims[1], 2, ceil=True)
1735 interval = relative_round((vmax - vmin) / ncolors, 1, ceil=True)
1736 bounds = np.sort(np.arange(vmax, vmin, -interval))
1737 return bounds
1739 @staticmethod
1740 def _get_colorpalette(error_metric):
1741 # cmap = matplotlib.cm.coolwarm
1742 # cmap = sns.color_palette("magma_r", as_cmap=True)
1743 # cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm",
1744 # cmap = sns.cubehelix_palette(8, start=2, rot=0, dark=0, light=.95, as_cmap=True)
1745 if error_metric == "mnmb":
1746 cmap = sns.mpl_palette("coolwarm", as_cmap=True)
1747 elif error_metric == "ioa":
1748 cmap = sns.mpl_palette("coolwarm_r", as_cmap=True)
1749 else:
1750 cmap = sns.color_palette("magma_r", as_cmap=True)
1751 return cmap
1753 def plot(self, plot_data, error_metric, error_long_name, error_units, model_type, limits, ahead=None):
1754 import cartopy.crs as ccrs
1755 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
1756 fig = plt.figure(figsize=(10, 5))
1757 ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
1758 _gl = ax.gridlines(xlocs=range(-180, 180, 5), ylocs=range(-90, 90, 2), draw_labels=True)
1759 _gl.xformatter = LONGITUDE_FORMATTER
1760 _gl.yformatter = LATITUDE_FORMATTER
1761 self._draw_background(ax)
1762 cmap = self._get_colorpalette(error_metric)
1763 ncolors = 20
1764 bounds = self._set_bounds(limits, ncolors, error_metric)
1765 norm = colors.BoundaryNorm(bounds, cmap.N, extend='both')
1766 cb = ax.scatter(plot_data["lon"], plot_data["lat"], c=plot_data[error_metric], marker='o', s=50,
1767 transform=ccrs.PlateCarree(), zorder=2, cmap=cmap, norm=norm)
1768 cbar_label = f"{error_long_name} (in {error_units})" if error_units is not None else error_long_name
1769 plt.colorbar(cb, label=cbar_label)
1770 self._adjust_extent(ax)
1771 sampling_letter = {"d": "D", "h": "H"}.get(self.sampling, self.sampling)
1772 title = model_type if ahead is None else f"{model_type} ({sampling_letter}{ahead})"
1773 plt.title(title)
1774 plt.tight_layout()
1776 @staticmethod
1777 def _adjust_extent(ax):
1778 import cartopy.crs as ccrs
1780 def diff(arr):
1781 return arr[1] - arr[0], arr[3] - arr[2]
1783 def find_ratio(delta, reference=5):
1784 return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5)
1786 extent = ax.get_extent(crs=ccrs.PlateCarree())
1787 ratio = find_ratio(diff(extent))
1788 new_extent = extent + np.array([-1, 1, -1, 1]) * ratio
1789 ax.set_extent(new_extent, crs=ccrs.PlateCarree())
1791 @staticmethod
1792 def _extract_coords(gen):
1793 coll = []
1794 for station in gen:
1795 coords = station.get_coordinates()
1796 coll.append((str(station), coords["lon"], coords["lat"]))
1797 return pd.DataFrame(coll, columns=["station", "lon", "lat"]).set_index("station")
1799 @staticmethod
1800 def _prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric, split_ahead=False):
1801 e = errors.sel({model_type_dim: model_type}, drop=True)
1802 if split_ahead is False:
1803 e = e.mean(ahead_dim)
1804 return e.to_dataframe(error_metric)
1806 @staticmethod
1807 def _draw_background(ax):
1808 """Draw coastline, lakes, ocean, rivers and country borders as background on the map."""
1810 import cartopy.feature as cfeature
1812 ax.add_feature(cfeature.LAND.with_scale("50m"))
1813 ax.natural_earth_shp(resolution='50m')
1814 ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
1815 ax.add_feature(cfeature.LAKES.with_scale("50m"))
1816 ax.add_feature(cfeature.OCEAN.with_scale("50m"))
1817 ax.add_feature(cfeature.RIVERS.with_scale("50m"))
1818 ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
1825 def _plot_individual(self):
1826 import cartopy.feature as cfeature
1827 import cartopy.crs as ccrs
1828 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
1829 from mpl_toolkits.axes_grid1 import make_axes_locatable
1831 for competitor in self.reference_models:
1832 file_name = os.path.join(self.skill_score_report_path,
1833 f"error_report_skill_score_{self.model_name}_-_{competitor}.csv"
1834 )
1836 plot_path = os.path.join(os.path.abspath(self.plot_folder),
1837 f"{self.plot_name}_{self.model_name}_-_{competitor}.pdf")
1838 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
1840 for i, lead_name in enumerate(df.columns[:-2]): # last two are lat lon
1841 fig = plt.figure()
1842 self._ax.scatter(df.lon.values, df.lat.values, c=df[lead_name],
1843 transform=ccrs.PlateCarree(),
1844 norm=norm, cmap=cmap)
1845 self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
1846 self._gl.xformatter = LONGITUDE_FORMATTER
1847 self._gl.yformatter = LATITUDE_FORMATTER
1848 label = f"Skill Score: {lead_name.replace('-', 'vs.').replace('(t+', ' (').replace(')', 'd)')}"
1849 self._cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
1850 orientation='horizontal', ticks=ticks,
1851 label=label,
1852 # cax=cax
1853 )
1855 # close all open figures / plots
1856 pdf_pages.savefig()
1857 pdf_pages.close()
1858 plt.close('all')
1860 def _plot(self, ncol: int = 2):
1861 import cartopy.feature as cfeature
1862 import cartopy.crs as ccrs
1863 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
1864 import string
1865 base_plot_name = self.plot_name
1866 for competitor in self.reference_models:
1867 file_name = os.path.join(self.skill_score_report_path,
1868 f"error_report_skill_score_{self.model_name}_-_{competitor}.csv"
1869 )
1871 self.plot_name = f"{base_plot_name}_{self.model_name}_-_{competitor}"
1872 df = self.open_data(file_name)
1874 nrow = int(np.ceil(len(df.columns[:-2])/ncol))
1875 bounds = np.linspace(-1, 1, 100)
1876 cmap = mpl.cm.coolwarm
1877 norm = colors.BoundaryNorm(bounds, cmap.N, extend='both')
1878 ticks = np.arange(norm.vmin, norm.vmax + .2, .2)
1879 fig, self._axes = plt.subplots(nrows=nrow, ncols=ncol, subplot_kw={'projection': ccrs.PlateCarree()})
1880 for i, ax in enumerate(self._axes.reshape(-1)): # last two are lat lon
1882 sub_name = f"({string.ascii_lowercase[i]})"
1883 lead_name = df.columns[i]
1884 ax.add_feature(cfeature.LAND.with_scale("50m"))
1885 ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
1886 ax.add_feature(cfeature.OCEAN.with_scale("50m"))
1887 ax.add_feature(cfeature.RIVERS.with_scale("50m"))
1888 ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
1889 ax.scatter(df.lon.values, df.lat.values, c=df[lead_name],
1890 marker='.',
1891 transform=ccrs.PlateCarree(),
1892 norm=norm, cmap=cmap)
1893 gl = ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
1894 gl.xformatter = LONGITUDE_FORMATTER
1895 gl.yformatter = LATITUDE_FORMATTER
1896 gl.top_labels = []
1897 gl.right_labels = []
1898 ax.text(0.01, 1.09, f'{sub_name} {lead_name.split("+")[1][:-1]}d',
1899 verticalalignment='top', horizontalalignment='left',
1900 transform=ax.transAxes,
1901 color='black',
1902 )
1903 label = f"Skill Score: {lead_name.replace('-', 'vs.').split('(')[0]}"
1905 fig.subplots_adjust(bottom=0.18)
1906 cax = fig.add_axes([0.15, 0.1, 0.7, 0.02])
1907 self._cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
1908 orientation='horizontal',
1909 ticks=ticks,
1910 label=label,
1911 cax=cax
1912 )
1914 fig.subplots_adjust(wspace=.001, hspace=.2)
1915 self._save(bbox_inches="tight")
1916 plt.close('all')
1919 @staticmethod
1920 def get_coords_from_index(name_string: str) -> List[float]:
1921 """
1923 :param name_string:
1924 :type name_string:
1925 :return: List of coords [lat, lon]
1926 :rtype: List
1927 """
1928 res = [float(frac.replace("_", ".")) for frac in name_string.split(sep="__")[1:]]
1929 return res