Coverage for mlair/data_handler/abstract_data_handler.py: 88%

49 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:40 +0000

1 

2__author__ = 'Lukas Leufen' 

3__date__ = '2020-09-21' 

4 

5import inspect 

6from typing import Union, Dict 

7 

8from mlair.helpers import remove_items, to_list 

9 

10 

11class AbstractDataHandler(object): 

12 

13 _requirements = [] 

14 _store_attributes = [] 

15 _skip_args = ["self"] 

16 

17 def __init__(self, *args, **kwargs): 

18 pass 

19 

20 @classmethod 

21 def build(cls, *args, **kwargs): 

22 """Return initialised class.""" 

23 return cls(*args, **kwargs) 

24 

25 def __len__(self, upsampling=False): 

26 raise NotImplementedError 

27 

28 @classmethod 

29 def requirements(cls, skip_args=None): 

30 """Return requirements and own arguments without duplicates.""" 

31 skip_args = cls._skip_args if skip_args is None else cls._skip_args + to_list(skip_args) 

32 return remove_items(list(set(cls._requirements + cls.own_args())), skip_args) 

33 

34 @classmethod 

35 def own_args(cls, *args): 

36 """Return all arguments (including kwonlyargs).""" 

37 arg_spec = inspect.getfullargspec(cls) 

38 list_of_args = arg_spec.args + arg_spec.kwonlyargs + cls.super_args() 

39 return list(set(remove_items(list_of_args, list(args)))) 

40 

41 @classmethod 

42 def super_args(cls): 

43 args = [] 

44 for super_cls in cls.__mro__: 

45 if super_cls == cls: 

46 continue 

47 if hasattr(super_cls, "own_args"): 

48 # args.extend(super_cls.own_args()) 

49 args.extend(getattr(super_cls, "own_args")()) 

50 return list(set(args)) 

51 

52 @classmethod 

53 def store_attributes(cls) -> list: 

54 """ 

55 Let MLAir know that some data should be stored in the data store. This is used for calculations on the train 

56 subset that should be applied to validation and test subset. 

57 

58 To work properly, add a class variable cls._store_attributes to your data handler. If your custom data handler 

59 is constructed on different data handlers (e.g. like the DefaultDataHandler), it is required to overwrite the 

60 get_store_attributs method in addition to return attributes from the corresponding subclasses. This is not 

61 required, if only attributes from the main class are to be returned. 

62 

63 Note, that MLAir will store these attributes with the data handler's identification. This depends on the custom 

64 data handler setting. When loading an attribute from the data handler, it is therefore required to extract the 

65 right information by using the class identification. In case of the DefaultDataHandler this can be achieved to 

66 convert all keys of the attribute to string and compare these with the station parameter. 

67 """ 

68 return list(set(cls._store_attributes)) 

69 

70 def get_store_attributes(self): 

71 """Returns all attribute names and values that are indicated by the store_attributes method.""" 

72 attr_dict = {} 

73 for attr in self.store_attributes(): 

74 attr_dict[attr] = self.__getattribute__(attr) 

75 return attr_dict 

76 

77 @classmethod 

78 def transformation(cls, *args, **kwargs): 

79 return None 

80 

81 def apply_transformation(self, data, inverse=False, **kwargs): 

82 """ 

83 This method must return transformed data. The flag inverse can be used to trigger either transformation or its 

84 inverse method. 

85 """ 

86 raise NotImplementedError 

87 

88 def get_X(self, upsampling=False, as_numpy=False): 

89 raise NotImplementedError 

90 

91 def get_Y(self, upsampling=False, as_numpy=False): 

92 raise NotImplementedError 

93 

94 def get_data(self, upsampling=False, as_numpy=False): 

95 return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy) 

96 

97 def get_coordinates(self) -> Union[None, Dict]: 

98 """Return coordinates as dictionary with keys `lon` and `lat`.""" 

99 return None 

100 

101 def _hash_list(self): 

102 return []