Coverage for mlair/reference_models/abstract_reference_model.py: 86%

31 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-12-18 17:51 +0000

1__author__ = "Felix Kleinert" 

2__date__ = "2021-01-29" 

3 

4import os 

5import sys 

6from abc import ABC 

7 

8import wget 

9 

10from mlair.configuration import check_path_and_create 

11 

12 

13class AbstractReferenceModel(ABC): 

14 """ 

15 Abstract reference model. All classes providing some reference or competitor models must inherent from this class. 

16 """ 

17 def __init__(self, *args, **kwargs): 

18 pass 

19 

20 def make_reference_available_locally(self, *args): 

21 raise NotImplementedError 

22 

23 @staticmethod 

24 def is_reference_available_locally(reference_path) -> bool: 

25 """ 

26 Checks if reference is available locally 

27 :param reference_path: look in this path for data 

28 """ 

29 

30 try: 

31 if os.listdir(reference_path): 

32 res = True 

33 else: 

34 res = False 

35 except FileNotFoundError: 

36 res = False 

37 return res 

38 

39 

40class AbstractReferenceB2share(AbstractReferenceModel): 

41 """ 

42 Abstract class for reference models located on b2share (eudat or fz-juelich) 

43 See also https://github.com/EUDAT-Training/B2SHARE-Training/blob/master/api/01_Retrieve_existing_record.md 

44 

45 """ 

46 

47 def __init__(self, b2share_hosturl: str, b2share_bucket: str, b2share_key: str): 

48 super().__init__() 

49 self.b2share_hosturl = b2share_hosturl 

50 self.b2share_bucket = b2share_bucket 

51 self.b2share_key = b2share_key 

52 

53 @property 

54 def b2share_url(self): 

55 return f"{self.b2share_hosturl}/api/files/{self.b2share_bucket}" 

56 

57 def bar_custom(self, current, total, width=80): 

58 progress_message = f"Downloading {self.b2share_key}: {round(current / total * 100)}% [{current} / {total}] bytes" 

59 sys.stdout.write("\r" + progress_message) 

60 sys.stdout.flush() 

61 

62 def download_from_b2share(self, tmp_download_path: str): 

63 check_path_and_create(tmp_download_path) 

64 wget.download(f"{self.b2share_url}/{self.b2share_key}", 

65 out=f"{tmp_download_path}{self.b2share_key}", 

66 bar=self.bar_custom 

67 ) 

68 

69 def make_reference_available_locally(self): 

70 raise NotImplementedError