Coverage for mlair/plotting/tracker_plot.py: 99%
276 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1from collections import OrderedDict
3import numpy as np
4import os
5from typing import Union, List, Optional, Dict
7from mlair.helpers import to_list
9from matplotlib import pyplot as plt, lines as mlines, ticker as ticker
10from matplotlib.patches import Rectangle
13class TrackObject:
15 """
16 A TrackObject can be used to create simple chains of objects.
18 :param name: string or list of strings with a name describing the track object
19 :param stage: additional meta information (can be used to highlight different blocks inside a chain)
20 """
22 def __init__(self, name: Union[List[str], str], stage: str):
23 self.name = to_list(name)
24 self.stage = stage
25 self.precursor: Optional[List[TrackObject]] = None
26 self.successor: Optional[List[TrackObject]] = None
27 self.x: Optional[float] = None
28 self.y: Optional[float] = None
30 def __repr__(self):
31 return str("/".join(self.name))
33 @property
34 def x(self):
35 """Get x value."""
36 return self._x
38 @x.setter
39 def x(self, value: float):
40 """Set x value."""
41 self._x = value
43 @property
44 def y(self):
45 """Get y value."""
46 return self._y
48 @y.setter
49 def y(self, value: float):
50 """Set y value."""
51 self._y = value
53 def add_precursor(self, precursor: "TrackObject"):
54 """Add a precursory track object."""
55 if self.precursor is None:
56 self.precursor = [precursor]
57 else:
58 if precursor not in self.precursor:
59 self.precursor.append(precursor)
60 else:
61 return
62 precursor.add_successor(self)
64 def add_successor(self, successor: "TrackObject"):
65 """Add a successive track object."""
66 if self.successor is None:
67 self.successor = [successor]
68 else:
69 if successor not in self.successor:
70 self.successor.append(successor)
71 else:
72 return
73 successor.add_precursor(self)
76class TrackChain:
78 def __init__(self, track_list):
79 self.track_list = track_list
80 self.scopes = self.get_all_scopes(self.track_list)
81 self.dims = self.get_all_dims(self.scopes)
83 def get_all_scopes(self, track_list) -> Dict:
84 """Return dictionary with all distinct variables as keys and its unique scopes as values."""
85 dims = {}
86 for track_dict in track_list: # all stages
87 for track in track_dict.values(): # single stage, all variables
88 for k, v in track.items(): # single variable
89 scopes = self.get_unique_scopes(v)
90 if dims.get(k) is None:
91 dims[k] = scopes
92 else:
93 dims[k] = np.unique(scopes + dims[k]).tolist()
94 return OrderedDict(sorted(dims.items()))
96 @staticmethod
97 def get_all_dims(scopes):
98 dims = {}
99 for k, v in scopes.items():
100 dims[k] = len(v)
101 return dims
103 def create_track_chain(self):
104 control = self.control_dict(self.scopes)
105 track_chain_dict = OrderedDict()
106 for track_dict in self.track_list:
107 stage, stage_track = list(track_dict.items())[0]
108 track_chain, control = self._create_track_chain(control, OrderedDict(sorted(stage_track.items())), stage)
109 control = self.clean_control(control)
110 track_chain_dict[stage] = track_chain
111 return track_chain_dict
113 def _create_track_chain(self, control, sorted_track_dict, stage):
114 track_objects = []
115 for variable, all_variable_tracks in sorted_track_dict.items():
116 for track_details in all_variable_tracks:
117 method, scope = track_details["method"], track_details["scope"]
118 tr = TrackObject([variable, method, scope], stage)
119 control_obj = control[variable][scope]
120 if method == "set":
121 track_objects = self._add_set_object(track_objects, tr, control_obj)
122 elif method == "get": # pragma: no branch
123 track_objects, skip_control_update = self._add_get_object(track_objects, tr, control_obj,
124 control, scope, variable)
125 if skip_control_update is True:
126 continue
127 else: # pragma: no cover
128 raise ValueError(f"method must be either set or get but given was {method}.")
129 self._update_control(control, variable, scope, tr)
130 return track_objects, control
132 @staticmethod
133 def _update_control(control, variable, scope, tr_obj):
134 control[variable][scope] = tr_obj
136 @staticmethod
137 def _add_track_object(track_objects, tr_obj, prev_obj):
138 if tr_obj.stage != prev_obj.stage:
139 track_objects.append(prev_obj)
140 return track_objects
142 def _add_precursor(self, track_objects, tr_obj, prev_obj):
143 tr_obj.add_precursor(prev_obj)
144 return self._add_track_object(track_objects, tr_obj, prev_obj)
146 def _add_set_object(self, track_objects, tr_obj, control_obj):
147 if control_obj is not None:
148 track_objects = self._add_precursor(track_objects, tr_obj, control_obj)
149 else:
150 track_objects.append(tr_obj)
151 return track_objects
153 def _recursive_decent(self, scope, control_obj_var):
154 scope = scope.rsplit(".", 1)
155 if len(scope) > 1:
156 scope = scope[0]
157 control_obj = control_obj_var[scope]
158 if control_obj is not None:
159 pre, candidate = control_obj, control_obj
160 while pre.precursor is not None and pre.name[1] != "set":
161 # change candidate on stage border
162 if pre.name[2] != pre.precursor[0].name[2]:
163 candidate = pre
164 pre = pre.precursor[0]
165 # correct pre if candidate is from same scope
166 if candidate.name[2] == pre.name[2]:
167 pre = candidate
168 return pre
169 else:
170 return self._recursive_decent(scope, control_obj_var)
172 def _add_get_object(self, track_objects, tr_obj, control_obj, control, scope, variable):
173 skip_control_update = False
174 if control_obj is not None:
175 track_objects = self._add_precursor(track_objects, tr_obj, control_obj)
176 else:
177 pre = self._recursive_decent(scope, control[variable])
178 if pre is not None:
179 track_objects = self._add_precursor(track_objects, tr_obj, pre)
180 else:
181 skip_control_update = True
182 return track_objects, skip_control_update
184 @staticmethod
185 def control_dict(scopes):
186 """Create empty control dictionary with variables and scopes as keys and None as default for all values."""
187 control = {}
188 for variable, scope_names in scopes.items():
189 control[variable] = {}
190 for s in scope_names:
191 update = {s: None}
192 if len(control[variable].keys()) == 0:
193 control[variable] = update
194 else:
195 control[variable].update(update)
196 return control
198 @staticmethod
199 def clean_control(control):
200 for k, v in control.items(): # var. scopes
201 for kv, vv in v.items(): # scope tr_obj
202 try:
203 if vv.precursor[0].name[2] != vv.name[2]:
204 control[k][kv] = None
205 except (TypeError, AttributeError):
206 pass
207 return control
209 @staticmethod
210 def get_unique_scopes(track_list: List[Dict]) -> List[str]:
211 """Get list with all unique elements from input including general scope if missing."""
212 scopes = [e["scope"] for e in track_list] + ["general"]
213 return np.unique(scopes).tolist()
216class TrackPlot:
218 def __init__(self, tracker_list, sparse_conn_mode=True, plot_folder: str = ".", skip_run_env=True, plot_name=None):
220 self.width = 0.6
221 self.height = 0.5
222 self.space_intern_y = 0.2
223 self.space_extern_y = 1
224 self.space_intern_x = 0.4
225 self.space_extern_x = 0.6
226 self.y_pos = None
227 self.anchor = None
228 self.x_max = None
230 track_chain_obj = TrackChain(tracker_list)
231 track_chain_dict = track_chain_obj.create_track_chain()
232 self.set_ypos_anchor(track_chain_obj.scopes, track_chain_obj.dims)
233 self.fig, self.ax = plt.subplots(figsize=(len(tracker_list) * 2, (self.anchor.max() - self.anchor.min()) / 3))
234 self._plot(track_chain_dict, sparse_conn_mode, skip_run_env, plot_folder, plot_name)
236 def _plot(self, track_chain_dict, sparse_conn_mode, skip_run_env, plot_folder, plot_name=None):
237 stages, v_lines = self.create_track_chain_plot(track_chain_dict, sparse_conn_mode=sparse_conn_mode,
238 skip_run_env=skip_run_env)
239 self.set_lims()
240 self.add_variable_names()
241 self.add_stages(v_lines, stages)
242 plt.tight_layout()
243 plot_name = "tracking.pdf" if plot_name is None else plot_name
244 plot_name = os.path.join(os.path.abspath(plot_folder), plot_name)
245 plt.savefig(plot_name, dpi=600)
247 def line(self, start_x, end_x, y, color="darkgrey"):
248 """Draw grey horizontal connection line from start_x to end_x on y-pos."""
249 # draw white border line
250 l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color="white",
251 linewidth=2.5)
252 self.ax.add_line(l)
253 # draw grey line
254 l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color=color,
255 linewidth=1.4)
256 self.ax.add_line(l)
258 def step(self, start_x, end_x, start_y, end_y, color="black"):
259 """Draw black connection step line from start_xy to end_xy. Step is taken shortly before end position."""
260 # adjust start and end by width height
261 start_x += self.width
262 start_y += self.height / 2
263 end_y += self.height / 2
264 step_x = end_x - (self.space_intern_x) / 2 # step is taken shortly before end
265 pos_x = [start_x, step_x, step_x, end_x]
266 pos_y = [start_y, start_y, end_y, end_y]
267 # draw white border line
268 l = mlines.Line2D(pos_x, pos_y, color="white", linewidth=2.5)
269 self.ax.add_line(l)
270 # draw black line
271 l = mlines.Line2D(pos_x, pos_y, color=color, linewidth=1.4)
272 self.ax.add_line(l)
274 def rect(self, x, y, method="get"):
275 """Draw rectangle with lower left at (x,y), size equal to width/height and label/color according to method."""
276 # draw rectangle
277 color = {"get": "orange"}.get(method, "lightblue")
278 r = Rectangle((x, y), self.width, self.height, color=color)
279 self.ax.add_artist(r)
280 # add label
281 rx, ry = r.get_xy()
282 cx = rx + r.get_width() / 2.0
283 cy = ry + r.get_height() / 2.0
284 self.ax.annotate(method, (cx, cy), color='w', weight='bold', fontsize=6, ha='center', va='center')
286 def set_ypos_anchor(self, scopes, dims):
287 anchor = sum(dims.values())
288 pos_dict = {}
289 d_y = 0
290 for k, v in scopes.items():
291 pos_dict[k] = {}
292 for e in v:
293 update = {e: anchor + d_y}
294 if len(pos_dict[k].keys()) == 0:
295 pos_dict[k] = update
296 else:
297 pos_dict[k].update(update)
298 d_y -= (self.space_intern_y + self.height)
299 d_y -= (self.space_extern_y - self.space_intern_y)
300 self.y_pos = pos_dict
301 self.anchor = np.array((d_y, self.height + self.space_extern_y)) + anchor
303 def plot_track_chain(self, chain, y_pos, x_pos=0, prev=None, stage=None, sparse_conn_mode=False):
304 if (chain.successor is None) or (chain.stage == stage):
305 var, method, scope = chain.name
306 x, y = x_pos, y_pos[var][scope]
307 self.rect(x, y, method=method)
308 chain.x, chain.y = x, y
309 if prev is not None and prev[0] is not None:
310 if (sparse_conn_mode is True) and (method == "set"):
311 pass
312 else:
313 if y == prev[1]:
314 self.line(prev[0], x, prev[1])
315 else:
316 self.step(prev[0], x, prev[1], y)
317 else:
318 x, y = chain.x, chain.y
320 x_max = None
321 if chain.successor is not None:
322 stage_count = 0
323 for e in chain.successor:
324 if e.stage == stage:
325 stage_count += 1
326 if stage_count > 50: 326 ↛ 327line 326 didn't jump to line 327, because the condition on line 326 was never true
327 continue
328 shift = self.width + self.space_intern_x if chain.stage == e.stage else 0
329 x_tmp = self.plot_track_chain(e, y_pos, x_pos + shift, prev=(x, y),
330 stage=stage, sparse_conn_mode=sparse_conn_mode)
331 x_max = np.nanmax(np.array([x_tmp, x_max], dtype=np.float64))
332 else:
333 x_max = np.nanmax(np.array([x, x_max, x_pos], dtype=np.float64))
334 else:
335 x_max = x
337 return x_max
339 def add_variable_names(self):
340 labels = []
341 pos = []
342 labels_major = []
343 pos_major = []
344 for k, v in self.y_pos.items():
345 for kv, vv in v.items():
346 if kv == "general":
347 labels_major.append(k)
348 pos_major.append(vv + self.height / 2)
349 else:
350 labels.append(kv.split(".", 1)[1])
351 pos.append(vv + self.height / 2)
352 self.ax.tick_params(axis="y", which="major", labelsize="large")
353 self.ax.yaxis.set_major_locator(ticker.FixedLocator(pos_major))
354 self.ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels_major))
355 self.ax.yaxis.set_minor_locator(ticker.FixedLocator(pos))
356 self.ax.yaxis.set_minor_formatter(ticker.FixedFormatter(labels))
358 def add_stages(self, vlines, stages):
359 x_max = self.x_max + self.space_intern_x + self.width
360 for l in vlines:
361 self.ax.vlines(l, *self.anchor, "black", "dashed")
362 vlines = [0] + vlines + [x_max]
363 pos = [(vlines[i] + vlines[i+1]) / 2 for i in range(len(vlines)-1)]
364 self.ax.xaxis.set_major_locator(ticker.FixedLocator(pos))
365 self.ax.xaxis.set_major_formatter(ticker.FixedFormatter(stages))
367 def create_track_chain_plot(self, track_chain_dict, sparse_conn_mode=True, skip_run_env=True):
368 x, x_max = 0, 0
369 v_lines, stages = [], []
370 for stage, track_chain in track_chain_dict.items():
371 if stage == "RunEnvironment" and skip_run_env is True:
372 continue
373 if x > 0:
374 v_lines.append(x - self.space_extern_x / 2)
375 for e in track_chain:
376 x_max = max(x_max, self.plot_track_chain(e, self.y_pos, x_pos=x, stage=stage, sparse_conn_mode=sparse_conn_mode))
377 x = x_max + self.space_extern_x + self.width
378 stages.append(stage)
379 self.x_max = x_max
380 return stages, v_lines
382 def set_lims(self):
383 x_max = self.x_max + self.space_intern_x + self.width
384 self.ax.set_xlim((0, x_max))
385 self.ax.set_ylim(self.anchor)