Coverage for encodermap/models/layers.py: 16%

430 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-15 21:06 +0200

1# -*- coding: utf-8 -*- 

2# encodermap/models/layers.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Module that implements custom layers. Mainly needed for handling periodicity, 

23backmapping or sparsity.""" 

24 

25 

26################################################################################ 

27# Imports 

28################################################################################ 

29 

30 

31# Future Imports at the top 

32from __future__ import annotations 

33 

34# Standard Library Imports 

35import itertools 

36from collections.abc import Sequence 

37from math import pi 

38from typing import Any, Optional, Type, TypeVar, Union 

39 

40# Third Party Imports 

41import numpy as np 

42import tensorflow as tf 

43from scipy.linalg import block_diag 

44from tensorflow.keras.layers import Concatenate, Dense, Lambda, Layer 

45 

46# Encodermap imports 

47from encodermap.encodermap_tf1.backmapping import chain_in_plane 

48from encodermap.loss_functions.loss_classes import testing 

49from encodermap.misc.backmapping import ( 

50 dihedrals_to_cartesian_tf_layers, 

51 split_and_reverse_cartesians, 

52 split_and_reverse_dihedrals, 

53) 

54from encodermap.misc.distances import pairwise_dist 

55from encodermap.parameters.parameters import ADCParameters, Parameters 

56 

57 

58################################################################################ 

59# Typing 

60################################################################################ 

61 

62 

63BackMapLayerTransformationsType = TypeVar( 

64 "BackMapLayerTransformationsType", 

65 bound="BackMapLayerTransformations", 

66) 

67PeriodicOutputType = TypeVar( 

68 "PeriodicOutputType", 

69 bound="PeriodicOutput", 

70) 

71MeanAnglesType = TypeVar( 

72 "MeanAnglesType", 

73 bound="MeanAngles", 

74) 

75EncoderMapBaseLayerType = TypeVar( 

76 "EncoderMapBaseLayerType", 

77 bound="EncoderMapBaseLayer", 

78) 

79BackMapLayerType = TypeVar( 

80 "BackMapLayerType", 

81 bound="BackMapLayer", 

82) 

83BackMapLayerWithSidechainsType = TypeVar( 

84 "BackMapLayerWithSidechainsType", 

85 bound="BackMapLayerWithSidechains", 

86) 

87 

88 

89################################################################################ 

90# Globals 

91################################################################################ 

92 

93 

94__all__: list[str] = ["PeriodicInput", "PeriodicOutput", "MeanAngles", "BackMapLayer"] 

95 

96 

97################################################################################ 

98# Layers 

99################################################################################ 

100 

101 

102@tf.keras.utils.register_keras_serializable() 

103class EncoderMapBaseLayer(Layer): 

104 """EncoderMap's Base Layer, that implements saving and loading parameters. 

105 

106 Classes that inherit from `EncoderMapBaseLayer` automatically receive 

107 parameters when deserialized. 

108 

109 """ 

110 

111 def __init__( 

112 self, 

113 parameters: Union[Parameters, ADCParameters], 

114 print_name: str, 

115 trainable: bool = False, 

116 ) -> None: 

117 """Instantiate the layer. 

118 

119 Args: 

120 parameters (Union[encocermap.parameters.Parameters, encocermap.parameters.ADCParameters]): An instance of 

121 encodermap's parameters. 

122 print_name (str): The name of this layer, as it should appear 

123 in summaries. 

124 trainable (bool): Whether this layer is trainable. As this layer 

125 has no kernel and/or bias. This argument has no influence. 

126 Defaults to False. 

127 

128 """ 

129 super().__init__() 

130 self.p = parameters 

131 self.print_name = print_name 

132 self._name = print_name 

133 self.trainable = trainable 

134 

135 @classmethod 

136 def from_config( 

137 cls: Type[EncoderMapBaseLayerType], 

138 config: dict[Any, Any], 

139 ) -> EncoderMapBaseLayerType: 

140 """Reconstructs this keras serializable from a dict. 

141 

142 Args: 

143 config (dict[Any, Any]): A dictionary. 

144 

145 Returns: 

146 EncoderMapBaseLayerType: An instance of the EncoderMapBaseLayer. 

147 

148 """ 

149 p = config.pop("p") 

150 if "cartesian_pwd_start" in p: 

151 p = ADCParameters(**p) 

152 else: 

153 p = Parameters(**p) 

154 return cls(parameters=p, **config) 

155 

156 def get_config(self) -> dict[Any, Any]: 

157 """Serializes this keras serializable. 

158 

159 Returns: 

160 dict[Any, Any]: A dict with the serializable objects. 

161 

162 """ 

163 config = super().get_config().copy() 

164 config.update( 

165 { 

166 "print_name": self.print_name, 

167 "p": self.p.to_dict(), 

168 } 

169 ) 

170 return config 

171 

172 

173@tf.keras.utils.register_keras_serializable() 

174class PeriodicInput(EncoderMapBaseLayer): 

175 """Layer that handles periodic input. Needed, if angles are treated. 

176 Input angles will be split into sin and cos components, 

177 and a tensor with shape[0] = 2 * inp_shape[0] will be returned 

178 """ 

179 

180 def __init__( 

181 self, 

182 parameters: Union[Parameters, ADCParameters], 

183 print_name: str, 

184 trainable: bool = False, 

185 **kwargs, 

186 ) -> None: 

187 """Instantiate the layer. Need parameters to get the info about the 

188 periodicity. Although angles are most often used, who knows what hyper-toroidal 

189 manifold your data lies in. 

190 

191 Args: 

192 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]): An instance of 

193 encodermap's parameters. 

194 print_name (str): The name of this layer, as it should appear 

195 in summaries. 

196 trainable (bool): Whether this layer is trainable. As this layer 

197 has no kernel and/or bias. This argument has no influence. 

198 Defaults to False. 

199 

200 """ 

201 super().__init__(parameters, print_name, trainable) 

202 

203 def call(self, inputs: tf.Tensor) -> tf.Tensor: 

204 """Call the layer.""" 

205 outputs = inputs 

206 if self.p.periodicity != 2 * pi: 

207 outputs = outputs / self.p.periodicity * 2 * pi 

208 outputs = Concatenate(axis=1, name=f"{self.print_name}_Concat")( 

209 [ 

210 tf.sin(outputs), 

211 tf.cos(outputs), 

212 ] 

213 ) 

214 return outputs 

215 

216 

217@tf.keras.utils.register_keras_serializable() 

218class BackMapLayerWithSidechains(Layer): 

219 """Also backmaps sidechains. For that, we need a way to know which 

220 distances, angles, dihedrals belong to the backbone, and which belong to 

221 a sidechain. See the docstring of `encodermap.misc.backmapping._full_backmapping_np` 

222 for details. 

223 

224 """ 

225 

226 def __init__( 

227 self, 

228 feature_description: Any, 

229 ) -> None: 

230 super().__init__() 

231 self.feature_description: dict[Any, Any] = feature_description 

232 

233 # Definitions and Tests 

234 n_residues: int = max(list(feature_description[-1].keys())) 

235 assert np.array_equal( 

236 np.arange(1, n_residues + 1), 

237 np.sort(np.asarray(list(feature_description[-1].keys()))), 

238 ), ( 

239 f"Currently the `feature_indices[-1]` dict needs to contain monotonous " 

240 f"increasing keys. Starting from 1 {feature_description[-1].keys()=}" 

241 ) 

242 n_sidechains: int = sum( 

243 [v + 1 for v in feature_description[-1].values() if v > 0] 

244 ) 

245 sum_sidechains = sum(list(feature_description[-1].values())) 

246 

247 # this can be defined beforehand and then stacked as often, as a batch needs it 

248 self.init_xyz: tf.Tensor = tf.zeros( 

249 shape=(1, n_residues * 3 + n_sidechains, 3), 

250 dtype=tf.float32, 

251 ) 

252 

253 # first we create the central_distance indices 

254 central_distance_indices = np.tri( 

255 N=n_residues * 3 - 1, 

256 M=n_residues * 3, 

257 k=0, 

258 ).astype(bool) 

259 right_side_central_distance_indices = [ 

260 np.full(shape=(1, n_sidechains), fill_value=False, dtype=bool) 

261 ] 

262 count = 0 # starts at the first atom of the central chan 

263 count2 = n_residues * 3 + 1 # starts at the first atom of the sidechain 

264 sidechain_cartesians_ind = [] 

265 sidechain_positions_indices = [] 

266 central_angle_index_triplets = np.vstack( 

267 [ 

268 np.arange(0, n_residues * 3)[:-2], 

269 np.arange(0, n_residues * 3)[1:-1], 

270 np.arange(0, n_residues * 3)[2:], 

271 ] 

272 ).T.tolist() 

273 sidechain_angle_index_triplets = [] 

274 central_dihedral_index_quadruplets = np.vstack( 

275 [ 

276 np.arange(0, n_residues * 3)[:-3], 

277 np.arange(0, n_residues * 3)[1:-2], 

278 np.arange(0, n_residues * 3)[2:-1], 

279 np.arange(0, n_residues * 3)[3:], 

280 ] 

281 ).T.tolist() 

282 sidechain_dihedral_index_quadruplets = [] 

283 

284 # iterate over feature_description[-1] to get all indices and the right side 

285 # of the central cartesians 

286 for i, (residue, n_sidechains_in_residue) in zip( 

287 itertools.count(1, 3), feature_description[-1].items() 

288 ): 

289 if n_sidechains_in_residue == 0: 

290 if residue == 1 or residue == n_residues: 

291 continue 

292 else: 

293 right_side_central_distance_indices.append(t) 

294 else: 

295 sidechain_cartesians_ind.append( 

296 np.arange(count, count + n_sidechains_in_residue) 

297 ) 

298 sidechain_positions_indices.append( 

299 [i] 

300 + np.arange(count2 - 1, count2 + n_sidechains_in_residue).tolist() 

301 ) 

302 for sidechain_i in range(n_sidechains_in_residue + 1): 

303 if sidechain_i == 0: 

304 # adds N-CA-CB 

305 sidechain_angle_index_triplets.append( 

306 [(residue - 1) * 3, (residue - 1) * 3 + 1, count2 - 1] 

307 ) 

308 # adds N-CA-CB-CG 

309 sidechain_dihedral_index_quadruplets.append( 

310 [ 

311 (residue - 1) * 3, 

312 (residue - 1) * 3 + 1, 

313 count2 - 1, 

314 count2, 

315 ] 

316 ) 

317 elif sidechain_i == 1: 

318 # adds CA-CB-CG 

319 sidechain_angle_index_triplets.append( 

320 [(residue - 1) * 3 + 1, count2 - 1, count2] 

321 ) 

322 # adds CA-CB-CG-CD 

323 if sidechain_i < n_sidechains_in_residue: 

324 sidechain_dihedral_index_quadruplets.append( 

325 [(residue - 1) * 3 + 1, count2 - 1, count2, count2 + 1] 

326 ) 

327 else: 

328 # adds CB-CG-CD and so on 

329 sidechain_angle_index_triplets.append( 

330 [ 

331 count2 + sidechain_i - 3, 

332 count2 + sidechain_i - 2, 

333 count2 + sidechain_i - 1, 

334 ] 

335 ) 

336 if sidechain_i < n_sidechains_in_residue: 

337 sidechain_dihedral_index_quadruplets.append( 

338 [ 

339 count2 + sidechain_i - 3, 

340 count2 + sidechain_i - 2, 

341 count2 + sidechain_i - 1, 

342 count2 + sidechain_i, 

343 ] 

344 ) 

345 count += n_sidechains_in_residue + 1 

346 count2 += n_sidechains_in_residue + 1 

347 t = np.zeros( 

348 shape=(3, n_sidechains), 

349 dtype=bool, 

350 ) 

351 t[:, :count] = True 

352 right_side_central_distance_indices.append(t) 

353 assert len(sidechain_angle_index_triplets) == n_sidechains 

354 assert len(sidechain_dihedral_index_quadruplets) == sum_sidechains, ( 

355 f"I could not reconstruct the correct number of sidechain dihedral " 

356 f"quadruplets. The number of sidechain dihedrals requires the list " 

357 f"to have length {sum_sidechains}, but I created a list with " 

358 f"{len(sidechain_dihedral_index_quadruplets)}." 

359 ) 

360 right_side_central_distance_indices.append( 

361 np.full(shape=(1, n_sidechains), fill_value=True, dtype=bool) 

362 ) 

363 right_side_central_distance_indices = np.vstack( 

364 right_side_central_distance_indices 

365 ) 

366 angle_index_triplets = np.vstack( 

367 central_angle_index_triplets + sidechain_angle_index_triplets 

368 ) 

369 dihedral_index_quadruplets = np.vstack( 

370 central_dihedral_index_quadruplets + sidechain_dihedral_index_quadruplets 

371 ) 

372 if sidechain_cartesians_ind != []: # if sidechains 

373 _use_sidechains = True 

374 sidechain_cartesians_ind = np.concatenate(sidechain_cartesians_ind) 

375 central_distance_indices = np.hstack( 

376 [central_distance_indices, right_side_central_distance_indices] 

377 ) 

378 side_distance_indices = [ 

379 (np.tri(N=i + 1, M=i + 2, k=0) + 1)[:, 1:] 

380 for i in feature_description[-1].values() 

381 if i > 0 

382 ] 

383 side_distance_indices = (block_diag(*side_distance_indices) % 2) == 0 

384 left_side_side_distance_indices = ( 

385 np.full( # all atoms in the central chain are True 

386 shape=(len(side_distance_indices), n_residues * 3), 

387 fill_value=True, 

388 dtype=bool, 

389 ) 

390 ) 

391 side_distance_indices = np.hstack( 

392 [left_side_side_distance_indices, side_distance_indices] 

393 ) 

394 distance_indices = np.vstack( 

395 [central_distance_indices, side_distance_indices] 

396 ) 

397 else: # if no sidechains 

398 _use_sidechains = False 

399 distance_indices = central_distance_indices 

400 assert distance_indices.shape == ( 

401 n_residues * 3 - 1 + n_sidechains, 

402 self.init_xyz.shape[1], 

403 ), ( 

404 f"The shape of the distance index after stacking is unexpected.\n" 

405 f"Expected: {(n_residues * 3 - 1 + n_sidechains, self.init_xyz.shape[1])}\n" 

406 f"Actual: {distance_indices.shape}" 

407 ) 

408 

409 # now the angles 

410 central_angle_indices = central_distance_indices[1:] 

411 if _use_sidechains: # if sidechains 

412 angle_indices = np.vstack( 

413 [central_distance_indices[1:], side_distance_indices] 

414 ) 

415 side_angle_indices = side_distance_indices 

416 else: # no sidechains 

417 angle_indices = central_distance_indices[1:] 

418 assert len(angle_indices) == len(distance_indices) - 1 

419 

420 # and the dihedrals 

421 if _use_sidechains: # if sidechains 

422 dihedral_indices = np.vstack( 

423 [ 

424 central_distance_indices[1:-1], 

425 side_distance_indices[sidechain_cartesians_ind], 

426 ] 

427 ) 

428 corrector = np.count_nonzero( 

429 list(feature_description[-1].values()) 

430 ) # per reisude with sidechain dihedrals one less 

431 else: 

432 dihedral_indices = central_distance_indices[1:-1] 

433 corrector = 0 

434 assert len(dihedral_indices) == len(distance_indices) - 2 - corrector 

435 assert angle_index_triplets.shape[0] == angle_indices.shape[0] 

436 assert dihedral_index_quadruplets.shape[0] == dihedral_indices.shape[0], ( 

437 f"The number of dihedral indices ({len(distance_indices)}) and quadruplets " 

438 f"does not match ({len(dihedral_index_quadruplets)})." 

439 ) 

440 

441 # create instance attributes with tf 

442 self._use_sidechains = tf.constant(_use_sidechains, dtype=tf.bool) 

443 self.n_sidechains_in_residue = tf.constant( 

444 np.array( 

445 [ 

446 self.feature_description[-1][k] 

447 for k in sorted(self.feature_description[-1].keys()) 

448 ] 

449 ).astype(np.int32), 

450 dtype=tf.int32, 

451 ) 

452 

453 # general 

454 self.up = tf.constant([[0.0, 0.0, 1.0]], dtype=tf.float32) 

455 self.down = tf.constant([[0.0, 0.0, -1.0]], dtype=tf.float32) 

456 

457 # distances 

458 self.no_of_central_distances = central_distance_indices.shape[0] 

459 self.n_sidechains = n_sidechains 

460 self.central_distance_indices = tf.constant( 

461 central_distance_indices, 

462 shape=central_distance_indices.shape, 

463 dtype=tf.bool, 

464 ) 

465 self.n_atoms = self.central_distance_indices.shape[1] 

466 

467 # angles 

468 self.no_of_central_angles = central_angle_indices.shape[0] 

469 self.no_of_side_angles = side_angle_indices.shape[0] 

470 self.central_angle_indices = tf.constant( 

471 central_angle_indices, shape=central_angle_indices.shape, dtype=tf.bool 

472 ) 

473 self.side_angle_indices = tf.constant( 

474 side_angle_indices, shape=side_angle_indices.shape, dtype=tf.bool 

475 ) 

476 self.central_angle_index_triplets = tf.constant( 

477 np.asarray(central_angle_index_triplets), 

478 shape=(len(central_angle_index_triplets), 3), 

479 dtype=tf.int32, 

480 ) 

481 self.sidechain_angle_index_triplets = tf.constant( 

482 np.asarray(sidechain_angle_index_triplets), 

483 shape=(len(sidechain_angle_index_triplets), 3), 

484 dtype=tf.int32, 

485 ) 

486 

487 # dihedrals 

488 self.no_of_dihedrals = dihedral_indices.shape[0] 

489 self.dihedral_indices = tf.constant( 

490 dihedral_indices, shape=dihedral_indices.shape, dtype=tf.bool 

491 ) 

492 self.dihedral_index_quadruplets = tf.constant( 

493 np.asarray(dihedral_index_quadruplets), 

494 shape=(len(dihedral_index_quadruplets), 4), 

495 dtype=tf.int32, 

496 ) 

497 

498 def get_config(self) -> dict[Any, Any]: 

499 """Serializes this keras serializable. 

500 

501 Returns: 

502 dict[Any, Any]: A dict with the serializable objects. 

503 

504 """ 

505 config = super().get_config().copy() 

506 config.update( 

507 { 

508 "feature_description": self.feature_description, 

509 } 

510 ) 

511 return config 

512 

513 @classmethod 

514 def from_config( 

515 cls: Type[BackMapLayerWithSidechainsType], 

516 config: dict[Any, Any], 

517 ) -> BackMapLayerWithSidechainsType: 

518 """Reconstructs this keras serializable from a dict. 

519 

520 Args: 

521 config (dict[Any, Any]): A dictionary. 

522 

523 Returns: 

524 BackMapLayerType: An instance of the BackMapLayer. 

525 

526 """ 

527 feature_description = config.pop("feature_description") 

528 out = {int(k): v for k, v in feature_description.items()} 

529 for k, v in out.items(): 

530 out[k] = {int(kv): vv for kv, vv in v.items()} 

531 return cls(feature_description=out) 

532 

533 def call(self, inputs: tuple[tf.Tensor, ...]) -> tf.Tensor: 

534 # Unpack inputs 

535 ( 

536 central_distances, 

537 central_angles, 

538 central_dihedrals, 

539 side_distances, 

540 side_angles, 

541 side_dihedrals, 

542 ) = inputs 

543 

544 # concatenate the dihedrals 

545 dihedrals = tf.concat( 

546 [ 

547 central_dihedrals, 

548 side_dihedrals, 

549 ], 

550 axis=1, 

551 ) 

552 

553 # distances 

554 xs_central = tf.TensorArray( 

555 dtype=tf.float32, 

556 size=self.no_of_central_distances + 1, 

557 clear_after_read=False, 

558 ) 

559 ys_central = tf.TensorArray( 

560 dtype=tf.float32, 

561 size=self.no_of_central_distances + 1, 

562 clear_after_read=False, 

563 ) 

564 xs_side = tf.TensorArray( 

565 dtype=tf.float32, 

566 size=self.n_sidechains, 

567 clear_after_read=False, 

568 ) 

569 ys_side = tf.TensorArray( 

570 dtype=tf.float32, 

571 size=self.n_sidechains, 

572 clear_after_read=False, 

573 ) 

574 xs_central = xs_central.write( 

575 0, tf.zeros((tf.shape(central_angles)[0],), dtype=tf.float32) 

576 ) 

577 ys_central = ys_central.write( 

578 0, tf.zeros((tf.shape(central_angles)[0],), dtype=tf.float32) 

579 ) 

580 # xs_central = [tf.zeros((tf.shape(central_angles)[0], ), dtype=tf.float32)] 

581 # ys_central = [tf.zeros((tf.shape(central_angles)[0], ), dtype=tf.float32)] 

582 # xs_side = [] 

583 # ys_side = [] 

584 

585 residue = 0 

586 idx = 0 

587 j = 0 

588 for i in range(self.no_of_central_distances): 

589 # xs_central.append(xs[-1] + central_distances[:, 1]) 

590 xs_central = xs_central.write( 

591 i + 1, xs_central.read(i) + central_distances[:, i] 

592 ) 

593 # ys_central.append(tf.zeros((tf.shape(central_angles)[0], ), dtype=tf.float32)) 

594 ys_central = ys_central.write( 

595 i + 1, tf.zeros((tf.shape(central_angles)[0],)) 

596 ) 

597 if idx == 0 and self._use_sidechains: 

598 n_sidechains = self.n_sidechains_in_residue[residue] 

599 if n_sidechains > 0: 

600 for n in range(n_sidechains + 1): 

601 # xs_side.append(xs_central.read(i)) 

602 xs_side = xs_side.write(j, xs_central.read(i + 1)) 

603 # ys_side.append( 

604 # tf.reduce_sum( 

605 # side_distances[:, j - n : j + 1], 

606 # axis=1, 

607 # ) 

608 # ) 

609 ys_side = ys_side.write( 

610 j, 

611 tf.reduce_sum( 

612 side_distances[:, j - n : j + 1], 

613 axis=1, 

614 ), 

615 ) 

616 j += 1 

617 idx += 1 

618 if idx >= 3: 

619 residue += 1 

620 idx = 0 

621 xs_central = tf.transpose(xs_central.stack(), perm=[1, 0]) 

622 ys_central = tf.transpose(ys_central.stack(), perm=[1, 0]) 

623 xs_side = tf.transpose(xs_side.stack(), perm=[1, 0]) 

624 ys_side = tf.transpose(ys_side.stack(), perm=[1, 0]) 

625 xs_side.set_shape((xs_central.shape[0], self.n_sidechains)) 

626 ys_side.set_shape((xs_central.shape[0], self.n_sidechains)) 

627 xs = tf.concat([xs_central, xs_side], axis=1) 

628 ys = tf.concat([ys_central, ys_side], axis=1) 

629 xyz_out = tf.stack( 

630 [ 

631 xs, 

632 ys, 

633 ], 

634 axis=2, 

635 ) 

636 xyz_out = tf.pad( 

637 tf.pad( 

638 xyz_out, 

639 ((0, 0), (0, 0), (0, 1)), 

640 constant_values=0, 

641 ), 

642 paddings=((0, 0), (0, 0), (0, 1)), 

643 constant_values=1, 

644 ) 

645 

646 # angles 

647 # Can't parallelize over angles (just over batch dimension) 

648 # because xyz_out is updated constantly and thus 

649 # xyz_out[..., -1] changes during iteration 

650 for i in range(self.no_of_central_angles): 

651 ind = self.central_angle_indices[i] 

652 angle_index = self.central_angle_index_triplets[i] 

653 ang = central_angles[:, i] 

654 direction = tf.repeat( 

655 self.up, 

656 repeats=tf.shape(ang)[0], 

657 axis=0, 

658 ) 

659 abc = tf.transpose( 

660 tf.gather( 

661 params=xyz_out, 

662 indices=angle_index, 

663 axis=1, 

664 batch_dims=0, 

665 )[..., :3], 

666 perm=[1, 0, 2], 

667 ) 

668 ba = abc[0] - abc[1] 

669 bc = abc[2] - abc[1] 

670 dot = tf.keras.backend.batch_dot( 

671 ba, 

672 bc, 

673 ) 

674 prod = tf.expand_dims(_batch_fro(ba) * _batch_fro(bc), axis=1) 

675 t = tf.clip_by_value(dot / prod, clip_value_min=-1, clip_value_max=1) 

676 current_angle = tf.squeeze(tf.acos(t)) 

677 angle = tf.abs(ang - current_angle) 

678 rotmat = _rotation_matrices( 

679 angle=angle, 

680 direction=direction, 

681 point=abc[1], 

682 ) 

683 dynamic = tf.transpose( 

684 tf.gather( 

685 params=xyz_out, indices=tf.where(~ind)[:, 0], axis=1, batch_dims=0 

686 ), 

687 perm=[0, 2, 1], 

688 ) 

689 rotated = tf.transpose( 

690 tf.keras.backend.batch_dot(rotmat, dynamic), 

691 perm=[0, 2, 1], 

692 ) 

693 static = tf.gather( 

694 params=xyz_out, indices=tf.where(ind)[:, 0], axis=1, batch_dims=0 

695 ) 

696 new = tf.TensorArray( 

697 dtype=tf.float32, 

698 size=self.no_of_central_distances + 1 + self.n_sidechains, 

699 clear_after_read=False, 

700 ) 

701 d = 0 

702 s = 0 

703 c = 0 

704 for j in ind: 

705 if j: 

706 new = new.write(c, static[:, s]) 

707 s += 1 

708 else: 

709 new = new.write(c, rotated[:, d]) 

710 d += 1 

711 c += 1 

712 xyz_out = tf.transpose( 

713 new.stack(), 

714 perm=[1, 0, 2], 

715 ) 

716 

717 # sidechains 

718 for i in range(self.no_of_side_angles): 

719 ind = self.side_angle_indices[i] 

720 angle_index = self.sidechain_angle_index_triplets[i] 

721 ang = side_angles[:, i] 

722 direction = tf.repeat( 

723 self.down, 

724 repeats=tf.shape(ang)[0], 

725 axis=0, 

726 ) 

727 abc = tf.transpose( 

728 tf.gather( 

729 params=xyz_out, 

730 indices=angle_index, 

731 axis=1, 

732 batch_dims=0, 

733 )[..., :3], 

734 perm=[1, 0, 2], 

735 ) 

736 ba = abc[0] - abc[1] 

737 bc = abc[2] - abc[1] 

738 dot = tf.keras.backend.batch_dot( 

739 ba, 

740 bc, 

741 ) 

742 prod = tf.expand_dims(_batch_fro(ba) * _batch_fro(bc), axis=1) 

743 t = tf.clip_by_value(dot / prod, clip_value_min=-1, clip_value_max=1) 

744 current_angle = tf.squeeze(tf.acos(t)) 

745 angle = tf.abs(ang - current_angle) 

746 rotmat = _rotation_matrices( 

747 angle=angle, 

748 direction=direction, 

749 point=abc[1], 

750 ) 

751 dynamic = tf.transpose( 

752 tf.gather( 

753 params=xyz_out, indices=tf.where(~ind)[:, 0], axis=1, batch_dims=0 

754 ), 

755 perm=[0, 2, 1], 

756 ) 

757 rotated = tf.transpose( 

758 tf.keras.backend.batch_dot(rotmat, dynamic), 

759 perm=[0, 2, 1], 

760 ) 

761 static = tf.gather( 

762 params=xyz_out, indices=tf.where(ind)[:, 0], axis=1, batch_dims=0 

763 ) 

764 new = tf.TensorArray( 

765 dtype=tf.float32, 

766 size=self.no_of_central_distances + 1 + self.n_sidechains, 

767 clear_after_read=False, 

768 ) 

769 d = 0 

770 s = 0 

771 c = 0 

772 for j in ind: 

773 if j: 

774 new = new.write(c, static[:, s]) 

775 s += 1 

776 else: 

777 new = new.write(c, rotated[:, d]) 

778 d += 1 

779 c += 1 

780 xyz_out = tf.transpose( 

781 new.stack(), 

782 perm=[1, 0, 2], 

783 ) 

784 

785 # dihedrals 

786 for i in range(self.no_of_dihedrals): 

787 ind = self.dihedral_indices[i] 

788 dihedral_index = self.dihedral_index_quadruplets[i] 

789 ang = dihedrals[:, i] 

790 abcd = tf.transpose( 

791 tf.gather( 

792 params=xyz_out, 

793 indices=dihedral_index, 

794 axis=1, 

795 batch_dims=0, 

796 )[..., :3], 

797 perm=[1, 0, 2], 

798 ) 

799 direction = abcd[2] - abcd[1] 

800 b1 = abcd[1] - abcd[0] 

801 b2 = abcd[2] - abcd[1] 

802 b3 = abcd[3] - abcd[2] 

803 c1 = tf.linalg.cross(b2, b3) 

804 c2 = tf.linalg.cross(b1, b2) 

805 p1 = tf.reduce_sum((b1 * c1), axis=1) 

806 p1 *= tf.sqrt(tf.reduce_sum((b2 * b2), axis=1)) 

807 p2 = tf.reduce_sum((c1 * c2), axis=1) 

808 current_angle = tf.atan2(p1, p2) 

809 angle = ang - current_angle 

810 rotmat = _rotation_matrices( 

811 angle=angle, 

812 direction=direction, 

813 point=abcd[1], 

814 ) 

815 dynamic = tf.transpose( 

816 tf.gather( 

817 params=xyz_out, indices=tf.where(~ind)[:, 0], axis=1, batch_dims=0 

818 ), 

819 perm=[0, 2, 1], 

820 ) 

821 rotated = tf.transpose( 

822 tf.keras.backend.batch_dot(rotmat, dynamic), 

823 perm=[0, 2, 1], 

824 ) 

825 static = tf.gather( 

826 params=xyz_out, indices=tf.where(ind)[:, 0], axis=1, batch_dims=0 

827 ) 

828 new = tf.TensorArray( 

829 dtype=tf.float32, 

830 size=self.no_of_central_distances + 1 + self.n_sidechains, 

831 clear_after_read=False, 

832 ) 

833 d = 0 

834 s = 0 

835 c = 0 

836 for j in ind: 

837 if j: 

838 new = new.write(c, static[:, s]) 

839 s += 1 

840 else: 

841 new = new.write(c, rotated[:, d]) 

842 d += 1 

843 c += 1 

844 xyz_out = tf.transpose( 

845 new.stack(), 

846 perm=[1, 0, 2], 

847 ) 

848 

849 return xyz_out[..., :3] 

850 

851 

852@tf.function 

853def _batch_fro(a: tf.Tensor) -> tf.Tensor: 

854 """Batch-wise Frobert norm, a.k.a. length of a vector.""" 

855 return tf.sqrt(tf.reduce_sum(a**2, axis=1)) 

856 

857 

858@tf.function 

859def _rotation_matrices(angle, direction, point) -> tf.Tensor: 

860 """Adapted from C. Gohlke's transformations.py. 

861 

862 Batch-wise 4x4 rotation matrices. 

863 

864 """ 

865 sina = tf.sin(angle) 

866 cosa = tf.cos(angle) 

867 direction_u = _unit_vector(direction) 

868 

869 # rotation matrix around unit vector 

870 R = tf.linalg.diag(tf.transpose([cosa, cosa, cosa]), k=0) 

871 R += tf.einsum("ki,kj->kij", direction_u, direction_u) * tf.expand_dims( 

872 tf.expand_dims(1.0 - cosa, -1), -1 

873 ) 

874 direction_u *= tf.expand_dims(sina, -1) 

875 

876 R_add = tf.TensorArray( 

877 dtype=tf.float32, 

878 size=tf.shape(angle)[0], 

879 clear_after_read=False, 

880 ) 

881 for i in range(tf.shape(angle)[0]): 

882 d = direction_u[i] 

883 R_add = R_add.write( 

884 i, [[0.0, -d[2], d[1]], [d[2], 0.0, -d[0]], [-d[1], d[0], 0.0]] 

885 ) 

886 R_add = R_add.stack() 

887 R += R_add 

888 R.set_shape((angle.shape[0], 3, 3)) 

889 

890 # rotation around origin 

891 test = tf.expand_dims( 

892 tf.pad( 

893 point - tf.keras.backend.batch_dot(R, point), 

894 paddings=((0, 0), (0, 1)), 

895 constant_values=1, 

896 ), 

897 axis=-1, 

898 ) 

899 R = tf.pad(R, ((0, 0), (0, 1), (0, 0))) 

900 M = tf.concat([R, test], axis=2) 

901 return M 

902 

903 

904@tf.function 

905def _unit_vector(vector: tf.Tensor) -> tf.Tensor: 

906 """Adapted from C. Gohlke's transformations.py""" 

907 length = tf.sqrt(tf.reduce_sum(vector**2, axis=1)) 

908 return vector / tf.expand_dims(length, 1) 

909 

910 

911@tf.keras.utils.register_keras_serializable() 

912class BackMapLayer(Layer): 

913 """Layer that implements backmapping from torsions-angles-distances to Euclidean coordinates.""" 

914 

915 def __init__(self, left_split: int, right_split: int) -> None: 

916 """Instantiate the layer.""" 

917 super().__init__() 

918 self._name = "BackmapLayer" 

919 self.left_split = left_split 

920 self.right_split = right_split 

921 

922 @classmethod 

923 def from_config( 

924 cls: Type[BackMapLayerType], 

925 config: dict[Any, Any], 

926 ) -> BackMapLayerType: 

927 """Reconstructs this keras serializable from a dict. 

928 

929 Args: 

930 config (dict[Any, Any]): A dictionary. 

931 

932 Returns: 

933 BackMapLayerType: An instance of the BackMapLayer. 

934 

935 """ 

936 left_split = config.pop("left_split") 

937 right_split = config.pop("right_split") 

938 return cls(left_split=left_split, right_split=right_split) 

939 

940 def get_config(self) -> dict[Any, Any]: 

941 """Serializes this keras serializable. 

942 

943 Returns: 

944 dict[Any, Any]: A dict with the serializable objects. 

945 

946 """ 

947 config = super().get_config().copy() 

948 config.update( 

949 { 

950 "left_split": self.left_split, 

951 "right_split": self.right_split, 

952 } 

953 ) 

954 return config 

955 

956 def call(self, inputs: tuple[tf.Tensor, tf.Tensor, tf.Tensor]) -> tf.Tensor: 

957 """Call the layers, inputs should be a tuple shaped, so that it can be split into 

958 distances, angles, dihedrals = inputs 

959 """ 

960 distances, angles, dihedrals = inputs 

961 # I don't know how negative distances can arrive at this step 

962 # but we replace them with the mean 

963 fixed_distances = tf.where( 

964 distances < 0.00001, 

965 tf.ones_like(distances) * tf.reduce_mean(distances), 

966 distances, 

967 ) 

968 # mean lengths 

969 # back_mean_lengths = tf.expand_dims(tf.reduce_mean(inp_distances, 0), 0) 

970 out = tf.expand_dims(tf.reduce_mean(distances, 0), 0) 

971 

972 # chain in plane 

973 # back_chain_in_plane = chain_in_plane(back_mean_lengths, out_angles) 

974 out = chain_in_plane(out, angles) 

975 

976 # dihedrals to cartesian 

977 # back_cartesians = dihedrals_to_cartesian_tf(out_dihedrals + pi, back_chain_in_plane) 

978 out_dihedrals = tf.add(dihedrals, pi) 

979 out = dihedrals_to_cartesian_tf_layers( 

980 out_dihedrals, 

981 out, 

982 left_iteration_counter=self.left_split, 

983 right_iteration_counter=self.right_split, 

984 ) 

985 return out 

986 

987 

988@tf.keras.utils.register_keras_serializable() 

989class BackMapLayerTransformations(Layer): 

990 """Experimental layer for using multimers with the ADCEMap.""" 

991 

992 def __init__(self, protein_lengths: Sequence[int]) -> None: 

993 """Instantiate the layer. 

994 

995 Args: 

996 protein_lengths (Sequence[int]): The lengths of the proteins in the 

997 multimers. Based on this information, the input to `self.call` 

998 will be split. 

999 

1000 """ 

1001 self.protein_lengths = protein_lengths 

1002 super().__init__() 

1003 self._name = "BackmapLayerTransformations" 

1004 

1005 @classmethod 

1006 def from_config( 

1007 cls: Type[BackMapLayerTransformationsType], 

1008 config: dict[Any, Any], 

1009 ) -> BackMapLayerTransformationsType: 

1010 """Reconstructs this keras serializable from a dict. 

1011 

1012 Args: 

1013 config (dict[Any, Any]): A dictionary. 

1014 

1015 Returns: 

1016 BackMapLayerTransformationsType: An instance of the BackMapLayerTransformations. 

1017 

1018 """ 

1019 protein_lengths = config.pop("protein_lengths") 

1020 return cls(protein_lengths=protein_lengths, **config) 

1021 

1022 def get_config(self) -> dict[Any, Any]: 

1023 """Serializes this keras serializable. 

1024 

1025 Returns: 

1026 dict[Any, Any]: A dict with the serializable objects. 

1027 

1028 """ 

1029 config = super().get_config().copy() 

1030 config.update( 

1031 { 

1032 "protein_lengths": self.protein_lengths, 

1033 } 

1034 ) 

1035 return config 

1036 

1037 def call(self, inputs): 

1038 """Call the layers, inputs should be a tuple shaped, so that it can be split into 

1039 distances, angles, dihedrals, matrices = inputs 

1040 """ 

1041 # Third Party Imports 

1042 from tensorflow_graphics.rendering.utils import transform_homogeneous 

1043 

1044 inp_distances, out_angles, out_dihedrals, matrices = inputs 

1045 

1046 out_cartesians = [] 

1047 current_length = 0 

1048 for i, protein_length in enumerate(self.protein_lengths): 

1049 if current_length == 0: 

1050 distance_ind = slice(0, protein_length * 3 - 1) 

1051 angle_ind = slice(0, protein_length * 3 - 2) 

1052 dihe_ind = slice(0, protein_length * 3 - 3) 

1053 else: 

1054 distance_ind = slice( 

1055 current_length * 3 - i, 

1056 current_length * 3 + protein_length * 3 - (i + 1), 

1057 ) 

1058 angle_ind = slice( 

1059 current_length * 3 - (i + 1), 

1060 current_length * 3 + protein_length * 3 - (i + 2), 

1061 ) 

1062 dihe_ind = slice( 

1063 current_length * 3 - (i + 2), 

1064 current_length * 3 + protein_length * 3 - (i + 3), 

1065 ) 

1066 current_length += protein_length 

1067 

1068 # index 

1069 current_lengths = inp_distances[:, distance_ind] 

1070 current_lengths = tf.expand_dims(tf.reduce_mean(current_lengths, 0), 0) 

1071 current_angles = out_angles[:, angle_ind] 

1072 current_dihedrals = out_dihedrals[:, dihe_ind] 

1073 current_dihedrals = tf.add(current_dihedrals, pi) 

1074 

1075 c = chain_in_plane(current_lengths, current_angles) 

1076 c = dihedrals_to_cartesian_tf_layers(current_dihedrals, c) 

1077 

1078 # for all other proteins apply homogeneous transformation matrices 

1079 if i != 0: 

1080 m = matrices[:, i - 1] 

1081 c = transform_homogeneous(m, c)[..., :3] 

1082 

1083 out_cartesians.append(c) 

1084 

1085 out_cartesians = tf.concat( 

1086 out_cartesians, 

1087 axis=1, 

1088 ) 

1089 return out_cartesians 

1090 

1091 

1092@tf.keras.utils.register_keras_serializable() 

1093class PeriodicOutput(EncoderMapBaseLayer): 

1094 """Layer that reverses the PeriodicInputLayer.""" 

1095 

1096 def __init__( 

1097 self, 

1098 parameters: Union[Parameters, ADCParameters], 

1099 print_name: str, 

1100 trainable: bool = False, 

1101 **kwargs, 

1102 ) -> None: 

1103 """Instantiate the layer, We also need to know here what periodicity is needed. 

1104 

1105 Args: 

1106 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]): An instance of 

1107 encodermap's parameters. 

1108 print_name (str): The name of this layer, as it should appear 

1109 in summaries. 

1110 trainable (bool): Whether this layer is trainable. As this layer 

1111 has no kernel and/or bias. This argument has no influence. 

1112 Defaults to False. 

1113 

1114 """ 

1115 super().__init__(parameters, print_name, trainable) 

1116 

1117 def call(self, inputs): 

1118 """Calls the layer. Inputs should be a tuple of (sin, cos) of the same angles""" 

1119 outputs = inputs 

1120 outputs = tf.atan2(*tf.split(outputs, 2, 1)) 

1121 

1122 if self.p.periodicity != 2 * pi: 

1123 outputs = outputs / (2 * pi) * self.p.periodicity 

1124 return outputs 

1125 

1126 

1127@tf.keras.utils.register_keras_serializable() 

1128class MeanAngles(Layer): 

1129 """Layer that implements the mean of periodic angles.""" 

1130 

1131 def __init__( 

1132 self, 

1133 parameters: Union[Parameters, ADCParameters], 

1134 print_name: str, 

1135 trainable: bool = False, 

1136 **kwargs, 

1137 ) -> None: 

1138 """Instantiate the layer. 

1139 

1140 Args: 

1141 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]): An instance of 

1142 encodermap's parameters. 

1143 print_name (str): The name of this layer, as it should appear 

1144 in summaries. 

1145 trainable (bool): Whether this layer is trainable. As this layer 

1146 has no kernel and/or bias. This argument has no influence. 

1147 Defaults to False. 

1148 

1149 """ 

1150 super().__init__() 

1151 self.trainable = trainable 

1152 

1153 def call(self, inputs): 

1154 """Call the layer""" 

1155 return tf.tile( 

1156 tf.expand_dims( 

1157 tf.math.reduce_mean(inputs, 0), 

1158 0, 

1159 ), 

1160 multiples=(tf.shape(inputs)[0], 1), 

1161 ) 

1162 

1163 

1164@tf.keras.utils.register_keras_serializable() 

1165class PairwiseDistances(EncoderMapBaseLayer): 

1166 """Layer that implements pairwise distances for both cases, with and without sidechain reconstruction""" 

1167 

1168 def __init__( 

1169 self, 

1170 parameters: Union[Parameters, ADCParameters], 

1171 print_name: str, 

1172 trainable: bool = False, 

1173 **kwargs, 

1174 ) -> None: 

1175 """Instantiate the layer. 

1176 

1177 Args: 

1178 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]): An instance of 

1179 encodermap's parameters. 

1180 print_name (str): The name of this layer, as it should appear 

1181 in summaries. 

1182 trainable (bool): Whether this layer is trainable. As this layer 

1183 has no kernel and/or bias. This argument has no influence. 

1184 Defaults to False. 

1185 

1186 """ 

1187 super().__init__(parameters, print_name, trainable) 

1188 if self.p.reconstruct_sidechains: 

1189 assert hasattr(self.p, "sidechain_info"), ( 

1190 "The provided parameters ask for sidechains to be reconstructed, " 

1191 "but don't contain a 'sidechain_info' attribute." 

1192 ) 

1193 self.indices = None 

1194 n_residues = max(list(self.p.sidechain_info[-1].keys())) 

1195 self.indices = np.arange(n_residues * 3)[ 

1196 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step 

1197 ] 

1198 atom = n_residues * 3 + 1 

1199 indices = [] 

1200 for residue, n_sidechains_in_residue in self.p.sidechain_info[-1].items(): 

1201 if n_sidechains_in_residue == 0: 

1202 continue 

1203 if residue == 1: 

1204 atom += n_sidechains_in_residue 

1205 else: 

1206 atom += n_sidechains_in_residue 

1207 indices.append(atom) 

1208 self.indices = np.concatenate([self.indices, indices]) 

1209 

1210 def get_config(self) -> dict[Any, Any]: 

1211 """Serializes this keras serializable. 

1212 

1213 Returns: 

1214 dict[Any, Any]: A dict with the serializable objects. 

1215 

1216 """ 

1217 sidechain_info = self.p.sidechain_info 

1218 config = super().get_config().copy() 

1219 config.update( 

1220 { 

1221 "sidechain_info": sidechain_info, 

1222 } 

1223 ) 

1224 return config 

1225 

1226 @classmethod 

1227 def from_config( 

1228 cls: Type[BackMapLayerWithSidechainsType], 

1229 config: dict[Any, Any], 

1230 ) -> BackMapLayerWithSidechainsType: 

1231 """Reconstructs this keras serializable from a dict. 

1232 

1233 Args: 

1234 config (dict[Any, Any]): A dictionary. 

1235 

1236 Returns: 

1237 BackMapLayerType: An instance of the BackMapLayer. 

1238 

1239 """ 

1240 p = config.pop("p") 

1241 if "cartesian_pwd_start" in p: 

1242 p = ADCParameters(**p) 

1243 else: 

1244 p = Parameters(**p) 

1245 sidechain_info = config.pop("sidechain_info") 

1246 out = {int(k): v for k, v in sidechain_info.items()} 

1247 for k, v in out.items(): 

1248 out[k] = {int(kv): vv for kv, vv in v.items()} 

1249 p.sidechain_info = out 

1250 return cls(parameters=p, **config) 

1251 

1252 def call(self, inputs): 

1253 """Call the layer""" 

1254 if not self.p.reconstruct_sidechains: 

1255 out = inputs[ 

1256 :, 

1257 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

1258 ] 

1259 else: 

1260 out = tf.gather( 

1261 params=inputs, 

1262 indices=self.indices, 

1263 axis=1, 

1264 batch_dims=0, 

1265 ) 

1266 out = pairwise_dist(out, flat=True) 

1267 return out