Coverage for mlair/plotting/training_monitoring.py: 100%
61 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1"""Plots to monitor training."""
3__author__ = 'Felix Kleinert, Lukas Leufen'
4__date__ = '2019-12-11'
6from typing import Union, Dict, List
8import tensorflow.keras as keras
9import matplotlib
10import matplotlib.pyplot as plt
11import pandas as pd
13from mlair.model_modules.keras_extensions import LearningRateDecay
14from mlair.helpers.helpers import relative_round
16# matplotlib.use('Agg')
17history_object = Union[Dict, keras.callbacks.History]
18lr_object = Union[Dict, LearningRateDecay]
21class PlotModelHistory:
22 """
23 Plot history of all plot_metrics (default: loss) for a training event.
25 For default plot_metric and val_plot_metric are plotted. If further metrics are provided (name must somehow include
26 the word `<plot_metric>`), this additional information is added to the plot with an separate y-axis scale on the
27 right side (shared for all additional metrics). The plot is saved locally. For a proper saving behaviour, the
28 parameter filename must include the absolute path for the plot.
29 """
31 def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False,
32 epoch_best: int = None):
33 """
34 Set attributes and create plot.
36 :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a
37 format ending like .pdf or .png to work.
38 :param history: the history object (or a dict with at least 'loss' and 'val_loss' as keys) to plot loss from
39 :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss)
40 :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default:
41 False -> look for losses from all branches, not only from main)
42 :param epoch_best: indicator at which epoch the best train result was achieved (should start counting at 0)
43 """
44 if isinstance(history, keras.callbacks.History):
45 history = history.history
46 self._data = pd.DataFrame.from_dict(history)
47 self._data.index += 1
48 self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch)
49 self._additional_columns = self._filter_columns(history)
50 self._epoch_best = epoch_best
51 self._plot(filename)
53 def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True):
54 _plot_metric = plot_metric
55 if correct_names is True:
56 if plot_metric.lower() == "mse":
57 plot_metric = "mean_squared_error"
58 elif plot_metric.lower() == "mae":
59 plot_metric = "mean_absolute_error"
60 available_keys = [k for k in history.keys() if
61 plot_metric in k and ("main" in k.lower() if main_branch else True)]
62 available_keys.sort(key=len)
63 if len(available_keys) == 0 and correct_names is True:
64 return self._get_plot_metric(history, _plot_metric, main_branch, correct_names=False)
65 return available_keys[0]
67 def _filter_columns(self, history: Dict) -> List[str]:
68 """
69 Select only columns named like %<plot_metric>%.
71 The default metrics '<plot_metric>' and 'val_<plot_metric>' are removed too.
73 :param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras
74 History.history)
76 :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>.
77 """
78 cols = list(filter(lambda x: self._plot_metric in x, history.keys()))
79 # heuristic: there is always val_<plot_metric> and <plot_metric> available in cols, because this is generated by
80 # the keras framework. If this metric isn't available the self._get_plot_metric() will fail before (but only
81 # because it is executed before)
82 cols.remove(f"val_{self._plot_metric}")
83 cols.remove(self._plot_metric)
84 return cols
86 def _plot(self, filename: str) -> None:
87 """
88 Create plot.
90 Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided, they will be added with
91 an additional yaxis on the right side. The plot is saved in filename.
93 :param filename: name (including total path) of the plot to save.
94 """
95 ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7)
96 if self._epoch_best is not None:
97 ax.scatter(self._epoch_best+1, self._data[[f"val_{self._plot_metric}"]].iloc[self._epoch_best],
98 s=100, marker="*", c="black")
99 ax.set_yscale('log')
100 if len(self._additional_columns) > 0:
101 self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True)
102 if self._epoch_best is not None:
103 final_res = self._data[[f'val_{self._plot_metric}']].min().values[0]
104 annotation = f"best epoch {self._epoch_best}"
105 else:
106 final_res = self._data[[f'val_{self._plot_metric}']].values[-1][0]
107 annotation = "final"
108 title = f"Model {self._plot_metric} (val, {annotation}): {relative_round(final_res, 5)}"
109 ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title)
110 ax.axhline(y=0, color="gray", linewidth=0.5)
111 plt.tight_layout()
112 plt.savefig(filename)
113 plt.close("all")
116class PlotModelLearningRate:
117 """
118 Plot the behaviour of the learning rate in dependence of the number of epochs.
120 The plot is saved locally as pdf. For a proper saving behaviour, the parameter filename must include the absolute
121 path for the plot.
122 """
124 def __init__(self, filename: str, lr_sc: lr_object):
125 """
126 Set attributes and create plot.
128 :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a
129 format ending like .pdf or .png to work.
130 :param lr_sc: the learning rate object (or a dict with `lr` as key) to plot from
131 """
132 if isinstance(lr_sc, LearningRateDecay):
133 lr_sc = lr_sc.lr
134 self._data = pd.DataFrame.from_dict(lr_sc)
135 self._plot(filename)
137 def _plot(self, filename: str) -> None:
138 """
139 Create plot.
141 Plot the learning rate in dependence of epoch.
143 :param filename: name (including total path) of the plot to save.
144 """
145 ax = self._data.plot(linewidth=0.7)
146 ax.set(xlabel="epoch", ylabel="learning rate")
147 plt.tight_layout()
148 plt.savefig(filename)
149 plt.close("all")