Coverage for mlair/data_handler/data_handler_neighbors.py: 100%
4 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1__author__ = 'Lukas Leufen'
2__date__ = '2020-07-17'
4"""
5WARNING: This data handler is just a prototype and has not been validated to work properly! Use it with caution!
6"""
8import datetime as dt
10import numpy as np
11import pandas as pd
12import xarray as xr
14from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
15from mlair.helpers import to_list
16from mlair.data_handler import DefaultDataHandler, AbstractDataHandler
17import os
18import copy
20from typing import Union, List
22number = Union[float, int]
23num_or_list = Union[number, List[number]]
26class DataHandlerNeighbors(DefaultDataHandler): # pragma: no cover
27 """Data handler including neighboring stations."""
29 def __init__(self, id_class, data_path, neighbors=None, min_length=0,
30 extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False):
31 self.neighbors = to_list(neighbors) if neighbors is not None else []
32 super().__init__(id_class, data_path, min_length=min_length, extreme_values=extreme_values,
33 extremes_on_right_tail_only=extremes_on_right_tail_only)
35 @classmethod
36 def build(cls, station, **kwargs):
37 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
38 sp = cls.data_handler(station, **sp_keys)
39 n_list = []
40 for neighbor in kwargs.get("neighbors", []):
41 n_list.append(cls.data_handler(neighbor, **sp_keys))
42 else:
43 kwargs["neighbors"] = n_list if len(n_list) > 0 else None
44 dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
45 return cls(sp, **dp_args)
47 def _create_collection(self):
48 return [self.id_class] + self.neighbors
50 def get_coordinates(self, include_neighbors=False):
51 neighbors = list(map(lambda n: n.get_coordinates(), self.neighbors)) if include_neighbors is True else []
52 return [super(DataHandlerNeighbors, self).get_coordinates()].append(neighbors)
55def run_data_prep(): # pragma: no cover
56 """Comment: methods just to start write meaningful test routines."""
57 data = DummyDataHandler("main_class")
58 data.get_X()
59 data.get_Y()
61 path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
62 data_prep = DataHandlerNeighbors(DummyDataHandler("main_class"),
63 path,
64 neighbors=[DummyDataHandler("neighbor1"),
65 DummyDataHandler("neighbor2")],
66 extreme_values=[1., 1.2])
67 data_prep.get_data(upsampling=False)
70def create_data_prep(): # pragma: no cover
71 """Comment: methods just to start write meaningful test routines."""
72 path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
73 station_type = None
74 network = 'UBA'
75 sampling = 'daily'
76 target_dim = 'variables'
77 target_var = 'o3'
78 interpolation_dim = 'datetime'
79 window_history_size = 7
80 window_lead_time = 3
81 central_station = DataHandlerSingleStation("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type,
82 network, sampling, target_dim,
83 target_var, interpolation_dim, window_history_size, window_lead_time)
84 neighbor1 = DataHandlerSingleStation("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},
85 station_type, network, sampling, target_dim,
86 target_var, interpolation_dim, window_history_size, window_lead_time)
87 neighbor2 = DataHandlerSingleStation("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type,
88 network, sampling, target_dim,
89 target_var, interpolation_dim, window_history_size, window_lead_time)
91 data_prep = []
92 data_prep.append(DataHandlerNeighbors(central_station, path, neighbors=[neighbor1, neighbor2]))
93 data_prep.append(DataHandlerNeighbors(neighbor1, path, neighbors=[central_station, neighbor2]))
94 data_prep.append(DataHandlerNeighbors(neighbor2, path, neighbors=[neighbor1, central_station]))
95 return data_prep
98class DummyDataHandler(AbstractDataHandler): # pragma: no cover
100 def __init__(self, name, number_of_samples=None):
101 """This data handler takes a name argument and the number of samples to generate. If not provided, a random
102 number between 100 and 150 is set."""
103 super().__init__()
104 self.name = name
105 self.number_of_samples = number_of_samples if number_of_samples is not None else np.random.randint(100, 150)
106 self._X = self.create_X()
107 self._Y = self.create_Y()
109 def create_X(self):
110 """Inputs are random numbers between 0 and 10 with shape (no_samples, window=14, variables=5)."""
111 X = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables
112 datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
113 return xr.DataArray(X, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
114 "window": range(14),
115 "variables": range(5)})
117 def create_Y(self):
118 """Targets are normal distributed random numbers with shape (no_samples, window=5, variables=1)."""
119 Y = np.round(0.5 * np.random.randn(self.number_of_samples, 5, 1), 1) # samples, window, variables
120 datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist()
121 return xr.DataArray(Y, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist,
122 "window": range(5),
123 "variables": range(1)})
125 def get_X(self, upsampling=False, as_numpy=False):
126 """Upsampling parameter is not used for X."""
127 return np.copy(self._X) if as_numpy is True else self._X
129 def get_Y(self, upsampling=False, as_numpy=False):
130 """Upsampling parameter is not used for Y."""
131 return np.copy(self._Y) if as_numpy is True else self._Y
133 def __str__(self):
134 return self.name
137if __name__ == "__main__":
138 """Comment: This is more for testing. Maybe reuse parts of this code for the testing routines."""
139 a = DataHandlerNeighbors
140 requirements = a.requirements()
142 kwargs = {"path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
143 "station_type": None,
144 "network": 'UBA',
145 "sampling": 'daily',
146 "target_dim": 'variables',
147 "target_var": 'o3',
148 "time_dim": 'datetime',
149 "window_history_size": 7,
150 "window_lead_time": 3,
151 "neighbors": ["DEBW034"],
152 "data_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"),
153 "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'},
154 "transformation": None, }
155 a_inst = a.build("DEBW011", **kwargs)
156 print(a_inst)
158 from mlair.data_handler.iterator import KerasIterator, DataCollection
160 data_prep = create_data_prep()
161 data_collection = DataCollection(data_prep)
162 for data in data_collection:
163 print(data)
164 path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras")
165 keras_it = KerasIterator(data_collection, 100, path, upsampling=True)
166 keras_it[2]