Coverage for mlair/plotting/tracker_plot.py: 99%

276 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:22 +0000

1from collections import OrderedDict 

2 

3import numpy as np 

4import os 

5from typing import Union, List, Optional, Dict 

6 

7from mlair.helpers import to_list 

8 

9from matplotlib import pyplot as plt, lines as mlines, ticker as ticker 

10from matplotlib.patches import Rectangle 

11 

12 

13class TrackObject: 

14 

15 """ 

16 A TrackObject can be used to create simple chains of objects. 

17 

18 :param name: string or list of strings with a name describing the track object 

19 :param stage: additional meta information (can be used to highlight different blocks inside a chain) 

20 """ 

21 

22 def __init__(self, name: Union[List[str], str], stage: str): 

23 self.name = to_list(name) 

24 self.stage = stage 

25 self.precursor: Optional[List[TrackObject]] = None 

26 self.successor: Optional[List[TrackObject]] = None 

27 self.x: Optional[float] = None 

28 self.y: Optional[float] = None 

29 

30 def __repr__(self): 

31 return str("/".join(self.name)) 

32 

33 @property 

34 def x(self): 

35 """Get x value.""" 

36 return self._x 

37 

38 @x.setter 

39 def x(self, value: float): 

40 """Set x value.""" 

41 self._x = value 

42 

43 @property 

44 def y(self): 

45 """Get y value.""" 

46 return self._y 

47 

48 @y.setter 

49 def y(self, value: float): 

50 """Set y value.""" 

51 self._y = value 

52 

53 def add_precursor(self, precursor: "TrackObject"): 

54 """Add a precursory track object.""" 

55 if self.precursor is None: 

56 self.precursor = [precursor] 

57 else: 

58 if precursor not in self.precursor: 

59 self.precursor.append(precursor) 

60 else: 

61 return 

62 precursor.add_successor(self) 

63 

64 def add_successor(self, successor: "TrackObject"): 

65 """Add a successive track object.""" 

66 if self.successor is None: 

67 self.successor = [successor] 

68 else: 

69 if successor not in self.successor: 

70 self.successor.append(successor) 

71 else: 

72 return 

73 successor.add_precursor(self) 

74 

75 

76class TrackChain: 

77 

78 def __init__(self, track_list): 

79 self.track_list = track_list 

80 self.scopes = self.get_all_scopes(self.track_list) 

81 self.dims = self.get_all_dims(self.scopes) 

82 

83 def get_all_scopes(self, track_list) -> Dict: 

84 """Return dictionary with all distinct variables as keys and its unique scopes as values.""" 

85 dims = {} 

86 for track_dict in track_list: # all stages 

87 for track in track_dict.values(): # single stage, all variables 

88 for k, v in track.items(): # single variable 

89 scopes = self.get_unique_scopes(v) 

90 if dims.get(k) is None: 

91 dims[k] = scopes 

92 else: 

93 dims[k] = np.unique(scopes + dims[k]).tolist() 

94 return OrderedDict(sorted(dims.items())) 

95 

96 @staticmethod 

97 def get_all_dims(scopes): 

98 dims = {} 

99 for k, v in scopes.items(): 

100 dims[k] = len(v) 

101 return dims 

102 

103 def create_track_chain(self): 

104 control = self.control_dict(self.scopes) 

105 track_chain_dict = OrderedDict() 

106 for track_dict in self.track_list: 

107 stage, stage_track = list(track_dict.items())[0] 

108 track_chain, control = self._create_track_chain(control, OrderedDict(sorted(stage_track.items())), stage) 

109 control = self.clean_control(control) 

110 track_chain_dict[stage] = track_chain 

111 return track_chain_dict 

112 

113 def _create_track_chain(self, control, sorted_track_dict, stage): 

114 track_objects = [] 

115 for variable, all_variable_tracks in sorted_track_dict.items(): 

116 for track_details in all_variable_tracks: 

117 method, scope = track_details["method"], track_details["scope"] 

118 tr = TrackObject([variable, method, scope], stage) 

119 control_obj = control[variable][scope] 

120 if method == "set": 

121 track_objects = self._add_set_object(track_objects, tr, control_obj) 

122 elif method == "get": # pragma: no branch 

123 track_objects, skip_control_update = self._add_get_object(track_objects, tr, control_obj, 

124 control, scope, variable) 

125 if skip_control_update is True: 

126 continue 

127 else: # pragma: no cover 

128 raise ValueError(f"method must be either set or get but given was {method}.") 

129 self._update_control(control, variable, scope, tr) 

130 return track_objects, control 

131 

132 @staticmethod 

133 def _update_control(control, variable, scope, tr_obj): 

134 control[variable][scope] = tr_obj 

135 

136 @staticmethod 

137 def _add_track_object(track_objects, tr_obj, prev_obj): 

138 if tr_obj.stage != prev_obj.stage: 

139 track_objects.append(prev_obj) 

140 return track_objects 

141 

142 def _add_precursor(self, track_objects, tr_obj, prev_obj): 

143 tr_obj.add_precursor(prev_obj) 

144 return self._add_track_object(track_objects, tr_obj, prev_obj) 

145 

146 def _add_set_object(self, track_objects, tr_obj, control_obj): 

147 if control_obj is not None: 

148 track_objects = self._add_precursor(track_objects, tr_obj, control_obj) 

149 else: 

150 track_objects.append(tr_obj) 

151 return track_objects 

152 

153 def _recursive_decent(self, scope, control_obj_var): 

154 scope = scope.rsplit(".", 1) 

155 if len(scope) > 1: 

156 scope = scope[0] 

157 control_obj = control_obj_var[scope] 

158 if control_obj is not None: 

159 pre, candidate = control_obj, control_obj 

160 while pre.precursor is not None and pre.name[1] != "set": 

161 # change candidate on stage border 

162 if pre.name[2] != pre.precursor[0].name[2]: 

163 candidate = pre 

164 pre = pre.precursor[0] 

165 # correct pre if candidate is from same scope 

166 if candidate.name[2] == pre.name[2]: 

167 pre = candidate 

168 return pre 

169 else: 

170 return self._recursive_decent(scope, control_obj_var) 

171 

172 def _add_get_object(self, track_objects, tr_obj, control_obj, control, scope, variable): 

173 skip_control_update = False 

174 if control_obj is not None: 

175 track_objects = self._add_precursor(track_objects, tr_obj, control_obj) 

176 else: 

177 pre = self._recursive_decent(scope, control[variable]) 

178 if pre is not None: 

179 track_objects = self._add_precursor(track_objects, tr_obj, pre) 

180 else: 

181 skip_control_update = True 

182 return track_objects, skip_control_update 

183 

184 @staticmethod 

185 def control_dict(scopes): 

186 """Create empty control dictionary with variables and scopes as keys and None as default for all values.""" 

187 control = {} 

188 for variable, scope_names in scopes.items(): 

189 control[variable] = {} 

190 for s in scope_names: 

191 update = {s: None} 

192 if len(control[variable].keys()) == 0: 

193 control[variable] = update 

194 else: 

195 control[variable].update(update) 

196 return control 

197 

198 @staticmethod 

199 def clean_control(control): 

200 for k, v in control.items(): # var. scopes 

201 for kv, vv in v.items(): # scope tr_obj 

202 try: 

203 if vv.precursor[0].name[2] != vv.name[2]: 

204 control[k][kv] = None 

205 except (TypeError, AttributeError): 

206 pass 

207 return control 

208 

209 @staticmethod 

210 def get_unique_scopes(track_list: List[Dict]) -> List[str]: 

211 """Get list with all unique elements from input including general scope if missing.""" 

212 scopes = [e["scope"] for e in track_list] + ["general"] 

213 return np.unique(scopes).tolist() 

214 

215 

216class TrackPlot: 

217 

218 def __init__(self, tracker_list, sparse_conn_mode=True, plot_folder: str = ".", skip_run_env=True, plot_name=None): 

219 

220 self.width = 0.6 

221 self.height = 0.5 

222 self.space_intern_y = 0.2 

223 self.space_extern_y = 1 

224 self.space_intern_x = 0.4 

225 self.space_extern_x = 0.6 

226 self.y_pos = None 

227 self.anchor = None 

228 self.x_max = None 

229 

230 track_chain_obj = TrackChain(tracker_list) 

231 track_chain_dict = track_chain_obj.create_track_chain() 

232 self.set_ypos_anchor(track_chain_obj.scopes, track_chain_obj.dims) 

233 self.fig, self.ax = plt.subplots(figsize=(len(tracker_list) * 2, (self.anchor.max() - self.anchor.min()) / 3)) 

234 self._plot(track_chain_dict, sparse_conn_mode, skip_run_env, plot_folder, plot_name) 

235 

236 def _plot(self, track_chain_dict, sparse_conn_mode, skip_run_env, plot_folder, plot_name=None): 

237 stages, v_lines = self.create_track_chain_plot(track_chain_dict, sparse_conn_mode=sparse_conn_mode, 

238 skip_run_env=skip_run_env) 

239 self.set_lims() 

240 self.add_variable_names() 

241 self.add_stages(v_lines, stages) 

242 plt.tight_layout() 

243 plot_name = "tracking.pdf" if plot_name is None else plot_name 

244 plot_name = os.path.join(os.path.abspath(plot_folder), plot_name) 

245 plt.savefig(plot_name, dpi=600) 

246 

247 def line(self, start_x, end_x, y, color="darkgrey"): 

248 """Draw grey horizontal connection line from start_x to end_x on y-pos.""" 

249 # draw white border line 

250 l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color="white", 

251 linewidth=2.5) 

252 self.ax.add_line(l) 

253 # draw grey line 

254 l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color=color, 

255 linewidth=1.4) 

256 self.ax.add_line(l) 

257 

258 def step(self, start_x, end_x, start_y, end_y, color="black"): 

259 """Draw black connection step line from start_xy to end_xy. Step is taken shortly before end position.""" 

260 # adjust start and end by width height 

261 start_x += self.width 

262 start_y += self.height / 2 

263 end_y += self.height / 2 

264 step_x = end_x - (self.space_intern_x) / 2 # step is taken shortly before end 

265 pos_x = [start_x, step_x, step_x, end_x] 

266 pos_y = [start_y, start_y, end_y, end_y] 

267 # draw white border line 

268 l = mlines.Line2D(pos_x, pos_y, color="white", linewidth=2.5) 

269 self.ax.add_line(l) 

270 # draw black line 

271 l = mlines.Line2D(pos_x, pos_y, color=color, linewidth=1.4) 

272 self.ax.add_line(l) 

273 

274 def rect(self, x, y, method="get"): 

275 """Draw rectangle with lower left at (x,y), size equal to width/height and label/color according to method.""" 

276 # draw rectangle 

277 color = {"get": "orange"}.get(method, "lightblue") 

278 r = Rectangle((x, y), self.width, self.height, color=color) 

279 self.ax.add_artist(r) 

280 # add label 

281 rx, ry = r.get_xy() 

282 cx = rx + r.get_width() / 2.0 

283 cy = ry + r.get_height() / 2.0 

284 self.ax.annotate(method, (cx, cy), color='w', weight='bold', fontsize=6, ha='center', va='center') 

285 

286 def set_ypos_anchor(self, scopes, dims): 

287 anchor = sum(dims.values()) 

288 pos_dict = {} 

289 d_y = 0 

290 for k, v in scopes.items(): 

291 pos_dict[k] = {} 

292 for e in v: 

293 update = {e: anchor + d_y} 

294 if len(pos_dict[k].keys()) == 0: 

295 pos_dict[k] = update 

296 else: 

297 pos_dict[k].update(update) 

298 d_y -= (self.space_intern_y + self.height) 

299 d_y -= (self.space_extern_y - self.space_intern_y) 

300 self.y_pos = pos_dict 

301 self.anchor = np.array((d_y, self.height + self.space_extern_y)) + anchor 

302 

303 def plot_track_chain(self, chain, y_pos, x_pos=0, prev=None, stage=None, sparse_conn_mode=False): 

304 if (chain.successor is None) or (chain.stage == stage): 

305 var, method, scope = chain.name 

306 x, y = x_pos, y_pos[var][scope] 

307 self.rect(x, y, method=method) 

308 chain.x, chain.y = x, y 

309 if prev is not None and prev[0] is not None: 

310 if (sparse_conn_mode is True) and (method == "set"): 

311 pass 

312 else: 

313 if y == prev[1]: 

314 self.line(prev[0], x, prev[1]) 

315 else: 

316 self.step(prev[0], x, prev[1], y) 

317 else: 

318 x, y = chain.x, chain.y 

319 

320 x_max = None 

321 if chain.successor is not None: 

322 stage_count = 0 

323 for e in chain.successor: 

324 if e.stage == stage: 

325 stage_count += 1 

326 if stage_count > 50: 326 ↛ 327line 326 didn't jump to line 327, because the condition on line 326 was never true

327 continue 

328 shift = self.width + self.space_intern_x if chain.stage == e.stage else 0 

329 x_tmp = self.plot_track_chain(e, y_pos, x_pos + shift, prev=(x, y), 

330 stage=stage, sparse_conn_mode=sparse_conn_mode) 

331 x_max = np.nanmax(np.array([x_tmp, x_max], dtype=np.float64)) 

332 else: 

333 x_max = np.nanmax(np.array([x, x_max, x_pos], dtype=np.float64)) 

334 else: 

335 x_max = x 

336 

337 return x_max 

338 

339 def add_variable_names(self): 

340 labels = [] 

341 pos = [] 

342 labels_major = [] 

343 pos_major = [] 

344 for k, v in self.y_pos.items(): 

345 for kv, vv in v.items(): 

346 if kv == "general": 

347 labels_major.append(k) 

348 pos_major.append(vv + self.height / 2) 

349 else: 

350 labels.append(kv.split(".", 1)[1]) 

351 pos.append(vv + self.height / 2) 

352 self.ax.tick_params(axis="y", which="major", labelsize="large") 

353 self.ax.yaxis.set_major_locator(ticker.FixedLocator(pos_major)) 

354 self.ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels_major)) 

355 self.ax.yaxis.set_minor_locator(ticker.FixedLocator(pos)) 

356 self.ax.yaxis.set_minor_formatter(ticker.FixedFormatter(labels)) 

357 

358 def add_stages(self, vlines, stages): 

359 x_max = self.x_max + self.space_intern_x + self.width 

360 for l in vlines: 

361 self.ax.vlines(l, *self.anchor, "black", "dashed") 

362 vlines = [0] + vlines + [x_max] 

363 pos = [(vlines[i] + vlines[i+1]) / 2 for i in range(len(vlines)-1)] 

364 self.ax.xaxis.set_major_locator(ticker.FixedLocator(pos)) 

365 self.ax.xaxis.set_major_formatter(ticker.FixedFormatter(stages)) 

366 

367 def create_track_chain_plot(self, track_chain_dict, sparse_conn_mode=True, skip_run_env=True): 

368 x, x_max = 0, 0 

369 v_lines, stages = [], [] 

370 for stage, track_chain in track_chain_dict.items(): 

371 if stage == "RunEnvironment" and skip_run_env is True: 

372 continue 

373 if x > 0: 

374 v_lines.append(x - self.space_extern_x / 2) 

375 for e in track_chain: 

376 x_max = max(x_max, self.plot_track_chain(e, self.y_pos, x_pos=x, stage=stage, sparse_conn_mode=sparse_conn_mode)) 

377 x = x_max + self.space_extern_x + self.width 

378 stages.append(stage) 

379 self.x_max = x_max 

380 return stages, v_lines 

381 

382 def set_lims(self): 

383 x_max = self.x_max + self.space_intern_x + self.width 

384 self.ax.set_xlim((0, x_max)) 

385 self.ax.set_ylim(self.anchor)