Coverage for encodermap/plot/utils.py: 38%

465 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/plot/utils.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22 

23"""Utils for plotting. Ideas and Code taken from the matplotlib documentation. 

24Menu: https://matplotlib.org/3.1.0/gallery/widgets/menu.html 

25LassoSelector: https://matplotlib.org/3.1.1/gallery/widgets/lasso_selector_demo_sgskip.html 

26PolygonSelector: https://matplotlib.org/3.1.3/gallery/widgets/polygon_selector_demo.html 

27Bezier: https://gist.github.com/gavincangan/b88a978e878e9bb1c0f8804e3af8de3c 

28 

29""" 

30 

31############################################################################## 

32# Imports 

33############################################################################## 

34 

35import os 

36import shutil 

37 

38import matplotlib as mpl 

39import matplotlib.pyplot as plt 

40import numpy as np 

41from matplotlib.artist import Artist 

42from matplotlib.lines import Line2D 

43from matplotlib.patches import Rectangle 

44from matplotlib.path import Path 

45from matplotlib.widgets import LassoSelector, PolygonSelector, RectangleSelector 

46from scipy.special import binom 

47from tqdm import tqdm 

48 

49from .._optional_imports import _optional_import 

50from ..encodermap_tf1.backmapping import dihedral_backmapping 

51from ..misc.clustering import gen_dummy_traj, get_cluster_frames 

52from ..misc.misc import _datetime_windows_and_linux_compatible 

53from .plotting import plot_cluster 

54 

55############################################################################## 

56# Optional Imports 

57############################################################################## 

58 

59 

60jinja2 = _optional_import("jinja2") 

61ngl = _optional_import("nglview") 

62mda = _optional_import("MDAnalysis") 

63md = _optional_import("mdtraj") 

64pd = _optional_import("pandas") 

65binom = _optional_import("scipy", "special.binom") 

66 

67############################################################################## 

68# Globals 

69############################################################################## 

70 

71_all__ = [ 

72 "Props", 

73 "SelectFromCollection", 

74 "StatusMenu", 

75 "Menu", 

76 "ModeButton", 

77 "MenuItem", 

78] 

79 

80 

81############################################################################## 

82# Functions (mainly for ternary plotting) 

83############################################################################## 

84 

85 

86def calculate_dssps(trajs, simplified=True): 

87 """Calculates dssp from TrajEnsemble. 

88 

89 Calculates dssps from ep.TrajEnsemble and returns a numpy array. 

90 

91 Args: 

92 trajs (ep.TrajEnsemble): The TrajEnsemble class to calculate the dssps from. 

93 simplified (bool, optional): Whether to return [H ,E, C] dssp values (True), 

94 or [H, B, E, G, I, T, S, ''] dssp values (False). Defaults to True. 

95 

96 Returns: 

97 np.ndarray: The dssp array of shape (trajs.n_frames, trajs.n_residues); 

98 

99 ToDo: 

100 Make it work with different protein legths. 

101 

102 """ 

103 all_dssp = [] 

104 pbar = tqdm(total=trajs.n_trajs, position=0, leave=True) 

105 for traj in trajs: 

106 dssp = md.compute_dssp(traj) 

107 all_dssp.append(dssp) 

108 pbar.update() 

109 pbar.close() 

110 all_dssp = np.vstack(all_dssp) 

111 return all_dssp 

112 

113 

114def abc_to_rgb(A=0.0, B=0.0, C=0.0): 

115 """Map values A, B, C (all in domain [0,1]) to 

116 suitable red, green, blue values.""" 

117 return (min(B + C, 1.0), min(A + C, 1.0), min(A + B, 1.0)) 

118 

119 

120def digitize_dssp(trajs, dssp, imshow=True, bins=100, progbar=None): 

121 """Digitizes the DSSP array""" 

122 if progbar is None: 

123 progbar = tqdm(total=all_dssp.shape[0] * 4, position=0, leave=True) 

124 if np.any(np.isin(dssp, ["H", "E", "C"], assume_unique=False)) and not "B" in dssp: 

125 sorted_ = np.array(["H", "E", "C"]) 

126 else: 

127 sorted_ = np.array(["H", "B", "E", "G", "I", "T", "S", ""]) 

128 num_aas = dssp.shape[1] 

129 uniques = correct_missing_uniques( 

130 [np.concatenate(np.unique(d, return_counts=True)) for d in dssp], 

131 sorted_=sorted_, 

132 progbar=progbar, 

133 ) 

134 uniques, counts = uniques[:, :3], uniques[:, 3:] 

135 indices = np.vstack([u.argsort()[sorted_.argsort()] for u in uniques]) 

136 counts = ( 

137 np.vstack([c[i] for c, i in zip(counts, indices)]).astype(np.float) / num_aas 

138 ) 

139 colors = np.array([abc_to_rgb(a, b, c) for a, b, c in zip(*counts.T)]) 

140 if imshow: 

141 print("digitizing") 

142 digitized = np.full((bins, bins, 3), (1.0, 1.0, 1.0)) 

143 x = trajs.lowd[:, 0] 

144 y = trajs.lowd[:, 1] 

145 H, xedges, yedges = np.histogram2d(x=x, y=y, bins=bins) 

146 for i in range(bins): 

147 for j in range(bins): 

148 where = np.where( 

149 ( 

150 (trajs.lowd[:, 0] >= xedges[i]) 

151 & (trajs.lowd[:, 0] < xedges[i + 1]) 

152 ) 

153 & ( 

154 (trajs.lowd[:, 1] >= yedges[j]) 

155 & (trajs.lowd[:, 1] < yedges[j + 1]) 

156 ) 

157 )[0] 

158 if len(where) != 0: 

159 counts_ = np.mean(counts[where], axis=0) 

160 try: 

161 color = abc_to_rgb(*counts_) 

162 except TypeError: 

163 print(counts[where], counts_) 

164 raise 

165 digitized[i, j] = color 

166 progbar.update() 

167 return digitized 

168 else: 

169 return colors 

170 

171 

172def correct_missing_uniques(uniques, sorted_, progbar=None): 

173 """Takes a list of DSSP letters and counts and adds zeros for all missing letters in sorted_. 

174 

175 Args: 

176 uniques (list of np.ndarray): A list of np.ndarrays. For every frame the DSSP letters 

177 and their respective counts should be provided. So for example 

178 ['H', 151] or ['H', 'E', 75, 76]. 

179 sorted_ (np.ndarray): The sorted list of DSSP letters. This list will be used to update every 

180 frame. So The using ['H', 'E', 'C'], the two examples from above will become ['H', 'E', 'C', 151, 0, 0] 

181 or ['H', 'E', 'C', 75, 76, 0], respectively. 

182 progbar (tqdm): A tqdm progbar. Defaults to None 

183 

184 Returns: 

185 np.ndarray: The corrected uniques. 

186 

187 """ 

188 if progbar is None: 

189 progbar = tqdm(total=len(uniques), position=0, leave=True) 

190 for i, u in enumerate(uniques): 

191 if len(u) != len(sorted_) * 2: 

192 letters, counts = np.split(u, 2) 

193 counts = counts.astype(int) 

194 newline = np.zeros(len(sorted_), dtype=int) 

195 for l, c in zip(letters, counts): 

196 newline[np.where(sorted_ == l)] = c 

197 newline = np.concatenate([sorted_, newline.astype(str)]) 

198 uniques[i] = newline 

199 progbar.update() 

200 uniques = np.vstack(uniques) 

201 return uniques 

202 

203 

204def _get_system_info(): 

205 import getpass 

206 import platform 

207 import re 

208 import socket 

209 import uuid 

210 

211 import psutil 

212 

213 info = {} 

214 try: 

215 info["platform"] = platform.system() 

216 info["system_user"] = getpass.getuser() 

217 info["platform_release"] = platform.release() 

218 info["platform_version"] = platform.version() 

219 info["architecture"] = platform.machine() 

220 info["hostname"] = socket.gethostname() 

221 info["ip_address"] = socket.gethostbyname(socket.gethostname()) 

222 info["mac_address"] = ":".join(re.findall("..", "%012x" % uuid.getnode())) 

223 info["processor"] = platform.processor() 

224 info["ram"] = str(round(psutil.virtual_memory().total / (1024.0**3))) + " GB" 

225 return info 

226 except Exception: 

227 return info 

228 

229 

230def _check_all_templates_defined(template, info_dict): 

231 import re 

232 

233 regex = r"\{(.*?)\}" 

234 matches = re.finditer(regex, template, re.MULTILINE | re.DOTALL) 

235 min_matches = [] 

236 for matchNum, match in enumerate(matches): 

237 for groupNum in range(0, len(match.groups())): 

238 min_matches.append(match.group(groupNum)) 

239 min_matches = list( 

240 set(map(lambda x: x.lstrip("{{").rstrip("}}"), [i for i in min_matches])) 

241 ) 

242 if all(key in info_dict for key in min_matches): 242 ↛ 245line 242 didn't jump to line 245, because the condition on line 242 was never false

243 return True 

244 else: 

245 missing = set(min_matches).difference(info_dict) 

246 raise Exception( 

247 f"Not all expressions defined in template. Missing expressions: {missing}" 

248 ) 

249 

250 

251def _create_readme(main_path, now, info_dict): # pragma: no cover 

252 from pip._internal.operations import freeze 

253 

254 from .._version import __version__ 

255 from .jinja_template import template 

256 

257 # update info dict 

258 md_file = os.path.join(main_path, "README.md") 

259 pip_freeze = "" 

260 for i in freeze.freeze(): 

261 pip_freeze += f" {i}\n" 

262 info_dict.update({"pip_freeze": pip_freeze}) 

263 info_dict.update({"filename": md_file.split(".")[0]}) 

264 info_dict.update({"now": now}) 

265 info_dict.update({"encodermap_version": __version__}) 

266 info_dict.update(_get_system_info()) 

267 

268 # check 

269 assert _check_all_templates_defined(template, info_dict) 

270 

271 # jinja2 

272 template = jinja2.Template(template) 

273 msg = template.render( 

274 info_dict, 

275 ) 

276 

277 # write 

278 with open(md_file, "w") as f: 

279 f.write(msg) 

280 

281 

282def _match_tops_and_trajs(min_topfiles, list_of_frames): 

283 out = {i: [] for i in min_topfiles} 

284 top_file_match = {md.load(i).top: i for i in min_topfiles} 

285 for frame in list_of_frames: 

286 top_file = top_file_match[frame.top] 

287 out[top_file].append(frame) 

288 return out 

289 

290 

291def _unpack_cluster_info(trajs, main_path, selector, dummy_traj, align_string): 

292 max_ = trajs.CVs["user_selected_points"].max() 

293 where = np.where(trajs.CVs["user_selected_points"] == max_)[0] 

294 now = _datetime_windows_and_linux_compatible() 

295 

296 # make dirs 

297 os.makedirs(os.path.join(main_path, "clusters"), exist_ok=True) 

298 main_path = os.path.join(main_path, f"clusters/{now}") 

299 os.makedirs(main_path, exist_ok=True) 

300 

301 # define names 

302 pdb_name = os.path.join(main_path, f"cluster_id_{max_}_stacked_10_structs.pdb") 

303 pdb_start_name = os.path.join(main_path, f"cluster_id_{max_}_start.pdb") 

304 pdb_origin_names = os.path.join(main_path, f"cluster_id_{max_}_pdb_origins.txt") 

305 xtc_name = os.path.join(main_path, f"cluster_id_{max_}.xtc") 

306 csv_name = os.path.join(main_path, f"cluster_id_{max_}_selected_points.csv") 

307 png_name = os.path.join(main_path, f"cluster_id_{max_}_image.png") 

308 lowd_npy_name = os.path.join( 

309 main_path, f"cluster_id_{max_}_cluster_lowd_points.npy" 

310 ) 

311 indices_npy_name = os.path.join( 

312 main_path, f"cluster_id_{max_}_cluster_lowd_points_indices.npy" 

313 ) 

314 current_clustering = os.path.join( 

315 main_path, 

316 f"cluster_id_{max_}_cluster_current_clustering_user_selected_points.npy", 

317 ) 

318 selector_npy_name = os.path.join( 

319 main_path, f"cluster_id_{max_}_selector_points.npy" 

320 ) 

321 parents_trajs = os.path.join( 

322 main_path, f"cluster_id_{max_}_all_plotted_trajs_in_correct_order.txt" 

323 ) 

324 

325 # save edges of selector 

326 if isinstance(selector.lasso, PolygonSelector): 326 ↛ 333line 326 didn't jump to line 333, because the condition on line 326 was never false

327 verts = np.vstack([selector.lasso._xs, selector.lasso._ys]).T 

328 selector_npy_name = selector_npy_name.replace( 

329 "selector", f"{selector.lasso.__class__.__name__.lower()}" 

330 ) 

331 np.save(selector_npy_name, verts) 

332 else: 

333 warnings.warn("Can currenlty only save the vertices of Polygon.") 

334 

335 # save pdb 

336 # if list is provided the pdbs of the list need to be saved 

337 if isinstance(dummy_traj, list): 337 ↛ 342line 337 didn't jump to line 342, because the condition on line 337 was never false

338 for i, traj in enumerate(dummy_traj): 

339 _pdb_name = pdb_name.replace(".pdb", f"_struct_{i}.pdb") 

340 traj.save_pdb(_pdb_name) 

341 else: 

342 dummy_traj.save_pdb(pdb_name) 

343 

344 # render png 

345 plot_cluster(trajs, pdb_name, png_name, max_) 

346 

347 # save all trajs 

348 with open(parents_trajs, "w") as f: 

349 for traj in trajs: 

350 f.write( 

351 f"{os.path.abspath(traj.traj_file)} {os.path.abspath(traj.top_file)} {traj.common_str}\n" 

352 ) 

353 

354 # create df 

355 if trajs.lowd.shape[-1] == 2: 355 ↛ 357line 355 didn't jump to line 357, because the condition on line 355 was never false

356 lowd_coords = {"x": [], "y": []} 

357 elif trajs.lowd.shaoe[-1] == 3: 

358 lowd_coords = {"x": [], "y": [], "z": []} 

359 else: 

360 lowd_coords = {f"lowd_{i}": [] for i in range(trajs.lowd.shape[-1])} 

361 df = pd.DataFrame( 

362 { 

363 "trajectory file": [], 

364 "topology file": [], 

365 "frame number": [], 

366 "time": [], 

367 **lowd_coords, 

368 "cluster id": [], 

369 "trajectory number": [], 

370 f"unique id in set of {trajs.n_trajs} trajs": [], 

371 } 

372 ) 

373 for w in where: 

374 frame = trajs.get_single_frame(w) 

375 if frame.traj_num is None: 375 ↛ 376line 375 didn't jump to line 376, because the condition on line 375 was never true

376 traj_frame = frame.id[0] 

377 else: 

378 traj_frame = frame.id[0, 1] 

379 df = df.append( 

380 { 

381 "trajectory file": os.path.abspath(frame.traj_file), 

382 "topology file": os.path.abspath(frame.top_file), 

383 "frame number": traj_frame, 

384 "time": frame.time[0], 

385 "cluster id": max_, 

386 "trajectory number": frame.traj_num, 

387 f"unique id in set of {trajs.n_trajs} trajs": w, 

388 **{k: v for k, v in zip(lowd_coords.keys(), frame.lowd)}, 

389 }, 

390 ignore_index=True, 

391 ) 

392 df = df.astype( 

393 dtype={ 

394 "trajectory file": str, 

395 "topology file": str, 

396 "frame number": int, 

397 "time": float, 

398 **{k: float for k in lowd_coords}, 

399 "cluster id": int, 

400 "trajectory number": int, 

401 f"unique id in set of {trajs.n_trajs} trajs": int, 

402 } 

403 ) 

404 df.to_csv(csv_name, index=False) 

405 

406 # save npy 

407 np.save(lowd_npy_name, trajs.CVs["user_selected_points"][where]) 

408 np.save(indices_npy_name, where) 

409 np.save(current_clustering, trajs.CVs["user_selected_points"]) 

410 

411 # save full traj 

412 _, dummy_traj = get_cluster_frames( 

413 trajs, 

414 max_, 

415 nglview=False, 

416 shorten=False, 

417 stack_atoms=True, 

418 col="user_selected_points", 

419 align_string=align_string, 

420 ) 

421 min_topfiles = list(set([i.top_file for i in trajs])) 

422 top_traj_dict = _match_tops_and_trajs(min_topfiles, dummy_traj) 

423 basenames = [trajs.basename_fn(i) for i in top_traj_dict.keys()] 

424 with open(pdb_origin_names, "w") as f: 

425 for i, (key, value) in enumerate(top_traj_dict.items()): 

426 _pdb_start_name = pdb_start_name.replace( 

427 ".pdb", f"_traj_{i}_from_{basenames[i]}.pdb" 

428 ) 

429 _xtc_name = xtc_name.replace(".xtc", f"_traj_{i}.xtc") 

430 shutil.copyfile(key, _pdb_start_name) 

431 f.write( 

432 f"{_pdb_start_name} is a copy (`shutil.copyfile`) of {min_topfiles[i]}. The corresponding trajectory files might originate from other places. Refer to {parents_trajs} for info about xtcs.\n" 

433 ) 

434 if value: 434 ↛ 425line 434 didn't jump to line 425, because the condition on line 434 was never false

435 for j, frame in enumerate(value): 

436 if j == 0: 

437 traj_out = frame 

438 else: 

439 traj_out = traj_out.join(frame) 

440 traj_out.save_xtc(_xtc_name) 

441 

442 # create an info dict 

443 info_dict = { 

444 "pdb_name": pdb_name, 

445 "pdb_start_name": pdb_start_name, 

446 "pdb_origin_names": pdb_origin_names, 

447 "xtc_name": xtc_name, 

448 "csv_name": csv_name, 

449 "png_name": png_name, 

450 "lowd_npy_name": lowd_npy_name, 

451 "indices_npy_name": indices_npy_name, 

452 "current_clustering": current_clustering, 

453 "selector_npy_name": selector_npy_name, 

454 "parents_trajs": parents_trajs, 

455 } 

456 info_dict = {k: os.path.basename(v) for k, v in info_dict.items()} 

457 info_dict.update({"cluster_id": max_}) 

458 info_dict.update({"cluster_id": max_}) 

459 info_dict.update({"n_trajs": trajs.n_trajs}) 

460 info_dict.update({"cluster_n_points": len(where)}) 

461 info_dict.update({"basename": trajs[0].basename}) 

462 info_dict.update({"cluster_abspath": main_path}) 

463 

464 # create a readme 

465 _create_readme(main_path, now, info_dict) 

466 

467 return max_, main_path 

468 

469 

470def _unpack_path_info(path): 

471 pass 

472 

473 

474############################################################################## 

475# Classes 

476############################################################################## 

477 

478 

479class Props: 

480 """Class to contain the properties of MenuItems. 

481 

482 Each MenuItem contains two copies of this class. 

483 One for props when the mouse cursor hovers over them. 

484 One for the rest of times. 

485 Can be used used as class or as dict. 

486 

487 Attributes: 

488 labelcolor (str): The color of the text. 

489 labelcolor_rgb (tuple): The RGBA value of labelcolor. 

490 bgcolor (str): The color of the background. 

491 bgcolor_rgb (tuple): The RGBA value of bgcolor. 

492 fontsize (int): The fontsize. 

493 alpha (float): The alpha value of the background. 

494 defaults (dict): Class Variable of default values. 

495 

496 Examples: 

497 >>> props = Props() 

498 >>> props.labelcolor 

499 black 

500 

501 """ 

502 

503 defaults = {"labelcolor": "black", "bgcolor": "blue", "fontsize": 15, "alpha": 0.2} 

504 

505 def __init__(self, **kwargs): 

506 """Instantiate the Props class. 

507 

508 Takes a dict as input and overwrites the class defaults. The dict is directly 

509 stored as an attribute and can be accessed via dot-notation. 

510 

511 Args: 

512 **kwargs: Dict containing values. If unknonwn values are passed they will be dropped. 

513 

514 """ 

515 self._setattr(self.defaults) 

516 for key, value in kwargs.items(): 

517 if key not in self.__dict__.keys(): 

518 print(f"Dropping unknown dict entry for {{'{key}': {value}}}") 

519 else: 

520 setattr(self, key, value) 

521 

522 @property 

523 def bgcolor_rgb(self): 

524 return mpl.colors.to_rgba(self.bgcolor) 

525 

526 @property 

527 def labelcolor_rgb(self): 

528 return mpl.colors.to_rgba(self.labelcolor) 

529 

530 def _setattr(self, dictionary): 

531 for key, value in dictionary.items(): 

532 setattr(self, key, value) 

533 

534 def __setitiem__(self, key, value): 

535 setattr(self, key, value) 

536 

537 def __getitem__(self, item): 

538 return getattr(self, item) 

539 

540 def _string_summary(self): 

541 return "ep.plotting.interactive.Props class containing Info for the MenuItem class." 

542 

543 def __str__(self): 

544 return self._string_summary() 

545 

546 def __repr__(self): 

547 return f"<{self._string_summary()} Object at 0x{id(self):02x}>" 

548 

549 

550class BezierBuilder(object): 

551 """Bézier curve interactive builder.""" 

552 

553 def __init__(self, control_polygon, ax_main, ax_bernstein=None): 

554 """Constructor. 

555 Receives the initial control polygon of the curve. 

556 """ 

557 self.control_polygon = control_polygon 

558 self.xp = list(control_polygon.get_xdata()) 

559 self.yp = list(control_polygon.get_ydata()) 

560 self.canvas = control_polygon.figure.canvas 

561 self.ax_main = ax_main 

562 self.ax_bernstein = ax_bernstein 

563 

564 # Event handler for mouse clicking 

565 self.cid = self.canvas.mpl_connect("button_press_event", self) 

566 

567 # Create Bézier curve 

568 line_bezier = Line2D([], [], c=control_polygon.get_markeredgecolor()) 

569 self.bezier_curve = self.ax_main.add_line(line_bezier) 

570 

571 def __call__(self, event): 

572 # Ignore clicks outside axes 

573 if event.inaxes != self.control_polygon.axes: 

574 return 

575 

576 # Add point 

577 self.xp.append(event.xdata) 

578 self.yp.append(event.ydata) 

579 self.control_polygon.set_data(self.xp, self.yp) 

580 

581 # Rebuild Bézier curve and update canvas 

582 self.bezier_curve.set_data(*self._build_bezier()) 

583 self._update_bernstein() 

584 self._update_bezier() 

585 

586 def _build_bezier(self): 

587 x, y = Bezier(list(zip(self.xp, self.yp))).T 

588 return x, y 

589 

590 def _update_bezier(self): 

591 self.canvas.draw() 

592 

593 def _update_bernstein(self): 

594 if self.ax_bernstein is not None: 

595 N = len(self.xp) - 1 

596 t = np.linspace(0, 1, num=200) 

597 ax = self.ax_bernstein 

598 ax.clear() 

599 for kk in range(N + 1): 

600 ax.plot(t, Bernstein(N, kk)(t)) 

601 ax.set_title("Bernstein basis, N = {}".format(N)) 

602 ax.set_xlim(0, 1) 

603 ax.set_ylim(0, 1) 

604 

605 @property 

606 def ind(self): 

607 return np.vstack(self.bezier_curve.get_data()).T 

608 

609 def disconnect(self): 

610 self.canvas.mpl_disconnect(self.cid) 

611 

612 

613def Bernstein(n, k): 

614 """Bernstein polynomial.""" 

615 coeff = binom(n, k) 

616 

617 def _bpoly(x): 

618 return coeff * x**k * (1 - x) ** (n - k) 

619 

620 return _bpoly 

621 

622 

623def Bezier(points, num=200): 

624 """Build Bézier curve from points.""" 

625 N = len(points) 

626 t = np.linspace(0, 1, num=num) 

627 curve = np.zeros((num, 2)) 

628 for ii in range(N): 

629 curve += np.outer(Bernstein(N - 1, ii)(t), points[ii]) 

630 return curve 

631 

632 

633class DummyTool: 

634 def disconnect(self): 

635 pass 

636 

637 

638class SelectFromCollection(object): 

639 """Select indices from a matplotlib collection using `LassoSelector`. 

640 

641 Selected indices are saved in the `ind` attribute. This tool fades out the 

642 points that are not part of the selection (i.e., reduces their alpha 

643 values). If your collection has alpha < 1, this tool will permanently 

644 alter the alpha values. 

645 

646 Note that this tool selects collection objects based on their *origins* 

647 (i.e., `offsets`). 

648 

649 Args: 

650 ax (matplotlib.axes.Axes): Axes to interact with. 

651 collection (matplotlib.collections.Collection): Subclass of collection 

652 you want to select from. 

653 alpha_other (float): To highlight a selection, this tool sets all 

654 selected points to an alpha value of 1 and non-selected points to 

655 `alpha_other`. Needs to fulfill 0 <= alpha_other <= 1 

656 

657 """ 

658 

659 def __init__(self, ax, collection, alpha_other=0.3, selector=LassoSelector): 

660 self.canvas = ax.figure.canvas 

661 self.collection = collection 

662 self.alpha_other = alpha_other 

663 

664 self.xys = collection.get_offsets() 

665 self.Npts = len(self.xys) 

666 

667 # Ensure that we have separate colors for each object 

668 self.fc = collection.get_facecolors() 

669 if len(self.fc) == 0: 

670 raise ValueError("Collection must have a facecolor") 

671 elif len(self.fc) == 1: 

672 self.fc = np.tile(self.fc, (self.Npts, 1)) 

673 

674 self.lasso = selector(ax, onselect=self.onselect, useblit=False) 

675 self.ind = [] 

676 

677 def onselect(self, verts): 

678 path = Path(verts) 

679 self.ind = np.nonzero(path.contains_points(self.xys))[0] 

680 self.fc[:, -1] = self.alpha_other 

681 self.fc[self.ind, -1] = 1 

682 self.collection.set_facecolors(self.fc) 

683 self.canvas.draw_idle() 

684 

685 def disconnect(self): 

686 self.lasso.disconnect_events() 

687 self.fc[:, -1] = 1 

688 self.collection.set_facecolors(self.fc) 

689 self.canvas.draw_idle() 

690 

691 def _string_summary(self): 

692 s = ( 

693 f"encodermap.SelectFromCollection object. Selecting points " 

694 f"from a set of {self.Npts} points. The current selector tool " 

695 f"is matplotlotlib's {self.lasso} tool. The selected points are " 

696 f"{self.ind}" 

697 ) 

698 return s 

699 

700 def __str__(self): 

701 return self._string_summary() 

702 

703 def __repr__(self): 

704 return self._string_summary() 

705 

706 

707class MenuItem(mpl.patches.Rectangle, mpl.artist.Artist): 

708 def __init__( 

709 self, 

710 fig, 

711 xy, 

712 width, 

713 height, 

714 labelstr, 

715 props={}, 

716 hoverprops={}, 

717 on_select=None, 

718 standalone=False, 

719 ): 

720 # define props and hoverprops 

721 self.labelstr = labelstr 

722 self.standalone = standalone 

723 props = { 

724 **{"labelcolor": "black", "bgcolor": "blue", "fontsize": 15, "alpha": 0.2}, 

725 **props, 

726 } 

727 hoverprops = { 

728 **{"labelcolor": "C0", "bgcolor": "yellow", "fontsize": 15, "alpha": 0.2}, 

729 **hoverprops, 

730 } 

731 self.props = Props(**props) 

732 self.hoverprops = Props(**hoverprops) 

733 

734 # set the on_select method 

735 self.on_select = on_select 

736 

737 # add a select event 

738 # hover event is only included if standalone is true 

739 # i.e. the MenuItem is not Instantiated from a Menu 

740 # In the case this is instantiated from a menu. The menu wil handle the hover 

741 self.cid_button = fig.canvas.mpl_connect( 

742 "button_release_event", self.check_select 

743 ) 

744 if self.standalone: 

745 self.cid_move = fig.canvas.mpl_connect( 

746 "motion_notify_event", self.set_hover 

747 ) 

748 

749 # Instantiate text and Rectangle 

750 mpl.patches.Rectangle.__init__( 

751 self, xy, width, height, label=self.labelstr, zorder=1 

752 ) 

753 self.text = mpl.text.Text( 

754 0, 

755 0, 

756 self.labelstr, 

757 color=self.props.labelcolor, 

758 fontproperties=dict(weight="bold", size=self.props.fontsize), 

759 zorder=2, 

760 verticalalignment="center", 

761 horizontalalignment="center", 

762 ) 

763 

764 # final thing to do is set the props of 

765 # the rectangle based whether a cursor hovers 

766 self.set_hover_props(False) 

767 

768 # after artist has been placed get it ready 

769 # for accepting hover events 

770 self.hover = False 

771 

772 def set_figure(self, figure): 

773 """Overwriting Base Class method to include labelstr""" 

774 mpl.patches.Rectangle.set_figure(self, figure) 

775 self.text.set_figure(figure) 

776 

777 def set_axes(self, axes): 

778 """Overwriting Base Class method to include labelstr""" 

779 mpl.patches.Rectangle.set_axes(self, axes) 

780 self.text.set_axes(axes) 

781 

782 def set_transform(self, transform): 

783 """Overwriting Base Class method to include labelstr""" 

784 mpl.patches.Rectangle.set_transform(self, transform) 

785 # set text to center of self(.rect) 

786 bbox = self.get_bbox() 

787 x = bbox.x0 + 0.5 * bbox.width 

788 y = bbox.y0 + 0.5 * bbox.height 

789 texttrans = mpl.transforms.Affine2D().translate(x, y) + self.axes.transData 

790 self.text.set_transform(texttrans) 

791 

792 def set_data(self, x, y): 

793 """Overwriting Base Class method to include labelstr""" 

794 # if len(x): 

795 # self.text.set_position((x[-1], y[-1])) 

796 mpl.patches.Rectangle.set_data(self, x, y) 

797 self.text.set_color(self.props.labelcolor_rgb) 

798 

799 def check_select(self, event): 

800 over, _ = self.contains(event) 

801 if not over: 

802 return 

803 if self.on_select is not None: 

804 self.on_select(self) 

805 return True 

806 

807 def draw(self, renderer): 

808 """Overwriting Base Class method to include labelstr""" 

809 # draw my label at the end of the line with 2 pixel offset 

810 mpl.patches.Rectangle.draw(self, renderer) 

811 self.text.draw(renderer) 

812 

813 def set_hover_props(self, check): 

814 if check: 

815 props = self.hoverprops 

816 else: 

817 props = self.props 

818 

819 self.set(facecolor=props.bgcolor_rgb, alpha=props.alpha) 

820 self.text.set_color(props.labelcolor_rgb) 

821 

822 def set_hover(self, event): 

823 """Check the hover status of MenuItem""" 

824 check, _ = self.contains(event) 

825 changed = check != self.hover 

826 if changed: 

827 self.set_hover_props(check) 

828 self.hover = check 

829 if changed and self.standalone: 

830 self.figure.canvas.draw() 

831 return changed 

832 

833 

834class ModeButton(MenuItem): 

835 def __init__(self, *args, **kwargs): 

836 # overwite the labelstr with 'Idle' 

837 super(ModeButton, self).__init__(*args, **kwargs) 

838 

839 # overwrite the on_select function with on_select_rotation 

840 self.pressed = False 

841 

842 def check_select(self, event, overwrite=False): 

843 if overwrite: 

844 return 

845 over, _ = self.contains(event) 

846 if not over: 

847 return 

848 if self.on_select is not None: 

849 self.on_select(self) 

850 self.pressed = not self.pressed 

851 if self.pressed: 

852 self.set_hover_props(True) 

853 return self.pressed 

854 

855 def set_hover(self, event): 

856 """Check the hover status of MenuItem""" 

857 check, _ = self.contains(event) 

858 changed = check != self.hover 

859 if changed and not self.pressed: 

860 self.set_hover_props(check) 

861 self.hover = check 

862 return changed 

863 

864 def _on_select_rotation(self): 

865 """Old function to rotate labelstrings.""" 

866 if self.status == len(self.labelrotation) - 1: 

867 self.status = 0 

868 else: 

869 self.status += 1 

870 self.labelstr = self.labelrotation[self.status] 

871 self.label.set_text(self.labelstr) 

872 # print(f"You pressed {self.labelstr}") 

873 

874 

875class Menu: 

876 def __init__(self, ax, items=["Reset", "Write", "Set Points"]): 

877 # suppresscomposite 

878 self.ax = ax 

879 self.fig = self.ax.get_figure() 

880 self.fig.suppressComposite = True 

881 

882 # add the menu items 

883 self.menuitems = {} 

884 self.add_items(items=items) 

885 

886 # mpl_connect move to set hover stuff 

887 self.cid = self.fig.canvas.mpl_connect("motion_notify_event", self.on_move) 

888 

889 def on_move(self, event): 

890 draw = False 

891 for item in self.menuitems.values(): 

892 draw = item.set_hover(event) 

893 if draw: 

894 self.fig.canvas.draw() 

895 break 

896 

897 def add_items(self, items): 

898 coords = self.get_coords(len(items)) 

899 for i, (s, c) in enumerate(zip(items, coords)): 

900 # on_select = lambda item: print(f"You pressed {item.labelstr}.") 

901 on_select = lambda item: True 

902 item = MenuItem( 

903 self.fig, (0, c[0]), 1, c[1] - c[0], labelstr=s, on_select=on_select 

904 ) 

905 self.menuitems[s] = item 

906 self.ax.add_artist(item) 

907 

908 def get_coords(self, no_items, gap_space=0.05): 

909 no_gaps = no_items - 1 

910 total_length = 1 - no_gaps * gap_space 

911 length_per = total_length / no_items 

912 coords = [] 

913 for i in range(no_items): 

914 if i == 0: 

915 coords.append([0, length_per]) 

916 else: 

917 coords.append( 

918 [ 

919 coords[i - 1][1] + gap_space, 

920 coords[i - 1][1] + gap_space + length_per, 

921 ] 

922 ) 

923 return coords 

924 

925 

926class StatusMenu(Menu): 

927 def __init__(self, ax): 

928 items = [ 

929 "Lasso", 

930 "Rectangle", 

931 "Ellipse", 

932 "Polygon", 

933 "Path", 

934 "Bezier", 

935 "Idle", 

936 "Mode", 

937 ] 

938 # call the parent class to use its get_coords() method 

939 # the add_items method is overwritten 

940 super(StatusMenu, self).__init__(ax, items=items) 

941 

942 # set the label rotation and the current status 

943 self.set_idle() 

944 

945 # click notify event to change self.status and switch ModeButtons off 

946 self.cid = self.fig.canvas.mpl_connect("button_release_event", self.on_click) 

947 

948 def set_idle(self): 

949 """Sets the idle status. Called at __init__ and 

950 when nothing is pressed.""" 

951 self.status = "Idle" 

952 self.menuitems[self.status].set_hover_props(True) 

953 self.menuitems[self.status].pressed = True 

954 

955 def on_click(self, event): 

956 draw = False 

957 for key, item in self.menuitems.items(): 

958 draw = item.pressed 

959 if draw and key != self.status: 

960 print(f"Changing Mode to {key}") 

961 self.menuitems[self.status].set_hover_props(False) 

962 self.menuitems[self.status].pressed = False 

963 self.menuitems[self.status].check_select(event, overwrite=True) 

964 self.status = key 

965 else: 

966 if all([not item.pressed for item in self.menuitems.values()]): 

967 print(f"Changing Mode to Idle") 

968 self.set_idle() 

969 

970 def add_items(self, items): 

971 coords = self.get_coords(len(items)) 

972 for i, (s, c) in enumerate(zip(items, coords)): 

973 if s == "Mode": 

974 on_select = lambda item: None # print(f"You pressed {item.labelstr}.") 

975 # overwrite props and hoverprops 

976 props = { 

977 "labelcolor": "black", 

978 "bgcolor": "orange", 

979 "fontsize": 15, 

980 "alpha": 1, 

981 } 

982 hoverprops = { 

983 "labelcolor": "black", 

984 "bgcolor": "orange", 

985 "fontsize": 15, 

986 "alpha": 1, 

987 } 

988 item = MenuItem( 

989 self.fig, 

990 (0, c[0]), 

991 1, 

992 c[1] - c[0], 

993 labelstr=s, 

994 on_select=on_select, 

995 props=props, 

996 hoverprops=hoverprops, 

997 ) 

998 item.pressed = False 

999 else: 

1000 # on_select = lambda item: print(f"You pressed {item.labelstr}.") 

1001 on_select = lambda item: True 

1002 # item = ModeButton(self.fig, self.ax, s, on_select=on_select) 

1003 item = ModeButton( 

1004 self.fig, (0, c[0]), 1, c[1] - c[0], labelstr=s, on_select=on_select 

1005 ) 

1006 self.menuitems[s] = item 

1007 self.ax.add_artist(item)