Coverage for mlair/plotting/data_insight_plotting.py: 100%
7 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1"""Collection of plots to get more insight into data."""
2__author__ = "Lukas Leufen, Felix Kleinert"
3__date__ = '2021-04-13'
5from typing import List, Dict
6import dill
7import os
8import logging
9import multiprocessing
10import psutil
11import sys
13import numpy as np
14import pandas as pd
15import xarray as xr
16import matplotlib
17# matplotlib.use("Agg")
18from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates
19from astropy.timeseries import LombScargle
21from mlair.data_handler import DataCollection
22from mlair.helpers import TimeTrackingWrapper, to_list, remove_items
23from mlair.plotting.abstract_plot_class import AbstractPlotClass
26@TimeTrackingWrapper
27class PlotStationMap(AbstractPlotClass): # pragma: no cover
28 """
29 Plot geographical overview of all used stations as squares.
31 Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to
32 plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored
33 topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf
35 .. image:: ../../../../../_source/_plots/station_map.png
36 :width: 400
37 """
39 def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"):
40 """
41 Set attributes and create plot.
43 :param generators: dictionary with the plot color of each data set as key and the generator containing all stations
44 as value.
45 :param plot_folder: path to save the plot (default: current directory)
46 """
47 super().__init__(plot_folder, plot_name)
48 self._ax = None
49 self._gl = None
50 self._plot(generators)
51 self._save(bbox_inches="tight")
53 def _draw_background(self):
54 """Draw coastline, lakes, ocean, rivers and country borders as background on the map."""
56 import cartopy.feature as cfeature
58 self._ax.add_feature(cfeature.LAND.with_scale("50m"))
59 self._ax.natural_earth_shp(resolution='50m')
60 self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
61 self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
62 self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
63 self._ax.add_feature(cfeature.RIVERS.with_scale("50m"))
64 self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
66 def _plot_stations(self, generators):
67 """
68 Loop over all keys in generators dict and its containing stations and plot the stations's position.
70 Position is highlighted by a square on the map regarding the given color.
72 :param generators: dictionary with the plot color of each data set as key and the generator containing all
73 stations as value.
74 """
76 import cartopy.crs as ccrs
77 if generators is not None:
78 legend_elements = []
79 default_colors = self.get_dataset_colors()
80 for element in generators:
81 data_collection, plot_opts = self._get_collection_and_opts(element)
82 name = data_collection.name or "unknown"
83 marker = plot_opts.get("marker", "s")
84 ms = plot_opts.get("ms", 6)
85 mec = plot_opts.get("mec", "k")
86 mfc = plot_opts.get("mfc", default_colors.get(name, "b"))
87 legend_elements.append(
88 mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None',
89 label=f"{name} ({len(data_collection)})"))
90 for station in data_collection:
91 coords = station.get_coordinates()
92 IDx, IDy = coords["lon"], coords["lat"]
93 self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree())
94 if len(legend_elements) > 0:
95 self._ax.legend(handles=legend_elements, loc='best')
97 @staticmethod
98 def _adjust_marker(marker):
99 _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"}
100 if isinstance(marker, int) and marker in _adjust.keys():
101 return _adjust[marker]
102 else:
103 return marker
105 @staticmethod
106 def _get_collection_and_opts(element):
107 if isinstance(element, tuple):
108 if len(element) == 1:
109 return element[0], {}
110 else:
111 return element
112 else:
113 return element, {}
115 def _plot(self, generators: List):
116 """
117 Create the station map plot.
119 Set figure and call all required sub-methods.
121 :param generators: dictionary with the plot color of each data set as key and the generator containing all
122 stations as value.
123 """
125 import cartopy.crs as ccrs
126 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
127 fig = plt.figure(figsize=(10, 5))
128 self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
129 self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
130 self._gl.xformatter = LONGITUDE_FORMATTER
131 self._gl.yformatter = LATITUDE_FORMATTER
132 self._draw_background()
133 self._plot_stations(generators)
134 self._adjust_extent()
135 plt.tight_layout()
137 def _adjust_extent(self):
138 import cartopy.crs as ccrs
140 def diff(arr):
141 return arr[1] - arr[0], arr[3] - arr[2]
143 def find_ratio(delta, reference=5):
144 return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5)
146 extent = self._ax.get_extent(crs=ccrs.PlateCarree())
147 ratio = find_ratio(diff(extent))
148 new_extent = extent + np.array([-1, 1, -1, 1]) * ratio
149 self._ax.set_extent(new_extent, crs=ccrs.PlateCarree())
152@TimeTrackingWrapper
153class PlotAvailability(AbstractPlotClass): # pragma: no cover
154 """
155 Create data availablility plot similar to Gantt plot.
157 Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal
158 resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a
159 colored bar or a blank space.
161 You can set different colors to highlight subsets for example by providing different generators for the same index
162 using different keys in the input dictionary.
164 Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs
165 in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset.
167 Calling this class will create three versions fo the availability plot.
169 1) Data availability for each element
170 1) Data availability as summary over all elements (is there at least a single elemnt for each time step)
171 1) Combination of single and overall availability
173 .. image:: ../../../../../_source/_plots/data_availability.png
174 :width: 400
176 .. image:: ../../../../../_source/_plots/data_availability_summary.png
177 :width: 400
179 .. image:: ../../../../../_source/_plots/data_availability_combined.png
180 :width: 400
182 """
184 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
185 summary_name="data availability", time_dimension="datetime", window_dimension="window"):
186 """Initialise."""
187 # create standard Gantt plot for all stations (currently in single pdf file with single page)
188 super().__init__(plot_folder, "data_availability")
189 self.time_dim = time_dimension
190 self.window_dim = window_dimension
191 self.sampling = self._get_sampling(sampling)[1]
192 self.linewidth = None
193 if self.sampling == 'h':
194 self.linewidth = 0.001
195 plot_dict = self._prepare_data(generators)
196 lgd = self._plot(plot_dict)
197 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
198 # create summary Gantt plot (is data in at least one station available)
199 self.plot_name += "_summary"
200 plot_dict_summary = self._summarise_data(generators, summary_name)
201 lgd = self._plot(plot_dict_summary)
202 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
203 # combination of station and summary plot, last element is summary broken bar
204 self.plot_name = "data_availability_combined"
205 plot_dict_summary.update(plot_dict)
206 lgd = self._plot(plot_dict_summary)
207 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
209 def _prepare_data(self, generators: Dict[str, DataCollection]):
210 plt_dict = {}
211 for subset, data_collection in generators.items():
212 for station in data_collection:
213 labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
214 labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
215 group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum()
216 plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values},
217 index=labels.coords[self.time_dim].values)
218 t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
219 t2 = [i[1:] for i in t if i[0]]
221 if plt_dict.get(str(station)) is None:
222 plt_dict[str(station)] = {subset: t2}
223 else:
224 plt_dict[str(station)].update({subset: t2})
225 return plt_dict
227 def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
228 plt_dict = {}
229 for subset, data_collection in generators.items():
230 all_data = None
231 for station in data_collection:
232 labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
233 labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
234 if all_data is None:
235 all_data = labels_bool
236 else:
237 tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords
238 all_data = np.logical_or(tmp, labels_bool).combine_first(
239 all_data) # apply logical on merge and fill missing with all_data
241 group = (all_data != all_data.shift({self.time_dim: 1})).cumsum()
242 plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values},
243 index=all_data.coords[self.time_dim].values)
244 t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
245 t2 = [i[1:] for i in t if i[0]]
246 if plt_dict.get(summary_name) is None:
247 plt_dict[summary_name] = {subset: t2}
248 else:
249 plt_dict[summary_name].update({subset: t2})
250 return plt_dict
252 def _plot(self, plt_dict):
253 colors = self.get_dataset_colors()
254 _used_colors = []
255 pos = 0
256 height = 0.8 # should be <= 1
257 yticklabels = []
258 number_of_stations = len(plt_dict.keys())
259 fig, ax = plt.subplots(figsize=(10, number_of_stations / 3))
260 for station, d in sorted(plt_dict.items(), reverse=True):
261 pos += 1
262 for subset, color in colors.items():
263 plt_data = d.get(subset)
264 if plt_data is None:
265 continue
266 elif color not in _used_colors: # this is required for a proper legend creation
267 _used_colors.append(color)
268 ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth)
269 yticklabels.append(station)
271 ax.set_ylim([height, number_of_stations + 1])
272 ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2)
273 ax.set_yticklabels(yticklabels)
274 handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors]
275 lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
276 return lgd
279@TimeTrackingWrapper
280class PlotAvailabilityHistogram(AbstractPlotClass): # pragma: no cover
281 """
282 Create data availability plots as histogram.
284 Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean).
285 Calling this class creates two different types of histograms where each generator
287 1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis)
288 2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number
289 of samples (yaxis)
291 .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png
292 :width: 400
294 .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png
295 :width: 400
297 """
299 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".",
300 subset_dim: str = 'DataSet', history_dim: str = 'window',
301 station_dim: str = 'Stations', ):
303 super().__init__(plot_folder, "data_availability_histogram")
305 self.subset_dim = subset_dim
306 self.history_dim = history_dim
307 self.station_dim = station_dim
309 self.freq = None
310 self.temporal_dim = None
311 self.target_dim = None
312 self._prepare_data(generators)
314 for plt_type in self.allowed_plot_types:
315 plot_name_tmp = self.plot_name
316 self.plot_name += '_' + plt_type
317 self._plot(plt_type=plt_type)
318 self._save()
319 self.plot_name = plot_name_tmp
321 def _set_dims_from_datahandler(self, data_handler):
322 self.temporal_dim = data_handler.id_class.time_dim
323 self.target_dim = data_handler.id_class.target_dim
324 self.freq = self._get_sampling(data_handler.id_class.sampling)[1]
326 @property
327 def allowed_plot_types(self):
328 plot_types = ['hist', 'hist_cum']
329 return plot_types
331 def _prepare_data(self, generators: Dict[str, DataCollection]):
332 """
333 Prepares data to be used by plot methods.
335 Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim
336 """
337 avail_data_time_sum = {}
338 avail_data_station_sum = {}
339 dataset_time_interval = {}
340 for subset, generator in generators.items():
341 avail_list = []
342 for station in generator:
343 self._set_dims_from_datahandler(data_handler=station)
344 station_data_x = station.get_X(as_numpy=False)[0]
345 station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame
346 self.target_dim: station_data_x[self.target_dim].values[0]}]
347 station_data_x = self._reduce_dims(station_data_x)
348 avail_list.append(station_data_x.notnull())
349 avail_data = xr.concat(avail_list, dim=self.station_dim).notnull()
350 avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim)
351 avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim)
352 dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray(
353 avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict'
354 )
355 avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(),
356 name=self.subset_dim)
357 )
358 full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq)
359 self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(),
360 name=self.subset_dim))
361 self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index})
362 self.dataset_time_interval = dataset_time_interval
364 def _reduce_dims(self, dataset):
365 if len(dataset.dims) > 2:
366 required = {self.temporal_dim, self.station_dim}
367 unimportant = set(dataset.dims).difference(required)
368 sel_dict = {un: dataset[un].values[0] for un in unimportant}
369 dataset = dataset.loc[sel_dict]
370 return dataset
372 @staticmethod
373 def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'):
374 if isinstance(xarray, xr.DataArray):
375 first = xarray.coords[dim_name].values[0]
376 last = xarray.coords[dim_name].values[-1]
377 if return_type == 'as_tuple':
378 return first, last
379 elif return_type == 'as_dict':
380 return {'first': first, 'last': last}
381 else:
382 raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'")
383 else:
384 raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}")
386 @staticmethod
387 def _make_full_time_index(irregular_time_index, freq):
388 full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq)
389 return full_time_index
391 def _plot(self, plt_type='hist', *args):
392 if plt_type == 'hist':
393 self._plot_hist()
394 elif plt_type == 'hist_cum':
395 self._plot_hist_cum()
396 else:
397 raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}")
399 def _plot_hist(self, *args):
400 colors = self.get_dataset_colors()
401 fig, axes = plt.subplots(figsize=(10, 3))
402 for i, subset in enumerate(self.dataset_time_interval.keys()):
403 plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset,
404 self.temporal_dim: slice(
405 self.dataset_time_interval[subset]['first'],
406 self.dataset_time_interval[subset]['last']
407 )
408 }
409 )
411 plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset)
412 plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset])
414 lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
415 facecolor='white', framealpha=1, edgecolor='black')
416 for lgd_line in lgd.get_lines():
417 lgd_line.set_linewidth(4.0)
418 plt.gca().xaxis.set_major_locator(mdates.YearLocator())
419 plt.title('')
420 plt.ylabel('Number of samples')
421 plt.tight_layout()
423 def _plot_hist_cum(self, *args):
424 colors = self.get_dataset_colors()
425 fig, axes = plt.subplots(figsize=(10, 3))
426 n_bins = int(self.avail_data_cum_sum.max().values)
427 bins = np.arange(0, n_bins + 1)
428 descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby(
429 self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False
430 ).coords[self.subset_dim].values
432 for subset in descending_subsets:
433 self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes,
434 bins=bins,
435 label=subset,
436 cumulative=-1,
437 color=colors[subset],
438 # alpha=.5
439 )
441 lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
442 facecolor='white', framealpha=1, edgecolor='black')
443 plt.title('')
444 plt.ylabel('Number of stations')
445 plt.xlabel('Number of samples')
446 plt.xlim((bins[0], bins[-1]))
447 plt.tight_layout()
450@TimeTrackingWrapper
451class PlotDataHistogram(AbstractPlotClass): # pragma: no cover
452 """
453 Plot histogram on transformed input and target data. This data is the same that the model sees during training. No
454 plots are create for the original values space (raw / unformatted data). This plot method will create a histogram
455 for input and target each comparing the subsets train, val and test, as well as a distinct one for the three
456 subsets.
458 .. image:: ../../../../../_source/_plots/datahistogram.png
459 :width: 400
461 """
463 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", plot_name="histogram",
464 variables_dim="variables", time_dim="datetime", window_dim="window", upsampling=False):
465 super().__init__(plot_folder, plot_name)
466 self.variables_dim = variables_dim
467 self.time_dim = time_dim
468 self.window_dim = window_dim
469 self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim)
470 self.bins = {}
471 self.interval_width = {}
472 self.bin_edges = {}
473 if upsampling is True:
474 self._handle_upsampling(generators)
476 # input plots
477 for branch_pos in range(number_of_branches):
478 self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos)
479 add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}"
480 for subset in generators.keys():
481 self._plot(add_name=add_name, subset=subset)
482 self._plot_combined(add_name=add_name)
484 # target plots
485 self._calculate_hist(generators, self.targets, input_data=False)
486 for subset in generators.keys():
487 self._plot(add_name="target", subset=subset)
488 self._plot_combined(add_name="target")
490 @staticmethod
491 def _handle_upsampling(generators):
492 if "train" in generators:
493 generators.update({"train_upsampled": generators["train"]})
495 @staticmethod
496 def _get_inputs_targets(gens, dim):
497 k = list(gens.keys())[0]
498 gen = gens[k][0]
499 inputs = list(set([y for x in to_list(gen.get_X(as_numpy=False)) for y in x.coords[dim].values.tolist()]))
500 targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist())
501 n_branches = len(gen.get_X(as_numpy=False))
502 return inputs, targets, n_branches
504 def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0):
505 n_bins = 100
506 for set_type, generator in generators.items():
507 upsampling = "upsampled" in set_type
508 tmp_bins = {}
509 tmp_edges = {}
510 end = {}
511 start = {}
512 if input_data is True:
513 f = lambda x: x.get_X(as_numpy=False, upsampling=upsampling)[branch_pos]
514 else:
515 f = lambda x: x.get_Y(as_numpy=False, upsampling=upsampling)
516 for gen in generator:
517 w = min(abs(f(gen).coords[self.window_dim].values))
518 data = f(gen).sel({self.window_dim: w})
519 res, _, g_edges = f_proc_hist(data, variables, n_bins, self.variables_dim)
520 for var in res.keys():
521 b = tmp_bins.get(var, [])
522 b.append(res[var])
523 tmp_bins[var] = b
524 e = tmp_edges.get(var, [])
525 e.append(g_edges[var])
526 tmp_edges[var] = e
527 end[var] = max([end.get(var, g_edges[var].max()), g_edges[var].max()])
528 start[var] = min([start.get(var, g_edges[var].min()), g_edges[var].min()])
529 # interpolate and aggregate
530 bins = {}
531 edges = {}
532 interval_width = {}
533 for var in tmp_bins.keys():
534 bin_edges = np.linspace(start[var], end[var], n_bins + 1)
535 interval_width[var] = bin_edges[1] - bin_edges[0]
536 for i, e in enumerate(tmp_bins[var]):
537 bins_interp = np.interp(bin_edges[:-1], tmp_edges[var][i][:-1], e, left=0, right=0)
538 bins[var] = bins.get(var, np.zeros(n_bins)) + bins_interp
539 edges[var] = bin_edges
541 self.bins[set_type] = bins
542 self.interval_width[set_type] = interval_width
543 self.bin_edges[set_type] = edges
545 def _plot(self, add_name, subset):
546 plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{subset}_{add_name}.pdf")
547 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
548 bins = self.bins[subset]
549 bin_edges = self.bin_edges[subset]
550 interval_width = self.interval_width[subset]
551 colors = self.get_dataset_colors()
552 colors.update({"train_upsampled": colors.get("train_val", "#000000")})
553 for var in bins.keys():
554 fig, ax = plt.subplots()
555 hist_var = bins[var]
556 n_var = sum(hist_var)
557 weights = hist_var / (interval_width[var] * n_var)
558 ax.hist(bin_edges[var][:-1], bin_edges[var], weights=weights, color=colors[subset])
559 ax.set_ylabel("probability density")
560 ax.set_xlabel(f"values")
561 ax.set_title(f"histogram {var} ({subset}, n={int(n_var)})")
562 pdf_pages.savefig()
563 # close all open figures / plots
564 pdf_pages.close()
565 plt.close('all')
567 def _plot_combined(self, add_name):
568 plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{add_name}.pdf")
569 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
570 variables = self.bins[list(self.bins.keys())[0]].keys()
571 colors = self.get_dataset_colors()
572 colors.update({"train_upsampled": colors.get("train_val", "#000000")})
573 for var in variables:
574 fig, ax = plt.subplots()
575 for subset in self.bins.keys():
576 hist_var = self.bins[subset][var]
577 interval_width = self.interval_width[subset][var]
578 bin_edges = self.bin_edges[subset][var]
579 n_var = sum(hist_var)
580 weights = hist_var / (interval_width * n_var)
581 ax.plot(bin_edges[:-1] + 0.5 * interval_width, weights, label=f"{subset}",
582 c=colors[subset])
583 ax.set_ylabel("probability density")
584 ax.set_xlabel("values")
585 ax.legend(loc="upper right")
586 ax.set_title(f"histogram {var}")
587 pdf_pages.savefig()
588 # close all open figures / plots
589 pdf_pages.close()
590 plt.close('all')
593@TimeTrackingWrapper
594class PlotPeriodogram(AbstractPlotClass): # pragma: no cover
595 """
596 Create Lomb-Scargle periodogram in raw input and target data. The Lomb-Scargle version can deal with missing values.
598 This plot routine is creating the following plots:
600 * "raw": data is not aggregated, 1 graph per variable
601 * "": single data lines are aggregated, 1 graph per variable
602 * "total": data is aggregated on all variables, single graph
604 If data consists on different sampling rates, a separate plot is create for each sampling.
606 .. image:: ../../../../../_source/_plots/periodogram.png
607 :width: 400
609 .. note::
610 This plot is not included in the default plot list. To use this plot, add "PlotPeriodogram" to the `plot_list`.
612 .. warning::
613 This plot is highly sensitive to the data handler structure. Therefore, it is highly likely that this method is
614 not compatible with any custom data handler. Proven data handlers are `DefaultDataHandler`,
615 `DataHandlerMixedSampling`, `DataHandlerMixedSamplingWithFilter`. To work properly, the data handler must have
616 the attribute `.id_class._data`.
618 """
620 def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="periodogram",
621 variables_dim="variables", time_dim="datetime", sampling="daily", use_multiprocessing=False):
622 super().__init__(plot_folder, plot_name)
623 self.variables_dim = variables_dim
624 self.time_dim = time_dim
626 for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling, sampling)):
627 self._sampling = s
628 self._add_text = {0: "input", 1: "target"}[pos]
629 multiple, label_names = self._has_filter_dimension(generator[0], pos)
630 self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing)
631 self._plot(raw=True)
632 self._plot(raw=False)
633 self._plot_total(raw=True)
634 self._plot_total(raw=False)
635 if multiple > 1:
636 self._plot_difference(label_names, plot_name_add="_last")
637 self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing,
638 use_last_input_value=False)
639 self._plot_difference(label_names, plot_name_add="_first")
641 @staticmethod
642 def _has_filter_dimension(g, pos):
643 """Inspect if filtered data is provided and return number and labels of filtered components."""
644 check_class = g.id_class
645 check_data = [check_class.get_X(as_numpy=False), check_class.get_Y(as_numpy=False)][pos]
646 if not hasattr(check_class, "filter_dim"): # data handler has no filtered data
647 return 1, []
648 else:
649 filter_dim = check_class.filter_dim
650 if filter_dim not in check_data.coords.dims: # current data has no filter (e.g. target data)
651 return 1, []
652 else:
653 return check_data.coords[filter_dim].shape[0], check_data.coords[filter_dim].values.tolist()
655 @TimeTrackingWrapper
656 def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False, use_last_input_value=True):
657 """
658 Create periodogram data.
659 """
660 self.raw_data = []
661 self.plot_data = []
662 self.plot_data_raw = []
663 self.plot_data_mean = []
664 iter = range(multiple if multiple == 1 else multiple + 1)
665 for m in iter:
666 plot_data_single = dict()
667 plot_data_raw_single = dict()
668 plot_data_mean_single = dict()
669 self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000)
670 raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing,
671 use_last_input_value=use_last_input_value)
672 for var in raw_data_single.keys():
673 pgram_com = []
674 pgram_mean = 0
675 all_data = raw_data_single[var]
676 pgram_mean_raw = np.zeros((len(self.f_index), len(all_data)))
677 for i, (f, pgram) in enumerate(all_data):
678 d = np.interp(self.f_index, f, pgram)
679 pgram_com.append(d)
680 pgram_mean += d
681 pgram_mean_raw[:, i] = d
682 pgram_mean /= len(all_data)
683 plot_data_single[var] = pgram_com
684 plot_data_mean_single[var] = (self.f_index, pgram_mean)
685 plot_data_raw_single[var] = (self.f_index, pgram_mean_raw)
686 self.plot_data.append(plot_data_single)
687 self.plot_data_mean.append(plot_data_mean_single)
688 self.plot_data_raw.append(plot_data_raw_single)
690 def _prepare_pgram_parallel_var(self, generator, m, pos, use_multiprocessing):
691 """Implementation of data preprocessing using parallel variables element processing."""
692 raw_data_single = dict()
693 for g in generator:
694 if m == 0:
695 d = g.id_class._data
696 else:
697 gd = g.id_class
698 filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]}
699 d = (gd.input_data.sel(filter_sel), gd.target_data)
700 d = d[pos] if isinstance(d, tuple) else d
701 res = []
702 if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution
703 pool = multiprocessing.Pool(
704 min([psutil.cpu_count(logical=False), len(d[self.variables_dim].values),
705 16])) # use only physical cpus
706 output = [
707 pool.apply_async(f_proc,
708 args=(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
709 for var in d[self.variables_dim].values]
710 for i, p in enumerate(output):
711 res.append(p.get())
712 pool.close()
713 pool.join()
714 else: # serial solution
715 for var in d[self.variables_dim].values:
716 res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
717 for (var_str, f, pgram) in res:
718 if var_str not in raw_data_single.keys():
719 raw_data_single[var_str] = [(f, pgram)]
720 else:
721 raw_data_single[var_str] = raw_data_single[var_str] + [(f, pgram)]
722 return raw_data_single
724 def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing, use_last_input_value=True):
725 """Implementation of data preprocessing using parallel generator element processing."""
726 raw_data_single = dict()
727 res = []
728 if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution
729 pool = multiprocessing.Pool(
730 min([psutil.cpu_count(logical=False), len(generator), 16])) # use only physical cpus
731 output = [
732 pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim, self.f_index,
733 use_last_input_value))
734 for g in generator]
735 for i, p in enumerate(output):
736 res.append(p.get())
737 pool.close()
738 pool.join()
739 else:
740 for g in generator:
741 res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim, self.f_index, use_last_input_value))
742 for res_dict in res:
743 for k, v in res_dict.items():
744 if k not in raw_data_single.keys():
745 raw_data_single[k] = v
746 else:
747 raw_data_single[k] = raw_data_single[k] + v
748 return raw_data_single
750 @staticmethod
751 def _add_annotation_line(ax, pos, div, lims, unit):
752 for p in to_list(pos): # per year
753 ax.vlines(p / div, *lims, "black")
754 ax.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor")
756 def _format_figure(self, ax, var_name="total"):
757 """
758 Set log scale on both axis, add labels and annotation lines, and set title.
759 :param ax: current ax object
760 :param var_name: name of variable that will be included in the title
761 """
762 ax.set_yscale('log')
763 ax.set_xscale('log')
764 ax.set_ylabel("power spectral density", fontsize='x-large') # unit depends on variable: [unit^2 day^-1]
765 ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large')
766 lims = ax.get_ylim()
767 self._add_annotation_line(ax, [1, 2, 3], 365.25, lims, "yr") # per year
768 self._add_annotation_line(ax, 1, 365.25 / 12, lims, "m") # per month
769 self._add_annotation_line(ax, 1, 7, lims, "w") # per week
770 self._add_annotation_line(ax, [1, 0.5], 1, lims, "d") # per day
771 if self._sampling == "hourly":
772 self._add_annotation_line(ax, 2, 1, lims, "d") # per day
773 self._add_annotation_line(ax, [1, 0.5], 1 / 24., lims, "h") # per hour
774 title = f"Periodogram ({var_name})"
775 ax.set_title(title)
777 def _plot(self, raw=True):
778 plot_path = os.path.join(os.path.abspath(self.plot_folder),
779 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}.pdf")
780 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
781 plot_data = self.plot_data[0]
782 plot_data_mean = self.plot_data_mean[0]
783 for var in plot_data.keys():
784 fig, ax = plt.subplots()
785 if raw is True:
786 for pgram in plot_data[var]:
787 ax.plot(self.f_index, pgram, "lightblue")
788 ax.plot(*plot_data_mean[var], "blue")
789 else:
790 ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
791 mean = ma.mean().mean(axis=1).values.flatten()
792 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
793 ax.plot(self.f_index, mean, "blue")
794 ax.fill_between(self.f_index, lower, upper, color="lightblue")
795 self._format_figure(ax, var)
796 pdf_pages.savefig()
797 # close all open figures / plots
798 pdf_pages.close()
799 plt.close('all')
801 def _plot_total(self, raw=True):
802 plot_path = os.path.join(os.path.abspath(self.plot_folder),
803 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}_total.pdf")
804 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
805 plot_data_raw = self.plot_data_raw[0]
806 fig, ax = plt.subplots()
807 res = None
808 for var in plot_data_raw.keys():
809 d_var = plot_data_raw[var][1]
810 res = d_var if res is None else np.concatenate((res, d_var), axis=-1)
811 if raw is True:
812 for i in range(res.shape[1]):
813 ax.plot(self.f_index, res[:, i], "lightblue")
814 ax.plot(self.f_index, res.mean(axis=1), "blue")
815 else:
816 ma = pd.DataFrame(np.vstack(res)).rolling(5, center=True, axis=0)
817 mean = ma.mean().mean(axis=1).values.flatten()
818 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
819 ax.plot(self.f_index, mean, "blue")
820 ax.fill_between(self.f_index, lower, upper, color="lightblue")
821 self._format_figure(ax, "total")
822 pdf_pages.savefig()
823 # close all open figures / plots
824 pdf_pages.close()
825 plt.close('all')
827 def _plot_difference(self, label_names, plot_name_add = ""):
828 plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter{plot_name_add}.pdf"
829 plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name)
830 logging.info(f"... plotting {plot_name}")
831 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
832 colors = ["grey", "blue", "red", "green", "orange", "purple", "black"]
833 label_names = ["orig"] + label_names
834 max_iter = len(self.plot_data)
835 var_keys = self.plot_data[0].keys()
836 for var in var_keys:
837 fig, ax = plt.subplots()
838 for i in reversed(range(max_iter)):
839 if label_names[i] == "unfiltered":
840 continue # do not include the filter 'unfiltered' because this is equal to the 'orig' data
841 plot_data = self.plot_data[i]
842 c = colors[i]
843 ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
844 mean = ma.mean().mean(axis=1).values.flatten()
845 ax.plot(self.f_index, mean, c, label=label_names[i])
846 if i < 1:
847 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
848 ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5, label=None)
849 self._format_figure(ax, var)
850 ax.legend(loc="upper center", ncol=max_iter)
851 pdf_pages.savefig()
852 # close all open figures / plots
853 pdf_pages.close()
854 plt.close('all')
857def f_proc(var, d_var, f_index, time_dim="datetime", use_last_value=True): # pragma: no cover
858 var_str = str(var)
859 t = (d_var[time_dim] - d_var[time_dim][0]).astype("timedelta64[h]").values / np.timedelta64(1, "D")
860 if len(d_var.shape) > 1: # use only max value if dimensions are remaining (e.g. max(window) -> latest value)
861 to_remove = remove_items(d_var.coords.dims, time_dim)
862 for e in to_list(to_remove):
863 d_var = d_var.sel({e: d_var[e].max() if use_last_value is True else d_var[e].min()})
864 pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").power(f_index)
865 # f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").autopower()
866 return var_str, f_index, pgram
869def f_proc_2(g, m, pos, variables_dim, time_dim, f_index, use_last_value): # pragma: no cover
871 # load lazy data
872 id_classes = list(filter(lambda x: "id_class" in x, dir(g))) if pos == 0 else ["id_class"]
873 for id_cls_name in id_classes:
874 id_cls = getattr(g, id_cls_name)
875 if hasattr(id_cls, "lazy"):
876 id_cls.load_lazy() if id_cls.lazy is True else None
878 raw_data_single = dict()
879 for dh in list(filter(lambda x: "unfiltered" not in x, id_classes)):
880 current_cls = getattr(g, dh)
881 if m == 0:
882 d = current_cls._data
883 if d is None:
884 window_dim = current_cls.window_dim
885 history = current_cls.history
886 last_entry = history.coords[window_dim][-1]
887 d1 = history.sel({window_dim: last_entry}, drop=True)
888 label = current_cls.label
889 first_entry = label.coords[window_dim][0]
890 d2 = label.sel({window_dim: first_entry}, drop=True)
891 d = (d1, d2)
892 else:
893 filter_sel = {"filter": current_cls.input_data.coords["filter"][m - 1]}
894 d = (current_cls.input_data.sel(filter_sel), current_cls.target_data)
895 d = d[pos] if isinstance(d, tuple) else d
896 for var in d[variables_dim].values:
897 d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim)
898 var_str, f, pgram = f_proc(var, d_var, f_index, use_last_value=use_last_value)
899 if var_str not in raw_data_single.keys():
900 raw_data_single[var_str] = [(f, pgram)]
901 else:
902 raise KeyError(f"There are multiple pgrams for key {var_str}. Please check your data handler.")
904 # perform clean up
905 for id_cls_name in id_classes:
906 id_cls = getattr(g, id_cls_name)
907 if hasattr(id_cls, "lazy"):
908 id_cls.clean_up() if id_cls.lazy is True else None
910 return raw_data_single
913def f_proc_hist(data, variables, n_bins, variables_dim): # pragma: no cover
914 res = {}
915 bin_edges = {}
916 interval_width = {}
917 for var in variables:
918 if var in data.coords[variables_dim]:
919 d = data.sel({variables_dim: var}).squeeze() if len(data.shape) > 1 else data
920 res[var], bin_edges[var] = np.histogram(d.values, n_bins)
921 interval_width[var] = bin_edges[var][1] - bin_edges[var][0]
922 return res, interval_width, bin_edges
925class PlotClimateFirFilter(AbstractPlotClass): # pragma: no cover
926 """
927 Plot climate FIR filter components.
929 * Creates a separate folder climFIR inside the given plot directory.
930 * For each station up to 4 examples are shown (1 for each season).
931 * Each filtered component and its residuum is drawn in a separate plot.
932 * A filter component plot includes the climate FIR input, the filter response, the true non-causal (ideal) filter
933 input, and the corresponding ideal response (containing information about future)
934 * A filter residuum plot include the climate FIR residuum and the ideal filter residuum.
935 """
937 def __init__(self, plot_folder, plot_data, sampling, name):
939 from mlair.helpers.filter import fir_filter_convolve
941 logging.info(f"start PlotClimateFirFilter for ({name})")
943 # adjust default plot parameters
944 rc_params = {
945 'axes.labelsize': 'large',
946 'xtick.labelsize': 'large',
947 'ytick.labelsize': 'large',
948 'legend.fontsize': 'medium',
949 'axes.titlesize': 'large'}
950 if plot_folder is None:
951 return
953 self.style_dict = {
954 "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"},
955 "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"},
956 "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2},
957 "ideal": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2},
958 "valid_area": {"color": "whitesmoke", "label": "valid area"},
959 "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"}
960 }
962 self.variables_list = []
963 plot_folder = os.path.join(os.path.abspath(plot_folder), "climFIR")
964 self.fir_filter_convolve = fir_filter_convolve
965 super().__init__(plot_folder, plot_name=None, rc_params=rc_params)
966 plot_dict, new_dim = self._prepare_data(plot_data)
967 self._name = name
968 self._plot(plot_dict, sampling, new_dim)
969 self._store_plot_data(plot_data)
971 def _prepare_data(self, data):
972 """Restructure plot data."""
973 plot_dict = {}
974 new_dim = None
975 for i in range(len(data)):
976 plot_data = data[i]
977 for p_d in plot_data:
978 var = p_d.get("var")
979 t0 = p_d.get("t0")
980 filter_input = p_d.get("filter_input")
981 filter_input_nc = p_d.get("filter_input_nc")
982 valid_range = p_d.get("valid_range")
983 time_range = p_d.get("time_range")
984 if new_dim is None:
985 new_dim = p_d.get("new_dim")
986 else:
987 assert new_dim == p_d.get("new_dim")
988 h = p_d.get("h")
989 plot_dict_var = plot_dict.get(var, {})
990 plot_dict_t0 = plot_dict_var.get(t0, {})
991 plot_dict_order = {"filter_input": filter_input,
992 "filter_input_nc": filter_input_nc,
993 "valid_range": valid_range,
994 "time_range": time_range,
995 "order": len(h), "h": h}
996 plot_dict_t0[i] = plot_dict_order
997 plot_dict_var[t0] = plot_dict_t0
998 plot_dict[var] = plot_dict_var
999 self.variables_list = list(plot_dict.keys())
1000 return plot_dict, new_dim
1002 def _plot(self, plot_dict, sampling, new_dim="window"):
1003 td_type = {"1d": "D", "1H": "h"}.get(sampling)
1004 for var, vis_dict in plot_dict.items():
1005 for it0, t0 in enumerate(vis_dict.keys()):
1006 vis_data = vis_dict[t0]
1007 residuum_true = None
1008 try:
1009 for ifilter in sorted(vis_data.keys()):
1010 data = vis_data[ifilter]
1011 filter_input = data["filter_input"]
1012 filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel(
1013 {new_dim: filter_input.coords[new_dim]})
1014 valid_range = data["valid_range"]
1015 time_axis = data["time_range"]
1016 filter_order = data["order"]
1017 h = data["h"]
1018 fig, ax = plt.subplots()
1020 # plot backgrounds
1021 self._plot_valid_area(ax, t0, valid_range, td_type)
1022 self._plot_t0(ax, t0)
1024 # original data
1025 self._plot_original_data(ax, time_axis, filter_input_nc)
1027 # clim apriori
1028 self._plot_apriori(ax, time_axis, filter_input, new_dim, ifilter, offset=1)
1030 # clim filter response
1031 residuum_estimated = self._plot_clim_filter(ax, time_axis, filter_input, new_dim, h,
1032 output_dtypes=filter_input.dtype)
1034 # ideal filter response
1035 residuum_true = self._plot_ideal_filter(ax, time_axis, filter_input_nc, new_dim, h,
1036 output_dtypes=filter_input.dtype)
1038 # set title, legend, and save plot
1039 xlims = self._set_xlim(ax, t0, filter_order, valid_range, td_type, time_axis)
1041 plt.title(f"Input of ClimFilter ({str(var)})")
1042 plt.legend()
1043 fig.autofmt_xdate()
1044 plt.tight_layout()
1045 self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}"
1046 self._save()
1048 # plot residuum
1049 fig, ax = plt.subplots()
1050 self._plot_valid_area(ax, t0, valid_range, td_type)
1051 self._plot_t0(ax, t0)
1052 self._plot_series(ax, time_axis, residuum_true.values.flatten(), style="ideal")
1053 self._plot_series(ax, time_axis, residuum_estimated.values.flatten(), style="clim")
1054 ax.set_xlim(xlims)
1055 plt.title(f"Residuum of ClimFilter ({str(var)})")
1056 plt.legend(loc="upper left")
1057 fig.autofmt_xdate()
1058 plt.tight_layout()
1060 self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum"
1061 self._save()
1062 except Exception as e:
1063 logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
1064 pass
1066 def _set_xlim(self, ax, t0, order, valid_range, td_type, time_axis):
1067 """
1068 Set xlims
1070 Use order and valid_range to find a good zoom in that hides edges of filter values that are effected by reduced
1071 filter order. Limits are returned to be usable for other plots.
1072 """
1073 t_minus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), (-valid_range.start + 0.3 * order))
1074 t_plus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), valid_range.stop + 0.3 * order)
1075 t_minus = t0 + np.timedelta64(-int(t_minus_delta), td_type)
1076 t_plus = t0 + np.timedelta64(int(t_plus_delta), td_type)
1077 ax_start = max(t_minus, time_axis[0])
1078 ax_end = min(t_plus, time_axis[-1])
1079 ax.set_xlim((ax_start, ax_end))
1080 return ax_start, ax_end
1082 def _plot_valid_area(self, ax, t0, valid_range, td_type):
1083 ax.axvspan(t0 + np.timedelta64(valid_range.start, td_type),
1084 t0 + np.timedelta64(valid_range.stop - 1, td_type), **self.style_dict["valid_area"])
1086 def _plot_t0(self, ax, t0):
1087 ax.axvline(t0, **self.style_dict["t0"])
1089 def _plot_series(self, ax, time_axis, data, style):
1090 ax.plot(time_axis, data, **self.style_dict[style])
1092 def _plot_original_data(self, ax, time_axis, data):
1093 # original data
1094 filter_input_nc = data
1095 self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), style="original")
1096 # self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed",
1097 # label="original")
1099 def _plot_apriori(self, ax, time_axis, data, new_dim, ifilter, offset):
1100 # clim apriori
1101 filter_input = data
1102 if ifilter == 0:
1103 d_tmp = filter_input.sel(
1104 {new_dim: slice(offset, filter_input.coords[new_dim].values.max())}).values.flatten()
1105 else:
1106 d_tmp = filter_input.values.flatten()
1107 self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, style="apriori")
1108 # self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid",
1109 # label="estimated future")
1111 def _plot_clim_filter(self, ax, time_axis, data, new_dim, h, output_dtypes):
1112 filter_input = data
1113 # clim filter response
1114 filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input,
1115 input_core_dims=[[new_dim]],
1116 output_core_dims=[[new_dim]],
1117 vectorize=True,
1118 kwargs={"h": h},
1119 output_dtypes=[output_dtypes])
1120 self._plot_series(ax, time_axis, filt.values.flatten(), style="clim")
1121 # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="solid",
1122 # label="clim filter response", linewidth=2)
1123 residuum_estimated = filter_input - filt
1124 return residuum_estimated
1126 def _plot_ideal_filter(self, ax, time_axis, data, new_dim, h, output_dtypes):
1127 filter_input_nc = data
1128 # ideal filter response
1129 filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input_nc,
1130 input_core_dims=[[new_dim]],
1131 output_core_dims=[[new_dim]],
1132 vectorize=True,
1133 kwargs={"h": h},
1134 output_dtypes=[output_dtypes])
1135 self._plot_series(ax, time_axis, filt.values.flatten(), style="ideal")
1136 # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="dashed",
1137 # label="ideal filter response", linewidth=2)
1138 residuum_true = filter_input_nc - filt
1139 return residuum_true
1141 def _store_plot_data(self, data):
1142 """Store plot data. Could be loaded in a notebook to redraw."""
1143 file = os.path.join(self.plot_folder, "_".join(self.variables_list) + "plot_data.pickle")
1144 with open(file, "wb") as f:
1145 dill.dump(data, f)
1148class PlotFirFilter(AbstractPlotClass): # pragma: no cover
1149 """
1150 Plot FIR filter components.
1152 * Creates a separate folder FIR inside the given plot directory.
1153 * For each station up to 4 examples are shown (1 for each season).
1154 * Each filtered component and its residuum is drawn in a separate plot.
1155 * A filter component plot includes the FIR input and the filter response
1156 * A filter residuum plot include the FIR residuum
1157 """
1159 def __init__(self, plot_folder, plot_data, name):
1161 logging.info(f"start PlotFirFilter for ({name})")
1163 # adjust default plot parameters
1164 rc_params = {
1165 'axes.labelsize': 'large',
1166 'xtick.labelsize': 'large',
1167 'ytick.labelsize': 'large',
1168 'legend.fontsize': 'medium',
1169 'axes.titlesize': 'large'}
1170 if plot_folder is None:
1171 return
1173 self.style_dict = {
1174 "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"},
1175 "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"},
1176 "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2},
1177 "FIR": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2},
1178 "valid_area": {"color": "whitesmoke", "label": "valid area"},
1179 "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"}
1180 }
1182 plot_folder = os.path.join(os.path.abspath(plot_folder), "FIR")
1183 super().__init__(plot_folder, plot_name=None, rc_params=rc_params)
1184 plot_dict = self._prepare_data(plot_data)
1185 self._name = name
1186 self._plot(plot_dict)
1187 self._store_plot_data(plot_data)
1189 def _prepare_data(self, data):
1190 """Restructure plot data."""
1191 plot_dict = {}
1192 for i in range(len(data)): # filter component
1193 for j in range(len(data[i])): # t0 counter
1194 plot_data = data[i][j]
1195 t0 = plot_data.get("t0")
1196 filter_input = plot_data.get("filter_input")
1197 filtered = plot_data.get("filtered")
1198 var_dim = plot_data.get("var_dim")
1199 time_dim = plot_data.get("time_dim")
1200 for var in filtered.coords[var_dim].values:
1201 plot_dict_var = plot_dict.get(var, {})
1202 plot_dict_t0 = plot_dict_var.get(t0, {})
1203 plot_dict_order = {"filter_input": filter_input.sel({var_dim: var}, drop=True),
1204 "filtered": filtered.sel({var_dim: var}, drop=True),
1205 "time_dim": time_dim}
1206 plot_dict_t0[i] = plot_dict_order
1207 plot_dict_var[t0] = plot_dict_t0
1208 plot_dict[var] = plot_dict_var
1209 return plot_dict
1211 def _plot(self, plot_dict):
1212 for var, viz_date_dict in plot_dict.items():
1213 for it0, t0 in enumerate(viz_date_dict.keys()):
1214 viz_data = viz_date_dict[t0]
1215 try:
1216 for ifilter in sorted(viz_data.keys()):
1217 data = viz_data[ifilter]
1218 filter_input = data["filter_input"]
1219 filtered = data["filtered"]
1220 time_dim = data["time_dim"]
1221 time_axis = filtered.coords[time_dim].values
1222 fig, ax = plt.subplots()
1224 # plot backgrounds
1225 self._plot_t0(ax, t0)
1227 # original data
1228 self._plot_data(ax, time_axis, filter_input, style="original")
1230 # filter response
1231 self._plot_data(ax, time_axis, filtered, style="FIR")
1233 # set title, legend, and save plot
1234 ax.set_xlim((time_axis[0], time_axis[-1]))
1236 plt.title(f"Input of Filter ({str(var)})")
1237 plt.legend()
1238 fig.autofmt_xdate()
1239 plt.tight_layout()
1240 self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}"
1241 self._save()
1243 # plot residuum
1244 fig, ax = plt.subplots()
1245 self._plot_t0(ax, t0)
1246 self._plot_data(ax, time_axis, filter_input - filtered, style="FIR")
1247 ax.set_xlim((time_axis[0], time_axis[-1]))
1248 plt.title(f"Residuum of Filter ({str(var)})")
1249 plt.legend(loc="upper left")
1250 fig.autofmt_xdate()
1251 plt.tight_layout()
1253 self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum"
1254 self._save()
1255 except Exception as e:
1256 logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
1257 pass
1259 def _plot_t0(self, ax, t0):
1260 ax.axvline(t0, **self.style_dict["t0"])
1262 def _plot_series(self, ax, time_axis, data, style):
1263 ax.plot(time_axis, data, **self.style_dict[style])
1265 def _plot_data(self, ax, time_axis, data, style="original"):
1266 # original data
1267 self._plot_series(ax, time_axis, data.values.flatten(), style=style)
1269 def _store_plot_data(self, data):
1270 """Store plot data. Could be loaded in a notebook to redraw."""
1271 file = os.path.join(self.plot_folder, "plot_data.pickle")
1272 with open(file, "wb") as f:
1273 dill.dump(data, f)