Coverage for mlair/data_handler/iterator.py: 91%

160 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:40 +0000

1 

2__author__ = 'Lukas Leufen' 

3__date__ = '2020-07-07' 

4 

5from collections import Iterator, Iterable 

6import tensorflow.keras as keras 

7import numpy as np 

8import math 

9import os 

10import shutil 

11import psutil 

12import multiprocessing 

13import logging 

14import dill 

15from typing import Tuple, List 

16 

17 

18class StandardIterator(Iterator): 

19 

20 _position: int = None 

21 

22 def __init__(self, collection: list): 

23 assert isinstance(collection, list) 

24 self._collection = collection 

25 self._position = 0 

26 

27 def __next__(self): 

28 """Return next element or stop iteration.""" 

29 try: 

30 value = self._collection[self._position] 

31 self._position += 1 

32 except IndexError: 

33 raise StopIteration() 

34 return value 

35 

36 

37class DataCollection(Iterable): 

38 

39 def __init__(self, collection: list = None, name: str = None): 

40 if collection is None: 

41 collection = [] 

42 assert isinstance(collection, list) 

43 self._collection = collection.copy() 

44 self._mapping = {} 

45 self._set_mapping() 

46 self._name = name 

47 

48 @property 

49 def name(self): 

50 return self._name 

51 

52 def __len__(self): 

53 return len(self._collection) 

54 

55 def __iter__(self) -> Iterator: 

56 return StandardIterator(self._collection) 

57 

58 def __getitem__(self, index): 

59 if isinstance(index, int): 

60 return self._collection[index] 

61 else: 

62 return self._collection[self._mapping[str(index)]] 

63 

64 def add(self, element): 

65 self._collection.append(element) 

66 self._mapping[str(element)] = len(self._collection) - 1 

67 

68 def _set_mapping(self): 

69 for i, e in enumerate(self._collection): 

70 self._mapping[str(e)] = i 

71 

72 def keys(self): 

73 return list(self._mapping.keys()) 

74 

75 

76class KerasIterator(keras.utils.Sequence): 

77 

78 def __init__(self, collection: DataCollection, batch_size: int, batch_path: str, shuffle_batches: bool = False, 

79 model=None, upsampling=False, name=None, use_multiprocessing=False, max_number_multiprocessing=1): 

80 self._collection = collection 

81 batch_path = os.path.join(batch_path, str(name if name is not None else id(self))) 

82 self._path = os.path.join(batch_path, "%i.pickle") 

83 self.batch_size = batch_size 

84 self.model = model 

85 self.shuffle = shuffle_batches 

86 self.upsampling = upsampling 

87 self.indexes: list = [] 

88 self._cleanup_path(batch_path) 

89 self._prepare_batches(use_multiprocessing, max_number_multiprocessing) 

90 

91 def __len__(self) -> int: 

92 return len(self.indexes) 

93 

94 def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: 

95 """Get batch for given index.""" 

96 return self.__data_generation(self.indexes[index]) 

97 

98 def _get_model_rank(self): 

99 if self.model is not None: 

100 try: 

101 mod_out = self.model.output_shape 

102 except AttributeError as e: 

103 # ToDo replace except statemnet with something meaningful. Depending on BNN architecture the attr 

104 # output_shape might not be defined. We use it here to check the number of tails -> make sure multiple 

105 # tails would also work with BNNs in future versions 

106 mod_out = (None, None) 

107 if isinstance(mod_out, tuple): # only one output branch: (None, ahead) 

108 mod_rank = 1 

109 elif isinstance(mod_out, list): # multiple output branches, e.g.: [(None, ahead), (None, ahead)] 

110 mod_rank = len(mod_out) 

111 else: # pragma: no cover 

112 raise TypeError("model output shape must either be tuple or list.") 

113 return mod_rank 

114 else: # no model provided, assume to use single output 

115 return 1 

116 

117 def __data_generation(self, index: int) -> Tuple[np.ndarray, np.ndarray]: 

118 """Load pickle data from disk.""" 

119 file = self._path % index 

120 with open(file, "rb") as f: 

121 data = dill.load(f) 

122 return data["X"], data["Y"] 

123 

124 @staticmethod 

125 def _concatenate(new: List[np.ndarray], old: List[np.ndarray]) -> List[np.ndarray]: 

126 """Concatenate two lists of data along axis=0.""" 

127 return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new)) 

128 

129 @staticmethod 

130 def _concatenate_multi(*args: List[np.ndarray]) -> List[np.ndarray]: 

131 """Concatenate two lists of data along axis=0.""" 

132 return list(map(lambda *_args: np.concatenate(_args, axis=0), *args)) 

133 

134 def _prepare_batches(self, use_multiprocessing=False, max_process=1) -> None: 

135 """ 

136 Prepare all batches as locally stored files. 

137 

138 Walk through all elements of collection and split (or merge) data according to the batch size. Too long data 

139 sets are divided into multiple batches. Not fully filled batches are retained together with remains from the 

140 next collection elements. These retained data are concatenated and also split into batches. If data are still 

141 remaining afterwards, they are saved as final smaller batch. All batches are enumerated by a running index 

142 starting at 0. A list with all batch numbers is stored in class's parameter indexes. This method can either 

143 use a serial approach or use multiprocessing to decrease computational time. 

144 """ 

145 index = 0 

146 remaining = [] 

147 mod_rank = self._get_model_rank() 

148 n_process = min([psutil.cpu_count(logical=False), len(self._collection), max_process]) # use only physical cpus 

149 if n_process > 1 and use_multiprocessing is True: # parallel solution 149 ↛ 150line 149 didn't jump to line 150, because the condition on line 149 was never true

150 pool = multiprocessing.Pool(n_process) 

151 output = [] 

152 else: 

153 pool = None 

154 output = None 

155 for data in self._collection: 

156 length = data.__len__(self.upsampling) 

157 batches = _get_number_of_mini_batches(length, self.batch_size) 

158 if pool is None: 158 ↛ 163line 158 didn't jump to line 163, because the condition on line 158 was never false

159 res = f_proc(data, self.upsampling, mod_rank, self.batch_size, self._path, index) 

160 if res is not None: 

161 remaining.append(res) 

162 else: 

163 output.append(pool.apply_async(f_proc, args=(data, self.upsampling, mod_rank, self.batch_size, self._path, index))) 

164 index += batches 

165 if output is not None: 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true

166 for p in output: 

167 res = p.get() 

168 if res is not None: 

169 remaining.append(res) 

170 pool.close() 

171 if len(remaining) > 0: 

172 X = self._concatenate_multi(*[e[0] for e in remaining]) 

173 Y = self._concatenate_multi(*[e[1] for e in remaining]) 

174 length = X[0].shape[0] 

175 batches = _get_number_of_mini_batches(length, self.batch_size) 

176 remaining = f_proc((X, Y), self.upsampling, mod_rank, self.batch_size, self._path, index) 

177 index += batches 

178 if remaining is not None: 178 ↛ 181line 178 didn't jump to line 181, because the condition on line 178 was never false

179 _save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index) 

180 index += 1 

181 self.indexes = np.arange(0, index).tolist() 

182 if pool is not None: 182 ↛ 183line 182 didn't jump to line 183, because the condition on line 182 was never true

183 pool.join() 

184 

185 @staticmethod 

186 def _cleanup_path(path: str, create_new: bool = True) -> None: 

187 """First remove existing path, second create empty path if enabled.""" 

188 if os.path.exists(path): 

189 shutil.rmtree(path) 

190 if create_new is True: 

191 os.makedirs(path) 

192 

193 def on_epoch_end(self) -> None: 

194 """Randomly shuffle indexes if enabled.""" 

195 if self.shuffle is True: 

196 np.random.shuffle(self.indexes) 

197 

198 

199def _save_to_pickle(path, X: List[np.ndarray], Y: List[np.ndarray], index: int) -> None: 

200 """Save data as pickle file with variables X and Y and given index as <index>.pickle .""" 

201 data = {"X": X, "Y": Y} 

202 file = path % index 

203 with open(file, "wb") as f: 

204 dill.dump(data, f) 

205 

206 

207def _get_batch(data_list: List[np.ndarray], b: int, batch_size: int) -> List[np.ndarray]: 

208 """Get batch according to batch size from data list.""" 

209 return list(map(lambda data: data[b * batch_size:(b + 1) * batch_size, ...], data_list)) 

210 

211 

212def _permute_data(X, Y): 

213 p = np.random.permutation(len(X[0])) # equiv to .shape[0] 

214 X = list(map(lambda x: x[p], X)) 

215 Y = list(map(lambda x: x[p], Y)) 

216 return X, Y 

217 

218 

219def _get_number_of_mini_batches(number_of_samples: int, batch_size: int) -> int: 

220 """Return number of mini batches as the floored ration of number of samples to batch size.""" 

221 return math.floor(number_of_samples / batch_size) 

222 

223 

224def f_proc(data, upsampling, mod_rank, batch_size, _path, index): 

225 if isinstance(data, tuple) is True: 

226 X, _Y = data 

227 else: 

228 X, _Y = data.get_data(upsampling=upsampling) 

229 Y = [_Y[0] for _ in range(mod_rank)] 

230 if upsampling: 

231 X, Y = _permute_data(X, Y) 

232 length = X[0].shape[0] 

233 batches = _get_number_of_mini_batches(length, batch_size) 

234 for b in range(batches): 

235 batch_X, batch_Y = _get_batch(X, b, batch_size), _get_batch(Y, b, batch_size) 

236 _save_to_pickle(_path, X=batch_X, Y=batch_Y, index=index) 

237 index += 1 

238 if (batches * batch_size) < length: # keep remaining to concatenate with next data element 

239 remaining = (_get_batch(X, batches, batch_size), _get_batch(Y, batches, batch_size)) 

240 else: 

241 remaining = None 

242 return remaining