Coverage for encodermap/loss_functions/loss_classes.py: 36%
105 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/loss_functions/loss_classes.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Losses for encodermap.
24All losses in EncoderMap inherit from `tf.keras.losses.Loss` and thus can be
25easily paired with other models.
27"""
28################################################################################
29# Imports
30################################################################################
33# Future Imports at the top
34from __future__ import annotations
36# Standard Library Imports
37import functools
38import inspect
39import os
40from collections.abc import Sequence
41from typing import Any, Optional, Union
43# Third Party Imports
44import tensorflow as tf
45import tensorflow.keras.backend as K
47# Encodermap imports
48from encodermap.misc.distances import (
49 pairwise_dist,
50 pairwise_dist_periodic,
51 periodic_distance,
52 sigmoid,
53)
54from encodermap.parameters.parameters import ADCParameters, Parameters
57################################################################################
58# Typing
59################################################################################
62################################################################################
63# Globals
64################################################################################
67__all__: list[str] = ["DihedralLoss", "AngleLoss"]
70################################################################################
71# Utils
72################################################################################
75def testing(cls_or_func):
76 if inspect.isclass(cls_or_func):
77 orig_init = cls_or_func.__init__
79 @functools.wraps(cls_or_func)
80 def __init__(self, *args, **kwargs):
81 if os.getenv("ENCODERMAP_TESTING", "False") != "True":
82 raise Exception(
83 f"You are instantiating a em.testing class ({cls_or_func.__name__}). "
84 f"These classes are actively developed and not stable. If you "
85 f"know what you are doing, set the environment variable "
86 f"'EM_TESTING' to 'True'."
87 )
88 return orig_init(self, *args, **kwargs)
90 cls_or_func.__init__ = __init__
91 return cls_or_func
92 else:
94 @functools.wraps(cls_or_func)
95 def newfunc(*args, **kwargs):
96 if os.getenv("ENCODERMAP_TESTING", "False") != "True":
97 raise Exception(
98 f"You are calling an em.testing function: ({cls_or_func.__name__}). "
99 f"These functions are actively developed and not stable. If you "
100 f"know what you are doing, set the environment variable "
101 f"'EM_TESTING' to 'True'."
102 )
103 return cls_or_func(*args, **kwargs)
105 return newfunc
108def _do_nothing(*args) -> None:
109 """This function does nothing. One of the functions provided to tf.cond."""
110 pass
113def _summary_cost(
114 name: str,
115 cost: tf.Tensor,
116) -> None:
117 """This functions logs a scalar to a name. One of the functions provided to tf.cond.
119 Args:
120 name (str): The name to log the scalar as.
121 cost (tf.Tensor): The scalar tensor to log.
123 """
124 tf.summary.scalar(name, cost)
127################################################################################
128# PublicClasses
129################################################################################
132@testing
133@tf.keras.utils.register_keras_serializable()
134class EncoderMapBaseLoss(tf.keras.losses.Loss):
135 """EncoderMap's base loss. Serializes parameters and `self._train_counter`.
137 It Can be subclassed to implement custom loss functions that have access to
138 EncoderMap's parameter classes.
140 """
142 def __init__(
143 self,
144 parameters: Optional[Union[Parameters, ADCParameters]] = None,
145 write_bool_cb: Optional[tf.keras.callbacks.Callback] = None,
146 **kwargs,
147 ) -> None:
148 """Instantiate the Loss class.
150 Most subclassed losses, don't need to overwrite this `__init__()`.
152 Args:
153 parameters (Optional[Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]]): The parameters
154 this class will use to decide hwo to compute losses.
156 """
157 if parameters is None:
158 self.p = Parameters()
159 else:
160 self.p = parameters
162 if write_bool_cb is None:
163 self.cb = None
164 self.write_bool = K.constant(False, "bool", name="log_bool")
165 else:
166 self.cb = write_bool_cb
167 self.write_bool = write_bool_cb.log_bool
169 super().__init__()
171 def call(self, loss_name, current_loss) -> None:
172 """Use super().call(loss_name, current_loss) to log the current loss to
173 tensorboard and advance the train counter.
175 Args:
176 loss_name (str): The name of the loss, as it should appear in Tensorboard.
177 current_loss (float): The current value of the loss.
179 Returns:
180 float: The current loss.
182 """
183 tf.cond(
184 self.write_bool,
185 true_fn=lambda: _summary_cost(loss_name, current_loss),
186 false_fn=lambda: _do_nothing(),
187 name="Cost",
188 )
189 return current_loss
191 @classmethod
192 def from_config(cls, config):
193 p = config.pop("p")
194 if "cartesian_pwd_start" in p:
195 p = ADCParameters(**p)
196 else:
197 p = Parameters(**p)
198 write_bool_cb = tf.keras.saving.deserialize_keras_object(
199 config.pop("write_bool_cb")
200 )
201 return cls(parameters=p, write_bool_cb=write_bool_cb, **config)
203 def get_config(self) -> dict[Any, Any]:
204 config = super().get_config().copy()
205 config.update(
206 {
207 "p": self.p.to_dict(),
208 "write_bool_cb": tf.keras.saving.serialize_keras_object(self.cb),
209 }
210 )
211 return config
214@testing
215@tf.keras.utils.register_keras_serializable()
216class ADCBaseLoss(EncoderMapBaseLoss):
217 """Base class for all Losses of the `AngleDihedralCartesianEncoderMap`.
219 Replaces the default `Parameters()` with `ADCParameters()`.
221 """
223 def __init__(
224 self,
225 parameters: Optional[ADCParameters] = None,
226 write_bool_cb: Optional[tf.keras.callbacks.Callback] = None,
227 **kwargs,
228 ) -> None:
229 super().__init__(parameters, write_bool_cb)
230 if parameters is None:
231 self.p = ADCParameters()
234@testing
235@tf.keras.utils.register_keras_serializable()
236class DihedralLoss(ADCBaseLoss):
237 """EncoderMap's `DihedralLoss` for `AngleDihedralCartesianEncoderMap`.
239 Uses the periodicity in `self.p` to compare the distances of input and
240 output dihedrals. The `inp_dihedrals` are a tensor of size
241 (batch_size, n_dihedrals), the `out_dihedrals` are a tensor of size
242 (batch_size, n_dihedrals). The distances between two dihedrals are
243 calculated with `d = tf.abs(inp_dihedrals - out_dihedrals)`. This array
244 has the shape (batch_size, n_dihedrals). Because angles lie in a periodic
245 space with periodicity (-pi, pi] this array needs to be adjusted to account
246 for this with `tf.min(d, periodicity - d)`. The resulting array of shape
247 (batch_size, n_dihedrals) will now be transformed based on the
248 `dihedral_cost_variant`, which can be 'mean_square', 'mean_abs', or 'mean_norm'.
250 """
252 name = "DihedralLoss"
254 def call(self, y_true: Sequence[tf.Tensor], y_pred: Sequence[tf.Tensor]) -> float:
255 inp_dihedrals = y_pred[1]
256 out_dihedrals = y_true[1]
258 if self.p.dihedral_cost_scale is not None:
259 if self.p.dihedral_cost_variant == "mean_square":
260 dihedral_cost = tf.reduce_mean(
261 tf.square(
262 periodic_distance(
263 inp_dihedrals, out_dihedrals, self.p.periodicity
264 )
265 )
266 )
267 elif self.p.dihedral_cost_variant == "mean_abs":
268 dihedral_cost = tf.reduce_mean(
269 tf.abs(
270 periodic_distance(
271 inp_dihedrals, out_dihedrals, self.p.periodicity
272 )
273 )
274 )
275 elif self.p.dihedral_cost_variant == "mean_norm":
276 dihedral_cost = tf.reduce_mean(
277 tf.norm(
278 periodic_distance(
279 inp_dihedrals, out_dihedrals, self.p.periodicity
280 ),
281 axis=1,
282 )
283 )
284 else:
285 raise ValueError(
286 f"dihedral_cost_variant {self.p.dihedral_cost_variant} not available"
287 )
288 dihedral_cost /= self.p.dihedral_cost_reference
289 if self.p.dihedral_cost_scale != 0:
290 dihedral_cost *= self.p.dihedral_cost_scale
291 else:
292 dihedral_cost = 0
294 return super().call("Dihedral Cost", dihedral_cost)
297@testing
298@tf.keras.utils.register_keras_serializable()
299class AngleLoss(ADCBaseLoss):
300 """EncoderMap's `AngleLoss` for `AngleDihedralCartesianEncoderMap`.
302 Uses the periodicity in `self.p` to compare the distances of input and
303 output angles. The `inp_angles` are a tensor of size
304 (batch_size, n_angles), the `out_angles` are a tensor of size
305 (batch_size, n_angles). The distances between two angles are
306 calculated with `d = tf.abs(inp_angles - out_angles)`. This array
307 has the shape (batch_size, n_angles). Because angles lie in a periodic
308 space with periodicity (-pi, pi] this array needs to be adjusted to account
309 for this with `tf.min(d, periodicity - d)`. The resulting array of shape
310 (batch_size, n_angles) will now be transformed based on the
311 `angle_cost_variant`, which can be 'mean_square', 'mean_abs', or 'mean_norm'.
313 """
315 def call(self, y_true: Sequence[tf.Tensor], y_pred: Sequence[tf.Tensor]) -> float:
316 inp_angles = y_pred[0]
317 out_angles = y_true[0]
319 if self.p.angle_cost_scale is not None:
320 if self.p.angle_cost_variant == "mean_square":
321 angle_cost = tf.reduce_mean(
322 tf.square(
323 periodic_distance(inp_angles, out_angles, self.p.periodicity)
324 )
325 )
326 elif self.p.angle_cost_variant == "mean_abs":
327 angle_cost = tf.reduce_mean(
328 tf.abs(
329 periodic_distance(inp_angles, out_angles, self.p.periodicity)
330 )
331 )
332 elif self.p.angle_cost_variant == "mean_norm":
333 angle_cost = tf.reduce_mean(
334 tf.norm(
335 periodic_distance(inp_angles, out_angles, self.p.periodicity),
336 axis=1,
337 )
338 )
339 else:
340 raise ValueError(
341 f"angle_cost_variant {self.p.angle_cost_variant} not available"
342 )
343 angle_cost /= self.p.angle_cost_reference
344 if self.p.angle_cost_scale != 0:
345 angle_cost *= self.p.angle_cost_scale
346 else:
347 angle_cost = 0
349 return super().call("Angle Cost", angle_cost)