Coverage for encodermap/misc/backmapping.py: 35%

322 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/misc/backmapping.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Backmapping functions largely based upon encodermap_tf1's nackmapping an martini-tools backwards.py 

23 

24 ToDo: 

25 * Using Quaternions in Tensorflow rotation matrices could be accelerated? 

26 * Multi Top. 

27 

28""" 

29 

30############################################################################## 

31# Imports 

32############################################################################## 

33 

34import copy 

35from math import pi 

36 

37import numpy as np 

38import tensorflow as tf 

39 

40from .._optional_imports import _optional_import 

41from ..misc import transformations as trans 

42from ..misc.errors import BadError 

43 

44############################################################################## 

45# Optional Imports 

46############################################################################## 

47 

48 

49md = _optional_import("mdtraj") 

50mda = _optional_import("MDAnalysis") 

51AnalysisFromFunction = _optional_import( 

52 "MDAnalysis", "analysis.base.AnalysisFromFunction" 

53) 

54MemoryReader = _optional_import("MDAnalysis", "coordinates.memory.MemoryReader") 

55 

56 

57############################################################################## 

58# Globals 

59############################################################################## 

60 

61 

62__all__ = ["backbone_hydrogen_oxygen_crossproduct"] 

63 

64 

65############################################################################## 

66# Public Functions 

67############################################################################## 

68 

69 

70def split_and_reverse_dihedrals(x): 

71 """Splits dihedrals in BackMapping model into left (reversed) and right part. 

72 These dihedrals are then used to bring the chain_in_plane into 3D. 

73 

74 Args: 

75 x (tf.Tensor): The dihedrals with shape (None, n_reisudes * 3 - 3) 

76 

77 Examples: 

78 >>> from encodermap.misc.backmapping import split_and_reverse_dihedrals 

79 >>> import numpy as np 

80 

81 >>> # create dihedrals for protein with 3 resiudes, i.e. 3*3 - 3 = 6 central dihedral angles 

82 >>> # single sample will be used -> shape = (1, 6) 

83 >>> np.random.seed(20) 

84 >>> dihedrals = np.random.random((1, 6)) * 2 * np.pi 

85 >>> print(dihedrals) 

86 [[3.69533481 5.64050171 5.60165278 5.12605805 0.22550092 4.34644107]] 

87 

88 >>> dihedrals_left, dihedrals_right = split_and_reverse_dihedrals(dihedrals) 

89 >>> print(dihedrals_left, dihedrals_right) 

90 [[5.60165278 5.64050171 3.69533481]] [[5.12605805 0.22550092 4.34644107]] 

91 

92 """ 

93 if tf.executing_eagerly(): 93 ↛ 100line 93 didn't jump to line 100, because the condition on line 93 was never false

94 middle = int(x.shape[1] / 2) 

95 if tf.math.equal(tf.math.mod(x.shape[1], 2), 0): 

96 return x[:, middle - 1 :: -1], x[:, middle:] 

97 else: 

98 return x[:, middle::-1], x[:, middle + 1 :] 

99 else: 

100 middle = int(x.shape[1] / 2) 

101 cond = tf.math.equal(tf.math.mod(x.shape[1], 2), 0) 

102 return tf.cond( 

103 cond, 

104 true_fn=lambda: (x[:, middle - 1 :: -1], x[:, middle:]), 

105 false_fn=lambda: (x[:, middle::-1], x[:, middle + 1 :]), 

106 ) 

107 

108 

109def split_and_reverse_cartesians(x): 

110 """Splits cartesians and returns a left (reversed) right part. 

111 

112 Because dihedrals are made up from 4 atoms, three atoms are 

113 identical in the left and right part of the list. This holds true: 

114 left[0] = right[2] 

115 left[1] = right[1] 

116 left[2] = right[0] 

117 

118 Args: 

119 x (tf.Tensor): The cartesians with shape (None, n_reisudes * 3, 3) 

120 

121 Examples: 

122 >>> from encodermap.misc.backmapping import split_and_reverse_cartesians 

123 >>> import numpy as np 

124 

125 >>> # create cartesians for protein with 3 resiudes, i.e. 9 

126 >>> # single sample will be used -> shape = (1, 9, 3) 

127 >>> np.random.seed(20) 

128 >>> cartesians = np.random.random((1, 9, 3)) * 10 

129 

130 >>> cartesians_left, cartesians_right = split_and_reverse_cartesians(cartesians) 

131 

132 >>> print(cartesians_left.shape, cartesians_right.shape) 

133 (1, 6, 3) (1, 6, 3) 

134 

135 >>> print(cartesians_left[:,0] == cartesians_right[:,2]) 

136 [[ True True True]] 

137 

138 >>> print(cartesians_left[:,1] == cartesians_right[:,1]) 

139 [[ True True True]] 

140 

141 >>> print(cartesians_left[:,2] == cartesians_right[:,0]) 

142 [[ True True True]] 

143 

144 

145 

146 """ 

147 middle = int(x.shape[1] / 2) 

148 return x[:, middle + 1 :: -1], x[:, middle - 1 :] 

149 

150 

151def dihedrals_to_cartesian_tf_layers(dihedrals, cartesians): 

152 """Calculates dihedrals to cartesians in Graph/Layer execution. 

153 

154 Args: 

155 dihedrals (tf.Tensor): The dihedrals of shape (None, n_resides * 3 - 3) 

156 cartesians (tf.Tensor): The cartesians of shaoe (None, n_residues * 3, 3). 

157 

158 """ 

159 

160 if len(cartesians.get_shape()) == 2: 160 ↛ 162line 160 didn't jump to line 162, because the condition on line 160 was never true

161 # if a single line of cartesians is passed it is repeated to match the number of dihedrals 

162 cartesians = Lambda( 

163 lambda x: tf.tile(tf.expand_dims(x[0], axis=0), [tf.shape(x[1])[0], 1, 1]) 

164 )((cartesians, dihedrals)) 

165 

166 # split and reverse so that the center of the molecule stays in the 2D plane 

167 # and the left and right ends curl into the 3rd dimension 

168 cartesians_left, cartesians_right = split_and_reverse_cartesians(cartesians) 

169 dihedrals_left, dihedrals_right = split_and_reverse_dihedrals(dihedrals) 

170 

171 new_cartesians_left = dihedral_to_cartesian_tf_one_way_layers( 

172 dihedrals_left, cartesians_left, int(dihedrals.shape[1] / 2) 

173 ) 

174 new_cartesians_right = dihedral_to_cartesian_tf_one_way_layers( 

175 dihedrals_right, cartesians_right, int(dihedrals.shape[1] / 2) 

176 ) 

177 

178 new_cartesians = tf.concat( 

179 [new_cartesians_left[:, ::-1], new_cartesians_right[:, 3:]], axis=1 

180 ) 

181 

182 return new_cartesians 

183 

184 

185def mdtraj_backmapping( 

186 top, 

187 dihedrals, 

188 sidechain_dihedrals=None, 

189 trajs=None, 

190 fake_dihedrals=False, 

191 verify_every_rotation=True, 

192 angle_type="radian", 

193): 

194 """Uses MDTraj and Christoph Gohlke's transformations.py to rotate the bonds in the provided topology. 

195 

196 Input currently only in angles. 

197 

198 General procedure: 

199 * Decide on which topology to use (if different topologies are in the TrajEnsemble class the `dihedrals` and 

200 `sidechain_dihedrals` arrays need to be altered so that the correct dihedrals are used. Because EncoderMap 

201 is trained on a full input `dihedrals` and `sidechain_dihedrals` contain the dihedrals for the topology 

202 in `TrajEnsemble` with the most of such angles. Some SingleTraj classes in TrajEnsemble might not contain all these 

203 angles if for example an amino acid has been modified the mutant contains more sidechain dihedrals than the 

204 wt. So the correct sidechain dihedrals for the wildtype need to be selected. 

205 * Get the indices of the far sides of the rotations. The graph is gradually broken apart and the longer 

206 subgraphs are kept. 

207 * Extend the trajectory. The lengths of dihedrals and sidechain_dihedrals should match. The frame given by top 

208 will be duplicated len(dihedrals)-times. 

209 * Get the current angles. We know what the final angles should be, but now how far to rotate the bonds. This 

210 can be done by getting the difference between current and target angle. 

211 * Rotate the bonds. Using Christoph Gohlke's transformations.py, the rotation matrix is constructed and 

212 the array is padded with zeros to resemble an array of quaternions. 

213 

214 Args: 

215 top (str): The topology file to use. 

216 dihedrals (np.ndarray): The dihedrals to put onto the trajectory. len(dihedrals) is number of frames of 

217 output trajectory. dihedrals.shape[1] needs to be the same as the number of dihedrals in the topology. 

218 sidechain_dihedrals (Union[np.ndarray, None], optional): The sidechain dihedrals to put onto the trajectory. 

219 If None is provided, the sidechains are kept like they were in the topology. Defaults to None. 

220 trajs (Union[em.TrajEnsemble, None], optional): Encodermap TrajEnsemble class. Can accelerate loading of 

221 current dihedral angles. Also checks of provided topology is part of trajs. Defaults to None. 

222 fake_dihedrals (bool, optional): Whether to fake dihedrals. For debugging. Defaults to False. 

223 verify_every_rotation (bool, optional): Whether the rotation succeeded. 

224 angle_type (bool, optional): Whether input is in degrees. Input will be converted to radians. Defaults to False. 

225 

226 Raises: 

227 Exception: If the input seems like it is in degrees. 

228 Exception: If top is not part of the TrajEnsemble class provided in argument `trajs`. 

229 

230 Returns: 

231 mdtraj.Trajectory: An MDTraj trajectory with the correct dihedrals/side-dihedrals. 

232 

233 """ 

234 import networkx as nx 

235 

236 # change the angles 

237 if angle_type == "radian": 

238 pass 

239 elif angle_type == "degree": 

240 dihedrals = np.deg2rad(dihedrals) 

241 sidechain_dihedrals = np.deg2rad(sidechain_dihedrals) 

242 else: 

243 raise Exception("Argument `angle_type` must be either 'radian' or 'degree'.") 

244 

245 # make sure the input has the same shape along the "frame" axis, that will be created. 

246 if sidechain_dihedrals is not None: 

247 assert len(dihedrals) == len(sidechain_dihedrals) 

248 

249 # decide on what to do with trajs. If it was supplied it can 

250 # either be TrajEnsemble or SingleTraj 

251 if trajs is not None: 

252 # TrajEnsemble 

253 if hasattr(trajs, "n_trajs"): 

254 # The topology in `top` should also be present in traj 

255 if not top in trajs._top_files: 

256 raise Exception( 

257 f"Provided topology is not part of TrajEnsemble object. Possible tops are {trajs._top_files}" 

258 ) 

259 

260 # Deciding which traj to use 

261 ind = [top == i for i in trajs._top_files].index(True) 

262 traj = trajs[ind] 

263 print( 

264 f"Using trajectory {traj} as parent for backmapping. Because its " 

265 f"topoloy file ({traj.top_file}) matches the file provided as " 

266 f"argument `top` ({top})." 

267 ) 

268 # SingleTraj 

269 

270 else: 

271 traj = trajs 

272 

273 # load the CVs to use the indices 

274 traj.load_CV("all") 

275 inp_traj = md.load(traj.top_file) 

276 else: 

277 # either build info Single and load the CVs, which is currently broken for pdbs 

278 from ..trajinfo import info_single 

279 

280 try: 

281 traj = info_single.SingleTraj(top) 

282 traj.load_CV("all") 

283 inp_traj = md.load(top) 

284 except OSError as e: 

285 # this is a weird PyEMMA error that keeps happening 

286 if "REMARK" in e.__str__(): 

287 traj = md.load(top) 

288 inp_traj = md.load(top) 

289 angles = ["psi", "omega", "phi"] 

290 dihedrals_ = [ 

291 getattr(md, f"compute_{a}")(traj)[0].tolist() for a in angles 

292 ] 

293 results = [None] * ( 

294 len(dihedrals_[0]) + len(dihedrals_[0]) + len(dihedrals_[2]) 

295 ) 

296 results[::3] = dihedrals_[0] 

297 results[1::3] = dihedrals_[1] 

298 results[2::3] = dihedrals_[2] 

299 dih_indices = np.array(results) 

300 else: 

301 raise e 

302 

303 # get indices of atoms for rotations 

304 g = inp_traj.top.to_bondgraph() 

305 # nx.draw(g, pos=nx.spring_layout(g)) 

306 if not nx.is_connected(g): 

307 raise BadError( 

308 f"MDTraj parsed the topology at {top} and found it disconnected. Changing dihedrals in multiple " 

309 f"chains is currently not possible. If you are sure your protein is just one chain you can try " 

310 f"the MDAnalysis backmapping backend or provide a topology of the file with manually fixed bonds." 

311 ) 

312 

313 # get near and far sides 

314 # dih indices are four atoms 

315 # bond indices are the two atoms in the middle giving the axis of rotation 

316 if hasattr(traj, "_CVs"): 

317 if "central_dihedrals" in traj._CVs.attrs: 

318 dih_indices = np.asarray(traj._CVs.attrs["central_dihedrals"]) 

319 

320 # at this point dih_bond_indices has been defined. 

321 # either via PyEMMA featurizer or the compute_phi/omega/psi methods of mdtraj 

322 dih_bond_indices = dih_indices[:, 1:3] 

323 

324 # filter out the Proline angles 

325 dih_bond_atoms = np.dstack( 

326 [ 

327 [traj.top.atom(a).__str__() for a in dih_bond_indices[:, 0]], 

328 [traj.top.atom(a).__str__() for a in dih_bond_indices[:, 1]], 

329 ] 

330 )[0] 

331 indices = np.arange(len(dih_bond_indices)).tolist() 

332 for i, bond in enumerate(dih_bond_atoms): 

333 if "PRO" in bond[0] and "PRO" in bond[1] and "N" in bond[0] and "CA" in bond[1]: 

334 indices.remove(i) 

335 

336 dih_indices = dih_indices[indices] 

337 dih_bond_indices = dih_bond_indices[indices] 

338 dihedrals = dihedrals[:, indices] 

339 dih_near_sides, dih_far_sides = _get_far_and_near_networkx( 

340 g, dih_bond_indices, inp_traj.top 

341 ) 

342 

343 if sidechain_dihedrals is not None: 

344 if "side_dihedrals" not in traj._CVs.attrs: 

345 try: 

346 traj.load_CV("all") 

347 except OSError as e: 

348 pass 

349 raise NotImplementedError( 

350 f"This traj produces some error with PyEMMA: {e} " "" 

351 ) 

352 

353 side_indices = np.asarray(traj._CVs.attrs["side_dihedrals"]) 

354 side_bond_indices = side_indices[:, 1:3] 

355 # filter out the Proline angles 

356 side_bond_atoms = np.dstack( 

357 [ 

358 [traj.top.atom(a).__str__() for a in side_bond_indices[:, 0]], 

359 [traj.top.atom(a).__str__() for a in side_bond_indices[:, 1]], 

360 ] 

361 )[0] 

362 indices = np.arange(len(side_bond_indices)).tolist() 

363 for i, bond in enumerate(side_bond_atoms): 

364 if ( 

365 "PRO" in bond[0] 

366 and "PRO" in bond[1] 

367 and "CA" in bond[0] 

368 and "CB" in bond[1] 

369 ): 

370 indices.remove(i) 

371 if ( 

372 "PRO" in bond[0] 

373 and "PRO" in bond[1] 

374 and "CB" in bond[0] 

375 and "CG" in bond[1] 

376 ): 

377 indices.remove(i) 

378 

379 side_indices = side_indices[indices] 

380 side_bond_indices = side_bond_indices[indices] 

381 sidechain_dihedrals = sidechain_dihedrals[:, indices] 

382 

383 side_near_sides, side_far_sides = _get_far_and_near_networkx( 

384 g, side_bond_indices, inp_traj.top 

385 ) 

386 

387 # extend the traj 

388 for i in range(len(dihedrals)): 

389 if i == 0: 

390 out_traj = copy.deepcopy(inp_traj) 

391 else: 

392 out_traj = out_traj.join(inp_traj) 

393 

394 if fake_dihedrals: 

395 print("Faking dihedrals for testing purposes.") 

396 # dihedrals = np.vstack([current_angles for i in range(len(dihedrals))]) 

397 # dihedrals[:, 0] = np.linspace(-170, 170, len(dihedrals)) 

398 dihedrals = np.dstack( 

399 [np.linspace(-170, 170, len(dihedrals)) for i in range(dihedrals.shape[1])] 

400 ).squeeze() 

401 sidechain_dihedrals = np.dstack( 

402 [ 

403 np.linspace(-170, 170, len(sidechain_dihedrals)) 

404 for i in range(sidechain_dihedrals.shape[1]) 

405 ] 

406 ).squeeze() 

407 

408 # adjust the torsions 

409 new_xyz = copy.deepcopy(out_traj.xyz) 

410 for i in range(dihedrals.shape[0]): 

411 for j in range(dihedrals.shape[1]): 

412 # central_dihedrals 

413 near_side = dih_near_sides[j] 

414 far_side = dih_far_sides[j] 

415 dihedral = dih_indices[j] 

416 bond = dih_bond_indices[j] 

417 

418 # define inputs 

419 target_angle = dihedrals[i, j] 

420 current_angle = _dihedral(new_xyz[i], dihedral)[0][0] 

421 angle = target_angle - current_angle 

422 direction = np.diff(new_xyz[i, bond], axis=0).flatten() 

423 pivot_point = new_xyz[i, bond[0]] 

424 

425 # perform rotation 

426 rotmat = trans.rotation_matrix(angle, direction, pivot_point) 

427 padded = np.pad( 

428 new_xyz[i][far_side], 

429 ((0, 0), (0, 1)), 

430 mode="constant", 

431 constant_values=1, 

432 ) 

433 new_xyz[i][far_side] = rotmat.dot(padded.T).T[:, :3] 

434 

435 if i == 0 and j == 0 and verify_every_rotation: 

436 dih_indexes = traj._CVs.attrs["central_dihedrals"][j] 

437 s = f"Near and far side for dihedral {[str(traj.top.atom(x)) for x in dih_indexes]} are:" 

438 s += ( 

439 f"\nNear: {[str(traj.top.atom(x)) for x in near_side]}, {near_side}" 

440 ) 

441 s += f"\nFar: {[str(traj.top.atom(x)) for x in dih_far_sides[j][:12]]}..., {dih_far_sides[j][:12]}..." 

442 s += f"\nRotation around bond {[str(traj.top.atom(x)) for x in bond]}, {bond}." 

443 s += f"\nPositions of near side before rotation are\n{out_traj.xyz[i][near_side]}." 

444 s += f"\nPositions of near side after rotation aren\n{new_xyz[i][near_side]}" 

445 print(s) 

446 

447 # verify 

448 if verify_every_rotation: 

449 _ = _dihedral(new_xyz[i], dihedral)[0][0] 

450 if not np.isclose(_, target_angle, atol=1e-3): 

451 s = ( 

452 f"Adjusting dihedral angle for atoms {[str(traj.top.atom(x)) for x in dihedral]} failed with a tolerance of 1e-4." 

453 f"\nTarget angle was {target_angle} {angle_type}, but rotation yieled angle with {_} {angle_type}." 

454 f"\nCurrent angle was {current_angle}. To reach target angle is a rotation of {angle} {angle_type} was carried out." 

455 f"\nRotation axis was vector from {traj.top.atom(bond[0])} to {traj.top.atom(bond[1])}" 

456 f"\nOnly these atoms should have been affected by rotation: {far_side}" 

457 "\nBut somehow this method still crashed. Maybe these prints will help." 

458 ) 

459 raise BadError(s) 

460 

461 if sidechain_dihedrals is not None: 

462 for j in range(sidechain_dihedrals.shape[1]): 

463 # central_dihedrals 

464 near_side = side_near_sides[j] 

465 far_side = side_far_sides[j] 

466 dihedral = dih_indices[j] 

467 bond = side_indices[j] 

468 

469 # define inputs 

470 target_angle = sidechain_dihedrals[i, j] 

471 current_angle = np.rad2deg(_dihedral(new_xyz[i], dihedral))[0][0] 

472 angle = target_angle - current_angle 

473 direction = np.diff(new_xyz[i, bond], axis=0).flatten() 

474 pivot_point = new_xyz[i, bond[0]] 

475 

476 # perform rotation 

477 rotmat = trans.rotation_matrix(angle, direction, pivot_point) 

478 padded = np.pad( 

479 new_xyz[i][far_side], 

480 ((0, 0), (0, 1)), 

481 mode="constant", 

482 constant_values=1, 

483 ) 

484 new_xyz[i][far_side] = rotmat.dot(padded.T).T[:, :3] 

485 

486 # overwrite traj and return 

487 out_traj.xyz = new_xyz 

488 return out_traj 

489 

490 

491def _get_far_and_near_networkx(bondgraph, edge_indices, top=None): 

492 """Returns near and far sides for a list of edges giving the indices of the two atoms at which the structure is broken. 

493 

494 Args: 

495 bondgraph (networkx.classes.graph.Graph): The bondgraph describing the protein. 

496 edge_indices (np.ndarray): The edges the graph will be broken at. 

497 

498 Returns: 

499 tuple: A tuple containing the following: 

500 near_sides (list of np.ndarray): List of integer arrays giving the near sides. len(near_sides) == len(edge_indices). 

501 far_sides (list of np.ndarray): Same as near sides, but this time the far sides. 

502 

503 """ 

504 import networkx as nx 

505 from networkx.algorithms.components.connected import connected_components 

506 

507 near_sides = [] 

508 far_sides = [] 

509 for i, edge in enumerate(edge_indices): 

510 G = nx.convert_node_labels_to_integers(bondgraph).copy() 

511 G.remove_edge(*edge) 

512 components = [*connected_components(G)] 

513 if len(components) != 2: 

514 if top is None: 

515 raise Exception( 

516 f"Splitting the topology of the trajectory at the edge " 

517 f"{edge} does not work. Provide a topology to see, " 

518 "which atoms are affected" 

519 ) 

520 else: 

521 raise Exception( 

522 f"Splitting at edge {edge} does not work. Here are the " 

523 f"atoms: {top.atom(edge[0])} and {top.atom(edge[1])}." 

524 ) 

525 

526 if edge[1] in components[0] and edge[0] in components[1]: 

527 components = components[::-1] 

528 assert len(components) == 2, print( 

529 f"Protein might be cyclic or contain more than 1 chain. {len(components)}" 

530 ) 

531 assert edge[0] in components[0] and edge[1] in components[1], print( 

532 "Finding near and far sides failed." 

533 ) 

534 subgraph = G.subgraph(components[-1]).copy() 

535 far_sides.append(np.asarray(subgraph.nodes)) 

536 subgraph = G.subgraph(components[0]).copy() 

537 near_sides.append(np.asarray(subgraph.nodes)) 

538 return near_sides, far_sides 

539 

540 

541def _dihedral(xyz, indices): 

542 """Returns current dihedral angle between positions. 

543 

544 Adapted from MDTraj. 

545 

546 Args: 

547 xyz (np.ndarray). This function only takes a xyz array of a single frame and uses np.expand_dims() 

548 to make that fame work with the `_displacement` function from mdtraj. 

549 indices (Union[np.ndarray, list]): List of 4 ints describing the dihedral. 

550 

551 """ 

552 indices = np.expand_dims(np.asarray(indices), 0) 

553 xyz = np.expand_dims(xyz, 0) 

554 ix10 = indices[:, [0, 1]] 

555 ix21 = indices[:, [1, 2]] 

556 ix32 = indices[:, [2, 3]] 

557 

558 b1 = _displacement(xyz, ix10) 

559 b2 = _displacement(xyz, ix21) 

560 b3 = _displacement(xyz, ix32) 

561 

562 c1 = np.cross(b2, b3) 

563 c2 = np.cross(b1, b2) 

564 

565 p1 = (b1 * c1).sum(-1) 

566 p1 *= (b2 * b2).sum(-1) ** 0.5 

567 p2 = (c1 * c2).sum(-1) 

568 

569 return np.arctan2(p1, p2, None) 

570 

571 

572def _displacement(xyz, pairs): 

573 "Displacement vector between pairs of points in each frame" 

574 value = np.diff(xyz[:, pairs], axis=2)[:, :, 0] 

575 assert value.shape == ( 

576 xyz.shape[0], 

577 pairs.shape[0], 

578 3, 

579 ), "v.shape %s, xyz.shape %s, pairs.shape %s" % ( 

580 str(value.shape), 

581 str(xyz.shape), 

582 str(pairs.shape), 

583 ) 

584 return value 

585 

586 

587def dihedral_to_cartesian_tf_one_way_layers(dihedrals, cartesian, n): 

588 dihedrals = -dihedrals 

589 

590 rotated = cartesian[:, 1:] 

591 collected_cartesians = [cartesian[:, 0:1]] 

592 for i in range(n): 

593 collected_cartesians.append(rotated[:, 0:1]) 

594 axis = rotated[:, 1] - rotated[:, 0] 

595 axis /= tf.norm(axis, axis=1, keepdims=True) 

596 offset = rotated[:, 1:2] 

597 rotated = offset + tf.matmul( 

598 rotated[:, 1:] - offset, rotation_matrix(axis, dihedrals[:, i]) 

599 ) 

600 collected_cartesians.append(rotated) 

601 collected_cartesians = tf.concat(collected_cartesians, axis=1) 

602 return collected_cartesians 

603 

604 

605def backbone_hydrogen_oxygen_crossproduct(backbone_positions): 

606 assert backbone_positions.shape[2] % 3 == 0 # C, CA, N atoms, multiple of three 

607 pass 

608 

609 

610def guess_sp2_atom(cartesians, indices, angle_to_previous, bond_length): 

611 added_cartesians = [] 

612 for i in indices: 

613 prev_vec = cartesians[:, i - 1] - cartesians[:, i] 

614 try: 

615 next_vec = cartesians[:, i + 1] - cartesians[:, i] 

616 except tf.errors.InvalidArgumentError: 

617 next_vec = cartesians[:, i - 2] - cartesians[:, i] 

618 

619 perpendicular_axis = tf.linalg.cross(prev_vec, next_vec) 

620 perpendicular_axis /= tf.norm(perpendicular_axis, axis=1, keepdims=True) 

621 bond_vec = tf.matmul( 

622 tf.expand_dims(prev_vec, 1), 

623 rotation_matrix(perpendicular_axis, angle_to_previous), 

624 ) 

625 bond_vec = bond_vec[:, 0, :] 

626 bond_vec *= bond_length / tf.norm(bond_vec, axis=1, keepdims=True) 

627 added_cartesians.append(cartesians[:, i] + bond_vec) 

628 added_cartesians = tf.stack(added_cartesians, axis=1) 

629 return added_cartesians 

630 

631 

632def guess_amide_H(cartesians, N_indices): 

633 return guess_sp2_atom(cartesians, N_indices[1::], 123 / 180 * pi, 1.10) 

634 

635 

636def guess_amide_O(cartesians, C_indices): 

637 return guess_sp2_atom(cartesians, C_indices, 121 / 180 * pi, 1.24) 

638 

639 

640def rotation_matrix(axis_unit_vec, angle): 

641 angle = tf.expand_dims(tf.expand_dims(angle, axis=-1), axis=-1) 

642 i = tf.expand_dims(tf.eye(3), 0) 

643 zeros = tf.zeros(tf.shape(axis_unit_vec)[0]) 

644 cross_prod_matrix = tf.convert_to_tensor( 

645 [ 

646 [zeros, -axis_unit_vec[:, 2], axis_unit_vec[:, 1]], 

647 [axis_unit_vec[:, 2], zeros, -axis_unit_vec[:, 0]], 

648 [-axis_unit_vec[:, 1], axis_unit_vec[:, 0], zeros], 

649 ] 

650 ) 

651 cross_prod_matrix = tf.transpose(cross_prod_matrix, [2, 0, 1]) 

652 r = tf.cos(angle) * i 

653 r += tf.sin(angle) * cross_prod_matrix 

654 axis_unit_vec = tf.expand_dims(axis_unit_vec, 2) 

655 r += (1 - tf.cos(angle)) * tf.matmul( 

656 axis_unit_vec, tf.transpose(axis_unit_vec, [0, 2, 1]) 

657 ) 

658 return r 

659 

660 

661def merge_cartesians( 

662 central_cartesians, N_indices, O_indices, H_cartesians, O_cartesians 

663): 

664 cartesian = [central_cartesians[:, 0]] 

665 h_i = 0 

666 o_i = 0 

667 for i in range(1, central_cartesians.shape[1]): 

668 cartesian.append(central_cartesians[:, i]) 

669 if i in N_indices[1::]: 

670 cartesian.append(H_cartesians[:, h_i]) 

671 h_i += 1 

672 elif i in O_indices: 

673 cartesian.append(O_cartesians[:, o_i]) 

674 o_i += 1 

675 cartesian = tf.stack(cartesian, axis=1) 

676 assert ( 

677 cartesian.shape[1] 

678 == central_cartesians.shape[1] + H_cartesians.shape[1] + O_cartesians.shape[1] 

679 ) 

680 return cartesian 

681 

682 

683def dihedral_backmapping( 

684 pdb_path, dihedral_trajectory, rough_n_points=-1, sidechains=None 

685): 

686 """ 

687 Takes a pdb file with a peptide and creates a trajectory based on the dihedral angles given. 

688 It simply rotates around the dihedral angle axis. In the result side-chains might overlap but the backbone should 

689 turn out quite well. 

690 

691 :param pdb_path: (str) 

692 :param dihedral_trajectory: 

693 array-like of shape (traj_length, number_of_dihedrals) 

694 :param rough_n_points: (int) a step_size to select a subset of values from dihedral_trajectory is calculated by 

695 max(1, int(len(dihedral_trajectory) / rough_n_points)) with rough_n_points = -1 all values are used. 

696 :return: (MDAnalysis.Universe) 

697 """ 

698 step_size = max(1, int(len(dihedral_trajectory) / rough_n_points)) 

699 dihedral_trajectory = dihedral_trajectory[::step_size] 

700 if sidechains is not None: 700 ↛ 701line 700 didn't jump to line 701, because the condition on line 700 was never true

701 sidechain_dihedral_trajectory = sidechains[::step_size] 

702 

703 uni = mda.Universe(pdb_path, format="PDB") 

704 protein = uni.select_atoms("protein") 

705 

706 dihedrals = [] 

707 sidechain_dihedrals = [] 

708 

709 for residue in protein.residues: 

710 psi = residue.psi_selection() 

711 if psi: 

712 dihedrals.append(psi) 

713 

714 for residue in protein.residues: 

715 omega = residue.omega_selection() 

716 if omega: 

717 dihedrals.append(omega) 

718 

719 for residue in protein.residues: 

720 phi = residue.phi_selection() 

721 if phi: 

722 dihedrals.append(phi) 

723 

724 if sidechains is not None: 724 ↛ 725line 724 didn't jump to line 725, because the condition on line 724 was never true

725 for residue in protein.residues: 

726 chi1 = residue.chi1_selection() 

727 if chi1: 

728 sidechain_dihedrals.append(chi1) 

729 

730 for residue in protein.residues: 

731 if "chi2" in residue.__dir__(): 

732 sidechain_dihedrals.append(residue.chi2_selection()) 

733 

734 for residue in protein.residues: 

735 if "chi3" in residue.__dir__(): 

736 sidechain_dihedrals.append(residue.chi3_selection()) 

737 

738 for residue in protein.residues: 

739 if "chi4" in residue.__dir__(): 

740 sidechain_dihedrals.append(residue.chi4_selection()) 

741 

742 for residue in protein.residues: 

743 if "chi5" in residue.__dir__(): 

744 sidechain_dihedrals.append(residue.chi5_selection()) 

745 

746 if sidechains is not None: 746 ↛ 747line 746 didn't jump to line 747, because the condition on line 746 was never true

747 if sidechain_dihedral_trajectory.shape[1] == len(sidechain_dihedrals) * 2: 

748 sidechain_dihedral_trajectory = sidechain_dihedral_trajectory[:, ::2] 

749 

750 _expand_universe(uni, len(dihedral_trajectory)) 

751 

752 if sidechains is None: 752 ↛ 757line 752 didn't jump to line 757, because the condition on line 752 was never false

753 for dihedral_values, step in zip(dihedral_trajectory, uni.trajectory): 

754 for dihedral, value in zip(dihedrals, dihedral_values): 

755 _set_dihedral(dihedral, protein, value / (2 * pi) * 360) 

756 else: 

757 for dihedral_values, sidechain_dihedral_values, step in zip( 

758 dihedral_trajectory, sidechain_dihedral_trajectory, uni.trajectory 

759 ): 

760 for dihedral, value in zip(dihedrals, dihedral_values): 

761 _set_dihedral(dihedral, protein, value / (2 * pi) * 360) 

762 for dihedral, value in zip(sidechain_dihedrals, sidechain_dihedral_values): 

763 _set_dihedral(dihedral, protein, value / (2 * pi) * 360) 

764 return uni 

765 

766 

767def _set_dihedral(dihedral, atoms, angle): 

768 current_angle = dihedral.dihedral.value() 

769 head = atoms[dihedral[2].id :] 

770 vec = dihedral[2].position - dihedral[1].position 

771 head.rotateby(angle - current_angle, vec, dihedral[2].position) 

772 

773 

774def _expand_universe(universe, length): 

775 coordinates = ( 

776 AnalysisFromFunction(lambda ag: ag.positions.copy(), universe.atoms) 

777 .run() 

778 .results 

779 )["timeseries"] 

780 coordinates = np.tile(coordinates, (length, 1, 1)) 

781 universe.load_new(coordinates, format=MemoryReader)