Coverage for encodermap/misc/backmapping.py: 7%
743 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/misc/backmapping.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Backmapping functions to create new atomistic conformations from intrinsic
23coordinates.
25"""
27################################################################################
28# Imports
29################################################################################
32# Future Imports at the top
33from __future__ import annotations
35# Standard Library Imports
36import os.path
37from contextlib import contextmanager
38from copy import deepcopy
39from math import pi
40from pathlib import Path
42# Third Party Imports
43import numpy as np
44import tensorflow as tf
45import transformations
46from optional_imports import _optional_import
47from tqdm import tqdm as normal_tqdm
48from tqdm.notebook import tqdm as notebook_tqdm
49from transformations import rotation_matrix as transformations_rotation_matrix
51# Encodermap imports
52from encodermap.loading import features
53from encodermap.misc.misc import _is_notebook
54from encodermap.misc.rotate import _dihedral, _get_near_and_far_networkx
55from encodermap.trajinfo.info_all import TrajEnsemble
56from encodermap.trajinfo.info_single import SingleTraj
59# numba to accelerate
60try:
61 # Third Party Imports
62 from numba import jit
64 _NUMBA_AVAILABLE: bool = True
65except ImportError:
66 _NUMBA_AVAILABLE: bool = False
69################################################################################
70# Optional Imports
71################################################################################
74md = _optional_import("mdtraj")
75xr = _optional_import("xarray")
76mda = _optional_import("MDAnalysis")
77AnalysisFromFunction = _optional_import(
78 "MDAnalysis", "analysis.base.AnalysisFromFunction"
79)
80MemoryReader = _optional_import("MDAnalysis", "coordinates.memory.MemoryReader")
81jit = _optional_import("numba", "jit")
82nb = _optional_import("numba")
83nx = _optional_import("networkx")
86################################################################################
87# Typing
88################################################################################
91# Standard Library Imports
92from collections.abc import Iterator, Sequence
93from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload
96if TYPE_CHECKING:
97 # Standard Library Imports
98 from io import BytesIO
100 # Third Party Imports
101 import MDAnalysis as mda
102 import mdtraj as md
103 import networkx as nx
104 from MDAnalysis.analysis.base import AnalysisFromFunction
105 from MDAnalysis.coordinates.memory import MemoryReader
107 # Local Folder Imports
108 from ..trajinfo import SingleTraj, TrajEnsemble
111################################################################################
112# Globals
113################################################################################
116__all__: list[str] = ["backbone_hydrogen_oxygen_crossproduct", "mdtraj_backmapping"]
119################################################################################
120# Helpers
121################################################################################
124@contextmanager
125def temp_seed(seed):
126 """Within this context manager, the random state of NumPy is fixed."""
127 state = np.random.get_state()
128 np.random.seed(seed)
129 try:
130 yield
131 finally:
132 np.random.set_state(state)
135def _raise_components_exception(
136 components: Sequence[nx.Graph],
137 trajs: TrajEnsemble,
138 top: Optional[md.Topology] = None,
139 remove_component_size: int = 0,
140) -> None:
141 """Raises a descriptive error message for the components upon breaking the
142 connection graph of a protein topology."""
143 smallest_component = sorted(components, key=lambda x: len(x))[0]
144 largest_component = sorted(components, key=lambda x: len(x))[1]
145 if top is None:
146 _str = f"from the provided {trajs.__class__.__name__}"
147 elif isinstance(top, (int, np.int64)):
148 _str = f"from `SingleTraj` number {top} of the provided `TrajEnsemble`."
149 else:
150 _str = f"from the provided {top} file"
151 msg = (
152 f"The protein {_str} is disconnected. Changing dihedrals "
153 f"in multiple disconnected chains is currently not possible. You can also "
154 f"use `em.plot.plot_ball_and_stick(traj highlight='bonds')` to see which "
155 f"bonds are present in the topology. If you are sure "
156 f"your protein is just one chain you can try to load a custom topology "
157 f"or provide a topology with manually fixed bonds. I got {len(components)} "
158 f"disconnected components. The smallest component contains these atoms: "
159 f"{smallest_component=} {largest_component=}. You can try and add "
160 f"these bonds with the `custom_aas` keyword to this function call. If "
161 f"these components contain unwanted residues like solvent or salt-ions, "
162 f"you can set `remove_component_size` to a number representing the "
163 f"sizes of these components (i.e. 3 for three atom water) to remove "
164 f"these components from the trajectory."
165 )
166 if remove_component_size > 0:
167 msg += (
168 f" Your currently chosen `remove_component_size`={remove_component_size} "
169 f"is not large enough to reduce the system to only one connected component."
170 )
171 raise Exception(msg)
174################################################################################
175# Public Functions
176################################################################################
179def split_and_reverse_dihedrals(x: tf.Tensor) -> tf.Tensor:
180 """Splits dihedrals in BackMapping model into left (reversed) and right part.
181 These dihedrals are then used to bring the chain_in_plane into 3D.
183 Args:
184 x (tf.Tensor): The dihedrals with shape (None, n_reisudes * 3 - 3)
186 Examples:
187 >>> from encodermap.misc.backmapping import split_and_reverse_dihedrals
188 >>> import numpy as np
190 >>> # create dihedrals for protein with 3 resiudes, i.e. 3*3 - 3 = 6 central dihedral angles
191 >>> # single sample will be used -> shape = (1, 6)
192 >>> np.random.seed(20)
193 >>> dihedrals = np.random.random((1, 6)) * 2 * np.pi
194 >>> print(dihedrals)
195 [[3.69533481 5.64050171 5.60165278 5.12605805 0.22550092 4.34644107]]
197 >>> dihedrals_left, dihedrals_right = split_and_reverse_dihedrals(dihedrals)
198 >>> print(dihedrals_left, dihedrals_right)
199 [[5.60165278 5.64050171 3.69533481]] [[5.12605805 0.22550092 4.34644107]]
201 """
202 middle = int(int(x.shape[1]) / 2)
203 cond = tf.math.equal(tf.math.mod(x.shape[1], 2), 0)
204 return tf.cond(
205 cond,
206 true_fn=lambda: (
207 x[:, middle - 1 :: -1],
208 x[:, middle:],
209 ), # , middle, middle),
210 false_fn=lambda: (
211 x[:, middle::-1],
212 x[:, middle + 1 :],
213 ), # , middle + 1, middle),
214 )
217def split_and_reverse_cartesians(x):
218 """Splits cartesians and returns a left (reversed) right part.
220 Because dihedrals are made up from 4 atoms, three atoms are
221 identical in the left and right part of the list. This holds true:
222 left[0] = right[2]
223 left[1] = right[1]
224 left[2] = right[0]
226 Args:
227 x (tf.Tensor): The cartesians with shape (None, n_reisudes * 3, 3)
229 Examples:
230 >>> from encodermap.misc.backmapping import split_and_reverse_cartesians
231 >>> import numpy as np
233 >>> # create cartesians for protein with 3 resiudes, i.e. 9
234 >>> # single sample will be used -> shape = (1, 9, 3)
235 >>> np.random.seed(20)
236 >>> cartesians = np.random.random((1, 9, 3)) * 10
238 >>> cartesians_left, cartesians_right = split_and_reverse_cartesians(cartesians)
240 >>> print(cartesians_left.shape, cartesians_right.shape)
241 (1, 6, 3) (1, 6, 3)
243 >>> print(cartesians_left[:,0] == cartesians_right[:,2])
244 [[ True True True]]
246 >>> print(cartesians_left[:,1] == cartesians_right[:,1])
247 [[ True True True]]
249 >>> print(cartesians_left[:,2] == cartesians_right[:,0])
250 [[ True True True]]
254 """
255 split = int(int(x.shape[1]) / 2)
256 return x[:, split + 1 :: -1], x[:, split - 1 :]
259def dihedrals_to_cartesian_tf_layers(
260 dihedrals: tf.Tensor,
261 cartesians: tf.Tensor,
262 left_iteration_counter: int,
263 right_iteration_counter: int,
264) -> tf.Tensor:
265 """Calculates dihedrals to cartesians in Graph/Layer execution.
267 Args:
268 dihedrals (tf.Tensor): The dihedrals of shape (None, n_resides * 3 - 3)
269 cartesians (tf.Tensor): The cartesians of shape (None, n_residues * 3, 3).
270 left_iteration_counter (int): The range(left_iteration_counter) for
271 iteration over the left split of the cartesians chain. Needs to
272 be supplied, because the shape of the SymbolicTensor is not known
273 to tensorflow.
274 right_iteration_counter (int): The range(right_iteration_counter) for
275 iteration over the left split of the cartesians chain. Needs to
276 be supplied, because the shape of the SymbolicTensor is not known
277 to tensorflow.
279 Returns:
280 tf.Tensor: The finished 3d chain.
282 """
283 if len(cartesians.get_shape()) == 2:
284 # if a single line of cartesians is passed, it is repeated to match the number of dihedrals
285 cartesians = tf.tile(
286 tf.expand_dims(cartesians, axis=0), [tf.shape(dihedrals[0], 1, 1)]
287 )
289 # split and reverse so that the center of the molecule stays on the 2D plane
290 # and the left and right ends curl into the 3rd dimension
291 cartesians_left, cartesians_right = split_and_reverse_cartesians(cartesians)
292 dihedrals_left, dihedrals_right = split_and_reverse_dihedrals(dihedrals)
294 new_cartesians_left = dihedral_to_cartesian_tf_one_way_layers(
295 dihedrals=dihedrals_left,
296 cartesian=cartesians_left,
297 n=left_iteration_counter,
298 )
299 new_cartesians_right = dihedral_to_cartesian_tf_one_way_layers(
300 dihedrals=dihedrals_right,
301 cartesian=cartesians_right,
302 n=right_iteration_counter,
303 )
305 new_cartesians = tf.concat(
306 [new_cartesians_left[:, ::-1], new_cartesians_right[:, 3:]], axis=1
307 )
309 return new_cartesians
312@jit(nopython=True)
313def _displacement_jit(xyz: np.ndarray, index: np.ndarray) -> np.ndarray:
314 """Faster nopython numba-jit compiled version of calculating displacements.
316 Args:
317 xyz (np.ndarray): Numpy array with shape (n_atoms, 3).
318 index (np.ndarray): Numpy array with shape (n_pairs, 2) and of type int,
319 indexing the displacements calculated between two atoms.
321 Returns:
322 np.ndarray: The displacements of shape (n_pairs, 1).
324 """
325 return xyz[index[1]] - xyz[index[0]]
328@jit(nopython=True)
329def _dihedral_jit(
330 xyz: np.ndarray,
331 indices: np.ndarray,
332) -> np.ndarray:
333 """Faster nopython numba-jit compiled implementation of calculating a dihedral.
335 Args:
336 xyz (np.ndarray): Numpy array with shape (n_atoms, 3).
337 indices (np.ndarray): Numpy array with shape (n_dihedrals, 4) and of type int,
338 indexing the dihedrals defined between four atoms.
340 Returns:
341 np.ndarray: The displacements of shape (n_dihedrals, 1).
343 """
344 b1 = _displacement_jit(xyz, indices[0:2])
345 b2 = _displacement_jit(xyz, indices[1:3])
346 b3 = _displacement_jit(xyz, indices[2:4])
347 c1 = np.cross(b2, b3)
348 c2 = np.cross(b1, b2)
349 p1 = (b1 * c1).sum(-1)
350 p1 *= (b2 * b2).sum(-1) ** 0.5
351 p2 = (c1 * c2).sum(-1)
352 return np.arctan2(p1, p2)
355@jit(nopython=True)
356def _rotmat_jit(
357 angle: np.float32,
358 direction: np.ndarray,
359 pivot_point: np.ndarray,
360) -> np.ndarray:
361 """Same as :func:`rotation_matrix`, but jit compiled."""
362 sina = np.sin(angle)
363 cosa = np.cos(angle)
364 direction_unit = direction / (direction**2).sum() ** 0.5
365 R = np.identity(3, dtype="float32")
366 R *= cosa
367 R += np.outer(direction_unit, direction_unit) * (1.0 - cosa)
368 direction_unit *= sina
369 R += np.array(
370 [
371 [0.0, -direction_unit[2], direction_unit[1]],
372 [direction_unit[2], 0.0, -direction_unit[0]],
373 [-direction_unit[1], direction_unit[0], 0.0],
374 ],
375 dtype="float32",
376 )
377 M = np.identity(4, dtype="float32")
378 M[:3, :3] = R
379 M[:3, 3] = pivot_point - np.dot(R, pivot_point)
380 return M
383@jit(nopython=True, parallel=True)
384def parallel_rotation_application(
385 xyz: np.ndarray,
386 dihedral_indices: np.ndarray,
387 dihedrals: np.ndarray,
388 new_and_far_sides: np.ndarray,
389) -> None:
390 for j in range(
391 dihedrals.shape[1]
392 ): # cannot be parallelized because the later angles depend on the previous
393 for i in nb.prange(
394 dihedrals.shape[0]
395 ): # can be parallelized because every frame can be treated separately
396 target_angle = dihedrals[i, j]
397 dihedral_index = dihedral_indices[i]
398 current_angle = _dihedral_jit(xyz[i], dihedral_index)
399 angle = target_angle - current_angle
400 direction = xyz[i, dihedral_index[2]] - xyz[i, dihedral_index[1]]
401 pivot_point = xyz[i, dihedral_index[0]]
402 M = _rotmat_jit(angle, direction, pivot_point)
403 padded = np.ones((len(xyz[i][~new_and_far_sides[i]]), 4), dtype="float32")
404 padded[:, :3] = xyz[i][~new_and_far_sides[i]]
405 xyz[i][~new_and_far_sides[i]] = M.dot(padded.T).T[:, :3]
408@overload
409def _full_backmapping_np( 409 ↛ exitline 409 didn't jump to the function exit
410 feature_description: dict[int, dict[int, Sequence[int]]],
411 *inputs: np.ndarray,
412 return_indices: Literal[False],
413) -> tuple[np.ndarray, BytesIO]: ...
416@overload
417def _full_backmapping_np( 417 ↛ exitline 417 didn't jump to the function exit
418 feature_description: dict[int, dict[int, Sequence[int]]],
419 *inputs: np.ndarray,
420 return_indices: Literal[True],
421) -> tuple[np.ndarray, BytesIO, dict[str, np.ndarray]]: ...
424def _full_backmapping_np(
425 feature_description: dict[int, dict[int, Sequence[int]]],
426 *inputs: np.ndarray,
427 return_indices: bool = False,
428) -> Union[
429 tuple[np.ndarray, BytesIO], tuple[np.ndarray, BytesIO, dict[str, np.ndarray]]
430]:
431 """Also BackMaps sidechains. For that, we need a way to know which
432 distances, angles, dihedrals belong to the backbone, and which belong to
433 a sidechain. In contrast to how cartesians are normally represented in
434 MD trajectories, the cartesians in this class are ordered by first all
435 central cartesians, then all sidechain cartesians. Consider a theoretical
436 protein with three residues: MET-LYS-GLY. The protein's backbone has
437 9 cartesian coordiantes, 8 distances, 7 angles, and 6 dihedrals.
438 Methionine has 3 sidechain angles (ch1, chi2, chi3), Lysine has 4 sidechain
439 angles (chi1, ..., chi4), and Glycine has no sidechain angles. The
440 feature_description dict for this protein should be:
441 {-1: {1: 3, 2: 4, 3: 0}.
443 * Cartesians:
444 The cartesian coordinates of the sidechain atoms is an array with shape
445 `(sum([v + 1 for v in feature_description[-1].values() if v > 0])`. So for
446 this theoretical protein, we have 4 sidechain coordinates for MET and 5 for LYS.
447 The combined cartesians array first contains all backbones, and then follows
448 up with the sidechain positions as defined in the feature_description dict.
449 * Distances:
450 The distances between the sidechain atoms include the CA of the backbone,
451 thus we have `(sum([v + 1 for v in feature_description[-1].values() if v > 0])`
452 sidechain distances. MET has 4 sidechain distances, LYS has 5.
453 * Angles:
454 The angles between the sidechain atoms include the CA and N of the backbone.
455 We have `(sum([v + 1 for v in feature_description[-1].values() if v > 0])`
456 angles. MET has 4 sidechain angles (N-CA-CB, CA-CB-CG, CB-CG-SD, CG-SD-CE)
457 and LYS has 5 sidechain angles.
458 * Dihedrals:
459 The dihedrals between the sidechain atoms include the CA and N of the
460 backbone, so that we have `(sum(list(feature_description[-1].values()))`
461 sidechain dihedrals.
463 Indexing these features is done via boolean arrays. True values are kept
464 stationary. False values are allowed to move. Indexing the central distances
465 can be done with a triangular matrix with 8 rows, corresponding to the
466 8 distances in our example protein: ::
468 METN, METCA, METC, LYSN, LYSCA, LYSC, GLYN, GLYCA, GLYC
469 0 True, False, False, False, False, False, False, False, False
470 1 True, True, False, False, False, False, False, False, False
471 2 True, True, True, False, False, False, False, False, False
472 3 True, True, True, True, False, False, False, False, False
473 4 True, True, True, True, True, False, False, False, False
474 5 True, True, True, True, True, True, False, False, False
475 6 True, True, True, True, True, True, True, False, False
476 7 True, True, True, True, True, True, True, True, False
478 Indexing the sidechain distances can be done with a matrix with small
479 triangular matrices like this. For our example protein we have 9 such
480 indices for the 9 sidechain distances ::
482 METCA, METCB, METCG, METSD, METCE, LYSCA, LYSCB, LYSCG, LYSCD, LYSCE, LYSNZ
483 0 True, False, False, False, False, False, False, False, False, False, False
484 1 True, True, False, False, False, False, False, False, False, False, False
485 2 True, True, True, False, False, False, False, False, False, False, False
486 3 True, True, True, False, False, False, False, False, False, False, False
487 4 True, True, True, True, False, False, False, False, False, False, False
488 5 False, False, False, False, False, True, False, False, False, False, False
489 6 False, False, False, False, False, True, True, False, False, False, False
490 7 False, False, False, False, False, True, True, True, False, False, False
491 8 False, False, False, False, False, True, True, True, True, False, False
492 9 False, False, False, False, False, True, True, True, True, True, False
493 10 False, False, False, False, False, True, True, True, True, True, True
495 However, to keep the CA-atoms from appearing twice, the resulting array needs to
496 ditch the CA columns of the side_distances and add True, where needed. The algorithm
497 for the right side of the central_distances is row 0: all false, every three rows
498 add True for the number of sidechain atoms. The last row is just True. The left
499 side for the sidechain distances (without CA) are just True. We keep the
500 chain immovable for adjusting these lengths. For
501 our protein the resulting array looks like this: ::
503 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17
504 METN, METCA, METC, LYSN, LYSCA, LYSC, GLYN, GLYCA, GLYC, METCB, METCG, METSD, METCE, LYSCB, LYSCG, LYSCD, LYSCE, LYSNZ
505 0 True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False
506 1 True, True, False, False, False, False, False, False, False, True, True, True, True, False, False, False, False, False
507 2 True, True, True, False, False, False, False, False, False, True, True, True, True, False, False, False, False, False
508 3 True, True, True, True, False, False, False, False, False, True, True, True, True, False, False, False, False, False
509 4 True, True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True, True
510 5 True, True, True, True, True, True, False, False, False, True, True, True, True, True, True, True, True, True
511 6 True, True, True, True, True, True, True, False, False, True, True, True, True, True, True, True, True, True
512 7 True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, True
513 -----------------------
514 8 True, True, True, True, True, True, True, True, True, False, False, False, False, True, True, True, True, True
515 9 True, True, True, True, True, True, True, True, True, True, False, False, False, True, True, True, True, True
516 10 True, True, True, True, True, True, True, True, True, True, True, False, False, True, True, True, True, True
517 11 True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True
518 12 True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False
519 13 True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False
520 14 True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False
521 15 True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False
522 16 True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False
524 Indexing the angles is similar to the distances. For an angle, we need
525 to index the left side (including pivot point) and the right side
526 (excluding the pivot point). For the central chain, this can be done by
527 omitting the first row of the distance indices. The sidechain indices can
528 be left as is. They contain the correct indices for rotation. Take the angle
529 METN-METCA-METCB as an example. We need the whole central chain to stay still,
530 while METCB, METCG, METSD, METCE are moving. That is exactly row 8 of
531 the array above. The rotation axis for this chain, which is currently just
532 in 2D is the z-axis: (0, 0, 1).
534 Indexing the dihedrals is similar to the angles. For a dihedral, we need to
535 define a pivot point: The first atom of the bond (the second atom of the four-tuple
536 defining the dihedral). Also, we need a rotation axis, which is the vector
537 connecting the two atoms. Whether the second atom is part of the movable
538 group is trivial, as it won't be rotated because it lies within the rotation axis.
539 The indices for the central dihedrals are just the same indices as the
540 distances omitting the first and last rows of the distance indices. This
541 coincides with the indices for the sidechain atoms. The first dihedral
542 angle of any protein is the psi1 angle between the N-terminal nitrogen
543 N-CA-C-N. In this case, the rotational axis is between CA-C and thus, the
544 sidechain (CB, etc.) needs to stay still for this dihedral. This is exactly
545 what row 1 of our index array describes. The next dihedrals, omega1, and phi1,
546 keep these indices of the sidechains, while the indices increase in the
547 central part of the indices (rows 2 and 3). The next dihedral, psi2 will keep
548 the sidechain dihedrals of the next residue stationary, as now the CA is on
549 the far side of the rotation. The sidechain dihedral rotations are similar,
550 however, the last row for every residue's sidechain must be omitted.
552 Args:
553 feature_description (dict[int, dict[int, Sequence[int]]]): A dictrionary
554 that defines the topological hull of the proteins.
555 *inputs (np.ndarray): The inputs in the order central_distances,
556 central_angles, central_dihedrals, side_distances, side_angles,
557 side_dihedrals.
558 return_indices (bool): Whether to also return a dict of np.ndarray, to be
559 used as indices.
561 Returns:
562 Union[tuple[np.ndarray, BytesIO], tuple[np.ndarray, BytesIO, dict[str, np.ndarray]]]:
563 Either a tuple of the reconstructed coordinates in a numpy array and
564 a BytesIO object containing png data. Or a tuple of the former two,
565 with a index dictionary.
567 """
568 ############################################################################
569 # Init
570 ############################################################################
572 # Imports
573 # Standard Library Imports
574 import itertools
575 from io import BytesIO
577 # Third Party Imports
578 import matplotlib.pyplot as plt
579 from scipy.linalg import block_diag
580 from transformations import rotation_matrix
582 # Encodermap imports
583 from encodermap.misc.rotate import _dihedral
585 # Definitions and Tests
586 n_residues: int = max(list(feature_description[-1].keys()))
587 assert np.array_equal(
588 np.arange(1, n_residues + 1),
589 np.sort(np.asarray(list(feature_description[-1].keys()))),
590 ), (
591 f"Currently the `feature_indices[-1]` dict needs to contain monotonous "
592 f"increasing keys. Starting from 1 {feature_description[-1].keys()=}"
593 )
594 n_sidechains: int = sum([v + 1 for v in feature_description[-1].values() if v > 0])
595 sum_sidechains = sum(list(feature_description[-1].values()))
597 # this can be defined beforehand and then stacked as often, as a batch needs it
598 init_xyz: np.ndarray = np.zeros(
599 shape=(1, n_residues * 3 + n_sidechains, 3),
600 dtype="float32",
601 )
603 # first we create the central_distance indices
604 central_distance_indices = np.tri(
605 N=n_residues * 3 - 1,
606 M=n_residues * 3,
607 k=0,
608 ).astype(bool)
609 right_side_central_distance_indices = [
610 np.full(shape=(1, n_sidechains), fill_value=False, dtype=bool)
611 ]
612 count = 0 # starts at the first atom of the central chan
613 count2 = n_residues * 3 + 1 # starts at the first atom of the sidechain
614 sidechain_cartesians_ind = []
615 sidechain_positions_indices = []
616 central_distance_index_duplets = np.vstack(
617 [
618 np.arange(0, n_residues * 3)[:-1],
619 np.arange(0, n_residues * 3)[1:],
620 ]
621 ).T.tolist()
622 sidechain_distance_index_duplets = []
623 central_angle_index_triplets = np.vstack(
624 [
625 np.arange(0, n_residues * 3)[:-2],
626 np.arange(0, n_residues * 3)[1:-1],
627 np.arange(0, n_residues * 3)[2:],
628 ]
629 ).T.tolist()
630 sidechain_angle_index_triplets = []
631 central_dihedral_index_quadruplets = np.vstack(
632 [
633 np.arange(0, n_residues * 3)[:-3],
634 np.arange(0, n_residues * 3)[1:-2],
635 np.arange(0, n_residues * 3)[2:-1],
636 np.arange(0, n_residues * 3)[3:],
637 ]
638 ).T.tolist()
639 sidechain_dihedral_index_quadruplets = []
641 # iterate over feature_description[-1] to get all indices and the right side
642 # of the central cartesians
643 for i, (residue, n_sidechains_in_residue) in zip(
644 itertools.count(1, 3), feature_description[-1].items()
645 ):
646 if n_sidechains_in_residue == 0:
647 if residue == 1 or residue == n_residues:
648 continue
649 else:
650 right_side_central_distance_indices.append(t)
651 else:
652 sidechain_cartesians_ind.append(
653 np.arange(count, count + n_sidechains_in_residue)
654 )
655 sidechain_positions_indices.append(
656 [i] + np.arange(count2 - 1, count2 + n_sidechains_in_residue).tolist()
657 )
658 for sidechain_i in range(n_sidechains_in_residue + 1):
659 if sidechain_i == 0:
660 # adds CA-CB
661 sidechain_distance_index_duplets.append(
662 [(residue - 1) * 3 + 1, count2 - 1]
663 )
664 # adds N-CA-CB
665 sidechain_angle_index_triplets.append(
666 [(residue - 1) * 3, (residue - 1) * 3 + 1, count2 - 1]
667 )
668 # adds N-CA-CB-CG
669 sidechain_dihedral_index_quadruplets.append(
670 [(residue - 1) * 3, (residue - 1) * 3 + 1, count2 - 1, count2]
671 )
672 elif sidechain_i == 1:
673 # adds CB-CG
674 sidechain_distance_index_duplets.append([count2 - 1, count2])
675 # adds CA-CB-CG
676 sidechain_angle_index_triplets.append(
677 [(residue - 1) * 3 + 1, count2 - 1, count2]
678 )
679 # adds CA-CB-CG-CD
680 if sidechain_i < n_sidechains_in_residue:
681 sidechain_dihedral_index_quadruplets.append(
682 [(residue - 1) * 3 + 1, count2 - 1, count2, count2 + 1]
683 )
684 else:
685 # adds CG-CD and so on
686 sidechain_distance_index_duplets.append(
687 [count2 + sidechain_i - 2, count2 + sidechain_i - 1]
688 )
689 # adds CB-CG-CD and so on
690 sidechain_angle_index_triplets.append(
691 [
692 count2 + sidechain_i - 3,
693 count2 + sidechain_i - 2,
694 count2 + sidechain_i - 1,
695 ]
696 )
697 if sidechain_i < n_sidechains_in_residue:
698 sidechain_dihedral_index_quadruplets.append(
699 [
700 count2 + sidechain_i - 3,
701 count2 + sidechain_i - 2,
702 count2 + sidechain_i - 1,
703 count2 + sidechain_i,
704 ]
705 )
706 count += n_sidechains_in_residue + 1
707 count2 += n_sidechains_in_residue + 1
708 t = np.zeros(
709 shape=(3, n_sidechains),
710 dtype=bool,
711 )
712 t[:, :count] = True
713 right_side_central_distance_indices.append(t)
714 assert len(sidechain_angle_index_triplets) == n_sidechains
715 assert len(sidechain_dihedral_index_quadruplets) == sum_sidechains, (
716 f"I could not reconstruct the correct number of sidechain dihedral "
717 f"quadruplets. The number of sidechain dihedrals requires the list "
718 f"to have length {sum_sidechains}, but I created a list with "
719 f"{len(sidechain_dihedral_index_quadruplets)}. The input has shape "
720 f"{inputs[5].shape}."
721 )
722 right_side_central_distance_indices.append(
723 np.full(shape=(1, n_sidechains), fill_value=True, dtype=bool)
724 )
725 right_side_central_distance_indices = np.vstack(right_side_central_distance_indices)
726 angle_index_triplets = np.vstack(
727 central_angle_index_triplets + sidechain_angle_index_triplets
728 )
729 dihedral_index_quadruplets = np.vstack(
730 central_dihedral_index_quadruplets + sidechain_dihedral_index_quadruplets
731 )
732 if sidechain_cartesians_ind != []: # if sidechains
733 _use_sidechains = True
734 sidechain_cartesians_ind = np.concatenate(sidechain_cartesians_ind)
735 central_distance_indices = np.hstack(
736 [central_distance_indices, right_side_central_distance_indices]
737 )
738 side_distance_indices = [
739 (np.tri(N=i + 1, M=i + 2, k=0) + 1)[:, 1:]
740 for i in feature_description[-1].values()
741 if i > 0
742 ]
743 side_distance_indices = (block_diag(*side_distance_indices) % 2) == 0
744 left_side_side_distance_indices = (
745 np.full( # all atoms in the central chain are True
746 shape=(len(side_distance_indices), n_residues * 3),
747 fill_value=True,
748 dtype=bool,
749 )
750 )
751 side_distance_indices = np.hstack(
752 [left_side_side_distance_indices, side_distance_indices]
753 )
754 distance_indices = np.vstack([central_distance_indices, side_distance_indices])
755 else: # if no sidechains
756 _use_sidechains = False
757 distance_indices = central_distance_indices
758 assert distance_indices.shape == (
759 n_residues * 3 - 1 + n_sidechains,
760 init_xyz.shape[1],
761 ), (
762 f"The shape of the distance index after stacking is unexpected.\n"
763 f"Expected: {(n_residues * 3 - 1 + n_sidechains, init_xyz.shape[1])}\n"
764 f"Actual: {distance_indices.shape}"
765 )
767 # now the angles
768 central_angle_indices = central_distance_indices[1:]
769 if _use_sidechains: # if sidechains
770 angle_indices = np.vstack([central_distance_indices[1:], side_distance_indices])
771 side_angle_indices = side_distance_indices
772 else: # no sidechains
773 angle_indices = central_distance_indices[1:]
774 assert len(angle_indices) == len(distance_indices) - 1
776 # and the dihedrals
777 if _use_sidechains: # if sidechains
778 dihedral_indices = np.vstack(
779 [
780 central_distance_indices[1:-1],
781 side_distance_indices[sidechain_cartesians_ind],
782 ]
783 )
784 corrector = np.count_nonzero(
785 list(feature_description[-1].values())
786 ) # per reisude with sidechain dihedrals one less
787 else:
788 dihedral_indices = central_distance_indices[1:-1]
789 corrector = 0
790 assert len(dihedral_indices) == len(distance_indices) - 2 - corrector
791 assert angle_index_triplets.shape[0] == angle_indices.shape[0]
792 assert dihedral_index_quadruplets.shape[0] == dihedral_indices.shape[0], (
793 f"The number of dihedral indices ({len(distance_indices)}) and quadruplets "
794 f"does not match ({len(dihedral_index_quadruplets)}). I get "
795 f"{inputs[2].shape[1] + inputs[5].shape[1]} in inputs."
796 )
798 ############################################################################
799 # Call
800 ############################################################################
802 (
803 central_distances,
804 central_angles,
805 central_dihedrals,
806 side_distances,
807 side_angles,
808 side_dihedrals,
809 ) = inputs
811 assert all(
812 len(i) == len(inputs[0]) for i in inputs[1:]
813 ), f"Inhomogeneous input lengths: {[len(i) for i in inputs]}"
814 distances = np.hstack([central_distances, side_distances])
815 angles = np.hstack([central_angles, side_angles])
816 dihedrals = np.hstack([central_dihedrals, side_dihedrals])
817 assert distance_indices.shape[0] == distances.shape[1]
818 assert angle_indices.shape[0] == angles.shape[1]
819 assert dihedral_indices.shape[0] == dihedrals.shape[1]
821 # for debug
822 fig, (ax1, ax2, ax3) = plt.subplots(
823 nrows=1, ncols=3, subplot_kw={"projection": "3d"}, figsize=(20, 8)
824 )
825 buf = BytesIO()
827 # copy the predefined array and make an array of quarternions
828 if len(angles) > 1:
829 xyz_out = np.repeat(init_xyz, len(angles), axis=0)
830 else:
831 xyz_out = init_xyz.copy()
832 xyz_out = np.pad(
833 xyz_out, ((0, 0), (0, 0), (0, 1)), mode="constant", constant_values=1
834 )
836 # distances in more tensor-flow-friendly implementation
837 # by creating a list and concatenating we can forego assignments whicha are
838 # not supported for symbolic tensors
839 xs_central = [np.zeros((len(angles),))]
840 ys_central = [np.zeros((len(angles),))]
841 xs_side = []
842 ys_side = []
844 residue = 0 # residue here is again, 0-based
845 idx = 0
846 j = 0
847 n_sidechains_in_residue = np.array(
848 [feature_description[-1][k] for k in sorted(feature_description[-1].keys())]
849 ).astype(np.int32)
850 for i in range(len(central_distance_indices)):
851 assert np.all(central_distances[:, i] > 0)
852 xs_central.append(xs_central[-1] + central_distances[:, i])
853 ys_central.append(np.zeros((len(angles),)))
854 if idx == 0 and _use_sidechains:
855 n_sidechains = n_sidechains_in_residue[residue]
856 if n_sidechains > 0:
857 for n in range(n_sidechains + 1):
858 xs_side.append(xs_central[-1])
859 dists = side_distances[:, j - n : j + 1]
860 assert np.all(dists > 0), (
861 f"Side distances at (0-based) residue {residue} are smaller than 0. "
862 f"This is the {n} sidechain distance of a total of "
863 f"{n_sidechains + 1} sidechain distances. For that, I index "
864 f"the sidechain dists array for all frames from {j-n=} to "
865 f"{j+1=}, which gives {side_distances[:, j-n:j+1]} for the "
866 f"first 5 frames.."
867 )
868 _ = np.sum(dists, axis=1)
869 if np.any(np.isnan(_)):
870 raise Exception(
871 f"At index {i=} of the central distances and {j=} of "
872 f"the sidechain distances, the value of the y coordinate "
873 f"became NaN. The shapes of the arrays are:"
874 f"{central_distances.shape=} {side_distances.shape=} "
875 f"{np.any(np.isnan(central_distances))=} "
876 f"{np.any(np.isnan(side_distances))=}"
877 )
878 ys_side.append(_)
879 j += 1
880 idx += 1
881 if idx >= 3:
882 residue += 1
883 idx = 0
884 xs = np.stack(xs_central + xs_side, axis=1)
885 ys = np.stack(ys_central + ys_side, axis=1)
886 xyz_out = np.stack([xs, ys, np.zeros(xs.shape), np.ones(xs.shape)], axis=2)
887 assert not np.any(
888 np.isnan(xyz_out)
889 ), f"After fixing dists, some values in `xyz_out` are NaN."
891 for frame, (
892 central_frame_dists,
893 side_frame_dists,
894 central_frame_angles,
895 side_frame_angles,
896 frame_dihedrals,
897 ) in enumerate(
898 zip(central_distances, side_distances, central_angles, side_angles, dihedrals)
899 ):
900 # plot
901 if frame == 0:
902 ax1.plot(*xyz_out[0, : n_residues * 3, :3].T, "bo-")
903 for ind in sidechain_positions_indices:
904 ax1.plot(*xyz_out[0, ind, :3].T, "bo-")
906 # angles
907 for i, (ang, ind, angle_index) in enumerate(
908 zip(
909 central_frame_angles,
910 central_angle_indices,
911 central_angle_index_triplets,
912 )
913 ):
914 direction = np.array([0, 0, 1]).astype("float32")
915 pivot_point = xyz_out[frame, angle_index[1], :3]
916 a, b, c = xyz_out[frame, angle_index, :3]
917 ba = a - b
918 bc = c - b
919 prod = np.linalg.norm(ba) * np.linalg.norm(bc)
920 cosine_angle = np.clip(np.dot(ba, bc) / prod, -1, 1)
921 current_angle = np.arccos(cosine_angle)
922 angle = np.abs(ang - current_angle)
923 rotmat = rotation_matrix(
924 angle=angle, direction=direction, point=pivot_point
925 )
926 rotated = rotmat.dot(xyz_out[frame, ~ind].T).T[:, :3]
927 xyz_out[frame, ~ind, :3] = rotated
929 if _use_sidechains:
930 for i, (ang, ind, angle_index) in enumerate(
931 zip(
932 side_frame_angles,
933 side_angle_indices,
934 sidechain_angle_index_triplets,
935 )
936 ):
937 direction = np.array([0, 0, -1]).astype("float32")
938 pivot_point = xyz_out[frame, angle_index[1], :3]
939 a, b, c = xyz_out[frame, angle_index, :3]
940 ba = a - b
941 bc = c - b
942 cosine_angle = np.clip(
943 np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc)), -1, 1
944 )
945 current_angle = np.arccos(cosine_angle)
946 angle = np.abs(ang - current_angle)
947 rotmat = rotation_matrix(
948 angle=angle, direction=direction, point=pivot_point
949 )
950 rotated = rotmat.dot(xyz_out[frame, ~ind].T).T[:, :3]
951 xyz_out[frame, ~ind, :3] = rotated
953 # plot
954 if frame == 0:
955 ax2.plot(*xyz_out[0, : n_residues * 3, :3].T, "bo-")
956 for ind in sidechain_positions_indices:
957 ax2.plot(*xyz_out[0, ind, :3].T, "bo-")
959 # dihedrals
960 for i, (d_ang, ind, dihedral_index) in enumerate(
961 zip(frame_dihedrals, dihedral_indices, dihedral_index_quadruplets)
962 ):
963 bond = dihedral_index[[1, 2]]
964 a, b, c, d = xyz_out[frame, dihedral_index, :3]
965 direction = np.diff(xyz_out[frame, bond, :3], axis=0).flatten()
966 pivot_point = xyz_out[frame, bond[0], :3]
967 current_dihedral = _dihedral(xyz_out[frame, :, :3], dihedral_index)[0][0]
968 if np.any(np.isnan(current_dihedral)):
969 current_dihedral = 0.0
970 dihedral = d_ang - current_dihedral
971 rotmat = rotation_matrix(
972 angle=dihedral, direction=direction, point=pivot_point
973 )
974 xyz_out[frame, ~ind, :3] = rotmat.dot(xyz_out[frame, ~ind].T).T[:, :3]
976 # plot
977 if frame == 0:
978 ax3.plot(*xyz_out[0, : n_residues * 3, :3].T, "bo-")
979 for ind in sidechain_positions_indices:
980 ax3.plot(*xyz_out[0, ind, :3].T, "bo-")
982 fig.savefig(buf, format="png")
983 buf.seek(0)
984 if not return_indices:
985 return xyz_out[..., :3], buf
986 else:
987 return (
988 xyz_out[..., :3],
989 buf,
990 {
991 "central_distance_indices": np.asarray(central_distance_index_duplets),
992 "side_distance_indices": np.asarray(sidechain_distance_index_duplets),
993 "central_angles_indices": np.asarray(central_angle_index_triplets),
994 "side_angles_indices": np.asarray(sidechain_angle_index_triplets),
995 "central_dihedrals_indices": np.asarray(
996 central_dihedral_index_quadruplets
997 ),
998 "side_dihedrals_indices": np.asarray(
999 sidechain_dihedral_index_quadruplets
1000 ),
1001 },
1002 )
1005def _residue_number() -> Iterator[int]:
1006 i = 0
1007 atoms = 0
1008 while True:
1009 yield i
1010 atoms += 1
1011 if atoms > 2:
1012 atoms = 0
1013 i += 1
1016def _alternating() -> Iterator[int]:
1017 i = 0
1018 while True:
1019 if i % 2 == 0:
1020 yield 1
1021 else:
1022 yield -1
1023 i += 1
1026@overload
1027def mdtraj_backmapping( 1027 ↛ exitline 1027 didn't jump to the function exit
1028 top: Optional[Union[Path, str, int, md.Topology]],
1029 dihedrals: Optional[np.ndarray],
1030 sidechain_dihedrals: Optional[np.ndarray],
1031 trajs: Optional[Union[TrajEnsemble, SingleTraj]],
1032 remove_component_size: int,
1033 verify_every_rotation: bool,
1034 angle_type: Literal["degree", "radian"],
1035 omega: bool,
1036 guess_amid_atoms: bool,
1037 return_indices: Literal[False],
1038 parallel: bool,
1039 progbar: Optional[Any],
1040) -> md.Trajectory: ...
1043@overload
1044def mdtraj_backmapping( 1044 ↛ exitline 1044 didn't jump to the function exit
1045 top: Optional[Union[Path, str, int, md.Topology]],
1046 dihedrals: Optional[np.ndarray],
1047 sidechain_dihedrals: Optional[np.ndarray],
1048 trajs: Optional[Union[TrajEnsemble, SingleTraj]],
1049 remove_component_size: int,
1050 verify_every_rotation: bool,
1051 angle_type: Literal["degree", "radian"],
1052 omega: bool,
1053 guess_amid_atoms: bool,
1054 return_indices: Literal[True],
1055 parallel: bool,
1056 progbar: Optional[Any],
1057) -> tuple[md.Trajectory, dict[str, np.ndarray]]: ...
1060def mdtraj_backmapping(
1061 top: Optional[Union[Path, str, int, md.Topology]] = None,
1062 dihedrals: Optional[np.ndarray] = None,
1063 sidechain_dihedrals: Optional[np.ndarray] = None,
1064 trajs: Optional[Union[TrajEnsemble, SingleTraj]] = None,
1065 remove_component_size: int = 0,
1066 verify_every_rotation: bool = False,
1067 angle_type: Literal["degree", "radian"] = "radian",
1068 omega: bool = True,
1069 guess_sp2_atoms: bool = True,
1070 return_indices: bool = False,
1071 parallel: bool = False,
1072 progbar: Optional[Any] = None,
1073) -> Union[md.Trajectory, tuple[md.Trajectory, dict[str, np.ndarray]]]:
1074 """Uses MDTraj and Christoph Gohlke's transformations.py to rotate the
1075 bonds in the provided topology.
1077 Todo:
1078 * Make this faster. Maybe write a C or FORTRAN implementation.
1080 General procedure:
1081 * Decide on which topology to use (if different topologies are in the
1082 `TrajEnsemble` class, the `dihedrals` and `sidechain_dihedrals` arrays
1083 need to be altered so that the correct dihedrals are used.
1084 Because EncoderMap is trained on a full input `dihedrals` and
1085 `sidechain_dihedrals` contain the dihedrals for the topology
1086 in `TrajEnsemble` with most of such angles. Some SingleTraj
1087 classes in TrajEnsemble might not contain all these angles if, for
1088 example, an amino acid has been modified the mutant contains more
1089 sidechain dihedrals than the wt. So the correct sidechain dihedrals
1090 for the wildtype need to be selected.
1091 * Get the indices of the far sides of the rotations. The graph is
1092 gradually broken apart and the longer sub-graphs are kept.
1093 * Extend the trajectory. The lengths of dihedrals and sidechain_dihedrals
1094 should match. The frame given by top will be duplicated
1095 len(dihedrals)-times.
1096 * Get the current angles. We know what the final angles should be,
1097 but now how far to rotate the bonds. This can be done by getting
1098 the difference between current and target angle.
1099 * Rotate the bonds. Using Christoph Gohlke's transformations.py,
1100 the rotation matrix is constructed and the array is padded
1101 with zeros to resemble an array of quaternions.
1103 Args:
1104 top (Optional[str]): The topology file to use.
1105 dihedrals (Optional[np.ndarray]): The dihedrals to put onto the trajectory.
1106 `len(dihedrals)` is number of frames of output trajectory.
1107 `dihedrals.shape[1]` needs to be the same as the number of dihedrals
1108 in the topology. Can be None, in which case dihedrals and
1109 sidechain dihedrals will be faked.
1110 sidechain_dihedrals (Optional[np.ndarray]):
1111 The sidechain dihedrals to put onto the trajectory.
1112 If None is provided, the sidechains are kept like they were in
1113 the topology. Defaults to None.
1114 trajs (Optional[em.TrajEnsemble, em.SingleTraj]): Encodermap TrajEnsemble
1115 class. It Can accelerate the loading of current dihedral angles.
1116 Checks if provided topology is part of trajs. Defaults to None.
1117 verify_every_rotation (bool): Whether the rotation succeeded.
1118 angle_type (Literal["degree", "radians"]): Whether input is in degrees. Input will be
1119 converted to radians. Defaults to False.
1120 omega (bool): Whether your input backbone dihedrals contain the omega angle.
1121 return_indices (bool): Whether to not only return the back-mapped
1122 trajectory, but also a dict of labels. This dict contains the keys:
1123 * 'dihedrals_labels'
1124 * 'generic_dihedrals_labels'
1125 * 'side_dihedrals_labels'
1126 * 'generic_side_dihedrals_labels'
1127 Which matches the indices of the returned dihedrals with the input
1128 MD structures in `top` and/or `trajs`. This can be useful to make
1129 sure that input dihedrals match output dihedrals. Why? Because there
1130 are some proline dihedrals that cannot be adjusted. They are filtered
1131 out before doing backmapping, and the indices give the names of all
1132 dihedrals that were adjusted. See the Example below.
1134 Examples:
1135 >>> from pathlib import Path
1136 >>> import numpy as np
1137 >>> import encodermap as em
1138 >>> from pprint import pprint
1139 >>> output_dir = Path(
1140 ... em.get_from_kondata(
1141 ... "OTU11",
1142 ... mk_parentdir=True,
1143 ... silence_overwrite_message=True,
1144 ... ),
1145 ... )
1146 >>> # assign how many backbone angles we need
1147 >>> traj = em.load(output_dir / "OTU11_wt_only_prot.pdb")
1148 >>> traj.load_CV("central_dihedrals")
1149 >>> n_angles = traj.central_dihedrals.shape[-1]
1150 >>> n_angles
1151 732
1152 >>> # create some fake dihedrals with a uniform distribution between -pi and pi
1153 >>> dihedrals = np.random.uniform(low=-np.pi, high=np.pi, size=(5, n_angles))
1154 >>> out, index = em.misc.backmapping.mdtraj_backmapping(
1155 ... top=output_dir / "OTU11_wt_only_prot.pdb",
1156 ... dihedrals=dihedrals,
1157 ... remove_component_size=10,
1158 ... return_indices=True,
1159 ... )
1160 >>> out = em.SingleTraj(out)
1161 >>> out.load_CV("central_dihedrals")
1162 >>> # Here you will see, what indicies were automatically dropped during backmapping
1163 >>> # They will be proline phi angles, as these angles can not be
1164 >>> # freely rotated
1165 >>> all_coords = set(out._CVs.coords["CENTRAL_DIHEDRALS"].values)
1166 >>> indexed_coords = set(index['dihedrals_labels'])
1167 >>> pprint(all_coords - indexed_coords)
1168 {'CENTERDIH PHI RESID PRO: 8 CHAIN 0',
1169 'CENTERDIH PHI RESID PRO: 70 CHAIN 0',
1170 'CENTERDIH PHI RESID PRO: 73 CHAIN 0',
1171 'CENTERDIH PHI RESID PRO: 80 CHAIN 0',
1172 'CENTERDIH PHI RESID PRO: 151 CHAIN 0',
1173 'CENTERDIH PHI RESID PRO: 200 CHAIN 0',
1174 'CENTERDIH PHI RESID PRO: 205 CHAIN 0',
1175 'CENTERDIH PHI RESID PRO: 231 CHAIN 0',
1176 'CENTERDIH PHI RESID PRO: 234 CHAIN 0',
1177 'CENTERDIH PHI RESID PRO: 238 CHAIN 0'}
1180 Raises:
1181 Exception: If the input seems like it is in degrees.
1182 Exception: If top is not part of the TrajEnsemble class provided in argument `trajs`.
1184 Returns:
1185 mdtraj.Trajectory: An MDTraj trajectory with the correct dihedrals/side-dihedrals.
1187 """
1188 # Third Party Imports
1189 import networkx as nx
1191 # if `dihedrals` is None we sample them from a random uniform distribution
1192 if dihedrals is None and sidechain_dihedrals is None:
1193 if trajs is not None:
1194 if top is None:
1195 top = 0
1196 if not hasattr(trajs, "central_dihedrals"):
1197 trajs.load_CV("central_dihedrals")
1198 if not hasattr(trajs, "side_dihedrals"):
1199 trajs.load_CV("side_dihedrals")
1200 with temp_seed(1):
1201 dihedrals = np.random.uniform(
1202 low=-np.pi,
1203 high=np.pi,
1204 size=(10, trajs[top].central_dihedrals.shape[-1]),
1205 )
1206 sidechain_dihedrals = np.random.uniform(
1207 low=-np.pi,
1208 high=np.pi,
1209 size=(10, trajs[top].side_dihedrals.shape[-1]),
1210 )
1211 elif top is not None and trajs is None:
1212 assert not isinstance(top, int) or isinstance(top, md.Topology), (
1213 f"When providing no `dihedrals` to sample fake dihedrals from "
1214 f"a random uniform distribution, you can't provide int or "
1215 f"md.Topology for `top`. Please provide a str or Path."
1216 )
1217 trajs = TrajEnsemble([top])
1218 top = 0
1219 trajs.load_CVs(["central_dihedrals", "side_dihedrals"])
1220 dihedrals = np.random.uniform(
1221 low=-np.pi,
1222 high=np.pi,
1223 size=(10, trajs[top].central_dihedrals.shape[-1]),
1224 )
1225 sidechain_dihedrals = np.random.uniform(
1226 low=-np.pi, high=np.pi, size=(10, trajs[top].side_dihedrals.shape[-1])
1227 )
1228 else:
1229 raise Exception(f"Please provide either a `top` or `trajs` argument.")
1231 # change and check the angles
1232 if angle_type == "radian":
1233 if np.any(dihedrals > np.pi):
1234 raise Exception(
1235 f"The argument `angle_type` is meant to specify, what angles "
1236 f"(radian or degree) are provided for the argument `dihedrals`. "
1237 f"This allows you to provide either to this function by just "
1238 f"specifying this argument. You specified {angle_type} but some "
1239 f"of your dihedrals are greater than pi."
1240 )
1241 if sidechain_dihedrals is not None:
1242 if np.any(sidechain_dihedrals > np.pi):
1243 raise Exception(
1244 f"The argument `angle_type` is meant to specify, what angles "
1245 f"(radian or degree) are provided for the argument `sidechain_dihedrals`. "
1246 f"This allows you to provide either to this function by just "
1247 f"specifying this argument. You specified {angle_type} but some "
1248 f"of your sidechain dihedrals are greater than pi."
1249 )
1250 elif angle_type == "degree":
1251 if np.all(dihedrals <= np.pi):
1252 raise Exception(
1253 f"The argument `angle_type` is meant to specify, what angles "
1254 f"(radian or degree) are provided for the argument `dihedrals`. "
1255 f"This allows you to provide either to this function by just "
1256 f"specifying this argument. You specified {angle_type} but none "
1257 f"of your dihedrals were greater than pi: {dihedrals}"
1258 )
1259 dihedrals = np.deg2rad(dihedrals)
1260 if sidechain_dihedrals is not None:
1261 if np.all(sidechain_dihedrals <= np.pi):
1262 raise Exception(
1263 f"The argument `angle_type` is meant to specify, what angles "
1264 f"(radian or degree) are provided for the argument `sidechain_dihedrals`. "
1265 f"This allows you to provide either to this function by just "
1266 f"specifying this argument. You specified {angle_type} but none "
1267 f"of your sidechain dihedrals were greater than pi."
1268 )
1269 sidechain_dihedrals = np.deg2rad(sidechain_dihedrals)
1270 else:
1271 raise Exception(
1272 f"Argument `angle_type` must be either 'radian' or 'degree', "
1273 f"you supplied: {angle_type}"
1274 )
1276 # make sure the input has the same shape along the "frame" axis, that will be created.
1277 if sidechain_dihedrals is not None:
1278 assert len(dihedrals) == len(sidechain_dihedrals), (
1279 f"The number of provided dihedrals ({len(dihedrals)}) and "
1280 f"sidechain dihedrals ({len(sidechain_dihedrals)}) must be the same."
1281 )
1283 # either top or trajs has to be not None
1284 if trajs is None: # pragma: no cover
1285 if top is None:
1286 raise Exception(
1287 f"Please provide the path to a topology file"
1288 f"(.pdb, .gro) to use for backmapping."
1289 )
1290 elif isinstance(top, (str, Path)):
1291 inp_trajs = TrajEnsemble([top])
1292 inp_trajs.load_CVs(["central_dihedrals", "side_dihedrals"])
1293 elif isinstance(top, int):
1294 raise Exception(
1295 f"When providing an int for `top`, pleas also provide a `em.TrajEnsemble` "
1296 f"for argument `trajs.`"
1297 )
1298 else:
1299 raise ValueError(
1300 f"Argument `top` must be of type str, int, or None, "
1301 f"you provided: {type(top)}."
1302 )
1303 elif isinstance(trajs, SingleTraj) or trajs.__class__.__name__ == "SingleTraj":
1304 if isinstance(top, (int, Path, str)):
1305 print(
1306 "When providing `em.SingleTraj` for argument `trajs`, the argument "
1307 "`top` will be ignored."
1308 )
1309 assert "central_dihedrals" in trajs._CVs, (
1310 f"The provided traj, doesn't have the collective variable 'central_dihedrals' "
1311 f"loaded. Please load them by calling: `traj.load_CVs('all')."
1312 )
1313 if sidechain_dihedrals is not None:
1314 assert "side_dihedrals" in trajs._CVs, (
1315 f"The provided traj, doesn't have the collective variable 'central_dihedrals' "
1316 f"loaded. Please load them by calling: `traj.load_CVs('all')."
1317 )
1318 inp_trajs = trajs._gen_ensemble()
1319 elif isinstance(trajs, TrajEnsemble) or trajs.__class__.__name__ == "TrajEnsemble":
1320 assert "central_dihedrals" in trajs._CVs, (
1321 f"The provided traj, doesn't have the collective variable 'central_dihedrals' "
1322 f"loaded. Please load them by calling: `traj.load_CVs('all')."
1323 )
1324 if sidechain_dihedrals is not None:
1325 assert "side_dihedrals" in trajs._CVs, (
1326 f"The provided traj, doesn't have the collective variable 'central_dihedrals' "
1327 f"loaded. Please load them by calling: `traj.load_CVs('all')."
1328 )
1329 if isinstance(top, (str, Path)) and os.path.isfile(top):
1330 print(
1331 "When providing `em.TrajEnsemble` for argument `trajs`, the argument "
1332 "`top` will be ignored if it points to a file."
1333 )
1334 elif isinstance(top, str) and top in trajs.common_str:
1335 _trajs_index = trajs.trajs_by_common_str[top][0].traj_num
1336 elif top is None:
1337 _trajs_index = 0
1338 elif isinstance(top, (int, np.int64)):
1339 # _trajs_index = top
1340 _trajs_index = trajs.trajs_by_top[trajs.top[top]][0].traj_num
1341 else:
1342 raise ValueError(
1343 f"Argument `top` must be of type str, int, md.Topology or None, "
1344 f"you provided: {type(top)}."
1345 )
1346 inp_trajs = trajs.tsel[_trajs_index]._gen_ensemble()
1347 else:
1348 raise ValueError(
1349 f"Argument `trajs` must be of type `em.SingleTraj`, `em.TrajEnsemble`, or None, "
1350 f"you provided: {type(trajs)}."
1351 )
1353 if trajs is None:
1354 trajs = inp_trajs
1356 # now we match the names of the featurizer
1357 all_central_indices = trajs._CVs.central_dihedrals.coords[
1358 trajs._CVs.central_dihedrals.attrs["feature_axis"]
1359 ]
1360 central_indices = all_central_indices[
1361 np.all(~np.isnan(inp_trajs._CVs.central_dihedrals.values[0]), axis=0)
1362 ]
1363 if sidechain_dihedrals is not None:
1364 all_side_indices = trajs._CVs.side_dihedrals.coords[
1365 trajs._CVs.side_dihedrals.attrs["feature_axis"]
1366 ]
1367 side_indices = all_side_indices[
1368 np.all(~np.isnan(inp_trajs._CVs.side_dihedrals.values[0]), axis=0)
1369 ]
1370 if not omega:
1371 central_indices = central_indices[
1372 ~central_indices.str.lower().str.contains("omega")
1373 ]
1374 if sidechain_dihedrals is not None:
1375 side_indices = side_indices[~side_indices.str.lower().str.contains("omega")]
1377 generic_labels = not any(
1378 central_indices.coords["CENTRAL_DIHEDRALS"].str.contains("RESID")
1379 )
1381 if generic_labels:
1382 # we have generic labels and add non-generic ones
1383 _back_labels = {
1384 "generic_dihedrals_labels": central_indices.values,
1385 "dihedrals_labels": np.asarray(
1386 features.CentralDihedrals(inp_trajs[0], omega=omega).describe()
1387 ),
1388 }
1389 if sidechain_dihedrals is not None:
1390 _back_labels |= {
1391 "generic_side_dihedrals_labels": side_indices.values,
1392 "side_dihedrals_labels": np.asarray(
1393 features.SideChainDihedrals(inp_trajs[0]).describe()
1394 ),
1395 }
1396 else:
1397 # we have non-generic labels and build generic ones
1398 _back_labels = {
1399 "dihedrals_labels": central_indices.values,
1400 "generic_dihedrals_labels": np.asarray(
1401 features.CentralDihedrals(
1402 inp_trajs[0], omega=omega, generic_labels=True
1403 ).describe()
1404 ),
1405 }
1406 if sidechain_dihedrals is not None:
1407 _back_labels |= {
1408 "side_dihedrals_labels": side_indices.values,
1409 "generic_side_dihedrals_labels": np.asarray(
1410 features.SideChainDihedrals(
1411 inp_trajs[0], generic_labels=True
1412 ).describe()
1413 ),
1414 }
1416 # check that all indices are present
1417 if len(all_central_indices) >= len(central_indices):
1418 if dihedrals.shape[1] == len(all_central_indices):
1419 dih_indices = np.arange(len(all_central_indices))[
1420 np.in1d(all_central_indices, central_indices)
1421 ]
1422 dihedrals = dihedrals[:, dih_indices]
1423 elif dihedrals.shape[1] == len(central_indices):
1424 dih_indices = np.arange(dihedrals.shape[1])
1425 dihedrals = dihedrals[:, dih_indices]
1426 else:
1427 raise Exception(
1428 f"The shape of the provided `dihedrals` is wrong, either provide "
1429 f"an array with shape[1] = {len(central_indices)}, or "
1430 f"{len(all_central_indices)}, "
1431 f"your array has the shape {dihedrals.shape[1]}."
1432 )
1433 else:
1434 raise NotImplementedError
1436 if sidechain_dihedrals is not None:
1437 if len(all_side_indices) >= len(side_indices):
1438 if sidechain_dihedrals.shape[1] == len(all_side_indices):
1439 _side_indices_out = side_indices.copy()
1440 side_indices = np.in1d(all_side_indices, side_indices)
1441 msg = (
1442 f"Your supplied `sidechain_dihedrals` are misshaped. They are"
1443 f"expected to have either shape (n, {len(side_indices)}, which"
1444 f"matches the number of sidechain dihedrals in the specified"
1445 f"topology: {inp_trajs[0].top} or a shape of (n, "
1446 f"{len(all_side_indices)}) which matches the total number of "
1447 f"possible sidechain angles in the provided `TrajEnsemble` with "
1448 f"{trajs.top} different toplogies."
1449 )
1450 assert side_indices.shape[0] == sidechain_dihedrals.shape[-1], msg
1451 sidechain_dihedrals = sidechain_dihedrals[:, side_indices]
1452 elif sidechain_dihedrals.shape[1] == len(side_indices):
1453 _side_indices_out = side_indices.copy()
1454 side_indices = np.arange(sidechain_dihedrals.shape[1])
1455 sidechain_dihedrals = sidechain_dihedrals[:, side_indices]
1456 else:
1457 raise Exception(
1458 f"The shape of the provided `dihedrals` is wrong, either provide "
1459 f"an array with shape[1] = {len(side_indices)}, or {len(all_side_indices)}, "
1460 f"your array has the shape {sidechain_dihedrals.shape[1]}."
1461 )
1462 else:
1463 raise NotImplementedError
1465 dih_indices = inp_trajs[0]._CVs.central_dihedrals_feature_indices.values[0]
1466 if omega:
1467 idx = ~np.all(np.isnan(dih_indices), axis=1)
1468 else:
1469 idx = (
1470 ~np.all(np.isnan(dih_indices), axis=1)
1471 & ~all_central_indices.str.lower().str.contains("omega")
1472 ).values
1473 dih_indices = dih_indices[idx]
1474 dih_indices = dih_indices.astype(int)
1475 _dih_indices = deepcopy(dih_indices)
1476 assert dih_indices.ndim == 2, f"Problem when calculating dihedrals {inp_trajs=}"
1477 side_indices = inp_trajs[0]._CVs.side_dihedrals_feature_indices.values[0]
1478 side_indices = side_indices[~np.all(np.isnan(side_indices), axis=1)]
1479 side_indices = side_indices.astype(int)
1480 _side_indices = deepcopy(side_indices)
1481 assert isinstance(inp_trajs, TrajEnsemble)
1483 # get indices of atoms for rotations
1484 g = inp_trajs.top[0].to_bondgraph()
1486 # can be used to visualize topology
1487 # nx.draw(g, pos=nx.spring_layout(g))
1488 if not nx.is_connected(g):
1489 # Third Party Imports
1490 from networkx import connected_components
1492 components = [*connected_components(g)]
1493 if remove_component_size > 0:
1494 component_sizes = sorted([len(c) for c in components])
1495 if any([i > remove_component_size for i in component_sizes[:-1]]):
1496 _raise_components_exception(
1497 components, trajs, top, remove_component_size
1498 )
1499 offending_components = []
1500 for c in components:
1501 if len(c) <= remove_component_size:
1502 offending_components.extend([a.index for a in list(c)])
1503 inp_trajs[0].atom_slice(offending_components, invert=True)
1504 g = inp_trajs.top[0].to_bondgraph()
1505 else:
1506 _raise_components_exception(components, trajs, top, remove_component_size)
1508 # at this point dih_bond_indices has been defined.
1509 dih_bond_indices = dih_indices[:, 1:3]
1510 assert (
1511 dih_bond_indices.shape[0] == _dih_indices.shape[0] == dihedrals.shape[1]
1512 ), f"{dih_bond_indices.shape=}, {_dih_indices.shape=}, {dihedrals.shape=} {omega=}"
1514 # filter out the proline angles
1515 dih_bond_atoms = np.dstack(
1516 [
1517 [inp_trajs[0].top.atom(a).__str__() for a in dih_bond_indices[:, 0]],
1518 [inp_trajs[0].top.atom(a).__str__() for a in dih_bond_indices[:, 1]],
1519 ]
1520 )[0]
1521 indices = np.full(dihedrals.shape[1], 1)
1522 assert indices.shape[0] == dihedrals.shape[1]
1523 assert (
1524 dihedrals[:, indices].shape == dihedrals.shape
1525 ), f"{dihedrals[:, indices].shape=} {dihedrals.shape=}"
1526 for i, bond in enumerate(dih_bond_atoms):
1527 if "PRO" in bond[0] and "PRO" in bond[1] and "N" in bond[0] and "CA" in bond[1]:
1528 indices[i] = 0
1529 indices = indices.astype(bool)
1531 # get rid of the proline dihedrals
1532 _back_labels["dihedrals_labels"] = _back_labels["dihedrals_labels"][indices]
1533 _back_labels["generic_dihedrals_labels"] = _back_labels["generic_dihedrals_labels"][
1534 indices
1535 ]
1536 dih_indices = dih_indices[indices]
1537 dih_bond_indices = dih_bond_indices[indices]
1538 dihedrals = dihedrals[:, indices]
1539 assert dihedrals.shape[1] == dih_indices.shape[0] == dih_bond_indices.shape[0]
1540 dih_near_sides, dih_far_sides = _get_near_and_far_networkx(
1541 g,
1542 dih_bond_indices,
1543 inp_trajs[0].top,
1544 parallel=parallel,
1545 )
1547 if sidechain_dihedrals is not None:
1548 side_bond_indices = side_indices[:, 1:3]
1549 assert (
1550 side_bond_indices.shape[0]
1551 == _side_indices.shape[0]
1552 == sidechain_dihedrals.shape[1]
1553 ), (
1554 f"{side_bond_indices.shape=}, {_side_indices.shape=}, "
1555 f"{sidechain_dihedrals.shape=}"
1556 )
1557 # filter out the proline angles
1558 side_bond_atoms = np.dstack(
1559 [
1560 [inp_trajs[0].top.atom(a).__str__() for a in side_bond_indices[:, 0]],
1561 [inp_trajs[0].top.atom(a).__str__() for a in side_bond_indices[:, 1]],
1562 ]
1563 )[0]
1564 indices = np.full(sidechain_dihedrals.shape[1], 1)
1565 assert indices.shape[0] == sidechain_dihedrals.shape[1]
1566 assert (
1567 sidechain_dihedrals[:, indices].shape == sidechain_dihedrals.shape
1568 ), f"{sidechain_dihedrals[:, indices].shape=} {sidechain_dihedrals.shape=}"
1569 for i, bond in enumerate(side_bond_atoms):
1570 if (
1571 "PRO" in bond[0]
1572 and "PRO" in bond[1]
1573 and "CA" in bond[0]
1574 and "CB" in bond[1]
1575 ):
1576 indices[i] = 0
1577 if (
1578 "PRO" in bond[0]
1579 and "PRO" in bond[1]
1580 and "CB" in bond[0]
1581 and "CG" in bond[1]
1582 ):
1583 indices[i] = 0
1584 indices = indices.astype(bool)
1585 _back_labels["side_dihedrals_labels"] = _back_labels["side_dihedrals_labels"][
1586 indices
1587 ]
1588 _back_labels["generic_side_dihedrals_labels"] = _back_labels[
1589 "generic_side_dihedrals_labels"
1590 ][indices]
1591 _side_indices_out = _side_indices_out[indices]
1592 side_indices = side_indices[indices]
1593 side_bond_indices = side_bond_indices[indices]
1594 sidechain_dihedrals = sidechain_dihedrals[:, indices]
1596 side_near_sides, side_far_sides = _get_near_and_far_networkx(
1597 g,
1598 side_bond_indices,
1599 inp_trajs[0].top,
1600 parallel=parallel,
1601 )
1603 # assert that the dihedrals and _back_labels have the correct shape
1604 # that way we can be sure to use the labels to index the correct dihedrals
1605 # after obtaining the finished trajectory
1606 assert (
1607 _back_labels["dihedrals_labels"].shape
1608 == _back_labels["generic_dihedrals_labels"].shape
1609 )
1610 assert _back_labels["dihedrals_labels"].shape[0] == dihedrals.shape[-1]
1611 if sidechain_dihedrals is not None:
1612 assert (
1613 _back_labels["side_dihedrals_labels"].shape
1614 == _back_labels["generic_side_dihedrals_labels"].shape
1615 )
1616 assert (
1617 _back_labels["side_dihedrals_labels"].shape[0]
1618 == sidechain_dihedrals.shape[-1]
1619 )
1621 # extend the traj
1622 for i in range(len(dihedrals)):
1623 if i == 0:
1624 out_traj = deepcopy(inp_trajs[0][0].traj)
1625 else:
1626 out_traj = out_traj.join(inp_trajs[0][0].traj)
1627 out_traj.top = inp_trajs[0].top
1629 # adjust the torsions
1630 new_xyz = np.ascontiguousarray(out_traj.xyz.copy().astype("float32"))
1631 new_xyz = np.pad(
1632 new_xyz, ((0, 0), (0, 0), (0, 1)), mode="constant", constant_values=1
1633 )
1634 dihedrals = dihedrals.astype("float32")
1635 dih_indices = dih_indices.astype("int32")
1636 total_counts = dihedrals.shape[0] * dihedrals.shape[1]
1637 if sidechain_dihedrals is not None:
1638 total_counts += dihedrals.shape[0] * sidechain_dihedrals.shape[1]
1639 sidechain_dihedrals = sidechain_dihedrals.astype("float32")
1640 if guess_sp2_atoms:
1641 total_counts += _countprog(out_traj)
1642 if parallel:
1643 raise Exception(f"Parallel has not yet been tested.")
1644 parallel_rotation_application(
1645 new_xyz,
1646 dih_indices,
1647 dihedrals,
1648 dih_near_sides,
1649 )
1650 else:
1651 pbar = progbar
1652 if pbar is None:
1653 if _is_notebook():
1654 pbar = notebook_tqdm(
1655 total=total_counts,
1656 leave=False,
1657 position=0,
1658 desc="Backmapping...",
1659 )
1660 else:
1661 pbar = normal_tqdm(
1662 total=total_counts,
1663 leave=False,
1664 position=0,
1665 desc="Backmapping...",
1666 )
1667 else:
1668 if not isinstance(pbar, bool):
1669 pbar.reset(pbar.total + total_counts - 1)
1670 else:
1671 pbar = None
1673 for i in range(dihedrals.shape[0]):
1674 for j in range(dihedrals.shape[1]):
1675 # central_dihedrals
1676 near_side = dih_near_sides[j]
1677 far_side = dih_far_sides[j]
1678 dihedral = dih_indices[j]
1679 bond = dih_bond_indices[j]
1681 # define inputs
1682 target_angle = dihedrals[i, j]
1683 current_angle = _dihedral(new_xyz[i, :, :3], dihedral)[0][0]
1684 angle = target_angle - current_angle
1685 direction = np.diff(new_xyz[i, bond, :3], axis=0).flatten()
1686 pivot_point = new_xyz[i, bond[0], :3]
1688 # perform rotation
1689 rotmat = transformations_rotation_matrix(angle, direction, pivot_point)
1690 new_xyz[i, far_side, :3] = rotmat.dot(new_xyz[i, far_side].T).T[:, :3]
1692 # verify
1693 if verify_every_rotation:
1694 _ = _dihedral(new_xyz[i, :, :3], dihedral)[0][0]
1695 if not np.isclose(_, target_angle, atol=1e-3):
1696 _resids = [
1697 str(inp_trajs[0].top.atom(x).residue.index)
1698 for x in dihedral
1699 ]
1700 s = (
1701 f"Adjusting dihedral angle for atoms "
1702 f"{[str(inp_trajs[0].top.atom(x)) for x in dihedral]} "
1703 f"failed with an absolute tolerance of 1e-3. "
1704 f"Residue indices are: "
1705 f"{_resids}"
1706 f"\nTarget angle was {target_angle} {angle_type}, "
1707 f"but rotation yielded angle with {_} {angle_type}."
1708 f"\nCurrent angle was {current_angle}. To reach "
1709 f"target angle is a rotation of {angle} {angle_type} "
1710 f"was carried out."
1711 f"\nRotation axis was vector from "
1712 f"{inp_trajs[0].top.atom(bond[0])} to "
1713 f"{inp_trajs[0].top.atom(bond[1])}"
1714 f"\nOnly these atoms should have been affected by "
1715 f"rotation: {far_side}"
1716 "\nBut somehow this method still crashed. Maybe "
1717 "these prints will help."
1718 )
1719 raise Exception(s)
1720 if pbar is not None:
1721 pbar.update()
1723 if sidechain_dihedrals is not None:
1724 for j in range(sidechain_dihedrals.shape[1]):
1725 # central_dihedrals
1726 near_side = side_near_sides[j]
1727 far_side = side_far_sides[j]
1728 dihedral = side_indices[j]
1729 bond = side_bond_indices[j]
1731 # define inputs
1732 target_angle = sidechain_dihedrals[i, j]
1733 current_angle = _dihedral(new_xyz[i, :, :3], dihedral)
1734 angle = target_angle - current_angle
1735 direction = np.diff(new_xyz[i, bond, :3], axis=0).flatten()
1736 pivot_point = new_xyz[i, bond[0], :3]
1738 # perform rotation
1739 rotmat = transformations_rotation_matrix(
1740 angle, direction, pivot_point
1741 )
1742 rotmat = transformations_rotation_matrix(
1743 angle, direction, pivot_point
1744 )
1745 new_xyz[i, far_side, :3] = rotmat.dot(new_xyz[i, far_side].T).T[
1746 :, :3
1747 ]
1749 if verify_every_rotation:
1750 _ = _dihedral(new_xyz[i, :, :3], dihedral)[0][0]
1751 if not np.isclose(_, target_angle, atol=1e-3):
1752 _resids = [
1753 str(inp_trajs[0].top.atom(x).residue.index)
1754 for x in dihedral
1755 ]
1756 s = (
1757 f"Adjusting dihedral angle for atoms "
1758 f"{[str(inp_trajs[0].top.atom(x)) for x in dihedral]} "
1759 f"failed with an absolute tolerance of 1e-3. "
1760 f"Residue indices are: "
1761 f"{_resids}"
1762 f"\nTarget angle was {target_angle} {angle_type}, "
1763 f"but rotation yielded angle with {_} {angle_type}."
1764 f"\nCurrent angle was {current_angle}. To reach target "
1765 f"angle is a rotation of {angle} {angle_type} was "
1766 f"carried out.\nRotation axis was vector from "
1767 f"{inp_trajs[0].top.atom(bond[0])} to "
1768 f"{inp_trajs[0].top.atom(bond[1])}"
1769 f"\nOnly these atoms should have been affected by "
1770 f"rotation: {far_side}\nBut somehow this method "
1771 f"still crashed. Maybe these prints will help."
1772 )
1773 raise Exception(s)
1774 if pbar is not None:
1775 pbar.update()
1777 # overwrite traj and return
1778 out_traj.xyz = new_xyz[..., :3]
1780 # fix the amide atoms
1781 if guess_sp2_atoms:
1782 _guess_sp2_atoms(out_traj, pbar)
1784 if not return_indices:
1785 return out_traj
1786 if return_indices:
1787 return out_traj, _back_labels
1790def _countprog(traj):
1791 total = 0
1792 for i, r in enumerate(traj.top.residues):
1793 for j in range(traj.n_frames):
1794 if r:
1795 continue
1796 if i > 0:
1797 total += 1
1798 if i < traj.n_residues - 2:
1799 total += 1
1800 return total
1803def _guess_sp2_atoms(
1804 traj: md.Trajectory,
1805 pbar: Optional[Any] = None,
1806) -> None:
1807 # Third Party Imports
1808 from scipy.spatial.transform import Rotation as R
1810 Ns = traj.top.select("name N")
1811 CAs = traj.top.select("name CA")
1812 Cs = traj.top.select("name C")
1813 assert len(Ns) == len(CAs) == len(Cs) == traj.n_residues, (
1814 f"I could not determine the correct number of backbone atoms for this "
1815 f"protein of {traj.n_residues} residues. I expected {traj.n_residues} "
1816 f"nitrogen atoms, but got {len(Ns)}. I expected {traj.n_residues} "
1817 f"alpha carbon atoms, but got {len(CAs)}. I expected {traj.n_residues} "
1818 f"carboxylic carbons, but got {len(Cs)}. Maybe your protein contains "
1819 f"non-standard residues."
1820 )
1821 for i, (r, N, CA, C) in enumerate(zip(traj.top.residues, Ns, CAs, Cs)):
1822 for j, frame in enumerate(traj):
1823 if r.name == "PRO":
1824 continue
1825 N_pos = frame.xyz[0, N]
1826 CA_pos = frame.xyz[0, CA]
1827 C_pos = frame.xyz[0, C]
1828 if i > 0:
1829 H = next(r.atoms_by_name("H")).index
1830 C_prev_pos = frame.xyz[0, Cs[i - 1]]
1831 v1 = CA_pos - N_pos
1832 v2 = C_prev_pos - N_pos
1833 n = np.cross(v1, v2)
1834 n /= np.linalg.norm(n)
1835 n *= 123 / 180 * np.pi
1836 M = R.from_rotvec(n)
1837 new_H_pos = v1 @ M.as_matrix()
1838 new_H_pos /= np.linalg.norm(new_H_pos)
1839 new_H_pos *= 0.11
1840 new_H_pos += N_pos
1841 traj.xyz[j, H] = new_H_pos
1842 if pbar is not None:
1843 pbar.update()
1844 if i < traj.n_residues - 2:
1845 O = next(r.atoms_by_name("O")).index
1846 N_next_pos = frame.xyz[0, Ns[i + 1]]
1847 v1 = CA_pos - C_pos
1848 v2 = N_next_pos - C_pos
1849 n = np.cross(v1, v2)
1850 n /= np.linalg.norm(n)
1851 n *= 121 / 180 * np.pi
1852 M = R.from_rotvec(n)
1853 new_O_pos = v1 @ M.as_matrix()
1854 new_O_pos /= np.linalg.norm(new_O_pos)
1855 new_O_pos *= 0.124
1856 new_O_pos += C_pos
1857 traj.xyz[j, O] = new_O_pos
1858 # actual_distance = np.linalg.norm(frame.xyz[0, C] - frame.xyz[0, O])
1859 # u = CA_pos - C_pos
1860 # v = frame.xyz[0, O] - C_pos
1861 # actual_angle = np.arccos(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v)))
1862 # print(
1863 # f"In the protein, the atoms {traj.top.atom(C)} and {traj.top.atom(O)} are "
1864 # f"{actual_distance:.3f} nm apart with a CA-C-O angle of {actual_angle:.3f} rad. "
1865 # f"Setting the angle to {121 / 180 * np.pi:.3f} and the distance "
1866 # f"to 0.124 nm, I get the position of to be {new_O_pos=}, instead "
1867 # f"of {frame.xyz[0, O]}."
1868 # )
1869 if pbar is not None:
1870 pbar.update()
1873def dihedral_to_cartesian_tf_one_way_layers(
1874 dihedrals: tf.Tensor,
1875 cartesian: tf.Tensor,
1876 n: int,
1877) -> None:
1878 """Takes one of the cartesian sides (left, or right) and turns them into 3D.
1880 The left or right sides of the cartesian chain can have different number
1881 of atoms in them, depending on whether the protein has an even or odd
1882 number of central atoms.
1884 Args:
1885 dihedrals: tf.Tensor: The dihedrals with shape (batch, None)
1886 cartesian: tf.Tensor: The cartesians with shape (batch, None, 3).
1887 n: int: The range variable for iteration. Must be the same as dihedrals.shape[-1].
1889 Returns:
1890 tf.Tensor: The output.
1892 """
1893 tf.debugging.assert_rank(dihedrals, 2)
1894 dihedrals = -dihedrals
1896 rotated = cartesian[:, 1:]
1897 collected_cartesians = [cartesian[:, :1]]
1898 for i in range(n):
1899 collected_cartesians.append(rotated[:, 0:1])
1900 axis = rotated[:, 1] - rotated[:, 0]
1901 # axis /= tf.norm(axis, axis=1, keepdims=True)
1902 # numeric problems with tf.norm
1903 # see here:
1904 # https://datascience.stackexchange.com/q/80898
1905 axis /= tf.expand_dims(tf.sqrt(tf.reduce_sum(tf.square(axis), axis=1)), axis=-1)
1906 offset = rotated[:, 1:2]
1907 rotated = offset + tf.matmul(
1908 rotated[:, 1:] - offset, rotation_matrix(axis, dihedrals[:, i])
1909 )
1910 collected_cartesians.append(rotated)
1911 collected_cartesians = tf.concat(collected_cartesians, axis=1)
1912 return collected_cartesians
1915def backbone_hydrogen_oxygen_crossproduct(backbone_positions):
1916 assert backbone_positions.shape[2] % 3 == 0 # C, CA, N atoms, multiple of three
1917 pass
1920def guess_sp2_atom(cartesians, indices, angle_to_previous, bond_length):
1921 added_cartesians = []
1922 for i in indices:
1923 prev_vec = cartesians[:, i - 1] - cartesians[:, i]
1924 try:
1925 next_vec = cartesians[:, i + 1] - cartesians[:, i]
1926 except tf.errors.InvalidArgumentError:
1927 next_vec = cartesians[:, i - 2] - cartesians[:, i]
1929 perpendicular_axis = tf.linalg.cross(prev_vec, next_vec)
1930 perpendicular_axis /= tf.norm(perpendicular_axis, axis=1, keepdims=True)
1931 bond_vec = tf.matmul(
1932 tf.expand_dims(prev_vec, 1),
1933 rotation_matrix(perpendicular_axis, angle_to_previous),
1934 )
1935 bond_vec = bond_vec[:, 0, :]
1936 bond_vec *= bond_length / tf.norm(bond_vec, axis=1, keepdims=True)
1937 added_cartesians.append(cartesians[:, i] + bond_vec)
1938 added_cartesians = tf.stack(added_cartesians, axis=1)
1939 return added_cartesians
1942def guess_amide_H(cartesians, N_indices):
1943 return guess_sp2_atom(cartesians, N_indices[1::], 123 / 180 * pi, 1.10)
1946def guess_amide_O(cartesians, C_indices):
1947 return guess_sp2_atom(cartesians, C_indices, 121 / 180 * pi, 1.24)
1950def rotation_matrix(axis_unit_vec, angle):
1951 angle = tf.expand_dims(tf.expand_dims(angle, axis=-1), axis=-1)
1952 i = tf.expand_dims(tf.eye(3), 0)
1953 zeros = tf.zeros(tf.shape(axis_unit_vec)[0])
1954 cross_prod_matrix = tf.convert_to_tensor(
1955 [
1956 [zeros, -axis_unit_vec[:, 2], axis_unit_vec[:, 1]],
1957 [axis_unit_vec[:, 2], zeros, -axis_unit_vec[:, 0]],
1958 [-axis_unit_vec[:, 1], axis_unit_vec[:, 0], zeros],
1959 ]
1960 )
1961 cross_prod_matrix = tf.transpose(cross_prod_matrix, [2, 0, 1])
1962 r = tf.cos(angle) * i
1963 r += tf.sin(angle) * cross_prod_matrix
1964 axis_unit_vec = tf.expand_dims(axis_unit_vec, 2)
1965 r += (1 - tf.cos(angle)) * tf.matmul(
1966 axis_unit_vec, tf.transpose(axis_unit_vec, [0, 2, 1])
1967 )
1968 return r
1971def merge_cartesians(
1972 central_cartesians, N_indices, O_indices, H_cartesians, O_cartesians
1973):
1974 cartesian = [central_cartesians[:, 0]]
1975 h_i = 0
1976 o_i = 0
1977 for i in range(1, central_cartesians.shape[1]):
1978 cartesian.append(central_cartesians[:, i])
1979 if i in N_indices[1::]:
1980 cartesian.append(H_cartesians[:, h_i])
1981 h_i += 1
1982 elif i in O_indices:
1983 cartesian.append(O_cartesians[:, o_i])
1984 o_i += 1
1985 cartesian = tf.stack(cartesian, axis=1)
1986 assert (
1987 cartesian.shape[1]
1988 == central_cartesians.shape[1] + H_cartesians.shape[1] + O_cartesians.shape[1]
1989 )
1990 return cartesian
1993def dihedral_backmapping(
1994 pdb_path, dihedral_trajectory, rough_n_points=-1, sidechains=None
1995):
1996 """
1997 Takes a pdb file with a peptide and creates a trajectory based on the dihedral angles given.
1998 It simply rotates around the dihedral angle axis. In the result side-chains might overlap but the backbone should
1999 turn out quite well.
2001 :param pdb_path: (str)
2002 :param dihedral_trajectory:
2003 array-like of shape (traj_length, number_of_dihedrals)
2004 :param rough_n_points: (int) a step_size to select a subset of values from dihedral_trajectory is calculated by
2005 max(1, int(len(dihedral_trajectory) / rough_n_points)) with rough_n_points = -1 all values are used.
2006 :return: (MDAnalysis.Universe)
2007 """
2008 step_size = max(1, int(len(dihedral_trajectory) / rough_n_points))
2009 dihedral_trajectory = dihedral_trajectory[::step_size]
2010 if sidechains is not None:
2011 sidechain_dihedral_trajectory = sidechains[::step_size]
2013 uni = mda.Universe(pdb_path, format="PDB")
2014 protein = uni.select_atoms("protein")
2016 dihedrals = []
2017 sidechain_dihedrals = []
2019 for residue in protein.residues:
2020 psi = residue.psi_selection()
2021 if psi:
2022 dihedrals.append(psi)
2024 for residue in protein.residues:
2025 omega = residue.omega_selection()
2026 if omega:
2027 dihedrals.append(omega)
2029 for residue in protein.residues:
2030 phi = residue.phi_selection()
2031 if phi:
2032 dihedrals.append(phi)
2034 if sidechains is not None:
2035 for residue in protein.residues:
2036 chi1 = residue.chi1_selection()
2037 if chi1:
2038 sidechain_dihedrals.append(chi1)
2040 for residue in protein.residues:
2041 if "chi2" in residue.__dir__():
2042 sidechain_dihedrals.append(residue.chi2_selection())
2044 for residue in protein.residues:
2045 if "chi3" in residue.__dir__():
2046 sidechain_dihedrals.append(residue.chi3_selection())
2048 for residue in protein.residues:
2049 if "chi4" in residue.__dir__():
2050 sidechain_dihedrals.append(residue.chi4_selection())
2052 for residue in protein.residues:
2053 if "chi5" in residue.__dir__():
2054 sidechain_dihedrals.append(residue.chi5_selection())
2056 if sidechains is not None:
2057 if sidechain_dihedral_trajectory.shape[1] == len(sidechain_dihedrals) * 2:
2058 sidechain_dihedral_trajectory = sidechain_dihedral_trajectory[:, ::2]
2060 _expand_universe(uni, len(dihedral_trajectory))
2062 if sidechains is None:
2063 for dihedral_values, step in zip(dihedral_trajectory, uni.trajectory):
2064 for dihedral, value in zip(dihedrals, dihedral_values):
2065 _set_dihedral(dihedral, protein, value / (2 * pi) * 360)
2066 else:
2067 for dihedral_values, sidechain_dihedral_values, step in zip(
2068 dihedral_trajectory, sidechain_dihedral_trajectory, uni.trajectory
2069 ):
2070 for dihedral, value in zip(dihedrals, dihedral_values):
2071 _set_dihedral(dihedral, protein, value / (2 * pi) * 360)
2072 for dihedral, value in zip(sidechain_dihedrals, sidechain_dihedral_values):
2073 _set_dihedral(dihedral, protein, value / (2 * pi) * 360)
2074 return uni
2077def _set_dihedral(dihedral, atoms, angle):
2078 current_angle = dihedral.dihedral.value()
2079 head = atoms[dihedral[2].id :]
2080 vec = dihedral[2].position - dihedral[1].position
2081 head.rotateby(angle - current_angle, vec, dihedral[2].position)
2084def _expand_universe(universe, length):
2085 coordinates = (
2086 AnalysisFromFunction(lambda ag: ag.positions.copy(), universe.atoms)
2087 .run()
2088 .results
2089 )["timeseries"]
2090 coordinates = np.tile(coordinates, (length, 1, 1))
2091 universe.load_new(coordinates, format=MemoryReader)