Coverage for encodermap/loading/featurizer.py: 53%
395 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/loading/featurizer.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Classes to be used as custom features with pyemma add_custom_feature
24ToDo:
25 * Write Docstrings.
26 * Write Examples.
27 * Sidechain angles, distances not working correctly.
28"""
31################################################################################
32# Imports
33################################################################################
36from __future__ import annotations
38import numpy as np
39import pandas as pd
41from .._optional_imports import _optional_import
42from ..loading import features
43from ..misc.misc import FEATURE_NAMES, _validate_uri
44from ..misc.xarray import get_indices_by_feature_dim, unpack_data_and_feature
45from ..trajinfo.info_all import TrajEnsemble
46from ..trajinfo.info_single import SingleTraj
48################################################################################
49# Optional Imports
50################################################################################
53featurizer = _optional_import("pyemma", "coordinates.featurizer")
54source = _optional_import("pyemma", "coordinates.source")
55xr = _optional_import("xarray")
56CHI1_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI1_ATOMS")
57CHI2_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI2_ATOMS")
58CHI3_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI3_ATOMS")
59CHI4_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI4_ATOMS")
60CHI5_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI5_ATOMS")
61Client = _optional_import("dask", "distributed.Client")
64################################################################################
65# Typing
66################################################################################
69from typing import TYPE_CHECKING
71if TYPE_CHECKING:
72 import xarray as xr
73 from dask.distributed import Client
74 from mdtraj.geometry.dihedral import (
75 CHI1_ATOMS,
76 CHI2_ATOMS,
77 CHI3_ATOMS,
78 CHI4_ATOMS,
79 CHI5_ATOMS,
80 )
81 from pyemma.coordinates import featurizer, source
82 from pyemma.coordinates.data._base.datasource import DataSource
83 from pyemma.coordinates.data.featurization._base import Feature
86################################################################################
87# Globals
88################################################################################
91__all__ = ["Featurizer"]
94UNDERSOCRE_MAPPING = {
95 "central_dihedrals": "CentralDihedrals",
96 "all_cartesians": "AllCartesians",
97 "all_distances": "AllBondDistances",
98 "central_cartesians": "CentralCartesians",
99 "central_distances": "CentralBondDistances",
100 "central_angles": "CentralAngles",
101 "side_cartesians": "SideChainCartesians",
102 "side_distances": "SideChainBondDistances",
103 "side_angles": "SideChainAngles",
104 "side_dihedrals": "SideChainDihedrals",
105}
108##############################################################################
109# Utils
110##############################################################################
112##############################################################################
113# Classes
114##############################################################################
117class Featurizer(type):
118 def __new__(cls, trajs, in_memory=True):
119 if in_memory: 119 ↛ 122line 119 didn't jump to line 122, because the condition on line 119 was never false
120 cls = PyEMMAFeaturizer
121 else:
122 from .dask_featurizer import DaskFeaturizer
124 cls = DaskFeaturizer
125 return cls(trajs)
128def format_output(
129 inps: list[DataSource],
130 feats: list[Feature],
131 trajs: list[TrajEnsemble],
132) -> tuple[list[np.ndarray], list[PyEMMAFeaturizer], list[TrajEnsemble]]:
133 """Formats the output of multiple topologies.
135 Iterates over the features in `feats` and looks for the feature
136 with the greatest dimension, i.e. the longest returned describe. This
137 feature yields the column names, the non-defined values are np.nan
139 Args:
140 inps (list[DataSource]): The list of inputs, that
141 return the values of the feats, when `get_output()` is called.
142 feats (list[encodermap.loading.Featurizer]: These featurizers collect the
143 features and will be used to determine the highest length of feats.
144 trajs (list[encodermap.trajinfo.TrajEnsemble]): List of trajs with
145 identical topologies.
147 Returns:
148 tuple[list[np.ndarray], list[Featurizer], list[TrajEnsembe]: The
149 data, that `TrajEnsemble` can work with.
151 """
153 class Featurizer_out:
154 pass
156 # append to this
157 all_out = []
159 feat_out = Featurizer_out()
160 feat_out.features = []
161 max_feat_lengths = {}
162 labels = {}
163 for feat in feats:
164 for i, f in enumerate(feat.feat.active_features):
165 name = f.__class__.__name__
167 if name not in max_feat_lengths:
168 max_feat_lengths[name] = 0
169 feat_out.features.append(
170 EmptyFeature(name, len(f.describe()), f.describe(), f.indexes)
171 )
173 if name == "SideChainDihedrals":
174 if name not in labels:
175 labels[name] = []
176 labels[name].extend(f.describe())
177 else:
178 if max_feat_lengths[name] < len(f.describe()):
179 max_feat_lengths[name] = len(f.describe())
180 labels[name] = f.describe()
181 feat_out.features[i] = EmptyFeature(
182 name, len(f.describe()), f.describe(), f.indexes
183 )
185 # rejig the sidechain labels
186 side_key = "SideChainDihedrals"
187 if side_key in labels: 187 ↛ 203line 187 didn't jump to line 203, because the condition on line 187 was never false
188 labels[side_key] = np.unique(labels[side_key])
189 labels[side_key] = sorted(
190 labels[side_key], key=lambda x: (int(x[-3:]), int(x[13]))
191 )
192 index_of_sidechain_dihedral_features = [
193 f.name == side_key for f in feat_out.features
194 ].index(True)
195 new_empty_feat = EmptyFeature(
196 side_key,
197 len(labels[side_key]),
198 labels[side_key],
199 None,
200 )
201 feat_out.features[index_of_sidechain_dihedral_features] = new_empty_feat
203 for (k, v), f in zip(labels.items(), feat_out.features):
204 if not len(v) == len(f.describe()) == f._dim: 204 ↛ 205line 204 didn't jump to line 205, because the condition on line 204 was never true
205 raise Exception(
206 f"Could not consolidate the features of the {f.name} "
207 f"feature. The `labels` dict, which dictates the size "
208 f"of the resulting array with np.nan's defines a shape "
209 f"of {len(v)}, but the feature defines a shape of {len(f.describe())} "
210 f"(or `f._dim = {f._dim}`). The labels dict gives these labels:\n\n{v}"
211 f"\n\n, the feature labels gives these labels:\n\n{f.describe()}."
212 )
214 # flatten the labels. These will be the columns for a pandas dataframe.
215 # At the start the dataframe will be full of np.nan.
216 # The values of inp.get_output() will then be used in conjunction with
217 # The labels of the features to fill this dataframe partially
218 flat_labels = [item for sublist in labels.values() for item in sublist]
219 if not len(flat_labels) == sum([f._dim for f in feat_out.features]): 219 ↛ 220line 219 didn't jump to line 220, because the condition on line 219 was never true
220 raise Exception(
221 f"The length of the generic CV labels ({len(flat_labels)} "
222 f"does not match the length of the labels of the generic features "
223 f"({[f._dim for f in feat_out.features]})."
224 )
226 # iterate over the sorted trajs, inps, and feats
227 for inp, feat, sub_trajs in zip(inps, feats, trajs):
228 # make a flat list for this specific feature space
229 describe_this_feature = []
230 for f in feat.feat.active_features:
231 # make sure generic labels are used
232 if f.describe.__func__.__name__ != "generic_describe": 232 ↛ 233line 232 didn't jump to line 233, because the condition on line 232 was never true
233 raise Exception(
234 f"It seems like this feature: {f.__class__} does not return generic "
235 f"feature names but topology-specifc ones (generic: 'SIDECHDIH CHI1 1', "
236 f"topology specific: 'SIDECHDIH CHI1 ASP1'). Normally, encodermap's "
237 f"features can be instantiated with a `generic_labels=True` flag to "
238 f"overwrite the features `describe()` method with a `generic_describe()` "
239 f"method. This changes the `.__func__.__name__` of the `describe()` method "
240 f"to 'generic_describe'. However the func name for this feature is "
241 f"{f.describe.__func__.__name__}."
242 )
243 describe_this_feature.extend(f.describe())
244 # use the output to fill a pandas dataframe with all labels
245 out = np.vstack(inp.get_output())
246 for o, traj in zip(out, sub_trajs):
247 df = pd.DataFrame(np.nan, index=range(len(out)), columns=flat_labels)
248 df = df.assign(**{k: v for k, v in zip(describe_this_feature, out.T)})
249 all_out.append((df.to_numpy(), feat_out, traj))
251 # make sure the sapes of all df matches
252 shapes = [o[0].shape[1] for o in all_out]
253 if not len(list(set(shapes))) == 1: 253 ↛ 254line 253 didn't jump to line 254, because the condition on line 253 was never true
254 raise Exception(
255 f"Alignment was not possible. Some values exhibit different shapes: "
256 f"{list(set(shapes))}. All shapes:\n\n{[o[0].shape[1] for o in all_out]}"
257 )
258 return all_out
261class PyEMMAFeaturizer:
262 def __init__(self, trajs):
263 self.trajs = trajs
264 self._copy_docstrings_from_pyemma()
266 def _copy_docstrings_from_pyemma(self):
267 if isinstance(self.feat, list):
268 feat_ = self.feat[0]
269 else:
270 feat_ = self.feat
271 self.add_all.__func__.__doc__ = feat_.add_all.__doc__
272 self.add_selection.__func__.__doc__ = feat_.add_selection.__doc__
273 self.add_distances.__func__.__doc__ = feat_.add_distances.__doc__
274 self.add_distances_ca.__func__.__doc__ = feat_.add_distances_ca.__doc__
275 self.add_inverse_distances.__func__.__doc__ = (
276 feat_.add_inverse_distances.__doc__
277 )
278 self.add_contacts.__func__.__doc__ = feat_.add_contacts.__doc__
279 self.add_residue_mindist.__func__.__doc__ = feat_.add_residue_mindist.__doc__
280 self.add_group_COM.__func__.__doc__ = feat_.add_group_COM.__doc__
281 self.add_residue_COM.__func__.__doc__ = feat_.add_residue_COM.__doc__
282 self.add_group_mindist.__func__.__doc__ = feat_.add_group_mindist.__doc__
283 self.add_angles.__func__.__doc__ = feat_.add_angles.__doc__
284 self.add_dihedrals.__func__.__doc__ = feat_.add_dihedrals.__doc__
285 self.add_backbone_torsions.__func__.__doc__ = (
286 feat_.add_backbone_torsions.__doc__
287 )
288 self.add_chi1_torsions.__func__.__doc__ = feat_.add_chi1_torsions.__doc__
289 self.add_sidechain_torsions.__func__.__doc__ = (
290 feat_.add_sidechain_torsions.__doc__
291 )
292 self.add_minrmsd_to_ref.__func__.__doc__ = feat_.add_minrmsd_to_ref.__doc__
294 def get_output(self) -> xr.Dataset:
295 if self.mode == "single_top":
296 if len(self.feat.active_features) == 0: 296 ↛ 297line 296 didn't jump to line 297, because the condition on line 296 was never true
297 print("No features loaded. No output will be returned")
298 return
300 if self.mode == "multiple_top":
301 if len(self.feat[0].features) == 0: 301 ↛ 302line 301 didn't jump to line 302, because the condition on line 301 was never true
302 print("No features loaded. No output will be returned")
303 return
305 if self.mode == "single_top":
306 datasets = []
307 out = self.inp.get_output()
308 for traj, out in zip(self.trajs, out):
309 datasets.append(unpack_data_and_feature(self, traj, out))
311 if len(datasets) == 1:
312 assert datasets[0].coords["traj_num"] == np.array(
313 [self.trajs[0].traj_num]
314 )
315 return datasets[0]
316 else:
317 out = xr.combine_nested(datasets, concat_dim="traj_num")
318 if ( 318 ↛ 322line 318 didn't jump to line 322
319 len(out.coords["traj_num"]) != len(self.trajs)
320 and len(out.coords["traj_num"]) != self.trajs.n_trajs
321 ):
322 raise Exception(
323 f"The combineNnested xarray method returned "
324 f"a bad dataset, which has {out.coords['traj_num']} "
325 f"trajectories, but the featurizer has {self.trajs} "
326 f"trajectories."
327 )
328 # out = xr.concat(datasets, dim='traj_num')
329 else:
330 out = format_output(self.inp, self.feat, self.sorted_trajs)
331 datasets = [unpack_data_and_feature(o[1], o[2], o[0]) for o in out]
332 out = xr.concat(datasets, dim="traj_num")
334 return out
336 def add_list_of_feats(self, which="all"):
337 """Adds features to the Featurizer to be loaded either in-memory or out-of-memory.
338 `which` can be either 'all' or a list of the following strings. 'all' will add all of these features:
339 * 'AllCartesians': Cartesian coordinates of all atoms with shape (n_frames, n_atoms, 3).
340 * 'AllBondDistances': Bond distances of all bonds recognized by mdtraj. Use top = md.Topology.from_openmm()
341 if mdtraj does not recognize all bonds.
342 * 'CentralCartesians': Cartesians of the N, C, CA atoms in the backbone with shape (n_frames, n_residues * 3, 3).
343 * 'CentralBondDistances': The bond distances of the N, C, CA bonds with shape (n_frames, n_residues * 3 - 1).
344 * 'CentralAngles': The angles between the backbone bonds with shape (n_frames, n_residues * 3 - 2).
345 * 'CentralDihedrals': The dihedrals between the backbone atoms (omega, phi, psi). With shape (n_frames,
346 n_residues * 3 - 3).
347 * 'SideChainCartesians': Cartesians of the sidechain-atoms. Starting with CB, CG, ...
348 * 'SideChainBondDistances': Bond distances between the sidechain atoms. starting with the CA-CG bond.
349 * 'SideChainAngles': Angles between sidechain atoms. Starting with the C-CA-CB angle.
350 * 'SideChainDihedrals': Dihedrals of the sidechains (chi1, chi2, chi3).
352 Args:
353 which (Union[str, list], optional). Either add 'all' features or a list of features. See Above for
354 possible features. Defaults to 'all'.
356 """
357 if isinstance(which, str):
358 if which == "all":
359 which = [
360 "CentralCartesians",
361 "CentralBondDistances",
362 "CentralAngles",
363 "CentralDihedrals",
364 "SideChainDihedrals",
365 ]
366 if not isinstance(which, list):
367 which = [which]
368 if self.mode == "single_top":
369 for cf in which:
370 if cf in UNDERSOCRE_MAPPING:
371 cf = UNDERSOCRE_MAPPING[cf]
372 feature = getattr(features, cf)(self.top)
373 self.feat.add_custom_feature(feature)
374 else:
375 for cf in which:
376 if cf in UNDERSOCRE_MAPPING: 376 ↛ 377line 376 didn't jump to line 377, because the condition on line 376 was never true
377 cf = UNDERSOCRE_MAPPING[cf]
378 for top, feat in zip(self.top, self.feat):
379 feature = getattr(features, cf)(top, generic_labels=True)
380 feat.add_custom_feature(feature)
382 def add_all(self, *args, **kwargs):
383 if self.mode == "multiple_top": 383 ↛ 384line 383 didn't jump to line 384, because the condition on line 383 was never true
384 raise Exception(
385 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
386 )
387 self.feat.add_all(*args, **kwargs)
389 def add_selection(self, *args, **kwargs):
390 if self.mode == "multiple_top": 390 ↛ 391line 390 didn't jump to line 391, because the condition on line 390 was never true
391 raise Exception(
392 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
393 )
394 self.feat.add_selection(*args, **kwargs)
396 def add_distances(self, *args, **kwargs):
397 if self.mode == "multiple_top": 397 ↛ 398line 397 didn't jump to line 398, because the condition on line 397 was never true
398 raise Exception(
399 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
400 )
401 self.feat.add_distances(*args, **kwargs)
403 def add_distances_ca(self, *args, **kwargs):
404 if self.mode == "multiple_top": 404 ↛ 405line 404 didn't jump to line 405, because the condition on line 404 was never true
405 raise Exception(
406 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
407 )
408 self.feat.add_distances_ca(*args, **kwargs)
410 def add_inverse_distances(self, *args, **kwargs):
411 if self.mode == "multiple_top": 411 ↛ 412line 411 didn't jump to line 412, because the condition on line 411 was never true
412 raise Exception(
413 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
414 )
415 self.feat.add_inverse_distances(*args, **kwargs)
417 def add_contacts(self, *args, **kwargs):
418 if self.mode == "multiple_top": 418 ↛ 419line 418 didn't jump to line 419, because the condition on line 418 was never true
419 raise Exception(
420 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
421 )
422 self.feat.add_contacts(*args, **kwargs)
424 def add_residue_mindist(self, *args, **kwargs):
425 if self.mode == "multiple_top": 425 ↛ 426line 425 didn't jump to line 426, because the condition on line 425 was never true
426 raise Exception(
427 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
428 )
429 self.feat.add_residue_mindist(*args, **kwargs)
431 def add_group_COM(self, *args, **kwargs):
432 if self.mode == "multiple_top": 432 ↛ 433line 432 didn't jump to line 433, because the condition on line 432 was never true
433 raise Exception(
434 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
435 )
436 self.feat.add_group_COM(*args, **kwargs)
438 def add_residue_COM(self, *args, **kwargs):
439 if self.mode == "multiple_top": 439 ↛ 440line 439 didn't jump to line 440, because the condition on line 439 was never true
440 raise Exception(
441 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
442 )
443 self.feat.add_residue_COM(*args, **kwargs)
445 def add_group_mindist(self, *args, **kwargs):
446 if self.mode == "multiple_top":
447 raise Exception(
448 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
449 )
450 self.feat.add_group_mindist(*args, **kwargs)
452 def add_angles(self, *args, **kwargs):
453 if self.mode == "multiple_top": 453 ↛ 454line 453 didn't jump to line 454, because the condition on line 453 was never true
454 raise Exception(
455 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
456 )
457 self.feat.add_angles(*args, **kwargs)
459 def add_dihedrals(self, *args, **kwargs):
460 if self.mode == "multiple_top": 460 ↛ 461line 460 didn't jump to line 461, because the condition on line 460 was never true
461 raise Exception(
462 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
463 )
464 self.feat.add_dihedrals(*args, **kwargs)
466 def add_backbone_torsions(self, *args, **kwargs):
467 if self.mode == "multiple_top": 467 ↛ 468line 467 didn't jump to line 468, because the condition on line 467 was never true
468 raise Exception(
469 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
470 )
471 self.feat.add_backbone_torsions(*args, **kwargs)
473 def add_chi1_torsions(self, *args, **kwargs):
474 if self.mode == "multiple_top": 474 ↛ 475line 474 didn't jump to line 475, because the condition on line 474 was never true
475 raise Exception(
476 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
477 )
478 self.feat.add_sidechain_torsions(which=["chi1"], *args, **kwargs)
480 def add_sidechain_torsions(self, *args, **kwargs):
481 if self.mode == "multiple_top":
482 raise Exception(
483 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
484 )
485 self.feat.add_sidechain_torsions(*args, **kwargs)
487 def add_minrmsd_to_ref(self, *args, **kwargs):
488 if self.mode == "multiple_top": 488 ↛ 489line 488 didn't jump to line 489, because the condition on line 488 was never true
489 raise Exception(
490 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies."
491 )
492 self.feat.add_minrmsd_to_ref(*args, **kwargs)
494 def add_custom_feature(self, feature):
495 self.feat.add_custom_feature(feature)
497 @property
498 def features(self):
499 if self.mode == "single_top": 499 ↛ 502line 499 didn't jump to line 502, because the condition on line 499 was never false
500 return self.feat.active_features
501 else:
502 return [f.features for f in self.feat]
504 @property
505 def sorted_info_single(self):
506 if self.mode == "single_top":
507 raise Exception(
508 "Attribute is only accessible, when working with mutliple topologies."
509 )
510 out = []
511 for info_all in self.sorted_trajs:
512 for traj in info_all:
513 out.append(traj)
514 return out
516 @property
517 def sorted_featurizers(self):
518 if self.mode == "single_top":
519 raise Exception(
520 "Attribute is only accessible, when working with mutliple topologies."
521 )
522 out = []
523 for feat, info_all in zip(self.feat, self.sorted_trajs):
524 out.extend([feat for i in range(info_all.n_trajs)])
525 return out
527 @property
528 def trajs(self):
529 return self._trajs
531 def describe(self):
532 return self.feat.describe()
534 @trajs.setter
535 def trajs(self, trajs):
536 if isinstance(trajs, SingleTraj) or trajs.__class__.__name__ == "SingleTraj":
537 self._trajs = trajs._gen_ensemble()
538 self.top = trajs.top
539 self.feat = featurizer(self.top)
540 if _validate_uri(trajs.traj_file): 540 ↛ 541line 540 didn't jump to line 541, because the condition on line 540 was never true
541 self.inp = source([trajs.xyz], features=self.feat)
542 else:
543 try:
544 self.inp = source([trajs.traj_file], features=self.feat)
545 except Exception:
546 print(trajs.traj_file)
547 print(trajs.top_file)
548 raise
549 self.mode = "single_top"
550 elif ( 550 ↛ 582line 550 didn't jump to line 582
551 isinstance(trajs, TrajEnsemble)
552 or trajs.__class__.__name__ == "TrajEnsemble"
553 ):
554 if len(trajs.top) > 1:
555 self._trajs = trajs
556 # self.top = Topologies(trajs.top)
557 self.top = trajs.top
558 self.sorted_trajs = []
559 for top in trajs.top:
560 matching_trajs = list(
561 filter(lambda x: True if x.top == top else False, trajs)
562 )
563 self.sorted_trajs.append(TrajEnsemble(matching_trajs))
564 self.feat = [Featurizer(t) for t in self.sorted_trajs]
565 self.inp = [
566 source([t.traj_file for t in t_subset], features=feat.feat)
567 for t_subset, feat in zip(self.sorted_trajs, self.feat)
568 ]
569 self.mode = "multiple_top"
570 else:
571 self._trajs = trajs
572 self.top = trajs.top[0]
573 self.feat = featurizer(self.top)
574 if all([_validate_uri(traj.traj_file) for traj in trajs]): 574 ↛ 575line 574 didn't jump to line 575, because the condition on line 574 was never true
575 self.inp = source(trajs.xtc, features=self.feat)
576 else:
577 self.inp = source(
578 [traj.traj_file for traj in trajs], features=self.feat
579 )
580 self.mode = "single_top"
581 else:
582 raise TypeError(
583 f"trajs must be {SingleTraj.__class__} or {TrajEnsemble.__class__}, you provided {trajs.__class__}"
584 )
586 def __len__(self):
587 if self.mode == "single_top": 587 ↛ 590line 587 didn't jump to line 590, because the condition on line 587 was never false
588 return len(self.feat.active_features)
589 else:
590 return len([f.features for f in self.feat])
592 def __str__(self):
593 if self.mode == "single_top":
594 return self.feat.__str__()
595 else:
596 return ", ".join([f.__str__() for f in self.feat])
598 def __repr__(self):
599 if self.mode == "single_top":
600 return self.feat.__repr__()
601 else:
602 return ", ".join([f.__repr__() for f in self.feat])
605class EmptyFeature:
606 """Class to fill with attributes to be read by encodermap.xarray.
608 This class will be used in multiple_top mode, where the attributes
609 _dim, describe and name will be overwritten with correct values to
610 build features that contain NaN values.
612 """
614 def __init__(self, name, _dim, description, indexes):
615 """Initialize the Empty feature.
617 Args:
618 name (str): The name of the feature.
619 _dim (int): The feature length of the feature shape=(n_frames, ferature).
620 description (list of str): The description for every feature.
622 """
623 self.name = name
624 self._dim = _dim
625 self.description = description
626 self.indexes = indexes
628 def describe(self):
629 return self.description
632class Topologies:
633 def __init__(self, tops, alignments=None):
634 self.tops = tops
635 if alignments is None:
636 alignments = [
637 "side_dihedrals",
638 "central_cartesians",
639 "central_distances",
640 "central_angles",
641 "central_dihedrals",
642 ]
643 self.alignments = {k: {} for k in alignments}
644 self.compare_tops()
645 allowed_strings = list(
646 filter(
647 lambda x: True if "side" in x else False,
648 (k for k in UNDERSOCRE_MAPPING.keys()),
649 )
650 )
651 if not all([i in allowed_strings for i in alignments]):
652 raise Exception(
653 f"Invalid alignment string in `alignments`. Allowed strings are {allowed_strings}"
654 )
656 def compare_tops(self):
657 if not all([t.n_residues == self.tops[0].n_residues for t in self.tops]):
658 raise Exception(
659 "Using Different Topologies currenlty only works if all contain the same number of residues."
660 )
661 generators = [t.residues for t in self.tops]
662 sidechains = [t.select("sidechain") for t in self.tops]
663 all_bonds = [
664 list(map(lambda x: (x[0].index, x[1].index), t.bonds)) for t in self.tops
665 ]
667 # iterate over residues of the sequences
668 n_res_max = max([t.n_residues for t in self.tops])
669 for i in range(n_res_max):
670 # get some info
671 residues = [next(g) for g in generators]
672 all_atoms = [[a.name for a in r.atoms] for r in residues]
673 atoms = [
674 list(
675 filter(
676 lambda x: True
677 if x.index in sel and "H" not in x.name and "OXT" not in x.name
678 else False,
679 r.atoms,
680 )
681 )
682 for r, sel in zip(residues, sidechains)
683 ]
684 atoms_indices = [[a.index for a in atoms_] for atoms_ in atoms]
685 bonds = [
686 list(
687 filter(
688 lambda bond: True if any([b in ai for b in bond]) else False, ab
689 )
690 )
691 for ai, ab in zip(atoms_indices, all_bonds)
692 ]
694 # reduce the integers of atoms_indices and bonds, so that N is 0. That way, we can compare them, even, when
695 # two amino aicds in the chains are different
696 N_indices = [
697 list(filter(lambda x: True if x.name == "N" else False, r.atoms))[
698 0
699 ].index
700 for r in residues
701 ]
703 # align to respective N
704 atoms_indices = [
705 [x - N for x in y] for y, N in zip(atoms_indices, N_indices)
706 ]
707 bonds = [
708 [(x[0] - N, x[1] - N) for x in y] for y, N in zip(bonds, N_indices)
709 ]
711 chi1 = [any(set(l).issubset(set(a)) for l in CHI1_ATOMS) for a in all_atoms]
712 chi2 = [any(set(l).issubset(set(a)) for l in CHI2_ATOMS) for a in all_atoms]
713 chi3 = [any(set(l).issubset(set(a)) for l in CHI3_ATOMS) for a in all_atoms]
714 chi4 = [any(set(l).issubset(set(a)) for l in CHI4_ATOMS) for a in all_atoms]
715 chi5 = [any(set(l).issubset(set(a)) for l in CHI5_ATOMS) for a in all_atoms]
716 chi = np.array([chi1, chi2, chi3, chi4, chi5])
718 self.alignments["side_dihedrals"][f"residue_{i}"] = chi
720 if "side_cartesians" in self.alignments:
721 raise NotImplementedError(
722 "Cartesians between different topologies can currently not be aligned."
723 )
725 if "side_distances" in self.alignments:
726 raise NotImplementedError(
727 "Distances between different topologies can currently not be aligned."
728 )
730 if "side_angles" in self.alignments:
731 raise NotImplementedError(
732 "Angles between different topologies can currently not be aligned."
733 )
735 self.drop_double_false()
737 def drop_double_false(self):
738 """Drops features that None of the topologies have.
740 For example: Asp and Glu. Asp has a chi1 and chi2 torsion. Glu has chi1, chi2 and chi3. Both
741 don't have chi4 or chi5. In self.compare_tops these dihedrals are still considered. In this
742 method they will be removed.
744 """
745 for alignment, value in self.alignments.items():
746 for residue, array in value.items():
747 where = np.where(np.any(array, axis=1))[0]
748 self.alignments[alignment][residue] = array[where]
750 def get_max_length(self, alignment):
751 """Maximum length that a feature should have given a certain axis.
753 Args:
754 alignment (str): The key for `self.alignments`.
756 """
757 alignment_dict = self.alignments[alignment]
758 stacked = np.vstack([v for v in alignment_dict.values()])
759 counts = np.count_nonzero(stacked, axis=0) # Flase is 0
760 return np.max(counts)
762 def format_output(self, inputs, feats, sorted_trajs):
763 """Formats the output of an em.Featurizer object using the alignment info.
765 Args:
766 inputs (list): List of pyemma.coordinates.data.feature_reader.FeatureReader objects.
767 feats (list): List of encodermap.Featurizer objetcs.
768 sorted_trajs (list): List of em.TrajEnsemble objects sorted in the same way as `self.tops`.
770 """
771 out = []
772 for i, (inp, top, feat, trajs) in enumerate(
773 zip(inputs, self.tops, feats, sorted_trajs)
774 ):
775 value_dict = {}
776 for traj_ind, (data, traj) in enumerate(zip(inp.get_output(), trajs)):
777 if any(
778 [isinstance(o, EmptyFeature) for o in feat.feat.active_features]
779 ):
780 from ..misc.xarray import add_one_by_one
782 ffunc = lambda x: True if "NaN" not in x else False
783 indices = [0] + add_one_by_one(
784 [len(list(filter(ffunc, f.describe()))) for f in feat.features]
785 )
786 else:
787 indices = get_indices_by_feature_dim(feat, traj, data.shape)
789 # divide the values returned by PyEMMA
790 for f, ind in zip(feat.features, indices):
791 try:
792 name = FEATURE_NAMES[f.name]
793 except KeyError:
794 name = f.__class__.__name__
795 f.name = name
796 except AttributeError:
797 name = f.__class__.__name__
798 f.name = name
799 if traj_ind == 0:
800 value_dict[name] = []
801 value_dict[name].append(data[:, ind])
803 # stack along the frame axis, just like pyemma would
804 value_dict = {k: np.vstack(v) for k, v in value_dict.items()}
806 # put nans in all features specified by alignment
807 for alignment, alignment_dict in self.alignments.items():
808 if alignment not in value_dict:
809 continue
810 max_length = self.get_max_length(alignment)
811 new_values = np.full(
812 shape=(value_dict[alignment].shape[0], max_length),
813 fill_value=np.nan,
814 )
815 where = np.vstack([v for v in alignment_dict.values()])[:, i]
816 new_values[:, where] = value_dict[alignment]
817 value_dict[alignment] = new_values
819 # find the index of the feature in feat.feat.active_features
820 names = np.array(
821 [f.__class__.__name__ for f in feat.feat.active_features]
822 )
823 index = np.where([n in FEATURE_NAMES for n in names])[0]
824 index = index[
825 np.where([FEATURE_NAMES[n] == alignment for n in names[index]])
826 ]
828 # get the old description and change it around
829 assert len(index) == 1
830 index = index[0]
831 if not isinstance(feat.feat.active_features[index], EmptyFeature):
832 old_desc = np.array(
833 [i for i in feat.feat.active_features[index].describe()]
834 )
835 new_desc = np.array(
836 [
837 f"NaN due to ensemble with other topologies {i}"
838 for i in range(max_length)
839 ]
840 )
841 new_desc[where] = old_desc
842 new_desc = new_desc.tolist()
844 # get the old indexes and add the NaNs
845 old_indexes = feat.feat.active_features[index].indexes
846 new_indexes = np.full(
847 shape=(max_length, old_indexes.shape[1]), fill_value=np.nan
848 )
849 new_indexes[where] = old_indexes
851 # create empty feature
852 new_class = EmptyFeature(
853 alignment, max_length, new_desc, new_indexes
854 )
855 feat.feat.active_features[index] = new_class
857 new_values = np.hstack([v for v in value_dict.values()])
858 out.append([new_values, feat, trajs])
859 return out
861 def format_output_generator(self, inputs, feats, sorted_trajs):
862 raise NotImplementedError("Will be implemented once tfRecords are implemented.")
864 def __iter__(self):
865 self._index = 0
866 return self
868 def __next__(self):
869 if self._index >= len(self.tops):
870 raise StopIteration
871 else:
872 self._index += 1
873 return self.tops[self._index - 1]