Coverage for encodermap/misc/rotate.py: 9%

189 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/misc/rotate.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Helpers to apply rotations to molecular coordinates. 

23 

24""" 

25 

26################################################################################ 

27# Imports 

28################################################################################ 

29 

30 

31# Future Imports at the top 

32from __future__ import annotations 

33 

34# Standard Library Imports 

35import copy 

36import warnings 

37from copy import deepcopy 

38from typing import TYPE_CHECKING, Optional, Union, overload 

39 

40# Third Party Imports 

41import numpy as np 

42import transformations as trans 

43from optional_imports import _optional_import 

44 

45# Encodermap imports 

46from encodermap.trajinfo.trajinfo_utils import _delete_bond 

47 

48 

49################################################################################ 

50# Optional Imports 

51################################################################################ 

52 

53 

54md = _optional_import("mdtraj") 

55 

56 

57################################################################################ 

58# Typing 

59################################################################################ 

60 

61 

62if TYPE_CHECKING: 

63 # Third Party Imports 

64 import networkx as nx 

65 from mdtraj.core.topology import Atom 

66 

67 

68################################################################################ 

69# Globals 

70################################################################################ 

71 

72 

73__all__: list[str] = ["mdtraj_rotate"] 

74 

75 

76################################################################################ 

77# Functions 

78################################################################################ 

79 

80 

81def arbitrary_dihedral( 

82 pos: np.ndarray, 

83 out: Optional[np.ndarray] = None, 

84) -> np.ndarray: 

85 """Computes the dihedral angles of a position array with shape (n_frames, 4). 

86 

87 Args: 

88 pos (np.ndarray): The positions between which to calculate the dihedrals. 

89 out (np.ndarray): A location into which the result is stored. If provided, 

90 it must have a shape that the inputs broadcast to. If not provided 

91 or None, a freshly-allocated array is returned. A tuple (possible 

92 only as a keyword argument) must have length equal to the number of outputs. 

93 

94 Returns: 

95 np.ndarray: The dihedral angles in radians. 

96 

97 """ 

98 p0 = pos[:, 0] 

99 p1 = pos[:, 1] 

100 p2 = pos[:, 2] 

101 p3 = pos[:, 3] 

102 

103 b1 = -1.0 * (p1 - p0) 

104 b2 = p2 - p1 

105 b3 = p3 - p2 

106 

107 c1 = np.cross(b2, b3) 

108 c2 = np.cross(b1, b2) 

109 

110 p1 = (b1 * c1).sum(-1) 

111 p1 *= (b2 * b2).sum(-1) ** 0.5 

112 p2 = (c1 * c2).sum(-1) 

113 

114 return np.arctan2(p1, p2, out) 

115 

116 

117def mdtraj_rotate( 

118 traj: md.Trajectory, 

119 angles: np.ndarray, 

120 indices: np.ndarray, 

121 deg: bool = False, 

122 check_cyclic_backbone: bool = True, 

123 verify_every_rotation: bool = False, 

124 drop_proline_angles: bool = False, 

125 delete_sulfide_bridges: bool = True, 

126) -> md.Trajectory: 

127 """Uses MDTraj and Christoph Gohlke's transformations.py to set bond 

128 rotations provided traj. 

129 

130 Input can be in radian (set `deg` to False) or degree (set `deg` to True). 

131 

132 General procedure: 

133 * Carry out some checks. Shapes of input need to be correct. `traj` 

134 needs to have a single frame and not be of a cyclic protein nor 

135 contain multiple chains. 

136 * Get the indices of the near and far side of the rotations. Every 

137 dihedral angle is indexed by 4 atoms. The rotational axis is located 

138 between the central two atoms (dihedral[1:3]). 

139 * Extend the trajectory. The lengths of dihedrals and sidechain_dihedrals 

140 should match. The frame given by top will be duplicated 

141 len(dihedrals)-times. 

142 * Get the current angles. We know what the final angles should be, but 

143 now how far to rotate the bonds. This can be done by getting the 

144 difference between current and target angle. 

145 * Rotate the bonds. Using Christoph Gohlke's transformations.py, 

146 the rotation matrix is constructed and the array is padded with 

147 zeros to resemble an array of quaternions. 

148 

149 Args: 

150 traj (mdtraj.Trajectory): The trajectory to use. Needs to have only one frame. 

151 angles (list[list[float]], np.ndarray): The values the angles should 

152 assume after the rotations. This arg can either be a nested list 

153 with floats or (better) a numpy array with the shape angles.shape = 

154 (n_new_frames, n_indexed_dihedrals). Here, angles.shape[0] defines 

155 how many frames the output trajectory is going to have and angles.shape[1] 

156 should be similar to the number of dihedrals you want to rotate around. 

157 A shape of (4, 2) would indicate that two dihedrals are going to be 

158 used for rotation and the output trajectory is going to have 4 frames. 

159 indices (list[list[int]], np.ndarray): A list of ints indexing the 

160 dihedrals to be rotated around. Naturally indices.shape[1] needs 

161 to be 4. Additionally indices.shape[0] needs to be the same as 

162 angles.shape[1]. indices indexes the angles along axis 1 and angles 

163 sets the values of those angles along axis 0. 

164 deg (bool, optional): Whether argument `angles` is in deg. 

165 Defaults to False. 

166 check_cyclic_backbone (bool): Whether the backbone should be 

167 checked for being cyclic. Rotating around a backbone angle for a 

168 cyclic protein is not possible and thus an Exception is raised. 

169 However, rotation around sidechain dihedrals is still possible. 

170 If you are sure you want to rotate sidechain dihedrals set this 

171 argument to False to prevent the cyclic backbone check. 

172 Defaults to True. 

173 verify_every_rotation (bool): Whether the rotation succeeded. 

174 drop_proline_angles (bool): Whether to automatically drop proline 

175 angles and indices. 

176 delete_sulfide_bridges (bool): Whether to automatically remove bonds from 

177 between cysteine residues. 

178 

179 Raises: 

180 Exception: If the input seems like it is in degrees, but `deg` is False. 

181 Exception: If `traj` contains more than 1 frame. 

182 Exception: If traj is not fully connected. 

183 Exception: If shapes of `angles` and `indices` mismatches. 

184 Exception: If shape[1] of `indices` is not 4. 

185 Exception: If backbone is cyclic and check_cyclic_backbone is True. 

186 Exception: If the first rotation does not reach a tolerance of 1e-3. 

187 

188 Returns: 

189 mdtraj.Trajectory: An MDTraj trajectory with applied rotations. 

190 

191 Examples: 

192 >>> import mdtraj as md 

193 >>> import numpy as np 

194 

195 >>> # load an arbitrary protein from the pdb 

196 >>> traj = md.load_pdb('https://files.rcsb.org/view/1GHC.pdb') 

197 >>> print(traj.n_frames) 

198 14 

199 

200 >>> # traj has multiple frames so we remove all but one 

201 >>> traj = traj[0] 

202 

203 >>> # Get indices of psi_angles 

204 >>> psi_indices, old_values = md.compute_psi(traj) 

205 

206 >>> # set every psi angle to be either 0 or 180 deg 

207 >>> angles = np.full((len(psi_indices), 2), [0, 180]).T 

208 

209 >>> # create the new traj with the desired rotations 

210 >>> out_traj = mdtraj_rotate(traj, angles, psi_indices, deg=True) 

211 >>> print(out_traj.n_frames) 

212 2 

213 

214 >>> # check values 

215 >>> _, new_values = md.compute_psi(out_traj) 

216 >>> print(np.abs(np.rad2deg(new_values[0, :2]).round(0))) # prevent rounding inconsistencies 

217 [0. 0.] 

218 >>> print(np.abs(np.rad2deg(new_values[1, :2]).round(0))) # prevent rounding inconsistencies 

219 [180. 180.] 

220 

221 """ 

222 # Third Party Imports 

223 import networkx as nx 

224 

225 if deg: 

226 angles = np.deg2rad(angles) 

227 else: 

228 if np.any(angles > 2 * np.pi): 

229 print( 

230 "Some of your input for `angles` is larger than 2pi. " 

231 "This suggests, that you provided `angles` in deg not in rad. " 

232 "Please set the argument `deg` to True to transpose the values " 

233 "of `angles`." 

234 ) 

235 

236 # make np arrays 

237 indices = np.asarray(indices).astype(int) 

238 angles = np.asarray(angles) 

239 

240 # if only one rotation/index is given expand dims 

241 if angles.ndim == 1: 

242 angles = np.expand_dims(angles, -1) 

243 if indices.ndim == 1: 

244 indices = np.expand_dims(indices, 0) 

245 

246 # check whether traj has only one frame: 

247 if traj.n_frames > 1: 

248 raise Exception( 

249 f"The provided `traj` has {traj.n_frames}. " 

250 f"Please provide a traj with only 1 frame." 

251 ) 

252 traj = deepcopy(traj) 

253 

254 # check if shape of indices and dihedrals is consistent 

255 if angles.shape[1] != indices.shape[0]: 

256 raise Exception( 

257 f"Shapes of `angles` and `indices` mismatch. Shape[1] of `angles` is " 

258 f"{angles.shape[1]}, which indicates that you also want to rotate " 

259 f"around {angles.shape[1]} dihedral(s), but indices indexes " 

260 f"{indices.shape[0]} dihedral angle(s). The shapes of the inputs " 

261 f"need to match. indices.shape[0] == angles.shape[1]" 

262 ) 

263 

264 # check whether 4 atoms are indexed 

265 if indices.shape[1] != 4: 

266 raise Exception( 

267 f"The shape of `indices` needs to be (n_dihedrals, 4), meaning, that " 

268 f"a dihedral angle is defined by 4 atoms. Your `indices` argument` " 

269 f"has shape {indices.shape} which is not allowed." 

270 ) 

271 

272 # check whether structure is whole 

273 g = traj.top.to_bondgraph() 

274 if not nx.is_connected(g): 

275 raise Exception( 

276 "Structure is disjoint and not fully connected. This can be caused by" 

277 "multiple problems: Topology was not parsed correctly " 

278 "(nonstandard residues), traj contains multiple proteins (chains), " 

279 "or water molecules. You can add the bonds with " 

280 "`mdtraj.Topology.add_bond()`, if you know what bonds exactly are missing." 

281 ) 

282 

283 # check whether is cyclic 

284 backbone = traj.atom_slice(traj.top.select("backbone")) 

285 try: 

286 edges = nx.algorithms.cycles.find_cycle(backbone.top.to_bondgraph()) 

287 except nx.NetworkXNoCycle: 

288 pass 

289 else: 

290 if check_cyclic_backbone: 

291 raise Exception( 

292 "The Backbone of your protein is cyclic. You can not rotate " 

293 "around the backbone dihedral (Ramachandran) angles. " 

294 "You can still rotate around sidechain angles. For that set the " 

295 "`check_cyclic_backbone` argument of this function to " 

296 "True to silence this warning." 

297 ) 

298 

299 # remove proline angles from angles and indices 

300 if drop_proline_angles: 

301 offending_atoms = set(traj.top.select("resname PRO and (name CA or name N)")) 

302 slice = [] 

303 for i, ind in enumerate(indices): 

304 ind = ind[1:3] 

305 if len(offending_atoms.intersection(set(ind))) == 2: 

306 slice.append(i) 

307 slice = np.array(slice) 

308 angles = np.delete(angles, slice, axis=1) 

309 indices = np.delete(indices, slice, axis=0) 

310 assert not offending_atoms.issubset(set(indices[:, 1:3].flatten())) 

311 

312 # delete sulfide bridges from the topology 

313 if delete_sulfide_bridges: 

314 for i, (a, b) in enumerate(traj.top.bonds): 

315 if ( 

316 a.element.symbol == "S" 

317 and b.element.symbol == "S" 

318 and a.residue.index != b.residue.index 

319 ): 

320 traj.top = _delete_bond(traj.top, (a, b)) 

321 g = traj.top.to_bondgraph() 

322 

323 # get near and far sides 

324 dih_near_sides, dih_far_sides = _get_near_and_far_networkx( 

325 g, indices[:, 1:3], top=traj.top 

326 ) 

327 

328 # extend the traj 

329 for i in range(len(angles)): 

330 if i == 0: 

331 out_traj = copy.deepcopy(traj) 

332 else: 

333 out_traj = out_traj.join(traj) 

334 

335 # adjust the torsions 

336 new_xyz = copy.deepcopy(out_traj.xyz) 

337 for i in range(angles.shape[0]): 

338 for j in range(angles.shape[1]): 

339 # central_dihedrals 

340 near_side = dih_near_sides[j] 

341 far_side = dih_far_sides[j] 

342 dihedral = indices[j] 

343 bond = dihedral[1:3] 

344 

345 # define inputs 

346 target_angle = angles[i, j] 

347 current_angle = _dihedral(new_xyz[i], dihedral)[0][0] 

348 angle = target_angle - current_angle 

349 direction = np.diff(new_xyz[i, bond], axis=0).flatten() 

350 pivot_point = new_xyz[i, bond[0]] 

351 

352 # perform rotation 

353 rotmat = trans.rotation_matrix(angle, direction, pivot_point) 

354 padded = np.pad( 

355 new_xyz[i][far_side], 

356 ((0, 0), (0, 1)), 

357 mode="constant", 

358 constant_values=1, 

359 ) 

360 new_xyz[i][far_side] = rotmat.dot(padded.T).T[:, :3] 

361 

362 if i == 0 and j == 0 and verify_every_rotation: 

363 dih_indexes = traj._CVs.attrs["central_dihedrals"][j] 

364 s = f"Near and far side for dihedral {[str(traj.top.atom(x)) for x in dih_indexes]} are:" 

365 s += ( 

366 f"\nNear: {[str(traj.top.atom(x)) for x in near_side]}, {near_side}" 

367 ) 

368 s += f"\nFar: {[str(traj.top.atom(x)) for x in dih_far_sides[j][:12]]}..., {dih_far_sides[j][:12]}..." 

369 s += f"\nRotation around bond {[str(traj.top.atom(x)) for x in bond]}, {bond}." 

370 s += f"\nPositions of near side before rotation are\n{out_traj.xyz[i][near_side]}." 

371 s += f"\nPositions of near side after rotation aren\n{new_xyz[i][near_side]}" 

372 print(s) 

373 

374 # verify 

375 if verify_every_rotation: 

376 _ = np.rad2deg(_dihedral(new_xyz[i], dihedral))[0][0] 

377 if not np.isclose(_, target_angle, atol=1e-3): 

378 s = f"Adjusting dihedral angle for atoms {[str(traj.top.atom(x)) for x in dihedral]} failed with a tolerance of 1e-4." 

379 s += f"\nTarget angle was {target_angle} deg, but rotation yieled angle with {_} deg." 

380 s += f"\nCurrent angle was {current_dih[j]}. To reach target angle is a rotation of {angle} degrees was carried out." 

381 s += f"\nRotation axis was vector from {traj.top.atom(bond[0])} to {traj.top.atom(bond[1])}" 

382 s += f"\nOnly these atoms should have been affected by rotation: {far_side}" 

383 s += "\nBut somehow this method still crashed. Maybe these prints will help." 

384 raise Exception(s) 

385 

386 # overwrite traj and return 

387 out_traj.xyz = new_xyz 

388 return out_traj 

389 

390 

391@overload 

392def _get_near_and_far_networkx( 392 ↛ exitline 392 didn't jump to the function exit

393 bondgraph: nx.Graph, 

394 edge_indices: np.ndarray, 

395 top: Optional[md.Topology] = None, 

396 parallel: bool = True, 

397) -> tuple[np.ndarray, None]: ... 

398 

399 

400@overload 

401def _get_near_and_far_networkx( 401 ↛ exitline 401 didn't jump to the function exit

402 bondgraph: nx.Graph, 

403 edge_indices: np.ndarray, 

404 top: Optional[md.Topology] = None, 

405 parallel: bool = False, 

406) -> tuple[list[np.ndarray], list[np.ndarray]]: ... 

407 

408 

409def _get_near_and_far_networkx( 

410 bondgraph: nx.Graph, 

411 edge_indices: np.ndarray, 

412 top: Optional[md.Topology] = None, 

413 parallel: bool = False, 

414) -> Union[tuple[list[np.ndarray], list[np.ndarray]], tuple[np.ndarray, None]]: 

415 """Returns near and far sides for a list of edges giving the indices of the 

416 two atoms at which the structure is broken. 

417 

418 Args: 

419 bondgraph (networkx.classes.graph.Graph): The bondgraph describing the protein. 

420 edge_indices (np.ndarray): The edges where the graph will be broken at. 

421 top (Optional[md.Topology]): Used for printing helpful messages in exceptions. 

422 parallel (bool): Whether to return a dense array and None. 

423 

424 Returns: 

425 tuple[list[np.ndarray], list[np.ndarray]]: A tuple containing the following: 

426 - near_sides (list[np.ndarray]): List of integer arrays giving the near 

427 - sides. len(near_sides) == len(edge_indices). 

428 - far_sides (list[np.ndarray]): Same as near sides, but this time the far sides. 

429 

430 """ 

431 # Third Party Imports 

432 import networkx as nx 

433 from networkx.algorithms.components.connected import connected_components 

434 

435 assert edge_indices.shape[1] == 2, ( 

436 f"Can only take `edge_indices` as a numpy array, with shape[1] = 2, but " 

437 f"the provided `edge_indices` has shape {edge_indices.shape[1]=}." 

438 ) 

439 

440 if parallel: 

441 out = np.zeros(shape=(len(edge_indices), len(bondgraph))).astype(bool) 

442 

443 near_sides = [] 

444 far_sides = [] 

445 for i, edge in enumerate(edge_indices): 

446 G = nx.convert_node_labels_to_integers(bondgraph).copy() 

447 try: 

448 G.remove_edge(*edge) 

449 except nx.NetworkXError as e: 

450 if top: 

451 raise Exception( 

452 f"Seems like the edge {[top.atom(a) for a in edge]=} " 

453 f"{[top.atom(a).index for a in edge]=}is not " 

454 f"part of the graph. This might originate from a bond, that has " 

455 f"been deleted, but which atoms are still considered to be part of " 

456 f"a dihedral. Maybe you want supplied this topology to the " 

457 f"mdtraj_backmapping method with wrong featurization." 

458 ) from e 

459 else: 

460 raise Exception( 

461 f"Please provide arg `top` to learn more about this Exception" 

462 ) from e 

463 except TypeError as e: 

464 raise Exception( 

465 f"Could not remove the edge {edge=}, {edge_indices.shape=}." 

466 ) from e 

467 components = [*connected_components(G)] 

468 if len(components) != 2: 

469 if top is None: 

470 raise Exception( 

471 f"Splitting the topology of the trajectory at the edge " 

472 f"{edge} does not work. Provide a topology to see, " 

473 "which atoms are affected" 

474 ) 

475 else: 

476 path = nx.shortest_path(G, *edge) 

477 assert isinstance(path, list) 

478 path_atoms = [top.atom(i) for i in path] 

479 a1 = top.atom(edge[0]) 

480 a2 = top.atom(edge[1]) 

481 raise Exception( 

482 f"Splitting at edge {edge} does not work. Here are the " 

483 f"atoms: {a1} and {a2}. Removing " 

484 f"this edge resulted in {len(components)} disconnected components. " 

485 f"The resSeqs of the residues are {a1.residue.resSeq} and " 

486 f"{a2.residue.resSeq}. The indices are {a1.residue.index} and " 

487 f"{a2.residue.index}. These two atoms are still connected via " 

488 f"this path: {path_atoms}." 

489 ) 

490 

491 if edge[1] in components[0] and edge[0] in components[1]: 

492 components = components[::-1] 

493 assert len(components) == 2, print( 

494 f"Protein might be cyclic or contain more than 1 chain. {len(components)}" 

495 ) 

496 assert edge[0] in components[0] and edge[1] in components[1], print( 

497 "Finding near and far sides failed." 

498 ) 

499 subgraph = G.subgraph(components[-1]).copy() 

500 far = np.asarray(subgraph.nodes) 

501 subgraph = G.subgraph(components[0]).copy() 

502 near = np.asarray(subgraph.nodes) 

503 if parallel: 

504 out[i][near] = True 

505 else: 

506 far_sides.append(far) 

507 near_sides.append(near) 

508 if not parallel: 

509 return near_sides, far_sides 

510 else: 

511 return out, None 

512 

513 

514def _angle( 

515 xyz: np.ndarray, 

516 indices: np.ndarray, 

517) -> float: 

518 """Returns current angle between positions. 

519 

520 Adapted from MDTraj. 

521 

522 Args: 

523 xyz (np.ndarray). This function only takes a xyz array of a single frame and uses np.expand_dims() 

524 to make that fame work with the `_displacement` function from mdtraj. 

525 indices (Union[np.ndarray, list]): List of 3 ints describing the dihedral. 

526 

527 Returns: 

528 np.ndarray: The angle. 

529 

530 """ 

531 indices = np.expand_dims(np.asarray(indices), 0) 

532 xyz = np.expand_dims(xyz, 0) 

533 ix01 = indices[:, [1, 0]] 

534 ix21 = indices[:, [1, 2]] 

535 

536 u_prime = _displacement(xyz, ix01) 

537 v_prime = _displacement(xyz, ix21) 

538 u_norm = np.sqrt((u_prime**2).sum(-1)) 

539 v_norm = np.sqrt((v_prime**2).sum(-1)) 

540 

541 u = u_prime / (u_norm[..., np.newaxis]) 

542 v = v_prime / (v_norm[..., np.newaxis]) 

543 

544 return np.arccos((u * v).sum(-1)) 

545 

546 

547def _dihedral( 

548 xyz: np.ndarray, 

549 indices: np.ndarray, 

550) -> float: 

551 """Returns current dihedral angle between positions. 

552 

553 Adapted from MDTraj. 

554 

555 Args: 

556 xyz (np.ndarray). This function only takes a xyz array of a single frame and uses np.expand_dims() 

557 to make that fame work with the `_displacement` function from mdtraj. 

558 indices (Union[np.ndarray, list]): List of 4 ints describing the dihedral. 

559 

560 Returns: 

561 np.ndarray: The dihedral. 

562 

563 """ 

564 indices = np.expand_dims(np.asarray(indices), 0) 

565 xyz = np.expand_dims(xyz, 0) 

566 ix10 = indices[:, [0, 1]] 

567 ix21 = indices[:, [1, 2]] 

568 ix32 = indices[:, [2, 3]] 

569 

570 b1 = _displacement(xyz, ix10) 

571 b2 = _displacement(xyz, ix21) 

572 b3 = _displacement(xyz, ix32) 

573 

574 c1 = np.cross(b2, b3) 

575 c2 = np.cross(b1, b2) 

576 

577 p1 = (b1 * c1).sum(-1) 

578 p1 *= (b2 * b2).sum(-1) ** 0.5 

579 p2 = (c1 * c2).sum(-1) 

580 

581 return np.arctan2(p1, p2, None) 

582 

583 

584def _displacement(xyz: np.ndarray, pairs: np.ndarray) -> np.ndarray: 

585 """Displacement vector between pairs of points in each frame 

586 

587 Args: 

588 xyz (np.ndarray): The coordinates of the atoms. 

589 pairs (np.ndarray): An array with integers and shape (n_pairs, 2), 

590 defining the atom paris between which the displacement will 

591 be calculated. 

592 

593 Returns: 

594 np.ndarray: An array with shape (n_pairs, ). 

595 

596 """ 

597 value = np.diff(xyz[:, pairs], axis=2)[:, :, 0] 

598 assert value.shape == ( 

599 xyz.shape[0], 

600 pairs.shape[0], 

601 3, 

602 ), "v.shape %s, xyz.shape %s, pairs.shape %s" % ( 

603 str(value.shape), 

604 str(xyz.shape), 

605 str(pairs.shape), 

606 ) 

607 return value