Coverage for encodermap/plot/interactive_plotting.py: 9%

683 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-15 21:06 +0200

1# -*- coding: utf-8 -*- 

2# encodermap/plot/interactive_plotting.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22 

23################################################################################ 

24# Imports 

25################################################################################ 

26 

27 

28# Future Imports at the top 

29from __future__ import annotations 

30 

31# Standard Library Imports 

32import getpass 

33import os 

34import platform 

35import re 

36import socket 

37import threading 

38import time 

39import uuid 

40from contextlib import contextmanager 

41from copy import deepcopy 

42from functools import cached_property 

43from pathlib import Path 

44from typing import TYPE_CHECKING, Any, Literal, Optional, Union 

45 

46# Third Party Imports 

47import ipywidgets as widgets 

48import numpy as np 

49from optional_imports import _optional_import 

50from pip._internal.operations.freeze import freeze 

51 

52# Encodermap imports 

53from encodermap.autoencoder.autoencoder import ( 

54 AngleDihedralCartesianEncoderMap, 

55 Autoencoder, 

56 DihedralEncoderMap, 

57 EncoderMap, 

58) 

59from encodermap.misc.misc import _datetime_windows_and_linux_compatible, _is_notebook 

60from encodermap.plot.plotting import ( 

61 _plot_free_energy, 

62 get_histogram, 

63 plot_trajs_by_parameter, 

64 to_density, 

65 to_free_energy, 

66) 

67from encodermap.trajinfo.info_all import TrajEnsemble 

68from encodermap.trajinfo.info_single import SingleTraj 

69 

70 

71################################################################################ 

72# Optional Imports 

73################################################################################ 

74 

75 

76sns = _optional_import("seaborn") 

77md = _optional_import("mdtraj") 

78jinja2 = _optional_import("jinja2") 

79make_subplots = _optional_import("plotly", "subplots.make_subplots") 

80px = _optional_import("plotly", "express") 

81go = _optional_import("plotly", "graph_objects") 

82Image = _optional_import("PIL", "Image") 

83Canvas = _optional_import("ipycanvas", "Canvas") 

84hold_canvas = _optional_import("ipycanvas", "hold_canvas") 

85nv = _optional_import("nglview") 

86psutil = _optional_import("psutil") 

87display = _optional_import("IPython", "display.display") 

88Image = _optional_import("PIL", "Image") 

89 

90 

91################################################################################ 

92# Typing 

93################################################################################ 

94 

95 

96if TYPE_CHECKING: 

97 # Third Party Imports 

98 from mdtraj import Topology, Trajectory 

99 

100 # Encodermap imports 

101 from encodermap.autoencoder.autoencoder import AutoencoderClass 

102 

103 

104################################################################################ 

105# Globals 

106################################################################################ 

107 

108 

109__all__: list[str] = ["InteractivePlotting"] 

110 

111 

112# fmt: off 

113BAD_MODEBAR_BUTTONS = [ 

114 "autoScale2d", "autoscale", "editInChartStudio", "editinchartstudio", 

115 "hoverCompareCartesian", "hovercompare", "lasso", "lasso2d", "orbitRotation", 

116 "orbitrotation", "pan", "pan2d", "pan3d", "reset", "resetCameraDefault3d", 

117 "resetCameraLastSave3d", "resetGeo", "resetSankeyGroup", "resetScale2d", 

118 "resetViewMapbox", "resetViews", "resetcameradefault", "resetcameralastsave", 

119 "resetsankeygroup", "resetscale", "resetview", "resetviews", "select", 

120 "select2d", "sendDataToCloud", "senddatatocloud", "tableRotation", 

121 "tablerotation", "toImage", "toggleHover", "toggleSpikelines", "togglehover", 

122 "togglespikelines", "toimage", "zoom", "zoom2d", "zoom3d", "zoomIn2d", 

123 "zoomInGeo", "zoomInMapbox", "zoomOut2d", "zoomOutGeo", "zoomOutMapbox", 

124 "zoomin", "zoomout", 

125] 

126# fmt: on 

127 

128 

129H5_INFO = """\ 

130## Loading a HDF5 file (.h5) with EncoderMap 

131 

132EncoderMap introduces a way of storing multiple trajectories (a `TrajectorEnsemble`) in a 

133single file. These files can be loaded via: 

134 

135```python 

136import encodermap as em 

137trajs = em.TrajEnsemble.from_dataset('{{ h5_file }}') 

138``` 

139""" 

140 

141 

142PATH_TEMPLATE = """\ 

143# README for EncoderMap.InteractivePlotting generate 

144 

145You just used EncoderMap's `InteractivePlotting` and saved protein conformations generated from a path in a low-dimensional representation of a {{ ensemble_type}}. The conformations were generated using a trained neural network autoencoder (EncoderMap's {{ autoencoder_class }} class) from {{ n_points }} {{ lowd_dim }}-dimensional coordinates. The {{ ensemble_type }} contained {{ n_top }} distinct protein topologies. From these topologies, the {{ chosen_top }} was chosen to build this cluster. Find the topological information in the `.pdb` file in this directory. Look at EncoderMap's documentation at https://ag-peter.github.io/encodermap/ to learn more about Trajectory Ensembles. 

146 

147### The complete Ensemble is also present 

148 

149If you want to get more information about the clustering you carried out, you can refer to these files: 

150 

151### lowd.csv 

152 

153This `.csv` file contains info about the complete ensemble this cluster was selected from. The columns are as follows: 

154 

155| traj_num | The number of the trajectory in the full dataset. This number is 0-based. If only one trajectory is loaded, its `trajectory number` might also be `None`. | 

156| --------- | ------------------------------------------------------------ | 

157| frame_num | The frame number. The trajectory number and frame number can be used to unmistakably identify frames in a trajectory ensemble. Frame numbers are also 0-based. | 

158| traj_file | Contains the trajectory data (file formats such as .xtc, .dcd, .h5). | 

159| top_file | Contains the topology of the file (i.e. atom types, masses, residues) (file formats such as .pdb, .gro, .h5). Some trajectory files (.h5) might also contain the topology. In that case `trajectory file` and `topology` file are identical. | 

160| time | The time of the frame. This can be used for time-based indexing of trajectories. EncoderMap offers the `SingleTraj.tsel[time]` accessor to distinguish it from frame-based indexing via `SingleTraj[frame]`. | 

161| x | The x coordinate of the low-dimensional projection. | 

162| y | The y-coordinate of the low-dimensional projection. | 

163 

164### path.npy 

165 

166This numpy array contains the (x, y)-coordinates of the low-dimensional path, that was used to generate the conformations. 

167 

168### path.png 

169 

170A nice render of the selected cluster. 

171 

172### generated.pdb and generated.xtc 

173 

174These files contain the topological (`.pdb`) and trajectory (`.xtc`)information to rebuild this path. Check out the function `encodermap.plot.plottinginteractive_path_visualization`, which can be used to display a animation of that path: 

175 

176```python 

177import encodermap as em 

178import numpy as np 

179import pandas as pd 

180 

181path = em.load("{{ xtc_file }}", "{{ pdb_file }}") 

182lowd = pd.read_csv("{{ csv_file }}") 

183path = np.load("{{ npy_file }}") 

184 

185em.plot.interactive_path_visualization( 

186 traj, 

187 lowd, 

188 path, 

189) 

190``` 

191 

192 

193 

194## Rendering this document 

195 

196If you don't like to view plain markdown files with a text-viewer there are many viewers available, that are able to render markdown nicely. I am currently using ghostwriter: 

197 

198https://ghostwriter.kde.org/ 

199 

200If you want to create a pdf from this document you can try a combination of pandoc, latex and groff. 

201 

202### HTML 

203 

204```bash 

205pandoc {{filename}}.md -o {{filename}}.html 

206``` 

207 

208### Latex 

209 

210```bash 

211pandoc {{filename}}.md -o {{filename}}.pdf 

212``` 

213 

214### Groff 

215 

216```bash 

217pandoc {{filename}}.md -t ms -o {{filename}}.pdf 

218``` 

219 

220## Debug Info 

221 

222``` 

223encodermap.__version__ = {{encodermap_version}} 

224system_user = {{system_user}} 

225platform = {{platform}} 

226platform_release = {{platform_release}} 

227platform_version = {{platform_version}} 

228architecture = {{architecture}} 

229hostname = {{hostname}} 

230ip_address = {{ip_address}} 

231mac_address = {{mac_address}} 

232processor = {{processor}} 

233ram = {{ram}} 

234pip freeze = {{pip_freeze}} 

235``` 

236 

237""" 

238 

239 

240CLUSTER_TEMPLATE = """\ 

241# README for EncoderMap.InteractivePlotting cluster 

242 

243You just used EncoderMap's `InteractivePlotting` and saved a cluster. Here's some information about this cluster. The cluster was selected from a `TrajectoryEnsemble` containing {{ n_trajs }} trajectories, {{ n_frames }} frames and {{ n_top }} unique topologies. This cluster was assigned the number {{ cluster_num }}. The file {{ h5_file }} contains only {{ n_points }} frames, chosen as representatives for this cluster. This file can be loaded with EncoderMap's `TrajEnsemble.from_dataset('{{ h5_file }}')` method. Look at EncoderMap's documentation at https://ag-peter.github.io/encodermap/ to learn more about Trajectory Ensembles. 

244 

245### The complete Ensemble is also present 

246 

247If you want to get more information about the clustering you carried out, you can refer to these files: 

248 

249### cluster_{{ cluster_num }}.csv 

250 

251This `.csv` file contains info about the complete ensemble this cluster was selected from. The columns are as follows: 

252 

253| traj_num | The number of the trajectory in the full dataset. This number is 0-based. If only one trajectory is loaded, its `trajectory number` might also be `None`. | 

254| ---------- | ------------------------------------------------------------ | 

255| frame_num | The frame number. The trajectory number and frame number can be used to unmistakably identify frames in a trajectory ensemble. Frame numbers are also 0-based. | 

256| traj_file | Contains the trajectory data (file formats such as .xtc, .dcd, .h5). | 

257| top_file | Contains the topology of the file (i.e. atom types, masses, residues) (file formats such as .pdb, .gro, .h5). Some trajectory files (.h5) might also contain the topology. In that case `trajectory file` and `topology` file are identical. | 

258| time | The time of the frame. This can be used for time-based indexing of trajectories. EncoderMap offers the `SingleTraj.tsel[time]` accessor to distinguish it from frame-based indexing via `SingleTraj[frame]`. | 

259| x | The x coordinate of the low-dimensional projection. | 

260| y | The y-coordinate of the low-dimensional projection. | 

261| cluster_id | This column contains -1, which are points not included in a cluster (outliers). Cluster 1 is denoted by a 0 in this column. If multiple clusters have been selected this column can contain multiple integer values. For every subsequent cluster, the `cluster_id` is advanced by 1. | 

262 

263### cluster_{{ cluster_num }}_selector.npy 

264 

265This numpy array contains the (x, y)-coordinates of the selector, that was used to highlight the cluster. Be careful, this shape might not be convex, so using convex algortihms to find points inside this Polygon might not work. 

266 

267### cluster_{{ cluster_num }}.png 

268 

269A nice render of the selected cluster. 

270 

271{{ h5_info }} 

272 

273## Rendering this document 

274 

275If you don't like to view plain markdown files with a text-viewer there are many viewers available, that are able to render markdown nicely. I am currently using ghostwriter: 

276 

277https://ghostwriter.kde.org/ 

278 

279If you want to create a pdf from this document you can try a combination of pandoc, latex and groff. 

280 

281### HTML 

282 

283```bash 

284pandoc {{filename}}.md -o {{filename}}.html 

285``` 

286 

287### Latex 

288 

289```bash 

290pandoc {{filename}}.md -o {{filename}}.pdf 

291``` 

292 

293### Groff 

294 

295```bash 

296pandoc {{filename}}.md -t ms -o {{filename}}.pdf 

297``` 

298 

299## Debug Info 

300 

301``` 

302encodermap.__version__ = {{encodermap_version}} 

303system_user = {{system_user}} 

304platform = {{platform}} 

305platform_release = {{platform_release}} 

306platform_version = {{platform_version}} 

307architecture = {{architecture}} 

308hostname = {{hostname}} 

309ip_address = {{ip_address}} 

310mac_address = {{mac_address}} 

311processor = {{processor}} 

312ram = {{ram}} 

313pip freeze = {{pip_freeze}} 

314 

315``` 

316""" 

317 

318 

319################################################################################ 

320# Utils 

321################################################################################ 

322 

323 

324@contextmanager 

325def set_env(**environ): 

326 """ 

327 Temporarily set the process environment variables. 

328 

329 >>> with set_env(PLUGINS_DIR='test/plugins'): 

330 ... "PLUGINS_DIR" in os.environ 

331 True 

332 

333 >>> "PLUGINS_DIR" in os.environ 

334 False 

335 

336 :type environ: dict[str, unicode] 

337 :param environ: Environment variables to set 

338 """ 

339 old_environ = dict(os.environ) 

340 os.environ.update(environ) 

341 try: 

342 yield 

343 finally: 

344 os.environ.clear() 

345 os.environ.update(old_environ) 

346 

347 

348def _check_all_templates_defined(template, info_dict): 

349 # Standard Library Imports 

350 import re 

351 

352 regex = r"\{(.*?)\}" 

353 matches = re.finditer(regex, template, re.MULTILINE | re.DOTALL) 

354 min_matches = [] 

355 for matchNum, match in enumerate(matches): 

356 for groupNum in range(0, len(match.groups())): 

357 min_matches.append(match.group(groupNum)) 

358 min_matches = list( 

359 set(map(lambda x: x.lstrip("{{").rstrip("}}"), [i for i in min_matches])) 

360 ) 

361 if all(key in info_dict for key in min_matches): 

362 return True 

363 else: 

364 missing = set(min_matches).difference(info_dict) 

365 raise Exception( 

366 f"Not all expressions defined in template. Missing expressions: {missing}" 

367 ) 

368 

369 

370def render_image(view: nv.NGLWidget, filename: str) -> None: 

371 """Renders a nv.NGLWidget inside a thread. 

372 

373 Args: 

374 view (nv.NGLWidget): The widget to be rendered. 

375 filename (str): The file to render to. 

376 

377 """ 

378 view.clear() 

379 im = view.render_image() 

380 while not im.value: 

381 time.sleep(0.1) 

382 with open(filename, "wb") as fh: 

383 fh.write(im.value.tobytes()) 

384 

385 

386def plotly_freeform_to_path(path, n_points=100): 

387 # Third Party Imports 

388 from scipy.interpolate import interp1d 

389 

390 verts = path.lstrip("M").split("L") 

391 verts = np.array([list(map(float, v.split(","))) for v in verts]) 

392 verts = np.array(verts) 

393 distance = np.cumsum( 

394 np.sqrt( 

395 np.ediff1d(verts[:, 0], to_begin=0) ** 2 

396 + np.ediff1d(verts[:, 1], to_begin=0) ** 2 

397 ) 

398 ) 

399 distance = distance / distance[-1] 

400 fx, fy = interp1d(distance, verts[:, 0]), interp1d(distance, verts[:, 1]) 

401 alpha = np.linspace(0, 1, n_points) 

402 path = np.vstack([fx(alpha), fy(alpha)]).T 

403 return path 

404 

405 

406################################################################################ 

407# Interactive Plotting 

408################################################################################ 

409 

410 

411class ProgressWidgetTqdmCompatible: 

412 """A jupyter widgtes `IntProgress` wrapper, that is compatible with tqdm calls. 

413 

414 Uses a contextmanager to open and close the progress bar. 

415 

416 """ 

417 

418 def __init__( 

419 self, 

420 container: widgets.GridspecLayout, 

421 empty: widgets.Output, 

422 total: int, 

423 description: str, 

424 ) -> None: 

425 """Instantiate the progress bar. 

426 

427 Args: 

428 container (widgets.GridSpecLayout): An instance of a widgets.GridSpecLayouts 

429 class. The progress bar will be placed in row 7 (index 6) at columns 

430 2 through to the end (index 1:). 

431 empty (widgtes.Output): After the progress bar closes, this object 

432 will be placed at the position of the progress bar to clear it. 

433 total (int): The initial total to count to. 

434 description (str): The description of the progress bar. 

435 

436 """ 

437 self.container = container 

438 self.total = total 

439 self.empty = empty 

440 self.description = description 

441 self._calls: dict[str, dict[str, int]] = {} 

442 self.print = os.getenv("ENCODERMAP_PRINT_PROG_UPDATES", "False") == "True" 

443 

444 def __enter__(self): 

445 self.progbar = widgets.IntProgress( 

446 value=0, 

447 min=0, 

448 max=self.total, 

449 step=1, 

450 description=self.description, 

451 layout={"width": "90%"}, 

452 ) 

453 self.container[6, 1:] = self.progbar 

454 return self 

455 

456 def __exit__(self, exc_type, exc_val, exc_tb): 

457 self.container[6, 1:] = self.empty 

458 if self.print: 

459 print(self._calls) 

460 del self.progbar 

461 

462 def debug_print(self) -> None: 

463 """Prints debug info.""" 

464 print(f"WidgetProgbar {id(self)}") 

465 for function, data in self._calls.items(): 

466 print( 

467 f"{function:<15} total: {data['total']:>3} n: {data['update_calls']:>3}" 

468 ) 

469 print("\n") 

470 

471 def update(self, n: int = 1, **kwargs) -> None: 

472 """Advances the progress bar by n. 

473 

474 Args: 

475 n (int): How far to advance. Defaults to 1. 

476 

477 """ 

478 function = kwargs.pop("function", None) 

479 if function is not None: 

480 if function not in self._calls: 

481 self._calls[function] = { 

482 "update_calls": 0, 

483 "total": 0, 

484 } 

485 if not isinstance(n, int): 

486 self.progbar.value += 1 

487 else: 

488 self.progbar.value += n 

489 if function is not None: 

490 self._calls[function]["update_calls"] += 1 

491 

492 def reset(self, total: int, **kwargs) -> None: 

493 """Resets the progress bar with a new total. 

494 

495 Args: 

496 total (int): New total. It should be greater than old total. 

497 

498 """ 

499 assert total > self.total 

500 function = kwargs.pop("function", None) 

501 if function is not None: 

502 if function not in self._calls: 

503 self._calls[function] = { 

504 "update_calls": 0, 

505 "total": total - self.total, 

506 } 

507 else: 

508 self._calls[function]["total"] += total - self.total 

509 self.total = total 

510 self.progbar = widgets.IntProgress( 

511 value=self.progbar.value, 

512 min=0, 

513 max=total, 

514 step=1, 

515 description=self.description, 

516 layout={"width": "90%"}, 

517 ) 

518 self.container[6, 1:] = self.progbar 

519 

520 

521class InteractivePlotting: 

522 """EncoderMap's interactive plotting for jupyter notebooks. 

523 

524 Instantiating this class will display an interactive display in your notebook. 

525 The display will look like this:: 

526 

527 ┌─────────────────────┐ ┌───────────┐ 

528 │Display │ │Top │ 

529 └─────────────────────┘ └───────────┘ 

530 ┌─────────────┐ ┌───┐ ┌─────────────┐ 

531 │ │ │ │ │ │ 

532 │ │ │ T │ │ │ 

533 │ Main │ │ R │ │ Molecular │ 

534 │ Plotting │ │ A │ │ Conform. │ 

535 │ Area │ │ C │ │ Area │ 

536 │ │ │ E │ │ │ 

537 │ │ │ │ │ │ 

538 └─────────────┘ └───┘ └─────────────┘ 

539 ┌───┐ ┌─────────────────────────────┐ 

540 │ │ │Progress Bar │ 

541 └───┘ └─────────────────────────────┘ 

542 ┌─┐ ┌─┐ ┌─┐ ┌─┐ ┌───────────────────┐ 

543 │C│ │G│ │S│ │D│ │Slider │ 

544 └─┘ └─┘ └─┘ └─┘ └───────────────────┘ 

545 ┌────────────────┐ ┌───────────────┐ 

546 │ │ │ │ 

547 │ Data │ │ │ 

548 │ Overview │ │ │ 

549 │ │ │ │ 

550 │ │ │ │ 

551 └────────────────┘ └───────────────┘ 

552 

553 The components do the following: 

554 * Display: 

555 This part will display debug information. 

556 * Top (Top selector): 

557 Select which topology to use when creating new 

558 molecular conformations from the autoencoder network. 

559 * Main plotting area: 

560 In this area, a scatter plot will be displayed. The coordinates of 

561 the scatter plot will be taken from the low-dimensional projection 

562 of the trajectories. The data for this plotting area can be 

563 taken from different sources. See the `_lowd_parser` docstring 

564 for information on how the lowd data is selected. Clicking 

565 on a point in the scatter plot displays the conformation of that 

566 point. 

567 * TRACE: 

568 Displays the high-dimensinal data of selected points or clusters. 

569 * Molecular conformation area: 

570 Displays molecular conformations. 

571 * Progress Bar: 

572 Displays progress. 

573 * C (Cluster button): 

574 After selecting point in the main plotting area 

575 with the lasso tool, hit this button to display the molecular 

576 conformations of the selected cluster. 

577 * G (Generate Button): 

578 Switch to density using the density button. 

579 Then, you can draw a freeform path into the Main plotting area. 

580 Pressing the generate button will generate the appropriate molecular 

581 conformations. If your data has multiple conformations, you can choose 

582 which conformation to use for decoding with the top selector. 

583 * S (Save button): 

584 Writes either a cluster or generated path to your disk. Uses the 

585 main_path of the autoencoder (the same directory as the training 

586 data will be stored). 

587 * D (Density button): 

588 Switch the main plotting area to Density. 

589 * Slider: 

590 In scatter mode this slider defines how many structures to select 

591 from a cluster for representation in the molecular conformations 

592 window. In density mode, this slider defines how many points along 

593 the user-drawn path should be sampled. 

594 

595 """ 

596 

597 _max_filepath_len: int = 50 

598 stride: int = 10 

599 _max_slider_len: int = 200 

600 _cluster_col: str = "_user_selected_points" 

601 _nbins: int = 50 

602 _cluster_method: Literal["stack", "join"] = "join" 

603 _help_url: str = "https://github.com/AG-Peter/encodermap" 

604 

605 @classmethod 

606 def from_project(cls, project_name: Literal["linear_dimers"]): 

607 # Encodermap imports 

608 from encodermap import load_project 

609 

610 trajs, autoencoder = load_project( 

611 project_name, 

612 traj=-1, 

613 load_autoencoder=True, 

614 ) 

615 return cls(autoencoder=autoencoder, trajs=trajs) 

616 

617 def __init__( 

618 self, 

619 autoencoder: Optional[AutoencoderClass] = None, 

620 trajs: Optional[Union[str, list[str], TrajEnsemble, SingleTraj]] = None, 

621 lowd_data: Optional[np.ndarray] = None, 

622 highd_data: Optional[np.ndarray] = None, 

623 align_string: str = "name CA", 

624 top: Optional[Union[str, list[str], Topology]] = None, 

625 ball_and_stick: bool = False, 

626 histogram_type: Union[None, Literal["free_energy", "density"]] = "free_energy", 

627 superpose: bool = True, 

628 ref_align_string: str = "name CA", 

629 base_traj: Optional[Trajectory] = None, 

630 ): 

631 """Instantiate the InteractivePlotting class. 

632 

633 Note: 

634 It is recommended to assign an instance of this class to a variable 

635 to safe variables from garbage collection:: 

636 

637 sess = em.InteractivePlotting() 

638 

639 Args: 

640 autoencoder (Optional[AutoencoderClass]): An instance of any of 

641 EncoderMap's autoencoder classes (`Autoencoder`, `EncoderMap`, 

642 `DihedralEncoderMap`, `AngleDihedralCartesianEncoderMap`). 

643 trajs (Optional[Union[str, list[str], TrajEnsemble, SingleTraj]]): The 

644 trajectory data to use this session. Molecular conformations are 

645 selected from these trajectories. Can be one of EncoderMap's 

646 trajectory data containers (`SingleTraj`, `TrajEnsemble`). Can 

647 also be a str or a list of str, that point to trajectory files 

648 (.xtc, .dcd, .h5, .pdb, .gro). Can also be None. In this case 

649 the `autoencoder` argument is expected to be a 

650 `AngleDihedralCartesianEncoderMap`, that is expected to contain 

651 the trajs. Defaults to None. 

652 lowd_data (Optional[np.ndarray]): The low-dimensional data to use 

653 for this session. If not provided low-dimensional data will be 

654 inferred from either `trajs` or `autoencoder`. Defaults to None. 

655 highd_data (Optional[np.ndarray]): The high-dimensional data to use 

656 for this session. If not provided high-dimensional data will be 

657 inferred from either `trajs` or `autoencoder`. Defaults to None. 

658 align_string (str): The alignment string to superimpose the 

659 structures of selected clusters. See 

660 https://mdtraj.org/1.9.4/atom_selection.html for info on how 

661 this string affects the selected atoms. Defaults to 'name CA'. 

662 top (Optional[Union[str, list[str], Topology]]): If trajs is a str, 

663 and a trajectory file format that does not have topological 

664 information (.xtc, .dcd), this argument will be used for topology. 

665 Can be a str (file) or an instance of MDTraj's Topology. Can also 

666 be a list of str, that matches the list of str in `trajs` with 

667 the appropriate topology files. If None is provided, the trajs 

668 argument is expected to be either `SingleTraj` or `TrajEnsemble`. 

669 Defaults to None. 

670 ball_and_stick (bool): Whether to represent the structures in ball and 

671 stick representation (True) or in cartoon representation (False). 

672 Defaults to False and cartoon representation. 

673 histogram_type (Union[None, Literal["free_energy", "density"]]): Decide 

674 how to style your histogram. If None, a straight histogram (count 

675 per bin) will be plotted. If 'density' a density will be plotted. 

676 If 'free_energy', the negative natural logartihm of the density 

677 will be plotted. Defaults to 'free_energy'. 

678 superpose (bool): Whether to superpose the clustered structures. 

679 Defaults to True. 

680 ref_align_string (str): If a `base_traj` is provided, this string will 

681 be used to select the atoms to align the clustering results against. 

682 base_traj (Optional[Trajectory]): If not None, this traj will be 

683 used to align the clustered frames against. Can be used to make 

684 all clusterings be consistent in their placement in the 3d space. 

685 

686 """ 

687 self.total = 0 

688 self.cluster_output = None 

689 self.path_output = None 

690 self.align_string = align_string 

691 self.top = top 

692 self.ball_and_stick = ball_and_stick 

693 self.histogram_type = histogram_type 

694 self.superpose = superpose 

695 self.ref_align_string = ref_align_string 

696 self.base_traj = base_traj 

697 self._username = os.getlogin() 

698 

699 # set the layout 

700 self.layout = go.Layout( 

701 { 

702 "modebar_add": ["drawline", "drawopenpath", "eraseshape"], 

703 "autosize": True, 

704 "margin": { 

705 "l": 0, 

706 "r": 0, 

707 "t": 0, 

708 "b": 0, 

709 }, 

710 "shapedefaults": {"editable": False}, 

711 } 

712 ) 

713 

714 # apply nest_asyncio for saving images 

715 if _is_notebook(): 

716 # Third Party Imports 

717 import nest_asyncio 

718 

719 nest_asyncio.apply() 

720 

721 # parse the complex arrangement of args 

722 self.autoencoder = autoencoder 

723 self.main_path = Path(".").resolve() 

724 if self.autoencoder is not None: 

725 if not self.autoencoder.read_only: 

726 self.main_path = Path(self.autoencoder.p.main_path) 

727 self.trajs = self._trajs_parser(autoencoder, trajs, top) 

728 self.highd = self._highd_parser(autoencoder, highd_data, self.trajs) 

729 self.lowd = self._lowd_parser(autoencoder, lowd_data, self.trajs) 

730 self.file_arr = [] 

731 self.frame_arr = [] 

732 for t in self.trajs: 

733 self.file_arr.extend([t.traj_file for i in range(t.n_frames)]) 

734 self.frame_arr.append(t.id[:, 1]) 

735 self.file_arr = np.array(self.file_arr) 

736 self.frame_arr = np.hstack(self.frame_arr) 

737 

738 # put the data into self.trajs 

739 if self.trajs is not None: 

740 if self.highd is not None: 

741 if "highd" not in self.trajs._CVs: 

742 self.trajs.load_CVs(self.highd, "highd") 

743 if "lowd" not in self.trajs._CVs: 

744 self.trajs.load_CVs(self.lowd, "lowd") 

745 

746 # debugging stuff 

747 self._debug_main_path = str(self.main_path) 

748 if len(self._debug_main_path) > self._max_filepath_len: 

749 self._debug_main_path = ( 

750 "/" 

751 + "/".join([i[0] for i in self.main_path.parts[1:-2]]) 

752 + "/" 

753 + "/".join(self.main_path.parts[-2:]) 

754 ) 

755 

756 # set up base images 

757 self._setup_histogram() 

758 self._setup_graph() 

759 

760 def _fake_progress(self) -> None: # pragma: no cover 

761 with ProgressWidgetTqdmCompatible( 

762 container=self.container, 

763 empty=self.progbar_empty, 

764 total=10, 

765 description="Testing...", 

766 ) as self.progbar: 

767 for i in range(10): 

768 time.sleep(0.3) 

769 self.progbar.update() 

770 

771 def _setup_histogram(self): 

772 ( 

773 self.xcenters, 

774 self.ycenters, 

775 self.xedges, 

776 self.yedges, 

777 self.H, 

778 ) = get_histogram( 

779 self.lowd[:, 0], 

780 self.lowd[:, 1], 

781 bins=self._nbins, 

782 transpose=True, 

783 return_edges=True, 

784 ) 

785 self.D = to_density(self.H) 

786 self.F = to_free_energy(self.D).astype(str) 

787 

788 def _trajs_parser( 

789 self, 

790 autoencoder: AutoencoderClass, 

791 trajs: Optional[Union[str, TrajEnsemble]] = None, 

792 top: Optional[Union[str, Topology]] = None, 

793 ) -> TrajEnsemble: 

794 """Parses the input trajs and chooses what trajs to use. 

795 

796 The order of priority follows: 

797 1. The input `trajs` parameter supersedes everything. If `trajs` 

798 1.1. If an `AutoencoderClass` has been provided, the trajs are 

799 checked, whether they conform to the expected input shape. 

800 2.2. If trajs is a str, rather than a `TrajEnsemble`, the argument 

801 `top` is used to build a `TrajEnsemble` from this topology 

802 and the `trajs`. Thus, `top` can be either str or md.Topology. 

803 2. If trajs is None, the `top` argument is not used and the 

804 `TrajEnsemble` of the provided `AngleDihedralCartesianEncoderMap` 

805 is used. 

806 

807 Args: 

808 autoencoder (AutoencoderClass): The autoencoder. 

809 trajs (Optional[Union[str, TrajEnsemble]]): The trajs. 

810 top (Optional[Union[str, Topology]]): The topology. 

811 

812 Returns: 

813 TrajEnsemble: The trajectory ensemble to use in this session. 

814 

815 """ 

816 if isinstance(trajs, str): 

817 if not Path(trajs).is_file(): 

818 # Standard Library Imports 

819 import errno 

820 

821 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), trajs) 

822 if isinstance(top, str): 

823 if not Path(top).is_file(): 

824 # Standard Library Imports 

825 import errno 

826 

827 raise FileNotFoundError( 

828 errno.ENOENT, os.strerror(errno.ENOENT), top 

829 ) 

830 trajs = TrajEnsemble([trajs], [top]) 

831 elif isinstance(trajs, (list, tuple)): 

832 trajs = TrajEnsemble(trajs, top) 

833 

834 if isinstance(autoencoder, AngleDihedralCartesianEncoderMap): 

835 if trajs is None: 

836 return autoencoder.trajs 

837 else: 

838 for key, d in autoencoder.inp_CV_data.items(): 

839 assert d.shape[1:] == trajs.CVs[key].shape[1:], ( 

840 f"The shape of the CV `{key}` of the provided `trajs` " 

841 f"{trajs.CVs[key].shape[1:]} does not match the shape of " 

842 f"the train data of the provided `autoencoder` {d.shape[1:]}." 

843 ) 

844 else: 

845 if isinstance(trajs, SingleTraj): 

846 return trajs._gen_ensemble() 

847 return trajs 

848 assert ( 

849 trajs is not None 

850 ), f"Please provide a `TrajEnsemble` for the argument `trajs`." 

851 if isinstance(trajs, SingleTraj): 

852 return trajs._gen_ensemble() 

853 return trajs 

854 

855 def _highd_parser( 

856 self, 

857 autoencoder: AutoencoderClass, 

858 highd: Optional[np.ndarray] = None, 

859 trajs: Optional[TrajEnsemble] = None, 

860 ) -> np.ndarray: 

861 """Selects which source of high-dimensional data to use. 

862 

863 The order of priority follows: 

864 1. The provided `highd` np.ndarray. 

865 1.1 If an autoencoder has been provided, the high-dimensional 

866 input data will be checked with the autoencoder's input shape. 

867 2. If no high-dimensional data has been provided (`highd=None`), the 

868 high-dimensional data from the provided `trajs` will be used. 

869 3. If the autoencoder is a `AngleDihedralCartesianEncoderMap`, the 

870 trajs of this autoencoder will be used. 

871 4. As a last resort, the autoencoder's `train_data` attribute will 

872 be used if the other datasources are not provided. 

873 

874 Args: 

875 autoencoder (AutoencoderClass): The autoencoder. 

876 highd (Optional[np.ndarray]): The high dimensional data. 

877 trajs (Optional[Union[str, TrajEnsemble]]): The trajs. 

878 

879 Returns: 

880 np.ndarray: The high-dimensional data to use in this session. 

881 

882 

883 """ 

884 if ( 

885 isinstance(autoencoder, AngleDihedralCartesianEncoderMap) 

886 or autoencoder.__class__.__name__ == "AngleDihedralCartesianEncoderMap" 

887 ): 

888 if highd is not None: 

889 assert isinstance(highd, np.ndarray), ( 

890 f"The argument `highd_data` only supports None or np.ndarray. You " 

891 f"supplied {type(highd)}." 

892 ) 

893 raise Exception( 

894 f"Confirming the shape of input highd and the input shape " 

895 f"of the autoencoder model is currently not implemented." 

896 ) 

897 else: 

898 if trajs is not None: 

899 if "central_dihedrals" not in trajs._CVs: 

900 print( 

901 f"The provided `trajs`, don't have any CVs loaded. I will " 

902 f"try to use the input data of the provided autoencoder." 

903 ) 

904 sparse, highd_data, CV_dict = autoencoder.get_train_data_from_trajs( 

905 trajs, 

906 autoencoder.p, 

907 ) 

908 if sparse: 

909 highd_data = [trajs.central_dihedrals] 

910 if autoencoder.p.use_backbone_angles: 

911 highd_data.insert(0, trajs.central_angles) 

912 if autoencoder.p.use_sidechains: 

913 highd_data.append(trajs.side_dihedrals) 

914 return np.hstack(highd_data) 

915 return highd_data 

916 else: 

917 return autoencoder.train_data 

918 elif isinstance( 

919 autoencoder, (Autoencoder, EncoderMap, DihedralEncoderMap) 

920 ) or autoencoder.__class__.__name__ in [ 

921 "Autoencoder", 

922 "EncoderMap", 

923 "DihedralEncoderMap", 

924 ]: 

925 if highd is not None: 

926 assert isinstance(highd, np.ndarray), ( 

927 f"The argument `highd_data` only supports None or np.ndarray. You " 

928 f"supplied {type(highd)}." 

929 ) 

930 if isinstance(autoencoder.train_data, np.ndarray): 

931 shape = autoencoder.train_data.shape 

932 else: 

933 for d in autoencoder.train_data: 

934 break 

935 if isinstance(d, tuple) and len(d) == 2: 

936 d = d[0][-1] 

937 else: 

938 raise NotImplementedError( 

939 f"Bad `train_data` shape: {autoencoder.train_data=} " 

940 f"{d=}" 

941 ) 

942 shape = d.shape 

943 assert highd.shape[-1] == shape[-1] or autoencoder._using_hypercube, ( 

944 f"The provided np.array in argument `highd_data` has shape {highd.shape}, " 

945 f"but the autoencoder's `train_data` has shape {autoencoder.train_data.shape}." 

946 ) 

947 return highd 

948 else: 

949 if trajs is not None: 

950 if "highd" in trajs.CVs: 

951 return trajs.highd 

952 return autoencoder.train_data 

953 elif autoencoder is None: 

954 if "highd" in trajs.CVs: 

955 return trajs.highd 

956 assert highd is not None, ( 

957 f"Please provide a numpy array containing high-dimensional data " 

958 f"or load high-dimensional data into your trajs with `trajs.load_CVs`." 

959 ) 

960 return highd 

961 else: 

962 raise TypeError(f"Unknown type for autoencoder: {type(autoencoder)}.") 

963 

964 def _lowd_parser(self, autoencoder, lowd, trajs): 

965 if ( 

966 isinstance(autoencoder, AngleDihedralCartesianEncoderMap) 

967 or autoencoder.__class__.__name__ == "AngleDihedralCartesianEncoderMap" 

968 ): 

969 if lowd is not None: 

970 assert isinstance(lowd, np.ndarray), ( 

971 f"The argument `lowd_data` only supports None or np.ndarray. You " 

972 f"supplied {type(lowd)}." 

973 ) 

974 return lowd 

975 else: 

976 if trajs is not None: 

977 if "lowd" in trajs.CVs: 

978 return trajs.lowd 

979 return autoencoder.encode(trajs) 

980 elif isinstance( 

981 autoencoder, (Autoencoder, EncoderMap, DihedralEncoderMap) 

982 ) or autoencoder.__class__.__name__ in [ 

983 "Autoencoder", 

984 "EncoderMap", 

985 "DihedralEncoderMap", 

986 ]: 

987 if lowd is not None: 

988 assert isinstance(lowd, np.ndarray), ( 

989 f"The argument `lowd_data` only supports None or np.ndarray. You " 

990 f"supplied {type(lowd)}." 

991 ) 

992 return lowd 

993 else: 

994 if trajs is not None: 

995 if "lowd" in trajs.CVs: 

996 return trajs.lowd 

997 return autoencoder.encode(self.highd) 

998 elif autoencoder is None: 

999 if "lowd" in trajs.CVs: 

1000 return trajs.lowd 

1001 assert lowd is not None, ( 

1002 f"Please provide a numpy array containing low-dimensional data " 

1003 f"or load low-dimensional data into your trajs with `trajs.load_CVs`." 

1004 ) 

1005 return lowd 

1006 else: 

1007 raise TypeError(f"Unknown type for autoencoder: {type(autoencoder)}.") 

1008 

1009 @cached_property 

1010 def density(self) -> Any: 

1011 if self.histogram_type is None: 

1012 H = self.H 

1013 else: 

1014 if self.histogram_type == "density": 

1015 H = self.D 

1016 elif self.histogram_type == "free_energy": 

1017 H = self.F 

1018 else: 

1019 raise TypeError( 

1020 f"Argument `histogram_type` needs to be either of None, " 

1021 f"'density' or 'free_energy'. You supplied {self.histogram_type}." 

1022 ) 

1023 return go.Contour( 

1024 x=self.xcenters, 

1025 y=self.ycenters, 

1026 z=H, 

1027 name="", 

1028 showlegend=False, 

1029 showscale=False, 

1030 visible=True, 

1031 colorscale="Viridis", 

1032 ) 

1033 

1034 @cached_property 

1035 def scatter(self) -> Any: 

1036 """go.Scattergl: The scatter plot using the low-dimensional data.""" 

1037 # Third Party Imports 

1038 from scipy.interpolate import interp1d 

1039 

1040 if hasattr(self, "trajs"): 

1041 basenames = np.array( 

1042 [traj.basename for traj in self.trajs for i in range(traj.n_frames)] 

1043 ) 

1044 traj_nums, frame_nums = self.trajs.id.T 

1045 customdata = np.stack( 

1046 ( 

1047 basenames, 

1048 traj_nums, 

1049 frame_nums, 

1050 ), 

1051 axis=-1, 

1052 ) 

1053 hovertemplate = ( 

1054 "%{customdata[0]} (Traj %{customdata[1]}, " 

1055 "Frame %{customdata[2]}): (%{x:.2f}, %{y:.2f})" 

1056 ) 

1057 else: 

1058 customdata = None 

1059 hovertemplate = None 

1060 

1061 # map the values to the same range 

1062 values = self.F.copy().astype("float32").T 

1063 values_ma = np.ma.masked_invalid(values) 

1064 interp = interp1d( 

1065 (np.min(values_ma), np.max(values_ma)), (0.0, 1.0), bounds_error=False 

1066 ) 

1067 values = interp(values) 

1068 

1069 # fill an array with the default color 

1070 self.marker_colors = np.full((len(self.lowd),), fill_value=1.0) 

1071 

1072 # set the marker colors into the marker_colors array 

1073 for i, x_ind in enumerate(zip(self.xedges[:-1], self.xedges[1:])): 

1074 for j, y_ind in enumerate(zip(self.yedges[:-1], self.yedges[1:])): 

1075 point_ind = ( 

1076 (x_ind[0] <= self.lowd[:, 0]) 

1077 & (self.lowd[:, 0] < x_ind[1]) 

1078 & (y_ind[0] <= self.lowd[:, 1]) 

1079 & (self.lowd[:, 1] < y_ind[1]) 

1080 ) 

1081 H_value = values[i, j] 

1082 if np.isnan(H_value): 

1083 continue 

1084 self.marker_colors[point_ind] = H_value 

1085 

1086 return go.Scattergl( 

1087 mode="markers", 

1088 x=self.lowd[:, 0], 

1089 y=self.lowd[:, 1], 

1090 visible=True, 

1091 marker={ 

1092 "color": self.marker_colors, 

1093 "colorscale": "Viridis", 

1094 "size": 1, 

1095 "line": { 

1096 "width": 0, 

1097 }, 

1098 }, 

1099 # opacity=0.8, 

1100 name="", 

1101 customdata=customdata, 

1102 hovertemplate=hovertemplate, 

1103 ) 

1104 

1105 def generate(self, b): 

1106 # clear the display 

1107 self.display.outputs = [] 

1108 

1109 self.progbar_description = "Backmapping: " 

1110 

1111 # some error 

1112 if len(self.canvas_path) == 0: 

1113 with self.display: 

1114 print(f"First Draw a line onto the Density map and the hit 'Generate'.") 

1115 return 

1116 

1117 # clear the pandas area 

1118 self.pandas_info_area.outputs = [] 

1119 

1120 # instantiate the progbar 

1121 # display a message 

1122 n_points = self.slider.value 

1123 with self.display: 

1124 print(f"Generating {n_points} points. Please stand by.") 

1125 

1126 # set up progbar 

1127 with ProgressWidgetTqdmCompatible( 

1128 container=self.container, 

1129 empty=self.progbar_empty, 

1130 total=0, 

1131 description=self.progbar_description, 

1132 ) as self.progbar: 

1133 

1134 # get the path 

1135 self.path = self._canvas_path_in_data_coords() 

1136 

1137 # generate 

1138 if ( 

1139 isinstance(self.autoencoder, AngleDihedralCartesianEncoderMap) 

1140 or self.autoencoder.__class__.__name__ 

1141 == "AngleDihedralCartesianEncoderMap" 

1142 ): 

1143 self.path_output = self.autoencoder.generate( 

1144 self.path, top=self.top_selector.value, progbar=self.progbar 

1145 ) 

1146 else: 

1147 # Encodermap imports 

1148 from encodermap.misc.backmapping import mdtraj_backmapping 

1149 

1150 dihedrals = self.autoencoder.generate(self.path).numpy() 

1151 self.path_output = mdtraj_backmapping( 

1152 top=self.trajs[0].top_file, 

1153 dihedrals=dihedrals, 

1154 progbar=self.progbar, 

1155 omega=False, 

1156 ) 

1157 

1158 self.display.outputs = [] 

1159 with self.display: 

1160 print(f"Conformations generated.") 

1161 

1162 # clear progbar 

1163 self.progbar_description = "" 

1164 

1165 # create the media widget 

1166 self.media_widget = widgets.Play( 

1167 value=0, 

1168 min=0, 

1169 max=n_points, 

1170 step=1, 

1171 disabled=False, 

1172 ) 

1173 self.media_slider = widgets.IntSlider() 

1174 widgets.jslink((self.media_widget, "value"), (self.media_slider, "value")) 

1175 self.container[7, 4:] = widgets.HBox( 

1176 [self.media_widget, self.media_slider], layout={"align-content": "center"} 

1177 ) 

1178 

1179 # create the view 

1180 view = nv.show_mdtraj(self.path_output) 

1181 self.ngl_area.children = [view] 

1182 if self.ball_and_stick: 

1183 view.clear_representations() 

1184 view.add_representation("ball+stick") 

1185 self.view = view 

1186 

1187 # switch to plotly 

1188 self.path_anim_widget.data[1].x = self.path[:, 0] 

1189 self.path_anim_widget.data[1].y = self.path[:, 1] 

1190 self.path_anim_widget.data[2].x = [self.path[0, 0]] 

1191 self.path_anim_widget.data[2].y = [self.path[1, 1]] 

1192 self.container[2:6, :3] = widgets.Box( 

1193 [self.path_anim_widget], 

1194 layout=widgets.Layout( 

1195 height="auto", 

1196 width="auto", 

1197 ), 

1198 ) 

1199 

1200 # make the slider responsive 

1201 self.media_slider.observe(self.advance_path, names="value") 

1202 

1203 def advance_path(self, n): 

1204 n = n["new"] 

1205 self.view.frame = n 

1206 self.path_anim_widget.data[2].x = [self.path[n, 0]] 

1207 self.path_anim_widget.data[2].y = [self.path[n, 1]] 

1208 

1209 def cluster(self, b): 

1210 # clear the display 

1211 self.display.outputs = [] 

1212 self.progbar_description = "Clustering: " 

1213 

1214 # some error 

1215 if self.selected_point_ids.size == 0: 

1216 with self.display: 

1217 print( 

1218 f"First select some points using the Lasso or Polygon tool " 

1219 f"and then click 'cluster'." 

1220 ) 

1221 return 

1222 

1223 # clear the pandas area 

1224 self.pandas_info_area.outputs = [] 

1225 

1226 # instantiate the progbar 

1227 with ProgressWidgetTqdmCompatible( 

1228 container=self.container, 

1229 empty=self.progbar_empty, 

1230 total=0, 

1231 description=self.progbar_description, 

1232 ) as self.progbar: 

1233 # read the slider 

1234 n_points = self.slider.value 

1235 

1236 # display a message 

1237 with self.display: 

1238 print(f"Clustering {n_points} points. Please stand by.") 

1239 

1240 # clustering 

1241 if self._cluster_col not in self.trajs.CVs: 

1242 _ = np.full(self.trajs.n_frames, -1) 

1243 try: 

1244 _[self.selected_point_ids] = 0 

1245 except IndexError as e: 

1246 raise SystemExit(f"{self.selected_point_ids=}") from e 

1247 self.trajs.load_CVs(_, self._cluster_col, override=True) 

1248 else: 

1249 _ = self.trajs.CVs[self._cluster_col] 

1250 max_ = _.max() 

1251 _[self.selected_point_ids] = max_ + 1 

1252 self.trajs.load_CVs(_, self._cluster_col, override=True) 

1253 self.selected_point_ids = np.array([]).astype(int) 

1254 

1255 self.cluster_output = self.trajs.cluster( 

1256 cluster_id=max(_), 

1257 col=self._cluster_col, 

1258 n_points=n_points, 

1259 ) 

1260 

1261 if self._cluster_method == "join": 

1262 self._cluster = self.cluster_output.join( 

1263 align_string=self.align_string, 

1264 superpose=self.superpose, 

1265 ref_align_string=self.ref_align_string, 

1266 base_traj=self.base_traj, 

1267 progbar=self.progbar, 

1268 ) 

1269 

1270 # nglview 

1271 total = 0 

1272 for i, val in enumerate(self._cluster.values()): 

1273 val.center_coordinates() 

1274 for j, frame in enumerate(val): 

1275 if i == j == 0: 

1276 view = nv.show_mdtraj(frame, gui=False) 

1277 else: 

1278 view.add_trajectory(frame) 

1279 total += 1 

1280 self.ngl_area.children = [view] 

1281 if self.ball_and_stick: 

1282 view.clear_representations() 

1283 for i in range(total): 

1284 view.add_representation("ball+stick", component=i) 

1285 else: 

1286 self._cluster = self.cluster_output.stack( 

1287 align_string=self.align_string, 

1288 superpose=self.superpose, 

1289 ref_align_string=self.ref_align_string, 

1290 base_traj=self.base_traj, 

1291 progbar=self.progbar, 

1292 ) 

1293 

1294 # nglview 

1295 self._cluster.center_coordinates() 

1296 view = nv.show_mdtraj(self._cluster, gui=False) 

1297 self.ngl_area.children = [view] 

1298 if self.ball_and_stick: 

1299 view.clear_representations() 

1300 view.add_representation("ball+stick") 

1301 

1302 # trace 

1303 if self.highd is not None: 

1304 d = self.highd[self.trajs.CVs[self._cluster_col] == max(_)] 

1305 self.trace_widget.data[0].z = d.T 

1306 

1307 # save the image, because threading is complicated in IPython 

1308 filename = Path("/tmp/tmp.png") 

1309 lock = threading.Lock() 

1310 with lock: 

1311 thread = threading.Thread( 

1312 target=render_image, 

1313 args=(view, filename), 

1314 ) 

1315 thread.daemon = True 

1316 thread.start() 

1317 

1318 # clear progbar 

1319 self.progbar_description = "" 

1320 

1321 # clear display 

1322 self.display.outputs = [] 

1323 with self.display: 

1324 print(f"Finished clustering.") 

1325 

1326 self.ngl_area.children = [view] 

1327 self.view = view 

1328 

1329 # pandas 

1330 with self.pandas_info_area: 

1331 display(self.cluster_output.dash_summary()) 

1332 

1333 def save(self, b): 

1334 if self.cluster_output is None and self.path_output is None: 

1335 self.display.outputs = [] 

1336 with self.display: 

1337 print( 

1338 "Please select a cluster or a path and hit 'Generate' or " 

1339 "'Cluster', before 'Save'." 

1340 ) 

1341 else: # path save 

1342 if self.path_output is not None: 

1343 try: 

1344 # Third Party Imports 

1345 import imageio 

1346 import moviepy 

1347 except (ModuleNotFoundError, NameError): 

1348 self.display.outputs = [] 

1349 with self.display: 

1350 print("Please install moviepy, imageio and ffmpeg") 

1351 return 

1352 with ProgressWidgetTqdmCompatible( 

1353 container=self.container, 

1354 empty=self.progbar_empty, 

1355 total=5, 

1356 description="Saving..", 

1357 ) as self.progbar: 

1358 fname = self._save_path_on_disk() 

1359 self.display.outputs = [] 

1360 with self.display: 

1361 print(f"Path saved at {fname}") 

1362 return fname 

1363 else: # cluster save 

1364 with ProgressWidgetTqdmCompatible( 

1365 container=self.container, 

1366 empty=self.progbar_empty, 

1367 total=5, 

1368 description="Saving...", 

1369 ) as self.progbar: 

1370 fname = self._save_cluster_on_disk() 

1371 self.display.outputs = [] 

1372 with self.display: 

1373 print(f"Cluster saved at {fname}") 

1374 return fname 

1375 

1376 def _save_path_on_disk(self) -> Path: 

1377 # Third Party Imports 

1378 from nglview.contrib.movie import MovieMaker 

1379 

1380 now = _datetime_windows_and_linux_compatible() 

1381 output = self.main_path / f"generated_paths/{now}" 

1382 output.mkdir(parents=True, exist_ok=True) 

1383 

1384 # define some files 

1385 # fmt: off 

1386 xtc_file = output / f"generated.xtc" 

1387 pdb_file = output / f"generated.pdb" 

1388 npy_file = output / f"path.npy" 

1389 mp4_file = output / f"animated_path.mp4" # save the cluster as h5 ensemble 

1390 png_file = output / f"path.png" # save the cluster as h5 ensemble 

1391 csv_file = output / "lowd.csv" # A csv file for later plotting the lowd 

1392 md_file = output / "README.md" # A readme filled by jinja 

1393 # fmt: on 

1394 

1395 # save the path 

1396 self.path_output.save_pdb(str(pdb_file)) 

1397 self.path_output.save_xtc(str(xtc_file)) 

1398 self.progbar.update() 

1399 

1400 # save the path 

1401 np.save(npy_file, self.path) 

1402 self.progbar.update() 

1403 

1404 # create an animation 

1405 # with tempfile.TemporaryDirectory() as td: 

1406 # td = Path(td) 

1407 # mov = MovieMaker( 

1408 # view=self.view, 

1409 # download_folder=str(td), 

1410 # # perframe_hook=self.update, 

1411 # output="my.gif", 

1412 # ) 

1413 # mov.make() 

1414 # print(list(td.glob("*"))) 

1415 self.progbar.update() 

1416 

1417 # save the lowd as csv 

1418 df = self.trajs.to_dataframe(CV=["lowd"]) 

1419 df["x"] = df.pop("LOWD FEATURE 0") 

1420 df["y"] = df.pop("LOWD FEATURE 1") 

1421 df.to_csv(csv_file) 

1422 self.progbar.update() 

1423 

1424 # save a png similar to cluster 

1425 fig = make_subplots(rows=1, cols=2) 

1426 fig.add_trace( 

1427 _plot_free_energy( 

1428 x=self.trajs.lowd[:, 0], 

1429 y=self.trajs.lod[:, 1], 

1430 cbar=True, 

1431 colorbar_x=0.45, 

1432 ), 

1433 row=1, 

1434 col=1, 

1435 ) 

1436 fig.add_trace( 

1437 plot_trajs_by_parameter( 

1438 self.trajs, 

1439 "traj_num", 

1440 type="scatter", 

1441 show=False, 

1442 ).data[0], 

1443 row=1, 

1444 col=2, 

1445 ) 

1446 fig.add_trace( 

1447 go.Scatter( 

1448 x=self.path[:, 0], 

1449 y=self.path[:, 1], 

1450 mode="lines", 

1451 name="", 

1452 ) 

1453 ) 

1454 fig.data[1].marker.colorscale = px.colors.get_colorscale("Viridis") 

1455 

1456 fig.update_layout( 

1457 { 

1458 "width": 800, 

1459 "height": 300, 

1460 "xaxis1": {"title": "x in a.u."}, 

1461 "xaxis2": {"title": "x in a.u."}, 

1462 "yaxis1": {"title": "y in a.u."}, 

1463 "yaxis2": {"title": "y in a.u."}, 

1464 "autosize": True, 

1465 "margin": { 

1466 "l": 0, 

1467 "r": 0, 

1468 "t": 0, 

1469 "b": 0, 

1470 }, 

1471 }, 

1472 ) 

1473 fig.write_image(png_file, engine="kaleido", width=1500, height=500, scale=2) 

1474 self.progbar.update() 

1475 

1476 # save a README 

1477 # Local Folder Imports 

1478 from .._version import get_versions 

1479 

1480 _ensemble_type = "single traj" 

1481 if self.trajs.__class__.__name__ == "TrajEnsemble": 

1482 if self.trajs.n_trajs > 1: 

1483 _ensemble_type = "trajectory ensemble" 

1484 

1485 info_dict = { 

1486 "platform": platform.system(), 

1487 "system_user": getpass.getuser(), 

1488 "platform_release": platform.release(), 

1489 "platform_version": platform.version(), 

1490 "architecture": platform.machine(), 

1491 "hostname": socket.gethostname(), 

1492 "ip_address": socket.gethostbyname(socket.gethostname()), 

1493 "mac_address": ":".join(re.findall("..", "%012x" % uuid.getnode())), 

1494 "processor": platform.processor(), 

1495 "ram": str(round(psutil.virtual_memory().total / (1024.0**3))) + " GB", 

1496 "pip_freeze": "\n".join(list(freeze())), 

1497 "n_frames": self.trajs.n_frames, 

1498 "n_trajs": self.trajs.n_trajs, 

1499 "n_top": len(self.trajs.top), 

1500 "encodermap_version": get_versions()["version"], 

1501 "filename": str(md_file.resolve()), 

1502 "n_points": len(self.path), 

1503 "ensemble_type": _ensemble_type, 

1504 "csv_file": str(csv_file.resolve()), 

1505 "pdb_file": str(pdb_file.resolve()), 

1506 "xtc_file": str(xtc_file.resolve()), 

1507 "npy_file": str(npy_file.resolve()), 

1508 "autoencoder_class": self.autoencoder.__class__.__name__, 

1509 "chosen_top": self.top_selector.options[self.top_selector.value], 

1510 "lowd_dim": self.lowd.shape[1], 

1511 } 

1512 # assert _check_all_templates_defined(PATH_TEMPLATE, info_dict) 

1513 template = jinja2.Template(PATH_TEMPLATE) 

1514 readme_text = template.render(info_dict) 

1515 md_file.write_text(readme_text) 

1516 self.progbar.update() 

1517 

1518 return output 

1519 

1520 def _save_cluster_on_disk(self) -> Path: 

1521 """Saves the cluster in self.cluster_output to disk. 

1522 

1523 Also writes a README.md and puts images into a directory. 

1524 

1525 """ 

1526 now = _datetime_windows_and_linux_compatible() 

1527 output = self.main_path / f"clusters/{now}" 

1528 output.mkdir(parents=True, exist_ok=True) 

1529 cluster_num = self.trajs.CVs[self._cluster_col].max() 

1530 

1531 # define some files 

1532 # fmt: off 

1533 h5_file = output / f"cluster_{cluster_num}.h5" # save the cluster as h5 ensemble 

1534 csv_file = output / f"cluster_{cluster_num}.csv" # the complete ensemble as a pandas array 

1535 md_file = output / "README.md" # A readme filled by jinja 

1536 png_name = output / f"cluster_{cluster_num}.png" # A render of the cluster 

1537 npy_file = output / f"cluster_{cluster_num}_selector.npy" # The xs and ys of the selector 

1538 # fmt: on 

1539 

1540 # save the cluster 

1541 self.cluster_output.save(h5_file) 

1542 self.progbar.update() 

1543 

1544 # save the pandas 

1545 CVs = ["lowd", self._cluster_col] 

1546 if isinstance(self.autoencoder, AngleDihedralCartesianEncoderMap): 

1547 CVs.append("central_dihedrals") 

1548 if self.autoencoder.p.use_backbone_angles: 

1549 CVs.append("central_angles") 

1550 if self.autoencoder.p.use_sidechains: 

1551 CVs.append("side_dihedrals") 

1552 df = self.trajs.to_dataframe(CV=CVs) 

1553 df["cluster_id"] = df.pop(self._cluster_col.upper() + " FEATURE") 

1554 df["x"] = df.pop("LOWD FEATURE 0") 

1555 df["y"] = df.pop("LOWD FEATURE 1") 

1556 df.to_csv(csv_file) 

1557 self.progbar.update() 

1558 

1559 # save the selector 

1560 verts = np.vstack([self.selector.xs, self.selector.ys]).T 

1561 np.save(npy_file, verts) 

1562 self.progbar.update() 

1563 

1564 # create a png 

1565 # the png is already saved in /tmp.json 

1566 with Image.open("/tmp/tmp.png") as im: 

1567 im = np.array(im).copy() 

1568 

1569 fig = make_subplots(rows=1, cols=3) 

1570 fig.add_trace( 

1571 _plot_free_energy( 

1572 x=self.trajs.lowd[:, 0], 

1573 y=self.trajs.lowd[:, 1], 

1574 ), 

1575 row=1, 

1576 col=1, 

1577 ) 

1578 with set_env(ENCODERMAP_SKIP_SCATTER_SIZE_CHECK="True"): 

1579 fig.add_trace( 

1580 plot_trajs_by_parameter( 

1581 self.trajs, 

1582 self._cluster_col, 

1583 type="scatter", 

1584 z_name_overwrite="cluster id", 

1585 show=False, 

1586 ).data[0], 

1587 row=1, 

1588 col=2, 

1589 ) 

1590 fig.add_trace( 

1591 px.imshow(im).data[0], 

1592 row=1, 

1593 col=3, 

1594 ) 

1595 

1596 fig.data[1].marker.colorscale = px.colors.get_colorscale("Viridis") 

1597 

1598 fig.update_layout( 

1599 { 

1600 "width": 1000, 

1601 "height": 300, 

1602 "xaxis1": {"title": "x in a.u."}, 

1603 "xaxis2": {"title": "x in a.u."}, 

1604 "xaxis3": { 

1605 "showticklabels": False, 

1606 }, 

1607 "yaxis1": {"title": "y in a.u."}, 

1608 "yaxis2": {"title": "y in a.u."}, 

1609 "yaxis3": { 

1610 "showticklabels": False, 

1611 }, 

1612 "coloraxis_showscale": False, 

1613 "autosize": True, 

1614 "margin": { 

1615 "l": 0, 

1616 "r": 0, 

1617 "t": 0, 

1618 "b": 0, 

1619 }, 

1620 }, 

1621 ) 

1622 fig.update_traces( 

1623 dict( 

1624 showscale=False, 

1625 coloraxis=None, 

1626 ), 

1627 selector={"type": "heatmap"}, 

1628 ) 

1629 fig.write_image(png_name, engine="kaleido", width=1500, height=500, scale=2) 

1630 self.progbar.update() 

1631 

1632 # save a README 

1633 # Local Folder Imports 

1634 from .._version import get_versions 

1635 

1636 info_dict = { 

1637 "platform": platform.system(), 

1638 "system_user": getpass.getuser(), 

1639 "platform_release": platform.release(), 

1640 "platform_version": platform.version(), 

1641 "architecture": platform.machine(), 

1642 "hostname": socket.gethostname(), 

1643 "ip_address": socket.gethostbyname(socket.gethostname()), 

1644 "mac_address": ":".join(re.findall("..", "%012x" % uuid.getnode())), 

1645 "processor": platform.processor(), 

1646 "ram": str(round(psutil.virtual_memory().total / (1024.0**3))) + " GB", 

1647 "pip_freeze": "\n".join(list(freeze())), 

1648 "h5_file": str(h5_file.resolve()), 

1649 "n_frames": self.trajs.n_frames, 

1650 "n_trajs": self.trajs.n_trajs, 

1651 "n_top": len(self.trajs.top), 

1652 "cluster_num": cluster_num, 

1653 "h5_info": jinja2.Template(H5_INFO).render( 

1654 {"h5_file": str(h5_file.resolve())} 

1655 ), 

1656 "encodermap_version": get_versions()["version"], 

1657 "filename": str(md_file.resolve()), 

1658 "csv_file": str(csv_file.resolve()), 

1659 } 

1660 # assert _check_all_templates_defined(CLUSTER_TEMPLATE, info_dict) 

1661 template = jinja2.Template(CLUSTER_TEMPLATE) 

1662 readme_text = template.render( 

1663 info_dict, 

1664 ) 

1665 md_file.write_text(readme_text) 

1666 self.progbar.update() 

1667 

1668 # cleanup 

1669 del self.selector 

1670 del self.cluster_output 

1671 self.progbar.update() 

1672 return output 

1673 

1674 def scatter_on_click(self, trace, points, selector): 

1675 # clear the display 

1676 self.display.outputs = [] 

1677 self.pandas_info_area.outputs = [] 

1678 

1679 # get the id of the clicked point 

1680 point_id = points.point_inds[0] 

1681 

1682 # display a message 

1683 try: 

1684 with self.display: 

1685 print( 

1686 f"Displaying conformation {point_id} for " 

1687 f"{self.file_arr[point_id]} {self.frame_arr[point_id]}" 

1688 ) 

1689 except Exception as e: 

1690 with self.display: 

1691 print( 

1692 f"Can't display point {point_id} due to error: {e}. The " 

1693 f"shapes of the file and frame arrays are " 

1694 f"{self.file_arr.shape}, {self.frame_arr.shape}" 

1695 ) 

1696 

1697 # color the main plot 

1698 # c = self.base_colors.copy() 

1699 s = self.base_sizes.copy() 

1700 # c[point_id] = "#ff7f0e" 

1701 s[point_id] = 20 

1702 with self.figure_widget.batch_update(): 

1703 # self.scatter_data.marker.color = c 

1704 self.scatter_data.marker.size = s 

1705 

1706 # plot the trace 

1707 if self.highd is not None: 

1708 d = np.expand_dims(self.highd[point_id], 0) 

1709 self.trace_widget.data[0].z = d.T 

1710 

1711 # nglview 

1712 frame = self.trajs.get_single_frame(point_id) 

1713 traj = deepcopy(frame.traj) 

1714 traj = traj.center_coordinates() 

1715 view = nv.show_mdtraj(traj, gui=False) 

1716 if self.ball_and_stick: 

1717 view.clear_representations() 

1718 view.add_representation("ball+stick") 

1719 self.ngl_area.children = [view] 

1720 self.view = view 

1721 

1722 # pandas 

1723 with self.pandas_info_area: 

1724 display(frame.dash_summary()) 

1725 

1726 def on_select(self, trace, points, selector): 

1727 self.display.outputs = [] 

1728 self.selected_point_ids = np.concatenate( 

1729 [self.selected_point_ids, np.asarray(points.point_inds)] 

1730 ) 

1731 self.selector = deepcopy(selector) 

1732 # c = self.base_colors.copy() 

1733 s = self.base_sizes.copy() 

1734 # c[self.selected_point_ids] = "#2ca02c" 

1735 s[self.selected_point_ids] = 15 

1736 with self.figure_widget.batch_update(): 

1737 # self.scatter_data.marker.color = c 

1738 self.scatter_data.marker.size = s 

1739 with self.display: 

1740 print( 

1741 f"Selected {len(self.selected_point_ids)} points. Hit 'cluster' to view." 

1742 ) 

1743 

1744 def switch_between_density_and_scatter(self, b): 

1745 if self._graph == "scatter": 

1746 self.container[2:6, :3] = widgets.Box( 

1747 [self.canvas], 

1748 layout=widgets.Layout( 

1749 height="auto", 

1750 width="auto", 

1751 ), 

1752 ) 

1753 self.slider.value = 100 

1754 self.container[7, 4:] = self.slider 

1755 # self.figure_widget.data[0].visible = False 

1756 # self.figure_widget.data[1].visible = True 

1757 self._graph = "density" 

1758 else: 

1759 self.container[2:6, :3] = widgets.Box( 

1760 [self.figure_widget], 

1761 layout=widgets.Layout( 

1762 height="auto", 

1763 width="auto", 

1764 ), 

1765 ) 

1766 self.slider.value = 10 

1767 self.container[7, 4:] = self.slider 

1768 # self.figure_widget.data[0].visible = True 

1769 # self.figure_widget.data[1].visible = False 

1770 self._graph = "scatter" 

1771 

1772 def help(self, n): 

1773 # Third Party Imports 

1774 from IPython.display import Javascript 

1775 

1776 out = widgets.Output() 

1777 with out: 

1778 display(Javascript(f'window.open("{self._help_url.tooltip}");')) 

1779 

1780 def on_canvas_mouse_down(self, x, y): 

1781 self.canvas_drawing = True 

1782 self.canvas_position = (x, y) 

1783 self.canvas_path = [self.canvas_position] 

1784 

1785 def on_canvas_mouse_up(self, x, y): 

1786 self.canvas_drawing = False 

1787 

1788 self.display.outputs = [] 

1789 with self.display: 

1790 if len(self.trajs.top) > 1: 

1791 print( 

1792 f"Select a topology from the Dropdown menu and hit " 

1793 f"'Gnerate' to generate new molecular conformations." 

1794 ) 

1795 else: 

1796 print("Click 'Generate' to generate new molecular conformations") 

1797 

1798 def on_canvas_mouse_move(self, x, y): 

1799 if not self.canvas_drawing: 

1800 return 

1801 

1802 with hold_canvas(): 

1803 self.canvas.stroke_line( 

1804 self.canvas_position[0], self.canvas_position[1], x, y 

1805 ) 

1806 self.canvas_position = (x, y) 

1807 

1808 self.canvas_path.append(self.canvas_position) 

1809 

1810 def _canvas_path_in_data_coords(self) -> np.ndarray: 

1811 """Returns the path coordinates in data coordinates. 

1812 

1813 Returns: 

1814 np.ndarray: An array of shape (n_points, 2) containing the 

1815 data coordinates. [:, 0] are the x-coordinates and 

1816 [:, 1] are the y-coordinates. 

1817 

1818 """ 

1819 # Third Party Imports 

1820 from scipy.interpolate import interp1d 

1821 

1822 path = np.array(self.canvas_path) 

1823 x = interp1d([0, 500], [self.lowd[:, 0].min(), self.lowd[:, 0].max()]) 

1824 x = x(path[:, 0]) 

1825 y = self.figure_widget.layout.yaxis.range 

1826 y = interp1d([500, 0], [self.lowd[:, 1].min(), self.lowd[:, 1].max()]) 

1827 y = y(path[:, 1]) 

1828 verts = np.vstack([x, y]).T 

1829 distance = np.cumsum( 

1830 np.sqrt( 

1831 np.ediff1d(verts[:, 0], to_begin=0) ** 2 

1832 + np.ediff1d(verts[:, 1], to_begin=0) ** 2 

1833 ) 

1834 ) 

1835 distance = distance / distance[-1] 

1836 fx, fy = interp1d(distance, verts[:, 0]), interp1d(distance, verts[:, 1]) 

1837 alpha = np.linspace(0, 1, self.slider.value) 

1838 path = np.vstack([fx(alpha), fy(alpha)]).T 

1839 return path 

1840 

1841 def _setup_graph(self): 

1842 # text areas 

1843 self._graph = "scatter" 

1844 self.container = widgets.GridspecLayout(n_rows=10, n_columns=7, height="1000px") 

1845 self.header = widgets.HTML( 

1846 value=f"<h2>EncoderMap Dashboard for {self._username} in {self._debug_main_path}</h2>", 

1847 layout=widgets.Layout(height="auto", width="auto"), 

1848 ) 

1849 self.display = widgets.Output( 

1850 layout=widgets.Layout(height="auto", width="auto") 

1851 ) 

1852 with self.display: 

1853 print( 

1854 "Interact with the Scatter Plot to view molecular conformations. " 

1855 "Select points with the lasso tool and click 'cluster' " 

1856 "to generate a cluster. Switch to 'Density' to draw a Path and " 

1857 "generate new conformations." 

1858 ) 

1859 

1860 # the traj options 

1861 options = [ 

1862 (f"{i + 1}: " + str(top).lstrip("<mdtraj.Topology with ").rstrip(">"), i) 

1863 for i, top in enumerate(self.trajs.top) 

1864 ] 

1865 if len(self.trajs.top) == len(self.trajs.common_str): 

1866 for top, sub_trajs in self.trajs.trajs_by_top.items(): 

1867 if len(sub_trajs.common_str) != 1: 

1868 break 

1869 else: 

1870 options = [(cs, i) for i, cs in enumerate(self.trajs.common_str)] 

1871 self.top_selector = widgets.Dropdown(options=options, description="Top:") 

1872 

1873 # some placeholders 

1874 self.ngl_area = widgets.Box(layout=widgets.Layout(height="auto", width="auto")) 

1875 self.progbar_empty = widgets.Output( 

1876 layout=widgets.Layout(height="auto", width="auto") 

1877 ) 

1878 self.pandas_all_area = widgets.Output( 

1879 layout=widgets.Layout(height="auto", width="auto") 

1880 ) 

1881 self.pandas_info_area = widgets.Output( 

1882 layout=widgets.Layout(height="auto", width="auto") 

1883 ) 

1884 

1885 # slider 

1886 self.slider = widgets.IntSlider( 

1887 value=10, 

1888 min=1, 

1889 max=self._max_slider_len, 

1890 description="Size", 

1891 continuous_update=False, 

1892 layout=widgets.Layout(height="auto", width="auto"), 

1893 ) 

1894 

1895 # buttons 

1896 self.help_button = widgets.HTML( 

1897 value=( 

1898 f'<a href={self._help_url}><div class="lm-Widget jupyter-widgets ' 

1899 f'jupyter-button widget-button mod-info" style="height: 50%; ' 

1900 f"width: 100%; grid-area: widget007; margin: auto; margin-top: 25px; display: " 

1901 f'flex; align-items: center; justify-content: center;">' 

1902 f'<i class="fa fa-info"></i>Help</div></a>' 

1903 ) 

1904 ) 

1905 self.cluster_button = widgets.Button( 

1906 description="Cluster", 

1907 icon="th", 

1908 button_style="info", 

1909 layout=widgets.Layout(height="auto", width="auto"), 

1910 tooltip=( 

1911 "After selecting points with the Lasso Tool, this button will " 

1912 "display a subset of the selected point in the display area. Use " 

1913 "the 'Size' slider to choose how many representative structures of " 

1914 "the selected cluster you want to have displayed." 

1915 ), 

1916 ) 

1917 self.generate_button = widgets.Button( 

1918 description="Generate", 

1919 icon="bezier-curve", 

1920 button_style="info", 

1921 tooltip=( 

1922 "Use the decoder part of the autoencoder to create new molecular " 

1923 "conformations from a path, that you have drawn with the 'Draw " 

1924 "open freeform' Tool. The 'Size' slider will choose how many " 

1925 "conformations to create along the path." 

1926 ), 

1927 layout=widgets.Layout(height="auto", width="auto"), 

1928 ) 

1929 self.save_button = widgets.Button( 

1930 description="Save", 

1931 icon="floppy-o", 

1932 button_style="info", 

1933 layout=widgets.Layout(height="auto", width="auto"), 

1934 ) 

1935 self.density_button = widgets.Button( 

1936 description="Density", 

1937 icon="bar-chart", 

1938 button_style="info", 

1939 layout=widgets.Layout(height="auto", width="auto"), 

1940 tooltip=("This button toggles between a density and a scatter plot."), 

1941 ) 

1942 

1943 # plots 

1944 self.heatmap = go.Heatmap( 

1945 z=[], 

1946 showlegend=False, 

1947 showscale=False, 

1948 colorscale="Viridis", 

1949 hoverinfo="skip", 

1950 name="", 

1951 hovertemplate="", 

1952 ) 

1953 

1954 # this array prepares the selection 

1955 self.selected_point_ids = np.array([]).astype(int) 

1956 

1957 # set up the canvas for drawing 

1958 img = go.Figure( 

1959 data=[self.density], 

1960 layout={ 

1961 "margin": { 

1962 "t": 0, 

1963 "b": 0, 

1964 "l": 0, 

1965 "r": 0, 

1966 }, 

1967 "yaxis_visible": True, 

1968 "xaxis_visible": True, 

1969 }, 

1970 ) 

1971 stream = img.to_image(format="png", width=500, height=500) 

1972 background_image = widgets.Image( 

1973 value=stream, 

1974 format="png", 

1975 width=500, 

1976 height=500, 

1977 ) 

1978 self.canvas = Canvas(width=500, height=500) 

1979 self.canvas.draw_image(background_image) 

1980 self.canvas_drawing = False 

1981 self.canvas_position = None 

1982 self.canvas_path = [] 

1983 self.canvas.on_mouse_down(self.on_canvas_mouse_down) 

1984 self.canvas.on_mouse_move(self.on_canvas_mouse_move) 

1985 self.canvas.on_mouse_up(self.on_canvas_mouse_up) 

1986 self.canvas.stroke_style = "#749cb8" 

1987 

1988 # main figure widget 

1989 self.figure_widget = go.FigureWidget( 

1990 data=[self.scatter], 

1991 layout=self.layout, 

1992 ) 

1993 self.scatter_data = self.figure_widget.data[0] 

1994 self.base_colors = self.marker_colors 

1995 self.scatter_data.marker.color = self.marker_colors 

1996 self.base_sizes = np.array([8] * len(self.lowd)) 

1997 self.scatter_data.marker.size = self.base_sizes 

1998 

1999 # the animation widget 

2000 self.path_anim_widget = go.FigureWidget( 

2001 data=[ 

2002 self.density, 

2003 go.Scatter( 

2004 x=[0, 0], 

2005 y=[0, 0], 

2006 mode="lines", 

2007 hovertemplate="Generation Path (%{x:.2f}, %{y:.2f})", 

2008 showlegend=False, 

2009 ), 

2010 go.Scatter( 

2011 x=[0, 0], 

2012 y=[0, 0], 

2013 mode="markers", 

2014 marker_size=12, 

2015 marker_line_width=2, 

2016 hovertemplate="Current Path (%{x:.2f}, %{y:.2f})", 

2017 showlegend=False, 

2018 ), 

2019 ], 

2020 layout=self.layout, 

2021 ) 

2022 

2023 # the trace widget 

2024 if self.highd is not None: 

2025 self.trace_widget = go.FigureWidget( 

2026 data=[self.heatmap], 

2027 layout=go.Layout( 

2028 { 

2029 "width": 50, 

2030 "modebar_remove": BAD_MODEBAR_BUTTONS, 

2031 "yaxis_visible": False, 

2032 "xaxis_visible": False, 

2033 "title": "Trace", 

2034 "height": 500, 

2035 "margin": { 

2036 "t": 25, 

2037 "b": 75, 

2038 "l": 10, 

2039 "r": 10, 

2040 }, 

2041 } 

2042 ), 

2043 ) 

2044 

2045 # responsiveness 

2046 self.scatter_data.on_click(self.scatter_on_click) 

2047 self.scatter_data.on_selection(self.on_select) 

2048 self.cluster_button.on_click(self.cluster) 

2049 self.density_button.on_click(self.switch_between_density_and_scatter) 

2050 self.save_button.on_click(self.save) 

2051 self.generate_button.on_click(self.generate) 

2052 

2053 # add the elements to the grid 

2054 self.container[0, :] = self.header 

2055 self.container[1, :-1] = self.display 

2056 self.container[1, -1] = self.top_selector 

2057 self.container[2:6, :3] = widgets.Box( 

2058 [self.figure_widget], 

2059 layout=widgets.Layout( 

2060 height="auto", 

2061 width="auto", 

2062 ), 

2063 ) 

2064 if self.highd is not None: 

2065 self.container[2:6, 3] = widgets.Box( 

2066 [self.trace_widget], 

2067 layout=widgets.Layout( 

2068 height="auto", 

2069 width="auto", 

2070 ), 

2071 ) 

2072 self.container[2:6, 4:] = self.ngl_area 

2073 self.container[6, 1:] = self.progbar_empty 

2074 self.container[6, 0] = self.help_button 

2075 self.container[7, 0] = self.cluster_button 

2076 self.container[7, 1] = self.generate_button 

2077 self.container[7, 2] = self.save_button 

2078 self.container[7, 3] = self.density_button 

2079 self.container[7, 4:] = self.slider 

2080 self.container[8:, :3] = self.pandas_all_area 

2081 self.container[8:, 4:] = self.pandas_info_area 

2082 with self.pandas_all_area: 

2083 display(self.trajs.dash_summary()) 

2084 

2085 # self.container = widgets.VBox([ 

2086 # self.header, 

2087 # self.display, 

2088 # self.figure_widget, 

2089 # ]) 

2090 

2091 display(self.container) 

2092 

2093 

2094# class InteractivePlottingDep: 

2095# """Class to open up an interactive plotting window. 

2096# 

2097# Contains subclasses to handle user-clickable menus and selectors. 

2098# 

2099# Attributes: 

2100# trajs (encodermap.TrajEnsemble): The trajs passed into this class. 

2101# fig (matplotlib.figure): The figure plotted onto. If ax is passed when 

2102# this class is instantiated, the parent figure will be fetched with 

2103# self.fig = self.ax.get_figure() 

2104# ax (matplotlib.axes): The axes where the lowd data of the trajs 

2105# is plotted on. 

2106# menu_ax (matplotlib.axes): The axes where the normal menu is plotted on. 

2107# status_menu_ax (matplotlib.axes): The axes on which the status menu is plotted on. 

2108# pts (matplotlib.collections.Collection): The points which are plotted. Based on some 

2109# other class variables, the color of this collection is adjusted. 

2110# statusmenu (encodermap.plot.utils.StatusMenu): The menu containing the 

2111# status buttons. 

2112# menu (encodermap.plot.utils.Menu): The menu containing the remaining buttons. 

2113# tool (encodermap.plot.utils.SelectFromCollection): The current active 

2114# tool used to select points. This can be lasso, polygon, etc... 

2115# mode (str): Current mode of the statusmenu. 

2116# 

2117# """ 

2118# 

2119# def __init__( 

2120# self, 

2121# autoencoder, 

2122# trajs=None, 

2123# data=None, 

2124# ax=None, 

2125# align_string="name CA", 

2126# top=None, 

2127# hist=False, 

2128# scatter_kws={"s": 5}, 

2129# ball_and_stick=False, 

2130# top_index=0, 

2131# ): 

2132# """Instantiate the InteractivePlotting class. 

2133# 

2134# Args: 

2135# trajs (encodermap.TrajEnsemble): The trajs of which the lowd info 

2136# should be plotted. 

2137# ax (matplotlib.axes, optional): On what axes to plot. If no axis is provided 

2138# a new figure and axes will be created, defaults to None. 

2139# 

2140# """ 

2141# # the align string for the cluster dummy method 

2142# self.align_string = align_string 

2143# self.top = top 

2144# self.hist = hist 

2145# self.autoencoder = autoencoder 

2146# self.ball_and_stick = ball_and_stick 

2147# self.top_index = top_index 

2148# 

2149# # scatter kws 

2150# self.scatter_kws = {**{"s": 80, "alpha": 0.5}, **scatter_kws} 

2151# 

2152# # close all plots 

2153# plt.close("all") 

2154# 

2155# # decide on fate of data 

2156# if data is None: 

2157# if hasattr(trajs, "lowd"): 

2158# print("Using the attribute `lowd` of provided `trajs`") 

2159# data = trajs.lowd 

2160# elif isinstance(trajs, (TrajEnsemble, SingleTraj)) and ( 

2161# isinstance(autoencoder, AngleDihedralCartesianEncoderMap) 

2162# or autoencoder.__class__.__name__ == "AngleDihedralCartesianEncoderMap" 

2163# ): 

2164# print( 

2165# "Using the provided `autoencoder` and `trajs` to create a projection." 

2166# ) 

2167# data = autoencoder.encode(trajs) 

2168# elif isinstance(data, np.ndarray) and hasattr(autoencoder, "encode"): 

2169# print("Using the `encode` method of `autoencoder` with provided data.") 

2170# if np.any(np.isnan(data)): 

2171# # Third Party Imports 

2172# import tensorflow as tf 

2173# 

2174# indices = np.stack(np.where(~np.isnan(data))).T.astype("int64") 

2175# dense_shape = data.shape 

2176# values = data[~np.isnan(data)].flatten().astype("float32") 

2177# data = tf.sparse.SparseTensor(indices, values, dense_shape) 

2178# data = autoencoder.encode(data) 

2179# elif hasattr(autoencoder, "encode"): 

2180# print("Using the `train_data` attribute of `autoencoder`.") 

2181# data = autoencoder.encode() 

2182# else: 

2183# print("Mocking data with np.random") 

2184# np.random.seed(19680801) 

2185# data = np.random.rand(100, 2) 

2186# if data.shape[1] != 2: 

2187# print("Using provided `data` to call encoder.") 

2188# data = autoencoder.encode(data) 

2189# self.data = data 

2190# 

2191# # see what traj has been provided 

2192# if trajs is None: 

2193# self.trajs = autoencoder.trajs 

2194# else: 

2195# if isinstance(trajs, str): 

2196# self.trajs = SingleTraj(trajs, self.top, traj_num=0)._gen_ensemble() 

2197# elif isinstance(trajs, list): 

2198# self.trajs = TrajEnsemble(trajs, self.top) 

2199# else: 

2200# self.trajs = trajs 

2201# 

2202# if isinstance(trajs, SingleTraj): 

2203# if "lowd" not in self.trajs.CVs: 

2204# self.trajs.load_CV(self.data, attr_name="lowd") 

2205# else: 

2206# if "lowd" not in self.trajs.CVs: 

2207# self.trajs.load_CVs(self.data, attr_name="lowd") 

2208# 

2209# # decide what function to use to build clusters 

2210# # Decided against gen_dummy traj as get_cluster_frames works better with jinja2 

2211# self.cluster_building_fn = get_cluster_frames 

2212# 

2213# # create fig and ax 

2214# if ax is None: 

2215# # create fig and ax 

2216# subplot_kw = dict(xlim=(0, 1), ylim=(0, 1), autoscale_on=True) 

2217# self.fig, self.ax = plt.subplots( 

2218# 1, 1, figsize=(10, 8) 

2219# ) # subplot_kw=subplot_kw) 

2220# else: 

2221# self.ax = ax 

2222# self.fig = self.ax.get_figure() 

2223# 

2224# # add the axes to create the menus on 

2225# self.fig.subplots_adjust(left=0.3) 

2226# self.menu_ax = plt.axes([0.05, 0.1, 0.15, 0.35], facecolor="lightblue") 

2227# self.status_menu_ax = plt.axes( 

2228# [0.05, 0.49, 0.15, 0.35], facecolor="lightyellow" 

2229# ) 

2230# 

2231# # remove everything in these axes 

2232# self.menu_ax.axis("off") 

2233# self.status_menu_ax.axis("off") 

2234# self.tool = DummyTool() 

2235# 

2236# # plot 

2237# self.pts = self.ax.scatter(self.data[:, 0], self.data[:, 1], **self.scatter_kws) 

2238# 

2239# # hist 

2240# if self.hist: 

2241# self.ax.hist2d(*data.T, bins=400, norm=mpl.colors.LogNorm()) 

2242# 

2243# # Check whether mouse enters drawing area 

2244# # Upon entering drawing area tools are initialized based on current mode 

2245# # Leave event currently serves no purpose 

2246# self.cid_ax_enter = self.fig.canvas.mpl_connect( 

2247# "axes_enter_event", self.on_enter_ax 

2248# ) 

2249# self.cid_ax_leave = self.fig.canvas.mpl_connect( 

2250# "axes_leave_event", self.on_leave_ax 

2251# ) 

2252# 

2253# # chech button presses and compare them with the status of the menuitems 

2254# self.cid_on_click = self.fig.canvas.mpl_connect( 

2255# "button_release_event", self.on_click 

2256# ) 

2257# 

2258# # Instantiate Menu 

2259# self.statusmenu = StatusMenu(self.status_menu_ax) 

2260# self.menu = Menu(self.menu_ax) 

2261# 

2262# # Show 

2263# plt.show() 

2264# 

2265# def on_click(self, event): 

2266# """Decides whether the release event happened in the drawing area or the menu. 

2267# 

2268# Args: 

2269# event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect(). 

2270# 

2271# """ 

2272# if event.inaxes == self.ax: 

2273# self.on_click_tool(event) 

2274# else: 

2275# self.on_click_menu(event) 

2276# 

2277# def on_enter_ax(self, event): 

2278# """Chosses the tool to use when self.ax is entered, based on current mode. 

2279# 

2280# Args: 

2281# event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect(). 

2282# 

2283# """ 

2284# # print('Axis is entered') 

2285# if event.inaxes is self.ax and self.mode == "Idle": 

2286# # reset point coloration 

2287# self.pts.set_color("C0") 

2288# self.tool.disconnect() 

2289# if event.inaxes is self.ax and self.mode != "Idle": 

2290# # statusmenu 

2291# for key, item in self.statusmenu.menuitems.items(): 

2292# if self.mode == key: 

2293# method = getattr(self, key.lower()) 

2294# method() 

2295# 

2296# def on_leave_ax(self, event): 

2297# """Disconnect the current tool.""" 

2298# pass 

2299# 

2300# def on_click_tool(self, event): 

2301# """Left here for convenience if some tools need a button release event.""" 

2302# pass 

2303# 

2304# 

2305# ick_menu(self, event): 

2306# """Chooses the function to call based on what MenuItem was clicked. 

2307# 

2308# Args: 

2309# event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect(). 

2310# 

2311# """ 

2312# for key, item in self.menu.menuitems.items(): 

2313# if item.check_select(event): 

2314# method = getattr(self, key.lower().replace(" ", "_")) 

2315# method() 

2316# 

2317# def reset(self): 

2318# """Called when 'Reset' is pressed.""" 

2319# if "user_selected_points" in self.trajs.CVs: 

2320# self.trajs._CVs.drop(labels="user_selected_points") 

2321# self.__init__( 

2322# self.trajs, 

2323# self.autoencoder, 

2324# self.data, 

2325# None, 

2326# self.align_string, 

2327# self.top, 

2328# self.hist, 

2329# self.scatter_kws, 

2330# self.ball_and_stick, 

2331# ) 

2332# 

2333# def write(self): 

2334# """Called when 'Write' is pressed.""" 

2335# if self.mode == "Idle": 

2336# return 

2337# time = _datetime_windows_and_linux_compatible() 

2338# if self.mode == "Bezier" or self.mode == "Path": 

2339# os.makedirs( 

2340# f"{self.autoencoder.p.main_path}/generated_paths/", exist_ok=True 

2341# ) 

2342# fname = ( 

2343# f"{self.autoencoder.p.main_path}/generated_paths/generated_{time}.pdb" 

2344# ) 

2345# with mda.Writer(fname) as w: 

2346# for step in self.uni.trajectory: 

2347# w.write(self.uni.atoms) 

2348# self.ax.set_title( 

2349# f"Generated Path with {len(generated)} points saved at {fname}" 

2350# ) 

2351# else: 

2352# if "user_selected_points" not in self.trajs.CVs: 

2353# self.ax.set_title("First set the points before writing them to disk.") 

2354# return 

2355# max_, fname = _unpack_cluster_info( 

2356# self.trajs, 

2357# self.autoencoder.p.main_path, 

2358# self.tool, 

2359# self.dummy_traj, 

2360# self.align_string, 

2361# ) 

2362# self.ax.set_title(f"Cluster {max_} saved at {fname}") 

2363# 

2364# def set_points(self): 

2365# """Called when 'Set Points' is pressed.""" 

2366# if self.mode == "Idle": 

2367# return 

2368# if self.mode != "Idle": 

2369# if "tool" not in self.__dict__.keys(): 

2370# self.ax.set_title(f"Tool {self.mode} not yet implemented.") 

2371# return 

2372# else: 

2373# indices = self.accept() 

2374# if self.mode == "Bezier" or self.mode == "Path": 

2375# if np.unique(self.path_points, axis=0).shape[0] != 200: 

2376# self.ax.set_title( 

2377# f"Tool {self.mode} returned not the requested unique points." 

2378# ) 

2379# return 

2380# self.dummy_traj = self.autoencoder.generate( 

2381# self.path_points, backend="mdanalysis", top=self.top_index 

2382# ) 

2383# self.view = ngl.show_mdanalysis(self.dummy_traj) 

2384# if self.ball_and_stick: 

2385# self.view.clear_representations() 

2386# self.view.add_ball_and_stick() 

2387# 

2388# self.ax.set_title( 

2389# f"Generated Path with {len(self.dummy_traj.trajectory)} points is accessible as InteractivePlotting.view." 

2390# ) 

2391# return 

2392# 

2393# if indices is not None and self.mode != "Bezier" and self.mode != "Path": 

2394# self.ax.set_title( 

2395# f"Currently working on rendering the cluster. I'll let you know, when I'm finished." 

2396# ) 

2397# indices = np.asarray(indices) 

2398# 

2399# # update user defined clustering 

2400# col = "user_selected_points" 

2401# if col not in self.trajs.CVs: 

2402# _ = np.full(self.trajs.n_frames, -1) 

2403# try: 

2404# _[indices] = 0 

2405# except IndexError as e: 

2406# print(indices) 

2407# raise SystemExit from e 

2408# self.trajs.load_CVs(_, col) 

2409# else: 

2410# _ = self.trajs.CVs[col] 

2411# max_ = _.max() 

2412# _[indices] = max_ + 1 

2413# self.trajs.load_CVs(_, col) 

2414# 

2415# # change coloration of self.pts 

2416# color_palette = sns.color_palette("Paired", self.trajs.CVs[col].max() + 1) 

2417# cluster_colors = [ 

2418# (*color_palette[x], 1) if x >= 0 else (0.5, 0.5, 0.5, 0.01) 

2419# for x in self.trajs.CVs[col] 

2420# ] 

2421# self.pts.set_color(cluster_colors) 

2422# 

2423# max_ = np.max(self.trajs.CVs[col]) 

2424# self.view, self.dummy_traj = self.cluster_building_fn( 

2425# self.trajs, 

2426# max_, 

2427# nglview=True, 

2428# shorten=True, 

2429# stack_atoms=True, 

2430# col=col, 

2431# align_string=self.align_string, 

2432# ball_and_stick=self.ball_and_stick, 

2433# ) 

2434# if self.ball_and_stick: 

2435# for i in range(len(self.dummy_traj)): 

2436# self.view.clear_representations(component=i) 

2437# self.view.add_ball_and_stick(component=i) 

2438# self.ax.set_title( 

2439# f"Cluster {max_} is accessible as InteractivePlotting.view." 

2440# ) 

2441# 

2442# def render_move(self): 

2443# pass 

2444# 

2445# def lasso(self): 

2446# self.tool = SelectFromCollection(self.ax, self.pts) 

2447# 

2448# def rectangle(self): 

2449# self.tool = SelectFromCollection(self.ax, self.pts, selector=RectangleSelector) 

2450# 

2451# def ellipse(self): 

2452# print("Ellipse not yet implemented") 

2453# 

2454# def polygon(self): 

2455# textstr = "\n".join( 

2456# ( 

2457# "Select points in the figure by enclosing them within a polygon.", 

2458# # Press the 'esc' key to start a new polygon. 

2459# "Try holding the 'shift' key to move all of the vertices.", 

2460# "Try holding the 'ctrl' key to move a single vertex.", 

2461# ) 

2462# ) 

2463# 

2464# # these are matplotlib.patch.Patch properties 

2465# props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) 

2466# 

2467# # place a text box in upper left in axes coords 

2468# self.manual_text = self.ax.text( 

2469# 0.05, 

2470# 0.95, 

2471# textstr, 

2472# transform=self.ax.transAxes, 

2473# fontsize=6, 

2474# verticalalignment="top", 

2475# bbox=props, 

2476# ) 

2477# self.tool = SelectFromCollection(self.ax, self.pts, selector=PolygonSelector) 

2478# 

2479# def path(self): 

2480# pass 

2481# 

2482# def bezier(self): 

2483# line = Line2D([], [], ls="--", c="#666666", marker="x", mew=2, mec="#204a87") 

2484# self.ax.add_line(line) 

2485# self.tool = BezierBuilder(line, self.ax) 

2486# 

2487# def accept(self): 

2488# if "manual_text" in self.__dict__.keys(): 

2489# self.manual_text.set_visible(False) 

2490# del self.manual_text 

2491# if self.mode == "Bezier": 

2492# self.path_points = copy.deepcopy(self.tool.ind) 

2493# selected_indices = self.tool.ind 

2494# self.tool.disconnect() 

2495# return selected_indices 

2496# 

2497# @property 

2498# def cluster_zoomed(self): 

2499# col = "user_selected_points" 

2500# if not col in self.trajs.df.keys(): 

2501# return 

2502# max_ = np.max(self.trajs.df[col]) 

2503# _ = plot_cluster_zoomed(self.trajs, max_, col=col) 

2504# return _ 

2505# 

2506# @property 

2507# def mode(self): 

2508# return self.statusmenu.status