Coverage for mlair/plotting/data_insight_plotting.py: 100%

8 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-12-18 17:51 +0000

1"""Collection of plots to get more insight into data.""" 

2__author__ = "Lukas Leufen, Felix Kleinert" 

3__date__ = '2021-04-13' 

4 

5from typing import List, Dict 

6import dill 

7import os 

8import logging 

9import multiprocessing 

10import psutil 

11import sys 

12 

13import numpy as np 

14import pandas as pd 

15import xarray as xr 

16import seaborn as sns 

17import matplotlib 

18# matplotlib.use("Agg") 

19from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates 

20from astropy.timeseries import LombScargle 

21 

22from mlair.data_handler import DataCollection 

23from mlair.helpers import TimeTrackingWrapper, to_list, remove_items 

24from mlair.plotting.abstract_plot_class import AbstractPlotClass 

25 

26 

27@TimeTrackingWrapper 

28class PlotStationMap(AbstractPlotClass): # pragma: no cover 

29 """ 

30 Plot geographical overview of all used stations as squares. 

31 

32 Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to 

33 plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored 

34 topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf 

35 

36 .. image:: ../../../../../_source/_plots/station_map.png 

37 :width: 400 

38 """ 

39 

40 def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"): 

41 """ 

42 Set attributes and create plot. 

43 

44 :param generators: dictionary with the plot color of each data set as key and the generator containing all stations 

45 as value. 

46 :param plot_folder: path to save the plot (default: current directory) 

47 """ 

48 super().__init__(plot_folder, plot_name) 

49 self._ax = None 

50 self._gl = None 

51 self._plot(generators) 

52 self._save(bbox_inches="tight") 

53 

54 def _draw_background(self): 

55 """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" 

56 

57 import cartopy.feature as cfeature 

58 

59 self._ax.add_feature(cfeature.LAND.with_scale("50m")) 

60 self._ax.natural_earth_shp(resolution='50m') 

61 self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') 

62 self._ax.add_feature(cfeature.LAKES.with_scale("50m")) 

63 self._ax.add_feature(cfeature.OCEAN.with_scale("50m")) 

64 self._ax.add_feature(cfeature.RIVERS.with_scale("50m")) 

65 self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') 

66 

67 def _plot_stations(self, generators): 

68 """ 

69 Loop over all keys in generators dict and its containing stations and plot the stations's position. 

70 

71 Position is highlighted by a square on the map regarding the given color. 

72 

73 :param generators: dictionary with the plot color of each data set as key and the generator containing all 

74 stations as value. 

75 """ 

76 

77 import cartopy.crs as ccrs 

78 if generators is not None: 

79 legend_elements = [] 

80 default_colors = self.get_dataset_colors() 

81 for element in generators: 

82 data_collection, plot_opts = self._get_collection_and_opts(element) 

83 name = data_collection.name or "unknown" 

84 marker = plot_opts.get("marker", "s") 

85 ms = plot_opts.get("ms", 6) 

86 mec = plot_opts.get("mec", "k") 

87 mfc = plot_opts.get("mfc", default_colors.get(name, "b")) 

88 legend_elements.append( 

89 mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', 

90 label=f"{name} ({len(data_collection)})")) 

91 for station in data_collection: 

92 coords = station.get_coordinates() 

93 IDx, IDy = coords["lon"], coords["lat"] 

94 self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree()) 

95 if len(legend_elements) > 0: 

96 self._ax.legend(handles=legend_elements, loc='best') 

97 

98 @staticmethod 

99 def _adjust_marker(marker): 

100 _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} 

101 if isinstance(marker, int) and marker in _adjust.keys(): 

102 return _adjust[marker] 

103 else: 

104 return marker 

105 

106 @staticmethod 

107 def _get_collection_and_opts(element): 

108 if isinstance(element, tuple): 

109 if len(element) == 1: 

110 return element[0], {} 

111 else: 

112 return element 

113 else: 

114 return element, {} 

115 

116 def _plot(self, generators: List): 

117 """ 

118 Create the station map plot. 

119 

120 Set figure and call all required sub-methods. 

121 

122 :param generators: dictionary with the plot color of each data set as key and the generator containing all 

123 stations as value. 

124 """ 

125 

126 import cartopy.crs as ccrs 

127 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER 

128 fig = plt.figure(figsize=(10, 5)) 

129 self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) 

130 self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) 

131 self._gl.xformatter = LONGITUDE_FORMATTER 

132 self._gl.yformatter = LATITUDE_FORMATTER 

133 self._draw_background() 

134 self._plot_stations(generators) 

135 self._adjust_extent() 

136 plt.tight_layout() 

137 

138 def _adjust_extent(self): 

139 import cartopy.crs as ccrs 

140 

141 def diff(arr): 

142 return arr[1] - arr[0], arr[3] - arr[2] 

143 

144 def find_ratio(delta, reference=5): 

145 return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5) 

146 

147 extent = self._ax.get_extent(crs=ccrs.PlateCarree()) 

148 ratio = find_ratio(diff(extent)) 

149 new_extent = extent + np.array([-1, 1, -1, 1]) * ratio 

150 self._ax.set_extent(new_extent, crs=ccrs.PlateCarree()) 

151 

152 

153@TimeTrackingWrapper 

154class PlotAvailability(AbstractPlotClass): # pragma: no cover 

155 """ 

156 Create data availablility plot similar to Gantt plot. 

157 

158 Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal 

159 resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a 

160 colored bar or a blank space. 

161 

162 You can set different colors to highlight subsets for example by providing different generators for the same index 

163 using different keys in the input dictionary. 

164 

165 Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs 

166 in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset. 

167 

168 Calling this class will create three versions fo the availability plot. 

169 

170 1) Data availability for each element 

171 1) Data availability as summary over all elements (is there at least a single elemnt for each time step) 

172 1) Combination of single and overall availability 

173 

174 .. image:: ../../../../../_source/_plots/data_availability.png 

175 :width: 400 

176 

177 .. image:: ../../../../../_source/_plots/data_availability_summary.png 

178 :width: 400 

179 

180 .. image:: ../../../../../_source/_plots/data_availability_combined.png 

181 :width: 400 

182 

183 """ 

184 

185 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily", 

186 summary_name="data availability", time_dimension="datetime", window_dimension="window"): 

187 """Initialise.""" 

188 # create standard Gantt plot for all stations (currently in single pdf file with single page) 

189 super().__init__(plot_folder, "data_availability") 

190 self.time_dim = time_dimension 

191 self.window_dim = window_dimension 

192 self.sampling = self._get_sampling(sampling)[1] 

193 self.linewidth = None 

194 if self.sampling == 'h': 

195 self.linewidth = 0.001 

196 plot_dict = self._prepare_data(generators) 

197 lgd = self._plot(plot_dict) 

198 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") 

199 # create summary Gantt plot (is data in at least one station available) 

200 self.plot_name += "_summary" 

201 plot_dict_summary = self._summarise_data(generators, summary_name) 

202 lgd = self._plot(plot_dict_summary) 

203 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") 

204 # combination of station and summary plot, last element is summary broken bar 

205 self.plot_name = "data_availability_combined" 

206 plot_dict_summary.update(plot_dict) 

207 lgd = self._plot(plot_dict_summary) 

208 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") 

209 

210 def _prepare_data(self, generators: Dict[str, DataCollection]): 

211 plt_dict = {} 

212 for subset, data_collection in generators.items(): 

213 for station in data_collection: 

214 labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() 

215 labels_bool = labels.sel(**{self.window_dim: 1}).notnull() 

216 group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum() 

217 plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, 

218 index=labels.coords[self.time_dim].values) 

219 t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) 

220 t2 = [i[1:] for i in t if i[0]] 

221 

222 if plt_dict.get(str(station)) is None: 

223 plt_dict[str(station)] = {subset: t2} 

224 else: 

225 plt_dict[str(station)].update({subset: t2}) 

226 return plt_dict 

227 

228 def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str): 

229 plt_dict = {} 

230 for subset, data_collection in generators.items(): 

231 all_data = None 

232 for station in data_collection: 

233 labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() 

234 labels_bool = labels.sel(**{self.window_dim: 1}).notnull() 

235 if all_data is None: 

236 all_data = labels_bool 

237 else: 

238 tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords 

239 all_data = np.logical_or(tmp, labels_bool).combine_first( 

240 all_data) # apply logical on merge and fill missing with all_data 

241 

242 group = (all_data != all_data.shift({self.time_dim: 1})).cumsum() 

243 plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, 

244 index=all_data.coords[self.time_dim].values) 

245 t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) 

246 t2 = [i[1:] for i in t if i[0]] 

247 if plt_dict.get(summary_name) is None: 

248 plt_dict[summary_name] = {subset: t2} 

249 else: 

250 plt_dict[summary_name].update({subset: t2}) 

251 return plt_dict 

252 

253 def _plot(self, plt_dict): 

254 colors = self.get_dataset_colors() 

255 _used_colors = [] 

256 pos = 0 

257 height = 0.8 # should be <= 1 

258 yticklabels = [] 

259 number_of_stations = len(plt_dict.keys()) 

260 fig, ax = plt.subplots(figsize=(10, number_of_stations / 3)) 

261 for station, d in sorted(plt_dict.items(), reverse=True): 

262 pos += 1 

263 for subset, color in colors.items(): 

264 plt_data = d.get(subset) 

265 if plt_data is None: 

266 continue 

267 elif color not in _used_colors: # this is required for a proper legend creation 

268 _used_colors.append(color) 

269 ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth) 

270 yticklabels.append(station) 

271 

272 ax.set_ylim([height, number_of_stations + 1]) 

273 ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2) 

274 ax.set_yticklabels(yticklabels) 

275 handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors] 

276 lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) 

277 return lgd 

278 

279 

280@TimeTrackingWrapper 

281class PlotAvailabilityHistogram(AbstractPlotClass): # pragma: no cover 

282 """ 

283 Create data availability plots as histogram. 

284 

285 Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean). 

286 Calling this class creates two different types of histograms where each generator 

287 

288 1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis) 

289 2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number 

290 of samples (yaxis) 

291 

292 .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png 

293 :width: 400 

294 

295 .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png 

296 :width: 400 

297 

298 """ 

299 

300 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", 

301 subset_dim: str = 'DataSet', history_dim: str = 'window', 

302 station_dim: str = 'Stations', ): 

303 

304 super().__init__(plot_folder, "data_availability_histogram") 

305 

306 self.subset_dim = subset_dim 

307 self.history_dim = history_dim 

308 self.station_dim = station_dim 

309 

310 self.freq = None 

311 self.temporal_dim = None 

312 self.target_dim = None 

313 self._prepare_data(generators) 

314 

315 for plt_type in self.allowed_plot_types: 

316 plot_name_tmp = self.plot_name 

317 self.plot_name += '_' + plt_type 

318 self._plot(plt_type=plt_type) 

319 self._save() 

320 self.plot_name = plot_name_tmp 

321 

322 def _set_dims_from_datahandler(self, data_handler): 

323 self.temporal_dim = data_handler.id_class.time_dim 

324 self.target_dim = data_handler.id_class.target_dim 

325 self.freq = self._get_sampling(data_handler.id_class.sampling)[1] 

326 

327 @property 

328 def allowed_plot_types(self): 

329 plot_types = ['hist', 'hist_cum'] 

330 return plot_types 

331 

332 def _prepare_data(self, generators: Dict[str, DataCollection]): 

333 """ 

334 Prepares data to be used by plot methods. 

335 

336 Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim 

337 """ 

338 avail_data_time_sum = {} 

339 avail_data_station_sum = {} 

340 dataset_time_interval = {} 

341 for subset, generator in generators.items(): 

342 avail_list = [] 

343 for station in generator: 

344 self._set_dims_from_datahandler(data_handler=station) 

345 station_data_x = station.get_X(as_numpy=False)[0] 

346 station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame 

347 self.target_dim: station_data_x[self.target_dim].values[0]}] 

348 station_data_x = self._reduce_dims(station_data_x) 

349 avail_list.append(station_data_x.notnull()) 

350 avail_data = xr.concat(avail_list, dim=self.station_dim).notnull() 

351 avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim) 

352 avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim) 

353 dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray( 

354 avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict' 

355 ) 

356 avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(), 

357 name=self.subset_dim) 

358 ) 

359 full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq) 

360 self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(), 

361 name=self.subset_dim)) 

362 self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index}) 

363 self.dataset_time_interval = dataset_time_interval 

364 

365 def _reduce_dims(self, dataset): 

366 if len(dataset.dims) > 2: 

367 required = {self.temporal_dim, self.station_dim} 

368 unimportant = set(dataset.dims).difference(required) 

369 sel_dict = {un: dataset[un].values[0] for un in unimportant} 

370 dataset = dataset.loc[sel_dict] 

371 return dataset 

372 

373 @staticmethod 

374 def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'): 

375 if isinstance(xarray, xr.DataArray): 

376 first = xarray.coords[dim_name].values[0] 

377 last = xarray.coords[dim_name].values[-1] 

378 if return_type == 'as_tuple': 

379 return first, last 

380 elif return_type == 'as_dict': 

381 return {'first': first, 'last': last} 

382 else: 

383 raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'") 

384 else: 

385 raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}") 

386 

387 @staticmethod 

388 def _make_full_time_index(irregular_time_index, freq): 

389 full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq) 

390 return full_time_index 

391 

392 def _plot(self, plt_type='hist', *args): 

393 if plt_type == 'hist': 

394 self._plot_hist() 

395 elif plt_type == 'hist_cum': 

396 self._plot_hist_cum() 

397 else: 

398 raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}") 

399 

400 def _plot_hist(self, *args): 

401 colors = self.get_dataset_colors() 

402 fig, axes = plt.subplots(figsize=(10, 3)) 

403 for i, subset in enumerate(self.dataset_time_interval.keys()): 

404 plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset, 

405 self.temporal_dim: slice( 

406 self.dataset_time_interval[subset]['first'], 

407 self.dataset_time_interval[subset]['last'] 

408 ) 

409 } 

410 ) 

411 

412 plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset) 

413 plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset]) 

414 

415 lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), 

416 facecolor='white', framealpha=1, edgecolor='black') 

417 for lgd_line in lgd.get_lines(): 

418 lgd_line.set_linewidth(4.0) 

419 plt.gca().xaxis.set_major_locator(mdates.YearLocator()) 

420 plt.title('') 

421 plt.ylabel('Number of samples') 

422 plt.tight_layout() 

423 

424 def _plot_hist_cum(self, *args): 

425 colors = self.get_dataset_colors() 

426 fig, axes = plt.subplots(figsize=(10, 3)) 

427 n_bins = int(self.avail_data_cum_sum.max().values) 

428 bins = np.arange(0, n_bins + 1) 

429 descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby( 

430 self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False 

431 ).coords[self.subset_dim].values 

432 

433 for subset in descending_subsets: 

434 self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes, 

435 bins=bins, 

436 label=subset, 

437 cumulative=-1, 

438 color=colors[subset], 

439 # alpha=.5 

440 ) 

441 

442 lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), 

443 facecolor='white', framealpha=1, edgecolor='black') 

444 plt.title('') 

445 plt.ylabel('Number of stations') 

446 plt.xlabel('Number of samples') 

447 plt.xlim((bins[0], bins[-1])) 

448 plt.tight_layout() 

449 

450 

451@TimeTrackingWrapper 

452class PlotDataMonthlyDistribution(AbstractPlotClass): # pragma: no cover 

453 

454 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", variables_dim="variables", 

455 time_dim="datetime", window_dim="window", target_var: str = "", target_var_unit: str = "ppb"): 

456 """Set attributes and create plot.""" 

457 super().__init__(plot_folder, "monthly_data_distribution") 

458 self.variables_dim = variables_dim 

459 self.time_dim = time_dim 

460 self.window_dim = window_dim 

461 self.coll_dim = "coll" 

462 self.subset_dim = "subset" 

463 self._data = self._prepare_data(generators) 

464 self._plot(target_var, target_var_unit) 

465 self._save() 

466 

467 def _prepare_data(self, generators) -> List[xr.DataArray]: 

468 """ 

469 Pre.process data required to plot. 

470 

471 :param generator: data 

472 :return: The entire data set, flagged with the corresponding month. 

473 """ 

474 f = lambda x: x.get_observation() 

475 forecasts = [] 

476 for set_type, generator in generators.items(): 

477 forecasts_set = None 

478 forecasts_monthly = {} 

479 for i, gen in enumerate(generator): 

480 data = f(gen) 

481 data = gen.apply_transformation(data, inverse=True) 

482 data = data.clip(min=0).reset_coords(drop=True) 

483 new_index = data.coords[self.time_dim].values.astype("datetime64[M]").astype(int) % 12 + 1 

484 data = data.assign_coords({self.time_dim: new_index}) 

485 forecasts_set = xr.concat([forecasts_set, data], self.time_dim) if forecasts_set is not None else data 

486 for month in set(forecasts_set.coords[self.time_dim].values): 

487 monthly_values = forecasts_set.sel({self.time_dim: month}).values 

488 forecasts_monthly[month] = np.concatenate((forecasts_monthly.get(month, []), monthly_values)) 

489 forecasts_monthly = pd.DataFrame.from_dict(forecasts_monthly, orient="index")#.transpose() 

490 forecasts_monthly[self.coll_dim] = set_type 

491 forecasts.append(forecasts_monthly.set_index(self.coll_dim, append=True)) 

492 forecasts = pd.concat(forecasts).stack().rename_axis(["month", "subset", "index"]) 

493 forecasts = forecasts.to_frame(name="values").reset_index(level=[0, 1]) 

494 return forecasts 

495 

496 @staticmethod 

497 def _spell_out_chemical_concentrations(short_name: str, add_concentration: bool = False): 

498 short2long = {'o3': 'ozone', 'no': 'nitrogen oxide', 'no2': 'nitrogen dioxide', 'nox': 'nitrogen dioxides'} 

499 _suffix = "" if add_concentration is False else " concentration" 

500 return f"{short2long[short_name]}{_suffix}" 

501 

502 def _plot(self, target_var: str, target_var_unit: str): 

503 """ 

504 Create a monthly grouped box plot over all stations but with separate boxes for each lead time step. 

505 

506 :param target_var: display name of the target variable on plot's axis 

507 """ 

508 ax = sns.boxplot(data=self._data, x="month", y="values", hue="subset", whis=1.5, 

509 palette=self.get_dataset_colors(), flierprops={'marker': '.', 'markersize': 1}, showmeans=True, 

510 meanprops={'markersize': 1, 'markeredgecolor': 'k'}) 

511 ylabel = self._spell_out_chemical_concentrations(target_var) 

512 ax.set(xlabel='month', ylabel=f'dma8 {ylabel} (in {target_var_unit})') 

513 plt.tight_layout() 

514 

515 

516@TimeTrackingWrapper 

517class PlotDataHistogram(AbstractPlotClass): # pragma: no cover 

518 """ 

519 Plot histogram on transformed input and target data. This data is the same that the model sees during training. No 

520 plots are create for the original values space (raw / unformatted data). This plot method will create a histogram 

521 for input and target each comparing the subsets train, val and test, as well as a distinct one for the three 

522 subsets. 

523 

524 .. image:: ../../../../../_source/_plots/datahistogram.png 

525 :width: 400 

526 

527 """ 

528 

529 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", plot_name="histogram", 

530 variables_dim="variables", time_dim="datetime", window_dim="window", upsampling=False): 

531 super().__init__(plot_folder, plot_name) 

532 self.variables_dim = variables_dim 

533 self.time_dim = time_dim 

534 self.window_dim = window_dim 

535 self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim) 

536 self.bins = {} 

537 self.interval_width = {} 

538 self.bin_edges = {} 

539 if upsampling is True: 

540 self._handle_upsampling(generators) 

541 

542 # input plots 

543 for branch_pos in range(number_of_branches): 

544 self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos) 

545 add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}" 

546 for subset in generators.keys(): 

547 self._plot(add_name=add_name, subset=subset) 

548 self._plot_combined(add_name=add_name) 

549 

550 # target plots 

551 self._calculate_hist(generators, self.targets, input_data=False) 

552 for subset in generators.keys(): 

553 self._plot(add_name="target", subset=subset) 

554 self._plot_combined(add_name="target") 

555 

556 @staticmethod 

557 def _handle_upsampling(generators): 

558 if "train" in generators: 

559 generators.update({"train_upsampled": generators["train"]}) 

560 

561 @staticmethod 

562 def _get_inputs_targets(gens, dim): 

563 k = list(gens.keys())[0] 

564 gen = gens[k][0] 

565 inputs = list(set([y for x in to_list(gen.get_X(as_numpy=False)) for y in x.coords[dim].values.tolist()])) 

566 targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist()) 

567 n_branches = len(gen.get_X(as_numpy=False)) 

568 return inputs, targets, n_branches 

569 

570 def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0): 

571 n_bins = 100 

572 for set_type, generator in generators.items(): 

573 upsampling = "upsampled" in set_type 

574 tmp_bins = {} 

575 tmp_edges = {} 

576 end = {} 

577 start = {} 

578 if input_data is True: 

579 f = lambda x: x.get_X(as_numpy=False, upsampling=upsampling)[branch_pos] 

580 else: 

581 f = lambda x: x.get_Y(as_numpy=False, upsampling=upsampling) 

582 for gen in generator: 

583 w = min(abs(f(gen).coords[self.window_dim].values)) 

584 data = f(gen).sel({self.window_dim: w}) 

585 res, _, g_edges = f_proc_hist(data, variables, n_bins, self.variables_dim) 

586 for var in res.keys(): 

587 b = tmp_bins.get(var, []) 

588 b.append(res[var]) 

589 tmp_bins[var] = b 

590 e = tmp_edges.get(var, []) 

591 e.append(g_edges[var]) 

592 tmp_edges[var] = e 

593 end[var] = max([end.get(var, g_edges[var].max()), g_edges[var].max()]) 

594 start[var] = min([start.get(var, g_edges[var].min()), g_edges[var].min()]) 

595 # interpolate and aggregate 

596 bins = {} 

597 edges = {} 

598 interval_width = {} 

599 for var in tmp_bins.keys(): 

600 bin_edges = np.linspace(start[var], end[var], n_bins + 1) 

601 interval_width[var] = bin_edges[1] - bin_edges[0] 

602 for i, e in enumerate(tmp_bins[var]): 

603 bins_interp = np.interp(bin_edges[:-1], tmp_edges[var][i][:-1], e, left=0, right=0) 

604 bins[var] = bins.get(var, np.zeros(n_bins)) + bins_interp 

605 edges[var] = bin_edges 

606 

607 self.bins[set_type] = bins 

608 self.interval_width[set_type] = interval_width 

609 self.bin_edges[set_type] = edges 

610 

611 def _plot(self, add_name, subset): 

612 plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{subset}_{add_name}.pdf") 

613 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

614 bins = self.bins[subset] 

615 bin_edges = self.bin_edges[subset] 

616 interval_width = self.interval_width[subset] 

617 colors = self.get_dataset_colors() 

618 colors.update({"train_upsampled": colors.get("train_val", "#000000")}) 

619 for var in bins.keys(): 

620 fig, ax = plt.subplots() 

621 hist_var = bins[var] 

622 n_var = sum(hist_var) 

623 weights = hist_var / (interval_width[var] * n_var) 

624 ax.hist(bin_edges[var][:-1], bin_edges[var], weights=weights, color=colors[subset]) 

625 ax.set_ylabel("probability density") 

626 ax.set_xlabel(f"values") 

627 ax.set_title(f"histogram {var} ({subset}, n={int(n_var)})") 

628 pdf_pages.savefig() 

629 # close all open figures / plots 

630 pdf_pages.close() 

631 plt.close('all') 

632 

633 def _plot_combined(self, add_name): 

634 plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{add_name}.pdf") 

635 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

636 variables = self.bins[list(self.bins.keys())[0]].keys() 

637 colors = self.get_dataset_colors() 

638 colors.update({"train_upsampled": colors.get("train_val", "#000000")}) 

639 for var in variables: 

640 fig, ax = plt.subplots() 

641 for subset in self.bins.keys(): 

642 hist_var = self.bins[subset][var] 

643 interval_width = self.interval_width[subset][var] 

644 bin_edges = self.bin_edges[subset][var] 

645 n_var = sum(hist_var) 

646 weights = hist_var / (interval_width * n_var) 

647 ax.plot(bin_edges[:-1] + 0.5 * interval_width, weights, label=f"{subset}", 

648 c=colors[subset]) 

649 ax.set_ylabel("probability density") 

650 ax.set_xlabel("values") 

651 ax.legend(loc="upper right") 

652 ax.set_title(f"histogram {var}") 

653 pdf_pages.savefig() 

654 # close all open figures / plots 

655 pdf_pages.close() 

656 plt.close('all') 

657 

658 

659@TimeTrackingWrapper 

660class PlotPeriodogram(AbstractPlotClass): # pragma: no cover 

661 """ 

662 Create Lomb-Scargle periodogram in raw input and target data. The Lomb-Scargle version can deal with missing values. 

663 

664 This plot routine is creating the following plots: 

665 

666 * "raw": data is not aggregated, 1 graph per variable 

667 * "": single data lines are aggregated, 1 graph per variable 

668 * "total": data is aggregated on all variables, single graph 

669 

670 If data consists on different sampling rates, a separate plot is create for each sampling. 

671 

672 .. image:: ../../../../../_source/_plots/periodogram.png 

673 :width: 400 

674 

675 .. note:: 

676 This plot is not included in the default plot list. To use this plot, add "PlotPeriodogram" to the `plot_list`. 

677 

678 .. warning:: 

679 This plot is highly sensitive to the data handler structure. Therefore, it is highly likely that this method is 

680 not compatible with any custom data handler. Proven data handlers are `DefaultDataHandler`, 

681 `DataHandlerMixedSampling`, `DataHandlerMixedSamplingWithFilter`. To work properly, the data handler must have 

682 the attribute `.id_class._data`. 

683 

684 """ 

685 

686 def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="periodogram", 

687 variables_dim="variables", time_dim="datetime", sampling="daily", use_multiprocessing=False): 

688 super().__init__(plot_folder, plot_name) 

689 self.variables_dim = variables_dim 

690 self.time_dim = time_dim 

691 

692 for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling, sampling)): 

693 self._sampling = s 

694 self._add_text = {0: "input", 1: "target"}[pos] 

695 multiple, label_names = self._has_filter_dimension(generator[0], pos) 

696 self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing) 

697 self._plot(raw=True) 

698 self._plot(raw=False) 

699 self._plot_total(raw=True) 

700 self._plot_total(raw=False) 

701 if multiple > 1: 

702 self._plot_difference(label_names, plot_name_add="_last") 

703 self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing, 

704 use_last_input_value=False) 

705 self._plot_difference(label_names, plot_name_add="_first") 

706 

707 @staticmethod 

708 def _has_filter_dimension(g, pos): 

709 """Inspect if filtered data is provided and return number and labels of filtered components.""" 

710 check_class = g.id_class 

711 check_data = [check_class.get_X(as_numpy=False), check_class.get_Y(as_numpy=False)][pos] 

712 if not hasattr(check_class, "filter_dim"): # data handler has no filtered data 

713 return 1, [] 

714 else: 

715 filter_dim = check_class.filter_dim 

716 if filter_dim not in check_data.coords.dims: # current data has no filter (e.g. target data) 

717 return 1, [] 

718 else: 

719 return check_data.coords[filter_dim].shape[0], check_data.coords[filter_dim].values.tolist() 

720 

721 @TimeTrackingWrapper 

722 def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False, use_last_input_value=True): 

723 """ 

724 Create periodogram data. 

725 """ 

726 self.raw_data = [] 

727 self.plot_data = [] 

728 self.plot_data_raw = [] 

729 self.plot_data_mean = [] 

730 iter = range(multiple if multiple == 1 else multiple + 1) 

731 for m in iter: 

732 plot_data_single = dict() 

733 plot_data_raw_single = dict() 

734 plot_data_mean_single = dict() 

735 self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000) 

736 raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing, 

737 use_last_input_value=use_last_input_value) 

738 for var in raw_data_single.keys(): 

739 pgram_com = [] 

740 pgram_mean = 0 

741 all_data = raw_data_single[var] 

742 pgram_mean_raw = np.zeros((len(self.f_index), len(all_data))) 

743 for i, (f, pgram) in enumerate(all_data): 

744 d = np.interp(self.f_index, f, pgram) 

745 pgram_com.append(d) 

746 pgram_mean += d 

747 pgram_mean_raw[:, i] = d 

748 pgram_mean /= len(all_data) 

749 plot_data_single[var] = pgram_com 

750 plot_data_mean_single[var] = (self.f_index, pgram_mean) 

751 plot_data_raw_single[var] = (self.f_index, pgram_mean_raw) 

752 self.plot_data.append(plot_data_single) 

753 self.plot_data_mean.append(plot_data_mean_single) 

754 self.plot_data_raw.append(plot_data_raw_single) 

755 

756 def _prepare_pgram_parallel_var(self, generator, m, pos, use_multiprocessing): 

757 """Implementation of data preprocessing using parallel variables element processing.""" 

758 raw_data_single = dict() 

759 for g in generator: 

760 if m == 0: 

761 d = g.id_class._data 

762 else: 

763 gd = g.id_class 

764 filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]} 

765 d = (gd.input_data.sel(filter_sel), gd.target_data) 

766 d = d[pos] if isinstance(d, tuple) else d 

767 res = [] 

768 if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution 

769 pool = multiprocessing.Pool( 

770 min([psutil.cpu_count(logical=False), len(d[self.variables_dim].values), 

771 16])) # use only physical cpus 

772 output = [ 

773 pool.apply_async(f_proc, 

774 args=(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim))) 

775 for var in d[self.variables_dim].values] 

776 for i, p in enumerate(output): 

777 res.append(p.get()) 

778 pool.close() 

779 pool.join() 

780 else: # serial solution 

781 for var in d[self.variables_dim].values: 

782 res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim))) 

783 for (var_str, f, pgram) in res: 

784 if var_str not in raw_data_single.keys(): 

785 raw_data_single[var_str] = [(f, pgram)] 

786 else: 

787 raw_data_single[var_str] = raw_data_single[var_str] + [(f, pgram)] 

788 return raw_data_single 

789 

790 def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing, use_last_input_value=True): 

791 """Implementation of data preprocessing using parallel generator element processing.""" 

792 raw_data_single = dict() 

793 res = [] 

794 if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution 

795 pool = multiprocessing.Pool( 

796 min([psutil.cpu_count(logical=False), len(generator), 16])) # use only physical cpus 

797 output = [ 

798 pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim, self.f_index, 

799 use_last_input_value)) 

800 for g in generator] 

801 for i, p in enumerate(output): 

802 res.append(p.get()) 

803 pool.close() 

804 pool.join() 

805 else: 

806 for g in generator: 

807 res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim, self.f_index, use_last_input_value)) 

808 for res_dict in res: 

809 for k, v in res_dict.items(): 

810 if k not in raw_data_single.keys(): 

811 raw_data_single[k] = v 

812 else: 

813 raw_data_single[k] = raw_data_single[k] + v 

814 return raw_data_single 

815 

816 @staticmethod 

817 def _add_annotation_line(ax, pos, div, lims, unit): 

818 for p in to_list(pos): # per year 

819 ax.vlines(p / div, *lims, "black") 

820 ax.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor") 

821 

822 def _format_figure(self, ax, var_name="total"): 

823 """ 

824 Set log scale on both axis, add labels and annotation lines, and set title. 

825 :param ax: current ax object 

826 :param var_name: name of variable that will be included in the title 

827 """ 

828 ax.set_yscale('log') 

829 ax.set_xscale('log') 

830 ax.set_ylabel("power spectral density", fontsize='x-large') # unit depends on variable: [unit^2 day^-1] 

831 ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large') 

832 lims = ax.get_ylim() 

833 self._add_annotation_line(ax, [1, 2, 3], 365.25, lims, "yr") # per year 

834 self._add_annotation_line(ax, 1, 365.25 / 12, lims, "m") # per month 

835 self._add_annotation_line(ax, 1, 7, lims, "w") # per week 

836 self._add_annotation_line(ax, [1, 0.5], 1, lims, "d") # per day 

837 if self._sampling == "hourly": 

838 self._add_annotation_line(ax, 2, 1, lims, "d") # per day 

839 self._add_annotation_line(ax, [1, 0.5], 1 / 24., lims, "h") # per hour 

840 title = f"Periodogram ({var_name})" 

841 ax.set_title(title) 

842 

843 def _plot(self, raw=True): 

844 plot_path = os.path.join(os.path.abspath(self.plot_folder), 

845 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}.pdf") 

846 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

847 plot_data = self.plot_data[0] 

848 plot_data_mean = self.plot_data_mean[0] 

849 for var in plot_data.keys(): 

850 fig, ax = plt.subplots() 

851 if raw is True: 

852 for pgram in plot_data[var]: 

853 ax.plot(self.f_index, pgram, "lightblue") 

854 ax.plot(*plot_data_mean[var], "blue") 

855 else: 

856 ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0) 

857 mean = ma.mean().mean(axis=1).values.flatten() 

858 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() 

859 ax.plot(self.f_index, mean, "blue") 

860 ax.fill_between(self.f_index, lower, upper, color="lightblue") 

861 self._format_figure(ax, var) 

862 pdf_pages.savefig() 

863 # close all open figures / plots 

864 pdf_pages.close() 

865 plt.close('all') 

866 

867 def _plot_total(self, raw=True): 

868 plot_path = os.path.join(os.path.abspath(self.plot_folder), 

869 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}_total.pdf") 

870 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

871 plot_data_raw = self.plot_data_raw[0] 

872 fig, ax = plt.subplots() 

873 res = None 

874 for var in plot_data_raw.keys(): 

875 d_var = plot_data_raw[var][1] 

876 res = d_var if res is None else np.concatenate((res, d_var), axis=-1) 

877 if raw is True: 

878 for i in range(res.shape[1]): 

879 ax.plot(self.f_index, res[:, i], "lightblue") 

880 ax.plot(self.f_index, res.mean(axis=1), "blue") 

881 else: 

882 ma = pd.DataFrame(np.vstack(res)).rolling(5, center=True, axis=0) 

883 mean = ma.mean().mean(axis=1).values.flatten() 

884 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() 

885 ax.plot(self.f_index, mean, "blue") 

886 ax.fill_between(self.f_index, lower, upper, color="lightblue") 

887 self._format_figure(ax, "total") 

888 pdf_pages.savefig() 

889 # close all open figures / plots 

890 pdf_pages.close() 

891 plt.close('all') 

892 

893 def _plot_difference(self, label_names, plot_name_add = ""): 

894 plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter{plot_name_add}.pdf" 

895 plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) 

896 logging.info(f"... plotting {plot_name}") 

897 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

898 colors = ["grey", "blue", "red", "green", "orange", "purple", "black"] 

899 label_names = ["orig"] + label_names 

900 max_iter = len(self.plot_data) 

901 var_keys = self.plot_data[0].keys() 

902 for var in var_keys: 

903 fig, ax = plt.subplots() 

904 for i in reversed(range(max_iter)): 

905 if label_names[i] == "unfiltered": 

906 continue # do not include the filter 'unfiltered' because this is equal to the 'orig' data 

907 plot_data = self.plot_data[i] 

908 c = colors[i] 

909 ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0) 

910 mean = ma.mean().mean(axis=1).values.flatten() 

911 ax.plot(self.f_index, mean, c, label=label_names[i]) 

912 if i < 1: 

913 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() 

914 ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5, label=None) 

915 self._format_figure(ax, var) 

916 ax.legend(loc="upper center", ncol=max_iter) 

917 pdf_pages.savefig() 

918 # close all open figures / plots 

919 pdf_pages.close() 

920 plt.close('all') 

921 

922 

923def f_proc(var, d_var, f_index, time_dim="datetime", use_last_value=True): # pragma: no cover 

924 var_str = str(var) 

925 t = (d_var[time_dim] - d_var[time_dim][0]).astype("timedelta64[h]").values / np.timedelta64(1, "D") 

926 if len(d_var.shape) > 1: # use only max value if dimensions are remaining (e.g. max(window) -> latest value) 

927 to_remove = remove_items(d_var.coords.dims, time_dim) 

928 for e in to_list(to_remove): 

929 d_var = d_var.sel({e: d_var[e].max() if use_last_value is True else d_var[e].min()}) 

930 pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").power(f_index) 

931 # f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").autopower() 

932 return var_str, f_index, pgram 

933 

934 

935def f_proc_2(g, m, pos, variables_dim, time_dim, f_index, use_last_value): # pragma: no cover 

936 

937 # load lazy data 

938 id_classes = list(filter(lambda x: "id_class" in x, dir(g))) if pos == 0 else ["id_class"] 

939 for id_cls_name in id_classes: 

940 id_cls = getattr(g, id_cls_name) 

941 if hasattr(id_cls, "lazy"): 

942 id_cls.load_lazy() if id_cls.lazy is True else None 

943 

944 raw_data_single = dict() 

945 for dh in list(filter(lambda x: "unfiltered" not in x, id_classes)): 

946 current_cls = getattr(g, dh) 

947 if m == 0: 

948 d = current_cls._data 

949 if d is None: 

950 window_dim = current_cls.window_dim 

951 history = current_cls.history 

952 last_entry = history.coords[window_dim][-1] 

953 d1 = history.sel({window_dim: last_entry}, drop=True) 

954 label = current_cls.label 

955 first_entry = label.coords[window_dim][0] 

956 d2 = label.sel({window_dim: first_entry}, drop=True) 

957 d = (d1, d2) 

958 else: 

959 filter_sel = {"filter": current_cls.input_data.coords["filter"][m - 1]} 

960 d = (current_cls.input_data.sel(filter_sel), current_cls.target_data) 

961 d = d[pos] if isinstance(d, tuple) else d 

962 for var in d[variables_dim].values: 

963 d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim) 

964 var_str, f, pgram = f_proc(var, d_var, f_index, use_last_value=use_last_value) 

965 if var_str not in raw_data_single.keys(): 

966 raw_data_single[var_str] = [(f, pgram)] 

967 else: 

968 raise KeyError(f"There are multiple pgrams for key {var_str}. Please check your data handler.") 

969 

970 # perform clean up 

971 for id_cls_name in id_classes: 

972 id_cls = getattr(g, id_cls_name) 

973 if hasattr(id_cls, "lazy"): 

974 id_cls.clean_up() if id_cls.lazy is True else None 

975 

976 return raw_data_single 

977 

978 

979def f_proc_hist(data, variables, n_bins, variables_dim): # pragma: no cover 

980 res = {} 

981 bin_edges = {} 

982 interval_width = {} 

983 for var in variables: 

984 if var in data.coords[variables_dim]: 

985 d = data.sel({variables_dim: var}).squeeze() if len(data.shape) > 1 else data 

986 res[var], bin_edges[var] = np.histogram(d.values, n_bins) 

987 interval_width[var] = bin_edges[var][1] - bin_edges[var][0] 

988 return res, interval_width, bin_edges 

989 

990 

991class PlotClimateFirFilter(AbstractPlotClass): # pragma: no cover 

992 """ 

993 Plot climate FIR filter components. 

994 

995 * Creates a separate folder climFIR inside the given plot directory. 

996 * For each station up to 4 examples are shown (1 for each season). 

997 * Each filtered component and its residuum is drawn in a separate plot. 

998 * A filter component plot includes the climate FIR input, the filter response, the true non-causal (ideal) filter 

999 input, and the corresponding ideal response (containing information about future) 

1000 * A filter residuum plot include the climate FIR residuum and the ideal filter residuum. 

1001 """ 

1002 

1003 def __init__(self, plot_folder, plot_data, sampling, name): 

1004 

1005 from mlair.helpers.filter import fir_filter_convolve 

1006 

1007 logging.info(f"start PlotClimateFirFilter for ({name})") 

1008 

1009 # adjust default plot parameters 

1010 rc_params = { 

1011 'axes.labelsize': 'large', 

1012 'xtick.labelsize': 'large', 

1013 'ytick.labelsize': 'large', 

1014 'legend.fontsize': 'medium', 

1015 'axes.titlesize': 'large'} 

1016 if plot_folder is None: 

1017 return 

1018 

1019 self.style_dict = { 

1020 "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"}, 

1021 "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"}, 

1022 "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2}, 

1023 "ideal": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2}, 

1024 "valid_area": {"color": "whitesmoke", "label": "valid area"}, 

1025 "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"} 

1026 } 

1027 

1028 self.variables_list = [] 

1029 plot_folder = os.path.join(os.path.abspath(plot_folder), "climFIR") 

1030 self.fir_filter_convolve = fir_filter_convolve 

1031 super().__init__(plot_folder, plot_name=None, rc_params=rc_params) 

1032 plot_dict, new_dim = self._prepare_data(plot_data) 

1033 self._name = name 

1034 self._plot(plot_dict, sampling, new_dim) 

1035 self._store_plot_data(plot_data) 

1036 

1037 def _prepare_data(self, data): 

1038 """Restructure plot data.""" 

1039 plot_dict = {} 

1040 new_dim = None 

1041 for i in range(len(data)): 

1042 plot_data = data[i] 

1043 for p_d in plot_data: 

1044 var = p_d.get("var") 

1045 t0 = p_d.get("t0") 

1046 filter_input = p_d.get("filter_input") 

1047 filter_input_nc = p_d.get("filter_input_nc") 

1048 valid_range = p_d.get("valid_range") 

1049 time_range = p_d.get("time_range") 

1050 if new_dim is None: 

1051 new_dim = p_d.get("new_dim") 

1052 else: 

1053 assert new_dim == p_d.get("new_dim") 

1054 h = p_d.get("h") 

1055 plot_dict_var = plot_dict.get(var, {}) 

1056 plot_dict_t0 = plot_dict_var.get(t0, {}) 

1057 plot_dict_order = {"filter_input": filter_input, 

1058 "filter_input_nc": filter_input_nc, 

1059 "valid_range": valid_range, 

1060 "time_range": time_range, 

1061 "order": len(h), "h": h} 

1062 plot_dict_t0[i] = plot_dict_order 

1063 plot_dict_var[t0] = plot_dict_t0 

1064 plot_dict[var] = plot_dict_var 

1065 self.variables_list = list(plot_dict.keys()) 

1066 return plot_dict, new_dim 

1067 

1068 def _plot(self, plot_dict, sampling, new_dim="window"): 

1069 td_type = {"1d": "D", "1H": "h"}.get(sampling) 

1070 for var, vis_dict in plot_dict.items(): 

1071 for it0, t0 in enumerate(vis_dict.keys()): 

1072 vis_data = vis_dict[t0] 

1073 residuum_true = None 

1074 try: 

1075 for ifilter in sorted(vis_data.keys()): 

1076 data = vis_data[ifilter] 

1077 filter_input = data["filter_input"] 

1078 filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel( 

1079 {new_dim: filter_input.coords[new_dim]}) 

1080 valid_range = data["valid_range"] 

1081 time_axis = data["time_range"] 

1082 filter_order = data["order"] 

1083 h = data["h"] 

1084 fig, ax = plt.subplots() 

1085 

1086 # plot backgrounds 

1087 self._plot_valid_area(ax, t0, valid_range, td_type) 

1088 self._plot_t0(ax, t0) 

1089 

1090 # original data 

1091 self._plot_original_data(ax, time_axis, filter_input_nc) 

1092 

1093 # clim apriori 

1094 self._plot_apriori(ax, time_axis, filter_input, new_dim, ifilter, offset=1) 

1095 

1096 # get ax lims 

1097 ylims = ax.get_ylim() 

1098 

1099 # clim filter response 

1100 residuum_estimated = self._plot_clim_filter(ax, time_axis, filter_input, new_dim, h, 

1101 output_dtypes=filter_input.dtype) 

1102 

1103 # ideal filter response 

1104 residuum_true = self._plot_ideal_filter(ax, time_axis, filter_input_nc, new_dim, h, 

1105 output_dtypes=filter_input.dtype) 

1106 

1107 # set title, legend, and save plot 

1108 xlims = self._set_xlim(ax, t0, filter_order, valid_range, td_type, time_axis) 

1109 ax.set_ylim(ylims) 

1110 

1111 plt.title(f"Input of ClimFilter ({str(var)})") 

1112 plt.legend() 

1113 fig.autofmt_xdate() 

1114 plt.tight_layout() 

1115 self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}" 

1116 self._save() 

1117 

1118 # plot residuum 

1119 fig, ax = plt.subplots() 

1120 self._plot_valid_area(ax, t0, valid_range, td_type) 

1121 self._plot_t0(ax, t0) 

1122 self._plot_series(ax, time_axis, residuum_true.values.flatten(), style="ideal") 

1123 self._plot_series(ax, time_axis, residuum_estimated.values.flatten(), style="clim") 

1124 ax.set_xlim(xlims) 

1125 self._set_ylim_by_valid_range(ax, residuum_true, residuum_estimated, new_dim, valid_range) 

1126 plt.title(f"Residuum of ClimFilter ({str(var)})") 

1127 plt.legend(loc="upper left") 

1128 fig.autofmt_xdate() 

1129 plt.tight_layout() 

1130 

1131 self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum" 

1132 self._save() 

1133 except Exception as e: 

1134 logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

1135 pass 

1136 

1137 @staticmethod 

1138 def _set_ylim_by_valid_range(ax, a, b, dim, valid_range): 

1139 ymax = max(a.sel({dim: valid_range}).max(), 

1140 b.sel({dim: valid_range}).max()) 

1141 ymin = min(a.sel({dim: valid_range}).min(), 

1142 b.sel({dim: valid_range}).min()) 

1143 ymax = 1.1 * ymax if ymax > 0 else 0.9 * ymax 

1144 ymin = 0.9 * ymin if ymin > 0 else 1.1 * ymin 

1145 ax.set_ylim((ymin, ymax)) 

1146 

1147 def _set_xlim(self, ax, t0, order, valid_range, td_type, time_axis): 

1148 """ 

1149 Set xlims 

1150 

1151 Use order and valid_range to find a good zoom in that hides edges of filter values that are effected by reduced 

1152 filter order. Limits are returned to be usable for other plots. 

1153 """ 

1154 t_minus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), (-valid_range.start + 0.3 * order)) 

1155 t_plus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), valid_range.stop + 0.3 * order) 

1156 t_minus = t0 + np.timedelta64(-int(t_minus_delta), td_type) 

1157 t_plus = t0 + np.timedelta64(int(t_plus_delta), td_type) 

1158 ax_start = max(t_minus, time_axis[0]) 

1159 ax_end = min(t_plus, time_axis[-1]) 

1160 ax.set_xlim((ax_start, ax_end)) 

1161 return ax_start, ax_end 

1162 

1163 def _plot_valid_area(self, ax, t0, valid_range, td_type): 

1164 ax.axvspan(t0 + np.timedelta64(valid_range.start, td_type), 

1165 t0 + np.timedelta64(valid_range.stop - 1, td_type), **self.style_dict["valid_area"]) 

1166 

1167 def _plot_t0(self, ax, t0): 

1168 ax.axvline(t0, **self.style_dict["t0"]) 

1169 

1170 def _plot_series(self, ax, time_axis, data, style): 

1171 ax.plot(time_axis, data, **self.style_dict[style]) 

1172 

1173 def _plot_original_data(self, ax, time_axis, data): 

1174 # original data 

1175 filter_input_nc = data 

1176 self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), style="original") 

1177 # self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed", 

1178 # label="original") 

1179 

1180 def _plot_apriori(self, ax, time_axis, data, new_dim, ifilter, offset): 

1181 # clim apriori 

1182 filter_input = data 

1183 if ifilter == 0: 

1184 d_tmp = filter_input.sel( 

1185 {new_dim: slice(offset, filter_input.coords[new_dim].values.max())}).values.flatten() 

1186 else: 

1187 d_tmp = filter_input.values.flatten() 

1188 self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, style="apriori") 

1189 # self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid", 

1190 # label="estimated future") 

1191 

1192 def _plot_clim_filter(self, ax, time_axis, data, new_dim, h, output_dtypes): 

1193 filter_input = data 

1194 # clim filter response 

1195 filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input, 

1196 input_core_dims=[[new_dim]], 

1197 output_core_dims=[[new_dim]], 

1198 vectorize=True, 

1199 kwargs={"h": h}, 

1200 output_dtypes=[output_dtypes]) 

1201 self._plot_series(ax, time_axis, filt.values.flatten(), style="clim") 

1202 # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="solid", 

1203 # label="clim filter response", linewidth=2) 

1204 residuum_estimated = filter_input - filt 

1205 return residuum_estimated 

1206 

1207 def _plot_ideal_filter(self, ax, time_axis, data, new_dim, h, output_dtypes): 

1208 filter_input_nc = data 

1209 # ideal filter response 

1210 filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input_nc, 

1211 input_core_dims=[[new_dim]], 

1212 output_core_dims=[[new_dim]], 

1213 vectorize=True, 

1214 kwargs={"h": h}, 

1215 output_dtypes=[output_dtypes]) 

1216 self._plot_series(ax, time_axis, filt.values.flatten(), style="ideal") 

1217 # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="dashed", 

1218 # label="ideal filter response", linewidth=2) 

1219 residuum_true = filter_input_nc - filt 

1220 return residuum_true 

1221 

1222 def _store_plot_data(self, data): 

1223 """Store plot data. Could be loaded in a notebook to redraw.""" 

1224 file = os.path.join(self.plot_folder, "_".join(self.variables_list) + "plot_data.pickle") 

1225 with open(file, "wb") as f: 

1226 dill.dump(data, f) 

1227 

1228 

1229class PlotFirFilter(AbstractPlotClass): # pragma: no cover 

1230 """ 

1231 Plot FIR filter components. 

1232 

1233 * Creates a separate folder FIR inside the given plot directory. 

1234 * For each station up to 4 examples are shown (1 for each season). 

1235 * Each filtered component and its residuum is drawn in a separate plot. 

1236 * A filter component plot includes the FIR input and the filter response 

1237 * A filter residuum plot include the FIR residuum 

1238 """ 

1239 

1240 def __init__(self, plot_folder, plot_data, name): 

1241 

1242 logging.info(f"start PlotFirFilter for ({name})") 

1243 

1244 # adjust default plot parameters 

1245 rc_params = { 

1246 'axes.labelsize': 'large', 

1247 'xtick.labelsize': 'large', 

1248 'ytick.labelsize': 'large', 

1249 'legend.fontsize': 'medium', 

1250 'axes.titlesize': 'large'} 

1251 if plot_folder is None: 

1252 return 

1253 

1254 self.style_dict = { 

1255 "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"}, 

1256 "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"}, 

1257 "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2}, 

1258 "FIR": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2}, 

1259 "valid_area": {"color": "whitesmoke", "label": "valid area"}, 

1260 "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"} 

1261 } 

1262 

1263 plot_folder = os.path.join(os.path.abspath(plot_folder), "FIR") 

1264 super().__init__(plot_folder, plot_name=None, rc_params=rc_params) 

1265 plot_dict = self._prepare_data(plot_data) 

1266 self._name = name 

1267 self._plot(plot_dict) 

1268 self._store_plot_data(plot_data) 

1269 

1270 def _prepare_data(self, data): 

1271 """Restructure plot data.""" 

1272 plot_dict = {} 

1273 for i in range(len(data)): # filter component 

1274 for j in range(len(data[i])): # t0 counter 

1275 plot_data = data[i][j] 

1276 t0 = plot_data.get("t0") 

1277 filter_input = plot_data.get("filter_input") 

1278 filtered = plot_data.get("filtered") 

1279 var_dim = plot_data.get("var_dim") 

1280 time_dim = plot_data.get("time_dim") 

1281 for var in filtered.coords[var_dim].values: 

1282 plot_dict_var = plot_dict.get(var, {}) 

1283 plot_dict_t0 = plot_dict_var.get(t0, {}) 

1284 plot_dict_order = {"filter_input": filter_input.sel({var_dim: var}, drop=True), 

1285 "filtered": filtered.sel({var_dim: var}, drop=True), 

1286 "time_dim": time_dim} 

1287 plot_dict_t0[i] = plot_dict_order 

1288 plot_dict_var[t0] = plot_dict_t0 

1289 plot_dict[var] = plot_dict_var 

1290 return plot_dict 

1291 

1292 def _plot(self, plot_dict): 

1293 for var, viz_date_dict in plot_dict.items(): 

1294 for it0, t0 in enumerate(viz_date_dict.keys()): 

1295 viz_data = viz_date_dict[t0] 

1296 try: 

1297 for ifilter in sorted(viz_data.keys()): 

1298 data = viz_data[ifilter] 

1299 filter_input = data["filter_input"] 

1300 filtered = data["filtered"] 

1301 time_dim = data["time_dim"] 

1302 time_axis = filtered.coords[time_dim].values 

1303 fig, ax = plt.subplots() 

1304 

1305 # plot backgrounds 

1306 self._plot_t0(ax, t0) 

1307 

1308 # original data 

1309 self._plot_data(ax, time_axis, filter_input, style="original") 

1310 

1311 # filter response 

1312 self._plot_data(ax, time_axis, filtered, style="FIR") 

1313 

1314 # set title, legend, and save plot 

1315 ax.set_xlim((time_axis[0], time_axis[-1])) 

1316 

1317 plt.title(f"Input of Filter ({str(var)})") 

1318 plt.legend() 

1319 fig.autofmt_xdate() 

1320 plt.tight_layout() 

1321 self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}" 

1322 self._save() 

1323 

1324 # plot residuum 

1325 fig, ax = plt.subplots() 

1326 self._plot_t0(ax, t0) 

1327 self._plot_data(ax, time_axis, filter_input - filtered, style="FIR") 

1328 ax.set_xlim((time_axis[0], time_axis[-1])) 

1329 plt.title(f"Residuum of Filter ({str(var)})") 

1330 plt.legend(loc="upper left") 

1331 fig.autofmt_xdate() 

1332 plt.tight_layout() 

1333 

1334 self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum" 

1335 self._save() 

1336 except Exception as e: 

1337 logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

1338 pass 

1339 

1340 def _plot_t0(self, ax, t0): 

1341 ax.axvline(t0, **self.style_dict["t0"]) 

1342 

1343 def _plot_series(self, ax, time_axis, data, style): 

1344 ax.plot(time_axis, data, **self.style_dict[style]) 

1345 

1346 def _plot_data(self, ax, time_axis, data, style="original"): 

1347 # original data 

1348 self._plot_series(ax, time_axis, data.values.flatten(), style=style) 

1349 

1350 def _store_plot_data(self, data): 

1351 """Store plot data. Could be loaded in a notebook to redraw.""" 

1352 file = os.path.join(self.plot_folder, "plot_data.pickle") 

1353 with open(file, "wb") as f: 

1354 dill.dump(data, f)