Coverage for mlair/plotting/abstract_plot_class.py: 100%
2 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1"""Abstract plot class that should be used for preprocessing and postprocessing plots."""
2__author__ = "Lukas Leufen"
3__date__ = '2021-04-13'
5import logging
6import os
8from matplotlib import pyplot as plt
11class AbstractPlotClass: # pragma: no cover
12 """
13 Abstract class for all plotting routines to unify plot workflow.
15 Each inheritance requires a _plot method. Create a plot class like:
17 .. code-block:: python
19 class MyCustomPlot(AbstractPlotClass):
21 def __init__(self, plot_folder, *args, **kwargs):
22 super().__init__(plot_folder, "custom_plot_name")
23 self._data = self._prepare_data(*args, **kwargs)
24 self._plot(*args, **kwargs)
25 self._save()
27 def _prepare_data(*args, **kwargs):
28 <your custom data preparation>
29 return data
31 def _plot(*args, **kwargs):
32 <your custom plotting without saving>
34 The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are
35 using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be
36 set in super class initialisation).
38 Methods like the shown _prepare_data() are optional. The only method required to implement is _plot.
40 If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot
41 class. It will log the spent time if you call your plotting without saving the returned object.
43 .. code-block:: python
45 @TimeTrackingWrapper
46 class MyCustomPlot(AbstractPlotClass):
47 pass
49 Let's assume it takes a while to create this very special plot.
51 >>> MyCustomPlot()
52 INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss)
54 """
56 def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None):
57 """Set up plot folder and name, and plot resolution (default 500dpi)."""
58 plot_folder = os.path.abspath(plot_folder)
59 if not os.path.exists(plot_folder):
60 os.makedirs(plot_folder)
61 self.plot_folder = plot_folder
62 self.plot_name = plot_name.replace("/", "_") if plot_name is not None else plot_name
63 self.resolution = resolution
64 if rc_params is None:
65 rc_params = {'axes.labelsize': 'large',
66 'xtick.labelsize': 'large',
67 'ytick.labelsize': 'large',
68 'legend.fontsize': 'large',
69 'axes.titlesize': 'large',
70 }
71 self.rc_params = rc_params
72 self._update_rc_params()
74 def __del__(self):
75 try:
76 plt.close('all')
77 except ImportError:
78 pass
80 def _plot(self, *args):
81 """Abstract plot class needs to be implemented in inheritance."""
82 raise NotImplementedError
84 def _save(self, **kwargs):
85 """Store plot locally. Name of and path to plot need to be set on initialisation."""
86 plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
87 logging.debug(f"... save plot to {plot_name}")
88 plt.savefig(plot_name, dpi=self.resolution, **kwargs)
89 plt.close('all')
91 def _update_rc_params(self):
92 plt.rcParams.update(self.rc_params)
94 @staticmethod
95 def _get_sampling(sampling, pos=1):
96 sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
97 sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
98 return sampling, sampling_letter
100 @staticmethod
101 def get_dataset_colors():
102 """
103 Standard colors used for train-, val-, and test-sets during postprocessing
104 """
105 colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code
106 return colors