Coverage for encodermap/loss_functions/loss_functions.py: 6%
275 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
1# -*- coding: utf-8 -*-
2# encodermap/loss_functions/loss_functions.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Loss functions for encodermap
24"""
25################################################################################
26# Imports
27################################################################################
30# Future Imports at the top
31from __future__ import annotations
33# Standard Library Imports
34from collections.abc import Callable
35from typing import TYPE_CHECKING, Any, Optional
37# Third Party Imports
38import tensorflow as tf
39import tensorflow.keras.backend as K
41# Encodermap imports
42from encodermap.encodermap_tf1.misc import distance_cost
43from encodermap.misc.distances import (
44 pairwise_dist,
45 pairwise_dist_periodic,
46 periodic_distance,
47 sigmoid,
48)
49from encodermap.parameters.parameters import (
50 ADCParameters,
51 AnyParameters,
52 Parameters,
53 ParametersFramework,
54)
57################################################################################
58# Globals
59################################################################################
62__all__: list[str] = [
63 "reconstruction_loss",
64 "auto_loss",
65 "center_loss",
66 "regularization_loss",
67 "loss_combinator",
68 "distance_loss",
69 "cartesian_loss",
70 "cartesian_distance_loss",
71 "angle_loss",
72 "dihedral_loss",
73]
76################################################################################
77# Functions for tf.cond
78# Don't know if this is really faster than logging every step to tensorboard
79################################################################################
82def _do_nothing(*args: Any, **kwargs: Any) -> None:
83 """This function does nothing. One of the functions provided to tf.cond."""
84 pass
87def _summary_cost(name: str, cost: tf.Tensor) -> None:
88 """This functions logs a scalar to a name. One of the functions provided to tf.cond."""
89 tf.summary.scalar(name, cost)
92##############################################################################
93# Legacy Code to make some tests
94##############################################################################
97def old_distance_loss(model, parameters=None): # pragma: no doccheck
98 # choose parameters
99 if parameters is None:
100 p = Parameters()
101 else:
102 p = parameters
103 # check Layers
104 if len(model.layers) == 2:
105 # sequential API
106 latent = model.encoder
107 else:
108 # functional API
109 latent = model.encoder
111 # closure
112 def loss(y_true, y_pred=None, step=None):
113 loss.name = "distance_loss"
114 y_pred = latent(y_true, training=True)
115 if p.distance_cost_scale is not None:
116 dist_cost = distance_cost(
117 y_true, y_pred, *p.dist_sig_parameters, p.periodicity
118 )
119 dist_cost *= p.distance_cost_scale
120 else:
121 dist_cost = 0.0
122 tf.summary.scalar("Distance Cost", dist_cost)
123 return dist_cost
125 return loss
128##############################################################################
129# Public Functions
130##############################################################################
133def basic_loss_combinator(*losses: Callable) -> Callable:
134 """Calculates the sum of a list of losses and returns a combined loss.
136 The basic loss combinator does not write to summary. Can be used for debugging.
138 """
140 def loss(y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None) -> tf.Tensor:
141 return sum([loss(y_true, y_pred) for loss in losses])
143 return loss
146def loss_combinator(*losses: Callable) -> Callable:
147 """Calculates the sum of a list of losses and returns a combined loss.
149 Args:
150 *losses: Variable length argument list of loss functions.
152 Returns:
153 Callable: A combined loss function that can be used in custom training or with model.fit()
155 Example:
156 >>> import encodermap as em
157 >>> from encodermap import loss_functions
158 >>> import tensorflow as tf
159 >>> import numpy as np
160 >>> tf.random.set_seed(1) # fix random state to pass doctest :)
161 ...
162 >>> model = tf.keras.Sequential([
163 ... tf.keras.layers.Dense(100, kernel_regularizer=tf.keras.regularizers.l2(), activation='relu'),
164 ... tf.keras.layers.Dense(2, kernel_regularizer=tf.keras.regularizers.l2(), activation='relu'),
165 ... tf.keras.layers.Dense(100, kernel_regularizer=tf.keras.regularizers.l2(), activation='relu')
166 ... ])
167 ...
168 >>> # Set up losses and bundle them using the loss combinator
169 >>> auto_loss = loss_functions.auto_loss(model)
170 >>> reg_loss = loss_functions.regularization_loss(model)
171 >>> loss = loss_functions.loss_combinator(auto_loss, reg_loss)
172 ...
173 >>> # Compile model, model.fit() usually takes a tuple of (data, classes) but in
174 >>> # regression learning the data needs to be provided twice. That's why we use fit(data, data)
175 >>> model.compile(tf.keras.optimizers.Adam(), loss=loss)
176 >>> data = np.random.random((100, 100))
177 >>> history = model.fit(x=data, y=data, verbose=0)
178 >>> tf.random.set_seed(None) # reset seed
179 ...
180 >>> # This weird contraption is also there to make the output predictable and pass tests
181 >>> # Somehow the tf.random.seed(1) does not work here. :(
182 >>> loss = history['loss'][0]
183 >>> print(loss) # doctest: +SKIP
184 {'loss': array([2.6])}
185 >>> print(type(loss))
186 <class 'float'>
188 """
190 def combined_loss_func(
191 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None
192 ) -> tf.Tensor:
193 cost = sum([loss(y_true, y_pred) for loss in losses])
194 tf.summary.scalar("Combined Cost", cost)
195 return cost
197 return combined_loss_func
200def distance_loss(
201 model: tf.keras.models.Model,
202 parameters: Optional[AnyParameters] = None,
203 callback: Optional[tf.keras.callbacks.Callback] = None,
204) -> Callable:
205 """Encodermap distance_loss
207 Transforms space using sigmoid function first proposed by sketch-map.
209 Args:
210 model (tf.keras.Model): A model you want to use the loss function on.
211 parameters (Optional[AnyParameters]): The parameters. If None is
212 provided default values (check them with `print(em.Parameters.defaults_description())`)
213 are used. Defaults to None.
214 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback,
215 that prevents a tensorboard write when `parameters.summary_step` is
216 set to greater values. This saves disk-space, as costs are not needed
217 to be logged every training step.
219 Note:
220 If the model contains two layers. The first layer will be assumed to be the decoder.
221 If the model contains more layers, one layer needs to be named 'latent' (case insensitive).
223 Raises:
224 Exception: When no bottleneck/latent layer can be found in the model.
226 Returns:
227 Callable: A loss function.
229 References::
231 @article{ceriotti2011simplifying,
232 title={Simplifying the representation of complex free-energy landscapes using sketch-map},
233 author={Ceriotti, Michele and Tribello, Gareth A and Parrinello, Michele},
234 journal={Proceedings of the National Academy of Sciences},
235 volume={108},
236 number={32},
237 pages={13023--13028},
238 year={2011},
239 publisher={National Acad Sciences}
240 }
242 """
243 # choose parameters
244 if parameters is None:
245 p = Parameters()
246 else:
247 p = parameters
249 # check Layers
250 if len(model.layers) == 2:
251 # sequential API
252 latent = model.encoder
253 else:
254 # functional API
255 latent = model.encoder
257 if callback is None:
258 write_bool = K.constant(False, "bool", name="log_bool")
259 else:
260 write_bool = callback.log_bool
262 # define dist loss
263 dist_loss = sigmoid_loss(p)
265 # closure
266 def distance_loss_func(
267 y_true: tf.Tensor,
268 y_pred: Optional[tf.Tensor] = None,
269 ) -> tf.Tensor:
270 """y_true can be whatever input you like, dihedrals, angles, pairwise
271 dist, contact maps. It will be transformed with Sketch-map's sigmoid
272 function, as will the output of the latent layer of the autoencoder.
273 the difference of these two will result in a loss function.
275 """
276 distance_loss_func.name = "distance_loss"
277 y_pred = latent(y_true, training=True)
278 # functional model gives a tuple
279 if isinstance(y_true, tuple):
280 y_true = tf.concat(y_true[:3], axis=1)
281 if p.distance_cost_scale is not None:
282 dist_cost = dist_loss(y_true, y_pred)
283 dist_cost *= p.distance_cost_scale
285 else:
286 dist_cost = 0.0
287 tf.cond(
288 write_bool,
289 true_fn=lambda: _summary_cost("Distance Cost", dist_cost),
290 false_fn=lambda: _do_nothing(),
291 name="Cost",
292 )
293 tf.debugging.assert_all_finite(
294 dist_cost, message=f"Dist cost became infinite or NaN."
295 )
296 return dist_cost
298 return distance_loss_func
301def sigmoid_loss(
302 parameters: Optional[AnyParameters] = None,
303 periodicity_overwrite: Optional[float] = None,
304 dist_dig_parameters_overwrite: Optional[tuple[float, ...]] = None,
305) -> Callable:
306 """Sigmoid loss closure for use in distance cost and cartesian distance cost.
308 Outer function prepares callable sigmoid. Sigmoid can then be called with just y_true and y_pred.
310 Args:
311 parameters (Optional[AnyParameters]): The parameters. If None is
312 provided default values (check them with `print(em.Parameters.defaults_description())`)
313 are used. Defaults to None.
314 periodicity_overwrite(Optional[float]): Cartesian distance cost is
315 always non-periodic. To make sure no periodicity is applied to the
316 data, set periodicity_overwrite to `float('inf')`. If None is provided,
317 the periodicity of the parameters class (default 2*pi) will be used.
318 Defaults to None.
319 dist_dig_parameters_overwrite (Optional[tuple[float, ...]]): Distance costs
320 for the `AngleDihedralCartesianEncoderMap` class come in two flavors.
321 The regular distance cost compares the encoder inputs to the latent
322 and use Sketch-map's sigmoid function to weigh these data accordingly.
323 The cartesian distance cost, on the other hand, compares the latent
324 and the pairwise distances of the input CA coordinates. This cost
325 function uses different sigmoid parameters (because the CA distances
326 don't lie in a periodic space). The tuple of 6 floats provided for
327 `dist_dig_parameters_overwrite` will supersede the `dist_sig_parameters`
328 in the `parameters` argument. Defaults to None.
330 Returns:
331 Callable: A function that takes `y_true` and `y_pred`.
332 Both need to be of the same shape.
334 """
335 if parameters is None:
336 p = Parameters()
337 else:
338 p = parameters
340 if periodicity_overwrite is not None:
341 periodicity = periodicity_overwrite
342 else:
343 periodicity = p.periodicity
345 if dist_dig_parameters_overwrite is not None:
346 dist_sig_parameters = dist_dig_parameters_overwrite
347 else:
348 dist_sig_parameters = p.dist_sig_parameters
350 # @tf.autograph.experimental.do_not_convert
351 def sigmoid_loss_func(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
352 r_h = y_true
353 r_l = y_pred
354 if periodicity == float("inf"):
355 dist_h = pairwise_dist(r_h)
356 else:
357 dist_h = pairwise_dist_periodic(r_h, periodicity)
358 dist_l = pairwise_dist(r_l)
360 sig_h = sigmoid(*dist_sig_parameters[:3])(dist_h)
361 sig_l = sigmoid(*dist_sig_parameters[3:])(dist_l)
363 cost = tf.reduce_mean(tf.square(sig_h - sig_l))
364 tf.debugging.assert_all_finite(
365 cost, message=f"Sigmoid cost became infinite or NaN."
366 )
367 return cost
369 return sigmoid_loss_func
372def center_loss(
373 model: tf.keras.models.Model,
374 parameters: Optional[AnyParameters] = None,
375 callback: Optional[tf.keras.callbacks.Callback] = None,
376) -> Callable:
377 """Encodermap center_loss
379 Use in custom training loops or in model.fit() training.
381 Args:
382 model (tf.keras.Model): A model you want to use the loss function on.
383 parameters (Optional[AnyParameters]): The parameters. If None is
384 provided default values (check them with `print(em.Parameters.defaults_description())`)
385 are used. Defaults to None.
386 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback,
387 that prevents a tensorboard write when `parameters.summary_step` is
388 set to greater values. This saves disk-space, as costs are not needed
389 to be logged every training step.
391 Note:
392 If the model contains two layers. The first layer will be assumed to be
393 the decoder. If the model contains more layers, one layer needs to
394 be named 'latent' (case-insensitive).
396 Raises:
397 Exception: When no bottleneck/latent layer can be found in the model.
399 Returns:
400 Callable: A loss function.
402 """
403 # choose parameters
404 if parameters is None:
405 p = Parameters()
406 else:
407 p = parameters
408 # check Layers
409 if len(model.layers) == 2:
410 # sequential API
411 latent = model.encoder
412 else:
413 # functional API
414 latent = model.encoder
416 if callback is None:
417 write_bool = K.constant(False, "bool", name="log_bool")
418 else:
419 write_bool = callback.log_bool
421 # closure
422 def center_loss_func(
423 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None
424 ) -> tf.Tensor:
425 """y_true will not be used in this loss function. y_pred can be supplied,
426 but if None will be taken from the latent layer. This loss function tries
427 to center the points in the latent layer.
429 """
430 center_loss_func.name = "center_loss"
431 y_pred = latent(y_true, training=True)
432 # functional model gives a tuple
433 if isinstance(y_true, tuple):
434 y_true = tf.concat(y_true[:3], axis=1)
435 if p.center_cost_scale is not None:
436 center_cost = tf.reduce_mean(tf.square(y_pred))
437 center_cost *= p.center_cost_scale
438 else:
439 center_cost = 0.0
440 tf.cond(
441 write_bool,
442 true_fn=lambda: _summary_cost("Center Cost", center_cost),
443 false_fn=lambda: _do_nothing(),
444 name="Cost",
445 )
446 tf.debugging.assert_all_finite(
447 center_cost, message=f"Center cost became infinite or NaN."
448 )
449 return center_cost
451 return center_loss_func
454def regularization_loss(
455 model: tf.keras.models.Model,
456 parameters: Optional[AnyParameters] = None,
457 callback: Optional[tf.keras.callbacks.Callback] = None,
458) -> Callable:
459 """Regularization loss of arbitrary tf.keras.Model
461 Use in custom training loops or in model.fit() training.
462 Loss is obtained as tf.math.add_n(model.losses)
464 Args:
465 model (tf.keras.Model): A model you want to use the loss function on.
466 parameters (Optional[AnyParameters]): The parameters. If None is
467 provided default values (check them with `print(em.Parameters.defaults_description())`)
468 are used. Defaults to None.
469 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback,
470 that prevents a tensorboard write when `parameters.summary_step` is
471 set to greater values. This saves disk-space, as costs are not needed
472 to be logged every training step.
474 Returns:
475 Callable: A loss function.
477 """
478 if parameters is None:
479 p = Parameters()
480 else:
481 p = parameters
483 if callback is None:
484 write_bool = K.constant(False, "bool", name="log_bool")
485 else:
486 write_bool = callback.log_bool
488 def regularization_loss_func(
489 y_true: Optional[tf.Tensor] = None, y_pred: Optional[tf.Tensor] = None
490 ) -> tf.Tensor:
491 """y_true and y_pred will not be considered here, because the
492 regularization loss is accessed via model.losses.
494 """
495 regularization_loss.name = "regularization_loss"
496 reg_loss = tf.math.add_n(model.losses)
497 tf.cond(
498 write_bool,
499 true_fn=lambda: _summary_cost("Regularization Cost", reg_loss),
500 false_fn=lambda: _do_nothing(),
501 name="Cost",
502 )
503 tf.debugging.assert_all_finite(
504 reg_loss, message=f"Regularization cost became infinite or NaN."
505 )
506 return reg_loss
508 return regularization_loss_func
511def reconstruction_loss(model: tf.keras.models.Model) -> Callable:
512 """Simple Autoencoder recosntruction loss.
514 Use in custom training loops or in model.fit training.
516 Args:
517 model (tf.keras.Model): A model you want to use the loss function on.
519 Returns:
520 Callable: A loss function to be used in custom training or model.fit.
521 Function takes the following arguments:
522 y_true (tf.Tensor): The true tensor.
523 y_pred (tf.Tensor, optional): The output tensor. If not supplied
524 the model will be called to get this tensor. Defaults to None.
525 step (int): A step for tensorboard callbacks. Defaults to None.
527 Examples:
528 >>> import tensorflow as tf
529 >>> import encodermap as em
530 >>> from encodermap import loss_functions
531 >>> model = tf.keras.Model()
532 >>> loss = loss_functions.reconstruction_loss(model)
533 >>> x = tf.random.normal(shape=(10, 10))
534 >>> loss(x, x).numpy()
535 0.0
537 """
539 def reconstruction_loss_func(
540 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None
541 ) -> tf.Tensor:
542 # if y_pred is None, this function is used in custom training
543 # and should use model to get the output
544 if y_pred is None:
545 y_pred = model(y_true)
546 # calculate error
547 reconstruction_error = tf.reduce_mean(tf.square(tf.subtract(y_pred, y_true)))
548 return reconstruction_error
550 return reconstruction_loss_func
553def auto_loss(
554 model: tf.keras.models.Model,
555 parameters: Optional[AnyParameters] = None,
556 callback: Optional[tf.keras.callbacks.Callback] = None,
557) -> Callable:
558 """Encodermap auto_loss.
560 Use in custom training loops or in model.fit() training.
562 Args:
563 model (tf.keras.Model): A model you want to use the loss function on.
564 parameters (Optional[AnyParameters]): The parameters. If None is
565 provided default values (check them with `print(em.Parameters.defaults_description())`)
566 are used. Defaults to None.
567 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback,
568 that prevents a tensorboard write when `parameters.summary_step` is
569 set to greater values. This saves disk-space, as costs are not needed
570 to be logged every training step.
572 Returns:
573 Callable: A loss function.
575 """
576 if parameters is None:
577 p = Parameters()
578 else:
579 p = parameters
581 if callback is None:
582 write_bool = K.constant(False, "bool", name="log_bool")
583 else:
584 write_bool = callback.log_bool
586 def auto_loss_func(
587 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None
588 ) -> tf.Tensor:
589 """y_true is complete model input, y_pred is complete model output.
590 Because here it is not intended to unpack the output into dihedrals
591 and angles, y_pred can be None and will be directly taken from the model.
592 """
593 auto_loss_func.name = "auto_loss"
595 if y_pred is None:
596 y_pred = model(y_true)
597 if p.auto_cost_scale is not None:
598 if p.auto_cost_variant == "mean_square":
599 auto_cost = tf.reduce_mean(
600 tf.square(periodic_distance(y_true, y_pred, p.periodicity))
601 )
602 elif p.auto_cost_variant == "mean_abs":
603 auto_cost = tf.reduce_mean(
604 tf.abs(periodic_distance(y_true, y_pred, p.periodicity))
605 )
606 elif p.auto_cost_variant == "mean_norm":
607 auto_cost = tf.reduce_mean(
608 tf.norm(periodic_distance(y_true, y_pred, p.periodicity), axis=1)
609 )
610 else:
611 raise ValueError(
612 "auto_cost_variant {} not available".format(p.auto_cost_variant)
613 )
614 auto_cost *= p.auto_cost_scale
615 else:
616 auto_cost = 0.0
617 tf.cond(
618 write_bool,
619 true_fn=lambda: _summary_cost("Auto Cost", auto_cost),
620 false_fn=lambda: _do_nothing(),
621 name="Cost",
622 )
623 tf.debugging.assert_all_finite(
624 auto_cost, message=f"Auto cost became infinite or NaN."
625 )
626 return auto_cost
628 return auto_loss_func
631def dihedral_loss(
632 model: tf.keras.models.Model,
633 parameters: Optional[AnyParameters] = None,
634 callback: Optional[tf.keras.callbacks.Callback] = None,
635) -> Callable:
636 """Encodermap dihedral loss.
638 Calculates distances between true and predicted dihedral angles. Respects
639 periodicity in a [-a, a] interval if the provided parameters have a
640 periodicity of 2 * a.
642 Note:
643 The interval should be (-a, a], but due to floating point precision we can't make this
644 distinction here.
646 Args:
647 model (tf.keras.Model): A model you want to use the loss function on.
648 parameters (Optional[AnyParameters]): The parameters. If None is
649 provided default values (check them with `print(em.Parameters.defaults_description())`)
650 are used. Defaults to None.
651 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback,
652 that prevents a tensorboard write when `parameters.summary_step` is
653 set to greater values. This saves disk-space, as costs are not needed
654 to be logged every training step.
656 Returns:
657 Callable: A loss function.
659 """
660 if parameters is None:
661 p = ADCParameters()
662 else:
663 p = parameters
665 if callback is None:
666 write_bool = K.constant(False, "bool", name="log_bool")
667 else:
668 write_bool = callback.log_bool
670 # closure
671 def dihedral_loss_func(
672 y_pred: tf.Tensor, y_true: Optional[tf.Tensor] = None
673 ) -> tf.Tensor:
674 """y_pred should be model input dihedrals, y_true should be model output dihedrals."""
675 dihedral_loss_func.name = "dihedral_loss"
676 if p.dihedral_cost_scale is not None:
677 if p.dihedral_cost_variant == "mean_square":
678 dihedral_cost = tf.reduce_mean(
679 tf.square(periodic_distance(y_true, y_pred, p.periodicity))
680 )
681 elif p.dihedral_cost_variant == "mean_abs":
682 dihedral_cost = tf.reduce_mean(
683 tf.abs(periodic_distance(y_true, y_pred, p.periodicity))
684 )
685 elif p.dihedral_cost_variant == "mean_norm":
686 dihedral_cost = tf.reduce_mean(
687 tf.norm(periodic_distance(y_true, y_pred, p.periodicity), axis=1)
688 )
689 else:
690 raise ValueError(
691 "dihedral_cost_variant {} not available".format(p.auto_cost_variant)
692 )
693 dihedral_cost /= p.dihedral_cost_reference
694 dihedral_cost *= p.dihedral_cost_scale
695 else:
696 dihedral_cost = 0.0
697 tf.cond(
698 write_bool,
699 true_fn=lambda: _summary_cost("Dihedral Cost", dihedral_cost),
700 false_fn=lambda: _do_nothing(),
701 name="Cost",
702 )
703 tf.debugging.assert_all_finite(
704 dihedral_cost, message=f"Dihedral cost became infinite or NaN."
705 )
706 return dihedral_cost
708 return dihedral_loss_func
711def side_dihedral_loss(
712 model: tf.keras.models.Model,
713 parameters: Optional[AnyParameters] = None,
714 callback: Optional[tf.keras.callbacks.Callback] = None,
715) -> Callable:
716 """Encodermap sidechain dihedral loss.
718 Calculates distances between true and predicted sidechain dihedral angles.
719 Respects periodicity in a [-a, a] interval if the provided parameters have
720 a periodicity of 2 * a.
722 Note:
723 The interval should be (-a, a], but due to floating point precision we can't make this
724 distinction here.
726 Args:
727 model (tf.keras.Model): A model you want to use the loss function on.
728 parameters (Optional[AnyParameters]): The parameters. If None is
729 provided default values (check them with `print(em.Parameters.defaults_description())`)
730 are used. Defaults to None.
731 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback,
732 that prevents a tensorboard write when `parameters.summary_step` is
733 set to greater values. This saves disk-space, as costs are not needed
734 to be logged every training step.
736 Returns:
737 Callable: A loss function.
739 """
740 if parameters is None:
741 p = ADCParameters()
742 else:
743 p = parameters
745 if callback is None:
746 write_bool = K.constant(False, "bool", name="log_bool")
747 else:
748 write_bool = callback.log_bool
750 # closure
751 def side_dihedral_loss_func(
752 y_pred: tf.Tensor, y_true: Optional[tf.Tensor] = None
753 ) -> tf.Tensor:
754 """y_pred should be model input side dihedrals, y_true should be model output side dihedrals."""
755 side_dihedral_loss_func.name = "side_dihedral_loss"
756 if p.side_dihedral_cost_scale is not None:
757 if p.side_dihedral_cost_variant == "mean_square":
758 side_dihedral_cost = tf.reduce_mean(
759 tf.square(periodic_distance(y_true, y_pred, p.periodicity))
760 )
761 elif p.side_dihedral_cost_variant == "mean_abs":
762 side_dihedral_cost = tf.reduce_mean(
763 tf.abs(periodic_distance(y_true, y_pred, p.periodicity))
764 )
765 elif p.side_dihedral_cost_variant == "mean_norm":
766 side_dihedral_cost = tf.reduce_mean(
767 tf.norm(periodic_distance(y_true, y_pred, p.periodicity), axis=1)
768 )
769 else:
770 raise ValueError(
771 "dihedral_cost_variant {} not available".format(p.auto_cost_variant)
772 )
773 side_dihedral_cost /= p.side_dihedral_cost_reference
774 side_dihedral_cost *= p.side_dihedral_cost_scale
775 else:
776 side_dihedral_cost = 0.0
777 tf.cond(
778 write_bool,
779 true_fn=lambda: _summary_cost(
780 "Sidechain Dihedral Cost", side_dihedral_cost
781 ),
782 false_fn=lambda: _do_nothing(),
783 name="Cost",
784 )
785 tf.debugging.assert_all_finite(
786 side_dihedral_cost, message=f"Side dihedral cost became infinite or NaN."
787 )
788 return side_dihedral_cost
790 return side_dihedral_loss_func
793def angle_loss(
794 model: tf.keras.models.Model,
795 parameters: Optional[AnyParameters] = None,
796 callback: Optional[tf.keras.callbacks.Callback] = None,
797) -> Callable:
798 """Encodermap angle loss.
800 Calculates distances between true and predicted angles. Respects periodicity
801 in an [-a, a] interval if the provided parameters have a periodicity of 2 * a.
803 Note:
804 The interval should be (-a, a], but due to floating point precision we can't make this
805 distinction here.
807 Args:
808 model (tf.keras.Model): A model you want to use the loss function on.
809 parameters (Optional[AnyParameters]): The parameters. If None is
810 provided default values (check them with `print(em.Parameters.defaults_description())`)
811 are used. Defaults to None.
812 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback,
813 that prevents a tensorboard write when `parameters.summary_step` is
814 set to greater values. This saves disk-space, as costs are not needed
815 to be logged every training step.
817 Returns:
818 Callable: A loss function.
820 """
821 if parameters is None:
822 p = ADCParameters()
823 else:
824 p = parameters
826 if callback is None:
827 write_bool = K.constant(False, "bool", name="log_bool")
828 else:
829 write_bool = callback.log_bool
831 # closure
832 def angle_loss_func(
833 y_pred: tf.Tensor, y_true: Optional[tf.Tensor] = None
834 ) -> tf.Tensor:
835 """y_true should be input angles. y_pred should be output angles (either from mean input angles or, when
836 ADCParameters.use_backbone_angles == True, directly from model output)."""
837 angle_loss_func.name = "angle_loss"
838 if p.angle_cost_scale is not None:
839 if p.angle_cost_variant == "mean_square":
840 angle_cost = tf.reduce_mean(
841 tf.square(periodic_distance(y_true, y_pred, p.periodicity))
842 )
843 elif p.angle_cost_variant == "mean_abs":
844 angle_cost = tf.reduce_mean(
845 tf.abs(periodic_distance(y_true, y_pred, p.periodicity))
846 )
847 elif p.angle_cost_variant == "mean_norm":
848 angle_cost = tf.reduce_mean(
849 tf.norm(periodic_distance(y_true, y_pred, p.periodicity), axis=1)
850 )
851 else:
852 raise ValueError(
853 "angle_cost_variant {} not available".format(p.auto_cost_variant)
854 )
855 angle_cost /= p.angle_cost_reference
856 angle_cost *= p.angle_cost_scale
857 else:
858 angle_cost = 0.0
859 tf.cond(
860 write_bool,
861 true_fn=lambda: _summary_cost("Angle Cost", angle_cost),
862 false_fn=lambda: _do_nothing(),
863 name="Cost",
864 )
865 tf.debugging.assert_all_finite(
866 angle_cost, message=f"Angle cost became infinite or NaN."
867 )
868 return angle_cost
870 return angle_loss_func
873def cartesian_distance_loss(
874 model: tf.keras.models.Model,
875 parameters: Optional[AnyParameters] = None,
876 callback: Optional[tf.keras.callbacks.Callback] = None,
877) -> Callable:
878 """Encodermap cartesian distance loss.
880 Calculates sigmoid-weighted distances between pairwise cartesians and latent.
881 Uses sketch-map's sigmoid function to transform the high-dimensional space
882 of the input and the low-dimensional space of latent.
884 Note:
885 Make sure to provide the pairwise cartesian distances. The output of
886 the latent will be compared to the input.
888 Note:
889 If the model contains two layers. The first layer will be assumed to be
890 the decoder. If the model contains more layers, one layer needs to
891 be named 'latent' (case-insensitive).
893 Args:
894 model (tf.keras.Model): A model you want to use the loss function on.
895 parameters (Optional[AnyParameters]): The parameters. If None is
896 provided default values (check them with `print(em.Parameters.defaults_description())`)
897 are used. Defaults to None.
898 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback,
899 that prevents a tensorboard write when `parameters.summary_step` is
900 set to greater values. This saves disk-space, as costs are not needed
901 to be logged every training step.
903 Returns:
904 Callable: A loss function.
906 """
907 if parameters is None:
908 p = ADCParameters()
909 else:
910 p = parameters
912 if callback is None:
913 write_bool = K.constant(False, "bool", name="log_bool")
914 else:
915 write_bool = callback.log_bool
917 dist_loss = sigmoid_loss(
918 p,
919 periodicity_overwrite=float("inf"),
920 dist_dig_parameters_overwrite=p.cartesian_dist_sig_parameters,
921 )
923 def cartesian_distance_loss_func(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
924 """y_true can be whatever input you like, dihedrals, angles, pairwise dist, contact maps. That will be
925 transformed with Sketch-map's sigmoid function, as will the output of the latent layer of the autoencoder.
926 the difference of these two will result in a loss function."""
927 cartesian_distance_loss_func.name = "cartesian_distance_loss"
928 if p.cartesian_distance_cost_scale is not None:
929 dist_cost = dist_loss(y_true, y_pred)
930 dist_cost *= p.cartesian_distance_cost_scale
931 else:
932 dist_cost = 0.0
933 tf.cond(
934 write_bool,
935 true_fn=lambda: _summary_cost("Cartesian Distance Cost", dist_cost),
936 false_fn=lambda: _do_nothing(),
937 name="Cost",
938 )
939 tf.debugging.assert_all_finite(
940 dist_cost, message=f"Cartesian distance cost became infinite or NaN."
941 )
942 return dist_cost
944 return cartesian_distance_loss_func
947def cartesian_loss(
948 model: tf.keras.models.Model,
949 scale_callback: Optional[tf.keras.callbacks.Callback] = None,
950 parameters: Optional["AnyParameters"] = None,
951 log_callback: Optional[tf.keras.callbacks.Callback] = None,
952 print_current_scale: bool = False,
953) -> Callable:
954 """Encodermap cartesian loss.
956 Calculates difference between input and output pairwise distances.
957 Adjustments to this cost function via the `soft_start` parameter need to be made via a
958 callback that re-compiles the model during training. For this, the
959 `soft_start` parameters of the outer function will be used.
960 It must be either 0 or 1, indexing the 1st or 2nd element of the
961 `cartesian_cost_scale_soft_start` tuple. The callback should also be provided
962 when `model.fit()` is executed.
964 Three cases are possible:
965 * Case 1: step < cartesian_cost_scale_soft_start[0]: cost_scale = 0
966 * Case 2: cartesian_cost_scale_soft_start[0] <= step <= cartesian_cost_scale_soft_start[1]:
967 cost_scale = p.cartesian_cost_scale / (cartesian_cost_scale_soft_start[1] - cartesian_cost_scale_soft_start[0]) * step
968 * Case 3: cartesian_cost_scale_soft_start[1] < step: cost_scale = p.cartesian_cost_scale
970 Note:
971 Make sure to provide the pairwise cartesian distances. This function will be
972 adjusted as training increases via a callback. See
973 `encodermap.callbacks.callbacks.IncreaseCartesianCost` for more info.
975 Args:
976 model (tf.keras.Model): The model to use the loss function on.
977 scale_callback: Optional[encoodermap.callbacks.IncreaseCartesianCost]:
978 parameters (Optional[AnyParameters]]): The parameters. If None is provided,
979 default values (check them with
980 `print(em.ADCParameters.defaults_description())`) are used. Defaults
981 to None.
982 soft_start (Union[int, None], optional): How to scale the cartesian loss.
983 The `encodermap.parameters.ADCParameters` class contains a two-tuple of integers. These
984 integers can be used to scale this loss function. If soft_start is 0,
985 the first value of `ADCParameters.cartesian_cost_scale_soft_start`
986 will be used. If it is 1, the second. If it is None, or both values
987 of `ADCParameters.cartesian_cost_scale_soft_start` are None, the cost will
988 not be scaled. Defaults to None.
989 print_current_scale (bool, optional): Whether to print the current scale.
990 Is used in testing. Defaults to False.
992 Raises:
993 Exception: When no bottleneck/latent layer can be found in the model.
994 Exception: When soft_start is greater than 1 and can't index the two-tuple.
996 Returns:
997 Callable: A loss function. Can be used in either custom training or model.fit().
999 """
1000 if parameters is None:
1001 p = ADCParameters()
1002 else:
1003 p = parameters
1005 if scale_callback is not None:
1006 current_scale_callback = scale_callback.current_cartesian_cost_scale
1007 else:
1008 current_scale_callback = K.constant(
1009 p.cartesian_cost_scale, dtype="float32", name="current_cartesian_cost_scale"
1010 )
1012 if print_current_scale:
1013 print(current_scale_callback)
1015 if log_callback is None:
1016 write_bool = K.constant(False, "bool", name="log_bool")
1017 else:
1018 write_bool = log_callback.log_bool
1020 def cartesian_loss_func(
1021 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None
1022 ) -> tf.Tensor:
1023 """y_true should be pairwise distances of input cartesians,
1024 y_pred should be pairwise distances of back-mapped output cartesians."""
1025 scale = current_scale_callback
1026 if p.cartesian_cost_variant == "mean_square":
1027 cartesian_cost = tf.reduce_mean(tf.square(y_true - y_pred))
1028 elif p.cartesian_cost_variant == "mean_abs":
1029 cartesian_cost = tf.reduce_mean(tf.abs(y_true - y_pred))
1030 elif p.cartesian_cost_variant == "mean_norm":
1031 cartesian_cost = tf.reduce_mean(tf.norm(y_true - y_pred, axis=1))
1032 else:
1033 raise ValueError(
1034 "cartesian_cost_variant {} not available".format(
1035 p.dihedral_to_cartesian_cost_variant
1036 )
1037 )
1038 cartesian_cost /= p.cartesian_cost_reference
1039 tf.cond(
1040 write_bool,
1041 true_fn=lambda: _summary_cost(
1042 "Cartesian Cost before scaling", cartesian_cost
1043 ),
1044 false_fn=lambda: _do_nothing(),
1045 name="Cost",
1046 )
1047 tf.cond(
1048 write_bool,
1049 true_fn=lambda: _summary_cost("Cartesian Cost current scaling", scale),
1050 false_fn=lambda: _do_nothing(),
1051 name="Cost",
1052 )
1053 cartesian_cost *= scale
1054 tf.cond(
1055 write_bool,
1056 true_fn=lambda: _summary_cost(
1057 "Cartesian Cost after scaling", cartesian_cost
1058 ),
1059 false_fn=lambda: _do_nothing(),
1060 name="Cost",
1061 )
1062 tf.debugging.assert_all_finite(
1063 cartesian_cost, message=f"Cartesian cost became infinite or NaN."
1064 )
1065 return cartesian_cost
1067 return cartesian_loss_func