Coverage for mlair/model_modules/loss.py: 78%

15 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-01 13:03 +0000

1"""Collection of different customised loss functions.""" 

2 

3from tensorflow.keras import backend as K 

4 

5from typing import Callable 

6 

7 

8def l_p_loss(power: int) -> Callable: 

9 """ 

10 Calculate the L<p> loss for given power p. 

11 

12 L1 (p=1) is equal to mean absolute error (MAE), L2 (p=2) is to mean squared error (MSE), ... 

13 

14 :param power: set the power of the error calculus 

15 

16 :return: loss for given power 

17 """ 

18 

19 def l_p_loss(y_true, y_pred): 

20 return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1) 

21 

22 return l_p_loss 

23 

24 

25def var_loss(y_true, y_pred) -> Callable: 

26 return K.mean(K.square(K.var(y_true) - K.var(y_pred))) 

27 

28 

29def custom_loss(loss_list, loss_weights=None) -> Callable: 

30 n = len(loss_list) 

31 if loss_weights is None: 

32 loss_weights = [1. / n for _ in range(n)] 

33 else: 

34 assert len(loss_weights) == n 

35 loss_weights = [w / sum(loss_weights) for w in loss_weights] 

36 

37 def loss(y_true, y_pred): 

38 return sum([loss_weights[i] * loss_list[i](y_true, y_pred) for i in range(n)]) 

39 

40 return loss