Coverage for encodermap/plot/plotting.py: 7%
590 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
1# -*- coding: utf-8 -*-
2# encodermap/plot/plotting.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Convenience functions for Plotting.
24"""
26##############################################################################
27# Imports
28##############################################################################
31# Future Imports at the top
32from __future__ import annotations
34# Standard Library Imports
35import os
36import shutil
37import subprocess
38import time
39from collections.abc import Sequence
40from functools import partial
41from typing import TYPE_CHECKING, Literal, Optional, Union, overload
43# Third Party Imports
44import matplotlib as mpl
45import matplotlib.pyplot as plt
46import numpy as np
47from ipywidgets import widgets
49# Encodermap imports
50from encodermap.encodermap_tf1.misc import periodic_distance_np, sigmoid
51from encodermap.misc.rotate import _dihedral
52from encodermap.parameters.parameters import AnyParameters
53from encodermap.trajinfo.info_all import TrajEnsemble
54from encodermap.trajinfo.info_single import SingleTraj
57################################################################################
58# Typing
59################################################################################
62if TYPE_CHECKING:
63 # Third Party Imports
64 import plotly.express as px
65 import plotly.graph_objs as go
68################################################################################
69# Optional Imports
70################################################################################
73# Third Party Imports
74from optional_imports import _optional_import
77md = _optional_import("mdtraj")
78nv = _optional_import("nglview")
79mda = _optional_import("MDAnalysis")
80pd = _optional_import("pandas")
81go = _optional_import("plotly", "graph_objects")
82px = _optional_import("plotly", "express")
83make_subplots = _optional_import("plotly", "subplots.make_subplots")
86################################################################################
87# Globals
88################################################################################
91__all__: list[str] = [
92 "distance_histogram",
93 "distance_histogram_interactive",
94 "plot_raw_data",
95 "interactive_path_visualization",
96 "plot_ramachandran",
97 "plot_dssp",
98 "plot_end2end",
99 "plot_ball_and_stick",
100 "plot_trajs_by_parameter",
101 "plot_free_energy",
102 "animate_lowd_trajectory",
103]
106GLOBAL_LAYOUT = {}
109################################################################################
110# Utilities
111################################################################################
114@overload
115def get_histogram( 115 ↛ exitline 115 didn't jump to the function exit
116 x: np.ndarray,
117 y: np.ndarray,
118 bins: int,
119 weights: Optional[np.ndarray],
120 avoid_zero_count: bool,
121 transpose: bool,
122 return_edges: Literal[False],
123) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ...
126@overload
127def get_histogram( 127 ↛ exitline 127 didn't jump to the function exit
128 x: np.ndarray,
129 y: np.ndarray,
130 bins: int,
131 weights: Optional[np.ndarray],
132 avoid_zero_count: bool,
133 transpose: bool,
134 return_edges: Literal[True],
135) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ...
138def get_histogram(
139 x: np.ndarray,
140 y: np.ndarray,
141 bins: int = 100,
142 weights: Optional[np.ndarray] = None,
143 avoid_zero_count: bool = False,
144 transpose: bool = False,
145 return_edges: bool = False,
146) -> Union[
147 tuple[np.ndarray, np.ndarray, np.ndarray],
148 tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
149]:
150 """Construct a 2D histogram.
152 Args:
153 x (np.ndarray): The x coordinates of the data.
154 y (np.ndarray): The y coordinates of the data.
155 bins (int): The number of bins passed to np.histogram2d.
156 weights (np.ndarray): The weights passed to np.histogram2d.
157 avoid_zero_count (bool): Avoid zero counts by lifting all
158 histogram elements to the minimum value before computing the free
159 energy. If False, zero histogram counts would yield infinity
160 in the free energy.
161 transpose (bool): Whether to transpose the output.
163 Returns:
164 tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
165 xcenters, ycenters, and the histogram.
167 Examples:
168 >>> from encodermap.plot.plotting import get_histogram
169 >>> x, y = np.random.uniform(size=(2, 500))
170 >>> xcenters, ycenters, H = get_histogram(x, y)
171 >>> xcenters.shape
172 (100,)
173 >>> H.shape
174 (100, 100)
175 >>> np.min(H)
176 0.0
177 >>> xcenters, ycenters, H = get_histogram(x, y, avoid_zero_count=True)
178 >>> np.min(H)
179 1.0
181 """
182 H, xedges, yedges = np.histogram2d(x, y, bins=bins, weights=weights)
183 xcenters = np.mean(np.vstack([xedges[0:-1], xedges[1:]]), axis=0)
184 ycenters = np.mean(np.vstack([yedges[0:-1], yedges[1:]]), axis=0)
185 if avoid_zero_count:
186 H = np.maximum(H, np.min(H[H.nonzero()]))
187 if transpose:
188 H = H.T
189 if not return_edges:
190 return xcenters, ycenters, H
191 else:
192 return xcenters, ycenters, xedges, yedges, H
195def get_density(
196 x: np.ndarray,
197 y: np.ndarray,
198 bins: int = 100,
199 weights: Optional[np.ndarray] = None,
200 avoid_zero_count: bool = False,
201 transpose: bool = False,
202) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
203 """Construct a 2D histogram with density.
205 Args:
206 x (np.ndarray): The x coordinates of the data.
207 y (np.ndarray): The y coordinates of the data.
208 bins (int): The number of bins passed to np.histogram2d.
209 weights (np.ndarray): The weights passed to np.histogram2d.
210 avoid_zero_count (bool): Avoid zero counts by lifting all
211 histogram elements to the minimum value before computing the free
212 energy. If False, zero histogram counts would yield infinity
213 in the free energy.
214 transpose (bool): Whether to transpose the output.
216 Returns:
217 tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
218 xcenters, ycenters, and the histogram.
220 """
221 xcenters, ycenters, H = get_histogram(
222 x, y, bins, weights, avoid_zero_count, transpose
223 )
224 return xcenters, ycenters, to_density(H)
227def to_density(H: np.ndarray) -> np.ndarray:
228 """Normalize histogram counts.
230 Args:
231 H (np.ndarray): The histogram to normalize.
233 Returns:
234 np.ndarray: The normalized histogram.
236 """
237 return H / float(H.sum())
240def to_free_energy(
241 H: np.ndarray,
242 kT: float = 1.0,
243 minener_zero: bool = False,
244):
245 """Compute free energies from histogram counts.
247 Args:
248 H (np.ndarray): The density histogram to get the free energy from.
249 kT (float): The value of kT in the desired energy unit. By default,
250 energies are computed in kT (setting 1.0). If you want to
251 measure the energy in kJ/mol at 298 K, use kT=2.479 and
252 change the cbar_label accordingly. Defaults to 1.0.
253 minener_zero (bool): Shifts the energy minimum to zero. Defaults to False.
255 Returns:
256 np.ndarray: The free energy values in units of kT.
258 """
259 F = np.inf * np.ones(shape=H.shape)
260 nonzero = H.nonzero()
261 F[nonzero] = -np.log(H[nonzero])
262 if minener_zero:
263 F[nonzero] -= np.min(F[nonzero])
264 F = F * kT
265 return F
268def get_free_energy(
269 x: np.ndarray,
270 y: np.ndarray,
271 bins: int = 100,
272 weights: Optional[np.ndarray] = None,
273 kT: float = 1.0,
274 avoid_zero_count: bool = False,
275 minener_zero: bool = False,
276 transpose: bool = True,
277) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
278 """Construct a 2D histogram with free energy.
280 Args:
281 x (np.ndarray): The x coordinates of the data.
282 y (np.ndarray): The y coordinates of the data.
283 bins (int): The number of bins passed to np.histogram2d.
284 weights (np.ndarray): The weights passed to np.histogram2d.
285 avoid_zero_count (bool): Avoid zero counts by lifting all
286 histogram elements to the minimum value before computing the free
287 energy. If False, zero histogram counts would yield infinity
288 in the free energy.
289 kT (float): The value of kT in the desired energy unit. By default,
290 energies are computed in kT (setting 1.0). If you want to
291 measure the energy in kJ/mol at 298 K, use kT=2.479 and
292 change the cbar_label accordingly. Defaults to 1.0.
293 minener_zero (bool): Shifts the energy minimum to zero. Defaults to False.
294 transpose (bool): Whether to transpose the output.
296 Returns:
297 tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
298 xcenters, ycenters, and the histogram.
300 """
301 xcenters, ycenters, H = get_density(
302 x, y, bins, weights, avoid_zero_count, transpose
303 )
305 # to free energy
306 H = to_free_energy(H, kT, minener_zero)
308 return xcenters, ycenters, H
311def hex_to_rgba(h, alpha=0.8):
312 h = h.lstrip("#")
313 r, g, b = tuple(int(h[i : i + 2], 16) for i in (0, 2, 4))
314 return f"rgba({r}, {g}, {b}, {alpha})"
317################################################################################
318# Private Functions
319################################################################################
322# @functools.cache
323def _get_squiggly_arrow(n: int = 1, n_frames: int = 200) -> pd.DataFrame:
324 if n == 1:
325 x = np.linspace(0.2, 2.5, n_frames)
326 y = np.sin(x * 2) / 0.5
327 xy = np.stack([x, y]).T
328 positions = np.full((n_frames, n_frames, 2), fill_value=np.nan)
329 time = []
330 for i, row in enumerate(xy):
331 positions[i:, i] = row
332 time.append(np.full(shape=(n_frames,), fill_value=i))
333 time = np.concatenate(time)
334 positions = positions.reshape(-1, 2)
335 assert len(time) == len(positions)
336 df = pd.DataFrame({"time": time, "x": positions[:, 0], "y": positions[:, 1]})
337 return df
338 else:
339 df = pd.DataFrame({"time": [], "x": [], "y": [], "trajectory": []})
340 for i in range(n):
341 theta = np.random.random() * 2 * np.pi - np.pi
342 rotmat = np.array(
343 [
344 [np.cos(theta), -np.sin(theta)],
345 [np.sin(theta), np.cos(theta)],
346 ]
347 )
348 x = np.linspace(0.0, 2.5, n_frames)
349 y = np.sin(x * 2) / 4
350 x -= 1.25
351 xy = rotmat @ np.stack([x, y])
352 xy[0] += np.random.random((1,))[0]
353 xy[1] += np.random.random((1,))[0]
354 xy = xy.T
355 positions = np.full((n_frames, n_frames, 2), fill_value=np.nan)
356 time = []
357 for j, row in enumerate(xy):
358 positions[j:, j] = row
359 time.append(np.full(shape=(n_frames,), fill_value=j))
360 time = np.concatenate(time)
361 positions = positions.reshape(-1, 2)
362 assert len(time) == len(positions)
363 sub_df = pd.DataFrame(
364 {
365 "time": time,
366 "x": positions[:, 0],
367 "y": positions[:, 1],
368 "trajectory": np.full((len(time),), fill_value=str(i + 1)),
369 }
370 )
371 df = pd.concat([df, sub_df])
372 return df
375def _project_onto_plane(x: np.ndarray, n: np.ndarray) -> np.ndarray:
376 assert np.isclose(np.linalg.norm(n), 1)
377 d = np.dot(x, n)
378 p = d * n
379 return x - p
382def _angle_arc(
383 points: np.ndarray,
384 name: str,
385 value: float,
386 radius: float = 0.05,
387 n_points: int = 100,
388) -> go.Scatter3d:
389 """Creates a `go.Scatetr3d` plot as an arc to represent he dihedral defined by `points`.
391 Args:
392 points (np.ndarray): The points as a (4, )-shaped numpy array.
393 name (str): The name of the angle arc when the mouse is hovered.
394 value (float): The value of the dihedral in radians.
395 radius (float): The radius of the arc. Defaults to 0.05 nm.
396 n_points (int): The number of points used to plot this arc. More
397 points might slow donw the system. Defaults to 100.
399 Returns:
400 go.Scatter3d: The plotly trace.
402 """
403 center = points[1]
404 u = a = points[1] - points[0]
405 v = points[2] - points[0]
406 face_normal = np.cross(u, v)
407 face_normal_unit = face_normal / np.linalg.norm(face_normal)
409 u = np.cross(face_normal, a)
410 u_unit = u / np.linalg.norm(u)
411 a = a / np.linalg.norm(a)
412 b = u_unit
414 rho = np.linspace(value - np.pi / 4, np.pi, num=n_points)
415 hovertemplate = "%{meta[0]:.2f} deg"
416 meta = [np.rad2deg(value)]
417 out = (
418 center
419 + radius * a * np.cos(rho)[:, np.newaxis]
420 + radius * b * np.sin(rho)[:, np.newaxis]
421 )
422 return go.Scatter3d(
423 x=out[:, 0],
424 y=out[:, 1],
425 z=out[:, 2],
426 name="",
427 line={
428 "color": "black",
429 "width": 5,
430 "dash": "dash",
431 },
432 mode="lines",
433 hovertemplate=hovertemplate,
434 meta=meta,
435 )
438def _dihedral_arc(
439 points: np.ndarray,
440 name: str,
441 radius: float = 0.05,
442 n_points: int = 100,
443 initial_points: Literal["random", "select"] = "select",
444 true_to_value: bool = True,
445) -> go.Scatter3d:
446 # get the center
447 center = np.mean(points[1:3], axis=0)
448 face_normal = points[2] - points[1]
449 face_normal_unit = face_normal / np.linalg.norm(face_normal)
450 sorted = np.argsort(face_normal)
452 if initial_points == "random":
453 # first, get a random vector on the plane with normal `face_normal`
454 vertical_to_face_normal = np.zeros((3,))
455 ind_largest = sorted[-1]
456 ind_2nd_largest = sorted[-2]
457 vertical_to_face_normal[ind_2nd_largest] = -face_normal[ind_largest]
458 vertical_to_face_normal[ind_largest] = face_normal[ind_2nd_largest]
459 vertical_to_face_normal_unit = vertical_to_face_normal / np.linalg.norm(
460 vertical_to_face_normal
461 )
462 a = vertical_to_face_normal_unit
463 dot = np.dot(face_normal, vertical_to_face_normal)
464 assert np.isclose(dot, 0, atol=1e-3)
466 # then get the crossproduct
467 u = np.cross(face_normal, vertical_to_face_normal)
468 u_unit = u / np.linalg.norm(u)
469 b = u_unit
470 hovertemplate = "%{meta[0]}"
471 meta = [name]
472 elif initial_points == "select":
473 a = points[0] - points[1]
474 c = points[3] - points[2]
475 a = _project_onto_plane(a, face_normal_unit)
477 u = np.cross(face_normal, a)
478 u_unit = u / np.linalg.norm(u)
479 a = a / np.linalg.norm(a)
480 b = u_unit
482 dihedral_value = _dihedral(points, [0, 1, 2, 3])[0, 0]
483 if true_to_value:
484 if dihedral_value >= 0:
485 rho = np.linspace(
486 0,
487 dihedral_value,
488 num=n_points,
489 )
490 else:
491 rho = np.linspace(
492 dihedral_value,
493 0,
494 num=n_points,
495 )
496 else:
497 rho = np.linspace(
498 0,
499 np.pi,
500 num=n_points,
501 )
502 hovertemplate = "%{meta[0]} %{meta[1]:.2f} deg"
503 meta = [name.split()[1], np.rad2deg(dihedral_value)]
504 else:
505 raise ValueError(
506 f"Argument `initial_points` must be 'random' or 'select', not {initial_points}."
507 )
509 out = (
510 center
511 + radius * a * np.cos(rho)[:, np.newaxis]
512 + radius * b * np.sin(rho)[:, np.newaxis]
513 )
514 return go.Scatter3d(
515 x=out[:, 0],
516 y=out[:, 1],
517 z=out[:, 2],
518 name="",
519 line={
520 "color": "black",
521 "width": 5,
522 "dash": "dash",
523 },
524 mode="lines",
525 hovertemplate=hovertemplate,
526 meta=meta,
527 )
530def _flatten_coords(traj: "SingleTraj") -> np.ndarray:
531 """Flattens coordinates, so it is easier to render them as images."""
532 # Third Party Imports
533 import networkx as nx
534 from mdtraj.geometry.angle import _angle
535 from networkx import connected_components
536 from transformations import affine_matrix_from_points, rotation_matrix
538 # Local Folder Imports
539 from ..loading.features import CentralAngles, CentralDihedrals, SideChainDihedrals
540 from ..misc.rotate import _get_near_and_far_networkx, mdtraj_rotate
542 indices = []
543 indices.append(CentralDihedrals(traj).indexes)
544 indices.append(SideChainDihedrals(traj).indexes)
545 indices = np.vstack(indices)
546 angles = np.full((1, indices.shape[0]), 0)
547 angles[::2] = 180
548 xyz = (
549 mdtraj_rotate(
550 traj.traj,
551 angles=angles,
552 indices=indices,
553 deg=True,
554 )
555 .xyz[0]
556 .copy()
557 )
559 # get best surface using 3d least squares
560 centroid = xyz.mean(axis=0)
561 xyzT = np.transpose(xyz)
562 xyzR = xyz - centroid
563 xyzRT = np.transpose(xyzR)
564 u, sigma, v = np.linalg.svd(xyzRT)
565 normal = u[2]
566 normal = normal / np.linalg.norm(normal)
568 # project points
569 a, b, c = normal
570 d = -a * centroid[0] - b * centroid[1] - c * centroid[2]
571 projected_points = []
572 for p in xyz:
573 projected_points.append(p - (p.dot(normal) + d / normal.dot(normal)) * normal)
574 xyz = np.array(projected_points)
576 # fix distances
577 edges = []
578 edge_lengths = []
579 atoms_in_bonds_is = set()
580 atoms_in_bonds_should_be = xyz.shape[0]
581 for a, b in traj.top.bonds:
582 atoms_in_bonds_is.add(a.index)
583 atoms_in_bonds_is.add(b.index)
584 edges.append([a.index, b.index])
585 edge_lengths.append(np.linalg.norm(traj.xyz[0, b.index] - traj.xyz[0, a.index]))
586 assert (
587 len(atoms_in_bonds_is) == atoms_in_bonds_should_be
588 ), f"Can't flatten topology: {traj.top}. There are atoms which are not part of bonds."
589 bondgraph = traj.top.to_bondgraph()
590 edges = np.asarray(edges)
591 edge_lengths = np.asarray(edge_lengths)
592 near_and_far_networkx = _get_near_and_far_networkx(
593 bondgraph,
594 edges,
595 traj.top,
596 parallel=True,
597 )[0]
598 for edge, indices, length_should_be in zip(
599 edges, near_and_far_networkx, edge_lengths
600 ):
601 vec = xyz[edge[1]] - xyz[edge[0]]
602 length_is = np.linalg.norm(vec)
603 vec /= np.linalg.norm(vec)
604 trans = vec * length_should_be
605 xyz[~indices] += trans
607 # fix angles
608 angle_should_be = 2 * np.pi / 3
609 angle_indices = CentralAngles(traj).angle_indexes
610 for i, (a, b, c) in enumerate(angle_indices):
611 angle_center = xyz[b]
612 ba = a - b
613 bc = c - b
614 angle_value = np.arccos(
615 np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))
616 )
617 diff = angle_should_be - angle_value
618 G = nx.convert_node_labels_to_integers(bondgraph).copy()
619 G.remove_edge(b, c)
620 components = [*connected_components(G)]
621 if c in components[0] and b in components[1]:
622 components = components[::-1]
623 subgraph = G.subgraph(components[1]).copy()
624 far = np.asarray(subgraph.nodes)
625 M = rotation_matrix(
626 diff,
627 normal,
628 angle_center,
629 )
630 padded = np.pad(
631 xyz[far].copy(), ((0, 0), (0, 1)), mode="constant", constant_values=1
632 )
633 xyz[far] = M.dot(padded.T).T[:, :3]
635 data = [
636 go.Scatter3d(
637 x=xyz[:, 0],
638 y=xyz[:, 1],
639 z=xyz[:, 2],
640 )
641 ]
642 fig = go.Figure(
643 data=data,
644 layout={
645 "width": 800,
646 "height": 800,
647 },
648 )
649 fig.show()
651 raise Exception(f"{xyz.shape=}")
654def _plot_ball_and_stick(
655 traj: Union["SingleTraj", md.Trajectory],
656 frame_subsample: Union[int, slice] = slice(None, None, 100),
657 highlight: Literal[
658 "atoms", "bonds", "angles", "dihedrals", "side_dihedrals", "central_dihedrals"
659 ] = "atoms",
660 atom_indices: Optional[Sequence[int]] = None,
661 custom_colors: Optional[dict[int, str]] = None,
662 add_angle_arcs: bool = True,
663 angle_arcs_true_to_value: bool = True,
664 animation: bool = False,
665 persistent_hover: bool = False,
666 flatten: bool = False,
667) -> go.Figure: # pragma: no cover
668 if hasattr(traj, "copy"):
669 traj = traj.copy()
670 else:
671 traj = SingleTraj(traj)
673 if atom_indices is None:
674 atom_indices = np.arange(traj.n_atoms)
675 if animation:
676 raise NotImplementedError(
677 f"Animation of ball and stick plot not yet implemented."
678 )
679 # data for plotting and annotation
680 if atom_indices is not None:
681 try:
682 traj.atom_slice(atom_indices)
683 except Exception as e:
684 raise Exception(
685 f"Can't index {traj=} with {np.asarray(atom_indices)=}"
686 ) from e
687 if not animation:
688 traj = traj[0]
689 xyz = traj.xyz[frame_subsample]
690 if flatten:
691 assert not any([a.element.symbol == "H" for a in traj.top.atoms]), (
692 f"Can only create a flattened representation for trajs without hydrogen. "
693 f"Use the `atom_indices` argument to remove the hydrogen."
694 )
695 xyz = _flatten_coords(traj)
696 times = traj.time[frame_subsample]
697 atom_names = np.array([str(a) for a in traj.top.atoms])
698 bonds = [(a.index, b.index) for a, b in traj.top.bonds]
699 sizes = np.array(
700 [24 if a.element.symbol != "H" else 10 for a in traj.top.atoms]
701 ).astype(float)
702 if highlight in [
703 "bonds",
704 "angles",
705 "dihedrals",
706 "central_dihedrals",
707 "side_dihedrals",
708 ]:
709 sizes /= 1.3
710 elements = np.array([a.element.number for a in traj.top.atoms])
711 coords = [f"x: {i:.3f}<br>y: {j:.3f}<br>z: {k:.3f}" for i, j, k in xyz[0]]
712 assert len(coords) == len(atom_names), f"{len(coords)=} {len(atom_names)=}"
713 colormap = {
714 1: "rgb(200, 200, 200)", # hydrogen
715 6: "rgb(80, 80, 80)", # carbon
716 7: "rgb(0, 0, 255)", # nitrogen
717 8: "rgb(255, 0, 0)", # oxygen
718 15: "rgb(160, 32, 240)", # phosphorus
719 16: "rgb(255, 255, 0)", # sulfur
720 34: "rgb(170, 74, 68)", # selenium
721 }
722 if custom_colors is None:
723 color = []
724 for i in elements:
725 if i in colormap:
726 color.append(colormap[i])
727 else:
728 color.append("rgb(255, 0, 124)")
729 else:
730 color = np.full(shape=(len(elements),), fill_value="rgb(126, 126, 126)")
731 for atom, c in custom_colors.items():
732 color[atom] = c
734 # for circle arcs
735 circles = []
737 # set customdata and hovertemplate
738 if highlight == "atoms":
739 customdata = np.stack(
740 (
741 atom_names,
742 coords,
743 ),
744 axis=-1,
745 )
746 hovertemplate = "%{customdata[0]}:<br>%{customdata[1]}"
747 hoverinfo = None
748 elif highlight == "angles":
749 # Local Folder Imports
750 from ..loading.features import CentralAngles, SideChainAngles
752 x_centers = []
753 y_centers = []
754 z_centers = []
755 angle_names = []
756 annotations_text = None
758 # Central Angles
759 f = CentralAngles(traj=traj)
760 for p, name, value in zip(f.indexes, f.describe(), f.transform()[0]):
761 x_centers.append(xyz[0, p[1], 0]),
762 y_centers.append(xyz[0, p[1], 1]),
763 z_centers.append(xyz[0, p[1], 2])
764 angle_names.append(
765 f"Angle: {traj.top.atom(p[0])} - {traj.top.atom(p[1])} - {traj.top.atom(p[2])}"
766 )
767 if add_angle_arcs:
768 circles.append(
769 _angle_arc(
770 xyz[0, p],
771 name=name,
772 value=value,
773 )
774 )
776 # Sidechain Angles
777 f = SideChainAngles(traj=traj)
778 for p, name, value in zip(f.indexes, f.describe(), f.transform()[0]):
779 x_centers.append(xyz[0, p[1], 0]),
780 y_centers.append(xyz[0, p[1], 1]),
781 z_centers.append(xyz[0, p[1], 2])
782 angle_names.append(
783 f"Angle: {traj.top.atom(p[0])} - {traj.top.atom(p[1])} - {traj.top.atom(p[2])}"
784 )
785 if add_angle_arcs:
786 circles.append(
787 _angle_arc(
788 xyz[0, p],
789 name=name,
790 value=value,
791 )
792 )
794 customdata = None
795 hovertemplate = None
796 hoverinfo = "skip"
797 center_customdata = angle_names
798 center_hovertemplate = "%{customdata}"
800 elif highlight == "bonds":
801 # Local Folder Imports
802 from ..loading.features import AllBondDistances
804 f = AllBondDistances(traj=traj)
806 x_centers = []
807 y_centers = []
808 z_centers = []
809 bond_names = []
810 annotations_text = []
812 for p, name in zip(f.indexes, f.describe()):
813 x_centers.append(np.mean(xyz[0, p, 0])),
814 y_centers.append(np.mean(xyz[0, p, 1])),
815 z_centers.append(np.mean(xyz[0, p, 2]))
816 bond_names.append(
817 f"Bond between {traj.top.atom(p[0])} and {traj.top.atom(p[1])}"
818 )
819 annotations_text.append(f"{traj.top.atom(p[0])} - {traj.top.atom(p[1])}")
821 customdata = None
822 hovertemplate = None
823 hoverinfo = "skip"
824 center_customdata = bond_names
825 center_hovertemplate = "%{customdata}"
826 elif highlight in ["dihedrals", "side_dihedrals", "central_dihedrals"]:
827 # Local Folder Imports
828 from ..loading.features import CentralDihedrals, SideChainDihedrals
830 x_centers = []
831 y_centers = []
832 z_centers = []
833 dihedral_names = []
834 annotations_text = []
836 # Central Dihedrals
837 if highlight in ["dihedrals", "central_dihedrals"]:
838 f = CentralDihedrals(traj=traj)
839 for p, name in zip(f.indexes, f.describe()):
840 x_centers.append(np.mean(xyz[0, p[1:3], 0])),
841 y_centers.append(np.mean(xyz[0, p[1:3], 1])),
842 z_centers.append(np.mean(xyz[0, p[1:3], 2]))
843 dihedral_names.append(name)
844 annotations_text.append(name.split()[1])
845 if add_angle_arcs:
846 circles.append(
847 _dihedral_arc(
848 xyz[0, p], name=name, true_to_value=angle_arcs_true_to_value
849 )
850 )
852 # Sidechain Dihedrals
853 if highlight in ["dihedrals", "side_dihedrals"]:
854 f = SideChainDihedrals(traj=traj)
855 for p, name in zip(f.indexes, f.describe()):
856 x_centers.append(np.mean(xyz[0, p[1:3], 0])),
857 y_centers.append(np.mean(xyz[0, p[1:3], 1])),
858 z_centers.append(np.mean(xyz[0, p[1:3], 2]))
859 dihedral_names.append(name)
860 annotations_text.append(name.split()[1])
861 if add_angle_arcs:
862 circles.append(
863 _dihedral_arc(
864 xyz[0, p], name=name, true_to_value=angle_arcs_true_to_value
865 )
866 )
868 customdata = None
869 hovertemplate = None
870 hoverinfo = "skip"
871 center_customdata = dihedral_names
872 center_hovertemplate = "%{customdata}"
873 else:
874 raise TypeError(
875 f"The argument `highlight` must be one of the following: "
876 f"'atoms', 'bonds', 'angles', 'dihedrals'. You supplied {highlight}"
877 )
879 # create scatter trace
880 scatter = go.Scatter3d(
881 x=xyz[0, :, 0],
882 y=xyz[0, :, 1],
883 z=xyz[0, :, 2],
884 customdata=customdata,
885 mode="markers",
886 hovertemplate=hovertemplate,
887 name="Atoms",
888 marker=dict(
889 size=sizes,
890 color=color,
891 opacity=1.0,
892 ),
893 hoverinfo=hoverinfo,
894 )
896 # create line trace
897 x_lines = []
898 y_lines = []
899 z_lines = []
900 for p in bonds:
901 for i in range(2):
902 x_lines.append(xyz[0, p[i], 0])
903 y_lines.append(xyz[0, p[i], 1])
904 z_lines.append(xyz[0, p[i], 2])
905 x_lines.append(None)
906 y_lines.append(None)
907 z_lines.append(None)
909 lines = go.Scatter3d(
910 x=x_lines,
911 y=y_lines,
912 z=z_lines,
913 mode="lines",
914 name="",
915 line=dict(
916 color="black",
917 width=(
918 6
919 if highlight
920 in ["bonds", "dihedrals", "central_dihedrals", "side_dihedrals"]
921 else 1
922 ),
923 ),
924 hoverinfo="skip",
925 )
927 # create figure
928 if highlight == "atoms":
929 data = [scatter, lines]
930 else:
931 centers = go.Scatter3d(
932 x=x_centers,
933 y=y_centers,
934 z=z_centers,
935 mode="markers",
936 marker=dict(
937 size=30,
938 color="rgba(0, 0, 0, 0)",
939 opacity=0.0,
940 ),
941 name=f"{highlight}".capitalize(),
942 customdata=center_customdata,
943 hovertemplate=center_hovertemplate,
944 )
945 data = [centers, scatter, lines]
947 if highlight in ["dihedrals", "angles", "central_dihderals", "side_dihedrals"]:
948 data.extend(circles)
950 fig = go.Figure(
951 data=data,
952 )
954 if persistent_hover:
955 annotations = []
956 if highlight == "atoms":
957 zipped = zip(xyz[0, :, 0], xyz[0, :, 1], xyz[0, :, 2], traj.top.atoms)
958 for x, y, z, a in zipped:
959 if a.element.symbol == "H":
960 continue
961 annotations.append(
962 {
963 "x": x,
964 "y": y,
965 "z": z,
966 "text": str(a),
967 }
968 )
970 else:
971 zipped = zip(x_centers, y_centers, z_centers, annotations_text)
972 for x, y, z, text in zipped:
973 annotations.append(
974 {
975 "x": x,
976 "y": y,
977 "z": z,
978 "text": text,
979 }
980 )
981 else:
982 annotations = []
984 scene = {
985 "xaxis_gridcolor": "rgb(102, 102, 102)",
986 "yaxis_gridcolor": "rgb(102, 102, 102)",
987 "zaxis_gridcolor": "rgb(102, 102, 102)",
988 "annotations": annotations,
989 }
990 if "scene" in GLOBAL_LAYOUT:
991 scene |= GLOBAL_LAYOUT["scene"]
992 global_layout = {k: v for k, v in GLOBAL_LAYOUT.items() if k != "scene"}
993 else:
994 global_layout = GLOBAL_LAYOUT.copy()
996 fig.update_layout(
997 height=900,
998 width=900,
999 showlegend=False,
1000 **global_layout,
1001 scene=scene,
1002 )
1004 # create frames
1005 if animation:
1006 frames = [go.Frame(data=[scatter, lines])]
1007 for points in xyz:
1008 x_lines = []
1009 y_lines = []
1010 z_lines = []
1011 for p in bonds:
1012 for i in range(2):
1013 x_lines.append(points[p[i], 0])
1014 y_lines.append(points[p[i], 1])
1015 z_lines.append(points[p[i], 2])
1016 x_lines.append(None)
1017 y_lines.append(None)
1018 z_lines.append(None)
1019 frame = go.Frame(
1020 data=[
1021 go.Scatter3d(
1022 x=points[:, 0],
1023 y=points[:, 1],
1024 z=points[:, 2],
1025 customdata=atom_names,
1026 mode="markers",
1027 hovertemplate="%{customdata}",
1028 name="",
1029 marker=dict(
1030 size=sizes,
1031 color=color,
1032 opacity=1.0,
1033 ),
1034 ),
1035 go.Scatter3d(
1036 x=x_lines,
1037 y=y_lines,
1038 z=z_lines,
1039 mode="lines",
1040 name="",
1041 line=dict(
1042 color="black",
1043 ),
1044 ),
1045 ],
1046 )
1047 frames.append(frame)
1048 fig.update(frames=frames)
1049 fig.update_layout(
1050 updatemenus=[
1051 dict(
1052 type="buttons",
1053 buttons=[
1054 dict(
1055 label="Play",
1056 method="animate",
1057 args=[
1058 None,
1059 dict(
1060 frame=dict(
1061 redraw=True, fromcurrent=True, mode="immediate"
1062 )
1063 ),
1064 ],
1065 )
1066 ],
1067 )
1068 ],
1069 # sliders=(
1070 # [
1071 # {
1072 # "steps": [
1073 # {
1074 # "args": [
1075 # [f.name],
1076 # {
1077 # "frame": {"duration": 0, "redraw": True},
1078 # "mode": "immediate",
1079 # },
1080 # ],
1081 # "label": f.name,
1082 # "method": "animate",
1083 # }
1084 # for f in frames
1085 # ],
1086 # }
1087 # ],
1088 # ),
1089 scene=dict(
1090 xaxis=dict(range=[np.min(xyz[..., 0]), np.max(xyz[..., 0])]),
1091 yaxis=dict(range=[np.min(xyz[..., 1]), np.max(xyz[..., 1])]),
1092 zaxis=dict(range=[np.min(xyz[..., 2]), np.max(xyz[..., 2])]),
1093 ),
1094 )
1095 return fig
1098################################################################################
1099# Plotting Functions
1100################################################################################
1103def animate_lowd_trajectory(
1104 n: int = 1,
1105 potential: bool = False,
1106 n_frames: int = 200,
1107) -> None:
1108 if not potential:
1109 p_init = np.random.random((1, 2)) * 10
1110 p = p_init.copy()
1111 v_init = np.random.random((1, 2)) * 0.05 - 0.025
1112 positions = np.full((n_frames, n_frames, 2), np.nan)
1113 time = []
1114 for i in range(n_frames):
1115 positions[i:, i] = p.copy()
1116 p += v_init
1117 time.append(np.full(shape=(n_frames,), fill_value=i))
1118 time = np.concatenate(time)
1119 x_min = np.nanmin(positions[..., 0]) - 1
1120 x_max = np.nanmax(positions[..., 0]) + 1
1121 y_min = np.nanmin(positions[..., 1]) - 1
1122 y_max = np.nanmax(positions[..., 1]) + 1
1123 positions = positions.reshape(-1, 2)
1124 assert len(time) == len(positions)
1125 df = pd.DataFrame({"time": time, "x": positions[:, 0], "y": positions[:, 1]})
1126 fig = px.line(
1127 data_frame=df,
1128 x="x",
1129 y="y",
1130 animation_frame="time",
1131 range_x=(x_min, x_max),
1132 range_y=(y_min, y_max),
1133 height=500,
1134 width=800,
1135 )
1136 fig.layout.updatemenus[0].buttons[0].args[1]["transition"]["duration"] = 0
1137 fig.show()
1138 else:
1139 if potential is True:
1140 if n == 1:
1141 df = _get_squiggly_arrow(n_frames=n_frames)
1142 x_min = np.nanmin(df["x"].values) - 1
1143 x_max = np.nanmax(df["x"].values) + 1
1144 y_min = np.nanmin(df["y"].values) - 1
1145 y_max = np.nanmax(df["y"].values) + 1
1146 fig = px.line(
1147 data_frame=df,
1148 x="x",
1149 y="y",
1150 animation_frame="time",
1151 range_x=(x_min, x_max),
1152 range_y=(y_min, y_max),
1153 height=500,
1154 width=800,
1155 )
1156 fig.layout.updatemenus[0].buttons[0].args[1]["transition"][
1157 "duration"
1158 ] = 0
1159 fig.show()
1160 else:
1161 df = _get_squiggly_arrow(n=n, n_frames=n_frames)
1162 x_min = np.nanmin(df["x"].values) - 1
1163 x_max = np.nanmax(df["x"].values) + 1
1164 y_min = np.nanmin(df["y"].values) - 1
1165 y_max = np.nanmax(df["y"].values) + 1
1166 fig = px.line(
1167 data_frame=df,
1168 x="x",
1169 y="y",
1170 animation_frame="time",
1171 color="trajectory",
1172 range_x=(x_min, x_max),
1173 range_y=(y_min, y_max),
1174 height=500,
1175 width=800,
1176 )
1177 fig.layout.updatemenus[0].buttons[0].args[1]["transition"][
1178 "duration"
1179 ] = 0
1180 fig.show()
1181 else:
1182 print(f"{potential=}")
1185def plot_trajs_by_parameter(
1186 trajs: Union[SingleTraj, TrajEnsemble],
1187 parameter: Union[
1188 Literal[
1189 "common_str",
1190 "frame",
1191 "encoded_frame",
1192 "traj_num",
1193 "topology",
1194 "free_energy",
1195 ],
1196 str,
1197 ] = "common_str",
1198 type: Literal["scatter", "heatmap"] = "scatter",
1199 x: Optional[np.ndarray] = None,
1200 y: Optional[np.ndarray] = None,
1201 contourtype: Literal["contour", "contourf"] = "countour",
1202 col: str = "lowd",
1203 nbins: int = 100,
1204 alpha: float = 0.8,
1205 z_name_overwrite: str = "",
1206 show: bool = True,
1207 cbar: bool = True,
1208) -> go.Figure:
1209 if x is None:
1210 assert y is None, "Must provide either x and y or both None."
1211 assert col in trajs._CVs, (
1212 f"The CV `col`={col} cannot be found in the `trajs` with CVs: "
1213 f"{list(trajs.CVs.keys())}. Please use `load_CVs` to load the "
1214 f"low-dimensional coordinates for the `trajs`."
1215 )
1216 x, y = trajs.CVs[col].T
1218 if (type == "scatter" and x.size > 25_000) and not os.getenv(
1219 "ENCODERMAP_SKIP_SCATTER_SIZE_CHECK", "False"
1220 ) == "True":
1221 print(
1222 f"The number of points is very large ({x.size}). Using scatter "
1223 f"with this number of points might crash your browser and maybe "
1224 f"even your system. Set the environment variable "
1225 f"'ENCODERMAP_SKIP_SCATTER_SIZE_CHECK' to 'True' to skip this check"
1226 )
1227 return
1229 data = None
1230 if parameter == "common_str":
1231 data = []
1232 for traj in trajs:
1233 data.extend([traj.common_str for i in range(traj.n_frames)])
1234 elif parameter == "free_energy":
1235 fig = go.Figure(
1236 data=[_plot_free_energy(*trajs.lowd.T, cbar=cbar)],
1237 layout={
1238 "autosize": True,
1239 "height": 800,
1240 "width": 800,
1241 "title": "Free Energy",
1242 "xaxis_title": "x in a.u.",
1243 "yaxis_title": "y in a.u.",
1244 },
1245 )
1246 if show:
1247 fig.show()
1248 return fig
1249 elif parameter == "encoded_frame":
1250 # Encodermap imports
1251 from encodermap.loading.features import pair
1253 type = "scatter"
1254 data = []
1255 for traj in trajs:
1256 data.extend([pair(traj.traj_num, i) for i in range(traj.n_frames)])
1257 elif parameter == "traj_num":
1258 data = []
1259 for traj in trajs:
1260 data.extend([traj.traj_num for i in range(traj.n_frames)])
1261 else:
1262 if parameter in trajs.CVs:
1263 if (_ := trajs.CVs[parameter]).ndim == 1:
1264 data = _
1266 if data is None:
1267 raise Exception(
1268 f"Argument `parameter` must be one of 'common_str', 'frame', "
1269 f"'encoded_frame', 'traj_num', 'topology', 'free_energy', or any "
1270 f"of the `TrajEnsemble` 1-dimensional CVs."
1271 f"You provided {parameter}."
1272 )
1274 # this is the same no matter what datasource we use
1275 df = pd.DataFrame({"x": x, "y": y, "data": data})
1276 if z_name_overwrite:
1277 parameter = z_name_overwrite
1278 title = parameter.replace("_", " ").title()
1279 title = (
1280 f"{title} for Trajectories with {trajs.n_frames} frames, "
1281 f"{trajs.n_trajs} trajs and {len(trajs.top)} uniques topologies."
1282 )
1283 if type == "scatter":
1284 fig = px.scatter(
1285 df,
1286 x="x",
1287 y="y",
1288 color="data",
1289 color_continuous_scale="Viridis",
1290 render_mode="webgl",
1291 labels={
1292 "x": "x in a.u.",
1293 "y": "y in a.u.",
1294 "data": parameter,
1295 },
1296 opacity=alpha,
1297 )
1298 if not cbar:
1299 fig.update_coloraxes(showscale=False)
1300 fig.update_layout(
1301 {
1302 "autosize": True,
1303 "height": 800,
1304 "width": 800,
1305 "title": title,
1306 }
1307 | GLOBAL_LAYOUT
1308 )
1309 elif type == "heatmap":
1310 if len(np.unique(df["data"])) > 10:
1311 colors = px.colors.qualitative.Alphabet
1312 else:
1313 colors = px.colors.qualitative.Plotly
1315 traces = []
1316 bins = [
1317 np.linspace(np.min(df["x"]), np.max(df["x"]), nbins + 1, endpoint=True),
1318 np.linspace(np.min(df["y"]), np.max(df["y"]), nbins + 1, endpoint=True),
1319 ]
1320 xcenters = np.mean(np.vstack([bins[0][:-1], bins[0][1:]]), axis=0)
1321 ycenters = np.mean(np.vstack([bins[1][:-1], bins[1][1:]]), axis=0)
1322 for i, (datapoint, sub_df) in enumerate(df.groupby(data)):
1323 color = colors[i]
1324 H, _, __ = np.histogram2d(*sub_df[["x", "y"]].values.T, bins=bins)
1325 traces.append(
1326 go.Contour(
1327 x=xcenters,
1328 y=ycenters,
1329 z=H.T,
1330 contours_type="constraint",
1331 contours_operation="<",
1332 contours_value=0,
1333 contours_coloring="none",
1334 fillcolor=hex_to_rgba(color, alpha=alpha),
1335 line_color=color,
1336 name=datapoint,
1337 visible=True,
1338 ),
1339 )
1340 # if contourtype == "contourf":
1341 # H = H.astype(bool).astype(float)
1342 # H[H == 0] = np.nan
1343 # traces.append(
1344 # go.Contour(
1345 # x=xcenters,
1346 # y=ycenters,
1347 # z=H.T,
1348 # colorscale=[[0, hex_to_rgba(color, alpha=alpha)], [1, "rgba(0, 0, 0, 0)"]],
1349 # showscale=False,
1350 # ),
1351 # )
1353 fig = go.Figure(
1354 data=traces,
1355 layout={
1356 "autosize": True,
1357 "width": 800,
1358 "height": 800,
1359 "title": title,
1360 }
1361 | GLOBAL_LAYOUT,
1362 )
1363 else:
1364 raise Exception(
1365 f"Argument `type` must be either 'scatter' or 'heatmap'. You provided {type}."
1366 )
1367 if show:
1368 fig.show()
1369 return fig
1372def _plot_free_energy(
1373 x: np.ndarray,
1374 y: np.ndarray,
1375 bins: int = 100,
1376 weights: Optional[np.ndarray] = None,
1377 kT: float = 1.0,
1378 avoid_zero_count: bool = False,
1379 minener_zero: bool = True,
1380 transpose: bool = True,
1381 cbar: bool = False,
1382 cbar_label: str = "free energy / kT",
1383 colorbar_x: Optional[float] = None,
1384) -> go.Contour:
1385 """Plots free energy using plotly.
1387 Args:
1388 x (np.ndarray): The x coordinates of the data.
1389 y (np.ndarray): The y coordinates of the data.
1390 bins (int): The number of bins passed to np.histogram2d.
1391 weights (np.ndarray): The weights passed to np.histogram2d.
1392 avoid_zero_count (bool): Avoid zero counts by lifting all
1393 histogram elements to the minimum value before computing the free
1394 energy. If False, zero histogram counts would yield infinity
1395 in the free energy.
1396 kT (float): The value of kT in the desired energy unit. By default,
1397 energies are computed in kT (setting 1.0). If you want to
1398 measure the energy in kJ/mol at 298 K, use kT=2.479 and
1399 change the cbar_label accordingly. Defaults to 1.0.
1400 minener_zero (bool): Shifts the energy minimum to zero. Defaults to False.
1401 transpose (bool): Whether to transpose the output.
1402 cbar (bool): Whether to display a colorbar. Dewfaults to False.
1403 cbar_label (str): The label of the colorbar. Defaults to 'free energy / kT'.
1404 colorbar_x (Optional[float]): Sets the x position with respect to xref
1405 of the color bar (in plot fraction). When xref is “paper”, None becomes
1406 1.02 when orientation is “v” and 0.5 when orientation is “h”. When
1407 xref is “container”, None becaomses 1 when orientation is “v” and
1408 0.5 when orientation is “h”. Must be between
1409 0 and 1 if xref is “container” and between “-2” and 3 if xref is
1410 “paper”.
1412 Returns:
1413 go.Contour: The contour plot.
1415 Examples:
1416 >>> import plotly.graph_objects as go
1417 >>> from encodermap.plot.plotting import _plot_free_energy
1418 ...
1419 >>> x, y = np.random.normal(size=(2, 1000))
1420 >>> trace = _plot_free_energy(x, y, bins=10)
1421 >>> fig = go.Figure(data=[trace])
1422 >>> np.any(fig.data[0].z == float("inf"))
1423 True
1425 """
1426 X, Y, Z = get_free_energy(
1427 x=x,
1428 y=y,
1429 weights=weights,
1430 bins=bins,
1431 kT=kT,
1432 avoid_zero_count=avoid_zero_count,
1433 minener_zero=minener_zero,
1434 transpose=transpose,
1435 )
1436 trace = go.Contour(
1437 x=X,
1438 y=Y,
1439 z=Z,
1440 name="Lowd projection",
1441 showscale=cbar,
1442 hoverinfo="none",
1443 colorscale="Viridis",
1444 colorbar_title=cbar_label,
1445 # histfunc="count",
1446 colorbar_x=colorbar_x,
1447 )
1448 return trace
1451def plot_free_energy(
1452 x: np.ndarray,
1453 y: np.ndarray,
1454 bins: int = 100,
1455 weights: Optional[np.ndarray] = None,
1456 kT: float = 1.0,
1457 avoid_zero_count: bool = False,
1458 minener_zero: bool = True,
1459 transpose: bool = True,
1460 cbar: bool = False,
1461 cbar_label: str = "free energy / kT",
1462 colorbar_x: Optional[float] = None,
1463) -> None:
1464 """Plots free energy using plotly.
1466 Args:
1467 x (np.ndarray): The x coordinates of the data.
1468 y (np.ndarray): The y coordinates of the data.
1469 bins (int): The number of bins passed to np.histogram2d.
1470 weights (np.ndarray): The weights passed to np.histogram2d.
1471 avoid_zero_count (bool): Avoid zero counts by lifting all
1472 histogram elements to the minimum value before computing the free
1473 energy. If False, zero histogram counts would yield infinity
1474 in the free energy.
1475 kT (float): The value of kT in the desired energy unit. By default,
1476 energies are computed in kT (setting 1.0). If you want to
1477 measure the energy in kJ/mol at 298 K, use kT=2.479 and
1478 change the cbar_label accordingly. Defaults to 1.0.
1479 minener_zero (bool): Shifts the energy minimum to zero. Defaults to False.
1480 transpose (bool): Whether to transpose the output.
1481 cbar (bool): Whether to display a colorbar. Dewfaults to False.
1482 cbar_label (str): The label of the colorbar. Defaults to 'free energy / kT'.
1483 colorbar_x (Optional[float]): Sets the x position with respect to xref
1484 of the color bar (in plot fraction). When xref is “paper”, None becomes
1485 1.02 when orientation is “v” and 0.5 when orientation is “h”. When
1486 xref is “container”, None becaomses 1 when orientation is “v” and
1487 0.5 when orientation is “h”. Must be between
1488 0 and 1 if xref is “container” and between “-2” and 3 if xref is
1489 “paper”.
1491 """
1492 fig = go.Figure(
1493 data=[
1494 _plot_free_energy(
1495 x=x,
1496 y=y,
1497 bins=bins,
1498 weights=weights,
1499 kT=kT,
1500 avoid_zero_count=avoid_zero_count,
1501 minener_zero=minener_zero,
1502 transpose=transpose,
1503 cbar=cbar,
1504 cbar_label=cbar_label,
1505 colorbar_x=colorbar_x,
1506 ),
1507 ],
1508 layout={
1509 "width": 500,
1510 "height": 500,
1511 }
1512 | GLOBAL_LAYOUT,
1513 )
1514 fig.show()
1517def interactive_path_visualization(
1518 traj: SingleTraj,
1519 lowd: Union[np.ndarray, pd.DataFrame],
1520 path: np.ndarray,
1521) -> widgets.GridBox:
1522 assert len(traj) == len(
1523 path
1524 ), f"Path has {len(path)} points, Trajectory has {len(traj)} frames."
1526 # define the traces
1527 if isinstance(lowd, pd.DataFrame):
1528 lowd = lowd[["x", "y"]].values
1529 trace1 = _plot_free_energy(*lowd.T, transpose=True)
1530 trace2 = go.Scatter(
1531 mode="lines",
1532 x=path[:, 0],
1533 y=path[:, 1],
1534 name="Path",
1535 )
1536 trace3 = go.Scatter(
1537 mode="markers",
1538 marker={"size": 10},
1539 x=[path[0, 0]],
1540 y=[path[0, 1]],
1541 name="Current path pos.",
1542 )
1544 # create a figure widget
1545 g = go.FigureWidget(
1546 data=[trace1, trace2, trace3],
1547 layout=go.Layout(
1548 {
1549 "height": 500,
1550 "width": 500,
1551 "showlegend": False,
1552 "margin": {
1553 "t": 0,
1554 "b": 0,
1555 "l": 0,
1556 "r": 0,
1557 },
1558 }
1559 ),
1560 )
1562 # create the nglview widget
1563 nglview = nv.show_mdtraj(traj.traj)
1565 # create the media slider
1566 media_widget = widgets.Play(
1567 value=0,
1568 min=0,
1569 max=len(path),
1570 step=1,
1571 disabled=False,
1572 )
1573 media_slider = widgets.IntSlider(
1574 value=0,
1575 min=0,
1576 max=len(path),
1577 )
1578 widgets.jslink((media_widget, "value"), (media_slider, "value"))
1580 box1 = widgets.Box(
1581 children=[g],
1582 layout=widgets.Layout(
1583 width="auto",
1584 height="auto",
1585 grid_area="main",
1586 ),
1587 style=widgets.Style(
1588 margin="0 0 0 0",
1589 pad="0 0 0 0",
1590 ),
1591 )
1593 box2 = widgets.Box(
1594 children=[nglview],
1595 layout=widgets.Layout(
1596 width="auto",
1597 height="auto",
1598 grid_area="sidebar",
1599 ),
1600 style=widgets.Style(
1601 margin="0 0 0 0",
1602 pad="0 0 0 0",
1603 ),
1604 )
1606 box3 = widgets.HBox(
1607 children=[media_widget, media_slider],
1608 layout=widgets.Layout(
1609 width="auto",
1610 height="auto",
1611 grid_area="footer",
1612 align_content="center",
1613 ),
1614 style=widgets.Style(
1615 margin="0 0 0 0",
1616 pad="0 0 0 0",
1617 ),
1618 )
1620 container = widgets.GridBox(
1621 children=[
1622 box1,
1623 box2,
1624 box3,
1625 ],
1626 layout=widgets.Layout(
1627 width="100%",
1628 grid_template_columns="auto auto",
1629 grid_template_rows="1000 px",
1630 grid_gap="5px",
1631 grid_template_areas="""
1632 "main sidebar sidebar sidebar"
1633 "footer footer footer footer"
1634 """,
1635 ),
1636 )
1638 def advance_path(n: int) -> None:
1639 n = n["new"]
1640 print(n)
1641 nglview.frame = n
1642 g.data[2].x = [path[n, 0]]
1643 g.data[2].y = [path[n, 1]]
1645 media_slider.observe(advance_path, names="value")
1647 return container
1650def distance_histogram_interactive(
1651 data: Union[np.ndarray, pd.DataFrame],
1652 periodicity: float,
1653 low_d_max: float = 5.0,
1654 n_values: int = 1000,
1655 bins: Union[Literal["auto"], int] = "auto",
1656 initial_guess: Optional[tuple[float, ...]] = None,
1657 renderer: Optional[Literal["colab", "plotly_mimetype+notebook"]] = None,
1658 parameters: Optional["AnyParameters"] = None,
1659) -> None: # pragma: no cover
1660 """Interactive version of `distance_histogram`.
1662 Note:
1664 Args:
1665 data (np.ndarray): 2-dimensional numpy array. Columns should iterate
1666 over the dimensions of the datapoints, i.e. the dimensionality
1667 of the data. The rows should iterate over datapoints.
1668 periodicity (float): Periodicity of the data. Use `float("inf")`
1669 for non-periodic data.
1670 low_d_max (float): Upper limit for plotting the low_d sigmoid.
1671 Defaults to 5.0.
1672 n_values (int): The number of x-values to use for the plotting
1673 of the sigmoid functions. Used in `np.linspace(min, max, n_values)`.
1674 Defaults to 1000.
1675 bins (Union[Literal["auto"], int]): Number of bins for histogram.
1676 Use 'auto' to let numpy decide how many bins to use. Defaults to 'auto'.
1677 initial_guess (Optional[tuple[float, ...]]): Tuple of sketchmap
1678 sigmoid parameters n shape (highd_sigma, highd_a, highd_b,
1679 lowd_sigma, lowd_a, lowd_b). If None is provided, the default
1680 values: (4.5, 12, 6, 1, 2, 6) are chosen. Defaults to None.
1681 parameters (AnyParameters): An instance of `encodermap.Parameters`, or
1682 `encodermap.ADCParameters`, to which the sigmoid parameters will be
1683 set.
1684 skip_data_size_check (bool): Whether to skip a check, that prevents the
1685 kernel to be killed when large datasets are passed.
1687 """
1688 # decide the renderer
1689 if renderer is None:
1690 try:
1691 # Third Party Imports
1692 from google.colab import data_table
1694 renderer = "colab"
1695 except (ModuleNotFoundError, NameError):
1696 renderer = "plotly_mimetype+notebook"
1698 assert not np.any(np.isnan(data)), "You provided some nans."
1700 # some helper functions
1701 def my_ceil(a, precision=0):
1702 return np.round(a + 0.5 * 10 ** (-precision), precision)
1704 def sigmoid(r, sigma=1, a=1, b=1):
1705 return 1 - (1 + (2 ** (a / b) - 1) * (r / sigma) ** a) ** (-b / a)
1707 def get_connection_traces(highd, lowd, lowd_max, highd_max):
1708 for i, (h, l) in enumerate(zip(highd, lowd)):
1709 l_plot = l / lowd_max
1710 h_plot = h / highd_max
1711 yield go.Scatter(
1712 x=[l_plot, h_plot],
1713 y=[0, 1],
1714 mode="lines",
1715 name=f"connection_{i}",
1716 showlegend=False,
1717 line_width=0.8,
1718 line_color="black",
1719 hovertemplate=f"{h:.2f} in highd maps to {l:.2f} in lowd {lowd_max=}",
1720 )
1722 # get the distances while accounting for periodicity
1723 vecs = periodic_distance_np(
1724 np.expand_dims(data, axis=1), np.expand_dims(data, axis=0), periodicity
1725 )
1726 dists = np.linalg.norm(vecs, axis=2)
1727 while True:
1728 try:
1729 dists = np.linalg.norm(dists, axis=2)
1730 except np.exceptions.AxisError:
1731 break
1732 dists = dists.reshape(-1)
1733 high_d_max = np.max(dists)
1735 # use the initial guess or default values
1736 if initial_guess is None:
1737 initial_guess = (4.5, 12, 6, 1, 2, 6)
1739 # instantiate the sliders
1740 lowd_sigma_slider = widgets.FloatSlider(
1741 value=initial_guess[3],
1742 min=0.1,
1743 max=my_ceil(low_d_max, 1),
1744 step=0.1,
1745 description="lowd sigma",
1746 continuous_udpate=True,
1747 )
1748 lowd_a_slider = widgets.FloatSlider(
1749 value=initial_guess[4],
1750 min=0.1,
1751 max=12.0,
1752 step=0.1,
1753 description="lowd a",
1754 continuous_udpate=True,
1755 )
1756 lowd_b_slider = widgets.FloatSlider(
1757 value=initial_guess[5],
1758 min=0.1,
1759 max=12.0,
1760 step=0.1,
1761 description="lowd b",
1762 continuous_udpate=True,
1763 )
1764 highd_sigma_slider = widgets.FloatSlider(
1765 value=initial_guess[0],
1766 min=0.1,
1767 max=my_ceil(np.max(dists), 1),
1768 step=0.1,
1769 description="highd sigma",
1770 continuous_udpate=True,
1771 )
1772 highd_a_slider = widgets.FloatSlider(
1773 value=initial_guess[1],
1774 min=0.1,
1775 max=12.0,
1776 step=0.1,
1777 description="highd a",
1778 continuous_udpate=True,
1779 )
1780 highd_b_slider = widgets.FloatSlider(
1781 value=initial_guess[2],
1782 min=0.1,
1783 max=12.0,
1784 step=0.1,
1785 description="highd b",
1786 continuous_udpate=True,
1787 )
1789 # histogram
1790 H, edges = np.histogram(dists, bins=bins, density=True)
1791 H *= 1 / max(H)
1792 centers = np.mean(np.vstack([edges[0:-1], edges[1:]]), axis=0)
1794 # highd sigmoid
1795 x_h = np.linspace(0, max(dists), n_values)
1796 highd_data = {
1797 "sigma": highd_sigma_slider.value,
1798 "a": highd_a_slider.value,
1799 "b": highd_b_slider.value,
1800 }
1801 y_h = sigmoid(x_h, **highd_data)
1803 # diff and norm
1804 dy = np.diff(y_h)
1805 dy_norm = dy / max(dy)
1807 # lowd sigmoid
1808 x_l = np.linspace(0, low_d_max, n_values)
1809 lowd_data = {
1810 "sigma": lowd_sigma_slider.value,
1811 "a": lowd_a_slider.value,
1812 "b": lowd_b_slider.value,
1813 }
1814 y_l = sigmoid(x_l, **lowd_data)
1815 edges_sig = sigmoid(edges, **highd_data)
1816 idx = np.argmin(
1817 np.abs(np.expand_dims(edges_sig, axis=1) - np.expand_dims(y_l, axis=0)), axis=1
1818 )
1819 edges_l = x_l[idx]
1821 # initial subplot with two traces
1822 fig = make_subplots(rows=3, cols=1, subplot_titles=["highd", "scaling", "lowd"])
1824 # add the bar
1825 fig.add_trace(
1826 go.Bar(
1827 x=centers,
1828 y=H,
1829 name="highd dists",
1830 xaxis="x1",
1831 yaxis="y1",
1832 marker_color="blue",
1833 marker_opacity=0.5,
1834 )
1835 )
1837 # add the lowd sigmoid
1838 fig.add_trace(
1839 go.Scatter(
1840 x=x_l,
1841 y=y_l,
1842 mode="lines",
1843 name="lowd sigmoid",
1844 line_color="orange",
1845 showlegend=False,
1846 ),
1847 row=3,
1848 col=1,
1849 )
1851 # add the title
1852 fig.update_layout(
1853 height=800,
1854 width=600,
1855 title={
1856 "text": "Interact with the plot to select sigmoid parameters",
1857 "y": 0.95,
1858 "x": 0.5,
1859 "xanchor": "center",
1860 "yanchor": "middle",
1861 },
1862 # hovermode="x",
1863 )
1865 # add the highd sigmoids to a second axis
1866 fig.add_trace(
1867 go.Scatter(
1868 x=x_h,
1869 y=y_h,
1870 name="sigmoid",
1871 line_color="orange",
1872 mode="lines",
1873 xaxis="x4",
1874 yaxis="y4",
1875 ),
1876 row=1,
1877 col=1,
1878 )
1880 fig.add_trace(
1881 go.Scatter(
1882 x=x_h,
1883 y=dy_norm,
1884 name="diff sigmoid",
1885 line_color="green",
1886 mode="lines",
1887 xaxis="x4",
1888 yaxis="y4",
1889 ),
1890 row=1,
1891 col=1,
1892 )
1894 # some adjustmentns for xaxis3
1895 fig.update_layout(
1896 xaxis1=dict(
1897 title="highd distance",
1898 showgrid=True,
1899 ),
1900 # xaxis2=dict(
1901 # showticklabels=False,
1902 # ),
1903 xaxis3=dict(
1904 title="lowd distance",
1905 ),
1906 yaxis2=dict(
1907 showticklabels=False,
1908 ),
1909 bargap=0,
1910 )
1912 # make the figure responsive
1913 # add connections lines
1914 trace_names = []
1915 _lowd_max = np.max(x_l).copy()
1916 _highd_max = np.max(x_h)
1917 for trace in get_connection_traces(edges, edges_l, _lowd_max, _highd_max):
1918 fig.add_trace(trace, row=2, col=1)
1919 trace_names.append(trace.name)
1921 # create a figure widget
1922 g = go.FigureWidget(fig)
1923 lowd_sigmoid_trace_index = [
1924 trace["name"] == "lowd sigmoid" for trace in g["data"]
1925 ].index(True)
1926 highd_sigmoid_trace_index = [
1927 trace["name"] == "sigmoid" for trace in g["data"]
1928 ].index(True)
1929 diff_sigmoid_trace_index = [
1930 trace["name"] == "diff sigmoid" for trace in g["data"]
1931 ].index(True)
1932 trace_names = np.where(
1933 np.in1d(np.asarray([t.name for t in g["data"]]), np.asarray(trace_names))
1934 )[0]
1935 object_mapping = {
1936 "lowd sigma": {"update_data": [lowd_sigmoid_trace_index], "keyword": "sigma"},
1937 "lowd a": {"update_data": [lowd_sigmoid_trace_index], "keyword": "a"},
1938 "lowd b": {"update_data": [lowd_sigmoid_trace_index], "keyword": "b"},
1939 "highd sigma": {
1940 "update_data": [highd_sigmoid_trace_index, diff_sigmoid_trace_index],
1941 "keyword": "sigma",
1942 },
1943 "highd a": {
1944 "update_data": [highd_sigmoid_trace_index, diff_sigmoid_trace_index],
1945 "keyword": "a",
1946 },
1947 "highd b": {
1948 "update_data": [highd_sigmoid_trace_index, diff_sigmoid_trace_index],
1949 "keyword": "b",
1950 },
1951 }
1953 # define the response function
1954 def response(change):
1955 nonlocal highd_data
1956 nonlocal lowd_data
1957 nonlocal edges
1958 key = change["owner"].description
1959 indices = object_mapping[key]["update_data"]
1960 kwarg = object_mapping[key]["keyword"]
1961 if "highd" in key:
1962 highd_data |= {kwarg: change["new"]}
1963 else:
1964 lowd_data |= {kwarg: change["new"]}
1965 y_h = sigmoid(x_h, **highd_data)
1966 y_l = sigmoid(x_l, **lowd_data)
1967 dy = np.diff(y_h)
1968 dy_norm = dy / max(dy)
1969 edges_sig = sigmoid(edges, **highd_data)
1970 idx = np.argmin(
1971 np.abs(np.expand_dims(edges_sig, axis=1) - np.expand_dims(y_l, axis=0)),
1972 axis=1,
1973 )
1974 new_edges_l = x_l[idx]
1976 # update the parameters
1977 if parameters is not None:
1978 if hasattr(parameters, "cartesian_dist_sig_parameters"):
1979 attr_name = "cartesian_dist_sig_parameters"
1980 else:
1981 attr_name = "dist_sig_parameters"
1982 payload = (
1983 highd_data["sigma"],
1984 highd_data["a"],
1985 highd_data["b"],
1986 lowd_data["sigma"],
1987 lowd_data["a"],
1988 lowd_data["b"],
1989 )
1990 setattr(parameters, attr_name, payload)
1992 # update the fig
1993 with g.batch_update():
1994 g.data[highd_sigmoid_trace_index].y = y_h
1995 g.data[diff_sigmoid_trace_index].y = dy_norm
1996 g.data[lowd_sigmoid_trace_index].y = y_l
1997 for i, (j, l, h) in enumerate(zip(trace_names, new_edges_l, edges)):
1998 # if i % 10 == 0:
1999 l_plot = l / _lowd_max
2000 h_plot = h / _highd_max
2001 g.data[j].x = [l_plot, h_plot]
2002 g.data[j].hovertemplate = f"{h:.2f} in highd maps to {l:.2f} in lowd"
2004 # observe the widgets
2005 lowd_sigma_slider.observe(response, names="value")
2006 lowd_a_slider.observe(response, names="value")
2007 lowd_b_slider.observe(response, names="value")
2008 highd_sigma_slider.observe(response, names="value")
2009 highd_a_slider.observe(response, names="value")
2010 highd_b_slider.observe(response, names="value")
2012 # create containers
2013 lowd_container = widgets.HBox(
2014 children=[lowd_sigma_slider, lowd_a_slider, lowd_b_slider]
2015 )
2016 highd_container = widgets.HBox(
2017 children=[highd_sigma_slider, highd_a_slider, highd_b_slider]
2018 )
2020 # start the app
2021 return widgets.VBox([lowd_container, highd_container, g])
2024def distance_histogram(
2025 data: np.ndarray,
2026 periodicity: float,
2027 sigmoid_parameters: tuple[float, float, float, float, float, float],
2028 axes: Optional[plt.Axes] = None,
2029 low_d_max: int = 5,
2030 bins: Union[Literal["auto"], int] = "auto",
2031) -> tuple[plt.Axes, plt.Axes, plt.Axes]: # pragma: no cover
2032 """Plots the histogram of all pairwise distances in the data.
2034 It also shows the sigmoid function and its normalized derivative.
2036 Args:
2037 data (np.ndarray): 2-dimensional numpy array. Columns should iterate
2038 over the dimensions of the datapoints, i.e. the dimensionality
2039 of the data. The rows should iterate over datapoints.
2040 periodicity (float): Periodicity of the data. Use float("inf")
2041 for non-periodic data.
2042 sigmoid_parameters (tuple): Tuple of sketchmap sigmoid parameters
2043 in shape (Sigma, A, B, sigma, a, b).
2044 axes (Union[np.ndarray, None], optional): A numpy array of two
2045 matplotlib.axes objects or None. If None is provided, the axes will
2046 be created. Defaults to None.
2047 low_d_max (int, optional): Upper limit for plotting the low_d sigmoid.
2048 Defaults to 5.
2049 bins (Union[str, int], optional): Number of bins for histogram.
2050 Use 'auto' to let matplotlib decide how many bins to use. Defaults to 'auto'.
2052 Returns:
2053 tuple: A tuple containing the following:
2054 - plt.axes: A matplotlib.pyplot axis used to plot the high-d distance
2055 sigmoid.
2056 - plt.axes: A matplotlib.pyplot axis used to plot the high-d distance
2057 histogram (a twinx of the first axis).
2058 - plt.axes: A matplotlib.pyplot axis used to plot the lowd sigmoid.
2060 """
2062 vecs = periodic_distance_np(
2063 np.expand_dims(data, axis=1), np.expand_dims(data, axis=0), periodicity
2064 )
2065 dists = np.linalg.norm(vecs, axis=2)
2066 while True:
2067 try:
2068 dists = np.linalg.norm(dists, axis=2)
2069 except np.AxisError:
2070 break
2071 dists = dists.reshape(-1)
2073 if axes is None:
2074 fig, axes = plt.subplots(2)
2075 axe2 = axes[0].twinx()
2076 counts, edges, patches = axe2.hist(
2077 dists, bins=bins, density=True, edgecolor="black"
2078 )
2079 x = np.linspace(0, max(dists), 1000)
2080 y = sigmoid(x, *sigmoid_parameters[:3])
2081 edges_sig = sigmoid(edges, *sigmoid_parameters[:3])
2082 dy = np.diff(y)
2083 dy_norm = dy / max(dy)
2084 axes[0].plot(x, y, color="C1", label="sigmoid")
2085 axes[0].plot(x[:-1], dy_norm, color="C2", label="diff sigmoid")
2087 axes[0].legend()
2088 axes[0].set_xlabel("distance")
2089 axes[0].set_ylim((0, 1))
2090 axes[0].set_zorder(axe2.get_zorder() + 1)
2091 axes[0].patch.set_visible(False)
2092 axes[0].set_title("high-d")
2094 x = np.linspace(0, low_d_max, 1000)
2095 y = sigmoid(x, *sigmoid_parameters[3:])
2096 dy = np.diff(y)
2097 dy_norm = dy / max(dy)
2098 idx = np.argmin(
2099 np.abs(np.expand_dims(edges_sig, axis=1) - np.expand_dims(y, axis=0)), axis=1
2100 )
2101 edges_x = x[idx]
2103 axes[1].plot(x, y, color="C1", label="sigmoid")
2105 axes[1].legend()
2106 axes[1].set_xlabel("distance")
2107 axes[1].set_ylim((0, 1))
2108 axes[1].set_title("low-d")
2109 for i in range(len(edges)):
2110 if edges_x[i] != edges_x[-1]:
2111 axes[1].annotate(
2112 "",
2113 xy=(edges[i], 0),
2114 xytext=(edges_x[i], 0),
2115 xycoords=axes[0].transData,
2116 textcoords=axes[1].transData,
2117 arrowprops=dict(facecolor="black", arrowstyle="-", clip_on=False),
2118 )
2119 axes[0].figure.tight_layout()
2120 return axes[0], axe2, axes[1]
2123def plot_raw_data(
2124 xyz: Union[np.ndarray, "SingleTraj"],
2125 frame_slice: slice = slice(0, 5),
2126 atom_slice: slice = slice(0, 50, 5),
2127) -> go.Figure: # pragma: no cover
2128 """Plots the raw data of a trajectory as xyz slices in a 3D plot.
2130 Conventions:
2131 * x: The cartesian coordinates.
2132 * y: The atom.
2133 * z: The frame.
2135 Args:
2136 xyz (Union[np.ndarray], "SingleTraj"]): Can be either a numpy array with
2137 shape (n_frames, n_atoms, 3) or a SingleTraj object.
2138 frame_slice (slice): A slice to select the frames you want.
2139 atom_slice (slice): A slice to select the atoms you want.
2141 """
2142 if not isinstance(xyz, np.ndarray):
2143 data = xyz.xyz
2144 frame_extend, atom_extend = data.shape[:2]
2145 atoms_ind = []
2146 for i in np.array(xyz.top.select("all"))[atom_slice]:
2147 atoms_ind.append(str(xyz.top.atom(i)))
2148 else:
2149 data = xyz
2150 frame_extend, atom_extend = data.shape[:2]
2151 atoms_ind = np.arange(atom_extend)[atom_slice]
2152 frames_ind = np.arange(frame_extend)[frame_slice]
2154 # create the surfaces
2155 surfaces = []
2156 cmin = float("inf")
2157 cmax = -float("inf")
2158 for frame, xyz_slice in zip(frames_ind, data[frame_slice, atom_slice]):
2159 x = np.arange(3)
2160 y = np.arange(len(atoms_ind))
2161 x, y = np.meshgrid(x, y)
2162 z = np.full(x.shape, frame)
2163 cmin = min([cmin, xyz_slice.min()])
2164 cmax = min([cmax, xyz_slice.max()])
2165 customdata = np.stack(
2166 (
2167 np.full(xyz_slice.T.shape, fill_value=frame),
2168 np.tile(atoms_ind, [3, 1]),
2169 np.tile(["x", "y", "z"], [len(atoms_ind), 1]).T,
2170 xyz_slice.T,
2171 ),
2172 axis=-1,
2173 )
2174 text = (
2175 "Cartesian coordinate %{customdata[2]}<br>of atom %{customdata[1]}<br>at "
2176 "frame %{customdata[0]:.d}: %{customdata[3]:.3f}"
2177 )
2178 surfaces.append(
2179 go.Surface(
2180 x=x,
2181 y=y,
2182 z=z,
2183 surfacecolor=xyz_slice,
2184 customdata=customdata,
2185 coloraxis="coloraxis",
2186 hovertemplate=text,
2187 name="",
2188 )
2189 )
2191 # create the figure
2192 fig = go.Figure(data=surfaces)
2193 fig.update_layout(
2194 title_text="Raw data plot",
2195 title_x=0.5,
2196 scene=dict(
2197 xaxis_title="xyz",
2198 yaxis_title="Atom",
2199 zaxis_title="Frame no.",
2200 xaxis=dict(
2201 tickmode="array",
2202 tickvals=[0, 1, 2],
2203 ticktext=["x", "y", "z"],
2204 ),
2205 yaxis=dict(
2206 tickmode="array",
2207 tickvals=np.arange(len(atoms_ind)),
2208 ticktext=atoms_ind,
2209 ),
2210 zaxis=dict(
2211 tickmode="array",
2212 tickvals=frames_ind,
2213 ticktext=frames_ind,
2214 ),
2215 ),
2216 legend_title="Cartesian coordinate value",
2217 width=700,
2218 height=700,
2219 coloraxis=dict(
2220 colorscale="viridis",
2221 colorbar_thickness=25,
2222 colorbar_len=0.75,
2223 cmin=cmin,
2224 cmax=cmax,
2225 ),
2226 coloraxis_colorbar=dict(
2227 title="value of coordinate",
2228 ),
2229 )
2230 return fig
2233def plot_ball_and_stick(
2234 traj: "SingleTraj",
2235 frame_subsample: Union[int, slice] = slice(None, None, 100),
2236 highlight: Literal["atoms", "bonds", "angles", "dihedrals"] = "atoms",
2237 atom_indices: Optional[Sequence[int]] = None,
2238 custom_colors: Optional[dict[int, str]] = None,
2239 add_angle_arcs: bool = True,
2240 animation: bool = False,
2241 persistent_hover: bool = False,
2242 flatten: bool = False,
2243) -> None: # pragma: no cover
2244 fig = _plot_ball_and_stick(
2245 traj=traj,
2246 frame_subsample=frame_subsample,
2247 highlight=highlight,
2248 atom_indices=atom_indices,
2249 custom_colors=custom_colors,
2250 add_angle_arcs=add_angle_arcs,
2251 animation=animation,
2252 persistent_hover=persistent_hover,
2253 flatten=flatten,
2254 )
2255 fig.show()
2258def plot_ramachandran(
2259 angles: Union[tuple[np.ndarray, np.ndarray], np.ndarray, "SingleTraj"],
2260 subsample: Optional[Union[int, slice, np.ndarray]] = None,
2261) -> None: # pragma: no cover
2262 """Plots a Ramachandran plot using plotly.
2264 Args:
2265 angles (Union[tuple[np.ndarray, np.ndarray], np.ndarray, "SingleTraj"]):
2266 Either a tuple of np.ndarray in which case it is assumed that the
2267 arrays are ordered like (psi, phi). Or an array of shape
2268 (2, n_frames, n_angles), in which case it is unpacked into psi and
2269 phi angles.
2270 subsample (Optional[Union[int, slice, np.ndarray]]): Any way to subsample
2271 the data along the time-axis. Can be int (one frame), slice (more frames,
2272 defined by start, stop, step) or np.ndarray (more frames defined by
2273 their integer index).
2275 """
2276 if isinstance(angles, tuple):
2277 psi, phi = angles
2278 elif isinstance(angles, np.ndarray):
2279 if angles.ndim == 3:
2280 psi, phi = angles
2281 else:
2282 psi, phi = angles[::2], angles[1::2]
2283 elif angles.__class__.__name__ == "SingleTraj":
2284 if not "central_dihedrals" in angles._CVs:
2285 angles.load_CV("central_dihedrals")
2286 _angles = angles._CVs.central_dihedrals
2287 psi = _angles[
2288 0, ..., _angles.CENTRAL_DIHEDRALS.str.lower().str.contains("psi")
2289 ].values
2290 phi = _angles[
2291 0, ..., _angles.CENTRAL_DIHEDRALS.str.lower().str.contains("phi")
2292 ].values
2293 else:
2294 raise ValueError("Wrong type for arg `angles`.")
2296 psi = psi.flatten()
2297 phi = phi.flatten()
2298 if np.all(psi < 4):
2299 mode = "rad"
2300 tickrange = np.linspace(-np.pi, np.pi, 5)
2301 ranges = [-np.pi, np.pi]
2302 else:
2303 mode = "deg"
2304 ranges = [-180, 180]
2305 tickrange = np.linspace(ranges[0], ranges[1], 5)
2306 ticklabels = [-180, -90, 0, 90, 180]
2307 if subsample is not None:
2308 psi = psi[::subsample]
2309 phi = phi[::subsample]
2311 fig = px.density_contour(
2312 x=phi,
2313 y=psi,
2314 marginal_x="violin",
2315 marginal_y="violin",
2316 labels={"x": "phi", "y": "psi"},
2317 range_x=ranges,
2318 range_y=ranges,
2319 )
2321 fig.data[0]["contours"].coloring = "fill"
2323 fig.update_layout(
2324 width=700,
2325 height=700,
2326 title_text="Ramachandran plot",
2327 xaxis=dict(
2328 tickmode="array",
2329 tickvals=tickrange,
2330 ticktext=ticklabels,
2331 ),
2332 yaxis=dict(
2333 tickmode="array",
2334 tickvals=tickrange,
2335 ticktext=ticklabels,
2336 ),
2337 )
2339 fig.show()
2342def plot_dssp(
2343 traj: SingleTraj,
2344 simplified: bool = True,
2345 subsample: Optional[Union[int, slice, np.ndarray]] = None,
2346 residue_subsample: int = 25,
2347) -> go.Figure: # pragma: no cover
2348 # get the dssp and color values
2349 # Third Party Imports
2350 import mdtraj as md
2352 dssp = md.compute_dssp(traj, simplified=simplified)
2354 # the yticks and yticklabels are created here
2355 residues = np.arange(traj.top.n_residues)
2356 residue_names = np.array([str(r) for r in traj.top.residues])
2358 # subsample the dssp array
2359 if subsample is not None:
2360 if isinstance(subsample, int):
2361 subsample = slice(None, None, subsample)
2362 dssp = dssp[subsample]
2364 # vectorize the dssp-str -> rgb-value function and apply
2365 func = np.vectorize(partial(dssp_to_rgb, simplified=simplified))
2366 dssp_color = np.swapaxes(np.dstack(func(dssp)), 0, 1)
2367 func = np.vectorize(partial(dssp_to_text, simplified=simplified))
2368 dssp_text = func(dssp)
2370 # create fig
2371 fig = px.imshow(dssp_color)
2372 customdata_res_names = np.tile(residue_names, (dssp.shape[0], 1))
2373 customdata = np.stack((customdata_res_names.T, dssp_text.T), axis=-1)
2374 fig.update(
2375 data=[
2376 {
2377 "customdata": customdata,
2378 "hovertemplate": "Time: %{x}<br>Residue: %{customdata[0]}<br>DSSP: %{customdata[1]}",
2379 "name": "",
2380 },
2381 ],
2382 )
2384 # subsample the residues for label purposes
2385 if residue_subsample > -1:
2386 residues = residues[::residue_subsample]
2387 residue_names = residue_names[::residue_subsample]
2389 # combine and update layout
2390 # fig = go.Figure(data=fig1.data + fig2.data)
2391 fig.update_layout(
2392 width=1000,
2393 height=700,
2394 title="DSSP plot",
2395 xaxis=dict(title="time"),
2396 yaxis=dict(
2397 title="residue",
2398 tickmode="array",
2399 tickvals=residues,
2400 ticktext=residue_names,
2401 ),
2402 legend=dict(
2403 title="DSSP",
2404 orientation="h",
2405 ),
2406 )
2408 # add the legend
2409 simplified_legend = {
2410 "Coil": "rgb(1.0, 1.0, 1.0)",
2411 "Extended": "rgb(1.0, 0.0, 0.0)",
2412 "Helical": "rgb(0.0, 0.0, 1.0)",
2413 }
2414 legend = {
2415 "Coil": "rgb(1.0, 1.0, 1.0)",
2416 "Isolated beta-bridge": "rgb(0.0, 0.0, 0.0)",
2417 "Extended beta-ladder": "rgb(1.0, 0.0, 0.0)",
2418 "3/10-helix": "rgb(0.5, 0.5, 0.5)",
2419 "Alpha-helix": "rgb(0.0, 0.0, 1.0)",
2420 "Pi-helix": "rgb(0.0, 1.0, 1.0)",
2421 "Bend": "rgb(0.0, 1.0, 0.0)",
2422 "Hydrogen bonded Turn": "rgb(1.0, 1.0, 0.0)",
2423 }
2424 iterator = simplified_legend if simplified else legend
2425 for key, val in iterator.items():
2426 trace = go.Bar(
2427 x=[0],
2428 y=[0],
2429 name=key,
2430 legend="legend1",
2431 # visible="legendonly",
2432 showlegend=True,
2433 )
2434 trace.update(
2435 marker_color=val,
2436 )
2437 fig.add_trace(trace)
2438 # show
2439 return fig
2442def dssp_to_text(
2443 val: str,
2444 simplified: bool = False,
2445) -> str: # pragma: no cover
2446 simplified_dssp = {"C": "Coil", "E": "Extended", "H": "Helical"}
2447 dssp = {
2448 " ": "Coil",
2449 "B": "Isolated beta-bridge",
2450 "E": "Extended beta-ladder",
2451 "G": "3/10-helix",
2452 "H": "Alpha-helix",
2453 "I": "Pi-helix",
2454 "S": "Bend",
2455 "T": "Hydrogen bonded Turn",
2456 }
2457 if simplified:
2458 return simplified_dssp[val]
2459 return dssp[val]
2462def dssp_to_rgb(
2463 val: str,
2464 simplified: bool = False,
2465) -> tuple[int, int, int]: # pragma: no cover
2466 """Here are the values returned for simplified:
2468 * "C": coil, white, rgb(1, 1, 1)
2469 * "E": extended, red, rgb(1, 0, 0)
2470 * "H": helix, blue, rgb(0, 0, 1)
2472 And here for the full DSSP assignment:
2474 * " ": coil, white, rgb(1, 1, 1)
2475 * "B": b-bridge, black, rgb(0, 0, 0)
2476 * "E": b-sheet, red, rgb(1, 0, 0)
2477 * "G": 3_10 helix, grey, rgb(0.5, 0.5, 0.5)
2478 * "H": A-helix, blue, rgb(0, 0, 1)
2479 * "I": pi-helix, purple, rgb(0, 1, 1)
2480 * "S": bend, green, rgb(0, 1, 0)
2481 * "T": turn, yellow(1, 1, 0)
2483 Args:
2484 val (str): The dssp value.
2485 simplified (bool): Whether to use the simplified scheme.
2487 """
2488 simplified_dssp = {"C": (1.0, 1.0, 1.0), "E": (1.0, 0.0, 0.0), "H": (0.0, 0.0, 1.0)}
2489 dssp = {
2490 " ": (1.0, 1.0, 1.0),
2491 "B": (0.0, 0.0, 0.0),
2492 "E": (1.0, 0.0, 0.0),
2493 "G": (0.5, 0.5, 0.5),
2494 "H": (0.0, 0.0, 1.0),
2495 "I": (0.0, 1.0, 1.0),
2496 "S": (0.0, 1.0, 0.0),
2497 "T": (1.0, 1.0, 0.0),
2498 }
2499 if simplified:
2500 return simplified_dssp[val]
2501 return dssp[val]
2504def plot_end2end(
2505 traj: SingleTraj,
2506 selstr: str = "name CA",
2507 subsample: Optional[Union[int, slice, np.ndarray]] = None,
2508 rolling_avg_window: int = 5,
2509) -> go.Figure: # pragma: no cover
2510 atoms = traj.top.select(selstr)[[0, -1]]
2511 dists = md.compute_distances(traj, [atoms])[:, 0]
2512 time = traj.time
2513 if subsample is not None:
2514 if isinstance(subsample, int):
2515 subsample = slice(None, None, subsample)
2516 dists = dists[subsample]
2517 time = time[subsample]
2519 fig = px.scatter(
2520 x=time,
2521 y=dists,
2522 labels=dict(
2523 x="time in ps",
2524 y="dist in nm",
2525 ),
2526 opacity=0.2,
2527 trendline="rolling",
2528 trendline_options=dict(
2529 window=rolling_avg_window,
2530 ),
2531 title="end to end distance",
2532 marginal_y="violin",
2533 )
2534 return fig
2537def _zoomingBoxManual(ax1, ax2, color="red", linewidth=2, roiKwargs={}, arrowKwargs={}):
2538 """Fakes a zoom effect between two mpl.axes.Axes.
2540 Uses mpl.patches.ConnectionPatch and mpl.patches.Rectangle
2541 to make it seem like ax2 is a zoomed in version of ax1.
2542 Instead of defining the coordinates of the zooming rectangle
2543 The axes limits of ax2 are used.
2545 Args:
2546 ax1 (plt.axes): The axes with the zoomed-out data.
2547 ax2 (plt.axes): The second axes with the zoomed-in data.
2548 color (str): The color of the zoom effect. Is passed into mpl,
2549 thus can be str, or tuple, ... Defaults to 'red'
2550 linewidth (int): The linewidth. Defaults to 2.
2551 roiKwargs (dict): Keyworded arguments for the rectangle.
2552 Defaults to {}.
2553 arrowKwargs (dict): Keyworded arguments for the arrow.
2554 Defaults to {}.
2556 """
2557 limits = np.array([*ax2.get_xlim(), *ax2.get_ylim()])
2558 roi = limits
2559 roiKwargs = dict(
2560 dict(
2561 [
2562 ("fill", False),
2563 ("linestyle", "dashed"),
2564 ("color", color),
2565 ("linewidth", linewidth),
2566 ]
2567 ),
2568 **roiKwargs,
2569 )
2570 ax1.add_patch(
2571 mpl.patches.Rectangle(
2572 [roi[0], roi[2]], roi[1] - roi[0], roi[3] - roi[2], **roiKwargs
2573 )
2574 )
2575 arrowKwargs = dict(
2576 dict([("arrowstyle", "-"), ("color", color), ("linewidth", linewidth)]),
2577 **arrowKwargs,
2578 )
2579 corners = np.vstack([limits[[0, 1, 1, 0]], limits[[2, 2, 3, 3]]]).T
2580 con1 = mpl.patches.ConnectionPatch(
2581 xyA=corners[0],
2582 xyB=corners[1],
2583 coordsA="data",
2584 coordsB="data",
2585 axesA=ax2,
2586 axesB=ax1,
2587 )
2588 con1.set_color([0, 0, 0])
2589 ax2.add_artist(con1)
2590 con1.set_linewidth(2)
2591 con2 = mpl.patches.ConnectionPatch(
2592 xyA=corners[3],
2593 xyB=corners[2],
2594 coordsA="data",
2595 coordsB="data",
2596 axesA=ax2,
2597 axesB=ax1,
2598 )
2599 con2.set_color([0, 0, 0])
2600 ax2.add_artist(con2)
2601 con2.set_linewidth(2)
2604def render_vmd(
2605 filepath,
2606 rotation=[0, 0, 0],
2607 scale=1,
2608 script_location="auto",
2609 image_location="auto",
2610 debug=False,
2611 image_name="",
2612 drawframes=False,
2613 ssupdate=True,
2614 renderer="tachyon",
2615 additional_spheres=[],
2616 additional_lines=[],
2617 surf=None,
2618 custom_script=None,
2619):
2620 """Render pdb file with a combination of vmd, tachyon and image magick.
2622 This function creates a standardised vmd tcl/tk script and writes it
2623 to disk. Then vmd is called with the subprocess package and used to
2624 create a tachyon input file. Tachyon is then called to render the image
2625 with ambient occlusion and soft lighting. The output is a high quality
2626 targa (.tga) image, which is converted to png using image magick.
2628 Args:
2629 filepath (str): Location of the pdb file which should be rendered.
2630 rotation ([x_rot, y_rot, z_rot], optional): List of rotation values. Defaults to [0, 0, 0].
2631 scale (float, optional): By how much the structure should be scaled. Defaults to 1.
2632 script_location (str, optional): Where to save the script. Script will be removed
2633 after finish nonehteless. Defaults to 'auto' and writes into cwd.
2634 image_location (str, optional): Where to render the images file to. Will be
2635 deleted nonetheless. Don't give an extension for this. Defaults to 'auto' and
2636 writes into cwd.
2637 debug (bool, optional): Print debug info. Defaults to False.
2638 image_name (str, optional): This string will be used to save the image to after it has
2639 been rendered and converted to png. This will not be deleted. Defaults to ''.
2640 drawframes (bool, optional): If a trajectory is loaded, this will render all frames in it.
2641 Defaults to False.
2642 ssupdate (bool, optional): Updates the secondary structure for every frame. Normally
2643 vmd uses the secondary structure of the first frame. Setting this to True calcs
2644 the sec struct for every frame. Defaults to True.
2645 renderer (str, optional): Which renderer to use.
2646 * 'tachyon' uses the external Tachyon rendered. So vmd -> .dat -> .tga -> .png.
2647 * 'snapshot' uses the vmd internal snapshot renderer.
2648 Defaults to 'tachyon'.
2649 additional_spheres (list, optional): Draw spheres around two subunits to make
2650 them visually more distinct. Takes a list of lists. Each list in the main
2651 list should contain 4 values [x, y, z, r] (r for radius). Defaults to [].
2652 additional_lines (list, optional): A list of additional lines that should be added to the
2653 script. Please refert to the vmd manual for further info. Defaults to [].
2654 surf (Union[str, None], optional): A string defining the surface renderer. Can either be
2655 'quicksurf' or 'surf'. If None is provided, the surface won't be rendered (falls back
2656 to cartoon representation). Defaults to None.
2657 custom_script (Union[str, None], optional): Provide a completely custom script as this option.
2658 The render commands will still be appended to this script. If None is provided, the
2659 default script will be used.
2661 See also:
2662 See this nice webpage about rendering publication worthy images with vmd.
2663 https://www.ks.uiuc.edu/Research/vmd/minitutorials/tachyonao/
2665 Returns:
2666 image (np.ndarray): This array contains the raw pixel data.
2667 Can be used with matplotlib to have a quick view of the image.
2669 """
2670 if "." in image_location:
2671 raise Exception(
2672 "The argument image_location does not take a file extension, because the name is used for a .dat, .tga and .png file."
2673 )
2675 # add a shebang to the script
2676 # script = '#!/home/soft/bin/vmd\n\n'
2678 # print debug hello world
2679 script = 'puts "Hello World"\n'
2681 # if a list of files is provided we iterate over them
2682 if isinstance(filepath, list):
2683 for i, file in enumerate(filepath):
2684 # load molecule and change representation
2685 script += f"mol new {file}\n"
2686 if surf is None:
2687 script += f"mol modstyle 0 {i} newcartoon 0.3 50\n"
2688 script += f"mol modcolor 0 {i} structure\n"
2689 elif surf == "quicksurf":
2690 script += f"mol modstyle 0 {i} quicksurf 0.6 0.7 0.7 Medium\n"
2691 else:
2692 script += f"mol modstyle 0 {i} {surf}\n"
2693 script += f"mol modmaterial 0 {i} AOChalky\n"
2694 if drawframes and md.load(file).n_frames > 1:
2695 if renderer == "STL":
2696 # Standard Library Imports
2697 import warnings
2699 warnings.warn(
2700 "Rendering multiple frames with STL may lead to "
2701 "undesired results. Instead of yielding the union of "
2702 "all single-frame surfaces, you will get a mishmash of "
2703 "all surfaces with intersection faces etc."
2704 )
2705 script += f"mol drawframes 0 {i} 0:1:999\n"
2706 else:
2707 # load molecule and change representation
2708 script += f"mol new {filepath}\n"
2709 if surf is None:
2710 script += "mol modstyle 0 0 newcartoon 0.3 50\n"
2711 script += "mol modcolor 0 0 structure\n"
2712 elif surf == "quicksurf":
2713 script += "mol modstyle 0 0 quicksurf 0.6 0.7 0.7 Medium\n"
2714 else:
2715 script += f"mol modstyle 0 0 {surf}\n"
2716 script += "mol modmaterial 0 0 AOChalky\n"
2717 if drawframes:
2718 script += "mol drawframes 0 0 0:1:999\n"
2720 if ssupdate:
2721 print(
2722 "\033[93m"
2723 + "For the ssupdate function to work encodermap/vmd/sscache.tcl will be sourced within vmd. If no Error is thrown the file is present."
2724 + "\033[0m"
2725 )
2726 sscache_location = (
2727 os.path.split(os.path.split(os.path.split(__file__)[0])[0])[0]
2728 + "/vmd/sscache.tcl"
2729 )
2730 if not os.path.isfile(sscache_location):
2731 raise FileNotFoundError(
2732 f"The sscache.tcl script is not here: {sscache_location}. Please put it there."
2733 )
2734 script += f"source {sscache_location}\n"
2735 script += "start_sscache 0\n"
2736 # script += "proc update_secondary_structure_assigment { args } {"
2737 # script += " foreach molid [molinfo list] {"
2738 # script += " mol ssrecalc $molid"
2739 # script += " }"
2740 # script += "}"
2741 # script += "trace variable vmd_frame(0) w update_secondary_structure_assigment"
2743 # change some parameters to make a nice image
2744 script += "color Display Background white\n"
2745 script += "color Axes Labels black\n"
2746 script += "display depthcue off\n"
2747 script += "display ambientocclusion on\n"
2748 script += "display aoambient 1.0\n"
2749 script += "display aodirect 0.3\n"
2750 script += "display antialias on\n"
2751 # script += 'display resize 2000 2000\n'
2752 script += "axes location off\n"
2754 # scale and rotate
2755 script += f"rotate x by {rotation[0]}\n"
2756 script += f"rotate y by {rotation[1]}\n"
2757 script += f"rotate z by {rotation[2]}\n"
2758 script += f"scale by {scale}\n"
2760 # define image location
2761 if image_location == "auto":
2762 image_location = os.getcwd() + "/vmdscene"
2764 # add spheres
2765 if np.any(additional_spheres):
2766 for _, color in zip(additional_spheres, ["grey", "iceblue"]):
2767 x, y, z, r = np.round(_, 2)
2768 script += f"draw color {color}\n"
2769 script += f"draw sphere {{ {x} {y} {z} }} radius {r} resolution 25\n"
2770 script += "draw material Transparent\n"
2772 # add additional lines
2773 if additional_lines:
2774 for line in additional_lines:
2775 script += line + "\n"
2777 if custom_script is not None:
2778 script = custom_script
2780 # render command. Alternatively, I can use external Tachyon, which makes better images
2781 if renderer == "tachyon":
2782 script += f"render Tachyon {image_location}.dat\n"
2783 elif renderer == "snapshot":
2784 script += "render aasamples TachyonInternal 6\n"
2785 script += f"render TachyonInternal {image_location}.tga\n"
2786 elif renderer == "STL":
2787 script += "axes location off\n"
2788 script += f"render STL {image_location}.stl\n"
2789 elif renderer == "Wavefront":
2790 script += "axes location off\n"
2791 script += f"render Wavefront {image_location}.obj\n"
2792 else:
2793 raise NotImplementedError(
2794 "Other renderers than tachyon and snaphsot currently not supported."
2795 )
2797 # list molecules and quit
2798 script += "mol list\n"
2799 script += "quit"
2801 if debug:
2802 print(script)
2804 # write the script
2805 if script_location == "auto":
2806 script_location = os.getcwd() + "/vmd_script.tcl"
2807 with open(script_location, "w") as f:
2808 f.write(script)
2810 # call vmd -e script
2811 cmd = f"vmd -e {script_location} -dispdev none"
2812 if debug:
2813 print(cmd)
2814 proc = subprocess.Popen(
2815 cmd,
2816 stdin=subprocess.PIPE,
2817 stdout=subprocess.PIPE,
2818 stderr=subprocess.PIPE,
2819 shell=True,
2820 )
2821 (stdout, stderr) = proc.communicate()
2822 if debug:
2823 print(stdout.decode("utf-8"))
2824 print("\n")
2825 print(stderr.decode("utf-8"))
2827 # check if image has been written
2828 if renderer == "tachyon":
2829 assert os.path.isfile(
2830 f"{image_location}.dat"
2831 ), "Tachyon datafile not generated by renderer"
2832 else:
2833 assert os.path.isfile(
2834 f"{image_location}.tga"
2835 ), f"Snapshot image not created. {stderr.decode()} {stdout.decode()}"
2837 time.sleep(2)
2838 assert os.path.isfile(
2839 f"{image_location}.tga"
2840 ), f"Tachyon datafile not generated by renderer. Here's the script:\n\n{script}\n\n"
2842 if renderer == "tachyon":
2843 # call Tachyon and render
2844 cmd = f"/usr/bin/tachyon -aasamples 12 {image_location}.dat -res 2000 2000 -fullshade -format TARGA -o {image_location}.tga"
2845 if debug:
2846 print(cmd)
2847 proc = subprocess.Popen(
2848 cmd,
2849 stdin=subprocess.PIPE,
2850 stdout=subprocess.PIPE,
2851 stderr=subprocess.PIPE,
2852 shell=True,
2853 )
2854 (stdout, stderr) = proc.communicate()
2855 if debug:
2856 print(stdout.decode("utf-8"))
2857 print("\n")
2858 print(stderr.decode("utf-8"))
2860 # check if image has been written
2861 assert os.path.isfile(
2862 f"{image_location}.tga"
2863 ), "Tachyon renderer did not render image"
2865 if renderer == "STL":
2866 if image_name:
2867 shutil.copyfile(f"{image_location}.stl", image_name)
2868 # Third Party Imports
2869 import trimesh
2871 mesh = trimesh.load(f"{image_location}.stl")
2872 os.remove(f"{image_location}.stl")
2873 return mesh
2875 if renderer == "Wavefront":
2876 if image_name:
2877 shutil.copyfile(f"{image_location}.obj", image_name)
2878 shutil.copyfile(f"{image_location}.mtl", image_name.replace(".obj", ".mtl"))
2879 print(
2880 f"Find the rendered images at {image_name} and {image_name.replace('.obj', '.mtl')}."
2881 )
2882 return None
2884 # convert to png
2885 cmd = f"/usr/bin/convert {image_location}.tga {image_location}.png"
2886 if debug:
2887 print(cmd)
2888 proc = subprocess.Popen(
2889 cmd,
2890 stdin=subprocess.PIPE,
2891 stdout=subprocess.PIPE,
2892 stderr=subprocess.PIPE,
2893 shell=True,
2894 )
2895 (stdout, stderr) = proc.communicate()
2896 if debug:
2897 print(stdout.decode("utf-8"))
2898 print("\n")
2899 print(stderr.decode("utf-8"))
2901 # read image
2902 image = plt.imread(f"{image_location}.png")
2904 # write image if name has been provided
2905 if image_name:
2906 if os.path.isabs(image_name):
2907 shutil.copyfile(f"{image_location}.png", image_name)
2908 else:
2909 shutil.copyfile(f"{image_location}.png", os.getcwd() + f"/{image_name}")
2911 # remove temporary files
2912 if renderer == "tachyon":
2913 os.remove(f"{image_location}.dat")
2914 os.remove(f"{image_location}.tga")
2915 os.remove(f"{image_location}.png")
2916 # os.remove(f'{script_location}')
2918 # return matplotlib image object
2919 return image
2922def plot_cluster(
2923 trajs, pdb_path, png_path, cluster_no=None, col="_user_selected_points"
2924):
2925 # Third Party Imports
2926 from mpl_toolkits.axes_grid1 import make_axes_locatable
2928 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2)
2929 fig.set_size_inches(20, 20)
2931 if cluster_no is None:
2932 cluster_no = trajs.CVs[col].max()
2934 # prepare ax1 to make the two side histograms
2935 divider = make_axes_locatable(ax4)
2936 axHistx = divider.append_axes("top", size=1.2, pad=0.1) # , sharex=ax1)
2937 axHisty = divider.append_axes("right", size=1.2, pad=0.1) # , sharey=ax1)
2939 # some data management
2940 data = trajs.lowd
2941 where = np.where(trajs.CVs[col] == cluster_no)
2942 not_where = np.where(trajs.CVs[col] != cluster_no)
2943 x = data[:, 0]
2944 y = data[:, 1]
2946 # scatter everything grey and cluster blue
2947 ax1.scatter(*data[where].T)
2948 ax1.scatter(*data[not_where].T, c="grey", s=5)
2949 ax1.set_xlabel("x in a.u.")
2950 ax1.set_ylabel("y in a.u.")
2951 ax1.set_title(f"Scatter of low-dimensional data")
2953 # density
2954 bin_density = 46
2955 log_density = True
2957 # ax2 gets hexbin density
2958 # x_bins = np.linspace(x.min(), x.max(), bin_density)
2959 # y_bins = np.linspace(y.min(), y.max(), bin_density)
2960 H, xedges, yedges = np.histogram2d(x=x, y=y, bins=bin_density)
2961 extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
2962 xcenters = np.mean(np.vstack([xedges[0:-1], xedges[1:]]), axis=0)
2963 ycenters = np.mean(np.vstack([yedges[0:-1], yedges[1:]]), axis=0)
2964 X, Y = np.meshgrid(xcenters, ycenters)
2965 extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
2966 if log_density:
2967 with np.errstate(divide="ignore"): # ignore division by zero error
2968 F = np.log(H)
2969 else:
2970 F = H
2971 # mappable = ax2.hexbin(x=X.ravel(), y=Y.ravel(), C=F.T.ravel(), cmap=plt.cm.turbo_r, extent=extent,
2972 # norm=mpl.colors.PowerNorm(1), gridsize=bin_density +1)
2973 cmap = plt.get_cmap("turbo").with_extremes(under="w")
2974 mappable = ax2.contourf(
2975 X, Y, H.T, cmap=cmap, levels=np.linspace(0.001, H.max(), 20)
2976 )
2977 ax2.set_xlabel("x in a.u.")
2978 ax2.set_ylabel("y in a.u.")
2979 ax2.set_title("Log density of points")
2981 # colorbar for ax2
2982 # colorbar
2983 # use the axes divider method to add colorbar
2984 ax_divider = make_axes_locatable(ax2)
2985 # add colorbaraxis to work with ticks and whatnot
2986 cax = ax_divider.append_axes("right", size="7%", pad="2%")
2987 # define colorbar norm. I like to work with values between 0 and 1
2988 # initialize colormap
2989 cb = plt.colorbar(mappable, cax=cax)
2990 cax.set_ylabel("Number of points")
2992 # cluster on ax4
2993 # x hist
2994 spines = [k for k in axHistx.spines.values()]
2995 spines[1].set_linewidth(0)
2996 spines[3].set_linewidth(0)
2997 axHistx.set_xticks([])
2998 H, edges, patches = axHistx.hist(data[:, 0][where], bins=50)
2999 axHistx.set_ylabel("count")
3000 axHistx.set_title("Scatter of Cluster")
3002 # y hist
3003 spines = [k for k in axHisty.spines.values()]
3004 spines[1].set_linewidth(0)
3005 spines[3].set_linewidth(0)
3006 axHisty.set_yticks([])
3007 H, edges, patches = axHisty.hist(
3008 data[:, 1][where], bins=50, orientation="horizontal"
3009 )
3010 axHisty.set_xlabel("count")
3012 # scatter data
3013 ax4.scatter(x=data[where, 0], y=data[where, 1])
3014 spines = [k for k in ax4.spines.values()]
3015 spines[3].set_linewidth(0)
3016 spines[1].set_linewidth(0)
3017 ax4.set_xlabel("x in a.u.")
3018 ax4.set_ylabel("y in a.u.")
3020 # annotate rms
3021 rms = np.np.floor(
3022 (1 / len(data[where]))
3023 * np.sum(
3024 (data[where, 0] - np.mean(data[where, 0])) ** 2
3025 + (data[where, 1] - np.mean(data[where, 1])) ** 2
3026 )
3027 )
3028 text = f"RMS = {np.round(rms, decimals=5)}"
3029 ax4.text(0.05, 0.95, text, transform=ax1.transAxes)
3031 # annotate geometric center
3032 centroid = [np.mean(x[where]), np.mean(y[where])]
3033 ax4.scatter(*centroid, s=50, c="C1")
3034 ax4.annotate(
3035 "geom. center",
3036 xy=centroid,
3037 xycoords="data",
3038 xytext=(0.95, 0.95),
3039 textcoords="axes fraction",
3040 arrowprops=dict(facecolor="black", shrink=0.05, fc="w", ec="k", lw=2),
3041 horizontalalignment="right",
3042 verticalalignment="top",
3043 color="C1",
3044 )
3046 # annotate rmsd center
3047 # view, dummy_traj = gen_dummy_traj(trajs, cluster_no, max_frames=100, col=col)
3048 # index, distances, centroid = rmsd_centroid_of_cluster(dummy_traj, parallel=False)
3049 # idx = np.round(np.linspace(0, len(where) - 1, 100)).astype(int)
3050 # where = where[idx]
3051 # centroid = data[where[0][::5][index]]
3052 # ax4.scatter(*centroid, s=50, c='C2')
3053 # ax4.annotate('rmsd center',
3054 # xy=centroid, xycoords='data',
3055 # xytext=(0.95, 0.85), textcoords='axes fraction',
3056 # arrowprops=dict(facecolor='black', shrink=0.05, fc="w", ec="k", lw=2),
3057 # horizontalalignment='right', verticalalignment='top', color='C2')
3059 # make vmd snapshot
3060 try:
3061 image = render_vmd(
3062 pdb_path, drawframes=True, renderer="tachyon", debug=False, scale=1.5
3063 )
3064 ax3.imshow(image)
3065 [k.set_linewidth(0) for k in ax3.spines.values()]
3066 ax3.set_xticks([])
3067 ax3.set_yticks([])
3068 ax3.set_title("Image of cluster")
3069 except:
3070 ax3.annotate("VMD Rendering not possible", (0.5, 0.5))
3071 pass
3073 # # calculate distances between rmsd centroid and all other points
3074 # distances = scipy.spatial.distance.cdist(centroid.reshape(1, 2), np.stack([x, y]).T)
3075 # H, edges, patches = ax3.hist(distances.flatten(), color='C1')
3076 # ax3.set_title("Distances to rmsd centroid.")
3077 # ax3.set_xlabel("Distance in a.u.")
3078 # ax3.set_ylabel("Count")
3080 plt.suptitle(f"Cluster {cluster_no}")
3081 plt.savefig(png_path, transparent=False)
3082 plt.close(fig)