Coverage for encodermap/models/models.py: 43%
457 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/models/models.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""ToDo:
23 * Add some nice images to the plot_model of the functional model.
25"""
27##############################################################################
28# Imports
29##############################################################################
32import warnings
33from math import pi
35import numpy as np
36import tensorflow as tf
37import tensorflow.keras.backend as K
38from tensorflow.keras.layers import Concatenate, Dense, Input, Lambda
40from ..encodermap_tf1.backmapping import chain_in_plane, dihedrals_to_cartesian_tf
41from ..misc import pairwise_dist
42from ..misc.summaries import add_layer_summaries
43from ..parameters.parameters import ADCParameters, Parameters
44from ..trajinfo.info_all import Capturing
45from .layers import (
46 BackMapLayer,
47 MeanAngles,
48 PairwiseDistances,
49 PeriodicInput,
50 PeriodicOutput,
51 Sparse,
52 SparseReshape,
53)
55##############################################################################
56# Globals
57##############################################################################
60__all__ = ["gen_sequential_model", "gen_functional_model"]
63##############################################################################
64# Public Functions
65##############################################################################
68def gen_sequential_model(input_shape, parameters=None, sparse=False):
69 """Returns a tf.keras Model build with the specified input shape and the parameters in the Parameters class.
71 Args:
72 input_shape (int): The input shape of the returned model. In most cases that is data.shape[1] of your data.
73 parameters (Union[encodermap.Parameters, encodermap.ADCParameters, None], optional): The parameters to
74 use on the returned model. If None is provided the default parameters in encodermap.Parameters.defaults
75 is used. You can look at the defaults with print(em.Parameters.defaults_description()). Defaults to None.
77 Returns:
78 em.SequentialModel: A subclass of tf.keras.Model build with specified parameters.
80 """
81 if parameters is None: 81 ↛ 82line 81 didn't jump to line 82, because the condition on line 81 was never true
82 parameters = Parameters()
83 else:
84 if isinstance(parameters, Parameters): 84 ↛ 86line 84 didn't jump to line 86, because the condition on line 84 was never false
85 return SequentialModel(input_shape, parameters, sparse=sparse)
86 elif isinstance(parameters, ADCParameters):
87 return ADCSequentialModel(input_shape, parameters)
88 else:
89 p = parameters
90 raise TypeError(
91 f"parameters need to be ecodermap.Parameters or encodermap.ACDParameters. You supplied {type(p)}"
92 )
95class Sparse(tf.keras.layers.Dense):
96 def call(self, inputs):
97 outputs = tf.sparse.sparse_dense_matmul(inputs, self.kernel)
98 if self.use_bias:
99 outputs = tf.nn.bias_add(outputs, self.bias)
100 return outputs
103class SparseModel(tf.keras.Model):
104 def __init__(self, name, input_shape):
105 super(SparseModel, self).__init__(name)
106 self._sparse_layer = Sparse(input_shape)
107 inputs = tf.keras.layers.Input(
108 shape=(input_shape,), sparse=True, name="sparse_tensor"
109 )
110 self._set_inputs(inputs)
112 def call(self, sparse_tensor):
113 return self._sparse_layer(sparse_tensor)
116def gen_functional_model(
117 input_dataset, parameters=None, reload_layers=None, sparse=False
118):
119 """Builds a model to specification of parameters using the functional API.
121 The functional API is much more flexible than the sequential API, in that models with multiple inputs and outputs
122 can be defined. Custom-layers and sub-models can be intermixed. In EncoderMap's case the functional API is used to
123 build the AngleDihedralCartesianAutoencoder, which takes input data in form of a tf.data.Dataset with:
124 * backbone_angles (angles between C, CA, N - atoms in the backbone).
125 * backbone_torsions (dihedral angles in the backbone, commonly known as omega, phi, psi).
126 * cartesian_coordinates (coordinates of the C, CA, N backbone atoms. This data has ndim 3, the other have ndim 2).
127 * backbone_distances (distances between the C, CA, N backbone atoms).
128 * sidechain_torsions (dihedral angles in the sidechain, commonly known as chi1, chi2, chi3, chi4, chi5).
129 Packing and unpacking that data in the correct manner is important. Make sure to double check whether you are using
130 angles or dihedrals. A simple print of the shape can be enough.
132 In the functional model all operations are tf.keras.layers, meaning that the projection onto a unit_circle that
133 the `SequentialModel` does in its `call()` method needs to be a layer. The FunctionalModel consist of 5 main parts:
134 * Angle Inputs: The provided dataset is unpacked and the periodic data of the angles is projected onto
135 a unit-circle. If the angles are in gradians, they will also be normalized into a [-pi, pi) interval.
136 * Autoencoder: The trainable part of the network consists of the Autoencoder part build to the specifications
137 in the provided parameters. Here, Dense layers are stacked. Only the angles and torsions are fed into the
138 Autoencoder. The Distances and Cartesians are used later.
139 * Angle Outputs: The angles are recalculated from their unit-circle inputs.
140 * Back-Mapping. The backmapping layer takes backbone_angles and backbone_dihedrals, backbone_distances to
141 calculate new cartesian coordinates.
142 * Pairwise Distances: The pairwise distances of the input cartesians and the back-mapped cartesians are calculated.
144 Args:
145 input_dataset (tf.data.Dataset): The dataset with the data in the order given in the explanation.
146 parameters (Union[em.ADCParameters, None], optional): The parameters to be used to build the network.
147 If None is provided the default parameters in encodermap.ADCParameters.defaults
148 is used. You can look at the defaults with print(em.ADCParameters.defaults_description()). Defaults to None.
149 reload_layers (Union[None, list], optional): List of layers that will be reloaded when reloading the model from
150 disk. Defaults to None, when a new model should be built.
152 Raises:
153 AssertionError: AssertionErrors will be raised when the input data is not formatted correctly.
154 This means, if len(cartesians) != len(distances) - 1, or len(cartesians) != len(angles) - 2.
155 This can also mean, the input dataset is not packed correctly. Please keep the order specified above.
156 This can also mean, that the provided protein is not linear (branched, circular, ...).
158 Returns:
159 em.FunctionalModel: A subclass of tf.keras.Model build with specified parameters.
161 """
162 if parameters is None: 162 ↛ 163line 162 didn't jump to line 163, because the condition on line 162 was never true
163 p = Parameters()
164 else:
165 p = parameters
167 if isinstance(p, Parameters): 167 ↛ 168line 167 didn't jump to line 168, because the condition on line 167 was never true
168 raise Exception(
169 "Functional Model is currently reserved for the ADCAutoencoder,"
170 "because of the way the data is packed and unpacked."
171 )
173 # unpack the shapes of the input
174 for i, d in enumerate(input_dataset): 174 ↛ 179line 174 didn't jump to line 179, because the loop on line 174 didn't complete
175 angles, dihedrals, cartesians, distances, side_dihedrals = d
176 break
178 # These assertions need to be changed for all proteins that are not-linear.
179 if isinstance(cartesians, tf.sparse.SparseTensor): 179 ↛ 180line 179 didn't jump to line 180, because the condition on line 179 was never true
180 assert distances.shape[1] == cartesians.shape[1] // 3 - 1
181 else:
182 assert distances.shape[1] == cartesians.shape[1] - 1, print(
183 distances.shape, cartesians.shape
184 )
185 assert angles.shape[1] == distances.shape[1] - 1, print(
186 angles.shape, cartesians.shape
187 )
188 assert dihedrals.shape[1] == distances.shape[1] - 2, print(
189 dihedrals.shape, cartesians.shape
190 )
192 if reload_layers is not None: 192 ↛ 193line 192 didn't jump to line 193, because the condition on line 192 was never true
193 raise Exception("currently not reloadable.")
195 # define regularizer
196 regularizer = tf.keras.regularizers.l2(p.l2_reg_constant)
198 # central cartesians
199 if not sparse or not isinstance(dihedrals, tf.sparse.SparseTensor): 199 ↛ 207line 199 didn't jump to line 207, because the condition on line 199 was never false
200 inp_dihedrals = Input(
201 shape=(dihedrals.shape[1],),
202 name="input_dihedrals",
203 )
204 x = PeriodicInput(p, "dihedrals")(inp_dihedrals)
205 get_dense_model_central_dihedrals = None
206 else:
207 inp_dihedrals = Input(
208 shape=(dihedrals.shape[1],),
209 name="input_dihedrals",
210 sparse=True,
211 )
212 x = Dense(dihedrals.shape[1])(inp_dihedrals)
213 get_dense_model_central_dihedrals = tf.keras.Model(
214 inputs=inp_dihedrals,
215 outputs=x,
216 )
217 x = PeriodicInput(p, "dihedrals")(x)
219 # backbone angles
220 if p.use_backbone_angles: 220 ↛ 242line 220 didn't jump to line 242, because the condition on line 220 was never false
221 if not sparse or not isinstance(angles, tf.sparse.SparseTensor): 221 ↛ 229line 221 didn't jump to line 229, because the condition on line 221 was never false
222 inp_angles = Input(
223 shape=(angles.shape[1],),
224 name="input_angles",
225 )
226 y = PeriodicInput(p, "angles")(inp_angles)
227 get_dense_model_central_angles = None
228 else:
229 inp_angles = Input(
230 shape=(angles.shape[1],),
231 name="input_angles",
232 sparse=True,
233 )
234 y = Dense(angles.shape[1])(inp_angles)
235 get_dense_model_central_angles = tf.keras.Model(
236 inputs=inp_angles,
237 outputs=y,
238 )
239 y = PeriodicInput(p, "angles")(y)
241 # sidechains
242 get_dense_model_side_dihedrals = None
243 if p.use_sidechains: 243 ↛ 265line 243 didn't jump to line 265, because the condition on line 243 was never false
244 if not sparse or not isinstance(side_dihedrals, tf.sparse.SparseTensor): 244 ↛ 250line 244 didn't jump to line 250, because the condition on line 244 was never false
245 inp_side_dihedrals = Input(
246 shape=(side_dihedrals.shape[1],), name="input_side_dihedrals"
247 )
248 z = PeriodicInput(p, "side_dihedrals")(inp_side_dihedrals)
249 else:
250 inp_side_dihedrals = Input(
251 shape=(side_dihedrals.shape[1],),
252 name="input_side_dihedrals",
253 sparse=True,
254 )
255 z = Dense(side_dihedrals.shape[1])(inp_side_dihedrals)
256 # z = SparseModel(name="Sparse_Model", input_shape=side_dihedrals.shape[1])(inp_side_dihedrals)
257 get_dense_model_side_dihedrals = tf.keras.Model(
258 inputs=inp_side_dihedrals,
259 outputs=z,
260 )
261 z = PeriodicInput(p, "side_dihedrals")(z)
263 # these inputs will be passed through and will be used for backmapping
264 # and RMSD metrics
265 if not sparse or not isinstance(cartesians, tf.sparse.SparseTensor): 265 ↛ 280line 265 didn't jump to line 280, because the condition on line 265 was never false
266 inp_cartesians = Input(
267 shape=(
268 cartesians.shape[1],
269 3,
270 ),
271 name="input_cartesians",
272 )
273 inp_distances = Input(
274 shape=(distances.shape[1],),
275 name="input_distances",
276 )
277 get_dense_model_cartesians = None
278 get_dense_model_distances = None
279 else:
280 inp_cartesians = Input(
281 shape=(cartesians.shape[1],),
282 name="input_cartesians",
283 sparse=True,
284 )
285 dc = Dense(cartesians.shape[1])(inp_cartesians)
286 get_dense_model_cartesians = tf.keras.Model(
287 inputs=inp_cartesians,
288 outputs=dc,
289 )
290 dc = tf.keras.layers.Reshape(
291 target_shape=(
292 cartesians.shape[1] // 3,
293 3,
294 ),
295 input_shape=(cartesians.shape[1],),
296 )(dc)
297 inp_distances = Input(
298 shape=(distances.shape[1],),
299 name="input_distances",
300 sparse=True,
301 )
302 dd = Dense(distances.shape[1])(inp_distances)
303 get_dense_model_distances = tf.keras.Model(
304 inputs=inp_distances,
305 outputs=dd,
306 )
308 # stack the three datasources going through the network
309 if not p.use_backbone_angles and not p.use_sidechains: 309 ↛ 310line 309 didn't jump to line 310, because the condition on line 309 was never true
310 splits = None
311 elif p.use_backbone_angles and not p.use_sidechains: 311 ↛ 312line 311 didn't jump to line 312, because the condition on line 311 was never true
312 splits = [x.shape[1], y.shape[1]]
313 x = Concatenate(axis=1, name="Main_Inputs")([x, y])
314 elif p.use_backbone_angles and p.use_sidechains: 314 ↛ 320line 314 didn't jump to line 320, because the condition on line 314 was never false
315 # ToDo: make sure the splits work here. There seems to be different outputs from tf and np.
316 splits = [x.shape[1], y.shape[1], z.shape[1]]
317 x = Concatenate(axis=1, name="Main_Inputs")([x, y, z])
319 # save the out_shape now and use it for an output layer
320 out_shape = x.shape[1]
322 # rename empty string in parameters to None
323 activation_functions = list(
324 map(lambda x: x if x != "" else None, p.activation_functions)
325 )
327 # define how layers are stacked
328 layer_data = list(
329 zip(
330 p.n_neurons + p.n_neurons[-2::-1],
331 activation_functions[1:] + activation_functions[-1::-1],
332 )
333 )
334 # add a layer that reshapes the output
335 layer_data.append([out_shape, "tanh"])
337 # decide layer names
338 names = []
339 for i, (n_neurons, act_fun) in enumerate(layer_data):
340 if i < len(p.n_neurons) - 1:
341 name = f"Encoder_{i}"
342 elif i > len(p.n_neurons) - 1:
343 ind = i - len(p.n_neurons)
344 name = f"Decoder_{ind}"
345 else:
346 name = "Latent"
347 names.append(name)
348 layer_data = list((*i, j) for i, j in zip(layer_data, names))
350 # unpack layer data into encoder and decoder
351 neurons = [i[0] for i in layer_data]
352 bottleneck_index = neurons.index(min(neurons)) + 1
353 encoder_layers = layer_data[:bottleneck_index]
354 decoder_layers = layer_data[bottleneck_index:]
356 # enocder layers
357 for n_neurons, act_fun, name in encoder_layers:
358 layer = Dense(
359 units=n_neurons,
360 activation=act_fun,
361 name=name,
362 kernel_initializer=tf.initializers.VarianceScaling(),
363 kernel_regularizer=regularizer,
364 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5),
365 )
366 x = layer(x)
368 # encoder model
369 if p.use_backbone_angles and p.use_sidechains: 369 ↛ 375line 369 didn't jump to line 375, because the condition on line 369 was never false
370 encoder = tf.keras.Model(
371 inputs=[inp_angles, inp_dihedrals, inp_side_dihedrals],
372 outputs=[x],
373 name="Encoder",
374 )
375 elif p.use_backbone_angles and not p.use_sidechains:
376 encoder = tf.keras.Model(
377 inputs=[inp_angles, inp_dihedrals],
378 outputs=[x],
379 name="Encoder",
380 )
381 else:
382 encoder = tf.keras.Model(
383 inputs=[inp_dihedrals, inp_side_dihedrals],
384 outputs=[x],
385 name="Encoder",
386 )
387 if p.tensorboard: 387 ↛ 388line 387 didn't jump to line 388, because the condition on line 387 was never true
388 with Capturing() as output:
389 encoder.summary()
390 with open(p.main_path + "/encoder_summary.txt", "w") as f:
391 f.write("\n".join(output))
393 # decoder input
394 decoder_input = Input(shape=(encoder_layers[-1][0],), name="Decoder_Input")
395 x = decoder_input
397 # decoder layers
398 for i, (n_neurons, act_fun, name) in enumerate(decoder_layers):
399 layer = Dense(
400 units=n_neurons,
401 activation=act_fun,
402 name=name,
403 kernel_initializer=tf.initializers.VarianceScaling(),
404 kernel_regularizer=regularizer,
405 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5),
406 )
407 x = layer(x)
409 # split output accordingly
410 if splits is None: 410 ↛ 411line 410 didn't jump to line 411, because the condition on line 410 was never true
411 out_dihedrals = x
412 out_angles = MeanAngles(p, "Mean_Angles", out_dihedrals.shape[0])(inp_angles)
413 decoder_input = [decoder_input, inp_angles]
414 decoder_output = [out_angles, out_dihedrals]
415 elif len(splits) == 2: 415 ↛ 416line 415 didn't jump to line 416, because the condition on line 415 was never true
416 out_angles, out_dihedrals = Lambda(
417 lambda x: tf.split(x, splits, 1), name="Split_Output"
418 )(x)
419 out_angles = PeriodicOutput(p, "Angles")(out_angles)
420 out_dihedrals = PeriodicOutput(p, "Dihedrals")(out_dihedrals)
421 decoder_input = [decoder_input]
422 decoder_output = [out_angles, out_dihedrals]
423 elif len(splits) == 3: 423 ↛ 434line 423 didn't jump to line 434, because the condition on line 423 was never false
424 out_dihedrals, out_angles, out_side_dihedrals = Lambda(
425 lambda x: tf.split(x, splits, 1), name="Split_Output"
426 )(x)
427 out_angles = PeriodicOutput(p, "Angles")(out_angles)
428 out_dihedrals = PeriodicOutput(p, "Dihedrals")(out_dihedrals)
429 out_side_dihedrals = PeriodicOutput(p, "Side_Dihedrals")(out_side_dihedrals)
430 decoder_input = [decoder_input]
431 decoder_output = [out_angles, out_dihedrals, out_side_dihedrals]
433 # decoder model before backmapping
434 decoder = tf.keras.Model(
435 inputs=decoder_input, name="Decoder", outputs=decoder_output
436 )
437 if p.tensorboard: 437 ↛ 438line 437 didn't jump to line 438, because the condition on line 437 was never true
438 with Capturing() as output:
439 decoder.summary()
440 with open(p.main_path + "/decoder_summary.txt", "w") as f:
441 f.write("\n".join(output))
443 # backmap input
444 back_inp_angles = Input(shape=(out_angles.shape[1],), name="Back_Angles_Input")
445 back_inp_dihedrals = Input(
446 shape=(out_dihedrals.shape[1],), name="Back_Dihedrals_Input"
447 )
448 if p.use_sidechains: 448 ↛ 454line 448 didn't jump to line 454, because the condition on line 448 was never false
449 pass_sidedihedrals = Input(
450 shape=(out_side_dihedrals.shape[1],), name="Side_Dihedrals_Pass_Through"
451 )
453 # backmapping. The hardest part
454 if not sparse or not isinstance(distances, tf.sparse.SparseTensor): 454 ↛ 459line 454 didn't jump to line 459, because the condition on line 454 was never false
455 back_cartesians = BackMapLayer()(
456 (inp_distances, back_inp_angles, back_inp_dihedrals)
457 )
458 else:
459 back_cartesians = BackMapLayer()((dd, back_inp_angles, back_inp_dihedrals))
461 # pairwise distances is the last part
462 if not sparse or not isinstance(cartesians, tf.sparse.SparseTensor): 462 ↛ 465line 462 didn't jump to line 465, because the condition on line 462 was never false
463 inp_pair = PairwiseDistances(p, "Input")(inp_cartesians)
464 else:
465 inp_pair = PairwiseDistances(p, "Input")(dc)
466 out_pair = PairwiseDistances(p, "Backmapped")(back_cartesians)
468 # backmap_model
469 if p.use_sidechains: 469 ↛ 489line 469 didn't jump to line 489, because the condition on line 469 was never false
470 backmap_model = tf.keras.Model(
471 name="Backmapping",
472 inputs=[
473 back_inp_angles,
474 back_inp_dihedrals,
475 pass_sidedihedrals,
476 inp_distances,
477 inp_cartesians,
478 ],
479 outputs=[
480 back_inp_angles,
481 back_inp_dihedrals,
482 back_cartesians,
483 inp_pair,
484 out_pair,
485 pass_sidedihedrals,
486 ],
487 )
488 else:
489 backmap_model = tf.keras.Model(
490 name="Backmapping",
491 inputs=[
492 back_inp_angles,
493 back_inp_dihedrals,
494 inp_distances,
495 inp_cartesians,
496 ],
497 outputs=[
498 back_inp_angles,
499 back_inp_dihedrals,
500 back_cartesians,
501 inp_pair,
502 out_pair,
503 ],
504 )
505 if p.tensorboard: 505 ↛ 506line 505 didn't jump to line 506, because the condition on line 505 was never true
506 with Capturing() as output:
507 backmap_model.summary()
508 with open(p.main_path + "/backmap_summary.txt", "w") as f:
509 f.write("\n".join(output))
511 # call all the models hierarchically to rebuild a complete model
512 if p.use_sidechains: 512 ↛ 528line 512 didn't jump to line 528
513 main_inputs = [
514 inp_angles,
515 inp_dihedrals,
516 inp_cartesians,
517 inp_distances,
518 inp_side_dihedrals,
519 ]
520 main_outputs = backmap_model(
521 (
522 *decoder(encoder((inp_angles, inp_dihedrals, inp_side_dihedrals))),
523 inp_distances,
524 inp_cartesians,
525 )
526 )
527 else:
528 main_inputs = [
529 inp_angles,
530 inp_dihedrals,
531 inp_cartesians,
532 inp_distances,
533 ]
534 main_outputs = backmap_model(
535 (
536 *decoder(encoder((inp_angles, inp_dihedrals))),
537 inp_distances,
538 inp_cartesians,
539 )
540 )
542 # full_model = tf.keras.Model(inputs=[inp_angles, inp_dihedrals, inp_cartesians, inp_distances, inp_side_dihedrals],
543 # outputs=main_outputs, name="Full_Model")
545 # pass input and outputs to FunctionalModel
546 # In FunctionalModel train_step is overwritten. Train_step should unpack the data and assign the inputs/outputs to the
547 # differtent loss functions.
548 if not sparse: 548 ↛ 557line 548 didn't jump to line 557, because the condition on line 548 was never false
549 model = FunctionalModel(
550 parameters=p,
551 inputs=main_inputs,
552 outputs=main_outputs,
553 encoder=encoder,
554 decoder=decoder,
555 )
556 else:
557 model = SparseFunctionalModel(
558 parameters=p,
559 inputs=main_inputs,
560 outputs=main_outputs,
561 encoder=encoder,
562 decoder=decoder,
563 get_dense_model_central_dihedrals=get_dense_model_central_dihedrals,
564 get_dense_model_central_angles=get_dense_model_central_angles,
565 get_dense_model_side_dihedrals=get_dense_model_side_dihedrals,
566 get_dense_model_cartesians=get_dense_model_cartesians,
567 get_dense_model_distances=get_dense_model_distances,
568 )
570 return model
573##############################################################################
574# Public Classes
575##############################################################################
578class FunctionalModel(tf.keras.Model):
579 def __init__(self, parameters, inputs, outputs, encoder, decoder):
580 super().__init__(inputs=inputs, outputs=outputs)
581 self.p = parameters
582 self.encoder_model = encoder
583 self.decoder_model = decoder
585 # train counter
586 self._train_counter = K.variable(0, "int64", name="train_counter")
588 def encoder(self, x, training=False):
589 return self.encoder_model(x, training=training)
591 def decoder(self, x, training=False):
592 return self.decoder_model(x, training=training)
594 def compile(self, *args, **kwargs):
595 super().compile(*args, **kwargs)
596 self.unpacked_loss_fns = {fn.__name__: fn for fn in self.compiled_loss._losses}
598 def get_loss(self, inp):
599 # unpack the inputs
600 if self.p.use_sidechains:
601 (
602 inp_angles,
603 inp_dihedrals,
604 inp_cartesians,
605 inp_distances,
606 inp_side_dihedrals,
607 ) = inp
608 else:
609 (
610 inp_angles,
611 inp_dihedrals,
612 inp_cartesians,
613 inp_distances,
614 ) = inp
615 # unpack the outputs
616 out = self(inp, training=True)
617 if self.p.use_sidechains:
618 (
619 out_angles,
620 out_dihedrals,
621 back_cartesians,
622 inp_pair,
623 out_pair,
624 out_side_dihedrals,
625 ) = out
626 else:
627 (
628 out_angles,
629 out_dihedrals,
630 back_cartesians,
631 inp_pair,
632 out_pair,
633 ) = out
635 # define latent for cartesian_distance_loss
636 if self.p.use_sidechains:
637 latent = self.encoder_model(
638 (inp_angles, inp_dihedrals, inp_side_dihedrals),
639 training=True,
640 )
641 else:
642 latent = self.encoder_model(
643 (inp_angles, inp_dihedrals),
644 training=True,
645 )
647 with tf.name_scope("Cost"):
648 loss = 0.0
649 # dihedral loss
650 loss += self.unpacked_loss_fns["dihedral_loss_func"](
651 inp_dihedrals, out_dihedrals
652 )
653 # angle loss
654 loss += self.unpacked_loss_fns["angle_loss_func"](inp_angles, out_angles)
655 # cartesian loss
656 loss += self.unpacked_loss_fns["cartesian_loss_func"](inp_pair, out_pair)
657 # distance loss
658 loss += self.unpacked_loss_fns["distance_loss_func"](
659 (inp_angles, inp_dihedrals, inp_side_dihedrals)
660 )
661 # cartesian distance cost
662 loss += self.unpacked_loss_fns["cartesian_distance_loss_func"](
663 inp_pair, latent
664 )
665 # center loss
666 loss += self.unpacked_loss_fns["center_loss_func"](
667 (inp_angles, inp_dihedrals, inp_side_dihedrals)
668 )
669 # reg loss
670 loss += self.unpacked_loss_fns["regularization_loss_func"]()
671 # side dihedral loss
672 if self.p.use_sidechains:
673 loss += self.unpacked_loss_fns["side_dihedral_loss_func"](
674 inp_side_dihedrals, out_side_dihedrals
675 )
676 tf.summary.scalar("Combined Cost", loss)
678 # autoloss
679 from encodermap.loss_functions.loss_functions import periodic_distance
681 angle_auto_loss = 2 * tf.reduce_mean(
682 tf.square(periodic_distance(inp_angles, out_angles, 2 * np.pi))
683 )
684 dihedral_auto_loss = 2 * tf.reduce_mean(
685 tf.square(periodic_distance(inp_dihedrals, out_dihedrals, 2 * np.pi))
686 )
687 side_dihedral_auto_loss = 2 * tf.reduce_mean(
688 tf.square(
689 periodic_distance(inp_side_dihedrals, out_side_dihedrals, 2 * np.pi)
690 )
691 )
692 return loss
694 def train_step(self, data):
695 # Data will always contain all inputs, depending on p, the model will return different tuples
696 # but the input will be the same, only when the encoder is called with teh train data, and the
697 # decoder is called does packing and unpacking matter.
698 # gradient tape to calculate loss for backmapping
700 with tf.GradientTape() as tape:
701 tf.summary.experimental.set_step(self._train_counter)
702 loss = self.get_loss(data)
703 loggable_encoder_layers = [
704 l for l in self.encoder_model.layers if l.__class__.__name__ == "Dense"
705 ]
706 loggable_decoder_layers = [
707 l for l in self.decoder_model.layers if l.__class__.__name__ == "Dense"
708 ]
709 for l in loggable_encoder_layers + loggable_decoder_layers:
710 add_layer_summaries(l, step=self._train_counter)
712 # optimization happens here
713 # Compute Gradients
714 # trainable_vars = self.trainable_variables
715 trainable_vars = (
716 self.encoder_model.trainable_variables
717 + self.decoder_model.trainable_variables
718 )
719 # maybe self.encoder_model.trainable_vars + self.decoder_model.trainable_vars
720 gradients = tape.gradient(loss, trainable_vars)
721 # Update weights
722 self.optimizer.apply_gradients(zip(gradients, trainable_vars))
723 # Update metrics
724 # self.compiled_metrics.update_state(data, self(data))
725 # Return a dict mapping metric names to current value
726 # Add loss to the dict so the ProgressBar callback can pick it up
727 # return {**{m.name: m.result() for m in self.metrics}, **{'loss': loss}}
728 # udpate train counter because tensorflow seems to have deprecated it
729 self._train_counter.assign_add(1)
730 return {"loss": loss}
733class SparseFunctionalModel(FunctionalModel):
734 def __init__(
735 self,
736 parameters,
737 inputs,
738 outputs,
739 encoder,
740 decoder,
741 get_dense_model_central_dihedrals,
742 get_dense_model_central_angles,
743 get_dense_model_side_dihedrals,
744 get_dense_model_cartesians,
745 get_dense_model_distances,
746 ):
747 super().__init__(parameters, inputs, outputs, encoder, decoder)
748 self.get_dense_model_central_dihedrals = get_dense_model_central_dihedrals
749 self.get_dense_model_central_angles = get_dense_model_central_angles
750 self.get_dense_model_side_dihedrals = get_dense_model_side_dihedrals
751 self.get_dense_model_cartesians = get_dense_model_cartesians
752 self.get_dense_model_distances = get_dense_model_distances
754 def get_loss(self, inp):
755 # unpack the inputs
756 (
757 sparse_inp_angles,
758 sparse_inp_dihedrals,
759 sparse_inp_cartesians,
760 sparse_inp_distances,
761 sparse_side_dihedrals,
762 ) = inp
764 # make the side dihedrals dense
765 if isinstance(sparse_inp_angles, tf.sparse.SparseTensor):
766 inp_angles = self.get_dense_model_central_angles(
767 sparse_inp_angles, training=True
768 )
769 else:
770 inp_angles = sparse_inp_angles
771 if isinstance(sparse_inp_dihedrals, tf.sparse.SparseTensor):
772 inp_dihedrals = self.get_dense_model_central_dihedrals(
773 sparse_inp_dihedrals, training=True
774 )
775 else:
776 inp_dihedrals = sparse_inp_dihedrals
777 if isinstance(sparse_inp_cartesians, tf.sparse.SparseTensor):
778 inp_cartesians = self.get_dense_model_cartesians(
779 sparse_inp_cartesians, training=True
780 )
781 else:
782 inp_cartesians = sparse_inp_cartesians
783 if isinstance(sparse_inp_distances, tf.sparse.SparseTensor):
784 inp_distances = self.get_dense_model_distances(
785 sparse_inp_distances, training=True
786 )
787 else:
788 inp_distances = sparse_inp_distances
790 if self.p.use_sidechains:
791 if isinstance(sparse_side_dihedrals, tf.sparse.SparseTensor):
792 inp_side_dihedrals = self.get_dense_model_side_dihedrals(
793 sparse_side_dihedrals, training=True
794 )
795 else:
796 inp_side_dihedrals = sparse_side_dihedrals
798 # make them into an Iterable again
799 if self.p.use_sidechains:
800 data = (
801 inp_angles,
802 inp_dihedrals,
803 inp_cartesians,
804 inp_distances,
805 inp_side_dihedrals,
806 )
807 else:
808 data = (
809 inp_angles,
810 inp_dihedrals,
811 inp_cartesians,
812 inp_distances,
813 )
815 # call the loss
816 return super().get_loss(data)
819class SequentialModel(tf.keras.Model):
820 def __init__(
821 self,
822 input_dim,
823 parameters=None,
824 reload_layers=None,
825 sparse=False,
826 get_dense_model=None,
827 ):
828 if parameters is None: 828 ↛ 829line 828 didn't jump to line 829, because the condition on line 828 was never true
829 self.p = Parameters()
830 else:
831 self.p = parameters
832 super().__init__()
833 self.sparse = sparse
834 self.input_dim = input_dim
835 self.get_dense_model = get_dense_model
837 # tensors for using tf.cond inside self.train_step()
838 self.update_step = tf.constant(self.p.summary_step, dtype="int64")
839 self.debug_tensor = tf.constant(self.p.tensorboard, dtype="bool")
841 # periodicity doubles the inputs and outputs
842 if self.p.periodicity < float("inf"):
843 self.input_dim *= 2
845 # if layers is provided we are rebuilding a trained model
846 if reload_layers is not None: 846 ↛ 847line 846 didn't jump to line 847, because the condition on line 846 was never true
847 if self.p.periodicity < float("inf"):
848 print(
849 f"rebuilding Model with input_dim = {int(self.input_dim/2)} and periodicity = {self.p.periodicity}"
850 )
851 else:
852 print(f"rebuilding Model with input_dim = {self.input_dim}")
853 if len(reload_layers) != 2:
854 raise Exception(
855 "currently only works with 2 layers. Encoder and Decoder."
856 )
857 assert all([isinstance(i, tf.keras.Sequential) for i in reload_layers])
858 # input_layer = tf.keras.layers.InputLayer(input_shape=(self.input_dim,), dtype='float32')
859 self.encoder_model = reload_layers[0]
860 self.decoder_model = reload_layers[1]
861 self.build(input_shape=(1, self.input_dim))
862 return
864 # define regularizer
865 regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant)
867 # rename empty string in parameters to None
868 activation_functions = list(
869 map(lambda x: x if x != "" else None, self.p.activation_functions)
870 )
872 # define how layers are stacked
873 layer_data = list(
874 zip(
875 self.p.n_neurons + self.p.n_neurons[-2::-1],
876 activation_functions[1:] + activation_functions[-1::-1],
877 )
878 )
879 # add a layer that reshapes the output
880 layer_data.append([self.input_dim, "tanh"])
882 # decide layer names
883 names = []
884 for i, (n_neurons, act_fun) in enumerate(layer_data):
885 if i < len(self.p.n_neurons) - 1:
886 name = f"Encoder_{i}"
887 elif i > len(self.p.n_neurons) - 1:
888 ind = i - len(self.p.n_neurons)
889 name = f"Decoder_{ind}"
890 else:
891 name = "Latent"
892 names.append(name)
893 layer_data = list((*i, j) for i, j in zip(layer_data, names))
895 # define encoder and decoder layers
896 neurons = [i[0] for i in layer_data]
897 bottleneck_index = neurons.index(min(neurons)) + 1
898 self.encoder_layers = layer_data[:bottleneck_index]
899 self.decoder_layers = layer_data[bottleneck_index:]
901 # input
902 # Instead of using InputLayer use Dense with kwarg input_shape
903 # allows model to be reloaded better <- weird english... reloaded better
904 if self.sparse: 904 ↛ 905line 904 didn't jump to line 905, because the condition on line 904 was never true
905 shape = self.input_dim
906 if self.p.periodicity < float("inf"):
907 shape /= 2
908 _input_layer = Input(
909 shape=(int(shape),),
910 sparse=True,
911 )
912 x = Dense(shape)(_input_layer)
913 self.get_dense_model = tf.keras.Model(
914 inputs=_input_layer,
915 outputs=x,
916 )
918 input_layer = tf.keras.layers.Dense(
919 input_shape=(self.input_dim,),
920 units=self.encoder_layers[0][0],
921 activation=self.encoder_layers[0][1],
922 name=self.encoder_layers[0][2],
923 kernel_initializer=tf.initializers.VarianceScaling(),
924 kernel_regularizer=regularizer,
925 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5),
926 )
928 # what model to use for the encoder
930 # output
931 # output_layer = tf.keras.layers.Dense(
932 # self.input_dim,
933 # name="Output",
934 # activation=None,
935 # kernel_initializer=tf.initializers.VarianceScaling(),
936 # kernel_regularizer=regularizer,
937 # bias_initializer=tf.initializers.RandomNormal(0.1, 0.05),
938 # trainable=True)
940 # encoder
941 self.encoder_model = tf.keras.Sequential(
942 [input_layer]
943 + [
944 tf.keras.layers.Dense(
945 n_neurons,
946 activation=act_fun,
947 name=name,
948 kernel_initializer=tf.initializers.VarianceScaling(),
949 kernel_regularizer=regularizer,
950 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5),
951 )
952 for n_neurons, act_fun, name in self.encoder_layers[1:]
953 ],
954 name="Encoder",
955 )
957 # decoder
958 self.decoder_model = tf.keras.Sequential(
959 [
960 tf.keras.layers.Dense(
961 n_neurons,
962 activation=act_fun,
963 name=name,
964 kernel_initializer=tf.initializers.VarianceScaling(),
965 kernel_regularizer=regularizer,
966 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5),
967 )
968 for n_neurons, act_fun, name in self.decoder_layers
969 ],
970 name="Decoder",
971 )
973 # build
974 self.build(input_shape=(1, self.input_dim))
976 # train counter
977 self._train_counter = K.variable(0, "int64", name="train_counter")
979 def build(self, *args, **kwargs):
980 input_shape = kwargs["input_shape"]
981 # Because build calls self.call and self.call calls self.encode
982 # the input dim needs to be halved here
983 if self.p.periodicity < float("inf"):
984 input_shape = (*input_shape[:-1], int(input_shape[-1] / 2))
985 try:
986 super().build(*args, **{**kwargs, **dict(input_shape=input_shape)})
987 except Exception:
988 if self.p.periodicity < float("inf"):
989 print(
990 "Exception is raised because of periodicity. In general "
991 "you don't need to call the build method from outside of "
992 "this class, because it is called at the end of __init__."
993 )
994 raise
996 def compile(self, *args, **kwargs):
997 super().compile(*args, **kwargs)
998 try:
999 self.unpacked_loss_fns = {
1000 fn.__name__: fn for fn in self.compiled_loss._losses
1001 }
1002 except AttributeError:
1003 for i in dir(self):
1004 if "loss" in i:
1005 print(i)
1006 print(self._callable_losses)
1007 raise
1009 def encoder(self, x, training=False):
1010 if self.sparse: 1010 ↛ 1011line 1010 didn't jump to line 1011, because the condition on line 1010 was never true
1011 x = self.get_dense_model(x)
1012 if self.p.periodicity < float("inf"):
1013 if self.p.periodicity != 2 * pi: 1013 ↛ 1014line 1013 didn't jump to line 1014, because the condition on line 1013 was never true
1014 x = x / self.p.periodicity * 2 * pi
1015 x = tf.concat([tf.sin(x), tf.cos(x)], 1)
1016 return self.encoder_model(x, training=training)
1018 def decoder(self, x, training=False):
1019 x = self.decoder_model(x, training=training)
1020 if self.p.periodicity < float("inf"):
1021 x = tf.atan2(*tf.split(x, 2, 1))
1022 if self.p.periodicity != 2 * pi: 1022 ↛ 1023line 1022 didn't jump to line 1023, because the condition on line 1022 was never true
1023 x = x / (2 * pi) * self.p.periodicity
1024 return x
1026 def call(self, x, training=False):
1027 # encode and decode
1028 encoded = self.encoder(x, training=training)
1029 decoded = self.decoder(encoded, training=training)
1030 return decoded
1032 def train_step(self, data):
1033 """Overwrites the normal train_step. What is different?
1035 Not much. Even the provided data is expected to be a tuple of (data, classes) (x, y) in classification tasks.
1036 The data is unpacked and y is discarded, because the Autoencoder Model is a regression task.
1038 Args:
1039 data (tuple): The (x, y) data of this train step.
1041 """
1042 x, _ = data
1043 if self.sparse:
1044 x = self.get_dense_model(x)
1046 with tf.GradientTape() as tape:
1047 tf.summary.experimental.set_step(self._train_counter)
1048 with tf.name_scope("Cost"):
1049 loss = 0.0
1050 for l in self.compiled_loss._losses:
1051 loss += l(x, self(x, training=True))
1052 tf.summary.scalar("Combined Cost", loss)
1053 for l in self.encoder_model.layers + self.decoder_model.layers:
1054 add_layer_summaries(l, step=self._train_counter)
1056 # Compute Gradients
1057 trainable_vars = self.trainable_variables
1058 gradients = tape.gradient(loss, trainable_vars)
1059 # Update weights
1060 self.optimizer.apply_gradients(zip(gradients, trainable_vars))
1061 # Update metrics
1062 self.compiled_metrics.update_state(x, self(x))
1063 # udpate train counter because tensorflow seems to have deprecated it
1064 self._train_counter.assign_add(1)
1065 # Return a dict mapping metric names to current value
1066 return {**{m.name: m.result() for m in self.metrics}, **{"loss": loss}}
1069class ADCSequentialModel(SequentialModel):
1070 def __init__(self, input_dim, parameters=None, reload_layers=None, sparse=False):
1071 warnings.warn("check split")
1072 if parameters is None:
1073 self.p = ADCParameters()
1074 else:
1075 self.p = parameters
1076 self.multiples = tf.TensorShape((self.p.batch_size, 1))
1077 super(ADCSequentialModel, self).__init__(input_dim, self.p, reload_layers)
1079 def call(self, x, training=False):
1080 # encode and decode
1081 if isinstance(x, tf.Tensor):
1082 pass
1083 elif isinstance(x, tuple):
1084 (
1085 inp_angles,
1086 inp_dihedrals,
1087 inp_cartesians,
1088 inp_distances,
1089 inp_side_dihedrals,
1090 ) = x
1091 if not self.p.use_backbone_angles and not self.p.use_sidechains:
1092 x = inp_dihedrals
1093 elif self.p.use_backbone_angles and not self.p.use_sidechains:
1094 x = tf.concat([inp_angles, inp_dihedrals], 1)
1095 elif self.p.use_backbone_angles and self.p.use_sidechains:
1096 x = tf.concat([inp_angles, inp_dihedrals, inp_side_dihedrals], 1)
1097 encoded = self.encoder(x, training=training)
1098 decoded = self.decoder(encoded, training=training)
1099 return decoded
1101 def call_and_map_back(
1102 self, x, distances, angles, dihedrals, cartesians, splits, side_dihedrals=None
1103 ):
1104 # latent = self.encoder(x, training=False)
1105 out = self(x, training=True)
1106 latent = self.encoder(x, training=True)
1108 # unpack out
1109 if splits is None:
1110 out_dihedrals = out
1111 out_angles = tf.tile(
1112 tf.expand_dims(tf.reduce_mean(angles, 0), 0), multiples=self.multiples
1113 )
1114 elif len(splits) == 2:
1115 out_angles, out_dihedrals = tf.split(out, splits, 1)
1116 elif len(splits) == 3:
1117 out_angles, out_dihedrals, out_side_dihedrals = tf.split(out, splits, 1)
1119 # do back-mapping
1120 back_mean_lengths = tf.expand_dims(tf.reduce_mean(distances, 0), 0)
1121 back_chain_in_plane = chain_in_plane(back_mean_lengths, out_angles)
1122 back_cartesians = dihedrals_to_cartesian_tf(
1123 out_dihedrals + pi, back_chain_in_plane
1124 )
1126 # get pairwise distances of CA atoms
1127 inp_pair = pairwise_dist(
1128 cartesians[
1129 :,
1130 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
1131 ],
1132 flat=True,
1133 )
1134 out_pair = pairwise_dist(
1135 back_cartesians[
1136 :,
1137 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
1138 ],
1139 flat=True,
1140 )
1142 with tf.name_scope("Cost"):
1143 loss = 0.0
1144 # dihedral loss
1145 loss += self.unpacked_loss_fns["dihedral_loss_func"](
1146 dihedrals, out_dihedrals
1147 )
1148 # angle loss
1149 loss += self.unpacked_loss_fns["angle_loss_func"](angles, out_angles)
1150 # cartesian loss
1151 loss += self.unpacked_loss_fns["cartesian_loss_func"](inp_pair, out_pair)
1152 # distance loss
1153 loss += self.unpacked_loss_fns["distance_loss_func"](x)
1154 # cartesian distance cost
1155 loss += self.unpacked_loss_fns["cartesian_distance_loss_func"](
1156 inp_pair, latent
1157 )
1158 # center loss
1159 loss += self.unpacked_loss_fns["center_loss_func"](x)
1160 # reg loss
1161 loss += self.unpacked_loss_fns["regularization_loss_func"]()
1162 if self.p.use_sidechains:
1163 loss += self.unpacked_loss_fns["side_dihedral_loss_func"](
1164 side_dihedrals, out_side_dihedrals
1165 )
1166 tf.summary.scalar("Combined Cost", loss)
1167 return loss
1169 def train_step(self, data):
1170 # unpack the data
1171 (
1172 inp_angles,
1173 inp_dihedrals,
1174 inp_cartesians,
1175 inp_distances,
1176 inp_side_dihedrals,
1177 ) = data
1178 if not self.p.use_backbone_angles and not self.p.use_sidechains:
1179 main_inputs = inp_dihedrals
1180 splits = None
1181 elif self.p.use_backbone_angles and not self.p.use_sidechains:
1182 main_inputs = tf.concat([inp_angles, inp_dihedrals], 1)
1183 splits = [inp_angles.shape[1], inp_dihedrals.shape[1]]
1184 elif self.p.use_backbone_angles and self.p.use_sidechains:
1185 # ToDo: make sure the splits work here. There seems to be different outputs from tf and np.
1186 main_inputs = tf.concat([inp_angles, inp_dihedrals, inp_side_dihedrals], 1)
1187 splits = [
1188 inp_angles.shape[1],
1189 inp_dihedrals.shape[1],
1190 inp_side_dihedrals.shape[1],
1191 ]
1193 # gradient tape to calculate loss for backmapping
1194 with tf.GradientTape() as tape:
1195 tf.summary.experimental.set_step(self._train_counter)
1196 if self.p.use_sidechains:
1197 loss = self.call_and_map_back(
1198 main_inputs,
1199 inp_distances,
1200 inp_angles,
1201 inp_dihedrals,
1202 inp_cartesians,
1203 splits,
1204 inp_side_dihedrals,
1205 )
1206 else:
1207 loss = self.call_and_map_back(
1208 main_inputs,
1209 inp_distances,
1210 inp_angles,
1211 inp_dihedrals,
1212 inp_cartesians,
1213 splits,
1214 )
1215 for l in self.encoder_model.layers + self.decoder_model.layers:
1216 add_layer_summaries(l, step=self._train_counter)
1218 # optimization happens here
1219 # Compute Gradients
1220 trainable_vars = self.trainable_variables
1221 gradients = tape.gradient(loss, trainable_vars)
1222 # Update weights
1223 self.optimizer.apply_gradients(zip(gradients, trainable_vars))
1224 # Update metrics
1225 self.compiled_metrics.update_state(data, self(data))
1226 # Return a dict mapping metric names to current value
1227 # Add loss to the dict so the ProgressBar callback can pick it up
1228 return {**{m.name: m.result() for m in self.metrics}, **{"loss": loss}}