Coverage for mlair/helpers/helpers.py: 83%
132 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1"""Collection of different help functions."""
2__author__ = 'Lukas Leufen, Felix Kleinert'
3__date__ = '2019-10-21'
5import inspect
6import math
7import argparse
9import numpy as np
10import xarray as xr
11import dask.array as da
13from typing import Dict, Callable, Union, List, Any, Tuple
15from tensorflow.keras.models import Model
16from tensorflow.python.keras.layers import deserialize, serialize
17from tensorflow.python.keras.saving import saving_utils
19"""
20The following code is copied from: https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883
21and is a hotfix to make keras.model.model models serializable/pickable
22"""
25def unpack(model, training_config, weights):
26 restored_model = deserialize(model)
27 if training_config is not None:
28 restored_model.compile(
29 **saving_utils.compile_args_from_training_config(
30 training_config
31 )
32 )
33 restored_model.set_weights(weights)
34 return restored_model
36# Hotfix function
37def make_keras_pickable():
39 def __reduce__(self):
40 model_metadata = saving_utils.model_metadata(self)
41 training_config = model_metadata.get("training_config", None)
42 model = serialize(self)
43 weights = self.get_weights()
44 return (unpack, (model, training_config, weights))
46 cls = Model
47 cls.__reduce__ = __reduce__
50" end of hotfix "
53def to_list(obj: Any) -> List:
54 """
55 Transform given object to list if obj is not already a list. Sets are also transformed to a list.
57 :param obj: object to transform to list
59 :return: list containing obj, or obj itself (if obj was already a list)
60 """
61 if isinstance(obj, (set, tuple, type({}.keys()))):
62 obj = list(obj)
63 elif not isinstance(obj, list):
64 obj = [obj]
65 return obj
68def sort_like(list_obj: list, sorted_obj: list):
69 """
70 Sort elements of list_obj as ordered in sorted_obj. Length of sorted_obj as allowed to be higher than length of
71 list_obj, but must contain at least all objects of list_obj. Will raise AssertionError, if not all elements of
72 list_obj are also in sorted_obj. Also it is required for list_obj and sorted_obj to have only unique elements.
74 :param list_obj: list to sort
75 :param sorted_obj: list to use ordering from
77 :return: sorted list
78 """
79 assert set(list_obj).issubset(sorted_obj)
80 assert len(set(list_obj)) == len(list_obj)
81 assert len(set(sorted_obj)) == len(sorted_obj)
82 return [e for e in sorted_obj if e in list_obj]
85def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray:
86 """
87 Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>.
89 :param d: dictionary with 2D-xarrays
90 :param coordinate_name: name of the new created axis (2D -> 3D)
92 :return: combined xarray
93 """
94 if len(d.keys()) == 1:
95 k = list(d.keys())
96 xarray: xr.DataArray = d[k[0]]
97 return xarray.expand_dims(dim={coordinate_name: k}, axis=0)
98 else:
99 xarray = None
100 for k, v in d.items():
101 if xarray is None:
102 xarray = v
103 xarray.coords[coordinate_name] = k
104 else:
105 tmp_xarray = v
106 tmp_xarray.coords[coordinate_name] = k
107 xarray = xr.concat([xarray, tmp_xarray], coordinate_name)
108 return xarray
111def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float:
112 """
113 Perform given rounding operation on number with the precision of decimals.
115 :param number: the number to round
116 :param decimals: numbers of decimals of the rounding operations (default 0 -> round to next integer value)
117 :param round_type: the actual rounding operation. Can be any callable function like math.ceil, math.floor or python
118 built-in round operation.
120 :return: rounded number with desired precision
121 """
122 multiplier = 10. ** decimals
123 return round_type(number * multiplier) / multiplier
126def relative_round(x: float, sig: int, ceil=False, floor=False) -> float:
127 """
128 Round small numbers according to given "significance".
130 Example: relative_round(0.03112, 2) -> 0.031, relative_round(0.03112, 1) -> 0.03
132 :params x: number to round
133 :params sig: "significance" to determine number of decimals
135 :return: rounded number
136 """
137 assert sig >= 1
138 assert not (ceil and floor)
139 if x == 0:
140 return 0
141 else:
142 rounded_number = round(x, sig-get_order(x)-1)
143 if floor is True and rounded_number > round(x, sig-get_order(x)):
144 res = rounded_number - 10 ** (get_order(x) - sig + 1)
145 elif ceil is True and rounded_number < round(x, sig-get_order(x)):
146 res = rounded_number + 10 ** (get_order(x) - sig + 1)
147 else:
148 res = rounded_number
149 return round(res, sig-get_order(res)-1)
152def get_order(x: float):
153 """Get order of number (as power of 10)"""
154 if x == 0:
155 return -np.inf
156 else:
157 return int(np.floor(np.log10(abs(x))))
160def remove_items(obj: Union[List, Dict, Tuple], items: Any):
161 """
162 Remove item(s) from either list, tuple or dictionary.
164 :param obj: object to remove items from (either dictionary or list)
165 :param items: elements to remove from obj. Can either be a list or single entry / key
167 :return: object without items
168 """
170 def remove_from_list(list_obj, item_list):
171 """Remove implementation for lists."""
172 if len(item_list) > 1:
173 return [e for e in list_obj if e not in item_list]
174 elif len(item_list) == 0:
175 return list_obj
176 else:
177 list_obj = list_obj.copy()
178 try:
179 list_obj.remove(item_list[0])
180 except ValueError:
181 pass
182 return list_obj
184 def remove_from_dict(dict_obj, key_list):
185 """Remove implementation for dictionaries."""
186 return {k: v for k, v in dict_obj.items() if k not in key_list}
188 items = to_list(items)
189 if isinstance(obj, list):
190 return remove_from_list(obj, items)
191 elif isinstance(obj, dict):
192 return remove_from_dict(obj, items)
193 elif isinstance(obj, tuple): 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true
194 return tuple(remove_from_list(to_list(obj), items))
195 else:
196 raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.")
199def select_from_dict(dict_obj: dict, sel_list: Any, remove_none: bool = False, filter_cond: bool = True) -> dict:
200 """
201 Extract all key values pairs whose key is contained in the sel_list.
203 Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore, the number of pairs in the
204 returned dict is always smaller or equal to the number of elements in the sel_list. If `filter_cond` is given, this
205 method either return the parts of the input dictionary that are included or not in `sel_list`.
206 """
207 sel_list = to_list(sel_list)
208 assert isinstance(dict_obj, dict)
209 sel_dict = {k: v for k, v in dict_obj.items() if (k in sel_list) is filter_cond}
210 sel_dict = sel_dict if not remove_none else {k: v for k, v in sel_dict.items() if v is not None}
211 return sel_dict
214def extract_value(encapsulated_value):
215 try:
216 if isinstance(encapsulated_value, str):
217 raise TypeError
218 if len(encapsulated_value) == 1:
219 return extract_value(encapsulated_value[0])
220 else:
221 raise NotImplementedError("Trying to extract an encapsulated value from objects with more than a single "
222 "entry is not supported by this function.")
223 except TypeError:
224 return encapsulated_value
227def is_xarray(arr) -> bool:
228 """
229 Returns True if arr is xarray.DataArray or xarray.Dataset.
230 :param arr: variable in question
231 :type arr: Any
232 :return:
233 :rtype: bool
234 """
235 return isinstance(arr, xr.DataArray) or isinstance(arr, xr.Dataset)
238def convert2xrda(arr: Union[xr.DataArray, xr.Dataset, np.ndarray, int, float],
239 use_1d_default: bool = False, **kwargs) -> Union[xr.DataArray, xr.Dataset]:
240 """
241 Converts np.array, int or float object to xarray.DataArray.
243 If a xarray.DataArray or xarray.Dataset is passed, returns that unchanged.
244 :param arr:
245 :type arr: xr.DataArray, xr.Dataset, np.ndarray, int, float
246 :param use_1d_default:
247 :type use_1d_default: bool
248 :param kwargs: Any additional kwargs which are accepted by xr.DataArray()
249 :type kwargs:
250 :return:
251 :rtype: xr.DataArray, xr.DataSet
252 """
253 if is_xarray(arr):
254 return arr
255 else:
256 if use_1d_default:
257 if isinstance(arr, da.core.Array):
258 raise TypeError(f"`use_1d_default=True' is used with `arr' of type da.array. For da.arrays please "
259 f"pass `use_1d_default=False' and specify keywords for xr.DataArray via kwargs.")
260 dims = kwargs.pop('dims', 'points')
261 coords = kwargs.pop('coords', None)
262 try:
263 if coords is None:
264 coords = {dims: range(arr.shape[0])}
265 except (AttributeError, IndexError):
266 if isinstance(arr, int) or isinstance(arr, float):
267 coords = kwargs.pop('coords', {dims: range(1)})
268 dims = to_list(dims)
269 else:
270 raise TypeError(f"`arr' must be arry-like, int or float. But is of type {type(arr)}")
271 kwargs.update({'dims': dims, 'coords': coords})
273 return xr.DataArray(arr, **kwargs)
276def filter_dict_by_value(dictionary: dict, filter_val: Any, filter_cond: bool) -> dict:
277 """
278 Filter dictionary by its values.
280 :param dictionary: dict to filter
281 :param filter_val: search only for key value pair with a value equal to filter_val
282 :param filter_cond: indicate to use either all dict entries that fulfil the filter_val criteria (if `True`) or that
283 do not match the criteria (if `False`)
284 :returns: a filtered dict with either matching or non-matching elements depending on the `filter_cond`
285 """
286 return dict(filter(lambda x: (x[1] == filter_val) is filter_cond, dictionary.items()))
289def str2bool(v):
290 if isinstance(v, bool):
291 return v
292 elif isinstance(v, str):
293 if v.lower() in ('yes', 'true', 't', 'y', '1'):
294 return True
295 elif v.lower() in ('no', 'false', 'f', 'n', '0'):
296 return False
297 else:
298 raise argparse.ArgumentTypeError('Boolean value expected.')
299 else:
300 raise argparse.ArgumentTypeError('Boolean value expected.')
303def squeeze_coords(d):
304 """Look for unused coords and remove them. Does only work for xarray DataArrays."""
305 try:
306 return d.drop(set(d.coords.keys()).difference(d.dims))
307 except Exception:
308 return d
311# def convert_size(size_bytes):
312# if size_bytes == 0:
313# return "0B"
314# size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
315# i = int(math.floor(math.log(size_bytes, 1024)))
316# p = math.pow(1024, i)
317# s = round(size_bytes / p, 2)
318# return "%s %s" % (s, size_name[i])
319#
320#
321# def get_size(obj, seen=None):
322# """Recursively finds size of objects"""
323# size = sys.getsizeof(obj)
324# if seen is None:
325# seen = set()
326# obj_id = id(obj)
327# if obj_id in seen:
328# return 0
329# # Important mark as seen *before* entering recursion to gracefully handle
330# # self-referential objects
331# seen.add(obj_id)
332# if isinstance(obj, dict):
333# size += sum([get_size(v, seen) for v in obj.values()])
334# size += sum([get_size(k, seen) for k in obj.keys()])
335# elif hasattr(obj, '__dict__'):
336# size += get_size(obj.__dict__, seen)
337# elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
338# size += sum([get_size(i, seen) for i in obj])
339# return size