Coverage for encodermap/models/layers.py: 83%
72 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/models/layers.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Module that implements custom layers. Mainly needed for handling periodicity,
23backmapping or sparsity."""
26################################################################################
27# Imports
28################################################################################
31from __future__ import annotations
33from math import pi
35import tensorflow as tf
36from tensorflow.keras.layers import Concatenate, Dense, Lambda, Layer
38from ..encodermap_tf1.backmapping import chain_in_plane
39from ..misc import pairwise_dist
40from ..misc.backmapping import (
41 dihedrals_to_cartesian_tf_layers,
42 split_and_reverse_cartesians,
43 split_and_reverse_dihedrals,
44)
46################################################################################
47# Globals
48################################################################################
51__all__ = ["PeriodicInput", "PeriodicOutput", "MeanAngles", "BackMapLayer", "Sparse"]
54################################################################################
55# Layers
56################################################################################
59class Sparse(Dense):
60 """Simple subclass of tf.keras.layers.Dense, which implements sparse_dense_matmul"""
62 def call(self, inputs):
63 """Call the layer."""
64 outputs = tf.sparse.sparse_dense_matmul(inputs, self.kernel)
65 if self.use_bias:
66 outputs = tf.nn.bias_add(outputs, self.bias)
67 return outputs
70class SparseReshape(tf.keras.layers.Reshape):
71 """Layer that can reshapa a sparse Tensor."""
73 def call(self, inputs):
74 return tf.reshape(inputs, shape=(self.target_shape,))
77class PeriodicInput(Layer):
78 """Layer that handles periodic input. Needed, if angles are treated. Input angles
79 will be split into sin and cos components and a tensor with shape[0] = 2 * inp_shape[0]
80 will be returned
81 """
83 def __init__(self, parameters, print_name, trainable=False):
84 """Instantiate the layer. Need parameters to get the info about the
85 periodicity. Although angles are most often used, who knows what hyper-toroidal
86 manifold your data lies in.
88 """
89 super().__init__(trainable)
90 self.p = parameters
91 self.print_name = print_name
92 self._name = self.print_name + "_Periodic_Input"
94 def call(self, inputs):
95 """Call the layer."""
96 outputs = inputs
97 if self.p.periodicity != 2 * pi:
98 outputs = Lambda(
99 lambda x: x / self.p.periodicity * 2 * pi,
100 name=f"{self.print_name}_Periodicity_to_2_pi",
101 )(outputs)
102 outputs = Concatenate(axis=1, name=f"{self.print_name}_Concat")(
103 [
104 Lambda(lambda x: tf.sin(x), name=f"{self.print_name}_sin")(outputs),
105 Lambda(lambda x: tf.cos(x), name=f"{self.print_name}_cos")(outputs),
106 ]
107 )
108 return outputs
111class BackMapLayer(Layer):
112 """Layer that implements backmapping from torsions-angles-distances to euclidean coordinates."""
114 def __init__(self):
115 """Instantiate the layer"""
116 super().__init__()
117 self._name = "Backmap_Layer"
119 def call(self, inputs):
120 """Call the layers, inputs should be a tuple shaped, so that it can be split into
121 distances, angles, dihedrals = inputs
122 """
123 inp_distances, out_angles, out_dihedrals = inputs
124 # mean lengths
125 # back_mean_lengths = tf.expand_dims(tf.reduce_mean(inp_distances, 0), 0)
126 out = Lambda(
127 lambda x: tf.expand_dims(tf.reduce_mean(x, 0), 0), name="Back_Mean_Lengths"
128 )(inp_distances)
130 # chain in plane
131 # back_chain_in_plane = chain_in_plane(back_mean_lengths, out_angles)
132 out = Lambda(lambda x: chain_in_plane(x[0], x[1]), name="Back_Chain_in_Plane")(
133 (out, out_angles)
134 )
136 # dihedrals to cartesian
137 # back_cartesians = dihedrals_to_cartesian_tf(out_dihedrals + pi, back_chain_in_plane)
138 out_dihedrals = Lambda(lambda x: tf.add(x, pi), name="Added_Pi")(out_dihedrals)
139 out = Lambda(
140 lambda x: dihedrals_to_cartesian_tf_layers(x[0], x[1]),
141 name="Back_Cartesians",
142 )((out_dihedrals, out))
143 return out
146class PeriodicOutput(Layer):
147 """Layer that reverses the PeriodicInputLayer."""
149 def __init__(self, parameters, print_name, trainable=False):
150 """Instantiate the layer, We also need to know here, what periodicity is needed."""
151 super().__init__(trainable)
152 self.p = parameters
153 self.print_name = print_name
154 self._name = self.print_name + "_Periodic_Output"
156 def call(self, inputs):
157 """Calls the layer, Inputs shold be a tuple of (sin, cos) of the same angles"""
158 outputs = inputs
159 outputs = Lambda(
160 lambda x: tf.atan2(*tf.split(x, 2, 1)),
161 name=f"{self.print_name}_Fom_Unitcircle",
162 )(outputs)
163 if self.p.periodicity != 2 * pi: 163 ↛ 168line 163 didn't jump to line 168, because the condition on line 163 was never false
164 outputs = Lambda(
165 lambda x: x / (2 * pi) * self.p.periodicity,
166 name=f"{self.print_name}_Periodicity_from_2_pi",
167 )(outputs)
168 return outputs
171class MeanAngles(Layer):
172 """Layer that implements the mean of periodic angles."""
174 def __init__(self, parameters, print_name, multiples_shape):
175 """Instantiate the layer."""
176 super().__init__()
177 self.p = parameters
178 self.print_name = print_name
179 self.multiples_shape = multiples_shape
180 self._name = self.print_name
182 def call(self, inputs):
183 """Call the layer"""
184 outputs = Lambda(
185 lambda x: tf.tile(
186 tf.expand_dims(tf.math.reduce_mean(x, 0), 0),
187 multiples=(self.multiples_shape, 1),
188 ),
189 name=self.print_name,
190 )(inputs)
191 return outputs
192 # out_angles = tf.tile(np.expand_dims(np.mean(angles, 0), 0), multiples=(out_dihedrals.shape[0], 1))
195class PairwiseDistances(Layer):
196 """Layer that implements pairwise distances."""
198 def __init__(self, parameters, print_name):
199 """Instantiate the layer."""
200 super().__init__()
201 self.p = parameters
202 self.print_name = print_name
203 self._name = self.print_name + "_Pairwise"
205 def call(self, inputs):
206 """Call the layer"""
207 out = inputs[
208 :,
209 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
210 ]
211 out = Lambda(
212 lambda x: pairwise_dist(x, flat=True),
213 name=f"{self.print_name}_Pairwise_Distances",
214 )(out)
215 return out