Coverage for mlair/helpers/testing.py: 96%

91 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:40 +0000

1"""Helper functions that are used to simplify testing.""" 

2import logging 

3import re 

4from typing import Union, Pattern, List 

5import inspect 

6 

7import numpy as np 

8import xarray as xr 

9 

10from mlair.helpers.helpers import remove_items, to_list 

11 

12 

13class PyTestRegex: 

14 r""" 

15 Assert that a given string meets some expectations. 

16 

17 Use like 

18 

19 >>> PyTestRegex(r"TestString\d+") == "TestString" 

20 False 

21 >>> PyTestRegex(r"TestString\d+") == "TestString2" 

22 True 

23 

24 

25 :param pattern: pattern or string to use for regular expresssion 

26 :param flags: python re flags 

27 """ 

28 

29 def __init__(self, pattern: Union[str, Pattern], flags: int = 0): 

30 """Construct PyTestRegex.""" 

31 self._regex = re.compile(pattern, flags) 

32 

33 def __eq__(self, actual: str) -> bool: 

34 """Return whether regex matches given string actual or not.""" 

35 return bool(self._regex.match(actual)) 

36 

37 def __repr__(self) -> str: 

38 """Show regex pattern.""" 

39 return self._regex.pattern 

40 

41 

42def PyTestAllEqual(check_list: List): 

43 class PyTestAllEqualClass: 

44 """ 

45 Check if all elements in list are the same. 

46 

47 :param check_list: list with elements to check 

48 """ 

49 

50 def __init__(self, check_list: List): 

51 """Construct class.""" 

52 self._list = check_list 

53 self._test_function = None 

54 

55 def _set_test_function(self, _list): 

56 if isinstance(_list[0], list): 

57 _test_function = self._set_test_function(_list[0]) 

58 self._test_function = lambda r, s: all(map(lambda x, y: _test_function(x, y) is None, r, s)) 

59 elif isinstance(_list[0], np.ndarray): 

60 self._test_function = np.testing.assert_array_equal 

61 elif isinstance(_list[0], xr.DataArray): 

62 self._test_function = xr.testing.assert_equal 

63 else: 

64 self._test_function = lambda x, y: self._assert(x, y) 

65 # raise TypeError(f"given type {type(_list[0])} is not supported by PyTestAllEqual.") 

66 return self._test_function 

67 

68 @staticmethod 

69 def _assert(x, y): 

70 assert x == y 

71 

72 def _check_all_equal(self) -> bool: 

73 """ 

74 Check if all elements are equal. 

75 

76 :return boolean if elements are equal 

77 """ 

78 equal = True 

79 self._set_test_function(self._list) 

80 for b in self._list: 

81 equal *= self._test_function(self._list[0], b) in [None, True] 

82 return bool(equal == 1) 

83 

84 def is_true(self) -> bool: 

85 """ 

86 Start equality check. 

87 

88 :return: true if equality test is passed, false otherwise 

89 """ 

90 return self._check_all_equal() 

91 

92 return PyTestAllEqualClass(check_list).is_true() 

93 

94 

95def get_all_args(*args, remove=None, add=None): 

96 res = [] 

97 for a in args: 

98 arg_spec = inspect.getfullargspec(a) 

99 res.extend(arg_spec.args) 

100 res.extend(arg_spec.kwonlyargs) 

101 res = sorted(list(set(res))) 

102 if remove is not None: 

103 res = remove_items(res, remove) 

104 if add is not None: 104 ↛ 105line 104 didn't jump to line 105, because the condition on line 104 was never true

105 res += to_list(add) 

106 return res 

107 

108 

109def check_nested_equality(obj1, obj2, precision=None, skip_args=None): 

110 """Check for equality in nested structures. Use precision to indicate number of decimals to check for consistency""" 

111 

112 assert precision is None or isinstance(precision, int) 

113 message = "" 

114 try: 

115 # print(f"check type {type(obj1)} and {type(obj2)}") 

116 message = f"{type(obj1)}!={type(obj2)}\n{obj1} and {obj2} do not match" 

117 assert type(obj1) == type(obj2) 

118 if isinstance(obj1, (tuple, list)): 

119 # print(f"check length {len(obj1)} and {len(obj2)}") 

120 message = f"{len(obj1)}!={len(obj2)}\nlengths of {obj1} and {obj2} do not match" 

121 assert len(obj1) == len(obj2) 

122 for pos in range(len(obj1)): 

123 # print(f"check pos {obj1[pos]} and {obj2[pos]}") 

124 message = f"{obj1[pos]}!={obj2[pos]}\nobjects on pos {pos} of {obj1} and {obj2} do not match" 

125 assert check_nested_equality(obj1[pos], obj2[pos], precision=precision, skip_args=skip_args) is True 

126 elif isinstance(obj1, dict): 

127 obj1_keys, obj2_keys = obj1.keys(), obj2.keys() 

128 if skip_args is not None and isinstance(skip_args, (str, list)): 128 ↛ 129line 128 didn't jump to line 129, because the condition on line 128 was never true

129 skip_args = to_list(skip_args) 

130 obj1_keys = list(set(obj1_keys).difference(skip_args)) 

131 obj2_keys = list(set(obj2_keys).difference(skip_args)) 

132 # print(f"check keys {obj1.keys()} and {obj2.keys()}") 

133 message = f"{sorted(obj1_keys)}!={sorted(obj2_keys)}\n{set(obj1_keys).symmetric_difference(obj2_keys)} " \ 

134 f"are not in both sorted key lists" 

135 assert sorted(obj1_keys) == sorted(obj2_keys) 

136 for k in obj1_keys: 

137 # print(f"check pos {obj1[k]} and {obj2[k]}") 

138 message = f"{obj1[k]}!={obj2[k]}\nobjects for key {k} of {obj1} and {obj2} do not match" 

139 assert check_nested_equality(obj1[k], obj2[k], precision=precision, skip_args=skip_args) is True 

140 elif isinstance(obj1, xr.DataArray): 

141 if precision is None: 

142 # print(f"check xr {obj1} and {obj2}") 

143 message = f"{obj1}!={obj2}\n{obj1} and {obj2} do not match" 

144 assert xr.testing.assert_equal(obj1, obj2) is None 

145 else: 

146 # print(f"check xr {obj1} and {obj2} with precision {precision}") 

147 message = f"{obj1}!={obj2} with precision {precision}\n{obj1} and {obj2} do not match" 

148 assert xr.testing.assert_allclose(obj1, obj2, atol=10**(-precision)) is None 

149 elif isinstance(obj1, np.ndarray): 

150 if precision is None: 

151 # print(f"check np {obj1} and {obj2}") 

152 message = f"{obj1}!={obj2}\n{obj1} and {obj2} do not match" 

153 assert np.testing.assert_array_equal(obj1, obj2) is None 

154 else: 

155 # print(f"check np {obj1} and {obj2} with precision {precision}") 

156 message = f"{obj1}!={obj2} with precision {precision}\n{obj1} and {obj2} do not match" 

157 assert np.testing.assert_array_almost_equal(obj1, obj2, decimal=precision) is None 

158 else: 

159 if isinstance(obj1, (int, float)) and isinstance(obj2, (int, float)): 

160 if precision is None: 

161 # print(f"check number equal {obj1} and {obj2}") 

162 message = f"{obj1}!={obj2}\n{obj1} and {obj2} do not match" 

163 assert np.testing.assert_equal(obj1, obj2) is None 

164 else: 

165 # print(f"check number equal {obj1} and {obj2} with precision {precision}") 

166 message = f"{obj1}!={obj2} with precision {precision}\n{obj1} and {obj2} do not match" 

167 assert np.testing.assert_almost_equal(obj1, obj2, decimal=precision) is None 

168 else: 

169 # print(f"check equal {obj1} and {obj2}") 

170 message = f"{obj1}!={obj2}\n{obj1} and {obj2} do not match" 

171 assert obj1 == obj2 

172 except AssertionError: 

173 logging.info(message) 

174 return False 

175 return True