Coverage for encodermap/callbacks/callbacks.py: 14%
240 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
1# -*- coding: utf-8 -*-
2# encodermap/callbacks/callbacks.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Callbacks to strew into the Autoencoder classes."""
25################################################################################
26# Imports
27################################################################################
30# Future Imports at the top
31from __future__ import annotations
33# Standard Library Imports
34from collections.abc import Callable
35from pathlib import Path
36from typing import TYPE_CHECKING, Literal, Optional, Union
38# Third Party Imports
39import matplotlib
40import matplotlib.pyplot as plt
41import numpy as np
42import tensorflow as tf
43import tensorflow.keras.backend as K
44from PIL import Image
45from tqdm import tqdm
47# Encodermap imports
48from encodermap.misc.saving_loading_models import save_model
49from encodermap.misc.summaries import image_summary
50from encodermap.parameters.parameters import ADCParameters, AnyParameters, Parameters
53################################################################################
54# Globals
55################################################################################
58__all__: list[str] = [
59 "ProgressBar",
60 "EarlyStop",
61 "CheckpointSaver",
62 "TensorboardWriteBool",
63 "IncreaseCartesianCost",
64 "NoneInterruptCallback",
65 "ImageCallback",
66 "EncoderMapBaseCallback",
67]
70################################################################################
71# Helpers
72################################################################################
75def np_to_sparse_tensor(a: np.ndarray) -> tf.sparse.SparseTensor:
76 indices = np.stack(np.where(~np.isnan(a))).T.astype("int64")
77 dense_shape = a.shape
78 a = a[~np.isnan(a)].flatten()
79 return tf.sparse.SparseTensor(indices, a, dense_shape)
82################################################################################
83# Public Classes
84################################################################################
87class NoneInterruptCallback(tf.keras.callbacks.Callback):
88 """A callback that interrupts training, when NaN is encountered in weights."""
90 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
91 """Gets the current loss at the end of the batch compares it to previous batches."""
92 for w in self.model.get_weights():
93 if not isinstance(w, np.ndarray):
94 continue
95 if np.any(np.isnan(w)):
96 print(
97 f"At batch {self.model._my_train_counter.numpy()}, the "
98 f"model has NaNs in one of its weights. Because "
99 f"multiplication with NaN yields NaN, this NaN value will now "
100 f"propagate through the network until all weights are tensors of "
101 f"NaNs. I stopped the training at this point, as further training is "
102 f"pointless. This error might originate from your input. You "
103 f"can run the training with `deterministic=True` and check "
104 f"whether this problems happens at the same training step "
105 f"for multiple trainings. If yes, your input contains NaNs. "
106 f"If no, you can try to lower the learning rate."
107 )
108 self.model.stop_training = True
109 break
112class EncoderMapBaseCallback(tf.keras.callbacks.Callback):
113 """Base class for callbacks in EncoderMap.
115 The `Parameters` class in EncoderMap has a `summary_step` variable that
116 dictates when variables and other tensors are logged to TensorBoard. No
117 matter what property is logged there will always be a code section
118 executing a `if train_step % summary_step == 0` code snippet. This is
119 handled centrally in this class. This class is instantiated inside the
120 user-facing `AutoEncoderClass` classes and is provided with the appropriate
121 parameters (`Parameters` for `EncoderMap` and `ADCParameters` for
122 `AngleDihedralCartesianEncoderMap`). Thus, subclassing this class does not
123 need to implement a new `__init__` method. Only the `on_summary_step` and
124 the `on_checkpoint_step` methods need to be implemented for sub-classes
125 if this class with code that should happen when these events happen.
127 Examples:
129 In this example, the `on_summary_step` method causes an exception.
131 >>> from typing import Optional
132 >>> import encodermap as em
133 ...
134 >>> class MyCallback(em.callbacks.EncoderMapBaseCallback):
135 ... def on_summary_step(self, step: int, logs: Optional[dict] = None) -> None:
136 ... raise Exception(f"Summary step {self.steps_counter} has been reached.")
137 ...
138 >>> emap = em.EncoderMap() # doctest: +ELLIPSIS
139 Output...
140 >>> emap.add_callback(MyCallback)
141 >>> emap.train() # doctest: +ELLIPSIS, +IGNORE_EXCEPTION_DETAIL
142 Traceback (most recent call last):
143 ...
144 Exception: Summary step 10 has been reached.
147 Attributes:
148 steps_counter (int): The current step counter. Increases every `on_train_batch_end`.
149 p (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]:
150 The parameters for this callback. Based on the `summary_step` and
151 `checkpoint_step` of the `encodermap.parameters.Parameters` class different
152 class-methods are called.
154 """
156 def __init__(self, parameters: Optional["AnyParameters"] = None) -> None:
157 """Instantiate the EncoderMapBaseCallback class.
159 Args:
160 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters, None], optional):
161 Parameters that will be used to print out the progress bar. If None is passed
162 default values (check them with print(em.ADCParameters.defaults_description())) will be used.
163 Defaults to None.
165 """
166 if parameters is None:
167 self.p = Parameters()
168 else:
169 self.p = parameters
170 super().__init__()
171 self.steps_counter = 0
173 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
174 """Called after a batch ends. The number of batch is provided by keras.
176 This method is the backbone of all of EncoderMap's callbacks. After
177 every batch is method is called by keras. When the number of that
178 batch matches either `encodermap.Parameters.summary_step` or `encodermap.Parameters.checkpoint_step`
179 the code on `self.on_summary_step`, or `self.on_checkpoint_step` is
180 executed. These methods should be overwritten by child classes.
182 Args:
183 batch (int): The number of the current batch. Provided by keras.
184 logs (Optional[dict]): `logs` is a dict containing the metrics results.
186 """
187 self.steps_counter += 1
188 if self.steps_counter % self.p.checkpoint_step == 0:
189 self.on_checkpoint_step(self.steps_counter, logs=logs)
190 if self.steps_counter % self.p.summary_step == 0:
191 self.on_summary_step(self.steps_counter, logs=logs)
193 def on_summary_step(self, step: int, logs: Optional[dict] = None) -> None:
194 """Executed, when the currently finished batch matches `encodermap.Parameters.summary_step`
196 Args:
197 step (int): The number of the current step.
198 logs (Optional[dict]): `logs` is a dict containing the metrics results.
200 """
201 pass
203 def on_checkpoint_step(self, step: int, logs: Optional[dict] = None) -> None:
204 """Executed, when the currently finished batch matches `encodermap.Parameters.checkpoint_step`
206 Args:
207 step (int): The number of the current step.
208 logs (Optional[dict]): `logs` is a dict containing the metrics results.
210 """
211 pass
214##############################################################################
215# Public Classes
216##############################################################################
219class EarlyStop(tf.keras.callbacks.Callback):
220 """Stop training when the loss is at its min, i.e. the loss stops decreasing.
222 Arguments:
223 patience (int): Number of epochs to wait after min has been hit. After this
224 number of no improvement, training stops.
226 """
228 def __init__(self, patience: int = 0) -> None:
229 """Instantiate the `EarlyStop` class.
231 Args:
232 patience (int): Number of training steps to wait after min has been hit.
233 Training is halted after this number of steps without improvement.
235 """
236 super().__init__()
237 self.patience = patience
238 # best_weights to store the weights at which the minimum loss occurs.
239 self.best_weights = None
241 def on_train_begin(self, logs: Optional[dict] = None) -> None:
242 """Sets some attributes at the beginning of training."""
243 # The number of epoch it has waited when loss is no longer minimum.
244 self.wait = 0
245 # The epoch the training stops at.
246 self.stopped_batch = 0
247 # Initialize the best as infinity.
248 self.best = np.Inf
250 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
251 """Gets the current loss at the end of the batch compares it to previous batches."""
252 current = logs.get("loss")
253 if np.less(current, self.best):
254 self.best = current
255 self.wait = 0
256 # Record the best weights if current results is better (less).
257 self.best_weights = self.model.get_weights()
258 else:
259 self.wait += 1
260 if self.wait >= self.patience:
261 self.stopped_batch = batch
262 self.model.stop_training = True
263 print("Restoring model weights from the end of the best epoch.")
264 self.model.set_weights(self.best_weights)
266 def on_train_end(self, logs: Optional[dict] = None) -> None:
267 """Prints a message after training, if an early stop occured."""
268 if self.stopped_batch > 0:
269 print("Step %05d: early stopping" % (self.stopped_batch + 1))
272class ProgressBar(EncoderMapBaseCallback):
273 """Progressbar Callback. Mix in with model.fit() and make sure to set verbosity to zero."""
275 def on_train_begin(self, logs: Optional[dict] = None) -> None:
276 """Simply creates the progressbar once training starts."""
277 self.pbar = tqdm(
278 total=self.p.n_steps,
279 initial=self.p.current_training_step,
280 position=0,
281 leave=True,
282 )
283 postfix = {f"Loss after step ?": "?"}
284 if isinstance(self.p, ADCParameters):
285 postfix["Cartesian cost scale"] = "?"
286 self.pbar.set_postfix(postfix)
288 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
289 """Overwrites the parent class' `on_train_batch_end` and adds a progress-bar update."""
290 super().on_train_batch_end(batch, logs=logs)
291 self.pbar.update()
293 def on_summary_step(self, epoch: int, logs: Optional[dict] = None) -> None:
294 """Update the progress bar after an epoch with the current loss.
296 Args:
297 epoch(int): Current epoch. Will be automatically passed by tensorflow.
298 logs (Optional[dict]): Also automatically passed by tensorflow.
299 Contains metrics and losses. logs['loss'] will be written to the progress bar.
301 """
302 if logs != {}:
303 postfix = {f"Loss after step {epoch}": logs["loss"]}
304 else:
305 postfix = {f"Loss after step {epoch}": "?"}
306 epoch += self.p.current_training_step
307 if isinstance(self.p, ADCParameters):
308 if self.p.cartesian_cost_scale_soft_start != (None, None):
309 if self.p.cartesian_cost_scale is not None:
310 if (
311 self.p.cartesian_cost_scale_soft_start[0] is None
312 or epoch is None
313 ):
314 scale = self.p.cartesian_cost_scale
315 else:
316 a, b = self.p.cartesian_cost_scale_soft_start
317 if epoch < a:
318 scale = 0
319 elif a <= epoch <= b:
320 scale = self.p.cartesian_cost_scale / (b - a) * (epoch - a)
321 else:
322 scale = self.p.cartesian_cost_scale
323 else:
324 scale = 0
325 else:
326 scale = self.p.cartesian_cost_scale
327 postfix["Cartesian cost scale"] = np.round(scale, 2)
328 self.pbar.set_postfix(postfix)
330 def on_train_end(self, logs: Optional[dict] = None) -> None:
331 """Close the Progress Bar"""
332 self.pbar.close()
335class ImageCallback(tf.keras.callbacks.Callback):
336 """Writes images to tensorboard."""
338 def __init__(
339 self,
340 parameters: AnyParameters,
341 highd_data: np.ndarray,
342 image_step: int,
343 backend: Literal["matplotlib", "plotly"] = "matplotlib",
344 mpl_scatter_kws: Optional[dict] = None,
345 mpl_hist_kws: Optional[dict] = None,
346 plotly_scatter_kws: Optional[dict] = None,
347 plotly_hist_kws: Optional[dict] = None,
348 additional_fns: Optional[list[Callable]] = None,
349 when: Literal["batch", "epoch"] = "batch",
350 save_dir: Optional[Union[str, Path]] = None,
351 ):
352 """Instantiate the ImageCallback.
354 This class uses `encodermap.misc.summaries` to plot images to tensorboard.
356 Args:
357 highd_data (np.ndarray): The high-dimensional data, that will be provided
358 to the plotting functions.
359 image_step (int): When to plot the images. This argument is combined
360 with the `when` argument to either plot every "epoch" % `image_step` == 0
361 or every "batch" % `image_step` == 0 steps.
362 backend (Literal["matplotlib", "plotly"]: Which backend to use for
363 plotting. Defaults to "matplotlib".
364 mpl_scatter_kws (Optional[dict]): A dictionary, that `matplotlib.pyplot.scatter`
365 takes as keyword args. If None is provided, the default dict
366 is {"s": 20}. Defaults to None.
367 mpl_hist_kws (Optional[dict]): A dictionary, that `matplotlib.pyplot.histogram`
368 takes as keyword args. If None is provided, the default dict
369 is {"bins": 50}. Defaults to None.
370 plotly_scatter_kws (Optional[dict[str, Any]]): A dict with items that
371 `plotly.express.scatter()` will accept. If None is provided,
372 a dict with size 20 will be passed to
373 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`,
374 which sets an appropriate size of scatter points for the size of
375 datasets encodermap is usually used for.
376 plotly_hist_kws (Optional[dict[str, Any]]): A dict with items that
377 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a
378 dict with bins 50 will be passed to
379 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`.
380 You can choose a colormap here by providing `{'bins': 50, 'cmap':
381 'plasma'}` for this argument.
382 additional_fns (Optional[list[Callabe]]): Can be None or a list
383 of functions, that return `io.BytesIO()` buffered images (see
384 Example).
385 when (Literal["batch", "epoch"]): When to plot the images. Works in
386 conjunction with the argument `image_step`.
387 save_dir (Optional[Union[str, Path]]): When specified, images are
388 saved to the specified directory during training.
390 Here's an example of how to use this class:
391 .. code-block:: python
392 def return_hist(data, hist_kws):
393 plt.close("all")
394 matplotlib.use("Agg") # overwrites current backend of notebook
395 plt.figure()
396 plt.hist2d(*data.T, **hist_kws)
397 buf = io.BytesIO()
398 plt.savefig(buf, format="png")
399 buf.seek(0)
400 image = tf.image.decode_png(buf.getvalue(), 4)
401 image = tf.expand_dims(image, 0)
402 return image
404 """
405 super().__init__()
406 self.backend = backend
407 self.p = parameters
408 self.highd_data = highd_data
409 self.image_step = image_step
410 if self.backend == "matplotlib":
411 scatter_kws = mpl_scatter_kws
412 hist_kws = mpl_hist_kws
413 if scatter_kws is None:
414 scatter_kws = {"s": 20}
415 if hist_kws is None:
416 hist_kws = {"bins": 50}
417 elif self.backend == "plotly":
418 scatter_kws = plotly_scatter_kws
419 hist_kws = plotly_hist_kws
420 if scatter_kws is None:
421 scatter_kws = {"size_max": 1, "opacity": 0.2}
422 if hist_kws is None:
423 hist_kws = {"bins": 50}
424 else:
425 raise Exception(
426 f"Argument `backend` must be either 'plotly' or 'matplotlib'."
427 )
428 self.scatter_kws = scatter_kws
429 self.hist_kws = hist_kws
430 self.additional_fns = additional_fns
431 self.when = when
432 self.save_dir = save_dir
433 if self.save_dir is not None:
434 self.save_dir = Path(save_dir)
436 def get_lowd(self):
437 if isinstance(self.highd_data, (list, tuple)):
438 if self.highd_data[0].shape[0] * self.highd_data[0].shape[1] > 100_000:
439 indices = np.split(
440 np.arange(self.highd_data[0].shape[0]),
441 np.arange(100, self.highd_data[0].shape[0], 100),
442 )
443 if len(indices[-1]) == 1:
444 indices = np.split(
445 np.arange(self.highd_data[0].shape[0]),
446 np.arange(100, self.highd_data[0].shape[0], 101),
447 )
448 lowd = []
449 for i, ind in enumerate(indices):
450 data = []
451 for d in self.highd_data:
452 if isinstance(d, tf.sparse.SparseTensor):
453 d = tf.sparse.to_dense(d, default_value=np.nan).numpy()[ind]
454 data.append(np_to_sparse_tensor(d))
455 else:
456 data.append(d[ind])
457 lowd.append(self.model.encoder(data).numpy())
458 return np.vstack(lowd)
459 lowd = self.model.encoder(self.highd_data).numpy()
460 return lowd
462 def on_epoch_end(self, epoch: int, logs: Optional[dict] = None) -> None:
463 """Calls `encodermap.misc.summaries.image_summary` on epoch end."""
464 if self.when == "epoch":
465 if self.image_step != 0:
466 if epoch % self.image_step != 0:
467 return
468 lowd = self.get_lowd()
469 assert lowd is not None
470 image_summary(
471 lowd=lowd,
472 step=epoch,
473 scatter_kws=self.scatter_kws,
474 hist_kws=self.hist_kws,
475 additional_fns=self.additional_fns,
476 backend=self.backend,
477 )
478 if self.save_dir is not None:
479 self.save_image_to_dir(lowd, epoch)
480 if isinstance(self.p, ADCParameters):
481 if self.p.track_RMSD and self.save_dir is not None:
482 rmsds = self.model.compiled_metrics._metrics[-1].result().numpy()
483 np.save(self.save_dir / f"rmsds_epoch_{epoch}.npy", rmsds)
485 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
486 """Calls `encodermap.misc.summaries.image_summary` on batch end."""
487 if self.when == "batch":
488 if self.image_step != 0:
489 if batch % self.image_step != 0:
490 return
491 lowd = self.get_lowd()
492 assert lowd is not None
493 image_summary(
494 lowd=lowd,
495 step=batch,
496 scatter_kws=self.scatter_kws,
497 hist_kws=self.hist_kws,
498 additional_fns=self.additional_fns,
499 backend=self.backend,
500 )
501 if self.save_dir is not None:
502 self.save_image_to_dir(lowd, batch)
503 if isinstance(self.p, ADCParameters):
504 if self.p.track_RMSD and self.save_dir is not None:
505 rmsds = self.model.compiled_metrics._metrics[-1].result().numpy()
506 np.save(self.save_dir / f"rmsds_batch_{batch}.npy", rmsds)
508 def save_image_to_dir(self, lowd: np.ndarray, step: int) -> None:
509 """Saves the lowd representation to disk, so it can be looked at later."""
510 outfile = self.save_dir / f"{self.when}_{step}.png"
511 if not np.any(np.isnan(lowd)):
512 plt.close("all")
513 matplotlib.use("Agg") # overwrites current backend of notebook
514 plt.figure()
515 plt.hist2d(*lowd.T, **self.hist_kws)
516 plt.savefig(outfile, format="png")
517 if self.additional_fns is not None:
518 for i, fn in enumerate(self.additional_fns):
519 image = Image.fromarray(fn(lowd).numpy()[0].astype("uint8"))
520 image.save(outfile.with_stem(f"{self.when}_{step}_user_fn_{i}"))
521 outfile = outfile.with_suffix(".npy")
522 np.save(outfile, lowd)
525class CheckpointSaver(EncoderMapBaseCallback):
526 """Callback that saves an `encodermap.models` model."""
528 def on_checkpoint_step(self, epoch: int, logs: Optional[dict] = None) -> None:
529 """Overwrites parent class' `on_checkpoint_step` method.
531 Uses `encodermap.misc.saving_loading_models.save_model` to save the model.
532 Luckily, the keras callbacks contain the model as an attribute (self.model).
534 """
535 save_model(self.model, self.p.main_path, step=epoch)
538class IncreaseCartesianCost(tf.keras.callbacks.Callback):
539 """Callback for the `enocdermap.autoencoder.AngleDihedralCarteisanEncoderMap`.
541 This callback implements the soft-start of the cartesian cost.
543 """
545 def __init__(
546 self,
547 parameters: Optional[ADCParameters] = None,
548 ) -> None:
549 """Instantiate the callback.
551 Args:
552 parameters (Optional[encodermap.parameters.ADCParameters]: Can be either None, or an instance
553 of `encodermap.parameters.ADCParameters`. These parameters define the
554 steps at which the cartesian cost scaling factor needs to be adjusted.
555 If None is provided, the default values `(None, None)`, i.e. no
556 cartesian cost, will be used. Defaults to None.
557 start_step (int): The current step of training. This argument
558 is important is training is stopped using the scaling cartesian
559 cost. This argument will usually be loaded from a file in the saved model.
561 """
562 if parameters is None:
563 self.p = ADCParameters()
564 else:
565 self.p = parameters
566 super().__init__()
567 self.a, self.b = self.p.cartesian_cost_scale_soft_start
568 self.last_compilation: bool = False
569 self.optimizer = tf.keras.optimizers.Adam(
570 learning_rate=self.p.learning_rate, clipvalue=1.0
571 )
572 # use a instance variable for the case the model is reloaded and re-trained.
573 self.current_step = self.p.current_training_step
574 self.current_cartesian_cost_scale = K.variable(
575 0.0, dtype="float32", name="current_cartesian_cost_scale"
576 )
577 K.set_value(
578 self.current_cartesian_cost_scale,
579 self.calc_current_cartesian_cost_scale(self.current_step),
580 )
582 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None):
583 "Sets the value of the keras backend variable `self.current_cartesian_cost_scale`"
584 self.current_step += 1
585 K.set_value(
586 self.current_cartesian_cost_scale,
587 self.calc_current_cartesian_cost_scale(self.current_step),
588 )
590 def calc_current_cartesian_cost_scale(self, epoch):
591 """Calculates the current cartesian distance scale, based on the parameters
592 `self.a`, `self.b` `self.p.cartesian_cost_scale`.
593 """
594 assert isinstance(epoch, int)
595 if self.p.cartesian_cost_scale is not None:
596 if self.p.cartesian_cost_scale_soft_start[0] is None or epoch is None:
597 scale = self.p.cartesian_cost_scale
598 else:
599 if epoch < self.a:
600 scale = 0.0
601 elif self.a <= epoch <= self.b:
602 scale = (
603 self.p.cartesian_cost_scale
604 / (self.b - self.a)
605 * (epoch - self.a)
606 )
607 else:
608 scale = self.p.cartesian_cost_scale
609 else:
610 scale = 0.0
611 # scale = K.variable(scale, dtype='float32', name='current_cartesian_cost_scale')
612 return scale
615class TensorboardWriteBoolAlwaysFalse(tf.keras.callbacks.Callback):
616 """A tensorboard callback, that is always False. Used for debugging."""
618 def __init__(self) -> None:
619 """Instantiate this class."""
620 self.log_bool = K.variable(False, bool, "log_scalar")
621 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar"))
624class TensorboardWriteBool(tf.keras.callbacks.Callback):
625 """This class saves the value of the keras variable `log_bool`.
627 Based on this variable, stuff will be written to tensorboard, or not.
629 """
631 def __init__(self, parameters: Optional["AnyParameters"] = None) -> None:
632 """Instantiate the class.
634 Args:
635 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters, None], optional):
636 Parameters that will be used to check when data should be written to tensorboard. If None is passed
637 default values (check them with print(em.ADCParameters.defaults_description())) will be used.
638 Defaults to None.
640 """
641 if parameters is None:
642 self.p = Parameters()
643 else:
644 self.p = parameters
645 super().__init__()
646 self.log_bool = K.variable(False, bool, "log_scalar")
647 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar"))
648 self.current_training_step = 0
650 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None:
651 """Sets the value of the keras backend variable `log_bool`.
653 This method does not use the `batch` argument because the variable
654 `self.current_training_step` is used.
656 """
657 self.current_training_step += 1
658 if self.p.tensorboard:
659 if self.current_training_step % self.p.summary_step == 0:
660 K.set_value(self.log_bool, K.variable(True, bool, "log_scalar"))
661 else:
662 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar"))