Coverage for mlair/reference_models/abstract_reference_model.py: 86%
31 statements
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2023-06-01 13:03 +0000
1__author__ = "Felix Kleinert"
2__date__ = "2021-01-29"
4import os
5import sys
6from abc import ABC
8import wget
10from mlair.configuration import check_path_and_create
13class AbstractReferenceModel(ABC):
14 """
15 Abstract reference model. All classes providing some reference or competitor models must inherent from this class.
16 """
17 def __init__(self, *args, **kwargs):
18 pass
20 def make_reference_available_locally(self, *args):
21 raise NotImplementedError
23 @staticmethod
24 def is_reference_available_locally(reference_path) -> bool:
25 """
26 Checks if reference is available locally
27 :param reference_path: look in this path for data
28 """
30 try:
31 if os.listdir(reference_path):
32 res = True
33 else:
34 res = False
35 except FileNotFoundError:
36 res = False
37 return res
40class AbstractReferenceB2share(AbstractReferenceModel):
41 """
42 Abstract class for reference models located on b2share (eudat or fz-juelich)
43 See also https://github.com/EUDAT-Training/B2SHARE-Training/blob/master/api/01_Retrieve_existing_record.md
45 """
47 def __init__(self, b2share_hosturl: str, b2share_bucket: str, b2share_key: str):
48 super().__init__()
49 self.b2share_hosturl = b2share_hosturl
50 self.b2share_bucket = b2share_bucket
51 self.b2share_key = b2share_key
53 @property
54 def b2share_url(self):
55 return f"{self.b2share_hosturl}/api/files/{self.b2share_bucket}"
57 def bar_custom(self, current, total, width=80):
58 progress_message = f"Downloading {self.b2share_key}: {round(current / total * 100)}% [{current} / {total}] bytes"
59 sys.stdout.write("\r" + progress_message)
60 sys.stdout.flush()
62 def download_from_b2share(self, tmp_download_path: str):
63 check_path_and_create(tmp_download_path)
64 wget.download(f"{self.b2share_url}/{self.b2share_key}",
65 out=f"{tmp_download_path}{self.b2share_key}",
66 bar=self.bar_custom
67 )
69 def make_reference_available_locally(self):
70 raise NotImplementedError