Coverage for mlair/plotting/training_monitoring.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1"""Plots to monitor training.""" 

2 

3__author__ = 'Felix Kleinert, Lukas Leufen' 

4__date__ = '2019-12-11' 

5 

6from typing import Union, Dict, List 

7 

8import tensorflow.keras as keras 

9import matplotlib 

10import matplotlib.pyplot as plt 

11import pandas as pd 

12 

13from mlair.model_modules.keras_extensions import LearningRateDecay 

14from mlair.helpers.helpers import relative_round 

15 

16# matplotlib.use('Agg') 

17history_object = Union[Dict, keras.callbacks.History] 

18lr_object = Union[Dict, LearningRateDecay] 

19 

20 

21class PlotModelHistory: 

22 """ 

23 Plot history of all plot_metrics (default: loss) for a training event. 

24 

25 For default plot_metric and val_plot_metric are plotted. If further metrics are provided (name must somehow include 

26 the word `<plot_metric>`), this additional information is added to the plot with an separate y-axis scale on the 

27 right side (shared for all additional metrics). The plot is saved locally. For a proper saving behaviour, the 

28 parameter filename must include the absolute path for the plot. 

29 """ 

30 

31 def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False, 

32 epoch_best: int = None): 

33 """ 

34 Set attributes and create plot. 

35 

36 :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a 

37 format ending like .pdf or .png to work. 

38 :param history: the history object (or a dict with at least 'loss' and 'val_loss' as keys) to plot loss from 

39 :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss) 

40 :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default: 

41 False -> look for losses from all branches, not only from main) 

42 :param epoch_best: indicator at which epoch the best train result was achieved (should start counting at 0) 

43 """ 

44 if isinstance(history, keras.callbacks.History): 

45 history = history.history 

46 self._data = pd.DataFrame.from_dict(history) 

47 self._data.index += 1 

48 self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch) 

49 self._additional_columns = self._filter_columns(history) 

50 self._epoch_best = epoch_best 

51 self._plot(filename) 

52 

53 def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True): 

54 _plot_metric = plot_metric 

55 if correct_names is True: 

56 if plot_metric.lower() == "mse": 

57 plot_metric = "mean_squared_error" 

58 elif plot_metric.lower() == "mae": 

59 plot_metric = "mean_absolute_error" 

60 available_keys = [k for k in history.keys() if 

61 plot_metric in k and ("main" in k.lower() if main_branch else True)] 

62 available_keys.sort(key=len) 

63 if len(available_keys) == 0 and correct_names is True: 

64 return self._get_plot_metric(history, _plot_metric, main_branch, correct_names=False) 

65 return available_keys[0] 

66 

67 def _filter_columns(self, history: Dict) -> List[str]: 

68 """ 

69 Select only columns named like %<plot_metric>%. 

70 

71 The default metrics '<plot_metric>' and 'val_<plot_metric>' are removed too. 

72 

73 :param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras 

74 History.history) 

75 

76 :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>. 

77 """ 

78 cols = list(filter(lambda x: self._plot_metric in x, history.keys())) 

79 # heuristic: there is always val_<plot_metric> and <plot_metric> available in cols, because this is generated by 

80 # the keras framework. If this metric isn't available the self._get_plot_metric() will fail before (but only 

81 # because it is executed before) 

82 cols.remove(f"val_{self._plot_metric}") 

83 cols.remove(self._plot_metric) 

84 return cols 

85 

86 def _plot(self, filename: str) -> None: 

87 """ 

88 Create plot. 

89 

90 Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided, they will be added with 

91 an additional yaxis on the right side. The plot is saved in filename. 

92 

93 :param filename: name (including total path) of the plot to save. 

94 """ 

95 ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7) 

96 if self._epoch_best is not None: 

97 ax.scatter(self._epoch_best+1, self._data[[f"val_{self._plot_metric}"]].iloc[self._epoch_best], 

98 s=100, marker="*", c="black") 

99 ax.set_yscale('log') 

100 if len(self._additional_columns) > 0: 

101 self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True) 

102 if self._epoch_best is not None: 

103 final_res = self._data[[f'val_{self._plot_metric}']].min().values[0] 

104 annotation = f"best epoch {self._epoch_best}" 

105 else: 

106 final_res = self._data[[f'val_{self._plot_metric}']].values[-1][0] 

107 annotation = "final" 

108 title = f"Model {self._plot_metric} (val, {annotation}): {relative_round(final_res, 5)}" 

109 ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title) 

110 ax.axhline(y=0, color="gray", linewidth=0.5) 

111 plt.tight_layout() 

112 plt.savefig(filename) 

113 plt.close("all") 

114 

115 

116class PlotModelLearningRate: 

117 """ 

118 Plot the behaviour of the learning rate in dependence of the number of epochs. 

119 

120 The plot is saved locally as pdf. For a proper saving behaviour, the parameter filename must include the absolute 

121 path for the plot. 

122 """ 

123 

124 def __init__(self, filename: str, lr_sc: lr_object): 

125 """ 

126 Set attributes and create plot. 

127 

128 :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a 

129 format ending like .pdf or .png to work. 

130 :param lr_sc: the learning rate object (or a dict with `lr` as key) to plot from 

131 """ 

132 if isinstance(lr_sc, LearningRateDecay): 

133 lr_sc = lr_sc.lr 

134 self._data = pd.DataFrame.from_dict(lr_sc) 

135 self._plot(filename) 

136 

137 def _plot(self, filename: str) -> None: 

138 """ 

139 Create plot. 

140 

141 Plot the learning rate in dependence of epoch. 

142 

143 :param filename: name (including total path) of the plot to save. 

144 """ 

145 ax = self._data.plot(linewidth=0.7) 

146 ax.set(xlabel="epoch", ylabel="learning rate") 

147 plt.tight_layout() 

148 plt.savefig(filename) 

149 plt.close("all")