Coverage for encodermap/misc/backmapping.py: 7%

743 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/misc/backmapping.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Backmapping functions to create new atomistic conformations from intrinsic 

23coordinates. 

24 

25""" 

26 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32# Future Imports at the top 

33from __future__ import annotations 

34 

35# Standard Library Imports 

36import os.path 

37from contextlib import contextmanager 

38from copy import deepcopy 

39from math import pi 

40from pathlib import Path 

41 

42# Third Party Imports 

43import numpy as np 

44import tensorflow as tf 

45import transformations 

46from optional_imports import _optional_import 

47from tqdm import tqdm as normal_tqdm 

48from tqdm.notebook import tqdm as notebook_tqdm 

49from transformations import rotation_matrix as transformations_rotation_matrix 

50 

51# Encodermap imports 

52from encodermap.loading import features 

53from encodermap.misc.misc import _is_notebook 

54from encodermap.misc.rotate import _dihedral, _get_near_and_far_networkx 

55from encodermap.trajinfo.info_all import TrajEnsemble 

56from encodermap.trajinfo.info_single import SingleTraj 

57 

58 

59# numba to accelerate 

60try: 

61 # Third Party Imports 

62 from numba import jit 

63 

64 _NUMBA_AVAILABLE: bool = True 

65except ImportError: 

66 _NUMBA_AVAILABLE: bool = False 

67 

68 

69################################################################################ 

70# Optional Imports 

71################################################################################ 

72 

73 

74md = _optional_import("mdtraj") 

75xr = _optional_import("xarray") 

76mda = _optional_import("MDAnalysis") 

77AnalysisFromFunction = _optional_import( 

78 "MDAnalysis", "analysis.base.AnalysisFromFunction" 

79) 

80MemoryReader = _optional_import("MDAnalysis", "coordinates.memory.MemoryReader") 

81jit = _optional_import("numba", "jit") 

82nb = _optional_import("numba") 

83nx = _optional_import("networkx") 

84 

85 

86################################################################################ 

87# Typing 

88################################################################################ 

89 

90 

91# Standard Library Imports 

92from collections.abc import Iterator, Sequence 

93from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload 

94 

95 

96if TYPE_CHECKING: 

97 # Standard Library Imports 

98 from io import BytesIO 

99 

100 # Third Party Imports 

101 import MDAnalysis as mda 

102 import mdtraj as md 

103 import networkx as nx 

104 from MDAnalysis.analysis.base import AnalysisFromFunction 

105 from MDAnalysis.coordinates.memory import MemoryReader 

106 

107 # Local Folder Imports 

108 from ..trajinfo import SingleTraj, TrajEnsemble 

109 

110 

111################################################################################ 

112# Globals 

113################################################################################ 

114 

115 

116__all__: list[str] = ["backbone_hydrogen_oxygen_crossproduct", "mdtraj_backmapping"] 

117 

118 

119################################################################################ 

120# Helpers 

121################################################################################ 

122 

123 

124@contextmanager 

125def temp_seed(seed): 

126 """Within this context manager, the random state of NumPy is fixed.""" 

127 state = np.random.get_state() 

128 np.random.seed(seed) 

129 try: 

130 yield 

131 finally: 

132 np.random.set_state(state) 

133 

134 

135def _raise_components_exception( 

136 components: Sequence[nx.Graph], 

137 trajs: TrajEnsemble, 

138 top: Optional[md.Topology] = None, 

139 remove_component_size: int = 0, 

140) -> None: 

141 """Raises a descriptive error message for the components upon breaking the 

142 connection graph of a protein topology.""" 

143 smallest_component = sorted(components, key=lambda x: len(x))[0] 

144 largest_component = sorted(components, key=lambda x: len(x))[1] 

145 if top is None: 

146 _str = f"from the provided {trajs.__class__.__name__}" 

147 elif isinstance(top, (int, np.int64)): 

148 _str = f"from `SingleTraj` number {top} of the provided `TrajEnsemble`." 

149 else: 

150 _str = f"from the provided {top} file" 

151 msg = ( 

152 f"The protein {_str} is disconnected. Changing dihedrals " 

153 f"in multiple disconnected chains is currently not possible. You can also " 

154 f"use `em.plot.plot_ball_and_stick(traj highlight='bonds')` to see which " 

155 f"bonds are present in the topology. If you are sure " 

156 f"your protein is just one chain you can try to load a custom topology " 

157 f"or provide a topology with manually fixed bonds. I got {len(components)} " 

158 f"disconnected components. The smallest component contains these atoms: " 

159 f"{smallest_component=} {largest_component=}. You can try and add " 

160 f"these bonds with the `custom_aas` keyword to this function call. If " 

161 f"these components contain unwanted residues like solvent or salt-ions, " 

162 f"you can set `remove_component_size` to a number representing the " 

163 f"sizes of these components (i.e. 3 for three atom water) to remove " 

164 f"these components from the trajectory." 

165 ) 

166 if remove_component_size > 0: 

167 msg += ( 

168 f" Your currently chosen `remove_component_size`={remove_component_size} " 

169 f"is not large enough to reduce the system to only one connected component." 

170 ) 

171 raise Exception(msg) 

172 

173 

174################################################################################ 

175# Public Functions 

176################################################################################ 

177 

178 

179def split_and_reverse_dihedrals(x: tf.Tensor) -> tf.Tensor: 

180 """Splits dihedrals in BackMapping model into left (reversed) and right part. 

181 These dihedrals are then used to bring the chain_in_plane into 3D. 

182 

183 Args: 

184 x (tf.Tensor): The dihedrals with shape (None, n_reisudes * 3 - 3) 

185 

186 Examples: 

187 >>> from encodermap.misc.backmapping import split_and_reverse_dihedrals 

188 >>> import numpy as np 

189 

190 >>> # create dihedrals for protein with 3 resiudes, i.e. 3*3 - 3 = 6 central dihedral angles 

191 >>> # single sample will be used -> shape = (1, 6) 

192 >>> np.random.seed(20) 

193 >>> dihedrals = np.random.random((1, 6)) * 2 * np.pi 

194 >>> print(dihedrals) 

195 [[3.69533481 5.64050171 5.60165278 5.12605805 0.22550092 4.34644107]] 

196 

197 >>> dihedrals_left, dihedrals_right = split_and_reverse_dihedrals(dihedrals) 

198 >>> print(dihedrals_left, dihedrals_right) 

199 [[5.60165278 5.64050171 3.69533481]] [[5.12605805 0.22550092 4.34644107]] 

200 

201 """ 

202 middle = int(int(x.shape[1]) / 2) 

203 cond = tf.math.equal(tf.math.mod(x.shape[1], 2), 0) 

204 return tf.cond( 

205 cond, 

206 true_fn=lambda: ( 

207 x[:, middle - 1 :: -1], 

208 x[:, middle:], 

209 ), # , middle, middle), 

210 false_fn=lambda: ( 

211 x[:, middle::-1], 

212 x[:, middle + 1 :], 

213 ), # , middle + 1, middle), 

214 ) 

215 

216 

217def split_and_reverse_cartesians(x): 

218 """Splits cartesians and returns a left (reversed) right part. 

219 

220 Because dihedrals are made up from 4 atoms, three atoms are 

221 identical in the left and right part of the list. This holds true: 

222 left[0] = right[2] 

223 left[1] = right[1] 

224 left[2] = right[0] 

225 

226 Args: 

227 x (tf.Tensor): The cartesians with shape (None, n_reisudes * 3, 3) 

228 

229 Examples: 

230 >>> from encodermap.misc.backmapping import split_and_reverse_cartesians 

231 >>> import numpy as np 

232 

233 >>> # create cartesians for protein with 3 resiudes, i.e. 9 

234 >>> # single sample will be used -> shape = (1, 9, 3) 

235 >>> np.random.seed(20) 

236 >>> cartesians = np.random.random((1, 9, 3)) * 10 

237 

238 >>> cartesians_left, cartesians_right = split_and_reverse_cartesians(cartesians) 

239 

240 >>> print(cartesians_left.shape, cartesians_right.shape) 

241 (1, 6, 3) (1, 6, 3) 

242 

243 >>> print(cartesians_left[:,0] == cartesians_right[:,2]) 

244 [[ True True True]] 

245 

246 >>> print(cartesians_left[:,1] == cartesians_right[:,1]) 

247 [[ True True True]] 

248 

249 >>> print(cartesians_left[:,2] == cartesians_right[:,0]) 

250 [[ True True True]] 

251 

252 

253 

254 """ 

255 split = int(int(x.shape[1]) / 2) 

256 return x[:, split + 1 :: -1], x[:, split - 1 :] 

257 

258 

259def dihedrals_to_cartesian_tf_layers( 

260 dihedrals: tf.Tensor, 

261 cartesians: tf.Tensor, 

262 left_iteration_counter: int, 

263 right_iteration_counter: int, 

264) -> tf.Tensor: 

265 """Calculates dihedrals to cartesians in Graph/Layer execution. 

266 

267 Args: 

268 dihedrals (tf.Tensor): The dihedrals of shape (None, n_resides * 3 - 3) 

269 cartesians (tf.Tensor): The cartesians of shape (None, n_residues * 3, 3). 

270 left_iteration_counter (int): The range(left_iteration_counter) for 

271 iteration over the left split of the cartesians chain. Needs to 

272 be supplied, because the shape of the SymbolicTensor is not known 

273 to tensorflow. 

274 right_iteration_counter (int): The range(right_iteration_counter) for 

275 iteration over the left split of the cartesians chain. Needs to 

276 be supplied, because the shape of the SymbolicTensor is not known 

277 to tensorflow. 

278 

279 Returns: 

280 tf.Tensor: The finished 3d chain. 

281 

282 """ 

283 if len(cartesians.get_shape()) == 2: 

284 # if a single line of cartesians is passed, it is repeated to match the number of dihedrals 

285 cartesians = tf.tile( 

286 tf.expand_dims(cartesians, axis=0), [tf.shape(dihedrals[0], 1, 1)] 

287 ) 

288 

289 # split and reverse so that the center of the molecule stays on the 2D plane 

290 # and the left and right ends curl into the 3rd dimension 

291 cartesians_left, cartesians_right = split_and_reverse_cartesians(cartesians) 

292 dihedrals_left, dihedrals_right = split_and_reverse_dihedrals(dihedrals) 

293 

294 new_cartesians_left = dihedral_to_cartesian_tf_one_way_layers( 

295 dihedrals=dihedrals_left, 

296 cartesian=cartesians_left, 

297 n=left_iteration_counter, 

298 ) 

299 new_cartesians_right = dihedral_to_cartesian_tf_one_way_layers( 

300 dihedrals=dihedrals_right, 

301 cartesian=cartesians_right, 

302 n=right_iteration_counter, 

303 ) 

304 

305 new_cartesians = tf.concat( 

306 [new_cartesians_left[:, ::-1], new_cartesians_right[:, 3:]], axis=1 

307 ) 

308 

309 return new_cartesians 

310 

311 

312@jit(nopython=True) 

313def _displacement_jit(xyz: np.ndarray, index: np.ndarray) -> np.ndarray: 

314 """Faster nopython numba-jit compiled version of calculating displacements. 

315 

316 Args: 

317 xyz (np.ndarray): Numpy array with shape (n_atoms, 3). 

318 index (np.ndarray): Numpy array with shape (n_pairs, 2) and of type int, 

319 indexing the displacements calculated between two atoms. 

320 

321 Returns: 

322 np.ndarray: The displacements of shape (n_pairs, 1). 

323 

324 """ 

325 return xyz[index[1]] - xyz[index[0]] 

326 

327 

328@jit(nopython=True) 

329def _dihedral_jit( 

330 xyz: np.ndarray, 

331 indices: np.ndarray, 

332) -> np.ndarray: 

333 """Faster nopython numba-jit compiled implementation of calculating a dihedral. 

334 

335 Args: 

336 xyz (np.ndarray): Numpy array with shape (n_atoms, 3). 

337 indices (np.ndarray): Numpy array with shape (n_dihedrals, 4) and of type int, 

338 indexing the dihedrals defined between four atoms. 

339 

340 Returns: 

341 np.ndarray: The displacements of shape (n_dihedrals, 1). 

342 

343 """ 

344 b1 = _displacement_jit(xyz, indices[0:2]) 

345 b2 = _displacement_jit(xyz, indices[1:3]) 

346 b3 = _displacement_jit(xyz, indices[2:4]) 

347 c1 = np.cross(b2, b3) 

348 c2 = np.cross(b1, b2) 

349 p1 = (b1 * c1).sum(-1) 

350 p1 *= (b2 * b2).sum(-1) ** 0.5 

351 p2 = (c1 * c2).sum(-1) 

352 return np.arctan2(p1, p2) 

353 

354 

355@jit(nopython=True) 

356def _rotmat_jit( 

357 angle: np.float32, 

358 direction: np.ndarray, 

359 pivot_point: np.ndarray, 

360) -> np.ndarray: 

361 """Same as :func:`rotation_matrix`, but jit compiled.""" 

362 sina = np.sin(angle) 

363 cosa = np.cos(angle) 

364 direction_unit = direction / (direction**2).sum() ** 0.5 

365 R = np.identity(3, dtype="float32") 

366 R *= cosa 

367 R += np.outer(direction_unit, direction_unit) * (1.0 - cosa) 

368 direction_unit *= sina 

369 R += np.array( 

370 [ 

371 [0.0, -direction_unit[2], direction_unit[1]], 

372 [direction_unit[2], 0.0, -direction_unit[0]], 

373 [-direction_unit[1], direction_unit[0], 0.0], 

374 ], 

375 dtype="float32", 

376 ) 

377 M = np.identity(4, dtype="float32") 

378 M[:3, :3] = R 

379 M[:3, 3] = pivot_point - np.dot(R, pivot_point) 

380 return M 

381 

382 

383@jit(nopython=True, parallel=True) 

384def parallel_rotation_application( 

385 xyz: np.ndarray, 

386 dihedral_indices: np.ndarray, 

387 dihedrals: np.ndarray, 

388 new_and_far_sides: np.ndarray, 

389) -> None: 

390 for j in range( 

391 dihedrals.shape[1] 

392 ): # cannot be parallelized because the later angles depend on the previous 

393 for i in nb.prange( 

394 dihedrals.shape[0] 

395 ): # can be parallelized because every frame can be treated separately 

396 target_angle = dihedrals[i, j] 

397 dihedral_index = dihedral_indices[i] 

398 current_angle = _dihedral_jit(xyz[i], dihedral_index) 

399 angle = target_angle - current_angle 

400 direction = xyz[i, dihedral_index[2]] - xyz[i, dihedral_index[1]] 

401 pivot_point = xyz[i, dihedral_index[0]] 

402 M = _rotmat_jit(angle, direction, pivot_point) 

403 padded = np.ones((len(xyz[i][~new_and_far_sides[i]]), 4), dtype="float32") 

404 padded[:, :3] = xyz[i][~new_and_far_sides[i]] 

405 xyz[i][~new_and_far_sides[i]] = M.dot(padded.T).T[:, :3] 

406 

407 

408@overload 

409def _full_backmapping_np( 409 ↛ exitline 409 didn't jump to the function exit

410 feature_description: dict[int, dict[int, Sequence[int]]], 

411 *inputs: np.ndarray, 

412 return_indices: Literal[False], 

413) -> tuple[np.ndarray, BytesIO]: ... 

414 

415 

416@overload 

417def _full_backmapping_np( 417 ↛ exitline 417 didn't jump to the function exit

418 feature_description: dict[int, dict[int, Sequence[int]]], 

419 *inputs: np.ndarray, 

420 return_indices: Literal[True], 

421) -> tuple[np.ndarray, BytesIO, dict[str, np.ndarray]]: ... 

422 

423 

424def _full_backmapping_np( 

425 feature_description: dict[int, dict[int, Sequence[int]]], 

426 *inputs: np.ndarray, 

427 return_indices: bool = False, 

428) -> Union[ 

429 tuple[np.ndarray, BytesIO], tuple[np.ndarray, BytesIO, dict[str, np.ndarray]] 

430]: 

431 """Also BackMaps sidechains. For that, we need a way to know which 

432 distances, angles, dihedrals belong to the backbone, and which belong to 

433 a sidechain. In contrast to how cartesians are normally represented in 

434 MD trajectories, the cartesians in this class are ordered by first all 

435 central cartesians, then all sidechain cartesians. Consider a theoretical 

436 protein with three residues: MET-LYS-GLY. The protein's backbone has 

437 9 cartesian coordiantes, 8 distances, 7 angles, and 6 dihedrals. 

438 Methionine has 3 sidechain angles (ch1, chi2, chi3), Lysine has 4 sidechain 

439 angles (chi1, ..., chi4), and Glycine has no sidechain angles. The 

440 feature_description dict for this protein should be: 

441 {-1: {1: 3, 2: 4, 3: 0}. 

442 

443 * Cartesians: 

444 The cartesian coordinates of the sidechain atoms is an array with shape 

445 `(sum([v + 1 for v in feature_description[-1].values() if v > 0])`. So for 

446 this theoretical protein, we have 4 sidechain coordinates for MET and 5 for LYS. 

447 The combined cartesians array first contains all backbones, and then follows 

448 up with the sidechain positions as defined in the feature_description dict. 

449 * Distances: 

450 The distances between the sidechain atoms include the CA of the backbone, 

451 thus we have `(sum([v + 1 for v in feature_description[-1].values() if v > 0])` 

452 sidechain distances. MET has 4 sidechain distances, LYS has 5. 

453 * Angles: 

454 The angles between the sidechain atoms include the CA and N of the backbone. 

455 We have `(sum([v + 1 for v in feature_description[-1].values() if v > 0])` 

456 angles. MET has 4 sidechain angles (N-CA-CB, CA-CB-CG, CB-CG-SD, CG-SD-CE) 

457 and LYS has 5 sidechain angles. 

458 * Dihedrals: 

459 The dihedrals between the sidechain atoms include the CA and N of the 

460 backbone, so that we have `(sum(list(feature_description[-1].values()))` 

461 sidechain dihedrals. 

462 

463 Indexing these features is done via boolean arrays. True values are kept 

464 stationary. False values are allowed to move. Indexing the central distances 

465 can be done with a triangular matrix with 8 rows, corresponding to the 

466 8 distances in our example protein: :: 

467 

468 METN, METCA, METC, LYSN, LYSCA, LYSC, GLYN, GLYCA, GLYC 

469 0 True, False, False, False, False, False, False, False, False 

470 1 True, True, False, False, False, False, False, False, False 

471 2 True, True, True, False, False, False, False, False, False 

472 3 True, True, True, True, False, False, False, False, False 

473 4 True, True, True, True, True, False, False, False, False 

474 5 True, True, True, True, True, True, False, False, False 

475 6 True, True, True, True, True, True, True, False, False 

476 7 True, True, True, True, True, True, True, True, False 

477 

478 Indexing the sidechain distances can be done with a matrix with small 

479 triangular matrices like this. For our example protein we have 9 such 

480 indices for the 9 sidechain distances :: 

481 

482 METCA, METCB, METCG, METSD, METCE, LYSCA, LYSCB, LYSCG, LYSCD, LYSCE, LYSNZ 

483 0 True, False, False, False, False, False, False, False, False, False, False 

484 1 True, True, False, False, False, False, False, False, False, False, False 

485 2 True, True, True, False, False, False, False, False, False, False, False 

486 3 True, True, True, False, False, False, False, False, False, False, False 

487 4 True, True, True, True, False, False, False, False, False, False, False 

488 5 False, False, False, False, False, True, False, False, False, False, False 

489 6 False, False, False, False, False, True, True, False, False, False, False 

490 7 False, False, False, False, False, True, True, True, False, False, False 

491 8 False, False, False, False, False, True, True, True, True, False, False 

492 9 False, False, False, False, False, True, True, True, True, True, False 

493 10 False, False, False, False, False, True, True, True, True, True, True 

494 

495 However, to keep the CA-atoms from appearing twice, the resulting array needs to 

496 ditch the CA columns of the side_distances and add True, where needed. The algorithm 

497 for the right side of the central_distances is row 0: all false, every three rows 

498 add True for the number of sidechain atoms. The last row is just True. The left 

499 side for the sidechain distances (without CA) are just True. We keep the 

500 chain immovable for adjusting these lengths. For 

501 our protein the resulting array looks like this: :: 

502 

503 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 

504 METN, METCA, METC, LYSN, LYSCA, LYSC, GLYN, GLYCA, GLYC, METCB, METCG, METSD, METCE, LYSCB, LYSCG, LYSCD, LYSCE, LYSNZ 

505 0 True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False 

506 1 True, True, False, False, False, False, False, False, False, True, True, True, True, False, False, False, False, False 

507 2 True, True, True, False, False, False, False, False, False, True, True, True, True, False, False, False, False, False 

508 3 True, True, True, True, False, False, False, False, False, True, True, True, True, False, False, False, False, False 

509 4 True, True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True, True 

510 5 True, True, True, True, True, True, False, False, False, True, True, True, True, True, True, True, True, True 

511 6 True, True, True, True, True, True, True, False, False, True, True, True, True, True, True, True, True, True 

512 7 True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, True 

513 ----------------------- 

514 8 True, True, True, True, True, True, True, True, True, False, False, False, False, True, True, True, True, True 

515 9 True, True, True, True, True, True, True, True, True, True, False, False, False, True, True, True, True, True 

516 10 True, True, True, True, True, True, True, True, True, True, True, False, False, True, True, True, True, True 

517 11 True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True 

518 12 True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False 

519 13 True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False 

520 14 True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False 

521 15 True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False 

522 16 True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False 

523 

524 Indexing the angles is similar to the distances. For an angle, we need 

525 to index the left side (including pivot point) and the right side 

526 (excluding the pivot point). For the central chain, this can be done by 

527 omitting the first row of the distance indices. The sidechain indices can 

528 be left as is. They contain the correct indices for rotation. Take the angle 

529 METN-METCA-METCB as an example. We need the whole central chain to stay still, 

530 while METCB, METCG, METSD, METCE are moving. That is exactly row 8 of 

531 the array above. The rotation axis for this chain, which is currently just 

532 in 2D is the z-axis: (0, 0, 1). 

533 

534 Indexing the dihedrals is similar to the angles. For a dihedral, we need to 

535 define a pivot point: The first atom of the bond (the second atom of the four-tuple 

536 defining the dihedral). Also, we need a rotation axis, which is the vector 

537 connecting the two atoms. Whether the second atom is part of the movable 

538 group is trivial, as it won't be rotated because it lies within the rotation axis. 

539 The indices for the central dihedrals are just the same indices as the 

540 distances omitting the first and last rows of the distance indices. This 

541 coincides with the indices for the sidechain atoms. The first dihedral 

542 angle of any protein is the psi1 angle between the N-terminal nitrogen 

543 N-CA-C-N. In this case, the rotational axis is between CA-C and thus, the 

544 sidechain (CB, etc.) needs to stay still for this dihedral. This is exactly 

545 what row 1 of our index array describes. The next dihedrals, omega1, and phi1, 

546 keep these indices of the sidechains, while the indices increase in the 

547 central part of the indices (rows 2 and 3). The next dihedral, psi2 will keep 

548 the sidechain dihedrals of the next residue stationary, as now the CA is on 

549 the far side of the rotation. The sidechain dihedral rotations are similar, 

550 however, the last row for every residue's sidechain must be omitted. 

551 

552 Args: 

553 feature_description (dict[int, dict[int, Sequence[int]]]): A dictrionary 

554 that defines the topological hull of the proteins. 

555 *inputs (np.ndarray): The inputs in the order central_distances, 

556 central_angles, central_dihedrals, side_distances, side_angles, 

557 side_dihedrals. 

558 return_indices (bool): Whether to also return a dict of np.ndarray, to be 

559 used as indices. 

560 

561 Returns: 

562 Union[tuple[np.ndarray, BytesIO], tuple[np.ndarray, BytesIO, dict[str, np.ndarray]]]: 

563 Either a tuple of the reconstructed coordinates in a numpy array and 

564 a BytesIO object containing png data. Or a tuple of the former two, 

565 with a index dictionary. 

566 

567 """ 

568 ############################################################################ 

569 # Init 

570 ############################################################################ 

571 

572 # Imports 

573 # Standard Library Imports 

574 import itertools 

575 from io import BytesIO 

576 

577 # Third Party Imports 

578 import matplotlib.pyplot as plt 

579 from scipy.linalg import block_diag 

580 from transformations import rotation_matrix 

581 

582 # Encodermap imports 

583 from encodermap.misc.rotate import _dihedral 

584 

585 # Definitions and Tests 

586 n_residues: int = max(list(feature_description[-1].keys())) 

587 assert np.array_equal( 

588 np.arange(1, n_residues + 1), 

589 np.sort(np.asarray(list(feature_description[-1].keys()))), 

590 ), ( 

591 f"Currently the `feature_indices[-1]` dict needs to contain monotonous " 

592 f"increasing keys. Starting from 1 {feature_description[-1].keys()=}" 

593 ) 

594 n_sidechains: int = sum([v + 1 for v in feature_description[-1].values() if v > 0]) 

595 sum_sidechains = sum(list(feature_description[-1].values())) 

596 

597 # this can be defined beforehand and then stacked as often, as a batch needs it 

598 init_xyz: np.ndarray = np.zeros( 

599 shape=(1, n_residues * 3 + n_sidechains, 3), 

600 dtype="float32", 

601 ) 

602 

603 # first we create the central_distance indices 

604 central_distance_indices = np.tri( 

605 N=n_residues * 3 - 1, 

606 M=n_residues * 3, 

607 k=0, 

608 ).astype(bool) 

609 right_side_central_distance_indices = [ 

610 np.full(shape=(1, n_sidechains), fill_value=False, dtype=bool) 

611 ] 

612 count = 0 # starts at the first atom of the central chan 

613 count2 = n_residues * 3 + 1 # starts at the first atom of the sidechain 

614 sidechain_cartesians_ind = [] 

615 sidechain_positions_indices = [] 

616 central_distance_index_duplets = np.vstack( 

617 [ 

618 np.arange(0, n_residues * 3)[:-1], 

619 np.arange(0, n_residues * 3)[1:], 

620 ] 

621 ).T.tolist() 

622 sidechain_distance_index_duplets = [] 

623 central_angle_index_triplets = np.vstack( 

624 [ 

625 np.arange(0, n_residues * 3)[:-2], 

626 np.arange(0, n_residues * 3)[1:-1], 

627 np.arange(0, n_residues * 3)[2:], 

628 ] 

629 ).T.tolist() 

630 sidechain_angle_index_triplets = [] 

631 central_dihedral_index_quadruplets = np.vstack( 

632 [ 

633 np.arange(0, n_residues * 3)[:-3], 

634 np.arange(0, n_residues * 3)[1:-2], 

635 np.arange(0, n_residues * 3)[2:-1], 

636 np.arange(0, n_residues * 3)[3:], 

637 ] 

638 ).T.tolist() 

639 sidechain_dihedral_index_quadruplets = [] 

640 

641 # iterate over feature_description[-1] to get all indices and the right side 

642 # of the central cartesians 

643 for i, (residue, n_sidechains_in_residue) in zip( 

644 itertools.count(1, 3), feature_description[-1].items() 

645 ): 

646 if n_sidechains_in_residue == 0: 

647 if residue == 1 or residue == n_residues: 

648 continue 

649 else: 

650 right_side_central_distance_indices.append(t) 

651 else: 

652 sidechain_cartesians_ind.append( 

653 np.arange(count, count + n_sidechains_in_residue) 

654 ) 

655 sidechain_positions_indices.append( 

656 [i] + np.arange(count2 - 1, count2 + n_sidechains_in_residue).tolist() 

657 ) 

658 for sidechain_i in range(n_sidechains_in_residue + 1): 

659 if sidechain_i == 0: 

660 # adds CA-CB 

661 sidechain_distance_index_duplets.append( 

662 [(residue - 1) * 3 + 1, count2 - 1] 

663 ) 

664 # adds N-CA-CB 

665 sidechain_angle_index_triplets.append( 

666 [(residue - 1) * 3, (residue - 1) * 3 + 1, count2 - 1] 

667 ) 

668 # adds N-CA-CB-CG 

669 sidechain_dihedral_index_quadruplets.append( 

670 [(residue - 1) * 3, (residue - 1) * 3 + 1, count2 - 1, count2] 

671 ) 

672 elif sidechain_i == 1: 

673 # adds CB-CG 

674 sidechain_distance_index_duplets.append([count2 - 1, count2]) 

675 # adds CA-CB-CG 

676 sidechain_angle_index_triplets.append( 

677 [(residue - 1) * 3 + 1, count2 - 1, count2] 

678 ) 

679 # adds CA-CB-CG-CD 

680 if sidechain_i < n_sidechains_in_residue: 

681 sidechain_dihedral_index_quadruplets.append( 

682 [(residue - 1) * 3 + 1, count2 - 1, count2, count2 + 1] 

683 ) 

684 else: 

685 # adds CG-CD and so on 

686 sidechain_distance_index_duplets.append( 

687 [count2 + sidechain_i - 2, count2 + sidechain_i - 1] 

688 ) 

689 # adds CB-CG-CD and so on 

690 sidechain_angle_index_triplets.append( 

691 [ 

692 count2 + sidechain_i - 3, 

693 count2 + sidechain_i - 2, 

694 count2 + sidechain_i - 1, 

695 ] 

696 ) 

697 if sidechain_i < n_sidechains_in_residue: 

698 sidechain_dihedral_index_quadruplets.append( 

699 [ 

700 count2 + sidechain_i - 3, 

701 count2 + sidechain_i - 2, 

702 count2 + sidechain_i - 1, 

703 count2 + sidechain_i, 

704 ] 

705 ) 

706 count += n_sidechains_in_residue + 1 

707 count2 += n_sidechains_in_residue + 1 

708 t = np.zeros( 

709 shape=(3, n_sidechains), 

710 dtype=bool, 

711 ) 

712 t[:, :count] = True 

713 right_side_central_distance_indices.append(t) 

714 assert len(sidechain_angle_index_triplets) == n_sidechains 

715 assert len(sidechain_dihedral_index_quadruplets) == sum_sidechains, ( 

716 f"I could not reconstruct the correct number of sidechain dihedral " 

717 f"quadruplets. The number of sidechain dihedrals requires the list " 

718 f"to have length {sum_sidechains}, but I created a list with " 

719 f"{len(sidechain_dihedral_index_quadruplets)}. The input has shape " 

720 f"{inputs[5].shape}." 

721 ) 

722 right_side_central_distance_indices.append( 

723 np.full(shape=(1, n_sidechains), fill_value=True, dtype=bool) 

724 ) 

725 right_side_central_distance_indices = np.vstack(right_side_central_distance_indices) 

726 angle_index_triplets = np.vstack( 

727 central_angle_index_triplets + sidechain_angle_index_triplets 

728 ) 

729 dihedral_index_quadruplets = np.vstack( 

730 central_dihedral_index_quadruplets + sidechain_dihedral_index_quadruplets 

731 ) 

732 if sidechain_cartesians_ind != []: # if sidechains 

733 _use_sidechains = True 

734 sidechain_cartesians_ind = np.concatenate(sidechain_cartesians_ind) 

735 central_distance_indices = np.hstack( 

736 [central_distance_indices, right_side_central_distance_indices] 

737 ) 

738 side_distance_indices = [ 

739 (np.tri(N=i + 1, M=i + 2, k=0) + 1)[:, 1:] 

740 for i in feature_description[-1].values() 

741 if i > 0 

742 ] 

743 side_distance_indices = (block_diag(*side_distance_indices) % 2) == 0 

744 left_side_side_distance_indices = ( 

745 np.full( # all atoms in the central chain are True 

746 shape=(len(side_distance_indices), n_residues * 3), 

747 fill_value=True, 

748 dtype=bool, 

749 ) 

750 ) 

751 side_distance_indices = np.hstack( 

752 [left_side_side_distance_indices, side_distance_indices] 

753 ) 

754 distance_indices = np.vstack([central_distance_indices, side_distance_indices]) 

755 else: # if no sidechains 

756 _use_sidechains = False 

757 distance_indices = central_distance_indices 

758 assert distance_indices.shape == ( 

759 n_residues * 3 - 1 + n_sidechains, 

760 init_xyz.shape[1], 

761 ), ( 

762 f"The shape of the distance index after stacking is unexpected.\n" 

763 f"Expected: {(n_residues * 3 - 1 + n_sidechains, init_xyz.shape[1])}\n" 

764 f"Actual: {distance_indices.shape}" 

765 ) 

766 

767 # now the angles 

768 central_angle_indices = central_distance_indices[1:] 

769 if _use_sidechains: # if sidechains 

770 angle_indices = np.vstack([central_distance_indices[1:], side_distance_indices]) 

771 side_angle_indices = side_distance_indices 

772 else: # no sidechains 

773 angle_indices = central_distance_indices[1:] 

774 assert len(angle_indices) == len(distance_indices) - 1 

775 

776 # and the dihedrals 

777 if _use_sidechains: # if sidechains 

778 dihedral_indices = np.vstack( 

779 [ 

780 central_distance_indices[1:-1], 

781 side_distance_indices[sidechain_cartesians_ind], 

782 ] 

783 ) 

784 corrector = np.count_nonzero( 

785 list(feature_description[-1].values()) 

786 ) # per reisude with sidechain dihedrals one less 

787 else: 

788 dihedral_indices = central_distance_indices[1:-1] 

789 corrector = 0 

790 assert len(dihedral_indices) == len(distance_indices) - 2 - corrector 

791 assert angle_index_triplets.shape[0] == angle_indices.shape[0] 

792 assert dihedral_index_quadruplets.shape[0] == dihedral_indices.shape[0], ( 

793 f"The number of dihedral indices ({len(distance_indices)}) and quadruplets " 

794 f"does not match ({len(dihedral_index_quadruplets)}). I get " 

795 f"{inputs[2].shape[1] + inputs[5].shape[1]} in inputs." 

796 ) 

797 

798 ############################################################################ 

799 # Call 

800 ############################################################################ 

801 

802 ( 

803 central_distances, 

804 central_angles, 

805 central_dihedrals, 

806 side_distances, 

807 side_angles, 

808 side_dihedrals, 

809 ) = inputs 

810 

811 assert all( 

812 len(i) == len(inputs[0]) for i in inputs[1:] 

813 ), f"Inhomogeneous input lengths: {[len(i) for i in inputs]}" 

814 distances = np.hstack([central_distances, side_distances]) 

815 angles = np.hstack([central_angles, side_angles]) 

816 dihedrals = np.hstack([central_dihedrals, side_dihedrals]) 

817 assert distance_indices.shape[0] == distances.shape[1] 

818 assert angle_indices.shape[0] == angles.shape[1] 

819 assert dihedral_indices.shape[0] == dihedrals.shape[1] 

820 

821 # for debug 

822 fig, (ax1, ax2, ax3) = plt.subplots( 

823 nrows=1, ncols=3, subplot_kw={"projection": "3d"}, figsize=(20, 8) 

824 ) 

825 buf = BytesIO() 

826 

827 # copy the predefined array and make an array of quarternions 

828 if len(angles) > 1: 

829 xyz_out = np.repeat(init_xyz, len(angles), axis=0) 

830 else: 

831 xyz_out = init_xyz.copy() 

832 xyz_out = np.pad( 

833 xyz_out, ((0, 0), (0, 0), (0, 1)), mode="constant", constant_values=1 

834 ) 

835 

836 # distances in more tensor-flow-friendly implementation 

837 # by creating a list and concatenating we can forego assignments whicha are 

838 # not supported for symbolic tensors 

839 xs_central = [np.zeros((len(angles),))] 

840 ys_central = [np.zeros((len(angles),))] 

841 xs_side = [] 

842 ys_side = [] 

843 

844 residue = 0 # residue here is again, 0-based 

845 idx = 0 

846 j = 0 

847 n_sidechains_in_residue = np.array( 

848 [feature_description[-1][k] for k in sorted(feature_description[-1].keys())] 

849 ).astype(np.int32) 

850 for i in range(len(central_distance_indices)): 

851 assert np.all(central_distances[:, i] > 0) 

852 xs_central.append(xs_central[-1] + central_distances[:, i]) 

853 ys_central.append(np.zeros((len(angles),))) 

854 if idx == 0 and _use_sidechains: 

855 n_sidechains = n_sidechains_in_residue[residue] 

856 if n_sidechains > 0: 

857 for n in range(n_sidechains + 1): 

858 xs_side.append(xs_central[-1]) 

859 dists = side_distances[:, j - n : j + 1] 

860 assert np.all(dists > 0), ( 

861 f"Side distances at (0-based) residue {residue} are smaller than 0. " 

862 f"This is the {n} sidechain distance of a total of " 

863 f"{n_sidechains + 1} sidechain distances. For that, I index " 

864 f"the sidechain dists array for all frames from {j-n=} to " 

865 f"{j+1=}, which gives {side_distances[:, j-n:j+1]} for the " 

866 f"first 5 frames.." 

867 ) 

868 _ = np.sum(dists, axis=1) 

869 if np.any(np.isnan(_)): 

870 raise Exception( 

871 f"At index {i=} of the central distances and {j=} of " 

872 f"the sidechain distances, the value of the y coordinate " 

873 f"became NaN. The shapes of the arrays are:" 

874 f"{central_distances.shape=} {side_distances.shape=} " 

875 f"{np.any(np.isnan(central_distances))=} " 

876 f"{np.any(np.isnan(side_distances))=}" 

877 ) 

878 ys_side.append(_) 

879 j += 1 

880 idx += 1 

881 if idx >= 3: 

882 residue += 1 

883 idx = 0 

884 xs = np.stack(xs_central + xs_side, axis=1) 

885 ys = np.stack(ys_central + ys_side, axis=1) 

886 xyz_out = np.stack([xs, ys, np.zeros(xs.shape), np.ones(xs.shape)], axis=2) 

887 assert not np.any( 

888 np.isnan(xyz_out) 

889 ), f"After fixing dists, some values in `xyz_out` are NaN." 

890 

891 for frame, ( 

892 central_frame_dists, 

893 side_frame_dists, 

894 central_frame_angles, 

895 side_frame_angles, 

896 frame_dihedrals, 

897 ) in enumerate( 

898 zip(central_distances, side_distances, central_angles, side_angles, dihedrals) 

899 ): 

900 # plot 

901 if frame == 0: 

902 ax1.plot(*xyz_out[0, : n_residues * 3, :3].T, "bo-") 

903 for ind in sidechain_positions_indices: 

904 ax1.plot(*xyz_out[0, ind, :3].T, "bo-") 

905 

906 # angles 

907 for i, (ang, ind, angle_index) in enumerate( 

908 zip( 

909 central_frame_angles, 

910 central_angle_indices, 

911 central_angle_index_triplets, 

912 ) 

913 ): 

914 direction = np.array([0, 0, 1]).astype("float32") 

915 pivot_point = xyz_out[frame, angle_index[1], :3] 

916 a, b, c = xyz_out[frame, angle_index, :3] 

917 ba = a - b 

918 bc = c - b 

919 prod = np.linalg.norm(ba) * np.linalg.norm(bc) 

920 cosine_angle = np.clip(np.dot(ba, bc) / prod, -1, 1) 

921 current_angle = np.arccos(cosine_angle) 

922 angle = np.abs(ang - current_angle) 

923 rotmat = rotation_matrix( 

924 angle=angle, direction=direction, point=pivot_point 

925 ) 

926 rotated = rotmat.dot(xyz_out[frame, ~ind].T).T[:, :3] 

927 xyz_out[frame, ~ind, :3] = rotated 

928 

929 if _use_sidechains: 

930 for i, (ang, ind, angle_index) in enumerate( 

931 zip( 

932 side_frame_angles, 

933 side_angle_indices, 

934 sidechain_angle_index_triplets, 

935 ) 

936 ): 

937 direction = np.array([0, 0, -1]).astype("float32") 

938 pivot_point = xyz_out[frame, angle_index[1], :3] 

939 a, b, c = xyz_out[frame, angle_index, :3] 

940 ba = a - b 

941 bc = c - b 

942 cosine_angle = np.clip( 

943 np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc)), -1, 1 

944 ) 

945 current_angle = np.arccos(cosine_angle) 

946 angle = np.abs(ang - current_angle) 

947 rotmat = rotation_matrix( 

948 angle=angle, direction=direction, point=pivot_point 

949 ) 

950 rotated = rotmat.dot(xyz_out[frame, ~ind].T).T[:, :3] 

951 xyz_out[frame, ~ind, :3] = rotated 

952 

953 # plot 

954 if frame == 0: 

955 ax2.plot(*xyz_out[0, : n_residues * 3, :3].T, "bo-") 

956 for ind in sidechain_positions_indices: 

957 ax2.plot(*xyz_out[0, ind, :3].T, "bo-") 

958 

959 # dihedrals 

960 for i, (d_ang, ind, dihedral_index) in enumerate( 

961 zip(frame_dihedrals, dihedral_indices, dihedral_index_quadruplets) 

962 ): 

963 bond = dihedral_index[[1, 2]] 

964 a, b, c, d = xyz_out[frame, dihedral_index, :3] 

965 direction = np.diff(xyz_out[frame, bond, :3], axis=0).flatten() 

966 pivot_point = xyz_out[frame, bond[0], :3] 

967 current_dihedral = _dihedral(xyz_out[frame, :, :3], dihedral_index)[0][0] 

968 if np.any(np.isnan(current_dihedral)): 

969 current_dihedral = 0.0 

970 dihedral = d_ang - current_dihedral 

971 rotmat = rotation_matrix( 

972 angle=dihedral, direction=direction, point=pivot_point 

973 ) 

974 xyz_out[frame, ~ind, :3] = rotmat.dot(xyz_out[frame, ~ind].T).T[:, :3] 

975 

976 # plot 

977 if frame == 0: 

978 ax3.plot(*xyz_out[0, : n_residues * 3, :3].T, "bo-") 

979 for ind in sidechain_positions_indices: 

980 ax3.plot(*xyz_out[0, ind, :3].T, "bo-") 

981 

982 fig.savefig(buf, format="png") 

983 buf.seek(0) 

984 if not return_indices: 

985 return xyz_out[..., :3], buf 

986 else: 

987 return ( 

988 xyz_out[..., :3], 

989 buf, 

990 { 

991 "central_distance_indices": np.asarray(central_distance_index_duplets), 

992 "side_distance_indices": np.asarray(sidechain_distance_index_duplets), 

993 "central_angles_indices": np.asarray(central_angle_index_triplets), 

994 "side_angles_indices": np.asarray(sidechain_angle_index_triplets), 

995 "central_dihedrals_indices": np.asarray( 

996 central_dihedral_index_quadruplets 

997 ), 

998 "side_dihedrals_indices": np.asarray( 

999 sidechain_dihedral_index_quadruplets 

1000 ), 

1001 }, 

1002 ) 

1003 

1004 

1005def _residue_number() -> Iterator[int]: 

1006 i = 0 

1007 atoms = 0 

1008 while True: 

1009 yield i 

1010 atoms += 1 

1011 if atoms > 2: 

1012 atoms = 0 

1013 i += 1 

1014 

1015 

1016def _alternating() -> Iterator[int]: 

1017 i = 0 

1018 while True: 

1019 if i % 2 == 0: 

1020 yield 1 

1021 else: 

1022 yield -1 

1023 i += 1 

1024 

1025 

1026@overload 

1027def mdtraj_backmapping( 1027 ↛ exitline 1027 didn't jump to the function exit

1028 top: Optional[Union[Path, str, int, md.Topology]], 

1029 dihedrals: Optional[np.ndarray], 

1030 sidechain_dihedrals: Optional[np.ndarray], 

1031 trajs: Optional[Union[TrajEnsemble, SingleTraj]], 

1032 remove_component_size: int, 

1033 verify_every_rotation: bool, 

1034 angle_type: Literal["degree", "radian"], 

1035 omega: bool, 

1036 guess_amid_atoms: bool, 

1037 return_indices: Literal[False], 

1038 parallel: bool, 

1039 progbar: Optional[Any], 

1040) -> md.Trajectory: ... 

1041 

1042 

1043@overload 

1044def mdtraj_backmapping( 1044 ↛ exitline 1044 didn't jump to the function exit

1045 top: Optional[Union[Path, str, int, md.Topology]], 

1046 dihedrals: Optional[np.ndarray], 

1047 sidechain_dihedrals: Optional[np.ndarray], 

1048 trajs: Optional[Union[TrajEnsemble, SingleTraj]], 

1049 remove_component_size: int, 

1050 verify_every_rotation: bool, 

1051 angle_type: Literal["degree", "radian"], 

1052 omega: bool, 

1053 guess_amid_atoms: bool, 

1054 return_indices: Literal[True], 

1055 parallel: bool, 

1056 progbar: Optional[Any], 

1057) -> tuple[md.Trajectory, dict[str, np.ndarray]]: ... 

1058 

1059 

1060def mdtraj_backmapping( 

1061 top: Optional[Union[Path, str, int, md.Topology]] = None, 

1062 dihedrals: Optional[np.ndarray] = None, 

1063 sidechain_dihedrals: Optional[np.ndarray] = None, 

1064 trajs: Optional[Union[TrajEnsemble, SingleTraj]] = None, 

1065 remove_component_size: int = 0, 

1066 verify_every_rotation: bool = False, 

1067 angle_type: Literal["degree", "radian"] = "radian", 

1068 omega: bool = True, 

1069 guess_sp2_atoms: bool = True, 

1070 return_indices: bool = False, 

1071 parallel: bool = False, 

1072 progbar: Optional[Any] = None, 

1073) -> Union[md.Trajectory, tuple[md.Trajectory, dict[str, np.ndarray]]]: 

1074 """Uses MDTraj and Christoph Gohlke's transformations.py to rotate the 

1075 bonds in the provided topology. 

1076 

1077 Todo: 

1078 * Make this faster. Maybe write a C or FORTRAN implementation. 

1079 

1080 General procedure: 

1081 * Decide on which topology to use (if different topologies are in the 

1082 `TrajEnsemble` class, the `dihedrals` and `sidechain_dihedrals` arrays 

1083 need to be altered so that the correct dihedrals are used. 

1084 Because EncoderMap is trained on a full input `dihedrals` and 

1085 `sidechain_dihedrals` contain the dihedrals for the topology 

1086 in `TrajEnsemble` with most of such angles. Some SingleTraj 

1087 classes in TrajEnsemble might not contain all these angles if, for 

1088 example, an amino acid has been modified the mutant contains more 

1089 sidechain dihedrals than the wt. So the correct sidechain dihedrals 

1090 for the wildtype need to be selected. 

1091 * Get the indices of the far sides of the rotations. The graph is 

1092 gradually broken apart and the longer sub-graphs are kept. 

1093 * Extend the trajectory. The lengths of dihedrals and sidechain_dihedrals 

1094 should match. The frame given by top will be duplicated 

1095 len(dihedrals)-times. 

1096 * Get the current angles. We know what the final angles should be, 

1097 but now how far to rotate the bonds. This can be done by getting 

1098 the difference between current and target angle. 

1099 * Rotate the bonds. Using Christoph Gohlke's transformations.py, 

1100 the rotation matrix is constructed and the array is padded 

1101 with zeros to resemble an array of quaternions. 

1102 

1103 Args: 

1104 top (Optional[str]): The topology file to use. 

1105 dihedrals (Optional[np.ndarray]): The dihedrals to put onto the trajectory. 

1106 `len(dihedrals)` is number of frames of output trajectory. 

1107 `dihedrals.shape[1]` needs to be the same as the number of dihedrals 

1108 in the topology. Can be None, in which case dihedrals and 

1109 sidechain dihedrals will be faked. 

1110 sidechain_dihedrals (Optional[np.ndarray]): 

1111 The sidechain dihedrals to put onto the trajectory. 

1112 If None is provided, the sidechains are kept like they were in 

1113 the topology. Defaults to None. 

1114 trajs (Optional[em.TrajEnsemble, em.SingleTraj]): Encodermap TrajEnsemble 

1115 class. It Can accelerate the loading of current dihedral angles. 

1116 Checks if provided topology is part of trajs. Defaults to None. 

1117 verify_every_rotation (bool): Whether the rotation succeeded. 

1118 angle_type (Literal["degree", "radians"]): Whether input is in degrees. Input will be 

1119 converted to radians. Defaults to False. 

1120 omega (bool): Whether your input backbone dihedrals contain the omega angle. 

1121 return_indices (bool): Whether to not only return the back-mapped 

1122 trajectory, but also a dict of labels. This dict contains the keys: 

1123 * 'dihedrals_labels' 

1124 * 'generic_dihedrals_labels' 

1125 * 'side_dihedrals_labels' 

1126 * 'generic_side_dihedrals_labels' 

1127 Which matches the indices of the returned dihedrals with the input 

1128 MD structures in `top` and/or `trajs`. This can be useful to make 

1129 sure that input dihedrals match output dihedrals. Why? Because there 

1130 are some proline dihedrals that cannot be adjusted. They are filtered 

1131 out before doing backmapping, and the indices give the names of all 

1132 dihedrals that were adjusted. See the Example below. 

1133 

1134 Examples: 

1135 >>> from pathlib import Path 

1136 >>> import numpy as np 

1137 >>> import encodermap as em 

1138 >>> from pprint import pprint 

1139 >>> output_dir = Path( 

1140 ... em.get_from_kondata( 

1141 ... "OTU11", 

1142 ... mk_parentdir=True, 

1143 ... silence_overwrite_message=True, 

1144 ... ), 

1145 ... ) 

1146 >>> # assign how many backbone angles we need 

1147 >>> traj = em.load(output_dir / "OTU11_wt_only_prot.pdb") 

1148 >>> traj.load_CV("central_dihedrals") 

1149 >>> n_angles = traj.central_dihedrals.shape[-1] 

1150 >>> n_angles 

1151 732 

1152 >>> # create some fake dihedrals with a uniform distribution between -pi and pi 

1153 >>> dihedrals = np.random.uniform(low=-np.pi, high=np.pi, size=(5, n_angles)) 

1154 >>> out, index = em.misc.backmapping.mdtraj_backmapping( 

1155 ... top=output_dir / "OTU11_wt_only_prot.pdb", 

1156 ... dihedrals=dihedrals, 

1157 ... remove_component_size=10, 

1158 ... return_indices=True, 

1159 ... ) 

1160 >>> out = em.SingleTraj(out) 

1161 >>> out.load_CV("central_dihedrals") 

1162 >>> # Here you will see, what indicies were automatically dropped during backmapping 

1163 >>> # They will be proline phi angles, as these angles can not be 

1164 >>> # freely rotated 

1165 >>> all_coords = set(out._CVs.coords["CENTRAL_DIHEDRALS"].values) 

1166 >>> indexed_coords = set(index['dihedrals_labels']) 

1167 >>> pprint(all_coords - indexed_coords) 

1168 {'CENTERDIH PHI RESID PRO: 8 CHAIN 0', 

1169 'CENTERDIH PHI RESID PRO: 70 CHAIN 0', 

1170 'CENTERDIH PHI RESID PRO: 73 CHAIN 0', 

1171 'CENTERDIH PHI RESID PRO: 80 CHAIN 0', 

1172 'CENTERDIH PHI RESID PRO: 151 CHAIN 0', 

1173 'CENTERDIH PHI RESID PRO: 200 CHAIN 0', 

1174 'CENTERDIH PHI RESID PRO: 205 CHAIN 0', 

1175 'CENTERDIH PHI RESID PRO: 231 CHAIN 0', 

1176 'CENTERDIH PHI RESID PRO: 234 CHAIN 0', 

1177 'CENTERDIH PHI RESID PRO: 238 CHAIN 0'} 

1178 

1179 

1180 Raises: 

1181 Exception: If the input seems like it is in degrees. 

1182 Exception: If top is not part of the TrajEnsemble class provided in argument `trajs`. 

1183 

1184 Returns: 

1185 mdtraj.Trajectory: An MDTraj trajectory with the correct dihedrals/side-dihedrals. 

1186 

1187 """ 

1188 # Third Party Imports 

1189 import networkx as nx 

1190 

1191 # if `dihedrals` is None we sample them from a random uniform distribution 

1192 if dihedrals is None and sidechain_dihedrals is None: 

1193 if trajs is not None: 

1194 if top is None: 

1195 top = 0 

1196 if not hasattr(trajs, "central_dihedrals"): 

1197 trajs.load_CV("central_dihedrals") 

1198 if not hasattr(trajs, "side_dihedrals"): 

1199 trajs.load_CV("side_dihedrals") 

1200 with temp_seed(1): 

1201 dihedrals = np.random.uniform( 

1202 low=-np.pi, 

1203 high=np.pi, 

1204 size=(10, trajs[top].central_dihedrals.shape[-1]), 

1205 ) 

1206 sidechain_dihedrals = np.random.uniform( 

1207 low=-np.pi, 

1208 high=np.pi, 

1209 size=(10, trajs[top].side_dihedrals.shape[-1]), 

1210 ) 

1211 elif top is not None and trajs is None: 

1212 assert not isinstance(top, int) or isinstance(top, md.Topology), ( 

1213 f"When providing no `dihedrals` to sample fake dihedrals from " 

1214 f"a random uniform distribution, you can't provide int or " 

1215 f"md.Topology for `top`. Please provide a str or Path." 

1216 ) 

1217 trajs = TrajEnsemble([top]) 

1218 top = 0 

1219 trajs.load_CVs(["central_dihedrals", "side_dihedrals"]) 

1220 dihedrals = np.random.uniform( 

1221 low=-np.pi, 

1222 high=np.pi, 

1223 size=(10, trajs[top].central_dihedrals.shape[-1]), 

1224 ) 

1225 sidechain_dihedrals = np.random.uniform( 

1226 low=-np.pi, high=np.pi, size=(10, trajs[top].side_dihedrals.shape[-1]) 

1227 ) 

1228 else: 

1229 raise Exception(f"Please provide either a `top` or `trajs` argument.") 

1230 

1231 # change and check the angles 

1232 if angle_type == "radian": 

1233 if np.any(dihedrals > np.pi): 

1234 raise Exception( 

1235 f"The argument `angle_type` is meant to specify, what angles " 

1236 f"(radian or degree) are provided for the argument `dihedrals`. " 

1237 f"This allows you to provide either to this function by just " 

1238 f"specifying this argument. You specified {angle_type} but some " 

1239 f"of your dihedrals are greater than pi." 

1240 ) 

1241 if sidechain_dihedrals is not None: 

1242 if np.any(sidechain_dihedrals > np.pi): 

1243 raise Exception( 

1244 f"The argument `angle_type` is meant to specify, what angles " 

1245 f"(radian or degree) are provided for the argument `sidechain_dihedrals`. " 

1246 f"This allows you to provide either to this function by just " 

1247 f"specifying this argument. You specified {angle_type} but some " 

1248 f"of your sidechain dihedrals are greater than pi." 

1249 ) 

1250 elif angle_type == "degree": 

1251 if np.all(dihedrals <= np.pi): 

1252 raise Exception( 

1253 f"The argument `angle_type` is meant to specify, what angles " 

1254 f"(radian or degree) are provided for the argument `dihedrals`. " 

1255 f"This allows you to provide either to this function by just " 

1256 f"specifying this argument. You specified {angle_type} but none " 

1257 f"of your dihedrals were greater than pi: {dihedrals}" 

1258 ) 

1259 dihedrals = np.deg2rad(dihedrals) 

1260 if sidechain_dihedrals is not None: 

1261 if np.all(sidechain_dihedrals <= np.pi): 

1262 raise Exception( 

1263 f"The argument `angle_type` is meant to specify, what angles " 

1264 f"(radian or degree) are provided for the argument `sidechain_dihedrals`. " 

1265 f"This allows you to provide either to this function by just " 

1266 f"specifying this argument. You specified {angle_type} but none " 

1267 f"of your sidechain dihedrals were greater than pi." 

1268 ) 

1269 sidechain_dihedrals = np.deg2rad(sidechain_dihedrals) 

1270 else: 

1271 raise Exception( 

1272 f"Argument `angle_type` must be either 'radian' or 'degree', " 

1273 f"you supplied: {angle_type}" 

1274 ) 

1275 

1276 # make sure the input has the same shape along the "frame" axis, that will be created. 

1277 if sidechain_dihedrals is not None: 

1278 assert len(dihedrals) == len(sidechain_dihedrals), ( 

1279 f"The number of provided dihedrals ({len(dihedrals)}) and " 

1280 f"sidechain dihedrals ({len(sidechain_dihedrals)}) must be the same." 

1281 ) 

1282 

1283 # either top or trajs has to be not None 

1284 if trajs is None: # pragma: no cover 

1285 if top is None: 

1286 raise Exception( 

1287 f"Please provide the path to a topology file" 

1288 f"(.pdb, .gro) to use for backmapping." 

1289 ) 

1290 elif isinstance(top, (str, Path)): 

1291 inp_trajs = TrajEnsemble([top]) 

1292 inp_trajs.load_CVs(["central_dihedrals", "side_dihedrals"]) 

1293 elif isinstance(top, int): 

1294 raise Exception( 

1295 f"When providing an int for `top`, pleas also provide a `em.TrajEnsemble` " 

1296 f"for argument `trajs.`" 

1297 ) 

1298 else: 

1299 raise ValueError( 

1300 f"Argument `top` must be of type str, int, or None, " 

1301 f"you provided: {type(top)}." 

1302 ) 

1303 elif isinstance(trajs, SingleTraj) or trajs.__class__.__name__ == "SingleTraj": 

1304 if isinstance(top, (int, Path, str)): 

1305 print( 

1306 "When providing `em.SingleTraj` for argument `trajs`, the argument " 

1307 "`top` will be ignored." 

1308 ) 

1309 assert "central_dihedrals" in trajs._CVs, ( 

1310 f"The provided traj, doesn't have the collective variable 'central_dihedrals' " 

1311 f"loaded. Please load them by calling: `traj.load_CVs('all')." 

1312 ) 

1313 if sidechain_dihedrals is not None: 

1314 assert "side_dihedrals" in trajs._CVs, ( 

1315 f"The provided traj, doesn't have the collective variable 'central_dihedrals' " 

1316 f"loaded. Please load them by calling: `traj.load_CVs('all')." 

1317 ) 

1318 inp_trajs = trajs._gen_ensemble() 

1319 elif isinstance(trajs, TrajEnsemble) or trajs.__class__.__name__ == "TrajEnsemble": 

1320 assert "central_dihedrals" in trajs._CVs, ( 

1321 f"The provided traj, doesn't have the collective variable 'central_dihedrals' " 

1322 f"loaded. Please load them by calling: `traj.load_CVs('all')." 

1323 ) 

1324 if sidechain_dihedrals is not None: 

1325 assert "side_dihedrals" in trajs._CVs, ( 

1326 f"The provided traj, doesn't have the collective variable 'central_dihedrals' " 

1327 f"loaded. Please load them by calling: `traj.load_CVs('all')." 

1328 ) 

1329 if isinstance(top, (str, Path)) and os.path.isfile(top): 

1330 print( 

1331 "When providing `em.TrajEnsemble` for argument `trajs`, the argument " 

1332 "`top` will be ignored if it points to a file." 

1333 ) 

1334 elif isinstance(top, str) and top in trajs.common_str: 

1335 _trajs_index = trajs.trajs_by_common_str[top][0].traj_num 

1336 elif top is None: 

1337 _trajs_index = 0 

1338 elif isinstance(top, (int, np.int64)): 

1339 # _trajs_index = top 

1340 _trajs_index = trajs.trajs_by_top[trajs.top[top]][0].traj_num 

1341 else: 

1342 raise ValueError( 

1343 f"Argument `top` must be of type str, int, md.Topology or None, " 

1344 f"you provided: {type(top)}." 

1345 ) 

1346 inp_trajs = trajs.tsel[_trajs_index]._gen_ensemble() 

1347 else: 

1348 raise ValueError( 

1349 f"Argument `trajs` must be of type `em.SingleTraj`, `em.TrajEnsemble`, or None, " 

1350 f"you provided: {type(trajs)}." 

1351 ) 

1352 

1353 if trajs is None: 

1354 trajs = inp_trajs 

1355 

1356 # now we match the names of the featurizer 

1357 all_central_indices = trajs._CVs.central_dihedrals.coords[ 

1358 trajs._CVs.central_dihedrals.attrs["feature_axis"] 

1359 ] 

1360 central_indices = all_central_indices[ 

1361 np.all(~np.isnan(inp_trajs._CVs.central_dihedrals.values[0]), axis=0) 

1362 ] 

1363 if sidechain_dihedrals is not None: 

1364 all_side_indices = trajs._CVs.side_dihedrals.coords[ 

1365 trajs._CVs.side_dihedrals.attrs["feature_axis"] 

1366 ] 

1367 side_indices = all_side_indices[ 

1368 np.all(~np.isnan(inp_trajs._CVs.side_dihedrals.values[0]), axis=0) 

1369 ] 

1370 if not omega: 

1371 central_indices = central_indices[ 

1372 ~central_indices.str.lower().str.contains("omega") 

1373 ] 

1374 if sidechain_dihedrals is not None: 

1375 side_indices = side_indices[~side_indices.str.lower().str.contains("omega")] 

1376 

1377 generic_labels = not any( 

1378 central_indices.coords["CENTRAL_DIHEDRALS"].str.contains("RESID") 

1379 ) 

1380 

1381 if generic_labels: 

1382 # we have generic labels and add non-generic ones 

1383 _back_labels = { 

1384 "generic_dihedrals_labels": central_indices.values, 

1385 "dihedrals_labels": np.asarray( 

1386 features.CentralDihedrals(inp_trajs[0], omega=omega).describe() 

1387 ), 

1388 } 

1389 if sidechain_dihedrals is not None: 

1390 _back_labels |= { 

1391 "generic_side_dihedrals_labels": side_indices.values, 

1392 "side_dihedrals_labels": np.asarray( 

1393 features.SideChainDihedrals(inp_trajs[0]).describe() 

1394 ), 

1395 } 

1396 else: 

1397 # we have non-generic labels and build generic ones 

1398 _back_labels = { 

1399 "dihedrals_labels": central_indices.values, 

1400 "generic_dihedrals_labels": np.asarray( 

1401 features.CentralDihedrals( 

1402 inp_trajs[0], omega=omega, generic_labels=True 

1403 ).describe() 

1404 ), 

1405 } 

1406 if sidechain_dihedrals is not None: 

1407 _back_labels |= { 

1408 "side_dihedrals_labels": side_indices.values, 

1409 "generic_side_dihedrals_labels": np.asarray( 

1410 features.SideChainDihedrals( 

1411 inp_trajs[0], generic_labels=True 

1412 ).describe() 

1413 ), 

1414 } 

1415 

1416 # check that all indices are present 

1417 if len(all_central_indices) >= len(central_indices): 

1418 if dihedrals.shape[1] == len(all_central_indices): 

1419 dih_indices = np.arange(len(all_central_indices))[ 

1420 np.in1d(all_central_indices, central_indices) 

1421 ] 

1422 dihedrals = dihedrals[:, dih_indices] 

1423 elif dihedrals.shape[1] == len(central_indices): 

1424 dih_indices = np.arange(dihedrals.shape[1]) 

1425 dihedrals = dihedrals[:, dih_indices] 

1426 else: 

1427 raise Exception( 

1428 f"The shape of the provided `dihedrals` is wrong, either provide " 

1429 f"an array with shape[1] = {len(central_indices)}, or " 

1430 f"{len(all_central_indices)}, " 

1431 f"your array has the shape {dihedrals.shape[1]}." 

1432 ) 

1433 else: 

1434 raise NotImplementedError 

1435 

1436 if sidechain_dihedrals is not None: 

1437 if len(all_side_indices) >= len(side_indices): 

1438 if sidechain_dihedrals.shape[1] == len(all_side_indices): 

1439 _side_indices_out = side_indices.copy() 

1440 side_indices = np.in1d(all_side_indices, side_indices) 

1441 msg = ( 

1442 f"Your supplied `sidechain_dihedrals` are misshaped. They are" 

1443 f"expected to have either shape (n, {len(side_indices)}, which" 

1444 f"matches the number of sidechain dihedrals in the specified" 

1445 f"topology: {inp_trajs[0].top} or a shape of (n, " 

1446 f"{len(all_side_indices)}) which matches the total number of " 

1447 f"possible sidechain angles in the provided `TrajEnsemble` with " 

1448 f"{trajs.top} different toplogies." 

1449 ) 

1450 assert side_indices.shape[0] == sidechain_dihedrals.shape[-1], msg 

1451 sidechain_dihedrals = sidechain_dihedrals[:, side_indices] 

1452 elif sidechain_dihedrals.shape[1] == len(side_indices): 

1453 _side_indices_out = side_indices.copy() 

1454 side_indices = np.arange(sidechain_dihedrals.shape[1]) 

1455 sidechain_dihedrals = sidechain_dihedrals[:, side_indices] 

1456 else: 

1457 raise Exception( 

1458 f"The shape of the provided `dihedrals` is wrong, either provide " 

1459 f"an array with shape[1] = {len(side_indices)}, or {len(all_side_indices)}, " 

1460 f"your array has the shape {sidechain_dihedrals.shape[1]}." 

1461 ) 

1462 else: 

1463 raise NotImplementedError 

1464 

1465 dih_indices = inp_trajs[0]._CVs.central_dihedrals_feature_indices.values[0] 

1466 if omega: 

1467 idx = ~np.all(np.isnan(dih_indices), axis=1) 

1468 else: 

1469 idx = ( 

1470 ~np.all(np.isnan(dih_indices), axis=1) 

1471 & ~all_central_indices.str.lower().str.contains("omega") 

1472 ).values 

1473 dih_indices = dih_indices[idx] 

1474 dih_indices = dih_indices.astype(int) 

1475 _dih_indices = deepcopy(dih_indices) 

1476 assert dih_indices.ndim == 2, f"Problem when calculating dihedrals {inp_trajs=}" 

1477 side_indices = inp_trajs[0]._CVs.side_dihedrals_feature_indices.values[0] 

1478 side_indices = side_indices[~np.all(np.isnan(side_indices), axis=1)] 

1479 side_indices = side_indices.astype(int) 

1480 _side_indices = deepcopy(side_indices) 

1481 assert isinstance(inp_trajs, TrajEnsemble) 

1482 

1483 # get indices of atoms for rotations 

1484 g = inp_trajs.top[0].to_bondgraph() 

1485 

1486 # can be used to visualize topology 

1487 # nx.draw(g, pos=nx.spring_layout(g)) 

1488 if not nx.is_connected(g): 

1489 # Third Party Imports 

1490 from networkx import connected_components 

1491 

1492 components = [*connected_components(g)] 

1493 if remove_component_size > 0: 

1494 component_sizes = sorted([len(c) for c in components]) 

1495 if any([i > remove_component_size for i in component_sizes[:-1]]): 

1496 _raise_components_exception( 

1497 components, trajs, top, remove_component_size 

1498 ) 

1499 offending_components = [] 

1500 for c in components: 

1501 if len(c) <= remove_component_size: 

1502 offending_components.extend([a.index for a in list(c)]) 

1503 inp_trajs[0].atom_slice(offending_components, invert=True) 

1504 g = inp_trajs.top[0].to_bondgraph() 

1505 else: 

1506 _raise_components_exception(components, trajs, top, remove_component_size) 

1507 

1508 # at this point dih_bond_indices has been defined. 

1509 dih_bond_indices = dih_indices[:, 1:3] 

1510 assert ( 

1511 dih_bond_indices.shape[0] == _dih_indices.shape[0] == dihedrals.shape[1] 

1512 ), f"{dih_bond_indices.shape=}, {_dih_indices.shape=}, {dihedrals.shape=} {omega=}" 

1513 

1514 # filter out the proline angles 

1515 dih_bond_atoms = np.dstack( 

1516 [ 

1517 [inp_trajs[0].top.atom(a).__str__() for a in dih_bond_indices[:, 0]], 

1518 [inp_trajs[0].top.atom(a).__str__() for a in dih_bond_indices[:, 1]], 

1519 ] 

1520 )[0] 

1521 indices = np.full(dihedrals.shape[1], 1) 

1522 assert indices.shape[0] == dihedrals.shape[1] 

1523 assert ( 

1524 dihedrals[:, indices].shape == dihedrals.shape 

1525 ), f"{dihedrals[:, indices].shape=} {dihedrals.shape=}" 

1526 for i, bond in enumerate(dih_bond_atoms): 

1527 if "PRO" in bond[0] and "PRO" in bond[1] and "N" in bond[0] and "CA" in bond[1]: 

1528 indices[i] = 0 

1529 indices = indices.astype(bool) 

1530 

1531 # get rid of the proline dihedrals 

1532 _back_labels["dihedrals_labels"] = _back_labels["dihedrals_labels"][indices] 

1533 _back_labels["generic_dihedrals_labels"] = _back_labels["generic_dihedrals_labels"][ 

1534 indices 

1535 ] 

1536 dih_indices = dih_indices[indices] 

1537 dih_bond_indices = dih_bond_indices[indices] 

1538 dihedrals = dihedrals[:, indices] 

1539 assert dihedrals.shape[1] == dih_indices.shape[0] == dih_bond_indices.shape[0] 

1540 dih_near_sides, dih_far_sides = _get_near_and_far_networkx( 

1541 g, 

1542 dih_bond_indices, 

1543 inp_trajs[0].top, 

1544 parallel=parallel, 

1545 ) 

1546 

1547 if sidechain_dihedrals is not None: 

1548 side_bond_indices = side_indices[:, 1:3] 

1549 assert ( 

1550 side_bond_indices.shape[0] 

1551 == _side_indices.shape[0] 

1552 == sidechain_dihedrals.shape[1] 

1553 ), ( 

1554 f"{side_bond_indices.shape=}, {_side_indices.shape=}, " 

1555 f"{sidechain_dihedrals.shape=}" 

1556 ) 

1557 # filter out the proline angles 

1558 side_bond_atoms = np.dstack( 

1559 [ 

1560 [inp_trajs[0].top.atom(a).__str__() for a in side_bond_indices[:, 0]], 

1561 [inp_trajs[0].top.atom(a).__str__() for a in side_bond_indices[:, 1]], 

1562 ] 

1563 )[0] 

1564 indices = np.full(sidechain_dihedrals.shape[1], 1) 

1565 assert indices.shape[0] == sidechain_dihedrals.shape[1] 

1566 assert ( 

1567 sidechain_dihedrals[:, indices].shape == sidechain_dihedrals.shape 

1568 ), f"{sidechain_dihedrals[:, indices].shape=} {sidechain_dihedrals.shape=}" 

1569 for i, bond in enumerate(side_bond_atoms): 

1570 if ( 

1571 "PRO" in bond[0] 

1572 and "PRO" in bond[1] 

1573 and "CA" in bond[0] 

1574 and "CB" in bond[1] 

1575 ): 

1576 indices[i] = 0 

1577 if ( 

1578 "PRO" in bond[0] 

1579 and "PRO" in bond[1] 

1580 and "CB" in bond[0] 

1581 and "CG" in bond[1] 

1582 ): 

1583 indices[i] = 0 

1584 indices = indices.astype(bool) 

1585 _back_labels["side_dihedrals_labels"] = _back_labels["side_dihedrals_labels"][ 

1586 indices 

1587 ] 

1588 _back_labels["generic_side_dihedrals_labels"] = _back_labels[ 

1589 "generic_side_dihedrals_labels" 

1590 ][indices] 

1591 _side_indices_out = _side_indices_out[indices] 

1592 side_indices = side_indices[indices] 

1593 side_bond_indices = side_bond_indices[indices] 

1594 sidechain_dihedrals = sidechain_dihedrals[:, indices] 

1595 

1596 side_near_sides, side_far_sides = _get_near_and_far_networkx( 

1597 g, 

1598 side_bond_indices, 

1599 inp_trajs[0].top, 

1600 parallel=parallel, 

1601 ) 

1602 

1603 # assert that the dihedrals and _back_labels have the correct shape 

1604 # that way we can be sure to use the labels to index the correct dihedrals 

1605 # after obtaining the finished trajectory 

1606 assert ( 

1607 _back_labels["dihedrals_labels"].shape 

1608 == _back_labels["generic_dihedrals_labels"].shape 

1609 ) 

1610 assert _back_labels["dihedrals_labels"].shape[0] == dihedrals.shape[-1] 

1611 if sidechain_dihedrals is not None: 

1612 assert ( 

1613 _back_labels["side_dihedrals_labels"].shape 

1614 == _back_labels["generic_side_dihedrals_labels"].shape 

1615 ) 

1616 assert ( 

1617 _back_labels["side_dihedrals_labels"].shape[0] 

1618 == sidechain_dihedrals.shape[-1] 

1619 ) 

1620 

1621 # extend the traj 

1622 for i in range(len(dihedrals)): 

1623 if i == 0: 

1624 out_traj = deepcopy(inp_trajs[0][0].traj) 

1625 else: 

1626 out_traj = out_traj.join(inp_trajs[0][0].traj) 

1627 out_traj.top = inp_trajs[0].top 

1628 

1629 # adjust the torsions 

1630 new_xyz = np.ascontiguousarray(out_traj.xyz.copy().astype("float32")) 

1631 new_xyz = np.pad( 

1632 new_xyz, ((0, 0), (0, 0), (0, 1)), mode="constant", constant_values=1 

1633 ) 

1634 dihedrals = dihedrals.astype("float32") 

1635 dih_indices = dih_indices.astype("int32") 

1636 total_counts = dihedrals.shape[0] * dihedrals.shape[1] 

1637 if sidechain_dihedrals is not None: 

1638 total_counts += dihedrals.shape[0] * sidechain_dihedrals.shape[1] 

1639 sidechain_dihedrals = sidechain_dihedrals.astype("float32") 

1640 if guess_sp2_atoms: 

1641 total_counts += _countprog(out_traj) 

1642 if parallel: 

1643 raise Exception(f"Parallel has not yet been tested.") 

1644 parallel_rotation_application( 

1645 new_xyz, 

1646 dih_indices, 

1647 dihedrals, 

1648 dih_near_sides, 

1649 ) 

1650 else: 

1651 pbar = progbar 

1652 if pbar is None: 

1653 if _is_notebook(): 

1654 pbar = notebook_tqdm( 

1655 total=total_counts, 

1656 leave=False, 

1657 position=0, 

1658 desc="Backmapping...", 

1659 ) 

1660 else: 

1661 pbar = normal_tqdm( 

1662 total=total_counts, 

1663 leave=False, 

1664 position=0, 

1665 desc="Backmapping...", 

1666 ) 

1667 else: 

1668 if not isinstance(pbar, bool): 

1669 pbar.reset(pbar.total + total_counts - 1) 

1670 else: 

1671 pbar = None 

1672 

1673 for i in range(dihedrals.shape[0]): 

1674 for j in range(dihedrals.shape[1]): 

1675 # central_dihedrals 

1676 near_side = dih_near_sides[j] 

1677 far_side = dih_far_sides[j] 

1678 dihedral = dih_indices[j] 

1679 bond = dih_bond_indices[j] 

1680 

1681 # define inputs 

1682 target_angle = dihedrals[i, j] 

1683 current_angle = _dihedral(new_xyz[i, :, :3], dihedral)[0][0] 

1684 angle = target_angle - current_angle 

1685 direction = np.diff(new_xyz[i, bond, :3], axis=0).flatten() 

1686 pivot_point = new_xyz[i, bond[0], :3] 

1687 

1688 # perform rotation 

1689 rotmat = transformations_rotation_matrix(angle, direction, pivot_point) 

1690 new_xyz[i, far_side, :3] = rotmat.dot(new_xyz[i, far_side].T).T[:, :3] 

1691 

1692 # verify 

1693 if verify_every_rotation: 

1694 _ = _dihedral(new_xyz[i, :, :3], dihedral)[0][0] 

1695 if not np.isclose(_, target_angle, atol=1e-3): 

1696 _resids = [ 

1697 str(inp_trajs[0].top.atom(x).residue.index) 

1698 for x in dihedral 

1699 ] 

1700 s = ( 

1701 f"Adjusting dihedral angle for atoms " 

1702 f"{[str(inp_trajs[0].top.atom(x)) for x in dihedral]} " 

1703 f"failed with an absolute tolerance of 1e-3. " 

1704 f"Residue indices are: " 

1705 f"{_resids}" 

1706 f"\nTarget angle was {target_angle} {angle_type}, " 

1707 f"but rotation yielded angle with {_} {angle_type}." 

1708 f"\nCurrent angle was {current_angle}. To reach " 

1709 f"target angle is a rotation of {angle} {angle_type} " 

1710 f"was carried out." 

1711 f"\nRotation axis was vector from " 

1712 f"{inp_trajs[0].top.atom(bond[0])} to " 

1713 f"{inp_trajs[0].top.atom(bond[1])}" 

1714 f"\nOnly these atoms should have been affected by " 

1715 f"rotation: {far_side}" 

1716 "\nBut somehow this method still crashed. Maybe " 

1717 "these prints will help." 

1718 ) 

1719 raise Exception(s) 

1720 if pbar is not None: 

1721 pbar.update() 

1722 

1723 if sidechain_dihedrals is not None: 

1724 for j in range(sidechain_dihedrals.shape[1]): 

1725 # central_dihedrals 

1726 near_side = side_near_sides[j] 

1727 far_side = side_far_sides[j] 

1728 dihedral = side_indices[j] 

1729 bond = side_bond_indices[j] 

1730 

1731 # define inputs 

1732 target_angle = sidechain_dihedrals[i, j] 

1733 current_angle = _dihedral(new_xyz[i, :, :3], dihedral) 

1734 angle = target_angle - current_angle 

1735 direction = np.diff(new_xyz[i, bond, :3], axis=0).flatten() 

1736 pivot_point = new_xyz[i, bond[0], :3] 

1737 

1738 # perform rotation 

1739 rotmat = transformations_rotation_matrix( 

1740 angle, direction, pivot_point 

1741 ) 

1742 rotmat = transformations_rotation_matrix( 

1743 angle, direction, pivot_point 

1744 ) 

1745 new_xyz[i, far_side, :3] = rotmat.dot(new_xyz[i, far_side].T).T[ 

1746 :, :3 

1747 ] 

1748 

1749 if verify_every_rotation: 

1750 _ = _dihedral(new_xyz[i, :, :3], dihedral)[0][0] 

1751 if not np.isclose(_, target_angle, atol=1e-3): 

1752 _resids = [ 

1753 str(inp_trajs[0].top.atom(x).residue.index) 

1754 for x in dihedral 

1755 ] 

1756 s = ( 

1757 f"Adjusting dihedral angle for atoms " 

1758 f"{[str(inp_trajs[0].top.atom(x)) for x in dihedral]} " 

1759 f"failed with an absolute tolerance of 1e-3. " 

1760 f"Residue indices are: " 

1761 f"{_resids}" 

1762 f"\nTarget angle was {target_angle} {angle_type}, " 

1763 f"but rotation yielded angle with {_} {angle_type}." 

1764 f"\nCurrent angle was {current_angle}. To reach target " 

1765 f"angle is a rotation of {angle} {angle_type} was " 

1766 f"carried out.\nRotation axis was vector from " 

1767 f"{inp_trajs[0].top.atom(bond[0])} to " 

1768 f"{inp_trajs[0].top.atom(bond[1])}" 

1769 f"\nOnly these atoms should have been affected by " 

1770 f"rotation: {far_side}\nBut somehow this method " 

1771 f"still crashed. Maybe these prints will help." 

1772 ) 

1773 raise Exception(s) 

1774 if pbar is not None: 

1775 pbar.update() 

1776 

1777 # overwrite traj and return 

1778 out_traj.xyz = new_xyz[..., :3] 

1779 

1780 # fix the amide atoms 

1781 if guess_sp2_atoms: 

1782 _guess_sp2_atoms(out_traj, pbar) 

1783 

1784 if not return_indices: 

1785 return out_traj 

1786 if return_indices: 

1787 return out_traj, _back_labels 

1788 

1789 

1790def _countprog(traj): 

1791 total = 0 

1792 for i, r in enumerate(traj.top.residues): 

1793 for j in range(traj.n_frames): 

1794 if r: 

1795 continue 

1796 if i > 0: 

1797 total += 1 

1798 if i < traj.n_residues - 2: 

1799 total += 1 

1800 return total 

1801 

1802 

1803def _guess_sp2_atoms( 

1804 traj: md.Trajectory, 

1805 pbar: Optional[Any] = None, 

1806) -> None: 

1807 # Third Party Imports 

1808 from scipy.spatial.transform import Rotation as R 

1809 

1810 Ns = traj.top.select("name N") 

1811 CAs = traj.top.select("name CA") 

1812 Cs = traj.top.select("name C") 

1813 assert len(Ns) == len(CAs) == len(Cs) == traj.n_residues, ( 

1814 f"I could not determine the correct number of backbone atoms for this " 

1815 f"protein of {traj.n_residues} residues. I expected {traj.n_residues} " 

1816 f"nitrogen atoms, but got {len(Ns)}. I expected {traj.n_residues} " 

1817 f"alpha carbon atoms, but got {len(CAs)}. I expected {traj.n_residues} " 

1818 f"carboxylic carbons, but got {len(Cs)}. Maybe your protein contains " 

1819 f"non-standard residues." 

1820 ) 

1821 for i, (r, N, CA, C) in enumerate(zip(traj.top.residues, Ns, CAs, Cs)): 

1822 for j, frame in enumerate(traj): 

1823 if r.name == "PRO": 

1824 continue 

1825 N_pos = frame.xyz[0, N] 

1826 CA_pos = frame.xyz[0, CA] 

1827 C_pos = frame.xyz[0, C] 

1828 if i > 0: 

1829 H = next(r.atoms_by_name("H")).index 

1830 C_prev_pos = frame.xyz[0, Cs[i - 1]] 

1831 v1 = CA_pos - N_pos 

1832 v2 = C_prev_pos - N_pos 

1833 n = np.cross(v1, v2) 

1834 n /= np.linalg.norm(n) 

1835 n *= 123 / 180 * np.pi 

1836 M = R.from_rotvec(n) 

1837 new_H_pos = v1 @ M.as_matrix() 

1838 new_H_pos /= np.linalg.norm(new_H_pos) 

1839 new_H_pos *= 0.11 

1840 new_H_pos += N_pos 

1841 traj.xyz[j, H] = new_H_pos 

1842 if pbar is not None: 

1843 pbar.update() 

1844 if i < traj.n_residues - 2: 

1845 O = next(r.atoms_by_name("O")).index 

1846 N_next_pos = frame.xyz[0, Ns[i + 1]] 

1847 v1 = CA_pos - C_pos 

1848 v2 = N_next_pos - C_pos 

1849 n = np.cross(v1, v2) 

1850 n /= np.linalg.norm(n) 

1851 n *= 121 / 180 * np.pi 

1852 M = R.from_rotvec(n) 

1853 new_O_pos = v1 @ M.as_matrix() 

1854 new_O_pos /= np.linalg.norm(new_O_pos) 

1855 new_O_pos *= 0.124 

1856 new_O_pos += C_pos 

1857 traj.xyz[j, O] = new_O_pos 

1858 # actual_distance = np.linalg.norm(frame.xyz[0, C] - frame.xyz[0, O]) 

1859 # u = CA_pos - C_pos 

1860 # v = frame.xyz[0, O] - C_pos 

1861 # actual_angle = np.arccos(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))) 

1862 # print( 

1863 # f"In the protein, the atoms {traj.top.atom(C)} and {traj.top.atom(O)} are " 

1864 # f"{actual_distance:.3f} nm apart with a CA-C-O angle of {actual_angle:.3f} rad. " 

1865 # f"Setting the angle to {121 / 180 * np.pi:.3f} and the distance " 

1866 # f"to 0.124 nm, I get the position of to be {new_O_pos=}, instead " 

1867 # f"of {frame.xyz[0, O]}." 

1868 # ) 

1869 if pbar is not None: 

1870 pbar.update() 

1871 

1872 

1873def dihedral_to_cartesian_tf_one_way_layers( 

1874 dihedrals: tf.Tensor, 

1875 cartesian: tf.Tensor, 

1876 n: int, 

1877) -> None: 

1878 """Takes one of the cartesian sides (left, or right) and turns them into 3D. 

1879 

1880 The left or right sides of the cartesian chain can have different number 

1881 of atoms in them, depending on whether the protein has an even or odd 

1882 number of central atoms. 

1883 

1884 Args: 

1885 dihedrals: tf.Tensor: The dihedrals with shape (batch, None) 

1886 cartesian: tf.Tensor: The cartesians with shape (batch, None, 3). 

1887 n: int: The range variable for iteration. Must be the same as dihedrals.shape[-1]. 

1888 

1889 Returns: 

1890 tf.Tensor: The output. 

1891 

1892 """ 

1893 tf.debugging.assert_rank(dihedrals, 2) 

1894 dihedrals = -dihedrals 

1895 

1896 rotated = cartesian[:, 1:] 

1897 collected_cartesians = [cartesian[:, :1]] 

1898 for i in range(n): 

1899 collected_cartesians.append(rotated[:, 0:1]) 

1900 axis = rotated[:, 1] - rotated[:, 0] 

1901 # axis /= tf.norm(axis, axis=1, keepdims=True) 

1902 # numeric problems with tf.norm 

1903 # see here: 

1904 # https://datascience.stackexchange.com/q/80898 

1905 axis /= tf.expand_dims(tf.sqrt(tf.reduce_sum(tf.square(axis), axis=1)), axis=-1) 

1906 offset = rotated[:, 1:2] 

1907 rotated = offset + tf.matmul( 

1908 rotated[:, 1:] - offset, rotation_matrix(axis, dihedrals[:, i]) 

1909 ) 

1910 collected_cartesians.append(rotated) 

1911 collected_cartesians = tf.concat(collected_cartesians, axis=1) 

1912 return collected_cartesians 

1913 

1914 

1915def backbone_hydrogen_oxygen_crossproduct(backbone_positions): 

1916 assert backbone_positions.shape[2] % 3 == 0 # C, CA, N atoms, multiple of three 

1917 pass 

1918 

1919 

1920def guess_sp2_atom(cartesians, indices, angle_to_previous, bond_length): 

1921 added_cartesians = [] 

1922 for i in indices: 

1923 prev_vec = cartesians[:, i - 1] - cartesians[:, i] 

1924 try: 

1925 next_vec = cartesians[:, i + 1] - cartesians[:, i] 

1926 except tf.errors.InvalidArgumentError: 

1927 next_vec = cartesians[:, i - 2] - cartesians[:, i] 

1928 

1929 perpendicular_axis = tf.linalg.cross(prev_vec, next_vec) 

1930 perpendicular_axis /= tf.norm(perpendicular_axis, axis=1, keepdims=True) 

1931 bond_vec = tf.matmul( 

1932 tf.expand_dims(prev_vec, 1), 

1933 rotation_matrix(perpendicular_axis, angle_to_previous), 

1934 ) 

1935 bond_vec = bond_vec[:, 0, :] 

1936 bond_vec *= bond_length / tf.norm(bond_vec, axis=1, keepdims=True) 

1937 added_cartesians.append(cartesians[:, i] + bond_vec) 

1938 added_cartesians = tf.stack(added_cartesians, axis=1) 

1939 return added_cartesians 

1940 

1941 

1942def guess_amide_H(cartesians, N_indices): 

1943 return guess_sp2_atom(cartesians, N_indices[1::], 123 / 180 * pi, 1.10) 

1944 

1945 

1946def guess_amide_O(cartesians, C_indices): 

1947 return guess_sp2_atom(cartesians, C_indices, 121 / 180 * pi, 1.24) 

1948 

1949 

1950def rotation_matrix(axis_unit_vec, angle): 

1951 angle = tf.expand_dims(tf.expand_dims(angle, axis=-1), axis=-1) 

1952 i = tf.expand_dims(tf.eye(3), 0) 

1953 zeros = tf.zeros(tf.shape(axis_unit_vec)[0]) 

1954 cross_prod_matrix = tf.convert_to_tensor( 

1955 [ 

1956 [zeros, -axis_unit_vec[:, 2], axis_unit_vec[:, 1]], 

1957 [axis_unit_vec[:, 2], zeros, -axis_unit_vec[:, 0]], 

1958 [-axis_unit_vec[:, 1], axis_unit_vec[:, 0], zeros], 

1959 ] 

1960 ) 

1961 cross_prod_matrix = tf.transpose(cross_prod_matrix, [2, 0, 1]) 

1962 r = tf.cos(angle) * i 

1963 r += tf.sin(angle) * cross_prod_matrix 

1964 axis_unit_vec = tf.expand_dims(axis_unit_vec, 2) 

1965 r += (1 - tf.cos(angle)) * tf.matmul( 

1966 axis_unit_vec, tf.transpose(axis_unit_vec, [0, 2, 1]) 

1967 ) 

1968 return r 

1969 

1970 

1971def merge_cartesians( 

1972 central_cartesians, N_indices, O_indices, H_cartesians, O_cartesians 

1973): 

1974 cartesian = [central_cartesians[:, 0]] 

1975 h_i = 0 

1976 o_i = 0 

1977 for i in range(1, central_cartesians.shape[1]): 

1978 cartesian.append(central_cartesians[:, i]) 

1979 if i in N_indices[1::]: 

1980 cartesian.append(H_cartesians[:, h_i]) 

1981 h_i += 1 

1982 elif i in O_indices: 

1983 cartesian.append(O_cartesians[:, o_i]) 

1984 o_i += 1 

1985 cartesian = tf.stack(cartesian, axis=1) 

1986 assert ( 

1987 cartesian.shape[1] 

1988 == central_cartesians.shape[1] + H_cartesians.shape[1] + O_cartesians.shape[1] 

1989 ) 

1990 return cartesian 

1991 

1992 

1993def dihedral_backmapping( 

1994 pdb_path, dihedral_trajectory, rough_n_points=-1, sidechains=None 

1995): 

1996 """ 

1997 Takes a pdb file with a peptide and creates a trajectory based on the dihedral angles given. 

1998 It simply rotates around the dihedral angle axis. In the result side-chains might overlap but the backbone should 

1999 turn out quite well. 

2000 

2001 :param pdb_path: (str) 

2002 :param dihedral_trajectory: 

2003 array-like of shape (traj_length, number_of_dihedrals) 

2004 :param rough_n_points: (int) a step_size to select a subset of values from dihedral_trajectory is calculated by 

2005 max(1, int(len(dihedral_trajectory) / rough_n_points)) with rough_n_points = -1 all values are used. 

2006 :return: (MDAnalysis.Universe) 

2007 """ 

2008 step_size = max(1, int(len(dihedral_trajectory) / rough_n_points)) 

2009 dihedral_trajectory = dihedral_trajectory[::step_size] 

2010 if sidechains is not None: 

2011 sidechain_dihedral_trajectory = sidechains[::step_size] 

2012 

2013 uni = mda.Universe(pdb_path, format="PDB") 

2014 protein = uni.select_atoms("protein") 

2015 

2016 dihedrals = [] 

2017 sidechain_dihedrals = [] 

2018 

2019 for residue in protein.residues: 

2020 psi = residue.psi_selection() 

2021 if psi: 

2022 dihedrals.append(psi) 

2023 

2024 for residue in protein.residues: 

2025 omega = residue.omega_selection() 

2026 if omega: 

2027 dihedrals.append(omega) 

2028 

2029 for residue in protein.residues: 

2030 phi = residue.phi_selection() 

2031 if phi: 

2032 dihedrals.append(phi) 

2033 

2034 if sidechains is not None: 

2035 for residue in protein.residues: 

2036 chi1 = residue.chi1_selection() 

2037 if chi1: 

2038 sidechain_dihedrals.append(chi1) 

2039 

2040 for residue in protein.residues: 

2041 if "chi2" in residue.__dir__(): 

2042 sidechain_dihedrals.append(residue.chi2_selection()) 

2043 

2044 for residue in protein.residues: 

2045 if "chi3" in residue.__dir__(): 

2046 sidechain_dihedrals.append(residue.chi3_selection()) 

2047 

2048 for residue in protein.residues: 

2049 if "chi4" in residue.__dir__(): 

2050 sidechain_dihedrals.append(residue.chi4_selection()) 

2051 

2052 for residue in protein.residues: 

2053 if "chi5" in residue.__dir__(): 

2054 sidechain_dihedrals.append(residue.chi5_selection()) 

2055 

2056 if sidechains is not None: 

2057 if sidechain_dihedral_trajectory.shape[1] == len(sidechain_dihedrals) * 2: 

2058 sidechain_dihedral_trajectory = sidechain_dihedral_trajectory[:, ::2] 

2059 

2060 _expand_universe(uni, len(dihedral_trajectory)) 

2061 

2062 if sidechains is None: 

2063 for dihedral_values, step in zip(dihedral_trajectory, uni.trajectory): 

2064 for dihedral, value in zip(dihedrals, dihedral_values): 

2065 _set_dihedral(dihedral, protein, value / (2 * pi) * 360) 

2066 else: 

2067 for dihedral_values, sidechain_dihedral_values, step in zip( 

2068 dihedral_trajectory, sidechain_dihedral_trajectory, uni.trajectory 

2069 ): 

2070 for dihedral, value in zip(dihedrals, dihedral_values): 

2071 _set_dihedral(dihedral, protein, value / (2 * pi) * 360) 

2072 for dihedral, value in zip(sidechain_dihedrals, sidechain_dihedral_values): 

2073 _set_dihedral(dihedral, protein, value / (2 * pi) * 360) 

2074 return uni 

2075 

2076 

2077def _set_dihedral(dihedral, atoms, angle): 

2078 current_angle = dihedral.dihedral.value() 

2079 head = atoms[dihedral[2].id :] 

2080 vec = dihedral[2].position - dihedral[1].position 

2081 head.rotateby(angle - current_angle, vec, dihedral[2].position) 

2082 

2083 

2084def _expand_universe(universe, length): 

2085 coordinates = ( 

2086 AnalysisFromFunction(lambda ag: ag.positions.copy(), universe.atoms) 

2087 .run() 

2088 .results 

2089 )["timeseries"] 

2090 coordinates = np.tile(coordinates, (length, 1, 1)) 

2091 universe.load_new(coordinates, format=MemoryReader)