Coverage for encodermap/plot/utils.py: 15%

464 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/plot/utils.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22 

23"""Utils for plotting. Ideas and Code taken from the matplotlib documentation. 

24Menu: https://matplotlib.org/3.1.0/gallery/widgets/menu.html 

25LassoSelector: https://matplotlib.org/3.1.1/gallery/widgets/lasso_selector_demo_sgskip.html 

26PolygonSelector: https://matplotlib.org/3.1.3/gallery/widgets/polygon_selector_demo.html 

27Bezier: https://gist.github.com/gavincangan/b88a978e878e9bb1c0f8804e3af8de3c 

28 

29""" 

30 

31################################################################################ 

32# Imports 

33################################################################################ 

34 

35# Future Imports at the top 

36from __future__ import annotations 

37 

38# Standard Library Imports 

39import copy 

40import os 

41import shutil 

42import warnings 

43 

44# Third Party Imports 

45import matplotlib as mpl 

46import matplotlib.pyplot as plt 

47import numpy as np 

48from matplotlib.artist import Artist 

49from matplotlib.lines import Line2D 

50from matplotlib.patches import Rectangle 

51from matplotlib.path import Path 

52from matplotlib.widgets import LassoSelector, PolygonSelector, RectangleSelector 

53from optional_imports import _optional_import 

54from packaging import version 

55from scipy.special import binom 

56from tqdm import tqdm 

57 

58# Encodermap imports 

59from encodermap.misc.misc import _datetime_windows_and_linux_compatible 

60 

61 

62################################################################################ 

63# Optional Imports 

64################################################################################ 

65 

66 

67jinja2 = _optional_import("jinja2") 

68ngl = _optional_import("nglview") 

69mda = _optional_import("MDAnalysis") 

70md = _optional_import("mdtraj") 

71pd = _optional_import("pandas") 

72binom = _optional_import("scipy", "special.binom") 

73plotly_lasso = _optional_import("plotly", "callbacks.LassoSelector") 

74 

75 

76################################################################################ 

77# Typing 

78################################################################################ 

79 

80 

81# Standard Library Imports 

82from typing import TYPE_CHECKING, Any, Optional, Union 

83 

84 

85if TYPE_CHECKING: 

86 # Local Folder Imports 

87 from ..trajinfo.info_all import TrajEnsemble 

88 

89 

90################################################################################ 

91# Globals 

92################################################################################ 

93 

94_all__ = [ 

95 "Props", 

96 "SelectFromCollection", 

97 "StatusMenu", 

98 "Menu", 

99 "ModeButton", 

100 "MenuItem", 

101] 

102 

103 

104################################################################################ 

105# Functions (mainly for ternary plotting) 

106################################################################################ 

107 

108 

109def abc_to_rgb(A=0.0, B=0.0, C=0.0): 

110 """Map values A, B, C (all in domain [0,1]) to 

111 suitable red, green, blue values.""" 

112 return (min(B + C, 1.0), min(A + C, 1.0), min(A + B, 1.0)) 

113 

114 

115def digitize_dssp(trajs, dssp, imshow=True, bins=100, progbar=None): 

116 """Digitizes the DSSP array""" 

117 if progbar is None: 

118 progbar = tqdm(total=all_dssp.shape[0] * 4, position=0, leave=True) 

119 if np.any(np.isin(dssp, ["H", "E", "C"], assume_unique=False)) and not "B" in dssp: 

120 sorted_ = np.array(["H", "E", "C"]) 

121 else: 

122 sorted_ = np.array(["H", "B", "E", "G", "I", "T", "S", ""]) 

123 num_aas = dssp.shape[1] 

124 uniques = correct_missing_uniques( 

125 [np.concatenate(np.unique(d, return_counts=True)) for d in dssp], 

126 sorted_=sorted_, 

127 progbar=progbar, 

128 ) 

129 uniques, counts = uniques[:, :3], uniques[:, 3:] 

130 indices = np.vstack([u.argsort()[sorted_.argsort()] for u in uniques]) 

131 counts = ( 

132 np.vstack([c[i] for c, i in zip(counts, indices)]).astype(np.float) / num_aas 

133 ) 

134 colors = np.array([abc_to_rgb(a, b, c) for a, b, c in zip(*counts.T)]) 

135 if imshow: 

136 print("digitizing") 

137 digitized = np.full((bins, bins, 3), (1.0, 1.0, 1.0)) 

138 x = trajs.lowd[:, 0] 

139 y = trajs.lowd[:, 1] 

140 H, xedges, yedges = np.histogram2d(x=x, y=y, bins=bins) 

141 for i in range(bins): 

142 for j in range(bins): 

143 where = np.where( 

144 ( 

145 (trajs.lowd[:, 0] >= xedges[i]) 

146 & (trajs.lowd[:, 0] < xedges[i + 1]) 

147 ) 

148 & ( 

149 (trajs.lowd[:, 1] >= yedges[j]) 

150 & (trajs.lowd[:, 1] < yedges[j + 1]) 

151 ) 

152 )[0] 

153 if len(where) != 0: 

154 counts_ = np.mean(counts[where], axis=0) 

155 try: 

156 color = abc_to_rgb(*counts_) 

157 except TypeError: 

158 print(counts[where], counts_) 

159 raise 

160 digitized[i, j] = color 

161 progbar.update() 

162 return digitized 

163 else: 

164 return colors 

165 

166 

167def correct_missing_uniques(uniques, sorted_, progbar=None): 

168 """Takes a list of DSSP letters and counts and adds zeros for all missing letters in sorted_. 

169 

170 Args: 

171 uniques (list of np.ndarray): A list of np.ndarrays. For every frame the DSSP letters 

172 and their respective counts should be provided. So for example 

173 ['H', 151] or ['H', 'E', 75, 76]. 

174 sorted_ (np.ndarray): The sorted list of DSSP letters. This list will be used to update every 

175 frame. So The using ['H', 'E', 'C'], the two examples from above will become ['H', 'E', 'C', 151, 0, 0] 

176 or ['H', 'E', 'C', 75, 76, 0], respectively. 

177 progbar (tqdm): A tqdm progbar. Defaults to None 

178 

179 Returns: 

180 np.ndarray: The corrected uniques. 

181 

182 """ 

183 if progbar is None: 

184 progbar = tqdm(total=len(uniques), position=0, leave=True) 

185 for i, u in enumerate(uniques): 

186 if len(u) != len(sorted_) * 2: 

187 letters, counts = np.split(u, 2) 

188 counts = counts.astype(int) 

189 newline = np.zeros(len(sorted_), dtype=int) 

190 for l, c in zip(letters, counts): 

191 newline[np.where(sorted_ == l)] = c 

192 newline = np.concatenate([sorted_, newline.astype(str)]) 

193 uniques[i] = newline 

194 progbar.update() 

195 uniques = np.vstack(uniques) 

196 return uniques 

197 

198 

199def _get_system_info() -> dict[str, Any]: 

200 # Standard Library Imports 

201 import getpass 

202 import platform 

203 import re 

204 import socket 

205 import uuid 

206 

207 # Third Party Imports 

208 import psutil 

209 

210 info = {} 

211 try: 

212 info["platform"] = platform.system() 

213 info["system_user"] = getpass.getuser() 

214 info["platform_release"] = platform.release() 

215 info["platform_version"] = platform.version() 

216 info["architecture"] = platform.machine() 

217 info["hostname"] = socket.gethostname() 

218 info["ip_address"] = socket.gethostbyname(socket.gethostname()) 

219 info["mac_address"] = ":".join(re.findall("..", "%012x" % uuid.getnode())) 

220 info["processor"] = platform.processor() 

221 info["ram"] = str(round(psutil.virtual_memory().total / (1024.0**3))) + " GB" 

222 return info 

223 except Exception: 

224 return info 

225 

226 

227def _check_all_templates_defined(template, info_dict): 

228 # Standard Library Imports 

229 import re 

230 

231 regex = r"\{(.*?)\}" 

232 matches = re.finditer(regex, template, re.MULTILINE | re.DOTALL) 

233 min_matches = [] 

234 for matchNum, match in enumerate(matches): 

235 for groupNum in range(0, len(match.groups())): 

236 min_matches.append(match.group(groupNum)) 

237 min_matches = list( 

238 set(map(lambda x: x.lstrip("{{").rstrip("}}"), [i for i in min_matches])) 

239 ) 

240 if all(key in info_dict for key in min_matches): 

241 return True 

242 else: 

243 missing = set(min_matches).difference(info_dict) 

244 raise Exception( 

245 f"Not all expressions defined in template. Missing expressions: {missing}" 

246 ) 

247 

248 

249def _create_readme(main_path, now, info_dict): # pragma: no cover 

250 # Third Party Imports 

251 from pip._internal.operations import freeze 

252 

253 # Local Folder Imports 

254 from .._version import get_versions 

255 

256 __version__ = get_versions()["version"] 

257 # Local Folder Imports 

258 from .jinja_template import template 

259 

260 # update info dict 

261 md_file = os.path.join(main_path, "README.md") 

262 pip_freeze = "" 

263 for i in freeze.freeze(): 

264 pip_freeze += f" {i}\n" 

265 info_dict.update({"pip_freeze": pip_freeze}) 

266 info_dict.update({"filename": md_file.split(".")[0]}) 

267 info_dict.update({"now": now}) 

268 info_dict.update({"encodermap_version": __version__}) 

269 info_dict.update(_get_system_info()) 

270 

271 # check 

272 assert _check_all_templates_defined(template, info_dict) 

273 

274 # jinja2 

275 template = jinja2.Template(template) 

276 msg = template.render( 

277 info_dict, 

278 ) 

279 

280 # write 

281 with open(md_file, "w") as f: 

282 f.write(msg) 

283 

284 

285def _unpack_cluster_info( 

286 trajs: TrajEnsemble, 

287 main_path: Union[Path, str], 

288 selector: Any, 

289 dummy_traj: TrajEnsemble, 

290 align_string: str, 

291 col: str, 

292 display: Any, 

293 progbar: Any, 

294) -> tuple[int, Path]: 

295 # Standard Library Imports 

296 from pathlib import Path 

297 

298 main_path = Path(main_path) 

299 max_ = trajs.CVs[col].max() 

300 where = np.where(trajs.CVs[col] == max_)[0] 

301 length = len(where) 

302 now = _datetime_windows_and_linux_compatible() 

303 

304 # make dirs 

305 os.makedirs(os.path.join(main_path, "clusters"), exist_ok=True) 

306 main_path = os.path.join(main_path, f"clusters/{now}") 

307 os.makedirs(main_path, exist_ok=True) 

308 progbar.update() 

309 

310 # define names 

311 h5_name = os.path.join(main_path, f"cluster_id_{max_}_stacked_{length}_structs.h5") 

312 pdb_start_name = os.path.join(main_path, f"cluster_id_{max_}_start.pdb") 

313 pdb_origin_names = os.path.join(main_path, f"cluster_id_{max_}_pdb_origins.txt") 

314 xtc_name = os.path.join(main_path, f"cluster_id_{max_}.xtc") 

315 csv_name = os.path.join(main_path, f"cluster_id_{max_}_selected_points.csv") 

316 png_name = os.path.join(main_path, f"cluster_id_{max_}_image.png") 

317 lowd_npy_name = os.path.join( 

318 main_path, f"cluster_id_{max_}_cluster_lowd_points.npy" 

319 ) 

320 indices_npy_name = os.path.join( 

321 main_path, f"cluster_id_{max_}_cluster_lowd_points_indices.npy" 

322 ) 

323 current_clustering = os.path.join( 

324 main_path, 

325 f"cluster_id_{max_}_cluster_current_clustering_%s.npy" % col, 

326 ) 

327 selector_npy_name = os.path.join( 

328 main_path, f"cluster_id_{max_}_selector_points.npy" 

329 ) 

330 parents_trajs = os.path.join( 

331 main_path, f"cluster_id_{max_}_all_plotted_trajs_in_correct_order.txt" 

332 ) 

333 

334 # save edges of selector 

335 try: 

336 verts = np.vstack([selector.xs, selector.ys]).T 

337 selector_npy_name = selector_npy_name.replace( 

338 "selector", f"{selector.__class__.__name__.lower()}" 

339 ) 

340 np.save(selector_npy_name, verts) 

341 except Exception as e: 

342 display.outputs = [] 

343 with display: 

344 print(f"Currently only plotly's LassoSelector is available. Exception: {e}") 

345 return 

346 progbar.update() 

347 

348 # save the output as a h5 file, so we can also save CVs and lowd 

349 dummy_traj.save(h5_name) 

350 progbar.update() 

351 

352 # render png 

353 # plot_cluster(trajs, h5_name, png_name, max_) 

354 

355 # save all trajs 

356 with open(parents_trajs, "w") as f: 

357 for traj in trajs: 

358 f.write( 

359 f"{os.path.abspath(traj.traj_file)} {os.path.abspath(traj.top_file)} {traj.traj_num} {traj.common_str}\n" 

360 ) 

361 

362 # create df 

363 if trajs.lowd.shape[-1] == 2: 

364 lowd_coords = {"x": [], "y": []} 

365 elif trajs.lowd.shaoe[-1] == 3: 

366 lowd_coords = {"x": [], "y": [], "z": []} 

367 else: 

368 lowd_coords = {f"lowd_{i}": [] for i in range(trajs.lowd.shape[-1])} 

369 progbar.update() 

370 df = pd.DataFrame( 

371 { 

372 "trajectory file": [], 

373 "topology file": [], 

374 "frame number": [], 

375 "time": [], 

376 **lowd_coords, 

377 "cluster id": [], 

378 "trajectory number": [], 

379 } 

380 ) 

381 # display.outputs = [] 

382 # with display: 

383 # print(f"Dataframe created {df.shape=}. {where=}") 

384 

385 progbar.update() 

386 for frame_num, frame in dummy_traj.iterframes(): 

387 if version.parse(pd.__version__) >= version.parse("2.0.0"): 

388 df.loc[len(df)] = pd.Series( 

389 { 

390 "trajectory file": os.path.abspath(frame.traj_file), 

391 "topology file": os.path.abspath(frame.top_file), 

392 "frame number": frame_num, 

393 "time": frame.time[0], 

394 "cluster id": max_, 

395 "trajectory number": frame.traj_num, 

396 **{k: v for k, v in zip(lowd_coords.keys(), frame.lowd[0])}, 

397 } 

398 ) 

399 else: 

400 df = df.append( 

401 { 

402 "trajectory file": os.path.abspath(frame.traj_file), 

403 "topology file": os.path.abspath(frame.top_file), 

404 "frame number": frame_num, 

405 "time": frame.time[0], 

406 "cluster id": max_, 

407 "trajectory number": frame.traj_num, 

408 **{k: v for k, v in zip(lowd_coords.keys(), frame.lowd[0])}, 

409 }, 

410 ignore_index=True, 

411 ) 

412 display.outputs = [] 

413 progbar.update() 

414 df = df.astype( 

415 dtype={ 

416 "trajectory file": str, 

417 "topology file": str, 

418 "frame number": int, 

419 "time": float, 

420 **{k: float for k in lowd_coords}, 

421 "cluster id": int, 

422 "trajectory number": int, 

423 } 

424 ) 

425 df.to_csv(csv_name, index=False) 

426 progbar.update() 

427 

428 # save npy 

429 np.save(lowd_npy_name, trajs.CVs[col][where]) 

430 np.save(indices_npy_name, where) 

431 np.save(current_clustering, trajs.CVs[col]) 

432 progbar.update() 

433 

434 # save full traj 

435 progbar.update() 

436 with open(pdb_origin_names, "w") as f: 

437 for i, (top, value) in enumerate(dummy_traj.trajs_by_top.items()): 

438 _pdb_start_name = pdb_start_name.replace( 

439 ".pdb", f"_traj_{i}_from_{trajs.basename_fn(value.top_files[0])}.pdb" 

440 ) 

441 _xtc_name = xtc_name.replace(".xtc", f"_traj_{i}.xtc") 

442 joined = value.join(progbar=False)[top] 

443 joined[0].save_pdb(_pdb_start_name) 

444 # shutil.copyfile(key, _pdb_start_name) 

445 f.write( 

446 f"{_pdb_start_name} is a copy (`shutil.copyfile`) of " 

447 f"{value.top_files[0]}. The corresponding trajectory files might " 

448 f"originate from other places. Refer to {parents_trajs} for info about xtcs.\n" 

449 ) 

450 joined.save_xtc(_xtc_name) 

451 

452 # create an info dict 

453 # Local Folder Imports 

454 from .jinja_template import h5_parents, h5_rebuild, xtc_parents, xtc_rebuild 

455 

456 info_dict = { 

457 "h5_name": h5_name, 

458 "pdb_start_name": pdb_start_name, 

459 "pdb_origin_names": pdb_origin_names, 

460 "xtc_name": xtc_name, 

461 "csv_name": csv_name, 

462 "png_name": png_name, 

463 "lowd_npy_name": lowd_npy_name, 

464 "indices_npy_name": indices_npy_name, 

465 "current_clustering": current_clustering, 

466 "selector_npy_name": selector_npy_name, 

467 "parents_trajs": parents_trajs, 

468 } 

469 

470 if all([t.extension == ".h5" for t in trajs]): 

471 template = jinja2.Template(h5_rebuild) 

472 rebuild_clustering_info = template.render( 

473 {"h5_file": trajs[0]._traj_file.resolve(), **info_dict}, 

474 ) 

475 template = jinja2.Template(h5_parents) 

476 parents_trajs = template.render( 

477 {"h5_file": trajs[0]._traj_file.resolve(), **info_dict}, 

478 ) 

479 else: 

480 template = jinja2.Template(xtc_rebuild) 

481 rebuild_clustering_info = template.render( 

482 info_dict, 

483 ) 

484 template = jinja2.Template(xtc_parents) 

485 parents_trajs = template.render( 

486 info_dict, 

487 ) 

488 

489 info_dict["parents_trajs"] = parents_trajs 

490 info_dict["rebuild_clustering_info"] = rebuild_clustering_info 

491 info_dict = {k: os.path.basename(v) for k, v in info_dict.items()} 

492 info_dict.update({"cluster_id": max_}) 

493 info_dict.update({"cluster_id": max_}) 

494 info_dict.update({"n_trajs": trajs.n_trajs}) 

495 info_dict.update({"cluster_n_points": len(where)}) 

496 info_dict.update({"basename": trajs[0].basename}) 

497 info_dict.update({"cluster_abspath": main_path}) 

498 

499 # create a readme 

500 _create_readme(main_path, now, info_dict) 

501 progbar.update() 

502 

503 return max_, main_path 

504 

505 

506################################################################################ 

507# Classes 

508################################################################################ 

509 

510 

511class Props: 

512 """Class to contain the properties of MenuItems. 

513 

514 Each MenuItem contains two copies of this class. 

515 One for props when the mouse cursor hovers over them. 

516 One for the rest of times. 

517 Can be used as class or as dict. 

518 

519 Attributes: 

520 labelcolor (str): The color of the text. 

521 labelcolor_rgb (tuple): The RGBA value of labelcolor. 

522 bgcolor (str): The color of the background. 

523 bgcolor_rgb (tuple): The RGBA value of bgcolor. 

524 fontsize (int): The fontsize. 

525 alpha (float): The alpha value of the background. 

526 defaults (dict): Class Variable of default values. 

527 

528 Examples: 

529 >>> props = Props() 

530 >>> props.labelcolor 

531 'black' 

532 

533 """ 

534 

535 defaults = {"labelcolor": "black", "bgcolor": "blue", "fontsize": 15, "alpha": 0.2} 

536 

537 def __init__(self, **kwargs): 

538 """Instantiate the Props class. 

539 

540 Takes a dict as input and overwrites the class defaults. The dict is directly 

541 stored as an attribute and can be accessed via dot-notation. 

542 

543 Args: 

544 **kwargs: dict containing values. If unknonwn values are passed they will be dropped. 

545 

546 """ 

547 self._setattr(self.defaults) 

548 for key, value in kwargs.items(): 

549 if key not in self.__dict__.keys(): 

550 print(f"Dropping unknown dict entry for {{'{key}': {value}}}") 

551 else: 

552 setattr(self, key, value) 

553 

554 @property 

555 def bgcolor_rgb(self): 

556 return mpl.colors.to_rgba(self.bgcolor) 

557 

558 @property 

559 def labelcolor_rgb(self): 

560 return mpl.colors.to_rgba(self.labelcolor) 

561 

562 def _setattr(self, dictionary): 

563 for key, value in dictionary.items(): 

564 setattr(self, key, value) 

565 

566 def __setitiem__(self, key, value): 

567 setattr(self, key, value) 

568 

569 def __getitem__(self, item): 

570 return getattr(self, item) 

571 

572 def _string_summary(self): 

573 return "ep.plotting.interactive.Props class containing Info for the MenuItem class." 

574 

575 def __str__(self): 

576 return self._string_summary() 

577 

578 def __repr__(self): 

579 return f"<{self._string_summary()} Object at 0x{id(self):02x}>" 

580 

581 

582class BezierBuilder(object): 

583 """Bézier curve interactive builder.""" 

584 

585 def __init__(self, control_polygon, ax_main, ax_bernstein=None): 

586 """Constructor. 

587 Receives the initial control polygon of the curve. 

588 """ 

589 self.control_polygon = control_polygon 

590 self.xp = list(control_polygon.get_xdata()) 

591 self.yp = list(control_polygon.get_ydata()) 

592 self.canvas = control_polygon.figure.canvas 

593 self.ax_main = ax_main 

594 self.ax_bernstein = ax_bernstein 

595 

596 # Event handler for mouse clicking 

597 self.cid = self.canvas.mpl_connect("button_press_event", self) 

598 

599 # Create Bézier curve 

600 line_bezier = Line2D([], [], c=control_polygon.get_markeredgecolor()) 

601 self.bezier_curve = self.ax_main.add_line(line_bezier) 

602 

603 def __call__(self, event): 

604 # Ignore clicks outside axes 

605 if event.inaxes != self.control_polygon.axes: 

606 return 

607 

608 # Add point 

609 self.xp.append(event.xdata) 

610 self.yp.append(event.ydata) 

611 self.control_polygon.set_data(self.xp, self.yp) 

612 

613 # Rebuild Bézier curve and update canvas 

614 self.bezier_curve.set_data(*self._build_bezier()) 

615 self._update_bernstein() 

616 self._update_bezier() 

617 

618 def _build_bezier(self): 

619 x, y = Bezier(list(zip(self.xp, self.yp))).T 

620 return x, y 

621 

622 def _update_bezier(self): 

623 self.canvas.draw() 

624 

625 def _update_bernstein(self): 

626 if self.ax_bernstein is not None: 

627 N = len(self.xp) - 1 

628 t = np.linspace(0, 1, num=200) 

629 ax = self.ax_bernstein 

630 ax.clear() 

631 for kk in range(N + 1): 

632 ax.plot(t, Bernstein(N, kk)(t)) 

633 ax.set_title("Bernstein basis, N = {}".format(N)) 

634 ax.set_xlim(0, 1) 

635 ax.set_ylim(0, 1) 

636 

637 @property 

638 def ind(self): 

639 return np.vstack(self.bezier_curve.get_data()).T 

640 

641 def disconnect(self): 

642 self.canvas.mpl_disconnect(self.cid) 

643 

644 

645def Bernstein(n, k): 

646 """Bernstein polynomial.""" 

647 coeff = binom(n, k) 

648 

649 def _bpoly(x): 

650 return coeff * x**k * (1 - x) ** (n - k) 

651 

652 return _bpoly 

653 

654 

655def Bezier(points, num=200): 

656 """Build Bézier curve from points.""" 

657 N = len(points) 

658 t = np.linspace(0, 1, num=num) 

659 curve = np.zeros((num, 2)) 

660 for ii in range(N): 

661 curve += np.outer(Bernstein(N - 1, ii)(t), points[ii]) 

662 return curve 

663 

664 

665class DummyTool: 

666 def disconnect(self): 

667 pass 

668 

669 

670class SelectFromCollection(object): 

671 """Select indices from a matplotlib collection using `LassoSelector`. 

672 

673 Selected indices are saved in the `ind` attribute. This tool fades out the 

674 points that are not part of the selection (i.e., reduces their alpha 

675 values). If your collection has alpha < 1, this tool will permanently 

676 alter the alpha values. 

677 

678 Note that this tool selects collection objects based on their *origins* 

679 (i.e., `offsets`). 

680 

681 Args: 

682 ax (matplotlib.axes.Axes): Axes to interact with. 

683 collection (matplotlib.collections.Collection): Subclass of collection 

684 you want to select from. 

685 alpha_other (float): To highlight a selection, this tool sets all 

686 selected points to an alpha value of 1 and non-selected points to 

687 `alpha_other`. Needs to fulfill 0 <= alpha_other <= 1 

688 

689 """ 

690 

691 def __init__(self, ax, collection, alpha_other=0.3, selector=LassoSelector): 

692 self.canvas = ax.figure.canvas 

693 self.collection = collection 

694 self.alpha_other = alpha_other 

695 

696 self.xys = collection.get_offsets() 

697 self.Npts = len(self.xys) 

698 

699 # Ensure that we have separate colors for each object 

700 self.fc = collection.get_facecolors() 

701 if len(self.fc) == 0: 

702 raise ValueError("Collection must have a facecolor") 

703 elif len(self.fc) == 1: 

704 self.fc = np.tile(self.fc, (self.Npts, 1)) 

705 

706 self.lasso = selector(ax, onselect=self.onselect, useblit=False) 

707 self.ind = [] 

708 

709 def onselect(self, verts): 

710 path = Path(verts) 

711 self.ind = np.nonzero(path.contains_points(self.xys))[0] 

712 self.fc[:, -1] = self.alpha_other 

713 self.fc[self.ind, -1] = 1 

714 self.collection.set_facecolors(self.fc) 

715 self.canvas.draw_idle() 

716 

717 def disconnect(self): 

718 self.lasso.disconnect_events() 

719 self.fc[:, -1] = 1 

720 self.collection.set_facecolors(self.fc) 

721 self.canvas.draw_idle() 

722 

723 def _string_summary(self): 

724 s = ( 

725 f"encodermap.SelectFromCollection object. Selecting points " 

726 f"from a set of {self.Npts} points. The current selector tool " 

727 f"is matplotlotlib's {self.lasso} tool. The selected points are " 

728 f"{self.ind}" 

729 ) 

730 return s 

731 

732 def __str__(self): 

733 return self._string_summary() 

734 

735 def __repr__(self): 

736 return self._string_summary() 

737 

738 

739class MenuItem(mpl.patches.Rectangle, mpl.artist.Artist): 

740 def __init__( 

741 self, 

742 fig, 

743 xy, 

744 width, 

745 height, 

746 labelstr, 

747 props={}, 

748 hoverprops={}, 

749 on_select=None, 

750 standalone=False, 

751 ): 

752 # define props and hoverprops 

753 self.labelstr = labelstr 

754 self.standalone = standalone 

755 props = { 

756 **{"labelcolor": "black", "bgcolor": "blue", "fontsize": 15, "alpha": 0.2}, 

757 **props, 

758 } 

759 hoverprops = { 

760 **{"labelcolor": "C0", "bgcolor": "yellow", "fontsize": 15, "alpha": 0.2}, 

761 **hoverprops, 

762 } 

763 self.props = Props(**props) 

764 self.hoverprops = Props(**hoverprops) 

765 

766 # set the on_select method 

767 self.on_select = on_select 

768 

769 # add a select event 

770 # hover event is only included if standalone is true 

771 # i.e. the MenuItem is not Instantiated from a Menu 

772 # In the case this is instantiated from a menu. The menu wil handle the hover 

773 self.cid_button = fig.canvas.mpl_connect( 

774 "button_release_event", self.check_select 

775 ) 

776 if self.standalone: 

777 self.cid_move = fig.canvas.mpl_connect( 

778 "motion_notify_event", self.set_hover 

779 ) 

780 

781 # Instantiate text and Rectangle 

782 mpl.patches.Rectangle.__init__( 

783 self, xy, width, height, label=self.labelstr, zorder=1 

784 ) 

785 self.text = mpl.text.Text( 

786 0, 

787 0, 

788 self.labelstr, 

789 color=self.props.labelcolor, 

790 fontproperties=dict(weight="bold", size=self.props.fontsize), 

791 zorder=2, 

792 verticalalignment="center", 

793 horizontalalignment="center", 

794 ) 

795 

796 # final thing to do is set the props of 

797 # the rectangle based whether a cursor hovers 

798 self.set_hover_props(False) 

799 

800 # after artist has been placed get it ready 

801 # for accepting hover events 

802 self.hover = False 

803 

804 def set_figure(self, figure): 

805 """Overwriting Base Class method to include labelstr""" 

806 mpl.patches.Rectangle.set_figure(self, figure) 

807 self.text.set_figure(figure) 

808 

809 def set_axes(self, axes): 

810 """Overwriting Base Class method to include labelstr""" 

811 mpl.patches.Rectangle.set_axes(self, axes) 

812 self.text.set_axes(axes) 

813 

814 def set_transform(self, transform): 

815 """Overwriting Base Class method to include labelstr""" 

816 mpl.patches.Rectangle.set_transform(self, transform) 

817 # set text to center of self(.rect) 

818 bbox = self.get_bbox() 

819 x = bbox.x0 + 0.5 * bbox.width 

820 y = bbox.y0 + 0.5 * bbox.height 

821 texttrans = mpl.transforms.Affine2D().translate(x, y) + self.axes.transData 

822 self.text.set_transform(texttrans) 

823 

824 def set_data(self, x, y): 

825 """Overwriting Base Class method to include labelstr""" 

826 # if len(x): 

827 # self.text.set_position((x[-1], y[-1])) 

828 mpl.patches.Rectangle.set_data(self, x, y) 

829 self.text.set_color(self.props.labelcolor_rgb) 

830 

831 def check_select(self, event): 

832 over, _ = self.contains(event) 

833 if not over: 

834 return 

835 if self.on_select is not None: 

836 self.on_select(self) 

837 return True 

838 

839 def draw(self, renderer): 

840 """Overwriting Base Class method to include labelstr""" 

841 # draw my label at the end of the line with 2 pixel offset 

842 mpl.patches.Rectangle.draw(self, renderer) 

843 self.text.draw(renderer) 

844 

845 def set_hover_props(self, check): 

846 if check: 

847 props = self.hoverprops 

848 else: 

849 props = self.props 

850 

851 self.set(facecolor=props.bgcolor_rgb, alpha=props.alpha) 

852 self.text.set_color(props.labelcolor_rgb) 

853 

854 def set_hover(self, event): 

855 """Check the hover status of MenuItem""" 

856 check, _ = self.contains(event) 

857 changed = check != self.hover 

858 if changed: 

859 self.set_hover_props(check) 

860 self.hover = check 

861 if changed and self.standalone: 

862 self.figure.canvas.draw() 

863 return changed 

864 

865 

866class ModeButton(MenuItem): 

867 def __init__(self, *args, **kwargs): 

868 # overwite the labelstr with 'Idle' 

869 super(ModeButton, self).__init__(*args, **kwargs) 

870 

871 # overwrite the on_select function with on_select_rotation 

872 self.pressed = False 

873 

874 def check_select(self, event, overwrite=False): 

875 if overwrite: 

876 return 

877 over, _ = self.contains(event) 

878 if not over: 

879 return 

880 if self.on_select is not None: 

881 self.on_select(self) 

882 self.pressed = not self.pressed 

883 if self.pressed: 

884 self.set_hover_props(True) 

885 return self.pressed 

886 

887 def set_hover(self, event): 

888 """Check the hover status of MenuItem""" 

889 check, _ = self.contains(event) 

890 changed = check != self.hover 

891 if changed and not self.pressed: 

892 self.set_hover_props(check) 

893 self.hover = check 

894 return changed 

895 

896 def _on_select_rotation(self): 

897 """Old function to rotate labelstrings.""" 

898 if self.status == len(self.labelrotation) - 1: 

899 self.status = 0 

900 else: 

901 self.status += 1 

902 self.labelstr = self.labelrotation[self.status] 

903 self.label.set_text(self.labelstr) 

904 # print(f"You pressed {self.labelstr}") 

905 

906 

907class Menu: 

908 def __init__(self, ax, items=["Reset", "Write", "Set Points"]): 

909 # suppresscomposite 

910 self.ax = ax 

911 self.fig = self.ax.get_figure() 

912 self.fig.suppressComposite = True 

913 

914 # add the menu items 

915 self.menuitems = {} 

916 self.add_items(items=items) 

917 

918 # mpl_connect move to set hover stuff 

919 self.cid = self.fig.canvas.mpl_connect("motion_notify_event", self.on_move) 

920 

921 def on_move(self, event): 

922 draw = False 

923 for item in self.menuitems.values(): 

924 draw = item.set_hover(event) 

925 if draw: 

926 self.fig.canvas.draw() 

927 break 

928 

929 def add_items(self, items): 

930 coords = self.get_coords(len(items)) 

931 for i, (s, c) in enumerate(zip(items, coords)): 

932 # on_select = lambda item: print(f"You pressed {item.labelstr}.") 

933 on_select = lambda item: True 

934 item = MenuItem( 

935 self.fig, (0, c[0]), 1, c[1] - c[0], labelstr=s, on_select=on_select 

936 ) 

937 self.menuitems[s] = item 

938 self.ax.add_artist(item) 

939 

940 def get_coords(self, no_items, gap_space=0.05): 

941 no_gaps = no_items - 1 

942 total_length = 1 - no_gaps * gap_space 

943 length_per = total_length / no_items 

944 coords = [] 

945 for i in range(no_items): 

946 if i == 0: 

947 coords.append([0, length_per]) 

948 else: 

949 coords.append( 

950 [ 

951 coords[i - 1][1] + gap_space, 

952 coords[i - 1][1] + gap_space + length_per, 

953 ] 

954 ) 

955 return coords 

956 

957 

958class StatusMenu(Menu): 

959 def __init__(self, ax): 

960 items = [ 

961 "Lasso", 

962 "Rectangle", 

963 "Ellipse", 

964 "Polygon", 

965 "Path", 

966 "Bezier", 

967 "Idle", 

968 "Mode", 

969 ] 

970 # call the parent class to use its get_coords() method 

971 # the add_items method is overwritten 

972 super(StatusMenu, self).__init__(ax, items=items) 

973 

974 # set the label rotation and the current status 

975 self.set_idle() 

976 

977 # click notify event to change self.status and switch ModeButtons off 

978 self.cid = self.fig.canvas.mpl_connect("button_release_event", self.on_click) 

979 

980 def set_idle(self): 

981 """Sets the idle status. Called at __init__ and 

982 when nothing is pressed.""" 

983 self.status = "Idle" 

984 self.menuitems[self.status].set_hover_props(True) 

985 self.menuitems[self.status].pressed = True 

986 

987 def on_click(self, event): 

988 draw = False 

989 for key, item in self.menuitems.items(): 

990 draw = item.pressed 

991 if draw and key != self.status: 

992 print(f"Changing Mode to {key}") 

993 self.menuitems[self.status].set_hover_props(False) 

994 self.menuitems[self.status].pressed = False 

995 self.menuitems[self.status].check_select(event, overwrite=True) 

996 self.status = key 

997 else: 

998 if all([not item.pressed for item in self.menuitems.values()]): 

999 print(f"Changing Mode to Idle") 

1000 self.set_idle() 

1001 

1002 def add_items(self, items): 

1003 coords = self.get_coords(len(items)) 

1004 for i, (s, c) in enumerate(zip(items, coords)): 

1005 if s == "Mode": 

1006 on_select = lambda item: None # print(f"You pressed {item.labelstr}.") 

1007 # overwrite props and hoverprops 

1008 props = { 

1009 "labelcolor": "black", 

1010 "bgcolor": "orange", 

1011 "fontsize": 15, 

1012 "alpha": 1, 

1013 } 

1014 hoverprops = { 

1015 "labelcolor": "black", 

1016 "bgcolor": "orange", 

1017 "fontsize": 15, 

1018 "alpha": 1, 

1019 } 

1020 item = MenuItem( 

1021 self.fig, 

1022 (0, c[0]), 

1023 1, 

1024 c[1] - c[0], 

1025 labelstr=s, 

1026 on_select=on_select, 

1027 props=props, 

1028 hoverprops=hoverprops, 

1029 ) 

1030 item.pressed = False 

1031 else: 

1032 # on_select = lambda item: print(f"You pressed {item.labelstr}.") 

1033 on_select = lambda item: True 

1034 # item = ModeButton(self.fig, self.ax, s, on_select=on_select) 

1035 item = ModeButton( 

1036 self.fig, (0, c[0]), 1, c[1] - c[0], labelstr=s, on_select=on_select 

1037 ) 

1038 self.menuitems[s] = item 

1039 self.ax.add_artist(item)