Coverage for encodermap/callbacks/callbacks.py: 14%

240 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-15 21:06 +0200

1# -*- coding: utf-8 -*- 

2# encodermap/callbacks/callbacks.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Callbacks to strew into the Autoencoder classes.""" 

23 

24 

25################################################################################ 

26# Imports 

27################################################################################ 

28 

29 

30# Future Imports at the top 

31from __future__ import annotations 

32 

33# Standard Library Imports 

34from collections.abc import Callable 

35from pathlib import Path 

36from typing import TYPE_CHECKING, Literal, Optional, Union 

37 

38# Third Party Imports 

39import matplotlib 

40import matplotlib.pyplot as plt 

41import numpy as np 

42import tensorflow as tf 

43import tensorflow.keras.backend as K 

44from PIL import Image 

45from tqdm import tqdm 

46 

47# Encodermap imports 

48from encodermap.misc.saving_loading_models import save_model 

49from encodermap.misc.summaries import image_summary 

50from encodermap.parameters.parameters import ADCParameters, AnyParameters, Parameters 

51 

52 

53################################################################################ 

54# Globals 

55################################################################################ 

56 

57 

58__all__: list[str] = [ 

59 "ProgressBar", 

60 "EarlyStop", 

61 "CheckpointSaver", 

62 "TensorboardWriteBool", 

63 "IncreaseCartesianCost", 

64 "NoneInterruptCallback", 

65 "ImageCallback", 

66 "EncoderMapBaseCallback", 

67] 

68 

69 

70################################################################################ 

71# Helpers 

72################################################################################ 

73 

74 

75def np_to_sparse_tensor(a: np.ndarray) -> tf.sparse.SparseTensor: 

76 indices = np.stack(np.where(~np.isnan(a))).T.astype("int64") 

77 dense_shape = a.shape 

78 a = a[~np.isnan(a)].flatten() 

79 return tf.sparse.SparseTensor(indices, a, dense_shape) 

80 

81 

82################################################################################ 

83# Public Classes 

84################################################################################ 

85 

86 

87class NoneInterruptCallback(tf.keras.callbacks.Callback): 

88 """A callback that interrupts training, when NaN is encountered in weights.""" 

89 

90 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

91 """Gets the current loss at the end of the batch compares it to previous batches.""" 

92 for w in self.model.get_weights(): 

93 if not isinstance(w, np.ndarray): 

94 continue 

95 if np.any(np.isnan(w)): 

96 print( 

97 f"At batch {self.model._my_train_counter.numpy()}, the " 

98 f"model has NaNs in one of its weights. Because " 

99 f"multiplication with NaN yields NaN, this NaN value will now " 

100 f"propagate through the network until all weights are tensors of " 

101 f"NaNs. I stopped the training at this point, as further training is " 

102 f"pointless. This error might originate from your input. You " 

103 f"can run the training with `deterministic=True` and check " 

104 f"whether this problems happens at the same training step " 

105 f"for multiple trainings. If yes, your input contains NaNs. " 

106 f"If no, you can try to lower the learning rate." 

107 ) 

108 self.model.stop_training = True 

109 break 

110 

111 

112class EncoderMapBaseCallback(tf.keras.callbacks.Callback): 

113 """Base class for callbacks in EncoderMap. 

114 

115 The `Parameters` class in EncoderMap has a `summary_step` variable that 

116 dictates when variables and other tensors are logged to TensorBoard. No 

117 matter what property is logged there will always be a code section 

118 executing a `if train_step % summary_step == 0` code snippet. This is 

119 handled centrally in this class. This class is instantiated inside the 

120 user-facing `AutoEncoderClass` classes and is provided with the appropriate 

121 parameters (`Parameters` for `EncoderMap` and `ADCParameters` for 

122 `AngleDihedralCartesianEncoderMap`). Thus, subclassing this class does not 

123 need to implement a new `__init__` method. Only the `on_summary_step` and 

124 the `on_checkpoint_step` methods need to be implemented for sub-classes 

125 if this class with code that should happen when these events happen. 

126 

127 Examples: 

128 

129 In this example, the `on_summary_step` method causes an exception. 

130 

131 >>> from typing import Optional 

132 >>> import encodermap as em 

133 ... 

134 >>> class MyCallback(em.callbacks.EncoderMapBaseCallback): 

135 ... def on_summary_step(self, step: int, logs: Optional[dict] = None) -> None: 

136 ... raise Exception(f"Summary step {self.steps_counter} has been reached.") 

137 ... 

138 >>> emap = em.EncoderMap() # doctest: +ELLIPSIS 

139 Output... 

140 >>> emap.add_callback(MyCallback) 

141 >>> emap.train() # doctest: +ELLIPSIS, +IGNORE_EXCEPTION_DETAIL 

142 Traceback (most recent call last): 

143 ... 

144 Exception: Summary step 10 has been reached. 

145 

146 

147 Attributes: 

148 steps_counter (int): The current step counter. Increases every `on_train_batch_end`. 

149 p (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]: 

150 The parameters for this callback. Based on the `summary_step` and 

151 `checkpoint_step` of the `encodermap.parameters.Parameters` class different 

152 class-methods are called. 

153 

154 """ 

155 

156 def __init__(self, parameters: Optional["AnyParameters"] = None) -> None: 

157 """Instantiate the EncoderMapBaseCallback class. 

158 

159 Args: 

160 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters, None], optional): 

161 Parameters that will be used to print out the progress bar. If None is passed 

162 default values (check them with print(em.ADCParameters.defaults_description())) will be used. 

163 Defaults to None. 

164 

165 """ 

166 if parameters is None: 

167 self.p = Parameters() 

168 else: 

169 self.p = parameters 

170 super().__init__() 

171 self.steps_counter = 0 

172 

173 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

174 """Called after a batch ends. The number of batch is provided by keras. 

175 

176 This method is the backbone of all of EncoderMap's callbacks. After 

177 every batch is method is called by keras. When the number of that 

178 batch matches either `encodermap.Parameters.summary_step` or `encodermap.Parameters.checkpoint_step` 

179 the code on `self.on_summary_step`, or `self.on_checkpoint_step` is 

180 executed. These methods should be overwritten by child classes. 

181 

182 Args: 

183 batch (int): The number of the current batch. Provided by keras. 

184 logs (Optional[dict]): `logs` is a dict containing the metrics results. 

185 

186 """ 

187 self.steps_counter += 1 

188 if self.steps_counter % self.p.checkpoint_step == 0: 

189 self.on_checkpoint_step(self.steps_counter, logs=logs) 

190 if self.steps_counter % self.p.summary_step == 0: 

191 self.on_summary_step(self.steps_counter, logs=logs) 

192 

193 def on_summary_step(self, step: int, logs: Optional[dict] = None) -> None: 

194 """Executed, when the currently finished batch matches `encodermap.Parameters.summary_step` 

195 

196 Args: 

197 step (int): The number of the current step. 

198 logs (Optional[dict]): `logs` is a dict containing the metrics results. 

199 

200 """ 

201 pass 

202 

203 def on_checkpoint_step(self, step: int, logs: Optional[dict] = None) -> None: 

204 """Executed, when the currently finished batch matches `encodermap.Parameters.checkpoint_step` 

205 

206 Args: 

207 step (int): The number of the current step. 

208 logs (Optional[dict]): `logs` is a dict containing the metrics results. 

209 

210 """ 

211 pass 

212 

213 

214############################################################################## 

215# Public Classes 

216############################################################################## 

217 

218 

219class EarlyStop(tf.keras.callbacks.Callback): 

220 """Stop training when the loss is at its min, i.e. the loss stops decreasing. 

221 

222 Arguments: 

223 patience (int): Number of epochs to wait after min has been hit. After this 

224 number of no improvement, training stops. 

225 

226 """ 

227 

228 def __init__(self, patience: int = 0) -> None: 

229 """Instantiate the `EarlyStop` class. 

230 

231 Args: 

232 patience (int): Number of training steps to wait after min has been hit. 

233 Training is halted after this number of steps without improvement. 

234 

235 """ 

236 super().__init__() 

237 self.patience = patience 

238 # best_weights to store the weights at which the minimum loss occurs. 

239 self.best_weights = None 

240 

241 def on_train_begin(self, logs: Optional[dict] = None) -> None: 

242 """Sets some attributes at the beginning of training.""" 

243 # The number of epoch it has waited when loss is no longer minimum. 

244 self.wait = 0 

245 # The epoch the training stops at. 

246 self.stopped_batch = 0 

247 # Initialize the best as infinity. 

248 self.best = np.Inf 

249 

250 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

251 """Gets the current loss at the end of the batch compares it to previous batches.""" 

252 current = logs.get("loss") 

253 if np.less(current, self.best): 

254 self.best = current 

255 self.wait = 0 

256 # Record the best weights if current results is better (less). 

257 self.best_weights = self.model.get_weights() 

258 else: 

259 self.wait += 1 

260 if self.wait >= self.patience: 

261 self.stopped_batch = batch 

262 self.model.stop_training = True 

263 print("Restoring model weights from the end of the best epoch.") 

264 self.model.set_weights(self.best_weights) 

265 

266 def on_train_end(self, logs: Optional[dict] = None) -> None: 

267 """Prints a message after training, if an early stop occured.""" 

268 if self.stopped_batch > 0: 

269 print("Step %05d: early stopping" % (self.stopped_batch + 1)) 

270 

271 

272class ProgressBar(EncoderMapBaseCallback): 

273 """Progressbar Callback. Mix in with model.fit() and make sure to set verbosity to zero.""" 

274 

275 def on_train_begin(self, logs: Optional[dict] = None) -> None: 

276 """Simply creates the progressbar once training starts.""" 

277 self.pbar = tqdm( 

278 total=self.p.n_steps, 

279 initial=self.p.current_training_step, 

280 position=0, 

281 leave=True, 

282 ) 

283 postfix = {f"Loss after step ?": "?"} 

284 if isinstance(self.p, ADCParameters): 

285 postfix["Cartesian cost scale"] = "?" 

286 self.pbar.set_postfix(postfix) 

287 

288 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

289 """Overwrites the parent class' `on_train_batch_end` and adds a progress-bar update.""" 

290 super().on_train_batch_end(batch, logs=logs) 

291 self.pbar.update() 

292 

293 def on_summary_step(self, epoch: int, logs: Optional[dict] = None) -> None: 

294 """Update the progress bar after an epoch with the current loss. 

295 

296 Args: 

297 epoch(int): Current epoch. Will be automatically passed by tensorflow. 

298 logs (Optional[dict]): Also automatically passed by tensorflow. 

299 Contains metrics and losses. logs['loss'] will be written to the progress bar. 

300 

301 """ 

302 if logs != {}: 

303 postfix = {f"Loss after step {epoch}": logs["loss"]} 

304 else: 

305 postfix = {f"Loss after step {epoch}": "?"} 

306 epoch += self.p.current_training_step 

307 if isinstance(self.p, ADCParameters): 

308 if self.p.cartesian_cost_scale_soft_start != (None, None): 

309 if self.p.cartesian_cost_scale is not None: 

310 if ( 

311 self.p.cartesian_cost_scale_soft_start[0] is None 

312 or epoch is None 

313 ): 

314 scale = self.p.cartesian_cost_scale 

315 else: 

316 a, b = self.p.cartesian_cost_scale_soft_start 

317 if epoch < a: 

318 scale = 0 

319 elif a <= epoch <= b: 

320 scale = self.p.cartesian_cost_scale / (b - a) * (epoch - a) 

321 else: 

322 scale = self.p.cartesian_cost_scale 

323 else: 

324 scale = 0 

325 else: 

326 scale = self.p.cartesian_cost_scale 

327 postfix["Cartesian cost scale"] = np.round(scale, 2) 

328 self.pbar.set_postfix(postfix) 

329 

330 def on_train_end(self, logs: Optional[dict] = None) -> None: 

331 """Close the Progress Bar""" 

332 self.pbar.close() 

333 

334 

335class ImageCallback(tf.keras.callbacks.Callback): 

336 """Writes images to tensorboard.""" 

337 

338 def __init__( 

339 self, 

340 parameters: AnyParameters, 

341 highd_data: np.ndarray, 

342 image_step: int, 

343 backend: Literal["matplotlib", "plotly"] = "matplotlib", 

344 mpl_scatter_kws: Optional[dict] = None, 

345 mpl_hist_kws: Optional[dict] = None, 

346 plotly_scatter_kws: Optional[dict] = None, 

347 plotly_hist_kws: Optional[dict] = None, 

348 additional_fns: Optional[list[Callable]] = None, 

349 when: Literal["batch", "epoch"] = "batch", 

350 save_dir: Optional[Union[str, Path]] = None, 

351 ): 

352 """Instantiate the ImageCallback. 

353 

354 This class uses `encodermap.misc.summaries` to plot images to tensorboard. 

355 

356 Args: 

357 highd_data (np.ndarray): The high-dimensional data, that will be provided 

358 to the plotting functions. 

359 image_step (int): When to plot the images. This argument is combined 

360 with the `when` argument to either plot every "epoch" % `image_step` == 0 

361 or every "batch" % `image_step` == 0 steps. 

362 backend (Literal["matplotlib", "plotly"]: Which backend to use for 

363 plotting. Defaults to "matplotlib". 

364 mpl_scatter_kws (Optional[dict]): A dictionary, that `matplotlib.pyplot.scatter` 

365 takes as keyword args. If None is provided, the default dict 

366 is {"s": 20}. Defaults to None. 

367 mpl_hist_kws (Optional[dict]): A dictionary, that `matplotlib.pyplot.histogram` 

368 takes as keyword args. If None is provided, the default dict 

369 is {"bins": 50}. Defaults to None. 

370 plotly_scatter_kws (Optional[dict[str, Any]]): A dict with items that 

371 `plotly.express.scatter()` will accept. If None is provided, 

372 a dict with size 20 will be passed to 

373 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`, 

374 which sets an appropriate size of scatter points for the size of 

375 datasets encodermap is usually used for. 

376 plotly_hist_kws (Optional[dict[str, Any]]): A dict with items that 

377 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a 

378 dict with bins 50 will be passed to 

379 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`. 

380 You can choose a colormap here by providing `{'bins': 50, 'cmap': 

381 'plasma'}` for this argument. 

382 additional_fns (Optional[list[Callabe]]): Can be None or a list 

383 of functions, that return `io.BytesIO()` buffered images (see 

384 Example). 

385 when (Literal["batch", "epoch"]): When to plot the images. Works in 

386 conjunction with the argument `image_step`. 

387 save_dir (Optional[Union[str, Path]]): When specified, images are 

388 saved to the specified directory during training. 

389 

390 Here's an example of how to use this class: 

391 .. code-block:: python 

392 def return_hist(data, hist_kws): 

393 plt.close("all") 

394 matplotlib.use("Agg") # overwrites current backend of notebook 

395 plt.figure() 

396 plt.hist2d(*data.T, **hist_kws) 

397 buf = io.BytesIO() 

398 plt.savefig(buf, format="png") 

399 buf.seek(0) 

400 image = tf.image.decode_png(buf.getvalue(), 4) 

401 image = tf.expand_dims(image, 0) 

402 return image 

403 

404 """ 

405 super().__init__() 

406 self.backend = backend 

407 self.p = parameters 

408 self.highd_data = highd_data 

409 self.image_step = image_step 

410 if self.backend == "matplotlib": 

411 scatter_kws = mpl_scatter_kws 

412 hist_kws = mpl_hist_kws 

413 if scatter_kws is None: 

414 scatter_kws = {"s": 20} 

415 if hist_kws is None: 

416 hist_kws = {"bins": 50} 

417 elif self.backend == "plotly": 

418 scatter_kws = plotly_scatter_kws 

419 hist_kws = plotly_hist_kws 

420 if scatter_kws is None: 

421 scatter_kws = {"size_max": 1, "opacity": 0.2} 

422 if hist_kws is None: 

423 hist_kws = {"bins": 50} 

424 else: 

425 raise Exception( 

426 f"Argument `backend` must be either 'plotly' or 'matplotlib'." 

427 ) 

428 self.scatter_kws = scatter_kws 

429 self.hist_kws = hist_kws 

430 self.additional_fns = additional_fns 

431 self.when = when 

432 self.save_dir = save_dir 

433 if self.save_dir is not None: 

434 self.save_dir = Path(save_dir) 

435 

436 def get_lowd(self): 

437 if isinstance(self.highd_data, (list, tuple)): 

438 if self.highd_data[0].shape[0] * self.highd_data[0].shape[1] > 100_000: 

439 indices = np.split( 

440 np.arange(self.highd_data[0].shape[0]), 

441 np.arange(100, self.highd_data[0].shape[0], 100), 

442 ) 

443 if len(indices[-1]) == 1: 

444 indices = np.split( 

445 np.arange(self.highd_data[0].shape[0]), 

446 np.arange(100, self.highd_data[0].shape[0], 101), 

447 ) 

448 lowd = [] 

449 for i, ind in enumerate(indices): 

450 data = [] 

451 for d in self.highd_data: 

452 if isinstance(d, tf.sparse.SparseTensor): 

453 d = tf.sparse.to_dense(d, default_value=np.nan).numpy()[ind] 

454 data.append(np_to_sparse_tensor(d)) 

455 else: 

456 data.append(d[ind]) 

457 lowd.append(self.model.encoder(data).numpy()) 

458 return np.vstack(lowd) 

459 lowd = self.model.encoder(self.highd_data).numpy() 

460 return lowd 

461 

462 def on_epoch_end(self, epoch: int, logs: Optional[dict] = None) -> None: 

463 """Calls `encodermap.misc.summaries.image_summary` on epoch end.""" 

464 if self.when == "epoch": 

465 if self.image_step != 0: 

466 if epoch % self.image_step != 0: 

467 return 

468 lowd = self.get_lowd() 

469 assert lowd is not None 

470 image_summary( 

471 lowd=lowd, 

472 step=epoch, 

473 scatter_kws=self.scatter_kws, 

474 hist_kws=self.hist_kws, 

475 additional_fns=self.additional_fns, 

476 backend=self.backend, 

477 ) 

478 if self.save_dir is not None: 

479 self.save_image_to_dir(lowd, epoch) 

480 if isinstance(self.p, ADCParameters): 

481 if self.p.track_RMSD and self.save_dir is not None: 

482 rmsds = self.model.compiled_metrics._metrics[-1].result().numpy() 

483 np.save(self.save_dir / f"rmsds_epoch_{epoch}.npy", rmsds) 

484 

485 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

486 """Calls `encodermap.misc.summaries.image_summary` on batch end.""" 

487 if self.when == "batch": 

488 if self.image_step != 0: 

489 if batch % self.image_step != 0: 

490 return 

491 lowd = self.get_lowd() 

492 assert lowd is not None 

493 image_summary( 

494 lowd=lowd, 

495 step=batch, 

496 scatter_kws=self.scatter_kws, 

497 hist_kws=self.hist_kws, 

498 additional_fns=self.additional_fns, 

499 backend=self.backend, 

500 ) 

501 if self.save_dir is not None: 

502 self.save_image_to_dir(lowd, batch) 

503 if isinstance(self.p, ADCParameters): 

504 if self.p.track_RMSD and self.save_dir is not None: 

505 rmsds = self.model.compiled_metrics._metrics[-1].result().numpy() 

506 np.save(self.save_dir / f"rmsds_batch_{batch}.npy", rmsds) 

507 

508 def save_image_to_dir(self, lowd: np.ndarray, step: int) -> None: 

509 """Saves the lowd representation to disk, so it can be looked at later.""" 

510 outfile = self.save_dir / f"{self.when}_{step}.png" 

511 if not np.any(np.isnan(lowd)): 

512 plt.close("all") 

513 matplotlib.use("Agg") # overwrites current backend of notebook 

514 plt.figure() 

515 plt.hist2d(*lowd.T, **self.hist_kws) 

516 plt.savefig(outfile, format="png") 

517 if self.additional_fns is not None: 

518 for i, fn in enumerate(self.additional_fns): 

519 image = Image.fromarray(fn(lowd).numpy()[0].astype("uint8")) 

520 image.save(outfile.with_stem(f"{self.when}_{step}_user_fn_{i}")) 

521 outfile = outfile.with_suffix(".npy") 

522 np.save(outfile, lowd) 

523 

524 

525class CheckpointSaver(EncoderMapBaseCallback): 

526 """Callback that saves an `encodermap.models` model.""" 

527 

528 def on_checkpoint_step(self, epoch: int, logs: Optional[dict] = None) -> None: 

529 """Overwrites parent class' `on_checkpoint_step` method. 

530 

531 Uses `encodermap.misc.saving_loading_models.save_model` to save the model. 

532 Luckily, the keras callbacks contain the model as an attribute (self.model). 

533 

534 """ 

535 save_model(self.model, self.p.main_path, step=epoch) 

536 

537 

538class IncreaseCartesianCost(tf.keras.callbacks.Callback): 

539 """Callback for the `enocdermap.autoencoder.AngleDihedralCarteisanEncoderMap`. 

540 

541 This callback implements the soft-start of the cartesian cost. 

542 

543 """ 

544 

545 def __init__( 

546 self, 

547 parameters: Optional[ADCParameters] = None, 

548 ) -> None: 

549 """Instantiate the callback. 

550 

551 Args: 

552 parameters (Optional[encodermap.parameters.ADCParameters]: Can be either None, or an instance 

553 of `encodermap.parameters.ADCParameters`. These parameters define the 

554 steps at which the cartesian cost scaling factor needs to be adjusted. 

555 If None is provided, the default values `(None, None)`, i.e. no 

556 cartesian cost, will be used. Defaults to None. 

557 start_step (int): The current step of training. This argument 

558 is important is training is stopped using the scaling cartesian 

559 cost. This argument will usually be loaded from a file in the saved model. 

560 

561 """ 

562 if parameters is None: 

563 self.p = ADCParameters() 

564 else: 

565 self.p = parameters 

566 super().__init__() 

567 self.a, self.b = self.p.cartesian_cost_scale_soft_start 

568 self.last_compilation: bool = False 

569 self.optimizer = tf.keras.optimizers.Adam( 

570 learning_rate=self.p.learning_rate, clipvalue=1.0 

571 ) 

572 # use a instance variable for the case the model is reloaded and re-trained. 

573 self.current_step = self.p.current_training_step 

574 self.current_cartesian_cost_scale = K.variable( 

575 0.0, dtype="float32", name="current_cartesian_cost_scale" 

576 ) 

577 K.set_value( 

578 self.current_cartesian_cost_scale, 

579 self.calc_current_cartesian_cost_scale(self.current_step), 

580 ) 

581 

582 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None): 

583 "Sets the value of the keras backend variable `self.current_cartesian_cost_scale`" 

584 self.current_step += 1 

585 K.set_value( 

586 self.current_cartesian_cost_scale, 

587 self.calc_current_cartesian_cost_scale(self.current_step), 

588 ) 

589 

590 def calc_current_cartesian_cost_scale(self, epoch): 

591 """Calculates the current cartesian distance scale, based on the parameters 

592 `self.a`, `self.b` `self.p.cartesian_cost_scale`. 

593 """ 

594 assert isinstance(epoch, int) 

595 if self.p.cartesian_cost_scale is not None: 

596 if self.p.cartesian_cost_scale_soft_start[0] is None or epoch is None: 

597 scale = self.p.cartesian_cost_scale 

598 else: 

599 if epoch < self.a: 

600 scale = 0.0 

601 elif self.a <= epoch <= self.b: 

602 scale = ( 

603 self.p.cartesian_cost_scale 

604 / (self.b - self.a) 

605 * (epoch - self.a) 

606 ) 

607 else: 

608 scale = self.p.cartesian_cost_scale 

609 else: 

610 scale = 0.0 

611 # scale = K.variable(scale, dtype='float32', name='current_cartesian_cost_scale') 

612 return scale 

613 

614 

615class TensorboardWriteBoolAlwaysFalse(tf.keras.callbacks.Callback): 

616 """A tensorboard callback, that is always False. Used for debugging.""" 

617 

618 def __init__(self) -> None: 

619 """Instantiate this class.""" 

620 self.log_bool = K.variable(False, bool, "log_scalar") 

621 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar")) 

622 

623 

624class TensorboardWriteBool(tf.keras.callbacks.Callback): 

625 """This class saves the value of the keras variable `log_bool`. 

626 

627 Based on this variable, stuff will be written to tensorboard, or not. 

628 

629 """ 

630 

631 def __init__(self, parameters: Optional["AnyParameters"] = None) -> None: 

632 """Instantiate the class. 

633 

634 Args: 

635 parameters (Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters, None], optional): 

636 Parameters that will be used to check when data should be written to tensorboard. If None is passed 

637 default values (check them with print(em.ADCParameters.defaults_description())) will be used. 

638 Defaults to None. 

639 

640 """ 

641 if parameters is None: 

642 self.p = Parameters() 

643 else: 

644 self.p = parameters 

645 super().__init__() 

646 self.log_bool = K.variable(False, bool, "log_scalar") 

647 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar")) 

648 self.current_training_step = 0 

649 

650 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

651 """Sets the value of the keras backend variable `log_bool`. 

652 

653 This method does not use the `batch` argument because the variable 

654 `self.current_training_step` is used. 

655 

656 """ 

657 self.current_training_step += 1 

658 if self.p.tensorboard: 

659 if self.current_training_step % self.p.summary_step == 0: 

660 K.set_value(self.log_bool, K.variable(True, bool, "log_scalar")) 

661 else: 

662 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar"))