Coverage for encodermap/models/layers.py: 83%

72 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/models/layers.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Module that implements custom layers. Mainly needed for handling periodicity, 

23backmapping or sparsity.""" 

24 

25 

26################################################################################ 

27# Imports 

28################################################################################ 

29 

30 

31from __future__ import annotations 

32 

33from math import pi 

34 

35import tensorflow as tf 

36from tensorflow.keras.layers import Concatenate, Dense, Lambda, Layer 

37 

38from ..encodermap_tf1.backmapping import chain_in_plane 

39from ..misc import pairwise_dist 

40from ..misc.backmapping import ( 

41 dihedrals_to_cartesian_tf_layers, 

42 split_and_reverse_cartesians, 

43 split_and_reverse_dihedrals, 

44) 

45 

46################################################################################ 

47# Globals 

48################################################################################ 

49 

50 

51__all__ = ["PeriodicInput", "PeriodicOutput", "MeanAngles", "BackMapLayer", "Sparse"] 

52 

53 

54################################################################################ 

55# Layers 

56################################################################################ 

57 

58 

59class Sparse(Dense): 

60 """Simple subclass of tf.keras.layers.Dense, which implements sparse_dense_matmul""" 

61 

62 def call(self, inputs): 

63 """Call the layer.""" 

64 outputs = tf.sparse.sparse_dense_matmul(inputs, self.kernel) 

65 if self.use_bias: 

66 outputs = tf.nn.bias_add(outputs, self.bias) 

67 return outputs 

68 

69 

70class SparseReshape(tf.keras.layers.Reshape): 

71 """Layer that can reshapa a sparse Tensor.""" 

72 

73 def call(self, inputs): 

74 return tf.reshape(inputs, shape=(self.target_shape,)) 

75 

76 

77class PeriodicInput(Layer): 

78 """Layer that handles periodic input. Needed, if angles are treated. Input angles 

79 will be split into sin and cos components and a tensor with shape[0] = 2 * inp_shape[0] 

80 will be returned 

81 """ 

82 

83 def __init__(self, parameters, print_name, trainable=False): 

84 """Instantiate the layer. Need parameters to get the info about the 

85 periodicity. Although angles are most often used, who knows what hyper-toroidal 

86 manifold your data lies in. 

87 

88 """ 

89 super().__init__(trainable) 

90 self.p = parameters 

91 self.print_name = print_name 

92 self._name = self.print_name + "_Periodic_Input" 

93 

94 def call(self, inputs): 

95 """Call the layer.""" 

96 outputs = inputs 

97 if self.p.periodicity != 2 * pi: 

98 outputs = Lambda( 

99 lambda x: x / self.p.periodicity * 2 * pi, 

100 name=f"{self.print_name}_Periodicity_to_2_pi", 

101 )(outputs) 

102 outputs = Concatenate(axis=1, name=f"{self.print_name}_Concat")( 

103 [ 

104 Lambda(lambda x: tf.sin(x), name=f"{self.print_name}_sin")(outputs), 

105 Lambda(lambda x: tf.cos(x), name=f"{self.print_name}_cos")(outputs), 

106 ] 

107 ) 

108 return outputs 

109 

110 

111class BackMapLayer(Layer): 

112 """Layer that implements backmapping from torsions-angles-distances to euclidean coordinates.""" 

113 

114 def __init__(self): 

115 """Instantiate the layer""" 

116 super().__init__() 

117 self._name = "Backmap_Layer" 

118 

119 def call(self, inputs): 

120 """Call the layers, inputs should be a tuple shaped, so that it can be split into 

121 distances, angles, dihedrals = inputs 

122 """ 

123 inp_distances, out_angles, out_dihedrals = inputs 

124 # mean lengths 

125 # back_mean_lengths = tf.expand_dims(tf.reduce_mean(inp_distances, 0), 0) 

126 out = Lambda( 

127 lambda x: tf.expand_dims(tf.reduce_mean(x, 0), 0), name="Back_Mean_Lengths" 

128 )(inp_distances) 

129 

130 # chain in plane 

131 # back_chain_in_plane = chain_in_plane(back_mean_lengths, out_angles) 

132 out = Lambda(lambda x: chain_in_plane(x[0], x[1]), name="Back_Chain_in_Plane")( 

133 (out, out_angles) 

134 ) 

135 

136 # dihedrals to cartesian 

137 # back_cartesians = dihedrals_to_cartesian_tf(out_dihedrals + pi, back_chain_in_plane) 

138 out_dihedrals = Lambda(lambda x: tf.add(x, pi), name="Added_Pi")(out_dihedrals) 

139 out = Lambda( 

140 lambda x: dihedrals_to_cartesian_tf_layers(x[0], x[1]), 

141 name="Back_Cartesians", 

142 )((out_dihedrals, out)) 

143 return out 

144 

145 

146class PeriodicOutput(Layer): 

147 """Layer that reverses the PeriodicInputLayer.""" 

148 

149 def __init__(self, parameters, print_name, trainable=False): 

150 """Instantiate the layer, We also need to know here, what periodicity is needed.""" 

151 super().__init__(trainable) 

152 self.p = parameters 

153 self.print_name = print_name 

154 self._name = self.print_name + "_Periodic_Output" 

155 

156 def call(self, inputs): 

157 """Calls the layer, Inputs shold be a tuple of (sin, cos) of the same angles""" 

158 outputs = inputs 

159 outputs = Lambda( 

160 lambda x: tf.atan2(*tf.split(x, 2, 1)), 

161 name=f"{self.print_name}_Fom_Unitcircle", 

162 )(outputs) 

163 if self.p.periodicity != 2 * pi: 163 ↛ 168line 163 didn't jump to line 168, because the condition on line 163 was never false

164 outputs = Lambda( 

165 lambda x: x / (2 * pi) * self.p.periodicity, 

166 name=f"{self.print_name}_Periodicity_from_2_pi", 

167 )(outputs) 

168 return outputs 

169 

170 

171class MeanAngles(Layer): 

172 """Layer that implements the mean of periodic angles.""" 

173 

174 def __init__(self, parameters, print_name, multiples_shape): 

175 """Instantiate the layer.""" 

176 super().__init__() 

177 self.p = parameters 

178 self.print_name = print_name 

179 self.multiples_shape = multiples_shape 

180 self._name = self.print_name 

181 

182 def call(self, inputs): 

183 """Call the layer""" 

184 outputs = Lambda( 

185 lambda x: tf.tile( 

186 tf.expand_dims(tf.math.reduce_mean(x, 0), 0), 

187 multiples=(self.multiples_shape, 1), 

188 ), 

189 name=self.print_name, 

190 )(inputs) 

191 return outputs 

192 # out_angles = tf.tile(np.expand_dims(np.mean(angles, 0), 0), multiples=(out_dihedrals.shape[0], 1)) 

193 

194 

195class PairwiseDistances(Layer): 

196 """Layer that implements pairwise distances.""" 

197 

198 def __init__(self, parameters, print_name): 

199 """Instantiate the layer.""" 

200 super().__init__() 

201 self.p = parameters 

202 self.print_name = print_name 

203 self._name = self.print_name + "_Pairwise" 

204 

205 def call(self, inputs): 

206 """Call the layer""" 

207 out = inputs[ 

208 :, 

209 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

210 ] 

211 out = Lambda( 

212 lambda x: pairwise_dist(x, flat=True), 

213 name=f"{self.print_name}_Pairwise_Distances", 

214 )(out) 

215 return out