Coverage for mlair/plotting/data_insight_plotting.py: 100%
8 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1"""Collection of plots to get more insight into data."""
2__author__ = "Lukas Leufen, Felix Kleinert"
3__date__ = '2021-04-13'
5from typing import List, Dict
6import dill
7import os
8import logging
9import multiprocessing
10import psutil
11import sys
13import numpy as np
14import pandas as pd
15import xarray as xr
16import seaborn as sns
17import matplotlib
18# matplotlib.use("Agg")
19from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates
20from astropy.timeseries import LombScargle
22from mlair.data_handler import DataCollection
23from mlair.helpers import TimeTrackingWrapper, to_list, remove_items
24from mlair.plotting.abstract_plot_class import AbstractPlotClass
27@TimeTrackingWrapper
28class PlotStationMap(AbstractPlotClass): # pragma: no cover
29 """
30 Plot geographical overview of all used stations as squares.
32 Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to
33 plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored
34 topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf
36 .. image:: ../../../../../_source/_plots/station_map.png
37 :width: 400
38 """
40 def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"):
41 """
42 Set attributes and create plot.
44 :param generators: dictionary with the plot color of each data set as key and the generator containing all stations
45 as value.
46 :param plot_folder: path to save the plot (default: current directory)
47 """
48 super().__init__(plot_folder, plot_name)
49 self._ax = None
50 self._gl = None
51 self._plot(generators)
52 self._save(bbox_inches="tight")
54 def _draw_background(self):
55 """Draw coastline, lakes, ocean, rivers and country borders as background on the map."""
57 import cartopy.feature as cfeature
59 self._ax.add_feature(cfeature.LAND.with_scale("50m"))
60 self._ax.natural_earth_shp(resolution='50m')
61 self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
62 self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
63 self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
64 self._ax.add_feature(cfeature.RIVERS.with_scale("50m"))
65 self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
67 def _plot_stations(self, generators):
68 """
69 Loop over all keys in generators dict and its containing stations and plot the stations's position.
71 Position is highlighted by a square on the map regarding the given color.
73 :param generators: dictionary with the plot color of each data set as key and the generator containing all
74 stations as value.
75 """
77 import cartopy.crs as ccrs
78 if generators is not None:
79 legend_elements = []
80 default_colors = self.get_dataset_colors()
81 for element in generators:
82 data_collection, plot_opts = self._get_collection_and_opts(element)
83 name = data_collection.name or "unknown"
84 marker = plot_opts.get("marker", "s")
85 ms = plot_opts.get("ms", 6)
86 mec = plot_opts.get("mec", "k")
87 mfc = plot_opts.get("mfc", default_colors.get(name, "b"))
88 legend_elements.append(
89 mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None',
90 label=f"{name} ({len(data_collection)})"))
91 for station in data_collection:
92 coords = station.get_coordinates()
93 IDx, IDy = coords["lon"], coords["lat"]
94 self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree())
95 if len(legend_elements) > 0:
96 self._ax.legend(handles=legend_elements, loc='best')
98 @staticmethod
99 def _adjust_marker(marker):
100 _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"}
101 if isinstance(marker, int) and marker in _adjust.keys():
102 return _adjust[marker]
103 else:
104 return marker
106 @staticmethod
107 def _get_collection_and_opts(element):
108 if isinstance(element, tuple):
109 if len(element) == 1:
110 return element[0], {}
111 else:
112 return element
113 else:
114 return element, {}
116 def _plot(self, generators: List):
117 """
118 Create the station map plot.
120 Set figure and call all required sub-methods.
122 :param generators: dictionary with the plot color of each data set as key and the generator containing all
123 stations as value.
124 """
126 import cartopy.crs as ccrs
127 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
128 fig = plt.figure(figsize=(10, 5))
129 self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
130 self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
131 self._gl.xformatter = LONGITUDE_FORMATTER
132 self._gl.yformatter = LATITUDE_FORMATTER
133 self._draw_background()
134 self._plot_stations(generators)
135 self._adjust_extent()
136 plt.tight_layout()
138 def _adjust_extent(self):
139 import cartopy.crs as ccrs
141 def diff(arr):
142 return arr[1] - arr[0], arr[3] - arr[2]
144 def find_ratio(delta, reference=5):
145 return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5)
147 extent = self._ax.get_extent(crs=ccrs.PlateCarree())
148 ratio = find_ratio(diff(extent))
149 new_extent = extent + np.array([-1, 1, -1, 1]) * ratio
150 self._ax.set_extent(new_extent, crs=ccrs.PlateCarree())
153@TimeTrackingWrapper
154class PlotAvailability(AbstractPlotClass): # pragma: no cover
155 """
156 Create data availablility plot similar to Gantt plot.
158 Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal
159 resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a
160 colored bar or a blank space.
162 You can set different colors to highlight subsets for example by providing different generators for the same index
163 using different keys in the input dictionary.
165 Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs
166 in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset.
168 Calling this class will create three versions fo the availability plot.
170 1) Data availability for each element
171 1) Data availability as summary over all elements (is there at least a single elemnt for each time step)
172 1) Combination of single and overall availability
174 .. image:: ../../../../../_source/_plots/data_availability.png
175 :width: 400
177 .. image:: ../../../../../_source/_plots/data_availability_summary.png
178 :width: 400
180 .. image:: ../../../../../_source/_plots/data_availability_combined.png
181 :width: 400
183 """
185 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
186 summary_name="data availability", time_dimension="datetime", window_dimension="window"):
187 """Initialise."""
188 # create standard Gantt plot for all stations (currently in single pdf file with single page)
189 super().__init__(plot_folder, "data_availability")
190 self.time_dim = time_dimension
191 self.window_dim = window_dimension
192 self.sampling = self._get_sampling(sampling)[1]
193 self.linewidth = None
194 if self.sampling == 'h':
195 self.linewidth = 0.001
196 plot_dict = self._prepare_data(generators)
197 lgd = self._plot(plot_dict)
198 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
199 # create summary Gantt plot (is data in at least one station available)
200 self.plot_name += "_summary"
201 plot_dict_summary = self._summarise_data(generators, summary_name)
202 lgd = self._plot(plot_dict_summary)
203 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
204 # combination of station and summary plot, last element is summary broken bar
205 self.plot_name = "data_availability_combined"
206 plot_dict_summary.update(plot_dict)
207 lgd = self._plot(plot_dict_summary)
208 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
210 def _prepare_data(self, generators: Dict[str, DataCollection]):
211 plt_dict = {}
212 for subset, data_collection in generators.items():
213 for station in data_collection:
214 labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
215 labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
216 group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum()
217 plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values},
218 index=labels.coords[self.time_dim].values)
219 t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
220 t2 = [i[1:] for i in t if i[0]]
222 if plt_dict.get(str(station)) is None:
223 plt_dict[str(station)] = {subset: t2}
224 else:
225 plt_dict[str(station)].update({subset: t2})
226 return plt_dict
228 def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
229 plt_dict = {}
230 for subset, data_collection in generators.items():
231 all_data = None
232 for station in data_collection:
233 labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
234 labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
235 if all_data is None:
236 all_data = labels_bool
237 else:
238 tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords
239 all_data = np.logical_or(tmp, labels_bool).combine_first(
240 all_data) # apply logical on merge and fill missing with all_data
242 group = (all_data != all_data.shift({self.time_dim: 1})).cumsum()
243 plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values},
244 index=all_data.coords[self.time_dim].values)
245 t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
246 t2 = [i[1:] for i in t if i[0]]
247 if plt_dict.get(summary_name) is None:
248 plt_dict[summary_name] = {subset: t2}
249 else:
250 plt_dict[summary_name].update({subset: t2})
251 return plt_dict
253 def _plot(self, plt_dict):
254 colors = self.get_dataset_colors()
255 _used_colors = []
256 pos = 0
257 height = 0.8 # should be <= 1
258 yticklabels = []
259 number_of_stations = len(plt_dict.keys())
260 fig, ax = plt.subplots(figsize=(10, number_of_stations / 3))
261 for station, d in sorted(plt_dict.items(), reverse=True):
262 pos += 1
263 for subset, color in colors.items():
264 plt_data = d.get(subset)
265 if plt_data is None:
266 continue
267 elif color not in _used_colors: # this is required for a proper legend creation
268 _used_colors.append(color)
269 ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth)
270 yticklabels.append(station)
272 ax.set_ylim([height, number_of_stations + 1])
273 ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2)
274 ax.set_yticklabels(yticklabels)
275 handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors]
276 lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
277 return lgd
280@TimeTrackingWrapper
281class PlotAvailabilityHistogram(AbstractPlotClass): # pragma: no cover
282 """
283 Create data availability plots as histogram.
285 Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean).
286 Calling this class creates two different types of histograms where each generator
288 1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis)
289 2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number
290 of samples (yaxis)
292 .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png
293 :width: 400
295 .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png
296 :width: 400
298 """
300 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".",
301 subset_dim: str = 'DataSet', history_dim: str = 'window',
302 station_dim: str = 'Stations', ):
304 super().__init__(plot_folder, "data_availability_histogram")
306 self.subset_dim = subset_dim
307 self.history_dim = history_dim
308 self.station_dim = station_dim
310 self.freq = None
311 self.temporal_dim = None
312 self.target_dim = None
313 self._prepare_data(generators)
315 for plt_type in self.allowed_plot_types:
316 plot_name_tmp = self.plot_name
317 self.plot_name += '_' + plt_type
318 self._plot(plt_type=plt_type)
319 self._save()
320 self.plot_name = plot_name_tmp
322 def _set_dims_from_datahandler(self, data_handler):
323 self.temporal_dim = data_handler.id_class.time_dim
324 self.target_dim = data_handler.id_class.target_dim
325 self.freq = self._get_sampling(data_handler.id_class.sampling)[1]
327 @property
328 def allowed_plot_types(self):
329 plot_types = ['hist', 'hist_cum']
330 return plot_types
332 def _prepare_data(self, generators: Dict[str, DataCollection]):
333 """
334 Prepares data to be used by plot methods.
336 Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim
337 """
338 avail_data_time_sum = {}
339 avail_data_station_sum = {}
340 dataset_time_interval = {}
341 for subset, generator in generators.items():
342 avail_list = []
343 for station in generator:
344 self._set_dims_from_datahandler(data_handler=station)
345 station_data_x = station.get_X(as_numpy=False)[0]
346 station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame
347 self.target_dim: station_data_x[self.target_dim].values[0]}]
348 station_data_x = self._reduce_dims(station_data_x)
349 avail_list.append(station_data_x.notnull())
350 avail_data = xr.concat(avail_list, dim=self.station_dim).notnull()
351 avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim)
352 avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim)
353 dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray(
354 avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict'
355 )
356 avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(),
357 name=self.subset_dim)
358 )
359 full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq)
360 self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(),
361 name=self.subset_dim))
362 self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index})
363 self.dataset_time_interval = dataset_time_interval
365 def _reduce_dims(self, dataset):
366 if len(dataset.dims) > 2:
367 required = {self.temporal_dim, self.station_dim}
368 unimportant = set(dataset.dims).difference(required)
369 sel_dict = {un: dataset[un].values[0] for un in unimportant}
370 dataset = dataset.loc[sel_dict]
371 return dataset
373 @staticmethod
374 def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'):
375 if isinstance(xarray, xr.DataArray):
376 first = xarray.coords[dim_name].values[0]
377 last = xarray.coords[dim_name].values[-1]
378 if return_type == 'as_tuple':
379 return first, last
380 elif return_type == 'as_dict':
381 return {'first': first, 'last': last}
382 else:
383 raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'")
384 else:
385 raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}")
387 @staticmethod
388 def _make_full_time_index(irregular_time_index, freq):
389 full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq)
390 return full_time_index
392 def _plot(self, plt_type='hist', *args):
393 if plt_type == 'hist':
394 self._plot_hist()
395 elif plt_type == 'hist_cum':
396 self._plot_hist_cum()
397 else:
398 raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}")
400 def _plot_hist(self, *args):
401 colors = self.get_dataset_colors()
402 fig, axes = plt.subplots(figsize=(10, 3))
403 for i, subset in enumerate(self.dataset_time_interval.keys()):
404 plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset,
405 self.temporal_dim: slice(
406 self.dataset_time_interval[subset]['first'],
407 self.dataset_time_interval[subset]['last']
408 )
409 }
410 )
412 plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset)
413 plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset])
415 lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
416 facecolor='white', framealpha=1, edgecolor='black')
417 for lgd_line in lgd.get_lines():
418 lgd_line.set_linewidth(4.0)
419 plt.gca().xaxis.set_major_locator(mdates.YearLocator())
420 plt.title('')
421 plt.ylabel('Number of samples')
422 plt.tight_layout()
424 def _plot_hist_cum(self, *args):
425 colors = self.get_dataset_colors()
426 fig, axes = plt.subplots(figsize=(10, 3))
427 n_bins = int(self.avail_data_cum_sum.max().values)
428 bins = np.arange(0, n_bins + 1)
429 descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby(
430 self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False
431 ).coords[self.subset_dim].values
433 for subset in descending_subsets:
434 self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes,
435 bins=bins,
436 label=subset,
437 cumulative=-1,
438 color=colors[subset],
439 # alpha=.5
440 )
442 lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
443 facecolor='white', framealpha=1, edgecolor='black')
444 plt.title('')
445 plt.ylabel('Number of stations')
446 plt.xlabel('Number of samples')
447 plt.xlim((bins[0], bins[-1]))
448 plt.tight_layout()
451@TimeTrackingWrapper
452class PlotDataMonthlyDistribution(AbstractPlotClass): # pragma: no cover
454 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", variables_dim="variables",
455 time_dim="datetime", window_dim="window", target_var: str = "", target_var_unit: str = "ppb"):
456 """Set attributes and create plot."""
457 super().__init__(plot_folder, "monthly_data_distribution")
458 self.variables_dim = variables_dim
459 self.time_dim = time_dim
460 self.window_dim = window_dim
461 self.coll_dim = "coll"
462 self.subset_dim = "subset"
463 self._data = self._prepare_data(generators)
464 self._plot(target_var, target_var_unit)
465 self._save()
467 def _prepare_data(self, generators) -> List[xr.DataArray]:
468 """
469 Pre.process data required to plot.
471 :param generator: data
472 :return: The entire data set, flagged with the corresponding month.
473 """
474 f = lambda x: x.get_observation()
475 forecasts = []
476 for set_type, generator in generators.items():
477 forecasts_set = None
478 forecasts_monthly = {}
479 for i, gen in enumerate(generator):
480 data = f(gen)
481 data = gen.apply_transformation(data, inverse=True)
482 data = data.clip(min=0).reset_coords(drop=True)
483 new_index = data.coords[self.time_dim].values.astype("datetime64[M]").astype(int) % 12 + 1
484 data = data.assign_coords({self.time_dim: new_index})
485 forecasts_set = xr.concat([forecasts_set, data], self.time_dim) if forecasts_set is not None else data
486 for month in set(forecasts_set.coords[self.time_dim].values):
487 monthly_values = forecasts_set.sel({self.time_dim: month}).values
488 forecasts_monthly[month] = np.concatenate((forecasts_monthly.get(month, []), monthly_values))
489 forecasts_monthly = pd.DataFrame.from_dict(forecasts_monthly, orient="index")#.transpose()
490 forecasts_monthly[self.coll_dim] = set_type
491 forecasts.append(forecasts_monthly.set_index(self.coll_dim, append=True))
492 forecasts = pd.concat(forecasts).stack().rename_axis(["month", "subset", "index"])
493 forecasts = forecasts.to_frame(name="values").reset_index(level=[0, 1])
494 return forecasts
496 @staticmethod
497 def _spell_out_chemical_concentrations(short_name: str, add_concentration: bool = False):
498 short2long = {'o3': 'ozone', 'no': 'nitrogen oxide', 'no2': 'nitrogen dioxide', 'nox': 'nitrogen dioxides'}
499 _suffix = "" if add_concentration is False else " concentration"
500 return f"{short2long[short_name]}{_suffix}"
502 def _plot(self, target_var: str, target_var_unit: str):
503 """
504 Create a monthly grouped box plot over all stations but with separate boxes for each lead time step.
506 :param target_var: display name of the target variable on plot's axis
507 """
508 ax = sns.boxplot(data=self._data, x="month", y="values", hue="subset", whis=1.5,
509 palette=self.get_dataset_colors(), flierprops={'marker': '.', 'markersize': 1}, showmeans=True,
510 meanprops={'markersize': 1, 'markeredgecolor': 'k'})
511 ylabel = self._spell_out_chemical_concentrations(target_var)
512 ax.set(xlabel='month', ylabel=f'dma8 {ylabel} (in {target_var_unit})')
513 plt.tight_layout()
516@TimeTrackingWrapper
517class PlotDataHistogram(AbstractPlotClass): # pragma: no cover
518 """
519 Plot histogram on transformed input and target data. This data is the same that the model sees during training. No
520 plots are create for the original values space (raw / unformatted data). This plot method will create a histogram
521 for input and target each comparing the subsets train, val and test, as well as a distinct one for the three
522 subsets.
524 .. image:: ../../../../../_source/_plots/datahistogram.png
525 :width: 400
527 """
529 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", plot_name="histogram",
530 variables_dim="variables", time_dim="datetime", window_dim="window", upsampling=False):
531 super().__init__(plot_folder, plot_name)
532 self.variables_dim = variables_dim
533 self.time_dim = time_dim
534 self.window_dim = window_dim
535 self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim)
536 self.bins = {}
537 self.interval_width = {}
538 self.bin_edges = {}
539 if upsampling is True:
540 self._handle_upsampling(generators)
542 # input plots
543 for branch_pos in range(number_of_branches):
544 self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos)
545 add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}"
546 for subset in generators.keys():
547 self._plot(add_name=add_name, subset=subset)
548 self._plot_combined(add_name=add_name)
550 # target plots
551 self._calculate_hist(generators, self.targets, input_data=False)
552 for subset in generators.keys():
553 self._plot(add_name="target", subset=subset)
554 self._plot_combined(add_name="target")
556 @staticmethod
557 def _handle_upsampling(generators):
558 if "train" in generators:
559 generators.update({"train_upsampled": generators["train"]})
561 @staticmethod
562 def _get_inputs_targets(gens, dim):
563 k = list(gens.keys())[0]
564 gen = gens[k][0]
565 inputs = list(set([y for x in to_list(gen.get_X(as_numpy=False)) for y in x.coords[dim].values.tolist()]))
566 targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist())
567 n_branches = len(gen.get_X(as_numpy=False))
568 return inputs, targets, n_branches
570 def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0):
571 n_bins = 100
572 for set_type, generator in generators.items():
573 upsampling = "upsampled" in set_type
574 tmp_bins = {}
575 tmp_edges = {}
576 end = {}
577 start = {}
578 if input_data is True:
579 f = lambda x: x.get_X(as_numpy=False, upsampling=upsampling)[branch_pos]
580 else:
581 f = lambda x: x.get_Y(as_numpy=False, upsampling=upsampling)
582 for gen in generator:
583 w = min(abs(f(gen).coords[self.window_dim].values))
584 data = f(gen).sel({self.window_dim: w})
585 res, _, g_edges = f_proc_hist(data, variables, n_bins, self.variables_dim)
586 for var in res.keys():
587 b = tmp_bins.get(var, [])
588 b.append(res[var])
589 tmp_bins[var] = b
590 e = tmp_edges.get(var, [])
591 e.append(g_edges[var])
592 tmp_edges[var] = e
593 end[var] = max([end.get(var, g_edges[var].max()), g_edges[var].max()])
594 start[var] = min([start.get(var, g_edges[var].min()), g_edges[var].min()])
595 # interpolate and aggregate
596 bins = {}
597 edges = {}
598 interval_width = {}
599 for var in tmp_bins.keys():
600 bin_edges = np.linspace(start[var], end[var], n_bins + 1)
601 interval_width[var] = bin_edges[1] - bin_edges[0]
602 for i, e in enumerate(tmp_bins[var]):
603 bins_interp = np.interp(bin_edges[:-1], tmp_edges[var][i][:-1], e, left=0, right=0)
604 bins[var] = bins.get(var, np.zeros(n_bins)) + bins_interp
605 edges[var] = bin_edges
607 self.bins[set_type] = bins
608 self.interval_width[set_type] = interval_width
609 self.bin_edges[set_type] = edges
611 def _plot(self, add_name, subset):
612 plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{subset}_{add_name}.pdf")
613 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
614 bins = self.bins[subset]
615 bin_edges = self.bin_edges[subset]
616 interval_width = self.interval_width[subset]
617 colors = self.get_dataset_colors()
618 colors.update({"train_upsampled": colors.get("train_val", "#000000")})
619 for var in bins.keys():
620 fig, ax = plt.subplots()
621 hist_var = bins[var]
622 n_var = sum(hist_var)
623 weights = hist_var / (interval_width[var] * n_var)
624 ax.hist(bin_edges[var][:-1], bin_edges[var], weights=weights, color=colors[subset])
625 ax.set_ylabel("probability density")
626 ax.set_xlabel(f"values")
627 ax.set_title(f"histogram {var} ({subset}, n={int(n_var)})")
628 pdf_pages.savefig()
629 # close all open figures / plots
630 pdf_pages.close()
631 plt.close('all')
633 def _plot_combined(self, add_name):
634 plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{add_name}.pdf")
635 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
636 variables = self.bins[list(self.bins.keys())[0]].keys()
637 colors = self.get_dataset_colors()
638 colors.update({"train_upsampled": colors.get("train_val", "#000000")})
639 for var in variables:
640 fig, ax = plt.subplots()
641 for subset in self.bins.keys():
642 hist_var = self.bins[subset][var]
643 interval_width = self.interval_width[subset][var]
644 bin_edges = self.bin_edges[subset][var]
645 n_var = sum(hist_var)
646 weights = hist_var / (interval_width * n_var)
647 ax.plot(bin_edges[:-1] + 0.5 * interval_width, weights, label=f"{subset}",
648 c=colors[subset])
649 ax.set_ylabel("probability density")
650 ax.set_xlabel("values")
651 ax.legend(loc="upper right")
652 ax.set_title(f"histogram {var}")
653 pdf_pages.savefig()
654 # close all open figures / plots
655 pdf_pages.close()
656 plt.close('all')
659@TimeTrackingWrapper
660class PlotPeriodogram(AbstractPlotClass): # pragma: no cover
661 """
662 Create Lomb-Scargle periodogram in raw input and target data. The Lomb-Scargle version can deal with missing values.
664 This plot routine is creating the following plots:
666 * "raw": data is not aggregated, 1 graph per variable
667 * "": single data lines are aggregated, 1 graph per variable
668 * "total": data is aggregated on all variables, single graph
670 If data consists on different sampling rates, a separate plot is create for each sampling.
672 .. image:: ../../../../../_source/_plots/periodogram.png
673 :width: 400
675 .. note::
676 This plot is not included in the default plot list. To use this plot, add "PlotPeriodogram" to the `plot_list`.
678 .. warning::
679 This plot is highly sensitive to the data handler structure. Therefore, it is highly likely that this method is
680 not compatible with any custom data handler. Proven data handlers are `DefaultDataHandler`,
681 `DataHandlerMixedSampling`, `DataHandlerMixedSamplingWithFilter`. To work properly, the data handler must have
682 the attribute `.id_class._data`.
684 """
686 def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="periodogram",
687 variables_dim="variables", time_dim="datetime", sampling="daily", use_multiprocessing=False):
688 super().__init__(plot_folder, plot_name)
689 self.variables_dim = variables_dim
690 self.time_dim = time_dim
692 for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling, sampling)):
693 self._sampling = s
694 self._add_text = {0: "input", 1: "target"}[pos]
695 multiple, label_names = self._has_filter_dimension(generator[0], pos)
696 self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing)
697 self._plot(raw=True)
698 self._plot(raw=False)
699 self._plot_total(raw=True)
700 self._plot_total(raw=False)
701 if multiple > 1:
702 self._plot_difference(label_names, plot_name_add="_last")
703 self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing,
704 use_last_input_value=False)
705 self._plot_difference(label_names, plot_name_add="_first")
707 @staticmethod
708 def _has_filter_dimension(g, pos):
709 """Inspect if filtered data is provided and return number and labels of filtered components."""
710 check_class = g.id_class
711 check_data = [check_class.get_X(as_numpy=False), check_class.get_Y(as_numpy=False)][pos]
712 if not hasattr(check_class, "filter_dim"): # data handler has no filtered data
713 return 1, []
714 else:
715 filter_dim = check_class.filter_dim
716 if filter_dim not in check_data.coords.dims: # current data has no filter (e.g. target data)
717 return 1, []
718 else:
719 return check_data.coords[filter_dim].shape[0], check_data.coords[filter_dim].values.tolist()
721 @TimeTrackingWrapper
722 def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False, use_last_input_value=True):
723 """
724 Create periodogram data.
725 """
726 self.raw_data = []
727 self.plot_data = []
728 self.plot_data_raw = []
729 self.plot_data_mean = []
730 iter = range(multiple if multiple == 1 else multiple + 1)
731 for m in iter:
732 plot_data_single = dict()
733 plot_data_raw_single = dict()
734 plot_data_mean_single = dict()
735 self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000)
736 raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing,
737 use_last_input_value=use_last_input_value)
738 for var in raw_data_single.keys():
739 pgram_com = []
740 pgram_mean = 0
741 all_data = raw_data_single[var]
742 pgram_mean_raw = np.zeros((len(self.f_index), len(all_data)))
743 for i, (f, pgram) in enumerate(all_data):
744 d = np.interp(self.f_index, f, pgram)
745 pgram_com.append(d)
746 pgram_mean += d
747 pgram_mean_raw[:, i] = d
748 pgram_mean /= len(all_data)
749 plot_data_single[var] = pgram_com
750 plot_data_mean_single[var] = (self.f_index, pgram_mean)
751 plot_data_raw_single[var] = (self.f_index, pgram_mean_raw)
752 self.plot_data.append(plot_data_single)
753 self.plot_data_mean.append(plot_data_mean_single)
754 self.plot_data_raw.append(plot_data_raw_single)
756 def _prepare_pgram_parallel_var(self, generator, m, pos, use_multiprocessing):
757 """Implementation of data preprocessing using parallel variables element processing."""
758 raw_data_single = dict()
759 for g in generator:
760 if m == 0:
761 d = g.id_class._data
762 else:
763 gd = g.id_class
764 filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]}
765 d = (gd.input_data.sel(filter_sel), gd.target_data)
766 d = d[pos] if isinstance(d, tuple) else d
767 res = []
768 if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution
769 pool = multiprocessing.Pool(
770 min([psutil.cpu_count(logical=False), len(d[self.variables_dim].values),
771 16])) # use only physical cpus
772 output = [
773 pool.apply_async(f_proc,
774 args=(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
775 for var in d[self.variables_dim].values]
776 for i, p in enumerate(output):
777 res.append(p.get())
778 pool.close()
779 pool.join()
780 else: # serial solution
781 for var in d[self.variables_dim].values:
782 res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
783 for (var_str, f, pgram) in res:
784 if var_str not in raw_data_single.keys():
785 raw_data_single[var_str] = [(f, pgram)]
786 else:
787 raw_data_single[var_str] = raw_data_single[var_str] + [(f, pgram)]
788 return raw_data_single
790 def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing, use_last_input_value=True):
791 """Implementation of data preprocessing using parallel generator element processing."""
792 raw_data_single = dict()
793 res = []
794 if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution
795 pool = multiprocessing.Pool(
796 min([psutil.cpu_count(logical=False), len(generator), 16])) # use only physical cpus
797 output = [
798 pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim, self.f_index,
799 use_last_input_value))
800 for g in generator]
801 for i, p in enumerate(output):
802 res.append(p.get())
803 pool.close()
804 pool.join()
805 else:
806 for g in generator:
807 res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim, self.f_index, use_last_input_value))
808 for res_dict in res:
809 for k, v in res_dict.items():
810 if k not in raw_data_single.keys():
811 raw_data_single[k] = v
812 else:
813 raw_data_single[k] = raw_data_single[k] + v
814 return raw_data_single
816 @staticmethod
817 def _add_annotation_line(ax, pos, div, lims, unit):
818 for p in to_list(pos): # per year
819 ax.vlines(p / div, *lims, "black")
820 ax.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor")
822 def _format_figure(self, ax, var_name="total"):
823 """
824 Set log scale on both axis, add labels and annotation lines, and set title.
825 :param ax: current ax object
826 :param var_name: name of variable that will be included in the title
827 """
828 ax.set_yscale('log')
829 ax.set_xscale('log')
830 ax.set_ylabel("power spectral density", fontsize='x-large') # unit depends on variable: [unit^2 day^-1]
831 ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large')
832 lims = ax.get_ylim()
833 self._add_annotation_line(ax, [1, 2, 3], 365.25, lims, "yr") # per year
834 self._add_annotation_line(ax, 1, 365.25 / 12, lims, "m") # per month
835 self._add_annotation_line(ax, 1, 7, lims, "w") # per week
836 self._add_annotation_line(ax, [1, 0.5], 1, lims, "d") # per day
837 if self._sampling == "hourly":
838 self._add_annotation_line(ax, 2, 1, lims, "d") # per day
839 self._add_annotation_line(ax, [1, 0.5], 1 / 24., lims, "h") # per hour
840 title = f"Periodogram ({var_name})"
841 ax.set_title(title)
843 def _plot(self, raw=True):
844 plot_path = os.path.join(os.path.abspath(self.plot_folder),
845 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}.pdf")
846 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
847 plot_data = self.plot_data[0]
848 plot_data_mean = self.plot_data_mean[0]
849 for var in plot_data.keys():
850 fig, ax = plt.subplots()
851 if raw is True:
852 for pgram in plot_data[var]:
853 ax.plot(self.f_index, pgram, "lightblue")
854 ax.plot(*plot_data_mean[var], "blue")
855 else:
856 ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
857 mean = ma.mean().mean(axis=1).values.flatten()
858 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
859 ax.plot(self.f_index, mean, "blue")
860 ax.fill_between(self.f_index, lower, upper, color="lightblue")
861 self._format_figure(ax, var)
862 pdf_pages.savefig()
863 # close all open figures / plots
864 pdf_pages.close()
865 plt.close('all')
867 def _plot_total(self, raw=True):
868 plot_path = os.path.join(os.path.abspath(self.plot_folder),
869 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}_total.pdf")
870 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
871 plot_data_raw = self.plot_data_raw[0]
872 fig, ax = plt.subplots()
873 res = None
874 for var in plot_data_raw.keys():
875 d_var = plot_data_raw[var][1]
876 res = d_var if res is None else np.concatenate((res, d_var), axis=-1)
877 if raw is True:
878 for i in range(res.shape[1]):
879 ax.plot(self.f_index, res[:, i], "lightblue")
880 ax.plot(self.f_index, res.mean(axis=1), "blue")
881 else:
882 ma = pd.DataFrame(np.vstack(res)).rolling(5, center=True, axis=0)
883 mean = ma.mean().mean(axis=1).values.flatten()
884 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
885 ax.plot(self.f_index, mean, "blue")
886 ax.fill_between(self.f_index, lower, upper, color="lightblue")
887 self._format_figure(ax, "total")
888 pdf_pages.savefig()
889 # close all open figures / plots
890 pdf_pages.close()
891 plt.close('all')
893 def _plot_difference(self, label_names, plot_name_add = ""):
894 plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter{plot_name_add}.pdf"
895 plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name)
896 logging.info(f"... plotting {plot_name}")
897 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
898 colors = ["grey", "blue", "red", "green", "orange", "purple", "black"]
899 label_names = ["orig"] + label_names
900 max_iter = len(self.plot_data)
901 var_keys = self.plot_data[0].keys()
902 for var in var_keys:
903 fig, ax = plt.subplots()
904 for i in reversed(range(max_iter)):
905 if label_names[i] == "unfiltered":
906 continue # do not include the filter 'unfiltered' because this is equal to the 'orig' data
907 plot_data = self.plot_data[i]
908 c = colors[i]
909 ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
910 mean = ma.mean().mean(axis=1).values.flatten()
911 ax.plot(self.f_index, mean, c, label=label_names[i])
912 if i < 1:
913 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
914 ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5, label=None)
915 self._format_figure(ax, var)
916 ax.legend(loc="upper center", ncol=max_iter)
917 pdf_pages.savefig()
918 # close all open figures / plots
919 pdf_pages.close()
920 plt.close('all')
923def f_proc(var, d_var, f_index, time_dim="datetime", use_last_value=True): # pragma: no cover
924 var_str = str(var)
925 t = (d_var[time_dim] - d_var[time_dim][0]).astype("timedelta64[h]").values / np.timedelta64(1, "D")
926 if len(d_var.shape) > 1: # use only max value if dimensions are remaining (e.g. max(window) -> latest value)
927 to_remove = remove_items(d_var.coords.dims, time_dim)
928 for e in to_list(to_remove):
929 d_var = d_var.sel({e: d_var[e].max() if use_last_value is True else d_var[e].min()})
930 pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").power(f_index)
931 # f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").autopower()
932 return var_str, f_index, pgram
935def f_proc_2(g, m, pos, variables_dim, time_dim, f_index, use_last_value): # pragma: no cover
937 # load lazy data
938 id_classes = list(filter(lambda x: "id_class" in x, dir(g))) if pos == 0 else ["id_class"]
939 for id_cls_name in id_classes:
940 id_cls = getattr(g, id_cls_name)
941 if hasattr(id_cls, "lazy"):
942 id_cls.load_lazy() if id_cls.lazy is True else None
944 raw_data_single = dict()
945 for dh in list(filter(lambda x: "unfiltered" not in x, id_classes)):
946 current_cls = getattr(g, dh)
947 if m == 0:
948 d = current_cls._data
949 if d is None:
950 window_dim = current_cls.window_dim
951 history = current_cls.history
952 last_entry = history.coords[window_dim][-1]
953 d1 = history.sel({window_dim: last_entry}, drop=True)
954 label = current_cls.label
955 first_entry = label.coords[window_dim][0]
956 d2 = label.sel({window_dim: first_entry}, drop=True)
957 d = (d1, d2)
958 else:
959 filter_sel = {"filter": current_cls.input_data.coords["filter"][m - 1]}
960 d = (current_cls.input_data.sel(filter_sel), current_cls.target_data)
961 d = d[pos] if isinstance(d, tuple) else d
962 for var in d[variables_dim].values:
963 d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim)
964 var_str, f, pgram = f_proc(var, d_var, f_index, use_last_value=use_last_value)
965 if var_str not in raw_data_single.keys():
966 raw_data_single[var_str] = [(f, pgram)]
967 else:
968 raise KeyError(f"There are multiple pgrams for key {var_str}. Please check your data handler.")
970 # perform clean up
971 for id_cls_name in id_classes:
972 id_cls = getattr(g, id_cls_name)
973 if hasattr(id_cls, "lazy"):
974 id_cls.clean_up() if id_cls.lazy is True else None
976 return raw_data_single
979def f_proc_hist(data, variables, n_bins, variables_dim): # pragma: no cover
980 res = {}
981 bin_edges = {}
982 interval_width = {}
983 for var in variables:
984 if var in data.coords[variables_dim]:
985 d = data.sel({variables_dim: var}).squeeze() if len(data.shape) > 1 else data
986 res[var], bin_edges[var] = np.histogram(d.values, n_bins)
987 interval_width[var] = bin_edges[var][1] - bin_edges[var][0]
988 return res, interval_width, bin_edges
991class PlotClimateFirFilter(AbstractPlotClass): # pragma: no cover
992 """
993 Plot climate FIR filter components.
995 * Creates a separate folder climFIR inside the given plot directory.
996 * For each station up to 4 examples are shown (1 for each season).
997 * Each filtered component and its residuum is drawn in a separate plot.
998 * A filter component plot includes the climate FIR input, the filter response, the true non-causal (ideal) filter
999 input, and the corresponding ideal response (containing information about future)
1000 * A filter residuum plot include the climate FIR residuum and the ideal filter residuum.
1001 """
1003 def __init__(self, plot_folder, plot_data, sampling, name):
1005 from mlair.helpers.filter import fir_filter_convolve
1007 logging.info(f"start PlotClimateFirFilter for ({name})")
1009 # adjust default plot parameters
1010 rc_params = {
1011 'axes.labelsize': 'large',
1012 'xtick.labelsize': 'large',
1013 'ytick.labelsize': 'large',
1014 'legend.fontsize': 'medium',
1015 'axes.titlesize': 'large'}
1016 if plot_folder is None:
1017 return
1019 self.style_dict = {
1020 "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"},
1021 "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"},
1022 "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2},
1023 "ideal": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2},
1024 "valid_area": {"color": "whitesmoke", "label": "valid area"},
1025 "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"}
1026 }
1028 self.variables_list = []
1029 plot_folder = os.path.join(os.path.abspath(plot_folder), "climFIR")
1030 self.fir_filter_convolve = fir_filter_convolve
1031 super().__init__(plot_folder, plot_name=None, rc_params=rc_params)
1032 plot_dict, new_dim = self._prepare_data(plot_data)
1033 self._name = name
1034 self._plot(plot_dict, sampling, new_dim)
1035 self._store_plot_data(plot_data)
1037 def _prepare_data(self, data):
1038 """Restructure plot data."""
1039 plot_dict = {}
1040 new_dim = None
1041 for i in range(len(data)):
1042 plot_data = data[i]
1043 for p_d in plot_data:
1044 var = p_d.get("var")
1045 t0 = p_d.get("t0")
1046 filter_input = p_d.get("filter_input")
1047 filter_input_nc = p_d.get("filter_input_nc")
1048 valid_range = p_d.get("valid_range")
1049 time_range = p_d.get("time_range")
1050 if new_dim is None:
1051 new_dim = p_d.get("new_dim")
1052 else:
1053 assert new_dim == p_d.get("new_dim")
1054 h = p_d.get("h")
1055 plot_dict_var = plot_dict.get(var, {})
1056 plot_dict_t0 = plot_dict_var.get(t0, {})
1057 plot_dict_order = {"filter_input": filter_input,
1058 "filter_input_nc": filter_input_nc,
1059 "valid_range": valid_range,
1060 "time_range": time_range,
1061 "order": len(h), "h": h}
1062 plot_dict_t0[i] = plot_dict_order
1063 plot_dict_var[t0] = plot_dict_t0
1064 plot_dict[var] = plot_dict_var
1065 self.variables_list = list(plot_dict.keys())
1066 return plot_dict, new_dim
1068 def _plot(self, plot_dict, sampling, new_dim="window"):
1069 td_type = {"1d": "D", "1H": "h"}.get(sampling)
1070 for var, vis_dict in plot_dict.items():
1071 for it0, t0 in enumerate(vis_dict.keys()):
1072 vis_data = vis_dict[t0]
1073 residuum_true = None
1074 try:
1075 for ifilter in sorted(vis_data.keys()):
1076 data = vis_data[ifilter]
1077 filter_input = data["filter_input"]
1078 filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel(
1079 {new_dim: filter_input.coords[new_dim]})
1080 valid_range = data["valid_range"]
1081 time_axis = data["time_range"]
1082 filter_order = data["order"]
1083 h = data["h"]
1084 fig, ax = plt.subplots()
1086 # plot backgrounds
1087 self._plot_valid_area(ax, t0, valid_range, td_type)
1088 self._plot_t0(ax, t0)
1090 # original data
1091 self._plot_original_data(ax, time_axis, filter_input_nc)
1093 # clim apriori
1094 self._plot_apriori(ax, time_axis, filter_input, new_dim, ifilter, offset=1)
1096 # clim filter response
1097 residuum_estimated = self._plot_clim_filter(ax, time_axis, filter_input, new_dim, h,
1098 output_dtypes=filter_input.dtype)
1100 # ideal filter response
1101 residuum_true = self._plot_ideal_filter(ax, time_axis, filter_input_nc, new_dim, h,
1102 output_dtypes=filter_input.dtype)
1104 # set title, legend, and save plot
1105 xlims = self._set_xlim(ax, t0, filter_order, valid_range, td_type, time_axis)
1107 plt.title(f"Input of ClimFilter ({str(var)})")
1108 plt.legend()
1109 fig.autofmt_xdate()
1110 plt.tight_layout()
1111 self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}"
1112 self._save()
1114 # plot residuum
1115 fig, ax = plt.subplots()
1116 self._plot_valid_area(ax, t0, valid_range, td_type)
1117 self._plot_t0(ax, t0)
1118 self._plot_series(ax, time_axis, residuum_true.values.flatten(), style="ideal")
1119 self._plot_series(ax, time_axis, residuum_estimated.values.flatten(), style="clim")
1120 ax.set_xlim(xlims)
1121 plt.title(f"Residuum of ClimFilter ({str(var)})")
1122 plt.legend(loc="upper left")
1123 fig.autofmt_xdate()
1124 plt.tight_layout()
1126 self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum"
1127 self._save()
1128 except Exception as e:
1129 logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
1130 pass
1132 def _set_xlim(self, ax, t0, order, valid_range, td_type, time_axis):
1133 """
1134 Set xlims
1136 Use order and valid_range to find a good zoom in that hides edges of filter values that are effected by reduced
1137 filter order. Limits are returned to be usable for other plots.
1138 """
1139 t_minus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), (-valid_range.start + 0.3 * order))
1140 t_plus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), valid_range.stop + 0.3 * order)
1141 t_minus = t0 + np.timedelta64(-int(t_minus_delta), td_type)
1142 t_plus = t0 + np.timedelta64(int(t_plus_delta), td_type)
1143 ax_start = max(t_minus, time_axis[0])
1144 ax_end = min(t_plus, time_axis[-1])
1145 ax.set_xlim((ax_start, ax_end))
1146 return ax_start, ax_end
1148 def _plot_valid_area(self, ax, t0, valid_range, td_type):
1149 ax.axvspan(t0 + np.timedelta64(valid_range.start, td_type),
1150 t0 + np.timedelta64(valid_range.stop - 1, td_type), **self.style_dict["valid_area"])
1152 def _plot_t0(self, ax, t0):
1153 ax.axvline(t0, **self.style_dict["t0"])
1155 def _plot_series(self, ax, time_axis, data, style):
1156 ax.plot(time_axis, data, **self.style_dict[style])
1158 def _plot_original_data(self, ax, time_axis, data):
1159 # original data
1160 filter_input_nc = data
1161 self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), style="original")
1162 # self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed",
1163 # label="original")
1165 def _plot_apriori(self, ax, time_axis, data, new_dim, ifilter, offset):
1166 # clim apriori
1167 filter_input = data
1168 if ifilter == 0:
1169 d_tmp = filter_input.sel(
1170 {new_dim: slice(offset, filter_input.coords[new_dim].values.max())}).values.flatten()
1171 else:
1172 d_tmp = filter_input.values.flatten()
1173 self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, style="apriori")
1174 # self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid",
1175 # label="estimated future")
1177 def _plot_clim_filter(self, ax, time_axis, data, new_dim, h, output_dtypes):
1178 filter_input = data
1179 # clim filter response
1180 filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input,
1181 input_core_dims=[[new_dim]],
1182 output_core_dims=[[new_dim]],
1183 vectorize=True,
1184 kwargs={"h": h},
1185 output_dtypes=[output_dtypes])
1186 self._plot_series(ax, time_axis, filt.values.flatten(), style="clim")
1187 # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="solid",
1188 # label="clim filter response", linewidth=2)
1189 residuum_estimated = filter_input - filt
1190 return residuum_estimated
1192 def _plot_ideal_filter(self, ax, time_axis, data, new_dim, h, output_dtypes):
1193 filter_input_nc = data
1194 # ideal filter response
1195 filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input_nc,
1196 input_core_dims=[[new_dim]],
1197 output_core_dims=[[new_dim]],
1198 vectorize=True,
1199 kwargs={"h": h},
1200 output_dtypes=[output_dtypes])
1201 self._plot_series(ax, time_axis, filt.values.flatten(), style="ideal")
1202 # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="dashed",
1203 # label="ideal filter response", linewidth=2)
1204 residuum_true = filter_input_nc - filt
1205 return residuum_true
1207 def _store_plot_data(self, data):
1208 """Store plot data. Could be loaded in a notebook to redraw."""
1209 file = os.path.join(self.plot_folder, "_".join(self.variables_list) + "plot_data.pickle")
1210 with open(file, "wb") as f:
1211 dill.dump(data, f)
1214class PlotFirFilter(AbstractPlotClass): # pragma: no cover
1215 """
1216 Plot FIR filter components.
1218 * Creates a separate folder FIR inside the given plot directory.
1219 * For each station up to 4 examples are shown (1 for each season).
1220 * Each filtered component and its residuum is drawn in a separate plot.
1221 * A filter component plot includes the FIR input and the filter response
1222 * A filter residuum plot include the FIR residuum
1223 """
1225 def __init__(self, plot_folder, plot_data, name):
1227 logging.info(f"start PlotFirFilter for ({name})")
1229 # adjust default plot parameters
1230 rc_params = {
1231 'axes.labelsize': 'large',
1232 'xtick.labelsize': 'large',
1233 'ytick.labelsize': 'large',
1234 'legend.fontsize': 'medium',
1235 'axes.titlesize': 'large'}
1236 if plot_folder is None:
1237 return
1239 self.style_dict = {
1240 "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"},
1241 "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"},
1242 "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2},
1243 "FIR": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2},
1244 "valid_area": {"color": "whitesmoke", "label": "valid area"},
1245 "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"}
1246 }
1248 plot_folder = os.path.join(os.path.abspath(plot_folder), "FIR")
1249 super().__init__(plot_folder, plot_name=None, rc_params=rc_params)
1250 plot_dict = self._prepare_data(plot_data)
1251 self._name = name
1252 self._plot(plot_dict)
1253 self._store_plot_data(plot_data)
1255 def _prepare_data(self, data):
1256 """Restructure plot data."""
1257 plot_dict = {}
1258 for i in range(len(data)): # filter component
1259 for j in range(len(data[i])): # t0 counter
1260 plot_data = data[i][j]
1261 t0 = plot_data.get("t0")
1262 filter_input = plot_data.get("filter_input")
1263 filtered = plot_data.get("filtered")
1264 var_dim = plot_data.get("var_dim")
1265 time_dim = plot_data.get("time_dim")
1266 for var in filtered.coords[var_dim].values:
1267 plot_dict_var = plot_dict.get(var, {})
1268 plot_dict_t0 = plot_dict_var.get(t0, {})
1269 plot_dict_order = {"filter_input": filter_input.sel({var_dim: var}, drop=True),
1270 "filtered": filtered.sel({var_dim: var}, drop=True),
1271 "time_dim": time_dim}
1272 plot_dict_t0[i] = plot_dict_order
1273 plot_dict_var[t0] = plot_dict_t0
1274 plot_dict[var] = plot_dict_var
1275 return plot_dict
1277 def _plot(self, plot_dict):
1278 for var, viz_date_dict in plot_dict.items():
1279 for it0, t0 in enumerate(viz_date_dict.keys()):
1280 viz_data = viz_date_dict[t0]
1281 try:
1282 for ifilter in sorted(viz_data.keys()):
1283 data = viz_data[ifilter]
1284 filter_input = data["filter_input"]
1285 filtered = data["filtered"]
1286 time_dim = data["time_dim"]
1287 time_axis = filtered.coords[time_dim].values
1288 fig, ax = plt.subplots()
1290 # plot backgrounds
1291 self._plot_t0(ax, t0)
1293 # original data
1294 self._plot_data(ax, time_axis, filter_input, style="original")
1296 # filter response
1297 self._plot_data(ax, time_axis, filtered, style="FIR")
1299 # set title, legend, and save plot
1300 ax.set_xlim((time_axis[0], time_axis[-1]))
1302 plt.title(f"Input of Filter ({str(var)})")
1303 plt.legend()
1304 fig.autofmt_xdate()
1305 plt.tight_layout()
1306 self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}"
1307 self._save()
1309 # plot residuum
1310 fig, ax = plt.subplots()
1311 self._plot_t0(ax, t0)
1312 self._plot_data(ax, time_axis, filter_input - filtered, style="FIR")
1313 ax.set_xlim((time_axis[0], time_axis[-1]))
1314 plt.title(f"Residuum of Filter ({str(var)})")
1315 plt.legend(loc="upper left")
1316 fig.autofmt_xdate()
1317 plt.tight_layout()
1319 self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum"
1320 self._save()
1321 except Exception as e:
1322 logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
1323 pass
1325 def _plot_t0(self, ax, t0):
1326 ax.axvline(t0, **self.style_dict["t0"])
1328 def _plot_series(self, ax, time_axis, data, style):
1329 ax.plot(time_axis, data, **self.style_dict[style])
1331 def _plot_data(self, ax, time_axis, data, style="original"):
1332 # original data
1333 self._plot_series(ax, time_axis, data.values.flatten(), style=style)
1335 def _store_plot_data(self, data):
1336 """Store plot data. Could be loaded in a notebook to redraw."""
1337 file = os.path.join(self.plot_folder, "plot_data.pickle")
1338 with open(file, "wb") as f:
1339 dill.dump(data, f)