Coverage for encodermap/misc/rotate.py: 9%
189 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/misc/rotate.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Helpers to apply rotations to molecular coordinates.
24"""
26################################################################################
27# Imports
28################################################################################
31# Future Imports at the top
32from __future__ import annotations
34# Standard Library Imports
35import copy
36import warnings
37from copy import deepcopy
38from typing import TYPE_CHECKING, Optional, Union, overload
40# Third Party Imports
41import numpy as np
42import transformations as trans
43from optional_imports import _optional_import
45# Encodermap imports
46from encodermap.trajinfo.trajinfo_utils import _delete_bond
49################################################################################
50# Optional Imports
51################################################################################
54md = _optional_import("mdtraj")
57################################################################################
58# Typing
59################################################################################
62if TYPE_CHECKING:
63 # Third Party Imports
64 import networkx as nx
65 from mdtraj.core.topology import Atom
68################################################################################
69# Globals
70################################################################################
73__all__: list[str] = ["mdtraj_rotate"]
76################################################################################
77# Functions
78################################################################################
81def arbitrary_dihedral(
82 pos: np.ndarray,
83 out: Optional[np.ndarray] = None,
84) -> np.ndarray:
85 """Computes the dihedral angles of a position array with shape (n_frames, 4).
87 Args:
88 pos (np.ndarray): The positions between which to calculate the dihedrals.
89 out (np.ndarray): A location into which the result is stored. If provided,
90 it must have a shape that the inputs broadcast to. If not provided
91 or None, a freshly-allocated array is returned. A tuple (possible
92 only as a keyword argument) must have length equal to the number of outputs.
94 Returns:
95 np.ndarray: The dihedral angles in radians.
97 """
98 p0 = pos[:, 0]
99 p1 = pos[:, 1]
100 p2 = pos[:, 2]
101 p3 = pos[:, 3]
103 b1 = -1.0 * (p1 - p0)
104 b2 = p2 - p1
105 b3 = p3 - p2
107 c1 = np.cross(b2, b3)
108 c2 = np.cross(b1, b2)
110 p1 = (b1 * c1).sum(-1)
111 p1 *= (b2 * b2).sum(-1) ** 0.5
112 p2 = (c1 * c2).sum(-1)
114 return np.arctan2(p1, p2, out)
117def mdtraj_rotate(
118 traj: md.Trajectory,
119 angles: np.ndarray,
120 indices: np.ndarray,
121 deg: bool = False,
122 check_cyclic_backbone: bool = True,
123 verify_every_rotation: bool = False,
124 drop_proline_angles: bool = False,
125 delete_sulfide_bridges: bool = True,
126) -> md.Trajectory:
127 """Uses MDTraj and Christoph Gohlke's transformations.py to set bond
128 rotations provided traj.
130 Input can be in radian (set `deg` to False) or degree (set `deg` to True).
132 General procedure:
133 * Carry out some checks. Shapes of input need to be correct. `traj`
134 needs to have a single frame and not be of a cyclic protein nor
135 contain multiple chains.
136 * Get the indices of the near and far side of the rotations. Every
137 dihedral angle is indexed by 4 atoms. The rotational axis is located
138 between the central two atoms (dihedral[1:3]).
139 * Extend the trajectory. The lengths of dihedrals and sidechain_dihedrals
140 should match. The frame given by top will be duplicated
141 len(dihedrals)-times.
142 * Get the current angles. We know what the final angles should be, but
143 now how far to rotate the bonds. This can be done by getting the
144 difference between current and target angle.
145 * Rotate the bonds. Using Christoph Gohlke's transformations.py,
146 the rotation matrix is constructed and the array is padded with
147 zeros to resemble an array of quaternions.
149 Args:
150 traj (mdtraj.Trajectory): The trajectory to use. Needs to have only one frame.
151 angles (list[list[float]], np.ndarray): The values the angles should
152 assume after the rotations. This arg can either be a nested list
153 with floats or (better) a numpy array with the shape angles.shape =
154 (n_new_frames, n_indexed_dihedrals). Here, angles.shape[0] defines
155 how many frames the output trajectory is going to have and angles.shape[1]
156 should be similar to the number of dihedrals you want to rotate around.
157 A shape of (4, 2) would indicate that two dihedrals are going to be
158 used for rotation and the output trajectory is going to have 4 frames.
159 indices (list[list[int]], np.ndarray): A list of ints indexing the
160 dihedrals to be rotated around. Naturally indices.shape[1] needs
161 to be 4. Additionally indices.shape[0] needs to be the same as
162 angles.shape[1]. indices indexes the angles along axis 1 and angles
163 sets the values of those angles along axis 0.
164 deg (bool, optional): Whether argument `angles` is in deg.
165 Defaults to False.
166 check_cyclic_backbone (bool): Whether the backbone should be
167 checked for being cyclic. Rotating around a backbone angle for a
168 cyclic protein is not possible and thus an Exception is raised.
169 However, rotation around sidechain dihedrals is still possible.
170 If you are sure you want to rotate sidechain dihedrals set this
171 argument to False to prevent the cyclic backbone check.
172 Defaults to True.
173 verify_every_rotation (bool): Whether the rotation succeeded.
174 drop_proline_angles (bool): Whether to automatically drop proline
175 angles and indices.
176 delete_sulfide_bridges (bool): Whether to automatically remove bonds from
177 between cysteine residues.
179 Raises:
180 Exception: If the input seems like it is in degrees, but `deg` is False.
181 Exception: If `traj` contains more than 1 frame.
182 Exception: If traj is not fully connected.
183 Exception: If shapes of `angles` and `indices` mismatches.
184 Exception: If shape[1] of `indices` is not 4.
185 Exception: If backbone is cyclic and check_cyclic_backbone is True.
186 Exception: If the first rotation does not reach a tolerance of 1e-3.
188 Returns:
189 mdtraj.Trajectory: An MDTraj trajectory with applied rotations.
191 Examples:
192 >>> import mdtraj as md
193 >>> import numpy as np
195 >>> # load an arbitrary protein from the pdb
196 >>> traj = md.load_pdb('https://files.rcsb.org/view/1GHC.pdb')
197 >>> print(traj.n_frames)
198 14
200 >>> # traj has multiple frames so we remove all but one
201 >>> traj = traj[0]
203 >>> # Get indices of psi_angles
204 >>> psi_indices, old_values = md.compute_psi(traj)
206 >>> # set every psi angle to be either 0 or 180 deg
207 >>> angles = np.full((len(psi_indices), 2), [0, 180]).T
209 >>> # create the new traj with the desired rotations
210 >>> out_traj = mdtraj_rotate(traj, angles, psi_indices, deg=True)
211 >>> print(out_traj.n_frames)
212 2
214 >>> # check values
215 >>> _, new_values = md.compute_psi(out_traj)
216 >>> print(np.abs(np.rad2deg(new_values[0, :2]).round(0))) # prevent rounding inconsistencies
217 [0. 0.]
218 >>> print(np.abs(np.rad2deg(new_values[1, :2]).round(0))) # prevent rounding inconsistencies
219 [180. 180.]
221 """
222 # Third Party Imports
223 import networkx as nx
225 if deg:
226 angles = np.deg2rad(angles)
227 else:
228 if np.any(angles > 2 * np.pi):
229 print(
230 "Some of your input for `angles` is larger than 2pi. "
231 "This suggests, that you provided `angles` in deg not in rad. "
232 "Please set the argument `deg` to True to transpose the values "
233 "of `angles`."
234 )
236 # make np arrays
237 indices = np.asarray(indices).astype(int)
238 angles = np.asarray(angles)
240 # if only one rotation/index is given expand dims
241 if angles.ndim == 1:
242 angles = np.expand_dims(angles, -1)
243 if indices.ndim == 1:
244 indices = np.expand_dims(indices, 0)
246 # check whether traj has only one frame:
247 if traj.n_frames > 1:
248 raise Exception(
249 f"The provided `traj` has {traj.n_frames}. "
250 f"Please provide a traj with only 1 frame."
251 )
252 traj = deepcopy(traj)
254 # check if shape of indices and dihedrals is consistent
255 if angles.shape[1] != indices.shape[0]:
256 raise Exception(
257 f"Shapes of `angles` and `indices` mismatch. Shape[1] of `angles` is "
258 f"{angles.shape[1]}, which indicates that you also want to rotate "
259 f"around {angles.shape[1]} dihedral(s), but indices indexes "
260 f"{indices.shape[0]} dihedral angle(s). The shapes of the inputs "
261 f"need to match. indices.shape[0] == angles.shape[1]"
262 )
264 # check whether 4 atoms are indexed
265 if indices.shape[1] != 4:
266 raise Exception(
267 f"The shape of `indices` needs to be (n_dihedrals, 4), meaning, that "
268 f"a dihedral angle is defined by 4 atoms. Your `indices` argument` "
269 f"has shape {indices.shape} which is not allowed."
270 )
272 # check whether structure is whole
273 g = traj.top.to_bondgraph()
274 if not nx.is_connected(g):
275 raise Exception(
276 "Structure is disjoint and not fully connected. This can be caused by"
277 "multiple problems: Topology was not parsed correctly "
278 "(nonstandard residues), traj contains multiple proteins (chains), "
279 "or water molecules. You can add the bonds with "
280 "`mdtraj.Topology.add_bond()`, if you know what bonds exactly are missing."
281 )
283 # check whether is cyclic
284 backbone = traj.atom_slice(traj.top.select("backbone"))
285 try:
286 edges = nx.algorithms.cycles.find_cycle(backbone.top.to_bondgraph())
287 except nx.NetworkXNoCycle:
288 pass
289 else:
290 if check_cyclic_backbone:
291 raise Exception(
292 "The Backbone of your protein is cyclic. You can not rotate "
293 "around the backbone dihedral (Ramachandran) angles. "
294 "You can still rotate around sidechain angles. For that set the "
295 "`check_cyclic_backbone` argument of this function to "
296 "True to silence this warning."
297 )
299 # remove proline angles from angles and indices
300 if drop_proline_angles:
301 offending_atoms = set(traj.top.select("resname PRO and (name CA or name N)"))
302 slice = []
303 for i, ind in enumerate(indices):
304 ind = ind[1:3]
305 if len(offending_atoms.intersection(set(ind))) == 2:
306 slice.append(i)
307 slice = np.array(slice)
308 angles = np.delete(angles, slice, axis=1)
309 indices = np.delete(indices, slice, axis=0)
310 assert not offending_atoms.issubset(set(indices[:, 1:3].flatten()))
312 # delete sulfide bridges from the topology
313 if delete_sulfide_bridges:
314 for i, (a, b) in enumerate(traj.top.bonds):
315 if (
316 a.element.symbol == "S"
317 and b.element.symbol == "S"
318 and a.residue.index != b.residue.index
319 ):
320 traj.top = _delete_bond(traj.top, (a, b))
321 g = traj.top.to_bondgraph()
323 # get near and far sides
324 dih_near_sides, dih_far_sides = _get_near_and_far_networkx(
325 g, indices[:, 1:3], top=traj.top
326 )
328 # extend the traj
329 for i in range(len(angles)):
330 if i == 0:
331 out_traj = copy.deepcopy(traj)
332 else:
333 out_traj = out_traj.join(traj)
335 # adjust the torsions
336 new_xyz = copy.deepcopy(out_traj.xyz)
337 for i in range(angles.shape[0]):
338 for j in range(angles.shape[1]):
339 # central_dihedrals
340 near_side = dih_near_sides[j]
341 far_side = dih_far_sides[j]
342 dihedral = indices[j]
343 bond = dihedral[1:3]
345 # define inputs
346 target_angle = angles[i, j]
347 current_angle = _dihedral(new_xyz[i], dihedral)[0][0]
348 angle = target_angle - current_angle
349 direction = np.diff(new_xyz[i, bond], axis=0).flatten()
350 pivot_point = new_xyz[i, bond[0]]
352 # perform rotation
353 rotmat = trans.rotation_matrix(angle, direction, pivot_point)
354 padded = np.pad(
355 new_xyz[i][far_side],
356 ((0, 0), (0, 1)),
357 mode="constant",
358 constant_values=1,
359 )
360 new_xyz[i][far_side] = rotmat.dot(padded.T).T[:, :3]
362 if i == 0 and j == 0 and verify_every_rotation:
363 dih_indexes = traj._CVs.attrs["central_dihedrals"][j]
364 s = f"Near and far side for dihedral {[str(traj.top.atom(x)) for x in dih_indexes]} are:"
365 s += (
366 f"\nNear: {[str(traj.top.atom(x)) for x in near_side]}, {near_side}"
367 )
368 s += f"\nFar: {[str(traj.top.atom(x)) for x in dih_far_sides[j][:12]]}..., {dih_far_sides[j][:12]}..."
369 s += f"\nRotation around bond {[str(traj.top.atom(x)) for x in bond]}, {bond}."
370 s += f"\nPositions of near side before rotation are\n{out_traj.xyz[i][near_side]}."
371 s += f"\nPositions of near side after rotation aren\n{new_xyz[i][near_side]}"
372 print(s)
374 # verify
375 if verify_every_rotation:
376 _ = np.rad2deg(_dihedral(new_xyz[i], dihedral))[0][0]
377 if not np.isclose(_, target_angle, atol=1e-3):
378 s = f"Adjusting dihedral angle for atoms {[str(traj.top.atom(x)) for x in dihedral]} failed with a tolerance of 1e-4."
379 s += f"\nTarget angle was {target_angle} deg, but rotation yieled angle with {_} deg."
380 s += f"\nCurrent angle was {current_dih[j]}. To reach target angle is a rotation of {angle} degrees was carried out."
381 s += f"\nRotation axis was vector from {traj.top.atom(bond[0])} to {traj.top.atom(bond[1])}"
382 s += f"\nOnly these atoms should have been affected by rotation: {far_side}"
383 s += "\nBut somehow this method still crashed. Maybe these prints will help."
384 raise Exception(s)
386 # overwrite traj and return
387 out_traj.xyz = new_xyz
388 return out_traj
391@overload
392def _get_near_and_far_networkx( 392 ↛ exitline 392 didn't jump to the function exit
393 bondgraph: nx.Graph,
394 edge_indices: np.ndarray,
395 top: Optional[md.Topology] = None,
396 parallel: bool = True,
397) -> tuple[np.ndarray, None]: ...
400@overload
401def _get_near_and_far_networkx( 401 ↛ exitline 401 didn't jump to the function exit
402 bondgraph: nx.Graph,
403 edge_indices: np.ndarray,
404 top: Optional[md.Topology] = None,
405 parallel: bool = False,
406) -> tuple[list[np.ndarray], list[np.ndarray]]: ...
409def _get_near_and_far_networkx(
410 bondgraph: nx.Graph,
411 edge_indices: np.ndarray,
412 top: Optional[md.Topology] = None,
413 parallel: bool = False,
414) -> Union[tuple[list[np.ndarray], list[np.ndarray]], tuple[np.ndarray, None]]:
415 """Returns near and far sides for a list of edges giving the indices of the
416 two atoms at which the structure is broken.
418 Args:
419 bondgraph (networkx.classes.graph.Graph): The bondgraph describing the protein.
420 edge_indices (np.ndarray): The edges where the graph will be broken at.
421 top (Optional[md.Topology]): Used for printing helpful messages in exceptions.
422 parallel (bool): Whether to return a dense array and None.
424 Returns:
425 tuple[list[np.ndarray], list[np.ndarray]]: A tuple containing the following:
426 - near_sides (list[np.ndarray]): List of integer arrays giving the near
427 - sides. len(near_sides) == len(edge_indices).
428 - far_sides (list[np.ndarray]): Same as near sides, but this time the far sides.
430 """
431 # Third Party Imports
432 import networkx as nx
433 from networkx.algorithms.components.connected import connected_components
435 assert edge_indices.shape[1] == 2, (
436 f"Can only take `edge_indices` as a numpy array, with shape[1] = 2, but "
437 f"the provided `edge_indices` has shape {edge_indices.shape[1]=}."
438 )
440 if parallel:
441 out = np.zeros(shape=(len(edge_indices), len(bondgraph))).astype(bool)
443 near_sides = []
444 far_sides = []
445 for i, edge in enumerate(edge_indices):
446 G = nx.convert_node_labels_to_integers(bondgraph).copy()
447 try:
448 G.remove_edge(*edge)
449 except nx.NetworkXError as e:
450 if top:
451 raise Exception(
452 f"Seems like the edge {[top.atom(a) for a in edge]=} "
453 f"{[top.atom(a).index for a in edge]=}is not "
454 f"part of the graph. This might originate from a bond, that has "
455 f"been deleted, but which atoms are still considered to be part of "
456 f"a dihedral. Maybe you want supplied this topology to the "
457 f"mdtraj_backmapping method with wrong featurization."
458 ) from e
459 else:
460 raise Exception(
461 f"Please provide arg `top` to learn more about this Exception"
462 ) from e
463 except TypeError as e:
464 raise Exception(
465 f"Could not remove the edge {edge=}, {edge_indices.shape=}."
466 ) from e
467 components = [*connected_components(G)]
468 if len(components) != 2:
469 if top is None:
470 raise Exception(
471 f"Splitting the topology of the trajectory at the edge "
472 f"{edge} does not work. Provide a topology to see, "
473 "which atoms are affected"
474 )
475 else:
476 path = nx.shortest_path(G, *edge)
477 assert isinstance(path, list)
478 path_atoms = [top.atom(i) for i in path]
479 a1 = top.atom(edge[0])
480 a2 = top.atom(edge[1])
481 raise Exception(
482 f"Splitting at edge {edge} does not work. Here are the "
483 f"atoms: {a1} and {a2}. Removing "
484 f"this edge resulted in {len(components)} disconnected components. "
485 f"The resSeqs of the residues are {a1.residue.resSeq} and "
486 f"{a2.residue.resSeq}. The indices are {a1.residue.index} and "
487 f"{a2.residue.index}. These two atoms are still connected via "
488 f"this path: {path_atoms}."
489 )
491 if edge[1] in components[0] and edge[0] in components[1]:
492 components = components[::-1]
493 assert len(components) == 2, print(
494 f"Protein might be cyclic or contain more than 1 chain. {len(components)}"
495 )
496 assert edge[0] in components[0] and edge[1] in components[1], print(
497 "Finding near and far sides failed."
498 )
499 subgraph = G.subgraph(components[-1]).copy()
500 far = np.asarray(subgraph.nodes)
501 subgraph = G.subgraph(components[0]).copy()
502 near = np.asarray(subgraph.nodes)
503 if parallel:
504 out[i][near] = True
505 else:
506 far_sides.append(far)
507 near_sides.append(near)
508 if not parallel:
509 return near_sides, far_sides
510 else:
511 return out, None
514def _angle(
515 xyz: np.ndarray,
516 indices: np.ndarray,
517) -> float:
518 """Returns current angle between positions.
520 Adapted from MDTraj.
522 Args:
523 xyz (np.ndarray). This function only takes a xyz array of a single frame and uses np.expand_dims()
524 to make that fame work with the `_displacement` function from mdtraj.
525 indices (Union[np.ndarray, list]): List of 3 ints describing the dihedral.
527 Returns:
528 np.ndarray: The angle.
530 """
531 indices = np.expand_dims(np.asarray(indices), 0)
532 xyz = np.expand_dims(xyz, 0)
533 ix01 = indices[:, [1, 0]]
534 ix21 = indices[:, [1, 2]]
536 u_prime = _displacement(xyz, ix01)
537 v_prime = _displacement(xyz, ix21)
538 u_norm = np.sqrt((u_prime**2).sum(-1))
539 v_norm = np.sqrt((v_prime**2).sum(-1))
541 u = u_prime / (u_norm[..., np.newaxis])
542 v = v_prime / (v_norm[..., np.newaxis])
544 return np.arccos((u * v).sum(-1))
547def _dihedral(
548 xyz: np.ndarray,
549 indices: np.ndarray,
550) -> float:
551 """Returns current dihedral angle between positions.
553 Adapted from MDTraj.
555 Args:
556 xyz (np.ndarray). This function only takes a xyz array of a single frame and uses np.expand_dims()
557 to make that fame work with the `_displacement` function from mdtraj.
558 indices (Union[np.ndarray, list]): List of 4 ints describing the dihedral.
560 Returns:
561 np.ndarray: The dihedral.
563 """
564 indices = np.expand_dims(np.asarray(indices), 0)
565 xyz = np.expand_dims(xyz, 0)
566 ix10 = indices[:, [0, 1]]
567 ix21 = indices[:, [1, 2]]
568 ix32 = indices[:, [2, 3]]
570 b1 = _displacement(xyz, ix10)
571 b2 = _displacement(xyz, ix21)
572 b3 = _displacement(xyz, ix32)
574 c1 = np.cross(b2, b3)
575 c2 = np.cross(b1, b2)
577 p1 = (b1 * c1).sum(-1)
578 p1 *= (b2 * b2).sum(-1) ** 0.5
579 p2 = (c1 * c2).sum(-1)
581 return np.arctan2(p1, p2, None)
584def _displacement(xyz: np.ndarray, pairs: np.ndarray) -> np.ndarray:
585 """Displacement vector between pairs of points in each frame
587 Args:
588 xyz (np.ndarray): The coordinates of the atoms.
589 pairs (np.ndarray): An array with integers and shape (n_pairs, 2),
590 defining the atom paris between which the displacement will
591 be calculated.
593 Returns:
594 np.ndarray: An array with shape (n_pairs, ).
596 """
597 value = np.diff(xyz[:, pairs], axis=2)[:, :, 0]
598 assert value.shape == (
599 xyz.shape[0],
600 pairs.shape[0],
601 3,
602 ), "v.shape %s, xyz.shape %s, pairs.shape %s" % (
603 str(value.shape),
604 str(xyz.shape),
605 str(pairs.shape),
606 )
607 return value