Coverage for mlair/run_modules/pre_processing.py: 59%

308 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:40 +0000

1"""Pre-processing module.""" 

2 

3__author__ = "Lukas Leufen, Felix Kleinert" 

4__date__ = '2019-11-25' 

5 

6import logging 

7import os 

8import traceback 

9from typing import Tuple 

10import multiprocessing 

11import requests 

12import psutil 

13import random 

14import dill 

15 

16import pandas as pd 

17 

18from mlair.data_handler import DataCollection, AbstractDataHandler 

19from mlair.helpers import TimeTracking, to_list, tables, remove_items 

20from mlair.configuration import path_config 

21from mlair.helpers.data_sources.data_loader import EmptyQueryResult 

22from mlair.helpers.testing import check_nested_equality 

23from mlair.run_modules.run_environment import RunEnvironment 

24 

25 

26class PreProcessing(RunEnvironment): 

27 """ 

28 Pre-process your data by using this class. 

29 

30 Schedule of pre-processing: 

31 #. load and check valid stations (either download or load from disk) 

32 #. split subsets (train, val, test, train & val) 

33 #. create small report on data metrics 

34 

35 Required objects [scope] from data store: 

36 * all elements from `DEFAULT_ARGS_LIST` in scope preprocessing for general data loading 

37 * all elements from `DEFAULT_ARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings 

38 * `fraction_of_training` [.] 

39 * `experiment_path` [.] 

40 * `use_all_stations_on_all_data_sets` [.] 

41 

42 Optional objects 

43 * all elements from `DEFAULT_KWARGS_LIST` in scope preprocessing for general data loading 

44 * all elements from `DEFAULT_KWARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings 

45 

46 Sets 

47 * `stations` in [., train, val, test, train_val] 

48 * `generator` in [train, val, test, train_val] 

49 * `transformation` [.] 

50 

51 Creates 

52 * all input and output data in `data_path` 

53 * latex reports in `experiment_path/latex_report` 

54 

55 """ 

56 

57 def __init__(self): 

58 """Set up and run pre-processing.""" 

59 super().__init__() 

60 self._run() 

61 

62 def _run(self): 

63 snapshot_load_path = self.data_store.get_default("snapshot_load_path", default=None) 

64 if snapshot_load_path is None: 64 ↛ 75line 64 didn't jump to line 75, because the condition on line 64 was never false

65 stations = self.data_store.get("stations") 

66 data_handler = self.data_store.get("data_handler") 

67 self._load_apriori() 

68 _, valid_stations = self.validate_station(data_handler, stations, 

69 "preprocessing") # , store_processed_data=False) 

70 if len(valid_stations) == 0: 70 ↛ 71line 70 didn't jump to line 71, because the condition on line 70 was never true

71 raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.") 

72 self.data_store.set("stations", valid_stations) 

73 self.split_train_val_test() 

74 else: 

75 self.load_snapshot(snapshot_load_path) 

76 self.report_pre_processing() 

77 self.prepare_competitors() 

78 if self.data_store.get_default("create_snapshot", False) is True: 78 ↛ 79line 78 didn't jump to line 79, because the condition on line 78 was never true

79 self.create_snapshot() 

80 

81 def report_pre_processing(self): 

82 """Log some metrics on data and create latex report.""" 

83 logging.debug(20 * '##') 

84 n_train = len(self.data_store.get('data_collection', 'train')) 

85 n_val = len(self.data_store.get('data_collection', 'val')) 

86 n_test = len(self.data_store.get('data_collection', 'test')) 

87 n_total = n_train + n_val + n_test 

88 logging.debug(f"Number of all stations: {n_total}") 

89 logging.debug(f"Number of training stations: {n_train}") 

90 logging.debug(f"Number of val stations: {n_val}") 

91 logging.debug(f"Number of test stations: {n_test}") 

92 self.create_latex_report() 

93 

94 def create_latex_report(self): 

95 """ 

96 Create tables with information on the station meta data and a summary on subset sample sizes. 

97 

98 * station_sample_size.md: see table below as markdown 

99 * station_sample_size.tex: same as table below as latex table 

100 * station_sample_size_short.tex: reduced size table without any meta data besides station ID, as latex table 

101 

102 All tables are stored inside experiment_path inside the folder latex_report. The table format (e.g. which meta 

103 data is highlighted) is currently hardcoded to have a stable table style. If further styles are needed, it is 

104 better to add an additional style than modifying the existing table styles. 

105 

106 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

107 | stat. ID | station_name | station_lon | station_lat | station_alt | train | val | test | 

108 +============+===========================================+===============+===============+===============+=========+=======+========+ 

109 | DEBW013 | Stuttgart Bad Cannstatt | 9.2297 | 48.8088 | 235 | 1434 | 712 | 1080 | 

110 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

111 | DEBW076 | Baden-Baden | 8.2202 | 48.7731 | 148 | 3037 | 722 | 710 | 

112 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

113 | DEBW087 | Schwäbische_Alb | 9.2076 | 48.3458 | 798 | 3044 | 714 | 1087 | 

114 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

115 | DEBW107 | Tübingen | 9.0512 | 48.5077 | 325 | 1803 | 715 | 1087 | 

116 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

117 | DEBY081 | Garmisch-Partenkirchen/Kreuzeckbahnstraße | 11.0631 | 47.4764 | 735 | 2935 | 525 | 714 | 

118 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

119 | # Stations | nan | nan | nan | nan | 6 | 6 | 6 | 

120 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

121 | # Samples | nan | nan | nan | nan | 12253 | 3388 | 4678 | 

122 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

123 

124 """ 

125 meta_cols = ["name", "lat", "lon", "alt", "country", "state", "type", "type_of_area", "toar1_category"] 

126 meta_round = ["lat", "lon", "alt"] 

127 precision = 4 

128 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

129 path_config.check_path_and_create(path) 

130 names_of_set = ["train", "val", "test"] 

131 df = self.create_info_df(meta_cols, meta_round, names_of_set, precision) 

132 column_format = tables.create_column_format_for_tex(df) 

133 tables.save_to_tex(path=path, filename="station_sample_size.tex", column_format=column_format, df=df) 

134 tables.save_to_md(path=path, filename="station_sample_size.md", df=df) 

135 df_nometa = df.drop(meta_cols, axis=1) 

136 column_format = tables.create_column_format_for_tex(df) 

137 tables.save_to_tex(path=path, filename="station_sample_size_short.tex", column_format=column_format, 

138 df=df_nometa) 

139 tables.save_to_md(path=path, filename="station_sample_size_short.md", df=df_nometa) 

140 df_descr = self.create_describe_df(df_nometa) 

141 column_format = tables.create_column_format_for_tex(df_descr) 

142 tables.save_to_tex(path=path, filename="station_describe_short.tex", column_format=column_format, df=df_descr) 

143 tables.save_to_md(path=path, filename="station_describe_short.md", df=df_descr) 

144 

145 @staticmethod 

146 def create_describe_df(df, percentiles=None, ignore_last_lines: int = 2): 

147 if percentiles is None: 147 ↛ 149line 147 didn't jump to line 149, because the condition on line 147 was never false

148 percentiles = [.05, .1, .25, .5, .75, .9, .95] 

149 df_descr = df.iloc[:-ignore_last_lines].astype('float32').describe( 

150 percentiles=percentiles).astype("int32", errors="ignore") 

151 df_descr = pd.concat([df.loc[['# Samples']], df_descr]).T 

152 df_descr.rename(columns={"# Samples": "no. samples", "count": "no. stations"}, inplace=True) 

153 df_descr_colnames = list(df_descr.columns) 

154 df_descr_colnames = [df_descr_colnames[1]] + [df_descr_colnames[0]] + df_descr_colnames[2:] 

155 df_descr = df_descr[df_descr_colnames] 

156 return df_descr 

157 

158 def create_info_df(self, meta_cols, meta_round, names_of_set, precision): 

159 use_multiprocessing = self.data_store.get("use_multiprocessing") 

160 max_process = self.data_store.get("max_number_multiprocessing") 

161 df = pd.DataFrame(columns=meta_cols + names_of_set) 

162 for set_name in names_of_set: 

163 data = self.data_store.get("data_collection", set_name) 

164 n_process = min([psutil.cpu_count(logical=False), len(data), max_process]) # use only physical cpus 

165 if n_process > 1 and use_multiprocessing is True: # parallel solution 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true

166 logging.info(f"use parallel create_info_df ({set_name})") 

167 pool = multiprocessing.Pool(n_process) 

168 logging.info(f"running {getattr(pool, '_processes')} processes in parallel") 

169 output = [pool.apply_async(f_proc_create_info_df, args=(station, meta_cols)) for station in data] 

170 for i, p in enumerate(output): 

171 res = p.get() 

172 station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"] 

173 df.loc[station_name, set_name] = shape 

174 if df.loc[station_name, meta_cols].isnull().any(): 

175 df.loc[station_name, meta_cols] = meta 

176 logging.info(f"...finished: {station_name} ({int((i + 1.) / len(output) * 100)}%)") 

177 pool.close() 

178 pool.join() 

179 else: # serial solution 

180 logging.info(f"use serial create_info_df ({set_name})") 

181 for station in data: 

182 res = f_proc_create_info_df(station, meta_cols) 

183 station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"] 

184 df.loc[station_name, set_name] = shape 

185 if df.loc[station_name, meta_cols].isnull().any(): 

186 df.loc[station_name, meta_cols] = meta 

187 df.loc["# Samples", set_name] = df.loc[:, set_name].sum() 

188 assert len(data) == df.loc[:, set_name].count() - 1 

189 df.loc["# Stations", set_name] = len(data) 

190 df[meta_round] = df[meta_round].astype(float).round(precision) 

191 df.sort_index(inplace=True) 

192 df = df.reindex(df.index.drop(["# Stations", "# Samples"]).to_list() + ["# Stations", "# Samples"], ) 

193 df.index.name = 'stat. ID' 

194 return df 

195 

196 def split_train_val_test(self) -> None: 

197 """ 

198 Split data into subsets. 

199 

200 Currently: train, val, test and train_val (actually this is only the merge of train and val, but as an separate 

201 data_collection). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs 

202 always to be executed at first, to set a proper transformation. 

203 """ 

204 fraction_of_training = self.data_store.get("fraction_of_training") 

205 stations = self.data_store.get("stations") 

206 train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), 

207 fraction_of_training) 

208 subset_names = ["train", "val", "test", "train_val"] 

209 if subset_names[0] != "train": # pragma: no cover 

210 raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset" 

211 f"order was: {subset_names}.") 

212 for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names): 

213 self.create_set_split(ind, scope) 

214 

215 @staticmethod 

216 def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]: 

217 """ 

218 Create the training, validation and test subset slice indices for given total_length. 

219 

220 The test data consists on (1-fraction) of total_length (fraction*len:end). Train and validation data therefore 

221 are made from fraction of total_length (0:fraction*len). Train and validation data is split by the factor 0.8 

222 for train and 0.2 for validation. In addition, split_set_indices returns also the combination of training and 

223 validation subset. 

224 

225 :param total_length: list with all objects to split 

226 :param fraction: ratio between test and union of train/val data 

227 

228 :return: slices for each subset in the order: train, val, test, train_val 

229 """ 

230 pos_test_split = int(total_length * fraction) 

231 train_index = slice(0, int(pos_test_split * 0.8)) 

232 val_index = slice(int(pos_test_split * 0.8), pos_test_split) 

233 test_index = slice(pos_test_split, total_length) 

234 train_val_index = slice(0, pos_test_split) 

235 return train_index, val_index, test_index, train_val_index 

236 

237 def create_set_split(self, index_list: slice, set_name: str) -> None: 

238 # get set stations 

239 stations = self.data_store.get("stations", scope=set_name) 

240 if self.data_store.get("use_all_stations_on_all_data_sets"): 

241 set_stations = stations 

242 else: 

243 set_stations = stations[index_list] 

244 logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") 

245 # create set data_collection and store 

246 data_handler = self.data_store.get("data_handler") 

247 collection, valid_stations = self.validate_station(data_handler, set_stations, set_name) 

248 self.data_store.set("stations", valid_stations, scope=set_name) 

249 self.data_store.set("data_collection", collection, scope=set_name) 

250 

251 def validate_station(self, data_handler: AbstractDataHandler, set_stations, set_name=None, 

252 store_processed_data=True): 

253 """ 

254 Check if all given stations in `all_stations` are valid. 

255 

256 Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the 

257 loading time are logged in debug mode. 

258 

259 :return: Corrected list containing only valid station IDs. 

260 """ 

261 t_outer = TimeTracking() 

262 logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}") 

263 # calculate transformation using train data 

264 if set_name == "train": 

265 logging.info("setup transformation using train data exclusively") 

266 self.transformation(data_handler, set_stations) 

267 # start station check 

268 collection = DataCollection(name=set_name) 

269 valid_stations = [] 

270 kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope=set_name) 

271 use_multiprocessing = self.data_store.get("use_multiprocessing") 

272 tmp_path = self.data_store.get("tmp_path") 

273 

274 max_process = self.data_store.get("max_number_multiprocessing") 

275 n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus 

276 if n_process > 1 and use_multiprocessing is True: # parallel solution 276 ↛ 277line 276 didn't jump to line 277, because the condition on line 276 was never true

277 logging.info("use parallel validate station approach") 

278 pool = multiprocessing.Pool(n_process) 

279 logging.info(f"running {getattr(pool, '_processes')} processes in parallel") 

280 kwargs.update({"tmp_path": tmp_path, "return_strategy": "reference"}) 

281 output = [ 

282 pool.apply_async(f_proc, args=(data_handler, station, set_name, store_processed_data), kwds=kwargs) 

283 for station in set_stations] 

284 for i, p in enumerate(output): 

285 _res_file, s = p.get() 

286 logging.info(f"...finished: {s} ({int((i + 1.) / len(output) * 100)}%)") 

287 with open(_res_file, "rb") as f: 

288 dh = dill.load(f) 

289 os.remove(_res_file) 

290 if dh is not None: 

291 collection.add(dh) 

292 valid_stations.append(s) 

293 pool.close() 

294 pool.join() 

295 else: # serial solution 

296 logging.info("use serial validate station approach") 

297 kwargs.update({"return_strategy": "result"}) 

298 for i, station in enumerate(set_stations): 

299 dh, s = f_proc(data_handler, station, set_name, store_processed_data, **kwargs) 

300 if dh is not None: 

301 collection.add(dh) 

302 valid_stations.append(s) 

303 logging.info(f"...finished: {s} ({int((i + 1.) / len(set_stations) * 100)}%)") 

304 

305 logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" 

306 f"{len(set_stations)} valid stations ({set_name}).") 

307 if set_name == "train": 

308 self.store_data_handler_attributes(data_handler, collection) 

309 return collection, valid_stations 

310 

311 def store_data_handler_attributes(self, data_handler, collection): 

312 store_attributes = data_handler.store_attributes() 

313 if len(store_attributes) > 0: 313 ↛ 314line 313 didn't jump to line 314, because the condition on line 313 was never true

314 logging.info(f"store following parameters ({len(store_attributes)}) requested by the data handler: " 

315 f"{','.join(store_attributes)}") 

316 attrs = {} 

317 for dh in collection: 

318 station = str(dh) 

319 for k, v in dh.get_store_attributes().items(): 

320 attrs[k] = dict(attrs.get(k, {}), **{station: v}) 

321 for k, v in attrs.items(): 

322 self.data_store.set(k, v) 

323 self._store_apriori() 

324 

325 def _store_apriori(self): 

326 apriori = self.data_store.get_default("apriori", default=None) 

327 if apriori: 327 ↛ 328line 327 didn't jump to line 328, because the condition on line 327 was never true

328 experiment_path = self.data_store.get("experiment_path") 

329 path = os.path.join(experiment_path, "data", "apriori") 

330 store_file = os.path.join(path, "apriori.pickle") 

331 if not os.path.exists(path): 

332 path_config.check_path_and_create(path) 

333 with open(store_file, "wb") as f: 

334 dill.dump(apriori, f, protocol=4) 

335 logging.debug(f"Store apriori options locally for later use at: {store_file}") 

336 

337 def _load_apriori(self): 

338 if self.data_store.get_default("apriori", default=None) is None: 338 ↛ exitline 338 didn't return from function '_load_apriori', because the condition on line 338 was never false

339 apriori_file = self.data_store.get_default("apriori_file", None) 

340 if apriori_file is not None: 340 ↛ 341line 340 didn't jump to line 341, because the condition on line 340 was never true

341 if os.path.exists(apriori_file): 

342 logging.info(f"use apriori data from given file: {apriori_file}") 

343 with open(apriori_file, "rb") as pickle_file: 

344 self.data_store.set("apriori", dill.load(pickle_file)) 

345 else: 

346 logging.info(f"cannot load apriori file: {apriori_file}. Use fresh calculation from data.") 

347 

348 def transformation(self, data_handler: AbstractDataHandler, stations): 

349 calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True) 

350 if hasattr(data_handler, "transformation"): 

351 transformation_opts = None if calculate_fresh_transformation is True else self._load_transformation() 

352 if transformation_opts is None: 352 ↛ 358line 352 didn't jump to line 358, because the condition on line 352 was never false

353 logging.info(f"start to calculate transformation parameters.") 

354 kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope="train") 

355 tmp_path = self.data_store.get_default("tmp_path", default=None) 

356 transformation_opts = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) 

357 else: 

358 logging.info("In case no valid train data could be found due to problems with transformation, please " 

359 "check your provided transformation file for compability with your data.") 

360 self.data_store.set("transformation", transformation_opts) 

361 if transformation_opts is not None: 

362 self._store_transformation(transformation_opts) 

363 

364 def _load_transformation(self): 

365 """Try to load transformation options from file if transformation_file is provided.""" 

366 transformation_file = self.data_store.get_default("transformation_file", None) 

367 if transformation_file is not None: 

368 if os.path.exists(transformation_file): 

369 logging.info(f"use transformation from given transformation file: {transformation_file}") 

370 with open(transformation_file, "rb") as pickle_file: 

371 return dill.load(pickle_file) 

372 else: 

373 logging.info(f"cannot load transformation file: {transformation_file}. Use fresh calculation of " 

374 f"transformation from train data.") 

375 

376 def _store_transformation(self, transformation_opts): 

377 """Store transformation options locally inside experiment_path if not exists already.""" 

378 experiment_path = self.data_store.get("experiment_path") 

379 transformation_path = os.path.join(experiment_path, "data", "transformation") 

380 transformation_file = os.path.join(transformation_path, "transformation.pickle") 

381 calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True) 

382 if not os.path.exists(transformation_file) or calculate_fresh_transformation: 382 ↛ exitline 382 didn't return from function '_store_transformation', because the condition on line 382 was never false

383 path_config.check_path_and_create(transformation_path) 

384 with open(transformation_file, "wb") as f: 

385 dill.dump(transformation_opts, f, protocol=4) 

386 logging.info(f"Store transformation options locally for later use at: {transformation_file}") 

387 

388 def prepare_competitors(self): 

389 """ 

390 Prepare competitor models already in the preprocessing stage. This is performed here, because some models might 

391 need to have internet access, which is depending on the operating system not possible during postprocessing. 

392 This method checks currently only, if the Intelli03-ts-v1 model is requested as competitor and downloads the 

393 data if required. 

394 """ 

395 logging.info("Searching for competitors to be prepared for use.") 

396 competitors = to_list(self.data_store.get_default("competitors", default=[])) 

397 if len(competitors) > 0: 397 ↛ 428line 397 didn't jump to line 428, because the condition on line 397 was never false

398 for competitor_name in competitors: 

399 if competitor_name.lower() == "IntelliO3-ts-v1".lower(): 399 ↛ 400line 399 didn't jump to line 400, because the condition on line 399 was never true

400 logging.info("Prepare IntelliO3-ts-v1 model") 

401 from mlair.reference_models.reference_model_intellio3_v1 import IntelliO3_ts_v1 

402 path = os.path.join(self.data_store.get("competitor_path"), competitor_name) 

403 IntelliO3_ts_v1("IntelliO3-ts-v1", ref_store_path=path).make_reference_available_locally(remove_tmp_dir=False) 

404 elif competitor_name.lower() == "CAMS".lower(): 404 ↛ 405line 404 didn't jump to line 405, because the condition on line 404 was never true

405 logging.info("Prepare CAMS forecasts") 

406 from mlair.reference_models.reference_model_cams import CAMSforecast 

407 interp_method = self.data_store.get_default("cams_interp_method", default=None) 

408 data_path = self.data_store.get_default("cams_data_path", default=None) 

409 path = os.path.join(self.data_store.get("competitor_path"), competitor_name) 

410 stations = {} 

411 for subset in ["train", "val", "test"]: 

412 data_collection = self.data_store.get("data_collection", subset) 

413 stations.update({str(s): s.get_coordinates() for s in data_collection if s not in stations}) 

414 if interp_method is None: 

415 CAMSforecast("CAMS", ref_store_path=path, data_path=data_path, interp_method=None 

416 ).make_reference_available_locally(stations) 

417 else: 

418 competitors = remove_items(competitors, "CAMS") 

419 for method in to_list(interp_method): 

420 CAMSforecast(f"CAMS{method}", ref_store_path=path + method, data_path=data_path, 

421 interp_method=method).make_reference_available_locally(stations) 

422 competitors.append(f"CAMS{method}") 

423 self.data_store.set("competitors", competitors) 

424 else: 

425 logging.info(f"No preparation required for competitor {competitor_name} as no specific instruction " 

426 f"is provided.") 

427 else: 

428 logging.info("No preparation required because no competitor was provided to the workflow.") 

429 

430 def create_snapshot(self): 

431 logging.info("create snapshot for preprocessing") 

432 from mlair.configuration.snapshot_names import animals 

433 for i_try in range(10): 

434 snapshot_name = random.choice(animals).lower() 

435 snapshot_path = os.path.abspath(self.data_store.get("snapshot_path")) 

436 path_config.check_path_and_create(snapshot_path, remove_existing=False) 

437 _snapshot_file = os.path.join(snapshot_path, f"snapshot_preprocessing_{snapshot_name}.pickle") 

438 if not os.path.exists(_snapshot_file): 

439 logging.info(f"store snapshot at: {_snapshot_file}") 

440 with open(_snapshot_file, "wb") as f: 

441 dill.dump(self.data_store, f, protocol=4) 

442 print(_snapshot_file) 

443 return 

444 logging.info(f"Could not create snapshot at {_snapshot_file} as file is already existing ({i_try + 1}/10)") 

445 logging.info(f"Could not create any snapshot after 10/10 tries.") 

446 

447 def load_snapshot(self, file): 

448 logging.info(f"load snapshot for preprocessing from {file}") 

449 with open(file, "rb") as f: 

450 snapshot = dill.load(f) 

451 excluded_params = ["activation", "activation_output", "add_dense_layer", "batch_normalization", "batch_path", 

452 "batch_size", "block_length", "bootstrap_method", "bootstrap_path", "bootstrap_type", 

453 "competitor_path", "competitors", "create_new_bootstraps", "create_new_model", 

454 "create_snapshot", "data_collection", "debug_mode", "dense_layer_configuration", 

455 "do_uncertainty_estimate", "dropout", "dropout_rnn", "early_stopping_epochs", "epochs", 

456 "evaluate_competitors", "evaluate_feature_importance", "experiment_name", "experiment_path", 

457 "exponent_last_layer", "forecast_path", "fraction_of_training", "hostname", "hpc_hosts", 

458 "kernel_regularizer", "kernel_size", "layer_configuration", "log_level_stream", 

459 "logging_path", "login_nodes", "loss_type", "loss_weights", "max_number_multiprocessing", 

460 "model_class", "model_display_name", "model_path", "n_boots", "n_hidden", "n_layer", 

461 "neighbors", "plot_list", "plot_path", "regularizer", "restore_best_model_weights", 

462 "snapshot_load_path", "snapshot_path", "stations", "tmp_path", "train_model", 

463 "transformation", "use_multiprocessing", "cams_data_path", "cams_interp_method", 

464 "do_bias_free_evaluation", "apriori_file", "model_path", "model_load_path", "era5_data_path", 

465 "era5_file_names", "ifs_data_path", "ifs_file_names"] 

466 data_handler = self.data_store.get("data_handler") 

467 model_class = self.data_store.get("model_class") 

468 excluded_params = list(set(excluded_params + data_handler.store_attributes() + model_class.requirements())) 

469 

470 if check_nested_equality(self.data_store._store, snapshot._store, skip_args=excluded_params) is True: 

471 self.update_datastore(snapshot, excluded_params=remove_items(excluded_params, ["transformation", 

472 "data_collection", 

473 "stations"])) 

474 else: 

475 raise ReferenceError("provided snapshot does not match with the current experiment setup. Abort this run!") 

476 

477 

478def f_proc(data_handler, station, name_affix, store, return_strategy="", tmp_path=None, **kwargs): 

479 """ 

480 Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and 

481 therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and 

482 the station that was used. This function must be implemented globally to work together with multiprocessing. 

483 """ 

484 assert return_strategy in ["result", "reference"] 

485 try: 

486 res = data_handler.build(station, name_affix=name_affix, store_processed_data=store, **kwargs) 

487 except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError, IndexError) as e: 

488 formatted_lines = traceback.format_exc().splitlines() 

489 logging.info(f"remove station {station} because it raised an error: {e} -> " 

490 f"{' | '.join(f_inspect_error(formatted_lines))}") 

491 logging.debug(f"detailed information for removal of station {station}: {traceback.format_exc()}") 

492 res = None 

493 if return_strategy == "result": 493 ↛ 496line 493 didn't jump to line 496, because the condition on line 493 was never false

494 return res, station 

495 else: 

496 if tmp_path is None: 

497 tmp_path = os.getcwd() 

498 _tmp_file = os.path.join(tmp_path, f"{station}_{'%032x' % random.getrandbits(128)}.pickle") 

499 with open(_tmp_file, "wb") as f: 

500 dill.dump(res, f, protocol=4) 

501 return _tmp_file, station 

502 

503 

504def f_proc_create_info_df(data, meta_cols): 

505 station_name = str(data.id_class) 

506 meta = data.id_class.meta 

507 res = {"station_name": station_name, "Y_shape": data.get_Y()[0].shape[0], 

508 "meta": meta.reindex(meta_cols).values.flatten()} 

509 return res 

510 

511 

512def f_inspect_error(formatted): 

513 for i in range(len(formatted) - 1, -1, -1): 513 ↛ 516line 513 didn't jump to line 516, because the loop on line 513 didn't complete

514 if "mlair/mlair" not in formatted[i]: 514 ↛ 513line 514 didn't jump to line 513, because the condition on line 514 was never false

515 return formatted[i - 3:i] 

516 return formatted[-3:0]