Coverage for mlair/plotting/data_insight_plotting.py: 100%

7 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-12-02 15:24 +0000

1"""Collection of plots to get more insight into data.""" 

2__author__ = "Lukas Leufen, Felix Kleinert" 

3__date__ = '2021-04-13' 

4 

5from typing import List, Dict 

6import dill 

7import os 

8import logging 

9import multiprocessing 

10import psutil 

11import sys 

12 

13import numpy as np 

14import pandas as pd 

15import xarray as xr 

16import matplotlib 

17# matplotlib.use("Agg") 

18from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates 

19from astropy.timeseries import LombScargle 

20 

21from mlair.data_handler import DataCollection 

22from mlair.helpers import TimeTrackingWrapper, to_list, remove_items 

23from mlair.plotting.abstract_plot_class import AbstractPlotClass 

24 

25 

26@TimeTrackingWrapper 

27class PlotStationMap(AbstractPlotClass): # pragma: no cover 

28 """ 

29 Plot geographical overview of all used stations as squares. 

30 

31 Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to 

32 plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored 

33 topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf 

34 

35 .. image:: ../../../../../_source/_plots/station_map.png 

36 :width: 400 

37 """ 

38 

39 def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"): 

40 """ 

41 Set attributes and create plot. 

42 

43 :param generators: dictionary with the plot color of each data set as key and the generator containing all stations 

44 as value. 

45 :param plot_folder: path to save the plot (default: current directory) 

46 """ 

47 super().__init__(plot_folder, plot_name) 

48 self._ax = None 

49 self._gl = None 

50 self._plot(generators) 

51 self._save(bbox_inches="tight") 

52 

53 def _draw_background(self): 

54 """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" 

55 

56 import cartopy.feature as cfeature 

57 

58 self._ax.add_feature(cfeature.LAND.with_scale("50m")) 

59 self._ax.natural_earth_shp(resolution='50m') 

60 self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') 

61 self._ax.add_feature(cfeature.LAKES.with_scale("50m")) 

62 self._ax.add_feature(cfeature.OCEAN.with_scale("50m")) 

63 self._ax.add_feature(cfeature.RIVERS.with_scale("50m")) 

64 self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') 

65 

66 def _plot_stations(self, generators): 

67 """ 

68 Loop over all keys in generators dict and its containing stations and plot the stations's position. 

69 

70 Position is highlighted by a square on the map regarding the given color. 

71 

72 :param generators: dictionary with the plot color of each data set as key and the generator containing all 

73 stations as value. 

74 """ 

75 

76 import cartopy.crs as ccrs 

77 if generators is not None: 

78 legend_elements = [] 

79 default_colors = self.get_dataset_colors() 

80 for element in generators: 

81 data_collection, plot_opts = self._get_collection_and_opts(element) 

82 name = data_collection.name or "unknown" 

83 marker = plot_opts.get("marker", "s") 

84 ms = plot_opts.get("ms", 6) 

85 mec = plot_opts.get("mec", "k") 

86 mfc = plot_opts.get("mfc", default_colors.get(name, "b")) 

87 legend_elements.append( 

88 mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', 

89 label=f"{name} ({len(data_collection)})")) 

90 for station in data_collection: 

91 coords = station.get_coordinates() 

92 IDx, IDy = coords["lon"], coords["lat"] 

93 self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree()) 

94 if len(legend_elements) > 0: 

95 self._ax.legend(handles=legend_elements, loc='best') 

96 

97 @staticmethod 

98 def _adjust_marker(marker): 

99 _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} 

100 if isinstance(marker, int) and marker in _adjust.keys(): 

101 return _adjust[marker] 

102 else: 

103 return marker 

104 

105 @staticmethod 

106 def _get_collection_and_opts(element): 

107 if isinstance(element, tuple): 

108 if len(element) == 1: 

109 return element[0], {} 

110 else: 

111 return element 

112 else: 

113 return element, {} 

114 

115 def _plot(self, generators: List): 

116 """ 

117 Create the station map plot. 

118 

119 Set figure and call all required sub-methods. 

120 

121 :param generators: dictionary with the plot color of each data set as key and the generator containing all 

122 stations as value. 

123 """ 

124 

125 import cartopy.crs as ccrs 

126 from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER 

127 fig = plt.figure(figsize=(10, 5)) 

128 self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) 

129 self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) 

130 self._gl.xformatter = LONGITUDE_FORMATTER 

131 self._gl.yformatter = LATITUDE_FORMATTER 

132 self._draw_background() 

133 self._plot_stations(generators) 

134 self._adjust_extent() 

135 plt.tight_layout() 

136 

137 def _adjust_extent(self): 

138 import cartopy.crs as ccrs 

139 

140 def diff(arr): 

141 return arr[1] - arr[0], arr[3] - arr[2] 

142 

143 def find_ratio(delta, reference=5): 

144 return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5) 

145 

146 extent = self._ax.get_extent(crs=ccrs.PlateCarree()) 

147 ratio = find_ratio(diff(extent)) 

148 new_extent = extent + np.array([-1, 1, -1, 1]) * ratio 

149 self._ax.set_extent(new_extent, crs=ccrs.PlateCarree()) 

150 

151 

152@TimeTrackingWrapper 

153class PlotAvailability(AbstractPlotClass): # pragma: no cover 

154 """ 

155 Create data availablility plot similar to Gantt plot. 

156 

157 Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal 

158 resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a 

159 colored bar or a blank space. 

160 

161 You can set different colors to highlight subsets for example by providing different generators for the same index 

162 using different keys in the input dictionary. 

163 

164 Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs 

165 in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset. 

166 

167 Calling this class will create three versions fo the availability plot. 

168 

169 1) Data availability for each element 

170 1) Data availability as summary over all elements (is there at least a single elemnt for each time step) 

171 1) Combination of single and overall availability 

172 

173 .. image:: ../../../../../_source/_plots/data_availability.png 

174 :width: 400 

175 

176 .. image:: ../../../../../_source/_plots/data_availability_summary.png 

177 :width: 400 

178 

179 .. image:: ../../../../../_source/_plots/data_availability_combined.png 

180 :width: 400 

181 

182 """ 

183 

184 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily", 

185 summary_name="data availability", time_dimension="datetime", window_dimension="window"): 

186 """Initialise.""" 

187 # create standard Gantt plot for all stations (currently in single pdf file with single page) 

188 super().__init__(plot_folder, "data_availability") 

189 self.time_dim = time_dimension 

190 self.window_dim = window_dimension 

191 self.sampling = self._get_sampling(sampling)[1] 

192 self.linewidth = None 

193 if self.sampling == 'h': 

194 self.linewidth = 0.001 

195 plot_dict = self._prepare_data(generators) 

196 lgd = self._plot(plot_dict) 

197 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") 

198 # create summary Gantt plot (is data in at least one station available) 

199 self.plot_name += "_summary" 

200 plot_dict_summary = self._summarise_data(generators, summary_name) 

201 lgd = self._plot(plot_dict_summary) 

202 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") 

203 # combination of station and summary plot, last element is summary broken bar 

204 self.plot_name = "data_availability_combined" 

205 plot_dict_summary.update(plot_dict) 

206 lgd = self._plot(plot_dict_summary) 

207 self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") 

208 

209 def _prepare_data(self, generators: Dict[str, DataCollection]): 

210 plt_dict = {} 

211 for subset, data_collection in generators.items(): 

212 for station in data_collection: 

213 labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() 

214 labels_bool = labels.sel(**{self.window_dim: 1}).notnull() 

215 group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum() 

216 plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, 

217 index=labels.coords[self.time_dim].values) 

218 t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) 

219 t2 = [i[1:] for i in t if i[0]] 

220 

221 if plt_dict.get(str(station)) is None: 

222 plt_dict[str(station)] = {subset: t2} 

223 else: 

224 plt_dict[str(station)].update({subset: t2}) 

225 return plt_dict 

226 

227 def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str): 

228 plt_dict = {} 

229 for subset, data_collection in generators.items(): 

230 all_data = None 

231 for station in data_collection: 

232 labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() 

233 labels_bool = labels.sel(**{self.window_dim: 1}).notnull() 

234 if all_data is None: 

235 all_data = labels_bool 

236 else: 

237 tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords 

238 all_data = np.logical_or(tmp, labels_bool).combine_first( 

239 all_data) # apply logical on merge and fill missing with all_data 

240 

241 group = (all_data != all_data.shift({self.time_dim: 1})).cumsum() 

242 plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, 

243 index=all_data.coords[self.time_dim].values) 

244 t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) 

245 t2 = [i[1:] for i in t if i[0]] 

246 if plt_dict.get(summary_name) is None: 

247 plt_dict[summary_name] = {subset: t2} 

248 else: 

249 plt_dict[summary_name].update({subset: t2}) 

250 return plt_dict 

251 

252 def _plot(self, plt_dict): 

253 colors = self.get_dataset_colors() 

254 _used_colors = [] 

255 pos = 0 

256 height = 0.8 # should be <= 1 

257 yticklabels = [] 

258 number_of_stations = len(plt_dict.keys()) 

259 fig, ax = plt.subplots(figsize=(10, number_of_stations / 3)) 

260 for station, d in sorted(plt_dict.items(), reverse=True): 

261 pos += 1 

262 for subset, color in colors.items(): 

263 plt_data = d.get(subset) 

264 if plt_data is None: 

265 continue 

266 elif color not in _used_colors: # this is required for a proper legend creation 

267 _used_colors.append(color) 

268 ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth) 

269 yticklabels.append(station) 

270 

271 ax.set_ylim([height, number_of_stations + 1]) 

272 ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2) 

273 ax.set_yticklabels(yticklabels) 

274 handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors] 

275 lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) 

276 return lgd 

277 

278 

279@TimeTrackingWrapper 

280class PlotAvailabilityHistogram(AbstractPlotClass): # pragma: no cover 

281 """ 

282 Create data availability plots as histogram. 

283 

284 Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean). 

285 Calling this class creates two different types of histograms where each generator 

286 

287 1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis) 

288 2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number 

289 of samples (yaxis) 

290 

291 .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png 

292 :width: 400 

293 

294 .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png 

295 :width: 400 

296 

297 """ 

298 

299 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", 

300 subset_dim: str = 'DataSet', history_dim: str = 'window', 

301 station_dim: str = 'Stations', ): 

302 

303 super().__init__(plot_folder, "data_availability_histogram") 

304 

305 self.subset_dim = subset_dim 

306 self.history_dim = history_dim 

307 self.station_dim = station_dim 

308 

309 self.freq = None 

310 self.temporal_dim = None 

311 self.target_dim = None 

312 self._prepare_data(generators) 

313 

314 for plt_type in self.allowed_plot_types: 

315 plot_name_tmp = self.plot_name 

316 self.plot_name += '_' + plt_type 

317 self._plot(plt_type=plt_type) 

318 self._save() 

319 self.plot_name = plot_name_tmp 

320 

321 def _set_dims_from_datahandler(self, data_handler): 

322 self.temporal_dim = data_handler.id_class.time_dim 

323 self.target_dim = data_handler.id_class.target_dim 

324 self.freq = self._get_sampling(data_handler.id_class.sampling)[1] 

325 

326 @property 

327 def allowed_plot_types(self): 

328 plot_types = ['hist', 'hist_cum'] 

329 return plot_types 

330 

331 def _prepare_data(self, generators: Dict[str, DataCollection]): 

332 """ 

333 Prepares data to be used by plot methods. 

334 

335 Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim 

336 """ 

337 avail_data_time_sum = {} 

338 avail_data_station_sum = {} 

339 dataset_time_interval = {} 

340 for subset, generator in generators.items(): 

341 avail_list = [] 

342 for station in generator: 

343 self._set_dims_from_datahandler(data_handler=station) 

344 station_data_x = station.get_X(as_numpy=False)[0] 

345 station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame 

346 self.target_dim: station_data_x[self.target_dim].values[0]}] 

347 station_data_x = self._reduce_dims(station_data_x) 

348 avail_list.append(station_data_x.notnull()) 

349 avail_data = xr.concat(avail_list, dim=self.station_dim).notnull() 

350 avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim) 

351 avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim) 

352 dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray( 

353 avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict' 

354 ) 

355 avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(), 

356 name=self.subset_dim) 

357 ) 

358 full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq) 

359 self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(), 

360 name=self.subset_dim)) 

361 self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index}) 

362 self.dataset_time_interval = dataset_time_interval 

363 

364 def _reduce_dims(self, dataset): 

365 if len(dataset.dims) > 2: 

366 required = {self.temporal_dim, self.station_dim} 

367 unimportant = set(dataset.dims).difference(required) 

368 sel_dict = {un: dataset[un].values[0] for un in unimportant} 

369 dataset = dataset.loc[sel_dict] 

370 return dataset 

371 

372 @staticmethod 

373 def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'): 

374 if isinstance(xarray, xr.DataArray): 

375 first = xarray.coords[dim_name].values[0] 

376 last = xarray.coords[dim_name].values[-1] 

377 if return_type == 'as_tuple': 

378 return first, last 

379 elif return_type == 'as_dict': 

380 return {'first': first, 'last': last} 

381 else: 

382 raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'") 

383 else: 

384 raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}") 

385 

386 @staticmethod 

387 def _make_full_time_index(irregular_time_index, freq): 

388 full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq) 

389 return full_time_index 

390 

391 def _plot(self, plt_type='hist', *args): 

392 if plt_type == 'hist': 

393 self._plot_hist() 

394 elif plt_type == 'hist_cum': 

395 self._plot_hist_cum() 

396 else: 

397 raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}") 

398 

399 def _plot_hist(self, *args): 

400 colors = self.get_dataset_colors() 

401 fig, axes = plt.subplots(figsize=(10, 3)) 

402 for i, subset in enumerate(self.dataset_time_interval.keys()): 

403 plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset, 

404 self.temporal_dim: slice( 

405 self.dataset_time_interval[subset]['first'], 

406 self.dataset_time_interval[subset]['last'] 

407 ) 

408 } 

409 ) 

410 

411 plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset) 

412 plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset]) 

413 

414 lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), 

415 facecolor='white', framealpha=1, edgecolor='black') 

416 for lgd_line in lgd.get_lines(): 

417 lgd_line.set_linewidth(4.0) 

418 plt.gca().xaxis.set_major_locator(mdates.YearLocator()) 

419 plt.title('') 

420 plt.ylabel('Number of samples') 

421 plt.tight_layout() 

422 

423 def _plot_hist_cum(self, *args): 

424 colors = self.get_dataset_colors() 

425 fig, axes = plt.subplots(figsize=(10, 3)) 

426 n_bins = int(self.avail_data_cum_sum.max().values) 

427 bins = np.arange(0, n_bins + 1) 

428 descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby( 

429 self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False 

430 ).coords[self.subset_dim].values 

431 

432 for subset in descending_subsets: 

433 self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes, 

434 bins=bins, 

435 label=subset, 

436 cumulative=-1, 

437 color=colors[subset], 

438 # alpha=.5 

439 ) 

440 

441 lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), 

442 facecolor='white', framealpha=1, edgecolor='black') 

443 plt.title('') 

444 plt.ylabel('Number of stations') 

445 plt.xlabel('Number of samples') 

446 plt.xlim((bins[0], bins[-1])) 

447 plt.tight_layout() 

448 

449 

450@TimeTrackingWrapper 

451class PlotDataHistogram(AbstractPlotClass): # pragma: no cover 

452 """ 

453 Plot histogram on transformed input and target data. This data is the same that the model sees during training. No 

454 plots are create for the original values space (raw / unformatted data). This plot method will create a histogram 

455 for input and target each comparing the subsets train, val and test, as well as a distinct one for the three 

456 subsets. 

457 

458 .. image:: ../../../../../_source/_plots/datahistogram.png 

459 :width: 400 

460 

461 """ 

462 

463 def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", plot_name="histogram", 

464 variables_dim="variables", time_dim="datetime", window_dim="window", upsampling=False): 

465 super().__init__(plot_folder, plot_name) 

466 self.variables_dim = variables_dim 

467 self.time_dim = time_dim 

468 self.window_dim = window_dim 

469 self.inputs, self.targets, number_of_branches = self._get_inputs_targets(generators, self.variables_dim) 

470 self.bins = {} 

471 self.interval_width = {} 

472 self.bin_edges = {} 

473 if upsampling is True: 

474 self._handle_upsampling(generators) 

475 

476 # input plots 

477 for branch_pos in range(number_of_branches): 

478 self._calculate_hist(generators, self.inputs, input_data=True, branch_pos=branch_pos) 

479 add_name = "input" if number_of_branches == 1 else f"input_branch_{branch_pos}" 

480 for subset in generators.keys(): 

481 self._plot(add_name=add_name, subset=subset) 

482 self._plot_combined(add_name=add_name) 

483 

484 # target plots 

485 self._calculate_hist(generators, self.targets, input_data=False) 

486 for subset in generators.keys(): 

487 self._plot(add_name="target", subset=subset) 

488 self._plot_combined(add_name="target") 

489 

490 @staticmethod 

491 def _handle_upsampling(generators): 

492 if "train" in generators: 

493 generators.update({"train_upsampled": generators["train"]}) 

494 

495 @staticmethod 

496 def _get_inputs_targets(gens, dim): 

497 k = list(gens.keys())[0] 

498 gen = gens[k][0] 

499 inputs = list(set([y for x in to_list(gen.get_X(as_numpy=False)) for y in x.coords[dim].values.tolist()])) 

500 targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist()) 

501 n_branches = len(gen.get_X(as_numpy=False)) 

502 return inputs, targets, n_branches 

503 

504 def _calculate_hist(self, generators, variables, input_data=True, branch_pos=0): 

505 n_bins = 100 

506 for set_type, generator in generators.items(): 

507 upsampling = "upsampled" in set_type 

508 tmp_bins = {} 

509 tmp_edges = {} 

510 end = {} 

511 start = {} 

512 if input_data is True: 

513 f = lambda x: x.get_X(as_numpy=False, upsampling=upsampling)[branch_pos] 

514 else: 

515 f = lambda x: x.get_Y(as_numpy=False, upsampling=upsampling) 

516 for gen in generator: 

517 w = min(abs(f(gen).coords[self.window_dim].values)) 

518 data = f(gen).sel({self.window_dim: w}) 

519 res, _, g_edges = f_proc_hist(data, variables, n_bins, self.variables_dim) 

520 for var in res.keys(): 

521 b = tmp_bins.get(var, []) 

522 b.append(res[var]) 

523 tmp_bins[var] = b 

524 e = tmp_edges.get(var, []) 

525 e.append(g_edges[var]) 

526 tmp_edges[var] = e 

527 end[var] = max([end.get(var, g_edges[var].max()), g_edges[var].max()]) 

528 start[var] = min([start.get(var, g_edges[var].min()), g_edges[var].min()]) 

529 # interpolate and aggregate 

530 bins = {} 

531 edges = {} 

532 interval_width = {} 

533 for var in tmp_bins.keys(): 

534 bin_edges = np.linspace(start[var], end[var], n_bins + 1) 

535 interval_width[var] = bin_edges[1] - bin_edges[0] 

536 for i, e in enumerate(tmp_bins[var]): 

537 bins_interp = np.interp(bin_edges[:-1], tmp_edges[var][i][:-1], e, left=0, right=0) 

538 bins[var] = bins.get(var, np.zeros(n_bins)) + bins_interp 

539 edges[var] = bin_edges 

540 

541 self.bins[set_type] = bins 

542 self.interval_width[set_type] = interval_width 

543 self.bin_edges[set_type] = edges 

544 

545 def _plot(self, add_name, subset): 

546 plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{subset}_{add_name}.pdf") 

547 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

548 bins = self.bins[subset] 

549 bin_edges = self.bin_edges[subset] 

550 interval_width = self.interval_width[subset] 

551 colors = self.get_dataset_colors() 

552 colors.update({"train_upsampled": colors.get("train_val", "#000000")}) 

553 for var in bins.keys(): 

554 fig, ax = plt.subplots() 

555 hist_var = bins[var] 

556 n_var = sum(hist_var) 

557 weights = hist_var / (interval_width[var] * n_var) 

558 ax.hist(bin_edges[var][:-1], bin_edges[var], weights=weights, color=colors[subset]) 

559 ax.set_ylabel("probability density") 

560 ax.set_xlabel(f"values") 

561 ax.set_title(f"histogram {var} ({subset}, n={int(n_var)})") 

562 pdf_pages.savefig() 

563 # close all open figures / plots 

564 pdf_pages.close() 

565 plt.close('all') 

566 

567 def _plot_combined(self, add_name): 

568 plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}_{add_name}.pdf") 

569 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

570 variables = self.bins[list(self.bins.keys())[0]].keys() 

571 colors = self.get_dataset_colors() 

572 colors.update({"train_upsampled": colors.get("train_val", "#000000")}) 

573 for var in variables: 

574 fig, ax = plt.subplots() 

575 for subset in self.bins.keys(): 

576 hist_var = self.bins[subset][var] 

577 interval_width = self.interval_width[subset][var] 

578 bin_edges = self.bin_edges[subset][var] 

579 n_var = sum(hist_var) 

580 weights = hist_var / (interval_width * n_var) 

581 ax.plot(bin_edges[:-1] + 0.5 * interval_width, weights, label=f"{subset}", 

582 c=colors[subset]) 

583 ax.set_ylabel("probability density") 

584 ax.set_xlabel("values") 

585 ax.legend(loc="upper right") 

586 ax.set_title(f"histogram {var}") 

587 pdf_pages.savefig() 

588 # close all open figures / plots 

589 pdf_pages.close() 

590 plt.close('all') 

591 

592 

593@TimeTrackingWrapper 

594class PlotPeriodogram(AbstractPlotClass): # pragma: no cover 

595 """ 

596 Create Lomb-Scargle periodogram in raw input and target data. The Lomb-Scargle version can deal with missing values. 

597 

598 This plot routine is creating the following plots: 

599 

600 * "raw": data is not aggregated, 1 graph per variable 

601 * "": single data lines are aggregated, 1 graph per variable 

602 * "total": data is aggregated on all variables, single graph 

603 

604 If data consists on different sampling rates, a separate plot is create for each sampling. 

605 

606 .. image:: ../../../../../_source/_plots/periodogram.png 

607 :width: 400 

608 

609 .. note:: 

610 This plot is not included in the default plot list. To use this plot, add "PlotPeriodogram" to the `plot_list`. 

611 

612 .. warning:: 

613 This plot is highly sensitive to the data handler structure. Therefore, it is highly likely that this method is 

614 not compatible with any custom data handler. Proven data handlers are `DefaultDataHandler`, 

615 `DataHandlerMixedSampling`, `DataHandlerMixedSamplingWithFilter`. To work properly, the data handler must have 

616 the attribute `.id_class._data`. 

617 

618 """ 

619 

620 def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="periodogram", 

621 variables_dim="variables", time_dim="datetime", sampling="daily", use_multiprocessing=False): 

622 super().__init__(plot_folder, plot_name) 

623 self.variables_dim = variables_dim 

624 self.time_dim = time_dim 

625 

626 for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling, sampling)): 

627 self._sampling = s 

628 self._add_text = {0: "input", 1: "target"}[pos] 

629 multiple, label_names = self._has_filter_dimension(generator[0], pos) 

630 self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing) 

631 self._plot(raw=True) 

632 self._plot(raw=False) 

633 self._plot_total(raw=True) 

634 self._plot_total(raw=False) 

635 if multiple > 1: 

636 self._plot_difference(label_names, plot_name_add="_last") 

637 self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing, 

638 use_last_input_value=False) 

639 self._plot_difference(label_names, plot_name_add="_first") 

640 

641 @staticmethod 

642 def _has_filter_dimension(g, pos): 

643 """Inspect if filtered data is provided and return number and labels of filtered components.""" 

644 check_class = g.id_class 

645 check_data = [check_class.get_X(as_numpy=False), check_class.get_Y(as_numpy=False)][pos] 

646 if not hasattr(check_class, "filter_dim"): # data handler has no filtered data 

647 return 1, [] 

648 else: 

649 filter_dim = check_class.filter_dim 

650 if filter_dim not in check_data.coords.dims: # current data has no filter (e.g. target data) 

651 return 1, [] 

652 else: 

653 return check_data.coords[filter_dim].shape[0], check_data.coords[filter_dim].values.tolist() 

654 

655 @TimeTrackingWrapper 

656 def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False, use_last_input_value=True): 

657 """ 

658 Create periodogram data. 

659 """ 

660 self.raw_data = [] 

661 self.plot_data = [] 

662 self.plot_data_raw = [] 

663 self.plot_data_mean = [] 

664 iter = range(multiple if multiple == 1 else multiple + 1) 

665 for m in iter: 

666 plot_data_single = dict() 

667 plot_data_raw_single = dict() 

668 plot_data_mean_single = dict() 

669 self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000) 

670 raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing, 

671 use_last_input_value=use_last_input_value) 

672 for var in raw_data_single.keys(): 

673 pgram_com = [] 

674 pgram_mean = 0 

675 all_data = raw_data_single[var] 

676 pgram_mean_raw = np.zeros((len(self.f_index), len(all_data))) 

677 for i, (f, pgram) in enumerate(all_data): 

678 d = np.interp(self.f_index, f, pgram) 

679 pgram_com.append(d) 

680 pgram_mean += d 

681 pgram_mean_raw[:, i] = d 

682 pgram_mean /= len(all_data) 

683 plot_data_single[var] = pgram_com 

684 plot_data_mean_single[var] = (self.f_index, pgram_mean) 

685 plot_data_raw_single[var] = (self.f_index, pgram_mean_raw) 

686 self.plot_data.append(plot_data_single) 

687 self.plot_data_mean.append(plot_data_mean_single) 

688 self.plot_data_raw.append(plot_data_raw_single) 

689 

690 def _prepare_pgram_parallel_var(self, generator, m, pos, use_multiprocessing): 

691 """Implementation of data preprocessing using parallel variables element processing.""" 

692 raw_data_single = dict() 

693 for g in generator: 

694 if m == 0: 

695 d = g.id_class._data 

696 else: 

697 gd = g.id_class 

698 filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]} 

699 d = (gd.input_data.sel(filter_sel), gd.target_data) 

700 d = d[pos] if isinstance(d, tuple) else d 

701 res = [] 

702 if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution 

703 pool = multiprocessing.Pool( 

704 min([psutil.cpu_count(logical=False), len(d[self.variables_dim].values), 

705 16])) # use only physical cpus 

706 output = [ 

707 pool.apply_async(f_proc, 

708 args=(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim))) 

709 for var in d[self.variables_dim].values] 

710 for i, p in enumerate(output): 

711 res.append(p.get()) 

712 pool.close() 

713 pool.join() 

714 else: # serial solution 

715 for var in d[self.variables_dim].values: 

716 res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim))) 

717 for (var_str, f, pgram) in res: 

718 if var_str not in raw_data_single.keys(): 

719 raw_data_single[var_str] = [(f, pgram)] 

720 else: 

721 raw_data_single[var_str] = raw_data_single[var_str] + [(f, pgram)] 

722 return raw_data_single 

723 

724 def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing, use_last_input_value=True): 

725 """Implementation of data preprocessing using parallel generator element processing.""" 

726 raw_data_single = dict() 

727 res = [] 

728 if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution 

729 pool = multiprocessing.Pool( 

730 min([psutil.cpu_count(logical=False), len(generator), 16])) # use only physical cpus 

731 output = [ 

732 pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim, self.f_index, 

733 use_last_input_value)) 

734 for g in generator] 

735 for i, p in enumerate(output): 

736 res.append(p.get()) 

737 pool.close() 

738 pool.join() 

739 else: 

740 for g in generator: 

741 res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim, self.f_index, use_last_input_value)) 

742 for res_dict in res: 

743 for k, v in res_dict.items(): 

744 if k not in raw_data_single.keys(): 

745 raw_data_single[k] = v 

746 else: 

747 raw_data_single[k] = raw_data_single[k] + v 

748 return raw_data_single 

749 

750 @staticmethod 

751 def _add_annotation_line(ax, pos, div, lims, unit): 

752 for p in to_list(pos): # per year 

753 ax.vlines(p / div, *lims, "black") 

754 ax.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor") 

755 

756 def _format_figure(self, ax, var_name="total"): 

757 """ 

758 Set log scale on both axis, add labels and annotation lines, and set title. 

759 :param ax: current ax object 

760 :param var_name: name of variable that will be included in the title 

761 """ 

762 ax.set_yscale('log') 

763 ax.set_xscale('log') 

764 ax.set_ylabel("power spectral density", fontsize='x-large') # unit depends on variable: [unit^2 day^-1] 

765 ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large') 

766 lims = ax.get_ylim() 

767 self._add_annotation_line(ax, [1, 2, 3], 365.25, lims, "yr") # per year 

768 self._add_annotation_line(ax, 1, 365.25 / 12, lims, "m") # per month 

769 self._add_annotation_line(ax, 1, 7, lims, "w") # per week 

770 self._add_annotation_line(ax, [1, 0.5], 1, lims, "d") # per day 

771 if self._sampling == "hourly": 

772 self._add_annotation_line(ax, 2, 1, lims, "d") # per day 

773 self._add_annotation_line(ax, [1, 0.5], 1 / 24., lims, "h") # per hour 

774 title = f"Periodogram ({var_name})" 

775 ax.set_title(title) 

776 

777 def _plot(self, raw=True): 

778 plot_path = os.path.join(os.path.abspath(self.plot_folder), 

779 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}.pdf") 

780 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

781 plot_data = self.plot_data[0] 

782 plot_data_mean = self.plot_data_mean[0] 

783 for var in plot_data.keys(): 

784 fig, ax = plt.subplots() 

785 if raw is True: 

786 for pgram in plot_data[var]: 

787 ax.plot(self.f_index, pgram, "lightblue") 

788 ax.plot(*plot_data_mean[var], "blue") 

789 else: 

790 ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0) 

791 mean = ma.mean().mean(axis=1).values.flatten() 

792 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() 

793 ax.plot(self.f_index, mean, "blue") 

794 ax.fill_between(self.f_index, lower, upper, color="lightblue") 

795 self._format_figure(ax, var) 

796 pdf_pages.savefig() 

797 # close all open figures / plots 

798 pdf_pages.close() 

799 plt.close('all') 

800 

801 def _plot_total(self, raw=True): 

802 plot_path = os.path.join(os.path.abspath(self.plot_folder), 

803 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}_total.pdf") 

804 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

805 plot_data_raw = self.plot_data_raw[0] 

806 fig, ax = plt.subplots() 

807 res = None 

808 for var in plot_data_raw.keys(): 

809 d_var = plot_data_raw[var][1] 

810 res = d_var if res is None else np.concatenate((res, d_var), axis=-1) 

811 if raw is True: 

812 for i in range(res.shape[1]): 

813 ax.plot(self.f_index, res[:, i], "lightblue") 

814 ax.plot(self.f_index, res.mean(axis=1), "blue") 

815 else: 

816 ma = pd.DataFrame(np.vstack(res)).rolling(5, center=True, axis=0) 

817 mean = ma.mean().mean(axis=1).values.flatten() 

818 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() 

819 ax.plot(self.f_index, mean, "blue") 

820 ax.fill_between(self.f_index, lower, upper, color="lightblue") 

821 self._format_figure(ax, "total") 

822 pdf_pages.savefig() 

823 # close all open figures / plots 

824 pdf_pages.close() 

825 plt.close('all') 

826 

827 def _plot_difference(self, label_names, plot_name_add = ""): 

828 plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter{plot_name_add}.pdf" 

829 plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) 

830 logging.info(f"... plotting {plot_name}") 

831 pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) 

832 colors = ["grey", "blue", "red", "green", "orange", "purple", "black"] 

833 label_names = ["orig"] + label_names 

834 max_iter = len(self.plot_data) 

835 var_keys = self.plot_data[0].keys() 

836 for var in var_keys: 

837 fig, ax = plt.subplots() 

838 for i in reversed(range(max_iter)): 

839 if label_names[i] == "unfiltered": 

840 continue # do not include the filter 'unfiltered' because this is equal to the 'orig' data 

841 plot_data = self.plot_data[i] 

842 c = colors[i] 

843 ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0) 

844 mean = ma.mean().mean(axis=1).values.flatten() 

845 ax.plot(self.f_index, mean, c, label=label_names[i]) 

846 if i < 1: 

847 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() 

848 ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5, label=None) 

849 self._format_figure(ax, var) 

850 ax.legend(loc="upper center", ncol=max_iter) 

851 pdf_pages.savefig() 

852 # close all open figures / plots 

853 pdf_pages.close() 

854 plt.close('all') 

855 

856 

857def f_proc(var, d_var, f_index, time_dim="datetime", use_last_value=True): # pragma: no cover 

858 var_str = str(var) 

859 t = (d_var[time_dim] - d_var[time_dim][0]).astype("timedelta64[h]").values / np.timedelta64(1, "D") 

860 if len(d_var.shape) > 1: # use only max value if dimensions are remaining (e.g. max(window) -> latest value) 

861 to_remove = remove_items(d_var.coords.dims, time_dim) 

862 for e in to_list(to_remove): 

863 d_var = d_var.sel({e: d_var[e].max() if use_last_value is True else d_var[e].min()}) 

864 pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").power(f_index) 

865 # f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").autopower() 

866 return var_str, f_index, pgram 

867 

868 

869def f_proc_2(g, m, pos, variables_dim, time_dim, f_index, use_last_value): # pragma: no cover 

870 

871 # load lazy data 

872 id_classes = list(filter(lambda x: "id_class" in x, dir(g))) if pos == 0 else ["id_class"] 

873 for id_cls_name in id_classes: 

874 id_cls = getattr(g, id_cls_name) 

875 if hasattr(id_cls, "lazy"): 

876 id_cls.load_lazy() if id_cls.lazy is True else None 

877 

878 raw_data_single = dict() 

879 for dh in list(filter(lambda x: "unfiltered" not in x, id_classes)): 

880 current_cls = getattr(g, dh) 

881 if m == 0: 

882 d = current_cls._data 

883 if d is None: 

884 window_dim = current_cls.window_dim 

885 history = current_cls.history 

886 last_entry = history.coords[window_dim][-1] 

887 d1 = history.sel({window_dim: last_entry}, drop=True) 

888 label = current_cls.label 

889 first_entry = label.coords[window_dim][0] 

890 d2 = label.sel({window_dim: first_entry}, drop=True) 

891 d = (d1, d2) 

892 else: 

893 filter_sel = {"filter": current_cls.input_data.coords["filter"][m - 1]} 

894 d = (current_cls.input_data.sel(filter_sel), current_cls.target_data) 

895 d = d[pos] if isinstance(d, tuple) else d 

896 for var in d[variables_dim].values: 

897 d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim) 

898 var_str, f, pgram = f_proc(var, d_var, f_index, use_last_value=use_last_value) 

899 if var_str not in raw_data_single.keys(): 

900 raw_data_single[var_str] = [(f, pgram)] 

901 else: 

902 raise KeyError(f"There are multiple pgrams for key {var_str}. Please check your data handler.") 

903 

904 # perform clean up 

905 for id_cls_name in id_classes: 

906 id_cls = getattr(g, id_cls_name) 

907 if hasattr(id_cls, "lazy"): 

908 id_cls.clean_up() if id_cls.lazy is True else None 

909 

910 return raw_data_single 

911 

912 

913def f_proc_hist(data, variables, n_bins, variables_dim): # pragma: no cover 

914 res = {} 

915 bin_edges = {} 

916 interval_width = {} 

917 for var in variables: 

918 if var in data.coords[variables_dim]: 

919 d = data.sel({variables_dim: var}).squeeze() if len(data.shape) > 1 else data 

920 res[var], bin_edges[var] = np.histogram(d.values, n_bins) 

921 interval_width[var] = bin_edges[var][1] - bin_edges[var][0] 

922 return res, interval_width, bin_edges 

923 

924 

925class PlotClimateFirFilter(AbstractPlotClass): # pragma: no cover 

926 """ 

927 Plot climate FIR filter components. 

928 

929 * Creates a separate folder climFIR inside the given plot directory. 

930 * For each station up to 4 examples are shown (1 for each season). 

931 * Each filtered component and its residuum is drawn in a separate plot. 

932 * A filter component plot includes the climate FIR input, the filter response, the true non-causal (ideal) filter 

933 input, and the corresponding ideal response (containing information about future) 

934 * A filter residuum plot include the climate FIR residuum and the ideal filter residuum. 

935 """ 

936 

937 def __init__(self, plot_folder, plot_data, sampling, name): 

938 

939 from mlair.helpers.filter import fir_filter_convolve 

940 

941 logging.info(f"start PlotClimateFirFilter for ({name})") 

942 

943 # adjust default plot parameters 

944 rc_params = { 

945 'axes.labelsize': 'large', 

946 'xtick.labelsize': 'large', 

947 'ytick.labelsize': 'large', 

948 'legend.fontsize': 'medium', 

949 'axes.titlesize': 'large'} 

950 if plot_folder is None: 

951 return 

952 

953 self.style_dict = { 

954 "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"}, 

955 "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"}, 

956 "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2}, 

957 "ideal": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2}, 

958 "valid_area": {"color": "whitesmoke", "label": "valid area"}, 

959 "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"} 

960 } 

961 

962 self.variables_list = [] 

963 plot_folder = os.path.join(os.path.abspath(plot_folder), "climFIR") 

964 self.fir_filter_convolve = fir_filter_convolve 

965 super().__init__(plot_folder, plot_name=None, rc_params=rc_params) 

966 plot_dict, new_dim = self._prepare_data(plot_data) 

967 self._name = name 

968 self._plot(plot_dict, sampling, new_dim) 

969 self._store_plot_data(plot_data) 

970 

971 def _prepare_data(self, data): 

972 """Restructure plot data.""" 

973 plot_dict = {} 

974 new_dim = None 

975 for i in range(len(data)): 

976 plot_data = data[i] 

977 for p_d in plot_data: 

978 var = p_d.get("var") 

979 t0 = p_d.get("t0") 

980 filter_input = p_d.get("filter_input") 

981 filter_input_nc = p_d.get("filter_input_nc") 

982 valid_range = p_d.get("valid_range") 

983 time_range = p_d.get("time_range") 

984 if new_dim is None: 

985 new_dim = p_d.get("new_dim") 

986 else: 

987 assert new_dim == p_d.get("new_dim") 

988 h = p_d.get("h") 

989 plot_dict_var = plot_dict.get(var, {}) 

990 plot_dict_t0 = plot_dict_var.get(t0, {}) 

991 plot_dict_order = {"filter_input": filter_input, 

992 "filter_input_nc": filter_input_nc, 

993 "valid_range": valid_range, 

994 "time_range": time_range, 

995 "order": len(h), "h": h} 

996 plot_dict_t0[i] = plot_dict_order 

997 plot_dict_var[t0] = plot_dict_t0 

998 plot_dict[var] = plot_dict_var 

999 self.variables_list = list(plot_dict.keys()) 

1000 return plot_dict, new_dim 

1001 

1002 def _plot(self, plot_dict, sampling, new_dim="window"): 

1003 td_type = {"1d": "D", "1H": "h"}.get(sampling) 

1004 for var, vis_dict in plot_dict.items(): 

1005 for it0, t0 in enumerate(vis_dict.keys()): 

1006 vis_data = vis_dict[t0] 

1007 residuum_true = None 

1008 try: 

1009 for ifilter in sorted(vis_data.keys()): 

1010 data = vis_data[ifilter] 

1011 filter_input = data["filter_input"] 

1012 filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel( 

1013 {new_dim: filter_input.coords[new_dim]}) 

1014 valid_range = data["valid_range"] 

1015 time_axis = data["time_range"] 

1016 filter_order = data["order"] 

1017 h = data["h"] 

1018 fig, ax = plt.subplots() 

1019 

1020 # plot backgrounds 

1021 self._plot_valid_area(ax, t0, valid_range, td_type) 

1022 self._plot_t0(ax, t0) 

1023 

1024 # original data 

1025 self._plot_original_data(ax, time_axis, filter_input_nc) 

1026 

1027 # clim apriori 

1028 self._plot_apriori(ax, time_axis, filter_input, new_dim, ifilter, offset=1) 

1029 

1030 # clim filter response 

1031 residuum_estimated = self._plot_clim_filter(ax, time_axis, filter_input, new_dim, h, 

1032 output_dtypes=filter_input.dtype) 

1033 

1034 # ideal filter response 

1035 residuum_true = self._plot_ideal_filter(ax, time_axis, filter_input_nc, new_dim, h, 

1036 output_dtypes=filter_input.dtype) 

1037 

1038 # set title, legend, and save plot 

1039 xlims = self._set_xlim(ax, t0, filter_order, valid_range, td_type, time_axis) 

1040 

1041 plt.title(f"Input of ClimFilter ({str(var)})") 

1042 plt.legend() 

1043 fig.autofmt_xdate() 

1044 plt.tight_layout() 

1045 self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}" 

1046 self._save() 

1047 

1048 # plot residuum 

1049 fig, ax = plt.subplots() 

1050 self._plot_valid_area(ax, t0, valid_range, td_type) 

1051 self._plot_t0(ax, t0) 

1052 self._plot_series(ax, time_axis, residuum_true.values.flatten(), style="ideal") 

1053 self._plot_series(ax, time_axis, residuum_estimated.values.flatten(), style="clim") 

1054 ax.set_xlim(xlims) 

1055 plt.title(f"Residuum of ClimFilter ({str(var)})") 

1056 plt.legend(loc="upper left") 

1057 fig.autofmt_xdate() 

1058 plt.tight_layout() 

1059 

1060 self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum" 

1061 self._save() 

1062 except Exception as e: 

1063 logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

1064 pass 

1065 

1066 def _set_xlim(self, ax, t0, order, valid_range, td_type, time_axis): 

1067 """ 

1068 Set xlims 

1069 

1070 Use order and valid_range to find a good zoom in that hides edges of filter values that are effected by reduced 

1071 filter order. Limits are returned to be usable for other plots. 

1072 """ 

1073 t_minus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), (-valid_range.start + 0.3 * order)) 

1074 t_plus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), valid_range.stop + 0.3 * order) 

1075 t_minus = t0 + np.timedelta64(-int(t_minus_delta), td_type) 

1076 t_plus = t0 + np.timedelta64(int(t_plus_delta), td_type) 

1077 ax_start = max(t_minus, time_axis[0]) 

1078 ax_end = min(t_plus, time_axis[-1]) 

1079 ax.set_xlim((ax_start, ax_end)) 

1080 return ax_start, ax_end 

1081 

1082 def _plot_valid_area(self, ax, t0, valid_range, td_type): 

1083 ax.axvspan(t0 + np.timedelta64(valid_range.start, td_type), 

1084 t0 + np.timedelta64(valid_range.stop - 1, td_type), **self.style_dict["valid_area"]) 

1085 

1086 def _plot_t0(self, ax, t0): 

1087 ax.axvline(t0, **self.style_dict["t0"]) 

1088 

1089 def _plot_series(self, ax, time_axis, data, style): 

1090 ax.plot(time_axis, data, **self.style_dict[style]) 

1091 

1092 def _plot_original_data(self, ax, time_axis, data): 

1093 # original data 

1094 filter_input_nc = data 

1095 self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), style="original") 

1096 # self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed", 

1097 # label="original") 

1098 

1099 def _plot_apriori(self, ax, time_axis, data, new_dim, ifilter, offset): 

1100 # clim apriori 

1101 filter_input = data 

1102 if ifilter == 0: 

1103 d_tmp = filter_input.sel( 

1104 {new_dim: slice(offset, filter_input.coords[new_dim].values.max())}).values.flatten() 

1105 else: 

1106 d_tmp = filter_input.values.flatten() 

1107 self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, style="apriori") 

1108 # self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid", 

1109 # label="estimated future") 

1110 

1111 def _plot_clim_filter(self, ax, time_axis, data, new_dim, h, output_dtypes): 

1112 filter_input = data 

1113 # clim filter response 

1114 filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input, 

1115 input_core_dims=[[new_dim]], 

1116 output_core_dims=[[new_dim]], 

1117 vectorize=True, 

1118 kwargs={"h": h}, 

1119 output_dtypes=[output_dtypes]) 

1120 self._plot_series(ax, time_axis, filt.values.flatten(), style="clim") 

1121 # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="solid", 

1122 # label="clim filter response", linewidth=2) 

1123 residuum_estimated = filter_input - filt 

1124 return residuum_estimated 

1125 

1126 def _plot_ideal_filter(self, ax, time_axis, data, new_dim, h, output_dtypes): 

1127 filter_input_nc = data 

1128 # ideal filter response 

1129 filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input_nc, 

1130 input_core_dims=[[new_dim]], 

1131 output_core_dims=[[new_dim]], 

1132 vectorize=True, 

1133 kwargs={"h": h}, 

1134 output_dtypes=[output_dtypes]) 

1135 self._plot_series(ax, time_axis, filt.values.flatten(), style="ideal") 

1136 # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="dashed", 

1137 # label="ideal filter response", linewidth=2) 

1138 residuum_true = filter_input_nc - filt 

1139 return residuum_true 

1140 

1141 def _store_plot_data(self, data): 

1142 """Store plot data. Could be loaded in a notebook to redraw.""" 

1143 file = os.path.join(self.plot_folder, "_".join(self.variables_list) + "plot_data.pickle") 

1144 with open(file, "wb") as f: 

1145 dill.dump(data, f) 

1146 

1147 

1148class PlotFirFilter(AbstractPlotClass): # pragma: no cover 

1149 """ 

1150 Plot FIR filter components. 

1151 

1152 * Creates a separate folder FIR inside the given plot directory. 

1153 * For each station up to 4 examples are shown (1 for each season). 

1154 * Each filtered component and its residuum is drawn in a separate plot. 

1155 * A filter component plot includes the FIR input and the filter response 

1156 * A filter residuum plot include the FIR residuum 

1157 """ 

1158 

1159 def __init__(self, plot_folder, plot_data, name): 

1160 

1161 logging.info(f"start PlotFirFilter for ({name})") 

1162 

1163 # adjust default plot parameters 

1164 rc_params = { 

1165 'axes.labelsize': 'large', 

1166 'xtick.labelsize': 'large', 

1167 'ytick.labelsize': 'large', 

1168 'legend.fontsize': 'medium', 

1169 'axes.titlesize': 'large'} 

1170 if plot_folder is None: 

1171 return 

1172 

1173 self.style_dict = { 

1174 "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"}, 

1175 "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"}, 

1176 "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2}, 

1177 "FIR": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2}, 

1178 "valid_area": {"color": "whitesmoke", "label": "valid area"}, 

1179 "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"} 

1180 } 

1181 

1182 plot_folder = os.path.join(os.path.abspath(plot_folder), "FIR") 

1183 super().__init__(plot_folder, plot_name=None, rc_params=rc_params) 

1184 plot_dict = self._prepare_data(plot_data) 

1185 self._name = name 

1186 self._plot(plot_dict) 

1187 self._store_plot_data(plot_data) 

1188 

1189 def _prepare_data(self, data): 

1190 """Restructure plot data.""" 

1191 plot_dict = {} 

1192 for i in range(len(data)): # filter component 

1193 for j in range(len(data[i])): # t0 counter 

1194 plot_data = data[i][j] 

1195 t0 = plot_data.get("t0") 

1196 filter_input = plot_data.get("filter_input") 

1197 filtered = plot_data.get("filtered") 

1198 var_dim = plot_data.get("var_dim") 

1199 time_dim = plot_data.get("time_dim") 

1200 for var in filtered.coords[var_dim].values: 

1201 plot_dict_var = plot_dict.get(var, {}) 

1202 plot_dict_t0 = plot_dict_var.get(t0, {}) 

1203 plot_dict_order = {"filter_input": filter_input.sel({var_dim: var}, drop=True), 

1204 "filtered": filtered.sel({var_dim: var}, drop=True), 

1205 "time_dim": time_dim} 

1206 plot_dict_t0[i] = plot_dict_order 

1207 plot_dict_var[t0] = plot_dict_t0 

1208 plot_dict[var] = plot_dict_var 

1209 return plot_dict 

1210 

1211 def _plot(self, plot_dict): 

1212 for var, viz_date_dict in plot_dict.items(): 

1213 for it0, t0 in enumerate(viz_date_dict.keys()): 

1214 viz_data = viz_date_dict[t0] 

1215 try: 

1216 for ifilter in sorted(viz_data.keys()): 

1217 data = viz_data[ifilter] 

1218 filter_input = data["filter_input"] 

1219 filtered = data["filtered"] 

1220 time_dim = data["time_dim"] 

1221 time_axis = filtered.coords[time_dim].values 

1222 fig, ax = plt.subplots() 

1223 

1224 # plot backgrounds 

1225 self._plot_t0(ax, t0) 

1226 

1227 # original data 

1228 self._plot_data(ax, time_axis, filter_input, style="original") 

1229 

1230 # filter response 

1231 self._plot_data(ax, time_axis, filtered, style="FIR") 

1232 

1233 # set title, legend, and save plot 

1234 ax.set_xlim((time_axis[0], time_axis[-1])) 

1235 

1236 plt.title(f"Input of Filter ({str(var)})") 

1237 plt.legend() 

1238 fig.autofmt_xdate() 

1239 plt.tight_layout() 

1240 self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}" 

1241 self._save() 

1242 

1243 # plot residuum 

1244 fig, ax = plt.subplots() 

1245 self._plot_t0(ax, t0) 

1246 self._plot_data(ax, time_axis, filter_input - filtered, style="FIR") 

1247 ax.set_xlim((time_axis[0], time_axis[-1])) 

1248 plt.title(f"Residuum of Filter ({str(var)})") 

1249 plt.legend(loc="upper left") 

1250 fig.autofmt_xdate() 

1251 plt.tight_layout() 

1252 

1253 self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum" 

1254 self._save() 

1255 except Exception as e: 

1256 logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") 

1257 pass 

1258 

1259 def _plot_t0(self, ax, t0): 

1260 ax.axvline(t0, **self.style_dict["t0"]) 

1261 

1262 def _plot_series(self, ax, time_axis, data, style): 

1263 ax.plot(time_axis, data, **self.style_dict[style]) 

1264 

1265 def _plot_data(self, ax, time_axis, data, style="original"): 

1266 # original data 

1267 self._plot_series(ax, time_axis, data.values.flatten(), style=style) 

1268 

1269 def _store_plot_data(self, data): 

1270 """Store plot data. Could be loaded in a notebook to redraw.""" 

1271 file = os.path.join(self.plot_folder, "plot_data.pickle") 

1272 with open(file, "wb") as f: 

1273 dill.dump(data, f)