Coverage for encodermap/callbacks/callbacks.py: 33%
155 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/callbacks/callbacks.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Callbacks to strew into the Autoencoder classes."""
25################################################################################
26# Imports
27################################################################################
30from __future__ import annotations
32from typing import TYPE_CHECKING, Callable, Literal, Optional
34import numpy as np
35import tensorflow as tf
36import tensorflow.keras.backend as K
37from tqdm import tqdm
39from ..misc.saving_loading_models import save_model
40from ..misc.summaries import image_summary
41from ..parameters.parameters import ADCParameters, Parameters
43################################################################################
44# Typing
45################################################################################
48if TYPE_CHECKING:
49 from encodermap._typing import AnyParameters
52################################################################################
53# Globals
54################################################################################
57__all__ = [
58 "ProgressBar",
59 "EarlyStop",
60 "CheckpointSaver",
61 "TensorboardWriteBool",
62 "IncreaseCartesianCost",
63]
66################################################################################
67# Public Classes
68################################################################################
71class EncoderMapBaseCallback(tf.keras.callbacks.Callback):
72 """Base class for multiple callbacks.
74 Can be used to implement new callbacks that can also use `enocdermap.Parameters`
75 classes. A counter is increased after a tran_batch is finished. Based on the
76 two attributes `summary_step` and `checkpoint_step` in the `encodermap.Parameters`
77 classes the corresponding methods are called. Has two class attributes that are important:
79 Attributes:
80 steps_counter (int): The current step counter. Increases every `on_train_batch_end`.
81 p (Union[encodermap.Parameters, encodermap.ADCParameters]: The parameters for this callback.
82 Based on the `summary_step` and `checkpoint_step` of this parameters class different
83 class-methods are called.
85 """
87 def __init__(self, parameters: Optional[AnyParameters] = None) -> None:
88 """Instantiate the EncoderMapBaseCallback class.
90 Args:
91 parameters (Union[encodermap.Parameters, encodermap.ADCParameters, None], optional):
92 Parameters that will be used to print out the progress bar. If None is passed
93 default values (check them with print(em.ADCParameters.defaults_description())) will be used.
94 Defaults to None.
96 """
97 if parameters is None: 97 ↛ 98line 97 didn't jump to line 98, because the condition on line 97 was never true
98 self.p = Parameters()
99 else:
100 self.p = parameters
101 super().__init__()
102 self.steps_counter = 0
104 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
105 """Called after a batch ends. The number of batch is provided by keras.
107 This method is the backbone of all of encodermap's callbacks. After
108 every batch is method is called by keras. When the number of that
109 batch matches either `encodermap.Parameters.summary_step` or `encodermap.Parameters.checkpoint_step`
110 the code on `self.on_summary_step`, or `self.on_checkpoint_step` is
111 executed. These methods should be overwritten by child classes.
113 Args:
114 batch (int): The number of the current batch. Provided by keras.
115 logs (Optional[dict]): `logs` is a dict containing the metrics results.
117 """
118 self.steps_counter += 1
119 if self.steps_counter % self.p.checkpoint_step == 0:
120 self.on_checkpoint_step(self.steps_counter, logs=logs)
121 if self.steps_counter % self.p.summary_step == 0:
122 self.on_summary_step(self.steps_counter, logs=logs)
124 def on_summary_step(self, step: int, logs: Optional[dict] = None) -> None:
125 """Executed, when the currently finished batch matches `encodermap.Parameters.summary_step`
127 Args:
128 step (int): The number of the current step.
129 logs (Optional[dict]): `logs` is a dict containing the metrics results.
131 """
132 pass
134 def on_checkpoint_step(self, step: int, logs: Optional[dict] = None) -> None:
135 """Executed, when the currently finished batch matches `encodermap.Parameters.checkpoint_step`
137 Args:
138 step (int): The number of the current step.
139 logs (Optional[dict]): `logs` is a dict containing the metrics results.
141 """
142 pass
145##############################################################################
146# Public Classes
147##############################################################################
150class EarlyStop(tf.keras.callbacks.Callback):
151 """Stop training when the loss is at its min, i.e. the loss stops decreasing.
153 Arguments:
154 patience (int): Number of epochs to wait after min has been hit. After this
155 number of no improvement, training stops.
157 """
159 def __init__(self, patience: int = 0) -> None:
160 """Instantiate the `EarlyStop` class.
162 Args:
163 patience (int): Number of training steps to wait after min has been hit.
164 Training is halted after this number of steps without improvement.
166 """
167 super().__init__()
168 self.patience = patience
169 # best_weights to store the weights at which the minimum loss occurs.
170 self.best_weights = None
172 def on_train_begin(self, logs: Optional[dict] = None) -> None:
173 """Sets some attributes at the beginning of training."""
174 # The number of epoch it has waited when loss is no longer minimum.
175 self.wait = 0
176 # The epoch the training stops at.
177 self.stopped_batch = 0
178 # Initialize the best as infinity.
179 self.best = np.Inf
181 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
182 """Gets the current loss at the end of the batch compares it to previous batches."""
183 current = logs.get("loss")
184 if np.less(current, self.best):
185 self.best = current
186 self.wait = 0
187 # Record the best weights if current results is better (less).
188 self.best_weights = self.model.get_weights()
189 else:
190 self.wait += 1
191 if self.wait >= self.patience:
192 self.stopped_batch = batch
193 self.model.stop_training = True
194 print("Restoring model weights from the end of the best epoch.")
195 self.model.set_weights(self.best_weights)
197 def on_train_end(self, logs: Optional[dict] = None) -> None:
198 """Prints a message after training, if an early stop occured."""
199 if self.stopped_batch > 0:
200 print("Step %05d: early stopping" % (self.stopped_batch + 1))
203class ProgressBar(EncoderMapBaseCallback):
204 """Progressbar Callback. Mix in with model.fit() and make sure to set verbosity to zero."""
206 def on_train_begin(self, logs: Optional[dict] = None) -> None:
207 """Simply creates the progressbar once training starts."""
208 self.pbar = tqdm(total=self.p.n_steps, position=0, leave=True)
209 postfix = {f"Loss after step ?": "?"}
210 if isinstance(self.p, ADCParameters):
211 postfix["Cartesian cost Scale"] = "?"
212 self.pbar.set_postfix(postfix)
214 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
215 """Overwrites the parent class' `on_train_batch_end` and adds a progress-bar update."""
216 super().on_train_batch_end(batch, logs=logs)
217 self.pbar.update()
219 def on_summary_step(self, epoch: int, logs: Optional[dict] = None) -> None:
220 """Update the progress bar after an epoch with the current loss.
222 Args:
223 epoch(int): Current epoch. Will be automatically passed by tensorflow.
224 logs (Optional[dict]): Also automatically passed by tensorflow.
225 Contains metrics and losses. logs['loss'] will be written to the progress bar.
227 """
228 if logs != {}:
229 postfix = {f"Loss after step {epoch}": logs["loss"]}
230 if isinstance(self.p, ADCParameters):
231 if self.p.cartesian_cost_scale_soft_start != (None, None):
232 if self.p.cartesian_cost_scale is not None:
233 if (
234 self.p.cartesian_cost_scale_soft_start[0] is None
235 or epoch is None
236 ):
237 scale = self.p.cartesian_cost_scale
238 else:
239 a, b = self.p.cartesian_cost_scale_soft_start
240 if epoch < a:
241 scale = 0
242 elif a <= epoch <= b:
243 scale = self.p.cartesian_cost_scale / (b - a) * (epoch - a)
244 else:
245 scale = self.p.cartesian_cost_scale
246 else:
247 scale = 0
248 else:
249 scale = self.p.cartesian_cost_scale
250 postfix["Cartesian cost Scale"] = np.round(scale, 2)
251 self.pbar.set_postfix(postfix)
253 def on_train_end(self, logs: Optional[dict] = None) -> None:
254 """Close the Progress Bar"""
255 self.pbar.close()
258class ImageCallback(tf.keras.callbacks.Callback):
259 """Writes images to tensoboard."""
261 def __init__(
262 self,
263 highd_data: np.ndarray,
264 image_step: int,
265 scatter_kws: Optional[dict] = None,
266 hist_kws: Optional[dict] = None,
267 additional_fns: Optional[list[Callable]] = None,
268 when: Literal["batch", "epoch"] = "batch",
269 ):
270 """Instantiate the ImageCallback.
272 This class uses `encodermap.misc.summaries` to plot images to tensorboard.
274 Args:
275 highd_data (np.ndarray): The high-dimensional data, that will be provided
276 to the plotting functions.
277 image_step (int): When to plot the images. This argument is combined
278 with the `when` argument to either plot every "epoch" % `image_step` == 0
279 or every "batch" % `image_step` == 0 steps.
280 scatter_kws (Optional[dict]): A dictionary, that `matploltib.pyplot.scatter`
281 takes as keyword args. If None is provided, the default dict
282 is {"s": 20}. Defaults to None.
283 hist_kws (Optional[dict]): A dictionary, that `matploltib.pyplot.histogram`
284 takes as keyword args. If None is provided, the default dict
285 is {"bins": 50}. Defaults to None.
286 additional_fns (Optional[list[Callabe]]): Can be None or a list
287 of functions, that return `io.BytesIO()` buffered images (see
288 Example).
289 when (Literal["batch", "epoch"]): When to plot the images. Works in
290 conjunction with the argument `image_step`.
292 Here's an example how to use this class
293 .. code-block:: python
294 def return_hist(data, hist_kws):
295 plt.close("all")
296 matplotlib.use("Agg") # overwrites current backend of notebook
297 plt.figure()
298 plt.hist2d(*data.T, **hist_kws)
299 buf = io.BytesIO()
300 plt.savefig(buf, format="png")
301 buf.seek(0)
302 image = tf.image.decode_png(buf.getvalue(), 4)
303 image = tf.expand_dims(image, 0)
304 return image
306 """
307 super().__init__()
308 self.highd_data = highd_data
309 self.image_step = image_step
310 if scatter_kws is None:
311 scatter_kws = {"s": 20}
312 self.scatter_kws = scatter_kws
313 if hist_kws is None:
314 hist_kws = {"bins": 50}
315 self.hist_kws = hist_kws
316 self.additional_fns = additional_fns
317 self.when = when
319 def on_epoch_end(self, epoch: int, logs: Optional[dict] = None) -> None:
320 """Calls `encodermap.misc.summaries.image_summary` on epoch end."""
321 if epoch % self.image_step == 0 and self.when == "epoch":
322 lowd = self.model.encoder(self.highd_data).numpy()
323 assert lowd is not None
324 image_summary(
325 lowd, epoch, self.scatter_kws, self.hist_kws, self.additional_fns
326 )
328 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
329 """Calls `encodermap.misc.summaries.image_summary` on batch end."""
330 if batch % self.image_step == 0 and self.when == "batch":
331 lowd = self.model.encoder(self.highd_data).numpy()
332 assert lowd is not None
333 image_summary(
334 lowd, batch, self.scatter_kws, self.hist_kws, self.additional_fns
335 )
338class CheckpointSaver(EncoderMapBaseCallback):
339 """Callback, that saves an `encodermap.models` model."""
341 def on_checkpoint_step(self, epoch: int, logs: Optional[dict] = None) -> None:
342 """Overwrites parent class' `on_checkpoint_step` method.
344 Uses `encodermap.misc.saving_loading_models.save_model` to save the model.
345 Luckily, the keras callbacks contain the model as an attribute (self.model).
347 """
348 save_model(self.model, self.p.main_path, "", step=epoch)
349 # tf.keras.models.save_model(self.model, f'{self.p.main_path}/keras_model_{epoch}_epochs.model')
352class IncreaseCartesianCost(tf.keras.callbacks.Callback):
353 """Callback for the `enocdermap.autoencoder.AngleDihedralCarteisanEncoderMap`.
355 This callback implements the soft-start of the cartesian cost.
357 """
359 def __init__(
360 self, parameters: Optional[ADCParameters] = None, start_step: int = 0
361 ) -> None:
362 """Instantiate the callback.
364 Args:
365 parameters (Optional[ACDParameters]: Can be either None, or an instance
366 of `encodermap.parameters.ACDParameters`. These parameters define the
367 steps at which the cartesian cost scaling factor needs to be adjusted.
368 If None is provided, the default values `(None, None)`, i.e. no
369 cartesian cost, will be used. Deafults to None.
370 start_step (int): The current step of the training. This argument
371 is important is training is stopped using the scaling cartesian
372 cost. This argument will usually be loaded from a file in the saved model.
374 """
375 if parameters is None: 375 ↛ 376line 375 didn't jump to line 376, because the condition on line 375 was never true
376 self.p = ADCParameters()
377 else:
378 self.p = parameters
379 super().__init__()
380 self.a, self.b = self.p.cartesian_cost_scale_soft_start
381 self.last_compilation: bool = False
382 self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.p.learning_rate)
383 # use a instance variable for the case the model is reloaded and re-trained.
384 self.current_step = start_step
385 self.current_cartesian_cost_scale = K.variable(
386 0.0, dtype="float32", name="current_cartesian_cost_scale"
387 )
388 K.set_value(
389 self.current_cartesian_cost_scale,
390 self.calc_current_cartesian_cost_scale(self.current_step),
391 )
393 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None):
394 "Sets the value of the keras backend variable `self.current_cartesian_cost_scale`"
395 self.current_step += 1
396 K.set_value(
397 self.current_cartesian_cost_scale,
398 self.calc_current_cartesian_cost_scale(self.current_step),
399 )
401 def calc_current_cartesian_cost_scale(self, epoch):
402 """Calculates the current cartesian distance scale, based on the parameters
403 `self.a`, `self.b` `self.p.cartesian_cost_scale`.
404 """
405 assert isinstance(epoch, int)
406 if self.p.cartesian_cost_scale is not None: 406 ↛ 421line 406 didn't jump to line 421, because the condition on line 406 was never false
407 if self.p.cartesian_cost_scale_soft_start[0] is None or epoch is None: 407 ↛ 408line 407 didn't jump to line 408, because the condition on line 407 was never true
408 scale = self.p.cartesian_cost_scale
409 else:
410 if epoch < self.a:
411 scale = 0.0
412 elif self.a <= epoch <= self.b: 412 ↛ 419line 412 didn't jump to line 419, because the condition on line 412 was never false
413 scale = (
414 self.p.cartesian_cost_scale
415 / (self.b - self.a)
416 * (epoch - self.a)
417 )
418 else:
419 scale = self.p.cartesian_cost_scale
420 else:
421 scale = 0.0
422 # scale = K.variable(scale, dtype='float32', name='current_cartesian_cost_scale')
423 return scale
426class TensorboardWriteBoolAlwaysFalse(tf.keras.callbacks.Callback):
427 """A tensorboard callback, that is always False. Used for debugging."""
429 def __init__(self) -> None:
430 """Instantiate this class."""
431 self.log_bool = K.variable(False, bool, "log_scalar")
432 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar"))
435class TensorboardWriteBool(tf.keras.callbacks.Callback):
436 """This class saves the value of the keras variable `log_bool`.
438 Based on this variable, stuff will be written to tensorboard, or not.
440 """
442 def __init__(self, parameters: Optional[AnyParameters] = None) -> None:
443 """Instantiate the class.
445 Args:
446 parameters (Union[encodermap.Parameters, encodermap.ADCParameters, None], optional):
447 Parameters that will be used check when data should be written to tensorboard. If None is passed
448 default values (check them with print(em.ADCParameters.defaults_description())) will be used.
449 Defaults to None.
451 """
452 if parameters is None:
453 self.p = Parameters()
454 else:
455 self.p = parameters
456 super().__init__()
457 self.log_bool = K.variable(False, bool, "log_scalar")
458 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar"))
459 self.current_training_step = 0
461 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
462 """Sets the value of the keras backend variable `log_bool`.
464 This method does not use the `batch` argument, because, the variable
465 `self.current_training_step` is used.
467 """
468 self.current_training_step += 1
469 if self.p.tensorboard:
470 if self.current_training_step % self.p.summary_step == 0:
471 K.set_value(self.log_bool, K.variable(True, bool, "log_scalar"))
472 else:
473 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar"))