Coverage for mlair/helpers/helpers.py: 83%

133 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-12-18 17:51 +0000

1"""Collection of different help functions.""" 

2__author__ = 'Lukas Leufen, Felix Kleinert' 

3__date__ = '2019-10-21' 

4 

5import inspect 

6import math 

7import argparse 

8 

9import numpy as np 

10import xarray as xr 

11import dask.array as da 

12 

13from typing import Dict, Callable, Union, List, Any, Tuple 

14 

15from tensorflow.keras.models import Model 

16from tensorflow.python.keras.layers import deserialize, serialize 

17from tensorflow.python.keras.saving import saving_utils 

18 

19""" 

20The following code is copied from: https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883 

21and is a hotfix to make keras.model.model models serializable/pickable 

22""" 

23 

24 

25def unpack(model, training_config, weights): 

26 restored_model = deserialize(model) 

27 if training_config is not None: 

28 restored_model.compile( 

29 **saving_utils.compile_args_from_training_config( 

30 training_config 

31 ) 

32 ) 

33 restored_model.set_weights(weights) 

34 return restored_model 

35 

36# Hotfix function 

37def make_keras_pickable(): 

38 

39 def __reduce__(self): 

40 model_metadata = saving_utils.model_metadata(self) 

41 training_config = model_metadata.get("training_config", None) 

42 model = serialize(self) 

43 weights = self.get_weights() 

44 return (unpack, (model, training_config, weights)) 

45 

46 cls = Model 

47 cls.__reduce__ = __reduce__ 

48 

49 

50" end of hotfix " 

51 

52 

53def to_list(obj: Any) -> List: 

54 """ 

55 Transform given object to list if obj is not already a list. Sets are also transformed to a list. 

56 

57 :param obj: object to transform to list 

58 

59 :return: list containing obj, or obj itself (if obj was already a list) 

60 """ 

61 if isinstance(obj, (set, tuple, type({}.keys()))): 

62 obj = list(obj) 

63 elif not isinstance(obj, list): 

64 obj = [obj] 

65 return obj 

66 

67 

68def sort_like(list_obj: list, sorted_obj: list): 

69 """ 

70 Sort elements of list_obj as ordered in sorted_obj. Length of sorted_obj as allowed to be higher than length of 

71 list_obj, but must contain at least all objects of list_obj. Will raise AssertionError, if not all elements of 

72 list_obj are also in sorted_obj. Also it is required for list_obj and sorted_obj to have only unique elements. 

73 

74 :param list_obj: list to sort 

75 :param sorted_obj: list to use ordering from 

76 

77 :return: sorted list 

78 """ 

79 assert set(list_obj).issubset(sorted_obj) 

80 assert len(set(list_obj)) == len(list_obj) 

81 assert len(set(sorted_obj)) == len(sorted_obj) 

82 return [e for e in sorted_obj if e in list_obj] 

83 

84 

85def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray: 

86 """ 

87 Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>. 

88 

89 :param d: dictionary with 2D-xarrays 

90 :param coordinate_name: name of the new created axis (2D -> 3D) 

91 

92 :return: combined xarray 

93 """ 

94 if len(d.keys()) == 1: 

95 k = list(d.keys()) 

96 xarray: xr.DataArray = d[k[0]] 

97 return xarray.expand_dims(dim={coordinate_name: k}, axis=0) 

98 else: 

99 xarray = None 

100 for k, v in d.items(): 

101 if xarray is None: 

102 xarray = v 

103 xarray.coords[coordinate_name] = k 

104 else: 

105 tmp_xarray = v 

106 tmp_xarray.coords[coordinate_name] = k 

107 xarray = xr.concat([xarray, tmp_xarray], coordinate_name) 

108 return xarray 

109 

110 

111def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float: 

112 """ 

113 Perform given rounding operation on number with the precision of decimals. 

114 

115 :param number: the number to round 

116 :param decimals: numbers of decimals of the rounding operations (default 0 -> round to next integer value) 

117 :param round_type: the actual rounding operation. Can be any callable function like math.ceil, math.floor or python 

118 built-in round operation. 

119 

120 :return: rounded number with desired precision 

121 """ 

122 multiplier = 10. ** decimals 

123 return round_type(number * multiplier) / multiplier 

124 

125 

126def relative_round(x: float, sig: int, ceil=False, floor=False) -> float: 

127 """ 

128 Round small numbers according to given "significance". 

129 

130 Example: relative_round(0.03112, 2) -> 0.031, relative_round(0.03112, 1) -> 0.03 

131 

132 :params x: number to round 

133 :params sig: "significance" to determine number of decimals 

134 

135 :return: rounded number 

136 """ 

137 assert sig >= 1 

138 assert not (ceil and floor) 

139 if x == 0: 

140 return 0 

141 else: 

142 rounded_number = round(x, sig-get_order(x)-1) 

143 if floor is True and rounded_number > round(x, sig-get_order(x)): 

144 res = rounded_number - 10 ** (get_order(x) - sig + 1) 

145 elif ceil is True and rounded_number < round(x, sig-get_order(x)): 

146 res = rounded_number + 10 ** (get_order(x) - sig + 1) 

147 else: 

148 res = rounded_number 

149 return round(res, sig-get_order(res)-1) 

150 

151 

152def get_order(x: float): 

153 """Get order of number (as power of 10)""" 

154 if x == 0: 

155 return -np.inf 

156 else: 

157 return int(np.floor(np.log10(abs(x)))) 

158 

159 

160def remove_items(obj: Union[List, Dict, Tuple], items: Any): 

161 """ 

162 Remove item(s) from either list, tuple or dictionary. 

163 

164 :param obj: object to remove items from (either dictionary or list) 

165 :param items: elements to remove from obj. Can either be a list or single entry / key 

166 

167 :return: object without items 

168 """ 

169 

170 def remove_from_list(list_obj, item_list): 

171 """Remove implementation for lists.""" 

172 if len(item_list) > 1: 

173 return [e for e in list_obj if e not in item_list] 

174 elif len(item_list) == 0: 

175 return list_obj 

176 else: 

177 list_obj = list_obj.copy() 

178 try: 

179 list_obj.remove(item_list[0]) 

180 except ValueError: 

181 pass 

182 return list_obj 

183 

184 def remove_from_dict(dict_obj, key_list): 

185 """Remove implementation for dictionaries.""" 

186 return {k: v for k, v in dict_obj.items() if k not in key_list} 

187 

188 items = to_list(items) 

189 if isinstance(obj, list): 

190 return remove_from_list(obj, items) 

191 elif isinstance(obj, dict): 

192 return remove_from_dict(obj, items) 

193 elif isinstance(obj, tuple): 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true

194 return tuple(remove_from_list(to_list(obj), items)) 

195 else: 

196 raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") 

197 

198 

199def select_from_dict(dict_obj: dict, sel_list: Any, remove_none: bool = False, filter_cond: bool = True) -> dict: 

200 """ 

201 Extract all key values pairs whose key is contained in the sel_list. 

202 

203 Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore, the number of pairs in the 

204 returned dict is always smaller or equal to the number of elements in the sel_list. If `filter_cond` is given, this 

205 method either return the parts of the input dictionary that are included or not in `sel_list`. 

206 """ 

207 sel_list = to_list(sel_list) 

208 assert isinstance(dict_obj, dict) 

209 sel_dict = {k: v for k, v in dict_obj.items() if (k in sel_list) is filter_cond} 

210 sel_dict = sel_dict if not remove_none else {k: v for k, v in sel_dict.items() if v is not None} 

211 return sel_dict 

212 

213 

214def extract_value(encapsulated_value): 

215 try: 

216 if isinstance(encapsulated_value, str): 

217 raise TypeError 

218 if len(encapsulated_value) == 1: 

219 return extract_value(encapsulated_value[0]) 

220 else: 

221 raise NotImplementedError("Trying to extract an encapsulated value from objects with more than a single " 

222 "entry is not supported by this function.") 

223 except TypeError: 

224 return encapsulated_value 

225 

226 

227def is_xarray(arr) -> bool: 

228 """ 

229 Returns True if arr is xarray.DataArray or xarray.Dataset. 

230 :param arr: variable in question 

231 :type arr: Any 

232 :return: 

233 :rtype: bool 

234 """ 

235 return isinstance(arr, xr.DataArray) or isinstance(arr, xr.Dataset) 

236 

237 

238def convert2xrda(arr: Union[xr.DataArray, xr.Dataset, np.ndarray, int, float], 

239 use_1d_default: bool = False, **kwargs) -> Union[xr.DataArray, xr.Dataset]: 

240 """ 

241 Converts np.array, int or float object to xarray.DataArray. 

242 

243 If a xarray.DataArray or xarray.Dataset is passed, returns that unchanged. 

244 :param arr: 

245 :type arr: xr.DataArray, xr.Dataset, np.ndarray, int, float 

246 :param use_1d_default: 

247 :type use_1d_default: bool 

248 :param kwargs: Any additional kwargs which are accepted by xr.DataArray() 

249 :type kwargs: 

250 :return: 

251 :rtype: xr.DataArray, xr.DataSet 

252 """ 

253 if is_xarray(arr): 

254 return arr 

255 else: 

256 if use_1d_default: 

257 if isinstance(arr, da.core.Array): 

258 raise TypeError(f"`use_1d_default=True' is used with `arr' of type da.array. For da.arrays please " 

259 f"pass `use_1d_default=False' and specify keywords for xr.DataArray via kwargs.") 

260 dims = kwargs.pop('dims', 'points') 

261 coords = kwargs.pop('coords', None) 

262 try: 

263 if coords is None: 

264 coords = {dims: range(arr.shape[0])} 

265 except (AttributeError, IndexError): 

266 if isinstance(arr, int) or isinstance(arr, float): 

267 coords = kwargs.pop('coords', {dims: range(1)}) 

268 dims = to_list(dims) 

269 else: 

270 raise TypeError(f"`arr' must be arry-like, int or float. But is of type {type(arr)}") 

271 kwargs.update({'dims': dims, 'coords': coords}) 

272 

273 return xr.DataArray(arr, **kwargs) 

274 

275 

276def filter_dict_by_value(dictionary: dict, filter_val: Any, filter_cond: bool) -> dict: 

277 """ 

278 Filter dictionary by its values. 

279 

280 :param dictionary: dict to filter 

281 :param filter_val: search only for key value pair with a value equal to filter_val 

282 :param filter_cond: indicate to use either all dict entries that fulfil the filter_val criteria (if `True`) or that 

283 do not match the criteria (if `False`) 

284 :returns: a filtered dict with either matching or non-matching elements depending on the `filter_cond` 

285 """ 

286 filter_val = to_list(filter_val) 

287 return dict(filter(lambda x: (x[1] in filter_val) is filter_cond, dictionary.items())) 

288 

289 

290def str2bool(v): 

291 if isinstance(v, bool): 

292 return v 

293 elif isinstance(v, str): 

294 if v.lower() in ('yes', 'true', 't', 'y', '1'): 

295 return True 

296 elif v.lower() in ('no', 'false', 'f', 'n', '0'): 

297 return False 

298 else: 

299 raise argparse.ArgumentTypeError('Boolean value expected.') 

300 else: 

301 raise argparse.ArgumentTypeError('Boolean value expected.') 

302 

303 

304def squeeze_coords(d): 

305 """Look for unused coords and remove them. Does only work for xarray DataArrays.""" 

306 try: 

307 return d.drop(set(d.coords.keys()).difference(d.dims)) 

308 except Exception: 

309 return d 

310 

311 

312# def convert_size(size_bytes): 

313# if size_bytes == 0: 

314# return "0B" 

315# size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") 

316# i = int(math.floor(math.log(size_bytes, 1024))) 

317# p = math.pow(1024, i) 

318# s = round(size_bytes / p, 2) 

319# return "%s %s" % (s, size_name[i]) 

320# 

321# 

322# def get_size(obj, seen=None): 

323# """Recursively finds size of objects""" 

324# size = sys.getsizeof(obj) 

325# if seen is None: 

326# seen = set() 

327# obj_id = id(obj) 

328# if obj_id in seen: 

329# return 0 

330# # Important mark as seen *before* entering recursion to gracefully handle 

331# # self-referential objects 

332# seen.add(obj_id) 

333# if isinstance(obj, dict): 

334# size += sum([get_size(v, seen) for v in obj.values()]) 

335# size += sum([get_size(k, seen) for k in obj.keys()]) 

336# elif hasattr(obj, '__dict__'): 

337# size += get_size(obj.__dict__, seen) 

338# elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): 

339# size += sum([get_size(i, seen) for i in obj]) 

340# return size