Coverage for encodermap/models/layers.py: 16%
430 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
1# -*- coding: utf-8 -*-
2# encodermap/models/layers.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Module that implements custom layers. Mainly needed for handling periodicity,
23backmapping or sparsity."""
26################################################################################
27# Imports
28################################################################################
31# Future Imports at the top
32from __future__ import annotations
34# Standard Library Imports
35import itertools
36from collections.abc import Sequence
37from math import pi
38from typing import Any, Optional, Type, TypeVar, Union
40# Third Party Imports
41import numpy as np
42import tensorflow as tf
43from scipy.linalg import block_diag
44from tensorflow.keras.layers import Concatenate, Dense, Lambda, Layer
46# Encodermap imports
47from encodermap.encodermap_tf1.backmapping import chain_in_plane
48from encodermap.loss_functions.loss_classes import testing
49from encodermap.misc.backmapping import (
50 dihedrals_to_cartesian_tf_layers,
51 split_and_reverse_cartesians,
52 split_and_reverse_dihedrals,
53)
54from encodermap.misc.distances import pairwise_dist
55from encodermap.parameters.parameters import ADCParameters, Parameters
58################################################################################
59# Typing
60################################################################################
63BackMapLayerTransformationsType = TypeVar(
64 "BackMapLayerTransformationsType",
65 bound="BackMapLayerTransformations",
66)
67PeriodicOutputType = TypeVar(
68 "PeriodicOutputType",
69 bound="PeriodicOutput",
70)
71MeanAnglesType = TypeVar(
72 "MeanAnglesType",
73 bound="MeanAngles",
74)
75EncoderMapBaseLayerType = TypeVar(
76 "EncoderMapBaseLayerType",
77 bound="EncoderMapBaseLayer",
78)
79BackMapLayerType = TypeVar(
80 "BackMapLayerType",
81 bound="BackMapLayer",
82)
83BackMapLayerWithSidechainsType = TypeVar(
84 "BackMapLayerWithSidechainsType",
85 bound="BackMapLayerWithSidechains",
86)
89################################################################################
90# Globals
91################################################################################
94__all__: list[str] = ["PeriodicInput", "PeriodicOutput", "MeanAngles", "BackMapLayer"]
97################################################################################
98# Layers
99################################################################################
102@tf.keras.utils.register_keras_serializable()
103class EncoderMapBaseLayer(Layer):
104 """EncoderMap's Base Layer, that implements saving and loading parameters.
106 Classes that inherit from `EncoderMapBaseLayer` automatically receive
107 parameters when deserialized.
109 """
111 def __init__(
112 self,
113 parameters: Union[Parameters, ADCParameters],
114 print_name: str,
115 trainable: bool = False,
116 ) -> None:
117 """Instantiate the layer.
119 Args:
120 parameters (Union[encocermap.parameters.Parameters, encocermap.parameters.ADCParameters]): An instance of
121 encodermap's parameters.
122 print_name (str): The name of this layer, as it should appear
123 in summaries.
124 trainable (bool): Whether this layer is trainable. As this layer
125 has no kernel and/or bias. This argument has no influence.
126 Defaults to False.
128 """
129 super().__init__()
130 self.p = parameters
131 self.print_name = print_name
132 self._name = print_name
133 self.trainable = trainable
135 @classmethod
136 def from_config(
137 cls: Type[EncoderMapBaseLayerType],
138 config: dict[Any, Any],
139 ) -> EncoderMapBaseLayerType:
140 """Reconstructs this keras serializable from a dict.
142 Args:
143 config (dict[Any, Any]): A dictionary.
145 Returns:
146 EncoderMapBaseLayerType: An instance of the EncoderMapBaseLayer.
148 """
149 p = config.pop("p")
150 if "cartesian_pwd_start" in p:
151 p = ADCParameters(**p)
152 else:
153 p = Parameters(**p)
154 return cls(parameters=p, **config)
156 def get_config(self) -> dict[Any, Any]:
157 """Serializes this keras serializable.
159 Returns:
160 dict[Any, Any]: A dict with the serializable objects.
162 """
163 config = super().get_config().copy()
164 config.update(
165 {
166 "print_name": self.print_name,
167 "p": self.p.to_dict(),
168 }
169 )
170 return config
173@tf.keras.utils.register_keras_serializable()
174class PeriodicInput(EncoderMapBaseLayer):
175 """Layer that handles periodic input. Needed, if angles are treated.
176 Input angles will be split into sin and cos components,
177 and a tensor with shape[0] = 2 * inp_shape[0] will be returned
178 """
180 def __init__(
181 self,
182 parameters: Union[Parameters, ADCParameters],
183 print_name: str,
184 trainable: bool = False,
185 **kwargs,
186 ) -> None:
187 """Instantiate the layer. Need parameters to get the info about the
188 periodicity. Although angles are most often used, who knows what hyper-toroidal
189 manifold your data lies in.
191 Args:
192 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]): An instance of
193 encodermap's parameters.
194 print_name (str): The name of this layer, as it should appear
195 in summaries.
196 trainable (bool): Whether this layer is trainable. As this layer
197 has no kernel and/or bias. This argument has no influence.
198 Defaults to False.
200 """
201 super().__init__(parameters, print_name, trainable)
203 def call(self, inputs: tf.Tensor) -> tf.Tensor:
204 """Call the layer."""
205 outputs = inputs
206 if self.p.periodicity != 2 * pi:
207 outputs = outputs / self.p.periodicity * 2 * pi
208 outputs = Concatenate(axis=1, name=f"{self.print_name}_Concat")(
209 [
210 tf.sin(outputs),
211 tf.cos(outputs),
212 ]
213 )
214 return outputs
217@tf.keras.utils.register_keras_serializable()
218class BackMapLayerWithSidechains(Layer):
219 """Also backmaps sidechains. For that, we need a way to know which
220 distances, angles, dihedrals belong to the backbone, and which belong to
221 a sidechain. See the docstring of `encodermap.misc.backmapping._full_backmapping_np`
222 for details.
224 """
226 def __init__(
227 self,
228 feature_description: Any,
229 ) -> None:
230 super().__init__()
231 self.feature_description: dict[Any, Any] = feature_description
233 # Definitions and Tests
234 n_residues: int = max(list(feature_description[-1].keys()))
235 assert np.array_equal(
236 np.arange(1, n_residues + 1),
237 np.sort(np.asarray(list(feature_description[-1].keys()))),
238 ), (
239 f"Currently the `feature_indices[-1]` dict needs to contain monotonous "
240 f"increasing keys. Starting from 1 {feature_description[-1].keys()=}"
241 )
242 n_sidechains: int = sum(
243 [v + 1 for v in feature_description[-1].values() if v > 0]
244 )
245 sum_sidechains = sum(list(feature_description[-1].values()))
247 # this can be defined beforehand and then stacked as often, as a batch needs it
248 self.init_xyz: tf.Tensor = tf.zeros(
249 shape=(1, n_residues * 3 + n_sidechains, 3),
250 dtype=tf.float32,
251 )
253 # first we create the central_distance indices
254 central_distance_indices = np.tri(
255 N=n_residues * 3 - 1,
256 M=n_residues * 3,
257 k=0,
258 ).astype(bool)
259 right_side_central_distance_indices = [
260 np.full(shape=(1, n_sidechains), fill_value=False, dtype=bool)
261 ]
262 count = 0 # starts at the first atom of the central chan
263 count2 = n_residues * 3 + 1 # starts at the first atom of the sidechain
264 sidechain_cartesians_ind = []
265 sidechain_positions_indices = []
266 central_angle_index_triplets = np.vstack(
267 [
268 np.arange(0, n_residues * 3)[:-2],
269 np.arange(0, n_residues * 3)[1:-1],
270 np.arange(0, n_residues * 3)[2:],
271 ]
272 ).T.tolist()
273 sidechain_angle_index_triplets = []
274 central_dihedral_index_quadruplets = np.vstack(
275 [
276 np.arange(0, n_residues * 3)[:-3],
277 np.arange(0, n_residues * 3)[1:-2],
278 np.arange(0, n_residues * 3)[2:-1],
279 np.arange(0, n_residues * 3)[3:],
280 ]
281 ).T.tolist()
282 sidechain_dihedral_index_quadruplets = []
284 # iterate over feature_description[-1] to get all indices and the right side
285 # of the central cartesians
286 for i, (residue, n_sidechains_in_residue) in zip(
287 itertools.count(1, 3), feature_description[-1].items()
288 ):
289 if n_sidechains_in_residue == 0:
290 if residue == 1 or residue == n_residues:
291 continue
292 else:
293 right_side_central_distance_indices.append(t)
294 else:
295 sidechain_cartesians_ind.append(
296 np.arange(count, count + n_sidechains_in_residue)
297 )
298 sidechain_positions_indices.append(
299 [i]
300 + np.arange(count2 - 1, count2 + n_sidechains_in_residue).tolist()
301 )
302 for sidechain_i in range(n_sidechains_in_residue + 1):
303 if sidechain_i == 0:
304 # adds N-CA-CB
305 sidechain_angle_index_triplets.append(
306 [(residue - 1) * 3, (residue - 1) * 3 + 1, count2 - 1]
307 )
308 # adds N-CA-CB-CG
309 sidechain_dihedral_index_quadruplets.append(
310 [
311 (residue - 1) * 3,
312 (residue - 1) * 3 + 1,
313 count2 - 1,
314 count2,
315 ]
316 )
317 elif sidechain_i == 1:
318 # adds CA-CB-CG
319 sidechain_angle_index_triplets.append(
320 [(residue - 1) * 3 + 1, count2 - 1, count2]
321 )
322 # adds CA-CB-CG-CD
323 if sidechain_i < n_sidechains_in_residue:
324 sidechain_dihedral_index_quadruplets.append(
325 [(residue - 1) * 3 + 1, count2 - 1, count2, count2 + 1]
326 )
327 else:
328 # adds CB-CG-CD and so on
329 sidechain_angle_index_triplets.append(
330 [
331 count2 + sidechain_i - 3,
332 count2 + sidechain_i - 2,
333 count2 + sidechain_i - 1,
334 ]
335 )
336 if sidechain_i < n_sidechains_in_residue:
337 sidechain_dihedral_index_quadruplets.append(
338 [
339 count2 + sidechain_i - 3,
340 count2 + sidechain_i - 2,
341 count2 + sidechain_i - 1,
342 count2 + sidechain_i,
343 ]
344 )
345 count += n_sidechains_in_residue + 1
346 count2 += n_sidechains_in_residue + 1
347 t = np.zeros(
348 shape=(3, n_sidechains),
349 dtype=bool,
350 )
351 t[:, :count] = True
352 right_side_central_distance_indices.append(t)
353 assert len(sidechain_angle_index_triplets) == n_sidechains
354 assert len(sidechain_dihedral_index_quadruplets) == sum_sidechains, (
355 f"I could not reconstruct the correct number of sidechain dihedral "
356 f"quadruplets. The number of sidechain dihedrals requires the list "
357 f"to have length {sum_sidechains}, but I created a list with "
358 f"{len(sidechain_dihedral_index_quadruplets)}."
359 )
360 right_side_central_distance_indices.append(
361 np.full(shape=(1, n_sidechains), fill_value=True, dtype=bool)
362 )
363 right_side_central_distance_indices = np.vstack(
364 right_side_central_distance_indices
365 )
366 angle_index_triplets = np.vstack(
367 central_angle_index_triplets + sidechain_angle_index_triplets
368 )
369 dihedral_index_quadruplets = np.vstack(
370 central_dihedral_index_quadruplets + sidechain_dihedral_index_quadruplets
371 )
372 if sidechain_cartesians_ind != []: # if sidechains
373 _use_sidechains = True
374 sidechain_cartesians_ind = np.concatenate(sidechain_cartesians_ind)
375 central_distance_indices = np.hstack(
376 [central_distance_indices, right_side_central_distance_indices]
377 )
378 side_distance_indices = [
379 (np.tri(N=i + 1, M=i + 2, k=0) + 1)[:, 1:]
380 for i in feature_description[-1].values()
381 if i > 0
382 ]
383 side_distance_indices = (block_diag(*side_distance_indices) % 2) == 0
384 left_side_side_distance_indices = (
385 np.full( # all atoms in the central chain are True
386 shape=(len(side_distance_indices), n_residues * 3),
387 fill_value=True,
388 dtype=bool,
389 )
390 )
391 side_distance_indices = np.hstack(
392 [left_side_side_distance_indices, side_distance_indices]
393 )
394 distance_indices = np.vstack(
395 [central_distance_indices, side_distance_indices]
396 )
397 else: # if no sidechains
398 _use_sidechains = False
399 distance_indices = central_distance_indices
400 assert distance_indices.shape == (
401 n_residues * 3 - 1 + n_sidechains,
402 self.init_xyz.shape[1],
403 ), (
404 f"The shape of the distance index after stacking is unexpected.\n"
405 f"Expected: {(n_residues * 3 - 1 + n_sidechains, self.init_xyz.shape[1])}\n"
406 f"Actual: {distance_indices.shape}"
407 )
409 # now the angles
410 central_angle_indices = central_distance_indices[1:]
411 if _use_sidechains: # if sidechains
412 angle_indices = np.vstack(
413 [central_distance_indices[1:], side_distance_indices]
414 )
415 side_angle_indices = side_distance_indices
416 else: # no sidechains
417 angle_indices = central_distance_indices[1:]
418 assert len(angle_indices) == len(distance_indices) - 1
420 # and the dihedrals
421 if _use_sidechains: # if sidechains
422 dihedral_indices = np.vstack(
423 [
424 central_distance_indices[1:-1],
425 side_distance_indices[sidechain_cartesians_ind],
426 ]
427 )
428 corrector = np.count_nonzero(
429 list(feature_description[-1].values())
430 ) # per reisude with sidechain dihedrals one less
431 else:
432 dihedral_indices = central_distance_indices[1:-1]
433 corrector = 0
434 assert len(dihedral_indices) == len(distance_indices) - 2 - corrector
435 assert angle_index_triplets.shape[0] == angle_indices.shape[0]
436 assert dihedral_index_quadruplets.shape[0] == dihedral_indices.shape[0], (
437 f"The number of dihedral indices ({len(distance_indices)}) and quadruplets "
438 f"does not match ({len(dihedral_index_quadruplets)})."
439 )
441 # create instance attributes with tf
442 self._use_sidechains = tf.constant(_use_sidechains, dtype=tf.bool)
443 self.n_sidechains_in_residue = tf.constant(
444 np.array(
445 [
446 self.feature_description[-1][k]
447 for k in sorted(self.feature_description[-1].keys())
448 ]
449 ).astype(np.int32),
450 dtype=tf.int32,
451 )
453 # general
454 self.up = tf.constant([[0.0, 0.0, 1.0]], dtype=tf.float32)
455 self.down = tf.constant([[0.0, 0.0, -1.0]], dtype=tf.float32)
457 # distances
458 self.no_of_central_distances = central_distance_indices.shape[0]
459 self.n_sidechains = n_sidechains
460 self.central_distance_indices = tf.constant(
461 central_distance_indices,
462 shape=central_distance_indices.shape,
463 dtype=tf.bool,
464 )
465 self.n_atoms = self.central_distance_indices.shape[1]
467 # angles
468 self.no_of_central_angles = central_angle_indices.shape[0]
469 self.no_of_side_angles = side_angle_indices.shape[0]
470 self.central_angle_indices = tf.constant(
471 central_angle_indices, shape=central_angle_indices.shape, dtype=tf.bool
472 )
473 self.side_angle_indices = tf.constant(
474 side_angle_indices, shape=side_angle_indices.shape, dtype=tf.bool
475 )
476 self.central_angle_index_triplets = tf.constant(
477 np.asarray(central_angle_index_triplets),
478 shape=(len(central_angle_index_triplets), 3),
479 dtype=tf.int32,
480 )
481 self.sidechain_angle_index_triplets = tf.constant(
482 np.asarray(sidechain_angle_index_triplets),
483 shape=(len(sidechain_angle_index_triplets), 3),
484 dtype=tf.int32,
485 )
487 # dihedrals
488 self.no_of_dihedrals = dihedral_indices.shape[0]
489 self.dihedral_indices = tf.constant(
490 dihedral_indices, shape=dihedral_indices.shape, dtype=tf.bool
491 )
492 self.dihedral_index_quadruplets = tf.constant(
493 np.asarray(dihedral_index_quadruplets),
494 shape=(len(dihedral_index_quadruplets), 4),
495 dtype=tf.int32,
496 )
498 def get_config(self) -> dict[Any, Any]:
499 """Serializes this keras serializable.
501 Returns:
502 dict[Any, Any]: A dict with the serializable objects.
504 """
505 config = super().get_config().copy()
506 config.update(
507 {
508 "feature_description": self.feature_description,
509 }
510 )
511 return config
513 @classmethod
514 def from_config(
515 cls: Type[BackMapLayerWithSidechainsType],
516 config: dict[Any, Any],
517 ) -> BackMapLayerWithSidechainsType:
518 """Reconstructs this keras serializable from a dict.
520 Args:
521 config (dict[Any, Any]): A dictionary.
523 Returns:
524 BackMapLayerType: An instance of the BackMapLayer.
526 """
527 feature_description = config.pop("feature_description")
528 out = {int(k): v for k, v in feature_description.items()}
529 for k, v in out.items():
530 out[k] = {int(kv): vv for kv, vv in v.items()}
531 return cls(feature_description=out)
533 def call(self, inputs: tuple[tf.Tensor, ...]) -> tf.Tensor:
534 # Unpack inputs
535 (
536 central_distances,
537 central_angles,
538 central_dihedrals,
539 side_distances,
540 side_angles,
541 side_dihedrals,
542 ) = inputs
544 # concatenate the dihedrals
545 dihedrals = tf.concat(
546 [
547 central_dihedrals,
548 side_dihedrals,
549 ],
550 axis=1,
551 )
553 # distances
554 xs_central = tf.TensorArray(
555 dtype=tf.float32,
556 size=self.no_of_central_distances + 1,
557 clear_after_read=False,
558 )
559 ys_central = tf.TensorArray(
560 dtype=tf.float32,
561 size=self.no_of_central_distances + 1,
562 clear_after_read=False,
563 )
564 xs_side = tf.TensorArray(
565 dtype=tf.float32,
566 size=self.n_sidechains,
567 clear_after_read=False,
568 )
569 ys_side = tf.TensorArray(
570 dtype=tf.float32,
571 size=self.n_sidechains,
572 clear_after_read=False,
573 )
574 xs_central = xs_central.write(
575 0, tf.zeros((tf.shape(central_angles)[0],), dtype=tf.float32)
576 )
577 ys_central = ys_central.write(
578 0, tf.zeros((tf.shape(central_angles)[0],), dtype=tf.float32)
579 )
580 # xs_central = [tf.zeros((tf.shape(central_angles)[0], ), dtype=tf.float32)]
581 # ys_central = [tf.zeros((tf.shape(central_angles)[0], ), dtype=tf.float32)]
582 # xs_side = []
583 # ys_side = []
585 residue = 0
586 idx = 0
587 j = 0
588 for i in range(self.no_of_central_distances):
589 # xs_central.append(xs[-1] + central_distances[:, 1])
590 xs_central = xs_central.write(
591 i + 1, xs_central.read(i) + central_distances[:, i]
592 )
593 # ys_central.append(tf.zeros((tf.shape(central_angles)[0], ), dtype=tf.float32))
594 ys_central = ys_central.write(
595 i + 1, tf.zeros((tf.shape(central_angles)[0],))
596 )
597 if idx == 0 and self._use_sidechains:
598 n_sidechains = self.n_sidechains_in_residue[residue]
599 if n_sidechains > 0:
600 for n in range(n_sidechains + 1):
601 # xs_side.append(xs_central.read(i))
602 xs_side = xs_side.write(j, xs_central.read(i + 1))
603 # ys_side.append(
604 # tf.reduce_sum(
605 # side_distances[:, j - n : j + 1],
606 # axis=1,
607 # )
608 # )
609 ys_side = ys_side.write(
610 j,
611 tf.reduce_sum(
612 side_distances[:, j - n : j + 1],
613 axis=1,
614 ),
615 )
616 j += 1
617 idx += 1
618 if idx >= 3:
619 residue += 1
620 idx = 0
621 xs_central = tf.transpose(xs_central.stack(), perm=[1, 0])
622 ys_central = tf.transpose(ys_central.stack(), perm=[1, 0])
623 xs_side = tf.transpose(xs_side.stack(), perm=[1, 0])
624 ys_side = tf.transpose(ys_side.stack(), perm=[1, 0])
625 xs_side.set_shape((xs_central.shape[0], self.n_sidechains))
626 ys_side.set_shape((xs_central.shape[0], self.n_sidechains))
627 xs = tf.concat([xs_central, xs_side], axis=1)
628 ys = tf.concat([ys_central, ys_side], axis=1)
629 xyz_out = tf.stack(
630 [
631 xs,
632 ys,
633 ],
634 axis=2,
635 )
636 xyz_out = tf.pad(
637 tf.pad(
638 xyz_out,
639 ((0, 0), (0, 0), (0, 1)),
640 constant_values=0,
641 ),
642 paddings=((0, 0), (0, 0), (0, 1)),
643 constant_values=1,
644 )
646 # angles
647 # Can't parallelize over angles (just over batch dimension)
648 # because xyz_out is updated constantly and thus
649 # xyz_out[..., -1] changes during iteration
650 for i in range(self.no_of_central_angles):
651 ind = self.central_angle_indices[i]
652 angle_index = self.central_angle_index_triplets[i]
653 ang = central_angles[:, i]
654 direction = tf.repeat(
655 self.up,
656 repeats=tf.shape(ang)[0],
657 axis=0,
658 )
659 abc = tf.transpose(
660 tf.gather(
661 params=xyz_out,
662 indices=angle_index,
663 axis=1,
664 batch_dims=0,
665 )[..., :3],
666 perm=[1, 0, 2],
667 )
668 ba = abc[0] - abc[1]
669 bc = abc[2] - abc[1]
670 dot = tf.keras.backend.batch_dot(
671 ba,
672 bc,
673 )
674 prod = tf.expand_dims(_batch_fro(ba) * _batch_fro(bc), axis=1)
675 t = tf.clip_by_value(dot / prod, clip_value_min=-1, clip_value_max=1)
676 current_angle = tf.squeeze(tf.acos(t))
677 angle = tf.abs(ang - current_angle)
678 rotmat = _rotation_matrices(
679 angle=angle,
680 direction=direction,
681 point=abc[1],
682 )
683 dynamic = tf.transpose(
684 tf.gather(
685 params=xyz_out, indices=tf.where(~ind)[:, 0], axis=1, batch_dims=0
686 ),
687 perm=[0, 2, 1],
688 )
689 rotated = tf.transpose(
690 tf.keras.backend.batch_dot(rotmat, dynamic),
691 perm=[0, 2, 1],
692 )
693 static = tf.gather(
694 params=xyz_out, indices=tf.where(ind)[:, 0], axis=1, batch_dims=0
695 )
696 new = tf.TensorArray(
697 dtype=tf.float32,
698 size=self.no_of_central_distances + 1 + self.n_sidechains,
699 clear_after_read=False,
700 )
701 d = 0
702 s = 0
703 c = 0
704 for j in ind:
705 if j:
706 new = new.write(c, static[:, s])
707 s += 1
708 else:
709 new = new.write(c, rotated[:, d])
710 d += 1
711 c += 1
712 xyz_out = tf.transpose(
713 new.stack(),
714 perm=[1, 0, 2],
715 )
717 # sidechains
718 for i in range(self.no_of_side_angles):
719 ind = self.side_angle_indices[i]
720 angle_index = self.sidechain_angle_index_triplets[i]
721 ang = side_angles[:, i]
722 direction = tf.repeat(
723 self.down,
724 repeats=tf.shape(ang)[0],
725 axis=0,
726 )
727 abc = tf.transpose(
728 tf.gather(
729 params=xyz_out,
730 indices=angle_index,
731 axis=1,
732 batch_dims=0,
733 )[..., :3],
734 perm=[1, 0, 2],
735 )
736 ba = abc[0] - abc[1]
737 bc = abc[2] - abc[1]
738 dot = tf.keras.backend.batch_dot(
739 ba,
740 bc,
741 )
742 prod = tf.expand_dims(_batch_fro(ba) * _batch_fro(bc), axis=1)
743 t = tf.clip_by_value(dot / prod, clip_value_min=-1, clip_value_max=1)
744 current_angle = tf.squeeze(tf.acos(t))
745 angle = tf.abs(ang - current_angle)
746 rotmat = _rotation_matrices(
747 angle=angle,
748 direction=direction,
749 point=abc[1],
750 )
751 dynamic = tf.transpose(
752 tf.gather(
753 params=xyz_out, indices=tf.where(~ind)[:, 0], axis=1, batch_dims=0
754 ),
755 perm=[0, 2, 1],
756 )
757 rotated = tf.transpose(
758 tf.keras.backend.batch_dot(rotmat, dynamic),
759 perm=[0, 2, 1],
760 )
761 static = tf.gather(
762 params=xyz_out, indices=tf.where(ind)[:, 0], axis=1, batch_dims=0
763 )
764 new = tf.TensorArray(
765 dtype=tf.float32,
766 size=self.no_of_central_distances + 1 + self.n_sidechains,
767 clear_after_read=False,
768 )
769 d = 0
770 s = 0
771 c = 0
772 for j in ind:
773 if j:
774 new = new.write(c, static[:, s])
775 s += 1
776 else:
777 new = new.write(c, rotated[:, d])
778 d += 1
779 c += 1
780 xyz_out = tf.transpose(
781 new.stack(),
782 perm=[1, 0, 2],
783 )
785 # dihedrals
786 for i in range(self.no_of_dihedrals):
787 ind = self.dihedral_indices[i]
788 dihedral_index = self.dihedral_index_quadruplets[i]
789 ang = dihedrals[:, i]
790 abcd = tf.transpose(
791 tf.gather(
792 params=xyz_out,
793 indices=dihedral_index,
794 axis=1,
795 batch_dims=0,
796 )[..., :3],
797 perm=[1, 0, 2],
798 )
799 direction = abcd[2] - abcd[1]
800 b1 = abcd[1] - abcd[0]
801 b2 = abcd[2] - abcd[1]
802 b3 = abcd[3] - abcd[2]
803 c1 = tf.linalg.cross(b2, b3)
804 c2 = tf.linalg.cross(b1, b2)
805 p1 = tf.reduce_sum((b1 * c1), axis=1)
806 p1 *= tf.sqrt(tf.reduce_sum((b2 * b2), axis=1))
807 p2 = tf.reduce_sum((c1 * c2), axis=1)
808 current_angle = tf.atan2(p1, p2)
809 angle = ang - current_angle
810 rotmat = _rotation_matrices(
811 angle=angle,
812 direction=direction,
813 point=abcd[1],
814 )
815 dynamic = tf.transpose(
816 tf.gather(
817 params=xyz_out, indices=tf.where(~ind)[:, 0], axis=1, batch_dims=0
818 ),
819 perm=[0, 2, 1],
820 )
821 rotated = tf.transpose(
822 tf.keras.backend.batch_dot(rotmat, dynamic),
823 perm=[0, 2, 1],
824 )
825 static = tf.gather(
826 params=xyz_out, indices=tf.where(ind)[:, 0], axis=1, batch_dims=0
827 )
828 new = tf.TensorArray(
829 dtype=tf.float32,
830 size=self.no_of_central_distances + 1 + self.n_sidechains,
831 clear_after_read=False,
832 )
833 d = 0
834 s = 0
835 c = 0
836 for j in ind:
837 if j:
838 new = new.write(c, static[:, s])
839 s += 1
840 else:
841 new = new.write(c, rotated[:, d])
842 d += 1
843 c += 1
844 xyz_out = tf.transpose(
845 new.stack(),
846 perm=[1, 0, 2],
847 )
849 return xyz_out[..., :3]
852@tf.function
853def _batch_fro(a: tf.Tensor) -> tf.Tensor:
854 """Batch-wise Frobert norm, a.k.a. length of a vector."""
855 return tf.sqrt(tf.reduce_sum(a**2, axis=1))
858@tf.function
859def _rotation_matrices(angle, direction, point) -> tf.Tensor:
860 """Adapted from C. Gohlke's transformations.py.
862 Batch-wise 4x4 rotation matrices.
864 """
865 sina = tf.sin(angle)
866 cosa = tf.cos(angle)
867 direction_u = _unit_vector(direction)
869 # rotation matrix around unit vector
870 R = tf.linalg.diag(tf.transpose([cosa, cosa, cosa]), k=0)
871 R += tf.einsum("ki,kj->kij", direction_u, direction_u) * tf.expand_dims(
872 tf.expand_dims(1.0 - cosa, -1), -1
873 )
874 direction_u *= tf.expand_dims(sina, -1)
876 R_add = tf.TensorArray(
877 dtype=tf.float32,
878 size=tf.shape(angle)[0],
879 clear_after_read=False,
880 )
881 for i in range(tf.shape(angle)[0]):
882 d = direction_u[i]
883 R_add = R_add.write(
884 i, [[0.0, -d[2], d[1]], [d[2], 0.0, -d[0]], [-d[1], d[0], 0.0]]
885 )
886 R_add = R_add.stack()
887 R += R_add
888 R.set_shape((angle.shape[0], 3, 3))
890 # rotation around origin
891 test = tf.expand_dims(
892 tf.pad(
893 point - tf.keras.backend.batch_dot(R, point),
894 paddings=((0, 0), (0, 1)),
895 constant_values=1,
896 ),
897 axis=-1,
898 )
899 R = tf.pad(R, ((0, 0), (0, 1), (0, 0)))
900 M = tf.concat([R, test], axis=2)
901 return M
904@tf.function
905def _unit_vector(vector: tf.Tensor) -> tf.Tensor:
906 """Adapted from C. Gohlke's transformations.py"""
907 length = tf.sqrt(tf.reduce_sum(vector**2, axis=1))
908 return vector / tf.expand_dims(length, 1)
911@tf.keras.utils.register_keras_serializable()
912class BackMapLayer(Layer):
913 """Layer that implements backmapping from torsions-angles-distances to Euclidean coordinates."""
915 def __init__(self, left_split: int, right_split: int) -> None:
916 """Instantiate the layer."""
917 super().__init__()
918 self._name = "BackmapLayer"
919 self.left_split = left_split
920 self.right_split = right_split
922 @classmethod
923 def from_config(
924 cls: Type[BackMapLayerType],
925 config: dict[Any, Any],
926 ) -> BackMapLayerType:
927 """Reconstructs this keras serializable from a dict.
929 Args:
930 config (dict[Any, Any]): A dictionary.
932 Returns:
933 BackMapLayerType: An instance of the BackMapLayer.
935 """
936 left_split = config.pop("left_split")
937 right_split = config.pop("right_split")
938 return cls(left_split=left_split, right_split=right_split)
940 def get_config(self) -> dict[Any, Any]:
941 """Serializes this keras serializable.
943 Returns:
944 dict[Any, Any]: A dict with the serializable objects.
946 """
947 config = super().get_config().copy()
948 config.update(
949 {
950 "left_split": self.left_split,
951 "right_split": self.right_split,
952 }
953 )
954 return config
956 def call(self, inputs: tuple[tf.Tensor, tf.Tensor, tf.Tensor]) -> tf.Tensor:
957 """Call the layers, inputs should be a tuple shaped, so that it can be split into
958 distances, angles, dihedrals = inputs
959 """
960 distances, angles, dihedrals = inputs
961 # I don't know how negative distances can arrive at this step
962 # but we replace them with the mean
963 fixed_distances = tf.where(
964 distances < 0.00001,
965 tf.ones_like(distances) * tf.reduce_mean(distances),
966 distances,
967 )
968 # mean lengths
969 # back_mean_lengths = tf.expand_dims(tf.reduce_mean(inp_distances, 0), 0)
970 out = tf.expand_dims(tf.reduce_mean(distances, 0), 0)
972 # chain in plane
973 # back_chain_in_plane = chain_in_plane(back_mean_lengths, out_angles)
974 out = chain_in_plane(out, angles)
976 # dihedrals to cartesian
977 # back_cartesians = dihedrals_to_cartesian_tf(out_dihedrals + pi, back_chain_in_plane)
978 out_dihedrals = tf.add(dihedrals, pi)
979 out = dihedrals_to_cartesian_tf_layers(
980 out_dihedrals,
981 out,
982 left_iteration_counter=self.left_split,
983 right_iteration_counter=self.right_split,
984 )
985 return out
988@tf.keras.utils.register_keras_serializable()
989class BackMapLayerTransformations(Layer):
990 """Experimental layer for using multimers with the ADCEMap."""
992 def __init__(self, protein_lengths: Sequence[int]) -> None:
993 """Instantiate the layer.
995 Args:
996 protein_lengths (Sequence[int]): The lengths of the proteins in the
997 multimers. Based on this information, the input to `self.call`
998 will be split.
1000 """
1001 self.protein_lengths = protein_lengths
1002 super().__init__()
1003 self._name = "BackmapLayerTransformations"
1005 @classmethod
1006 def from_config(
1007 cls: Type[BackMapLayerTransformationsType],
1008 config: dict[Any, Any],
1009 ) -> BackMapLayerTransformationsType:
1010 """Reconstructs this keras serializable from a dict.
1012 Args:
1013 config (dict[Any, Any]): A dictionary.
1015 Returns:
1016 BackMapLayerTransformationsType: An instance of the BackMapLayerTransformations.
1018 """
1019 protein_lengths = config.pop("protein_lengths")
1020 return cls(protein_lengths=protein_lengths, **config)
1022 def get_config(self) -> dict[Any, Any]:
1023 """Serializes this keras serializable.
1025 Returns:
1026 dict[Any, Any]: A dict with the serializable objects.
1028 """
1029 config = super().get_config().copy()
1030 config.update(
1031 {
1032 "protein_lengths": self.protein_lengths,
1033 }
1034 )
1035 return config
1037 def call(self, inputs):
1038 """Call the layers, inputs should be a tuple shaped, so that it can be split into
1039 distances, angles, dihedrals, matrices = inputs
1040 """
1041 # Third Party Imports
1042 from tensorflow_graphics.rendering.utils import transform_homogeneous
1044 inp_distances, out_angles, out_dihedrals, matrices = inputs
1046 out_cartesians = []
1047 current_length = 0
1048 for i, protein_length in enumerate(self.protein_lengths):
1049 if current_length == 0:
1050 distance_ind = slice(0, protein_length * 3 - 1)
1051 angle_ind = slice(0, protein_length * 3 - 2)
1052 dihe_ind = slice(0, protein_length * 3 - 3)
1053 else:
1054 distance_ind = slice(
1055 current_length * 3 - i,
1056 current_length * 3 + protein_length * 3 - (i + 1),
1057 )
1058 angle_ind = slice(
1059 current_length * 3 - (i + 1),
1060 current_length * 3 + protein_length * 3 - (i + 2),
1061 )
1062 dihe_ind = slice(
1063 current_length * 3 - (i + 2),
1064 current_length * 3 + protein_length * 3 - (i + 3),
1065 )
1066 current_length += protein_length
1068 # index
1069 current_lengths = inp_distances[:, distance_ind]
1070 current_lengths = tf.expand_dims(tf.reduce_mean(current_lengths, 0), 0)
1071 current_angles = out_angles[:, angle_ind]
1072 current_dihedrals = out_dihedrals[:, dihe_ind]
1073 current_dihedrals = tf.add(current_dihedrals, pi)
1075 c = chain_in_plane(current_lengths, current_angles)
1076 c = dihedrals_to_cartesian_tf_layers(current_dihedrals, c)
1078 # for all other proteins apply homogeneous transformation matrices
1079 if i != 0:
1080 m = matrices[:, i - 1]
1081 c = transform_homogeneous(m, c)[..., :3]
1083 out_cartesians.append(c)
1085 out_cartesians = tf.concat(
1086 out_cartesians,
1087 axis=1,
1088 )
1089 return out_cartesians
1092@tf.keras.utils.register_keras_serializable()
1093class PeriodicOutput(EncoderMapBaseLayer):
1094 """Layer that reverses the PeriodicInputLayer."""
1096 def __init__(
1097 self,
1098 parameters: Union[Parameters, ADCParameters],
1099 print_name: str,
1100 trainable: bool = False,
1101 **kwargs,
1102 ) -> None:
1103 """Instantiate the layer, We also need to know here what periodicity is needed.
1105 Args:
1106 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]): An instance of
1107 encodermap's parameters.
1108 print_name (str): The name of this layer, as it should appear
1109 in summaries.
1110 trainable (bool): Whether this layer is trainable. As this layer
1111 has no kernel and/or bias. This argument has no influence.
1112 Defaults to False.
1114 """
1115 super().__init__(parameters, print_name, trainable)
1117 def call(self, inputs):
1118 """Calls the layer. Inputs should be a tuple of (sin, cos) of the same angles"""
1119 outputs = inputs
1120 outputs = tf.atan2(*tf.split(outputs, 2, 1))
1122 if self.p.periodicity != 2 * pi:
1123 outputs = outputs / (2 * pi) * self.p.periodicity
1124 return outputs
1127@tf.keras.utils.register_keras_serializable()
1128class MeanAngles(Layer):
1129 """Layer that implements the mean of periodic angles."""
1131 def __init__(
1132 self,
1133 parameters: Union[Parameters, ADCParameters],
1134 print_name: str,
1135 trainable: bool = False,
1136 **kwargs,
1137 ) -> None:
1138 """Instantiate the layer.
1140 Args:
1141 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]): An instance of
1142 encodermap's parameters.
1143 print_name (str): The name of this layer, as it should appear
1144 in summaries.
1145 trainable (bool): Whether this layer is trainable. As this layer
1146 has no kernel and/or bias. This argument has no influence.
1147 Defaults to False.
1149 """
1150 super().__init__()
1151 self.trainable = trainable
1153 def call(self, inputs):
1154 """Call the layer"""
1155 return tf.tile(
1156 tf.expand_dims(
1157 tf.math.reduce_mean(inputs, 0),
1158 0,
1159 ),
1160 multiples=(tf.shape(inputs)[0], 1),
1161 )
1164@tf.keras.utils.register_keras_serializable()
1165class PairwiseDistances(EncoderMapBaseLayer):
1166 """Layer that implements pairwise distances for both cases, with and without sidechain reconstruction"""
1168 def __init__(
1169 self,
1170 parameters: Union[Parameters, ADCParameters],
1171 print_name: str,
1172 trainable: bool = False,
1173 **kwargs,
1174 ) -> None:
1175 """Instantiate the layer.
1177 Args:
1178 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]): An instance of
1179 encodermap's parameters.
1180 print_name (str): The name of this layer, as it should appear
1181 in summaries.
1182 trainable (bool): Whether this layer is trainable. As this layer
1183 has no kernel and/or bias. This argument has no influence.
1184 Defaults to False.
1186 """
1187 super().__init__(parameters, print_name, trainable)
1188 if self.p.reconstruct_sidechains:
1189 assert hasattr(self.p, "sidechain_info"), (
1190 "The provided parameters ask for sidechains to be reconstructed, "
1191 "but don't contain a 'sidechain_info' attribute."
1192 )
1193 self.indices = None
1194 n_residues = max(list(self.p.sidechain_info[-1].keys()))
1195 self.indices = np.arange(n_residues * 3)[
1196 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step
1197 ]
1198 atom = n_residues * 3 + 1
1199 indices = []
1200 for residue, n_sidechains_in_residue in self.p.sidechain_info[-1].items():
1201 if n_sidechains_in_residue == 0:
1202 continue
1203 if residue == 1:
1204 atom += n_sidechains_in_residue
1205 else:
1206 atom += n_sidechains_in_residue
1207 indices.append(atom)
1208 self.indices = np.concatenate([self.indices, indices])
1210 def get_config(self) -> dict[Any, Any]:
1211 """Serializes this keras serializable.
1213 Returns:
1214 dict[Any, Any]: A dict with the serializable objects.
1216 """
1217 sidechain_info = self.p.sidechain_info
1218 config = super().get_config().copy()
1219 config.update(
1220 {
1221 "sidechain_info": sidechain_info,
1222 }
1223 )
1224 return config
1226 @classmethod
1227 def from_config(
1228 cls: Type[BackMapLayerWithSidechainsType],
1229 config: dict[Any, Any],
1230 ) -> BackMapLayerWithSidechainsType:
1231 """Reconstructs this keras serializable from a dict.
1233 Args:
1234 config (dict[Any, Any]): A dictionary.
1236 Returns:
1237 BackMapLayerType: An instance of the BackMapLayer.
1239 """
1240 p = config.pop("p")
1241 if "cartesian_pwd_start" in p:
1242 p = ADCParameters(**p)
1243 else:
1244 p = Parameters(**p)
1245 sidechain_info = config.pop("sidechain_info")
1246 out = {int(k): v for k, v in sidechain_info.items()}
1247 for k, v in out.items():
1248 out[k] = {int(kv): vv for kv, vv in v.items()}
1249 p.sidechain_info = out
1250 return cls(parameters=p, **config)
1252 def call(self, inputs):
1253 """Call the layer"""
1254 if not self.p.reconstruct_sidechains:
1255 out = inputs[
1256 :,
1257 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
1258 ]
1259 else:
1260 out = tf.gather(
1261 params=inputs,
1262 indices=self.indices,
1263 axis=1,
1264 batch_dims=0,
1265 )
1266 out = pairwise_dist(out, flat=True)
1267 return out