Coverage for mlair/plotting/abstract_plot_class.py: 100%

2 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:40 +0000

1"""Abstract plot class that should be used for preprocessing and postprocessing plots.""" 

2__author__ = "Lukas Leufen" 

3__date__ = '2021-04-13' 

4 

5import logging 

6import os 

7 

8from matplotlib import pyplot as plt 

9 

10 

11class AbstractPlotClass: # pragma: no cover 

12 """ 

13 Abstract class for all plotting routines to unify plot workflow. 

14 

15 Each inheritance requires a _plot method. Create a plot class like: 

16 

17 .. code-block:: python 

18 

19 class MyCustomPlot(AbstractPlotClass): 

20 

21 def __init__(self, plot_folder, *args, **kwargs): 

22 super().__init__(plot_folder, "custom_plot_name") 

23 self._data = self._prepare_data(*args, **kwargs) 

24 self._plot(*args, **kwargs) 

25 self._save() 

26 

27 def _prepare_data(*args, **kwargs): 

28 <your custom data preparation> 

29 return data 

30 

31 def _plot(*args, **kwargs): 

32 <your custom plotting without saving> 

33 

34 The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are 

35 using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be 

36 set in super class initialisation). 

37 

38 Methods like the shown _prepare_data() are optional. The only method required to implement is _plot. 

39 

40 If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot 

41 class. It will log the spent time if you call your plotting without saving the returned object. 

42 

43 .. code-block:: python 

44 

45 @TimeTrackingWrapper 

46 class MyCustomPlot(AbstractPlotClass): 

47 pass 

48 

49 Let's assume it takes a while to create this very special plot. 

50 

51 >>> MyCustomPlot() 

52 INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss) 

53 

54 """ 

55 

56 def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None): 

57 """Set up plot folder and name, and plot resolution (default 500dpi).""" 

58 plot_folder = os.path.abspath(plot_folder) 

59 if not os.path.exists(plot_folder): 

60 os.makedirs(plot_folder) 

61 self.plot_folder = plot_folder 

62 self.plot_name = plot_name.replace("/", "_") if plot_name is not None else plot_name 

63 self.resolution = resolution 

64 if rc_params is None: 

65 rc_params = {'axes.labelsize': 'large', 

66 'xtick.labelsize': 'large', 

67 'ytick.labelsize': 'large', 

68 'legend.fontsize': 'large', 

69 'axes.titlesize': 'large', 

70 } 

71 self.rc_params = rc_params 

72 self._update_rc_params() 

73 

74 def __del__(self): 

75 try: 

76 plt.close('all') 

77 except ImportError: 

78 pass 

79 

80 def _plot(self, *args): 

81 """Abstract plot class needs to be implemented in inheritance.""" 

82 raise NotImplementedError 

83 

84 def _save(self, **kwargs): 

85 """Store plot locally. Name of and path to plot need to be set on initialisation.""" 

86 plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") 

87 logging.debug(f"... save plot to {plot_name}") 

88 plt.savefig(plot_name, dpi=self.resolution, **kwargs) 

89 plt.close('all') 

90 

91 def _update_rc_params(self): 

92 plt.rcParams.update(self.rc_params) 

93 

94 @staticmethod 

95 def _get_sampling(sampling, pos=1): 

96 sampling = (sampling, sampling) if isinstance(sampling, str) else sampling 

97 sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "") 

98 return sampling, sampling_letter 

99 

100 @staticmethod 

101 def get_dataset_colors(): 

102 """ 

103 Standard colors used for train-, val-, and test-sets during postprocessing 

104 """ 

105 colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code 

106 return colors