Coverage for mlair/helpers/testing.py: 96%
91 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1"""Helper functions that are used to simplify testing."""
2import logging
3import re
4from typing import Union, Pattern, List
5import inspect
7import numpy as np
8import xarray as xr
10from mlair.helpers.helpers import remove_items, to_list
13class PyTestRegex:
14 r"""
15 Assert that a given string meets some expectations.
17 Use like
19 >>> PyTestRegex(r"TestString\d+") == "TestString"
20 False
21 >>> PyTestRegex(r"TestString\d+") == "TestString2"
22 True
25 :param pattern: pattern or string to use for regular expresssion
26 :param flags: python re flags
27 """
29 def __init__(self, pattern: Union[str, Pattern], flags: int = 0):
30 """Construct PyTestRegex."""
31 self._regex = re.compile(pattern, flags)
33 def __eq__(self, actual: str) -> bool:
34 """Return whether regex matches given string actual or not."""
35 return bool(self._regex.match(actual))
37 def __repr__(self) -> str:
38 """Show regex pattern."""
39 return self._regex.pattern
42def PyTestAllEqual(check_list: List):
43 class PyTestAllEqualClass:
44 """
45 Check if all elements in list are the same.
47 :param check_list: list with elements to check
48 """
50 def __init__(self, check_list: List):
51 """Construct class."""
52 self._list = check_list
53 self._test_function = None
55 def _set_test_function(self, _list):
56 if isinstance(_list[0], list):
57 _test_function = self._set_test_function(_list[0])
58 self._test_function = lambda r, s: all(map(lambda x, y: _test_function(x, y) is None, r, s))
59 elif isinstance(_list[0], np.ndarray):
60 self._test_function = np.testing.assert_array_equal
61 elif isinstance(_list[0], xr.DataArray):
62 self._test_function = xr.testing.assert_equal
63 else:
64 self._test_function = lambda x, y: self._assert(x, y)
65 # raise TypeError(f"given type {type(_list[0])} is not supported by PyTestAllEqual.")
66 return self._test_function
68 @staticmethod
69 def _assert(x, y):
70 assert x == y
72 def _check_all_equal(self) -> bool:
73 """
74 Check if all elements are equal.
76 :return boolean if elements are equal
77 """
78 equal = True
79 self._set_test_function(self._list)
80 for b in self._list:
81 equal *= self._test_function(self._list[0], b) in [None, True]
82 return bool(equal == 1)
84 def is_true(self) -> bool:
85 """
86 Start equality check.
88 :return: true if equality test is passed, false otherwise
89 """
90 return self._check_all_equal()
92 return PyTestAllEqualClass(check_list).is_true()
95def get_all_args(*args, remove=None, add=None):
96 res = []
97 for a in args:
98 arg_spec = inspect.getfullargspec(a)
99 res.extend(arg_spec.args)
100 res.extend(arg_spec.kwonlyargs)
101 res = sorted(list(set(res)))
102 if remove is not None:
103 res = remove_items(res, remove)
104 if add is not None: 104 ↛ 105line 104 didn't jump to line 105, because the condition on line 104 was never true
105 res += to_list(add)
106 return res
109def check_nested_equality(obj1, obj2, precision=None, skip_args=None):
110 """Check for equality in nested structures. Use precision to indicate number of decimals to check for consistency"""
112 assert precision is None or isinstance(precision, int)
113 message = ""
114 try:
115 # print(f"check type {type(obj1)} and {type(obj2)}")
116 message = f"{type(obj1)}!={type(obj2)}\n{obj1} and {obj2} do not match"
117 assert type(obj1) == type(obj2)
118 if isinstance(obj1, (tuple, list)):
119 # print(f"check length {len(obj1)} and {len(obj2)}")
120 message = f"{len(obj1)}!={len(obj2)}\nlengths of {obj1} and {obj2} do not match"
121 assert len(obj1) == len(obj2)
122 for pos in range(len(obj1)):
123 # print(f"check pos {obj1[pos]} and {obj2[pos]}")
124 message = f"{obj1[pos]}!={obj2[pos]}\nobjects on pos {pos} of {obj1} and {obj2} do not match"
125 assert check_nested_equality(obj1[pos], obj2[pos], precision=precision, skip_args=skip_args) is True
126 elif isinstance(obj1, dict):
127 obj1_keys, obj2_keys = obj1.keys(), obj2.keys()
128 if skip_args is not None and isinstance(skip_args, (str, list)): 128 ↛ 129line 128 didn't jump to line 129, because the condition on line 128 was never true
129 skip_args = to_list(skip_args)
130 obj1_keys = list(set(obj1_keys).difference(skip_args))
131 obj2_keys = list(set(obj2_keys).difference(skip_args))
132 # print(f"check keys {obj1.keys()} and {obj2.keys()}")
133 message = f"{sorted(obj1_keys)}!={sorted(obj2_keys)}\n{set(obj1_keys).symmetric_difference(obj2_keys)} " \
134 f"are not in both sorted key lists"
135 assert sorted(obj1_keys) == sorted(obj2_keys)
136 for k in obj1_keys:
137 # print(f"check pos {obj1[k]} and {obj2[k]}")
138 message = f"{obj1[k]}!={obj2[k]}\nobjects for key {k} of {obj1} and {obj2} do not match"
139 assert check_nested_equality(obj1[k], obj2[k], precision=precision, skip_args=skip_args) is True
140 elif isinstance(obj1, xr.DataArray):
141 if precision is None:
142 # print(f"check xr {obj1} and {obj2}")
143 message = f"{obj1}!={obj2}\n{obj1} and {obj2} do not match"
144 assert xr.testing.assert_equal(obj1, obj2) is None
145 else:
146 # print(f"check xr {obj1} and {obj2} with precision {precision}")
147 message = f"{obj1}!={obj2} with precision {precision}\n{obj1} and {obj2} do not match"
148 assert xr.testing.assert_allclose(obj1, obj2, atol=10**(-precision)) is None
149 elif isinstance(obj1, np.ndarray):
150 if precision is None:
151 # print(f"check np {obj1} and {obj2}")
152 message = f"{obj1}!={obj2}\n{obj1} and {obj2} do not match"
153 assert np.testing.assert_array_equal(obj1, obj2) is None
154 else:
155 # print(f"check np {obj1} and {obj2} with precision {precision}")
156 message = f"{obj1}!={obj2} with precision {precision}\n{obj1} and {obj2} do not match"
157 assert np.testing.assert_array_almost_equal(obj1, obj2, decimal=precision) is None
158 else:
159 if isinstance(obj1, (int, float)) and isinstance(obj2, (int, float)):
160 if precision is None:
161 # print(f"check number equal {obj1} and {obj2}")
162 message = f"{obj1}!={obj2}\n{obj1} and {obj2} do not match"
163 assert np.testing.assert_equal(obj1, obj2) is None
164 else:
165 # print(f"check number equal {obj1} and {obj2} with precision {precision}")
166 message = f"{obj1}!={obj2} with precision {precision}\n{obj1} and {obj2} do not match"
167 assert np.testing.assert_almost_equal(obj1, obj2, decimal=precision) is None
168 else:
169 # print(f"check equal {obj1} and {obj2}")
170 message = f"{obj1}!={obj2}\n{obj1} and {obj2} do not match"
171 assert obj1 == obj2
172 except AssertionError:
173 logging.info(message)
174 return False
175 return True