Coverage for encodermap/misc/backmapping.py: 35%
322 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/misc/backmapping.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Backmapping functions largely based upon encodermap_tf1's nackmapping an martini-tools backwards.py
24 ToDo:
25 * Using Quaternions in Tensorflow rotation matrices could be accelerated?
26 * Multi Top.
28"""
30##############################################################################
31# Imports
32##############################################################################
34import copy
35from math import pi
37import numpy as np
38import tensorflow as tf
40from .._optional_imports import _optional_import
41from ..misc import transformations as trans
42from ..misc.errors import BadError
44##############################################################################
45# Optional Imports
46##############################################################################
49md = _optional_import("mdtraj")
50mda = _optional_import("MDAnalysis")
51AnalysisFromFunction = _optional_import(
52 "MDAnalysis", "analysis.base.AnalysisFromFunction"
53)
54MemoryReader = _optional_import("MDAnalysis", "coordinates.memory.MemoryReader")
57##############################################################################
58# Globals
59##############################################################################
62__all__ = ["backbone_hydrogen_oxygen_crossproduct"]
65##############################################################################
66# Public Functions
67##############################################################################
70def split_and_reverse_dihedrals(x):
71 """Splits dihedrals in BackMapping model into left (reversed) and right part.
72 These dihedrals are then used to bring the chain_in_plane into 3D.
74 Args:
75 x (tf.Tensor): The dihedrals with shape (None, n_reisudes * 3 - 3)
77 Examples:
78 >>> from encodermap.misc.backmapping import split_and_reverse_dihedrals
79 >>> import numpy as np
81 >>> # create dihedrals for protein with 3 resiudes, i.e. 3*3 - 3 = 6 central dihedral angles
82 >>> # single sample will be used -> shape = (1, 6)
83 >>> np.random.seed(20)
84 >>> dihedrals = np.random.random((1, 6)) * 2 * np.pi
85 >>> print(dihedrals)
86 [[3.69533481 5.64050171 5.60165278 5.12605805 0.22550092 4.34644107]]
88 >>> dihedrals_left, dihedrals_right = split_and_reverse_dihedrals(dihedrals)
89 >>> print(dihedrals_left, dihedrals_right)
90 [[5.60165278 5.64050171 3.69533481]] [[5.12605805 0.22550092 4.34644107]]
92 """
93 if tf.executing_eagerly(): 93 ↛ 100line 93 didn't jump to line 100, because the condition on line 93 was never false
94 middle = int(x.shape[1] / 2)
95 if tf.math.equal(tf.math.mod(x.shape[1], 2), 0):
96 return x[:, middle - 1 :: -1], x[:, middle:]
97 else:
98 return x[:, middle::-1], x[:, middle + 1 :]
99 else:
100 middle = int(x.shape[1] / 2)
101 cond = tf.math.equal(tf.math.mod(x.shape[1], 2), 0)
102 return tf.cond(
103 cond,
104 true_fn=lambda: (x[:, middle - 1 :: -1], x[:, middle:]),
105 false_fn=lambda: (x[:, middle::-1], x[:, middle + 1 :]),
106 )
109def split_and_reverse_cartesians(x):
110 """Splits cartesians and returns a left (reversed) right part.
112 Because dihedrals are made up from 4 atoms, three atoms are
113 identical in the left and right part of the list. This holds true:
114 left[0] = right[2]
115 left[1] = right[1]
116 left[2] = right[0]
118 Args:
119 x (tf.Tensor): The cartesians with shape (None, n_reisudes * 3, 3)
121 Examples:
122 >>> from encodermap.misc.backmapping import split_and_reverse_cartesians
123 >>> import numpy as np
125 >>> # create cartesians for protein with 3 resiudes, i.e. 9
126 >>> # single sample will be used -> shape = (1, 9, 3)
127 >>> np.random.seed(20)
128 >>> cartesians = np.random.random((1, 9, 3)) * 10
130 >>> cartesians_left, cartesians_right = split_and_reverse_cartesians(cartesians)
132 >>> print(cartesians_left.shape, cartesians_right.shape)
133 (1, 6, 3) (1, 6, 3)
135 >>> print(cartesians_left[:,0] == cartesians_right[:,2])
136 [[ True True True]]
138 >>> print(cartesians_left[:,1] == cartesians_right[:,1])
139 [[ True True True]]
141 >>> print(cartesians_left[:,2] == cartesians_right[:,0])
142 [[ True True True]]
146 """
147 middle = int(x.shape[1] / 2)
148 return x[:, middle + 1 :: -1], x[:, middle - 1 :]
151def dihedrals_to_cartesian_tf_layers(dihedrals, cartesians):
152 """Calculates dihedrals to cartesians in Graph/Layer execution.
154 Args:
155 dihedrals (tf.Tensor): The dihedrals of shape (None, n_resides * 3 - 3)
156 cartesians (tf.Tensor): The cartesians of shaoe (None, n_residues * 3, 3).
158 """
160 if len(cartesians.get_shape()) == 2: 160 ↛ 162line 160 didn't jump to line 162, because the condition on line 160 was never true
161 # if a single line of cartesians is passed it is repeated to match the number of dihedrals
162 cartesians = Lambda(
163 lambda x: tf.tile(tf.expand_dims(x[0], axis=0), [tf.shape(x[1])[0], 1, 1])
164 )((cartesians, dihedrals))
166 # split and reverse so that the center of the molecule stays in the 2D plane
167 # and the left and right ends curl into the 3rd dimension
168 cartesians_left, cartesians_right = split_and_reverse_cartesians(cartesians)
169 dihedrals_left, dihedrals_right = split_and_reverse_dihedrals(dihedrals)
171 new_cartesians_left = dihedral_to_cartesian_tf_one_way_layers(
172 dihedrals_left, cartesians_left, int(dihedrals.shape[1] / 2)
173 )
174 new_cartesians_right = dihedral_to_cartesian_tf_one_way_layers(
175 dihedrals_right, cartesians_right, int(dihedrals.shape[1] / 2)
176 )
178 new_cartesians = tf.concat(
179 [new_cartesians_left[:, ::-1], new_cartesians_right[:, 3:]], axis=1
180 )
182 return new_cartesians
185def mdtraj_backmapping(
186 top,
187 dihedrals,
188 sidechain_dihedrals=None,
189 trajs=None,
190 fake_dihedrals=False,
191 verify_every_rotation=True,
192 angle_type="radian",
193):
194 """Uses MDTraj and Christoph Gohlke's transformations.py to rotate the bonds in the provided topology.
196 Input currently only in angles.
198 General procedure:
199 * Decide on which topology to use (if different topologies are in the TrajEnsemble class the `dihedrals` and
200 `sidechain_dihedrals` arrays need to be altered so that the correct dihedrals are used. Because EncoderMap
201 is trained on a full input `dihedrals` and `sidechain_dihedrals` contain the dihedrals for the topology
202 in `TrajEnsemble` with the most of such angles. Some SingleTraj classes in TrajEnsemble might not contain all these
203 angles if for example an amino acid has been modified the mutant contains more sidechain dihedrals than the
204 wt. So the correct sidechain dihedrals for the wildtype need to be selected.
205 * Get the indices of the far sides of the rotations. The graph is gradually broken apart and the longer
206 subgraphs are kept.
207 * Extend the trajectory. The lengths of dihedrals and sidechain_dihedrals should match. The frame given by top
208 will be duplicated len(dihedrals)-times.
209 * Get the current angles. We know what the final angles should be, but now how far to rotate the bonds. This
210 can be done by getting the difference between current and target angle.
211 * Rotate the bonds. Using Christoph Gohlke's transformations.py, the rotation matrix is constructed and
212 the array is padded with zeros to resemble an array of quaternions.
214 Args:
215 top (str): The topology file to use.
216 dihedrals (np.ndarray): The dihedrals to put onto the trajectory. len(dihedrals) is number of frames of
217 output trajectory. dihedrals.shape[1] needs to be the same as the number of dihedrals in the topology.
218 sidechain_dihedrals (Union[np.ndarray, None], optional): The sidechain dihedrals to put onto the trajectory.
219 If None is provided, the sidechains are kept like they were in the topology. Defaults to None.
220 trajs (Union[em.TrajEnsemble, None], optional): Encodermap TrajEnsemble class. Can accelerate loading of
221 current dihedral angles. Also checks of provided topology is part of trajs. Defaults to None.
222 fake_dihedrals (bool, optional): Whether to fake dihedrals. For debugging. Defaults to False.
223 verify_every_rotation (bool, optional): Whether the rotation succeeded.
224 angle_type (bool, optional): Whether input is in degrees. Input will be converted to radians. Defaults to False.
226 Raises:
227 Exception: If the input seems like it is in degrees.
228 Exception: If top is not part of the TrajEnsemble class provided in argument `trajs`.
230 Returns:
231 mdtraj.Trajectory: An MDTraj trajectory with the correct dihedrals/side-dihedrals.
233 """
234 import networkx as nx
236 # change the angles
237 if angle_type == "radian":
238 pass
239 elif angle_type == "degree":
240 dihedrals = np.deg2rad(dihedrals)
241 sidechain_dihedrals = np.deg2rad(sidechain_dihedrals)
242 else:
243 raise Exception("Argument `angle_type` must be either 'radian' or 'degree'.")
245 # make sure the input has the same shape along the "frame" axis, that will be created.
246 if sidechain_dihedrals is not None:
247 assert len(dihedrals) == len(sidechain_dihedrals)
249 # decide on what to do with trajs. If it was supplied it can
250 # either be TrajEnsemble or SingleTraj
251 if trajs is not None:
252 # TrajEnsemble
253 if hasattr(trajs, "n_trajs"):
254 # The topology in `top` should also be present in traj
255 if not top in trajs._top_files:
256 raise Exception(
257 f"Provided topology is not part of TrajEnsemble object. Possible tops are {trajs._top_files}"
258 )
260 # Deciding which traj to use
261 ind = [top == i for i in trajs._top_files].index(True)
262 traj = trajs[ind]
263 print(
264 f"Using trajectory {traj} as parent for backmapping. Because its "
265 f"topoloy file ({traj.top_file}) matches the file provided as "
266 f"argument `top` ({top})."
267 )
268 # SingleTraj
270 else:
271 traj = trajs
273 # load the CVs to use the indices
274 traj.load_CV("all")
275 inp_traj = md.load(traj.top_file)
276 else:
277 # either build info Single and load the CVs, which is currently broken for pdbs
278 from ..trajinfo import info_single
280 try:
281 traj = info_single.SingleTraj(top)
282 traj.load_CV("all")
283 inp_traj = md.load(top)
284 except OSError as e:
285 # this is a weird PyEMMA error that keeps happening
286 if "REMARK" in e.__str__():
287 traj = md.load(top)
288 inp_traj = md.load(top)
289 angles = ["psi", "omega", "phi"]
290 dihedrals_ = [
291 getattr(md, f"compute_{a}")(traj)[0].tolist() for a in angles
292 ]
293 results = [None] * (
294 len(dihedrals_[0]) + len(dihedrals_[0]) + len(dihedrals_[2])
295 )
296 results[::3] = dihedrals_[0]
297 results[1::3] = dihedrals_[1]
298 results[2::3] = dihedrals_[2]
299 dih_indices = np.array(results)
300 else:
301 raise e
303 # get indices of atoms for rotations
304 g = inp_traj.top.to_bondgraph()
305 # nx.draw(g, pos=nx.spring_layout(g))
306 if not nx.is_connected(g):
307 raise BadError(
308 f"MDTraj parsed the topology at {top} and found it disconnected. Changing dihedrals in multiple "
309 f"chains is currently not possible. If you are sure your protein is just one chain you can try "
310 f"the MDAnalysis backmapping backend or provide a topology of the file with manually fixed bonds."
311 )
313 # get near and far sides
314 # dih indices are four atoms
315 # bond indices are the two atoms in the middle giving the axis of rotation
316 if hasattr(traj, "_CVs"):
317 if "central_dihedrals" in traj._CVs.attrs:
318 dih_indices = np.asarray(traj._CVs.attrs["central_dihedrals"])
320 # at this point dih_bond_indices has been defined.
321 # either via PyEMMA featurizer or the compute_phi/omega/psi methods of mdtraj
322 dih_bond_indices = dih_indices[:, 1:3]
324 # filter out the Proline angles
325 dih_bond_atoms = np.dstack(
326 [
327 [traj.top.atom(a).__str__() for a in dih_bond_indices[:, 0]],
328 [traj.top.atom(a).__str__() for a in dih_bond_indices[:, 1]],
329 ]
330 )[0]
331 indices = np.arange(len(dih_bond_indices)).tolist()
332 for i, bond in enumerate(dih_bond_atoms):
333 if "PRO" in bond[0] and "PRO" in bond[1] and "N" in bond[0] and "CA" in bond[1]:
334 indices.remove(i)
336 dih_indices = dih_indices[indices]
337 dih_bond_indices = dih_bond_indices[indices]
338 dihedrals = dihedrals[:, indices]
339 dih_near_sides, dih_far_sides = _get_far_and_near_networkx(
340 g, dih_bond_indices, inp_traj.top
341 )
343 if sidechain_dihedrals is not None:
344 if "side_dihedrals" not in traj._CVs.attrs:
345 try:
346 traj.load_CV("all")
347 except OSError as e:
348 pass
349 raise NotImplementedError(
350 f"This traj produces some error with PyEMMA: {e} " ""
351 )
353 side_indices = np.asarray(traj._CVs.attrs["side_dihedrals"])
354 side_bond_indices = side_indices[:, 1:3]
355 # filter out the Proline angles
356 side_bond_atoms = np.dstack(
357 [
358 [traj.top.atom(a).__str__() for a in side_bond_indices[:, 0]],
359 [traj.top.atom(a).__str__() for a in side_bond_indices[:, 1]],
360 ]
361 )[0]
362 indices = np.arange(len(side_bond_indices)).tolist()
363 for i, bond in enumerate(side_bond_atoms):
364 if (
365 "PRO" in bond[0]
366 and "PRO" in bond[1]
367 and "CA" in bond[0]
368 and "CB" in bond[1]
369 ):
370 indices.remove(i)
371 if (
372 "PRO" in bond[0]
373 and "PRO" in bond[1]
374 and "CB" in bond[0]
375 and "CG" in bond[1]
376 ):
377 indices.remove(i)
379 side_indices = side_indices[indices]
380 side_bond_indices = side_bond_indices[indices]
381 sidechain_dihedrals = sidechain_dihedrals[:, indices]
383 side_near_sides, side_far_sides = _get_far_and_near_networkx(
384 g, side_bond_indices, inp_traj.top
385 )
387 # extend the traj
388 for i in range(len(dihedrals)):
389 if i == 0:
390 out_traj = copy.deepcopy(inp_traj)
391 else:
392 out_traj = out_traj.join(inp_traj)
394 if fake_dihedrals:
395 print("Faking dihedrals for testing purposes.")
396 # dihedrals = np.vstack([current_angles for i in range(len(dihedrals))])
397 # dihedrals[:, 0] = np.linspace(-170, 170, len(dihedrals))
398 dihedrals = np.dstack(
399 [np.linspace(-170, 170, len(dihedrals)) for i in range(dihedrals.shape[1])]
400 ).squeeze()
401 sidechain_dihedrals = np.dstack(
402 [
403 np.linspace(-170, 170, len(sidechain_dihedrals))
404 for i in range(sidechain_dihedrals.shape[1])
405 ]
406 ).squeeze()
408 # adjust the torsions
409 new_xyz = copy.deepcopy(out_traj.xyz)
410 for i in range(dihedrals.shape[0]):
411 for j in range(dihedrals.shape[1]):
412 # central_dihedrals
413 near_side = dih_near_sides[j]
414 far_side = dih_far_sides[j]
415 dihedral = dih_indices[j]
416 bond = dih_bond_indices[j]
418 # define inputs
419 target_angle = dihedrals[i, j]
420 current_angle = _dihedral(new_xyz[i], dihedral)[0][0]
421 angle = target_angle - current_angle
422 direction = np.diff(new_xyz[i, bond], axis=0).flatten()
423 pivot_point = new_xyz[i, bond[0]]
425 # perform rotation
426 rotmat = trans.rotation_matrix(angle, direction, pivot_point)
427 padded = np.pad(
428 new_xyz[i][far_side],
429 ((0, 0), (0, 1)),
430 mode="constant",
431 constant_values=1,
432 )
433 new_xyz[i][far_side] = rotmat.dot(padded.T).T[:, :3]
435 if i == 0 and j == 0 and verify_every_rotation:
436 dih_indexes = traj._CVs.attrs["central_dihedrals"][j]
437 s = f"Near and far side for dihedral {[str(traj.top.atom(x)) for x in dih_indexes]} are:"
438 s += (
439 f"\nNear: {[str(traj.top.atom(x)) for x in near_side]}, {near_side}"
440 )
441 s += f"\nFar: {[str(traj.top.atom(x)) for x in dih_far_sides[j][:12]]}..., {dih_far_sides[j][:12]}..."
442 s += f"\nRotation around bond {[str(traj.top.atom(x)) for x in bond]}, {bond}."
443 s += f"\nPositions of near side before rotation are\n{out_traj.xyz[i][near_side]}."
444 s += f"\nPositions of near side after rotation aren\n{new_xyz[i][near_side]}"
445 print(s)
447 # verify
448 if verify_every_rotation:
449 _ = _dihedral(new_xyz[i], dihedral)[0][0]
450 if not np.isclose(_, target_angle, atol=1e-3):
451 s = (
452 f"Adjusting dihedral angle for atoms {[str(traj.top.atom(x)) for x in dihedral]} failed with a tolerance of 1e-4."
453 f"\nTarget angle was {target_angle} {angle_type}, but rotation yieled angle with {_} {angle_type}."
454 f"\nCurrent angle was {current_angle}. To reach target angle is a rotation of {angle} {angle_type} was carried out."
455 f"\nRotation axis was vector from {traj.top.atom(bond[0])} to {traj.top.atom(bond[1])}"
456 f"\nOnly these atoms should have been affected by rotation: {far_side}"
457 "\nBut somehow this method still crashed. Maybe these prints will help."
458 )
459 raise BadError(s)
461 if sidechain_dihedrals is not None:
462 for j in range(sidechain_dihedrals.shape[1]):
463 # central_dihedrals
464 near_side = side_near_sides[j]
465 far_side = side_far_sides[j]
466 dihedral = dih_indices[j]
467 bond = side_indices[j]
469 # define inputs
470 target_angle = sidechain_dihedrals[i, j]
471 current_angle = np.rad2deg(_dihedral(new_xyz[i], dihedral))[0][0]
472 angle = target_angle - current_angle
473 direction = np.diff(new_xyz[i, bond], axis=0).flatten()
474 pivot_point = new_xyz[i, bond[0]]
476 # perform rotation
477 rotmat = trans.rotation_matrix(angle, direction, pivot_point)
478 padded = np.pad(
479 new_xyz[i][far_side],
480 ((0, 0), (0, 1)),
481 mode="constant",
482 constant_values=1,
483 )
484 new_xyz[i][far_side] = rotmat.dot(padded.T).T[:, :3]
486 # overwrite traj and return
487 out_traj.xyz = new_xyz
488 return out_traj
491def _get_far_and_near_networkx(bondgraph, edge_indices, top=None):
492 """Returns near and far sides for a list of edges giving the indices of the two atoms at which the structure is broken.
494 Args:
495 bondgraph (networkx.classes.graph.Graph): The bondgraph describing the protein.
496 edge_indices (np.ndarray): The edges the graph will be broken at.
498 Returns:
499 tuple: A tuple containing the following:
500 near_sides (list of np.ndarray): List of integer arrays giving the near sides. len(near_sides) == len(edge_indices).
501 far_sides (list of np.ndarray): Same as near sides, but this time the far sides.
503 """
504 import networkx as nx
505 from networkx.algorithms.components.connected import connected_components
507 near_sides = []
508 far_sides = []
509 for i, edge in enumerate(edge_indices):
510 G = nx.convert_node_labels_to_integers(bondgraph).copy()
511 G.remove_edge(*edge)
512 components = [*connected_components(G)]
513 if len(components) != 2:
514 if top is None:
515 raise Exception(
516 f"Splitting the topology of the trajectory at the edge "
517 f"{edge} does not work. Provide a topology to see, "
518 "which atoms are affected"
519 )
520 else:
521 raise Exception(
522 f"Splitting at edge {edge} does not work. Here are the "
523 f"atoms: {top.atom(edge[0])} and {top.atom(edge[1])}."
524 )
526 if edge[1] in components[0] and edge[0] in components[1]:
527 components = components[::-1]
528 assert len(components) == 2, print(
529 f"Protein might be cyclic or contain more than 1 chain. {len(components)}"
530 )
531 assert edge[0] in components[0] and edge[1] in components[1], print(
532 "Finding near and far sides failed."
533 )
534 subgraph = G.subgraph(components[-1]).copy()
535 far_sides.append(np.asarray(subgraph.nodes))
536 subgraph = G.subgraph(components[0]).copy()
537 near_sides.append(np.asarray(subgraph.nodes))
538 return near_sides, far_sides
541def _dihedral(xyz, indices):
542 """Returns current dihedral angle between positions.
544 Adapted from MDTraj.
546 Args:
547 xyz (np.ndarray). This function only takes a xyz array of a single frame and uses np.expand_dims()
548 to make that fame work with the `_displacement` function from mdtraj.
549 indices (Union[np.ndarray, list]): List of 4 ints describing the dihedral.
551 """
552 indices = np.expand_dims(np.asarray(indices), 0)
553 xyz = np.expand_dims(xyz, 0)
554 ix10 = indices[:, [0, 1]]
555 ix21 = indices[:, [1, 2]]
556 ix32 = indices[:, [2, 3]]
558 b1 = _displacement(xyz, ix10)
559 b2 = _displacement(xyz, ix21)
560 b3 = _displacement(xyz, ix32)
562 c1 = np.cross(b2, b3)
563 c2 = np.cross(b1, b2)
565 p1 = (b1 * c1).sum(-1)
566 p1 *= (b2 * b2).sum(-1) ** 0.5
567 p2 = (c1 * c2).sum(-1)
569 return np.arctan2(p1, p2, None)
572def _displacement(xyz, pairs):
573 "Displacement vector between pairs of points in each frame"
574 value = np.diff(xyz[:, pairs], axis=2)[:, :, 0]
575 assert value.shape == (
576 xyz.shape[0],
577 pairs.shape[0],
578 3,
579 ), "v.shape %s, xyz.shape %s, pairs.shape %s" % (
580 str(value.shape),
581 str(xyz.shape),
582 str(pairs.shape),
583 )
584 return value
587def dihedral_to_cartesian_tf_one_way_layers(dihedrals, cartesian, n):
588 dihedrals = -dihedrals
590 rotated = cartesian[:, 1:]
591 collected_cartesians = [cartesian[:, 0:1]]
592 for i in range(n):
593 collected_cartesians.append(rotated[:, 0:1])
594 axis = rotated[:, 1] - rotated[:, 0]
595 axis /= tf.norm(axis, axis=1, keepdims=True)
596 offset = rotated[:, 1:2]
597 rotated = offset + tf.matmul(
598 rotated[:, 1:] - offset, rotation_matrix(axis, dihedrals[:, i])
599 )
600 collected_cartesians.append(rotated)
601 collected_cartesians = tf.concat(collected_cartesians, axis=1)
602 return collected_cartesians
605def backbone_hydrogen_oxygen_crossproduct(backbone_positions):
606 assert backbone_positions.shape[2] % 3 == 0 # C, CA, N atoms, multiple of three
607 pass
610def guess_sp2_atom(cartesians, indices, angle_to_previous, bond_length):
611 added_cartesians = []
612 for i in indices:
613 prev_vec = cartesians[:, i - 1] - cartesians[:, i]
614 try:
615 next_vec = cartesians[:, i + 1] - cartesians[:, i]
616 except tf.errors.InvalidArgumentError:
617 next_vec = cartesians[:, i - 2] - cartesians[:, i]
619 perpendicular_axis = tf.linalg.cross(prev_vec, next_vec)
620 perpendicular_axis /= tf.norm(perpendicular_axis, axis=1, keepdims=True)
621 bond_vec = tf.matmul(
622 tf.expand_dims(prev_vec, 1),
623 rotation_matrix(perpendicular_axis, angle_to_previous),
624 )
625 bond_vec = bond_vec[:, 0, :]
626 bond_vec *= bond_length / tf.norm(bond_vec, axis=1, keepdims=True)
627 added_cartesians.append(cartesians[:, i] + bond_vec)
628 added_cartesians = tf.stack(added_cartesians, axis=1)
629 return added_cartesians
632def guess_amide_H(cartesians, N_indices):
633 return guess_sp2_atom(cartesians, N_indices[1::], 123 / 180 * pi, 1.10)
636def guess_amide_O(cartesians, C_indices):
637 return guess_sp2_atom(cartesians, C_indices, 121 / 180 * pi, 1.24)
640def rotation_matrix(axis_unit_vec, angle):
641 angle = tf.expand_dims(tf.expand_dims(angle, axis=-1), axis=-1)
642 i = tf.expand_dims(tf.eye(3), 0)
643 zeros = tf.zeros(tf.shape(axis_unit_vec)[0])
644 cross_prod_matrix = tf.convert_to_tensor(
645 [
646 [zeros, -axis_unit_vec[:, 2], axis_unit_vec[:, 1]],
647 [axis_unit_vec[:, 2], zeros, -axis_unit_vec[:, 0]],
648 [-axis_unit_vec[:, 1], axis_unit_vec[:, 0], zeros],
649 ]
650 )
651 cross_prod_matrix = tf.transpose(cross_prod_matrix, [2, 0, 1])
652 r = tf.cos(angle) * i
653 r += tf.sin(angle) * cross_prod_matrix
654 axis_unit_vec = tf.expand_dims(axis_unit_vec, 2)
655 r += (1 - tf.cos(angle)) * tf.matmul(
656 axis_unit_vec, tf.transpose(axis_unit_vec, [0, 2, 1])
657 )
658 return r
661def merge_cartesians(
662 central_cartesians, N_indices, O_indices, H_cartesians, O_cartesians
663):
664 cartesian = [central_cartesians[:, 0]]
665 h_i = 0
666 o_i = 0
667 for i in range(1, central_cartesians.shape[1]):
668 cartesian.append(central_cartesians[:, i])
669 if i in N_indices[1::]:
670 cartesian.append(H_cartesians[:, h_i])
671 h_i += 1
672 elif i in O_indices:
673 cartesian.append(O_cartesians[:, o_i])
674 o_i += 1
675 cartesian = tf.stack(cartesian, axis=1)
676 assert (
677 cartesian.shape[1]
678 == central_cartesians.shape[1] + H_cartesians.shape[1] + O_cartesians.shape[1]
679 )
680 return cartesian
683def dihedral_backmapping(
684 pdb_path, dihedral_trajectory, rough_n_points=-1, sidechains=None
685):
686 """
687 Takes a pdb file with a peptide and creates a trajectory based on the dihedral angles given.
688 It simply rotates around the dihedral angle axis. In the result side-chains might overlap but the backbone should
689 turn out quite well.
691 :param pdb_path: (str)
692 :param dihedral_trajectory:
693 array-like of shape (traj_length, number_of_dihedrals)
694 :param rough_n_points: (int) a step_size to select a subset of values from dihedral_trajectory is calculated by
695 max(1, int(len(dihedral_trajectory) / rough_n_points)) with rough_n_points = -1 all values are used.
696 :return: (MDAnalysis.Universe)
697 """
698 step_size = max(1, int(len(dihedral_trajectory) / rough_n_points))
699 dihedral_trajectory = dihedral_trajectory[::step_size]
700 if sidechains is not None: 700 ↛ 701line 700 didn't jump to line 701, because the condition on line 700 was never true
701 sidechain_dihedral_trajectory = sidechains[::step_size]
703 uni = mda.Universe(pdb_path, format="PDB")
704 protein = uni.select_atoms("protein")
706 dihedrals = []
707 sidechain_dihedrals = []
709 for residue in protein.residues:
710 psi = residue.psi_selection()
711 if psi:
712 dihedrals.append(psi)
714 for residue in protein.residues:
715 omega = residue.omega_selection()
716 if omega:
717 dihedrals.append(omega)
719 for residue in protein.residues:
720 phi = residue.phi_selection()
721 if phi:
722 dihedrals.append(phi)
724 if sidechains is not None: 724 ↛ 725line 724 didn't jump to line 725, because the condition on line 724 was never true
725 for residue in protein.residues:
726 chi1 = residue.chi1_selection()
727 if chi1:
728 sidechain_dihedrals.append(chi1)
730 for residue in protein.residues:
731 if "chi2" in residue.__dir__():
732 sidechain_dihedrals.append(residue.chi2_selection())
734 for residue in protein.residues:
735 if "chi3" in residue.__dir__():
736 sidechain_dihedrals.append(residue.chi3_selection())
738 for residue in protein.residues:
739 if "chi4" in residue.__dir__():
740 sidechain_dihedrals.append(residue.chi4_selection())
742 for residue in protein.residues:
743 if "chi5" in residue.__dir__():
744 sidechain_dihedrals.append(residue.chi5_selection())
746 if sidechains is not None: 746 ↛ 747line 746 didn't jump to line 747, because the condition on line 746 was never true
747 if sidechain_dihedral_trajectory.shape[1] == len(sidechain_dihedrals) * 2:
748 sidechain_dihedral_trajectory = sidechain_dihedral_trajectory[:, ::2]
750 _expand_universe(uni, len(dihedral_trajectory))
752 if sidechains is None: 752 ↛ 757line 752 didn't jump to line 757, because the condition on line 752 was never false
753 for dihedral_values, step in zip(dihedral_trajectory, uni.trajectory):
754 for dihedral, value in zip(dihedrals, dihedral_values):
755 _set_dihedral(dihedral, protein, value / (2 * pi) * 360)
756 else:
757 for dihedral_values, sidechain_dihedral_values, step in zip(
758 dihedral_trajectory, sidechain_dihedral_trajectory, uni.trajectory
759 ):
760 for dihedral, value in zip(dihedrals, dihedral_values):
761 _set_dihedral(dihedral, protein, value / (2 * pi) * 360)
762 for dihedral, value in zip(sidechain_dihedrals, sidechain_dihedral_values):
763 _set_dihedral(dihedral, protein, value / (2 * pi) * 360)
764 return uni
767def _set_dihedral(dihedral, atoms, angle):
768 current_angle = dihedral.dihedral.value()
769 head = atoms[dihedral[2].id :]
770 vec = dihedral[2].position - dihedral[1].position
771 head.rotateby(angle - current_angle, vec, dihedral[2].position)
774def _expand_universe(universe, length):
775 coordinates = (
776 AnalysisFromFunction(lambda ag: ag.positions.copy(), universe.atoms)
777 .run()
778 .results
779 )["timeseries"]
780 coordinates = np.tile(coordinates, (length, 1, 1))
781 universe.load_new(coordinates, format=MemoryReader)