Coverage for encodermap/plot/utils.py: 15%
464 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/plot/utils.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
23"""Utils for plotting. Ideas and Code taken from the matplotlib documentation.
24Menu: https://matplotlib.org/3.1.0/gallery/widgets/menu.html
25LassoSelector: https://matplotlib.org/3.1.1/gallery/widgets/lasso_selector_demo_sgskip.html
26PolygonSelector: https://matplotlib.org/3.1.3/gallery/widgets/polygon_selector_demo.html
27Bezier: https://gist.github.com/gavincangan/b88a978e878e9bb1c0f8804e3af8de3c
29"""
31################################################################################
32# Imports
33################################################################################
35# Future Imports at the top
36from __future__ import annotations
38# Standard Library Imports
39import copy
40import os
41import shutil
42import warnings
44# Third Party Imports
45import matplotlib as mpl
46import matplotlib.pyplot as plt
47import numpy as np
48from matplotlib.artist import Artist
49from matplotlib.lines import Line2D
50from matplotlib.patches import Rectangle
51from matplotlib.path import Path
52from matplotlib.widgets import LassoSelector, PolygonSelector, RectangleSelector
53from optional_imports import _optional_import
54from packaging import version
55from scipy.special import binom
56from tqdm import tqdm
58# Encodermap imports
59from encodermap.misc.misc import _datetime_windows_and_linux_compatible
62################################################################################
63# Optional Imports
64################################################################################
67jinja2 = _optional_import("jinja2")
68ngl = _optional_import("nglview")
69mda = _optional_import("MDAnalysis")
70md = _optional_import("mdtraj")
71pd = _optional_import("pandas")
72binom = _optional_import("scipy", "special.binom")
73plotly_lasso = _optional_import("plotly", "callbacks.LassoSelector")
76################################################################################
77# Typing
78################################################################################
81# Standard Library Imports
82from typing import TYPE_CHECKING, Any, Optional, Union
85if TYPE_CHECKING:
86 # Local Folder Imports
87 from ..trajinfo.info_all import TrajEnsemble
90################################################################################
91# Globals
92################################################################################
94_all__ = [
95 "Props",
96 "SelectFromCollection",
97 "StatusMenu",
98 "Menu",
99 "ModeButton",
100 "MenuItem",
101]
104################################################################################
105# Functions (mainly for ternary plotting)
106################################################################################
109def abc_to_rgb(A=0.0, B=0.0, C=0.0):
110 """Map values A, B, C (all in domain [0,1]) to
111 suitable red, green, blue values."""
112 return (min(B + C, 1.0), min(A + C, 1.0), min(A + B, 1.0))
115def digitize_dssp(trajs, dssp, imshow=True, bins=100, progbar=None):
116 """Digitizes the DSSP array"""
117 if progbar is None:
118 progbar = tqdm(total=all_dssp.shape[0] * 4, position=0, leave=True)
119 if np.any(np.isin(dssp, ["H", "E", "C"], assume_unique=False)) and not "B" in dssp:
120 sorted_ = np.array(["H", "E", "C"])
121 else:
122 sorted_ = np.array(["H", "B", "E", "G", "I", "T", "S", ""])
123 num_aas = dssp.shape[1]
124 uniques = correct_missing_uniques(
125 [np.concatenate(np.unique(d, return_counts=True)) for d in dssp],
126 sorted_=sorted_,
127 progbar=progbar,
128 )
129 uniques, counts = uniques[:, :3], uniques[:, 3:]
130 indices = np.vstack([u.argsort()[sorted_.argsort()] for u in uniques])
131 counts = (
132 np.vstack([c[i] for c, i in zip(counts, indices)]).astype(np.float) / num_aas
133 )
134 colors = np.array([abc_to_rgb(a, b, c) for a, b, c in zip(*counts.T)])
135 if imshow:
136 print("digitizing")
137 digitized = np.full((bins, bins, 3), (1.0, 1.0, 1.0))
138 x = trajs.lowd[:, 0]
139 y = trajs.lowd[:, 1]
140 H, xedges, yedges = np.histogram2d(x=x, y=y, bins=bins)
141 for i in range(bins):
142 for j in range(bins):
143 where = np.where(
144 (
145 (trajs.lowd[:, 0] >= xedges[i])
146 & (trajs.lowd[:, 0] < xedges[i + 1])
147 )
148 & (
149 (trajs.lowd[:, 1] >= yedges[j])
150 & (trajs.lowd[:, 1] < yedges[j + 1])
151 )
152 )[0]
153 if len(where) != 0:
154 counts_ = np.mean(counts[where], axis=0)
155 try:
156 color = abc_to_rgb(*counts_)
157 except TypeError:
158 print(counts[where], counts_)
159 raise
160 digitized[i, j] = color
161 progbar.update()
162 return digitized
163 else:
164 return colors
167def correct_missing_uniques(uniques, sorted_, progbar=None):
168 """Takes a list of DSSP letters and counts and adds zeros for all missing letters in sorted_.
170 Args:
171 uniques (list of np.ndarray): A list of np.ndarrays. For every frame the DSSP letters
172 and their respective counts should be provided. So for example
173 ['H', 151] or ['H', 'E', 75, 76].
174 sorted_ (np.ndarray): The sorted list of DSSP letters. This list will be used to update every
175 frame. So The using ['H', 'E', 'C'], the two examples from above will become ['H', 'E', 'C', 151, 0, 0]
176 or ['H', 'E', 'C', 75, 76, 0], respectively.
177 progbar (tqdm): A tqdm progbar. Defaults to None
179 Returns:
180 np.ndarray: The corrected uniques.
182 """
183 if progbar is None:
184 progbar = tqdm(total=len(uniques), position=0, leave=True)
185 for i, u in enumerate(uniques):
186 if len(u) != len(sorted_) * 2:
187 letters, counts = np.split(u, 2)
188 counts = counts.astype(int)
189 newline = np.zeros(len(sorted_), dtype=int)
190 for l, c in zip(letters, counts):
191 newline[np.where(sorted_ == l)] = c
192 newline = np.concatenate([sorted_, newline.astype(str)])
193 uniques[i] = newline
194 progbar.update()
195 uniques = np.vstack(uniques)
196 return uniques
199def _get_system_info() -> dict[str, Any]:
200 # Standard Library Imports
201 import getpass
202 import platform
203 import re
204 import socket
205 import uuid
207 # Third Party Imports
208 import psutil
210 info = {}
211 try:
212 info["platform"] = platform.system()
213 info["system_user"] = getpass.getuser()
214 info["platform_release"] = platform.release()
215 info["platform_version"] = platform.version()
216 info["architecture"] = platform.machine()
217 info["hostname"] = socket.gethostname()
218 info["ip_address"] = socket.gethostbyname(socket.gethostname())
219 info["mac_address"] = ":".join(re.findall("..", "%012x" % uuid.getnode()))
220 info["processor"] = platform.processor()
221 info["ram"] = str(round(psutil.virtual_memory().total / (1024.0**3))) + " GB"
222 return info
223 except Exception:
224 return info
227def _check_all_templates_defined(template, info_dict):
228 # Standard Library Imports
229 import re
231 regex = r"\{(.*?)\}"
232 matches = re.finditer(regex, template, re.MULTILINE | re.DOTALL)
233 min_matches = []
234 for matchNum, match in enumerate(matches):
235 for groupNum in range(0, len(match.groups())):
236 min_matches.append(match.group(groupNum))
237 min_matches = list(
238 set(map(lambda x: x.lstrip("{{").rstrip("}}"), [i for i in min_matches]))
239 )
240 if all(key in info_dict for key in min_matches):
241 return True
242 else:
243 missing = set(min_matches).difference(info_dict)
244 raise Exception(
245 f"Not all expressions defined in template. Missing expressions: {missing}"
246 )
249def _create_readme(main_path, now, info_dict): # pragma: no cover
250 # Third Party Imports
251 from pip._internal.operations import freeze
253 # Local Folder Imports
254 from .._version import get_versions
256 __version__ = get_versions()["version"]
257 # Local Folder Imports
258 from .jinja_template import template
260 # update info dict
261 md_file = os.path.join(main_path, "README.md")
262 pip_freeze = ""
263 for i in freeze.freeze():
264 pip_freeze += f" {i}\n"
265 info_dict.update({"pip_freeze": pip_freeze})
266 info_dict.update({"filename": md_file.split(".")[0]})
267 info_dict.update({"now": now})
268 info_dict.update({"encodermap_version": __version__})
269 info_dict.update(_get_system_info())
271 # check
272 assert _check_all_templates_defined(template, info_dict)
274 # jinja2
275 template = jinja2.Template(template)
276 msg = template.render(
277 info_dict,
278 )
280 # write
281 with open(md_file, "w") as f:
282 f.write(msg)
285def _unpack_cluster_info(
286 trajs: TrajEnsemble,
287 main_path: Union[Path, str],
288 selector: Any,
289 dummy_traj: TrajEnsemble,
290 align_string: str,
291 col: str,
292 display: Any,
293 progbar: Any,
294) -> tuple[int, Path]:
295 # Standard Library Imports
296 from pathlib import Path
298 main_path = Path(main_path)
299 max_ = trajs.CVs[col].max()
300 where = np.where(trajs.CVs[col] == max_)[0]
301 length = len(where)
302 now = _datetime_windows_and_linux_compatible()
304 # make dirs
305 os.makedirs(os.path.join(main_path, "clusters"), exist_ok=True)
306 main_path = os.path.join(main_path, f"clusters/{now}")
307 os.makedirs(main_path, exist_ok=True)
308 progbar.update()
310 # define names
311 h5_name = os.path.join(main_path, f"cluster_id_{max_}_stacked_{length}_structs.h5")
312 pdb_start_name = os.path.join(main_path, f"cluster_id_{max_}_start.pdb")
313 pdb_origin_names = os.path.join(main_path, f"cluster_id_{max_}_pdb_origins.txt")
314 xtc_name = os.path.join(main_path, f"cluster_id_{max_}.xtc")
315 csv_name = os.path.join(main_path, f"cluster_id_{max_}_selected_points.csv")
316 png_name = os.path.join(main_path, f"cluster_id_{max_}_image.png")
317 lowd_npy_name = os.path.join(
318 main_path, f"cluster_id_{max_}_cluster_lowd_points.npy"
319 )
320 indices_npy_name = os.path.join(
321 main_path, f"cluster_id_{max_}_cluster_lowd_points_indices.npy"
322 )
323 current_clustering = os.path.join(
324 main_path,
325 f"cluster_id_{max_}_cluster_current_clustering_%s.npy" % col,
326 )
327 selector_npy_name = os.path.join(
328 main_path, f"cluster_id_{max_}_selector_points.npy"
329 )
330 parents_trajs = os.path.join(
331 main_path, f"cluster_id_{max_}_all_plotted_trajs_in_correct_order.txt"
332 )
334 # save edges of selector
335 try:
336 verts = np.vstack([selector.xs, selector.ys]).T
337 selector_npy_name = selector_npy_name.replace(
338 "selector", f"{selector.__class__.__name__.lower()}"
339 )
340 np.save(selector_npy_name, verts)
341 except Exception as e:
342 display.outputs = []
343 with display:
344 print(f"Currently only plotly's LassoSelector is available. Exception: {e}")
345 return
346 progbar.update()
348 # save the output as a h5 file, so we can also save CVs and lowd
349 dummy_traj.save(h5_name)
350 progbar.update()
352 # render png
353 # plot_cluster(trajs, h5_name, png_name, max_)
355 # save all trajs
356 with open(parents_trajs, "w") as f:
357 for traj in trajs:
358 f.write(
359 f"{os.path.abspath(traj.traj_file)} {os.path.abspath(traj.top_file)} {traj.traj_num} {traj.common_str}\n"
360 )
362 # create df
363 if trajs.lowd.shape[-1] == 2:
364 lowd_coords = {"x": [], "y": []}
365 elif trajs.lowd.shaoe[-1] == 3:
366 lowd_coords = {"x": [], "y": [], "z": []}
367 else:
368 lowd_coords = {f"lowd_{i}": [] for i in range(trajs.lowd.shape[-1])}
369 progbar.update()
370 df = pd.DataFrame(
371 {
372 "trajectory file": [],
373 "topology file": [],
374 "frame number": [],
375 "time": [],
376 **lowd_coords,
377 "cluster id": [],
378 "trajectory number": [],
379 }
380 )
381 # display.outputs = []
382 # with display:
383 # print(f"Dataframe created {df.shape=}. {where=}")
385 progbar.update()
386 for frame_num, frame in dummy_traj.iterframes():
387 if version.parse(pd.__version__) >= version.parse("2.0.0"):
388 df.loc[len(df)] = pd.Series(
389 {
390 "trajectory file": os.path.abspath(frame.traj_file),
391 "topology file": os.path.abspath(frame.top_file),
392 "frame number": frame_num,
393 "time": frame.time[0],
394 "cluster id": max_,
395 "trajectory number": frame.traj_num,
396 **{k: v for k, v in zip(lowd_coords.keys(), frame.lowd[0])},
397 }
398 )
399 else:
400 df = df.append(
401 {
402 "trajectory file": os.path.abspath(frame.traj_file),
403 "topology file": os.path.abspath(frame.top_file),
404 "frame number": frame_num,
405 "time": frame.time[0],
406 "cluster id": max_,
407 "trajectory number": frame.traj_num,
408 **{k: v for k, v in zip(lowd_coords.keys(), frame.lowd[0])},
409 },
410 ignore_index=True,
411 )
412 display.outputs = []
413 progbar.update()
414 df = df.astype(
415 dtype={
416 "trajectory file": str,
417 "topology file": str,
418 "frame number": int,
419 "time": float,
420 **{k: float for k in lowd_coords},
421 "cluster id": int,
422 "trajectory number": int,
423 }
424 )
425 df.to_csv(csv_name, index=False)
426 progbar.update()
428 # save npy
429 np.save(lowd_npy_name, trajs.CVs[col][where])
430 np.save(indices_npy_name, where)
431 np.save(current_clustering, trajs.CVs[col])
432 progbar.update()
434 # save full traj
435 progbar.update()
436 with open(pdb_origin_names, "w") as f:
437 for i, (top, value) in enumerate(dummy_traj.trajs_by_top.items()):
438 _pdb_start_name = pdb_start_name.replace(
439 ".pdb", f"_traj_{i}_from_{trajs.basename_fn(value.top_files[0])}.pdb"
440 )
441 _xtc_name = xtc_name.replace(".xtc", f"_traj_{i}.xtc")
442 joined = value.join(progbar=False)[top]
443 joined[0].save_pdb(_pdb_start_name)
444 # shutil.copyfile(key, _pdb_start_name)
445 f.write(
446 f"{_pdb_start_name} is a copy (`shutil.copyfile`) of "
447 f"{value.top_files[0]}. The corresponding trajectory files might "
448 f"originate from other places. Refer to {parents_trajs} for info about xtcs.\n"
449 )
450 joined.save_xtc(_xtc_name)
452 # create an info dict
453 # Local Folder Imports
454 from .jinja_template import h5_parents, h5_rebuild, xtc_parents, xtc_rebuild
456 info_dict = {
457 "h5_name": h5_name,
458 "pdb_start_name": pdb_start_name,
459 "pdb_origin_names": pdb_origin_names,
460 "xtc_name": xtc_name,
461 "csv_name": csv_name,
462 "png_name": png_name,
463 "lowd_npy_name": lowd_npy_name,
464 "indices_npy_name": indices_npy_name,
465 "current_clustering": current_clustering,
466 "selector_npy_name": selector_npy_name,
467 "parents_trajs": parents_trajs,
468 }
470 if all([t.extension == ".h5" for t in trajs]):
471 template = jinja2.Template(h5_rebuild)
472 rebuild_clustering_info = template.render(
473 {"h5_file": trajs[0]._traj_file.resolve(), **info_dict},
474 )
475 template = jinja2.Template(h5_parents)
476 parents_trajs = template.render(
477 {"h5_file": trajs[0]._traj_file.resolve(), **info_dict},
478 )
479 else:
480 template = jinja2.Template(xtc_rebuild)
481 rebuild_clustering_info = template.render(
482 info_dict,
483 )
484 template = jinja2.Template(xtc_parents)
485 parents_trajs = template.render(
486 info_dict,
487 )
489 info_dict["parents_trajs"] = parents_trajs
490 info_dict["rebuild_clustering_info"] = rebuild_clustering_info
491 info_dict = {k: os.path.basename(v) for k, v in info_dict.items()}
492 info_dict.update({"cluster_id": max_})
493 info_dict.update({"cluster_id": max_})
494 info_dict.update({"n_trajs": trajs.n_trajs})
495 info_dict.update({"cluster_n_points": len(where)})
496 info_dict.update({"basename": trajs[0].basename})
497 info_dict.update({"cluster_abspath": main_path})
499 # create a readme
500 _create_readme(main_path, now, info_dict)
501 progbar.update()
503 return max_, main_path
506################################################################################
507# Classes
508################################################################################
511class Props:
512 """Class to contain the properties of MenuItems.
514 Each MenuItem contains two copies of this class.
515 One for props when the mouse cursor hovers over them.
516 One for the rest of times.
517 Can be used as class or as dict.
519 Attributes:
520 labelcolor (str): The color of the text.
521 labelcolor_rgb (tuple): The RGBA value of labelcolor.
522 bgcolor (str): The color of the background.
523 bgcolor_rgb (tuple): The RGBA value of bgcolor.
524 fontsize (int): The fontsize.
525 alpha (float): The alpha value of the background.
526 defaults (dict): Class Variable of default values.
528 Examples:
529 >>> props = Props()
530 >>> props.labelcolor
531 'black'
533 """
535 defaults = {"labelcolor": "black", "bgcolor": "blue", "fontsize": 15, "alpha": 0.2}
537 def __init__(self, **kwargs):
538 """Instantiate the Props class.
540 Takes a dict as input and overwrites the class defaults. The dict is directly
541 stored as an attribute and can be accessed via dot-notation.
543 Args:
544 **kwargs: dict containing values. If unknonwn values are passed they will be dropped.
546 """
547 self._setattr(self.defaults)
548 for key, value in kwargs.items():
549 if key not in self.__dict__.keys():
550 print(f"Dropping unknown dict entry for {{'{key}': {value}}}")
551 else:
552 setattr(self, key, value)
554 @property
555 def bgcolor_rgb(self):
556 return mpl.colors.to_rgba(self.bgcolor)
558 @property
559 def labelcolor_rgb(self):
560 return mpl.colors.to_rgba(self.labelcolor)
562 def _setattr(self, dictionary):
563 for key, value in dictionary.items():
564 setattr(self, key, value)
566 def __setitiem__(self, key, value):
567 setattr(self, key, value)
569 def __getitem__(self, item):
570 return getattr(self, item)
572 def _string_summary(self):
573 return "ep.plotting.interactive.Props class containing Info for the MenuItem class."
575 def __str__(self):
576 return self._string_summary()
578 def __repr__(self):
579 return f"<{self._string_summary()} Object at 0x{id(self):02x}>"
582class BezierBuilder(object):
583 """Bézier curve interactive builder."""
585 def __init__(self, control_polygon, ax_main, ax_bernstein=None):
586 """Constructor.
587 Receives the initial control polygon of the curve.
588 """
589 self.control_polygon = control_polygon
590 self.xp = list(control_polygon.get_xdata())
591 self.yp = list(control_polygon.get_ydata())
592 self.canvas = control_polygon.figure.canvas
593 self.ax_main = ax_main
594 self.ax_bernstein = ax_bernstein
596 # Event handler for mouse clicking
597 self.cid = self.canvas.mpl_connect("button_press_event", self)
599 # Create Bézier curve
600 line_bezier = Line2D([], [], c=control_polygon.get_markeredgecolor())
601 self.bezier_curve = self.ax_main.add_line(line_bezier)
603 def __call__(self, event):
604 # Ignore clicks outside axes
605 if event.inaxes != self.control_polygon.axes:
606 return
608 # Add point
609 self.xp.append(event.xdata)
610 self.yp.append(event.ydata)
611 self.control_polygon.set_data(self.xp, self.yp)
613 # Rebuild Bézier curve and update canvas
614 self.bezier_curve.set_data(*self._build_bezier())
615 self._update_bernstein()
616 self._update_bezier()
618 def _build_bezier(self):
619 x, y = Bezier(list(zip(self.xp, self.yp))).T
620 return x, y
622 def _update_bezier(self):
623 self.canvas.draw()
625 def _update_bernstein(self):
626 if self.ax_bernstein is not None:
627 N = len(self.xp) - 1
628 t = np.linspace(0, 1, num=200)
629 ax = self.ax_bernstein
630 ax.clear()
631 for kk in range(N + 1):
632 ax.plot(t, Bernstein(N, kk)(t))
633 ax.set_title("Bernstein basis, N = {}".format(N))
634 ax.set_xlim(0, 1)
635 ax.set_ylim(0, 1)
637 @property
638 def ind(self):
639 return np.vstack(self.bezier_curve.get_data()).T
641 def disconnect(self):
642 self.canvas.mpl_disconnect(self.cid)
645def Bernstein(n, k):
646 """Bernstein polynomial."""
647 coeff = binom(n, k)
649 def _bpoly(x):
650 return coeff * x**k * (1 - x) ** (n - k)
652 return _bpoly
655def Bezier(points, num=200):
656 """Build Bézier curve from points."""
657 N = len(points)
658 t = np.linspace(0, 1, num=num)
659 curve = np.zeros((num, 2))
660 for ii in range(N):
661 curve += np.outer(Bernstein(N - 1, ii)(t), points[ii])
662 return curve
665class DummyTool:
666 def disconnect(self):
667 pass
670class SelectFromCollection(object):
671 """Select indices from a matplotlib collection using `LassoSelector`.
673 Selected indices are saved in the `ind` attribute. This tool fades out the
674 points that are not part of the selection (i.e., reduces their alpha
675 values). If your collection has alpha < 1, this tool will permanently
676 alter the alpha values.
678 Note that this tool selects collection objects based on their *origins*
679 (i.e., `offsets`).
681 Args:
682 ax (matplotlib.axes.Axes): Axes to interact with.
683 collection (matplotlib.collections.Collection): Subclass of collection
684 you want to select from.
685 alpha_other (float): To highlight a selection, this tool sets all
686 selected points to an alpha value of 1 and non-selected points to
687 `alpha_other`. Needs to fulfill 0 <= alpha_other <= 1
689 """
691 def __init__(self, ax, collection, alpha_other=0.3, selector=LassoSelector):
692 self.canvas = ax.figure.canvas
693 self.collection = collection
694 self.alpha_other = alpha_other
696 self.xys = collection.get_offsets()
697 self.Npts = len(self.xys)
699 # Ensure that we have separate colors for each object
700 self.fc = collection.get_facecolors()
701 if len(self.fc) == 0:
702 raise ValueError("Collection must have a facecolor")
703 elif len(self.fc) == 1:
704 self.fc = np.tile(self.fc, (self.Npts, 1))
706 self.lasso = selector(ax, onselect=self.onselect, useblit=False)
707 self.ind = []
709 def onselect(self, verts):
710 path = Path(verts)
711 self.ind = np.nonzero(path.contains_points(self.xys))[0]
712 self.fc[:, -1] = self.alpha_other
713 self.fc[self.ind, -1] = 1
714 self.collection.set_facecolors(self.fc)
715 self.canvas.draw_idle()
717 def disconnect(self):
718 self.lasso.disconnect_events()
719 self.fc[:, -1] = 1
720 self.collection.set_facecolors(self.fc)
721 self.canvas.draw_idle()
723 def _string_summary(self):
724 s = (
725 f"encodermap.SelectFromCollection object. Selecting points "
726 f"from a set of {self.Npts} points. The current selector tool "
727 f"is matplotlotlib's {self.lasso} tool. The selected points are "
728 f"{self.ind}"
729 )
730 return s
732 def __str__(self):
733 return self._string_summary()
735 def __repr__(self):
736 return self._string_summary()
739class MenuItem(mpl.patches.Rectangle, mpl.artist.Artist):
740 def __init__(
741 self,
742 fig,
743 xy,
744 width,
745 height,
746 labelstr,
747 props={},
748 hoverprops={},
749 on_select=None,
750 standalone=False,
751 ):
752 # define props and hoverprops
753 self.labelstr = labelstr
754 self.standalone = standalone
755 props = {
756 **{"labelcolor": "black", "bgcolor": "blue", "fontsize": 15, "alpha": 0.2},
757 **props,
758 }
759 hoverprops = {
760 **{"labelcolor": "C0", "bgcolor": "yellow", "fontsize": 15, "alpha": 0.2},
761 **hoverprops,
762 }
763 self.props = Props(**props)
764 self.hoverprops = Props(**hoverprops)
766 # set the on_select method
767 self.on_select = on_select
769 # add a select event
770 # hover event is only included if standalone is true
771 # i.e. the MenuItem is not Instantiated from a Menu
772 # In the case this is instantiated from a menu. The menu wil handle the hover
773 self.cid_button = fig.canvas.mpl_connect(
774 "button_release_event", self.check_select
775 )
776 if self.standalone:
777 self.cid_move = fig.canvas.mpl_connect(
778 "motion_notify_event", self.set_hover
779 )
781 # Instantiate text and Rectangle
782 mpl.patches.Rectangle.__init__(
783 self, xy, width, height, label=self.labelstr, zorder=1
784 )
785 self.text = mpl.text.Text(
786 0,
787 0,
788 self.labelstr,
789 color=self.props.labelcolor,
790 fontproperties=dict(weight="bold", size=self.props.fontsize),
791 zorder=2,
792 verticalalignment="center",
793 horizontalalignment="center",
794 )
796 # final thing to do is set the props of
797 # the rectangle based whether a cursor hovers
798 self.set_hover_props(False)
800 # after artist has been placed get it ready
801 # for accepting hover events
802 self.hover = False
804 def set_figure(self, figure):
805 """Overwriting Base Class method to include labelstr"""
806 mpl.patches.Rectangle.set_figure(self, figure)
807 self.text.set_figure(figure)
809 def set_axes(self, axes):
810 """Overwriting Base Class method to include labelstr"""
811 mpl.patches.Rectangle.set_axes(self, axes)
812 self.text.set_axes(axes)
814 def set_transform(self, transform):
815 """Overwriting Base Class method to include labelstr"""
816 mpl.patches.Rectangle.set_transform(self, transform)
817 # set text to center of self(.rect)
818 bbox = self.get_bbox()
819 x = bbox.x0 + 0.5 * bbox.width
820 y = bbox.y0 + 0.5 * bbox.height
821 texttrans = mpl.transforms.Affine2D().translate(x, y) + self.axes.transData
822 self.text.set_transform(texttrans)
824 def set_data(self, x, y):
825 """Overwriting Base Class method to include labelstr"""
826 # if len(x):
827 # self.text.set_position((x[-1], y[-1]))
828 mpl.patches.Rectangle.set_data(self, x, y)
829 self.text.set_color(self.props.labelcolor_rgb)
831 def check_select(self, event):
832 over, _ = self.contains(event)
833 if not over:
834 return
835 if self.on_select is not None:
836 self.on_select(self)
837 return True
839 def draw(self, renderer):
840 """Overwriting Base Class method to include labelstr"""
841 # draw my label at the end of the line with 2 pixel offset
842 mpl.patches.Rectangle.draw(self, renderer)
843 self.text.draw(renderer)
845 def set_hover_props(self, check):
846 if check:
847 props = self.hoverprops
848 else:
849 props = self.props
851 self.set(facecolor=props.bgcolor_rgb, alpha=props.alpha)
852 self.text.set_color(props.labelcolor_rgb)
854 def set_hover(self, event):
855 """Check the hover status of MenuItem"""
856 check, _ = self.contains(event)
857 changed = check != self.hover
858 if changed:
859 self.set_hover_props(check)
860 self.hover = check
861 if changed and self.standalone:
862 self.figure.canvas.draw()
863 return changed
866class ModeButton(MenuItem):
867 def __init__(self, *args, **kwargs):
868 # overwite the labelstr with 'Idle'
869 super(ModeButton, self).__init__(*args, **kwargs)
871 # overwrite the on_select function with on_select_rotation
872 self.pressed = False
874 def check_select(self, event, overwrite=False):
875 if overwrite:
876 return
877 over, _ = self.contains(event)
878 if not over:
879 return
880 if self.on_select is not None:
881 self.on_select(self)
882 self.pressed = not self.pressed
883 if self.pressed:
884 self.set_hover_props(True)
885 return self.pressed
887 def set_hover(self, event):
888 """Check the hover status of MenuItem"""
889 check, _ = self.contains(event)
890 changed = check != self.hover
891 if changed and not self.pressed:
892 self.set_hover_props(check)
893 self.hover = check
894 return changed
896 def _on_select_rotation(self):
897 """Old function to rotate labelstrings."""
898 if self.status == len(self.labelrotation) - 1:
899 self.status = 0
900 else:
901 self.status += 1
902 self.labelstr = self.labelrotation[self.status]
903 self.label.set_text(self.labelstr)
904 # print(f"You pressed {self.labelstr}")
907class Menu:
908 def __init__(self, ax, items=["Reset", "Write", "Set Points"]):
909 # suppresscomposite
910 self.ax = ax
911 self.fig = self.ax.get_figure()
912 self.fig.suppressComposite = True
914 # add the menu items
915 self.menuitems = {}
916 self.add_items(items=items)
918 # mpl_connect move to set hover stuff
919 self.cid = self.fig.canvas.mpl_connect("motion_notify_event", self.on_move)
921 def on_move(self, event):
922 draw = False
923 for item in self.menuitems.values():
924 draw = item.set_hover(event)
925 if draw:
926 self.fig.canvas.draw()
927 break
929 def add_items(self, items):
930 coords = self.get_coords(len(items))
931 for i, (s, c) in enumerate(zip(items, coords)):
932 # on_select = lambda item: print(f"You pressed {item.labelstr}.")
933 on_select = lambda item: True
934 item = MenuItem(
935 self.fig, (0, c[0]), 1, c[1] - c[0], labelstr=s, on_select=on_select
936 )
937 self.menuitems[s] = item
938 self.ax.add_artist(item)
940 def get_coords(self, no_items, gap_space=0.05):
941 no_gaps = no_items - 1
942 total_length = 1 - no_gaps * gap_space
943 length_per = total_length / no_items
944 coords = []
945 for i in range(no_items):
946 if i == 0:
947 coords.append([0, length_per])
948 else:
949 coords.append(
950 [
951 coords[i - 1][1] + gap_space,
952 coords[i - 1][1] + gap_space + length_per,
953 ]
954 )
955 return coords
958class StatusMenu(Menu):
959 def __init__(self, ax):
960 items = [
961 "Lasso",
962 "Rectangle",
963 "Ellipse",
964 "Polygon",
965 "Path",
966 "Bezier",
967 "Idle",
968 "Mode",
969 ]
970 # call the parent class to use its get_coords() method
971 # the add_items method is overwritten
972 super(StatusMenu, self).__init__(ax, items=items)
974 # set the label rotation and the current status
975 self.set_idle()
977 # click notify event to change self.status and switch ModeButtons off
978 self.cid = self.fig.canvas.mpl_connect("button_release_event", self.on_click)
980 def set_idle(self):
981 """Sets the idle status. Called at __init__ and
982 when nothing is pressed."""
983 self.status = "Idle"
984 self.menuitems[self.status].set_hover_props(True)
985 self.menuitems[self.status].pressed = True
987 def on_click(self, event):
988 draw = False
989 for key, item in self.menuitems.items():
990 draw = item.pressed
991 if draw and key != self.status:
992 print(f"Changing Mode to {key}")
993 self.menuitems[self.status].set_hover_props(False)
994 self.menuitems[self.status].pressed = False
995 self.menuitems[self.status].check_select(event, overwrite=True)
996 self.status = key
997 else:
998 if all([not item.pressed for item in self.menuitems.values()]):
999 print(f"Changing Mode to Idle")
1000 self.set_idle()
1002 def add_items(self, items):
1003 coords = self.get_coords(len(items))
1004 for i, (s, c) in enumerate(zip(items, coords)):
1005 if s == "Mode":
1006 on_select = lambda item: None # print(f"You pressed {item.labelstr}.")
1007 # overwrite props and hoverprops
1008 props = {
1009 "labelcolor": "black",
1010 "bgcolor": "orange",
1011 "fontsize": 15,
1012 "alpha": 1,
1013 }
1014 hoverprops = {
1015 "labelcolor": "black",
1016 "bgcolor": "orange",
1017 "fontsize": 15,
1018 "alpha": 1,
1019 }
1020 item = MenuItem(
1021 self.fig,
1022 (0, c[0]),
1023 1,
1024 c[1] - c[0],
1025 labelstr=s,
1026 on_select=on_select,
1027 props=props,
1028 hoverprops=hoverprops,
1029 )
1030 item.pressed = False
1031 else:
1032 # on_select = lambda item: print(f"You pressed {item.labelstr}.")
1033 on_select = lambda item: True
1034 # item = ModeButton(self.fig, self.ax, s, on_select=on_select)
1035 item = ModeButton(
1036 self.fig, (0, c[0]), 1, c[1] - c[0], labelstr=s, on_select=on_select
1037 )
1038 self.menuitems[s] = item
1039 self.ax.add_artist(item)