Coverage for encodermap/loading/featurizer.py: 13%
667 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 22:46 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 22:46 +0200
1# -*- coding: utf-8 -*-
2# encodermap/loading/featurizer.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""EncoderMap featurization follows the example of the now deprecated PyEMMA package.
24You can define your features in advance, inspect the expected output and then let
25the computer do the number crunching afterwards. This can be done with either
26PyEMMAs streamable featurization or **new** with dask and delayed on a dask-cluster
27of your liking. Here are the basic concepts of EncoderMap's featurization.
29"""
32################################################################################
33# Imports
34################################################################################
37# Future Imports at the top
38from __future__ import annotations
40# Standard Library Imports
41import itertools
42import numbers
43import os
44import re
45import time
46import warnings
47from pathlib import Path
49# Third Party Imports
50import numpy as np
51from optional_imports import _optional_import
53# Encodermap imports
54from encodermap.loading import features
55from encodermap.loading.delayed import build_dask_xarray
58################################################################################
59# Optional Imports
60################################################################################
63xr = _optional_import("xarray")
64md = _optional_import("mdtraj")
65rich = _optional_import("rich")
66Client = _optional_import("dask", "distributed.Client")
67dask = _optional_import("dask")
68Callback = _optional_import("dask", "callbacks.Callback")
69dot_graph = _optional_import("dask", "dot.dot_graph")
70progress = _optional_import("dask", "distributed.progress")
71HDF5TrajectoryFile = _optional_import("mdtraj", "formats.HDF5TrajectoryFile")
72_get_global_client = _optional_import("distributed", "client._get_global_client")
75################################################################################
76# Typing
77################################################################################
80# Standard Library Imports
81from collections.abc import Iterable, Sequence
82from typing import TYPE_CHECKING, Literal, Optional, Union
85if TYPE_CHECKING:
86 # Third Party Imports
87 import dask
88 import xarray as xr
89 from dask import dot_graph
90 from dask.callbacks import Callback
91 from dask.distributed import Client, progress
92 from distributed.client import _get_global_client
94 # Encodermap imports
95 from encodermap.loading.features import AnyFeature
96 from encodermap.trajinfo.info_all import TrajEnsemble
97 from encodermap.trajinfo.info_single import SingleTraj
100################################################################################
101# Import tqdm which can be either the jupyter one or the plain one
102################################################################################
105def _is_notebook() -> bool: # pragma: no cover
106 """Checks, whether code is currently executed in a notebook."""
107 try:
108 # Third Party Imports
109 from IPython import get_ipython
111 if "IPKernelApp" not in get_ipython().config:
112 return False
113 except ImportError:
114 return False
115 except AttributeError:
116 return False
117 return True
120if _is_notebook(): # pragma: no cover
121 # Third Party Imports
122 from tqdm.notebook import tqdm
123else:
124 # Third Party Imports
125 from tqdm import tqdm
128################################################################################
129# Globals
130################################################################################
133__all__: list[str] = ["Featurizer", "DaskFeaturizer"]
136UNDERSCORE_MAPPING: dict[str, str] = {
137 "central_dihedrals": "CentralDihedrals",
138 "all_cartesians": "AllCartesians",
139 "all_distances": "AllBondDistances",
140 "central_cartesians": "CentralCartesians",
141 "central_distances": "CentralBondDistances",
142 "central_angles": "CentralAngles",
143 "side_cartesians": "SideChainCartesians",
144 "side_distances": "SideChainBondDistances",
145 "side_angles": "SideChainAngles",
146 "side_dihedrals": "SideChainDihedrals",
147}
149_ADD_X_FUNCTION_NAMES: list[str] = [
150 "add_all",
151 "add_selection",
152 "add_distances_ca",
153 "add_distances",
154 "add_inverse_distances",
155 "add_contacts",
156 "add_residue_mindist",
157 "add_group_COM",
158 "add_residue_COM",
159 "add_angles",
160 "add_dihedrals",
161 "add_minrmsd_to_ref",
162 "add_backbone_torsions",
163 "add_sidechain_torsions",
164]
167################################################################################
168# Utils
169################################################################################
172class CoordsLoad:
173 pass
176def is_iterable_of_types(l, supertype):
177 """Checks whether all elements of l are of type `supertype`."""
178 return is_iterable(l) and all(
179 issubclass(t, supertype) for t, _ in itertools.groupby(l, type)
180 )
183def is_iterable_of_int(l):
184 """Checks if l is iterable and contains only integral types."""
185 return is_iterable_of_types(l, numbers.Integral)
188def is_iterable(I):
189 return isinstance(I, Iterable)
192def _atoms_in_residues(
193 traj: SingleTraj,
194 residue_idxs: Sequence[int],
195 subset_of_atom_idxs: Optional[np.ndarray] = None,
196 fallback_to_full_residue: bool = True,
197) -> list[np.ndarray]:
198 """Returns a list of arrays containing the atom indices in each residue of `residue_idxs`
200 Args:
201 traj (SingleTraj): A `SingleTraj` instance.
202 residue_idxs (Sequence[int]): List or ndarray (ndim=1) of integers.
203 subset_of_atom_idxs (Optional[np.ndarray]): Iterable of atom_idxs to which the
204 selection has to be restricted. If None, all atoms considered.
205 Defaults to None.
206 fallback_to_full_residue (bool): It is possible that some
207 residues don't yield any atoms with some subsets. Take all atoms in
208 that case. If False, then [] is returned for that residue.
209 Defaults to None.
211 Returns:
212 list[np.ndarray]: The resulting list of arrays.
214 """
215 atoms_in_residues = []
216 if subset_of_atom_idxs is None:
217 subset_of_atom_idxs = np.arange(traj.top.n_atoms)
218 special_residues = []
219 for rr in traj.top.residues:
220 if rr.index in residue_idxs:
221 toappend = np.array(
222 [aa.index for aa in rr.atoms if aa.index in subset_of_atom_idxs]
223 )
224 if len(toappend) == 0:
225 special_residues.append(rr)
226 if fallback_to_full_residue:
227 toappend = np.array([aa.index for aa in rr.atoms])
229 atoms_in_residues.append(toappend)
231 # Any special cases?
232 if len(special_residues) != 0:
233 if fallback_to_full_residue:
234 msg = "the full residue"
235 else:
236 msg = "emtpy lists"
237 warnings.warn(
238 f"These residues yielded no atoms in the subset and were returned as "
239 f"{msg} {[rr for rr in special_residues[-2:]]}"
240 )
242 return atoms_in_residues
245def combinations(
246 seq: Iterable,
247 k: int,
248) -> np.ndarray:
249 """Return j length subsequences of elements from the input iterable.
251 This version uses Numpy/Scipy and should be preferred over itertools. It avoids
252 the creation of all intermediate Python objects.
254 Examples:
255 >>> import numpy as np
256 >>> from itertools import combinations as iter_comb
257 >>> x = np.arange(3)
258 >>> c1 = combinations(x, 2)
259 >>> print(c1)
260 [[0 1]
261 [0 2]
262 [1 2]]
263 >>> c2 = np.array(tuple(iter_comb(x, 2)))
264 >>> print(c2)
265 [[0 1]
266 [0 2]
267 [1 2]]
268 """
269 # Standard Library Imports
270 from itertools import chain
271 from itertools import combinations as _combinations
273 # Third Party Imports
274 from scipy.special import comb
276 count = comb(len(seq), k, exact=True)
277 res = np.fromiter(chain.from_iterable(_combinations(seq, k)), int, count=count * k)
278 return res.reshape(-1, k)
281def product(*arrays: np.ndarray) -> np.ndarray:
282 """Generate a cartesian product of input arrays.
284 Args:
285 arrays (np.ndarray): 1-D arrays to form the cartesian product of.
287 Returns:
288 np.ndarray: 2-D array of shape (M, len(arrays)) containing cartesian
289 products formed of input arrays.
291 """
292 arrays = [np.asarray(x) for x in arrays]
293 shape = (len(x) for x in arrays)
294 dtype = arrays[0].dtype
296 ix = np.indices(shape)
297 ix = ix.reshape(len(arrays), -1).T
299 out = np.empty_like(ix, dtype=dtype)
301 for n, _ in enumerate(arrays):
302 out[:, n] = arrays[n][ix[:, n]]
304 return out
307def _parse_pairwise_input(
308 indices1: Sequence[int],
309 indices2: Sequence[int],
310) -> np.ndarray:
311 """For input of a pairwise type (distances, inverse distances, contacts) checks
312 the type of input the user gave and formats it so that `DistanceFeature`,
313 `InverseDistanceFeature`, and `ContactFeature` can work.
315 In case the input isn't already a list of distances, this function will:
316 - sort the indices1 array
317 - check for duplicates within the indices1 array
318 - sort the indices2 array
319 - check for duplicates within the indices2 array
320 - check for duplicates between the indices1 and indices2 array
321 - if indices2 is None, produce a list of pairs of indices in indices1, or
322 - if indices2 is not None, produce a list of pairs of (i,j) where i comes from indices1, and j from indices2
324 """
326 if is_iterable_of_int(indices1):
327 # Eliminate duplicates and sort
328 indices1 = np.unique(indices1)
330 # Intra-group distances
331 if indices2 is None:
332 atom_pairs = combinations(indices1, 2)
334 # Inter-group distances
335 elif is_iterable_of_int(indices2):
336 # Eliminate duplicates and sort
337 indices2 = np.unique(indices2)
339 # Eliminate duplicates between indices1 and indices1
340 uniqs = np.in1d(indices2, indices1, invert=True)
341 indices2 = indices2[uniqs]
342 atom_pairs = product(indices1, indices2)
344 else:
345 atom_pairs = indices1
347 return atom_pairs
350def pairs(
351 sel: np.ndarray,
352 excluded_neighbors: int = 0,
353) -> np.ndarray:
354 """Creates all pairs between indexes. Will exclude closest neighbors up to
355 `excluded_neighbors` The self-pair (i,i) is always excluded.
357 Args:
358 sel (np.ndarray): Array with selected atom indexes.
359 excluded_neighbors (int): Number of neighbors that will be excluded
360 when creating the pairs. Defaults to 0.
362 Returns:
363 np.ndarray: A m x 2 array with all pair indexes between
364 different atoms that are at least `excluded_neighbors` indexes
365 apart, i.e. if i is the index of an atom, the pairs
366 [i,i-2], [i,i-1], [i,i], [i,i+1], [i,i+2], will not be in `sel`
367 (n=excluded_neighbors) if `excluded_neighbors` = 2. Moreover,
368 the list is non-redundant,i.e. if [i,j] is in sel, then [j,i] is not.
370 """
371 assert isinstance(excluded_neighbors, int)
373 p = []
374 for i in range(len(sel)):
375 for j in range(i + 1, len(sel)):
376 # get ordered pair
377 I = sel[i]
378 J = sel[j]
379 if I > J:
380 I = sel[j]
381 J = sel[i]
382 # exclude 1 and 2 neighbors
383 if J > I + excluded_neighbors:
384 p.append([I, J])
385 return np.array(p)
388class Track(Callback):
389 def __init__(
390 self,
391 path: str = "/tmp.json/dasks",
392 save_every: int = 1,
393 ) -> None:
394 self.path = path
395 self.save_every = save_every
396 self.n = 0
397 os.makedirs(path, exist_ok=True)
399 def _plot(
400 self,
401 dsk,
402 state,
403 ) -> None:
404 data = {}
405 func = {}
406 for key in state["released"]:
407 data[key] = {"color": "blue"}
408 for key in state["cache"]:
409 data[key] = {"color": "red"}
410 for key in state["finished"]:
411 func[key] = {"color": "blue"}
412 for key in state["running"]:
413 func[key] = {"color": "red"}
415 filename = os.path.join(self.path, "part_{:0>4d}".format(self.n))
417 dot_graph(
418 dsk,
419 filename=filename,
420 format="png",
421 data_attributes=data,
422 function_attributes=func,
423 )
425 def _pretask(
426 self,
427 key,
428 dsk,
429 state,
430 ) -> None:
431 if self.n % self.save_every == 0:
432 self._plot(dsk, state)
433 self.n += 1
435 def _finish(
436 self,
437 dsk,
438 state,
439 errored,
440 ) -> None:
441 self._plot(dsk, state)
442 self.n += 1
445################################################################################
446# Classes
447################################################################################
450class SingleTrajFeaturizer:
451 def __init__(self, traj: SingleTraj, delayed: bool = False) -> None:
452 self.traj = traj
453 self.delayed = delayed
454 self._n_custom_features = 0
455 self._custom_feature_ids = []
456 self.active_features = []
458 def add_list_of_feats(
459 self,
460 which: Union[Literal["all", "full"], Sequence[str]] = "all",
461 deg: bool = False,
462 omega: bool = True,
463 check_aas: bool = True,
464 periodic: bool = True,
465 delayed: bool = False,
466 ) -> None:
467 """Adds features to the Featurizer to be loaded either in-memory. The
468 argument `which` can be either 'all' or a list of the following strings:
469 * 'AllCartesians': Cartesian coordinates of all atoms with
470 shape (n_frames, n_atoms, 3).
471 * 'AllBondDistances': Bond distances of all bonds recognized by
472 mdtraj. Use top = md.Topology.from_openmm()if mdtraj does not
473 recognize all bonds.
474 * 'CentralCartesians': Cartesians of the N, C, CA atoms in the
475 backbone with shape (n_frames, n_residues * 3, 3).
476 * 'CentralBondDistances': The bond distances of the N, C, CA bonds
477 with shape (n_frames, n_residues * 3 - 1).
478 * 'CentralAngles': The angles between the backbone bonds with shape
479 (n_frames, n_residues * 3 - 2).
480 * 'CentralDihedrals': The dihedrals between the backbone atoms
481 (omega, phi, psi). With shape (n_frames, n_residues * 3 - 3).
482 * 'SideChainCartesians': Cartesians of the sidechain-atoms.
483 Starting with CB, CG, ...
484 * 'SideChainBondDistances': Bond distances between the
485 sidechain atoms. starting with the CA-CG bond.
486 * 'SideChainAngles': Angles between sidechain atoms. Starting with
487 the C-CA-CB angle.
488 * 'SideChainDihedrals': Dihedrals of the sidechains (chi1, chi2, chi3).
489 If 'all' is provided for `which` the CentralCartesian, CentralDistances,
490 CentralAngles, CentralDihedrals, SideChainDihedrals will be added.
492 Args:
493 which (Union[str, list], optional). Either add 'all' features or
494 a list of features. See Above for possible features. Defaults
495 to 'all'.
496 deg (bool): Whether the output should be formatted in degrees.
497 omega (bool): Whether to include the omega angles of the backbone.
498 check_aas (bool): Whether to check if all residues in top are known. Helps with custom topology to not skip unkonw custom/non-natural amino acids.
500 """
501 recognized_str = list(UNDERSCORE_MAPPING.keys()) + list(
502 UNDERSCORE_MAPPING.values()
503 )
504 if isinstance(which, str):
505 if which == "full":
506 which = [
507 "CentralCartesians",
508 "CentralBondDistances",
509 "CentralAngles",
510 "CentralDihedrals",
511 "SideChainDihedrals",
512 "SideChainCartesians",
513 "SideChainAngles",
514 "AllCartesians",
515 "SideChainBondDistances",
516 ]
517 elif which == "all":
518 which = [
519 "CentralCartesians",
520 "CentralBondDistances",
521 "CentralAngles",
522 "CentralDihedrals",
523 "SideChainDihedrals",
524 ]
525 else:
526 if which not in recognized_str:
527 raise Exception(
528 f"Recognized arguments to which are 'all' or any of the "
529 f"following: {recognized_str}. The str you provided {which} "
530 f"did not match any."
531 )
532 which = [which]
533 elif isinstance(which, (list, tuple)):
534 assert all([isinstance(i, str) for i in which]), (
535 f"The argument `which` needs to be provided a sequence of str. "
536 f"You have offending types in this argument."
537 )
538 diff = set(which) - set(recognized_str)
539 if diff:
540 raise Exception(
541 f"One or more of the str in `which` are not recognized. "
542 f"The argument `which` needs to be a sequence containing any "
543 f"of the following: {recognized_str}. The unrecognized str are: "
544 f"{diff}."
545 )
547 # add the features
548 for cf in which:
549 if cf in UNDERSCORE_MAPPING:
550 cf = UNDERSCORE_MAPPING[cf]
551 feature = getattr(features, cf)
552 if (
553 not feature._use_periodic
554 and not feature._use_angle
555 and not feature._use_omega
556 ):
557 feature = feature(
558 self.traj,
559 check_aas=True,
560 delayed=delayed,
561 )
562 elif (
563 feature._use_periodic
564 and not feature._use_angle
565 and not feature._use_omega
566 ):
567 feature = feature(
568 self.traj,
569 check_aas=True,
570 periodic=periodic,
571 delayed=delayed,
572 )
573 elif (
574 feature._use_periodic and feature._use_angle and not feature._use_omega
575 ):
576 feature = feature(
577 self.traj,
578 deg=deg,
579 check_aas=check_aas,
580 periodic=periodic,
581 delayed=delayed,
582 )
583 elif feature._use_periodic and feature._use_angle and feature._use_omega:
584 feature = feature(
585 self.traj,
586 deg=deg,
587 omega=omega,
588 check_aas=check_aas,
589 periodic=periodic,
590 delayed=delayed,
591 )
592 else:
593 raise Exception(
594 f"Unknown combination of `_use_angle` and `_use_omega` in "
595 f"class attributes of {feature=}"
596 )
597 self._add_feature(feature)
599 def add_custom_feature(self, feature: AnyFeature) -> None:
600 # Encodermap imports
601 from encodermap.loading.features import CustomFeature
603 if not hasattr(feature, "name"):
604 if not hasattr(feature, "id"):
605 feature.id = self._n_custom_features
606 self._custom_feature_ids.append(self._n_custom_features)
607 self._n_custom_features += 1
608 elif feature.id is None:
609 feature.id = self._n_custom_features
610 self._custom_feature_ids.append(self._n_custom_features)
611 self._n_custom_features += 1
612 else:
613 assert feature.id not in self._custom_feature_ids, (
614 f"A CustomFeature with the id {feature.id} already exists. "
615 f"Please change the id of your CustomFeature."
616 )
617 assert isinstance(feature, CustomFeature) or issubclass(
618 feature.__class__, CustomFeature
619 )
620 feature.name = f"CustomFeature_{feature.id}"
621 self._add_feature(feature)
623 def _add_feature(self, feature: AnyFeature) -> None:
624 """Adds any feature to the list of current features.
626 Also checks whether the feature is already part of the active features.
628 """
629 assert feature.delayed == self.delayed, (
630 f"In-memory featurizer {self.__class__} unexpectedly got a delayed "
631 f"feature {feature}. {feature.delayed=} {self.delayed=}"
632 )
633 if feature.dimension == 0:
634 warnings.warn(
635 f"Given an empty feature (eg. due to an empty/ineffective "
636 f"selection). Skipping it. Feature desc: {feature.describe()}"
637 )
638 return
639 if feature not in self.active_features:
640 self.active_features.append(feature)
641 else:
642 warnings.warn(
643 f"Tried to re-add the same feature {feature.__class__.__name__} to "
644 f"{self.active_features=}"
645 )
647 def add_distances_ca(
648 self,
649 periodic: bool = True,
650 excluded_neighbors: int = 2,
651 delayed: bool = False,
652 ) -> None:
653 """Adds the distances between all Ca's to the feature list.
655 Args:
656 periodic (bool): Use the minimum image convention when computing distances.
657 excluded_neighbors (int): Number of exclusions when compiling the
658 list of pairs. Two CA-atoms are considered neighbors if they
659 belong to adjacent residues. Defaults to 2.
661 """
662 at_idxs_ca = self.select_Ca
663 res_idxs_ca = [self.traj.top.atom(ca).residue.index for ca in at_idxs_ca]
664 res_idxs_ca_pairs = pairs(res_idxs_ca, excluded_neighbors=excluded_neighbors)
665 distance_indexes = []
666 for ri, rj in res_idxs_ca_pairs:
667 distance_indexes.append(
668 [
669 self.traj.top.residue(ri).atom("CA").index,
670 self.traj.top.residue(rj).atom("CA").index,
671 ]
672 )
673 distance_indexes = np.array(distance_indexes)
675 self.add_distances(distance_indexes, periodic=periodic, delayed=delayed)
677 def add_distances(
678 self,
679 indices: Union[np.ndarray, Sequence[int]],
680 periodic: bool = True,
681 indices2: Optional[Sequence[int]] = None,
682 delayed: bool = False,
683 ) -> None:
684 """Adds the distances between atoms to the feature list.
686 Args:
687 indices (Union[np.ndarray, Iterable[Sequence[int]]]): Can be one of to types:
688 A numpy array of shape (n, 2) with the pairs of atoms between
689 which the distances shall be computed. Or a sequence of integers
690 which are the indices (not pairs of indices) of the atoms between
691 which the distances shall be computed. In this case, the arg `indices2`
692 needs to be supplied.
694 periodic (bool): If periodic is True and the trajectory contains
695 unitcell information, distances will be computed under the
696 minimum image convention. Defaults to True.
697 indices2 (Optional[Sequence[int]]): Only has effect if `indices` is
698 a sequence of integers. Instead of the above behavior, only the
699 distances between the atoms in indices` and `indices2` will be
700 computed.
702 Note:
703 When using the iterable of integers input, `indices` and `indices2`
704 will be sorted numerically and made unique before converting them to
705 a pairlist. Please look carefully at the output of `self.describe()` to
706 see what features exactly have been added.
708 """
709 # Local Folder Imports
710 from .features import DistanceFeature
712 atom_pairs = _parse_pairwise_input(indices, indices2)
714 atom_pairs = self._check_indices(atom_pairs)
715 f = DistanceFeature(self.traj, atom_pairs, periodic=periodic, delayed=delayed)
716 self._add_feature(f)
718 def add_backbone_torsions(
719 self,
720 selstr: Optional[int] = None,
721 deg: bool = False,
722 cossin: bool = False,
723 periodic: bool = True,
724 delayed: bool = False,
725 ) -> None:
726 """Adds all backbone phi/psi angles or the ones specified in `selstr` to the feature list.
728 Args:
729 selstr (Optional[str]): If None, all phi/psi angles will be considered.
730 Otherwise, can be a string specifying the atoms of specific
731 backbone torsions (see example).
732 deg (bool): Whether the output should be in degrees (True) or radians
733 (False). Defaults to False.
734 cossin (bool): Whether to return the angles (False) or tuples of their
735 cos and sin values (True). Defaults to False.
736 periodic (bool): Whether to observe the minimum image convention
737 and respect proteins breaking over the periodic boundary
738 condition as a whole (True). In this case, the trajectory container
739 in `traj` needs to have unitcell information. Defaults to True.
741 Examples:
742 >>> import encodermap as em
743 >>> import numpy as np
744 >>> from pprint import pprint
745 >>> trajs = em.load_project("linear_dimers")
746 >>> feat = em.Featurizer(trajs[0])
747 >>> feat.add_backbone_torsions("resname PRO")
748 >>> pprint(feat.describe())
749 ['PHI 0 PRO 19',
750 'PSI 0 PRO 19',
751 'PHI 0 PRO 37',
752 'PSI 0 PRO 37',
753 'PHI 0 PRO 38',
754 'PSI 0 PRO 38',
755 'PHI 0 PRO 95',
756 'PSI 0 PRO 95',
757 'PHI 0 PRO 113',
758 'PSI 0 PRO 113',
759 'PHI 0 PRO 114',
760 'PSI 0 PRO 114']
761 >>> ds = feat.get_output()
762 >>> da = ds.BackboneTorsionFeature
763 >>> phi_indices = da.coords["BACKBONETORSIONFEATURE"].str.contains("PHI")
764 >>> angles = np.rad2deg(da.sel(BACKBONETORSIONFEATURE=phi_indices).values[0])
765 >>> np.min(angles)
766 -103.39891
767 >>> np.max(angles)
768 -10.015779
770 """
771 # Local Folder Imports
772 from .features import BackboneTorsionFeature
774 f = BackboneTorsionFeature(
775 self.traj,
776 selstr=selstr,
777 deg=deg,
778 cossin=cossin,
779 periodic=periodic,
780 delayed=delayed,
781 )
782 self._add_feature(f)
784 def add_angles(
785 self,
786 indexes: np.ndarray,
787 deg: bool = False,
788 cossin: bool = False,
789 periodic: bool = True,
790 delayed: bool = False,
791 ) -> None:
792 """Adds the list of angles to the feature list.
794 Args:
795 indexes (np.ndarray): An array with triplets of atom indices.
796 deg (bool): Whether the output should be in degrees (True) or radians
797 (False). Defaults to False.
798 cossin (bool): Whether to return the angles (False) or tuples of their
799 cos and sin values (True). Defaults to False.
800 periodic (bool): Whether to observe the minimum image convention
801 and respect proteins breaking over the periodic boundary
802 condition as a whole (True). In this case, the trajectory container
803 in `traj` needs to have unitcell information. Defaults to True.
805 """
806 # Local Folder Imports
807 from .features import AngleFeature
809 indexes = self._check_indices(indexes, pair_n=3)
810 f = AngleFeature(
811 self.traj,
812 indexes,
813 deg=deg,
814 cossin=cossin,
815 periodic=periodic,
816 delayed=delayed,
817 )
818 self._add_feature(f)
820 def add_all(
821 self,
822 reference: Optional[md.Trajectory] = None,
823 atom_indices: Optional[np.ndarray] = None,
824 ref_atom_indices: Optional[np.ndarray] = None,
825 delayed: bool = False,
826 ) -> None:
827 """Adds all atom coordinates to the feature list.
828 The coordinates are flattened as follows: [x1, y1, z1, x2, y2, z2, ...]
830 Args:
831 reference (Optional[md.Trajectory]). If different from None, all
832 data is aligned using MDTraj's superpose. Defaults to None.
833 atom_indices (Optional[np.ndarray]): The indices of atoms to superpose
834 If None all atoms will be used. Defaults to None.
835 ref_atom_indices (Optional[np.ndarray]): Use these atoms on the
836 reference structure. If not supplied, the same atom indices
837 will be used for this trajectory and the reference one.
839 """
840 self.add_selection(
841 list(range(self.traj.n_atoms)),
842 reference=reference,
843 atom_indices=atom_indices,
844 ref_atom_indices=ref_atom_indices,
845 delayed=delayed,
846 )
848 def add_selection(
849 self,
850 indexes: np.ndarray,
851 reference: Optional[np.ndarray] = None,
852 atom_indices: Optional[np.ndarray] = None,
853 ref_atom_indices: Optional[np.ndarray] = None,
854 delayed: bool = False,
855 ) -> None:
856 """Adds the coordinates of the selected atom indexes to the feature list.
857 The coordinates of the selection [1, 2, ...] are flattened as follows: [x1, y1, z1, x2, y2, z2, ...]
859 Args:
860 indexes (np.ndarray): Array with selected atom indexes.
861 reference (Optional[md.Trajectory]). If different from None, all
862 data is aligned using MDTraj's superpose. Defaults to None.
863 atom_indices (Optional[np.ndarray]): The indices of atoms to superpose
864 If None, all atoms will be used. Defaults to None.
865 ref_atom_indices (Optional[np.ndarray]): Use these atoms on the
866 reference structure. If not supplied, the same atom indices
867 will be used for this trajectory and the reference one.
869 """
870 # Local Folder Imports
871 from .features import AlignFeature, SelectionFeature
873 if reference is None:
874 f = SelectionFeature(self.traj, indexes, delayed=delayed)
875 else:
876 if not isinstance(reference, md.Trajectory):
877 raise ValueError(
878 "reference is not a mdtraj.Trajectory object, but {}".format(
879 reference
880 )
881 )
882 f = AlignFeature(
883 reference=reference,
884 indexes=indexes,
885 atom_indices=atom_indices,
886 ref_atom_indices=ref_atom_indices,
887 delayed=delayed,
888 )
889 self._add_feature(f)
891 def add_inverse_distances(
892 self,
893 indices: Union[np.ndarray, Sequence[int]],
894 periodic: bool = True,
895 indices2: Optional[Union[np.ndarray, Sequence[int]]] = None,
896 delayed: bool = False,
897 ) -> None:
898 """Adds the inverse distances between atoms to the feature list.
900 Args:
901 indices (Union[np.ndarray, Sequence[int]]): A array with shape (n, 2)
902 giving the pairs of atoms between which the inverse distances
903 shall be computed. Can also be a sequence of integers giving
904 the first atoms in the distance calculations. In this case,
905 `indices2` needs to be supplied.
906 periodic (bool): Whether to observe the minimum image convention
907 and respect proteins breaking over the periodic boundary
908 condition as a whole (True). In this case, the trajectory container
909 in `traj` needs to have unitcell information. Defaults to True.
910 indices2 (Optional[Union[np.ndarray, Sequence[int]]]): If the argument
911 `indices` is just a sequence of int (and not a (n, 2) np.ndarray),
912 this argument needs to be provided.
914 Note:
915 When using the *iterable of integers* input, `indices` and `indices2`
916 will be sorted numerically and made unique before converting them to
917 a pairlist. Please look carefully at the output of `describe()` to
918 see what features exactly have been added.
920 """
921 # Local Folder Imports
922 from .features import InverseDistanceFeature
924 atom_pairs = _parse_pairwise_input(
925 indices,
926 indices2,
927 )
929 atom_pairs = self._check_indices(atom_pairs)
930 f = InverseDistanceFeature(
931 self.traj, atom_pairs, periodic=periodic, delayed=delayed
932 )
933 self._add_feature(f)
935 def add_contacts(
936 self,
937 indices: Union[np.ndarray, Sequence[int]],
938 indices2: Optional[Union[np.ndarray, Sequence[int]]] = None,
939 threshold: float = 0.3,
940 periodic: bool = True,
941 count_contacts: bool = False,
942 delayed: bool = False,
943 ) -> None:
944 """Adds the contacts to the feature list.
946 Args:
947 indices (Union[np.ndarray, Sequence[int]]): A array with shape (n, 2)
948 giving the pairs of atoms between which the inverse distances
949 shall be computed. Can also be a sequence of integers giving
950 the first atoms in the distance calculations. In this case,
951 `indices2` needs to be supplied.
952 indices2 (Optional[Union[np.ndarray, Sequence[int]]]): If the argument
953 `indices` is just a sequence of int (and not a (n, 2) np.ndarray),
954 this argument needs to be provided.
955 threshold (float): Distance below this (in nanometer) are considered
956 as contacts. The output will contain 1.0 for these contacts.
957 Above this threshold, the output will contain 0.0. Defaults to 0.2.
958 periodic (bool): Whether to observe the minimum image convention
959 and respect proteins breaking over the periodic boundary
960 condition as a whole (True). In this case, the trajectory container
961 in `traj` needs to have unitcell information. Defaults to True.
962 count_contacts (bool): If set to true, this feature will return
963 the number of formed contacts (and not feature values with
964 either 1.0 or 0). The output of this feature will be of shape
965 (Nt,1), and not (Nt, nr_of_contacts). Defaults to False.
967 Note:
968 When using the *iterable of integers* input, `indices` and `indices2`
969 will be sorted numerically and made unique before converting them
970 to a pairlist. Please look carefully at the output of `describe()`
971 to see what features exactly have been added.
973 """
974 # Local Folder Imports
975 from .features import ContactFeature
977 atom_pairs = _parse_pairwise_input(indices, indices2)
978 atom_pairs = self._check_indices(atom_pairs)
979 f = ContactFeature(
980 self.traj, atom_pairs, threshold, periodic, count_contacts, delayed=delayed
981 )
982 self._add_feature(f)
984 def add_residue_mindist(
985 self,
986 residue_pairs: Union[Literal["all"], np.ndarray] = "all",
987 scheme: Literal["ca", "closest", "closest-heavy"] = "closest-heavy",
988 ignore_nonprotein: bool = True,
989 threshold: Optional[float] = None,
990 periodic: bool = True,
991 count_contacts: bool = False,
992 delayed: bool = False,
993 ) -> None:
994 """Adds the minimum distance between residues to the feature list.
995 See below how the minimum distance can be defined. If the topology
996 generated out of `traj` contains information on periodic boundary
997 conditions, the minimum image convention will be used when computing
998 distances.
1000 Args:
1001 residue_pairs (Union[Literal["all"], np.ndarray]): Can be 'all', in
1002 which case mindists will be calculated between all pairs of
1003 residues excluding first and second neighbor. If a np.array
1004 with shape (n ,2) is supplied, these residue indices (0-based)
1005 will be used to compute the mindists. Defaults to 'all'.
1006 scheme (Literal["ca", "closest", "closest-heavy"]): Within a residue,
1007 determines the sub-group atoms that will be considered when
1008 computing distances. Defaults to 'closest-heavy'.
1009 ignore_nonprotein (bool): Whether to ignore residues that are not
1010 of protein type (e.g. water molecules, post-translational modifications,
1011 non-standard residues, etc.). Defaults to True.
1012 threshold (float): Distances below this threshold (in nm) will
1013 result in a feature 1.0, the distances above will result in 0.0. If
1014 left to None, the numerical value will be returned. Defaults to None.
1015 periodic (bool): Whether to observe the minimum image convention
1016 and respect proteins breaking over the periodic boundary
1017 condition as a whole (True). In this case, the trajectory container
1018 in `traj` needs to have unitcell information. Defaults to True.
1019 count_contacts (bool): If set to true, this feature will return
1020 the number of formed contacts (and not feature values with
1021 either 1.0 or 0). The output of this feature will be of shape
1022 (Nt,1), and not (Nt, nr_of_contacts). Defaults to False.
1024 Note:
1025 Using `scheme` = 'closest' or 'closest-heavy' with
1026 `residue pairs` = 'all' will compute nearly all interatomic distances,
1027 for every frame, before extracting the closest pairs. This can be
1028 very time-consuming. Those schemes are intended to be used with a
1029 subset of residues chosen `residue_pairs`.
1032 """
1033 # Local Folder Imports
1034 from .features import ResidueMinDistanceFeature
1036 if scheme != "ca" and isinstance(residue_pairs, str):
1037 if residue_pairs == "all":
1038 warnings.warn(
1039 "Using all residue pairs with schemes like closest or "
1040 "closest-heavy is very time consuming. Consider reducing "
1041 "the residue pairs"
1042 )
1044 f = ResidueMinDistanceFeature(
1045 self.traj,
1046 residue_pairs,
1047 scheme,
1048 ignore_nonprotein,
1049 threshold,
1050 periodic,
1051 count_contacts=count_contacts,
1052 delayed=delayed,
1053 )
1055 self._add_feature(f)
1057 def add_group_COM(
1058 self,
1059 group_definitions: Sequence[int],
1060 ref_geom: Optional[md.Trajectory] = None,
1061 image_molecules: bool = False,
1062 mass_weighted: bool = True,
1063 delayed: bool = False,
1064 ) -> None:
1065 """Adds the centers of mass (COM) in cartesian coordinates of a group or
1066 groups of atoms. If these group definitions coincide directly with
1067 residues, use `add_residue_COM` instead. No periodic boundaries are
1068 taken into account.
1070 Args:
1071 group_definitions (Sequence[int]): List of the groups of atom indices
1072 for which the COM will be computed. The atoms are zero-indexed.
1073 ref_geom (Optional[md.Trajectory]): If a md.Trajectory is provided,
1074 the coordinates of the provided traj will be centered using
1075 this reference, before computing COM. If None is provided,
1076 they won't be centered. Defaults to None.
1077 image_molecules (bool): The method traj.image_molecules will be
1078 called before computing averages. The method tries to correct
1079 for molecules broken across periodic boundary conditions, but
1080 can be time consuming. See
1081 http://mdtraj.org/latest/api/generated/mdtraj.Trajectory.html#mdtraj.Trajectory.image_molecules
1082 for more details. Defualts to False.
1083 mass_weighted (bool): Set to False if you want the geometric center
1084 and not the COM. Defaults to True.
1086 """
1087 # Local Folder Imports
1088 from .features import GroupCOMFeature
1090 f = GroupCOMFeature(
1091 self.traj,
1092 group_definitions,
1093 ref_geom=ref_geom,
1094 image_molecules=image_molecules,
1095 mass_weighted=mass_weighted,
1096 delayed=delayed,
1097 )
1098 self._add_feature(f)
1100 def add_residue_COM(
1101 self,
1102 residue_indices: Sequence[int],
1103 scheme: Literal["all", "backbone", "sidechain"] = "all",
1104 ref_geom: Optional[md.Trajectory] = None,
1105 image_molecules: bool = False,
1106 mass_weighted: bool = True,
1107 delayed: bool = False,
1108 ) -> None:
1109 """Adds a per-residue center of mass (COM) in cartesian coordinates.
1110 No periodic boundaries are taken into account.
1112 Args:
1113 residue_indices (Sequence[int]): List of the residue indices
1114 for which the COM will be computed. The atoms are zero-indexed.
1115 scheme (Literal["all", "backbone", "sidechain"]): What atoms
1116 contribute to the COM computation. If the scheme yields no
1117 atoms for some residue, the selection falls back to 'all' for
1118 that residue. Defaults to 'all'.
1119 ref_geom (Optional[md.Trajectory]): If a md.Trajectory is provided,
1120 the coordinates of the provided traj will be centered using
1121 this reference, before computing COM. If None is provided,
1122 they won't be centered. Defaults to None.
1123 image_molecules (bool): The method traj.image_molecules will be
1124 called before computing averages. The method tries to correct
1125 for molecules broken across periodic boundary conditions, but
1126 can be time consuming. See
1127 http://mdtraj.org/latest/api/generated/mdtraj.Trajectory.html#mdtraj.Trajectory.image_molecules
1128 for more details. Defualts to False.
1129 mass_weighted (bool): Set to False if you want the geometric center
1130 and not the COM. Defaults to True.
1132 """
1134 # Local Folder Imports
1135 from .features import ResidueCOMFeature
1137 assert scheme in ["all", "backbone", "sidechain"]
1139 residue_atoms = _atoms_in_residues(
1140 self.traj,
1141 residue_indices,
1142 subset_of_atom_idxs=self.traj.top.select(scheme),
1143 )
1145 f = ResidueCOMFeature(
1146 self.traj,
1147 np.asarray(residue_indices),
1148 residue_atoms,
1149 scheme,
1150 ref_geom=ref_geom,
1151 image_molecules=image_molecules,
1152 mass_weighted=mass_weighted,
1153 delayed=delayed,
1154 )
1156 self._add_feature(f)
1158 def add_dihedrals(
1159 self,
1160 indexes: np.ndarray,
1161 deg: bool = False,
1162 cossin: bool = False,
1163 periodic: bool = True,
1164 delayed: bool = False,
1165 ) -> None:
1166 """Adds the list of dihedrals to the feature list
1168 Args:
1169 indexes (np.ndarray): An array with quadruplets of atom indices.
1170 deg (bool): Whether the output should be in degrees (True) or radians
1171 (False). Defaults to False.
1172 cossin (bool): Whether to return the angles (False) or tuples of their
1173 cos and sin values (True). Defaults to False.
1174 periodic (bool): Whether to observe the minimum image convention
1175 and respect proteins breaking over the periodic boundary
1176 condition as a whole (True). In this case, the trajectory container
1177 in `traj` needs to have unitcell information. Defaults to True.
1179 """
1180 # Local Folder Imports
1181 from .features import DihedralFeature
1183 indexes = self._check_indices(indexes, pair_n=4)
1184 f = DihedralFeature(
1185 self.traj,
1186 indexes,
1187 deg=deg,
1188 cossin=cossin,
1189 periodic=periodic,
1190 delayed=delayed,
1191 )
1192 self._add_feature(f)
1194 def add_sidechain_torsions(
1195 self,
1196 selstr: Optional[str] = None,
1197 deg: bool = False,
1198 cossin: bool = False,
1199 periodic: bool = True,
1200 which: Union[
1201 Literal["all"], Sequence[Literal["chi1", "chi2", "chi3", "chi4", "chi5"]]
1202 ] = "all",
1203 delayed: bool = False,
1204 ) -> None:
1205 """Adds all side chain torsion angles or the ones specified in `selstr`
1206 to the feature list.
1208 Args:
1210 selstr (Optional[str]): Selection string specifying the atom
1211 selection used to specify a specific set of backbone angles.
1212 If None (default), all chi angles found in the topology will be
1213 computed. Defaults to None.
1214 deg (bool): Whether the output should be in degrees (True) or radians
1215 (False). Defaults to False.
1216 cossin (bool): Whether to return the angles (False) or tuples of their
1217 cos and sin values (True). Defaults to False.
1218 periodic (bool): Whether to observe the minimum image convention
1219 and respect proteins breaking over the periodic boundary
1220 condition as a whole (True). In this case, the trajectory container
1221 in `traj` needs to have unitcell information. Defaults to True.
1222 which (Union[Literal["all"], Sequence[Literal["chi1", "chi2", "chi3", "chi4", "chi5"]]]):
1223 Which angles to consider. Can be 'all' or any combination of
1224 ('all', 'chi1', 'chi2', 'chi3', 'chi4', 'chi5').
1226 """
1227 # Local Folder Imports
1228 from .features import SideChainTorsions
1230 f = SideChainTorsions(
1231 self.traj,
1232 selstr=selstr,
1233 deg=deg,
1234 cossin=cossin,
1235 periodic=periodic,
1236 which=which,
1237 delayed=delayed,
1238 )
1239 self._add_feature(f)
1241 def add_minrmsd_to_ref(
1242 self,
1243 ref: Union[md.Trajectory, SingleTraj],
1244 ref_frame: int = 0,
1245 atom_indices: Optional[np.ndarray] = None,
1246 precentered: bool = False,
1247 delayed: bool = False,
1248 ) -> None:
1249 """Adds the minimum root-mean-square-deviation (minrmsd)
1250 with respect to a reference structure to the feature list.
1252 Args:
1253 ref (Union[md.Trajectory, SingleTraj]): Reference structure for
1254 computing the minrmsd.
1255 ref_frame (int): Reference frame of the filename specified in `ref`.
1256 Defaults to 0.
1257 atom_indices (Optional[np.ndarray]): Atoms that will be used for:
1258 1. aligning the target and reference geometries.
1259 2. computing rmsd after the alignment.
1260 If left to None, all atoms of `ref` will be used.
1261 precentered (bool): Use this boolean at your own risk to let
1262 mdtraj know that the target conformations are already centered
1263 at the origin, i.e., their (uniformly weighted) center of
1264 mass lies at the origin. This will speed up the computation of
1265 the rmsd. Defaults to False
1266 """
1267 # Local Folder Imports
1268 from .features import MinRmsdFeature
1270 f = MinRmsdFeature(
1271 self.traj,
1272 ref,
1273 ref_frame=ref_frame,
1274 atom_indices=atom_indices,
1275 precentered=precentered,
1276 delayed=delayed,
1277 )
1278 self._add_feature(f)
1280 @property
1281 def ndim(self) -> int:
1282 return self.dimension()
1284 @property
1285 def features(self) -> list[AnyFeature]:
1286 return self.active_features
1288 @property
1289 def select_Ca(self) -> np.ndarray:
1290 return self.traj.top.select("name CA")
1292 def _check_indices(self, pair_inds: np.ndarray, pair_n: int = 2) -> np.ndarray:
1293 """Ensure pairs are valid (shapes, all atom indices available?, etc.)"""
1295 pair_inds = np.array(pair_inds).astype(dtype=int, casting="safe")
1297 if pair_inds.ndim != 2:
1298 raise ValueError("pair indices has to be a matrix.")
1300 if pair_inds.shape[1] != pair_n:
1301 raise ValueError(f"pair indices shape has to be (x, {pair_n}).")
1303 if pair_inds.max() > self.traj.top.n_atoms:
1304 raise ValueError(
1305 f"index out of bounds: {pair_inds.max()}. Maximum atom index "
1306 f"available: {self.traj.top.n_atoms}"
1307 )
1309 return pair_inds
1311 def transform(self, p: Optional[tqdm] = None) -> np.ndarray:
1312 """Calls the `transform()` methods of the accumulated features.
1314 Args:
1315 p (Optional[tqdm]): If an instance of tqdm is provided, a progress
1316 it will be updated after every call of `feature.transform()`.
1317 If None is provided, no progress bar will be displayed.
1319 Returns:
1320 np.ndarray: A numpy array with the features in `self.active_features`,
1321 stacked along the feature dimension.
1323 """
1324 # Encodermap imports
1325 from encodermap.loading.features import CustomFeature
1327 # if there are no features selected, return given trajectory
1328 if not self.active_features:
1329 warnings.warn(
1330 "You have not selected any features. Add features and call "
1331 "`transform` or `get_output` again.."
1332 )
1333 return
1335 # otherwise, build feature vector.
1336 feature_vec = []
1338 for f in self.active_features:
1339 # perform sanity checks for custom feature input
1340 if isinstance(f, CustomFeature):
1341 # NOTE: casting=safe raises in numpy>=1.9
1342 vec = f.transform(self.traj).astype(np.float32, casting="safe")
1343 if vec.shape[0] == 0:
1344 vec = np.empty((0, f.dimension))
1346 if not isinstance(vec, np.ndarray):
1347 raise ValueError(
1348 f"Your custom feature {f.describe()} did not return a numpy.ndarray!"
1349 )
1350 if not vec.ndim == 2:
1351 raise ValueError(
1352 f"Your custom feature {f.describe()} did not return a "
1353 f"2d array. Shape was {vec.shape}"
1354 )
1355 if not vec.shape[0] == self.traj.xyz.shape[0]:
1356 raise ValueError(
1357 f"Your custom feature {f.describe()} did not return as "
1358 f"many frames, as it received. Input was {self.traj.xyz.shape[0]}, "
1359 f"output was {vec.shape[0]}"
1360 )
1361 else:
1362 vec = f.transform().astype(np.float32)
1363 feature_vec.append(vec)
1365 if p is not None:
1366 p.update()
1368 if len(feature_vec) > 1:
1369 res = np.hstack(feature_vec)
1370 else:
1371 res = feature_vec[0]
1373 return res
1375 def get_output(self, pbar: Optional[tqdm] = None) -> xr.Dataset:
1376 # Encodermap imports
1377 from encodermap.misc.xarray import unpack_data_and_feature
1379 if pbar is None:
1380 if self.traj.basename is None:
1381 desc = f"Getting output of {len(self.active_features)} features"
1382 else:
1383 desc = (
1384 f"Getting output of {len(self.active_features)} features for "
1385 f"{self.traj.basename}"
1386 )
1387 pbar = tqdm(
1388 total=len(self.active_features),
1389 desc=desc,
1390 )
1391 with pbar as p:
1392 out = self.transform(p=p)
1393 return unpack_data_and_feature(self, self.traj, out)
1395 def describe(self) -> list[str]:
1396 all_labels = []
1397 for f in self.active_features:
1398 all_labels += f.describe()
1399 return all_labels
1401 def __repr__(self) -> str:
1402 # Standard Library Imports
1403 import pprint
1405 feat_str = pprint.pformat(self.describe()[:10])[:-1] + ", ...]"
1406 return f"EncoderMap Featurizer with features:\n{feat_str}"
1408 def __len__(self) -> int:
1409 return len(self.active_features)
1411 def dimension(self) -> int:
1412 return sum(f.dimension for f in self.active_features)
1415class Featurizer:
1416 """EncoderMap's featurization has drawn much inspiration from PyEMMA
1417 (https://github.com/markovmodel/PyEMMA).
1419 EncoderMap's Featurizer collects and computes collective variables (CVs).
1420 CVs are data that are aligned with MD trajectories on the frame/time axis.
1421 Trajectory data contains (besides the topology) an axis for atoms, and
1422 an axis for cartesian coordinate (x, y, z), so that a trajectory can be
1423 understood as an array with shape (n_frames, n_atoms, 3). A CV is an array
1424 that is aligned with the frame/time and has its own feature axis. If the
1425 trajectory in our example has 3 residues (MET, ALA, GLY), we can define
1426 6 dihedral angles along the backbone of this peptide. These angles are:
1428 * PSI1: Between MET1-N - MET1-CA - MET1-C - ALA2-N
1429 * OMEGA1: Between MET1-CA - MET1-C - ALA2-N - ALA2-CA
1430 * PHI1: Between MET1-C - ALA2-N - ALA2-CA - ALA2-C
1431 * PSI2: Between ALA2-N - ALA2-CA - ALA2-C - GLY3-N
1432 * OMEGA2: Between ALA2-CA - ALA2-C - GLY3-N - GLY3-CA
1433 * PHI2: Between ALA2-C - GLY3-N - GLY3-CA - GLY3-C
1435 Thus, the collective variable 'backbone-dihedrals' provides an array of
1436 shape (n_frames, 6) and is aligned with the frame/time axis of the trajectory.
1438 """
1440 def __new__(cls, traj: Union[SingleTraj, TrajEnsemble]):
1441 # Encodermap imports
1442 from encodermap.trajinfo.info_single import SingleTraj
1444 if isinstance(traj, SingleTraj):
1445 return SingleTrajFeaturizer(traj)
1446 else:
1447 return EnsembleFeaturizer(traj)
1450class AddSingleFeatureMethodsToClass(type):
1451 """Metaclass that programatically adds methods to the EnsembleFeaturizer."""
1453 def __new__(cls, name, bases, dct): # pragma: no doccheck
1454 x = super().__new__(cls, name, bases, dct)
1456 # iteratively add these functions
1457 for add_X_function_name in _ADD_X_FUNCTION_NAMES:
1458 # create a function with the corresponding add_X_function_name
1459 # IMPORTANT: keep this as a keyword argument, to prevent
1460 # python from late-binding
1461 def add_X_func(
1462 self, *args, add_x_name=add_X_function_name, **kwargs
1463 ) -> None:
1464 # iterate over the trajs in self.trajs
1465 for top, trajs in self.trajs.trajs_by_top.items():
1466 # create a featurizer
1467 if top not in self.feature_containers:
1468 f = SingleTrajFeaturizer(trajs[0], delayed=self.delayed)
1469 self.feature_containers[top] = f
1470 else:
1471 f = self.feature_containers[top]
1472 # get the method defined by pyemma_function_name
1473 func = getattr(f, add_x_name)
1474 # call the method with *args and **kwargs, so that the
1475 # correct feature is added
1476 func(*args, **kwargs)
1477 # this is the feature we are looking for.
1478 feature = f.active_features[-1]
1479 # add the feature
1480 if top in self.active_features:
1481 if feature in self.active_features[top]:
1482 continue
1483 else:
1484 self.active_features[top].append(feature)
1485 else:
1486 self.active_features.setdefault(top, []).append(feature)
1488 # also add the docstring :)
1489 add_X_func.__doc__ = getattr(
1490 SingleTrajFeaturizer, add_X_function_name
1491 ).__doc__
1492 setattr(x, add_X_function_name, add_X_func)
1493 return x
1496class DaskFeaturizerMeta(type):
1497 def __new__(cls, name, bases, dct):
1498 x = super().__new__(cls, name, bases, dct)
1500 for add_X_function_name in _ADD_X_FUNCTION_NAMES + ["add_list_of_feats"]:
1502 def add_X_func(self, *args, add_x_name=add_X_function_name, **kwargs):
1503 # call the parents featurizer class add function
1504 assert self.feat.delayed, (
1505 f"Programmatically added `add_X_func` got a featurizer with a"
1506 f"wrong `delayed` variable: {id(self.feat)=} {self.feat.delayed=}"
1507 )
1508 getattr(self.feat, add_x_name)(*args, delayed=True, **kwargs)
1510 add_X_func.__doc__ = getattr(
1511 SingleTrajFeaturizer,
1512 add_X_function_name,
1513 ).__doc__
1514 setattr(x, add_X_function_name, add_X_func)
1515 return x
1518class EnsembleFeaturizer(metaclass=AddSingleFeatureMethodsToClass):
1519 """The EnsembleFeaturizer is a container of multiple SinlgeTrajFeaturizer.
1521 The `SingleTrajFeaturizer` are collected in a dict with the topologies
1522 of the sub-ensembles as keys.
1524 """
1526 def __init__(self, trajs: TrajEnsemble, delayed: bool = False) -> None:
1527 """Instantiates the `EnsembleFeaturizer`.
1529 Args:
1530 trajs (TrajEnsmble): The `TrajEnsemble` to featurizer.
1531 delayed (bool): Whether using dask to calculate features, or just do
1532 a regular featurization.
1534 """
1535 self.trajs = trajs
1536 self.delayed = delayed
1537 self.active_features = {}
1538 self.feature_containers = {}
1539 self.ensemble = False
1540 self._n_custom_features = 0
1541 self._custom_feature_ids = []
1543 def describe(self) -> dict[md.Topology, list[str]]:
1544 """Returns the labels of the feature output.
1546 Returns:
1547 dict[md.Topology, list[str]]: A dict where the keys are the
1548 topologies in the `TrajEnsemble` and the values are the
1549 `describe()` outputs of the `SingleTrajFeaturizer` classes.
1551 """
1552 out = {}
1553 for top, container in self.feature_containers.items():
1554 out[top] = container.describe()
1555 return out
1557 def __len__(self) -> int:
1558 lengths = [len(f) for f in self.feature_containers.values()]
1559 assert all(
1560 [lengths[0] == length for length in lengths]
1561 ), f"This `{self.__class__.__name__}` has uneven features per topology."
1562 if len(lengths) < 1:
1563 return 0
1564 return lengths[0]
1566 def _add_feature(
1567 self, f: AnyFeature, top: md.Topology, trajs: TrajEnsemble
1568 ) -> None:
1569 assert f.delayed == self.delayed, (
1570 f"In-memory featurizer {self.__class__} unexpectedly got a delayed "
1571 f"feature {f}. {f.delayed=} {self.delayed=}"
1572 )
1573 if top in self.feature_containers:
1574 feat = self.feature_containers[top]
1575 else:
1576 feat = SingleTrajFeaturizer(trajs[0], delayed=self.delayed)
1577 self.feature_containers[top] = feat
1578 feat._add_feature(f)
1579 self.active_features.setdefault(top, []).append(f)
1581 def add_custom_feature(self, feature: AnyFeature) -> None:
1582 # Encodermap imports
1583 from encodermap.loading.features import CustomFeature
1585 # decide on feature's id
1586 if feature.__class__.__name__ == "CustomFeature":
1587 if not hasattr(feature, "name"):
1588 if not hasattr(feature, "id"):
1589 feature.id = self._n_custom_features
1590 self._custom_feature_ids.append(self._n_custom_features)
1591 self._n_custom_features += 1
1592 elif feature.id is None:
1593 feature.id = self._n_custom_features
1594 self._custom_feature_ids.append(self._n_custom_features)
1595 self._n_custom_features += 1
1596 else:
1597 assert feature.id not in self._custom_feature_ids, (
1598 f"A CustomFeature with the id {feature.id} already exists. "
1599 f"Please change the id of your CustomFeature."
1600 )
1601 assert (
1602 isinstance(feature, CustomFeature)
1603 or issubclass(feature.__class__, CustomFeature)
1604 or hasattr(feature, "_is_custom")
1605 )
1606 feature.name = f"CustomFeature_{feature.id}"
1607 else:
1608 try:
1609 feature.name = feature.__class__.__name__
1610 except AttributeError:
1611 pass
1613 # add
1614 for top, trajs in self.trajs.trajs_by_top.items():
1615 self._add_feature(feature, top, trajs)
1617 def add_list_of_feats(
1618 self,
1619 which: Union[Literal["all"], Sequence[str]] = "all",
1620 ensemble: bool = False,
1621 deg: bool = False,
1622 omega: bool = True,
1623 check_aas: bool = True,
1624 periodic: bool = True,
1625 delayed: bool = False,
1626 ) -> None:
1627 """Adds features to the Featurizer to be loaded either in-memory. The
1628 argument `which` can be either 'all' or a list of the following strings:
1629 * 'AllCartesians': Cartesian coordinates of all atoms with
1630 shape (n_frames, n_atoms, 3).
1631 * 'AllBondDistances': Bond distances of all bonds recognized by
1632 mdtraj. Use top = md.Topology.from_openmm()if mdtraj does not
1633 recognize all bonds.
1634 * 'CentralCartesians': Cartesians of the N, C, CA atoms in the
1635 backbone with shape (n_frames, n_residues * 3, 3).
1636 * 'CentralBondDistances': The bond distances of the N, C, CA bonds
1637 with shape (n_frames, n_residues * 3 - 1).
1638 * 'CentralAngles': The angles between the backbone bonds with shape
1639 (n_frames, n_residues * 3 - 2).
1640 * 'CentralDihedrals': The dihedrals between the backbone atoms
1641 (omega, phi, psi). With shape (n_frames, n_residues * 3 - 3).
1642 * 'SideChainCartesians': Cartesians of the sidechain-atoms.
1643 Starting with CB, CG, ...
1644 * 'SideChainBondDistances': Bond distances between the
1645 sidechain atoms. starting with the CA-CG bond.
1646 * 'SideChainAngles': Angles between sidechain atoms. Starting with
1647 the C-CA-CB angle.
1648 * 'SideChainDihedrals': Dihedrals of the sidechains (chi1, chi2, chi3).
1649 If 'all' is provided for `which` the CentralCartesian, CentralDistances,
1650 CentralAngles, CentralDihedrals, SideChainDihedrals will be added.
1652 Args:
1653 which (Union[str, list], optional). Either add 'all' features or
1654 a list of features. See Above for possible features. Defaults
1655 to 'all'.
1656 ensemble (bool): Whether the trajs in this class belong to an ensemble.
1657 This implies that they contain either the same topology or are
1658 very similar (think wt, and mutant). Setting this option True will
1659 try to match the CVs of the trajs onto the same dataset.
1660 If a VAL residue has been replaced by LYS in the mutant,
1661 the number of sidechain dihedrals will increase. The CVs of the
1662 trajs with VAL will thus contain some NaN values. Defaults to False.
1663 deg (bool): Whether the output should be formatted in degrees.
1664 omega (bool): Whether to include the omega angles of the backbone.
1665 check_aas (bool): Whether to check if all residues in top are known. Helps with custom topology to not skip unkonw custom/non-natural amino acids.
1667 """
1668 self.ensemble = ensemble
1669 recognized_str = list(UNDERSCORE_MAPPING.keys()) + list(
1670 UNDERSCORE_MAPPING.values()
1671 )
1672 for top, trajs in self.trajs.trajs_by_top.items():
1673 if isinstance(which, str):
1674 if which == "full":
1675 which = [
1676 "CentralCartesians",
1677 "CentralBondDistances",
1678 "CentralAngles",
1679 "CentralDihedrals",
1680 "SideChainDihedrals",
1681 "SideChainCartesians",
1682 "SideChainAngles",
1683 "AllCartesians",
1684 "SideChainBondDistances",
1685 ]
1686 if which == "all":
1687 which = [
1688 "CentralCartesians",
1689 "CentralBondDistances",
1690 "CentralAngles",
1691 "CentralDihedrals",
1692 "SideChainDihedrals",
1693 ]
1694 else:
1695 if which not in recognized_str:
1696 raise Exception(
1697 f"Recognized arguments to which are 'all' or any of the "
1698 f"following: {recognized_str}. The str you provided {which} "
1699 f"did not match any."
1700 )
1701 which = [which]
1702 elif isinstance(which, (list, tuple)):
1703 assert all([isinstance(i, str) for i in which]), (
1704 f"The argument `which` needs to be provided a sequence of str. "
1705 f"You have offending types in this argument."
1706 )
1707 diff = set(which) - set(recognized_str)
1708 if diff:
1709 raise Exception(
1710 f"One or more of the str in `which` are not recognized. "
1711 f"The argument `which` needs to be a sequence containing any "
1712 f"of the following: {recognized_str}. The unrecognized str are: "
1713 f"{diff}."
1714 )
1716 # add the features
1717 # the _use_omega and _use_angle class attrs are added by the
1718 # `FeatureMeta` in `features.py` by inspecting a `Feature` subclass'
1719 # call signature
1720 for cf in which:
1721 if cf in UNDERSCORE_MAPPING:
1722 cf = UNDERSCORE_MAPPING[cf]
1723 feature = getattr(features, cf)
1724 if (
1725 not feature._use_periodic
1726 and not feature._use_angle
1727 and not feature._use_omega
1728 ):
1729 feature = feature(
1730 trajs[0],
1731 check_aas=True,
1732 generic_labels=ensemble,
1733 delayed=delayed,
1734 )
1735 elif (
1736 feature._use_periodic
1737 and not feature._use_angle
1738 and not feature._use_omega
1739 ):
1740 feature = feature(
1741 trajs[0],
1742 check_aas=True,
1743 generic_labels=ensemble,
1744 periodic=periodic,
1745 delayed=delayed,
1746 )
1747 elif (
1748 feature._use_periodic
1749 and feature._use_angle
1750 and not feature._use_omega
1751 ):
1752 feature = feature(
1753 trajs[0],
1754 deg=deg,
1755 check_aas=check_aas,
1756 generic_labels=ensemble,
1757 periodic=periodic,
1758 delayed=delayed,
1759 )
1760 elif (
1761 feature._use_periodic and feature._use_angle and feature._use_omega
1762 ):
1763 feature = feature(
1764 trajs[0],
1765 deg=deg,
1766 omega=omega,
1767 check_aas=check_aas,
1768 generic_labels=ensemble,
1769 periodic=periodic,
1770 delayed=delayed,
1771 )
1772 else:
1773 raise Exception(
1774 f"Unknown combination of `_use_angle` and `_use_omega` in "
1775 f"class attributes of {feature=}"
1776 )
1777 if top in self.active_features:
1778 if feature in self.active_features[top]:
1779 warnings.warn(
1780 f"Tried to re-add the same feature {feature.__class__.__name__} to "
1781 f"{self.active_features=}"
1782 )
1783 continue
1784 else:
1785 self.active_features[top].append(feature)
1786 else:
1787 self.active_features.setdefault(top, []).append(feature)
1789 if top in self.feature_containers:
1790 f = self.feature_containers[top]
1791 else:
1792 f = SingleTrajFeaturizer(trajs[0], delayed=self.delayed)
1793 self.feature_containers[top] = f
1794 f._add_feature(feature)
1796 # after all is done, all tops should contain the same number of feats
1797 no_of_feats = set([len(v) for v in self.active_features.values()])
1798 assert len(no_of_feats) == 1, (
1799 f"I was not able to add the same number of features to the respective "
1800 f"topologies:\n{self.active_features=}\n{self.feature_containers=}"
1801 )
1803 @property
1804 def features(self) -> list[AnyFeature]:
1805 feats = []
1806 for features in self.active_features.items():
1807 feats.extend(list(features))
1808 return feats
1810 def transform(
1811 self,
1812 traj: Union[SingleTraj, md.Trajectory],
1813 outer_p: Optional[Union[tqdm, rich.progress.Progress]] = None,
1814 inner_p: Optional[Union[tqdm, rich.progress.Progress]] = None,
1815 inner_p_id: Optional[int] = None,
1816 ) -> np.ndarray:
1817 """Applies the features to the trajectory data.
1819 traj (Union[SingleTraj, md.Trajectory]): The trajectory which provides
1820 the data. Make sure, that the topology of this traj matches the
1821 topology used to initialize the features.
1822 outer_p (Optional[Union[tqdm, rich.progress.Progress]]): An object
1823 that supports `.update()` to advance a progress bar. The
1824 `rich.progress.Progress` is special, as it needs additional code
1825 to advance the multi-file progress bar dispolayed by the
1826 `EnsembleFeaturzier`. The `outer_p` represents the overall
1827 progress.
1828 inner_p (Optional[Union[tqdm, rich.progress.Progress]]): Same as `outer_p`,
1829 but the `inner_p` represents the progress per file.
1830 inner_p_id (Optional[int]): The id of the `inner_p`, which needs to be
1831 provided, if `outer_p` and `inner_p` are instances of
1832 `rich.progress.Progress`.
1834 """
1835 # Encodermap imports
1836 from encodermap.loading.features import CustomFeature
1838 feature_vec = []
1840 for f in self.active_features[traj.top]:
1841 xyz = traj.xyz
1842 if traj._have_unitcell:
1843 unitcell_vectors = traj.unitcell_vectors
1844 unitcell_info = np.hstack([traj.unitcell_lengths, traj.unitcell_angles])
1845 else:
1846 unitcell_vectors = None
1847 unitcell_info = None
1848 if isinstance(f, CustomFeature):
1849 vec = f.transform(
1850 traj,
1851 xyz,
1852 unitcell_vectors,
1853 unitcell_info,
1854 ).astype(np.float32, casting="safe")
1855 if vec.shape[0] == 0:
1856 vec = np.empty((0, f.dimension))
1858 if not isinstance(vec, np.ndarray):
1859 raise ValueError(
1860 "Your custom feature %s did not return"
1861 " a numpy.ndarray!" % str(f.describe())
1862 )
1863 if vec.ndim == 1:
1864 vec = np.expand_dims(vec, -1)
1865 if vec.ndim == 3:
1866 vec = vec.reshape(xyz.shape[0], -1)
1867 f.atom_feature = True
1868 if not vec.shape[0] == traj.xyz.shape[0]:
1869 raise ValueError(
1870 "Your custom feature %s did not return"
1871 " as many frames as it received!"
1872 "Input was %i, output was %i"
1873 % (str(f.describe()), traj.xyz.shape[0], vec.shape[0])
1874 )
1875 else:
1876 vec = f.transform(
1877 xyz,
1878 unitcell_vectors,
1879 unitcell_info,
1880 ).astype(np.float32)
1881 feature_vec.append(vec)
1883 if outer_p is not None:
1884 if isinstance(outer_p, rich.progress.Progress):
1885 outer_p.update(0, advance=1)
1886 else:
1887 outer_p.update()
1889 if inner_p is not None:
1890 if isinstance(inner_p, rich.progress.Progress):
1891 if inner_p_id is None:
1892 inner_p_id = traj.traj_num + 1
1893 outer_p.update(inner_p_id, advance=1)
1894 else:
1895 outer_p.update()
1897 if len(feature_vec) > 1:
1898 res = np.hstack(feature_vec)
1899 else:
1900 res = feature_vec[0]
1902 # sleep half a second to let the progbars catch up
1903 if outer_p is not None or inner_p is not None:
1904 time.sleep(0.5)
1906 return res
1908 def n_features(self) -> int:
1909 for i, (key, val) in enumerate(self.active_features.items()):
1910 if i == 0:
1911 length = len(val)
1912 else:
1913 _debug = []
1914 for i, (key, val) in enumerate(self.active_features.items()):
1915 _debug.append(f"Top: {key}\nValue: {val}")
1916 _debug = "\n\n".join(_debug)
1917 assert length == len(val), (
1918 f"There are different number of features per topology in "
1919 f"`self.active_features`. These features can't be transformed. "
1920 f"Here are the features by topology:\n{_debug}"
1921 )
1922 return length
1924 def get_output(
1925 self,
1926 pbar: Optional[tqdm] = None,
1927 ) -> xr.Dataset:
1928 # Encodermap imports
1929 from encodermap.misc.xarray import unpack_data_and_feature
1931 if self.active_features == {}:
1932 print(f"First add some features before calling `get_output()`.")
1933 return
1934 DSs = []
1935 n_features = self.n_features()
1937 try:
1938 # Third Party Imports
1939 from rich.progress import Progress
1941 _rich_installed = True
1942 except ModuleNotFoundError:
1943 _rich_installed = False
1945 if pbar is None and _rich_installed:
1946 with Progress() as progress:
1947 tasks = []
1948 progress.add_task(
1949 description=(
1950 f"Getting output for an ensemble containing "
1951 f"{self.trajs.n_trajs} trajs"
1952 ),
1953 total=n_features * self.trajs.n_trajs,
1954 )
1955 for i, traj in enumerate(self.trajs):
1956 desc = traj.basename
1957 if traj.basename == "trajs":
1958 desc = f"trajectory {traj.traj_num}"
1959 tasks.append(
1960 progress.add_task(
1961 description=(
1962 f"Getting output of {n_features} features for "
1963 f"{desc}"
1964 ),
1965 total=n_features,
1966 )
1967 )
1968 for i, traj in enumerate(self.trajs):
1969 out = self.transform(traj, progress, progress, inner_p_id=i + 1)
1970 ds = unpack_data_and_feature(
1971 self.feature_containers[traj.top], traj, out
1972 )
1973 DSs.append(ds)
1974 else:
1975 for i, traj in enumerate(self.trajs):
1976 out = self.transform(traj, pbar, None)
1977 ds = unpack_data_and_feature(
1978 self.feature_containers[traj.top], traj, out
1979 )
1980 DSs.append(ds)
1981 return format_output(DSs)
1984def format_output(
1985 datasets: Sequence[xr.Dataset],
1986) -> xr.Dataset:
1987 """Concatenates multiple xr.Datasets and keeps coordinates in correct order.
1989 Iterates over the labels in the coords that are not `traj`, `time` and picks
1990 the one with the greatest dimension. These labels will be used as
1991 the column names, the non-defined values are np.nan.
1993 Args:
1994 datasets (Sequence[xr.Dataset]): The datasets to combine.
1996 Returns:
1997 xr.Dataset: The output dataset.
1999 """
2000 # Encodermap imports
2001 from encodermap.trajinfo.trajinfo_utils import trajs_combine_attrs
2003 # make sure that all traj-nums are unique
2004 traj_nums = [ds.traj_num.values for ds in datasets]
2005 assert all([i.size == 1 for i in traj_nums])
2006 traj_nums = np.array(traj_nums)[:, 0]
2007 assert len(traj_nums) == len(np.unique(traj_nums)), (
2008 f"The sequence of datasets provided for arg `datasets` contains multiple "
2009 f"traj_nums: {traj_nums=}"
2010 )
2012 # create a large dataset
2013 out = xr.concat(
2014 datasets,
2015 data_vars="all",
2016 # compat="broadcast_equals",
2017 # coords="all",
2018 # join="outer",
2019 dim="traj_num",
2020 fill_value=np.nan,
2021 combine_attrs=trajs_combine_attrs,
2022 )
2024 # EncoderMap datasets
2025 encodermap_dataarrays = list(UNDERSCORE_MAPPING.keys())
2027 all_labels = {}
2028 for name, da in out.data_vars.items():
2029 if name not in encodermap_dataarrays:
2030 continue
2031 feature_axis = da.attrs["feature_axis"]
2032 labels = da.coords[feature_axis].values
2033 all_labels.setdefault(feature_axis, []).extend(labels)
2035 # side dihedrals and central dihedrals need some special sorting
2036 # the other generic labels can be sorted by their last int
2037 for key, val in all_labels.items():
2038 all_labels[key] = np.unique(all_labels[key])
2040 if key == "SIDE_DIHEDRALS":
2041 all_labels[key] = sorted(
2042 all_labels[key],
2043 key=lambda x: (
2044 int(re.findall(r"\d+", x)[-1]),
2045 int(re.findall(r"\d+", x)[0]),
2046 ),
2047 )
2048 elif key == "ALLATOM":
2049 all_labels[key] = sorted(
2050 all_labels[key],
2051 key=lambda x: (
2052 0 if x.endswith("c") else 1,
2053 *map(int, re.findall(r"\d+", x)[::-1]),
2054 ),
2055 )
2056 elif key == "CENTRAL_DIHEDRALS":
2057 all_labels[key] = sorted(
2058 all_labels[key],
2059 key=lambda x: (
2060 int(re.findall(r"\d+", x)[-1]),
2061 1 if "PSI" in x else (2 if "OMEGA" in x else 3),
2062 ),
2063 )
2064 else:
2065 all_labels[key] = sorted(
2066 all_labels[key], key=lambda x: int(re.findall(r"\d+", x)[-1])
2067 )
2068 return out.reindex(all_labels, fill_value=np.nan)
2071class DaskFeaturizer(metaclass=DaskFeaturizerMeta):
2072 """Container for `SingleTrajFeaturizer` and `EnsembleFeaturizer`
2073 that implements delayed transforms.
2075 The DaskFeaturizer is similar to the other two featurizer classes and
2076 mostly implements the same API. However, instead of computing the
2077 transformations using in-memory computing, it prepares a `xarray.Dataset`,
2078 which contains `dask.Arrays`. This dataset can be lazily and distributively
2079 evaluated using dask.distributed clients and clusters.
2081 """
2083 def __init__(
2084 self,
2085 trajs: Union[SingleTraj, TrajEnsemble],
2086 n_workers: Union[str, int] = "cpu-2",
2087 client: Optional[Client] = None,
2088 ) -> None:
2090 if not hasattr(trajs, "itertrajs"):
2091 self.feat = SingleTrajFeaturizer(trajs, delayed=True)
2092 else:
2093 self.feat = EnsembleFeaturizer(trajs, delayed=True)
2095 if n_workers == "cpu-2":
2096 # Standard Library Imports
2097 from multiprocessing import cpu_count
2099 n_workers = cpu_count() - 2
2101 if n_workers == "max":
2102 # Standard Library Imports
2103 from multiprocessing import cpu_count
2105 n_workers = cpu_count()
2107 dask.config.set(scheduler="processes")
2109 if client is None:
2110 self.client = _get_global_client()
2111 else:
2112 self.client = client
2113 if self.client is None:
2114 self.client = Client(n_workers=n_workers)
2115 print(
2116 f"Created dask scheduler. Access the dashboard via: "
2117 f"{self.client.dashboard_link}"
2118 )
2119 else:
2120 print(
2121 f"Using existing dask scheduler. Access the dashboard via: "
2122 f"{self.client.dashboard_link}"
2123 )
2125 def add_custom_feature(self, feature):
2126 if not hasattr(feature, "delayed"):
2127 feature.delayed = True
2128 if not feature.delayed:
2129 feature.delayed = True
2130 self.feat.add_custom_feature(feature)
2131 if hasattr(self, "dataset"):
2132 warnings.warn(
2133 f"The compute graph has already been built. I will rebuild the "
2134 f"graph and add the feature as a transformer. Subsequent "
2135 f"calls to `.get_output()` will include this feature."
2136 )
2137 self.build_graph()
2139 def build_graph(
2140 self,
2141 traj: Optional[SingleTraj] = None,
2142 streamable: bool = False,
2143 return_delayeds: bool = False,
2144 ) -> None:
2145 """Prepares the dask graph.
2147 Args:
2148 with_trajectories (Optional[bool]): Whether to also compute xyz.
2149 This can be useful if you want to also save the trajectories to disk.
2151 """
2152 if self.feat.active_features == {} or self.feat.active_features == []:
2153 print(f"First add some features before calling `get_output()`.")
2154 return
2156 self.dataset, self.variables = build_dask_xarray(
2157 self,
2158 traj=traj,
2159 streamable=streamable,
2160 return_delayeds=return_delayeds,
2161 )
2163 def to_netcdf(
2164 self,
2165 filename: Union[Path, str],
2166 overwrite: bool = False,
2167 with_trajectories: bool = False,
2168 ) -> str:
2169 """Saves the dask tasks to a NetCDF4 formatted HDF5 file.
2171 Args:
2172 filename (Union[str, list[str]]): The filename to be used.
2173 overwrite (bool): Whether to overwrite the existing filename.
2174 with_trajectories (bool): Also save the trajectory data. The output
2175 file can be read with `encodermap.load(filename)` and rebuilds
2176 the trajectories complete with traj_nums, common_str, custom_top,
2177 and all CVs, that this featurizer calculates.
2179 Returns:
2180 str: Returns the filename of the created files.
2182 """
2183 # Standard Library Imports
2184 from pathlib import Path
2186 filename = Path(filename)
2187 if "dataset" in self.__dict__:
2188 raise Exception(f"Graph already built.")
2190 # allows multiple writes to netcdf4 files
2191 def set_env():
2192 os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
2194 self.client.run(set_env)
2196 if filename.is_file() and not overwrite: # pragma: nocover
2197 raise Exception(
2198 f"File {filename} already exists. Set `overwrite=True` to overwrite."
2199 )
2200 if filename.is_file() and overwrite:
2201 filename.unlink()
2203 # build
2204 self.build_graph(return_delayeds=with_trajectories)
2206 if self.variables is not None:
2207 # Third Party Imports
2208 import h5py
2209 from xarray import conventions
2210 from xarray.backends.api import (
2211 _finalize_store,
2212 _validate_attrs,
2213 _validate_dataset_names,
2214 )
2215 from xarray.backends.common import ArrayWriter
2216 from xarray.backends.h5netcdf_ import H5NetCDFStore
2218 # use xarrays's to_netcdf code and add saving of delayed coordinates, etc.
2219 _validate_dataset_names(self.dataset)
2220 _validate_attrs(self.dataset, invalid_netcdf=False)
2221 store_open = H5NetCDFStore.open
2222 have_chunks = any(
2223 v.chunks is not None for v in self.dataset.variables.values()
2224 )
2225 autoclose = have_chunks
2226 store = store_open(
2227 filename=filename,
2228 mode="a",
2229 format="NETCDF4",
2230 autoclose=autoclose,
2231 )
2232 writer = ArrayWriter()
2233 try:
2234 # create dicts of data to write to store
2235 variables, attrs = conventions.encode_dataset_coordinates(self.dataset)
2236 variables |= self.variables
2237 store.store(variables, attrs, set(), writer, None)
2238 store.close()
2239 writes = writer.sync(compute=False)
2240 finally:
2241 store.close()
2243 # this runs the computation and displays a progress
2244 delayed = dask.delayed(_finalize_store)(writes, store)
2245 delayed = delayed.persist()
2246 progress(delayed)
2247 delayed.compute()
2249 # afterward, we remove the unwanted groups starting with md from the .h5 file
2250 # they are artifacts of hijacking xarray's `to_netcdf`
2251 # we also move all keys that are not part of the traj coords
2252 md_keys = ["coordinates", "time", "cell_lengths", "cell_angles"]
2253 with h5py.File(filename, "a") as f:
2254 keys = list(f.keys())
2255 for key in filter(lambda k: k.startswith("md"), keys):
2256 del f[key]
2257 keys = list(f.keys())
2258 for key in keys:
2259 if not any([m in key for m in md_keys]):
2260 f.move(key, f"CVs/{key}")
2262 # and add common_str, custom_top, etc.
2263 self.feat.trajs.save(fname=filename, CVs=False, only_top=True)
2264 else:
2265 self.dataset.to_netcdf(
2266 filename,
2267 format="NETCDF4",
2268 group="CVs",
2269 engine="h5netcdf",
2270 invalid_netcdf=False,
2271 compute=True,
2272 )
2273 return str(filename)
2275 def get_output(
2276 self,
2277 make_trace: bool = False,
2278 ) -> xr.Dataset:
2279 """This function passes the trajs and the features of to dask to create a
2280 delayed xarray out of that."""
2281 if "dataset" not in self.__dict__:
2282 self.build_graph()
2283 if not make_trace:
2284 ds = self.dataset.compute()
2285 if not ds:
2286 raise Exception(
2287 f"Computed dataset is empty. Maybe a computation failed in "
2288 f"the dask-delayed dataset: {self.dataset}"
2289 )
2290 # future = client.submit(future)
2291 # out = self.client.compute(self.dataset)
2292 # progress(out)
2293 # return out.result()
2294 else:
2295 raise NotImplementedError("Currently not able to trace dask execution.")
2296 # else:
2297 # with tempfile.TemporaryDirectory() as tmpdir:
2298 # tmpdir = Path(tmpdir)
2299 # with Track(path=str(tmpdir)):
2300 # out = self.client.compute(self.dataset)
2301 # progress(out)
2302 # return out.result()
2303 #
2304 # raise NotImplementedError(
2305 # "gifsicle --delay 10 --loop=forever --colors 256 --scale=0.4 -O3 --merge dasks/part_*.png > output.gif"
2306 # )
2307 return ds
2309 @property
2310 def feature_containers(self) -> dict[md.Topology, SingleTrajFeaturizer]:
2311 return self.feat.feature_containers
2313 @property
2314 def active_features(
2315 self,
2316 ) -> Union[list[AnyFeature], dict[md.Topology, list[AnyFeature]]]:
2317 return self.feat.active_features
2319 def __len__(self):
2320 return len(self.feat)
2322 def transform(
2323 self,
2324 traj_or_trajs: Optional[Union[SingleTraj, TrajEnsemble]] = None,
2325 *args,
2326 **kwargs,
2327 ) -> np.ndarray:
2328 return self.feat.transform(traj_or_trajs, *args, **kwargs)
2330 def describe(self) -> list[str]:
2331 return self.feat.describe()
2333 def dimension(self) -> int:
2334 return self.feat.dimension
2336 def visualize(self) -> None:
2337 return dask.visualize(self.dataset)