Coverage for encodermap/plot/utils.py: 38%
465 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/plot/utils.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
23"""Utils for plotting. Ideas and Code taken from the matplotlib documentation.
24Menu: https://matplotlib.org/3.1.0/gallery/widgets/menu.html
25LassoSelector: https://matplotlib.org/3.1.1/gallery/widgets/lasso_selector_demo_sgskip.html
26PolygonSelector: https://matplotlib.org/3.1.3/gallery/widgets/polygon_selector_demo.html
27Bezier: https://gist.github.com/gavincangan/b88a978e878e9bb1c0f8804e3af8de3c
29"""
31##############################################################################
32# Imports
33##############################################################################
35import os
36import shutil
38import matplotlib as mpl
39import matplotlib.pyplot as plt
40import numpy as np
41from matplotlib.artist import Artist
42from matplotlib.lines import Line2D
43from matplotlib.patches import Rectangle
44from matplotlib.path import Path
45from matplotlib.widgets import LassoSelector, PolygonSelector, RectangleSelector
46from scipy.special import binom
47from tqdm import tqdm
49from .._optional_imports import _optional_import
50from ..encodermap_tf1.backmapping import dihedral_backmapping
51from ..misc.clustering import gen_dummy_traj, get_cluster_frames
52from ..misc.misc import _datetime_windows_and_linux_compatible
53from .plotting import plot_cluster
55##############################################################################
56# Optional Imports
57##############################################################################
60jinja2 = _optional_import("jinja2")
61ngl = _optional_import("nglview")
62mda = _optional_import("MDAnalysis")
63md = _optional_import("mdtraj")
64pd = _optional_import("pandas")
65binom = _optional_import("scipy", "special.binom")
67##############################################################################
68# Globals
69##############################################################################
71_all__ = [
72 "Props",
73 "SelectFromCollection",
74 "StatusMenu",
75 "Menu",
76 "ModeButton",
77 "MenuItem",
78]
81##############################################################################
82# Functions (mainly for ternary plotting)
83##############################################################################
86def calculate_dssps(trajs, simplified=True):
87 """Calculates dssp from TrajEnsemble.
89 Calculates dssps from ep.TrajEnsemble and returns a numpy array.
91 Args:
92 trajs (ep.TrajEnsemble): The TrajEnsemble class to calculate the dssps from.
93 simplified (bool, optional): Whether to return [H ,E, C] dssp values (True),
94 or [H, B, E, G, I, T, S, ''] dssp values (False). Defaults to True.
96 Returns:
97 np.ndarray: The dssp array of shape (trajs.n_frames, trajs.n_residues);
99 ToDo:
100 Make it work with different protein legths.
102 """
103 all_dssp = []
104 pbar = tqdm(total=trajs.n_trajs, position=0, leave=True)
105 for traj in trajs:
106 dssp = md.compute_dssp(traj)
107 all_dssp.append(dssp)
108 pbar.update()
109 pbar.close()
110 all_dssp = np.vstack(all_dssp)
111 return all_dssp
114def abc_to_rgb(A=0.0, B=0.0, C=0.0):
115 """Map values A, B, C (all in domain [0,1]) to
116 suitable red, green, blue values."""
117 return (min(B + C, 1.0), min(A + C, 1.0), min(A + B, 1.0))
120def digitize_dssp(trajs, dssp, imshow=True, bins=100, progbar=None):
121 """Digitizes the DSSP array"""
122 if progbar is None:
123 progbar = tqdm(total=all_dssp.shape[0] * 4, position=0, leave=True)
124 if np.any(np.isin(dssp, ["H", "E", "C"], assume_unique=False)) and not "B" in dssp:
125 sorted_ = np.array(["H", "E", "C"])
126 else:
127 sorted_ = np.array(["H", "B", "E", "G", "I", "T", "S", ""])
128 num_aas = dssp.shape[1]
129 uniques = correct_missing_uniques(
130 [np.concatenate(np.unique(d, return_counts=True)) for d in dssp],
131 sorted_=sorted_,
132 progbar=progbar,
133 )
134 uniques, counts = uniques[:, :3], uniques[:, 3:]
135 indices = np.vstack([u.argsort()[sorted_.argsort()] for u in uniques])
136 counts = (
137 np.vstack([c[i] for c, i in zip(counts, indices)]).astype(np.float) / num_aas
138 )
139 colors = np.array([abc_to_rgb(a, b, c) for a, b, c in zip(*counts.T)])
140 if imshow:
141 print("digitizing")
142 digitized = np.full((bins, bins, 3), (1.0, 1.0, 1.0))
143 x = trajs.lowd[:, 0]
144 y = trajs.lowd[:, 1]
145 H, xedges, yedges = np.histogram2d(x=x, y=y, bins=bins)
146 for i in range(bins):
147 for j in range(bins):
148 where = np.where(
149 (
150 (trajs.lowd[:, 0] >= xedges[i])
151 & (trajs.lowd[:, 0] < xedges[i + 1])
152 )
153 & (
154 (trajs.lowd[:, 1] >= yedges[j])
155 & (trajs.lowd[:, 1] < yedges[j + 1])
156 )
157 )[0]
158 if len(where) != 0:
159 counts_ = np.mean(counts[where], axis=0)
160 try:
161 color = abc_to_rgb(*counts_)
162 except TypeError:
163 print(counts[where], counts_)
164 raise
165 digitized[i, j] = color
166 progbar.update()
167 return digitized
168 else:
169 return colors
172def correct_missing_uniques(uniques, sorted_, progbar=None):
173 """Takes a list of DSSP letters and counts and adds zeros for all missing letters in sorted_.
175 Args:
176 uniques (list of np.ndarray): A list of np.ndarrays. For every frame the DSSP letters
177 and their respective counts should be provided. So for example
178 ['H', 151] or ['H', 'E', 75, 76].
179 sorted_ (np.ndarray): The sorted list of DSSP letters. This list will be used to update every
180 frame. So The using ['H', 'E', 'C'], the two examples from above will become ['H', 'E', 'C', 151, 0, 0]
181 or ['H', 'E', 'C', 75, 76, 0], respectively.
182 progbar (tqdm): A tqdm progbar. Defaults to None
184 Returns:
185 np.ndarray: The corrected uniques.
187 """
188 if progbar is None:
189 progbar = tqdm(total=len(uniques), position=0, leave=True)
190 for i, u in enumerate(uniques):
191 if len(u) != len(sorted_) * 2:
192 letters, counts = np.split(u, 2)
193 counts = counts.astype(int)
194 newline = np.zeros(len(sorted_), dtype=int)
195 for l, c in zip(letters, counts):
196 newline[np.where(sorted_ == l)] = c
197 newline = np.concatenate([sorted_, newline.astype(str)])
198 uniques[i] = newline
199 progbar.update()
200 uniques = np.vstack(uniques)
201 return uniques
204def _get_system_info():
205 import getpass
206 import platform
207 import re
208 import socket
209 import uuid
211 import psutil
213 info = {}
214 try:
215 info["platform"] = platform.system()
216 info["system_user"] = getpass.getuser()
217 info["platform_release"] = platform.release()
218 info["platform_version"] = platform.version()
219 info["architecture"] = platform.machine()
220 info["hostname"] = socket.gethostname()
221 info["ip_address"] = socket.gethostbyname(socket.gethostname())
222 info["mac_address"] = ":".join(re.findall("..", "%012x" % uuid.getnode()))
223 info["processor"] = platform.processor()
224 info["ram"] = str(round(psutil.virtual_memory().total / (1024.0**3))) + " GB"
225 return info
226 except Exception:
227 return info
230def _check_all_templates_defined(template, info_dict):
231 import re
233 regex = r"\{(.*?)\}"
234 matches = re.finditer(regex, template, re.MULTILINE | re.DOTALL)
235 min_matches = []
236 for matchNum, match in enumerate(matches):
237 for groupNum in range(0, len(match.groups())):
238 min_matches.append(match.group(groupNum))
239 min_matches = list(
240 set(map(lambda x: x.lstrip("{{").rstrip("}}"), [i for i in min_matches]))
241 )
242 if all(key in info_dict for key in min_matches): 242 ↛ 245line 242 didn't jump to line 245, because the condition on line 242 was never false
243 return True
244 else:
245 missing = set(min_matches).difference(info_dict)
246 raise Exception(
247 f"Not all expressions defined in template. Missing expressions: {missing}"
248 )
251def _create_readme(main_path, now, info_dict): # pragma: no cover
252 from pip._internal.operations import freeze
254 from .._version import __version__
255 from .jinja_template import template
257 # update info dict
258 md_file = os.path.join(main_path, "README.md")
259 pip_freeze = ""
260 for i in freeze.freeze():
261 pip_freeze += f" {i}\n"
262 info_dict.update({"pip_freeze": pip_freeze})
263 info_dict.update({"filename": md_file.split(".")[0]})
264 info_dict.update({"now": now})
265 info_dict.update({"encodermap_version": __version__})
266 info_dict.update(_get_system_info())
268 # check
269 assert _check_all_templates_defined(template, info_dict)
271 # jinja2
272 template = jinja2.Template(template)
273 msg = template.render(
274 info_dict,
275 )
277 # write
278 with open(md_file, "w") as f:
279 f.write(msg)
282def _match_tops_and_trajs(min_topfiles, list_of_frames):
283 out = {i: [] for i in min_topfiles}
284 top_file_match = {md.load(i).top: i for i in min_topfiles}
285 for frame in list_of_frames:
286 top_file = top_file_match[frame.top]
287 out[top_file].append(frame)
288 return out
291def _unpack_cluster_info(trajs, main_path, selector, dummy_traj, align_string):
292 max_ = trajs.CVs["user_selected_points"].max()
293 where = np.where(trajs.CVs["user_selected_points"] == max_)[0]
294 now = _datetime_windows_and_linux_compatible()
296 # make dirs
297 os.makedirs(os.path.join(main_path, "clusters"), exist_ok=True)
298 main_path = os.path.join(main_path, f"clusters/{now}")
299 os.makedirs(main_path, exist_ok=True)
301 # define names
302 pdb_name = os.path.join(main_path, f"cluster_id_{max_}_stacked_10_structs.pdb")
303 pdb_start_name = os.path.join(main_path, f"cluster_id_{max_}_start.pdb")
304 pdb_origin_names = os.path.join(main_path, f"cluster_id_{max_}_pdb_origins.txt")
305 xtc_name = os.path.join(main_path, f"cluster_id_{max_}.xtc")
306 csv_name = os.path.join(main_path, f"cluster_id_{max_}_selected_points.csv")
307 png_name = os.path.join(main_path, f"cluster_id_{max_}_image.png")
308 lowd_npy_name = os.path.join(
309 main_path, f"cluster_id_{max_}_cluster_lowd_points.npy"
310 )
311 indices_npy_name = os.path.join(
312 main_path, f"cluster_id_{max_}_cluster_lowd_points_indices.npy"
313 )
314 current_clustering = os.path.join(
315 main_path,
316 f"cluster_id_{max_}_cluster_current_clustering_user_selected_points.npy",
317 )
318 selector_npy_name = os.path.join(
319 main_path, f"cluster_id_{max_}_selector_points.npy"
320 )
321 parents_trajs = os.path.join(
322 main_path, f"cluster_id_{max_}_all_plotted_trajs_in_correct_order.txt"
323 )
325 # save edges of selector
326 if isinstance(selector.lasso, PolygonSelector): 326 ↛ 333line 326 didn't jump to line 333, because the condition on line 326 was never false
327 verts = np.vstack([selector.lasso._xs, selector.lasso._ys]).T
328 selector_npy_name = selector_npy_name.replace(
329 "selector", f"{selector.lasso.__class__.__name__.lower()}"
330 )
331 np.save(selector_npy_name, verts)
332 else:
333 warnings.warn("Can currenlty only save the vertices of Polygon.")
335 # save pdb
336 # if list is provided the pdbs of the list need to be saved
337 if isinstance(dummy_traj, list): 337 ↛ 342line 337 didn't jump to line 342, because the condition on line 337 was never false
338 for i, traj in enumerate(dummy_traj):
339 _pdb_name = pdb_name.replace(".pdb", f"_struct_{i}.pdb")
340 traj.save_pdb(_pdb_name)
341 else:
342 dummy_traj.save_pdb(pdb_name)
344 # render png
345 plot_cluster(trajs, pdb_name, png_name, max_)
347 # save all trajs
348 with open(parents_trajs, "w") as f:
349 for traj in trajs:
350 f.write(
351 f"{os.path.abspath(traj.traj_file)} {os.path.abspath(traj.top_file)} {traj.common_str}\n"
352 )
354 # create df
355 if trajs.lowd.shape[-1] == 2: 355 ↛ 357line 355 didn't jump to line 357, because the condition on line 355 was never false
356 lowd_coords = {"x": [], "y": []}
357 elif trajs.lowd.shaoe[-1] == 3:
358 lowd_coords = {"x": [], "y": [], "z": []}
359 else:
360 lowd_coords = {f"lowd_{i}": [] for i in range(trajs.lowd.shape[-1])}
361 df = pd.DataFrame(
362 {
363 "trajectory file": [],
364 "topology file": [],
365 "frame number": [],
366 "time": [],
367 **lowd_coords,
368 "cluster id": [],
369 "trajectory number": [],
370 f"unique id in set of {trajs.n_trajs} trajs": [],
371 }
372 )
373 for w in where:
374 frame = trajs.get_single_frame(w)
375 if frame.traj_num is None: 375 ↛ 376line 375 didn't jump to line 376, because the condition on line 375 was never true
376 traj_frame = frame.id[0]
377 else:
378 traj_frame = frame.id[0, 1]
379 df = df.append(
380 {
381 "trajectory file": os.path.abspath(frame.traj_file),
382 "topology file": os.path.abspath(frame.top_file),
383 "frame number": traj_frame,
384 "time": frame.time[0],
385 "cluster id": max_,
386 "trajectory number": frame.traj_num,
387 f"unique id in set of {trajs.n_trajs} trajs": w,
388 **{k: v for k, v in zip(lowd_coords.keys(), frame.lowd)},
389 },
390 ignore_index=True,
391 )
392 df = df.astype(
393 dtype={
394 "trajectory file": str,
395 "topology file": str,
396 "frame number": int,
397 "time": float,
398 **{k: float for k in lowd_coords},
399 "cluster id": int,
400 "trajectory number": int,
401 f"unique id in set of {trajs.n_trajs} trajs": int,
402 }
403 )
404 df.to_csv(csv_name, index=False)
406 # save npy
407 np.save(lowd_npy_name, trajs.CVs["user_selected_points"][where])
408 np.save(indices_npy_name, where)
409 np.save(current_clustering, trajs.CVs["user_selected_points"])
411 # save full traj
412 _, dummy_traj = get_cluster_frames(
413 trajs,
414 max_,
415 nglview=False,
416 shorten=False,
417 stack_atoms=True,
418 col="user_selected_points",
419 align_string=align_string,
420 )
421 min_topfiles = list(set([i.top_file for i in trajs]))
422 top_traj_dict = _match_tops_and_trajs(min_topfiles, dummy_traj)
423 basenames = [trajs.basename_fn(i) for i in top_traj_dict.keys()]
424 with open(pdb_origin_names, "w") as f:
425 for i, (key, value) in enumerate(top_traj_dict.items()):
426 _pdb_start_name = pdb_start_name.replace(
427 ".pdb", f"_traj_{i}_from_{basenames[i]}.pdb"
428 )
429 _xtc_name = xtc_name.replace(".xtc", f"_traj_{i}.xtc")
430 shutil.copyfile(key, _pdb_start_name)
431 f.write(
432 f"{_pdb_start_name} is a copy (`shutil.copyfile`) of {min_topfiles[i]}. The corresponding trajectory files might originate from other places. Refer to {parents_trajs} for info about xtcs.\n"
433 )
434 if value: 434 ↛ 425line 434 didn't jump to line 425, because the condition on line 434 was never false
435 for j, frame in enumerate(value):
436 if j == 0:
437 traj_out = frame
438 else:
439 traj_out = traj_out.join(frame)
440 traj_out.save_xtc(_xtc_name)
442 # create an info dict
443 info_dict = {
444 "pdb_name": pdb_name,
445 "pdb_start_name": pdb_start_name,
446 "pdb_origin_names": pdb_origin_names,
447 "xtc_name": xtc_name,
448 "csv_name": csv_name,
449 "png_name": png_name,
450 "lowd_npy_name": lowd_npy_name,
451 "indices_npy_name": indices_npy_name,
452 "current_clustering": current_clustering,
453 "selector_npy_name": selector_npy_name,
454 "parents_trajs": parents_trajs,
455 }
456 info_dict = {k: os.path.basename(v) for k, v in info_dict.items()}
457 info_dict.update({"cluster_id": max_})
458 info_dict.update({"cluster_id": max_})
459 info_dict.update({"n_trajs": trajs.n_trajs})
460 info_dict.update({"cluster_n_points": len(where)})
461 info_dict.update({"basename": trajs[0].basename})
462 info_dict.update({"cluster_abspath": main_path})
464 # create a readme
465 _create_readme(main_path, now, info_dict)
467 return max_, main_path
470def _unpack_path_info(path):
471 pass
474##############################################################################
475# Classes
476##############################################################################
479class Props:
480 """Class to contain the properties of MenuItems.
482 Each MenuItem contains two copies of this class.
483 One for props when the mouse cursor hovers over them.
484 One for the rest of times.
485 Can be used used as class or as dict.
487 Attributes:
488 labelcolor (str): The color of the text.
489 labelcolor_rgb (tuple): The RGBA value of labelcolor.
490 bgcolor (str): The color of the background.
491 bgcolor_rgb (tuple): The RGBA value of bgcolor.
492 fontsize (int): The fontsize.
493 alpha (float): The alpha value of the background.
494 defaults (dict): Class Variable of default values.
496 Examples:
497 >>> props = Props()
498 >>> props.labelcolor
499 black
501 """
503 defaults = {"labelcolor": "black", "bgcolor": "blue", "fontsize": 15, "alpha": 0.2}
505 def __init__(self, **kwargs):
506 """Instantiate the Props class.
508 Takes a dict as input and overwrites the class defaults. The dict is directly
509 stored as an attribute and can be accessed via dot-notation.
511 Args:
512 **kwargs: Dict containing values. If unknonwn values are passed they will be dropped.
514 """
515 self._setattr(self.defaults)
516 for key, value in kwargs.items():
517 if key not in self.__dict__.keys():
518 print(f"Dropping unknown dict entry for {{'{key}': {value}}}")
519 else:
520 setattr(self, key, value)
522 @property
523 def bgcolor_rgb(self):
524 return mpl.colors.to_rgba(self.bgcolor)
526 @property
527 def labelcolor_rgb(self):
528 return mpl.colors.to_rgba(self.labelcolor)
530 def _setattr(self, dictionary):
531 for key, value in dictionary.items():
532 setattr(self, key, value)
534 def __setitiem__(self, key, value):
535 setattr(self, key, value)
537 def __getitem__(self, item):
538 return getattr(self, item)
540 def _string_summary(self):
541 return "ep.plotting.interactive.Props class containing Info for the MenuItem class."
543 def __str__(self):
544 return self._string_summary()
546 def __repr__(self):
547 return f"<{self._string_summary()} Object at 0x{id(self):02x}>"
550class BezierBuilder(object):
551 """Bézier curve interactive builder."""
553 def __init__(self, control_polygon, ax_main, ax_bernstein=None):
554 """Constructor.
555 Receives the initial control polygon of the curve.
556 """
557 self.control_polygon = control_polygon
558 self.xp = list(control_polygon.get_xdata())
559 self.yp = list(control_polygon.get_ydata())
560 self.canvas = control_polygon.figure.canvas
561 self.ax_main = ax_main
562 self.ax_bernstein = ax_bernstein
564 # Event handler for mouse clicking
565 self.cid = self.canvas.mpl_connect("button_press_event", self)
567 # Create Bézier curve
568 line_bezier = Line2D([], [], c=control_polygon.get_markeredgecolor())
569 self.bezier_curve = self.ax_main.add_line(line_bezier)
571 def __call__(self, event):
572 # Ignore clicks outside axes
573 if event.inaxes != self.control_polygon.axes:
574 return
576 # Add point
577 self.xp.append(event.xdata)
578 self.yp.append(event.ydata)
579 self.control_polygon.set_data(self.xp, self.yp)
581 # Rebuild Bézier curve and update canvas
582 self.bezier_curve.set_data(*self._build_bezier())
583 self._update_bernstein()
584 self._update_bezier()
586 def _build_bezier(self):
587 x, y = Bezier(list(zip(self.xp, self.yp))).T
588 return x, y
590 def _update_bezier(self):
591 self.canvas.draw()
593 def _update_bernstein(self):
594 if self.ax_bernstein is not None:
595 N = len(self.xp) - 1
596 t = np.linspace(0, 1, num=200)
597 ax = self.ax_bernstein
598 ax.clear()
599 for kk in range(N + 1):
600 ax.plot(t, Bernstein(N, kk)(t))
601 ax.set_title("Bernstein basis, N = {}".format(N))
602 ax.set_xlim(0, 1)
603 ax.set_ylim(0, 1)
605 @property
606 def ind(self):
607 return np.vstack(self.bezier_curve.get_data()).T
609 def disconnect(self):
610 self.canvas.mpl_disconnect(self.cid)
613def Bernstein(n, k):
614 """Bernstein polynomial."""
615 coeff = binom(n, k)
617 def _bpoly(x):
618 return coeff * x**k * (1 - x) ** (n - k)
620 return _bpoly
623def Bezier(points, num=200):
624 """Build Bézier curve from points."""
625 N = len(points)
626 t = np.linspace(0, 1, num=num)
627 curve = np.zeros((num, 2))
628 for ii in range(N):
629 curve += np.outer(Bernstein(N - 1, ii)(t), points[ii])
630 return curve
633class DummyTool:
634 def disconnect(self):
635 pass
638class SelectFromCollection(object):
639 """Select indices from a matplotlib collection using `LassoSelector`.
641 Selected indices are saved in the `ind` attribute. This tool fades out the
642 points that are not part of the selection (i.e., reduces their alpha
643 values). If your collection has alpha < 1, this tool will permanently
644 alter the alpha values.
646 Note that this tool selects collection objects based on their *origins*
647 (i.e., `offsets`).
649 Args:
650 ax (matplotlib.axes.Axes): Axes to interact with.
651 collection (matplotlib.collections.Collection): Subclass of collection
652 you want to select from.
653 alpha_other (float): To highlight a selection, this tool sets all
654 selected points to an alpha value of 1 and non-selected points to
655 `alpha_other`. Needs to fulfill 0 <= alpha_other <= 1
657 """
659 def __init__(self, ax, collection, alpha_other=0.3, selector=LassoSelector):
660 self.canvas = ax.figure.canvas
661 self.collection = collection
662 self.alpha_other = alpha_other
664 self.xys = collection.get_offsets()
665 self.Npts = len(self.xys)
667 # Ensure that we have separate colors for each object
668 self.fc = collection.get_facecolors()
669 if len(self.fc) == 0:
670 raise ValueError("Collection must have a facecolor")
671 elif len(self.fc) == 1:
672 self.fc = np.tile(self.fc, (self.Npts, 1))
674 self.lasso = selector(ax, onselect=self.onselect, useblit=False)
675 self.ind = []
677 def onselect(self, verts):
678 path = Path(verts)
679 self.ind = np.nonzero(path.contains_points(self.xys))[0]
680 self.fc[:, -1] = self.alpha_other
681 self.fc[self.ind, -1] = 1
682 self.collection.set_facecolors(self.fc)
683 self.canvas.draw_idle()
685 def disconnect(self):
686 self.lasso.disconnect_events()
687 self.fc[:, -1] = 1
688 self.collection.set_facecolors(self.fc)
689 self.canvas.draw_idle()
691 def _string_summary(self):
692 s = (
693 f"encodermap.SelectFromCollection object. Selecting points "
694 f"from a set of {self.Npts} points. The current selector tool "
695 f"is matplotlotlib's {self.lasso} tool. The selected points are "
696 f"{self.ind}"
697 )
698 return s
700 def __str__(self):
701 return self._string_summary()
703 def __repr__(self):
704 return self._string_summary()
707class MenuItem(mpl.patches.Rectangle, mpl.artist.Artist):
708 def __init__(
709 self,
710 fig,
711 xy,
712 width,
713 height,
714 labelstr,
715 props={},
716 hoverprops={},
717 on_select=None,
718 standalone=False,
719 ):
720 # define props and hoverprops
721 self.labelstr = labelstr
722 self.standalone = standalone
723 props = {
724 **{"labelcolor": "black", "bgcolor": "blue", "fontsize": 15, "alpha": 0.2},
725 **props,
726 }
727 hoverprops = {
728 **{"labelcolor": "C0", "bgcolor": "yellow", "fontsize": 15, "alpha": 0.2},
729 **hoverprops,
730 }
731 self.props = Props(**props)
732 self.hoverprops = Props(**hoverprops)
734 # set the on_select method
735 self.on_select = on_select
737 # add a select event
738 # hover event is only included if standalone is true
739 # i.e. the MenuItem is not Instantiated from a Menu
740 # In the case this is instantiated from a menu. The menu wil handle the hover
741 self.cid_button = fig.canvas.mpl_connect(
742 "button_release_event", self.check_select
743 )
744 if self.standalone:
745 self.cid_move = fig.canvas.mpl_connect(
746 "motion_notify_event", self.set_hover
747 )
749 # Instantiate text and Rectangle
750 mpl.patches.Rectangle.__init__(
751 self, xy, width, height, label=self.labelstr, zorder=1
752 )
753 self.text = mpl.text.Text(
754 0,
755 0,
756 self.labelstr,
757 color=self.props.labelcolor,
758 fontproperties=dict(weight="bold", size=self.props.fontsize),
759 zorder=2,
760 verticalalignment="center",
761 horizontalalignment="center",
762 )
764 # final thing to do is set the props of
765 # the rectangle based whether a cursor hovers
766 self.set_hover_props(False)
768 # after artist has been placed get it ready
769 # for accepting hover events
770 self.hover = False
772 def set_figure(self, figure):
773 """Overwriting Base Class method to include labelstr"""
774 mpl.patches.Rectangle.set_figure(self, figure)
775 self.text.set_figure(figure)
777 def set_axes(self, axes):
778 """Overwriting Base Class method to include labelstr"""
779 mpl.patches.Rectangle.set_axes(self, axes)
780 self.text.set_axes(axes)
782 def set_transform(self, transform):
783 """Overwriting Base Class method to include labelstr"""
784 mpl.patches.Rectangle.set_transform(self, transform)
785 # set text to center of self(.rect)
786 bbox = self.get_bbox()
787 x = bbox.x0 + 0.5 * bbox.width
788 y = bbox.y0 + 0.5 * bbox.height
789 texttrans = mpl.transforms.Affine2D().translate(x, y) + self.axes.transData
790 self.text.set_transform(texttrans)
792 def set_data(self, x, y):
793 """Overwriting Base Class method to include labelstr"""
794 # if len(x):
795 # self.text.set_position((x[-1], y[-1]))
796 mpl.patches.Rectangle.set_data(self, x, y)
797 self.text.set_color(self.props.labelcolor_rgb)
799 def check_select(self, event):
800 over, _ = self.contains(event)
801 if not over:
802 return
803 if self.on_select is not None:
804 self.on_select(self)
805 return True
807 def draw(self, renderer):
808 """Overwriting Base Class method to include labelstr"""
809 # draw my label at the end of the line with 2 pixel offset
810 mpl.patches.Rectangle.draw(self, renderer)
811 self.text.draw(renderer)
813 def set_hover_props(self, check):
814 if check:
815 props = self.hoverprops
816 else:
817 props = self.props
819 self.set(facecolor=props.bgcolor_rgb, alpha=props.alpha)
820 self.text.set_color(props.labelcolor_rgb)
822 def set_hover(self, event):
823 """Check the hover status of MenuItem"""
824 check, _ = self.contains(event)
825 changed = check != self.hover
826 if changed:
827 self.set_hover_props(check)
828 self.hover = check
829 if changed and self.standalone:
830 self.figure.canvas.draw()
831 return changed
834class ModeButton(MenuItem):
835 def __init__(self, *args, **kwargs):
836 # overwite the labelstr with 'Idle'
837 super(ModeButton, self).__init__(*args, **kwargs)
839 # overwrite the on_select function with on_select_rotation
840 self.pressed = False
842 def check_select(self, event, overwrite=False):
843 if overwrite:
844 return
845 over, _ = self.contains(event)
846 if not over:
847 return
848 if self.on_select is not None:
849 self.on_select(self)
850 self.pressed = not self.pressed
851 if self.pressed:
852 self.set_hover_props(True)
853 return self.pressed
855 def set_hover(self, event):
856 """Check the hover status of MenuItem"""
857 check, _ = self.contains(event)
858 changed = check != self.hover
859 if changed and not self.pressed:
860 self.set_hover_props(check)
861 self.hover = check
862 return changed
864 def _on_select_rotation(self):
865 """Old function to rotate labelstrings."""
866 if self.status == len(self.labelrotation) - 1:
867 self.status = 0
868 else:
869 self.status += 1
870 self.labelstr = self.labelrotation[self.status]
871 self.label.set_text(self.labelstr)
872 # print(f"You pressed {self.labelstr}")
875class Menu:
876 def __init__(self, ax, items=["Reset", "Write", "Set Points"]):
877 # suppresscomposite
878 self.ax = ax
879 self.fig = self.ax.get_figure()
880 self.fig.suppressComposite = True
882 # add the menu items
883 self.menuitems = {}
884 self.add_items(items=items)
886 # mpl_connect move to set hover stuff
887 self.cid = self.fig.canvas.mpl_connect("motion_notify_event", self.on_move)
889 def on_move(self, event):
890 draw = False
891 for item in self.menuitems.values():
892 draw = item.set_hover(event)
893 if draw:
894 self.fig.canvas.draw()
895 break
897 def add_items(self, items):
898 coords = self.get_coords(len(items))
899 for i, (s, c) in enumerate(zip(items, coords)):
900 # on_select = lambda item: print(f"You pressed {item.labelstr}.")
901 on_select = lambda item: True
902 item = MenuItem(
903 self.fig, (0, c[0]), 1, c[1] - c[0], labelstr=s, on_select=on_select
904 )
905 self.menuitems[s] = item
906 self.ax.add_artist(item)
908 def get_coords(self, no_items, gap_space=0.05):
909 no_gaps = no_items - 1
910 total_length = 1 - no_gaps * gap_space
911 length_per = total_length / no_items
912 coords = []
913 for i in range(no_items):
914 if i == 0:
915 coords.append([0, length_per])
916 else:
917 coords.append(
918 [
919 coords[i - 1][1] + gap_space,
920 coords[i - 1][1] + gap_space + length_per,
921 ]
922 )
923 return coords
926class StatusMenu(Menu):
927 def __init__(self, ax):
928 items = [
929 "Lasso",
930 "Rectangle",
931 "Ellipse",
932 "Polygon",
933 "Path",
934 "Bezier",
935 "Idle",
936 "Mode",
937 ]
938 # call the parent class to use its get_coords() method
939 # the add_items method is overwritten
940 super(StatusMenu, self).__init__(ax, items=items)
942 # set the label rotation and the current status
943 self.set_idle()
945 # click notify event to change self.status and switch ModeButtons off
946 self.cid = self.fig.canvas.mpl_connect("button_release_event", self.on_click)
948 def set_idle(self):
949 """Sets the idle status. Called at __init__ and
950 when nothing is pressed."""
951 self.status = "Idle"
952 self.menuitems[self.status].set_hover_props(True)
953 self.menuitems[self.status].pressed = True
955 def on_click(self, event):
956 draw = False
957 for key, item in self.menuitems.items():
958 draw = item.pressed
959 if draw and key != self.status:
960 print(f"Changing Mode to {key}")
961 self.menuitems[self.status].set_hover_props(False)
962 self.menuitems[self.status].pressed = False
963 self.menuitems[self.status].check_select(event, overwrite=True)
964 self.status = key
965 else:
966 if all([not item.pressed for item in self.menuitems.values()]):
967 print(f"Changing Mode to Idle")
968 self.set_idle()
970 def add_items(self, items):
971 coords = self.get_coords(len(items))
972 for i, (s, c) in enumerate(zip(items, coords)):
973 if s == "Mode":
974 on_select = lambda item: None # print(f"You pressed {item.labelstr}.")
975 # overwrite props and hoverprops
976 props = {
977 "labelcolor": "black",
978 "bgcolor": "orange",
979 "fontsize": 15,
980 "alpha": 1,
981 }
982 hoverprops = {
983 "labelcolor": "black",
984 "bgcolor": "orange",
985 "fontsize": 15,
986 "alpha": 1,
987 }
988 item = MenuItem(
989 self.fig,
990 (0, c[0]),
991 1,
992 c[1] - c[0],
993 labelstr=s,
994 on_select=on_select,
995 props=props,
996 hoverprops=hoverprops,
997 )
998 item.pressed = False
999 else:
1000 # on_select = lambda item: print(f"You pressed {item.labelstr}.")
1001 on_select = lambda item: True
1002 # item = ModeButton(self.fig, self.ax, s, on_select=on_select)
1003 item = ModeButton(
1004 self.fig, (0, c[0]), 1, c[1] - c[0], labelstr=s, on_select=on_select
1005 )
1006 self.menuitems[s] = item
1007 self.ax.add_artist(item)