Coverage for mlair/model_modules/linear_model.py: 33%

38 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1"""Calculate ordinary least squared model.""" 

2 

3__author__ = "Felix Kleinert, Lukas Leufen" 

4__date__ = '2019-12-11' 

5 

6import numpy as np 

7import statsmodels.api as sm 

8 

9 

10class OrdinaryLeastSquaredModel: 

11 """ 

12 Implementation of an ordinary least squared model (OLS). 

13 

14 Inputs and outputs are retrieved from a generator. This generator needs to return in xarray format and has to be 

15 iterable. OLS is calculated on initialisation using statsmodels package. Train your personal OLS using: 

16 

17 .. code-block:: python 

18 

19 # next(train_data) should be return (x, y) 

20 my_ols_model = OrdinaryLeastSquaredModel(train_data) 

21 

22 After calculation, use your OLS model with 

23 

24 .. code-block:: python 

25 

26 # input_data needs to be structured like train data 

27 result_ols = my_ols_model.predict(input_data) 

28 

29 :param generator: generator object returning a tuple containing inputs and outputs as xarrays 

30 """ 

31 

32 def __init__(self, generator): 

33 """Set up OLS model.""" 

34 self.x = [] 

35 self.y = [] 

36 self.generator = generator 

37 self.model = self._train_ols_model_from_generator() 

38 

39 def _train_ols_model_from_generator(self): 

40 self._set_x_y_from_generator() 

41 self.x = sm.add_constant(self.x) 

42 return self.ordinary_least_squared_model(self.x, self.y) 

43 

44 def _set_x_y_from_generator(self): 

45 data_x, data_y = None, None 

46 for item in self.generator: 

47 x, y = item.get_data(as_numpy=True) 

48 x = self.flatten(x) 

49 data_x = self._concatenate(x, data_x) 

50 data_y = self._concatenate(y, data_y) 

51 self.x, self.y = np.concatenate(data_x, axis=1), data_y[0] 

52 

53 def _concatenate(self, new, old): 

54 return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new)) if old is not None else new 

55 

56 def predict(self, data): 

57 """Apply OLS model on data.""" 

58 data = sm.add_constant(np.concatenate(self.flatten(data), axis=1), has_constant="add") 

59 return np.atleast_2d(self.model.predict(data)) 

60 

61 @staticmethod 

62 def flatten(data): 

63 shapes = list(map(lambda x: x.shape, data)) 

64 return list(map(lambda x, shape: x.reshape(shape[0], -1), data, shapes)) 

65 

66 @staticmethod 

67 def reshape_xarray_to_numpy(data): 

68 """Reshape xarray data to numpy data and flatten.""" 

69 shape = data.values.shape 

70 res = data.values.reshape(shape[0], shape[1] * shape[3]) 

71 return res 

72 

73 @staticmethod 

74 def ordinary_least_squared_model(x, y): 

75 """Calculate ols model using statsmodels.""" 

76 ols_model = sm.OLS(y, x) 

77 return ols_model.fit()