Coverage for encodermap/trajinfo/info_all.py: 8%
1225 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
1# -*- coding: utf-8 -*-
2# encodermap/trajinfo/info_all.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Classes to work with ensembles of trajectories.
24The statistics of a protein can be better described by an ensemble of proteins,
25rather than a single long trajectory. Treating a protein in such a way opens great
26possibilities and changes the way one can treat molecular dynamics data.
27Trajectory ensembles allow:
28 * Faster convergence via adaptive sampling.
29 * Better grasp of equilibrium and off-equilibrium dynamics.
32This subpackage contains two classes which are containers of trajectory data.
33The SingleTraj trajectory contains information about a single trajectory.
34The TrajEnsemble class contains information about multiple trajectories. This adds
35a new dimension to MD data. The time and atom dimension are already established.
36Two frames can be appended along the time axis to get a trajectory with multiple
37frames. If they are appended along the atom axis, the new frame contains the
38atoms of these two. The trajectory works in a similar fashion. Adding two trajectories
39along the trajectory axis returns a trajectory ensemble, represented as a `TrajEnsemble`
40class in this package.
42See also:
43 http://statisticalbiophysicsblog.org/?p=92
45"""
47################################################################################
48# Imports
49################################################################################
52# Future Imports at the top
53from __future__ import annotations
55# Standard Library Imports
56import contextlib
57import copy
58import json
59import operator
60import os
61import re
62import warnings
63from collections.abc import Callable, Iterator, KeysView, Sequence
64from copy import deepcopy
65from pathlib import Path
66from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload
68# Third Party Imports
69import numpy as np
70import tables
71from optional_imports import _optional_import
72from tqdm import tqdm as normal_tqdm_
73from tqdm.notebook import tqdm as notebook_tqdm_
75# Encodermap imports
76from encodermap._typing import CanBeIndex, CustomAAsDict
77from encodermap.misc.misc import (
78 _TOPOLOGY_EXTS,
79 _is_notebook,
80 get_full_common_str_and_ref,
81)
82from encodermap.misc.xarray_save_wrong_hdf5 import save_netcdf_alongside_mdtraj
83from encodermap.trajinfo.trajinfo_utils import CustomTopology, TrajEnsembleFeatureType
86################################################################################
87# Typing
88################################################################################
91if TYPE_CHECKING: # pragma: no cover
92 # Third Party Imports
93 import mdtraj as md
94 import pandas as pd
95 import tensorflow as tf
96 import xarray as xr
98 # Encodermap imports
99 from encodermap.trajinfo.info_single import SingleTraj
102string_types = (str,)
105################################################################################
106# Optional Imports
107################################################################################
110md = _optional_import("mdtraj")
111pd = _optional_import("pandas")
112xr = _optional_import("xarray")
113HDF5TrajectoryFile = _optional_import("mdtraj", "formats.hdf5")
114h5py = _optional_import("h5py")
116################################################################################
117# Globals
118################################################################################
121__all__: list[str] = ["TrajEnsemble"]
124################################################################################
125# Utils
126################################################################################
129class notebook_tqdm(notebook_tqdm_):
130 def __init__(self, *args, **kwargs):
131 kwargs.pop("function", None)
132 super().__init__(*args, **kwargs)
134 def reset(self, total=None, **kwargs):
135 self.total = total
136 self.refresh()
138 def update(self, n=1, **kwargs):
139 kwargs.pop("function", None)
140 super().update(n=n)
143class normal_tqdm(normal_tqdm_):
144 def __init__(self, *args, **kwargs):
145 self._calls = {}
146 function = kwargs.pop("function")
147 super().__init__(*args, **kwargs)
148 if function not in self._calls:
149 self._calls[function] = {
150 "update_calls": 0,
151 "total": self.total,
152 }
153 self.print = os.getenv("ENCODERMAP_PRINT_PROG_UPDATES", "False") == "True"
154 if self.print:
155 print("INSTANTIATION")
156 self.debug_print()
158 def debug_print(self):
159 print(f"Progbar {id(self)}")
160 for function, data in self._calls.items():
161 print(
162 f"{function:<15} total: {data['total']:>3} n: {data['update_calls']:>3}"
163 )
164 print("\n")
166 def update(self, n=1, **kwargs):
167 function = kwargs.pop("function", None)
168 if function is not None:
169 if function not in self._calls:
170 self._calls[function] = {
171 "update_calls": 0,
172 "total": 0,
173 }
174 if self.print:
175 print(f"BEFORE UPDATE ({function})")
176 self.debug_print()
177 super().update(n)
178 if function is not None:
179 self._calls[function]["update_calls"] += 1
180 if self.print and function is not None:
181 print(f"AFTER UPDATE ({function})")
182 self.debug_print()
184 def reset(self, total=None, **kwargs):
185 assert total > self.total
186 function = kwargs.pop("function", None)
187 if function is not None:
188 if function not in self._calls:
189 self._calls[function] = {
190 "update_calls": 0,
191 "total": total - self.total,
192 }
193 else:
194 self._calls[function]["total"] += total - self.total
195 if self.print:
196 print(f"BEFORE RESET ({function})")
197 self.debug_print()
198 self.total = total
199 self.refresh()
200 if self.print and function is not None:
201 print(f"AFTER RESET ({function})")
202 self.debug_print()
205@contextlib.contextmanager
206def temp_seed(seed: int) -> Iterator[None]:
207 """Temporarily set a numpy seed in a context manager.
209 Args:
210 seed (int): The seed.
212 Examples:
213 >>> from encodermap.trajinfo.info_all import temp_seed
214 >>> import numpy as np
215 >>> with temp_seed(123456789):
216 ... print(np.random.randint(low=0, high=10, size=(5, )))
217 [8 2 9 7 4]
219 """
220 state = np.random.get_state()
221 np.random.seed(seed)
222 try:
223 yield
224 finally:
225 np.random.set_state(state)
228def _check_mode(m, modes):
229 if m not in modes:
230 raise ValueError(
231 "This operation is only available when a file " 'is open in mode="%s".' % m
232 )
235################################################################################
236# Classes
237################################################################################
240class HDF5GroupWrite(md.formats.HDF5TrajectoryFile):
241 def keys(self) -> KeysView:
242 if not self._open:
243 raise Exception(f"Can't view keys of closed HDF5 file.")
244 nodes = [n.name for n in self._handle.list_nodes("/")]
245 return KeysView(nodes)
247 def _initialize_headers(
248 self,
249 group_id: str,
250 n_atoms: int,
251 set_coordinates: bool,
252 set_time: bool,
253 set_cell: bool,
254 set_velocities: bool,
255 set_kineticEnergy: bool,
256 set_potentialEnergy: bool,
257 set_temperature: bool,
258 set_alchemicalLambda: bool,
259 ) -> None:
260 # Local Folder Imports
261 from .._version import get_versions
263 version = get_versions()["version"]
264 self._n_atoms = n_atoms
266 self._handle.root._v_attrs.conventions = "Pande"
267 self._handle.root._v_attrs.conventionVersion = "1.1"
268 self._handle.root._v_attrs.program = "MDTraj"
269 self._handle.root._v_attrs.programVersion = version
270 self._handle.root._v_attrs.title = "title"
272 # if the client has not the title attribute themselves, we'll
273 # set it to MDTraj as a default option.
274 if not hasattr(self._handle.root._v_attrs, "application"):
275 self._handle.root._v_attrs.application = "MDTraj"
277 # create arrays that store frame level informat
278 if set_coordinates:
279 self._create_earray(
280 where="/",
281 name=f"coordinates_{group_id}",
282 atom=self.tables.Float32Atom(),
283 shape=(0, self._n_atoms, 3),
284 )
285 getattr(self._handle.root, f"coordinates_{group_id}").attrs[
286 "units"
287 ] = "nanometers"
289 if set_time:
290 self._create_earray(
291 where="/",
292 name=f"time_{group_id}",
293 atom=self.tables.Float32Atom(),
294 shape=(0,),
295 )
296 getattr(self._handle.root, f"time_{group_id}").attrs[
297 "units"
298 ] = "picoseconds"
300 if set_cell:
301 self._create_earray(
302 where="/",
303 name=f"cell_lengths_{group_id}",
304 atom=self.tables.Float32Atom(),
305 shape=(0, 3),
306 )
307 self._create_earray(
308 where="/",
309 name=f"cell_angles_{group_id}",
310 atom=self.tables.Float32Atom(),
311 shape=(0, 3),
312 )
313 getattr(self._handle.root, f"cell_lengths_{group_id}").attrs[
314 "units"
315 ] = "nanometers"
316 getattr(self._handle.root, f"cell_angles_{group_id}").attrs[
317 "units"
318 ] = "degrees"
320 if set_velocities:
321 self._create_earray(
322 where="/",
323 name=f"velocities_{group_id}",
324 atom=self.tables.Float32Atom(),
325 shape=(0, self._n_atoms, 3),
326 )
327 getattr(self._handle.root, f"velocities_{group_id}").attrs[
328 "units"
329 ] = "nanometers/picosecond"
331 if set_kineticEnergy:
332 self._create_earray(
333 where="/",
334 name=f"kineticEnergy_{group_id}",
335 atom=self.tables.Float32Atom(),
336 shape=(0,),
337 )
338 getattr(self._handle.root, f"kineticEnergy_{group_id}").attrs[
339 "units"
340 ] = "kilojoules_per_mole"
342 if set_potentialEnergy:
343 self._create_earray(
344 where="/",
345 name=f"potentialEnergy_{group_id}",
346 atom=self.tables.Float32Atom(),
347 shape=(0,),
348 )
349 getattr(self._handle.root, f"potentialEnergy_{group_id}").attrs[
350 "units"
351 ] = "kilojoules_per_mole"
353 if set_temperature:
354 self._create_earray(
355 where="/",
356 name=f"temperature_{group_id}",
357 atom=self.tables.Float32Atom(),
358 shape=(0,),
359 )
360 getattr(self._handle.root, f"temperature_{group_id}").attrs[
361 "units"
362 ] = "kelvin"
364 if set_alchemicalLambda:
365 self._create_earray(
366 where="/",
367 name=f"lambda_{group_id}",
368 atom=self.tables.Float32Atom(),
369 shape=(0,),
370 )
371 self._get_node("/", name=f"lambda_{group_id}").attrs["units"] = (
372 "dimensionless"
373 )
375 def write_into_group(
376 self,
377 group_id: str,
378 coordinates: np.ndarray,
379 time: np.ndarray,
380 cell_lengths: np.ndarray,
381 cell_angles: np.ndarray,
382 topology: md.Topology,
383 ) -> None:
384 # Third Party Imports
385 from mdtraj.utils import ensure_type, in_units_of
387 _check_mode(self.mode, ("w", "a"))
389 if self.mode == "a":
390 try:
391 self._frame_index = len(
392 getattr(self._handle.root, f"coordinates_{group_id}")
393 )
394 self._needs_initialization = False
395 except self.tables.NoSuchNodeError:
396 self._frame_index = 0
397 self._needs_initialization = True
399 # these must be either both present or both absent. since
400 # we're going to throw an error if one is present w/o the other,
401 # lets do it now.
402 if cell_lengths is None and cell_angles is not None:
403 raise ValueError("cell_lengths were given, but no cell_angles")
404 if cell_lengths is not None and cell_angles is None:
405 raise ValueError("cell_angles were given, but no cell_lengths")
407 # if the input arrays are openmm.unit.Quantities, convert them
408 # into md units. Note that this acts as a no-op if the user doesn't
409 # have openmm.unit installed (e.g. they didn't install OpenMM)
410 coordinates = in_units_of(coordinates, None, "nanometers")
411 time = in_units_of(time, None, "picoseconds")
412 cell_lengths = in_units_of(cell_lengths, None, "nanometers")
413 cell_angles = in_units_of(cell_angles, None, "degrees")
415 # do typechecking and shapechecking on the arrays
416 # this ensure_type method has a lot of options, but basically it lets
417 # us validate most aspects of the array. Also, we can upconvert
418 # on defficent ndim, which means that if the user sends in a single
419 # frame of data (i.e. coordinates is shape=(n_atoms, 3)), we can
420 # realize that. obviously the default mode is that they want to
421 # write multiple frames at a time, so the coordinate shape is
422 # (n_frames, n_atoms, 3)
423 coordinates = ensure_type(
424 coordinates,
425 dtype=np.float32,
426 ndim=3,
427 name="coordinates",
428 shape=(None, None, 3),
429 can_be_none=False,
430 warn_on_cast=False,
431 add_newaxis_on_deficient_ndim=True,
432 )
433 (
434 n_frames,
435 n_atoms,
436 ) = coordinates.shape[0:2]
437 time = ensure_type(
438 time,
439 dtype=np.float32,
440 ndim=1,
441 name="time",
442 shape=(n_frames,),
443 can_be_none=True,
444 warn_on_cast=False,
445 add_newaxis_on_deficient_ndim=True,
446 )
447 cell_lengths = ensure_type(
448 cell_lengths,
449 dtype=np.float32,
450 ndim=2,
451 name="cell_lengths",
452 shape=(n_frames, 3),
453 can_be_none=True,
454 warn_on_cast=False,
455 add_newaxis_on_deficient_ndim=True,
456 )
457 cell_angles = ensure_type(
458 cell_angles,
459 dtype=np.float32,
460 ndim=2,
461 name="cell_angles",
462 shape=(n_frames, 3),
463 can_be_none=True,
464 warn_on_cast=False,
465 add_newaxis_on_deficient_ndim=True,
466 )
468 # if this is our first call to write(), we need to create the headers
469 # and the arrays in the underlying HDF5 file
470 if self._needs_initialization:
471 self._initialize_headers(
472 group_id=group_id,
473 n_atoms=n_atoms,
474 set_coordinates=True,
475 set_time=(time is not None),
476 set_cell=(cell_lengths is not None or cell_angles is not None),
477 set_velocities=False,
478 set_kineticEnergy=False,
479 set_potentialEnergy=False,
480 set_temperature=False,
481 set_alchemicalLambda=False,
482 )
483 self._needs_initialization = False
485 # we need to check that that the entries that the user is trying
486 # to save are actually fields in OUR file
488 try:
489 # try to get the nodes for all the fields that we have
490 # which are not None
491 names = [
492 f"coordinates_{group_id}",
493 f"time_{group_id}",
494 f"cell_angles_{group_id}",
495 f"cell_lengths_{group_id}",
496 ]
497 for name in names:
498 contents = locals()[name.replace(f"_{group_id}", "")]
499 if contents is not None:
500 self._get_node(where="/", name=name).append(contents)
501 if contents is None:
502 # for each attribute that they're not saving, we want
503 # to make sure the file doesn't explect it
504 try:
505 self._get_node(where="/", name=name)
506 raise AssertionError()
507 except self.tables.NoSuchNodeError:
508 pass
509 except self.tables.NoSuchNodeError:
510 raise ValueError(
511 "The file that you're trying to save to doesn't "
512 "contain the field %s. You can always save a new trajectory "
513 "and have it contain this information, but I don't allow 'ragged' "
514 "arrays. If one frame is going to have %s information, then I expect "
515 "all of them to. So I can't save it for just these frames. Sorry "
516 "about that :)" % (name, name)
517 )
518 except AssertionError:
519 raise ValueError(
520 "The file that you're saving to expects each frame "
521 "to contain %s information, but you did not supply it."
522 "I don't allow 'ragged' arrays. If one frame is going "
523 "to have %s information, then I expect all of them to. " % (name, name)
524 )
526 self._frame_index += n_frames
527 self.flush()
528 self.write_topology(group_id, topology)
530 def write_topology(
531 self,
532 group_id: str,
533 topology_object: md.Topology,
534 ) -> None:
535 """Writes the topology into the group_id.
537 Args:
538 group_id (str): The name of the group. Normally 'topology' is
539 used for single traj HDF5 files. Can also be 'topology_<traj_num>',
540 where <traj_num> is the traj_num of a trajectory.
541 topology_object (md.Topology): The topology to put into the group.
543 """
544 _check_mode(self.mode, ("w", "a"))
546 try:
547 node = self._handle.get_node("/", name=f"topology_{group_id}")
548 except tables.NoSuchNodeError:
549 pass
550 else:
551 if self.mode != "a":
552 raise Exception(
553 f"File already exists and has trajectory information. "
554 f"Set `overwrite` to True to overwrite."
555 )
556 self._handle.remove_node("/", name=f"topology_{group_id}")
558 # we want to be able to handle the openmm Topology object
559 # here too, so if it's not an mdtraj topology we'll just guess
560 # that it's probably an openmm topology and convert
561 if not isinstance(topology_object, md.Topology):
562 topology_object = md.Topology.from_openmm(topology_object)
564 try:
565 topology_dict = {"chains": [], "bonds": []}
567 for chain in topology_object.chains:
568 chain_dict = {"residues": [], "index": int(chain.index)}
569 for residue in chain.residues:
570 residue_dict = {
571 "index": int(residue.index),
572 "name": str(residue.name),
573 "atoms": [],
574 "resSeq": int(residue.resSeq),
575 "segmentID": str(residue.segment_id),
576 }
578 for atom in residue.atoms:
579 try:
580 element_symbol_string = str(atom.element.symbol)
581 except AttributeError:
582 element_symbol_string = ""
584 residue_dict["atoms"].append(
585 {
586 "index": int(atom.index),
587 "name": str(atom.name),
588 "element": element_symbol_string,
589 }
590 )
591 chain_dict["residues"].append(residue_dict)
592 topology_dict["chains"].append(chain_dict)
594 for atom1, atom2 in topology_object.bonds:
595 topology_dict["bonds"].append([int(atom1.index), int(atom2.index)])
597 except AttributeError as e:
598 raise AttributeError(
599 "topology_object fails to implement the"
600 "chains() -> residue() -> atoms() and bond() protocol. "
601 "Specifically, we encountered the following %s" % e
602 )
604 # actually set the tables
605 try:
606 self._remove_node(where="/", name="topology")
607 except self.tables.NoSuchNodeError:
608 pass
610 data = json.dumps(topology_dict)
611 if not isinstance(data, bytes):
612 data = data.encode("ascii")
614 if self.tables.__version__ >= "3.0.0":
615 self._handle.create_array(
616 where="/", name=f"topology_{group_id}", obj=[data]
617 )
618 else:
619 self._handle.createArray(
620 where="/", name=f"topology_{group_id}", object=[data]
621 )
623 def read(self, traj_num: int):
624 # Third Party Imports
625 from mdtraj.utils import in_units_of
627 def get_field(name, slice, out_units, can_be_none=True):
628 try:
629 node = self._get_node(where="/", name=name + f"_{traj_num}")
630 data = node.__getitem__(slice)
631 in_units = node.attrs.units
632 if not isinstance(in_units, string_types):
633 in_units = in_units.decode()
634 data = in_units_of(data, in_units, out_units)
635 return data
636 except self.tables.NoSuchNodeError:
637 if can_be_none:
638 return None
639 raise
641 out = {
642 "coordinates": get_field(
643 "coordinates",
644 (slice(None), slice(None), slice(None)),
645 out_units="nanometers",
646 can_be_none=False,
647 ),
648 "time": get_field(
649 "time", slice(None), out_units="picoseconds", can_be_none=False
650 ),
651 "cell_lengths": get_field(
652 "cell_lengths",
653 (slice(None), slice(None)),
654 out_units="nanometers",
655 can_be_none=False,
656 ),
657 "cell_angles": get_field(
658 "cell_angles",
659 (slice(None), slice(None)),
660 out_units="degrees",
661 can_be_none=False,
662 ),
663 }
664 return out
666 def read_topology(
667 self,
668 group_id: str,
669 ) -> md.Topology:
670 # Third Party Imports
671 import mdtraj.core.element as elem
673 try:
674 raw = self._get_node("/", name=group_id)[0]
675 if not isinstance(raw, string_types):
676 raw = raw.decode()
677 topology_dict = json.loads(raw)
678 except self.tables.NoSuchNodeError:
679 return None
681 topology = md.Topology()
683 for chain_dict in sorted(
684 topology_dict["chains"], key=operator.itemgetter("index")
685 ):
686 chain = topology.add_chain()
687 for residue_dict in sorted(
688 chain_dict["residues"], key=operator.itemgetter("index")
689 ):
690 try:
691 resSeq = residue_dict["resSeq"]
692 except KeyError:
693 resSeq = None
694 warnings.warn(
695 "No resSeq information found in HDF file, defaulting to zero-based indices"
696 )
697 try:
698 segment_id = residue_dict["segmentID"]
699 except KeyError:
700 segment_id = ""
701 residue = topology.add_residue(
702 residue_dict["name"], chain, resSeq=resSeq, segment_id=segment_id
703 )
704 for atom_dict in sorted(
705 residue_dict["atoms"], key=operator.itemgetter("index")
706 ):
707 try:
708 element = elem.get_by_symbol(atom_dict["element"])
709 except KeyError:
710 element = elem.virtual
711 topology.add_atom(atom_dict["name"], element, residue)
713 atoms = list(topology.atoms)
714 for index1, index2 in topology_dict["bonds"]:
715 topology.add_bond(atoms[index1], atoms[index2])
717 return topology
719 def read_trajs(self) -> Sequence[md.Trajectory]:
720 # Third Party Imports
721 from mdtraj.core.trajectory import Trajectory
723 nodes = [n.name for n in self._handle.list_nodes("/") if n.name != "CVs"]
724 traj_nums = []
725 trajs = {}
726 for node in nodes:
727 traj_nums.extend(re.findall(r"\d+", node))
728 traj_nums = list(sorted(map(int, set(traj_nums))))
729 for traj_num in traj_nums:
730 topology = self.read_topology(f"topology_{traj_num}")
731 data = self.read(traj_num)
732 trajs[traj_num] = Trajectory(
733 xyz=data["coordinates"],
734 topology=topology,
735 time=data["time"],
736 unitcell_lengths=data["cell_lengths"],
737 unitcell_angles=data["cell_angles"],
738 )
739 return trajs
741 def read_traj(self, traj_num: int) -> md.Trajectory:
742 # Third Party Imports
743 from mdtraj.core.trajectory import Trajectory
745 topology = self.read_topology(f"topology_{traj_num}")
746 data = self.read(traj_num)
747 traj = Trajectory(
748 xyz=data["coordinates"],
749 topology=topology,
750 time=data["time"],
751 unitcell_lengths=data["cell_lengths"],
752 unitcell_angles=data["cell_angles"],
753 )
754 return traj
757class TrajEnsembleTsel:
758 def __init__(self, other):
759 self.other = other
761 def __getitem__(self, item: CanBeIndex) -> Union[TrajEnsemble, SingleTraj]:
762 items = np.array(list(self.other.trajs_by_traj_num.keys()))
763 if isinstance(item, (int, np.int64)):
764 if item not in items:
765 raise ValueError(
766 f"No trajectories with traj_num {item} in TrajEnsemble {self.other} "
767 f"with trajectories: {items}"
768 )
769 return self.other.trajs_by_traj_num[item]
770 elif isinstance(item, (list, np.ndarray)):
771 if np.asarray(item).ndim == 1:
772 idx = np.where(np.in1d(items, np.asarray(item)))[0]
773 else:
774 return self.other._pyemma_indexing_tsel(item)
775 elif isinstance(item, slice):
776 raise NotImplementedError("Currently can't index trajs with slice.")
777 else:
778 raise ValueError(
779 f"The `tsel[]` method of `TrajEnsmeble` takes {CanBeIndex} types, "
780 f"but {type(item)} was provided."
781 )
782 if len(idx) == 0:
783 raise ValueError(
784 f"No trajs with traj_nums {item} in TrajEnsmble {self.other} "
785 f"with trajectories: {items}"
786 )
787 return self.other[idx]
790class TrajEnsemble:
791 """A fancy list of single trajectories. Topologies can be different across trajs.
793 Check out http://statisticalbiophysicsblog.org/?p=92 for why trajectory ensembles are awesome.
795 This class is a fancy list of :obj:`encodermap.trajinfo.info_single.SingleTraj``.
796 Trajectories can have different topologies and will be grouped by
797 the ``common_str`` argument. Each trajectory has its own unique ``traj_num``,
798 which identifies it in the ensemble - even when the ensemble is sliced or
799 subsampled.
801 Examples:
802 >>> import encodermap as em
803 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG")
804 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF")
806 Addition of two :obj:`encodermap.trajinfo.info_single.SingleTraj` also creates an ensemble.
808 >>> trajs = traj1 + traj2
809 >>> trajs # doctest: +ELLIPSIS
810 <encodermap.TrajEnsemble object. Current backend is no_load. Containing 2 trajectories. Common str is ['1YUG', '1YUF']. Not containing any CVs...>
812 Indexing a :obj:`TrajEnsemble` returns a :obj:`encodermap.trajinfo.info_single.SingleTraj`
813 based on its 0-based index. Think of the :obj:`TrajEnsmeble` as a list of
814 :obj:`encodermap.trajinfo.info_single.SingleTraj`. But trajectories can also
815 have ``traj_nums``, which do not have to adhere to ``[0, 1, 2, ...]``. This
816 is similar to how a :obj:`pandas.DataFrame` offers indexing via ``.loc[]``
817 and ``.iloc[]`` (https://pandas.pydata.org/docs/user_guide/indexing.html#different-choices-for-indexing).
818 For indexing trajs based on their ``traj_num``, you can use the ``.tsel[]``
819 accessor of the :obj:`TrajEnsmeble`
821 Examples:
822 >>> import encodermap as em
823 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG")
824 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF")
826 Addition of two `SingleTraj` also creates an ensemble.
828 >>> trajs = traj1 + traj2
829 >>> trajs.traj_nums
830 [0, 1]
832 Change the ``traj_num`` of ``traj2``
834 >>> trajs[1].traj_num = 4
835 >>> trajs.traj_nums
836 [0, 4]
837 >>> trajs[1] # doctest: +ELLIPSIS
838 <encodermap.SingleTraj object. Currently not in memory. Basename is '1YUF'. Not containing any CVs. Common string is '1YUF'. Object at ...>
839 >>> trajs.tsel[4] # doctest: +ELLIPSIS
840 <encodermap.SingleTraj object. Currently not in memory. Basename is '1YUF'. Not containing any CVs. Common string is '1YUF'. Object at ...>
842 :obj:`TrajEnsemble` supports fancy indexing. You can slice to your liking
843 (``trajs[::5]`` returns a :obj:`TrajEnsemble` object that only consideres
844 every fifth frame). Besides indexing by slices and integers, you can pass a
845 2-dimensional :obj:`numpy.ndarray`. ``np.array([[0, 5], [1, 10], [5, 20]])``
846 will return a :obj:`TrajEnsemble` object with frame 5 of trajectory 0, frame
847 10 of trajectory 1 and frame 20 of trajectory 5.
849 Examples:
850 >>> import encodermap as em
851 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG")
852 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF")
853 >>> trajs = traj1 + traj2
854 >>> sel = trajs[[[0, 0], [0, 1], [0, 2], [1, 10]]]
855 >>> sel # doctest: +ELLIPSIS
856 <encodermap.TrajEnsemble object. Current backend is no_load. Containing 4 frames and 2 trajectories. Common str is...>
859 The :obj:`TrajEnsemble` class also is an iterator to iterate over trajectores.
860 Besides plain iteration, the :obj:`TrajEnsmeble` also offers alternate iterators.
861 The ``itertrajs()`` iterator returns a two-tuple of ``traj_num`` and ``traj``.
862 The ``iterframes()`` iterator returns a three-tuple of ``traj_num``,
863 ``frame_num``, and ``traj``.
865 Examples:
866 >>> import encodermap as em
867 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG")
868 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF")
869 >>> trajs = traj1 + traj2
870 >>> trajs[1].traj_num = 4
871 >>> for traj_num, traj in trajs.itertrajs():
872 ... print(traj_num, traj.n_frames)
873 0 15
874 4 16
875 >>> for traj_num, frame_num ,traj in trajs.subsample(10).iterframes():
876 ... print(traj_num, frame_num, traj.n_frames)
877 0 0 1
878 0 10 1
879 4 0 1
880 4 10 1
882 The :obj:`TrajEnsemble` has multiple alternative constructors. The
883 :meth:`with_overwrite_trajnums` constructor fixes inhomogeneous sequences of
884 :obj:`encodermap.trajinfo.info_single.SingleTraj` and :obj:`TrajEnsemble`.
886 Examples:
887 >>> import encodermap as em
888 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG", traj_num=0)
889 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF", traj_num=0)
890 >>> trajs = em.TrajEnsemble([traj1, traj2]) # doctest: +IGNORE_EXCEPTION_DETAIL, +ELLIPSIS, +NORMALIZE_WHITESPACE
891 Traceback (most recent call last):
892 ...
893 Exception: The `traj_num` attributes of the provided 2 `SingleTraj`s is not unique, the `traj_num` 0 occurs 2 times. This can happen, if you use `SingleTraj`s, that are already part of a `TrajEnsemble`. To create copies of the `SingleTraj`s and overwrite their `traj_num`s, use the `with_overwrite_trajnums()` constructor.
894 >>> trajs = em.TrajEnsemble.with_overwrite_trajnums(traj1, traj2)
895 >>> trajs # doctest: +ELLIPSIS
896 <encodermap.TrajEnsemble...>
898 The :meth:`from_dataset` constructor can be used to load an ensemble from
899 an ``.h5`` file
901 Examples:
902 >>> import encodermap as em
903 >>> from tempfile import TemporaryDirectory
904 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG")
905 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF")
906 >>> trajs = em.TrajEnsemble([traj1, traj2])
907 >>> with TemporaryDirectory() as td:
908 ... trajs.save(td + "/trajs.h5")
909 ... new = em.TrajEnsemble.from_dataset(td + "/trajs.h5")
910 ... print(new) # doctest: +ELLIPSIS
911 encodermap.TrajEnsemble object. Current backend is no_load. Containing 2 trajectories. Common str is...Not containing any CVs.
913 Attributes:
914 CVs (dict[str, np.ndarray]): The collective variables of the ``SingleTraj``
915 classes. Only CVs with matching names in all ``SingleTraj`` classes
916 are returned. The data is stacked along a hypothetical time axis
917 along the trajs.
918 _CVs (xarray.Dataset): The same data as in CVs but with labels.
919 Additionally, the xarray is not stacked along the time axis.
920 It contains an extra dimension for trajectories.
921 n_trajs (int): Number of individual trajectories in this class.
922 n_frames (int): Number of frames, sum over all trajectories.
923 locations (list[str]): A list with the locations of the trajectories.
924 top (list[mdtraj.Topology]): A list with the reference pdb for each trajecotry.
925 basenames (list[str]): A list with the names of the trajecotries.
926 The leading path and the file extension is omitted.
927 name_arr (np.ndarray): An array with ``len(name_arr) == n_frames``.
928 This array keeps track of each frame in this object by identifying each
929 frame with a filename. This can be useful, when frames are mixed inside
930 a :obj:`TrajEnsemble` class.
932 """
934 def __init__(
935 self,
936 trajs: Union[
937 Sequence[str],
938 Sequence[Path],
939 Sequence[md.Trajectory],
940 Sequence[SingleTraj],
941 ],
942 tops: Union[None, Sequence[str], Sequence[Path]] = None,
943 backend: Literal["mdtraj", "no_load"] = "no_load",
944 common_str: Optional[Sequence[str]] = None,
945 basename_fn: Optional[Callable[[str], str]] = None,
946 traj_nums: Optional[Sequence[int]] = None,
947 custom_top: Optional[CustomAAsDict] = None,
948 ) -> None:
949 """Instantiate the :obj:`TrajEnsmeble` class with two lists of files.
951 Args:
952 trajs (Union[Sequence[str], Sequence[md.Trajectory],
953 Sequence[SingleTraj], Sequence[Path]]): List of strings with
954 paths to trajectories. Can also be a list of md.Trajectory or
955 em.SingleTraj.
956 tops (Optional[list[str]]): List of strings with paths to reference pdbs.
957 backend (str, optional): Choose the backend to load trajectories:
958 - 'mdtraj' uses mdtraj, which loads all trajectories into RAM.
959 - 'no_load' creates an empty trajectory object.
960 Defaults to 'no_load', which makes the instantiation of large
961 ensembles fast and RAM efficient.
962 common_str (list[str], optional): If you want to include trajectories with
963 different topology. The common string is used to pair traj-files
964 (``.xtc, .dcd, .lammpstrj, ...``) with their topology
965 (``.pdb, .gro, ...``). The common-string should be a substring
966 of matching traj and topology files.
967 basename_fn (Union[None, Callable[[str], str], optional): A function
968 to apply to the trajectory file path string to return the basename
969 of the trajectory. If None is provided, the filename without
970 extension will be used. When all files are named the same and
971 the folder they're in defines the name of the trajectory, you
972 can supply ``lambda x: split('/')[-2]`` as this argument.
973 Defaults to None.
974 custom_top: Optional[CustomAAsDict]: An instance of the
975 :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology` or a
976 dictionary that can be made into such.
978 """
979 # defaults
980 # Local Folder Imports
981 from .info_single import SingleTraj
983 # check if h5file might be a complete dataset
984 if isinstance(trajs, (str, Path)):
985 if Path(trajs).suffix == ".h5":
986 return TrajEnsemble.from_dataset(trajs, basename_fn=basename_fn)
988 if tops == []:
989 raise Exception(
990 f"Your list of topology files is empty: {tops=}. Pass None, if "
991 f"your trajectories are all .pdb/.gro files."
992 )
994 self.backend = backend
996 # custom topology to load dihedral angles
997 self._custom_top = custom_top
999 # set the trajnums
1000 if traj_nums is not None:
1001 # Standard Library Imports
1002 from copy import deepcopy
1004 if not len(traj_nums) == len(trajs):
1005 raise Exception(
1006 f"Uneven length of `traj_nums` ({len(traj_nums)} "
1007 f"and `trajs` ({len(trajs)}) provided."
1008 )
1009 if all([isinstance(t, SingleTraj) for t in trajs]):
1010 trajs_ = []
1011 for n, t in zip(traj_nums, trajs):
1012 t = deepcopy(t)
1013 t.traj_num = n
1014 trajs_.append(t)
1015 trajs = trajs_
1017 # make sure, that traj_nums are not duplicated
1018 elif all([isinstance(t, SingleTraj) for t in trajs]) and isinstance(
1019 trajs, (list, tuple)
1020 ):
1021 if any([t.traj_num is None for t in trajs]) and any(
1022 [isinstance(t.traj_num, int) for t in trajs]
1023 ):
1024 raise Exception(
1025 f"The `SingleTraj`s you provided have bad `traj_num`s "
1026 f"one has `None`, the others have int: {[t.traj_num for t in trajs]}"
1027 )
1028 if not all([(i.traj_num is None) for i in trajs]):
1029 uniques, counts = np.unique(
1030 np.asarray([t.traj_num for t in trajs]), return_counts=True
1031 )
1032 if np.any(counts > 1):
1033 ex_num = uniques[np.argmax(counts)]
1034 raise Exception(
1035 f"The `traj_num` attributes of the provided {len(trajs)} `SingleTraj`s are "
1036 f"not unique, the `traj_num` {ex_num} occurs {np.max(counts)} times. "
1037 f"This can happen, if you use `SingleTraj`s, that are already part of "
1038 f"a `TrajEnsemble`. To create copies of the `SingleTraj`s and over"
1039 f"write their `traj_num`s, use the `with_overwrite_trajnums()` constructor."
1040 )
1041 trajs = list(sorted(trajs, key=lambda x: x.traj_num))
1043 # basename function
1044 if basename_fn is None:
1045 basename_fn = lambda x: os.path.basename(x).split(".")[0]
1046 self.basename_fn = basename_fn
1048 # common string
1049 if common_str is None:
1050 common_str = []
1051 if isinstance(common_str, str):
1052 self.common_str = [common_str]
1053 else:
1054 self.common_str = common_str
1056 # loading with setters
1057 if tops is None:
1058 tops = []
1059 self._top_files = tops
1060 if all([isinstance(traj, str) for traj in trajs]):
1061 if self._top_files == [] and all(
1062 ["." + top.split(".")[-1] in _TOPOLOGY_EXTS for top in trajs]
1063 ):
1064 self._top_files = trajs
1065 if all([isinstance(traj, Path) for traj in trajs]) and not tops:
1066 self._top_files = [str(t) for t in trajs]
1067 if isinstance(tops, str):
1068 self._top_files = [tops]
1069 self.traj_files = trajs
1071 # set the traj nums afterwards
1072 if traj_nums is not None:
1073 for i, traj in zip(traj_nums, self.trajs):
1074 traj.traj_num = i
1076 @classmethod
1077 def with_overwrite_trajnums(
1078 cls,
1079 *trajs: Union[TrajEnsemble, "SingleTraj"],
1080 ) -> TrajEnsemble:
1081 """Creates a :obj:`TrajEnsemble` by copying the provided
1082 :obj:`encodermap.trajinfo.info_single.SingleTraj` instances and
1083 changing their ``traj_num`` attribute to adhere to ``[0, 1, 2, ...]``.
1085 Args:
1086 trajs (Sequence[SingleTraj]): The sequence of trajs.
1088 Returns:
1089 TrajEnsemble: A :obj:`TrajEnsemble` instance.
1091 """
1092 # Standard Library Imports
1093 from copy import deepcopy
1095 # Local Folder Imports
1096 from .info_single import SingleTraj
1098 new_trajs = []
1099 i = 0
1100 for t in trajs:
1101 if isinstance(t, SingleTraj):
1102 t = deepcopy(t)
1103 t.traj_num = i
1104 new_trajs.append(t)
1105 i += 1
1106 elif isinstance(t, TrajEnsemble):
1107 for traj in t:
1108 traj = deepcopy(traj)
1109 traj.traj_num = i
1110 new_trajs.append(traj)
1111 i += 1
1112 else:
1113 raise TypeError(
1114 f"Classmethod `with_overwrite_trajnums` can only accept `SingleTraj` "
1115 f"and `TrajEnsemble`, but {t} is not an instance of either."
1116 )
1117 return cls(new_trajs)
1119 @classmethod
1120 def from_textfile(
1121 cls,
1122 fname: Union[str, Path],
1123 basename_fn: Optional[Callable[[str], str]] = None,
1124 ) -> TrajEnsemble:
1125 """Creates a :obj:`TrajEnsemble` object from a textfile.
1127 The textfile needs to be space-separated with two or three columns:
1128 - Column 1:
1129 The trajectory file.
1130 - Column 2:
1131 The corresponding topology file (If you are using ``.h5`` trajs,
1132 column 1 and 2 will be identical, but column 2 needs to be there
1133 nonetheless).
1134 - Column 3:
1135 The common string of the trajectory. This column can be left
1136 out, which will result in an :obj:`TrajEnsemble` without common
1137 strings.
1139 Args:
1140 fname (Union[str, Path]): File to be read.
1141 basename_fn (Union[None, Callable[[str], str]], optional): A function
1142 to apply to the ``traj_file`` string to return the basename of
1143 the trajectory. If None is provided, the filename without
1144 extension will be used. When all files are named the same and
1145 the folder they're in defines the name of the trajectory, you
1146 can supply ``lambda x: split('/')[-2]`` as this argument.
1147 Defaults to None.
1149 Returns:
1150 TrajEnsemble: A :obj:`TrajEnsemble` instance.
1152 """
1153 # Local Folder Imports
1154 from ..trajinfo import info_single
1156 traj_files = []
1157 top_files = []
1158 common_str = []
1159 traj_nums = []
1161 with open(fname, "r") as f:
1162 for row in f:
1163 traj_files.append(row.split()[0])
1164 top_files.append(row.split()[1])
1165 try:
1166 traj_nums.append(int(row.split()[2]))
1167 except ValueError:
1168 traj_nums.append(None)
1169 try:
1170 common_str.append(row.split()[3])
1171 except IndexError:
1172 common_str.append("")
1174 trajs = []
1175 for i, (traj_file, top_file, cs, traj_num) in enumerate(
1176 zip(traj_files, top_files, common_str, traj_nums)
1177 ):
1178 trajs.append(info_single.SingleTraj(traj_file, top_file, cs, traj_num))
1180 return cls(
1181 trajs, common_str=np.unique(common_str).tolist(), basename_fn=basename_fn
1182 )
1184 @classmethod
1185 def from_dataset(
1186 cls,
1187 fname: Union[str, Path],
1188 basename_fn: Optional[Callable[[str], str]] = None,
1189 ) -> TrajEnsemble:
1190 # Local Folder Imports
1191 from .info_single import SingleTraj
1193 traj_nums = []
1194 with h5py.File(fname) as h5file:
1195 for key in h5file.keys():
1196 if key == "CVs":
1197 continue
1198 traj_nums.extend(re.findall(r"\d+", key))
1199 traj_nums = list(sorted(map(int, set(traj_nums))))
1201 trajs = []
1202 for traj_num in traj_nums:
1203 trajs.append(
1204 SingleTraj(
1205 traj=fname,
1206 top=fname,
1207 traj_num=traj_num,
1208 basename_fn=basename_fn,
1209 )
1210 )
1211 common_str = list(set([t.common_str for t in trajs]))
1212 newclass = cls(trajs=trajs, common_str=common_str, basename_fn=basename_fn)
1213 return newclass
1215 def load_custom_topology(
1216 self,
1217 custom_top: Optional[Union[CustomTopology, CustomAAsDict]] = None,
1218 ) -> None:
1219 """Loads a custom_topology from a `CustomTopology` class or a dict.
1221 See Also:
1222 `CustomTopology`
1224 Args:
1225 custom_top: Optional[Union[CustomTopology, CustomAAsDict]]: An instance of the
1226 `CustomTopology` class or a dictionary that can be made into such.
1228 """
1229 for traj in self.trajs:
1230 traj.load_custom_topology(custom_top)
1232 @property
1233 def tsel(self):
1234 return TrajEnsembleTsel(self)
1236 def del_featurizer(self) -> None:
1237 """Deletes the current instance of ``self.featurizer``."""
1238 if hasattr(self, "_featurizer"):
1239 del self._featurizer
1241 @property
1242 def featurizer(self):
1243 # Local Folder Imports
1244 from ..loading.featurizer import Featurizer
1246 if not hasattr(self, "_featurizer"):
1247 self._featurizer = Featurizer(self)
1248 return self._featurizer
1250 @property
1251 def traj_files(self) -> list[str]:
1252 """list: A list of the traj_files of the individual SingleTraj classes."""
1253 return self._traj_files
1255 @property
1256 def top_files(self) -> list[str]:
1257 """list: Returns minimal set of topology files.
1259 If yoy want a list of top files with the same
1260 length as self.trajs use self._top_files and
1261 self._traj_files.
1263 """
1264 return list(dict.fromkeys(self._top_files))
1266 @traj_files.setter
1267 def traj_files(self, trajs):
1268 # Local Folder Imports
1269 from ..trajinfo import info_single
1271 traj_nums = np.arange(len(trajs))
1272 # fill this lists
1273 self.trajs = []
1275 if all([isinstance(traj, Path) for traj in trajs]):
1276 trajs = [str(traj) for traj in trajs]
1278 if all([isinstance(i, md.Trajectory) for i in trajs]):
1279 self.backend = "mdtraj"
1280 self.trajs = [
1281 info_single.SingleTraj(
1282 traj,
1283 traj_num=i,
1284 basename_fn=self.basename_fn,
1285 custom_top=self._custom_top,
1286 )
1287 for i, traj in zip(traj_nums, trajs)
1288 ]
1289 elif all([i.__class__.__name__ == "SingleTraj" for i in trajs]):
1290 self.trajs = trajs
1291 self._top_files = [traj.top_file for traj in self.trajs]
1292 self._traj_files = [traj.traj_file for traj in self.trajs]
1293 # check backends and common str
1294 if (
1295 not all([traj.backend == "no_load" for traj in trajs])
1296 or self.backend == "mdtraj"
1297 ):
1298 (traj.load_traj() for traj in trajs)
1299 for i, traj in zip(traj_nums, trajs):
1300 if traj.traj_num is None:
1301 traj.traj_num = i
1302 if traj._CVs:
1303 traj._CVs = traj._CVs.assign_coords(traj_num=[i])
1304 elif all([isinstance(i, str) for i in trajs]) and self.top_files:
1305 # find common_str matches in top_files and traj_files
1306 (
1307 self._traj_files,
1308 self._top_files,
1309 self._common_str,
1310 ) = get_full_common_str_and_ref(trajs, self._top_files, self.common_str)
1311 for i, t, top, cs in zip(
1312 traj_nums, self._traj_files, self._top_files, self._common_str
1313 ):
1314 t = info_single.SingleTraj(
1315 traj=t,
1316 top=top,
1317 backend=self.backend,
1318 common_str=cs,
1319 traj_num=i,
1320 basename_fn=self.basename_fn,
1321 custom_top=self._custom_top,
1322 )
1323 self.trajs.append(t)
1324 elif all([isinstance(i, str) for i in trajs]) and not self.top_files:
1325 for i, traj_file in zip(traj_nums, trajs):
1326 self.trajs.append(
1327 info_single.SingleTraj(
1328 traj=traj_file,
1329 basename_fn=self.basename_fn,
1330 traj_num=i,
1331 custom_top=self._custom_top,
1332 )
1333 )
1334 else:
1335 raise TypeError(
1336 "The objects in the list are not of the correct type or inconsistent. "
1337 f"You provided {[c.__class__.__name__ for c in trajs]}. "
1338 "Please provide a list of `str`, list of `mdtraj.Trajectory` or list of `SingleTraj`."
1339 )
1341 @property
1342 def top(self) -> list[md.Topology]:
1343 """list: Returns a minimal set of mdtraj.Topologies.
1345 If all trajectories share the same topology a list
1346 with len 1 will be returned.
1348 """
1349 out = []
1350 for traj in self.trajs:
1351 try:
1352 if traj.top not in out:
1353 out.append(traj.top)
1354 except IOError as e:
1355 if "no such file" in str(e).lower():
1356 raise e
1357 raise Exception(
1358 f"I have no idea how this IOError can occur. {self.trajs=}."
1359 ) from e
1360 return out
1362 @property
1363 def trajs_by_top(self) -> dict[md.Topology, TrajEnsemble]:
1364 """dict[md.Topology, TrajEnsemble]: Returns the trajs in `self` ordered by top.
1366 If all trajectories share the same topology, a dict with
1367 one key will be returned.
1369 """
1370 out = {}
1371 for traj in self.trajs:
1372 out.setdefault(traj.top, []).append(traj)
1373 out = {k: TrajEnsemble(v) for k, v in out.items()}
1374 for v in out.values():
1375 v.common_str = list(set([t.common_str for t in v]))
1376 return out
1378 @property
1379 def trajs_by_common_str(self) -> dict[Union[None, str], TrajEnsemble]:
1380 """dict[str, TrajEnsemble]: Returns the trajs in `self` ordered by top.
1382 If all trajectories share the same common_str, a dict with
1383 one key will be returned. As the common_str can be None, None can also
1384 occur as a key in this dict.
1386 """
1387 out = {}
1388 for traj in self.trajs:
1389 out.setdefault(traj.common_str, []).append(traj)
1390 out = {k: TrajEnsemble(v) for k, v in out.items()}
1391 return out
1393 def sidechain_info(self) -> dict[int, dict[int, Sequence[int]]]:
1394 """Indices used for the AngleDihedralCartesianEncoderMap class to
1395 allow training with multiple different sidechains.
1397 Returns:
1398 dict[str, Sequence[int]]: The indices. The key '-1' is used for
1399 the hypothetical convex hull of all feature spaces (the output of
1400 the tensorflow model). The other keys match the common_str of the
1401 trajs.
1403 Raises:
1404 Exception: When the common_strings and topologies are not
1405 aligned. An exception is raised. Aligned means that all trajs
1406 with the same common_str should possess the same topology.
1408 """
1409 # make sure no clustal w has not been loaded
1410 if any([hasattr(t, "clustal_w") for t in self.trajs]):
1411 raise NotImplementedError(
1412 f"This is currently not supported for TrajEsnembles with "
1413 f"clustal_w alignments."
1414 )
1415 else:
1416 max_residues = max([t.n_residues for t in self])
1418 # make sure CVs are loaded and contain the appropriate values
1419 should_be = {
1420 "central_cartesians",
1421 "central_dihedrals",
1422 "central_distances",
1423 "central_angles",
1424 "side_dihedrals",
1425 "side_cartesians",
1426 "side_distances",
1427 "side_angles",
1428 }
1429 diff = should_be - set(self._CVs.data_vars.keys())
1430 if len(diff) > 0:
1431 raise Exception(
1432 f"The TrajEnsemble misses these CVs to calculate the sidechain_info: "
1433 f"{list(diff)}. Please load them with `trajs.load_CVs({list(diff)})`."
1434 )
1436 # make sure we are using an ensemble with generic indices
1437 forbidden_names = set(
1438 [residue.name for traj in self for residue in traj.top.residues]
1439 )
1440 if (
1441 len(
1442 (
1443 offending := [
1444 label
1445 for label in self._CVs.central_distances.coords[
1446 "CENTRAL_DISTANCES"
1447 ].values
1448 if any(f in label for f in forbidden_names)
1449 ]
1450 )
1451 )
1452 > 0
1453 ):
1454 raise Exception(
1455 f"The CVs in this TrajEnsemble were not loaded with the `ensemble=True` "
1456 f"keyword in `trajs.load_CVs()`. Finding the sidechain_info in "
1457 f"such a set of CVs is not possible. The offending labels {offending} "
1458 f"contain residue names, which should not occur if CVs were loaded "
1459 f"with `ensemble=True`."
1460 )
1462 # the key -1 is the feature hull, telling tensorflow
1463 # how to create the branched chain of backbone and sidechains
1464 # i.e. how many sidechains there are per residue max
1465 # the other keys correspond to which sidechain atoms are non nan and
1466 # can be used when calculating the distance matrices per different atom
1467 out = {-1: {}}
1469 # the feature hull
1470 max_sidechains = self._CVs.side_dihedrals.coords["SIDE_DIHEDRALS"].values
1471 for residx in range(1, max_residues + 1):
1472 labels = [l for l in max_sidechains if l.endswith(" " + str(residx))]
1473 out[-1][residx] = len(labels)
1475 # for every traj
1476 for traj in self.trajs:
1477 assert (
1478 traj.traj_num not in out
1479 ), f"This traj has the same `traj_num` as another traj."
1480 out[traj.traj_num] = {}
1481 for residx in range(1, max_residues + 1):
1482 labels = [
1483 l
1484 for l in traj._CVs.side_dihedrals.coords["SIDE_DIHEDRALS"].values
1485 if l.endswith(" " + str(residx))
1486 ]
1487 first_frame = traj._CVs.coords["frame_num"].min()
1488 length = (
1489 traj._CVs.side_dihedrals.sel(
1490 frame_num=first_frame, SIDE_DIHEDRALS=labels
1491 )
1492 .dropna(dim="SIDE_DIHEDRALS")
1493 .sizes["SIDE_DIHEDRALS"]
1494 )
1495 out[traj.traj_num][residx] = length
1496 return out
1498 @property
1499 def trajs_by_traj_num(self) -> dict[int, SingleTraj]:
1500 out = {}
1501 for traj in self:
1502 out[traj.traj_num] = traj
1503 return out
1505 @property
1506 def id(self) -> np.ndarray:
1507 """np.ndarray: Duplication of self.index_arr"""
1508 return self.index_arr
1510 @property
1511 def n_residues(self) -> int:
1512 """list: List of number of residues of the SingleTraj classes"""
1513 return [traj.n_residues for traj in self.trajs]
1515 @property
1516 def basenames(self) -> list[str]:
1517 """list: List of the basenames in the Info single classes."""
1518 return [traj.basename for traj in self.trajs]
1520 @property
1521 def traj_nums(self) -> list[int]:
1522 """list: Number of info single classes in self."""
1523 return [traj.traj_num for traj in self.trajs]
1525 @property
1526 def n_trajs(self) -> int:
1527 """int: Number of trajectories in this ensemble."""
1528 return len(self.trajs)
1530 def to_alignment_query(self) -> str:
1531 """A string, that cen be put into sequence alignment software."""
1532 for cs, trajs in self.trajs_by_common_str.items():
1533 assert len(trajs.top) == 1, (
1534 f"Can't provide a query for a `TrajEnsemble`, where a common string "
1535 f"has multiple topologies. In this case, the common string '{cs}' "
1536 f"has these topologies: {trajs.top}. When you are unhappy with how "
1537 f"EncoderMap automatically applies common strings to trajectory "
1538 f"filenames, keep in mind that you can always generate a `TrajEnsemble` "
1539 f"from multiple `TrajEnsembles`. You can do: \n"
1540 f"trajs1 = em.load(\n"
1541 f" traj_files1,\n"
1542 f" top_files1,\n"
1543 f" common_str=['trajs1'],\n"
1544 f")\n"
1545 f"trajs2 = em.load(\n"
1546 f" traj_files2,\n"
1547 f" top_files2,\n"
1548 f" common_str=['trajs2'],\n"
1549 f")\n"
1550 f"trajs = trajs1 + trajs2\n"
1551 f"to force a `TrajEnsemble` to adhere to your common strings."
1552 )
1553 out = ""
1554 for cs, trajs in self.trajs_by_common_str.items():
1555 seq = trajs.top[0].to_fasta()
1556 for j, s in enumerate(seq):
1557 add = f">{cs}n{s}\n"
1558 return out
1560 def parse_clustal_w_alignment(self, aln: str) -> None:
1561 """Parse an alignment in ClustalW format and add the info to the trajectories.
1563 Args:
1564 aln (str): The alignment in ClustalW format.
1566 """
1567 self.clustal_w = aln
1568 # remove empty lines
1569 aln = "\n".join(
1570 list(filter(lambda x: not re.match(r"^\s*$", x), aln.splitlines()))
1571 )
1573 # every three lines represent one trajectory
1574 lines = aln.splitlines()
1575 n_lines = len(lines)
1576 assert (
1577 n_lines % (self.n_trajs + 1) == 0
1578 or n_lines % (len(self.common_str) + 1) == 0
1579 ), (
1580 f"The CLUSTAL W aln string, that you provided has the wrong number of "
1581 f"lines. I've expected to receive a multiple of {self.n_trajs + 1} ("
1582 f"which is the number of trajs ({self.n_trajs}) plus one for the score-"
1583 f"characters ' ', '.', ':', '*'), but the number of provided lines was "
1584 f"{n_lines}."
1585 )
1587 if "|" in aln:
1588 for i, (_, sub_trajs) in enumerate(self.trajs_by_top.items()):
1589 for cs, trajs in sub_trajs.trajs_by_common_str.items():
1590 for traj in trajs:
1591 for j, chain in enumerate(traj.top.chains):
1592 search = (
1593 f"{cs}|TrajNum_{traj.traj_num}Topology_{i}Chain_{j}"
1594 )
1595 data = ""
1596 for line in lines:
1597 if line.startswith(search):
1598 data += line.split()[-1]
1599 assert (
1600 test := len(data.replace("-", ""))
1601 ) == traj.n_residues, (
1602 f"The CLUSTAL W sequence {data} with {test} one-letter "
1603 f"residues has not the same number of residues as trajectory "
1604 f"{traj}, which has {traj.n_residues}"
1605 )
1606 traj.clustal_w = data
1607 else:
1608 for cs, trajs in self.trajs_by_common_str.items():
1609 search = cs
1610 data = ""
1611 for line in lines:
1612 if line.startswith(search):
1613 data += line.split()[-1]
1614 assert (test := len(data.replace("-", ""))) == trajs[0].n_residues, (
1615 f"The CLUSTAL W sequence {data} with {test} one-letter "
1616 f"residues has not the same number of residues as trajectory "
1617 f"{trajs[0]}, which has {trajs[0].n_residues}"
1618 )
1619 for traj in trajs:
1620 traj.clustal_w = data
1622 def del_CVs(self, CVs: Optional[Sequence[str]] = None) -> None:
1623 """Deletes all CVs in all trajs. Does not affect the files."""
1624 if CVs is None:
1625 for traj in self.trajs:
1626 traj.del_CVs()
1627 else:
1628 if not isinstance(CVs, (list, tuple)):
1629 CVs = [CVs]
1630 remove = deepcopy(CVs)
1631 for CV in CVs:
1632 remove.append(f"{CV}_feature_indices")
1633 for traj in self.trajs:
1634 traj._CVs = traj._CVs.drop_vars(remove, errors="ignore")
1636 @property
1637 def _CVs(self) -> xr.Dataset:
1638 """xarray.Dataset: Returns x-array Dataset of matching CVs. stacked
1639 along the trajectory-axis."""
1640 # Local Folder Imports
1641 from .trajinfo_utils import trajs_combine_attrs
1643 ds = xr.combine_nested(
1644 [traj._CVs for traj in self.trajs],
1645 concat_dim="traj_num",
1646 compat="broadcast_equals",
1647 fill_value=np.nan,
1648 coords="all",
1649 join="outer",
1650 combine_attrs=trajs_combine_attrs,
1651 )
1653 # if ensemble we don't need to reorder labels. That was already done
1654 # by the Featurizer. We know if we have an ensemble if the trajs
1655 # have features with nans (except the "feature_indices" dataarrays
1656 if any(
1657 [
1658 np.any(np.isnan(v.values))
1659 for traj in self
1660 for n, v in traj._CVs.data_vars.items()
1661 if "feature_indices" not in n
1662 ]
1663 ):
1664 return ds
1666 # sort the combined arrays
1667 new_label_order = {}
1668 non_indices_data_vars = [k for k in ds.keys() if "feature_indices" not in k]
1669 for k in non_indices_data_vars:
1670 if (ind_k := f"{k}_feature_indices") not in ds:
1671 continue
1672 argsort = []
1673 da = ds[ind_k]
1675 # ResidueMinDistanceFeature
1676 if "RES_NO" in da.coords:
1677 continue
1679 try:
1680 feature_axis_name = da.attrs["feature_axis"]
1681 except KeyError as e:
1682 raise Exception(f"There is no feature_axis attribute in {da=}") from e
1683 labels = da.coords[feature_axis_name].values
1684 if "ATOM_NO" not in da.coords and "frame_num" not in da.coords:
1685 iterable = da.values.T
1686 elif "ATOM_NO" not in da.coords and "frame_num" in da.coords:
1687 iterable = (
1688 da.stack({"frame": ("traj_num", "frame_num")})
1689 .transpose("frame", ...)
1690 .dropna("frame", how="all")
1691 )
1692 iterable = iterable.transpose(feature_axis_name, "frame").values
1693 elif "ATOM_NO" in da.coords and "frame_num" not in da.coords:
1694 iterable = da.transpose(feature_axis_name, "traj_num", "ATOM_NO").values
1695 else:
1696 iterable = (
1697 da.stack({"frame": ("traj_num", "frame_num")})
1698 .transpose("frame", ...)
1699 .dropna("frame", how="all")
1700 )
1701 iterable = iterable.transpose(feature_axis_name, "frame", "ATOM_NO")
1702 for i, (row, label) in enumerate(zip(iterable, labels)):
1703 if (
1704 "ATOM_NO" not in da.coords and "RES_NO" not in da.coords
1705 ): # selection feature
1706 row = row[~np.isnan(row)]
1707 uniques, counts = np.unique(row, return_counts=True)
1708 ind = uniques[np.argmax(counts)]
1709 argsort.append(ind)
1710 else:
1711 row = np.nan_to_num(row + 1, 0).sum(1)
1712 row = row[np.nonzero(row)]
1713 uniques, counts = np.unique(row, return_counts=True)
1714 if "side" in feature_axis_name:
1715 raise Exception(f"{new_label_order=}")
1716 try:
1717 ind = uniques[np.argmax(counts)]
1718 except ValueError as e:
1719 if "attempt to get argmax of an empty" in str(e):
1720 raise Exception(
1721 f"Can't order the data_var {k}, as the {row=} "
1722 f"has become empty ({row=})."
1723 )
1724 raise e
1725 argsort.append(ind)
1726 labels = labels[np.argsort(argsort)]
1727 new_label_order[feature_axis_name] = labels
1729 ds = ds.reindex(new_label_order, fill_value=np.nan)
1730 return ds
1732 def _calc_CV(self) -> dict[str, np.ndarray]:
1733 if (
1734 not all([traj.CVs for traj in self.trajs])
1735 or [traj.CVs for traj in self.trajs] == []
1736 ):
1737 return {}
1738 else:
1739 CVs = {}
1740 matching_keys = list(
1741 set.intersection(*[set(traj.CVs.keys()) for traj in self.trajs])
1742 )
1743 dropping_keys = set(matching_keys).difference(
1744 *[set(traj.CVs.keys()) for traj in self.trajs]
1745 )
1746 if dropping_keys:
1747 print(
1748 f"The CVs {dropping_keys} will not be in the `CVs` dictionary, "
1749 f"as they are only present in some, but not all of the {len(self.trajs)} "
1750 f"trajectories. You can access them with "
1751 f"`TrajEnsemble([t for t in trajs if any([cv in {dropping_keys} for cv in t.CVs.keys()])])`"
1752 )
1753 if matching_keys != []:
1754 for key in matching_keys:
1755 data = []
1756 for traj in self.trajs:
1757 data.append(traj._CVs[key].values)
1758 # check if all shapes are the same
1759 shapes = [d.shape[2:] for d in data]
1760 if not len(set(shapes)) == 1:
1761 print(
1762 f"I am not returning the CVs for the feature '{key}' "
1763 f"because some trajectories have different "
1764 f"shapes for these CVs. The shapes are {set(shapes)}. "
1765 f"If you want to access these CVs, use the `_CVs` "
1766 f"xarray dataset of `TrajEnsemble` instead."
1767 )
1768 continue
1769 if np.all(
1770 [
1771 any([isinstance(ind, int) for ind in traj.index])
1772 for traj in self.trajs
1773 ]
1774 ):
1775 data = np.vstack([d for d in data])
1776 if data.ndim <= 3:
1777 data = data.reshape(-1, data.shape[-1])
1778 else:
1779 data = data.reshape(-1, *data.shape[-2:])
1780 else:
1781 try:
1782 data = np.concatenate(
1783 [d.squeeze(axis=0) for d in data], axis=0
1784 )
1785 except ValueError as e:
1786 if "zero-dimensional" in str(e):
1787 data = np.hstack([d.squeeze(axis=0) for d in data])
1788 if "all the input arrays must have the same" in str(e):
1789 err_shapes = "\n".join(
1790 [
1791 f"Traj: {self.trajs[i].traj_num:<3} CV '{key}' shape: {d.shape}"
1792 for i, d in enumerate(data)
1793 ]
1794 )
1795 raise Exception(
1796 f"Can't concatenate the data of the CV '{key}'."
1797 f"The shapes of this CV for the individual "
1798 f"trajectories are:\n"
1799 f"{err_shapes}"
1800 )
1801 else:
1802 raise e
1803 if data.shape[-1] != 1:
1804 CVs[key] = data
1805 else:
1806 CVs[key] = data.squeeze(-1)
1807 return CVs
1809 @property
1810 def CVs(self) -> dict[str, np.ndarray]:
1811 """dict: Returns dict of CVs in SingleTraj classes. Only CVs with the same names
1812 in all SingleTraj classes are loaded.
1814 """
1815 return self._calc_CV()
1817 @property
1818 def locations(self) -> list[str]:
1819 """list: Duplication of self.traj_files but using the trajs own traj_file attribute.
1820 Ensures that traj files are always returned independent of the current load state.
1821 """
1822 return [traj.traj_file for traj in self.trajs]
1824 @property
1825 def index_arr(self) -> np.ndarray:
1826 """np.ndarray: Returns np.ndarray with ndim = 2. Clearly assigning every
1827 loaded frame an identifier of traj_num (self.index_arr[:,0]) and
1828 frame_num (self.index_arr[:,1]). Can be used to create an unspecified
1829 subset of frames and can be useful when used with clustering.
1831 """
1832 # can also be made to use the SingleTraj.index_arr attribute,
1833 # but doing it this way the traj is loaded.
1834 # which might slow down thing significantly
1835 return np.vstack([traj.id for traj in self.trajs])
1837 @property
1838 def name_arr(self) -> np.ndarray:
1839 """np.ndarray: Trajectory names with the same length as self.n_frames."""
1840 name_arr = []
1841 if not np.all([traj.n_frames for traj in self.trajs]):
1842 return np.array(name_arr)
1843 else:
1844 for x, traj in enumerate(self.trajs):
1845 names = [traj.basename for i in range(traj.n_frames)]
1846 name_arr.extend(names)
1847 return np.array(name_arr)
1849 @property
1850 def n_frames(self) -> int:
1851 """int: Sum of the loaded frames."""
1852 return sum([traj.n_frames for traj in self.trajs])
1854 @property
1855 def frames(self) -> list[int]:
1856 """list: Frames of individual trajectories."""
1857 return [traj.n_frames for traj in self.trajs]
1859 @property
1860 def CVs_in_file(self) -> bool:
1861 """bool: Is true, if CVs can be loaded from file. Can be used to build a
1862 data generator from."""
1863 return all([traj.CVs_in_file for traj in self.trajs])
1865 def _traj_joined(self, progbar: Optional[Any] = None) -> md.Trajectory:
1866 # use traj[0] of the trajs list as the traj from which the topology will be used
1867 parent_traj = self.trajs[0].traj
1869 # join the correct number of trajs
1870 # by use of the `divmod` method, the frames parent_traj traj will be
1871 # appended for a certain amount, until the remainder of the division
1872 # is met by that time, the parent traj will be sliced to fill the correct number of frames
1873 no_of_iters, rest = divmod(self.n_frames, parent_traj.n_frames)
1875 total = self.n_frames + no_of_iters + 1
1876 if progbar is None:
1877 if _is_notebook():
1878 progbar = notebook_tqdm(
1879 total=total,
1880 leave=False,
1881 position=0,
1882 desc="Clustering...",
1883 )
1884 else:
1885 progbar = normal_tqdm(
1886 total=total,
1887 leave=False,
1888 position=0,
1889 desc="Clustering...",
1890 function="_traj_joined",
1891 )
1892 else:
1893 if not isinstance(progbar, bool):
1894 progbar.reset(progbar.total + total, function="_traj_joined")
1895 else:
1896 progbar = None
1898 for i in range(no_of_iters + 1):
1899 if i == 0:
1900 dummy_traj = copy.deepcopy(parent_traj)
1901 elif i == no_of_iters:
1902 if rest != 0:
1903 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj)[:rest])
1904 else:
1905 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj))
1906 if progbar is not None:
1907 progbar.update(function="_traj_joined")
1909 # some checks
1910 assert self.n_frames == dummy_traj.n_frames
1911 # assert self.n_frames == len(self.trajs), f"{self.n_frames=}, {len(self.trajs)=}"
1913 # change the xyz coordinates of dummy_traj according to the frames in joined trajs
1914 for i, (_, __, traj) in enumerate(self.iterframes()):
1915 try:
1916 dummy_traj.xyz[i] = traj.xyz
1917 except ValueError as e:
1918 if "broadcast" not in str(e):
1919 raise e
1920 warnings.warn(
1921 f"This`TrajEnsemble` has {len(self.top)} unique topologies. "
1922 f"I will use the topology with {self.top[0].n_atoms} for joining "
1923 f"and discard atoms in the other trajectories."
1924 )
1925 dummy_traj.xyz[i] = traj.xyz[0, : dummy_traj.n_atoms]
1926 if progbar is not None:
1927 progbar.update(function="_traj_joined")
1929 return dummy_traj
1931 @property
1932 def traj_joined(self) -> md.Trajectory:
1933 """mdtraj.Trajectory: Returns a mdtraj Trajectory with every frame of
1934 this class appended along the time axis.
1936 Can also work if different topologies (with the same number of atoms) are loaded.
1937 In that case, the first frame in self will be used as topology parent and the remaining frames'
1938 xyz coordinates are used to position the parents' atoms accordingly.
1940 Examples:
1941 >>> import encodermap as em
1942 >>> trajs = em.load_project("pASP_pGLU")
1943 >>> subsample = trajs[0][:20] + trajs[1][:20]
1944 >>> subsample.split_into_frames().traj_joined # doctest: +ELLIPSIS
1945 <mdtraj.Trajectory with 40 frames, 69 atoms, 6 residues, and unitcells at ...>
1947 """
1948 return self._traj_joined()
1950 @property
1951 def xyz(self) -> np.ndarray:
1952 """np.ndarray: xyz coordinates of all atoms stacked along the traj-time axis.
1954 Only works if all trajs share the same topology.
1956 """
1957 if len(self.top) == 1:
1958 xyz = np.vstack([traj.xyz for traj in self.trajs])
1959 return xyz
1960 else:
1961 try:
1962 xyz = np.vstack([traj.xyz for traj in self.trajs])
1963 return xyz
1964 except Exception as e:
1965 msg = (
1966 "Non consistent topologies don't allow to return a "
1967 "common xyz. This could be achived by implementing a "
1968 "high-dimensional masked numpy array with nans at "
1969 "non-defined positions."
1970 )
1971 e2 = NotImplementedError(msg)
1972 raise e2 from e
1974 def copy(self):
1975 return deepcopy(self)
1977 def split_into_frames(self, inplace: bool = False) -> None:
1978 """Splits self into separate frames.
1980 Args:
1981 inplace (bool): Whether to do the split inplace or not.
1982 Defaults to False and thus, returns a new `TrajEnsemble` class.
1984 """
1985 frames = []
1986 for traj_num, frame_num, frame in self.iterframes():
1987 frames.append(frame)
1988 out = TrajEnsemble.with_overwrite_trajnums(*frames)
1989 assert out.trajs != []
1990 if inplace:
1991 self = out
1992 else:
1993 return out
1995 def save_CVs(self, path: Union[str, Path]) -> None:
1996 """Saves the CVs to a NETCDF file using xarray."""
1997 self._CVs.to_netcdf(path, format="NETCDF4", engine="h5netcdf")
1999 def cluster(
2000 self,
2001 cluster_id: int,
2002 col: str = "cluster_membership",
2003 memberships: Optional[np.ndarray] = None,
2004 n_points: int = -1,
2005 overwrite: bool = True,
2006 ) -> TrajEnsemble:
2007 """Clusters this :obj:`TrajEnsemble` based on the provided
2008 ``cluster_id`` and ``col``.
2010 With 'clustering' we mean to extract a subset given a certain membership.
2011 Take two trajectories with 3 frames each as an ensemble. Let's say we
2012 calculate the end-to-end distance of the trajectories and use it as
2013 a collective variable of the system. The values are
2014 ``[0.8, 1.3, 1.2, 1.9, 0.2, 1.3]``. Based on these values, we define a
2015 boolean CV (using 0 as False and 1 as True) which says whether the
2016 end-to-end distance is smaller or grather than 1.0. We give this CV the
2017 name ``'end_to_end_binary'`` and the values are ``[0, 1, 1, 1, 0, 1]``.
2018 We can use this CV to 'cluster' the :obj:`TrajEnsemble` via:
2019 - ``cluster = trajs.cluster(cluster_id=0, col='end_to_end_binary')``:
2020 This gives a :obj:`TrajEnsemble` with 2 frames.
2021 - ``cluster = trajs.cluster(cluster_id=0, col='end_to_end_binary')``:
2022 This gives a :obj:`TrajEnsemble` with 4 frames.
2023 Sometimes, you want to save this a cluster in a format that can be rendered
2024 by graphical programs (``.xtc, .pdb``), you can use either the :meth:`join` or
2025 :meth:`stack` method of the resulting :obj:``TrajEnsemble` to get a
2026 `mdtraj.Trajectory`, which is either stacked along the atom axis or
2027 joined along the time axis.
2029 Note:
2030 If the resulting :obj:`TrajEnsemble` has inhomogeneous topologies, the
2031 :meth:`join` method will return a dict[md.Topology, md.Trajectory]
2032 instead. This dict can be used to save multiple (``.xtc, .pdb``) files
2033 and visualize your cluster in external programs.
2035 The ``col`` parameter takes any CV name, that is per-frame and integer.
2037 Args:
2038 cluster_id (int): The cluster id to use. Needs to be an integer,
2039 that is present in the ``col`` parameter.
2040 col (str): Which 'column' of the collective variables to use.
2041 Needs to be a key, that can be found in ``trajs.CVs.keys()``.
2042 memberships (Optional[np.ndarray]): If a :obj:`numpy.ndarray` is
2043 provided here, the memberships from this array will be used.
2044 In this case, the ``col`` argument will be unused.
2045 n_points (int): How many points the resulting cluster should contain.
2046 Subsamples the points in ``col == cluster_id`` evenly and without
2047 repeat. If set to -1, all points will be used.
2048 overwrite (bool): When the ``memberships`` argument is used, but the
2049 :obj:`TrajEnsemble` already has a CV under the name specified by
2050 ``col``, you can set this to True to overwrite this column. Can
2051 be helpful, when you iteratively conduct multiple clusterings.
2053 Examples:
2055 Import EncoderMap and NumPy.
2057 >>> import encodermap as em
2058 >>> import numpy as np
2060 Load an example project.
2062 >>> trajs = em.load_project("pASP_pGLU", load_autoencoder=False)
2064 Create an array full of ``-1``'s. These are the 'outliers'.
2066 >>> cluster_membership = np.ones(shape=(trajs.n_frames, )) * -1
2068 Select the first 5 frames of every traj to be in cluster 0.
2070 >>> cluster_membership[trajs.id[:, 1] < 5] = 0
2072 Select all frames between 50 and 55 to be cluster 1.
2074 >>> cluster_membership[(50 <= trajs.id[:, 1]) & (trajs.id[:, 1] <= 55)] = 1
2075 >>> np.unique(cluster_membership)
2076 array([-1., 0., 1.])
2078 Load this array as a CV called ``'clu_mem'``.
2080 >>> trajs.load_CVs(cluster_membership, attr_name='clu_mem')
2082 Extract all of cluster 0 with ``n_points=-1``.
2084 >>> clu0 = trajs.cluster(0, "clu_mem")
2085 >>> clu0.n_frames
2086 35
2088 Extract an evenly spaced subset of cluster 1 with 10 total points.
2090 >>> clu1 = trajs.cluster(1, "clu_mem", n_points=10)
2091 >>> clu1.n_frames
2092 10
2094 Cclusters with inhomogeneous topologies can be stacked along the atom axis.
2096 >>> [t.n_atoms for t in trajs]
2097 [69, 83, 103, 91, 80, 63, 73]
2098 >>> stacked = clu1.stack()
2099 >>> stacked.n_atoms
2100 795
2102 But joining the trajectories returns a ``dict[top, traj]`` if the
2103 topologies are inhomogeneous.
2105 >>> joined = clu1.join()
2106 >>> type(joined)
2107 <class 'dict'>
2109 """
2110 if memberships is not None:
2111 if not overwrite:
2112 assert col not in self._CVs, (
2113 f"Can't load {memberships} as new CVs. "
2114 f"The CV {col} containing cluster memberships already exists. "
2115 f"Choose a different name for the argument `col`."
2116 )
2117 self.load_CVs(memberships, col, override=overwrite)
2119 assert (
2120 col in self._CVs
2121 ), f"To use the CV '{col}' for clustering, add it to the CVs with `load_CVs`."
2123 # find the index
2124 index_ = (self.CVs[col] == cluster_id).squeeze()
2125 index = self.index_arr[index_]
2126 frame_index = np.arange(self.n_frames)[index_]
2127 assert index.size > 0, (
2128 f"The `cluster_id` {cluster_id} is not present in the `col` {col}: "
2129 f"{np.unique(self.CVs[col])=}"
2130 )
2131 if n_points > 0:
2132 ind = np.unique(
2133 np.round(np.linspace(0, len(index) - 1, n_points)).astype(int)
2134 )
2135 index = index[ind]
2136 frame_index = frame_index[ind]
2137 try:
2138 out = self[index]
2139 except IndexError as e:
2140 out = self._return_frames_by_index(frame_index)
2141 if hasattr(self, "clustal_w"):
2142 out.parse_clustal_w_alignment(self.clustal_w)
2143 return out
2145 def join(
2146 self,
2147 align_string: str = "name CA",
2148 superpose: bool = True,
2149 ref_align_string: str = "name CA",
2150 base_traj: Optional[md.Trajectory] = None,
2151 progbar: Optional[Any] = None,
2152 dict_keys: Literal["top", "cs"] = "top",
2153 ) -> dict[Union[md.Topology, str], md.Trajectory]:
2154 if len(self.top) > 1 and superpose:
2155 assert align_string == ref_align_string == "name CA", (
2156 f"Aligning different topologies only possible, when the `align"
2157 f"_string` and `ref_align_string` both are 'name CA'."
2158 )
2159 if progbar is None:
2160 if _is_notebook():
2161 progbar = notebook_tqdm(
2162 total=len(self.top),
2163 leave=False,
2164 position=0,
2165 desc="Joining...",
2166 )
2167 else:
2168 progbar = normal_tqdm(
2169 total=len(self.top),
2170 leave=False,
2171 position=0,
2172 desc="Joining...",
2173 function="join",
2174 )
2175 else:
2176 if not isinstance(progbar, bool):
2177 progbar.reset(progbar.total + len(self.top), function="join")
2178 else:
2179 progbar = None
2181 all_trajs = []
2182 out_by_top = {}
2183 if dict_keys == "top":
2184 iterator = enumerate(self.trajs_by_top.items())
2185 elif dict_keys == "cs":
2186 iterator = enumerate(self.trajs_by_common_str.items())
2187 else:
2188 raise TypeError(
2189 f"The argument `dict_keys` '{dict_keys}' is not recognized."
2190 )
2191 for i, (top, traj) in iterator:
2192 traj = traj._traj_joined(progbar=progbar)
2193 if superpose:
2194 if base_traj is not None:
2195 CAs_traj = traj.top.select(align_string)
2196 CAs_ref = base_traj.top.select(ref_align_string)
2197 if hasattr(self, "clustal_w"):
2198 new_CAs_traj = []
2199 new_CAs_ref = []
2200 i_t = 0
2201 i_r = 0
2202 for aln_t, aln_r in zip(
2203 self.trajs_by_top[traj.top][0].clustal_w,
2204 self.trajs_by_top[base_traj.top][0].clustal_w,
2205 ):
2206 if aln_t == aln_r == "-":
2207 pass
2208 elif aln_t != "-" and aln_r == "-":
2209 i_t += 1
2210 elif aln_t == "-" and aln_r != "-":
2211 i_r += 1
2212 else:
2213 new_CAs_traj.append(CAs_traj[i_t])
2214 new_CAs_ref.append(CAs_ref[i_r])
2215 i_t += 1
2216 i_r += 1
2217 new_CAs_traj = np.array(new_CAs_traj)
2218 new_CAs_ref = np.array(new_CAs_ref)
2219 CAs_traj = new_CAs_traj.copy()
2220 CAs_ref = new_CAs_ref.copy()
2221 else:
2222 length = min(len(CAs_traj), len(CAs_ref))
2223 CAs_traj = CAs_traj[:length]
2224 CAs_ref = CAs_ref[:length]
2225 traj = traj.superpose(
2226 base_traj,
2227 atom_indices=CAs_traj,
2228 ref_atom_indices=CAs_ref,
2229 )
2230 else:
2231 traj = traj.superpose(
2232 traj,
2233 atom_indices=traj.top.select(align_string),
2234 )
2235 if i == 0:
2236 base_traj = traj[0]
2237 all_trajs.append(traj)
2238 out_by_top[top] = traj
2239 if progbar is not None:
2240 progbar.update(function="join")
2242 # return
2243 return out_by_top
2245 def stack(
2246 self,
2247 align_string: str = "name CA",
2248 superpose: bool = True,
2249 ref_align_string: str = "name CA",
2250 base_traj: Optional[md.Trajectory] = None,
2251 progbar: Optional[Any] = None,
2252 ) -> md.Trajectory:
2253 if progbar is None:
2254 if _is_notebook():
2255 progbar = notebook_tqdm(
2256 total=self.n_frames,
2257 leave=False,
2258 position=0,
2259 desc="Stacking...",
2260 )
2261 else:
2262 progbar = normal_tqdm(
2263 total=self.n_frames,
2264 leave=False,
2265 position=0,
2266 desc="Stacking...",
2267 function="stack",
2268 )
2269 else:
2270 if not isinstance(progbar, bool):
2271 progbar.reset(progbar.total + self.n_frames, function="stack")
2272 else:
2273 progbar = None
2275 all_trajs = self.join(
2276 align_string,
2277 superpose,
2278 ref_align_string,
2279 base_traj,
2280 progbar=progbar,
2281 )
2282 atoms = 0
2283 # stack
2284 for i, traj in enumerate(all_trajs.values()):
2285 for j, frame in enumerate(traj):
2286 atoms += frame.n_atoms
2287 if i == 0 and j == 0:
2288 stacked = deepcopy(frame)
2289 else:
2290 stacked = stacked.stack(frame)
2291 if progbar is not None:
2292 progbar.update(function="stack")
2293 assert stacked.n_atoms == atoms
2294 return stacked
2296 def _trace(self, CV: Sequence[str]) -> np.ndarray:
2297 """Creates a low-dimensional represnetation of the loaded CV data by
2298 stacking all arguments in `CV` along a single axis.
2300 If this `TrajEnsemble` has 10 trajectories with 100 frames each
2301 and a CV with shape (100, 50, 3) in each of them with the name 'cartesians'
2302 then `trajs._trace` will return a np.ndarray of shape
2303 (1000, 150).
2305 Args:
2306 CV (Sequence[str]): The CVs to combine in the trace.
2308 Returns:
2309 np.ndarray: The trace.
2311 """
2312 out = []
2313 for i in CV:
2314 v = self.CVs[i]
2315 out.append(v.reshape(v.shape[0], -1))
2316 return np.concatenate(out)
2318 def to_dataframe(self, CV: Union[str, Sequence[str]]) -> pd.DataFrame:
2319 # frame nums can be inhomogeneous
2320 frame_num = []
2321 for traj in self:
2322 if (_ := traj.id).ndim == 2:
2323 frame_num.extend(_[:, 1])
2324 else:
2325 frame_num.extend(_)
2326 time = []
2327 for traj in self:
2328 time.extend(traj.time)
2330 # the CV data can be directly extracted from xarray
2331 if isinstance(CV, str):
2332 data = [
2333 self._CVs[CV]
2334 .stack({"frame": ("traj_num", "frame_num")})
2335 .transpose("frame", ...)
2336 .dropna("frame", how="all")
2337 .to_pandas()
2338 ]
2339 else:
2340 data = []
2341 for cv in CV:
2342 df = (
2343 self._CVs[cv]
2344 .stack({"frame": ("traj_num", "frame_num")})
2345 .transpose("frame", ...)
2346 .dropna("frame", how="all")
2347 .to_pandas()
2348 )
2349 if len(df.columns) == 1:
2350 df = df.rename(columns={0: cv.upper()})
2351 data.append(df)
2353 df = pd.DataFrame(
2354 {
2355 "traj_file": [
2356 traj.traj_file for traj in self for i in range(traj.n_frames)
2357 ],
2358 "top_file": [
2359 traj.top_file for traj in self for i in range(traj.n_frames)
2360 ],
2361 "traj_num": [
2362 traj.traj_num for traj in self for i in range(traj.n_frames)
2363 ],
2364 "frame_num": frame_num,
2365 "time": time,
2366 },
2367 )
2368 df = df.set_index(["traj_num", "frame_num"])
2369 return pd.concat([df, *data], axis=1)
2371 def dash_summary(self) -> pd.DataFrame:
2372 """A :obj:`pandas.DataFrame` that summarizes this ensemble.
2374 Returns:
2375 pd.DataFrame: The DataFrame.
2377 """
2378 if self.n_trajs == 1:
2379 return self.trajs[0].dash_summary()
2380 else:
2381 # atoms
2382 n_atoms = np.unique([t.n_atoms for t in self])
2383 if len(n_atoms) == 1:
2384 n_atoms = n_atoms[0]
2386 # dt
2387 dt = []
2388 for t in self:
2389 dt.extend(np.unique(t.time[1:] - t.time[:-1]))
2390 dt = np.unique(np.asarray(dt))
2391 if len(dt) == 1:
2392 dt = dt[0]
2394 # traj_files
2395 traj_files = [t.basename for t in self]
2397 # topologies
2398 multiple_tops = len(self.top) != 1
2399 df = pd.DataFrame(
2400 {
2401 "field": [
2402 "n_trajs",
2403 "n_frames",
2404 "n_atoms",
2405 "dt (ps)",
2406 "trajs",
2407 "multiple tops",
2408 "common_str",
2409 ],
2410 "value": [
2411 self.n_trajs,
2412 self.n_frames,
2413 n_atoms,
2414 dt,
2415 traj_files,
2416 multiple_tops,
2417 list(set(self.common_str)),
2418 ],
2419 }
2420 )
2421 return df.astype(str)
2423 def load_CVs(
2424 self,
2425 data: TrajEnsembleFeatureType = None,
2426 attr_name: Optional[str] = None,
2427 cols: Optional[list[int]] = None,
2428 deg: Optional[bool] = None,
2429 periodic: bool = True,
2430 labels: Optional[list[str]] = None,
2431 directory: Optional[Union[str, Path]] = None,
2432 ensemble: bool = False,
2433 override: bool = False,
2434 custom_aas: Optional[CustomAAsDict] = None,
2435 alignment: Optional[str] = None,
2436 ) -> None:
2437 """Loads CVs in various ways. The easiest way is to provide a single
2438 :obj:`numpy.ndarray` and a name for that array.
2440 Besides np.ndarray, files (``.txt and .npy``) can be loaded. Features
2441 or Featurizers can be provided. A :obj:`xarray.Dataset` can be provided.
2442 A str can be provided which either is the name of one of EncoderMap's
2443 features (`encodermap.features`) or the string can be 'all',
2444 which loads all features required for EncoderMap's
2445 :obj:`encodermap.autoencoder.autoencoder`AngleDihedralCartesianEncoderMap`.
2447 Args:
2448 data (Optional[TrajEnsembleFeatureType]): The CV to
2449 load. When a :obj:`numpy.ndarray` is provided, it needs to have
2450 a shape matching ``n_frames`` and the data will be distributed
2451 to the trajs, When a list of files is provided, ``len(data)``
2452 (the files) needs to match ``n_trajs``. The first file will be
2453 loaded by the first traj (based on the traj's ``traj_num``) and
2454 so on. If a list of :obj:`numpy.ndarray` is provided, the first
2455 array will be assigned to the first traj (based on the traj's
2456 ``traj_num``). If None is provided, the argument ``directory``
2457 will be used to construct a str using this expression
2458 ``fname = directory + traj.basename + '_' + attr_name``. If
2459 there are ``.txt`` or ``.npy`` files matching that string in
2460 the ``directory``, the CVs will be loaded from these files to
2461 the corresponding trajs. Defaults to None.
2462 attr_name (Optional[str]): The name under which the CV should
2463 be found in the class. Choose whatever you like. ``'highd'``,
2464 ``'lowd'``, ``'dists'``, etc. The CV can then be accessed via
2465 dot-notation: ``trajs.attr_name``. Defaults to None, in which
2466 case, the argument ``data`` should point to existing files.
2467 The ``attr_name`` will be extracted from these files.
2468 cols (Optional[list[int]]): A list of integers indexing the columns
2469 of the data to be loaded. This is useful if a file contains
2470 columns which are not features (i.e. an indexer or the error of
2471 the features. eg::
2473 id f1 f2 f1_err f2_err
2474 0 1.0 2.0 0.1 0.1
2475 1 2.5 1.2 0.11 0.52
2477 In that case, you would want to supply ``cols=[1, 2]`` to the
2478 ``cols`` argument. If None is provided all columns are loaded.
2479 Defaults to None.
2480 deg (Optional[bool]): Whether to return angular CVs using degrees.
2481 If None or False, CVs will be in radian. Defaults to None.
2482 periodic (bool): Whether to use the minimum image convention to
2483 calculate distances/angles/dihedrals. This is generally recommended,
2484 when you don't clean up your trajectories and the proteins break
2485 over the periodic boundary conditions. However, when the protein is
2486 large, the distance between one site and another might be shorter
2487 through the periodic boundary. This can lead to wrong results
2488 in your distance calculations.
2489 labels (list[str]): A list containing the labels for the dimensions of
2490 the data. If you provide a :obj:`numpy.ndarray` with shape
2491 ``(n_trajs, n_frames, n_feat)``, this list needs to be of
2492 ``len(n_feat)``. An exception will be raised otherwise. If None is
2493 privided, the labels will be automatically generated. Defaults to None.
2494 directory (Optional[str]): If this argument is provided, the
2495 directory will be searched for ``.txt`` or ``.npy`` files which
2496 have the same names as the trajectories have basenames. The
2497 CVs will then be loaded from these files.
2498 ensemble (bool): Whether the trajs in this class belong to an ensemble.
2499 This implies that they contain either the same topology or are
2500 very similar (think wt, and mutant). Setting this option True will
2501 try to match the CVs of the trajs onto the same dataset.
2502 If a VAL residue has been replaced by LYS in the mutant,
2503 the number of sidechain dihedrals will increase. The CVs of the
2504 trajs with VAL will thus contain some NaN values. Defaults to False.
2505 override (bool): Whether to override CVs with the same name as ``attr_name``.
2506 custom_aas (Optional[CustomAAsDict]): You can provide non-standard
2507 residue definitions in this argument. See
2508 :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology` for
2509 information how to use the custom_aas argument. If set to None
2510 (default), only standard residue names are assumed.
2511 alignment (Optional[str]): If your proteins have similar but different
2512 sequences, you can provide a CLUSTAL W alignment as this argument
2513 and the featurization will align the features accordingly.
2515 Raises:
2516 TypeError: When wrong Type has been provided for data.
2518 """
2519 # Local Folder Imports
2520 from .trajinfo_utils import load_CVs_ensembletraj
2522 if data is None:
2523 data = self.featurizer
2525 if custom_aas is not None:
2526 self.load_custom_topology(custom_aas)
2528 if alignment is not None:
2529 if ensemble:
2530 self.parse_clustal_w_alignment(alignment)
2531 else:
2532 print(
2533 "Providing a CLUSTAL W alignment for featurization of ensembles "
2534 "of protein families, makes only sense, when `ensemble` is also "
2535 "set to True. This makes EncoderMap align the features based "
2536 "on their sequence alignment."
2537 )
2539 # if some trajs are missing time
2540 b, c = np.unique(
2541 np.asarray([t.backend for t in self.trajs]), return_counts=True
2542 )
2543 if len(b) > 1:
2544 for traj in self.trajs:
2545 traj.load_traj()
2547 load_CVs_ensembletraj(
2548 self,
2549 data,
2550 attr_name,
2551 cols,
2552 deg,
2553 periodic,
2554 labels,
2555 directory,
2556 ensemble,
2557 override,
2558 )
2560 def save(
2561 self,
2562 fname: Union[str, Path],
2563 CVs: Union[Literal["all"], list[str], Literal[False]] = "all",
2564 overwrite: bool = False,
2565 only_top: bool = False,
2566 ) -> None:
2567 """Saves this TrajEnsemble into a single ``.h5`` file.
2569 Args:
2570 fname (Union[str, Path]): Where to save the file.
2571 CVs (Union[Literal["all"], list[str], Literal[False]]): Which CVs
2572 to alos store in the file. If set to ``'all'``, all CVs will
2573 be saved. Otherwise, a list[str] can be provided to only save
2574 specific CVs. Can also be set to False, no CVs are stored in the
2575 file.
2576 overwrite (bool): If the file exists, it is overwritten.
2577 only_top (bool): Only writes the trajectorie's topologies into the file.
2579 Raises:
2580 IOError: If file already exists and overwrite is not True.
2582 """
2583 # Third Party Imports
2584 from mdtraj.utils import in_units_of
2586 if any([hasattr(traj, "clustal_w") for traj in self]):
2587 warnings.warn(
2588 "Can't currently save a `TrajEnsemble` with a clustal w alignment"
2589 )
2590 return
2592 fname = Path(fname)
2593 assert (
2594 fname.suffix == ".h5"
2595 ), "We recommend the .h5 file extension for these files."
2596 if fname.is_file() and not overwrite and not only_top:
2597 raise IOError(
2598 f"File {fname} already exists. Set `overwrite` to True to overwrite."
2599 )
2600 if fname.is_file() and overwrite:
2601 fname.unlink()
2603 for i, traj in self.itertrajs():
2604 with HDF5GroupWrite(fname, "a", force_overwrite=overwrite) as f:
2605 if not only_top:
2606 f.write_into_group(
2607 group_id=str(i),
2608 coordinates=in_units_of(
2609 traj.xyz, md.Trajectory._distance_unit, f.distance_unit
2610 ),
2611 time=traj.time,
2612 cell_lengths=in_units_of(
2613 traj.unitcell_lengths,
2614 md.Trajectory._distance_unit,
2615 f.distance_unit,
2616 ),
2617 cell_angles=traj.unitcell_angles,
2618 topology=traj.top,
2619 )
2620 else:
2621 f.write_topology(traj.traj_num, traj.top)
2622 for i, traj in self.itertrajs():
2623 traj._custom_top.traj = traj
2624 traj._custom_top.to_hdf_file(fname)
2625 traj._common_str_to_hdf_file(fname)
2626 if CVs == "all":
2627 save_netcdf_alongside_mdtraj(fname, self._CVs)
2628 return
2629 if self._CVs and CVs:
2630 with h5py.File(fname, "a") as file:
2631 if "CVs" in list(file.keys()):
2632 grp = file["CVs"]
2633 else:
2634 grp = file.create_group("CVs")
2635 for key in CVs:
2636 value = self._CVs[key]
2637 assert self.n_frames == value.shape[1]
2638 grp.create_dataset(name=key, data=value)
2640 def _return_trajs_by_index(self, index: Sequence[int]) -> TrajEnsemble:
2641 """Creates a :obj:`TrajEnsemble` object with the trajs specified by ``index``.
2643 This is a sub-method of the ``trajs[]`` indexer.
2645 """
2646 new_common_str = []
2647 for i, traj in enumerate(self.trajs):
2648 if i not in index:
2649 continue
2650 new_common_str.append(traj.common_str)
2651 new_common_str = list(set(new_common_str))
2652 for i, ind in enumerate(index):
2653 if i == 0:
2654 trajs_subset = self.trajs[ind]._gen_ensemble()
2655 else:
2656 new_traj = self.trajs[ind]._gen_ensemble()
2657 trajs_subset += new_traj
2658 trajs_subset.common_str = new_common_str
2659 trajs_subset.basename_fn = self.basename_fn
2660 return trajs_subset
2662 def _return_frames_by_index(self, index: Sequence[int]) -> TrajEnsemble:
2663 """Creates a :obj:`TrajEnsemble` object with the frames specified by ``index``."""
2664 new_common_str = []
2665 frames = []
2666 for traj_num, frame_num, frame in self.iterframes():
2667 if frame_num not in index:
2668 continue
2669 frames.append(frame)
2670 new_common_str.append(frame.common_str)
2671 assert len(frames) > 0, (
2672 f"Can't return frames with index {index} from trajectory with "
2673 f"frames {self.id[:, -1]}."
2674 )
2675 new_common_str = list(set(new_common_str))
2676 for i, frame in enumerate(frames):
2677 if i == 0:
2678 trajs_subset = frame._gen_ensemble()
2679 else:
2680 new_traj = frame._gen_ensemble()
2681 trajs_subset = trajs_subset + new_traj
2682 trajs_subset.common_str = new_common_str
2683 trajs_subset.basename_fn = self.basename_fn
2684 return trajs_subset
2686 def _pyemma_indexing_tsel(self, key: np.ndarray) -> TrajEnsemble:
2687 """Returns a new :obj:`TrajEnsemble` by giving the indices of traj and frame."""
2688 if key.ndim == 1:
2689 key = key.reshape(len(key), 1).T
2690 trajs = []
2691 for i, num in enumerate(np.unique(key[:, 0])):
2692 frames = key[key[:, 0] == num, 1]
2693 trajs.append(self.tsel[num].fsel[frames])
2694 return TrajEnsemble(
2695 trajs, basename_fn=self.basename_fn, common_str=self.common_str
2696 )
2698 def _pyemma_indexing_no_tsel(self, key: np.ndarray) -> TrajEnsemble:
2699 """Returns a new :obj:`TrajEnsemble` by giving the indices of traj and frame."""
2700 if key.ndim == 1:
2701 key = key.reshape(len(key), 1).T
2702 trajs = []
2703 for i, num in enumerate(np.unique(key[:, 0])):
2704 assert num < self.n_trajs, (
2705 f"Can't identify trajectory with number {num} in an ensemble "
2706 f"with {self.n_trajs} trajectories."
2707 )
2708 frames = key[key[:, 0] == num, 1]
2709 trajs.append(self.trajs[num][frames])
2710 return TrajEnsemble(
2711 trajs, basename_fn=self.basename_fn, common_str=self.common_str
2712 )
2714 def subsample(
2715 self,
2716 stride: Optional[int] = None,
2717 total: Optional[int] = None,
2718 ) -> Optional[TrajEnsemble]:
2719 """Returns a subset of this :obj:`TrajEnsemble` given the provided stride
2720 or total.
2722 This is a faster alternative than using the ``trajs[trajs.index_arr[::1000]]``
2723 when HDF5 trajs are used, because the slicing information is saved in the
2724 respective :obj:`encodermap.trajinfo.info_single.SingleTraj`
2725 and loading of single frames is faster in HDF5 formatted trajs.
2727 Args:
2728 stride (Optional[int]): Return a frame ever stride frames.
2729 total (Optional[int]): Return a total of evenly sampled frames.
2731 Returns:
2732 TrajEnsemble: A trajectory ensemble.
2734 Note:
2735 The result from ``subsample(1000)` `is different from
2736 ``trajs[trajs.index_arr[::1000]]``. With subsample every trajectory
2737 is sub-sampled independently. Consider a :obj:`TrajEnsemble` with two
2738 :obj:`encodermap.trajinfo.info_single.SingleTraj` trajectories with
2739 18 frames each. ``subsampled = trajs.subsample(5)`` would return a
2740 :obj:`TrajEnsemble` with two trajs with 3 frames each
2741 (``subsampled.n_frames == 6``). Whereas,
2742 ``subsampled = trajs[trajs.index_arr[::5]]`` would return a
2743 :obj:`TrajEnsemble` with 7 SingleTrajs with 1 frame each
2744 (``subsampled.n_frames == 7``). Because the time and frame numbers
2745 are saved all the time, this should not be too much of a problem.
2747 """
2748 if stride is None and total is not None:
2749 idx = self.id[
2750 np.unique(
2751 np.round(np.linspace(0, self.n_frames - 1, total)).astype(int)
2752 )
2753 ]
2754 return self[idx]
2755 elif total is None and stride is not None:
2756 trajs = []
2757 for i, traj in enumerate(self.trajs):
2758 _ = traj[slice(None, None, stride)]
2759 trajs.append(_)
2760 return TrajEnsemble(
2761 trajs, common_str=self.common_str, basename_fn=self.basename_fn
2762 )
2763 else:
2764 print("Provide either stride or total.")
2766 def get_single_frame(self, key: int) -> "SingleTraj":
2767 """Returns a single frame from all loaded trajectories.
2769 Consider a :obj:`TrajEnsemble` class with two trajectories. One has 10
2770 frames, the other 5 (``trajs.n_frames`` is 15). Calling
2771 ``trajs.get_single_frame(12)`` is equal to calling ``trajs[1][1]``.
2772 Calling ``trajs.get_single_frame(16)`` will error, and
2773 ``trajs.get_single_frame(1)`` is the same as ``trajs[0][1]``.
2775 Args:
2776 key (int): The frame to return.
2778 Returns:
2779 encodermap.trajinfo.info_single.SingleTraj: The frame.
2781 """
2782 # some input checks
2783 if self.n_frames == 0:
2784 raise Exception(
2785 "Indexing a no_load backend does not work. I need some "
2786 "information about the frames in each trajectory. Please "
2787 "load either highd or lowd."
2788 )
2789 if key >= self.n_frames:
2790 raise IndexError(
2791 "index {} is out of bounds for trajectory with {} frames".format(
2792 key, self.n_frames
2793 )
2794 )
2795 if not isinstance(key, (int, np.int32, np.int64)):
2796 raise IndexError(
2797 "if you want a single frame, please provide an integer. "
2798 "If you want multiple frames use ep.TrajEnsemble[]"
2799 )
2801 if len(self.trajs) == 1:
2802 return self.trajs[0][key]
2803 else:
2804 num, frame = np.hstack(
2805 [
2806 np.array([np.full(t.n_frames, t.traj_num), np.arange(t.n_frames)])
2807 for t in self.trajs
2808 ]
2809 ).T[key]
2810 traj_nums = np.unique(self.id[:, 0])
2811 if not np.array_equal(traj_nums, np.arange(len(traj_nums) + 1)):
2812 traj_out = self.tsel[num][frame]
2813 else:
2814 traj_out = self.trajs[num][frame]
2815 return traj_out
2817 def unload(self) -> None:
2818 """Unloads all trajs in self."""
2819 [traj.unload() for traj in self]
2820 self.backend = "no_load"
2822 def load_trajs(self) -> None:
2823 """Loads all trajs in self."""
2824 [traj.load_traj() for traj in self]
2825 self.backend = "mdtraj"
2827 @overload
2828 def batch_iterator( 2828 ↛ exitline 2828 didn't jump to the function exit
2829 self,
2830 batch_size: int,
2831 replace: bool = False,
2832 CV_names: tuple[str] = ("",),
2833 deterministic: bool = True,
2834 yield_index: bool = True,
2835 start: int = 1,
2836 ) -> Iterator[
2837 tuple[
2838 np.ndarray,
2839 np.ndarray,
2840 ]
2841 ]: ...
2843 @overload
2844 def batch_iterator( 2844 ↛ exitline 2844 didn't jump to the function exit
2845 self,
2846 batch_size: int,
2847 replace: bool = False,
2848 CV_names: tuple[str] = ("",),
2849 deterministic: bool = True,
2850 yield_index: bool = False,
2851 start: int = 1,
2852 ) -> Iterator[np.ndarray]: ...
2854 @overload
2855 def batch_iterator( 2855 ↛ exitline 2855 didn't jump to the function exit
2856 self,
2857 batch_size: int,
2858 replace: bool = False,
2859 CV_names: Optional[Sequence[str]] = None,
2860 deterministic: bool = True,
2861 yield_index: bool = True,
2862 start: int = 1,
2863 ) -> Iterator[
2864 tuple[
2865 np.ndarray,
2866 tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
2867 ]
2868 ]: ...
2870 @overload
2871 def batch_iterator( 2871 ↛ exitline 2871 didn't jump to the function exit
2872 self,
2873 batch_size: int,
2874 replace: bool = False,
2875 CV_names: Optional[Sequence[str]] = None,
2876 deterministic: bool = True,
2877 yield_index: bool = False,
2878 start: int = 1,
2879 ) -> Iterator[
2880 tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]
2881 ]: ...
2883 def batch_iterator(
2884 self,
2885 batch_size: int,
2886 replace: bool = False,
2887 CV_names: Optional[Sequence[str]] = None,
2888 deterministic: bool = False,
2889 yield_index: bool = False,
2890 start: int = 1,
2891 ) -> Iterator[Any]:
2892 """Lazy batched iterator of CV data.
2894 This iterator extracts batches of CV data from the ensemble. If the
2895 ensemble is a large HDF5 datset, this provides the ability to use all
2896 data without loading it all into memory.
2898 Examples:
2900 Import EncoderMap and load some example trajectories.
2902 >>> import encodermap as em
2903 >>> trajs = em.TrajEnsemble(
2904 ... [
2905 ... 'https://files.rcsb.org/view/1YUG.pdb',
2906 ... 'https://files.rcsb.org/view/1YUF.pdb'
2907 ... ]
2908 ... )
2910 This iterator will yield new samples forever. The batch is a tuple
2911 of :obj:`numpy.ndarray`.
2913 >>> for batch in trajs.batch_iterator(batch_size=2):
2914 ... print([b.shape for b in batch])
2915 ... break
2916 [(2, 148), (2, 147), (2, 150, 3), (2, 149), (2, 82)]
2918 Use it with Python's builtin ``next()`` function. The ``deterministic``
2919 flag returns deterministic batches. The ``yield_index`` flag also
2920 provides the index of the extracted batch. In this example, both batches
2921 are extracted from the 1YUG trajectory (``traj_num==0``).
2923 >>> iterator = trajs.batch_iterator(deterministic=True, batch_size=2, yield_index=True)
2924 >>> index, batch = next(iterator)
2925 >>> index
2926 [[0 5]
2927 [0 8]]
2928 >>> index, batch = next(iterator)
2929 >>> index
2930 [[ 0 3]
2931 [ 0 10]]
2933 If a single string is requested for ``CV_names``, the batch, will
2934 be a sinlge :obj:`numpy.ndarray`, rather than a tuple thereof.
2936 >>> iterator = trajs.batch_iterator(batch_size=2, CV_names=["central_dihedrals"])
2937 >>> batch = next(iterator)
2938 >>> batch.shape
2939 (2, 147)
2941 Args:
2942 batch_size (int): The size of the batch.
2943 replace (bool): Whether inside a single batch a sample can occur
2944 more than once. Set to False (default) to only allow unique
2945 samples in a batch.
2946 CV_names (Sequence[str]): The names of the CVs to be used in the
2947 iterator. If a list/tuple with a single string is provided, the
2948 batch will be a :obj:`numpy.ndarray`, rather than a tuple
2949 thereof.
2950 deterministic (bbol): Whether the samples should be deterministic.
2951 yield_index (bool): Whether to also yield the index of the extracted
2952 samples.
2953 start (int): A start ineteger, which can be used together with
2954 ``deterministic=True`` to get different deterministic datasets.
2956 Returns:
2957 Iterator[Any]: Different iterators based on chosen arguments.
2959 """
2960 # Encodermap imports
2961 from encodermap.autoencoder.autoencoder import np_to_sparse_tensor
2963 # the standard CV_names
2964 if CV_names is None:
2965 CV_names = [
2966 "central_angles",
2967 "central_dihedrals",
2968 "central_cartesians",
2969 "central_distances",
2970 "side_dihedrals",
2971 ]
2973 if self.CVs_in_file and all([t.index == (None,) for t in self.trajs]):
2974 ds = xr.open_dataset(
2975 self.trajs[0].traj_file, engine="h5netcdf", group="CVs"
2976 )
2977 ds = ds[CV_names]
2978 total_len = (
2979 ds.stack({"frame": ("traj_num", "frame_num")})
2980 .transpose("frame", ...)
2981 .dropna("frame", how="all")
2982 .sizes["frame"]
2983 )
2984 else:
2985 ds = self._CVs[CV_names]
2986 total_len = (
2987 ds.stack({"frame": ("traj_num", "frame_num")})
2988 .transpose("frame", ...)
2989 .dropna("frame", how="all")
2990 .sizes["frame"]
2991 )
2992 traj_nums_and_frames = self.id.copy()
2993 if self.CVs_in_file:
2994 assert len(traj_nums_and_frames) == total_len, (
2995 f"The CVs of the trajs are not aligned with the frames. The "
2996 f"CVs stacked along the traj/frame axis have a shape of {total_len}, "
2997 f"while the id array of the trajs has a shape of {self.id.shape}. "
2998 f"The frames of the trajs are reported as {self.n_frames}. The "
2999 f"CV data was extracted from the .h5 file {self.trajs[0].traj_file}, "
3000 f"by stacking the traj/frame axis into a combined axis and dropping "
3001 f"the frames full of NaNs for the CVs {CV_names}."
3002 )
3003 else:
3004 assert len(traj_nums_and_frames) == total_len, (
3005 f"The CVs of the trajs are not aligned with the frames. The "
3006 f"CVs stacked along the traj/frame axis have a shape of {total_len}, "
3007 f"while the id array of the trajs has a shape of {self.id.shape}. "
3008 f"The frames of the trajs are reported as {self.n_frames}. The CV "
3009 f"data was obtained from combining the CVs {CV_names} of the trajectories "
3010 f"in this ensemble along a traj axis."
3011 )
3013 # detect sparse or sidechain only sparse
3014 sparse = set()
3015 sub_ds = ds.sel(frame_num=0)
3016 for o in CV_names:
3017 datum = sub_ds[o].values
3018 if np.isnan(datum).any(1).any(None):
3019 sparse.add(o)
3021 # start the loop
3022 # i is the counter for the sample
3023 # j is the counter for how many tries were needed to select either
3024 # unique indices (if replace is False) or select indices where not
3025 i = start
3026 while True:
3027 index = []
3028 out = [[] for o in CV_names]
3029 j = 0
3030 while len(index) < batch_size:
3031 if j > 100 * batch_size:
3032 raise Exception(
3033 f"Can't find unique indices after 100 iterations. "
3034 f"Current index is {index=}."
3035 )
3036 if deterministic:
3037 np.random.seed(i + j)
3038 idx = tuple(traj_nums_and_frames[np.random.randint(0, total_len, 1)[0]])
3039 if idx in index and not replace:
3040 j += 1
3041 continue
3042 data = ds.sel(traj_num=idx[0], frame_num=idx[1])
3043 # check if any values in CV_names are all nans
3044 # this can happen for ensembles with different length trajectories
3045 # we append to `out_`, because if not all nans for this frame
3046 # we can append `out_` to `out`
3047 out_ = []
3048 for k, o in enumerate(CV_names):
3049 v = data[o].values
3050 if "dist" in o:
3051 assert np.all(np.nan_to_num(v, copy=True, nan=1.0) > 0.0), (
3052 f"Distances for the selection traj_num={idx[0]} frame={idx[1]} "
3053 f"contained a 0. This will result in problems with the cartesian "
3054 f"cost."
3055 )
3056 if np.all(np.isnan(v)):
3057 # if all nans break
3058 out_ = []
3059 break
3060 else:
3061 out_.append(v)
3062 # and continue
3063 if out_ == []:
3064 j += 1
3065 continue
3066 # if not, we can append
3067 index.append(idx)
3068 for k, o in enumerate(out_):
3069 out[k].append(o)
3070 j += 1
3071 # stack
3072 out = [np.stack(o, 0) for o in out]
3073 index = np.array(index)
3075 # make sparse tensors
3076 for i, o in enumerate(CV_names):
3077 if o in sparse:
3078 if out[i].ndim > 2:
3079 out[i] = out[i].reshape(batch_size, -1)
3080 out[i] = np_to_sparse_tensor(out[i])
3082 i += 1
3083 # and yield
3084 if len(CV_names) > 1:
3085 out = tuple(out)
3086 else:
3087 out = out[0]
3088 if yield_index:
3089 yield index, out
3090 else:
3091 yield out
3093 def tf_dataset(
3094 self,
3095 batch_size: int,
3096 replace: bool = False,
3097 sidechains: bool = False,
3098 reconstruct_sidechains: bool = False,
3099 CV_names: Optional[list[str]] = None,
3100 deterministic: bool = False,
3101 prefetch: bool = True,
3102 start: int = 1,
3103 ) -> tf.data.Dataset:
3104 # Third Party Imports
3105 import tensorflow as tf
3107 gen = lambda: self.batch_iterator(
3108 batch_size, replace, CV_names, deterministic, start=start
3109 )
3110 if CV_names is None and not sidechains and not reconstruct_sidechains:
3111 CV_names = [
3112 "central_angles",
3113 "central_dihedrals",
3114 "central_cartesians",
3115 "central_distances",
3116 ]
3117 elif CV_names is None and sidechains and not reconstruct_sidechains:
3118 CV_names = [
3119 "central_angles",
3120 "central_dihedrals",
3121 "central_cartesians",
3122 "central_distances",
3123 "side_dihedrals",
3124 ]
3125 elif CV_names is None and reconstruct_sidechains:
3126 CV_names = [
3127 "central_angles",
3128 "central_dihedrals",
3129 "all_cartesians",
3130 "central_distances",
3131 "side_angles",
3132 "side_dihedrals",
3133 "side_distances",
3134 ]
3135 for o in CV_names:
3136 assert o in self._CVs, f"The CV '{o}' is not loaded in this ensemble."
3138 # define the TensorSpecs
3139 sample = next(
3140 self.batch_iterator(
3141 batch_size=batch_size,
3142 replace=replace,
3143 CV_names=CV_names,
3144 deterministic=deterministic,
3145 )
3146 )
3147 if isinstance(sample, tuple):
3148 tensor_specs = []
3149 for o, s in enumerate(sample):
3150 if isinstance(s, tf.sparse.SparseTensor):
3151 tensor_specs.append(
3152 tf.SparseTensorSpec(shape=s.dense_shape, dtype="float32")
3153 )
3154 else:
3155 tensor_specs.append(tf.TensorSpec(shape=s.shape, dtype="float32"))
3156 tensor_specs = tuple(tensor_specs)
3157 else:
3158 tensor_specs = tf.TensorSpec(shape=sample.shape, dtype="float32")
3159 dataset = tf.data.Dataset.from_generator(gen, output_signature=tensor_specs)
3160 if prefetch:
3161 dataset = dataset.prefetch(batch_size * 4)
3162 if deterministic:
3163 options = tf.data.Options()
3164 options.deterministic = True
3165 dataset = dataset.with_options(options)
3166 assert dataset.options().deterministic
3167 return dataset
3169 def itertrajs(self) -> Iterator[tuple[int, "SingleTraj"]]:
3170 """Generator over the SingleTraj classes.
3172 Yields:
3173 tuple: A tuple containing the following:
3174 - int: A loop-counter integer. Is identical with traj.traj_num.
3175 - encodermap.SingleTraj: An SingleTraj object.
3177 Examples:
3178 >>> import encodermap as em
3179 >>> trajs = em.TrajEnsemble(
3180 ... [
3181 ... 'https://files.rcsb.org/view/1YUG.pdb',
3182 ... 'https://files.rcsb.org/view/1YUF.pdb'
3183 ... ]
3184 ... )
3185 >>> for i, traj in trajs.itertrajs():
3186 ... print(traj.basename)
3187 1YUG
3188 1YUF
3190 """
3191 for traj in self:
3192 yield traj.traj_num, traj
3194 def iterframes(self) -> Iterator[tuple[int, int, "SingleTraj"]]:
3195 """Generator over the frames in this instance.
3197 Yields:
3198 tuple: A tuple containing the following:
3199 - int: The traj_num
3200 - int: The frame_num
3201 - encodermap.SingleTraj: An SingleTraj object.
3203 Examples:
3205 Import EncoderMap and load an example :obj:`TrajEnsemble`.
3208 >>> import encodermap as em
3209 >>> trajs = em.TrajEnsemble(
3210 ... [
3211 ... 'https://files.rcsb.org/view/1YUG.pdb',
3212 ... 'https://files.rcsb.org/view/1YUF.pdb',
3213 ... ],
3214 ... )
3215 >>> print(trajs.n_frames)
3216 31
3218 Subsample every tenth frame.
3220 >>> trajs = trajs.subsample(10)
3221 >>> trajs.n_frames
3222 4
3224 Call the :meth:`iterframes` method.
3226 >>> for traj_num, frame_num, frame in trajs.iterframes():
3227 ... print(traj_num, frame_num, frame.n_frames)
3228 0 0 1
3229 0 10 1
3230 1 0 1
3231 1 10 1
3233 """
3234 for traj in self:
3235 yield from traj.iterframes(with_traj_num=True)
3237 def __copy__(self):
3238 cls = self.__class__
3239 result = cls.__new__(cls)
3240 result.__dict__.update(self.__dict__)
3241 return result
3243 def __deepcopy__(self, memo):
3244 # Standard Library Imports
3245 from copy import deepcopy
3247 cls = self.__class__
3248 result = cls.__new__(cls)
3249 memo[id(self)] = result
3250 for k, v in self.__dict__.items():
3251 setattr(result, k, deepcopy(v, memo))
3252 return result
3254 def __getitem__(self, key: CanBeIndex) -> TrajEnsemble:
3255 if isinstance(key, (int, np.int32, np.int64)):
3256 return self.trajs[key]
3257 elif isinstance(key, list) and not isinstance(key[0], list):
3258 new_class = self._return_trajs_by_index(key)
3259 return new_class
3260 elif isinstance(key, np.ndarray):
3261 if key.ndim == 1:
3262 new_class = self._return_trajs_by_index(key)
3263 return new_class
3264 elif key.ndim == 2:
3265 new_class = self._pyemma_indexing_no_tsel(key)
3266 return new_class
3267 else:
3268 raise IndexError(
3269 f"Passing a key with more than 2 dimensions makes no sense. "
3270 f"One dim for trajs, one for frames. Your key has "
3271 f"{key.ndim} dimensions."
3272 )
3273 elif isinstance(key, slice):
3274 start, stop, step = key.indices(self.n_trajs)
3275 list_ = list(range(start, stop, step))
3276 new_class = self[list_]
3277 return new_class
3278 elif isinstance(key, list) and all(isinstance(k, list) for k in key):
3279 return self[np.asarray(key)]
3280 raise IndexError(f"Invalid argument for slicing: {key=}")
3282 def __reversed__(self):
3283 raise NotImplementedError()
3285 def __eq__(self, other):
3286 # check if traj_files and ids are the same
3287 if len(self) != len(other):
3288 return False
3289 else:
3290 # Standard Library Imports
3291 import functools
3293 same_strings = functools.reduce(
3294 lambda x, y: x and y,
3295 map(
3296 lambda a, b: a == b,
3297 [traj.traj_file for traj in self.trajs],
3298 [traj2.traj_file for traj2 in other.trajs],
3299 ),
3300 True,
3301 )
3302 same_ids = all(
3303 [
3304 np.array_equal(traj1.id, traj2.id)
3305 for traj1, traj2 in zip(self.trajs, other.trajs)
3306 ]
3307 )
3308 same_CVs = self._CVs.equals(other._CVs)
3309 return same_strings and same_ids and same_CVs
3311 def __iter__(self):
3312 self._index = 0
3313 return self
3315 def __next__(self):
3316 if self._index >= self.n_trajs:
3317 raise StopIteration
3318 else:
3319 self._index += 1
3320 return self.trajs[self._index - 1]
3322 def __radd__(self, y):
3323 """Reverse addition to make sum() work."""
3324 if isinstance(y, int):
3325 return self
3326 return self.__add__(y)
3328 def __add__(self, y):
3329 """Addition of two TrajEnsemble objects returns new TrajEnsemble with
3330 trajectories joined along the traj axis.
3332 """
3333 # assert the other contains trajs
3334 if len(y.trajs) == 0: # pragma: nocover
3335 raise Exception(
3336 f"The `TrajEnsemble` {y} does not contain any trajs and can't "
3337 f"be used in addition."
3338 )
3339 # decide on the new backend
3340 if self.backend != y.backend:
3341 print("Mismatch between the backends. Using 'mdtraj'.")
3342 y.load_trajs()
3343 self.load_trajs()
3345 if not set(self.traj_nums).isdisjoint(set(y.traj_nums)):
3346 raise Exception(
3347 f"Can't add two `TrajEnsemble` with overlapping traj_nums: "
3348 f"left side: {self.traj_nums}\n"
3349 f"right side: {y.traj_nums}"
3350 )
3352 # build a common_str_ array with the correct number of entries
3353 # use this to create a new class
3354 # if there are no references in self or y. One of them was created from mdtraj.Trajectories
3355 if not any([self._top_files + y._top_files]):
3356 new_class = self.__class__(self.trajs + y.trajs, backend=self.backend)
3357 else:
3358 common_str_ = (
3359 get_full_common_str_and_ref(
3360 self.traj_files, self._top_files, self.common_str
3361 )[2]
3362 + get_full_common_str_and_ref(y.traj_files, y._top_files, y.common_str)[
3363 2
3364 ]
3365 )
3366 common_str_ = list(dict.fromkeys(common_str_))
3367 new_class = self.__class__(
3368 self.traj_files + y.traj_files,
3369 self._top_files + y._top_files,
3370 traj_nums=self.traj_nums + y.traj_nums,
3371 backend=self.backend,
3372 common_str=common_str_,
3373 )
3374 # put the trajs directly in the new class. This way the frames of the SingleTraj classes are preserved
3375 new_class.trajs = self.trajs + y.trajs
3377 return new_class
3379 def __getattr__(self, attr: str):
3380 if attr in self.CVs:
3381 return self.CVs[attr]
3382 else:
3383 return self.__getattribute__(attr)
3385 def _string_summary(self) -> str:
3386 if all([i.trajectory for i in self.trajs]):
3387 s = (
3388 f"encodermap.TrajEnsemble object. Current backend is "
3389 f"{self.backend}. Containing {self.n_frames} frames and "
3390 f"{self.n_trajs} trajectories."
3391 )
3392 else:
3393 s = (
3394 f"encodermap.TrajEnsemble object. Current backend is "
3395 f"{self.backend}. Containing {self.n_trajs} trajectories."
3396 )
3397 if self.common_str:
3398 s += f" Common str is {self.common_str}."
3399 if self.CVs:
3400 for key, value in self.CVs.items():
3401 s += f" CV {key} with shape {value.shape} loaded."
3402 else:
3403 s += " Not containing any CVs."
3404 return s
3406 def __len__(self) -> int:
3407 return self.n_frames
3409 def __str__(self) -> str:
3410 return self._string_summary()
3412 def __repr__(self) -> str:
3413 return f"<{self._string_summary()} Object at 0x{id(self):02x}>"