Coverage for encodermap/trajinfo/info_single.py: 11%
815 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/trajinfo/info_single.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Classes to work with ensembles of trajectories.
24The statistics of a protein can be better described by an ensemble of proteins,
25rather than a single long trajectory. Treating a protein in such a way opens great
26possibilities and changes the way one can treat molecular dynamics data.
27Trajectory ensembles allow:
28 * Faster convergence via adaptive sampling.
29 * Better anomaly detection of unique structural states.
31This subpackage contains two classes which are containers of trajectory data.
32The SingleTraj trajectory contains information about a single trajectory.
33The TrajEnsemble class contains information about multiple trajectories. This adds
34a new dimension to MD data. The time and atom dimension are already established.
35Two frames can be appended along the time axis to get a trajectory with multiple
36frames. If they are appended along the atom axis, the new frame contains the
37atoms of these two. The trajectory works in a similar fashion. Adding two trajectories
38along the trajectory axis returns a trajectory ensemble, represented as a TrajEnsemble
39class in this package.
41See also:
42 http://statisticalbiophysicsblog.org/?p=92
44"""
46################################################################################
47# Imports
48################################################################################
51# Future Imports at the top
52from __future__ import annotations
54# Standard Library Imports
55import errno
56import os
57import re
58import sys
59import warnings
60from copy import deepcopy
61from io import StringIO
62from pathlib import Path
64# Third Party Imports
65import numpy as np
66import tables
67from optional_imports import _optional_import
69# Encodermap imports
70from encodermap._typing import CanBeIndex, CustomAAsDict
71from encodermap.misc.misc import _TOPOLOGY_EXTS
72from encodermap.misc.xarray_save_wrong_hdf5 import save_netcdf_alongside_mdtraj
73from encodermap.trajinfo.load_traj import _load_pdb_from_uri, _load_traj
76################################################################################
77# Optional Imports
78################################################################################
81md = _optional_import("mdtraj")
82mda = _optional_import("MDAnalysis")
83h5 = _optional_import("h5py")
84xr = _optional_import("xarray")
85h5py = _optional_import("h5py")
86pd = _optional_import("pandas")
89################################################################################
90# Typing
91################################################################################
94# Standard Library Imports
95from collections.abc import Callable, Iterable, Sequence
96from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload
99if TYPE_CHECKING: # pragma: no cover
100 # Third Party Imports
101 import h5py as h5
102 import MDAnalysis as mda
103 import mdtraj as md
104 import xarray as xr
106 # Encodermap imports
107 from encodermap.trajinfo.info_all import TrajEnsemble
108 from encodermap.trajinfo.trajinfo_utils import CustomTopology, SingleTrajFeatureType
111################################################################################
112# Globals
113################################################################################
116__all__: list[str] = ["SingleTraj"]
119################################################################################
120# Utilities
121################################################################################
124class MixedUpInputs(Exception):
125 """For when the user provides trajectories as topologies and vice versa."""
127 pass
130class Capturing(list):
131 """Class to capture print statements from function calls.
133 Examples:
134 >>> # write a function
135 >>> def my_func(arg='argument'):
136 ... print(arg)
137 ... return('fin')
138 >>> # use capturing context manager
139 >>> with Capturing() as output:
140 ... my_func('new_argument')
141 >>> print(output)
142 ['new_argument', "'fin'"]
144 """
146 def __enter__(self):
147 self._stdout = sys.stdout
148 sys.stdout = self._stringio = StringIO()
149 return self
151 def __exit__(self, *args):
152 self.extend(self._stringio.getvalue().splitlines())
153 del self._stringio # free up some memory
154 sys.stdout = self._stdout
157def _hash_numpy_array(x: np.ndarray) -> int:
158 hash_value = hash(x.shape)
159 hash_value ^= hash(x.strides)
160 hash_value ^= hash(x.data.tobytes())
161 return hash_value
164################################################################################
165# Classes
166################################################################################
169class SingleTrajFsel:
170 def __init__(self, other):
171 self.other = other
173 def __getitem__(self, item: CanBeIndex) -> SingleTraj:
174 if self.other.traj_num is None:
175 if isinstance(item, (int, np.int64)):
176 idx = np.where(self.other.id == item)[0]
177 elif isinstance(item, (list, np.ndarray)):
178 idx = np.where(np.in1d(self.other.id, np.asarray(item)))[0]
179 elif isinstance(item, slice):
180 raise NotImplementedError("Currently can't index frames with slice.")
181 else:
182 raise ValueError(
183 f"The `fsel[]` method of `SingleTraj` takes {CanBeIndex} types, "
184 f"but {type(item)} was provided."
185 )
186 else:
187 if isinstance(item, (int, np.int64)):
188 idx = np.where(self.other.id[:, 1] == item)[0]
189 elif isinstance(item, (list, np.ndarray)):
190 idx = np.where(np.in1d(self.other.id[:, 1], np.asarray(item)))[0]
191 elif isinstance(item, slice):
192 raise NotImplementedError("Currently can't index frames with slice.")
193 else:
194 raise ValueError(
195 f"The `fsel[]` method of `SingleTraj` takes {CanBeIndex} types, "
196 f"but {type(item)} was provided."
197 )
198 if len(idx) == 0:
199 raise ValueError(
200 f"No frames with frame index {item} in trajectory {self.other} "
201 f"with frames: {self.other._frames}"
202 )
203 return self.other[idx]
206class SingleTraj:
207 """This class contains the info about a single trajectory.
209 This class contains many of the attributes and methods of
210 `mdtraj.Trajectory <https://www.mdtraj.org/1.9.8.dev0/api/generated/mdtraj.Trajectory.html>`_.
211 It is meant to be used as a standalone single trajetcory or in an ensemble
212 defined in the :obj:`encodermap.trajinfo.info_all.TrajEnsemble` class.
213 Other than the standard
214 `mdtraj.Trajectory <https://www.mdtraj.org/1.9.8.dev0/api/generated/mdtraj.Trajectory.html>`_
215 this class loads the MD data only when needed. The location of the file(s)
216 and other attributes like indices (int, list[int], numpy.ndarray, slice)
217 are stored until the traj is accessed via the
218 :attr:`SingleTraj.traj` attribute. The returned traj is a
219 `mdtraj.Trajectory <https://www.mdtraj.org/1.9.8.dev0/api/generated/mdtraj.Trajectory.html>`_
220 with the correct number of frames in the correct sequence.
222 Besides MD data, this class keeps track of your collective variables. Oftentimes
223 the raw xyz data of a trajectory is not needed for understanding the conformation
224 and suitable CVs are selected to represent a protein via internal coordinates
225 (torsions, pairwise distances, etc.). This class keeps tack of your CVs.
226 Whether you call them ``'highd'`` or ``'torsions'``, this class keeps
227 track of everything and returns the values when you need them.
229 :obj:`SingleTraj` supports fancy indexing, so you can extract one or more frames
230 from a Trajectory as a separate trajectory. For example, to form a
231 trajectory with every other frame, you can slice with ``traj[::2]``.
233 Note:
234 SingleTraj uses the nanometer, degree & picosecond unit system.
236 Attributes:
237 backend (str): Current state of loading. If ``backend == 'no_load'`` xyz
238 data will be loaded from disk, if accessed. If ``backend == 'mdtraj'``,
239 the data is already in RAM.
240 common_str (str): Substring of :attr:`traj_file` and :attr:`top_file`.
241 Used to group multiple trajectory and topology files.
242 If ``traj files=['protein1_traj1.xtc', 'protein1_traj2.xtc']``
243 have the same topolgy stored in a file called ``'protein1.pdb'``, you
244 can load them with ``common_str='protein1'`` together with more ``.xtc``
245 and ``.pdb`` files and these two ``.xtc`` files will use the correct
246 ``.pdb`` file.
247 index (Sequence[Union[None, int, list, numpy.ndarray, slice]]): A sequence of
248 fancy slices of the trajectory. When file is loaded from disk, the
249 fancy indexes will be applied one after the other.
250 traj_num (int): Integer to identify a SingleTraj class in a TrajEnsemble class.
251 traj_file (str): Trajectory file used to create this class.
252 top_file (str): Topology file used to create this class. If a .h5 trajectory
253 was used traj_file and top_file are identical. If a
254 ``mdtraj.Trajectory`` was used to create SingleTraj, these strings are empty.
256 Examples:
257 Load a pdb file with 14 frames from rcsb.org
259 >>> import encodermap as em
260 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb")
261 >>> traj # doctest: +ELLIPSIS
262 <encodermap.SingleTraj object...
263 >>> traj.n_frames
264 14
266 Providing common_str sets this attribute.
268 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb", common_str="1GHC")
269 >>> traj.common_str
270 '1GHC'
272 Indexing using integers returns a SingleTraj with only one frame.
274 >>> frame = traj[5]
275 >>> frame.n_frames
276 1
278 Indexing can also use lists of integers.
280 >>> subset = traj[[0, 1, 5]]
281 >>> subset.n_frames
282 3
284 Further indexing this ``subset`` uses the current trajectory 'as is'. Although
285 frame 0, 1, and 5 have been extracted from ``traj``, we get frame 5 from
286 ``subset`` by indexing with 2.
288 >>> frame = subset[2]
289 >>> frame.id
290 array([5])
292 Indexing using the original frame indices from the file is done using the
293 ``fsel[]`` accessor.
295 >>> frame = subset.fsel[5]
296 >>> frame.id
297 array([5])
299 Advanced slicing
301 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb")[-1:7:-2]
302 >>> [frame.id[0] for frame in traj]
303 [13, 11, 9]
305 The ``traj_num`` argument is mainly used in :obj:`encodermap.TrajEnsemble`,
306 but can be provided manually.
308 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb", traj_num=2)
309 >>> traj.traj_num
310 2
312 The argument ``basename_fn`` should be a callable, that takes a string
313 and returns a string.
315 >>> from pathlib import Path
316 >>> def my_basename_fn(filename):
317 ... stem = str(Path(filename).stem)
318 ... return "custom_" + stem
319 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb", basename_fn=my_basename_fn)
320 >>> traj.basename
321 'custom_1GHC'
323 Build a trajectory ensemble from multiple :obj:`SingleTraj` objects.
325 >>> traj1 = em.SingleTraj("https://files.rcsb.org/view/1YUG.pdb")
326 >>> traj2 = em.SingleTraj("https://files.rcsb.org/view/1YUF.pdb")
327 >>> trajs = traj1 + traj2
328 >>> print(trajs.n_trajs, trajs.n_frames, [traj.n_frames for traj in trajs])
329 2 31 [15, 16]
331 """
333 _mdtraj_attr = [
334 "n_frames",
335 "n_atoms",
336 "n_chains",
337 "n_residues",
338 "openmm_boxes",
339 "openmm_positions",
340 "time",
341 "timestep",
342 "xyz",
343 "unitcell_vectors",
344 "unitcell_lengths",
345 "unitcell_angles",
346 "_check_valid_unitcell",
347 "_distance_unit",
348 "_have_unitcell",
349 "_rmsd_traces",
350 "_savers",
351 "_string_summary_basic",
352 "_time",
353 "_time_default_to_arange",
354 "_topology",
355 "_unitcell_angles",
356 "_unitcell_lengths",
357 "_xyz",
358 ]
360 def __init__(
361 self,
362 traj: Union[str, Path, md.Trajectory],
363 top: Optional[Union[str, Path]] = None,
364 common_str: str = "",
365 backend: Literal["no_load", "mdtraj"] = "no_load",
366 index: Optional[Union[int, list[int], np.ndarray, slice]] = None,
367 traj_num: Optional[int] = None,
368 basename_fn: Optional[Callable[[str], str]] = None,
369 custom_top: Optional[CustomAAsDict] = None,
370 ) -> None:
371 """Initialize the SingleTraj object with location and reference pdb file.
373 Args:
374 traj (Union[str, mdtraj.Trajectory]): The trajectory. This argument
375 can either be the filename of a trajectory file
376 (``.xtc, .dcd, .h5, .trr``) or an instance of
377 ``mdtraj.Trajectory``.
378 top (Union[str, mdtraj.Topology], optional): The path to the
379 topology file. Defaults to :obj:`None`. If a ``mdtraj.Trajectory``
380 or a ``.h5`` file is provided in :obj:`traj`, this argument will
381 not be used and the topology from the corresponding ``traj``
382 argument will be used.
383 common_str (str, optional): A string to group traj of similar
384 topology. If multiple :obj:`SingleTraj` are grouped in one
385 :obj:`encodermap.trajinfo.info_all.TrajEnsemble`, the
386 ``common_str`` is used to group them together. Defaults to ''
387 which means this instance of :obj:`SingleTraj` won't have a common
388 string.
389 backend (Literal['no_load', 'mdtraj'], optional): Choose the
390 backend to load trajectories.
391 - 'mdtraj' uses mdtraj, which loads all trajectories into RAM.
392 - 'no_load' creates an empty trajectory object.
393 Defaults to 'no_load'
394 index (Optional[Union[int, list[int], numpy.ndarray, slice]]): An
395 integer or a Sequence of ``int``. If an integer is provided,
396 only the frame at this position will be loaded once the internal
397 ``mdtraj.Trajectory`` is accessed. If an array or list is provided,
398 the corresponding frames will be used. Indices always slice the
399 trajectory as is, meaning they don't index the original frames
400 of the trajectory on disk (see Example section). These indices
401 can have duplicates: ``[0, 1, 1, 2, 0, 1]``. A slice object
402 can also be provided. Supports fancy slicing (``traj[1:50:3]``).
403 If :obj:`None` is provided, the traj is loaded fully.
404 Defaults to :obj:`None`.
405 traj_num (Union[int, None], optional): If working with multiple
406 trajs, this is the easiest unique identifier. If multiple
407 :obj:`SingleTraj` are instantiated by
408 :obj:`encodermap.trajinfo.info_all.TrajEnsemble` the ``traj_num``
409 is used as a unique identifier per traj. Defaults to :obj:`None`.
410 basename_fn (Optional[Callable[[str], str]]): A function to apply to
411 ``traj`` to give it another identifier. If all your trajs are called
412 ``'traj.xtc'`` and only the directory they're in gives them a
413 unique identifier, you can provide a function into this
414 argument to split the path. The function has to take a :obj:`str` and
415 return :obj:`str`. If None is provided, the basename
416 is extracted like so: ``lambda x: x.split('/')[0].split('.')[-1]``.
417 Defaults to None, in which case the filename without
418 extension will be used.
419 custom_top: Optional[:obj:`encodermap._typing.CustomAAsDict`]: An
420 instance of the :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology`
421 class or a dictionary that can be made into such.
423 """
424 # Imports
425 # Encodermap imports
426 from encodermap.misc.xarray import construct_xarray_from_numpy
427 from encodermap.trajinfo.trajinfo_utils import CustomTopology
429 # defaults
430 self.__traj = traj
431 self.backend = backend
432 self.common_str = common_str
433 self.index = index if isinstance(index, tuple) else (index,)
434 self.traj_num = traj_num
435 self._loaded_once = False if backend == "no_load" else True
436 self._orig_frames = np.array([])
437 self._CVs = xr.Dataset()
439 # custom topology to load dihedral angles
440 self._custom_top = custom_top
441 if self._custom_top is not None:
442 if isinstance(self._custom_top, dict):
443 self._custom_top = CustomTopology.from_dict(self._custom_top, traj=self)
444 else:
445 self._custom_top = CustomTopology(traj=self)
447 # _atom indices are for delayed atom-slicing
448 self._atom_indices = None
450 # decide the basename
451 if basename_fn is None:
452 basename_fn = lambda x: os.path.basename(x).split(".")[0]
453 self.basename_fn = basename_fn
455 # save the filename
456 if isinstance(traj, str):
457 if self._validate_uri(traj):
458 self._traj_file = traj
459 else:
460 self._traj_file = Path(traj)
461 elif isinstance(traj, Path):
462 self._traj_file = traj
463 elif isinstance(traj, md.Trajectory):
464 self._traj_file = Path("")
465 self._top_file = Path("")
466 self.backend = "mdtraj"
467 self.trajectory = traj
468 self._loaded_once = True
469 self._orig_frames = np.arange(traj.n_frames)
470 else:
471 raise ValueError(
472 f"Argument `traj` takes either str, Path, or "
473 f"mdtraj.Trajectory. You supplied: {type(traj)}."
474 )
476 if top is not None:
477 if isinstance(top, md.Topology):
478 if custom_top is not None:
479 raise Exception(
480 f"Providing an MDTraj Topology as the `top` argument interferes "
481 f"with the argument `custom_topology`. Use one or the other. "
482 )
483 self._top_file = Path("")
484 else:
485 if self._validate_uri(top):
486 self._top_file = top
487 else:
488 self._top_file = Path(top)
489 if isinstance(self._traj_file, Path):
490 if (
491 self._traj_file.suffix in _TOPOLOGY_EXTS
492 and self._traj_file != self._top_file
493 ):
494 raise MixedUpInputs(
495 f"You probably mixed up the input. Normally you "
496 f"want to instantiate with `SingleTraj(traj, top)`."
497 f"Based on the files and the extensions you provided "
498 f"(traj={self._traj_file.name} and top="
499 f"{self._top_file.name}), you want to change the "
500 f"order of the arguments, or use keyword arguments."
501 )
502 else:
503 if isinstance(self._traj_file, Path):
504 if self._traj_file.suffix in _TOPOLOGY_EXTS:
505 if self._validate_uri(traj):
506 self._top_file = traj
507 else:
508 self._top_file = self._traj_file
509 else:
510 self._top_file = self._traj_file
512 if self.backend == "no_load":
513 self.trajectory = False
514 self.topology = False
515 else:
516 if isinstance(self._traj_file, str) and self._validate_uri(self._traj_file):
517 traj = md.load_pdb(str(self.traj_file))
518 elif self._traj_file != Path(""):
519 try:
520 traj = md.load(str(self._traj_file), top=str(self._top_file))
521 except tables.NoSuchNodeError as e:
522 if self.traj_num is None:
523 raise
524 # Local Folder Imports
525 from .info_all import HDF5GroupWrite
527 with HDF5GroupWrite(self.top_file) as h5file:
528 traj = h5file.read_traj(self.traj_num)
529 self.trajectory = traj
530 self.topology = False
531 self._loaded_once = True
532 self.topology = self._custom_top.top
533 self._orig_frames = np.arange(traj.n_frames)
535 # check the h5file. Maybe it contains CVs and topology
536 CVs_in_file = False
537 custom_top_in_file = False
538 common_str_in_file = False
539 if isinstance(self._traj_file, Path):
540 if self._traj_file.suffix == ".h5":
541 with h5.File(self.traj_file, "r") as file:
542 h5_attrs = list(file.attrs.keys())
543 if "CVs" in file.keys():
544 CVs_in_file = True
545 if any(["custom_top" in s for s in h5_attrs]):
546 custom_top_in_file = True
547 if any(["common_str" in s for s in h5_attrs]):
548 common_str_in_file = True
550 # if group 'CVs' is in h5store, but no data
551 if CVs_in_file:
552 test = xr.open_dataset(
553 self.traj_file,
554 group="CVs",
555 engine="h5netcdf",
556 backend_kwargs={"phony_dims": "access"},
557 )
558 if len(test.data_vars) == 0:
559 CVs_in_file = False
561 # inform the use about a grandfathered traj_num
562 if custom_top_in_file:
563 h5_top_attrs = list(filter(lambda x: x.startswith("_custom_top"), h5_attrs))
564 finds = [re.findall(r"\d+", s) for s in h5_top_attrs]
565 if not all([f == [] for f in finds]):
566 h5_top_traj_nums = [int(f[-1]) for f in finds]
567 if self.traj_num not in h5_top_traj_nums:
568 if len(h5_top_attrs) == 1:
569 h5_top_traj_num = h5_top_traj_nums[0]
570 with h5py.File(self.traj_file, "a") as file:
571 file.attrs["_custom_top"] = file.attrs[h5_top_attrs[0]]
572 del file.attrs[h5_top_attrs[0]]
573 print(
574 f"The file {self.traj_file} was previously saved with the `traj_num`="
575 f"{h5_top_traj_num}, but this instance uses `traj_num`={self.traj_num}. "
576 f"I have overwritten the `traj_num` for the custom_topology in this file."
577 )
578 elif len(h5_top_attrs) > 1:
579 raise Exception(
580 f"The h5File contains multiple '_custom_top' attrs: {h5_top_attrs} "
581 f"of which none matches this Traj's `traj_num`: {self.traj_num}."
582 )
584 # inform the use about a grandfathered traj_num
585 if common_str_in_file:
586 common_str_key = "common_str"
587 h5_common_str_traj_num = None
588 h5_common_str_attrs = list(
589 filter(lambda x: x.startswith("common_str"), h5_attrs)
590 )
591 finds = [re.findall(r"\d+", s) for s in h5_common_str_attrs]
592 if not all([f == [] for f in finds]):
593 h5_common_str_traj_nums = [int(f[-1]) for f in finds]
594 if self.traj_num not in h5_common_str_traj_nums:
595 if len(h5_common_str_attrs) == 1:
596 h5_common_str_traj_num = h5_common_str_traj_nums[0]
597 with h5py.File(self.traj_file, "a") as file:
598 file.attrs["common_str"] = file.attrs[
599 h5_common_str_attrs[0]
600 ]
601 del file.attrs[h5_common_str_attrs[0]]
602 print(
603 f"The file {self.traj_file} was previously saved with the `traj_num`="
604 f"{h5_common_str_traj_num}, but this instance uses `traj_num`={self.traj_num}. "
605 f"I have overwritten the `traj_num` for the custom_topology in this file."
606 )
607 elif len(h5_common_str_attrs) > 1:
608 raise Exception(
609 f"The h5File contains multiple '_custom_top' attrs: {h5_common_str_attrs} "
610 f"of which none matches this Traj's `traj_num`: {self.traj_num}."
611 )
612 else:
613 h5_common_str_traj_num = self.traj_num
614 if h5_common_str_traj_num is not None:
615 common_str_key = f"common_str_{h5_common_str_traj_num}"
617 # maybe load CVs from h5 file
618 if CVs_in_file:
619 try:
620 ds = xr.open_dataset(
621 self.traj_file,
622 group="CVs",
623 engine="h5netcdf",
624 backend_kwargs={"phony_dims": "access"},
625 )
626 if ds.sizes["traj_num"] > 1:
627 assert self.traj_num in ds.coords["traj_num"], (
628 f"This trajectory with {self.traj_num=} is not in "
629 f"the dataset with traj_nums: {ds.coords['traj_num']}."
630 )
631 ds = ds.sel(traj_num=self.traj_num)
632 ds = ds.expand_dims("traj_num")
633 if str(ds.coords["traj_name"].values) != self.basename:
634 ds.coords["traj_name"] = [self.basename]
635 ds = ds.dropna("frame_num", how="all")
636 self._CVs = ds
637 # bad formatted h5 file
638 except OSError:
639 DAs = {
640 k: construct_xarray_from_numpy(self, i[()], k)
641 for k, i in file["CVs"].items()
642 }
643 DS = xr.Dataset(DAs)
644 self._CVs.update(DS)
645 # other exceptions probably due to formatting
646 except Exception as e:
647 raise Exception(
648 f"The formatting of the data in the file "
649 f"{self.traj_file} is off. Xarray could "
650 f"not load the group 'CVs' and failed with {e}. "
651 f"Some debug: {CVs_in_file=} and {file.keys()=}."
652 ) from e
654 # get the original frame indices from the dataset
655 # this is the only case where we want to overwrite
656 # this variable
657 if not self._loaded_once:
658 self._loaded_once = True
659 self._orig_frames = self._CVs["frame_num"].values
661 # iteratively apply index
662 index = self._orig_frames
663 for ind in self.index:
664 if ind is not None:
665 index = index[ind]
667 # set the _CVs accordingly
668 self._CVs = self._CVs.loc[{"frame_num": index}]
670 # maybe load _custom_top_from file
671 if custom_top_in_file:
672 ctop = CustomTopology.from_hdf5_file(
673 self.traj_file,
674 traj=self,
675 )
676 if custom_top is not None:
677 if self._custom_top != ctop:
678 print(
679 f"The file {self.traj_file} contained a custom_topology: "
680 f"{ctop}. The custom topology you provided {self._custom_top} "
681 f"is different from that in the file. The provided one will "
682 f"take precedent over the one in the file. "
683 )
684 else:
685 self._custom_top = ctop
687 # maybe load common_str form file
688 if common_str_in_file:
689 with h5py.File(self.traj_file, "r") as f:
690 common_str_file = f.attrs[common_str_key]
692 if self.common_str == "":
693 self.common_str = common_str_file
694 else:
695 if common_str_file != self.common_str:
696 warnings.warn(
697 f"The file {self.traj_file} contains a "
698 f"`common_str`={common_str_file}, which differs from the "
699 f"`common_str` provided in __init__ ({self.common_str}). "
700 f"I will use the one from "
701 f"__init__ and overwrite the one in the file."
702 )
703 try:
704 with h5py.File(self.traj_file, "a") as f:
705 f.attrs[common_str_key] = self.common_str
706 except OSError as e:
707 warnings.warn(
708 f"I could not write the common str in the file, because "
709 f"it is already opened by another python object: {str(e)}"
710 )
712 @classmethod
713 def from_pdb_id(cls, pdb_id: str, traj_num: Optional[int] = None) -> SingleTraj:
714 """Alternate constructor for the TrajEnsemble class.
716 Builds an SingleTraj class from a pdb-id.
718 Args:
719 pdb_id (str): The 4-letter pdb id.
721 Returns:
722 SingleTraj: An SingleTraj class.
724 """
725 url = f"https://files.rcsb.org/view/{pdb_id.upper()}.pdb"
726 return cls(url, common_str=pdb_id.upper(), traj_num=traj_num)
728 @property
729 def featurizer(self):
730 # Local Folder Imports
731 from ..loading.featurizer import SingleTrajFeaturizer
733 if not hasattr(self, "_featurizer"):
734 self._featurizer = SingleTrajFeaturizer(self)
735 return self._featurizer
737 @property
738 def indices_chi1(self) -> np.ndarray:
739 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the
740 atoms that take part in this dihedral angle. This index is 0-based."""
741 return self._custom_top.indices_chi1()
743 @property
744 def indices_chi2(self) -> np.ndarray:
745 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the
746 atoms that take part in this dihedral angle. This index is 0-based."""
747 return self._custom_top.indices_chi2()
749 @property
750 def indices_chi3(self) -> np.ndarray:
751 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the
752 atoms that take part in this dihedral angle. This index is 0-based."""
753 return self._custom_top.indices_chi3()
755 @property
756 def indices_chi4(self) -> np.ndarray:
757 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the
758 atoms that take part in this dihedral angle. This index is 0-based."""
759 return self._custom_top.indices_chi4()
761 @property
762 def indices_chi5(self) -> np.ndarray:
763 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the
764 atoms that take part in this dihedral angle. This index is 0-based."""
765 return self._custom_top.indices_chi5()
767 @property
768 def indices_phi(self) -> np.ndarray:
769 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the
770 atoms that take part in this dihedral angle. This index is 0-based."""
771 return self._custom_top.indices_phi()
773 @property
774 def indices_omega(self) -> np.ndarray:
775 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the
776 atoms that take part in this dihedral angle. This index is 0-based."""
777 return self._custom_top.indices_omega()
779 @property
780 def indices_psi(self) -> np.ndarray:
781 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the
782 atoms that take part in this dihedral angle. This index is 0-based."""
783 return self._custom_top.indices_psi()
785 @property
786 def _original_frame_indices(self) -> np.ndarray:
787 """numpy.ndarray: If trajectory has not been loaded, it is loaded and the
788 frames of the trajectory file on disk are put into a `np.arange()`. If
789 the trajectory is sliced in weird ways, this array tracks the original frames.
790 """
791 if self._loaded_once:
792 return self._orig_frames
793 else:
794 self.load_traj()
795 return self._orig_frames
797 @property
798 def _frames(self) -> np.ndarray:
799 """numpy.ndarray: Applies self.index over self._orig_frames."""
800 frames = self._orig_frames.copy()
801 for ind in self.index:
802 if ind is not None:
803 frames = frames[ind]
804 return np.asarray(frames)
806 def _trace(self, CV: Sequence[str]) -> np.ndarray:
807 """Creates a low-dimensional representation of the loaded CV data by
808 stacking all arguments in `CV` along a single axis.
810 If this `SingleTraj` has 100 frames and a CV with shape (100, 50, 3) with
811 the name 'cartesians', then `traj._trace` will return a numpy.ndarray of shape
812 (100, 150).
814 Args:
815 CV (Sequence[str]): The CVs to combine in the trace.
817 Returns:
818 numpy.ndarray: The trace.
820 """
821 out = []
822 for i in CV:
823 v = self.CVs[i]
824 out.append(v.reshape(v.shape[0], -1))
825 return np.concatenate(out)
827 @property
828 def traj_file(self) -> str:
829 """str: The traj file as a string (rather than a `pathlib.Path`)."""
830 return str(self._traj_file)
832 @property
833 def top_file(self) -> str:
834 """str: The topology file as a string (rather than a `pathlib.Path`)."""
835 return str(self._top_file)
837 @property
838 def traj(self) -> md.Trajectory:
839 """mdtraj.Trajectory: This attribute always returns an mdtraj.Trajectory.
840 if `backend` is 'no_load', the trajectory will be loaded into memory and returned.
842 """
843 if self.backend == "no_load":
844 self.load_traj()
845 out = self.trajectory
846 self.unload()
847 return out
848 else:
849 return self.trajectory
851 @property
852 def _traj(self):
853 """Needs to be here to complete setter.
854 Not returning anything, because setter is also not returning anything."""
855 pass
857 @_traj.setter
858 def _traj(self, traj_file):
859 """Sets the traj and trajectory attributes. Can be provided str or
860 mdtraj.Trajectory and sets the attributes based on the chosen backend."""
861 if self.topology:
862 reinject_top = deepcopy(self.topology)
863 else:
864 reinject_top = False
865 self.trajectory, _ = _load_traj(
866 *self.index,
867 traj_file=traj_file,
868 top_file=self._top_file,
869 traj_num=self.traj_num,
870 atom_indices=self._atom_indices,
871 )
872 if not self._loaded_once:
873 self._loaded_once = True
874 self._orig_frames = _
875 if reinject_top:
876 self.trajectory.top = reinject_top
877 self.topology = reinject_top
879 @property
880 def basename(self) -> str:
881 """str: Basename is the filename without path and without extension. If `basename_fn` is not None, it will be
882 applied to `traj_file`."""
883 if self.traj_file:
884 if str(self.traj_file) == ".":
885 return None
886 return self.basename_fn(self.traj_file)
888 @property
889 def extension(self) -> str:
890 """str: Extension is the file extension of the trajectory file (self.traj_file)."""
891 if isinstance(self._traj_file, Path):
892 return self._traj_file.suffix
893 else:
894 return "." + self._traj_file.split(".")[-1]
896 @property
897 def id(self) -> np.ndarray:
898 """numpy.ndarray: id is an array of unique identifiers which identify the frames in
899 this SingleTraj object when multiple Trajectories are considered.
901 If the traj was initialized from an TrajEnsemble class, the traj gets a unique
902 identifier (traj_num) which will also be put into the id array, so that id
903 can have two shapes ((n_frames, ), (n_frames, 2)) This corresponds to
904 self.id.ndim = 1 and self.id.ndim = 2. In the latter case self.id[:,1] are the
905 frames and self.id[:,0] is an array full of traj_num.
907 """
908 values = self._original_frame_indices
909 if isinstance(values, (int, np.integer)):
910 if self.traj_num is None:
911 return np.array([values])
912 else:
913 return np.array([[self.traj_num, values]])
914 else:
915 for i, ind in enumerate(self.index):
916 if ind is not None:
917 values = values[ind]
919 # if reduced all the way to single frame
920 if isinstance(values, (int, np.integer)):
921 if self.traj_num is None:
922 return np.array([values])
923 else:
924 return np.array([[self.traj_num, values]])
926 # else
927 if self.traj_num is None:
928 return values
929 else:
930 return np.array([np.full(len(values), self.traj_num), values]).T
932 @property
933 def n_frames(self) -> int:
934 """int: Number of frames in traj.
936 Loads the traj into memory if not in HDF5 file format. Be aware.
938 """
939 if any([isinstance(ind, (int, np.integer)) for ind in self.index]):
940 self._loaded_once = True
941 ind = [i for i in self.index if isinstance(i, (int, np.integer))][0]
942 self._orig_frames = ind
943 return 1
944 elif self._traj_file.suffix == ".h5":
945 with h5.File(self.traj_file, "r") as file:
946 if self.index == (None,):
947 if (
948 "coordinates" not in list(file.keys())
949 and self.traj_num is not None
950 ):
951 n_frames = np.arange(
952 file[f"coordinates_{self.traj_num}"].shape[0]
953 )
954 else:
955 n_frames = np.arange(file["coordinates"].shape[0])
956 else:
957 for i, ind in enumerate(self.index):
958 if i == 0:
959 if (
960 "coordinates" not in list(file.keys())
961 and self.traj_num is not None
962 ):
963 n_frames = np.arange(
964 file[f"coordinates_{self.traj_num}"].shape[0]
965 )
966 else:
967 n_frames = np.arange(file["coordinates"].shape[0])
968 if ind is not None:
969 n_frames = n_frames[ind]
970 else:
971 if ind is not None:
972 n_frames = n_frames[ind]
973 if not self._loaded_once:
974 if (
975 "coordinates" not in list(file.keys())
976 and self.traj_num is not None
977 ):
978 self._orig_frames = np.arange(
979 file[f"coordinates_{self.traj_num}"].shape[0]
980 )
981 else:
982 self._orig_frames = np.arange(file["coordinates"].shape[0])
984 # return single int or length of array
985 if isinstance(n_frames, (int, np.integer)):
986 return n_frames
987 else:
988 return len(n_frames)
989 elif self._traj_file.suffix == ".xtc":
990 with mda.coordinates.XTC.XTCReader(self.traj_file) as reader:
991 if self.index == (None,):
992 n_frames = np.arange(reader.n_frames)
993 else:
994 for i, ind in enumerate(self.index):
995 if i == 0:
996 n_frames = np.arange(reader.n_frames)[ind]
997 else:
998 n_frames = n_frames[ind]
999 if ind is None:
1000 n_frames = n_frames[0]
1001 if not self._loaded_once:
1002 self._loaded_once = True
1003 self._orig_frames = np.arange(reader.n_frames)
1004 if isinstance(n_frames, (int, np.integer)):
1005 return n_frames
1006 else:
1007 return len(n_frames)
1008 else:
1009 self.load_traj()
1010 return self.traj.n_frames
1012 @property
1013 def _n_frames_base_h5_file(self) -> int:
1014 """int: Can be used to get n_frames without loading an HDF5 into memory."""
1015 if self.extension == ".h5":
1016 with h5.File(self.traj_file, "r") as file:
1017 return file["coordinates"].shape[0]
1018 else:
1019 return -1
1021 @property
1022 def CVs_in_file(self) -> bool:
1023 """bool: Is True, if `traj_file` has exyension .h5 and contains CVs."""
1024 if self.extension == ".h5":
1025 with h5.File(self.traj_file, "r") as file:
1026 if "CVs" in file.keys():
1027 return True
1028 return False
1030 @property
1031 def n_atoms(self) -> int:
1032 """int: Number of atoms in traj.
1034 Loads the traj into memory if not in HDF5 file format. Be aware.
1036 """
1037 if self._atom_indices is not None:
1038 return len(self._atom_indices)
1039 if self.extension == ".h5":
1040 with h5.File(self.traj_file, "r") as file:
1041 try:
1042 h5_coordinates_shape = file["coordinates"].shape[1]
1043 except KeyError as e:
1044 if self.traj_num is not None:
1045 h5_coordinates_shape = file[
1046 f"coordinates_{self.traj_num}"
1047 ].shape[1]
1048 else:
1049 raise Exception(
1050 f"Neither key 'coordinates', nor 'coordinates_{self.traj_num}' "
1051 f"exists in h5file {self.traj_file} with keys {file.keys()}."
1052 ) from e
1053 return h5_coordinates_shape
1054 else:
1055 return self.top.n_atoms
1057 @property
1058 def n_residues(self) -> int:
1059 """int: Number of residues in traj."""
1060 return self.top.n_residues
1062 @property
1063 def n_chains(self) -> int:
1064 """int: Number of chains in traj."""
1065 return self.top.n_chains
1067 @property
1068 def top(self) -> md.Topology:
1069 """mdtraj.Topology: The structure of a Topology object is similar to that of a PDB file.
1071 It consists. of a set of Chains (often but not always corresponding to
1072 polymer chains). Each Chain contains a set of Residues, and each Residue
1073 contains a set of Atoms. In addition, the Topology stores a list of which
1074 atom pairs are bonded to each other.
1075 Atom and residue names should follow the PDB 3.0 nomenclature for all
1076 molecules for which one exists
1078 Attributes:
1079 chains (generator): Iterate over chains.
1080 residues (generator): Iterate over residues.
1081 atoms (generator): Iterate over atoms.
1082 bonds (generator): Iterate over bonds.
1084 """
1085 return self._get_top()
1087 def _get_raw_top(self) -> md.Topology:
1088 """Reads different files and loads md.Topology from them.
1090 This topology will *NOT* be corrected with `CustomTopology`.
1092 Returns:
1093 mdtraj.Topology: The raw topology.
1095 """
1096 # Third Party Imports
1097 import tables
1099 # Local Folder Imports
1100 from .info_all import HDF5GroupWrite
1102 if self.top_file:
1103 if self.top_file != ".":
1104 if self._validate_uri(self.top_file):
1105 assert self.top_file.endswith(".pdb")
1106 top = _load_pdb_from_uri(self.top_file)
1107 elif not os.path.isfile(self.top_file):
1108 raise FileNotFoundError(
1109 errno.ENOENT, os.strerror(errno.ENOENT), self.top_file
1110 )
1111 if self.backend == "no_load" and not self.extension == ".h5" and self.traj_file:
1112 if self._validate_uri(self.top_file):
1113 top = _load_pdb_from_uri(self.top_file)
1114 else:
1115 top = md.load_topology(self.top_file)
1116 if self.extension == ".h5":
1117 try:
1118 top = md.load_topology(self.top_file)
1119 except tables.NoSuchNodeError as e:
1120 if self.traj_num is None:
1121 raise e
1122 with HDF5GroupWrite(self.top_file) as h5file:
1123 top = h5file.read_traj(self.traj_num).top
1124 if self.backend == "no_load" and "top" not in locals():
1125 try:
1126 top = md.load_topology(self.top_file)
1127 except tables.NoSuchNodeError as e:
1128 if self.traj_num is None:
1129 raise e
1130 with HDF5GroupWrite(self.top_file) as h5file:
1131 top = h5file.read_traj(self.traj_num).top
1132 if self.backend == "mdtraj":
1133 top = self.traj.top
1134 else:
1135 if self._validate_uri(self.top_file):
1136 top = _load_pdb_from_uri(self.top_file)
1137 else:
1138 try:
1139 top = md.load_topology(self.top_file)
1140 except tables.NoSuchNodeError as e:
1141 if self.traj_num is None:
1142 raise e
1143 with HDF5GroupWrite(self.top_file) as h5file:
1144 top = h5file.read_traj(self.traj_num).top
1145 return top
1147 def _get_top(self) -> md.Topology:
1148 """Reads different files and loads md.Topology from them.
1150 Returns:
1151 mdtraj.Topology: The structure of a Topology object is similar to that of a PDB file.
1153 """
1154 if self.topology:
1155 top = self.topology
1156 else:
1157 top = self._custom_top.top
1158 return top
1160 def copy(self) -> SingleTraj:
1161 """Returns a copy of ``self``."""
1162 return deepcopy(self)
1164 def del_CVs(self) -> None:
1165 """Resets the ``_CVs`` attribute to an empty :obj:`xarray.Dataset`."""
1166 del self._CVs
1167 self._CVs = xr.Dataset()
1169 def _calc_CV(self) -> dict[str, np.ndarray]:
1170 """Returns the current CVs as a dictionary."""
1171 if self._CVs:
1172 out = {}
1173 for key, val in self._CVs.data_vars.items():
1174 if "feature_indices" in key:
1175 if key.startswith("MinRmsdFeature"):
1176 continue
1177 if "cartesian" in key:
1178 assert val.shape[0] == 1, (
1179 f"The substring 'feature_indices' is special and can "
1180 f"only contain a (1, n_frames) or (1, n_frames, 4) arrays. "
1181 f"Your value of {key=} has the shape: {val.shape=} "
1182 f"If you have manually "
1183 f"loaded a feature with this substring, use a different "
1184 f"one. These CVs contain integer indices and not values."
1185 )
1186 else:
1187 assert val.shape[-1] <= 4, (
1188 f"The substring 'feature_indices' is special and can "
1189 f"only contain a (1, n_frames) or (1, n_frames, <=4) arrays. "
1190 f"Your value of {key=} has the shape: {val.shape=} "
1191 f"If you have manually "
1192 f"loaded a feature with this substring, use a different "
1193 f"one. These CVs contain integer indices and not values."
1194 )
1195 continue
1196 axis_name = (
1197 "feature_axis"
1198 if "feature_axis" in val.attrs
1199 else "feature_axes" if "feature_axes" in val.attrs else None
1200 )
1201 if key == "central_angles_indices":
1202 raise Exception(
1203 f"{val.shape=}. {axis_name=} {val.attrs[axis_name]=}"
1204 )
1205 if np.any(np.isnan(val)):
1206 if axis_name is not None:
1207 val = val.dropna(val.attrs[axis_name])
1208 else:
1209 val = val.dropna(key.upper())
1210 try:
1211 out[key] = val.values.squeeze(0)
1212 except ValueError as e:
1213 raise Exception(f"{key=} {val=}") from e
1214 return out
1215 else:
1216 return {}
1218 @property
1219 def CVs(self) -> dict[str, np.ndarray]:
1220 """dict[str, numpy.ndarray]: Returns a simple dict from the more complicated self._CVs xarray Dataset.
1222 If self._CVs is empty and self.traj_file is a HDF5 (.h5) file, the contents
1223 of the HDF5 will be checked, whether CVs have been stored there.
1224 If not and empty dict will be returned.
1226 """
1227 return self._calc_CV()
1229 def _validate_uri(self, uri: Union[Path, str]) -> bool:
1230 """Checks whether `uri` is a valid uri."""
1231 # Encodermap imports
1232 from encodermap.misc.misc import _validate_uri
1234 return _validate_uri(str(uri))
1236 def load_traj(
1237 self,
1238 new_backend: Literal["no_load", "mdtraj"] = "mdtraj",
1239 ) -> None:
1240 """Loads the trajectory, with a new specified backend.
1242 After this is called the instance variable self.trajectory
1243 will contain a mdtraj Trajectory object.
1245 Args:
1246 new_backend (str, optional): Can either be:
1247 - ``'mdtraj'`` to load the trajectory using mdtraj.
1248 - ``'no_load'`` to not load the traj (unload).
1249 Defaults to ``'mdtraj'``.
1251 """
1252 if self.backend == new_backend:
1253 return
1254 if self.backend == "mdtraj" and new_backend == "no_load":
1255 self.unload()
1256 self.topology = False
1257 if self.backend == "no_load" and new_backend == "mdtraj":
1258 self.backend = new_backend
1259 # call the setter again
1260 try:
1261 self._traj = self.traj_file
1262 except Exception:
1263 self.backend = "no_load"
1264 raise
1265 self.topology = self.top
1267 def select(
1268 self,
1269 sel_str: str = "all",
1270 ) -> np.ndarray:
1271 """Execute a selection against the topology.
1273 Args:
1274 sel_str (str, optional): What to select. Defaults to 'all'.
1276 See also:
1277 https://mdtraj.org/1.9.4/atom_selection.html
1279 Examples:
1280 >>> import encodermap as em
1281 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb")
1282 >>> select = traj.top.select("name CA and resSeq 1")
1283 >>> select
1284 array([1])
1286 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb")
1287 >>> select = traj.top.select("name CA and resSeq 1")
1288 >>> traj.top.atom(select[0])
1289 MET1-CA
1291 """
1292 return self.top.select(sel_str)
1294 def unload(
1295 self,
1296 CVs: bool = False,
1297 ) -> None:
1298 """Clears up RAM by deleting the trajectory info and the CV data.
1300 If ``CVs`` is set to True the loaded CVs will also be deleted.
1302 Args:
1303 CVs (bool, optional): Whether to also delete CVs, defaults to False.
1305 """
1306 if self.backend == "no_load":
1307 return
1308 self.backend = "no_load"
1309 for key in self._mdtraj_attr:
1310 try:
1311 del self.__dict__[key]
1312 except KeyError:
1313 pass
1314 if CVs:
1315 self._CVs = xr.Dataset()
1316 self.trajectory, self.topology = False, False
1318 def _gen_ensemble(self) -> TrajEnsemble:
1319 """Creates a :obj:`encodermap.trajinfo.info_all.TrajEnsemble` instance
1320 with this :obj:`SingleTraj` in it.
1322 This method is needed to add two SingleTraj objects
1323 along the 'trajectory' axis with the method add_new_traj.
1324 This method is also called by the __getitem__ method of the TrajEnsemble class.
1326 """
1327 # Encodermap imports
1328 from encodermap.trajinfo.info_all import TrajEnsemble
1330 info_all = TrajEnsemble(
1331 trajs=[self],
1332 backend=self.backend,
1333 common_str=[],
1334 basename_fn=self.basename_fn,
1335 )
1336 info_all.trajs[0]._CVs = self._CVs
1337 info_all.trajs[0].common_str = self.common_str
1338 info_all.common_str = [self.common_str]
1339 info_all.trajs[0].traj_num = self.traj_num
1340 info_all.trajs[0].index = self.index
1341 info_all.trajs[0]._custom_top = self._custom_top
1342 info_all.trajs[0].topology = self._custom_top.top
1343 return info_all
1345 def _add_along_traj(self, y: SingleTraj) -> TrajEnsemble:
1346 """Puts self and y into a :obj:`encodermap.trajinfo.info_all.TrajEnsemble`.
1348 This way the trajectories are not appended along the time
1349 axis but rather along the 'trajectory' axis.
1351 Args:
1352 y (SingleTraj): The other :obj:`SingleTraj` trajectory.
1354 """
1355 if self.traj_num == y.traj_num:
1356 if y.traj_num is not None:
1357 y.traj_num += 1
1358 else:
1359 y.traj_num = 1
1360 class_1 = self._gen_ensemble()
1361 class_2 = y._gen_ensemble()
1362 new_class = class_1 + class_2
1363 return new_class
1365 def get_single_frame(self, key: int) -> SingleTraj:
1366 """Returns a single frame from the trajectory.
1368 Args:
1369 key (Union[int, np.int]): Index of the frame.
1371 Examples:
1373 Import EncoderMap and load :obj:`SingleTraj`.
1375 >>> import encodermap as em
1376 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb")
1377 >>> traj.n_frames
1378 14
1380 Load the same traj and give it a ``traj_num`` for recognition in a
1381 set of multiple trajectories.
1383 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb", traj_num=5)
1384 >>> frame = traj.get_single_frame(2)
1385 >>> frame.id
1386 array([[5, 2]])
1388 """
1389 return self.__getitem__(key)
1391 def show_traj(self, gui: bool = True) -> Any:
1392 """Returns an nglview view object.
1394 See Also:
1395 https://nglviewer.org/nglview/latest/
1397 Returns:
1398 view (nglview.widget): The nglview widget object.
1400 """
1401 # Third Party Imports
1402 import nglview
1404 view = nglview.show_mdtraj(self.traj, gui=gui)
1405 return view
1407 def dash_summary(self) -> pd.DataFrame:
1408 """Returns a :obj:`pandas.DataFrame` with useful information about this instance.
1410 Returns:
1411 pd.DataFrame: The dataframe.
1413 """
1414 dt = self.traj.time
1415 dt = np.unique(dt[1:] - dt[:-1])
1416 if len(dt) == 1:
1417 dt = dt[0]
1418 elif len(dt) == 0:
1419 dt = "single frame"
1420 if self.index == (None,):
1421 index = "[::]"
1422 else:
1423 index = self.index[1:]
1424 if len(index) == 1:
1425 index = index[0]
1426 df = pd.DataFrame(
1427 {
1428 "field": [
1429 "n_frames",
1430 "n_atoms",
1431 "dt (ps)",
1432 "traj_file",
1433 "top_file",
1434 "index",
1435 "common_str",
1436 ],
1437 "value": [
1438 self.n_frames,
1439 self.n_atoms,
1440 dt,
1441 self.traj_file,
1442 self.top_file,
1443 index,
1444 self.common_str,
1445 ],
1446 }
1447 )
1448 return df.astype(str)
1450 def load_custom_topology(
1451 self,
1452 custom_top: Optional[Union[CustomTopology, CustomAAsDict]] = None,
1453 ) -> None:
1454 """Loads a custom_topology from a `CustomTopology` class or a dict.
1456 See Also:
1457 :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology`
1459 Args:
1460 custom_top: Optional[Union[CustomTopology, CustomAAsDict]]: An
1461 instance of :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology`
1462 or a dictionary that can be made into such.
1464 """
1465 # Encodermap imports
1466 from encodermap.trajinfo.trajinfo_utils import CustomTopology
1468 if isinstance(custom_top, CustomTopology):
1469 self._custom_top = custom_top
1470 else:
1471 self._custom_top = CustomTopology.from_dict(custom_top, traj=self)
1472 # overwrite the old topology
1473 self.topology = self._custom_top.top
1475 def load_CV(
1476 self,
1477 data: SingleTrajFeatureType,
1478 attr_name: Optional[str] = None,
1479 cols: Optional[list[int]] = None,
1480 deg: Optional[bool] = None,
1481 periodic: bool = True,
1482 labels: Optional[list[str]] = None,
1483 override: bool = False,
1484 ) -> None:
1485 """Load CVs into traj. Many options are possible. Provide xarray,
1486 numpy array, em.loading.feature, em.featurizer, and even string!
1488 This method loads CVs into the SingleTraj instance. Many ways of doing so
1489 are available:
1490 - :obj:`numpy.ndarray`: The easiest way. Provide a np array and a name for
1491 the array, and the data will be saved as an instance variable,
1492 accesible via `SingleTraj.name`.
1493 - :obj:`xarray.DataArray`: You can load a multidimensional xarray as
1494 data into the class. Please refer to xarrays own documentation
1495 if you want to create one yourself.
1496 - :obj:`xarray.Dataset`: You can add another dataset to the existing _CVs.
1497 - :obj:`encodermap.loading.features.Feature`: If you provide one of the
1498 features from ``encodermap.loading.features`` the resulting
1499 features will be loaded and also be placed under the set name.
1500 - :obj:`encodermap.loading.featurizer.Featurizer`: If you provide a
1501 full featurizer, the data will be generated and be accessible
1502 as an attribute.
1503 - str: If a string is provided, the data will be loaded from a
1504 ``.txt``, ``.npy``, or NetCDF / HDF5 ``.nc`` file.
1506 Args:
1507 data (Union[str, numpy.ndarray, xr.DataArray,
1508 em.loading.features.Feature, em.loading.featurizer.Featurizer]):
1509 The CV to load. Either as :obj:`numpy.ndarray`,
1510 :obj:`xarray.DataArray`, EncoderMap feature,
1511 or EncoderMap Featurizer.
1512 attr_name (Optional[str]): The name under which the CV
1513 should be found in the class. Is needed, if a raw numpy array
1514 is passed, otherwise the name will be generated from the filename
1515 (if ``data == str``), the DataArray.name (if ``data == xarray.DataArray``),
1516 or the feature name.
1517 cols (Optional[list]): A list specifying the columns
1518 to use it for the high-dimensional data. If your highD data contains
1519 (x,y,z,...)-errors or has an enumeration column at ``col=0``
1520 this can be used to remove this unwanted data.
1521 deg (Optional[bool]): Whether the provided data is in radians (False)
1522 or degree (True). It can also be None for non-angular data.
1523 labels (Optional[Union[list, str]]): If you want to label
1524 the data you provided, pass a list of str. If set to None,
1525 the features in this dimension will be labeled as
1526 ``[f"{attr_name.upper()} FEATURE {i}" for i in range(self.n_frames)]``.
1527 If a str is provided, the features will be labeled as
1528 ``[f"{attr_name.upper()} {label.upper()} {i}" for i in range(self.n_frames)]``.
1529 If a list of str is provided, it needs to have the same length
1530 as the traj has frames. Defaults to None.
1531 override (bool): Whether to overwrite existing CVs. The method will also
1532 print a message which CVs have been overwritten.
1534 Examples:
1536 Import EncoderMap and load an example Trajectory.
1538 >>> import encodermap as em
1539 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb")
1541 Load the central dihedrals using ``data='central_dihedrals'` as shortcut.
1543 >>> traj.load_CV("central_dihedrals")
1544 >>> traj.central_dihedrals.shape
1545 (14, 222)
1548 >>> traj._CVs['central_dihedrals']['CENTRAL_DIHEDRALS'].values[:2]
1549 ['CENTERDIH PSI RESID MET: 1 CHAIN 0'
1550 'CENTERDIH OMEGA RESID MET: 1 CHAIN 0']
1552 Slicing the :obj:`SingleTraj` keeps all CVs in order.
1554 >>> import numpy as np
1555 >>> from pathlib import Path
1556 >>> traj1 = em.SingleTraj(
1557 ... Path(em.__file__).parent.parent / "tests/data/1am7_corrected.xtc",
1558 ... Path(em.__file__).parent.parent / "tests/data/1am7_protein.pdb",
1559 ... )
1560 >>> traj1.load_CV(traj1.xyz[..., -1], 'z_coordinate')
1561 ...
1562 >>> for i, frame in enumerate(traj1):
1563 ... print(np.array_equal(frame.z_coordinate[0], frame.xyz[0, :, -1]))
1564 ... if i == 3:
1565 ... break
1566 True
1567 True
1568 True
1569 True
1571 Raises:
1572 FileNotFoundError: When the file given by `data` does not exist.
1573 IOError: When the provided filename does not have .txt, .npy or .nc extension.
1574 TypeError: When `data` does not match the specified input types.
1575 Exception: When a numpy array has been passed as `data` and no `attr_name` has been provided.
1576 Exception: When the provided `attr_name` is str, but cannot be a python identifier.
1578 """
1579 # Local Folder Imports
1580 from .trajinfo_utils import load_CVs_singletraj, trajs_combine_attrs
1582 if isinstance(attr_name, str):
1583 if "feature_indices" in attr_name:
1584 raise Exception(
1585 f"The substring 'feature_indices' is a protected attribute. "
1586 f"Your attribute can't contain this substring."
1587 )
1589 new_CVs = load_CVs_singletraj(
1590 data,
1591 self,
1592 attr_name,
1593 cols,
1594 deg=deg,
1595 periodic=periodic,
1596 labels=labels,
1597 )
1598 if self._CVs:
1599 assert (
1600 len(new_CVs.coords["traj_num"]) == 1
1601 ), f"something bad happened: {self._CVs=}"
1602 if len(new_CVs.coords["traj_num"]) > 1:
1603 raise Exception(
1604 f"The provided feature resulted in a dataset with "
1605 f"{new_CVs.sizes['traj_num']} trajectories. A `SingleTraj` "
1606 f"class can't accept such a feature."
1607 )
1608 if self.traj_num is not None:
1609 assert new_CVs.coords["traj_num"] == np.array([self.traj_num]), (
1610 data,
1611 self.traj_num,
1612 new_CVs.coords["traj_num"],
1613 )
1615 # check the sizes
1616 len_CVs = new_CVs.coords["frame_num"].shape[0]
1617 if self._CVs:
1618 n_frames = self._CVs.coords["frame_num"].shape[0]
1619 else:
1620 n_frames = self.n_frames
1621 if n_frames != len_CVs:
1622 raise Exception(
1623 f"Loading the requested {data} CVs is not possible, as "
1624 f"they are not aligned with the number of frames in the "
1625 f"trajectory. The CVs have {len_CVs} frames, the trajectory "
1626 f"{self._traj_file} has {n_frames} frames."
1627 )
1629 # check the keys and whether they get overwritten
1630 if hasattr(new_CVs, "keys"):
1631 new_keys = new_CVs.keys()
1632 else:
1633 new_keys = set([new_CVs.name])
1634 if override:
1635 if overwritten_keys := self._CVs.keys() & new_keys:
1636 warnings.warn(
1637 f"Overwriting the following CVs with new values: "
1638 f"{overwritten_keys}."
1639 )
1640 self._CVs = xr.merge(
1641 [new_CVs, self._CVs],
1642 combine_attrs=trajs_combine_attrs,
1643 compat="override",
1644 join="left",
1645 )
1646 else:
1647 try:
1648 CVs = xr.merge([self._CVs, new_CVs], combine_attrs=trajs_combine_attrs)
1649 assert len(CVs.coords["traj_num"]) == 1, (
1650 f"Can't merge\n\n{self._CVs=}\n\nand\n\n{new_CVs=}\n\n, "
1651 f"because they would stack along the traj axis."
1652 )
1653 self._CVs = CVs
1654 except xr.core.merge.MergeError as e:
1655 msg = (
1656 f"Could not add the CV `{attr_name}` to the CVs of the traj "
1657 f"likely due to it being already in the CVs "
1658 f"({list(self.CVs.keys())}). Set `override` to True to "
1659 f"overwrite these CVs. In case you are faced with "
1660 f"conflicting values on 'traj_name', here they are:\n\n"
1661 f"{self._CVs.coords['traj_name']=}\n\n{new_CVs.coords['traj_name']=}"
1662 )
1663 raise Exception(msg) from e
1665 def _common_str_to_hdf_file(self, fname: Union[Path, str]) -> None:
1666 """Saves this traj's common_str as an attribute to a hdf5 file."""
1667 common_str = self.common_str
1668 if common_str is None:
1669 common_str = ""
1670 with h5py.File(fname, "a") as f:
1671 f.attrs[f"common_str_{self.traj_num}"] = common_str
1673 def save_CV_as_numpy(
1674 self,
1675 attr_name: str,
1676 fname: Optional[str] = None,
1677 overwrite: bool = False,
1678 ) -> None:
1679 """Saves a specified collective variable of this traj as a ``.npy`` file.
1681 This got its own method for parallelization purposes.
1683 Args:
1684 attr_name (str): Name of the CV to save.
1685 fname (str, optional): Can be either
1686 overwrite (bool, opt): Whether to overwrite the file. Defaults to False.
1688 Raises:
1689 IOError: When the file already exists and overwrite is set to False.
1691 """
1692 if fname is None: # pragma: no cover
1693 fname = f"{self.basename}_{attr_name}.npy"
1694 if os.path.isdir(fname):
1695 fname = os.path.join(fname, f"{self.basename}_{attr_name}.npy")
1696 if os.path.isfile(fname) and not overwrite:
1697 raise IOError(f"{fname} already exists. Set overwrite=True to overwrite.")
1698 np.save(fname, self.CVs[attr_name])
1700 def sidechain_info(self) -> dict[int, dict[int, Sequence[int]]]:
1701 """Indices used for the AngleDihedralCartesianEncoderMap class to
1702 allow training with multiple different sidechains.
1704 Returns:
1705 dict[str, Sequence[int]]: The indices. The key '-1' is used for
1706 the hypothetical convex hull of all feature spaces (the output of
1707 the tensorflow model). The other keys match the common_str of the
1708 trajs.
1710 Raises:
1711 Exception: When the common_strings and topologies are not
1712 aligned. An exception is raised. Aligned means that all trajs
1713 with the same common_str should possess the same topology.
1715 """
1716 # make sure no clustal w has not been loaded
1717 if hasattr(self, "clustal_w"):
1718 raise NotImplementedError(
1719 f"This is currently not supported for TrajEsnembles with "
1720 f"clustal_w alignments."
1721 )
1722 if not "side_dihedrals" in self._CVs:
1723 self.load_CV("side_dihedrals")
1725 max_residues = self.top.n_residues
1726 out = {}
1728 for residx in range(1, max_residues + 1):
1729 labels = [
1730 l
1731 for l in self._CVs.side_dihedrals.coords["SIDE_DIHEDRALS"].values
1732 if int(re.findall(r"\d+", l)[-2]) == residx
1733 ]
1734 first_frame = self._CVs.coords["frame_num"].min()
1735 length = (
1736 self._CVs.side_dihedrals.sel(
1737 frame_num=first_frame, SIDE_DIHEDRALS=labels
1738 )
1739 .dropna(dim="SIDE_DIHEDRALS")
1740 .sizes["SIDE_DIHEDRALS"]
1741 )
1742 out[residx] = length
1743 return out
1745 def atom_slice(
1746 self,
1747 atom_indices: np.ndarray,
1748 invert: bool = False,
1749 ) -> None:
1750 """Deletes atoms from this :obj:`SingleTraj` instance.
1752 Args:
1753 atom_indices (Union[list, numpy.ndarray]): The indices of the
1754 atoms to keep.
1755 invert (bool): If False, it is assumed, that the atoms in ``atom_indices``
1756 are the ones to be kept. If True, the atoms in ``atom_indices``
1757 are the ones to be removed.
1759 """
1760 atom_indices = np.asarray(atom_indices)
1761 if invert:
1762 atom_indices = np.array(
1763 [a.index for a in self.top.atoms if a.index not in atom_indices]
1764 )
1765 self._atom_indices = atom_indices
1766 if self._CVs:
1767 warnings.warn(
1768 "Dropping CVs from trajectory. Slicing CVs with this method is "
1769 "currently not possible. Raise an issue if you want to have this "
1770 "feature added."
1771 )
1772 self._CVs = xr.Dataset()
1773 self._custom_top._parsed = False
1774 self.topology = self._custom_top.top.subset(atom_indices)
1775 if self.traj_file != ".":
1776 self._traj = self.traj_file
1778 def join(self, other: Union[SingleTraj, md.Trajectory]) -> md.Trajectory:
1779 """Join two trajectories together along the time/frame axis.
1781 Note:
1782 Returns a ``mdtraj.Trajectory`` and thus loses CVs, filenames, etc.
1784 """
1785 if isinstance(other, md.Trajectory):
1786 return self.traj.join(other)
1787 return self.traj.join(other.traj)
1789 def stack(self, other: SingleTraj) -> md.Trajectory:
1790 """Stack two trajectories along the atom axis
1792 Note:
1793 Returns a m``dtraj.Trajectory`` and thus loses CVs, filenames, etc.
1795 """
1796 if isinstance(other, md.core.trajectory.Trajectory):
1797 return self.traj.stack(other)
1798 return self.traj.stack(other.traj)
1800 def superpose(
1801 self,
1802 reference: Union[md.Trajectory, SingleTraj],
1803 frame: int = 0,
1804 atom_indices: Optional[np.ndarray] = None,
1805 ref_atom_indices: Optional[np.ndarray] = None,
1806 parallel: bool = True,
1807 inherit_CVs: bool = False,
1808 ) -> SingleTraj:
1809 """Superpose each conformation in this trajectory upon a reference
1811 Args:
1812 reference (Union[mdtraj.Trajectory, SingleTraj]): The reference
1813 frame to align to. If the reference has multiple frames and
1814 you want to use a specific frame as reference, use the
1815 ``frame`` argument also.
1816 frame (int, optional): Align to this frame in reference.
1817 Default is 0.
1818 atom_indices (Union[np.array, None], optional): Indices in
1819 self, used to calculate RMS values. Defaults to None which
1820 means all atoms will be used.
1821 ref_atom_indices (Union[np.array, None], optional): Indices
1822 in reference, used to calculate RMS values. Defaults to None
1823 which means all atoms will be used.
1824 parallel (bool, optional): Use OpenMP to run the superposition
1825 in parallel over multiple cores.
1826 inherit_CVs (bool, optional): Whether to also inherit the CVs.
1827 This feature is currently not implemented. It would require
1828 additional code in all Feature classes discerning intrinsic
1829 (distance, angle, cluster_membership, etc.) or an extrinsic
1830 feature (absolute coordinate, COG position, etc.). Then this
1831 extrinsic/intrinsic boolean flag also needs to accompany the
1832 xarray Datasets, so that the intrinsic CVs can be inherited,
1833 and the extrinsic can be dropped with a corresponding message.
1835 Returns:
1836 SingleTraj: A new trajectory with atoms aligned.
1838 """
1839 if isinstance(reference, md.core.trajectory.Trajectory):
1840 new = SingleTraj(
1841 self.traj.superpose(
1842 reference, frame, atom_indices, ref_atom_indices, parallel
1843 )
1844 )
1845 new = SingleTraj(
1846 self.traj.superpose(
1847 reference.traj, frame, atom_indices, ref_atom_indices, parallel
1848 )
1849 )
1850 if self.traj_file:
1851 new._traj_file = self._traj_file
1852 if self.top_file:
1853 new._top_file = self._top_file
1854 if self._CVs and inherit_CVs:
1855 raise NotImplementedError("CV inheritance not implemented yet.")
1856 return new
1858 def save(
1859 self,
1860 fname: str,
1861 CVs: Union[Literal["all"], list[str]] = "all",
1862 overwrite: bool = False,
1863 ) -> None:
1864 """Save the trajectory as HDF5 file format to disk.
1866 Args:
1867 fname (str): The filename.
1868 CVs (Union[List, 'all'], optional): Either provide a list of strings
1869 of the CVs you would like to save to disk, or set to 'all' to save
1870 all CVs. Defaults to [].
1871 overwrite (bool, optional): Whether force overwrite an existing file.
1872 Defaults to False.
1874 Raises:
1875 IOError: When the file already exists and overwrite is False.
1877 """
1878 # check and drop inhomogeneous attributes
1879 offending_keys = []
1880 if self._CVs:
1881 for da in self._CVs.data_vars.values():
1882 for key, val in da.attrs.items():
1883 if isinstance(val, list):
1884 offending_keys.append(key)
1885 for key in offending_keys:
1886 for da in self._CVs.data_vars.values():
1887 if key in da.attrs:
1888 del da.attrs[key]
1889 if key in self._CVs.attrs:
1890 del self._CVs.attrs[key]
1891 # raise exception if file already exists
1892 if os.path.isfile(fname) and not overwrite:
1893 raise IOError(f"{fname} already exists. Set overwrite=True to overwrite.")
1894 else:
1895 self.traj.save_hdf5(fname, force_overwrite=overwrite)
1897 # save the topology as a nested dict
1898 if self._custom_top:
1899 self._custom_top.traj = self
1900 self._custom_top.to_hdf_file(fname=fname)
1902 if self.common_str:
1903 with h5.File(fname, "a") as file:
1904 if self.traj_num is None:
1905 file.attrs["common_str"] = self.common_str
1906 else:
1907 file.attrs[f"common_str_{self.traj_num}"] = self.common_str
1909 # save the CVs with xarray
1910 if self._CVs and CVs == "all":
1911 save_netcdf_alongside_mdtraj(fname, self._CVs)
1912 return
1913 if self._CVs and isinstance(CVs, list):
1914 with h5.File(fname, "a") as file:
1915 if "CVs" in list(file.keys()):
1916 grp = file["CVs"]
1917 else:
1918 grp = file.create_group("CVs")
1919 for key in CVs:
1920 value = self._CVs[key]
1921 assert self.n_frames == value.shape[1]
1922 grp.create_dataset(name=key, data=value)
1924 @overload
1925 def iterframes( 1925 ↛ exitline 1925 didn't jump to the function exit
1926 self,
1927 with_traj_num: bool = False,
1928 ) -> Iterable[tuple[int, SingleTraj]]: ...
1930 @overload
1931 def iterframes( 1931 ↛ exitline 1931 didn't jump to the function exit
1932 self,
1933 with_traj_num: bool = True,
1934 ) -> Iterable[tuple[int, int, SingleTraj]]: ...
1936 def iterframes(
1937 self,
1938 with_traj_num: bool = False,
1939 ) -> Iterable[tuple[Any, ...]]:
1940 """Iterator over the frames in this class.
1942 Args:
1943 with_traj_num (bool): Whether to return a three-tuple of traj_num,
1944 frame_num, frame (True) or just traj_num, frame (False).
1946 Yields:
1947 tuple: A tuple containing the following:
1948 - int: The traj_num.
1949 - int: The frame_num.
1950 - encodermap.SingleTraj: An SingleTraj object.
1952 Examples:
1954 Import EncoderMap and create :obj:`SingleTraj` instance.
1956 >>> import encodermap as em
1957 >>> traj = em.SingleTraj('https://files.rcsb.org/view/1YUG.pdb')
1958 >>> traj.n_frames
1959 15
1961 Slicing the trajectory every 5th frame
1963 >>> traj = traj[::5]
1964 >>> traj.n_frames
1965 3
1967 Using the ``iterframes()`` iterator.
1969 >>> for frame_num, frame in traj.iterframes():
1970 ... print(frame_num, frame.n_frames)
1971 0 1
1972 5 1
1973 10 1
1975 """
1976 if self.id.ndim == 2:
1977 a = self.id[:, 1]
1978 else:
1979 a = self.id
1980 for i, frame in zip(a, self):
1981 if with_traj_num:
1982 yield self.traj_num, i, frame
1983 else:
1984 yield i, frame
1986 def __copy__(self):
1987 cls = self.__class__
1988 result = cls.__new__(cls)
1989 result.__dict__.update(self.__dict__)
1990 return result
1992 def __deepcopy__(self, memo):
1993 # Standard Library Imports
1994 from copy import deepcopy
1996 cls = self.__class__
1997 result = cls.__new__(cls)
1998 memo[id(self)] = result
1999 for k, v in self.__dict__.items():
2000 setattr(result, k, deepcopy(v, memo))
2001 return result
2003 def __hash__(self) -> int:
2004 hash_value = hash(self.top)
2005 # combine with hashes of arrays
2006 hash_value ^= _hash_numpy_array(self._xyz)
2007 hash_value ^= _hash_numpy_array(self.time)
2008 if self._unitcell_lengths is not None:
2009 hash_value ^= _hash_numpy_array(self._unitcell_lengths)
2010 if self._unitcell_angles is not None:
2011 hash_value ^= _hash_numpy_array(self._unitcell_angles)
2012 return hash_value
2014 def __eq__(self, other: SingleTraj) -> bool:
2015 """Two SingleTraj objects are the same, when the trajectories are the same,
2016 the files are the same, and the loaded CVs are the same."""
2017 if self._traj_file != other._traj_file:
2018 return False
2019 if self.__hash__() != other.__hash__():
2020 return False
2021 if not self._CVs.equals(other._CVs):
2022 return False
2023 return True
2025 def __reversed__(self) -> SingleTraj:
2026 """Reverses the frame order of the traj. Same as traj[::-1]"""
2027 return self[::-1]
2029 def __enter__(self):
2030 """Enters context manager. Inside the context manager, the traj stays loaded."""
2031 self.load_traj()
2032 return self
2034 def __exit__(self, type, value, traceback):
2035 """Exits the context manager and deletes unwanted variables."""
2036 self.unload()
2038 def __iter__(self):
2039 """Iterate over frames in this class. Returns the correct
2040 CVs along with the frame of the trajectory."""
2041 self._index = 0
2042 if len(self) == 0 and self.index is None:
2043 self.load_traj()
2044 return self
2046 def __next__(self):
2047 if len(self.id) == 1:
2048 return self
2049 if self._index >= self.n_frames:
2050 raise StopIteration
2051 else:
2052 self._index += 1
2053 return self[self._index - 1]
2055 def __getitem__(self, key: CanBeIndex) -> SingleTraj:
2056 """This method returns another trajectory as a SingleTraj class.
2058 Args:
2059 key (Union[int, list[int], numpy.ndarray, slice]): Indexing the trajectory
2060 can be done by int (returns a traj with 1 frame), lists of int or
2061 numpy.ndarray (returns a new traj with len(traj) == len(key)), or
2062 slice ([::3]), which returns a new traj with the correct number of
2063 frames.
2065 Returns:
2066 Info_Single: A SingleTraj object with this frame in it.
2068 """
2069 if not isinstance(key, (int, np.int_, list, np.ndarray, slice)):
2070 raise TypeError(
2071 f"Indexing of `SingleTraj` requires the index to "
2072 f"be one of the following types: (int, "
2073 f"list, np.ndarray, slice), you provided {type(key)}."
2074 )
2076 if any([isinstance(i, (int, np.integer)) for i in self.index]) and key != 0:
2077 raise IndexError(
2078 f"SingleTraj index {key} out of range for traj with only 1 frame."
2079 )
2081 if isinstance(key, (int, np.integer)):
2082 if key > self.n_frames:
2083 raise IndexError(
2084 f"Index {key} out of range for traj with "
2085 f"{self.n_frames} frames."
2086 )
2087 if isinstance(key, (list, np.ndarray)):
2088 if any([k > self.n_frames for k in key]):
2089 raise IndexError(
2090 f"At least one index in {key} out of range for "
2091 f"traj with {self.n_frames} frames. Normally frames are "
2092 f"selected by current integer index. If you are trying to "
2093 f"access frames by their number as it is in the file {self.traj_file}, "
2094 f"you can use the `fsel[]` locator of this class:\n\n"
2095 f"traj = em.load('traj_file.xtc', 'top_file.xtc')\n"
2096 f"traj.fsel[{key}]."
2097 )
2099 # append the index to the list of "transformations"
2100 new_index = (*self.index, key)
2102 # build a new traj from that
2103 if self.backend == "no_load":
2104 traj_out = SingleTraj(
2105 self.traj_file,
2106 self.top_file,
2107 backend=self.backend,
2108 common_str=self.common_str,
2109 index=new_index,
2110 traj_num=self.traj_num,
2111 basename_fn=self.basename_fn,
2112 )
2113 else:
2114 traj_out = SingleTraj(
2115 self.trajectory[key],
2116 self.top_file,
2117 backend=self.backend,
2118 common_str=self.common_str,
2119 index=new_index,
2120 traj_num=self.traj_num,
2121 basename_fn=self.basename_fn,
2122 )
2123 traj_out._traj_file = self._traj_file
2124 traj_out._top_file = self._top_file
2125 assert traj_out._traj_file == self._traj_file
2127 # the original_frames
2128 if self._atom_indices is not None:
2129 traj_out.atom_slice(self._atom_indices)
2130 assert self.n_atoms == traj_out.n_atoms
2131 assert self.n_residues == traj_out.n_residues
2132 traj_out._orig_frames = self._orig_frames
2133 traj_out._loaded_once = self._loaded_once
2135 # last the CVs
2136 if self._CVs:
2137 traj_out._CVs = self._CVs.isel(frame_num=key)
2138 if "frame_num" not in traj_out._CVs.dims:
2139 traj_out._CVs = traj_out._CVs.expand_dims(
2140 {
2141 "frame_num": [key],
2142 },
2143 )
2144 traj_out._CVs.assign_coords(time=("frame_num", traj_out.time))
2146 return traj_out
2148 @property
2149 def fsel(self):
2150 return SingleTrajFsel(self)
2152 def __add__(self, y: SingleTraj) -> TrajEnsemble:
2153 """Addition of two SingleTraj classes yields TrajEnsemble class.
2155 Args:
2156 y (encodermap.SingleTraj): The other traj, that will be added.
2158 Returns:
2159 encodermap.TrajEnsemble: The new trajs.
2161 """
2162 return self._add_along_traj(y)
2164 def __getattr__(self, attr):
2165 """What to do when attributes cannot be obtained in a normal way?.
2167 This method allows access to the `self.CVs` dictionary's values as
2168 instance variables. Furthermore, if a mdtraj variable is called,
2169 the traj is loaded, and the correct variable is returned.
2171 """
2172 if attr in self._mdtraj_attr:
2173 self.load_traj()
2174 return getattr(self.traj, attr)
2175 elif attr in self._CVs:
2176 val = self._CVs[attr] # [index]
2177 axis_name = (
2178 "feature_axis"
2179 if "feature_axis" in val.attrs
2180 else "feature_axes" if "feature_axes" in val.attrs else None
2181 )
2182 if np.any(np.isnan(val)):
2183 if axis_name is not None:
2184 if "indices" in val.name:
2185 assert val.dropna("ATOM_NO").size > 0, (
2186 f"Can't `dropna` from DataArray {val=}. This would "
2187 f"result in an empty array. Here are the nan indices: "
2188 f"{np.where(np.isnan(val.values))[0]=}"
2189 )
2190 val = val.dropna("ATOM_NO")
2191 else:
2192 val = val.dropna(val.attrs[axis_name])
2193 else:
2194 val = val.dropna(attr.upper())
2195 val_squeezed = val.values.squeeze(0)
2196 assert val_squeezed.size > 0, (
2197 f"The value of {attr} is unexpectedly empty {val.size=} after dropping nans "
2198 f"and dropping nans from he feature axis '{val.attrs[axis_name]}' and "
2199 f"squeezing out the trajectory axis. This is the original "
2200 f"DataArray:\n{self._CVs[attr]}\n\nAfter the dropna, the value was:\n{val}"
2201 )
2202 return val_squeezed
2203 elif attr == "traj": # pragma: no cover
2204 self.__getattribute__(attr)
2205 elif attr == "id": # pragma: no cover
2206 self.__getattribute__(attr)
2207 elif attr == "top":
2208 return self._get_top()
2209 else:
2210 raise AttributeError(f"'SingleTraj' object has no attribute '{attr}'")
2212 def _string_summary(self) -> str: # pragma: no cover
2213 """Returns a summary about the current instance.
2215 Number of frames, index, loaded CVs.
2217 """
2218 s = f"encodermap.SingleTraj object."
2219 if self.backend == "no_load":
2220 s += " Currently not in memory."
2221 else:
2222 s += " Data currently in memory."
2223 if self.basename:
2224 s += f" Basename is '{self.basename}'."
2225 if self.index != (None,):
2226 with np.printoptions(threshold=1, edgeitems=1):
2227 s += f" At indices {self.index[1:]}."
2228 if self._CVs:
2229 for key, value in self._CVs.items():
2230 if "feature_indices" in key:
2231 continue
2232 shape = value.shape
2233 if not shape:
2234 shape = 1
2235 s += f" CV {key} with shape {shape} loaded."
2236 else:
2237 s += " Not containing any CVs."
2238 if "n_atoms" in self.__dict__.keys():
2239 s += f" Containing {self.n_atoms} atoms."
2240 if "n_frames" in self.__dict__.keys():
2241 s += f" Containing {self.n_frames} frames."
2242 if self.common_str:
2243 s += f" Common string is '{self.common_str}'."
2244 if self._custom_top:
2245 s += (
2246 f" Custom topology with residues: "
2247 f"{[r.name for r in self._custom_top.residues]}."
2248 )
2249 return s
2251 def __len__(self):
2252 return self.n_frames
2254 def __str__(self):
2255 return self._string_summary()
2257 def __repr__(self):
2258 return f"<{self._string_summary()} Object at 0x{id(self):02x}>"