Coverage for encodermap/models/models.py: 8%
976 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
1# -*- coding: utf-8 -*-
2# encodermap/models/models.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""
24"""
26################################################################################
27# Imports
28################################################################################
31# Future Imports at the top
32from __future__ import annotations
34# Standard Library Imports
35import os
36import warnings
37from collections.abc import Iterable, Sequence
38from math import pi
39from typing import TYPE_CHECKING, Any, Literal, Optional, Type, TypeVar, Union, overload
41# Third Party Imports
42import numpy as np
43import scipy
44import tensorflow as tf
45import tensorflow.keras.backend as K
46from tensorflow.keras.layers import Concatenate, Dense, Input, Lambda
48# Encodermap imports
49from encodermap.encodermap_tf1.backmapping import (
50 chain_in_plane,
51 dihedrals_to_cartesian_tf,
52)
53from encodermap.loss_functions.loss_classes import testing
54from encodermap.misc.distances import pairwise_dist
55from encodermap.misc.summaries import add_layer_summaries
56from encodermap.models.layers import (
57 BackMapLayer,
58 BackMapLayerTransformations,
59 BackMapLayerWithSidechains,
60 MeanAngles,
61 PairwiseDistances,
62 PeriodicInput,
63 PeriodicOutput,
64)
65from encodermap.parameters.parameters import ADCParameters, AnyParameters, Parameters
66from encodermap.trajinfo.info_single import Capturing
69################################################################################
70# Typing
71################################################################################
74SequentialModelType = TypeVar(
75 "SequentialModelType",
76 bound="SequentialModel",
77)
78ADCSparseFunctionalModelType = TypeVar(
79 "ADCSparseFunctionalModelType",
80 bound="ADCSparseFunctionalModel",
81)
82ADCFunctionalModelTestingType = TypeVar(
83 "ADCFunctionalModelTestingType",
84 bound="ADCFunctionalModelTesting",
85)
86ADCFunctionalModelType = TypeVar(
87 "ADCFunctionalModelType",
88 bound="ADCFunctionalModel",
89)
90ADCFunctionalModelSidechainReconstructionType = TypeVar(
91 "ADCFunctionalModelSidechainReconstructionType",
92 bound="ADCFunctionalModelSidechainReconstruction",
93)
94ADCFunctionalModelInputType = Union[
95 tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor],
96 tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor],
97]
100################################################################################
101# Globals
102################################################################################
105__all__: list[str] = ["gen_sequential_model", "gen_functional_model"]
108################################################################################
109# Helper Classes
110################################################################################
113class MyKernelInitializer(tf.keras.initializers.Initializer):
114 """Custom Kernel initializer to make weights deterministic.
116 Gets a numpy array called weights. When called, it checks whether the requested
117 shape matches the shape of the numpy array and then returns the array.
118 For example, see the documentation of `MyBiasInitializer`.
120 """
122 def __init__(
123 self,
124 weights: np.ndarray,
125 ) -> None:
126 """Initialize the KernelInitializer.
128 Args:
129 weights (np.ndarray): The weights which will be returned when the initializer
130 is called.
132 """
133 self.weights = weights
135 def __call__(
136 self,
137 shape: Sequence[int],
138 dtype: Optional[Any] = None,
139 **kwargs,
140 ) -> np.ndarray:
141 """Returns the weight as a float32 numpy array.
143 Returns:
144 np.ndarray: The weight.
146 """
147 assert tuple(shape) == self.weights.shape, (
148 f"Can't initialize Kernel. Requested shape: {tuple(shape)} shape "
149 f"of pre-set weights: {self.weights.shape}"
150 )
151 return self.weights.astype("float32")
154class MyBiasInitializer(tf.keras.initializers.Initializer):
155 """Custom Bias initializer to make bias deterministic.
157 Gets a numpy array called bias. When called, it checks whether the requested
158 shape matches the shape of the numpy array and then returns the array.
160 Examples:
161 >>> # Imports
162 >>> from encodermap.models.models import MyBiasInitializer
163 >>> import numpy as np
164 >>> import tensorflow as tf
165 >>> from tensorflow import keras
166 >>> from tensorflow.keras import layers
167 ...
168 >>> # Create a model with the bias initializer
169 >>> model = tf.keras.models.Sequential(
170 ... [
171 ... layers.Dense(
172 ... 2,
173 ... activation="relu",
174 ... name="layer1",
175 ... bias_initializer=MyBiasInitializer(np.array([1.0, 0.5])),
176 ... ),
177 ... layers.Dense(
178 ... 3,
179 ... activation="relu",
180 ... name="layer2",
181 ... bias_initializer=MyBiasInitializer(np.array([0.1, 0.2, 0.3])),
182 ... ),
183 ... layers.Dense(4, name="layer3"),
184 ... ]
185 ... )
186 ...
187 >>> model.build(input_shape=(10, 2))
188 >>> for layer in model.layers:
189 ... print(layer.get_weights()[1])
190 [1. 0.5]
191 [0.1 0.2 0.3]
192 [0. 0. 0. 0.]
193 >>> # This example fails with an AssertionError, because the
194 >>> # bias shape of the second layer is wrong:
195 >>> model = tf.keras.models.Sequential(
196 ... [
197 ... layers.Dense(
198 ... 2,
199 ... activation="relu",
200 ... name="layer1",
201 ... bias_initializer=MyBiasInitializer(np.array([1.0, 0.5])),
202 ... ),
203 ... layers.Dense(
204 ... 3,
205 ... activation="relu",
206 ... name="layer2",
207 ... bias_initializer=MyBiasInitializer(np.array([0.1, 0.2])),
208 ... ),
209 ... layers.Dense(4, name="layer3"),
210 ... ]
211 ... )
212 ...
213 >>> model.build(input_shape=(10, 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
214 Traceback (most recent call last):
215 AssertionError: Can't initialize Bias. Requested shape: (3,) shape of pre-set bias: (2,)
217 """
219 def __init__(
220 self,
221 bias: np.ndarray,
222 ) -> None:
223 """Initialize the BiasInitializer.
225 Args:
226 bias (np.ndarray): The bias which will be returned when the initializer
227 is called.
229 """
230 self.bias = bias
232 def __call__(
233 self,
234 shape: Sequence[int],
235 dtype: Optional[Any] = None,
236 **kwargs,
237 ) -> np.ndarray:
238 """Returns the bias as a float32 numpy array.
240 Returns:
241 np.ndarray: The bias.
243 """
244 assert tuple(shape) == self.bias.shape, (
245 f"Can't initialize Bias. Requested shape: {tuple(shape)} shape "
246 f"of pre-set bias: {self.bias.shape}"
247 )
248 return self.bias.astype("float32")
251################################################################################
252# Public Functions
253################################################################################
256def gen_sequential_model(
257 input_shape: int,
258 parameters: Optional[AnyParameters] = None,
259 sparse: bool = False,
260) -> SequentialModel:
261 """Returns a tf.keras model with the specified input shape and the
262 parameters in the Parameters class.
264 Args:
265 input_shape (int): The input shape of the returned model. In most cases
266 that is data.shape[1] of your data.
267 parameters (Optional[AnyParameters]): The parameters to use on the
268 returned model. If None is provided the default parameters in
269 encodermap.Parameters.defaults is used. You can look at the defaults
270 with print(em.Parameters.defaults_description()). Defaults to None.
271 sparse (bool): Whether sparse inputs are expected. Defaults to False.
273 Returns:
274 em.SequentialModel: A subclass of tf.keras.Model build with specified parameters.
276 """
277 if parameters is None:
278 parameters = Parameters()
279 else:
280 if isinstance(parameters, Parameters):
281 return SequentialModel(input_shape, parameters, sparse=sparse)
282 elif isinstance(parameters, ADCParameters):
283 return ADCSequentialModel(input_shape, parameters)
284 else:
285 p = parameters
286 raise TypeError(
287 f"parameters need to be ecodermap.Parameters or encodermap.ADCParameters. You supplied {type(p)}"
288 )
291def _get_deterministic_random_normal(
292 mean: float = 0.1,
293 stddev: float = 0.05,
294 seed: Optional[int] = None,
295) -> tf.compat.v1.random_normal_initializer:
296 """Returns a deterministic random_normal_initializer wit tensorflow1.
298 For the tf2 implementation, look into `MyKernelInitializer`.
299 Moving from tf1 to tf2, the seeding method has changed, so that the same
300 seed can't be used to get the same random data in tf1 and tf2.
302 """
303 # Third Party Imports
304 import tensorflow.compat.v1 as tf
306 return tf.random_normal_initializer(mean, stddev, seed=seed)
309def _get_deterministic_variance_scaling(
310 seed: Optional[int] = None,
311) -> tf.compat.v1.variance_scaling_initializer:
312 """Returns a deterministic variance_scaling_initializer wit tensorflow1.
314 For the tf2 implementation, look into `MyBiasInitializer`.
315 Moving from tf1 to tf2, the seeding method has changed, so that the same
316 seed can't be used to get the same random data in tf1 and tf2.
318 """
319 # Third Party Imports
320 import tensorflow.compat.v1 as tf
322 return tf.variance_scaling_initializer(seed=seed)
325@overload
326def gen_functional_model( 326 ↛ exitline 326 didn't jump to the function exit
327 input_shapes: Union[
328 tf.data.Dataset,
329 tuple[tuple[int], tuple[int], tuple[int, int], tuple[int], tuple[int]],
330 ],
331 parameters: Optional[ADCParameters] = None,
332 sparse: bool = False,
333 sidechain_only_sparse: bool = False,
334 kernel_initializer: Union[
335 dict[str, np.ndarray], Literal["ones", "VarianceScaling", "deterministic"]
336 ] = "VarianceScaling",
337 bias_initializer: Union[
338 dict[str, np.ndarray], Literal["ones", "RandomNormal", "deterministic"]
339 ] = "RandomNormal",
340 write_summary: bool = True,
341 use_experimental_model: bool = True,
342) -> ADCFunctionalModelTesting: ...
345@overload
346def gen_functional_model( 346 ↛ exitline 346 didn't jump to the function exit
347 input_shapes: Union[
348 tf.data.Dataset,
349 tuple[tuple[int], tuple[int], tuple[int, int], tuple[int], tuple[int]],
350 ],
351 parameters: Optional[ADCParameters] = None,
352 sparse: bool = False,
353 sidechain_only_sparse: bool = False,
354 kernel_initializer: Union[
355 dict[str, np.ndarray], Literal["ones", "VarianceScaling", "deterministic"]
356 ] = "VarianceScaling",
357 bias_initializer: Union[
358 dict[str, np.ndarray], Literal["ones", "RandomNormal", "deterministic"]
359 ] = "RandomNormal",
360 write_summary: bool = True,
361 use_experimental_model: bool = False,
362) -> ADCFunctionalModel: ...
365@overload
366def gen_functional_model( 366 ↛ exitline 366 didn't jump to the function exit
367 input_shapes: Union[
368 tf.data.Dataset,
369 tuple[tuple[int], tuple[int], tuple[int], tuple[int], tuple[int]],
370 ],
371 parameters: Optional[ADCParameters] = None,
372 sparse: bool = False,
373 sidechain_only_sparse: bool = False,
374 kernel_initializer: Union[
375 dict[str, np.ndarray], Literal["ones", "VarianceScaling", "deterministic"]
376 ] = "VarianceScaling",
377 bias_initializer: Union[
378 dict[str, np.ndarray], Literal["ones", "RandomNormal", "deterministic"]
379 ] = "RandomNormal",
380 write_summary: bool = True,
381 use_experimental_model: bool = False,
382) -> ADCSparseFunctionalModel: ...
385def gen_functional_model(
386 input_shapes: Union[
387 tf.data.Dataset,
388 tuple[
389 tuple[int],
390 tuple[int],
391 Union[tuple[int, int], tuple[int]],
392 tuple[int],
393 tuple[int],
394 ],
395 ],
396 parameters: Optional[ADCParameters] = None,
397 sparse: bool = False,
398 sidechain_only_sparse: bool = False,
399 kernel_initializer: Union[
400 dict[str, np.ndarray], Literal["ones", "VarianceScaling", "deterministic"]
401 ] = "VarianceScaling",
402 bias_initializer: Union[
403 dict[str, np.ndarray], Literal["ones", "RandomNormal", "deterministic"]
404 ] = "RandomNormal",
405 write_summary: bool = True,
406 use_experimental_model: bool = False,
407) -> Union[ADCSparseFunctionalModel, ADCFunctionalModel, ADCFunctionalModelTesting]:
408 """New implementation of the functional model API for AngleCartesianDihedralEncoderMap
410 The functional API is much more flexible than the sequential API, in that
411 models with multiple inputs and outputs can be defined. Custom layers and
412 submodels can be intermixed. In EncoderMap's case, the functional API is used to
413 build the AngleDihedralCartesianAutoencoder, which takes input data in form
414 of a tf.data.Dataset with:
415 * backbone_angles (angles between C, CA, N - atoms in the backbone).
416 * backbone_torsions (dihedral angles in the backbone,
417 commonly known as omega, phi, psi).
418 * cartesian_coordinates (coordinates of the C, CA, N backbone
419 atoms. This data has ndim 3, the other have ndim 2).
420 * backbone_distances (distances between the C, CA, N backbone atoms).
421 * sidechain_torsions (dihedral angles in the sidechain,
422 commonly known as chi1, chi2, chi3, chi4, chi5).
423 Packing and unpacking that data in the correct order is important.
424 Make sure to double-check whether you are using angles or dihedrals.
425 A simple print of the shape can be enough.
427 Args:
428 input_shapes(Union[tf.data.Dataset, tuple[int, int, int, int, int]]):
429 The input shapes, that will be used in the construction of the model.
430 parameters (Optional[encodermap.parameters.ADCParameters]): An instance
431 of `encodermap.parameters.ADCParameters`,
432 which holds further parameters in network construction. If None
433 is provided, a new instance with default parameters will be
434 created. Defaults to None.
435 sparse (bool): Whether sparse inputs are expected. Defaults to False.
436 sidechain_only_sparse (bool): A special case, when the proteins have
437 the same number of residues, but different numbers of sidechain
438 dihedrals. In that case only the sidechain dihedrals are considered
439 to be sparse. Defaults to False.
440 kernel_initializer (Union[dict[str, np.ndarray],
441 Literal["ones", "VarianceScaling", "deterministic"]]): How to initialize
442 the weights. If "ones" is provided, the weights will be initialized
443 with `tf.keras.initializers.Constant(1)`. If "VarianceScaling" is
444 provided, the weights will be initialized with `tf.keras.initializers.
445 VarianceScaling()`. Defaults to "VarianceScaling". If "deterministic"
446 is provided, a seed will be used with VarianceScaling. If a dict with
447 weight matrices is supplied, the keys should follow this naming con-
448 vention: ["dense/kernel", "dense_1/kernel", "dense_2/kernel", etc.]
449 This is tensorflow's naming convention for unnamed dense layers.
450 bias_initializer (Union[dict[str, np.ndarray],
451 Literal["ones", "RandomNormal", "deterministic"]]): How to initialize
452 the weights. If "ones" is provided, the weights will be initialized
453 with `tf.keras.initializers.Constant(1)`. If "RandomNormal" is
454 provided, the weights will be initialized with `tf.keras.initializers.
455 RandomNormal(0.1, 0.05)`. Defaults to "RandomNormal". If "deterministic"
456 is provided, a seed will be used with RandomNormal. If a dict with
457 bias matrices is supplied, the keys should follow this naming con-
458 vention: ["dense/bias", "dense_1/bias", "dense_2/bias", etc.]
459 This is tensorflow's naming convention for unnamed dense layers.
460 write_summary (bool): Whether to print a summary. If p.tensorboard is True
461 a file will be generated. at the main_path.
463 Returns:
464 tf.keras.models.Model: The model.
467 Here's a scheme of the generated network::
469 ┌───────────────────────────────────────────────────────────────────────────────────────┐
470 │A linear protein with N standard residues has N*3 backbone atoms (..C-N-CA-C-N..) │
471 │it has N*3 - 1 distances between these atoms │
472 │it has N*3 - 2 angles between three atoms │
473 │it has N*3 - 3 dihedrals between 4 atoms │
474 │it has S sidechain dihedrals based on the sequence │
475 └───────┬─────────────────┬─────────────────┬─────────────────┬─────────────────┬───────┘
476 │ │ │ │ │
477 │ │ │ │ │
478 │ │ │ │ │
479 ┌───────┴───────┐ ┌───────┴───────┐ ┌───────┴───────┐ ┌───────┴───────┐ ┌───────┴───────┐
480 │cartesians │ │distances │ │angles │ │dihedrals │ │side dihedrals │
481 │(batch, N*3, 3)│ │(batch, N*3-1) │ │(batch, N*3-2) │ │(batch, N*3-3) │ │(batch, S) ├───────┐
482 └───────┬───────┘ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘ └────────┬──────┘ │Every type
483 │ │ │ │ │ │of angular
484 │ │ ┌─────────┼─────────────────┼──────────────────┤ │input has
485 │ │ │ │ │ │ │its own
486 ┌───────┴───────┐ │ │ ┌───────┴───────┐ ┌───────┴───────┐ ┌────────┴──────┐ │cost contri
487 │pair cartesians│ │ ┌─────┼─┤unitcircle ang │ │unitcircle dih │ │unitcircle sdih│ │bution
488 │(batch, batch) │ │ │if no│ │(b, (N*3-2)*2) │ │(b, (N*3-3)*2) │ │(b, S*2) │ │which
489 └───────┬───────┘ │ │angles └───────┬───────┘ └───────┬───────┘ └────────┬──────┘ │compares
490 │compare the pair │ │are │ │ │ │ │input
491 │wise distances of│ │fed │ if│use_backbone_angles if│use_sidechains│and
492 │the input cartesi│ │through │ │ │ │output
493 │ans with the gene│ │the ┼ ┌───────┴─────────────────┴──────────────────┴──────┐ │->
494 │rated cartesians │ │network│concatenate the angle-inputs. Based on parameters. │ │angle_cost
495 │-> cartesian loss│ │use │ │(batch, sum(angle_shapes) │ │dihedral_cost
496 │ │ │mean │ └─────────────────────────┬─────────────────────────┘ │side_dihedral
497 │ │ │angles │ │_cost
498 │ │ │ │ │ │
499 │ │ │ │ │ │
500 │ │ │ │ ┌─────────────┴──────────────┐ │
501 │ │ │ │ │Encoder layers │ │
502 │ │ │ │ │(batch, n_neurons) │ │
503 │ │ │ │ └─────────────┬──────────────┘ │
504 │ │ │ │ │ │
505 │ │ │ │ │ │
506 │ │ │ │add a sigmoid-weighted │ add a loss function │
507 │ compare the│ │ │loss function that┌────────┴────────┐ to center the points │
508 │ ┌──────────┼─┼─────┴──────────────────┤Bottleneck,Latent├──────────────────── │
509 │ │generated │ │ compares the pair-│ (batch, 2) │ around the origin │
510 │ │cartesians│ │ wise distances of └────────┬────────┘ -> center loss │
511 │ │with the │ │ input and latent │ │
512 │ │pairwise │ │ samples │ │
513 │ │distances │ │ -> distance loss │ │
514 │ │of the │ │ ┌─────────────┴──────────────┐ │
515 │ │bottleneck│ │ │Decoder layers │ │
516 │ │use a 2nd │ │ │(batch, n_neurons) │ │
517 │ │sigmoid │ │ └─────────────┬──────────────┘ │
518 │ │function │ │ │ │
519 │ │for this │ │ │ │
520 │ │-> │ │ │ │
521 │ │cartesian │ │ ┌─────────────────────────┴─────────────────────────┐ │
522 │ │distance │ │ │split the output of the decoder to get angles back │ │
523 │ │loss │ │ │(batch, sum(angle_shapes) │ │
524 │ │ │ │ └───────┬─────────────────┬─────────────────┬───────┘ │
525 │ │ │ │ │ │ │ │
526 │ │ │ │ │ │ │ │
527 │ │ │ │ │ │ │ │
528 │ │ │ │ ┌───────┴───────┐ ┌───────┴───────┐ ┌───────┴───────┐ │
529 │ │ │ │ │unitcircle ang │ │unitcircle dih │ │unitcircle sdih│ │
530 │ │ │ │ │(b, (N*3-2)*2) │ │(b, (N*3-3)*2) │ │(b, S*2) │ │
531 │ │ │ │ └───────┬───────┘ └───────┬───────┘ └────────┬──────┘ │
532 │ │ │ │ │ │ │ │
533 │ │ │ │ if│use_backbone_angles if│use_sidechains│
534 │ │ │ │ │ │ │ │
535 │ │ │ │ ┌───────┴───────┐ ┌───────┴───────┐ ┌────────┴──────┐ │
536 │ │ │ └───────┤(mean) angles │ │dihedrals │ │side dihedrals │ │
537 │ │ │ │(batch,3N*3-2) │ │(batch,3N*3-3) │ │(batch, S) ├───────┘
538 │ │ │ └───────┬───────┘ └───────┬───────┘ └───────────────┘
539 │ │ │ │ │
540 │ │ │ │ │
541 │ │ │ │ │
542 │ │ ┌───────┴─────────────────┴─────────────────┴──────┐
543 │ │ │create new cartesians with chain-in-plane and │
544 │ │ │rotation matrices (batch, 3*N, 3) │
545 │ │ └───────┬──────────────────────────────────────────┘
546 │ │ │
547 │ │ │
548 │ │ │
549 │ │ ┌───────┴───────┐
550 │ └──┤gen pair cartes│
551 │ │(batch,batch) │
552 └─────────┴───────────────┘
554 """
555 if isinstance(input_shapes, tuple):
556 assert isinstance(input_shapes[2], tuple), (
557 f"Please provide a tuple for the shape of the cartesians, so a model "
558 f"with the correct sparse inputs can be created."
559 )
560 # if parameters are None, create a new instance of `ADCParameters`
561 p = ADCParameters() if parameters is None else parameters
563 # inform the user about tensorflow_graphics
564 if p.multimer_training is not None:
565 if p.multimer_training == "homogeneous_transform":
566 try:
567 # Third Party Imports
568 import tensorflow_graphics as tfg
569 except ModuleNotFoundError as e:
570 raise Exception(
571 f"To use the 'homogeneous_transform' multimer training, please "
572 f"install the 'tensorflow_graphics' package:\n"
573 f"`pip install tensorflow_graphics`"
574 ) from e
576 # it is important to keep track of the inputs.
577 # the inputs will always be provided in the order:
578 # angles, central_dihedrals, cartesians, distances, side_dihedrals
579 # these values will always be provided. They might not go through the
580 # network (especially the side_dihedrals), but the shape will be provided
581 # nonetheless.
582 if not p.reconstruct_sidechains:
583 (
584 angles_input_shape,
585 central_dihedrals_input_shape,
586 cartesians_input_shape,
587 distances_input_shape,
588 side_dihedrals_input_shape,
589 sparse,
590 sidechain_only_sparse,
591 ) = _unpack_and_assert_input_shapes(
592 input_shapes,
593 p,
594 sparse,
595 sidechain_only_sparse,
596 )
597 else:
598 (
599 central_angles_input_shape,
600 central_dihedrals_input_shape,
601 all_cartesians_input_shape,
602 central_distances_input_shape,
603 side_angles_input_shape,
604 side_dihedrals_input_shape,
605 side_distances_input_shape,
606 sparse,
607 sidechain_only_sparse,
608 ) = _unpack_and_assert_input_shapes_w_sidechains(
609 input_shapes,
610 p,
611 sparse,
612 sidechain_only_sparse,
613 )
615 # define the regularizer, that will be used from here on out
616 # the L2 regularizer adds a loss with 1/2 * sum(w ** 2) to
617 # each layer
618 regularizer = tf.keras.regularizers.l2(p.l2_reg_constant)
620 # define the inputs
621 (
622 input_central_dihedrals_placeholder,
623 input_central_dihedrals_unit_circle,
624 input_central_dihedrals_dense_model,
625 ) = _create_inputs_periodic_maybe_sparse(
626 central_dihedrals_input_shape,
627 p,
628 name="central_dihedrals",
629 sparse=sparse,
630 )
631 assert (
632 input_central_dihedrals_placeholder.shape[1] * 2
633 == input_central_dihedrals_unit_circle.shape[1]
634 )
636 # The split left and split right are needed for the `dihedrals_to_cartesian_tf_layers`
637 # function. In this function is a for loop, that needs to iterate over a
638 # conditional. That conditional needs to be set here, otherwise tensorflow
639 # produces a symbolic tensor with shape [None, None], which can't be used for
640 # iteration
641 # The chain of atoms is split so that 3 atoms remain in on the xy plain
642 # the left and right tails of these three atoms rotate into the z-axis to
643 # create a 3D structure. However, the number of central cartesians can
644 # be even or uneven, depending on the number of amino acids.
645 # Case 1: 3 Amino Acids (9 cartesians, 6 dihedrals):
646 # Here, the N-CA-C atoms of the 2nd amino acid remain on the xy plane.
647 # The left and right cartesians index the atoms of the first two and last
648 # two residues, respectively. Thus, they assume the shape 6 and 6.
649 # The 6 dihedrals are split evenly into 3 and 3.
650 # Case 2: 4 Amino acids (12 cartesians, 9 dihedrals):
651 # Here, the left cartesians contain 8 atoms, the right contain 7 atoms.
652 # Thus, the dihedrals are split unevenly into 5 dihedrals left, 4 right
653 # Case 3: M1-connected diUbi with 152 residues (456 cartesians, 453 dihedrals):
654 # Here, the dihedrals are split into 227 for left and 226 for right
655 # if sparse:
656 # _cartesians_input_shape = cartesians_input_shape // 3
657 # else:
658 # _cartesians_input_shape = cartesians_input_shape
659 # _split = int(int(_cartesians_input_shape) / 2)
660 # _cartesian_right = np.arange(_cartesians_input_shape)[_split - 1:]
661 # _dihedrals_right = np.arange(central_dihedrals_input_shape)[_split - 1:]
662 # _cartesian_left = np.arange(_cartesians_input_shape)[_split + 1:: -1]
663 # _dihedrals_left = np.arange(central_dihedrals_input_shape)[_split - 2:: -1]
664 # _n_left = int(_dihedrals_left.shape[-1])
665 # _n_right = int(_dihedrals_right.shape[-1])
666 if not p.reconstruct_sidechains:
667 if sparse:
668 left_split = cartesians_input_shape // 3 // 2 - 1
669 else:
670 left_split = cartesians_input_shape // 2 - 1
671 right_split = central_dihedrals_input_shape // 2
672 # assert _n_left == left_split, f"{_n_left=} {left_split=} {cartesians_input_shape=} {sparse=}"
673 # assert _n_right == right_split
675 # this input list is provided as inputs to the encoder
676 # that way if the user trains with backbone angles and side-didherals
677 # the encoder can be provided with a list/tuple of three numpy arrays
678 # with the respective values,
679 # otherwise the user would need to stack these arrays along the
680 # feature axis before feeding it to the encoder
681 encoder_input_list = [input_central_dihedrals_placeholder]
683 # backbone angles
684 # For the case of sparse and not using backbone angles, the angles will be
685 # treated as a non-periodic input.
686 if p.use_backbone_angles and not p.reconstruct_sidechains:
687 (
688 input_angles_placeholder,
689 input_angles_unit_circle,
690 input_angles_dense_model,
691 ) = _create_inputs_periodic_maybe_sparse(
692 angles_input_shape,
693 p,
694 name="angles",
695 sparse=sparse,
696 )
697 assert (
698 input_angles_placeholder.shape[1] * 2 == input_angles_unit_circle.shape[1]
699 )
700 encoder_input_list = [
701 input_angles_placeholder,
702 input_central_dihedrals_placeholder,
703 ]
704 elif p.use_backbone_angles and p.reconstruct_sidechains:
705 (
706 input_central_angles_placeholder,
707 input_central_angles_unit_circle,
708 input_central_angles_dense_model,
709 ) = _create_inputs_periodic_maybe_sparse(
710 central_angles_input_shape,
711 p,
712 name="central_angles",
713 sparse=sparse,
714 )
715 assert (
716 input_central_angles_placeholder.shape[1] * 2
717 == input_central_angles_unit_circle.shape[1]
718 )
719 encoder_input_list = [
720 input_central_angles_placeholder,
721 input_central_dihedrals_placeholder,
722 ]
723 else:
724 (
725 input_angles_placeholder,
726 input_angles_unit_circle,
727 input_angles_dense_model,
728 ) = _create_inputs_non_periodic_maybe_sparse(
729 shape=(angles_input_shape,),
730 p=p,
731 name=f"angles",
732 sparse=sparse,
733 )
735 # sidechain dihedrals
736 if p.use_sidechains:
737 # define the inputs
738 (
739 input_side_dihedrals_placeholder,
740 input_side_dihedrals_unit_circle,
741 input_side_dihedrals_dense_model,
742 ) = _create_inputs_periodic_maybe_sparse(
743 side_dihedrals_input_shape,
744 p,
745 name="side_dihedrals",
746 sparse=sparse or sidechain_only_sparse,
747 )
748 assert (
749 input_side_dihedrals_placeholder.shape[1] * 2
750 == input_side_dihedrals_unit_circle.shape[1]
751 )
752 encoder_input_list.append(input_side_dihedrals_placeholder)
753 else:
754 input_side_dihedrals_placeholder = None
755 input_side_dihedrals_unit_circle = None
756 input_side_dihedrals_dense_model = None
758 # create more input placeholders for the sidechain angles
759 if p.reconstruct_sidechains:
760 (
761 input_side_angles_placeholder,
762 input_side_angles_unit_circle,
763 input_side_angles_dense_model,
764 ) = _create_inputs_periodic_maybe_sparse(
765 side_angles_input_shape,
766 p,
767 name="side_angles",
768 sparse=sparse or sidechain_only_sparse,
769 )
770 assert (
771 input_side_dihedrals_placeholder.shape[1] * 2
772 == input_side_dihedrals_unit_circle.shape[1]
773 )
774 encoder_input_list = [
775 input_central_angles_placeholder,
776 input_central_dihedrals_placeholder,
777 input_side_angles_placeholder,
778 input_side_dihedrals_placeholder,
779 ]
781 # create input placeholders for the cartesians
782 if not p.reconstruct_sidechains:
783 if not sparse:
784 maybe_sparse_cartesian_input_shape = (cartesians_input_shape, 3)
785 else:
786 maybe_sparse_cartesian_input_shape = (cartesians_input_shape,)
787 else:
788 if not sparse:
789 maybe_sparse_cartesian_input_shape = (all_cartesians_input_shape, 3)
790 else:
791 maybe_sparse_cartesian_input_shape = (all_cartesians_input_shape,)
792 (
793 input_cartesians_placeholder,
794 input_dense_cartesians_placeholder,
795 input_cartesians_dense_model,
796 ) = _create_inputs_non_periodic_maybe_sparse(
797 shape=maybe_sparse_cartesian_input_shape,
798 p=p,
799 name="cartesians",
800 sparse=sparse,
801 reshape=3,
802 )
804 if p.reconstruct_sidechains:
805 (
806 input_central_distances_placeholder,
807 input_dense_central_distances_placeholder,
808 input_central_distances_dense_model,
809 ) = _create_inputs_non_periodic_maybe_sparse(
810 shape=(central_distances_input_shape,),
811 p=p,
812 name="central_distances",
813 sparse=sparse,
814 )
815 (
816 input_side_distances_placeholder,
817 input_dense_side_distances_placeholder,
818 input_side_distances_dense_model,
819 ) = _create_inputs_non_periodic_maybe_sparse(
820 shape=(side_distances_input_shape,),
821 p=p,
822 name="side_distances",
823 sparse=sparse,
824 )
825 else:
826 # create input placeholders for the distances
827 (
828 input_distances_placeholder,
829 input_dense_distances_placeholder,
830 input_distances_dense_model,
831 ) = _create_inputs_non_periodic_maybe_sparse(
832 shape=(distances_input_shape,), p=p, name="distances", sparse=sparse
833 )
835 # we can now create the input pairwise distances, which can be used in
836 # the case of multimer homogeneous_transformation matrices
837 input_cartesians_pairwise = PairwiseDistances(p, "input")(
838 input_dense_cartesians_placeholder
839 )
841 # flatten the input cartesians pairwise distances to pass them through the network
842 # the pairwise distances are, just like angles
843 # rotationally and translationally invariant and can thus be used for
844 # training
845 input_cartesians_pairwise_defined_shape = None
846 multimer_lengths = []
847 if p.multimer_training is not None:
848 if p.multimer_topology_classes is not None:
849 key1 = list(p.multimer_lengths.keys())[0]
850 for key, val in p.multimer_lengths.items():
851 for i, v in enumerate(val):
852 assert v == p.multimer_lengths[key1][i], (
853 f"The current model for using multiple topologies with "
854 f"multimers only supports multimers with the same number "
855 f"of residues per multiimer in all topology classes."
856 )
857 multimer_lengths = p.multimer_lengths[key1]
858 else:
859 multimer_lengths = p.multimer_lengths
860 flattened_shape = int(scipy.special.binom(sum(multimer_lengths), 2))
861 if p.multimer_training == "homogeneous_transformation":
862 input_cartesians_pairwise_defined_shape = tf.reshape(
863 input_cartesians_pairwise,
864 shape=(tf.shape(input_cartesians_pairwise)[0], flattened_shape),
865 )
866 encoder_input_list.append(input_dense_cartesians_placeholder)
868 # define the splits for the decoder output
869 # because the angular inputs are concatenated for the decoder, we want
870 # to keep track how to split them afterward
871 if not p.reconstruct_sidechains:
872 splits, encoder_input_placeholder = _concatenate_inputs(
873 p,
874 input_angles_unit_circle,
875 input_central_dihedrals_unit_circle,
876 input_side_dihedrals_unit_circle,
877 input_cartesians_pairwise_defined_shape,
878 )
879 else:
880 input_angles_unit_circle = input_central_angles_unit_circle
881 splits, encoder_input_placeholder = _concatenate_inputs_reconstruct_sidechains(
882 p,
883 input_central_angles_unit_circle,
884 input_central_dihedrals_unit_circle,
885 input_side_angles_unit_circle,
886 input_side_dihedrals_unit_circle,
887 )
889 assert encoder_input_placeholder is not None
890 assert encoder_input_list is not None
891 assert all([i is not None for i in encoder_input_list])
893 # build the encoder provide it with the encoder_input_placeholder
894 encoder_model, encoder_output_placeholder = _get_encoder_model(
895 encoder_input_placeholder,
896 p,
897 input_list=encoder_input_list,
898 kernel_initializer=kernel_initializer,
899 kernel_regularizer=regularizer,
900 bias_initializer=bias_initializer,
901 write_summary=write_summary,
902 )
904 # build the decoder to the required shape
905 (
906 decoder_model,
907 output_angles_placeholder,
908 output_central_dihedrals_placeholder,
909 output_side_dihedrals_placeholder,
910 extra_output_placeholder,
911 ) = _get_adc_decoder(
912 p,
913 splits,
914 input_angles_placeholder=input_angles_unit_circle,
915 kernel_initializer=kernel_initializer,
916 kernel_regularizer=regularizer,
917 bias_initializer=bias_initializer,
918 write_summary=write_summary,
919 input_placeholder=encoder_output_placeholder,
920 n_proteins=len(multimer_lengths),
921 )
923 # Provide the backmap layer with all it needs
924 if p.multimer_training is None and not p.reconstruct_sidechains:
925 back_cartesians = BackMapLayer(
926 left_split=left_split,
927 right_split=right_split,
928 )(
929 (
930 input_dense_distances_placeholder,
931 output_angles_placeholder,
932 output_central_dihedrals_placeholder,
933 )
934 )
935 elif p.multimer_training is None and p.reconstruct_sidechains:
936 _inputs = (
937 input_dense_central_distances_placeholder,
938 output_angles_placeholder,
939 output_central_dihedrals_placeholder,
940 input_dense_side_distances_placeholder,
941 extra_output_placeholder,
942 output_side_dihedrals_placeholder,
943 )
944 back_cartesians = BackMapLayerWithSidechains(p.sidechain_info)(_inputs)
945 else:
946 back_cartesians = BackMapLayerTransformations(multimer_lengths)(
947 (
948 input_dense_distances_placeholder,
949 output_angles_placeholder,
950 output_central_dihedrals_placeholder,
951 extra_output_placeholder,
952 )
953 )
955 # create the pairwise distances of the input cartesians and the output
956 # back-mapped cartesians
957 output_cartesians_pairwise = PairwiseDistances(p, "output")(back_cartesians)
959 # create a functional model from the inputs and outputs
960 # define the inputs
961 if not p.reconstruct_sidechains:
962 inputs = [
963 input_angles_placeholder,
964 input_central_dihedrals_placeholder,
965 input_cartesians_placeholder,
966 input_distances_placeholder,
967 ]
969 # the outputs depend on the parameters used
970 # for use_backbone_angles, the decoder output will be a list of
971 # tensors and thus needs to be unpacked
972 # else, the output is a single tensor and can't be unpacked
973 if p.use_backbone_angles:
974 outputs = [
975 *decoder_model(encoder_model(encoder_input_list)),
976 back_cartesians,
977 input_cartesians_pairwise,
978 output_cartesians_pairwise,
979 ]
980 else:
981 outputs = [
982 output_angles_placeholder,
983 decoder_model(encoder_model(encoder_input_list)),
984 back_cartesians,
985 input_cartesians_pairwise,
986 output_cartesians_pairwise,
987 ]
988 if len(splits) >= 3:
989 inputs.append(input_side_dihedrals_placeholder)
991 else:
992 inputs = [
993 input_central_angles_placeholder,
994 input_central_dihedrals_placeholder,
995 input_cartesians_placeholder,
996 input_central_distances_placeholder,
997 input_side_angles_placeholder,
998 input_side_dihedrals_placeholder,
999 input_side_distances_placeholder,
1000 ]
1001 outputs = [
1002 *decoder_model(encoder_model(encoder_input_list)),
1003 back_cartesians,
1004 input_cartesians_pairwise,
1005 output_cartesians_pairwise,
1006 ]
1008 model = ADCFunctionalModelSidechainReconstruction(
1009 parameters=p,
1010 inputs=inputs,
1011 outputs=outputs,
1012 encoder=encoder_model,
1013 decoder=decoder_model,
1014 get_dense_model_central_dihedrals=input_central_dihedrals_dense_model,
1015 get_dense_model_central_angles=input_central_angles_dense_model,
1016 get_dense_model_side_dihedrals=input_side_dihedrals_dense_model,
1017 get_dense_model_cartesians=input_cartesians_dense_model,
1018 get_dense_model_central_distances=input_central_distances_dense_model,
1019 get_dense_model_side_distances=input_side_distances_dense_model,
1020 get_dense_model_side_angles=input_side_angles_dense_model,
1021 )
1022 return model
1024 # create the final model
1025 if not sparse and not sidechain_only_sparse:
1026 if use_experimental_model:
1027 ModelClass = ADCFunctionalModelTesting
1028 else:
1029 ModelClass = ADCFunctionalModel
1030 model = ModelClass(
1031 parameters=p,
1032 inputs=inputs,
1033 outputs=outputs,
1034 encoder=encoder_model,
1035 decoder=decoder_model,
1036 )
1037 else:
1038 model = ADCSparseFunctionalModel(
1039 parameters=p,
1040 inputs=inputs,
1041 outputs=outputs,
1042 encoder=encoder_model,
1043 decoder=decoder_model,
1044 get_dense_model_central_dihedrals=input_central_dihedrals_dense_model,
1045 get_dense_model_central_angles=input_angles_dense_model,
1046 get_dense_model_side_dihedrals=input_side_dihedrals_dense_model,
1047 get_dense_model_cartesians=input_cartesians_dense_model,
1048 get_dense_model_distances=input_distances_dense_model,
1049 )
1051 # write a summary
1052 if write_summary:
1053 if p.tensorboard or p.write_summary:
1054 with Capturing() as output:
1055 model.summary()
1056 with open(p.main_path + "/complete_model_summary.txt", "w") as f:
1057 f.write("\n".join(output))
1058 else:
1059 model.summary()
1060 return model
1063def _unpack_and_assert_input_shapes_w_sidechains(
1064 input_shapes: Union[
1065 tf.data.Dataset,
1066 tuple[
1067 tuple[int],
1068 tuple[int],
1069 Union[tuple[int, int], tuple[int]],
1070 tuple[int],
1071 tuple[int],
1072 ],
1073 ],
1074 p: ADCParameters,
1075 input_sparse: bool = False,
1076 input_sidechain_only_sparse: bool = False,
1077) -> tuple[int, int, int, int, int, bool, bool]:
1078 """This function unpacks and asserts the input_shapes for the regular protein case.
1080 In contrast to `_unpack_data_and_assert_input_shapes`, a full sidechain
1081 reconstruction will be executed.
1083 Args:
1084 input_shapes(Union[tf.data.Dataset, tuple[int, int, int, int, int]]):
1085 The input shapes, that will be used in the construction of the model.
1086 parameters (Optional[encodermap.parametersADCParameters]): An instance
1087 of `encodermap.parameters.ADCParameters`,
1088 which holds further parameters in network construction. If None
1089 is provided, a new instance with default parameters will be
1090 created. Defaults to None.
1091 sparse (bool): Whether sparse inputs are expected. Defaults to False.
1092 input_sidechain_only_sparse (bool): Whether only the sidechain dihedrals
1093 are sparse. In that case, the input shape of the cartesians is
1094 different, because the cartesians are flattened to a rank 2 tensor
1095 before running them through a dense layer and then stacking them again
1096 to shape (n_frames, n_atoms, 3).
1098 Returns:
1099 tuple: A tuple containing the following:
1100 - int: The input shape for the training angles.
1101 - int: The input shape for the training dihedrals.
1102 - int: The input shape for the cartesians.
1103 - int: The input shape for the distances.
1104 - Union[int, None]: The input shape for the training sidechain dihedrals.
1105 Can be None, if they are not used for training.
1107 """
1108 if p.multimer_training is not None:
1109 assert (
1110 not input_sparse
1111 ), f"Using multimers currently not possible with sparse and/or full sidechain reconstruction."
1112 return _unpack_and_assert_input_shapes_multimers(input_shapes, p)
1113 if isinstance(input_shapes, (tuple, list)):
1114 assert len(input_shapes) == 7
1115 (
1116 central_angles_input_shape,
1117 central_dihedrals_input_shape,
1118 all_cartesians_input_shape,
1119 central_distances_input_shape,
1120 side_angles_input_shape,
1121 side_dihedrals_input_shape,
1122 side_distances_input_shape,
1123 ) = [i[0] for i in input_shapes]
1124 if input_sparse and len(input_shapes[2]) == 2:
1125 sidechain_only_sparse = True
1126 sparse = input_sparse
1127 else:
1128 sidechain_only_sparse = input_sidechain_only_sparse
1129 sparse = input_sparse
1130 else:
1131 d = input_shapes.element_spec
1133 # all dense
1134 if not any([isinstance(i, tf.SparseTensorSpec) for i in d]):
1135 sparse = False
1136 sidechain_only_sparse = False
1137 # only sparse sidechains
1138 elif all(
1139 [not isinstance(i, tf.SparseTensorSpec) for i in d[:-1]]
1140 ) and isinstance(d[-1], tf.SparseTensorSpec):
1141 sparse = False
1142 sidechain_only_sparse = True
1143 # other stuff sparse
1144 else:
1145 sparse = True
1146 sidechain_only_sparse = False
1148 # check if dataset is batches
1149 try:
1150 central_angles_input_shape = d[0].shape[1]
1151 except IndexError as e:
1152 raise Exception(
1153 f"You probably provided a tf.data.Dataset, that is not batched "
1154 f"and thus an index error was raised."
1155 ) from e
1157 # define shapes
1158 central_dihedrals_input_shape = d[1].shape[1]
1159 all_cartesians_input_shape = d[2].shape[1]
1160 central_distances_input_shape = d[3].shape[1]
1161 side_angles_input_shape = d[4].shape[1]
1162 try:
1163 side_dihedrals_input_shape = d[5].shape[1]
1164 side_distances_input_shape = d[6].shape[1]
1165 except IndexError:
1166 raise Exception(f"Not enough items in tuple for sidechain reconstruction.")
1168 # make sure that the inputs have had the correct order
1169 # because a protein with N residues has N*3 cartesians, N*3 - 1 distances
1170 # N*3 - 2 angles, and N*3 - 3 dihedrals
1171 N = (central_distances_input_shape + 1) / 3
1172 if not sparse or sidechain_only_sparse:
1173 assert all_cartesians_input_shape == N * 3 + side_distances_input_shape
1174 # sparse tensors have to be rank 2, so the sparse cartesians need to be
1175 # flattened, and the stacked back, once they are dense again
1176 # as of tf >= 2.16 sparse tensors can have a higher rank
1177 # maybe this is worth updating
1178 else:
1179 assert all_cartesians_input_shape // 3 == N * 3 + side_distances_input_shape
1180 assert central_angles_input_shape == N * 3 - 2
1181 assert central_dihedrals_input_shape == central_angles_input_shape - 1
1182 assert side_dihedrals_input_shape < side_angles_input_shape
1184 return (
1185 central_angles_input_shape,
1186 central_dihedrals_input_shape,
1187 all_cartesians_input_shape,
1188 central_distances_input_shape,
1189 side_angles_input_shape,
1190 side_dihedrals_input_shape,
1191 side_distances_input_shape,
1192 sparse,
1193 sidechain_only_sparse,
1194 )
1197@testing
1198def _unpack_and_assert_input_shapes_multimers(
1199 input_shapes: Union[
1200 tf.data.Dataset,
1201 tuple[
1202 tuple[int],
1203 tuple[int],
1204 tuple[int, int],
1205 tuple[int],
1206 tuple[int],
1207 ],
1208 ],
1209 p: ADCParameters,
1210) -> tuple[int, int, int, int, int, bool, bool]:
1211 if not p.use_backbone_angles:
1212 raise Exception(
1213 f"Training with multimers currently only possible with backbone_angles"
1214 )
1215 if not p.use_sidechains:
1216 raise Exception(
1217 f"Training with multimers currently only possible with use_sidechains"
1218 )
1220 if p.multimer_topology_classes is not None:
1221 lengths = []
1222 for top in p.multimer_topology_classes:
1223 n_proteins = len(p.multimer_lengths[top])
1224 lengths.append(n_proteins)
1225 assert len(set(lengths)) == 1, (
1226 f"Can only use topology-class multimer training with a consistent "
1227 f"number of proteins per multimer. Got "
1228 f"{[len(p.multimer_lengths[top]) for top in p.multimer_topology_classes]=}"
1229 )
1230 n_proteins = lengths[0]
1231 else:
1232 n_proteins = len(p.multimer_lengths)
1234 if isinstance(input_shapes, tf.data.Dataset):
1235 for d in input_shapes:
1236 break
1237 input_shapes = tuple([i.shape[1:] for i in d])
1239 (
1240 angles_input_shape,
1241 central_dihedrals_input_shape,
1242 cartesians_input_shape,
1243 distances_input_shape,
1244 side_dihedrals_input_shape,
1245 ) = [i[0] for i in input_shapes]
1247 N = cartesians_input_shape // 3 // n_proteins
1248 assert distances_input_shape == n_proteins * (N * 3 - 1)
1249 assert angles_input_shape == n_proteins * (N * 3 - 2)
1250 assert central_dihedrals_input_shape == n_proteins * (N * 3 - 3)
1252 return (
1253 angles_input_shape,
1254 central_dihedrals_input_shape,
1255 cartesians_input_shape,
1256 distances_input_shape,
1257 side_dihedrals_input_shape,
1258 False,
1259 False,
1260 )
1263def _unpack_and_assert_input_shapes(
1264 input_shapes: Union[
1265 tf.data.Dataset,
1266 tuple[
1267 tuple[int],
1268 tuple[int],
1269 Union[tuple[int, int], tuple[int]],
1270 tuple[int],
1271 tuple[int],
1272 ],
1273 ],
1274 p: ADCParameters,
1275 input_sparse: bool = False,
1276 input_sidechain_only_sparse: bool = False,
1277) -> tuple[int, int, int, int, Union[int, None], bool, bool]:
1278 """This function unpacks and asserts the input_shapes for the regular protein case.
1280 Args:
1281 input_shapes(Union[tf.data.Dataset, tuple[int, int, int, int, int]]):
1282 The input shapes, that will be used in the construction of the model.
1283 parameters (Optional[encodermap.parameters.ADCParameters]): An instance of
1284 `encodermap.parameters.ADCParameters`,
1285 which holds further parameters in network construction. If None
1286 is provided, a new instance with default parameters will be
1287 created. Defaults to None.
1288 sparse (bool): Whether sparse inputs are expected. Defaults to False.
1289 input_sidechain_only_sparse (bool): Whether only the sidechain dihedrals
1290 are sparse. In that case, the input shape of the cartesians is
1291 different, because the cartesians are flattened to a rank 2 tensor
1292 before running them through a dense layer and then stacking them again
1293 to shape (n_frames, n_atoms, 3).
1295 Returns:
1296 tuple: A tuple containing the following:
1297 - int: The input shape for the training angles.
1298 - int: The input shape for the training dihedrals.
1299 - int: The input shape for the cartesians.
1300 - int: The input shape for the distances.
1301 - Union[int, None]: The input shape for the training sidechain dihedrals.
1302 Can be None, if they are not used for training.
1304 """
1305 if p.multimer_training is not None:
1306 assert not input_sparse, f"Using multimers currently not possible with sparse."
1307 return _unpack_and_assert_input_shapes_multimers(input_shapes, p)
1308 if isinstance(input_shapes, (tuple, list)):
1309 if len(input_shapes) == 5:
1310 (
1311 angles_input_shape,
1312 central_dihedrals_input_shape,
1313 cartesians_input_shape,
1314 distances_input_shape,
1315 side_dihedrals_input_shape,
1316 ) = [i[0] for i in input_shapes]
1317 else:
1318 (
1319 angles_input_shape,
1320 central_dihedrals_input_shape,
1321 cartesians_input_shape,
1322 distances_input_shape,
1323 ) = [i[0] for i in input_shapes]
1324 side_dihedrals_input_shape = None
1325 if input_sparse and len(input_shapes[2]) == 2:
1326 sidechain_only_sparse = True
1327 sparse = input_sparse
1328 else:
1329 sidechain_only_sparse = input_sidechain_only_sparse
1330 sparse = input_sparse
1331 else:
1332 d = input_shapes.element_spec
1334 # all dense
1335 if not any([isinstance(i, tf.SparseTensorSpec) for i in d]):
1336 sparse = False
1337 sidechain_only_sparse = False
1338 # only sparse sidechains
1339 elif all(
1340 [not isinstance(i, tf.SparseTensorSpec) for i in d[:-1]]
1341 ) and isinstance(d[-1], tf.SparseTensorSpec):
1342 sparse = False
1343 sidechain_only_sparse = True
1344 # other stuff sparse
1345 else:
1346 sparse = True
1347 sidechain_only_sparse = False
1349 # check if dataset is batches
1350 try:
1351 angles_input_shape = d[0].shape[1]
1352 except IndexError as e:
1353 raise Exception(
1354 f"You probably provided a tf.data.Dataset, that is not batched "
1355 f"and thus an index error was raised."
1356 ) from e
1358 # define shapes
1359 central_dihedrals_input_shape = d[1].shape[1]
1360 try:
1361 cartesians_input_shape = d[2].shape[1]
1362 except IndexError as e:
1363 raise Exception(
1364 f"Could not decide on a cartesian input shape for the requested "
1365 f"model using the provided dataset with {d=}. Normally, "
1366 f"it is expected for index 2 of this dataset to provide the "
1367 f"input shape of the cartesian coordinates. However, an "
1368 f"IndexError was raised, trying to access this index. "
1369 ) from e
1370 distances_input_shape = d[3].shape[1]
1371 if len(d) > 4:
1372 side_dihedrals_input_shape = d[4].shape[1]
1373 else:
1374 side_dihedrals_input_shape = None
1376 # make sure that the inputs have had the correct order
1377 # because a protein with N residues has N*3 cartesians, N*3 - 1 distances
1378 # N*3 - 2 angles, and N*3 - 3 dihedrals
1379 if not sparse or sidechain_only_sparse:
1380 N = cartesians_input_shape // 3
1381 # sparse tensors have to be rank 2, so the sparse cartesians need to be
1382 # flattened, and the stacked back, once they are dense again
1383 else:
1384 N = cartesians_input_shape // 3 // 3
1385 assert (
1386 distances_input_shape == N * 3 - 1
1387 ), f"{N=} {sparse=} {sidechain_only_sparse=}"
1388 assert angles_input_shape == N * 3 - 2
1389 assert central_dihedrals_input_shape == N * 3 - 3
1391 return (
1392 angles_input_shape,
1393 central_dihedrals_input_shape,
1394 cartesians_input_shape,
1395 distances_input_shape,
1396 side_dihedrals_input_shape,
1397 sparse,
1398 sidechain_only_sparse,
1399 )
1402def _get_adc_decoder(
1403 p: ADCParameters,
1404 splits: list[int],
1405 input_angles_placeholder: Optional[tf.Tensor] = None,
1406 kernel_initializer: Union[
1407 dict[str, np.ndarray], Literal["ones", "VarianceScaling", "deterministic"]
1408 ] = "VarianceScaling",
1409 kernel_regularizer: tf.keras.regularizers.Regularizer = tf.keras.regularizers.l2(
1410 0.001
1411 ),
1412 bias_initializer: Union[
1413 dict[str, np.ndarray], Literal["ones", "RandomNormal", "deterministic"]
1414 ] = "RandomNormal",
1415 write_summary: bool = False,
1416 input_placeholder: Optional[tf.Tensor] = None,
1417 n_proteins: Optional[int] = None,
1418) -> tuple[
1419 tf.keras.models.Model,
1420 tf.Tensor,
1421 tf.Tensor,
1422 Union[None, tf.Tensor],
1423 Union[None, tf.Tensor],
1424]:
1425 """Special function to run a decoder and unpack the outputs.
1427 This function calls `_get_decoder_model` to get a standard decoder and then
1428 splits the output according to the provided `splits` and the `p`.
1430 Args:
1431 p (encodermap.parameters.ADCParameters): The parameters.
1432 splits (list[int]): A list of ints giving the splits of the decoder
1433 outputs. It is expected that the splits follow the logic of
1434 angles-dihedrals-sidedihedrals. If only dihedrals are used for
1435 training, `splits` is expected to be a list of len 1.
1436 input_angles_placeholder (Optional[tf.Tensor]): When only using dihedrals
1437 for training, this placeholder should be provided to create a
1438 set of mean angles. Can also be None, in case len(splits) >= 2.
1439 kernel_initializer (Union[dict[str, np.ndarray],
1440 Literal["ones", "VarianceScaling", "deterministic"]]): How to initialize
1441 the weights. If "ones" is provided, the weights will be initialized
1442 with `tf.keras.initializers.Constant(1)`. If "VarianceScaling" is
1443 provided, the weights will be initialized with `tf.keras.initializers.
1444 VarianceScaling()`. Defaults to "VarianceScaling". If "deterministic"
1445 is provided, a seed will be used with VarianceScaling. If a dict with
1446 weight matrices is supplied, the keys should follow this naming con-
1447 vention: ["dense/kernel", "dense_1/kernel", "dense_2/kernel", etc.]
1448 This is tensorflow's naming convention for unnamed dense layers.
1449 kernel_regularizer (tf.keras.regularizers.Regularizer): The regularizer
1450 for the kernel (i.e. the layer weights). Standard in EncoderMap is
1451 to use the l2 regularizer with a regularization constant of 0.001.
1452 bias_initializer (Union[dict[str, np.ndarray],
1453 Literal["ones", "RandomNormal", "deterministic"]]): How to initialize
1454 the weights. If "ones" is provided, the weights will be initialized
1455 with `tf.keras.initializers.Constant(1)`. If "RandomNormal" is
1456 provided, the weights will be initialized with `tf.keras.initializers.
1457 RandomNormal(0.1, 0.05)`. Defaults to "RandomNormal". If "deterministic"
1458 is provided, a seed will be used with RandomNormal. If a dict with
1459 bias matrices is supplied, the keys should follow this naming con-
1460 vention: ["dense/bias", "dense_1/bias", "dense_2/bias", etc.]
1461 This is tensorflow's naming convention for unnamed dense layers.
1462 write_summary (bool): Whether to print a summary. If p.tensorboard is True
1463 a file will be generated. at the main_path.
1464 n_proteins (Optional[int]): If not None, number of proteins that
1465 constitute the multimer group that is trained.
1467 Returns:
1468 tuple: A tuple containing the following:
1469 - tf.keras.models.Model: The decoder model.
1470 - tf.Tensor: The angles (either mean, or learned angles).
1471 - tf.Tensor: The dihedrals.
1472 - Union[None, tf.Tensor]: The sidechain dihedrals. If p.use_sidechains
1473 is false, None will be returned.
1474 - Union[None, tf.Tensor]: The homogeneous transformation matrices
1475 for multimer training. If p.multimer_training is None, None
1476 will be returned.
1479 """
1480 if len(splits) == 2:
1481 assert splits[0] - 2 == splits[1], (
1482 f"Order of splits is wrong. It is expected, that the splits for "
1483 f"angles at splits[0], is splits two larger than the central dihedrals "
1484 f"at splits[1]. However, {splits=}"
1485 )
1486 # get the standard decoder and its inputs and outputs.
1487 if p.multimer_training is not None:
1488 splits[-1] = (n_proteins - 1) * 4 * 4
1489 decoder, output_placeholder, input_placeholder = _get_decoder_model(
1490 p=p,
1491 out_shape=sum(splits),
1492 kernel_initializer=kernel_initializer,
1493 kernel_regularizer=kernel_regularizer,
1494 bias_initializer=bias_initializer,
1495 write_summary=False,
1496 input_placeholder=input_placeholder,
1497 )
1499 # fmt: off
1500 splits_side_dihedrals = None
1501 extra_tensor = None
1502 if len(splits) == 1:
1503 assert not p.use_backbone_angles and not p.use_sidechains, f"Parameters and splits do not coincide: {p=}, {splits=}"
1504 splits_central_dihedrals = PeriodicOutput(p, "dihedrals_from_unit_circle")(output_placeholder)
1505 splits_angles = MeanAngles(p, "mean_angles")(input_angles_placeholder)
1506 decoder_output = splits_central_dihedrals
1507 elif len(splits) == 2:
1508 assert p.use_backbone_angles and not p.use_sidechains, f"Parameters and splits do not coincide: {p=}, {splits=}"
1509 # out_angles, out_dihedrals = tf.split(output_placeholder, splits, 1)
1510 out_angles, out_dihedrals = tf.keras.layers.Lambda(lambda tensor: tf.split(tensor, splits, 1))(output_placeholder)
1511 splits_angles = PeriodicOutput(p, "angles_from_unit_circle")(out_angles)
1512 splits_central_dihedrals = PeriodicOutput(p, "dihedrals_from_unit_circle")(out_dihedrals)
1513 decoder_output = (splits_angles, splits_central_dihedrals)
1514 elif len(splits) == 3:
1515 assert p.use_backbone_angles and p.use_sidechains, f"Parameters and splits do not coincide: {p=}, {splits=}"
1516 # out_angles, out_dihedrals, out_side_dihedrals = tf.split(output_placeholder, splits, 1)
1517 out_angles, out_dihedrals, out_side_dihedrals = tf.keras.layers.Lambda(lambda tensor: tf.split(tensor, splits, 1))(output_placeholder)
1518 splits_angles = PeriodicOutput(p, "angles_from_unit_circle")(out_angles)
1519 splits_central_dihedrals = PeriodicOutput(p, "dihedrals_from_unit_circle")(out_dihedrals)
1520 splits_side_dihedrals = PeriodicOutput(p, "side_dihedrals_from_unit_circle")(out_side_dihedrals)
1521 decoder_output = (splits_angles, splits_central_dihedrals, splits_side_dihedrals)
1522 elif len(splits) == 4:
1523 if p.multimer_training is None and not p.reconstruct_sidechains:
1524 raise Exception(f"Got wrong splits: {splits=}")
1525 if p.multimer_training is not None and not p.reconstruct_sidechains:
1526 # out_angles, out_dihedrals, out_side_dihedrals, out_transformation_matrices = tf.split(output_placeholder, splits, 1)
1527 out_angles, out_dihedrals, out_side_dihedrals, out_transformation_matrices = tf.keras.layers.Lambda(lambda tensor: tf.split(tensor, splits, 1))(output_placeholder)
1528 splits_angles = PeriodicOutput(p, "angles_from_unit_circle")(out_angles)
1529 splits_central_dihedrals = PeriodicOutput(p, "dihedrals_from_unit_circle")(out_dihedrals)
1530 splits_side_dihedrals = PeriodicOutput(p, "side_dihedrals_from_unit_circle")(out_side_dihedrals)
1531 extra_tensor = tf.reshape(
1532 out_transformation_matrices,
1533 shape=(tf.shape(splits_angles)[0], n_proteins - 1, 4, 4)
1534 )
1535 if p.multimer_training is None and p.reconstruct_sidechains:
1536 # out_angles, out_dihedrals, out_side_angles, out_side_dihedrals = tf.split(output_placeholder, splits, 1)
1537 out_angles, out_dihedrals, out_side_angles, out_side_dihedrals = tf.keras.layers.Lambda(lambda tensor: tf.split(tensor, splits, 1))(output_placeholder)
1538 splits_angles = PeriodicOutput(p, "central_angles_from_unit_circle")(out_angles)
1539 splits_central_dihedrals = PeriodicOutput(p, "central_dihedrals_from_unit_circle")(out_dihedrals)
1540 extra_tensor = PeriodicOutput(p, "side_angles_from_unit_circle")(out_side_angles)
1541 splits_side_dihedrals = PeriodicOutput(p, "side_dihedrals_from_unit_circle")(out_side_dihedrals)
1542 decoder_output = (splits_angles, splits_central_dihedrals, extra_tensor, splits_side_dihedrals)
1543 else:
1544 raise Exception(f"Got wrong splits: {splits=}")
1545 # fmt: on
1547 # create the model
1548 try:
1549 model = tf.keras.models.Model(
1550 inputs=input_placeholder,
1551 outputs=decoder_output,
1552 name="Decoder",
1553 )
1554 except ValueError as e:
1555 raise Exception(f"{splits=}, {decoder_output=}") from e
1557 # assert that the sequence of outputs is correct
1558 assert model.input_shape[1] == p.n_neurons[-1]
1559 # for use_backbone angles we can assert the shape
1560 if p.use_backbone_angles:
1561 assert model.output_shape[0][1] * 2 == splits[0]
1562 # if only using dihedrals, we can't assume the shape, as it is defined
1563 # during runtime by the batch size.
1564 else:
1565 assert model.output_shape[0] is None
1566 # if using backbone_angles, the output shape is a list of tuple of ints
1567 if p.use_backbone_angles:
1568 assert model.output_shape[1][1] * 2 == splits[1]
1569 else:
1570 model.output_shape[-1] == splits[0]
1571 if p.use_sidechains:
1572 assert model.output_shape[2][1] * 2 == splits[2]
1573 else:
1574 assert len(model.output_shape) == 2
1576 # write a summary
1577 if write_summary:
1578 if p.tensorboard or p.write_summary:
1579 with Capturing() as output:
1580 model.summary()
1581 with open(p.main_path + "/decoder_summary.txt", "w") as f:
1582 f.write("\n".join(output))
1583 else:
1584 model.summary()
1586 return (
1587 model,
1588 splits_angles,
1589 splits_central_dihedrals,
1590 splits_side_dihedrals,
1591 extra_tensor,
1592 )
1595def _get_decoder_model(
1596 p: ADCParameters,
1597 out_shape: int,
1598 kernel_initializer: Union[
1599 dict[str, np.ndarray], Literal["ones", "VarianceScaling", "deterministic"]
1600 ] = "VarianceScaling",
1601 kernel_regularizer: tf.keras.regularizers.Regularizer = tf.keras.regularizers.l2(
1602 0.001
1603 ),
1604 bias_initializer: Union[
1605 dict[str, np.ndarray], Literal["ones", "RandomNormal", "deterministic"]
1606 ] = "RandomNormal",
1607 write_summary: bool = False,
1608 input_placeholder: Optional[tf.Tensor] = None,
1609) -> tuple[tf.keras.models.Model, tf.Tensor, tf.Tensor]:
1610 """Create a decoder to the requested specs.
1612 Contrary to the `_get_encoder_model` function, this function doesn't require
1613 an input placeholder. The input placeholder is created in the function body.
1614 Thus, a combined autoencoder model can be built by stacking the encoder and
1615 decoder like so: `output = decoder(encoder(input))`.
1617 Args:
1618 p (encodermap.parameters.ADCParameters): The parameters.
1619 out_shape (int): The output shape of the decoder. Make sure to match it
1620 with the input shape of the encoder.
1621 kernel_initializer (Union[dict[str, np.ndarray],
1622 Literal["ones", "VarianceScaling", "deterministic"]]): How to initialize
1623 the weights. If "ones" is provided, the weights will be initialized
1624 with `tf.keras.initializers.Constant(1)`. If "VarianceScaling" is
1625 provided, the weights will be initialized with `tf.keras.initializers.
1626 VarianceScaling()`. Defaults to "VarianceScaling". If "deterministic"
1627 is provided, a seed will be used with VarianceScaling. If a dict with
1628 weight matrices is supplied, the keys should follow this naming con-
1629 vention: ["dense/kernel", "dense_1/kernel", "dense_2/kernel", etc.]
1630 This is tensorflow's naming convention for unnamed dense layers.
1631 kernel_regularizer (tf.keras.regularizers.Regularizer): The regularizer
1632 for the kernel (i.e. the layer weights). Standard in EncoderMap is
1633 to use the l2 regularizer with a regularization constant of 0.001.
1634 bias_initializer (Union[dict[str, np.ndarray],
1635 Literal["ones", "RandomNormal", "deterministic"]]): How to initialize
1636 the weights. If "ones" is provided, the weights will be initialized
1637 with `tf.keras.initializers.Constant(1)`. If "RandomNormal" is
1638 provided, the weights will be initialized with `tf.keras.initializers.
1639 RandomNormal(0.1, 0.05)`. Defaults to "RandomNormal". If "deterministic"
1640 is provided, a seed will be used with RandomNormal. If a dict with
1641 bias matrices is supplied, the keys should follow this naming con-
1642 vention: ["dense/bias", "dense_1/bias", "dense_2/bias", etc.]
1643 This is tensorflow's naming convention for unnamed dense layers.
1644 write_summary (bool): Whether to print a summary. If p.tensorboard is True
1645 a file will be generated. at the main_path.
1647 Returns:
1648 tuple: A tuple containing the following:
1649 - tf.keras.models.Model: The decoder model.
1650 - tf.Tensor: The output tensor with shape `out_shape`.
1651 - tf.Tensor: The input placeholder tensor with shape `p.n_neurons`.
1653 """
1654 n_neurons_with_inputs = [out_shape] + p.n_neurons
1656 # generate a new placeholder
1657 # this way, the decoder is can be created as a detached model, if no
1658 # input placeholder is provided
1659 if input_placeholder is None:
1660 inp = Input(shape=(p.n_neurons[-1],), name="decoder_input")
1661 else:
1662 inp = input_placeholder
1663 assert inp.shape[1] == p.n_neurons[-1], (
1664 f"The input shape of the decoder does not match the requested input "
1665 f"shape. I got an input shape of {inp.shape[1]=}, while parameters "
1666 f"requested {p.n_neurons[-1]=}."
1667 )
1669 out = inp
1670 for i, (n_neurons, act_fun) in enumerate(
1671 zip(n_neurons_with_inputs[-2::-1], p.activation_functions[-2::-1])
1672 ):
1673 if act_fun:
1674 act_fun = getattr(tf.nn, act_fun)
1675 else:
1676 act_fun = None
1677 if isinstance(kernel_initializer, str):
1678 if kernel_initializer == "VarianceScaling":
1679 _kernel_initializer = tf.keras.initializers.VarianceScaling()
1680 elif kernel_initializer == "deterministic":
1681 seed = 121110987654321 + i
1682 _kernel_initializer = _get_deterministic_variance_scaling(seed=seed)
1683 elif kernel_initializer == "ones":
1684 _kernel_initializer = tf.keras.initializers.Constant(1)
1685 else:
1686 raise Exception(
1687 f"Keyword `kernel_initializer` only supports 'VarianceScaling' "
1688 f", 'ones', or 'deterministic'. Got {kernel_initializer=}"
1689 )
1690 elif isinstance(kernel_initializer, dict):
1691 kernel_name = f"dense_{i + len(p.n_neurons)}/kernel"
1692 _kernel_initializer = MyKernelInitializer(kernel_initializer[kernel_name])
1693 else:
1694 raise TypeError(
1695 f"Arg `kernel_initializer` must be of type str or dict, "
1696 f"you supplied {type(kernel_initializer)=}."
1697 )
1698 if isinstance(bias_initializer, str):
1699 if bias_initializer == "RandomNormal":
1700 _bias_initializer = tf.keras.initializers.RandomNormal(0.1, 0.05)
1701 elif bias_initializer == "deterministic":
1702 seed = 121110987654321 + i
1703 _bias_initializer = _get_deterministic_random_normal(
1704 0.1, 0.05, seed=seed
1705 )
1706 elif bias_initializer == "ones":
1707 _bias_initializer = tf.keras.initializers.Constant(1)
1708 else:
1709 raise Exception(
1710 f"Keyword `bias_initializer` only supports 'RandomNormal' "
1711 f", 'ones' or 'deterministic'. Got {bias_initializer=}"
1712 )
1713 elif isinstance(bias_initializer, dict):
1714 bias_name = f"dense_{i + len(p.n_neurons)}/bias"
1715 _bias_initializer = MyBiasInitializer(bias_initializer[bias_name])
1716 else:
1717 raise TypeError(
1718 f"Arg `bias_initializer` must be of type str or dict, "
1719 f"you supplied {type(bias_initializer)=}."
1720 )
1721 out = tf.keras.layers.Dense(
1722 units=n_neurons,
1723 activation=act_fun,
1724 name=f"Decoder_{i}",
1725 kernel_initializer=_kernel_initializer,
1726 kernel_regularizer=kernel_regularizer,
1727 bias_initializer=_bias_initializer,
1728 )(out)
1730 # create the model
1731 model = tf.keras.models.Model(
1732 inputs=inp,
1733 outputs=out,
1734 name="Decoder",
1735 )
1737 # check the model input and output
1738 model.input_shape[1] == p.n_neurons[-1]
1739 model.output_shape[1] == out_shape
1741 if write_summary:
1742 if p.tensorboard or p.write_summary:
1743 with Capturing() as output:
1744 model.summary()
1745 with open(p.main_path + "/decoder_summary.txt", "w") as f:
1746 f.write("\n".join(output))
1747 else:
1748 model.summary()
1750 return model, out, inp
1753def _get_encoder_model(
1754 inp: tf.Tensor,
1755 p: "AnyParameters",
1756 input_list: list[tf.Tensor],
1757 kernel_initializer: Union[
1758 dict[str, np.ndarray], Literal["ones", "VarianceScaling", "deterministic"]
1759 ] = "VarianceScaling",
1760 kernel_regularizer: tf.keras.regularizers.Regularizer = tf.keras.regularizers.l2(
1761 0.001
1762 ),
1763 bias_initializer: Union[
1764 dict[str, np.ndarray], Literal["ones", "RandomNormal", "deterministic"]
1765 ] = "RandomNormal",
1766 write_summary: bool = False,
1767) -> tuple[tf.keras.models.Model, tf.Tensor]:
1768 """Create an encoder model and feed the inp through it.
1770 Args:
1771 inp (tf.Tensor): The input tensor of the encoder.
1772 p (encodermap.parameters.ADCParameters): The parameters.
1773 input_list (list[tf.Tensor]): This list contains the input placeholders
1774 for the encoder. Make sure that these input tensors point to the
1775 `inp` tensor in some way.
1776 kernel_initializer (Union[dict[str, np.ndarray],
1777 Literal["ones", "VarianceScaling", "deterministic"]]): How to initialize
1778 the weights. If "ones" is provided, the weights will be initialized
1779 with `tf.keras.initializers.Constant(1)`. If "VarianceScaling" is
1780 provided, the weights will be initialized with `tf.keras.initializers.
1781 VarianceScaling()`. Defaults to "VarianceScaling". If "deterministic"
1782 is provided, a seed will be used with VarianceScaling. If a dict with
1783 weight matrices is supplied, the keys should follow this naming con-
1784 vention: ["dense/kernel", "dense_1/kernel", "dense_2/kernel", etc.]
1785 This is tensorflow's naming convention for unnamed dense layers.
1786 kernel_regularizer (tf.keras.regularizers.Regularizer): The regularizer
1787 for the kernel (i.e. the layer weights). Standard in EncoderMap is
1788 to use the l2 regularizer with a regularization constant of 0.001.
1789 bias_initializer (Union[dict[str, np.ndarray],
1790 Literal["ones", "RandomNormal", "deterministic"]]): How to initialize
1791 the weights. If "ones" is provided, the weights will be initialized
1792 with `tf.keras.initializers.Constant(1)`. If "RandomNormal" is
1793 provided, the weights will be initialized with `tf.keras.initializers.
1794 RandomNormal(0.1, 0.05)`. Defaults to "RandomNormal". If "deterministic"
1795 is provided, a seed will be used with RandomNormal. If a dict with
1796 bias matrices is supplied, the keys should follow this naming con-
1797 vention: ["dense/bias", "dense_1/bias", "dense_2/bias", etc.]
1798 This is tensorflow's naming convention for unnamed dense layers.
1799 write_summary (bool): Whether to print a summary. If p.tensorboard is True
1800 a file will be generated. at the main_path.
1802 Returns:
1803 tuple: A tuple containing:
1804 - tf.keras.models.Model: The encoder model.
1805 - tf.Tensor: The output of the model.
1807 """
1808 out = inp
1809 for i, (n_neurons, act_fun) in enumerate(
1810 zip(p.n_neurons, p.activation_functions[1:])
1811 ):
1812 # define the activation function for this dense layer
1813 if act_fun:
1814 act_fun = getattr(tf.nn, act_fun)
1815 else:
1816 act_fun = None
1818 # get the kernel initializer for that layer
1819 if isinstance(kernel_initializer, str):
1820 if kernel_initializer == "VarianceScaling":
1821 _kernel_initializer = tf.keras.initializers.VarianceScaling()
1822 elif kernel_initializer == "deterministic":
1823 seed = 123456789101112 + i
1824 _kernel_initializer = _get_deterministic_variance_scaling(seed=seed)
1825 elif kernel_initializer == "ones":
1826 _kernel_initializer = tf.keras.initializers.Constant(1)
1827 else:
1828 raise Exception(
1829 f"Keyword `kernel_initializer` only supports 'VarianceScaling' "
1830 f", 'ones', or 'deterministic'. Got {kernel_initializer=}"
1831 )
1832 elif isinstance(kernel_initializer, dict):
1833 if i == 0:
1834 kernel_name = "dense/kernel"
1835 else:
1836 kernel_name = f"dense_{i}/kernel"
1837 _kernel_initializer = MyKernelInitializer(kernel_initializer[kernel_name])
1838 else:
1839 raise TypeError(
1840 f"Arg `kernel_initializer` must be of type str or dict, "
1841 f"you supplied {type(kernel_initializer)=}."
1842 )
1843 if isinstance(bias_initializer, str):
1844 if bias_initializer == "RandomNormal":
1845 _bias_initializer = tf.keras.initializers.RandomNormal(0.1, 0.05)
1846 elif bias_initializer == "deterministic":
1847 seed = 123456789101112 + i
1848 _bias_initializer = _get_deterministic_random_normal(
1849 mean=0.1, stddev=0.05, seed=seed
1850 )
1851 elif bias_initializer == "ones":
1852 _bias_initializer = tf.keras.initializers.Constant(1)
1853 else:
1854 raise Exception(
1855 f"Keyword `bias_initializer` only supports 'RandomNormal' "
1856 f", 'ones' or 'deterministic'. Got {bias_initializer=}"
1857 )
1858 elif isinstance(bias_initializer, dict):
1859 if i == 0:
1860 bias_name = "dense/bias"
1861 else:
1862 bias_name = f"dense_{i}/bias"
1863 _bias_initializer = MyBiasInitializer(bias_initializer[bias_name])
1864 else:
1865 raise TypeError(
1866 f"Arg `bias_initializer` must be of type str or dict, "
1867 f"you supplied {type(bias_initializer)=}."
1868 )
1870 # define the layer and directly call it
1871 layer = tf.keras.layers.Dense(
1872 units=n_neurons,
1873 activation=act_fun,
1874 name=f"Encoder_{i}",
1875 kernel_initializer=_kernel_initializer,
1876 kernel_regularizer=kernel_regularizer,
1877 bias_initializer=_bias_initializer,
1878 )
1879 out = layer(out)
1880 # if kernel_initializer == "deterministic" and i == 0:
1881 # raise Exception(f"{_kernel_initializer((906, 128))[0, 0]=} {_kernel_initializer.seed=} {_bias_initializer((128, ))[0]=} {_bias_initializer.seed=}")
1882 # # assert layer.weights[0].numpy()[0, 0] == 0.004596
1883 # raise Exception
1885 # check the shape of the output
1886 out.shape[1] == p.n_neurons[-1]
1888 # create the model here
1889 model = tf.keras.models.Model(
1890 inputs=input_list,
1891 outputs=out,
1892 name="Encoder",
1893 )
1895 # assert the input of the model
1896 # for use_backbone_angles
1897 if isinstance(model.input_shape, list):
1898 assert len(model.input_shape) == len(input_list)
1899 # for only central dihedrals
1900 else:
1901 model.input_shape[1] == input_list[0].shape[1]
1902 if p.use_backbone_angles and p.multimer_training is None:
1903 assert model.input_shape[0][1] - 1 == model.input_shape[1][1]
1905 # print/write info
1906 if write_summary:
1907 if p.tensorboard or p.write_summary:
1908 with Capturing() as output:
1909 model.summary()
1910 with open(p.main_path + "/encoder_summary.txt", "w") as f:
1911 f.write("\n".join(output))
1912 else:
1913 model.summary()
1915 return model, out
1918def _concatenate_inputs_reconstruct_sidechains(
1919 p: ADCParameters,
1920 central_angles_unit_circle: tf.Tensor,
1921 central_dihedrals_unit_circle: tf.Tensor,
1922 side_angles_unit_circle: tf.Tensor,
1923 side_dihedrals_unit_circle: tf.Tensor,
1924) -> tuple[list[int], tf.Tensor]: # pragma: no doccheck
1925 """Concatenates input Tensors for the AngleDihedralCartesianEncoderMap with
1926 sidechain reconstruction.
1928 """
1929 splits = [
1930 central_angles_unit_circle.shape[1],
1931 central_dihedrals_unit_circle.shape[1],
1932 side_angles_unit_circle.shape[1],
1933 side_dihedrals_unit_circle.shape[1],
1934 ]
1935 out = Concatenate(axis=1, name="concatenate_angular_inputs")(
1936 (
1937 central_angles_unit_circle,
1938 central_dihedrals_unit_circle,
1939 side_angles_unit_circle,
1940 side_dihedrals_unit_circle,
1941 )
1942 )
1943 return splits, out
1946def _concatenate_inputs(
1947 p: ADCParameters,
1948 angles_unit_circle: Union[tf.Tensor, None],
1949 central_dihedrals_unit_circle: tf.Tensor,
1950 side_dihedrals_unit_circle: Optional[tf.Tensor] = None,
1951 input_cartesians_pairwise_defined_shape: Optional[tf.Tensor] = None,
1952) -> tuple[list[int], tf.Tensor]:
1953 """Concatenates input Tensors for the AngleDihedralCartesianEncoderMap.
1955 As the AngleDihedralCartesianEncoderMap model can use either central_dihedrals,
1956 central_angles and central_dihedrals, central_angles and central_dihedrals and
1957 side_dihedrals for its Encoder input, these input sources need to be
1958 concatenated (after they have been projected onto a unit circle). This function
1959 concatenates these inputs in the correct order and ensures a correct shape
1960 of the inputs.
1962 Args:
1963 p (encodermap.parameters.ADCParameters): A parameter instance.
1964 angles_unit_circle (Union[tf.Tensor, None]): Can be None, in case only
1965 the central_dihedrals are used for training. Otherwise, needs to
1966 be the central angles.
1967 central_dihedrals_unit_circle (tf.Tensor): The unit circle projected
1968 central dihedrals.
1969 side_dihedrals_unit_circle: Can be None, if case the side dihedrals are
1970 not used for training. Otherwise, needs to be the side dihedrals.
1971 input_cartesians_pairwise_defined_shape (Optional[tf.Tensor]): The pairwise
1972 distances of the input cartesians.
1974 Returns:
1975 tuple: A tuple containing the following:
1976 - list[int]: A list of the shape[1] of the input tensors. If only
1977 dihedrals are used for training, this list has only one entry.
1978 In the other cases, this list can be used to split the output
1979 of the decoder again into the constituents of central_angles,
1980 central_dihedrals, side_dihedrals.
1981 - tf.Tensor: The concatenated inputs.
1983 """
1984 if not p.use_backbone_angles and not p.use_sidechains:
1985 splits = [central_dihedrals_unit_circle.shape[1]]
1986 out = central_dihedrals_unit_circle
1987 elif p.use_backbone_angles and not p.use_sidechains:
1988 splits = [angles_unit_circle.shape[1], central_dihedrals_unit_circle.shape[1]]
1989 out = Concatenate(axis=1, name="concatenate_angular_inputs")(
1990 (angles_unit_circle, central_dihedrals_unit_circle)
1991 )
1992 elif p.use_backbone_angles and p.use_sidechains:
1993 if p.multimer_training is None:
1994 splits = [
1995 angles_unit_circle.shape[1],
1996 central_dihedrals_unit_circle.shape[1],
1997 side_dihedrals_unit_circle.shape[1],
1998 ]
1999 out = Concatenate(axis=1, name="concatenate_angular_inputs")(
2000 (
2001 angles_unit_circle,
2002 central_dihedrals_unit_circle,
2003 side_dihedrals_unit_circle,
2004 )
2005 )
2006 else:
2007 splits = [
2008 angles_unit_circle.shape[1],
2009 central_dihedrals_unit_circle.shape[1],
2010 side_dihedrals_unit_circle.shape[1],
2011 input_cartesians_pairwise_defined_shape.shape[1],
2012 ]
2013 out = Concatenate(axis=1, name="input_cartesians_pairwise_defined_shape")(
2014 (
2015 angles_unit_circle,
2016 central_dihedrals_unit_circle,
2017 side_dihedrals_unit_circle,
2018 input_cartesians_pairwise_defined_shape,
2019 )
2020 )
2021 else:
2022 raise Exception(
2023 "Only allowed combinations are:\n"
2024 " * No sidechains, no backbone angles\n"
2025 " * No sidechains, yes backbone angles\n"
2026 " * Yes Sidechains, yes backbone angles\n"
2027 f"Your parameters are: {p.use_sidechains=}. {p.use_backbone_angles=}"
2028 )
2029 return splits, out
2032def _create_inputs_non_periodic_maybe_sparse(
2033 shape: Union[tuple[int], tuple[int, int]],
2034 p: ADCParameters,
2035 name: str,
2036 sparse: bool,
2037 reshape: Optional[int] = None,
2038) -> Union[tf.Tensor, tf.Tensor, Optional[tf.keras.Model]]:
2039 """Creates an input Tensor.
2041 Args:
2042 shape (Union[tuple[int], tuple[int, int]]): The shape can be either a
2043 tuple with one int (in case of the central distances) or a tuple
2044 of two ints (in case of central cartesians), in which case, the
2045 2nd is checked to be 3 (for the xyz coordinates).
2046 name (str): The name of this input tensor. Will be preceded with 'input_'.
2047 sparse (bool): Whether a sparse->dense model should be returned. Defaults to False.
2048 reshape (Optional[int]): Whether the input will be in flattened cartesians
2049 and thus reshaped to (shape // reshape, reshape). Thus, only the
2050 reshape 3 is currently used in EncoderMap. If None is specified, the
2051 output will not be reshaped. Defaults to None.
2053 Returns:
2054 tuple: A tuple containing the following:
2055 - tf.Tensor: The placeholder tensor for the input. If sparse is True,
2056 this Tensor will first be fed through a Dense layer to use sparse
2057 matrix multiplication to make it dense again.
2058 - Union[tf.Tensor, None]: The Dense output of the Tensor, if sparse is True.
2059 - Union[tf.keras.Model, None]: The model to get from sparse to dense.
2060 If sparse is False, None will be returned here.
2062 """
2063 if len(shape) == 2:
2064 assert (
2065 shape[1] == 3
2066 ), f"Provided tuple of two ints is not of cartesian xyz coordinates!."
2067 if not sparse:
2068 placeholder = Input(shape=shape, name=f"input_{name}")
2069 dense_model = None
2070 output_placeholder = placeholder
2071 else:
2072 assert len(shape) == 1, (
2073 f"Sparse tensors can only be of rank 2. The Input that accepts these "
2074 f"tensors can also just be a tuple with a single int in it. You requested "
2075 f"an input with rank {len(shape)} ({shape=}). Please reevaluate how you use "
2076 f"`gen_functional_model` and try to reshape the input."
2077 )
2078 placeholder = Input(shape=shape, name=f"sparse_input_{name}", sparse=True)
2080 kernel_initializer = tf.keras.initializers.VarianceScaling()
2081 bias_initializer = tf.keras.initializers.RandomNormal(0.1, 0.05)
2082 if os.getenv("CONSTANT_SPARSE_TO_DENSE", "False") == "True":
2083 warnings.warn("Using constant for to_dense initializers.")
2084 kernel_initializer = tf.keras.initializers.Constant(1)
2085 bias_initializer = tf.keras.initializers.Constant(1)
2087 output_placeholder = Dense(
2088 units=shape[0],
2089 trainable=p.trainable_dense_to_sparse,
2090 kernel_initializer=kernel_initializer,
2091 bias_initializer=bias_initializer,
2092 kernel_regularizer=tf.keras.regularizers.l2(p.l2_reg_constant),
2093 )(placeholder)
2094 dense_model = tf.keras.Model(
2095 inputs=placeholder,
2096 outputs=output_placeholder,
2097 )
2098 if reshape is not None:
2099 output_placeholder = tf.keras.layers.Reshape(
2100 target_shape=(shape[0] // reshape, reshape),
2101 input_shape=shape,
2102 name=f"reshape_sparse_to_dense_{name}",
2103 )(output_placeholder)
2104 return placeholder, output_placeholder, dense_model
2107def _create_inputs_periodic_maybe_sparse(
2108 shape: int,
2109 p: ADCParameters,
2110 name: str,
2111 sparse: bool,
2112) -> tuple[tf.Tensor, tf.Tensor, Union[tf.keras.Model, None]]:
2113 """Creates an input Tensor and also projects it onto a unit circle (returns
2114 the sin, cos, sin, cos, ...) of the values.
2116 Args:
2117 shape (int): The shape can be either a
2118 tuple with one int (in case of the central distances) or a tuple
2119 of two ints (in case of central cartesians), in which case, the
2120 2nd is checked to be 3 (for the xyz coordinates).
2121 p (encodermap.parameters.ADCParameters): An instance of ADCParameters, which contains info
2122 about the periodicity of the input space.
2123 name (str): The name of this input tensor. Will be preceded with 'input_'.
2124 The to unit_circle input will be called 'input_{name}_to_unit_circle'.
2125 sparse (bool): Whether a sparse->dense model should be returned.
2127 Returns:
2128 tuple: A tuple containing the following:
2129 - tf.Tensor: The placeholder tensor for the input. If sparse is True,
2130 this Tensor will first be fed through a Dense layer to use sparse
2131 matrix multiplication to make it dense again.
2132 - tf.Tensor: The PeriodicInput of the same tensor.
2133 - Union[tf.keras.Model, None]: The model to get from sparse to dense.
2134 If sparse is False, a None will be returned here.
2136 """
2137 (
2138 placeholder,
2139 dense_placeholder,
2140 dense_model,
2141 ) = _create_inputs_non_periodic_maybe_sparse(
2142 shape=(shape,),
2143 p=p,
2144 name=name,
2145 sparse=sparse,
2146 )
2147 unit_circle = PeriodicInput(p, f"input_{name}_to_unit_circle")(dense_placeholder)
2148 return placeholder, unit_circle, dense_model
2151################################################################################
2152# Public Classes
2153################################################################################
2156class ADCFunctionalModel(tf.keras.Model):
2157 """A subclass of tf.keras.Model, that implements the logic for the
2158 AngleDihedralCartesianEncoderMap.
2160 """
2162 def __init__(
2163 self,
2164 parameters: ADCParameters,
2165 inputs: Iterable[tf.Tensor],
2166 outputs: Iterable[tf.Tensor],
2167 encoder: tf.keras.Model,
2168 decoder: tf.keras.Model,
2169 ) -> None:
2170 """Initialize the Model.
2172 Args:
2173 parameters (encodermap.parameters.ADCParameters): An instance of the ADCParameters class.
2174 inputs (Iterable[tf.Tensor]): The inputs of the model.
2175 outputs (Iterable[tf.Tensor]): The outputs of the model.
2176 encoder (tf.keras.Model): The encoder as its own model.
2177 decoder (tf.keras.Model): The decoder as its own model.
2179 """
2180 self.p = parameters
2181 super().__init__(inputs=inputs, outputs=outputs, name="ADCFunctionalModel")
2182 self.encoder_model = encoder
2183 self.decoder_model = decoder
2185 # train counter
2186 self._my_train_counter = K.variable(0, "int64", name="train_counter")
2188 def get_config(self) -> dict[str, Any]:
2189 """Serializes this keras serializable.
2191 Returns:
2192 dict[str, Any]: A dict with the serializable objects.
2194 """
2195 config = super().get_config().copy()
2196 config.update(
2197 {
2198 "parameters": self.p.to_dict(),
2199 "inputs": [i.shape for i in self.inputs],
2200 "outputs": [o.shape for o in self.outputs],
2201 "encoder": tf.keras.saving.serialize_keras_object(self.encoder_model),
2202 "decoder": tf.keras.saving.serialize_keras_object(self.decoder_model),
2203 }
2204 )
2205 return config
2207 @classmethod
2208 def from_config(
2209 cls: Type[ADCFunctionalModelType],
2210 config: dict[str, Any],
2211 custom_objects: Optional[dict[Any, Any]] = None,
2212 ) -> ADCFunctionalModelType:
2213 """Reconstructs this keras serializable from a dict.
2215 Args:
2216 config (dict[str, Any]): A dictionary.
2217 custom_objects (Optional[dict[str, Any]]): Not needed here, but see
2218 https://keras.io/guides/serialization_and_saving/ for yourself.
2220 Returns:
2221 ADCFunctionalModelType: An instance of the ADCFunctionalModel.
2223 """
2224 inputs_config = config.pop("inputs")
2225 inputs = tf.keras.saving.deserialize_keras_object(inputs_config)
2226 outputs_config = config.pop("outputs")
2227 outputs = tf.keras.saving.deserialize_keras_object(outputs_config)
2228 encoder_config = config.pop("encoder")
2229 encoder = tf.keras.saving.deserialize_keras_object(encoder_config)
2230 decoder_config = config.pop("decoder")
2231 decoder = tf.keras.saving.deserialize_keras_object(decoder_config)
2232 parameters = config.pop("parameters")
2234 if "cartesian_pwd_start" in parameters:
2235 parameters = ADCParameters(**parameters)
2236 else:
2237 parameters = Parameters(**parameters)
2239 # create a new functional model and apply the weights from the encoder and decoder
2240 input_shapes = tuple([tuple(i[1:]) for i in inputs])
2241 new_model = gen_functional_model(input_shapes, parameters, write_summary=False)
2242 new_model.encoder_model.set_weights(encoder.get_weights())
2243 new_model.decoder_model.set_weights(decoder.get_weights())
2244 return cls(
2245 parameters,
2246 new_model.inputs,
2247 new_model.outputs,
2248 new_model.encoder_model,
2249 new_model.decoder_model,
2250 )
2252 @property
2253 def encoder(self) -> tf.keras.Model:
2254 return self.encoder_model
2256 @property
2257 def decoder(self) -> tf.keras.Model:
2258 return self.decoder_model
2260 def compile(self, *args, **kwargs) -> None:
2261 self.unpacked_loss_fns = kwargs.pop("unpacked_loss_fns")
2262 super().compile(*args, **kwargs)
2264 def get_loss(self, inp: ADCFunctionalModelInputType) -> tf.Tensor:
2265 # unpack the inputs
2266 if not self.p.reconstruct_sidechains:
2267 if self.p.use_sidechains or len(inp) == 5:
2268 (
2269 inp_angles,
2270 inp_dihedrals,
2271 inp_cartesians,
2272 inp_distances,
2273 inp_side_dihedrals,
2274 ) = inp
2276 elif len(inp) == 4:
2277 (
2278 inp_angles,
2279 inp_dihedrals,
2280 inp_cartesians,
2281 inp_distances,
2282 ) = inp
2283 # call the model
2284 if not self.p.use_sidechains:
2285 out = self(
2286 (
2287 inp_angles,
2288 inp_dihedrals,
2289 inp_cartesians,
2290 inp_distances,
2291 ),
2292 training=True,
2293 )
2294 else:
2295 out = self(inp, training=True)
2296 else:
2297 (
2298 inp_angles,
2299 inp_dihedrals,
2300 inp_cartesians,
2301 inp_distances,
2302 inp_side_angles,
2303 inp_side_dihedrals,
2304 inp_side_distances,
2305 ) = inp
2306 out = self(inp, training=True)
2308 # unpack the outputs
2309 if self.p.multimer_training is None:
2310 if self.p.reconstruct_sidechains:
2311 (
2312 out_angles,
2313 out_dihedrals,
2314 out_side_angles,
2315 out_side_dihedrals,
2316 back_cartesians,
2317 inp_pair,
2318 out_pair,
2319 ) = out
2320 elif self.p.use_sidechains and not self.p.reconstruct_sidechains:
2321 (
2322 out_angles,
2323 out_dihedrals,
2324 out_side_dihedrals,
2325 back_cartesians,
2326 inp_pair,
2327 out_pair,
2328 ) = out
2329 else:
2330 (
2331 out_angles,
2332 out_dihedrals,
2333 back_cartesians,
2334 inp_pair,
2335 out_pair,
2336 ) = out
2337 else:
2338 raise NotImplementedError
2340 # get the latent
2341 if self.p.multimer_training is None:
2342 if self.p.reconstruct_sidechains:
2343 latent = self.encoder_model(
2344 (inp_angles, inp_dihedrals, inp_side_angles, inp_side_dihedrals),
2345 training=True,
2346 )
2347 elif self.p.use_sidechains and not self.p.reconstruct_sidechains:
2348 latent = self.encoder_model(
2349 (inp_angles, inp_dihedrals, inp_side_dihedrals),
2350 training=True,
2351 )
2352 elif self.p.use_backbone_angles and not self.p.reconstruct_sidechains:
2353 latent = self.encoder_model(
2354 (inp_angles, inp_dihedrals),
2355 training=True,
2356 )
2357 else:
2358 latent = self.encoder_model(
2359 inp_dihedrals,
2360 training=True,
2361 )
2362 else:
2363 if self.p.multimer_training == "homogeneous_transformation":
2364 latent = self.encoder_model(
2365 (inp_angles, inp_dihedrals, inp_side_dihedrals, inp_cartesians),
2366 training=True,
2367 )
2368 else:
2369 raise NotImplementedError
2371 with tf.name_scope("Cost"):
2372 loss = 0.0
2373 # dihedral loss
2374 loss += self.unpacked_loss_fns["dihedral_loss_func"](
2375 inp_dihedrals, out_dihedrals
2376 )
2378 # angle loss
2379 # either uses trained angles or mean angles
2380 loss += self.unpacked_loss_fns["angle_loss_func"](inp_angles, out_angles)
2382 if self.p.reconstruct_sidechains:
2383 loss += self.unpacked_loss_fns["angle_loss_func"](
2384 inp_side_angles, out_side_angles
2385 )
2387 # cartesian loss
2388 # compares the pairwise distances of the input cartesians
2389 # and the output cartesians
2390 # this cost function will slowly be added via a soft-start
2391 loss += self.unpacked_loss_fns["cartesian_loss_func"](inp_pair, out_pair)
2393 # distance loss
2394 # compares the input and the latent, thus needs to be adjusted
2395 # based on whether the encoder takes angles+dihedrals+side dihedrals,
2396 # angles+dihedrals, or just dihedrals.
2397 if self.p.multimer_training is None:
2398 if self.p.reconstruct_sidechains:
2399 loss += self.unpacked_loss_fns["distance_loss_func"](
2400 (inp_angles, inp_dihedrals, inp_side_angles, inp_side_dihedrals)
2401 )
2402 elif self.p.use_sidechains and not self.p.reconstruct_sidechains:
2403 loss += self.unpacked_loss_fns["distance_loss_func"](
2404 (inp_angles, inp_dihedrals, inp_side_dihedrals)
2405 )
2406 elif self.p.use_backbone_angles and not self.p.reconstruct_sidechains:
2407 loss += self.unpacked_loss_fns["distance_loss_func"](
2408 (inp_angles, inp_dihedrals)
2409 )
2410 else:
2411 loss += self.unpacked_loss_fns["distance_loss_func"](inp_dihedrals)
2412 else:
2413 if self.p.multimer_training == "homogeneous_transformation":
2414 loss += self.unpacked_loss_fns["distance_loss_func"](
2415 (inp_angles, inp_dihedrals, inp_side_dihedrals, inp_cartesians)
2416 )
2417 else:
2418 raise NotImplementedError
2420 # cartesian distance cost
2421 # Compares the input pairwise distances with the latent using a
2422 # second sigmoid function
2423 loss += self.unpacked_loss_fns["cartesian_distance_loss_func"](
2424 inp_pair, latent
2425 )
2427 # center loss
2428 # makes sure, that the latent is in the center and thus depends on
2429 # the input of the encoder
2430 if self.p.multimer_training is None:
2431 if self.p.reconstruct_sidechains:
2432 loss += self.unpacked_loss_fns["center_loss_func"](
2433 (inp_angles, inp_dihedrals, inp_side_angles, inp_side_dihedrals)
2434 )
2435 elif self.p.use_sidechains and not self.p.reconstruct_sidechains:
2436 loss += self.unpacked_loss_fns["center_loss_func"](
2437 (inp_angles, inp_dihedrals, inp_side_dihedrals)
2438 )
2439 elif self.p.use_backbone_angles and not self.p.reconstruct_sidechains:
2440 loss += self.unpacked_loss_fns["center_loss_func"](
2441 (inp_angles, inp_dihedrals)
2442 )
2443 else:
2444 loss += self.unpacked_loss_fns["center_loss_func"](inp_dihedrals)
2445 else:
2446 if self.p.multimer_training == "homogeneous_transformation":
2447 loss += self.unpacked_loss_fns["center_loss_func"](
2448 (inp_angles, inp_dihedrals, inp_side_dihedrals, inp_cartesians)
2449 )
2450 else:
2451 raise NotImplementedError
2453 # reg loss
2454 # just add the squared weights of all trainable layers
2455 loss += self.unpacked_loss_fns["regularization_loss_func"]()
2457 # side dihedral loss
2458 if self.p.use_sidechains:
2459 loss += self.unpacked_loss_fns["side_dihedral_loss_func"](
2460 inp_side_dihedrals, out_side_dihedrals
2461 )
2462 tf.summary.scalar("Combined Cost", loss)
2464 return loss, inp_cartesians, back_cartesians
2466 def train_step(self, data: ADCFunctionalModelInputType) -> None:
2467 """Can receive three types of data.
2469 * use_backbone_angles = False, use_sidechains = False:
2470 Will receive a four-tuple in the order: angles, dihedrals, cartesians,
2471 distances. The angles will be used to construct mean angles.
2472 * use_backbone_angles = True, use_sidechains = False:
2473 Will receive the same four-tuple as above, but the angles will be
2474 fed through the autoencoder.
2475 * use_backbone_angles = True, use_sidechains = True:
2476 Will receive a five-tuple in the order: angles, dihedrals, cartesians,
2477 distances, side dihedrals. The angles, central dihedrals and side
2478 dihedrals will be fed through the autoencoder.
2480 """
2482 with tf.GradientTape() as tape:
2483 tf.summary.experimental.set_step(self._my_train_counter)
2484 loss, inp_cartesians, out_cartesians = self.get_loss(data)
2485 loggable_encoder_layers = [
2486 l for l in self.encoder_model.layers if l.__class__.__name__ == "Dense"
2487 ]
2488 loggable_decoder_layers = [
2489 l for l in self.decoder_model.layers if l.__class__.__name__ == "Dense"
2490 ]
2491 for l in loggable_encoder_layers + loggable_decoder_layers:
2492 add_layer_summaries(l, step=self._my_train_counter)
2494 # Compute Gradients
2495 if not self.p.trainable_dense_to_sparse:
2496 trainable_vars = (
2497 self.encoder_model.trainable_variables
2498 + self.decoder_model.trainable_variables
2499 )
2500 else:
2501 trainable_vars = self.trainable_variables
2502 gradients = tape.gradient(loss, trainable_vars)
2504 # Update weights
2505 self.optimizer.apply_gradients(zip(gradients, trainable_vars))
2507 # Update metrics
2508 for metric in self.compiled_metrics._metrics:
2509 metric.update_state(data, self(data))
2511 # update train counter because tensorflow seems to have deprecated it
2512 self._my_train_counter.assign_add(1)
2513 return {m.name: m.result() for m in self.compiled_metrics._metrics} | {
2514 "loss": loss,
2515 }
2518@testing
2519class ADCFunctionalModelTesting(tf.keras.Model):
2520 """A subclass of tf.keras.Model, that implements the logic for the
2521 AngleDihedralCartesianEncoderMap.
2523 """
2525 def __init__(
2526 self,
2527 parameters: ADCParameters,
2528 inputs: Iterable[tf.Tensor],
2529 outputs: Iterable[tf.Tensor],
2530 encoder: tf.keras.Model,
2531 decoder: tf.keras.Model,
2532 ) -> None:
2533 """Initialize the Model.
2535 Args:
2536 parameters (encodermap.parameters.ADCParameters): An instance of the ADCParameters class.
2537 inputs (Iterable[tf.Tensor]): The inputs of the model.
2538 outputs (Iterable[tf.Tensor]): The outputs of the model.
2539 encoder (tf.keras.Model): The encoder as its own model.
2540 decoder (tf.keras.Model): The decoder as its own model.
2542 """
2544 self.p = parameters
2545 super().__init__(inputs=inputs, outputs=outputs, name="ADCFunctionalModel")
2546 self.encoder_model = encoder
2547 self.decoder_model = decoder
2549 # loggable layers
2550 self.loggable_encoder_layers = [
2551 l for l in self.encoder_model.layers if l.__class__.__name__ == "Dense"
2552 ]
2553 self.loggable_decoder_layers = [
2554 l for l in self.decoder_model.layers if l.__class__.__name__ == "Dense"
2555 ]
2557 def get_config(self) -> dict[str, Any]:
2558 """Serializes this keras serializable.
2560 Returns:
2561 dict[str, Any]: A dict with the serializable objects.
2563 """
2564 config = super().get_config().copy()
2565 config.update(
2566 {
2567 "parameters": self.p.to_dict(),
2568 "inputs": [i.shape for i in self.inputs],
2569 "outputs": [o.shape for o in self.outputs],
2570 "encoder": tf.keras.saving.serialize_keras_object(self.encoder_model),
2571 "decoder": tf.keras.saving.serialize_keras_object(self.decoder_model),
2572 }
2573 )
2574 return config
2576 def train_step(self, data: ADCFunctionalModelInputType) -> Any:
2577 with tf.GradientTape() as tape:
2578 tf.summary.experimental.set_step(self._my_train_counter)
2579 if self.p.use_sidechains:
2580 y_pred = self(data, training=True)
2581 else:
2582 y_pred = self(data[:-1], training=True)
2584 loss = self.compute_loss(y=data, y_pred=y_pred)
2586 for l in self.loggable_encoder_layers + self.loggable_decoder_layers:
2587 add_layer_summaries(l, step=self._my_train_counter)
2589 trainable_vars = (
2590 self.encoder_model.trainable_variables
2591 + self.decoder_model.trainable_variables
2592 )
2593 gradients = tape.gradient(loss, trainable_vars)
2595 self.optimizer.apply_gradients(zip(gradients, trainable_vars))
2597 return {"loss": loss}
2599 @classmethod
2600 def from_config(
2601 cls: Type[ADCFunctionalModelTestingType],
2602 config: dict[str, Any],
2603 custom_objects: Optional[dict[Any, Any]] = None,
2604 ) -> ADCFunctionalModelTestingType:
2605 """Reconstructs this keras serializable from a dict.
2607 Args:
2608 config (dict[str, Any]): A dictionary.
2609 custom_objects (Optional[dict[str, Any]]): Not needed here, but see
2610 https://keras.io/guides/serialization_and_saving/ for yourself.
2612 Returns:
2613 ADCFunctionalModelTestingType: An instance of the ADCFunctionalModelTesting.
2615 """
2616 inputs_config = config.pop("inputs")
2617 inputs = tf.keras.saving.deserialize_keras_object(inputs_config)
2618 outputs_config = config.pop("outputs")
2619 outputs = tf.keras.saving.deserialize_keras_object(outputs_config)
2620 encoder_config = config.pop("encoder")
2621 encoder = tf.keras.saving.deserialize_keras_object(encoder_config)
2622 decoder_config = config.pop("decoder")
2623 decoder = tf.keras.saving.deserialize_keras_object(decoder_config)
2624 parameters = config.pop("parameters")
2626 if "cartesian_pwd_start" in parameters:
2627 parameters = ADCParameters(**parameters)
2628 else:
2629 parameters = Parameters(**parameters)
2631 # create a new functional model and apply the weights from the encoder and decoder
2632 input_shapes = tuple([tuple(i[1:]) for i in inputs])
2633 new_model = gen_functional_model(input_shapes, parameters, write_summary=False)
2634 new_model.encoder_model.set_weights(encoder.get_weights())
2635 new_model.decoder_model.set_weights(decoder.get_weights())
2636 return cls(
2637 parameters,
2638 new_model.inputs,
2639 new_model.outputs,
2640 new_model.encoder_model,
2641 new_model.decoder_model,
2642 )
2644 @property
2645 def encoder(self) -> tf.keras.Model:
2646 return self.encoder_model
2648 @property
2649 def decoder(self) -> tf.keras.Model:
2650 return self.decoder_model
2653class ADCSparseFunctionalModel(ADCFunctionalModel):
2654 def __init__(
2655 self,
2656 parameters: ADCParameters,
2657 inputs: Iterable[tf.Tensor],
2658 outputs: Iterable[tf.Tensor],
2659 encoder: tf.keras.Model,
2660 decoder: tf.keras.Model,
2661 get_dense_model_central_angles: tf.keras.Model,
2662 get_dense_model_central_dihedrals: tf.keras.Model,
2663 get_dense_model_cartesians: tf.keras.Model,
2664 get_dense_model_distances: tf.keras.Model,
2665 get_dense_model_side_dihedrals: Union[tf.keras.Model, None],
2666 ) -> None:
2667 """Instantiate the Model.
2669 Args:
2670 parameters (encodermap.parameters.ADCParameters): An instance of the ADCParameters class.
2671 inputs (Iterable[tf.Tensor]): The inputs of the model.
2672 outputs (Iterable[tf.Tensor]): The outputs of the model.
2673 encoder (tf.keras.Model): The encoder as its own model.
2674 decoder (tf.keras.Model): The decoder as its own model.
2675 get_dense_model_central_angles (tf.keras.Model): A model with a
2676 single dense layer that uses sparse matrix multiplication to
2677 transform the sparse tensor.
2678 get_dense_model_central_dihedrals (tf.keras.Model): A model with a
2679 single dense layer that uses sparse matrix multiplication to
2680 transform the sparse tensor.
2681 get_dense_model_cartesians (tf.keras.Model): A model with a
2682 single dense layer that uses sparse matrix multiplication to
2683 transform the sparse tensor.
2684 get_dense_model_distances (tf.keras.Model): A model with a
2685 single dense layer that uses sparse matrix multiplication to
2686 transform the sparse tensor.
2687 get_dense_model_side_dihedrals (Union[tf.keras.Model, None]):
2688 A model with a single dense layer that uses sparse matrix
2689 multiplication to transform the sparse tensor. Can be None,
2690 for when these angular inputs are not used for training.
2692 """
2693 super().__init__(parameters, inputs, outputs, encoder, decoder)
2694 self.get_dense_model_central_angles = get_dense_model_central_angles
2695 self.get_dense_model_central_dihedrals = get_dense_model_central_dihedrals
2696 self.get_dense_model_cartesians = get_dense_model_cartesians
2697 self.get_dense_model_distances = get_dense_model_distances
2698 self.get_dense_model_side_dihedrals = get_dense_model_side_dihedrals
2699 if self.get_dense_model_cartesians is not None:
2700 self.reshape_layer = tf.keras.layers.Reshape(
2701 target_shape=(inputs[2].shape[1] // 3, 3),
2702 input_shape=(inputs[2].shape[1],),
2703 name="reshape_sparse_to_dense_internally",
2704 )
2705 self.reshape_layer.build(
2706 input_shape=(self.p.batch_size, inputs[2].shape[1])
2707 )
2708 else:
2709 self.reshape_layer = None
2711 def get_config(self) -> dict[str, Any]:
2712 """Serializes this keras serializable.
2714 Returns:
2715 dict[str, Any]: A dict with the serializable objects.
2717 """
2718 config = super().get_config().copy()
2719 config.update(
2720 {
2721 "get_dense_model_central_angles": tf.keras.saving.serialize_keras_object(
2722 self.get_dense_model_central_angles
2723 ),
2724 "get_dense_model_central_dihedrals": tf.keras.saving.serialize_keras_object(
2725 self.get_dense_model_central_dihedrals
2726 ),
2727 "get_dense_model_cartesians": tf.keras.saving.serialize_keras_object(
2728 self.get_dense_model_cartesians
2729 ),
2730 "get_dense_model_distances": tf.keras.saving.serialize_keras_object(
2731 self.get_dense_model_distances
2732 ),
2733 "get_dense_model_side_dihedrals": tf.keras.saving.serialize_keras_object(
2734 self.get_dense_model_side_dihedrals
2735 ),
2736 }
2737 )
2738 return config
2740 @classmethod
2741 def from_config(
2742 cls: Type[ADCSparseFunctionalModelType],
2743 config: dict[str, Any],
2744 custom_objects: Optional[dict[Any, Any]] = None,
2745 ) -> ADCSparseFunctionalModelType:
2746 """Reconstructs this keras serializable from a dict.
2748 Args:
2749 config (dict[str, Any]): A dictionary.
2750 custom_objects (Optional[dict[str, Any]]): Not needed here, but see
2751 https://keras.io/guides/serialization_and_saving/ for yourself.
2753 Returns:
2754 ADCSparseFunctionalModelType: An instance of the ADCSparseFunctionalModel.
2756 """
2757 inputs_config = config.pop("inputs")
2758 inputs = tf.keras.saving.deserialize_keras_object(inputs_config)
2759 outputs_config = config.pop("outputs")
2760 outputs = tf.keras.saving.deserialize_keras_object(outputs_config)
2761 encoder_config = config.pop("encoder")
2762 encoder = tf.keras.saving.deserialize_keras_object(encoder_config)
2763 decoder_config = config.pop("decoder")
2764 decoder = tf.keras.saving.deserialize_keras_object(decoder_config)
2766 # get the dense models
2767 get_dense_model_central_angles = config.pop("get_dense_model_central_angles")
2768 get_dense_model_central_angles = tf.keras.saving.deserialize_keras_object(
2769 get_dense_model_central_angles
2770 )
2772 get_dense_model_central_dihedrals = config.pop(
2773 "get_dense_model_central_dihedrals"
2774 )
2775 get_dense_model_central_dihedrals = tf.keras.saving.deserialize_keras_object(
2776 get_dense_model_central_dihedrals
2777 )
2779 get_dense_model_cartesians = config.pop("get_dense_model_cartesians")
2780 get_dense_model_cartesians = tf.keras.saving.deserialize_keras_object(
2781 get_dense_model_cartesians
2782 )
2784 get_dense_model_distances = config.pop("get_dense_model_distances")
2785 get_dense_model_distances = tf.keras.saving.deserialize_keras_object(
2786 get_dense_model_distances
2787 )
2789 get_dense_model_side_dihedrals = config.pop("get_dense_model_side_dihedrals")
2790 get_dense_model_side_dihedrals = tf.keras.saving.deserialize_keras_object(
2791 get_dense_model_side_dihedrals
2792 )
2794 parameters = config.pop("parameters")
2795 if "cartesian_pwd_start" in parameters:
2796 parameters = ADCParameters(**parameters)
2797 else:
2798 parameters = Parameters(**parameters)
2800 # create a new functional model and apply the weights from the encoder and decoder
2801 input_shapes = tuple([tuple(i[1:]) for i in inputs])
2802 new_model = gen_functional_model(
2803 input_shapes,
2804 parameters,
2805 write_summary=False,
2806 sparse=True,
2807 )
2808 if len(encoder.get_weights()) != len(new_model.encoder_model.get_weights()):
2809 # here, we can assume that the model was trained with
2810 # only sparse sidechains
2811 new_model = gen_functional_model(
2812 input_shapes,
2813 parameters,
2814 write_summary=False,
2815 sparse=False,
2816 sidechain_only_sparse=True,
2817 )
2819 # for l in new_model.encoder_model.layers:
2820 # print(f"new_model.encoder layer {l=}")
2821 # for l in encoder.layers:
2822 # print(f"encoder layer {l=}")
2824 new_model.encoder_model.set_weights(encoder.get_weights())
2825 new_model.decoder_model.set_weights(decoder.get_weights())
2827 if new_model.get_dense_model_central_angles is not None:
2828 new_model.get_dense_model_central_angles.set_weights(
2829 get_dense_model_central_angles.get_weights()
2830 )
2831 if new_model.get_dense_model_central_dihedrals is not None:
2832 new_model.get_dense_model_central_dihedrals.set_weights(
2833 get_dense_model_central_dihedrals.get_weights()
2834 )
2835 if new_model.get_dense_model_cartesians is not None:
2836 new_model.get_dense_model_cartesians.set_weights(
2837 get_dense_model_cartesians.get_weights()
2838 )
2839 if new_model.get_dense_model_distances is not None:
2840 new_model.get_dense_model_distances.set_weights(
2841 get_dense_model_distances.get_weights()
2842 )
2843 if new_model.get_dense_model_side_dihedrals is not None:
2844 new_model.get_dense_model_side_dihedrals.set_weights(
2845 get_dense_model_side_dihedrals.get_weights()
2846 )
2848 return cls(
2849 parameters,
2850 new_model.inputs,
2851 new_model.outputs,
2852 new_model.encoder_model,
2853 new_model.decoder_model,
2854 new_model.get_dense_model_central_angles,
2855 new_model.get_dense_model_central_dihedrals,
2856 new_model.get_dense_model_cartesians,
2857 new_model.get_dense_model_distances,
2858 new_model.get_dense_model_side_dihedrals,
2859 )
2861 def get_loss(self, inp):
2862 # unpack the inputs
2863 if self.p.use_sidechains and len(inp) == 5:
2864 (
2865 sparse_inp_angles,
2866 sparse_inp_dihedrals,
2867 sparse_inp_cartesians,
2868 sparse_inp_distances,
2869 sparse_side_dihedrals,
2870 ) = inp
2871 else:
2872 (
2873 sparse_inp_angles,
2874 sparse_inp_dihedrals,
2875 sparse_inp_cartesians,
2876 sparse_inp_distances,
2877 ) = inp
2879 if isinstance(sparse_inp_angles, tf.sparse.SparseTensor):
2880 inp_angles = self.get_dense_model_central_angles(
2881 sparse_inp_angles, training=True
2882 )
2883 else:
2884 inp_angles = sparse_inp_angles
2885 if isinstance(sparse_inp_dihedrals, tf.sparse.SparseTensor):
2886 inp_dihedrals = self.get_dense_model_central_dihedrals(
2887 sparse_inp_dihedrals, training=True
2888 )
2889 else:
2890 inp_dihedrals = sparse_inp_dihedrals
2891 if isinstance(sparse_inp_cartesians, tf.sparse.SparseTensor):
2893 inp_cartesians = self.get_dense_model_cartesians(
2894 sparse_inp_cartesians, training=True
2895 )
2896 else:
2897 inp_cartesians = sparse_inp_cartesians
2898 if isinstance(sparse_inp_distances, tf.sparse.SparseTensor):
2899 inp_distances = self.get_dense_model_distances(
2900 sparse_inp_distances, training=True
2901 )
2902 else:
2903 inp_distances = sparse_inp_distances
2905 if self.p.use_sidechains:
2906 if isinstance(sparse_side_dihedrals, tf.sparse.SparseTensor):
2907 inp_side_dihedrals = self.get_dense_model_side_dihedrals(
2908 sparse_side_dihedrals, training=True
2909 )
2910 else:
2911 inp_side_dihedrals = sparse_side_dihedrals
2913 # make them into an Iterable again
2914 if self.p.use_sidechains:
2915 data = (
2916 inp_angles,
2917 inp_dihedrals,
2918 inp_cartesians,
2919 inp_distances,
2920 inp_side_dihedrals,
2921 )
2922 else:
2923 data = (
2924 inp_angles,
2925 inp_dihedrals,
2926 inp_cartesians,
2927 inp_distances,
2928 )
2930 # call the loss
2931 # when we are using sparse `inp_cartesians`, index `[1]` of the output
2932 # of `super().get_loss(data)` contains 'central_cartesians' in the flattened
2933 # rank 2 form and we need to transform it
2934 resulting_loss, _, out_cartesians = super().get_loss(data)
2935 if self.get_dense_model_cartesians is not None:
2936 inp_cartesians = self.reshape_layer(inp_cartesians)
2937 return resulting_loss, inp_cartesians, out_cartesians
2940class ADCFunctionalModelSidechainReconstruction(ADCSparseFunctionalModel):
2941 def __init__(
2942 self,
2943 parameters: ADCParameters,
2944 inputs: Iterable[tf.Tensor],
2945 outputs: Iterable[tf.Tensor],
2946 encoder: tf.keras.Model,
2947 decoder: tf.keras.Model,
2948 **kwargs: Any,
2949 ) -> None:
2950 super().__init__(
2951 parameters=parameters,
2952 inputs=inputs,
2953 outputs=outputs,
2954 encoder=encoder,
2955 decoder=decoder,
2956 get_dense_model_central_angles=kwargs["get_dense_model_central_angles"],
2957 get_dense_model_central_dihedrals=kwargs[
2958 "get_dense_model_central_dihedrals"
2959 ],
2960 get_dense_model_cartesians=kwargs["get_dense_model_cartesians"],
2961 get_dense_model_distances=kwargs["get_dense_model_central_distances"],
2962 get_dense_model_side_dihedrals=kwargs["get_dense_model_side_dihedrals"],
2963 )
2964 self.get_dense_model_side_angles = kwargs["get_dense_model_side_angles"]
2965 self.get_dense_model_side_distances = kwargs["get_dense_model_side_distances"]
2967 def get_loss(self, inp: tuple[tf.Tensor, ...]):
2968 # unpack the inputs
2969 (
2970 sparse_inp_central_angles,
2971 sparse_inp_central_dihedrals,
2972 sparse_inp_all_cartesians,
2973 sparse_inp_central_distances,
2974 sparse_inp_side_angles,
2975 sparse_inp_side_dihedrals,
2976 sparse_inp_side_distances,
2977 ) = inp
2979 # central angles
2980 if isinstance(sparse_inp_central_angles, tf.sparse.SparseTensor):
2981 inp_central_angles = self.get_dense_model_central_angles(
2982 sparse_inp_central_angles, training=True
2983 )
2984 else:
2985 inp_central_angles = sparse_inp_central_angles
2987 # central dihedrals
2988 if isinstance(sparse_inp_central_dihedrals, tf.sparse.SparseTensor):
2989 inp_central_dihedrals = self.get_dense_model_central_dihedrals(
2990 sparse_inp_central_dihedrals, training=True
2991 )
2992 else:
2993 inp_central_dihedrals = sparse_inp_central_dihedrals
2995 # all cartesians
2996 if isinstance(sparse_inp_all_cartesians, tf.sparse.SparseTensor):
2997 inp_all_cartesians = self.get_dense_model_cartesians(
2998 sparse_inp_all_cartesians, training=True
2999 )
3000 else:
3001 inp_all_cartesians = sparse_inp_all_cartesians
3003 # central distances
3004 if isinstance(sparse_inp_central_distances, tf.sparse.SparseTensor):
3005 inp_central_distances = self.get_dense_model_distances(
3006 sparse_inp_central_distances, training=True
3007 )
3008 else:
3009 inp_central_distances = sparse_inp_central_distances
3011 # side_angles
3012 if isinstance(sparse_inp_side_angles, tf.sparse.SparseTensor):
3013 inp_side_angles = self.get_dense_model_side_angles(
3014 sparse_inp_side_angles, training=True
3015 )
3016 else:
3017 inp_side_angles = sparse_inp_side_angles
3019 # side dihedrals
3020 if isinstance(sparse_inp_side_dihedrals, tf.sparse.SparseTensor):
3021 inp_side_dihedrals = self.get_dense_model_side_dihedrals(
3022 sparse_inp_side_dihedrals, training=True
3023 )
3024 else:
3025 inp_side_dihedrals = sparse_inp_side_dihedrals
3027 # side distances
3028 if isinstance(sparse_inp_side_distances, tf.sparse.SparseTensor):
3029 inp_side_distances = self.get_dense_model_side_distances(
3030 sparse_inp_side_distances, training=True
3031 )
3032 else:
3033 inp_side_distances = sparse_inp_side_distances
3035 data = (
3036 inp_central_angles,
3037 inp_central_dihedrals,
3038 inp_all_cartesians,
3039 inp_central_distances,
3040 inp_side_angles,
3041 inp_side_dihedrals,
3042 inp_side_distances,
3043 )
3045 # call the loss
3046 resulting_loss, _, out_cartesians = ADCFunctionalModel.get_loss(self, data)
3047 if self.get_dense_model_cartesians is not None:
3048 inp_all_cartesians = self.reshape_layer(inp_all_cartesians)
3049 return resulting_loss, inp_all_cartesians, out_cartesians
3051 def get_config(self) -> dict[str, Any]:
3052 sidechain_info = self.p.sidechain_info
3053 config = super().get_config().copy()
3054 config.update(
3055 {
3056 "sidechain_info": sidechain_info,
3057 "get_dense_model_side_angles": tf.keras.saving.serialize_keras_object(
3058 self.get_dense_model_side_angles
3059 ),
3060 "get_dense_model_side_distances": tf.keras.saving.serialize_keras_object(
3061 self.get_dense_model_side_distances
3062 ),
3063 }
3064 )
3065 return config
3067 @classmethod
3068 def from_config(
3069 cls: Type[ADCFunctionalModelSidechainReconstructionType],
3070 config: dict[Any, Any],
3071 ) -> ADCFunctionalModelSidechainReconstructionType:
3072 """Reconstructs this keras serializable from a dict.
3074 Args:
3075 config (dict[Any, Any]): A dictionary.
3077 Returns:
3078 BackMapLayerType: An instance of the BackMapLayer.
3080 """
3081 raise Exception(f"Also put the sidechain_indices back into the parameters")
3082 return cls(parameters=p, **config)
3085class SequentialModel(tf.keras.Model):
3086 def __init__(
3087 self,
3088 input_dim: int,
3089 parameters: Optional[Parameters] = None,
3090 sparse: bool = False,
3091 get_dense_model: Optional[tf.keras.Model] = None,
3092 # reload_layers: Optional[Sequence[tf.keras.Model]] = None,
3093 ) -> None:
3094 if parameters is None:
3095 self.p = Parameters()
3096 else:
3097 self.p = parameters
3098 super().__init__()
3099 self.sparse = sparse
3100 self.input_dim = input_dim
3101 self.get_dense_model = get_dense_model
3103 # tensors for using tf.cond inside self.train_step()
3104 self.update_step = tf.constant(self.p.summary_step, dtype="int64")
3105 self.debug_tensor = tf.constant(self.p.tensorboard, dtype="bool")
3107 # periodicity doubles the inputs and outputs
3108 if self.p.periodicity < float("inf"):
3109 self.input_dim *= 2
3111 # define regularizer
3112 regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant)
3114 # rename empty string in parameters to None
3115 activation_functions = list(
3116 map(lambda x: x if x != "" else None, self.p.activation_functions)
3117 )
3119 # define how layers are stacked
3120 layer_data = list(
3121 zip(
3122 self.p.n_neurons + self.p.n_neurons[-2::-1],
3123 activation_functions[1:] + activation_functions[-2::-1],
3124 )
3125 )
3126 # add a layer that reshapes the output
3127 layer_data.append([self.input_dim, None])
3129 # decide layer names
3130 names = []
3131 for i, (n_neurons, act_fun) in enumerate(layer_data):
3132 if i < len(self.p.n_neurons) - 1:
3133 name = f"Encoder_{i}"
3134 elif i > len(self.p.n_neurons) - 1:
3135 ind = i - len(self.p.n_neurons)
3136 name = f"Decoder_{ind}"
3137 else:
3138 name = "Latent"
3139 names.append(name)
3140 layer_data = list((*i, j) for i, j in zip(layer_data, names))
3142 # define encoder and decoder layers
3143 neurons = [i[0] for i in layer_data]
3144 bottleneck_index = neurons.index(min(neurons)) + 1
3145 self.encoder_layers = layer_data[:bottleneck_index]
3146 self.decoder_layers = layer_data[bottleneck_index:]
3148 # input
3149 # Instead of using InputLayer use Dense with kwarg input_shape
3150 # allows model to be reloaded better <- weird english... reloaded better
3151 if self.sparse:
3152 shape = self.input_dim
3153 if self.p.periodicity < float("inf"):
3154 shape /= 2
3155 _input_layer = Input(
3156 shape=(int(shape),),
3157 sparse=True,
3158 )
3159 x = Dense(shape)(_input_layer)
3160 self.get_dense_model = tf.keras.Model(
3161 inputs=_input_layer,
3162 outputs=x,
3163 )
3165 input_layer = tf.keras.layers.Dense(
3166 input_shape=(self.input_dim,),
3167 units=self.encoder_layers[0][0],
3168 activation=self.encoder_layers[0][1],
3169 name=self.encoder_layers[0][2],
3170 kernel_initializer=tf.initializers.VarianceScaling(),
3171 kernel_regularizer=regularizer,
3172 bias_initializer=tf.initializers.RandomNormal(0.1, 0.05),
3173 )
3175 # encoder
3176 self.encoder_model = tf.keras.Sequential(
3177 [input_layer]
3178 + [
3179 tf.keras.layers.Dense(
3180 n_neurons,
3181 activation=act_fun,
3182 name=name,
3183 kernel_initializer=tf.initializers.VarianceScaling(),
3184 kernel_regularizer=regularizer,
3185 bias_initializer=tf.initializers.RandomNormal(0.1, 0.05),
3186 )
3187 for n_neurons, act_fun, name in self.encoder_layers[1:]
3188 ],
3189 name="Encoder",
3190 )
3192 # decoder
3193 self.decoder_model = tf.keras.Sequential(
3194 [
3195 tf.keras.layers.Dense(
3196 n_neurons,
3197 activation=act_fun,
3198 name=name,
3199 kernel_initializer=tf.initializers.VarianceScaling(),
3200 kernel_regularizer=regularizer,
3201 bias_initializer=tf.initializers.RandomNormal(0.1, 0.05),
3202 )
3203 for n_neurons, act_fun, name in self.decoder_layers
3204 ],
3205 name="Decoder",
3206 )
3208 # build
3209 self.build(input_shape=(1, self.input_dim))
3211 # train counter
3212 self._my_train_counter = K.variable(0, "int64", name="train_counter")
3214 @classmethod
3215 def from_config(
3216 cls: Type[SequentialModelType],
3217 config: dict[str, Any],
3218 custom_objects: Optional[dict[Any, Any]] = None,
3219 ) -> SequentialModelType:
3220 """Reconstructs this keras serializable from a dict.
3222 Args:
3223 config (dict[str, Any]): A dictionary.
3224 custom_objects (Optional[dict[str, Any]]): Not needed here, but see
3225 https://keras.io/guides/serialization_and_saving/ for yourself.
3227 Returns:
3228 SequentialModelType: An instance of the SequentialModel.
3230 """
3231 input_dim = config.pop("input_dim")
3232 sparse = config.pop("sparse")
3233 parameters = config.pop("parameters")
3234 if "cartesian_pwd_start" in parameters:
3235 parameters = ADCParameters(**parameters)
3236 else:
3237 parameters = Parameters(**parameters)
3239 if parameters.periodicity < float("inf"):
3240 input_dim = input_dim // 2
3241 else:
3242 pass
3244 encoder_config = config.pop("encoder")
3245 encoder = tf.keras.saving.deserialize_keras_object(encoder_config)
3246 decoder_config = config.pop("decoder")
3247 decoder = tf.keras.saving.deserialize_keras_object(decoder_config)
3248 get_dense_model = config.pop("get_dense_model")
3249 if get_dense_model is not None:
3250 get_dense_model = tf.keras.saving.deserialize_keras_object(get_dense_model)
3252 new_class = cls(
3253 input_dim=input_dim,
3254 parameters=parameters,
3255 sparse=sparse,
3256 get_dense_model=get_dense_model,
3257 )
3259 new_class.encoder_model.set_weights(encoder.get_weights())
3260 new_class.decoder_model.set_weights(decoder.get_weights())
3262 if parameters.periodicity < float("inf"):
3263 new_class.compute_output_shape(input_shape=(1, input_dim))
3264 else:
3265 new_class.compute_output_shape(input_shape=(1, input_dim))
3267 return new_class
3269 def get_config(self) -> dict[str, Any]:
3270 """Serializes this keras serializable.
3272 Returns:
3273 dict[str, Any]: A dict with the serializable objects.
3275 """
3276 config = super().get_config()
3277 config.update(
3278 {
3279 "input_dim": self.input_dim,
3280 "parameters": self.p.to_dict(),
3281 "sparse": self.sparse,
3282 "encoder": tf.keras.saving.serialize_keras_object(self.encoder_model),
3283 "decoder": tf.keras.saving.serialize_keras_object(self.decoder_model),
3284 }
3285 )
3286 if self.get_dense_model is not None:
3287 config["get_dense_model"] = tf.keras.saving.serialize_keras_object(
3288 self.get_dense_model
3289 )
3290 else:
3291 config["get_dense_model"] = None
3292 return config
3294 def build(self, input_shape):
3295 input_shape = self.encoder_model.input_shape
3296 if isinstance(input_shape, tuple):
3297 input_shape = input_shape[1]
3299 if self.p.periodicity < float("inf"):
3300 super().build(input_shape=(1, input_shape // 2))
3301 else:
3302 super().build(input_shape=(1, input_shape))
3304 def compile(self, *args, **kwargs):
3305 self.unpacked_loss_fns = kwargs.pop("unpacked_loss_fns")
3306 super().compile(*args, **kwargs)
3308 def encoder(self, x, training=False):
3309 """In the sequential model, the encoder is a method (as oppes to a model).
3311 This method handles the input, when the periodicity of the input data
3312 is greater than float('inf').
3314 Args:
3315 x (Union[np.ndarray, tf.Tensor): The input.
3316 training (bool): Whether we are training and compute gradients.
3318 Returns:
3319 Union[np.ndarray, tf.Tensor]: The output of the encoder.
3321 """
3322 if self.sparse:
3323 x = self.get_dense_model(x)
3324 if self.p.periodicity < float("inf"):
3325 if self.p.periodicity != 2 * pi:
3326 x = x / self.p.periodicity * 2 * pi
3327 x = tf.concat([tf.sin(x), tf.cos(x)], 1)
3328 return self.encoder_model(x, training=training)
3330 def decoder(self, x, training=False):
3331 x = self.decoder_model(x, training=training)
3332 if self.p.periodicity < float("inf"):
3333 x = tf.atan2(*tf.split(x, 2, 1))
3334 if self.p.periodicity != 2 * pi:
3335 x = x / (2 * pi) * self.p.periodicity
3336 return x
3338 def call(self, x, training=False):
3339 # encode and decode
3340 encoded = self.encoder(x, training=training)
3341 decoded = self.decoder(encoded, training=training)
3342 return decoded
3344 def train_step(self, data):
3345 """Overwrites the normal train_step. What is different?
3347 Not much. Even the provided data is expected to be a tuple of (data, classes) (x, y) in classification tasks.
3348 The data is unpacked, and y is discarded, because the Autoencoder Model is a regression task.
3350 Args:
3351 data (tuple): The (x, y) data of this train step.
3353 """
3354 x, _ = data
3355 if self.sparse:
3356 x = self.get_dense_model(x)
3358 with tf.GradientTape() as tape:
3359 tf.summary.experimental.set_step(self._my_train_counter)
3360 with tf.name_scope("Cost"):
3361 loss = 0.0
3362 for l in self.compiled_loss._losses:
3363 loss += l(x, self(x, training=True))
3364 tf.summary.scalar("Combined Cost", loss)
3365 for l in self.encoder_model.layers + self.decoder_model.layers:
3366 add_layer_summaries(l, step=self._my_train_counter)
3368 # Compute Gradients
3369 trainable_vars = self.trainable_variables
3370 gradients = tape.gradient(loss, trainable_vars)
3371 # Update weights
3372 self.optimizer.apply_gradients(zip(gradients, trainable_vars))
3373 # Update metrics
3374 self.compiled_metrics.update_state(x, self(x))
3375 # update train counter because tensorflow seems to have deprecated it
3376 self._my_train_counter.assign_add(1)
3377 # Return a dict mapping metric names to current value
3378 return {**{m.name: m.result() for m in self.metrics}, **{"loss": loss}}
3381@testing
3382class ADCSequentialModel(SequentialModel):
3383 def __init__(self, input_dim, parameters=None, reload_layers=None, sparse=False):
3384 warnings.warn("check split")
3385 if parameters is None:
3386 self.p = ADCParameters()
3387 else:
3388 self.p = parameters
3389 self.multiples = tf.TensorShape((self.p.batch_size, 1))
3390 super(ADCSequentialModel, self).__init__(input_dim, self.p, reload_layers)
3392 def call(self, x, training=False):
3393 # encode and decode
3394 if isinstance(x, tf.Tensor):
3395 pass
3396 elif isinstance(x, tuple):
3397 (
3398 inp_angles,
3399 inp_dihedrals,
3400 inp_cartesians,
3401 inp_distances,
3402 inp_side_dihedrals,
3403 ) = x
3404 if not self.p.use_backbone_angles and not self.p.use_sidechains:
3405 x = inp_dihedrals
3406 elif self.p.use_backbone_angles and not self.p.use_sidechains:
3407 x = tf.concat([inp_angles, inp_dihedrals], 1)
3408 elif self.p.use_backbone_angles and self.p.use_sidechains:
3409 x = tf.concat([inp_angles, inp_dihedrals, inp_side_dihedrals], 1)
3410 encoded = self.encoder(x, training=training)
3411 decoded = self.decoder(encoded, training=training)
3412 return decoded
3414 def call_and_map_back(
3415 self, x, distances, angles, dihedrals, cartesians, splits, side_dihedrals=None
3416 ):
3417 # latent = self.encoder(x, training=False)
3418 out = self(x, training=True)
3419 latent = self.encoder(x, training=True)
3421 # unpack out
3422 if splits is None:
3423 out_dihedrals = out
3424 out_angles = tf.tile(
3425 tf.expand_dims(tf.reduce_mean(angles, 0), 0), multiples=self.multiples
3426 )
3427 elif len(splits) == 2:
3428 out_angles, out_dihedrals = tf.split(out, splits, 1)
3429 elif len(splits) == 3:
3430 out_angles, out_dihedrals, out_side_dihedrals = tf.split(out, splits, 1)
3432 # do back-mapping
3433 back_mean_lengths = tf.expand_dims(tf.reduce_mean(distances, 0), 0)
3434 back_chain_in_plane = chain_in_plane(back_mean_lengths, out_angles)
3435 back_cartesians = dihedrals_to_cartesian_tf(
3436 out_dihedrals + pi, back_chain_in_plane
3437 )
3439 # get pairwise distances of CA atoms
3440 inp_pair = pairwise_dist(
3441 cartesians[
3442 :,
3443 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
3444 ],
3445 flat=True,
3446 )
3447 out_pair = pairwise_dist(
3448 back_cartesians[
3449 :,
3450 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
3451 ],
3452 flat=True,
3453 )
3455 with tf.name_scope("Cost"):
3456 loss = 0.0
3457 # dihedral loss
3458 loss += self.unpacked_loss_fns["dihedral_loss_func"](
3459 dihedrals, out_dihedrals
3460 )
3461 # angle loss
3462 loss += self.unpacked_loss_fns["angle_loss_func"](angles, out_angles)
3463 # cartesian loss
3464 loss += self.unpacked_loss_fns["cartesian_loss_func"](inp_pair, out_pair)
3465 # distance loss
3466 loss += self.unpacked_loss_fns["distance_loss_func"](x)
3467 # cartesian distance cost
3468 loss += self.unpacked_loss_fns["cartesian_distance_loss_func"](
3469 inp_pair, latent
3470 )
3471 # center loss
3472 loss += self.unpacked_loss_fns["center_loss_func"](x)
3473 # reg loss
3474 loss += self.unpacked_loss_fns["regularization_loss_func"]()
3475 if self.p.use_sidechains:
3476 loss += self.unpacked_loss_fns["side_dihedral_loss_func"](
3477 side_dihedrals, out_side_dihedrals
3478 )
3479 tf.summary.scalar("Combined Cost", loss)
3480 return loss
3482 def train_step(self, data):
3483 # unpack the data
3484 (
3485 inp_angles,
3486 inp_dihedrals,
3487 inp_cartesians,
3488 inp_distances,
3489 inp_side_dihedrals,
3490 ) = data
3491 if not self.p.use_backbone_angles and not self.p.use_sidechains:
3492 main_inputs = inp_dihedrals
3493 splits = None
3494 elif self.p.use_backbone_angles and not self.p.use_sidechains:
3495 main_inputs = tf.concat([inp_angles, inp_dihedrals], 1)
3496 splits = [inp_angles.shape[1], inp_dihedrals.shape[1]]
3497 elif self.p.use_backbone_angles and self.p.use_sidechains:
3498 # ToDo: make sure the splits work here. There seems to be different outputs from tf and np.
3499 main_inputs = tf.concat([inp_angles, inp_dihedrals, inp_side_dihedrals], 1)
3500 splits = [
3501 inp_angles.shape[1],
3502 inp_dihedrals.shape[1],
3503 inp_side_dihedrals.shape[1],
3504 ]
3506 # gradient tape to calculate loss for backmapping
3507 with tf.GradientTape() as tape:
3508 tf.summary.experimental.set_step(self._my_train_counter)
3509 if self.p.use_sidechains:
3510 loss = self.call_and_map_back(
3511 main_inputs,
3512 inp_distances,
3513 inp_angles,
3514 inp_dihedrals,
3515 inp_cartesians,
3516 splits,
3517 inp_side_dihedrals,
3518 )
3519 else:
3520 loss = self.call_and_map_back(
3521 main_inputs,
3522 inp_distances,
3523 inp_angles,
3524 inp_dihedrals,
3525 inp_cartesians,
3526 splits,
3527 )
3528 for l in self.encoder_model.layers + self.decoder_model.layers:
3529 add_layer_summaries(l, step=self._my_train_counter)
3531 # optimization happens here
3532 # Compute Gradients
3533 trainable_vars = self.trainable_variables
3534 gradients = tape.gradient(loss, trainable_vars)
3535 # Update weights
3536 self.optimizer.apply_gradients(zip(gradients, trainable_vars))
3537 # Update metrics
3538 self.compiled_metrics.update_state(data, self(data))
3539 # Return a dict mapping metric names to current value
3540 # Add loss to the dict so the ProgressBar callback can pick it up
3541 return {**{m.name: m.result() for m in self.metrics}, **{"loss": loss}}