Coverage for mlair/data_handler/data_handler_neighbors.py: 100%

4 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1__author__ = 'Lukas Leufen' 

2__date__ = '2020-07-17' 

3 

4""" 

5WARNING: This data handler is just a prototype and has not been validated to work properly! Use it with caution! 

6""" 

7 

8import datetime as dt 

9 

10import numpy as np 

11import pandas as pd 

12import xarray as xr 

13 

14from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation 

15from mlair.helpers import to_list 

16from mlair.data_handler import DefaultDataHandler, AbstractDataHandler 

17import os 

18import copy 

19 

20from typing import Union, List 

21 

22number = Union[float, int] 

23num_or_list = Union[number, List[number]] 

24 

25 

26class DataHandlerNeighbors(DefaultDataHandler): # pragma: no cover 

27 """Data handler including neighboring stations.""" 

28 

29 def __init__(self, id_class, data_path, neighbors=None, min_length=0, 

30 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False): 

31 self.neighbors = to_list(neighbors) if neighbors is not None else [] 

32 super().__init__(id_class, data_path, min_length=min_length, extreme_values=extreme_values, 

33 extremes_on_right_tail_only=extremes_on_right_tail_only) 

34 

35 @classmethod 

36 def build(cls, station, **kwargs): 

37 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} 

38 sp = cls.data_handler(station, **sp_keys) 

39 n_list = [] 

40 for neighbor in kwargs.get("neighbors", []): 

41 n_list.append(cls.data_handler(neighbor, **sp_keys)) 

42 else: 

43 kwargs["neighbors"] = n_list if len(n_list) > 0 else None 

44 dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} 

45 return cls(sp, **dp_args) 

46 

47 def _create_collection(self): 

48 return [self.id_class] + self.neighbors 

49 

50 def get_coordinates(self, include_neighbors=False): 

51 neighbors = list(map(lambda n: n.get_coordinates(), self.neighbors)) if include_neighbors is True else [] 

52 return [super(DataHandlerNeighbors, self).get_coordinates()].append(neighbors) 

53 

54 

55def run_data_prep(): # pragma: no cover 

56 """Comment: methods just to start write meaningful test routines.""" 

57 data = DummyDataHandler("main_class") 

58 data.get_X() 

59 data.get_Y() 

60 

61 path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") 

62 data_prep = DataHandlerNeighbors(DummyDataHandler("main_class"), 

63 path, 

64 neighbors=[DummyDataHandler("neighbor1"), 

65 DummyDataHandler("neighbor2")], 

66 extreme_values=[1., 1.2]) 

67 data_prep.get_data(upsampling=False) 

68 

69 

70def create_data_prep(): # pragma: no cover 

71 """Comment: methods just to start write meaningful test routines.""" 

72 path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") 

73 station_type = None 

74 network = 'UBA' 

75 sampling = 'daily' 

76 target_dim = 'variables' 

77 target_var = 'o3' 

78 interpolation_dim = 'datetime' 

79 window_history_size = 7 

80 window_lead_time = 3 

81 central_station = DataHandlerSingleStation("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, 

82 network, sampling, target_dim, 

83 target_var, interpolation_dim, window_history_size, window_lead_time) 

84 neighbor1 = DataHandlerSingleStation("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {}, 

85 station_type, network, sampling, target_dim, 

86 target_var, interpolation_dim, window_history_size, window_lead_time) 

87 neighbor2 = DataHandlerSingleStation("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, 

88 network, sampling, target_dim, 

89 target_var, interpolation_dim, window_history_size, window_lead_time) 

90 

91 data_prep = [] 

92 data_prep.append(DataHandlerNeighbors(central_station, path, neighbors=[neighbor1, neighbor2])) 

93 data_prep.append(DataHandlerNeighbors(neighbor1, path, neighbors=[central_station, neighbor2])) 

94 data_prep.append(DataHandlerNeighbors(neighbor2, path, neighbors=[neighbor1, central_station])) 

95 return data_prep 

96 

97 

98class DummyDataHandler(AbstractDataHandler): # pragma: no cover 

99 

100 def __init__(self, name, number_of_samples=None): 

101 """This data handler takes a name argument and the number of samples to generate. If not provided, a random 

102 number between 100 and 150 is set.""" 

103 super().__init__() 

104 self.name = name 

105 self.number_of_samples = number_of_samples if number_of_samples is not None else np.random.randint(100, 150) 

106 self._X = self.create_X() 

107 self._Y = self.create_Y() 

108 

109 def create_X(self): 

110 """Inputs are random numbers between 0 and 10 with shape (no_samples, window=14, variables=5).""" 

111 X = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables 

112 datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist() 

113 return xr.DataArray(X, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist, 

114 "window": range(14), 

115 "variables": range(5)}) 

116 

117 def create_Y(self): 

118 """Targets are normal distributed random numbers with shape (no_samples, window=5, variables=1).""" 

119 Y = np.round(0.5 * np.random.randn(self.number_of_samples, 5, 1), 1) # samples, window, variables 

120 datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist() 

121 return xr.DataArray(Y, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist, 

122 "window": range(5), 

123 "variables": range(1)}) 

124 

125 def get_X(self, upsampling=False, as_numpy=False): 

126 """Upsampling parameter is not used for X.""" 

127 return np.copy(self._X) if as_numpy is True else self._X 

128 

129 def get_Y(self, upsampling=False, as_numpy=False): 

130 """Upsampling parameter is not used for Y.""" 

131 return np.copy(self._Y) if as_numpy is True else self._Y 

132 

133 def __str__(self): 

134 return self.name 

135 

136 

137if __name__ == "__main__": 

138 """Comment: This is more for testing. Maybe reuse parts of this code for the testing routines.""" 

139 a = DataHandlerNeighbors 

140 requirements = a.requirements() 

141 

142 kwargs = {"path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"), 

143 "station_type": None, 

144 "network": 'UBA', 

145 "sampling": 'daily', 

146 "target_dim": 'variables', 

147 "target_var": 'o3', 

148 "time_dim": 'datetime', 

149 "window_history_size": 7, 

150 "window_lead_time": 3, 

151 "neighbors": ["DEBW034"], 

152 "data_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"), 

153 "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}, 

154 "transformation": None, } 

155 a_inst = a.build("DEBW011", **kwargs) 

156 print(a_inst) 

157 

158 from mlair.data_handler.iterator import KerasIterator, DataCollection 

159 

160 data_prep = create_data_prep() 

161 data_collection = DataCollection(data_prep) 

162 for data in data_collection: 

163 print(data) 

164 path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras") 

165 keras_it = KerasIterator(data_collection, 100, path, upsampling=True) 

166 keras_it[2]