Coverage for encodermap/callbacks/callbacks.py: 33%

155 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/callbacks/callbacks.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Callbacks to strew into the Autoencoder classes.""" 

23 

24 

25################################################################################ 

26# Imports 

27################################################################################ 

28 

29 

30from __future__ import annotations 

31 

32from typing import TYPE_CHECKING, Callable, Literal, Optional 

33 

34import numpy as np 

35import tensorflow as tf 

36import tensorflow.keras.backend as K 

37from tqdm import tqdm 

38 

39from ..misc.saving_loading_models import save_model 

40from ..misc.summaries import image_summary 

41from ..parameters.parameters import ADCParameters, Parameters 

42 

43################################################################################ 

44# Typing 

45################################################################################ 

46 

47 

48if TYPE_CHECKING: 

49 from encodermap._typing import AnyParameters 

50 

51 

52################################################################################ 

53# Globals 

54################################################################################ 

55 

56 

57__all__ = [ 

58 "ProgressBar", 

59 "EarlyStop", 

60 "CheckpointSaver", 

61 "TensorboardWriteBool", 

62 "IncreaseCartesianCost", 

63] 

64 

65 

66################################################################################ 

67# Public Classes 

68################################################################################ 

69 

70 

71class EncoderMapBaseCallback(tf.keras.callbacks.Callback): 

72 """Base class for multiple callbacks. 

73 

74 Can be used to implement new callbacks that can also use `enocdermap.Parameters` 

75 classes. A counter is increased after a tran_batch is finished. Based on the 

76 two attributes `summary_step` and `checkpoint_step` in the `encodermap.Parameters` 

77 classes the corresponding methods are called. Has two class attributes that are important: 

78 

79 Attributes: 

80 steps_counter (int): The current step counter. Increases every `on_train_batch_end`. 

81 p (Union[encodermap.Parameters, encodermap.ADCParameters]: The parameters for this callback. 

82 Based on the `summary_step` and `checkpoint_step` of this parameters class different 

83 class-methods are called. 

84 

85 """ 

86 

87 def __init__(self, parameters: Optional[AnyParameters] = None) -> None: 

88 """Instantiate the EncoderMapBaseCallback class. 

89 

90 Args: 

91 parameters (Union[encodermap.Parameters, encodermap.ADCParameters, None], optional): 

92 Parameters that will be used to print out the progress bar. If None is passed 

93 default values (check them with print(em.ADCParameters.defaults_description())) will be used. 

94 Defaults to None. 

95 

96 """ 

97 if parameters is None: 97 ↛ 98line 97 didn't jump to line 98, because the condition on line 97 was never true

98 self.p = Parameters() 

99 else: 

100 self.p = parameters 

101 super().__init__() 

102 self.steps_counter = 0 

103 

104 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

105 """Called after a batch ends. The number of batch is provided by keras. 

106 

107 This method is the backbone of all of encodermap's callbacks. After 

108 every batch is method is called by keras. When the number of that 

109 batch matches either `encodermap.Parameters.summary_step` or `encodermap.Parameters.checkpoint_step` 

110 the code on `self.on_summary_step`, or `self.on_checkpoint_step` is 

111 executed. These methods should be overwritten by child classes. 

112 

113 Args: 

114 batch (int): The number of the current batch. Provided by keras. 

115 logs (Optional[dict]): `logs` is a dict containing the metrics results. 

116 

117 """ 

118 self.steps_counter += 1 

119 if self.steps_counter % self.p.checkpoint_step == 0: 

120 self.on_checkpoint_step(self.steps_counter, logs=logs) 

121 if self.steps_counter % self.p.summary_step == 0: 

122 self.on_summary_step(self.steps_counter, logs=logs) 

123 

124 def on_summary_step(self, step: int, logs: Optional[dict] = None) -> None: 

125 """Executed, when the currently finished batch matches `encodermap.Parameters.summary_step` 

126 

127 Args: 

128 step (int): The number of the current step. 

129 logs (Optional[dict]): `logs` is a dict containing the metrics results. 

130 

131 """ 

132 pass 

133 

134 def on_checkpoint_step(self, step: int, logs: Optional[dict] = None) -> None: 

135 """Executed, when the currently finished batch matches `encodermap.Parameters.checkpoint_step` 

136 

137 Args: 

138 step (int): The number of the current step. 

139 logs (Optional[dict]): `logs` is a dict containing the metrics results. 

140 

141 """ 

142 pass 

143 

144 

145############################################################################## 

146# Public Classes 

147############################################################################## 

148 

149 

150class EarlyStop(tf.keras.callbacks.Callback): 

151 """Stop training when the loss is at its min, i.e. the loss stops decreasing. 

152 

153 Arguments: 

154 patience (int): Number of epochs to wait after min has been hit. After this 

155 number of no improvement, training stops. 

156 

157 """ 

158 

159 def __init__(self, patience: int = 0) -> None: 

160 """Instantiate the `EarlyStop` class. 

161 

162 Args: 

163 patience (int): Number of training steps to wait after min has been hit. 

164 Training is halted after this number of steps without improvement. 

165 

166 """ 

167 super().__init__() 

168 self.patience = patience 

169 # best_weights to store the weights at which the minimum loss occurs. 

170 self.best_weights = None 

171 

172 def on_train_begin(self, logs: Optional[dict] = None) -> None: 

173 """Sets some attributes at the beginning of training.""" 

174 # The number of epoch it has waited when loss is no longer minimum. 

175 self.wait = 0 

176 # The epoch the training stops at. 

177 self.stopped_batch = 0 

178 # Initialize the best as infinity. 

179 self.best = np.Inf 

180 

181 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

182 """Gets the current loss at the end of the batch compares it to previous batches.""" 

183 current = logs.get("loss") 

184 if np.less(current, self.best): 

185 self.best = current 

186 self.wait = 0 

187 # Record the best weights if current results is better (less). 

188 self.best_weights = self.model.get_weights() 

189 else: 

190 self.wait += 1 

191 if self.wait >= self.patience: 

192 self.stopped_batch = batch 

193 self.model.stop_training = True 

194 print("Restoring model weights from the end of the best epoch.") 

195 self.model.set_weights(self.best_weights) 

196 

197 def on_train_end(self, logs: Optional[dict] = None) -> None: 

198 """Prints a message after training, if an early stop occured.""" 

199 if self.stopped_batch > 0: 

200 print("Step %05d: early stopping" % (self.stopped_batch + 1)) 

201 

202 

203class ProgressBar(EncoderMapBaseCallback): 

204 """Progressbar Callback. Mix in with model.fit() and make sure to set verbosity to zero.""" 

205 

206 def on_train_begin(self, logs: Optional[dict] = None) -> None: 

207 """Simply creates the progressbar once training starts.""" 

208 self.pbar = tqdm(total=self.p.n_steps, position=0, leave=True) 

209 postfix = {f"Loss after step ?": "?"} 

210 if isinstance(self.p, ADCParameters): 

211 postfix["Cartesian cost Scale"] = "?" 

212 self.pbar.set_postfix(postfix) 

213 

214 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

215 """Overwrites the parent class' `on_train_batch_end` and adds a progress-bar update.""" 

216 super().on_train_batch_end(batch, logs=logs) 

217 self.pbar.update() 

218 

219 def on_summary_step(self, epoch: int, logs: Optional[dict] = None) -> None: 

220 """Update the progress bar after an epoch with the current loss. 

221 

222 Args: 

223 epoch(int): Current epoch. Will be automatically passed by tensorflow. 

224 logs (Optional[dict]): Also automatically passed by tensorflow. 

225 Contains metrics and losses. logs['loss'] will be written to the progress bar. 

226 

227 """ 

228 if logs != {}: 

229 postfix = {f"Loss after step {epoch}": logs["loss"]} 

230 if isinstance(self.p, ADCParameters): 

231 if self.p.cartesian_cost_scale_soft_start != (None, None): 

232 if self.p.cartesian_cost_scale is not None: 

233 if ( 

234 self.p.cartesian_cost_scale_soft_start[0] is None 

235 or epoch is None 

236 ): 

237 scale = self.p.cartesian_cost_scale 

238 else: 

239 a, b = self.p.cartesian_cost_scale_soft_start 

240 if epoch < a: 

241 scale = 0 

242 elif a <= epoch <= b: 

243 scale = self.p.cartesian_cost_scale / (b - a) * (epoch - a) 

244 else: 

245 scale = self.p.cartesian_cost_scale 

246 else: 

247 scale = 0 

248 else: 

249 scale = self.p.cartesian_cost_scale 

250 postfix["Cartesian cost Scale"] = np.round(scale, 2) 

251 self.pbar.set_postfix(postfix) 

252 

253 def on_train_end(self, logs: Optional[dict] = None) -> None: 

254 """Close the Progress Bar""" 

255 self.pbar.close() 

256 

257 

258class ImageCallback(tf.keras.callbacks.Callback): 

259 """Writes images to tensoboard.""" 

260 

261 def __init__( 

262 self, 

263 highd_data: np.ndarray, 

264 image_step: int, 

265 scatter_kws: Optional[dict] = None, 

266 hist_kws: Optional[dict] = None, 

267 additional_fns: Optional[list[Callable]] = None, 

268 when: Literal["batch", "epoch"] = "batch", 

269 ): 

270 """Instantiate the ImageCallback. 

271 

272 This class uses `encodermap.misc.summaries` to plot images to tensorboard. 

273 

274 Args: 

275 highd_data (np.ndarray): The high-dimensional data, that will be provided 

276 to the plotting functions. 

277 image_step (int): When to plot the images. This argument is combined 

278 with the `when` argument to either plot every "epoch" % `image_step` == 0 

279 or every "batch" % `image_step` == 0 steps. 

280 scatter_kws (Optional[dict]): A dictionary, that `matploltib.pyplot.scatter` 

281 takes as keyword args. If None is provided, the default dict 

282 is {"s": 20}. Defaults to None. 

283 hist_kws (Optional[dict]): A dictionary, that `matploltib.pyplot.histogram` 

284 takes as keyword args. If None is provided, the default dict 

285 is {"bins": 50}. Defaults to None. 

286 additional_fns (Optional[list[Callabe]]): Can be None or a list 

287 of functions, that return `io.BytesIO()` buffered images (see 

288 Example). 

289 when (Literal["batch", "epoch"]): When to plot the images. Works in 

290 conjunction with the argument `image_step`. 

291 

292 Here's an example how to use this class 

293 .. code-block:: python 

294 def return_hist(data, hist_kws): 

295 plt.close("all") 

296 matplotlib.use("Agg") # overwrites current backend of notebook 

297 plt.figure() 

298 plt.hist2d(*data.T, **hist_kws) 

299 buf = io.BytesIO() 

300 plt.savefig(buf, format="png") 

301 buf.seek(0) 

302 image = tf.image.decode_png(buf.getvalue(), 4) 

303 image = tf.expand_dims(image, 0) 

304 return image 

305 

306 """ 

307 super().__init__() 

308 self.highd_data = highd_data 

309 self.image_step = image_step 

310 if scatter_kws is None: 

311 scatter_kws = {"s": 20} 

312 self.scatter_kws = scatter_kws 

313 if hist_kws is None: 

314 hist_kws = {"bins": 50} 

315 self.hist_kws = hist_kws 

316 self.additional_fns = additional_fns 

317 self.when = when 

318 

319 def on_epoch_end(self, epoch: int, logs: Optional[dict] = None) -> None: 

320 """Calls `encodermap.misc.summaries.image_summary` on epoch end.""" 

321 if epoch % self.image_step == 0 and self.when == "epoch": 

322 lowd = self.model.encoder(self.highd_data).numpy() 

323 assert lowd is not None 

324 image_summary( 

325 lowd, epoch, self.scatter_kws, self.hist_kws, self.additional_fns 

326 ) 

327 

328 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

329 """Calls `encodermap.misc.summaries.image_summary` on batch end.""" 

330 if batch % self.image_step == 0 and self.when == "batch": 

331 lowd = self.model.encoder(self.highd_data).numpy() 

332 assert lowd is not None 

333 image_summary( 

334 lowd, batch, self.scatter_kws, self.hist_kws, self.additional_fns 

335 ) 

336 

337 

338class CheckpointSaver(EncoderMapBaseCallback): 

339 """Callback, that saves an `encodermap.models` model.""" 

340 

341 def on_checkpoint_step(self, epoch: int, logs: Optional[dict] = None) -> None: 

342 """Overwrites parent class' `on_checkpoint_step` method. 

343 

344 Uses `encodermap.misc.saving_loading_models.save_model` to save the model. 

345 Luckily, the keras callbacks contain the model as an attribute (self.model). 

346 

347 """ 

348 save_model(self.model, self.p.main_path, "", step=epoch) 

349 # tf.keras.models.save_model(self.model, f'{self.p.main_path}/keras_model_{epoch}_epochs.model') 

350 

351 

352class IncreaseCartesianCost(tf.keras.callbacks.Callback): 

353 """Callback for the `enocdermap.autoencoder.AngleDihedralCarteisanEncoderMap`. 

354 

355 This callback implements the soft-start of the cartesian cost. 

356 

357 """ 

358 

359 def __init__( 

360 self, parameters: Optional[ADCParameters] = None, start_step: int = 0 

361 ) -> None: 

362 """Instantiate the callback. 

363 

364 Args: 

365 parameters (Optional[ACDParameters]: Can be either None, or an instance 

366 of `encodermap.parameters.ACDParameters`. These parameters define the 

367 steps at which the cartesian cost scaling factor needs to be adjusted. 

368 If None is provided, the default values `(None, None)`, i.e. no 

369 cartesian cost, will be used. Deafults to None. 

370 start_step (int): The current step of the training. This argument 

371 is important is training is stopped using the scaling cartesian 

372 cost. This argument will usually be loaded from a file in the saved model. 

373 

374 """ 

375 if parameters is None: 375 ↛ 376line 375 didn't jump to line 376, because the condition on line 375 was never true

376 self.p = ADCParameters() 

377 else: 

378 self.p = parameters 

379 super().__init__() 

380 self.a, self.b = self.p.cartesian_cost_scale_soft_start 

381 self.last_compilation: bool = False 

382 self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.p.learning_rate) 

383 # use a instance variable for the case the model is reloaded and re-trained. 

384 self.current_step = start_step 

385 self.current_cartesian_cost_scale = K.variable( 

386 0.0, dtype="float32", name="current_cartesian_cost_scale" 

387 ) 

388 K.set_value( 

389 self.current_cartesian_cost_scale, 

390 self.calc_current_cartesian_cost_scale(self.current_step), 

391 ) 

392 

393 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None): 

394 "Sets the value of the keras backend variable `self.current_cartesian_cost_scale`" 

395 self.current_step += 1 

396 K.set_value( 

397 self.current_cartesian_cost_scale, 

398 self.calc_current_cartesian_cost_scale(self.current_step), 

399 ) 

400 

401 def calc_current_cartesian_cost_scale(self, epoch): 

402 """Calculates the current cartesian distance scale, based on the parameters 

403 `self.a`, `self.b` `self.p.cartesian_cost_scale`. 

404 """ 

405 assert isinstance(epoch, int) 

406 if self.p.cartesian_cost_scale is not None: 406 ↛ 421line 406 didn't jump to line 421, because the condition on line 406 was never false

407 if self.p.cartesian_cost_scale_soft_start[0] is None or epoch is None: 407 ↛ 408line 407 didn't jump to line 408, because the condition on line 407 was never true

408 scale = self.p.cartesian_cost_scale 

409 else: 

410 if epoch < self.a: 

411 scale = 0.0 

412 elif self.a <= epoch <= self.b: 412 ↛ 419line 412 didn't jump to line 419, because the condition on line 412 was never false

413 scale = ( 

414 self.p.cartesian_cost_scale 

415 / (self.b - self.a) 

416 * (epoch - self.a) 

417 ) 

418 else: 

419 scale = self.p.cartesian_cost_scale 

420 else: 

421 scale = 0.0 

422 # scale = K.variable(scale, dtype='float32', name='current_cartesian_cost_scale') 

423 return scale 

424 

425 

426class TensorboardWriteBoolAlwaysFalse(tf.keras.callbacks.Callback): 

427 """A tensorboard callback, that is always False. Used for debugging.""" 

428 

429 def __init__(self) -> None: 

430 """Instantiate this class.""" 

431 self.log_bool = K.variable(False, bool, "log_scalar") 

432 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar")) 

433 

434 

435class TensorboardWriteBool(tf.keras.callbacks.Callback): 

436 """This class saves the value of the keras variable `log_bool`. 

437 

438 Based on this variable, stuff will be written to tensorboard, or not. 

439 

440 """ 

441 

442 def __init__(self, parameters: Optional[AnyParameters] = None) -> None: 

443 """Instantiate the class. 

444 

445 Args: 

446 parameters (Union[encodermap.Parameters, encodermap.ADCParameters, None], optional): 

447 Parameters that will be used check when data should be written to tensorboard. If None is passed 

448 default values (check them with print(em.ADCParameters.defaults_description())) will be used. 

449 Defaults to None. 

450 

451 """ 

452 if parameters is None: 

453 self.p = Parameters() 

454 else: 

455 self.p = parameters 

456 super().__init__() 

457 self.log_bool = K.variable(False, bool, "log_scalar") 

458 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar")) 

459 self.current_training_step = 0 

460 

461 def on_train_batch_end(self, batch: int, logs: Optional[dict] = None) -> None: 

462 """Sets the value of the keras backend variable `log_bool`. 

463 

464 This method does not use the `batch` argument, because, the variable 

465 `self.current_training_step` is used. 

466 

467 """ 

468 self.current_training_step += 1 

469 if self.p.tensorboard: 

470 if self.current_training_step % self.p.summary_step == 0: 

471 K.set_value(self.log_bool, K.variable(True, bool, "log_scalar")) 

472 else: 

473 K.set_value(self.log_bool, K.variable(False, bool, "log_scalar"))