Coverage for mlair/run_modules/pre_processing.py: 59%

307 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-01 13:03 +0000

1"""Pre-processing module.""" 

2 

3__author__ = "Lukas Leufen, Felix Kleinert" 

4__date__ = '2019-11-25' 

5 

6import logging 

7import os 

8import traceback 

9from typing import Tuple 

10import multiprocessing 

11import requests 

12import psutil 

13import random 

14import dill 

15 

16import pandas as pd 

17 

18from mlair.data_handler import DataCollection, AbstractDataHandler 

19from mlair.helpers import TimeTracking, to_list, tables, remove_items 

20from mlair.configuration import path_config 

21from mlair.helpers.data_sources.data_loader import EmptyQueryResult 

22from mlair.helpers.testing import check_nested_equality 

23from mlair.run_modules.run_environment import RunEnvironment 

24 

25 

26class PreProcessing(RunEnvironment): 

27 """ 

28 Pre-process your data by using this class. 

29 

30 Schedule of pre-processing: 

31 #. load and check valid stations (either download or load from disk) 

32 #. split subsets (train, val, test, train & val) 

33 #. create small report on data metrics 

34 

35 Required objects [scope] from data store: 

36 * all elements from `DEFAULT_ARGS_LIST` in scope preprocessing for general data loading 

37 * all elements from `DEFAULT_ARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings 

38 * `fraction_of_training` [.] 

39 * `experiment_path` [.] 

40 * `use_all_stations_on_all_data_sets` [.] 

41 

42 Optional objects 

43 * all elements from `DEFAULT_KWARGS_LIST` in scope preprocessing for general data loading 

44 * all elements from `DEFAULT_KWARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings 

45 

46 Sets 

47 * `stations` in [., train, val, test, train_val] 

48 * `generator` in [train, val, test, train_val] 

49 * `transformation` [.] 

50 

51 Creates 

52 * all input and output data in `data_path` 

53 * latex reports in `experiment_path/latex_report` 

54 

55 """ 

56 

57 def __init__(self): 

58 """Set up and run pre-processing.""" 

59 super().__init__() 

60 self._run() 

61 

62 def _run(self): 

63 snapshot_load_path = self.data_store.get_default("snapshot_load_path", default=None) 

64 if snapshot_load_path is None: 64 ↛ 75line 64 didn't jump to line 75, because the condition on line 64 was never false

65 stations = self.data_store.get("stations") 

66 data_handler = self.data_store.get("data_handler") 

67 self._load_apriori() 

68 _, valid_stations = self.validate_station(data_handler, stations, 

69 "preprocessing") # , store_processed_data=False) 

70 if len(valid_stations) == 0: 70 ↛ 71line 70 didn't jump to line 71, because the condition on line 70 was never true

71 raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.") 

72 self.data_store.set("stations", valid_stations) 

73 self.split_train_val_test() 

74 else: 

75 self.load_snapshot(snapshot_load_path) 

76 self.report_pre_processing() 

77 self.prepare_competitors() 

78 if self.data_store.get_default("create_snapshot", False) is True: 78 ↛ 79line 78 didn't jump to line 79, because the condition on line 78 was never true

79 self.create_snapshot() 

80 

81 def report_pre_processing(self): 

82 """Log some metrics on data and create latex report.""" 

83 logging.debug(20 * '##') 

84 n_train = len(self.data_store.get('data_collection', 'train')) 

85 n_val = len(self.data_store.get('data_collection', 'val')) 

86 n_test = len(self.data_store.get('data_collection', 'test')) 

87 n_total = n_train + n_val + n_test 

88 logging.debug(f"Number of all stations: {n_total}") 

89 logging.debug(f"Number of training stations: {n_train}") 

90 logging.debug(f"Number of val stations: {n_val}") 

91 logging.debug(f"Number of test stations: {n_test}") 

92 self.create_latex_report() 

93 

94 def create_latex_report(self): 

95 """ 

96 Create tables with information on the station meta data and a summary on subset sample sizes. 

97 

98 * station_sample_size.md: see table below as markdown 

99 * station_sample_size.tex: same as table below as latex table 

100 * station_sample_size_short.tex: reduced size table without any meta data besides station ID, as latex table 

101 

102 All tables are stored inside experiment_path inside the folder latex_report. The table format (e.g. which meta 

103 data is highlighted) is currently hardcoded to have a stable table style. If further styles are needed, it is 

104 better to add an additional style than modifying the existing table styles. 

105 

106 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

107 | stat. ID | station_name | station_lon | station_lat | station_alt | train | val | test | 

108 +============+===========================================+===============+===============+===============+=========+=======+========+ 

109 | DEBW013 | Stuttgart Bad Cannstatt | 9.2297 | 48.8088 | 235 | 1434 | 712 | 1080 | 

110 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

111 | DEBW076 | Baden-Baden | 8.2202 | 48.7731 | 148 | 3037 | 722 | 710 | 

112 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

113 | DEBW087 | Schwäbische_Alb | 9.2076 | 48.3458 | 798 | 3044 | 714 | 1087 | 

114 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

115 | DEBW107 | Tübingen | 9.0512 | 48.5077 | 325 | 1803 | 715 | 1087 | 

116 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

117 | DEBY081 | Garmisch-Partenkirchen/Kreuzeckbahnstraße | 11.0631 | 47.4764 | 735 | 2935 | 525 | 714 | 

118 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

119 | # Stations | nan | nan | nan | nan | 6 | 6 | 6 | 

120 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

121 | # Samples | nan | nan | nan | nan | 12253 | 3388 | 4678 | 

122 +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ 

123 

124 """ 

125 meta_cols = ["name", "lat", "lon", "alt", "country", "state", "type", "type_of_area", "toar1_category"] 

126 meta_round = ["lat", "lon", "alt"] 

127 precision = 4 

128 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") 

129 path_config.check_path_and_create(path) 

130 names_of_set = ["train", "val", "test"] 

131 df = self.create_info_df(meta_cols, meta_round, names_of_set, precision) 

132 column_format = tables.create_column_format_for_tex(df) 

133 tables.save_to_tex(path=path, filename="station_sample_size.tex", column_format=column_format, df=df) 

134 tables.save_to_md(path=path, filename="station_sample_size.md", df=df) 

135 df_nometa = df.drop(meta_cols, axis=1) 

136 column_format = tables.create_column_format_for_tex(df) 

137 tables.save_to_tex(path=path, filename="station_sample_size_short.tex", column_format=column_format, 

138 df=df_nometa) 

139 tables.save_to_md(path=path, filename="station_sample_size_short.md", df=df_nometa) 

140 df_descr = self.create_describe_df(df_nometa) 

141 column_format = tables.create_column_format_for_tex(df_descr) 

142 tables.save_to_tex(path=path, filename="station_describe_short.tex", column_format=column_format, df=df_descr) 

143 tables.save_to_md(path=path, filename="station_describe_short.md", df=df_descr) 

144 

145 @staticmethod 

146 def create_describe_df(df, percentiles=None, ignore_last_lines: int = 2): 

147 if percentiles is None: 147 ↛ 149line 147 didn't jump to line 149, because the condition on line 147 was never false

148 percentiles = [.05, .1, .25, .5, .75, .9, .95] 

149 df_descr = df.iloc[:-ignore_last_lines].astype('float32').describe( 

150 percentiles=percentiles).astype("int32", errors="ignore") 

151 df_descr = pd.concat([df.loc[['# Samples']], df_descr]).T 

152 df_descr.rename(columns={"# Samples": "no. samples", "count": "no. stations"}, inplace=True) 

153 df_descr_colnames = list(df_descr.columns) 

154 df_descr_colnames = [df_descr_colnames[1]] + [df_descr_colnames[0]] + df_descr_colnames[2:] 

155 df_descr = df_descr[df_descr_colnames] 

156 return df_descr 

157 

158 def create_info_df(self, meta_cols, meta_round, names_of_set, precision): 

159 use_multiprocessing = self.data_store.get("use_multiprocessing") 

160 max_process = self.data_store.get("max_number_multiprocessing") 

161 df = pd.DataFrame(columns=meta_cols + names_of_set) 

162 for set_name in names_of_set: 

163 data = self.data_store.get("data_collection", set_name) 

164 n_process = min([psutil.cpu_count(logical=False), len(data), max_process]) # use only physical cpus 

165 if n_process > 1 and use_multiprocessing is True: # parallel solution 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true

166 logging.info(f"use parallel create_info_df ({set_name})") 

167 pool = multiprocessing.Pool(n_process) 

168 logging.info(f"running {getattr(pool, '_processes')} processes in parallel") 

169 output = [pool.apply_async(f_proc_create_info_df, args=(station, meta_cols)) for station in data] 

170 for i, p in enumerate(output): 

171 res = p.get() 

172 station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"] 

173 df.loc[station_name, set_name] = shape 

174 if df.loc[station_name, meta_cols].isnull().any(): 

175 df.loc[station_name, meta_cols] = meta 

176 logging.info(f"...finished: {station_name} ({int((i + 1.) / len(output) * 100)}%)") 

177 pool.close() 

178 pool.join() 

179 else: # serial solution 

180 logging.info(f"use serial create_info_df ({set_name})") 

181 for station in data: 

182 res = f_proc_create_info_df(station, meta_cols) 

183 station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"] 

184 df.loc[station_name, set_name] = shape 

185 if df.loc[station_name, meta_cols].isnull().any(): 

186 df.loc[station_name, meta_cols] = meta 

187 df.loc["# Samples", set_name] = df.loc[:, set_name].sum() 

188 assert len(data) == df.loc[:, set_name].count() - 1 

189 df.loc["# Stations", set_name] = len(data) 

190 df[meta_round] = df[meta_round].astype(float).round(precision) 

191 df.sort_index(inplace=True) 

192 df = df.reindex(df.index.drop(["# Stations", "# Samples"]).to_list() + ["# Stations", "# Samples"], ) 

193 df.index.name = 'stat. ID' 

194 return df 

195 

196 def split_train_val_test(self) -> None: 

197 """ 

198 Split data into subsets. 

199 

200 Currently: train, val, test and train_val (actually this is only the merge of train and val, but as an separate 

201 data_collection). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs 

202 always to be executed at first, to set a proper transformation. 

203 """ 

204 fraction_of_training = self.data_store.get("fraction_of_training") 

205 stations = self.data_store.get("stations") 

206 train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), 

207 fraction_of_training) 

208 subset_names = ["train", "val", "test", "train_val"] 

209 if subset_names[0] != "train": # pragma: no cover 

210 raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset" 

211 f"order was: {subset_names}.") 

212 for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names): 

213 self.create_set_split(ind, scope) 

214 

215 @staticmethod 

216 def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]: 

217 """ 

218 Create the training, validation and test subset slice indices for given total_length. 

219 

220 The test data consists on (1-fraction) of total_length (fraction*len:end). Train and validation data therefore 

221 are made from fraction of total_length (0:fraction*len). Train and validation data is split by the factor 0.8 

222 for train and 0.2 for validation. In addition, split_set_indices returns also the combination of training and 

223 validation subset. 

224 

225 :param total_length: list with all objects to split 

226 :param fraction: ratio between test and union of train/val data 

227 

228 :return: slices for each subset in the order: train, val, test, train_val 

229 """ 

230 pos_test_split = int(total_length * fraction) 

231 train_index = slice(0, int(pos_test_split * 0.8)) 

232 val_index = slice(int(pos_test_split * 0.8), pos_test_split) 

233 test_index = slice(pos_test_split, total_length) 

234 train_val_index = slice(0, pos_test_split) 

235 return train_index, val_index, test_index, train_val_index 

236 

237 def create_set_split(self, index_list: slice, set_name: str) -> None: 

238 # get set stations 

239 stations = self.data_store.get("stations", scope=set_name) 

240 if self.data_store.get("use_all_stations_on_all_data_sets"): 

241 set_stations = stations 

242 else: 

243 set_stations = stations[index_list] 

244 logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") 

245 # create set data_collection and store 

246 data_handler = self.data_store.get("data_handler") 

247 collection, valid_stations = self.validate_station(data_handler, set_stations, set_name) 

248 self.data_store.set("stations", valid_stations, scope=set_name) 

249 self.data_store.set("data_collection", collection, scope=set_name) 

250 

251 def validate_station(self, data_handler: AbstractDataHandler, set_stations, set_name=None, 

252 store_processed_data=True): 

253 """ 

254 Check if all given stations in `all_stations` are valid. 

255 

256 Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the 

257 loading time are logged in debug mode. 

258 

259 :return: Corrected list containing only valid station IDs. 

260 """ 

261 t_outer = TimeTracking() 

262 logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}") 

263 # calculate transformation using train data 

264 if set_name == "train": 

265 logging.info("setup transformation using train data exclusively") 

266 self.transformation(data_handler, set_stations) 

267 # start station check 

268 collection = DataCollection(name=set_name) 

269 valid_stations = [] 

270 kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope=set_name) 

271 use_multiprocessing = self.data_store.get("use_multiprocessing") 

272 tmp_path = self.data_store.get("tmp_path") 

273 

274 max_process = self.data_store.get("max_number_multiprocessing") 

275 n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus 

276 if n_process > 1 and use_multiprocessing is True: # parallel solution 276 ↛ 277line 276 didn't jump to line 277, because the condition on line 276 was never true

277 logging.info("use parallel validate station approach") 

278 pool = multiprocessing.Pool(n_process) 

279 logging.info(f"running {getattr(pool, '_processes')} processes in parallel") 

280 kwargs.update({"tmp_path": tmp_path, "return_strategy": "reference"}) 

281 output = [ 

282 pool.apply_async(f_proc, args=(data_handler, station, set_name, store_processed_data), kwds=kwargs) 

283 for station in set_stations] 

284 for i, p in enumerate(output): 

285 _res_file, s = p.get() 

286 logging.info(f"...finished: {s} ({int((i + 1.) / len(output) * 100)}%)") 

287 with open(_res_file, "rb") as f: 

288 dh = dill.load(f) 

289 os.remove(_res_file) 

290 if dh is not None: 

291 collection.add(dh) 

292 valid_stations.append(s) 

293 pool.close() 

294 pool.join() 

295 else: # serial solution 

296 logging.info("use serial validate station approach") 

297 kwargs.update({"return_strategy": "result"}) 

298 for station in set_stations: 

299 dh, s = f_proc(data_handler, station, set_name, store_processed_data, **kwargs) 

300 if dh is not None: 

301 collection.add(dh) 

302 valid_stations.append(s) 

303 

304 logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" 

305 f"{len(set_stations)} valid stations ({set_name}).") 

306 if set_name == "train": 

307 self.store_data_handler_attributes(data_handler, collection) 

308 return collection, valid_stations 

309 

310 def store_data_handler_attributes(self, data_handler, collection): 

311 store_attributes = data_handler.store_attributes() 

312 if len(store_attributes) > 0: 312 ↛ 313line 312 didn't jump to line 313, because the condition on line 312 was never true

313 logging.info(f"store following parameters ({len(store_attributes)}) requested by the data handler: " 

314 f"{','.join(store_attributes)}") 

315 attrs = {} 

316 for dh in collection: 

317 station = str(dh) 

318 for k, v in dh.get_store_attributes().items(): 

319 attrs[k] = dict(attrs.get(k, {}), **{station: v}) 

320 for k, v in attrs.items(): 

321 self.data_store.set(k, v) 

322 self._store_apriori() 

323 

324 def _store_apriori(self): 

325 apriori = self.data_store.get_default("apriori", default=None) 

326 if apriori: 326 ↛ 327line 326 didn't jump to line 327, because the condition on line 326 was never true

327 experiment_path = self.data_store.get("experiment_path") 

328 path = os.path.join(experiment_path, "data", "apriori") 

329 store_file = os.path.join(path, "apriori.pickle") 

330 if not os.path.exists(path): 

331 path_config.check_path_and_create(path) 

332 with open(store_file, "wb") as f: 

333 dill.dump(apriori, f, protocol=4) 

334 logging.debug(f"Store apriori options locally for later use at: {store_file}") 

335 

336 def _load_apriori(self): 

337 if self.data_store.get_default("apriori", default=None) is None: 337 ↛ exitline 337 didn't return from function '_load_apriori', because the condition on line 337 was never false

338 apriori_file = self.data_store.get_default("apriori_file", None) 

339 if apriori_file is not None: 339 ↛ 340line 339 didn't jump to line 340, because the condition on line 339 was never true

340 if os.path.exists(apriori_file): 

341 logging.info(f"use apriori data from given file: {apriori_file}") 

342 with open(apriori_file, "rb") as pickle_file: 

343 self.data_store.set("apriori", dill.load(pickle_file)) 

344 else: 

345 logging.info(f"cannot load apriori file: {apriori_file}. Use fresh calculation from data.") 

346 

347 def transformation(self, data_handler: AbstractDataHandler, stations): 

348 calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True) 

349 if hasattr(data_handler, "transformation"): 

350 transformation_opts = None if calculate_fresh_transformation is True else self._load_transformation() 

351 if transformation_opts is None: 351 ↛ 357line 351 didn't jump to line 357, because the condition on line 351 was never false

352 logging.info(f"start to calculate transformation parameters.") 

353 kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope="train") 

354 tmp_path = self.data_store.get_default("tmp_path", default=None) 

355 transformation_opts = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) 

356 else: 

357 logging.info("In case no valid train data could be found due to problems with transformation, please " 

358 "check your provided transformation file for compability with your data.") 

359 self.data_store.set("transformation", transformation_opts) 

360 if transformation_opts is not None: 

361 self._store_transformation(transformation_opts) 

362 

363 def _load_transformation(self): 

364 """Try to load transformation options from file if transformation_file is provided.""" 

365 transformation_file = self.data_store.get_default("transformation_file", None) 

366 if transformation_file is not None: 

367 if os.path.exists(transformation_file): 

368 logging.info(f"use transformation from given transformation file: {transformation_file}") 

369 with open(transformation_file, "rb") as pickle_file: 

370 return dill.load(pickle_file) 

371 else: 

372 logging.info(f"cannot load transformation file: {transformation_file}. Use fresh calculation of " 

373 f"transformation from train data.") 

374 

375 def _store_transformation(self, transformation_opts): 

376 """Store transformation options locally inside experiment_path if not exists already.""" 

377 experiment_path = self.data_store.get("experiment_path") 

378 transformation_path = os.path.join(experiment_path, "data", "transformation") 

379 transformation_file = os.path.join(transformation_path, "transformation.pickle") 

380 calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True) 

381 if not os.path.exists(transformation_file) or calculate_fresh_transformation: 381 ↛ exitline 381 didn't return from function '_store_transformation', because the condition on line 381 was never false

382 path_config.check_path_and_create(transformation_path) 

383 with open(transformation_file, "wb") as f: 

384 dill.dump(transformation_opts, f, protocol=4) 

385 logging.info(f"Store transformation options locally for later use at: {transformation_file}") 

386 

387 def prepare_competitors(self): 

388 """ 

389 Prepare competitor models already in the preprocessing stage. This is performed here, because some models might 

390 need to have internet access, which is depending on the operating system not possible during postprocessing. 

391 This method checks currently only, if the Intelli03-ts-v1 model is requested as competitor and downloads the 

392 data if required. 

393 """ 

394 logging.info("Searching for competitors to be prepared for use.") 

395 competitors = to_list(self.data_store.get_default("competitors", default=[])) 

396 if len(competitors) > 0: 396 ↛ 427line 396 didn't jump to line 427, because the condition on line 396 was never false

397 for competitor_name in competitors: 

398 if competitor_name.lower() == "IntelliO3-ts-v1".lower(): 398 ↛ 399line 398 didn't jump to line 399, because the condition on line 398 was never true

399 logging.info("Prepare IntelliO3-ts-v1 model") 

400 from mlair.reference_models.reference_model_intellio3_v1 import IntelliO3_ts_v1 

401 path = os.path.join(self.data_store.get("competitor_path"), competitor_name) 

402 IntelliO3_ts_v1("IntelliO3-ts-v1", ref_store_path=path).make_reference_available_locally(remove_tmp_dir=False) 

403 elif competitor_name.lower() == "CAMS".lower(): 403 ↛ 404line 403 didn't jump to line 404, because the condition on line 403 was never true

404 logging.info("Prepare CAMS forecasts") 

405 from mlair.reference_models.reference_model_cams import CAMSforecast 

406 interp_method = self.data_store.get_default("cams_interp_method", default=None) 

407 data_path = self.data_store.get_default("cams_data_path", default=None) 

408 path = os.path.join(self.data_store.get("competitor_path"), competitor_name) 

409 stations = {} 

410 for subset in ["train", "val", "test"]: 

411 data_collection = self.data_store.get("data_collection", subset) 

412 stations.update({str(s): s.get_coordinates() for s in data_collection if s not in stations}) 

413 if interp_method is None: 

414 CAMSforecast("CAMS", ref_store_path=path, data_path=data_path, interp_method=None 

415 ).make_reference_available_locally(stations) 

416 else: 

417 competitors = remove_items(competitors, "CAMS") 

418 for method in to_list(interp_method): 

419 CAMSforecast(f"CAMS{method}", ref_store_path=path + method, data_path=data_path, 

420 interp_method=method).make_reference_available_locally(stations) 

421 competitors.append(f"CAMS{method}") 

422 self.data_store.set("competitors", competitors) 

423 else: 

424 logging.info(f"No preparation required for competitor {competitor_name} as no specific instruction " 

425 f"is provided.") 

426 else: 

427 logging.info("No preparation required because no competitor was provided to the workflow.") 

428 

429 def create_snapshot(self): 

430 logging.info("create snapshot for preprocessing") 

431 from mlair.configuration.snapshot_names import animals 

432 for i_try in range(10): 

433 snapshot_name = random.choice(animals).lower() 

434 snapshot_path = os.path.abspath(self.data_store.get("snapshot_path")) 

435 path_config.check_path_and_create(snapshot_path, remove_existing=False) 

436 _snapshot_file = os.path.join(snapshot_path, f"snapshot_preprocessing_{snapshot_name}.pickle") 

437 if not os.path.exists(_snapshot_file): 

438 logging.info(f"store snapshot at: {_snapshot_file}") 

439 with open(_snapshot_file, "wb") as f: 

440 dill.dump(self.data_store, f, protocol=4) 

441 print(_snapshot_file) 

442 return 

443 logging.info(f"Could not create snapshot at {_snapshot_file} as file is already existing ({i_try + 1}/10)") 

444 logging.info(f"Could not create any snapshot after 10/10 tries.") 

445 

446 def load_snapshot(self, file): 

447 logging.info(f"load snapshot for preprocessing from {file}") 

448 with open(file, "rb") as f: 

449 snapshot = dill.load(f) 

450 excluded_params = ["activation", "activation_output", "add_dense_layer", "batch_normalization", "batch_path", 

451 "batch_size", "block_length", "bootstrap_method", "bootstrap_path", "bootstrap_type", 

452 "competitor_path", "competitors", "create_new_bootstraps", "create_new_model", 

453 "create_snapshot", "data_collection", "debug_mode", "dense_layer_configuration", 

454 "do_uncertainty_estimate", "dropout", "dropout_rnn", "early_stopping_epochs", "epochs", 

455 "evaluate_competitors", "evaluate_feature_importance", "experiment_name", "experiment_path", 

456 "exponent_last_layer", "forecast_path", "fraction_of_training", "hostname", "hpc_hosts", 

457 "kernel_regularizer", "kernel_size", "layer_configuration", "log_level_stream", 

458 "logging_path", "login_nodes", "loss_type", "loss_weights", "max_number_multiprocessing", 

459 "model_class", "model_display_name", "model_path", "n_boots", "n_hidden", "n_layer", 

460 "neighbors", "plot_list", "plot_path", "regularizer", "restore_best_model_weights", 

461 "snapshot_load_path", "snapshot_path", "stations", "tmp_path", "train_model", 

462 "transformation", "use_multiprocessing", "cams_data_path", "cams_interp_method", 

463 "do_bias_free_evaluation", "apriori_file", "model_path", "model_load_path"] 

464 data_handler = self.data_store.get("data_handler") 

465 model_class = self.data_store.get("model_class") 

466 excluded_params = list(set(excluded_params + data_handler.store_attributes() + model_class.requirements())) 

467 

468 if check_nested_equality(self.data_store._store, snapshot._store, skip_args=excluded_params) is True: 

469 self.update_datastore(snapshot, excluded_params=remove_items(excluded_params, ["transformation", 

470 "data_collection", 

471 "stations"])) 

472 else: 

473 raise ReferenceError("provided snapshot does not match with the current experiment setup. Abort this run!") 

474 

475 

476def f_proc(data_handler, station, name_affix, store, return_strategy="", tmp_path=None, **kwargs): 

477 """ 

478 Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and 

479 therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and 

480 the station that was used. This function must be implemented globally to work together with multiprocessing. 

481 """ 

482 assert return_strategy in ["result", "reference"] 

483 try: 

484 res = data_handler.build(station, name_affix=name_affix, store_processed_data=store, **kwargs) 

485 except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError, IndexError) as e: 

486 formatted_lines = traceback.format_exc().splitlines() 

487 logging.info(f"remove station {station} because it raised an error: {e} -> " 

488 f"{' | '.join(f_inspect_error(formatted_lines))}") 

489 logging.debug(f"detailed information for removal of station {station}: {traceback.format_exc()}") 

490 res = None 

491 if return_strategy == "result": 491 ↛ 494line 491 didn't jump to line 494, because the condition on line 491 was never false

492 return res, station 

493 else: 

494 if tmp_path is None: 

495 tmp_path = os.getcwd() 

496 _tmp_file = os.path.join(tmp_path, f"{station}_{'%032x' % random.getrandbits(128)}.pickle") 

497 with open(_tmp_file, "wb") as f: 

498 dill.dump(res, f, protocol=4) 

499 return _tmp_file, station 

500 

501 

502def f_proc_create_info_df(data, meta_cols): 

503 station_name = str(data.id_class) 

504 meta = data.id_class.meta 

505 res = {"station_name": station_name, "Y_shape": data.get_Y()[0].shape[0], 

506 "meta": meta.reindex(meta_cols).values.flatten()} 

507 return res 

508 

509 

510def f_inspect_error(formatted): 

511 for i in range(len(formatted) - 1, -1, -1): 511 ↛ 514line 511 didn't jump to line 514, because the loop on line 511 didn't complete

512 if "mlair/mlair" not in formatted[i]: 512 ↛ 511line 512 didn't jump to line 511, because the condition on line 512 was never false

513 return formatted[i - 3:i] 

514 return formatted[-3:0]