Coverage for encodermap/plot/interactive_plotting.py: 9%
683 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
1# -*- coding: utf-8 -*-
2# encodermap/plot/interactive_plotting.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
23################################################################################
24# Imports
25################################################################################
28# Future Imports at the top
29from __future__ import annotations
31# Standard Library Imports
32import getpass
33import os
34import platform
35import re
36import socket
37import threading
38import time
39import uuid
40from contextlib import contextmanager
41from copy import deepcopy
42from functools import cached_property
43from pathlib import Path
44from typing import TYPE_CHECKING, Any, Literal, Optional, Union
46# Third Party Imports
47import ipywidgets as widgets
48import numpy as np
49from optional_imports import _optional_import
50from pip._internal.operations.freeze import freeze
52# Encodermap imports
53from encodermap.autoencoder.autoencoder import (
54 AngleDihedralCartesianEncoderMap,
55 Autoencoder,
56 DihedralEncoderMap,
57 EncoderMap,
58)
59from encodermap.misc.misc import _datetime_windows_and_linux_compatible, _is_notebook
60from encodermap.plot.plotting import (
61 _plot_free_energy,
62 get_histogram,
63 plot_trajs_by_parameter,
64 to_density,
65 to_free_energy,
66)
67from encodermap.trajinfo.info_all import TrajEnsemble
68from encodermap.trajinfo.info_single import SingleTraj
71################################################################################
72# Optional Imports
73################################################################################
76sns = _optional_import("seaborn")
77md = _optional_import("mdtraj")
78jinja2 = _optional_import("jinja2")
79make_subplots = _optional_import("plotly", "subplots.make_subplots")
80px = _optional_import("plotly", "express")
81go = _optional_import("plotly", "graph_objects")
82Image = _optional_import("PIL", "Image")
83Canvas = _optional_import("ipycanvas", "Canvas")
84hold_canvas = _optional_import("ipycanvas", "hold_canvas")
85nv = _optional_import("nglview")
86psutil = _optional_import("psutil")
87display = _optional_import("IPython", "display.display")
88Image = _optional_import("PIL", "Image")
91################################################################################
92# Typing
93################################################################################
96if TYPE_CHECKING:
97 # Third Party Imports
98 from mdtraj import Topology, Trajectory
100 # Encodermap imports
101 from encodermap.autoencoder.autoencoder import AutoencoderClass
104################################################################################
105# Globals
106################################################################################
109__all__: list[str] = ["InteractivePlotting"]
112# fmt: off
113BAD_MODEBAR_BUTTONS = [
114 "autoScale2d", "autoscale", "editInChartStudio", "editinchartstudio",
115 "hoverCompareCartesian", "hovercompare", "lasso", "lasso2d", "orbitRotation",
116 "orbitrotation", "pan", "pan2d", "pan3d", "reset", "resetCameraDefault3d",
117 "resetCameraLastSave3d", "resetGeo", "resetSankeyGroup", "resetScale2d",
118 "resetViewMapbox", "resetViews", "resetcameradefault", "resetcameralastsave",
119 "resetsankeygroup", "resetscale", "resetview", "resetviews", "select",
120 "select2d", "sendDataToCloud", "senddatatocloud", "tableRotation",
121 "tablerotation", "toImage", "toggleHover", "toggleSpikelines", "togglehover",
122 "togglespikelines", "toimage", "zoom", "zoom2d", "zoom3d", "zoomIn2d",
123 "zoomInGeo", "zoomInMapbox", "zoomOut2d", "zoomOutGeo", "zoomOutMapbox",
124 "zoomin", "zoomout",
125]
126# fmt: on
129H5_INFO = """\
130## Loading a HDF5 file (.h5) with EncoderMap
132EncoderMap introduces a way of storing multiple trajectories (a `TrajectorEnsemble`) in a
133single file. These files can be loaded via:
135```python
136import encodermap as em
137trajs = em.TrajEnsemble.from_dataset('{{ h5_file }}')
138```
139"""
142PATH_TEMPLATE = """\
143# README for EncoderMap.InteractivePlotting generate
145You just used EncoderMap's `InteractivePlotting` and saved protein conformations generated from a path in a low-dimensional representation of a {{ ensemble_type}}. The conformations were generated using a trained neural network autoencoder (EncoderMap's {{ autoencoder_class }} class) from {{ n_points }} {{ lowd_dim }}-dimensional coordinates. The {{ ensemble_type }} contained {{ n_top }} distinct protein topologies. From these topologies, the {{ chosen_top }} was chosen to build this cluster. Find the topological information in the `.pdb` file in this directory. Look at EncoderMap's documentation at https://ag-peter.github.io/encodermap/ to learn more about Trajectory Ensembles.
147### The complete Ensemble is also present
149If you want to get more information about the clustering you carried out, you can refer to these files:
151### lowd.csv
153This `.csv` file contains info about the complete ensemble this cluster was selected from. The columns are as follows:
155| traj_num | The number of the trajectory in the full dataset. This number is 0-based. If only one trajectory is loaded, its `trajectory number` might also be `None`. |
156| --------- | ------------------------------------------------------------ |
157| frame_num | The frame number. The trajectory number and frame number can be used to unmistakably identify frames in a trajectory ensemble. Frame numbers are also 0-based. |
158| traj_file | Contains the trajectory data (file formats such as .xtc, .dcd, .h5). |
159| top_file | Contains the topology of the file (i.e. atom types, masses, residues) (file formats such as .pdb, .gro, .h5). Some trajectory files (.h5) might also contain the topology. In that case `trajectory file` and `topology` file are identical. |
160| time | The time of the frame. This can be used for time-based indexing of trajectories. EncoderMap offers the `SingleTraj.tsel[time]` accessor to distinguish it from frame-based indexing via `SingleTraj[frame]`. |
161| x | The x coordinate of the low-dimensional projection. |
162| y | The y-coordinate of the low-dimensional projection. |
164### path.npy
166This numpy array contains the (x, y)-coordinates of the low-dimensional path, that was used to generate the conformations.
168### path.png
170A nice render of the selected cluster.
172### generated.pdb and generated.xtc
174These files contain the topological (`.pdb`) and trajectory (`.xtc`)information to rebuild this path. Check out the function `encodermap.plot.plottinginteractive_path_visualization`, which can be used to display a animation of that path:
176```python
177import encodermap as em
178import numpy as np
179import pandas as pd
181path = em.load("{{ xtc_file }}", "{{ pdb_file }}")
182lowd = pd.read_csv("{{ csv_file }}")
183path = np.load("{{ npy_file }}")
185em.plot.interactive_path_visualization(
186 traj,
187 lowd,
188 path,
189)
190```
194## Rendering this document
196If you don't like to view plain markdown files with a text-viewer there are many viewers available, that are able to render markdown nicely. I am currently using ghostwriter:
198https://ghostwriter.kde.org/
200If you want to create a pdf from this document you can try a combination of pandoc, latex and groff.
202### HTML
204```bash
205pandoc {{filename}}.md -o {{filename}}.html
206```
208### Latex
210```bash
211pandoc {{filename}}.md -o {{filename}}.pdf
212```
214### Groff
216```bash
217pandoc {{filename}}.md -t ms -o {{filename}}.pdf
218```
220## Debug Info
222```
223encodermap.__version__ = {{encodermap_version}}
224system_user = {{system_user}}
225platform = {{platform}}
226platform_release = {{platform_release}}
227platform_version = {{platform_version}}
228architecture = {{architecture}}
229hostname = {{hostname}}
230ip_address = {{ip_address}}
231mac_address = {{mac_address}}
232processor = {{processor}}
233ram = {{ram}}
234pip freeze = {{pip_freeze}}
235```
237"""
240CLUSTER_TEMPLATE = """\
241# README for EncoderMap.InteractivePlotting cluster
243You just used EncoderMap's `InteractivePlotting` and saved a cluster. Here's some information about this cluster. The cluster was selected from a `TrajectoryEnsemble` containing {{ n_trajs }} trajectories, {{ n_frames }} frames and {{ n_top }} unique topologies. This cluster was assigned the number {{ cluster_num }}. The file {{ h5_file }} contains only {{ n_points }} frames, chosen as representatives for this cluster. This file can be loaded with EncoderMap's `TrajEnsemble.from_dataset('{{ h5_file }}')` method. Look at EncoderMap's documentation at https://ag-peter.github.io/encodermap/ to learn more about Trajectory Ensembles.
245### The complete Ensemble is also present
247If you want to get more information about the clustering you carried out, you can refer to these files:
249### cluster_{{ cluster_num }}.csv
251This `.csv` file contains info about the complete ensemble this cluster was selected from. The columns are as follows:
253| traj_num | The number of the trajectory in the full dataset. This number is 0-based. If only one trajectory is loaded, its `trajectory number` might also be `None`. |
254| ---------- | ------------------------------------------------------------ |
255| frame_num | The frame number. The trajectory number and frame number can be used to unmistakably identify frames in a trajectory ensemble. Frame numbers are also 0-based. |
256| traj_file | Contains the trajectory data (file formats such as .xtc, .dcd, .h5). |
257| top_file | Contains the topology of the file (i.e. atom types, masses, residues) (file formats such as .pdb, .gro, .h5). Some trajectory files (.h5) might also contain the topology. In that case `trajectory file` and `topology` file are identical. |
258| time | The time of the frame. This can be used for time-based indexing of trajectories. EncoderMap offers the `SingleTraj.tsel[time]` accessor to distinguish it from frame-based indexing via `SingleTraj[frame]`. |
259| x | The x coordinate of the low-dimensional projection. |
260| y | The y-coordinate of the low-dimensional projection. |
261| cluster_id | This column contains -1, which are points not included in a cluster (outliers). Cluster 1 is denoted by a 0 in this column. If multiple clusters have been selected this column can contain multiple integer values. For every subsequent cluster, the `cluster_id` is advanced by 1. |
263### cluster_{{ cluster_num }}_selector.npy
265This numpy array contains the (x, y)-coordinates of the selector, that was used to highlight the cluster. Be careful, this shape might not be convex, so using convex algortihms to find points inside this Polygon might not work.
267### cluster_{{ cluster_num }}.png
269A nice render of the selected cluster.
271{{ h5_info }}
273## Rendering this document
275If you don't like to view plain markdown files with a text-viewer there are many viewers available, that are able to render markdown nicely. I am currently using ghostwriter:
277https://ghostwriter.kde.org/
279If you want to create a pdf from this document you can try a combination of pandoc, latex and groff.
281### HTML
283```bash
284pandoc {{filename}}.md -o {{filename}}.html
285```
287### Latex
289```bash
290pandoc {{filename}}.md -o {{filename}}.pdf
291```
293### Groff
295```bash
296pandoc {{filename}}.md -t ms -o {{filename}}.pdf
297```
299## Debug Info
301```
302encodermap.__version__ = {{encodermap_version}}
303system_user = {{system_user}}
304platform = {{platform}}
305platform_release = {{platform_release}}
306platform_version = {{platform_version}}
307architecture = {{architecture}}
308hostname = {{hostname}}
309ip_address = {{ip_address}}
310mac_address = {{mac_address}}
311processor = {{processor}}
312ram = {{ram}}
313pip freeze = {{pip_freeze}}
315```
316"""
319################################################################################
320# Utils
321################################################################################
324@contextmanager
325def set_env(**environ):
326 """
327 Temporarily set the process environment variables.
329 >>> with set_env(PLUGINS_DIR='test/plugins'):
330 ... "PLUGINS_DIR" in os.environ
331 True
333 >>> "PLUGINS_DIR" in os.environ
334 False
336 :type environ: dict[str, unicode]
337 :param environ: Environment variables to set
338 """
339 old_environ = dict(os.environ)
340 os.environ.update(environ)
341 try:
342 yield
343 finally:
344 os.environ.clear()
345 os.environ.update(old_environ)
348def _check_all_templates_defined(template, info_dict):
349 # Standard Library Imports
350 import re
352 regex = r"\{(.*?)\}"
353 matches = re.finditer(regex, template, re.MULTILINE | re.DOTALL)
354 min_matches = []
355 for matchNum, match in enumerate(matches):
356 for groupNum in range(0, len(match.groups())):
357 min_matches.append(match.group(groupNum))
358 min_matches = list(
359 set(map(lambda x: x.lstrip("{{").rstrip("}}"), [i for i in min_matches]))
360 )
361 if all(key in info_dict for key in min_matches):
362 return True
363 else:
364 missing = set(min_matches).difference(info_dict)
365 raise Exception(
366 f"Not all expressions defined in template. Missing expressions: {missing}"
367 )
370def render_image(view: nv.NGLWidget, filename: str) -> None:
371 """Renders a nv.NGLWidget inside a thread.
373 Args:
374 view (nv.NGLWidget): The widget to be rendered.
375 filename (str): The file to render to.
377 """
378 view.clear()
379 im = view.render_image()
380 while not im.value:
381 time.sleep(0.1)
382 with open(filename, "wb") as fh:
383 fh.write(im.value.tobytes())
386def plotly_freeform_to_path(path, n_points=100):
387 # Third Party Imports
388 from scipy.interpolate import interp1d
390 verts = path.lstrip("M").split("L")
391 verts = np.array([list(map(float, v.split(","))) for v in verts])
392 verts = np.array(verts)
393 distance = np.cumsum(
394 np.sqrt(
395 np.ediff1d(verts[:, 0], to_begin=0) ** 2
396 + np.ediff1d(verts[:, 1], to_begin=0) ** 2
397 )
398 )
399 distance = distance / distance[-1]
400 fx, fy = interp1d(distance, verts[:, 0]), interp1d(distance, verts[:, 1])
401 alpha = np.linspace(0, 1, n_points)
402 path = np.vstack([fx(alpha), fy(alpha)]).T
403 return path
406################################################################################
407# Interactive Plotting
408################################################################################
411class ProgressWidgetTqdmCompatible:
412 """A jupyter widgtes `IntProgress` wrapper, that is compatible with tqdm calls.
414 Uses a contextmanager to open and close the progress bar.
416 """
418 def __init__(
419 self,
420 container: widgets.GridspecLayout,
421 empty: widgets.Output,
422 total: int,
423 description: str,
424 ) -> None:
425 """Instantiate the progress bar.
427 Args:
428 container (widgets.GridSpecLayout): An instance of a widgets.GridSpecLayouts
429 class. The progress bar will be placed in row 7 (index 6) at columns
430 2 through to the end (index 1:).
431 empty (widgtes.Output): After the progress bar closes, this object
432 will be placed at the position of the progress bar to clear it.
433 total (int): The initial total to count to.
434 description (str): The description of the progress bar.
436 """
437 self.container = container
438 self.total = total
439 self.empty = empty
440 self.description = description
441 self._calls: dict[str, dict[str, int]] = {}
442 self.print = os.getenv("ENCODERMAP_PRINT_PROG_UPDATES", "False") == "True"
444 def __enter__(self):
445 self.progbar = widgets.IntProgress(
446 value=0,
447 min=0,
448 max=self.total,
449 step=1,
450 description=self.description,
451 layout={"width": "90%"},
452 )
453 self.container[6, 1:] = self.progbar
454 return self
456 def __exit__(self, exc_type, exc_val, exc_tb):
457 self.container[6, 1:] = self.empty
458 if self.print:
459 print(self._calls)
460 del self.progbar
462 def debug_print(self) -> None:
463 """Prints debug info."""
464 print(f"WidgetProgbar {id(self)}")
465 for function, data in self._calls.items():
466 print(
467 f"{function:<15} total: {data['total']:>3} n: {data['update_calls']:>3}"
468 )
469 print("\n")
471 def update(self, n: int = 1, **kwargs) -> None:
472 """Advances the progress bar by n.
474 Args:
475 n (int): How far to advance. Defaults to 1.
477 """
478 function = kwargs.pop("function", None)
479 if function is not None:
480 if function not in self._calls:
481 self._calls[function] = {
482 "update_calls": 0,
483 "total": 0,
484 }
485 if not isinstance(n, int):
486 self.progbar.value += 1
487 else:
488 self.progbar.value += n
489 if function is not None:
490 self._calls[function]["update_calls"] += 1
492 def reset(self, total: int, **kwargs) -> None:
493 """Resets the progress bar with a new total.
495 Args:
496 total (int): New total. It should be greater than old total.
498 """
499 assert total > self.total
500 function = kwargs.pop("function", None)
501 if function is not None:
502 if function not in self._calls:
503 self._calls[function] = {
504 "update_calls": 0,
505 "total": total - self.total,
506 }
507 else:
508 self._calls[function]["total"] += total - self.total
509 self.total = total
510 self.progbar = widgets.IntProgress(
511 value=self.progbar.value,
512 min=0,
513 max=total,
514 step=1,
515 description=self.description,
516 layout={"width": "90%"},
517 )
518 self.container[6, 1:] = self.progbar
521class InteractivePlotting:
522 """EncoderMap's interactive plotting for jupyter notebooks.
524 Instantiating this class will display an interactive display in your notebook.
525 The display will look like this::
527 ┌─────────────────────┐ ┌───────────┐
528 │Display │ │Top │
529 └─────────────────────┘ └───────────┘
530 ┌─────────────┐ ┌───┐ ┌─────────────┐
531 │ │ │ │ │ │
532 │ │ │ T │ │ │
533 │ Main │ │ R │ │ Molecular │
534 │ Plotting │ │ A │ │ Conform. │
535 │ Area │ │ C │ │ Area │
536 │ │ │ E │ │ │
537 │ │ │ │ │ │
538 └─────────────┘ └───┘ └─────────────┘
539 ┌───┐ ┌─────────────────────────────┐
540 │ │ │Progress Bar │
541 └───┘ └─────────────────────────────┘
542 ┌─┐ ┌─┐ ┌─┐ ┌─┐ ┌───────────────────┐
543 │C│ │G│ │S│ │D│ │Slider │
544 └─┘ └─┘ └─┘ └─┘ └───────────────────┘
545 ┌────────────────┐ ┌───────────────┐
546 │ │ │ │
547 │ Data │ │ │
548 │ Overview │ │ │
549 │ │ │ │
550 │ │ │ │
551 └────────────────┘ └───────────────┘
553 The components do the following:
554 * Display:
555 This part will display debug information.
556 * Top (Top selector):
557 Select which topology to use when creating new
558 molecular conformations from the autoencoder network.
559 * Main plotting area:
560 In this area, a scatter plot will be displayed. The coordinates of
561 the scatter plot will be taken from the low-dimensional projection
562 of the trajectories. The data for this plotting area can be
563 taken from different sources. See the `_lowd_parser` docstring
564 for information on how the lowd data is selected. Clicking
565 on a point in the scatter plot displays the conformation of that
566 point.
567 * TRACE:
568 Displays the high-dimensinal data of selected points or clusters.
569 * Molecular conformation area:
570 Displays molecular conformations.
571 * Progress Bar:
572 Displays progress.
573 * C (Cluster button):
574 After selecting point in the main plotting area
575 with the lasso tool, hit this button to display the molecular
576 conformations of the selected cluster.
577 * G (Generate Button):
578 Switch to density using the density button.
579 Then, you can draw a freeform path into the Main plotting area.
580 Pressing the generate button will generate the appropriate molecular
581 conformations. If your data has multiple conformations, you can choose
582 which conformation to use for decoding with the top selector.
583 * S (Save button):
584 Writes either a cluster or generated path to your disk. Uses the
585 main_path of the autoencoder (the same directory as the training
586 data will be stored).
587 * D (Density button):
588 Switch the main plotting area to Density.
589 * Slider:
590 In scatter mode this slider defines how many structures to select
591 from a cluster for representation in the molecular conformations
592 window. In density mode, this slider defines how many points along
593 the user-drawn path should be sampled.
595 """
597 _max_filepath_len: int = 50
598 stride: int = 10
599 _max_slider_len: int = 200
600 _cluster_col: str = "_user_selected_points"
601 _nbins: int = 50
602 _cluster_method: Literal["stack", "join"] = "join"
603 _help_url: str = "https://github.com/AG-Peter/encodermap"
605 @classmethod
606 def from_project(cls, project_name: Literal["linear_dimers"]):
607 # Encodermap imports
608 from encodermap import load_project
610 trajs, autoencoder = load_project(
611 project_name,
612 traj=-1,
613 load_autoencoder=True,
614 )
615 return cls(autoencoder=autoencoder, trajs=trajs)
617 def __init__(
618 self,
619 autoencoder: Optional[AutoencoderClass] = None,
620 trajs: Optional[Union[str, list[str], TrajEnsemble, SingleTraj]] = None,
621 lowd_data: Optional[np.ndarray] = None,
622 highd_data: Optional[np.ndarray] = None,
623 align_string: str = "name CA",
624 top: Optional[Union[str, list[str], Topology]] = None,
625 ball_and_stick: bool = False,
626 histogram_type: Union[None, Literal["free_energy", "density"]] = "free_energy",
627 superpose: bool = True,
628 ref_align_string: str = "name CA",
629 base_traj: Optional[Trajectory] = None,
630 ):
631 """Instantiate the InteractivePlotting class.
633 Note:
634 It is recommended to assign an instance of this class to a variable
635 to safe variables from garbage collection::
637 sess = em.InteractivePlotting()
639 Args:
640 autoencoder (Optional[AutoencoderClass]): An instance of any of
641 EncoderMap's autoencoder classes (`Autoencoder`, `EncoderMap`,
642 `DihedralEncoderMap`, `AngleDihedralCartesianEncoderMap`).
643 trajs (Optional[Union[str, list[str], TrajEnsemble, SingleTraj]]): The
644 trajectory data to use this session. Molecular conformations are
645 selected from these trajectories. Can be one of EncoderMap's
646 trajectory data containers (`SingleTraj`, `TrajEnsemble`). Can
647 also be a str or a list of str, that point to trajectory files
648 (.xtc, .dcd, .h5, .pdb, .gro). Can also be None. In this case
649 the `autoencoder` argument is expected to be a
650 `AngleDihedralCartesianEncoderMap`, that is expected to contain
651 the trajs. Defaults to None.
652 lowd_data (Optional[np.ndarray]): The low-dimensional data to use
653 for this session. If not provided low-dimensional data will be
654 inferred from either `trajs` or `autoencoder`. Defaults to None.
655 highd_data (Optional[np.ndarray]): The high-dimensional data to use
656 for this session. If not provided high-dimensional data will be
657 inferred from either `trajs` or `autoencoder`. Defaults to None.
658 align_string (str): The alignment string to superimpose the
659 structures of selected clusters. See
660 https://mdtraj.org/1.9.4/atom_selection.html for info on how
661 this string affects the selected atoms. Defaults to 'name CA'.
662 top (Optional[Union[str, list[str], Topology]]): If trajs is a str,
663 and a trajectory file format that does not have topological
664 information (.xtc, .dcd), this argument will be used for topology.
665 Can be a str (file) or an instance of MDTraj's Topology. Can also
666 be a list of str, that matches the list of str in `trajs` with
667 the appropriate topology files. If None is provided, the trajs
668 argument is expected to be either `SingleTraj` or `TrajEnsemble`.
669 Defaults to None.
670 ball_and_stick (bool): Whether to represent the structures in ball and
671 stick representation (True) or in cartoon representation (False).
672 Defaults to False and cartoon representation.
673 histogram_type (Union[None, Literal["free_energy", "density"]]): Decide
674 how to style your histogram. If None, a straight histogram (count
675 per bin) will be plotted. If 'density' a density will be plotted.
676 If 'free_energy', the negative natural logartihm of the density
677 will be plotted. Defaults to 'free_energy'.
678 superpose (bool): Whether to superpose the clustered structures.
679 Defaults to True.
680 ref_align_string (str): If a `base_traj` is provided, this string will
681 be used to select the atoms to align the clustering results against.
682 base_traj (Optional[Trajectory]): If not None, this traj will be
683 used to align the clustered frames against. Can be used to make
684 all clusterings be consistent in their placement in the 3d space.
686 """
687 self.total = 0
688 self.cluster_output = None
689 self.path_output = None
690 self.align_string = align_string
691 self.top = top
692 self.ball_and_stick = ball_and_stick
693 self.histogram_type = histogram_type
694 self.superpose = superpose
695 self.ref_align_string = ref_align_string
696 self.base_traj = base_traj
697 self._username = os.getlogin()
699 # set the layout
700 self.layout = go.Layout(
701 {
702 "modebar_add": ["drawline", "drawopenpath", "eraseshape"],
703 "autosize": True,
704 "margin": {
705 "l": 0,
706 "r": 0,
707 "t": 0,
708 "b": 0,
709 },
710 "shapedefaults": {"editable": False},
711 }
712 )
714 # apply nest_asyncio for saving images
715 if _is_notebook():
716 # Third Party Imports
717 import nest_asyncio
719 nest_asyncio.apply()
721 # parse the complex arrangement of args
722 self.autoencoder = autoencoder
723 self.main_path = Path(".").resolve()
724 if self.autoencoder is not None:
725 if not self.autoencoder.read_only:
726 self.main_path = Path(self.autoencoder.p.main_path)
727 self.trajs = self._trajs_parser(autoencoder, trajs, top)
728 self.highd = self._highd_parser(autoencoder, highd_data, self.trajs)
729 self.lowd = self._lowd_parser(autoencoder, lowd_data, self.trajs)
730 self.file_arr = []
731 self.frame_arr = []
732 for t in self.trajs:
733 self.file_arr.extend([t.traj_file for i in range(t.n_frames)])
734 self.frame_arr.append(t.id[:, 1])
735 self.file_arr = np.array(self.file_arr)
736 self.frame_arr = np.hstack(self.frame_arr)
738 # put the data into self.trajs
739 if self.trajs is not None:
740 if self.highd is not None:
741 if "highd" not in self.trajs._CVs:
742 self.trajs.load_CVs(self.highd, "highd")
743 if "lowd" not in self.trajs._CVs:
744 self.trajs.load_CVs(self.lowd, "lowd")
746 # debugging stuff
747 self._debug_main_path = str(self.main_path)
748 if len(self._debug_main_path) > self._max_filepath_len:
749 self._debug_main_path = (
750 "/"
751 + "/".join([i[0] for i in self.main_path.parts[1:-2]])
752 + "/"
753 + "/".join(self.main_path.parts[-2:])
754 )
756 # set up base images
757 self._setup_histogram()
758 self._setup_graph()
760 def _fake_progress(self) -> None: # pragma: no cover
761 with ProgressWidgetTqdmCompatible(
762 container=self.container,
763 empty=self.progbar_empty,
764 total=10,
765 description="Testing...",
766 ) as self.progbar:
767 for i in range(10):
768 time.sleep(0.3)
769 self.progbar.update()
771 def _setup_histogram(self):
772 (
773 self.xcenters,
774 self.ycenters,
775 self.xedges,
776 self.yedges,
777 self.H,
778 ) = get_histogram(
779 self.lowd[:, 0],
780 self.lowd[:, 1],
781 bins=self._nbins,
782 transpose=True,
783 return_edges=True,
784 )
785 self.D = to_density(self.H)
786 self.F = to_free_energy(self.D).astype(str)
788 def _trajs_parser(
789 self,
790 autoencoder: AutoencoderClass,
791 trajs: Optional[Union[str, TrajEnsemble]] = None,
792 top: Optional[Union[str, Topology]] = None,
793 ) -> TrajEnsemble:
794 """Parses the input trajs and chooses what trajs to use.
796 The order of priority follows:
797 1. The input `trajs` parameter supersedes everything. If `trajs`
798 1.1. If an `AutoencoderClass` has been provided, the trajs are
799 checked, whether they conform to the expected input shape.
800 2.2. If trajs is a str, rather than a `TrajEnsemble`, the argument
801 `top` is used to build a `TrajEnsemble` from this topology
802 and the `trajs`. Thus, `top` can be either str or md.Topology.
803 2. If trajs is None, the `top` argument is not used and the
804 `TrajEnsemble` of the provided `AngleDihedralCartesianEncoderMap`
805 is used.
807 Args:
808 autoencoder (AutoencoderClass): The autoencoder.
809 trajs (Optional[Union[str, TrajEnsemble]]): The trajs.
810 top (Optional[Union[str, Topology]]): The topology.
812 Returns:
813 TrajEnsemble: The trajectory ensemble to use in this session.
815 """
816 if isinstance(trajs, str):
817 if not Path(trajs).is_file():
818 # Standard Library Imports
819 import errno
821 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), trajs)
822 if isinstance(top, str):
823 if not Path(top).is_file():
824 # Standard Library Imports
825 import errno
827 raise FileNotFoundError(
828 errno.ENOENT, os.strerror(errno.ENOENT), top
829 )
830 trajs = TrajEnsemble([trajs], [top])
831 elif isinstance(trajs, (list, tuple)):
832 trajs = TrajEnsemble(trajs, top)
834 if isinstance(autoencoder, AngleDihedralCartesianEncoderMap):
835 if trajs is None:
836 return autoencoder.trajs
837 else:
838 for key, d in autoencoder.inp_CV_data.items():
839 assert d.shape[1:] == trajs.CVs[key].shape[1:], (
840 f"The shape of the CV `{key}` of the provided `trajs` "
841 f"{trajs.CVs[key].shape[1:]} does not match the shape of "
842 f"the train data of the provided `autoencoder` {d.shape[1:]}."
843 )
844 else:
845 if isinstance(trajs, SingleTraj):
846 return trajs._gen_ensemble()
847 return trajs
848 assert (
849 trajs is not None
850 ), f"Please provide a `TrajEnsemble` for the argument `trajs`."
851 if isinstance(trajs, SingleTraj):
852 return trajs._gen_ensemble()
853 return trajs
855 def _highd_parser(
856 self,
857 autoencoder: AutoencoderClass,
858 highd: Optional[np.ndarray] = None,
859 trajs: Optional[TrajEnsemble] = None,
860 ) -> np.ndarray:
861 """Selects which source of high-dimensional data to use.
863 The order of priority follows:
864 1. The provided `highd` np.ndarray.
865 1.1 If an autoencoder has been provided, the high-dimensional
866 input data will be checked with the autoencoder's input shape.
867 2. If no high-dimensional data has been provided (`highd=None`), the
868 high-dimensional data from the provided `trajs` will be used.
869 3. If the autoencoder is a `AngleDihedralCartesianEncoderMap`, the
870 trajs of this autoencoder will be used.
871 4. As a last resort, the autoencoder's `train_data` attribute will
872 be used if the other datasources are not provided.
874 Args:
875 autoencoder (AutoencoderClass): The autoencoder.
876 highd (Optional[np.ndarray]): The high dimensional data.
877 trajs (Optional[Union[str, TrajEnsemble]]): The trajs.
879 Returns:
880 np.ndarray: The high-dimensional data to use in this session.
883 """
884 if (
885 isinstance(autoencoder, AngleDihedralCartesianEncoderMap)
886 or autoencoder.__class__.__name__ == "AngleDihedralCartesianEncoderMap"
887 ):
888 if highd is not None:
889 assert isinstance(highd, np.ndarray), (
890 f"The argument `highd_data` only supports None or np.ndarray. You "
891 f"supplied {type(highd)}."
892 )
893 raise Exception(
894 f"Confirming the shape of input highd and the input shape "
895 f"of the autoencoder model is currently not implemented."
896 )
897 else:
898 if trajs is not None:
899 if "central_dihedrals" not in trajs._CVs:
900 print(
901 f"The provided `trajs`, don't have any CVs loaded. I will "
902 f"try to use the input data of the provided autoencoder."
903 )
904 sparse, highd_data, CV_dict = autoencoder.get_train_data_from_trajs(
905 trajs,
906 autoencoder.p,
907 )
908 if sparse:
909 highd_data = [trajs.central_dihedrals]
910 if autoencoder.p.use_backbone_angles:
911 highd_data.insert(0, trajs.central_angles)
912 if autoencoder.p.use_sidechains:
913 highd_data.append(trajs.side_dihedrals)
914 return np.hstack(highd_data)
915 return highd_data
916 else:
917 return autoencoder.train_data
918 elif isinstance(
919 autoencoder, (Autoencoder, EncoderMap, DihedralEncoderMap)
920 ) or autoencoder.__class__.__name__ in [
921 "Autoencoder",
922 "EncoderMap",
923 "DihedralEncoderMap",
924 ]:
925 if highd is not None:
926 assert isinstance(highd, np.ndarray), (
927 f"The argument `highd_data` only supports None or np.ndarray. You "
928 f"supplied {type(highd)}."
929 )
930 if isinstance(autoencoder.train_data, np.ndarray):
931 shape = autoencoder.train_data.shape
932 else:
933 for d in autoencoder.train_data:
934 break
935 if isinstance(d, tuple) and len(d) == 2:
936 d = d[0][-1]
937 else:
938 raise NotImplementedError(
939 f"Bad `train_data` shape: {autoencoder.train_data=} "
940 f"{d=}"
941 )
942 shape = d.shape
943 assert highd.shape[-1] == shape[-1] or autoencoder._using_hypercube, (
944 f"The provided np.array in argument `highd_data` has shape {highd.shape}, "
945 f"but the autoencoder's `train_data` has shape {autoencoder.train_data.shape}."
946 )
947 return highd
948 else:
949 if trajs is not None:
950 if "highd" in trajs.CVs:
951 return trajs.highd
952 return autoencoder.train_data
953 elif autoencoder is None:
954 if "highd" in trajs.CVs:
955 return trajs.highd
956 assert highd is not None, (
957 f"Please provide a numpy array containing high-dimensional data "
958 f"or load high-dimensional data into your trajs with `trajs.load_CVs`."
959 )
960 return highd
961 else:
962 raise TypeError(f"Unknown type for autoencoder: {type(autoencoder)}.")
964 def _lowd_parser(self, autoencoder, lowd, trajs):
965 if (
966 isinstance(autoencoder, AngleDihedralCartesianEncoderMap)
967 or autoencoder.__class__.__name__ == "AngleDihedralCartesianEncoderMap"
968 ):
969 if lowd is not None:
970 assert isinstance(lowd, np.ndarray), (
971 f"The argument `lowd_data` only supports None or np.ndarray. You "
972 f"supplied {type(lowd)}."
973 )
974 return lowd
975 else:
976 if trajs is not None:
977 if "lowd" in trajs.CVs:
978 return trajs.lowd
979 return autoencoder.encode(trajs)
980 elif isinstance(
981 autoencoder, (Autoencoder, EncoderMap, DihedralEncoderMap)
982 ) or autoencoder.__class__.__name__ in [
983 "Autoencoder",
984 "EncoderMap",
985 "DihedralEncoderMap",
986 ]:
987 if lowd is not None:
988 assert isinstance(lowd, np.ndarray), (
989 f"The argument `lowd_data` only supports None or np.ndarray. You "
990 f"supplied {type(lowd)}."
991 )
992 return lowd
993 else:
994 if trajs is not None:
995 if "lowd" in trajs.CVs:
996 return trajs.lowd
997 return autoencoder.encode(self.highd)
998 elif autoencoder is None:
999 if "lowd" in trajs.CVs:
1000 return trajs.lowd
1001 assert lowd is not None, (
1002 f"Please provide a numpy array containing low-dimensional data "
1003 f"or load low-dimensional data into your trajs with `trajs.load_CVs`."
1004 )
1005 return lowd
1006 else:
1007 raise TypeError(f"Unknown type for autoencoder: {type(autoencoder)}.")
1009 @cached_property
1010 def density(self) -> Any:
1011 if self.histogram_type is None:
1012 H = self.H
1013 else:
1014 if self.histogram_type == "density":
1015 H = self.D
1016 elif self.histogram_type == "free_energy":
1017 H = self.F
1018 else:
1019 raise TypeError(
1020 f"Argument `histogram_type` needs to be either of None, "
1021 f"'density' or 'free_energy'. You supplied {self.histogram_type}."
1022 )
1023 return go.Contour(
1024 x=self.xcenters,
1025 y=self.ycenters,
1026 z=H,
1027 name="",
1028 showlegend=False,
1029 showscale=False,
1030 visible=True,
1031 colorscale="Viridis",
1032 )
1034 @cached_property
1035 def scatter(self) -> Any:
1036 """go.Scattergl: The scatter plot using the low-dimensional data."""
1037 # Third Party Imports
1038 from scipy.interpolate import interp1d
1040 if hasattr(self, "trajs"):
1041 basenames = np.array(
1042 [traj.basename for traj in self.trajs for i in range(traj.n_frames)]
1043 )
1044 traj_nums, frame_nums = self.trajs.id.T
1045 customdata = np.stack(
1046 (
1047 basenames,
1048 traj_nums,
1049 frame_nums,
1050 ),
1051 axis=-1,
1052 )
1053 hovertemplate = (
1054 "%{customdata[0]} (Traj %{customdata[1]}, "
1055 "Frame %{customdata[2]}): (%{x:.2f}, %{y:.2f})"
1056 )
1057 else:
1058 customdata = None
1059 hovertemplate = None
1061 # map the values to the same range
1062 values = self.F.copy().astype("float32").T
1063 values_ma = np.ma.masked_invalid(values)
1064 interp = interp1d(
1065 (np.min(values_ma), np.max(values_ma)), (0.0, 1.0), bounds_error=False
1066 )
1067 values = interp(values)
1069 # fill an array with the default color
1070 self.marker_colors = np.full((len(self.lowd),), fill_value=1.0)
1072 # set the marker colors into the marker_colors array
1073 for i, x_ind in enumerate(zip(self.xedges[:-1], self.xedges[1:])):
1074 for j, y_ind in enumerate(zip(self.yedges[:-1], self.yedges[1:])):
1075 point_ind = (
1076 (x_ind[0] <= self.lowd[:, 0])
1077 & (self.lowd[:, 0] < x_ind[1])
1078 & (y_ind[0] <= self.lowd[:, 1])
1079 & (self.lowd[:, 1] < y_ind[1])
1080 )
1081 H_value = values[i, j]
1082 if np.isnan(H_value):
1083 continue
1084 self.marker_colors[point_ind] = H_value
1086 return go.Scattergl(
1087 mode="markers",
1088 x=self.lowd[:, 0],
1089 y=self.lowd[:, 1],
1090 visible=True,
1091 marker={
1092 "color": self.marker_colors,
1093 "colorscale": "Viridis",
1094 "size": 1,
1095 "line": {
1096 "width": 0,
1097 },
1098 },
1099 # opacity=0.8,
1100 name="",
1101 customdata=customdata,
1102 hovertemplate=hovertemplate,
1103 )
1105 def generate(self, b):
1106 # clear the display
1107 self.display.outputs = []
1109 self.progbar_description = "Backmapping: "
1111 # some error
1112 if len(self.canvas_path) == 0:
1113 with self.display:
1114 print(f"First Draw a line onto the Density map and the hit 'Generate'.")
1115 return
1117 # clear the pandas area
1118 self.pandas_info_area.outputs = []
1120 # instantiate the progbar
1121 # display a message
1122 n_points = self.slider.value
1123 with self.display:
1124 print(f"Generating {n_points} points. Please stand by.")
1126 # set up progbar
1127 with ProgressWidgetTqdmCompatible(
1128 container=self.container,
1129 empty=self.progbar_empty,
1130 total=0,
1131 description=self.progbar_description,
1132 ) as self.progbar:
1134 # get the path
1135 self.path = self._canvas_path_in_data_coords()
1137 # generate
1138 if (
1139 isinstance(self.autoencoder, AngleDihedralCartesianEncoderMap)
1140 or self.autoencoder.__class__.__name__
1141 == "AngleDihedralCartesianEncoderMap"
1142 ):
1143 self.path_output = self.autoencoder.generate(
1144 self.path, top=self.top_selector.value, progbar=self.progbar
1145 )
1146 else:
1147 # Encodermap imports
1148 from encodermap.misc.backmapping import mdtraj_backmapping
1150 dihedrals = self.autoencoder.generate(self.path).numpy()
1151 self.path_output = mdtraj_backmapping(
1152 top=self.trajs[0].top_file,
1153 dihedrals=dihedrals,
1154 progbar=self.progbar,
1155 omega=False,
1156 )
1158 self.display.outputs = []
1159 with self.display:
1160 print(f"Conformations generated.")
1162 # clear progbar
1163 self.progbar_description = ""
1165 # create the media widget
1166 self.media_widget = widgets.Play(
1167 value=0,
1168 min=0,
1169 max=n_points,
1170 step=1,
1171 disabled=False,
1172 )
1173 self.media_slider = widgets.IntSlider()
1174 widgets.jslink((self.media_widget, "value"), (self.media_slider, "value"))
1175 self.container[7, 4:] = widgets.HBox(
1176 [self.media_widget, self.media_slider], layout={"align-content": "center"}
1177 )
1179 # create the view
1180 view = nv.show_mdtraj(self.path_output)
1181 self.ngl_area.children = [view]
1182 if self.ball_and_stick:
1183 view.clear_representations()
1184 view.add_representation("ball+stick")
1185 self.view = view
1187 # switch to plotly
1188 self.path_anim_widget.data[1].x = self.path[:, 0]
1189 self.path_anim_widget.data[1].y = self.path[:, 1]
1190 self.path_anim_widget.data[2].x = [self.path[0, 0]]
1191 self.path_anim_widget.data[2].y = [self.path[1, 1]]
1192 self.container[2:6, :3] = widgets.Box(
1193 [self.path_anim_widget],
1194 layout=widgets.Layout(
1195 height="auto",
1196 width="auto",
1197 ),
1198 )
1200 # make the slider responsive
1201 self.media_slider.observe(self.advance_path, names="value")
1203 def advance_path(self, n):
1204 n = n["new"]
1205 self.view.frame = n
1206 self.path_anim_widget.data[2].x = [self.path[n, 0]]
1207 self.path_anim_widget.data[2].y = [self.path[n, 1]]
1209 def cluster(self, b):
1210 # clear the display
1211 self.display.outputs = []
1212 self.progbar_description = "Clustering: "
1214 # some error
1215 if self.selected_point_ids.size == 0:
1216 with self.display:
1217 print(
1218 f"First select some points using the Lasso or Polygon tool "
1219 f"and then click 'cluster'."
1220 )
1221 return
1223 # clear the pandas area
1224 self.pandas_info_area.outputs = []
1226 # instantiate the progbar
1227 with ProgressWidgetTqdmCompatible(
1228 container=self.container,
1229 empty=self.progbar_empty,
1230 total=0,
1231 description=self.progbar_description,
1232 ) as self.progbar:
1233 # read the slider
1234 n_points = self.slider.value
1236 # display a message
1237 with self.display:
1238 print(f"Clustering {n_points} points. Please stand by.")
1240 # clustering
1241 if self._cluster_col not in self.trajs.CVs:
1242 _ = np.full(self.trajs.n_frames, -1)
1243 try:
1244 _[self.selected_point_ids] = 0
1245 except IndexError as e:
1246 raise SystemExit(f"{self.selected_point_ids=}") from e
1247 self.trajs.load_CVs(_, self._cluster_col, override=True)
1248 else:
1249 _ = self.trajs.CVs[self._cluster_col]
1250 max_ = _.max()
1251 _[self.selected_point_ids] = max_ + 1
1252 self.trajs.load_CVs(_, self._cluster_col, override=True)
1253 self.selected_point_ids = np.array([]).astype(int)
1255 self.cluster_output = self.trajs.cluster(
1256 cluster_id=max(_),
1257 col=self._cluster_col,
1258 n_points=n_points,
1259 )
1261 if self._cluster_method == "join":
1262 self._cluster = self.cluster_output.join(
1263 align_string=self.align_string,
1264 superpose=self.superpose,
1265 ref_align_string=self.ref_align_string,
1266 base_traj=self.base_traj,
1267 progbar=self.progbar,
1268 )
1270 # nglview
1271 total = 0
1272 for i, val in enumerate(self._cluster.values()):
1273 val.center_coordinates()
1274 for j, frame in enumerate(val):
1275 if i == j == 0:
1276 view = nv.show_mdtraj(frame, gui=False)
1277 else:
1278 view.add_trajectory(frame)
1279 total += 1
1280 self.ngl_area.children = [view]
1281 if self.ball_and_stick:
1282 view.clear_representations()
1283 for i in range(total):
1284 view.add_representation("ball+stick", component=i)
1285 else:
1286 self._cluster = self.cluster_output.stack(
1287 align_string=self.align_string,
1288 superpose=self.superpose,
1289 ref_align_string=self.ref_align_string,
1290 base_traj=self.base_traj,
1291 progbar=self.progbar,
1292 )
1294 # nglview
1295 self._cluster.center_coordinates()
1296 view = nv.show_mdtraj(self._cluster, gui=False)
1297 self.ngl_area.children = [view]
1298 if self.ball_and_stick:
1299 view.clear_representations()
1300 view.add_representation("ball+stick")
1302 # trace
1303 if self.highd is not None:
1304 d = self.highd[self.trajs.CVs[self._cluster_col] == max(_)]
1305 self.trace_widget.data[0].z = d.T
1307 # save the image, because threading is complicated in IPython
1308 filename = Path("/tmp/tmp.png")
1309 lock = threading.Lock()
1310 with lock:
1311 thread = threading.Thread(
1312 target=render_image,
1313 args=(view, filename),
1314 )
1315 thread.daemon = True
1316 thread.start()
1318 # clear progbar
1319 self.progbar_description = ""
1321 # clear display
1322 self.display.outputs = []
1323 with self.display:
1324 print(f"Finished clustering.")
1326 self.ngl_area.children = [view]
1327 self.view = view
1329 # pandas
1330 with self.pandas_info_area:
1331 display(self.cluster_output.dash_summary())
1333 def save(self, b):
1334 if self.cluster_output is None and self.path_output is None:
1335 self.display.outputs = []
1336 with self.display:
1337 print(
1338 "Please select a cluster or a path and hit 'Generate' or "
1339 "'Cluster', before 'Save'."
1340 )
1341 else: # path save
1342 if self.path_output is not None:
1343 try:
1344 # Third Party Imports
1345 import imageio
1346 import moviepy
1347 except (ModuleNotFoundError, NameError):
1348 self.display.outputs = []
1349 with self.display:
1350 print("Please install moviepy, imageio and ffmpeg")
1351 return
1352 with ProgressWidgetTqdmCompatible(
1353 container=self.container,
1354 empty=self.progbar_empty,
1355 total=5,
1356 description="Saving..",
1357 ) as self.progbar:
1358 fname = self._save_path_on_disk()
1359 self.display.outputs = []
1360 with self.display:
1361 print(f"Path saved at {fname}")
1362 return fname
1363 else: # cluster save
1364 with ProgressWidgetTqdmCompatible(
1365 container=self.container,
1366 empty=self.progbar_empty,
1367 total=5,
1368 description="Saving...",
1369 ) as self.progbar:
1370 fname = self._save_cluster_on_disk()
1371 self.display.outputs = []
1372 with self.display:
1373 print(f"Cluster saved at {fname}")
1374 return fname
1376 def _save_path_on_disk(self) -> Path:
1377 # Third Party Imports
1378 from nglview.contrib.movie import MovieMaker
1380 now = _datetime_windows_and_linux_compatible()
1381 output = self.main_path / f"generated_paths/{now}"
1382 output.mkdir(parents=True, exist_ok=True)
1384 # define some files
1385 # fmt: off
1386 xtc_file = output / f"generated.xtc"
1387 pdb_file = output / f"generated.pdb"
1388 npy_file = output / f"path.npy"
1389 mp4_file = output / f"animated_path.mp4" # save the cluster as h5 ensemble
1390 png_file = output / f"path.png" # save the cluster as h5 ensemble
1391 csv_file = output / "lowd.csv" # A csv file for later plotting the lowd
1392 md_file = output / "README.md" # A readme filled by jinja
1393 # fmt: on
1395 # save the path
1396 self.path_output.save_pdb(str(pdb_file))
1397 self.path_output.save_xtc(str(xtc_file))
1398 self.progbar.update()
1400 # save the path
1401 np.save(npy_file, self.path)
1402 self.progbar.update()
1404 # create an animation
1405 # with tempfile.TemporaryDirectory() as td:
1406 # td = Path(td)
1407 # mov = MovieMaker(
1408 # view=self.view,
1409 # download_folder=str(td),
1410 # # perframe_hook=self.update,
1411 # output="my.gif",
1412 # )
1413 # mov.make()
1414 # print(list(td.glob("*")))
1415 self.progbar.update()
1417 # save the lowd as csv
1418 df = self.trajs.to_dataframe(CV=["lowd"])
1419 df["x"] = df.pop("LOWD FEATURE 0")
1420 df["y"] = df.pop("LOWD FEATURE 1")
1421 df.to_csv(csv_file)
1422 self.progbar.update()
1424 # save a png similar to cluster
1425 fig = make_subplots(rows=1, cols=2)
1426 fig.add_trace(
1427 _plot_free_energy(
1428 x=self.trajs.lowd[:, 0],
1429 y=self.trajs.lod[:, 1],
1430 cbar=True,
1431 colorbar_x=0.45,
1432 ),
1433 row=1,
1434 col=1,
1435 )
1436 fig.add_trace(
1437 plot_trajs_by_parameter(
1438 self.trajs,
1439 "traj_num",
1440 type="scatter",
1441 show=False,
1442 ).data[0],
1443 row=1,
1444 col=2,
1445 )
1446 fig.add_trace(
1447 go.Scatter(
1448 x=self.path[:, 0],
1449 y=self.path[:, 1],
1450 mode="lines",
1451 name="",
1452 )
1453 )
1454 fig.data[1].marker.colorscale = px.colors.get_colorscale("Viridis")
1456 fig.update_layout(
1457 {
1458 "width": 800,
1459 "height": 300,
1460 "xaxis1": {"title": "x in a.u."},
1461 "xaxis2": {"title": "x in a.u."},
1462 "yaxis1": {"title": "y in a.u."},
1463 "yaxis2": {"title": "y in a.u."},
1464 "autosize": True,
1465 "margin": {
1466 "l": 0,
1467 "r": 0,
1468 "t": 0,
1469 "b": 0,
1470 },
1471 },
1472 )
1473 fig.write_image(png_file, engine="kaleido", width=1500, height=500, scale=2)
1474 self.progbar.update()
1476 # save a README
1477 # Local Folder Imports
1478 from .._version import get_versions
1480 _ensemble_type = "single traj"
1481 if self.trajs.__class__.__name__ == "TrajEnsemble":
1482 if self.trajs.n_trajs > 1:
1483 _ensemble_type = "trajectory ensemble"
1485 info_dict = {
1486 "platform": platform.system(),
1487 "system_user": getpass.getuser(),
1488 "platform_release": platform.release(),
1489 "platform_version": platform.version(),
1490 "architecture": platform.machine(),
1491 "hostname": socket.gethostname(),
1492 "ip_address": socket.gethostbyname(socket.gethostname()),
1493 "mac_address": ":".join(re.findall("..", "%012x" % uuid.getnode())),
1494 "processor": platform.processor(),
1495 "ram": str(round(psutil.virtual_memory().total / (1024.0**3))) + " GB",
1496 "pip_freeze": "\n".join(list(freeze())),
1497 "n_frames": self.trajs.n_frames,
1498 "n_trajs": self.trajs.n_trajs,
1499 "n_top": len(self.trajs.top),
1500 "encodermap_version": get_versions()["version"],
1501 "filename": str(md_file.resolve()),
1502 "n_points": len(self.path),
1503 "ensemble_type": _ensemble_type,
1504 "csv_file": str(csv_file.resolve()),
1505 "pdb_file": str(pdb_file.resolve()),
1506 "xtc_file": str(xtc_file.resolve()),
1507 "npy_file": str(npy_file.resolve()),
1508 "autoencoder_class": self.autoencoder.__class__.__name__,
1509 "chosen_top": self.top_selector.options[self.top_selector.value],
1510 "lowd_dim": self.lowd.shape[1],
1511 }
1512 # assert _check_all_templates_defined(PATH_TEMPLATE, info_dict)
1513 template = jinja2.Template(PATH_TEMPLATE)
1514 readme_text = template.render(info_dict)
1515 md_file.write_text(readme_text)
1516 self.progbar.update()
1518 return output
1520 def _save_cluster_on_disk(self) -> Path:
1521 """Saves the cluster in self.cluster_output to disk.
1523 Also writes a README.md and puts images into a directory.
1525 """
1526 now = _datetime_windows_and_linux_compatible()
1527 output = self.main_path / f"clusters/{now}"
1528 output.mkdir(parents=True, exist_ok=True)
1529 cluster_num = self.trajs.CVs[self._cluster_col].max()
1531 # define some files
1532 # fmt: off
1533 h5_file = output / f"cluster_{cluster_num}.h5" # save the cluster as h5 ensemble
1534 csv_file = output / f"cluster_{cluster_num}.csv" # the complete ensemble as a pandas array
1535 md_file = output / "README.md" # A readme filled by jinja
1536 png_name = output / f"cluster_{cluster_num}.png" # A render of the cluster
1537 npy_file = output / f"cluster_{cluster_num}_selector.npy" # The xs and ys of the selector
1538 # fmt: on
1540 # save the cluster
1541 self.cluster_output.save(h5_file)
1542 self.progbar.update()
1544 # save the pandas
1545 CVs = ["lowd", self._cluster_col]
1546 if isinstance(self.autoencoder, AngleDihedralCartesianEncoderMap):
1547 CVs.append("central_dihedrals")
1548 if self.autoencoder.p.use_backbone_angles:
1549 CVs.append("central_angles")
1550 if self.autoencoder.p.use_sidechains:
1551 CVs.append("side_dihedrals")
1552 df = self.trajs.to_dataframe(CV=CVs)
1553 df["cluster_id"] = df.pop(self._cluster_col.upper() + " FEATURE")
1554 df["x"] = df.pop("LOWD FEATURE 0")
1555 df["y"] = df.pop("LOWD FEATURE 1")
1556 df.to_csv(csv_file)
1557 self.progbar.update()
1559 # save the selector
1560 verts = np.vstack([self.selector.xs, self.selector.ys]).T
1561 np.save(npy_file, verts)
1562 self.progbar.update()
1564 # create a png
1565 # the png is already saved in /tmp.json
1566 with Image.open("/tmp/tmp.png") as im:
1567 im = np.array(im).copy()
1569 fig = make_subplots(rows=1, cols=3)
1570 fig.add_trace(
1571 _plot_free_energy(
1572 x=self.trajs.lowd[:, 0],
1573 y=self.trajs.lowd[:, 1],
1574 ),
1575 row=1,
1576 col=1,
1577 )
1578 with set_env(ENCODERMAP_SKIP_SCATTER_SIZE_CHECK="True"):
1579 fig.add_trace(
1580 plot_trajs_by_parameter(
1581 self.trajs,
1582 self._cluster_col,
1583 type="scatter",
1584 z_name_overwrite="cluster id",
1585 show=False,
1586 ).data[0],
1587 row=1,
1588 col=2,
1589 )
1590 fig.add_trace(
1591 px.imshow(im).data[0],
1592 row=1,
1593 col=3,
1594 )
1596 fig.data[1].marker.colorscale = px.colors.get_colorscale("Viridis")
1598 fig.update_layout(
1599 {
1600 "width": 1000,
1601 "height": 300,
1602 "xaxis1": {"title": "x in a.u."},
1603 "xaxis2": {"title": "x in a.u."},
1604 "xaxis3": {
1605 "showticklabels": False,
1606 },
1607 "yaxis1": {"title": "y in a.u."},
1608 "yaxis2": {"title": "y in a.u."},
1609 "yaxis3": {
1610 "showticklabels": False,
1611 },
1612 "coloraxis_showscale": False,
1613 "autosize": True,
1614 "margin": {
1615 "l": 0,
1616 "r": 0,
1617 "t": 0,
1618 "b": 0,
1619 },
1620 },
1621 )
1622 fig.update_traces(
1623 dict(
1624 showscale=False,
1625 coloraxis=None,
1626 ),
1627 selector={"type": "heatmap"},
1628 )
1629 fig.write_image(png_name, engine="kaleido", width=1500, height=500, scale=2)
1630 self.progbar.update()
1632 # save a README
1633 # Local Folder Imports
1634 from .._version import get_versions
1636 info_dict = {
1637 "platform": platform.system(),
1638 "system_user": getpass.getuser(),
1639 "platform_release": platform.release(),
1640 "platform_version": platform.version(),
1641 "architecture": platform.machine(),
1642 "hostname": socket.gethostname(),
1643 "ip_address": socket.gethostbyname(socket.gethostname()),
1644 "mac_address": ":".join(re.findall("..", "%012x" % uuid.getnode())),
1645 "processor": platform.processor(),
1646 "ram": str(round(psutil.virtual_memory().total / (1024.0**3))) + " GB",
1647 "pip_freeze": "\n".join(list(freeze())),
1648 "h5_file": str(h5_file.resolve()),
1649 "n_frames": self.trajs.n_frames,
1650 "n_trajs": self.trajs.n_trajs,
1651 "n_top": len(self.trajs.top),
1652 "cluster_num": cluster_num,
1653 "h5_info": jinja2.Template(H5_INFO).render(
1654 {"h5_file": str(h5_file.resolve())}
1655 ),
1656 "encodermap_version": get_versions()["version"],
1657 "filename": str(md_file.resolve()),
1658 "csv_file": str(csv_file.resolve()),
1659 }
1660 # assert _check_all_templates_defined(CLUSTER_TEMPLATE, info_dict)
1661 template = jinja2.Template(CLUSTER_TEMPLATE)
1662 readme_text = template.render(
1663 info_dict,
1664 )
1665 md_file.write_text(readme_text)
1666 self.progbar.update()
1668 # cleanup
1669 del self.selector
1670 del self.cluster_output
1671 self.progbar.update()
1672 return output
1674 def scatter_on_click(self, trace, points, selector):
1675 # clear the display
1676 self.display.outputs = []
1677 self.pandas_info_area.outputs = []
1679 # get the id of the clicked point
1680 point_id = points.point_inds[0]
1682 # display a message
1683 try:
1684 with self.display:
1685 print(
1686 f"Displaying conformation {point_id} for "
1687 f"{self.file_arr[point_id]} {self.frame_arr[point_id]}"
1688 )
1689 except Exception as e:
1690 with self.display:
1691 print(
1692 f"Can't display point {point_id} due to error: {e}. The "
1693 f"shapes of the file and frame arrays are "
1694 f"{self.file_arr.shape}, {self.frame_arr.shape}"
1695 )
1697 # color the main plot
1698 # c = self.base_colors.copy()
1699 s = self.base_sizes.copy()
1700 # c[point_id] = "#ff7f0e"
1701 s[point_id] = 20
1702 with self.figure_widget.batch_update():
1703 # self.scatter_data.marker.color = c
1704 self.scatter_data.marker.size = s
1706 # plot the trace
1707 if self.highd is not None:
1708 d = np.expand_dims(self.highd[point_id], 0)
1709 self.trace_widget.data[0].z = d.T
1711 # nglview
1712 frame = self.trajs.get_single_frame(point_id)
1713 traj = deepcopy(frame.traj)
1714 traj = traj.center_coordinates()
1715 view = nv.show_mdtraj(traj, gui=False)
1716 if self.ball_and_stick:
1717 view.clear_representations()
1718 view.add_representation("ball+stick")
1719 self.ngl_area.children = [view]
1720 self.view = view
1722 # pandas
1723 with self.pandas_info_area:
1724 display(frame.dash_summary())
1726 def on_select(self, trace, points, selector):
1727 self.display.outputs = []
1728 self.selected_point_ids = np.concatenate(
1729 [self.selected_point_ids, np.asarray(points.point_inds)]
1730 )
1731 self.selector = deepcopy(selector)
1732 # c = self.base_colors.copy()
1733 s = self.base_sizes.copy()
1734 # c[self.selected_point_ids] = "#2ca02c"
1735 s[self.selected_point_ids] = 15
1736 with self.figure_widget.batch_update():
1737 # self.scatter_data.marker.color = c
1738 self.scatter_data.marker.size = s
1739 with self.display:
1740 print(
1741 f"Selected {len(self.selected_point_ids)} points. Hit 'cluster' to view."
1742 )
1744 def switch_between_density_and_scatter(self, b):
1745 if self._graph == "scatter":
1746 self.container[2:6, :3] = widgets.Box(
1747 [self.canvas],
1748 layout=widgets.Layout(
1749 height="auto",
1750 width="auto",
1751 ),
1752 )
1753 self.slider.value = 100
1754 self.container[7, 4:] = self.slider
1755 # self.figure_widget.data[0].visible = False
1756 # self.figure_widget.data[1].visible = True
1757 self._graph = "density"
1758 else:
1759 self.container[2:6, :3] = widgets.Box(
1760 [self.figure_widget],
1761 layout=widgets.Layout(
1762 height="auto",
1763 width="auto",
1764 ),
1765 )
1766 self.slider.value = 10
1767 self.container[7, 4:] = self.slider
1768 # self.figure_widget.data[0].visible = True
1769 # self.figure_widget.data[1].visible = False
1770 self._graph = "scatter"
1772 def help(self, n):
1773 # Third Party Imports
1774 from IPython.display import Javascript
1776 out = widgets.Output()
1777 with out:
1778 display(Javascript(f'window.open("{self._help_url.tooltip}");'))
1780 def on_canvas_mouse_down(self, x, y):
1781 self.canvas_drawing = True
1782 self.canvas_position = (x, y)
1783 self.canvas_path = [self.canvas_position]
1785 def on_canvas_mouse_up(self, x, y):
1786 self.canvas_drawing = False
1788 self.display.outputs = []
1789 with self.display:
1790 if len(self.trajs.top) > 1:
1791 print(
1792 f"Select a topology from the Dropdown menu and hit "
1793 f"'Gnerate' to generate new molecular conformations."
1794 )
1795 else:
1796 print("Click 'Generate' to generate new molecular conformations")
1798 def on_canvas_mouse_move(self, x, y):
1799 if not self.canvas_drawing:
1800 return
1802 with hold_canvas():
1803 self.canvas.stroke_line(
1804 self.canvas_position[0], self.canvas_position[1], x, y
1805 )
1806 self.canvas_position = (x, y)
1808 self.canvas_path.append(self.canvas_position)
1810 def _canvas_path_in_data_coords(self) -> np.ndarray:
1811 """Returns the path coordinates in data coordinates.
1813 Returns:
1814 np.ndarray: An array of shape (n_points, 2) containing the
1815 data coordinates. [:, 0] are the x-coordinates and
1816 [:, 1] are the y-coordinates.
1818 """
1819 # Third Party Imports
1820 from scipy.interpolate import interp1d
1822 path = np.array(self.canvas_path)
1823 x = interp1d([0, 500], [self.lowd[:, 0].min(), self.lowd[:, 0].max()])
1824 x = x(path[:, 0])
1825 y = self.figure_widget.layout.yaxis.range
1826 y = interp1d([500, 0], [self.lowd[:, 1].min(), self.lowd[:, 1].max()])
1827 y = y(path[:, 1])
1828 verts = np.vstack([x, y]).T
1829 distance = np.cumsum(
1830 np.sqrt(
1831 np.ediff1d(verts[:, 0], to_begin=0) ** 2
1832 + np.ediff1d(verts[:, 1], to_begin=0) ** 2
1833 )
1834 )
1835 distance = distance / distance[-1]
1836 fx, fy = interp1d(distance, verts[:, 0]), interp1d(distance, verts[:, 1])
1837 alpha = np.linspace(0, 1, self.slider.value)
1838 path = np.vstack([fx(alpha), fy(alpha)]).T
1839 return path
1841 def _setup_graph(self):
1842 # text areas
1843 self._graph = "scatter"
1844 self.container = widgets.GridspecLayout(n_rows=10, n_columns=7, height="1000px")
1845 self.header = widgets.HTML(
1846 value=f"<h2>EncoderMap Dashboard for {self._username} in {self._debug_main_path}</h2>",
1847 layout=widgets.Layout(height="auto", width="auto"),
1848 )
1849 self.display = widgets.Output(
1850 layout=widgets.Layout(height="auto", width="auto")
1851 )
1852 with self.display:
1853 print(
1854 "Interact with the Scatter Plot to view molecular conformations. "
1855 "Select points with the lasso tool and click 'cluster' "
1856 "to generate a cluster. Switch to 'Density' to draw a Path and "
1857 "generate new conformations."
1858 )
1860 # the traj options
1861 options = [
1862 (f"{i + 1}: " + str(top).lstrip("<mdtraj.Topology with ").rstrip(">"), i)
1863 for i, top in enumerate(self.trajs.top)
1864 ]
1865 if len(self.trajs.top) == len(self.trajs.common_str):
1866 for top, sub_trajs in self.trajs.trajs_by_top.items():
1867 if len(sub_trajs.common_str) != 1:
1868 break
1869 else:
1870 options = [(cs, i) for i, cs in enumerate(self.trajs.common_str)]
1871 self.top_selector = widgets.Dropdown(options=options, description="Top:")
1873 # some placeholders
1874 self.ngl_area = widgets.Box(layout=widgets.Layout(height="auto", width="auto"))
1875 self.progbar_empty = widgets.Output(
1876 layout=widgets.Layout(height="auto", width="auto")
1877 )
1878 self.pandas_all_area = widgets.Output(
1879 layout=widgets.Layout(height="auto", width="auto")
1880 )
1881 self.pandas_info_area = widgets.Output(
1882 layout=widgets.Layout(height="auto", width="auto")
1883 )
1885 # slider
1886 self.slider = widgets.IntSlider(
1887 value=10,
1888 min=1,
1889 max=self._max_slider_len,
1890 description="Size",
1891 continuous_update=False,
1892 layout=widgets.Layout(height="auto", width="auto"),
1893 )
1895 # buttons
1896 self.help_button = widgets.HTML(
1897 value=(
1898 f'<a href={self._help_url}><div class="lm-Widget jupyter-widgets '
1899 f'jupyter-button widget-button mod-info" style="height: 50%; '
1900 f"width: 100%; grid-area: widget007; margin: auto; margin-top: 25px; display: "
1901 f'flex; align-items: center; justify-content: center;">'
1902 f'<i class="fa fa-info"></i>Help</div></a>'
1903 )
1904 )
1905 self.cluster_button = widgets.Button(
1906 description="Cluster",
1907 icon="th",
1908 button_style="info",
1909 layout=widgets.Layout(height="auto", width="auto"),
1910 tooltip=(
1911 "After selecting points with the Lasso Tool, this button will "
1912 "display a subset of the selected point in the display area. Use "
1913 "the 'Size' slider to choose how many representative structures of "
1914 "the selected cluster you want to have displayed."
1915 ),
1916 )
1917 self.generate_button = widgets.Button(
1918 description="Generate",
1919 icon="bezier-curve",
1920 button_style="info",
1921 tooltip=(
1922 "Use the decoder part of the autoencoder to create new molecular "
1923 "conformations from a path, that you have drawn with the 'Draw "
1924 "open freeform' Tool. The 'Size' slider will choose how many "
1925 "conformations to create along the path."
1926 ),
1927 layout=widgets.Layout(height="auto", width="auto"),
1928 )
1929 self.save_button = widgets.Button(
1930 description="Save",
1931 icon="floppy-o",
1932 button_style="info",
1933 layout=widgets.Layout(height="auto", width="auto"),
1934 )
1935 self.density_button = widgets.Button(
1936 description="Density",
1937 icon="bar-chart",
1938 button_style="info",
1939 layout=widgets.Layout(height="auto", width="auto"),
1940 tooltip=("This button toggles between a density and a scatter plot."),
1941 )
1943 # plots
1944 self.heatmap = go.Heatmap(
1945 z=[],
1946 showlegend=False,
1947 showscale=False,
1948 colorscale="Viridis",
1949 hoverinfo="skip",
1950 name="",
1951 hovertemplate="",
1952 )
1954 # this array prepares the selection
1955 self.selected_point_ids = np.array([]).astype(int)
1957 # set up the canvas for drawing
1958 img = go.Figure(
1959 data=[self.density],
1960 layout={
1961 "margin": {
1962 "t": 0,
1963 "b": 0,
1964 "l": 0,
1965 "r": 0,
1966 },
1967 "yaxis_visible": True,
1968 "xaxis_visible": True,
1969 },
1970 )
1971 stream = img.to_image(format="png", width=500, height=500)
1972 background_image = widgets.Image(
1973 value=stream,
1974 format="png",
1975 width=500,
1976 height=500,
1977 )
1978 self.canvas = Canvas(width=500, height=500)
1979 self.canvas.draw_image(background_image)
1980 self.canvas_drawing = False
1981 self.canvas_position = None
1982 self.canvas_path = []
1983 self.canvas.on_mouse_down(self.on_canvas_mouse_down)
1984 self.canvas.on_mouse_move(self.on_canvas_mouse_move)
1985 self.canvas.on_mouse_up(self.on_canvas_mouse_up)
1986 self.canvas.stroke_style = "#749cb8"
1988 # main figure widget
1989 self.figure_widget = go.FigureWidget(
1990 data=[self.scatter],
1991 layout=self.layout,
1992 )
1993 self.scatter_data = self.figure_widget.data[0]
1994 self.base_colors = self.marker_colors
1995 self.scatter_data.marker.color = self.marker_colors
1996 self.base_sizes = np.array([8] * len(self.lowd))
1997 self.scatter_data.marker.size = self.base_sizes
1999 # the animation widget
2000 self.path_anim_widget = go.FigureWidget(
2001 data=[
2002 self.density,
2003 go.Scatter(
2004 x=[0, 0],
2005 y=[0, 0],
2006 mode="lines",
2007 hovertemplate="Generation Path (%{x:.2f}, %{y:.2f})",
2008 showlegend=False,
2009 ),
2010 go.Scatter(
2011 x=[0, 0],
2012 y=[0, 0],
2013 mode="markers",
2014 marker_size=12,
2015 marker_line_width=2,
2016 hovertemplate="Current Path (%{x:.2f}, %{y:.2f})",
2017 showlegend=False,
2018 ),
2019 ],
2020 layout=self.layout,
2021 )
2023 # the trace widget
2024 if self.highd is not None:
2025 self.trace_widget = go.FigureWidget(
2026 data=[self.heatmap],
2027 layout=go.Layout(
2028 {
2029 "width": 50,
2030 "modebar_remove": BAD_MODEBAR_BUTTONS,
2031 "yaxis_visible": False,
2032 "xaxis_visible": False,
2033 "title": "Trace",
2034 "height": 500,
2035 "margin": {
2036 "t": 25,
2037 "b": 75,
2038 "l": 10,
2039 "r": 10,
2040 },
2041 }
2042 ),
2043 )
2045 # responsiveness
2046 self.scatter_data.on_click(self.scatter_on_click)
2047 self.scatter_data.on_selection(self.on_select)
2048 self.cluster_button.on_click(self.cluster)
2049 self.density_button.on_click(self.switch_between_density_and_scatter)
2050 self.save_button.on_click(self.save)
2051 self.generate_button.on_click(self.generate)
2053 # add the elements to the grid
2054 self.container[0, :] = self.header
2055 self.container[1, :-1] = self.display
2056 self.container[1, -1] = self.top_selector
2057 self.container[2:6, :3] = widgets.Box(
2058 [self.figure_widget],
2059 layout=widgets.Layout(
2060 height="auto",
2061 width="auto",
2062 ),
2063 )
2064 if self.highd is not None:
2065 self.container[2:6, 3] = widgets.Box(
2066 [self.trace_widget],
2067 layout=widgets.Layout(
2068 height="auto",
2069 width="auto",
2070 ),
2071 )
2072 self.container[2:6, 4:] = self.ngl_area
2073 self.container[6, 1:] = self.progbar_empty
2074 self.container[6, 0] = self.help_button
2075 self.container[7, 0] = self.cluster_button
2076 self.container[7, 1] = self.generate_button
2077 self.container[7, 2] = self.save_button
2078 self.container[7, 3] = self.density_button
2079 self.container[7, 4:] = self.slider
2080 self.container[8:, :3] = self.pandas_all_area
2081 self.container[8:, 4:] = self.pandas_info_area
2082 with self.pandas_all_area:
2083 display(self.trajs.dash_summary())
2085 # self.container = widgets.VBox([
2086 # self.header,
2087 # self.display,
2088 # self.figure_widget,
2089 # ])
2091 display(self.container)
2094# class InteractivePlottingDep:
2095# """Class to open up an interactive plotting window.
2096#
2097# Contains subclasses to handle user-clickable menus and selectors.
2098#
2099# Attributes:
2100# trajs (encodermap.TrajEnsemble): The trajs passed into this class.
2101# fig (matplotlib.figure): The figure plotted onto. If ax is passed when
2102# this class is instantiated, the parent figure will be fetched with
2103# self.fig = self.ax.get_figure()
2104# ax (matplotlib.axes): The axes where the lowd data of the trajs
2105# is plotted on.
2106# menu_ax (matplotlib.axes): The axes where the normal menu is plotted on.
2107# status_menu_ax (matplotlib.axes): The axes on which the status menu is plotted on.
2108# pts (matplotlib.collections.Collection): The points which are plotted. Based on some
2109# other class variables, the color of this collection is adjusted.
2110# statusmenu (encodermap.plot.utils.StatusMenu): The menu containing the
2111# status buttons.
2112# menu (encodermap.plot.utils.Menu): The menu containing the remaining buttons.
2113# tool (encodermap.plot.utils.SelectFromCollection): The current active
2114# tool used to select points. This can be lasso, polygon, etc...
2115# mode (str): Current mode of the statusmenu.
2116#
2117# """
2118#
2119# def __init__(
2120# self,
2121# autoencoder,
2122# trajs=None,
2123# data=None,
2124# ax=None,
2125# align_string="name CA",
2126# top=None,
2127# hist=False,
2128# scatter_kws={"s": 5},
2129# ball_and_stick=False,
2130# top_index=0,
2131# ):
2132# """Instantiate the InteractivePlotting class.
2133#
2134# Args:
2135# trajs (encodermap.TrajEnsemble): The trajs of which the lowd info
2136# should be plotted.
2137# ax (matplotlib.axes, optional): On what axes to plot. If no axis is provided
2138# a new figure and axes will be created, defaults to None.
2139#
2140# """
2141# # the align string for the cluster dummy method
2142# self.align_string = align_string
2143# self.top = top
2144# self.hist = hist
2145# self.autoencoder = autoencoder
2146# self.ball_and_stick = ball_and_stick
2147# self.top_index = top_index
2148#
2149# # scatter kws
2150# self.scatter_kws = {**{"s": 80, "alpha": 0.5}, **scatter_kws}
2151#
2152# # close all plots
2153# plt.close("all")
2154#
2155# # decide on fate of data
2156# if data is None:
2157# if hasattr(trajs, "lowd"):
2158# print("Using the attribute `lowd` of provided `trajs`")
2159# data = trajs.lowd
2160# elif isinstance(trajs, (TrajEnsemble, SingleTraj)) and (
2161# isinstance(autoencoder, AngleDihedralCartesianEncoderMap)
2162# or autoencoder.__class__.__name__ == "AngleDihedralCartesianEncoderMap"
2163# ):
2164# print(
2165# "Using the provided `autoencoder` and `trajs` to create a projection."
2166# )
2167# data = autoencoder.encode(trajs)
2168# elif isinstance(data, np.ndarray) and hasattr(autoencoder, "encode"):
2169# print("Using the `encode` method of `autoencoder` with provided data.")
2170# if np.any(np.isnan(data)):
2171# # Third Party Imports
2172# import tensorflow as tf
2173#
2174# indices = np.stack(np.where(~np.isnan(data))).T.astype("int64")
2175# dense_shape = data.shape
2176# values = data[~np.isnan(data)].flatten().astype("float32")
2177# data = tf.sparse.SparseTensor(indices, values, dense_shape)
2178# data = autoencoder.encode(data)
2179# elif hasattr(autoencoder, "encode"):
2180# print("Using the `train_data` attribute of `autoencoder`.")
2181# data = autoencoder.encode()
2182# else:
2183# print("Mocking data with np.random")
2184# np.random.seed(19680801)
2185# data = np.random.rand(100, 2)
2186# if data.shape[1] != 2:
2187# print("Using provided `data` to call encoder.")
2188# data = autoencoder.encode(data)
2189# self.data = data
2190#
2191# # see what traj has been provided
2192# if trajs is None:
2193# self.trajs = autoencoder.trajs
2194# else:
2195# if isinstance(trajs, str):
2196# self.trajs = SingleTraj(trajs, self.top, traj_num=0)._gen_ensemble()
2197# elif isinstance(trajs, list):
2198# self.trajs = TrajEnsemble(trajs, self.top)
2199# else:
2200# self.trajs = trajs
2201#
2202# if isinstance(trajs, SingleTraj):
2203# if "lowd" not in self.trajs.CVs:
2204# self.trajs.load_CV(self.data, attr_name="lowd")
2205# else:
2206# if "lowd" not in self.trajs.CVs:
2207# self.trajs.load_CVs(self.data, attr_name="lowd")
2208#
2209# # decide what function to use to build clusters
2210# # Decided against gen_dummy traj as get_cluster_frames works better with jinja2
2211# self.cluster_building_fn = get_cluster_frames
2212#
2213# # create fig and ax
2214# if ax is None:
2215# # create fig and ax
2216# subplot_kw = dict(xlim=(0, 1), ylim=(0, 1), autoscale_on=True)
2217# self.fig, self.ax = plt.subplots(
2218# 1, 1, figsize=(10, 8)
2219# ) # subplot_kw=subplot_kw)
2220# else:
2221# self.ax = ax
2222# self.fig = self.ax.get_figure()
2223#
2224# # add the axes to create the menus on
2225# self.fig.subplots_adjust(left=0.3)
2226# self.menu_ax = plt.axes([0.05, 0.1, 0.15, 0.35], facecolor="lightblue")
2227# self.status_menu_ax = plt.axes(
2228# [0.05, 0.49, 0.15, 0.35], facecolor="lightyellow"
2229# )
2230#
2231# # remove everything in these axes
2232# self.menu_ax.axis("off")
2233# self.status_menu_ax.axis("off")
2234# self.tool = DummyTool()
2235#
2236# # plot
2237# self.pts = self.ax.scatter(self.data[:, 0], self.data[:, 1], **self.scatter_kws)
2238#
2239# # hist
2240# if self.hist:
2241# self.ax.hist2d(*data.T, bins=400, norm=mpl.colors.LogNorm())
2242#
2243# # Check whether mouse enters drawing area
2244# # Upon entering drawing area tools are initialized based on current mode
2245# # Leave event currently serves no purpose
2246# self.cid_ax_enter = self.fig.canvas.mpl_connect(
2247# "axes_enter_event", self.on_enter_ax
2248# )
2249# self.cid_ax_leave = self.fig.canvas.mpl_connect(
2250# "axes_leave_event", self.on_leave_ax
2251# )
2252#
2253# # chech button presses and compare them with the status of the menuitems
2254# self.cid_on_click = self.fig.canvas.mpl_connect(
2255# "button_release_event", self.on_click
2256# )
2257#
2258# # Instantiate Menu
2259# self.statusmenu = StatusMenu(self.status_menu_ax)
2260# self.menu = Menu(self.menu_ax)
2261#
2262# # Show
2263# plt.show()
2264#
2265# def on_click(self, event):
2266# """Decides whether the release event happened in the drawing area or the menu.
2267#
2268# Args:
2269# event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect().
2270#
2271# """
2272# if event.inaxes == self.ax:
2273# self.on_click_tool(event)
2274# else:
2275# self.on_click_menu(event)
2276#
2277# def on_enter_ax(self, event):
2278# """Chosses the tool to use when self.ax is entered, based on current mode.
2279#
2280# Args:
2281# event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect().
2282#
2283# """
2284# # print('Axis is entered')
2285# if event.inaxes is self.ax and self.mode == "Idle":
2286# # reset point coloration
2287# self.pts.set_color("C0")
2288# self.tool.disconnect()
2289# if event.inaxes is self.ax and self.mode != "Idle":
2290# # statusmenu
2291# for key, item in self.statusmenu.menuitems.items():
2292# if self.mode == key:
2293# method = getattr(self, key.lower())
2294# method()
2295#
2296# def on_leave_ax(self, event):
2297# """Disconnect the current tool."""
2298# pass
2299#
2300# def on_click_tool(self, event):
2301# """Left here for convenience if some tools need a button release event."""
2302# pass
2303#
2304#
2305# ick_menu(self, event):
2306# """Chooses the function to call based on what MenuItem was clicked.
2307#
2308# Args:
2309# event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect().
2310#
2311# """
2312# for key, item in self.menu.menuitems.items():
2313# if item.check_select(event):
2314# method = getattr(self, key.lower().replace(" ", "_"))
2315# method()
2316#
2317# def reset(self):
2318# """Called when 'Reset' is pressed."""
2319# if "user_selected_points" in self.trajs.CVs:
2320# self.trajs._CVs.drop(labels="user_selected_points")
2321# self.__init__(
2322# self.trajs,
2323# self.autoencoder,
2324# self.data,
2325# None,
2326# self.align_string,
2327# self.top,
2328# self.hist,
2329# self.scatter_kws,
2330# self.ball_and_stick,
2331# )
2332#
2333# def write(self):
2334# """Called when 'Write' is pressed."""
2335# if self.mode == "Idle":
2336# return
2337# time = _datetime_windows_and_linux_compatible()
2338# if self.mode == "Bezier" or self.mode == "Path":
2339# os.makedirs(
2340# f"{self.autoencoder.p.main_path}/generated_paths/", exist_ok=True
2341# )
2342# fname = (
2343# f"{self.autoencoder.p.main_path}/generated_paths/generated_{time}.pdb"
2344# )
2345# with mda.Writer(fname) as w:
2346# for step in self.uni.trajectory:
2347# w.write(self.uni.atoms)
2348# self.ax.set_title(
2349# f"Generated Path with {len(generated)} points saved at {fname}"
2350# )
2351# else:
2352# if "user_selected_points" not in self.trajs.CVs:
2353# self.ax.set_title("First set the points before writing them to disk.")
2354# return
2355# max_, fname = _unpack_cluster_info(
2356# self.trajs,
2357# self.autoencoder.p.main_path,
2358# self.tool,
2359# self.dummy_traj,
2360# self.align_string,
2361# )
2362# self.ax.set_title(f"Cluster {max_} saved at {fname}")
2363#
2364# def set_points(self):
2365# """Called when 'Set Points' is pressed."""
2366# if self.mode == "Idle":
2367# return
2368# if self.mode != "Idle":
2369# if "tool" not in self.__dict__.keys():
2370# self.ax.set_title(f"Tool {self.mode} not yet implemented.")
2371# return
2372# else:
2373# indices = self.accept()
2374# if self.mode == "Bezier" or self.mode == "Path":
2375# if np.unique(self.path_points, axis=0).shape[0] != 200:
2376# self.ax.set_title(
2377# f"Tool {self.mode} returned not the requested unique points."
2378# )
2379# return
2380# self.dummy_traj = self.autoencoder.generate(
2381# self.path_points, backend="mdanalysis", top=self.top_index
2382# )
2383# self.view = ngl.show_mdanalysis(self.dummy_traj)
2384# if self.ball_and_stick:
2385# self.view.clear_representations()
2386# self.view.add_ball_and_stick()
2387#
2388# self.ax.set_title(
2389# f"Generated Path with {len(self.dummy_traj.trajectory)} points is accessible as InteractivePlotting.view."
2390# )
2391# return
2392#
2393# if indices is not None and self.mode != "Bezier" and self.mode != "Path":
2394# self.ax.set_title(
2395# f"Currently working on rendering the cluster. I'll let you know, when I'm finished."
2396# )
2397# indices = np.asarray(indices)
2398#
2399# # update user defined clustering
2400# col = "user_selected_points"
2401# if col not in self.trajs.CVs:
2402# _ = np.full(self.trajs.n_frames, -1)
2403# try:
2404# _[indices] = 0
2405# except IndexError as e:
2406# print(indices)
2407# raise SystemExit from e
2408# self.trajs.load_CVs(_, col)
2409# else:
2410# _ = self.trajs.CVs[col]
2411# max_ = _.max()
2412# _[indices] = max_ + 1
2413# self.trajs.load_CVs(_, col)
2414#
2415# # change coloration of self.pts
2416# color_palette = sns.color_palette("Paired", self.trajs.CVs[col].max() + 1)
2417# cluster_colors = [
2418# (*color_palette[x], 1) if x >= 0 else (0.5, 0.5, 0.5, 0.01)
2419# for x in self.trajs.CVs[col]
2420# ]
2421# self.pts.set_color(cluster_colors)
2422#
2423# max_ = np.max(self.trajs.CVs[col])
2424# self.view, self.dummy_traj = self.cluster_building_fn(
2425# self.trajs,
2426# max_,
2427# nglview=True,
2428# shorten=True,
2429# stack_atoms=True,
2430# col=col,
2431# align_string=self.align_string,
2432# ball_and_stick=self.ball_and_stick,
2433# )
2434# if self.ball_and_stick:
2435# for i in range(len(self.dummy_traj)):
2436# self.view.clear_representations(component=i)
2437# self.view.add_ball_and_stick(component=i)
2438# self.ax.set_title(
2439# f"Cluster {max_} is accessible as InteractivePlotting.view."
2440# )
2441#
2442# def render_move(self):
2443# pass
2444#
2445# def lasso(self):
2446# self.tool = SelectFromCollection(self.ax, self.pts)
2447#
2448# def rectangle(self):
2449# self.tool = SelectFromCollection(self.ax, self.pts, selector=RectangleSelector)
2450#
2451# def ellipse(self):
2452# print("Ellipse not yet implemented")
2453#
2454# def polygon(self):
2455# textstr = "\n".join(
2456# (
2457# "Select points in the figure by enclosing them within a polygon.",
2458# # Press the 'esc' key to start a new polygon.
2459# "Try holding the 'shift' key to move all of the vertices.",
2460# "Try holding the 'ctrl' key to move a single vertex.",
2461# )
2462# )
2463#
2464# # these are matplotlib.patch.Patch properties
2465# props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)
2466#
2467# # place a text box in upper left in axes coords
2468# self.manual_text = self.ax.text(
2469# 0.05,
2470# 0.95,
2471# textstr,
2472# transform=self.ax.transAxes,
2473# fontsize=6,
2474# verticalalignment="top",
2475# bbox=props,
2476# )
2477# self.tool = SelectFromCollection(self.ax, self.pts, selector=PolygonSelector)
2478#
2479# def path(self):
2480# pass
2481#
2482# def bezier(self):
2483# line = Line2D([], [], ls="--", c="#666666", marker="x", mew=2, mec="#204a87")
2484# self.ax.add_line(line)
2485# self.tool = BezierBuilder(line, self.ax)
2486#
2487# def accept(self):
2488# if "manual_text" in self.__dict__.keys():
2489# self.manual_text.set_visible(False)
2490# del self.manual_text
2491# if self.mode == "Bezier":
2492# self.path_points = copy.deepcopy(self.tool.ind)
2493# selected_indices = self.tool.ind
2494# self.tool.disconnect()
2495# return selected_indices
2496#
2497# @property
2498# def cluster_zoomed(self):
2499# col = "user_selected_points"
2500# if not col in self.trajs.df.keys():
2501# return
2502# max_ = np.max(self.trajs.df[col])
2503# _ = plot_cluster_zoomed(self.trajs, max_, col=col)
2504# return _
2505#
2506# @property
2507# def mode(self):
2508# return self.statusmenu.status