Coverage for mlair/model_modules/linear_model.py: 33%
38 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1"""Calculate ordinary least squared model."""
3__author__ = "Felix Kleinert, Lukas Leufen"
4__date__ = '2019-12-11'
6import numpy as np
7import statsmodels.api as sm
10class OrdinaryLeastSquaredModel:
11 """
12 Implementation of an ordinary least squared model (OLS).
14 Inputs and outputs are retrieved from a generator. This generator needs to return in xarray format and has to be
15 iterable. OLS is calculated on initialisation using statsmodels package. Train your personal OLS using:
17 .. code-block:: python
19 # next(train_data) should be return (x, y)
20 my_ols_model = OrdinaryLeastSquaredModel(train_data)
22 After calculation, use your OLS model with
24 .. code-block:: python
26 # input_data needs to be structured like train data
27 result_ols = my_ols_model.predict(input_data)
29 :param generator: generator object returning a tuple containing inputs and outputs as xarrays
30 """
32 def __init__(self, generator):
33 """Set up OLS model."""
34 self.x = []
35 self.y = []
36 self.generator = generator
37 self.model = self._train_ols_model_from_generator()
39 def _train_ols_model_from_generator(self):
40 self._set_x_y_from_generator()
41 self.x = sm.add_constant(self.x)
42 return self.ordinary_least_squared_model(self.x, self.y)
44 def _set_x_y_from_generator(self):
45 data_x, data_y = None, None
46 for item in self.generator:
47 x, y = item.get_data(as_numpy=True)
48 x = self.flatten(x)
49 data_x = self._concatenate(x, data_x)
50 data_y = self._concatenate(y, data_y)
51 self.x, self.y = np.concatenate(data_x, axis=1), data_y[0]
53 def _concatenate(self, new, old):
54 return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new)) if old is not None else new
56 def predict(self, data):
57 """Apply OLS model on data."""
58 data = sm.add_constant(np.concatenate(self.flatten(data), axis=1), has_constant="add")
59 return np.atleast_2d(self.model.predict(data))
61 @staticmethod
62 def flatten(data):
63 shapes = list(map(lambda x: x.shape, data))
64 return list(map(lambda x, shape: x.reshape(shape[0], -1), data, shapes))
66 @staticmethod
67 def reshape_xarray_to_numpy(data):
68 """Reshape xarray data to numpy data and flatten."""
69 shape = data.values.shape
70 res = data.values.reshape(shape[0], shape[1] * shape[3])
71 return res
73 @staticmethod
74 def ordinary_least_squared_model(x, y):
75 """Calculate ols model using statsmodels."""
76 ols_model = sm.OLS(y, x)
77 return ols_model.fit()