Coverage for encodermap/loss_functions/loss_functions.py: 6%

275 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-15 21:06 +0200

1# -*- coding: utf-8 -*- 

2# encodermap/loss_functions/loss_functions.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Loss functions for encodermap 

23 

24""" 

25################################################################################ 

26# Imports 

27################################################################################ 

28 

29 

30# Future Imports at the top 

31from __future__ import annotations 

32 

33# Standard Library Imports 

34from collections.abc import Callable 

35from typing import TYPE_CHECKING, Any, Optional 

36 

37# Third Party Imports 

38import tensorflow as tf 

39import tensorflow.keras.backend as K 

40 

41# Encodermap imports 

42from encodermap.encodermap_tf1.misc import distance_cost 

43from encodermap.misc.distances import ( 

44 pairwise_dist, 

45 pairwise_dist_periodic, 

46 periodic_distance, 

47 sigmoid, 

48) 

49from encodermap.parameters.parameters import ( 

50 ADCParameters, 

51 AnyParameters, 

52 Parameters, 

53 ParametersFramework, 

54) 

55 

56 

57################################################################################ 

58# Globals 

59################################################################################ 

60 

61 

62__all__: list[str] = [ 

63 "reconstruction_loss", 

64 "auto_loss", 

65 "center_loss", 

66 "regularization_loss", 

67 "loss_combinator", 

68 "distance_loss", 

69 "cartesian_loss", 

70 "cartesian_distance_loss", 

71 "angle_loss", 

72 "dihedral_loss", 

73] 

74 

75 

76################################################################################ 

77# Functions for tf.cond 

78# Don't know if this is really faster than logging every step to tensorboard 

79################################################################################ 

80 

81 

82def _do_nothing(*args: Any, **kwargs: Any) -> None: 

83 """This function does nothing. One of the functions provided to tf.cond.""" 

84 pass 

85 

86 

87def _summary_cost(name: str, cost: tf.Tensor) -> None: 

88 """This functions logs a scalar to a name. One of the functions provided to tf.cond.""" 

89 tf.summary.scalar(name, cost) 

90 

91 

92############################################################################## 

93# Legacy Code to make some tests 

94############################################################################## 

95 

96 

97def old_distance_loss(model, parameters=None): # pragma: no doccheck 

98 # choose parameters 

99 if parameters is None: 

100 p = Parameters() 

101 else: 

102 p = parameters 

103 # check Layers 

104 if len(model.layers) == 2: 

105 # sequential API 

106 latent = model.encoder 

107 else: 

108 # functional API 

109 latent = model.encoder 

110 

111 # closure 

112 def loss(y_true, y_pred=None, step=None): 

113 loss.name = "distance_loss" 

114 y_pred = latent(y_true, training=True) 

115 if p.distance_cost_scale is not None: 

116 dist_cost = distance_cost( 

117 y_true, y_pred, *p.dist_sig_parameters, p.periodicity 

118 ) 

119 dist_cost *= p.distance_cost_scale 

120 else: 

121 dist_cost = 0.0 

122 tf.summary.scalar("Distance Cost", dist_cost) 

123 return dist_cost 

124 

125 return loss 

126 

127 

128############################################################################## 

129# Public Functions 

130############################################################################## 

131 

132 

133def basic_loss_combinator(*losses: Callable) -> Callable: 

134 """Calculates the sum of a list of losses and returns a combined loss. 

135 

136 The basic loss combinator does not write to summary. Can be used for debugging. 

137 

138 """ 

139 

140 def loss(y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None) -> tf.Tensor: 

141 return sum([loss(y_true, y_pred) for loss in losses]) 

142 

143 return loss 

144 

145 

146def loss_combinator(*losses: Callable) -> Callable: 

147 """Calculates the sum of a list of losses and returns a combined loss. 

148 

149 Args: 

150 *losses: Variable length argument list of loss functions. 

151 

152 Returns: 

153 Callable: A combined loss function that can be used in custom training or with model.fit() 

154 

155 Example: 

156 >>> import encodermap as em 

157 >>> from encodermap import loss_functions 

158 >>> import tensorflow as tf 

159 >>> import numpy as np 

160 >>> tf.random.set_seed(1) # fix random state to pass doctest :) 

161 ... 

162 >>> model = tf.keras.Sequential([ 

163 ... tf.keras.layers.Dense(100, kernel_regularizer=tf.keras.regularizers.l2(), activation='relu'), 

164 ... tf.keras.layers.Dense(2, kernel_regularizer=tf.keras.regularizers.l2(), activation='relu'), 

165 ... tf.keras.layers.Dense(100, kernel_regularizer=tf.keras.regularizers.l2(), activation='relu') 

166 ... ]) 

167 ... 

168 >>> # Set up losses and bundle them using the loss combinator 

169 >>> auto_loss = loss_functions.auto_loss(model) 

170 >>> reg_loss = loss_functions.regularization_loss(model) 

171 >>> loss = loss_functions.loss_combinator(auto_loss, reg_loss) 

172 ... 

173 >>> # Compile model, model.fit() usually takes a tuple of (data, classes) but in 

174 >>> # regression learning the data needs to be provided twice. That's why we use fit(data, data) 

175 >>> model.compile(tf.keras.optimizers.Adam(), loss=loss) 

176 >>> data = np.random.random((100, 100)) 

177 >>> history = model.fit(x=data, y=data, verbose=0) 

178 >>> tf.random.set_seed(None) # reset seed 

179 ... 

180 >>> # This weird contraption is also there to make the output predictable and pass tests 

181 >>> # Somehow the tf.random.seed(1) does not work here. :( 

182 >>> loss = history['loss'][0] 

183 >>> print(loss) # doctest: +SKIP 

184 {'loss': array([2.6])} 

185 >>> print(type(loss)) 

186 <class 'float'> 

187 

188 """ 

189 

190 def combined_loss_func( 

191 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None 

192 ) -> tf.Tensor: 

193 cost = sum([loss(y_true, y_pred) for loss in losses]) 

194 tf.summary.scalar("Combined Cost", cost) 

195 return cost 

196 

197 return combined_loss_func 

198 

199 

200def distance_loss( 

201 model: tf.keras.models.Model, 

202 parameters: Optional[AnyParameters] = None, 

203 callback: Optional[tf.keras.callbacks.Callback] = None, 

204) -> Callable: 

205 """Encodermap distance_loss 

206 

207 Transforms space using sigmoid function first proposed by sketch-map. 

208 

209 Args: 

210 model (tf.keras.Model): A model you want to use the loss function on. 

211 parameters (Optional[AnyParameters]): The parameters. If None is 

212 provided default values (check them with `print(em.Parameters.defaults_description())`) 

213 are used. Defaults to None. 

214 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback, 

215 that prevents a tensorboard write when `parameters.summary_step` is 

216 set to greater values. This saves disk-space, as costs are not needed 

217 to be logged every training step. 

218 

219 Note: 

220 If the model contains two layers. The first layer will be assumed to be the decoder. 

221 If the model contains more layers, one layer needs to be named 'latent' (case insensitive). 

222 

223 Raises: 

224 Exception: When no bottleneck/latent layer can be found in the model. 

225 

226 Returns: 

227 Callable: A loss function. 

228 

229 References:: 

230 

231 @article{ceriotti2011simplifying, 

232 title={Simplifying the representation of complex free-energy landscapes using sketch-map}, 

233 author={Ceriotti, Michele and Tribello, Gareth A and Parrinello, Michele}, 

234 journal={Proceedings of the National Academy of Sciences}, 

235 volume={108}, 

236 number={32}, 

237 pages={13023--13028}, 

238 year={2011}, 

239 publisher={National Acad Sciences} 

240 } 

241 

242 """ 

243 # choose parameters 

244 if parameters is None: 

245 p = Parameters() 

246 else: 

247 p = parameters 

248 

249 # check Layers 

250 if len(model.layers) == 2: 

251 # sequential API 

252 latent = model.encoder 

253 else: 

254 # functional API 

255 latent = model.encoder 

256 

257 if callback is None: 

258 write_bool = K.constant(False, "bool", name="log_bool") 

259 else: 

260 write_bool = callback.log_bool 

261 

262 # define dist loss 

263 dist_loss = sigmoid_loss(p) 

264 

265 # closure 

266 def distance_loss_func( 

267 y_true: tf.Tensor, 

268 y_pred: Optional[tf.Tensor] = None, 

269 ) -> tf.Tensor: 

270 """y_true can be whatever input you like, dihedrals, angles, pairwise 

271 dist, contact maps. It will be transformed with Sketch-map's sigmoid 

272 function, as will the output of the latent layer of the autoencoder. 

273 the difference of these two will result in a loss function. 

274 

275 """ 

276 distance_loss_func.name = "distance_loss" 

277 y_pred = latent(y_true, training=True) 

278 # functional model gives a tuple 

279 if isinstance(y_true, tuple): 

280 y_true = tf.concat(y_true[:3], axis=1) 

281 if p.distance_cost_scale is not None: 

282 dist_cost = dist_loss(y_true, y_pred) 

283 dist_cost *= p.distance_cost_scale 

284 

285 else: 

286 dist_cost = 0.0 

287 tf.cond( 

288 write_bool, 

289 true_fn=lambda: _summary_cost("Distance Cost", dist_cost), 

290 false_fn=lambda: _do_nothing(), 

291 name="Cost", 

292 ) 

293 tf.debugging.assert_all_finite( 

294 dist_cost, message=f"Dist cost became infinite or NaN." 

295 ) 

296 return dist_cost 

297 

298 return distance_loss_func 

299 

300 

301def sigmoid_loss( 

302 parameters: Optional[AnyParameters] = None, 

303 periodicity_overwrite: Optional[float] = None, 

304 dist_dig_parameters_overwrite: Optional[tuple[float, ...]] = None, 

305) -> Callable: 

306 """Sigmoid loss closure for use in distance cost and cartesian distance cost. 

307 

308 Outer function prepares callable sigmoid. Sigmoid can then be called with just y_true and y_pred. 

309 

310 Args: 

311 parameters (Optional[AnyParameters]): The parameters. If None is 

312 provided default values (check them with `print(em.Parameters.defaults_description())`) 

313 are used. Defaults to None. 

314 periodicity_overwrite(Optional[float]): Cartesian distance cost is 

315 always non-periodic. To make sure no periodicity is applied to the 

316 data, set periodicity_overwrite to `float('inf')`. If None is provided, 

317 the periodicity of the parameters class (default 2*pi) will be used. 

318 Defaults to None. 

319 dist_dig_parameters_overwrite (Optional[tuple[float, ...]]): Distance costs 

320 for the `AngleDihedralCartesianEncoderMap` class come in two flavors. 

321 The regular distance cost compares the encoder inputs to the latent 

322 and use Sketch-map's sigmoid function to weigh these data accordingly. 

323 The cartesian distance cost, on the other hand, compares the latent 

324 and the pairwise distances of the input CA coordinates. This cost 

325 function uses different sigmoid parameters (because the CA distances 

326 don't lie in a periodic space). The tuple of 6 floats provided for 

327 `dist_dig_parameters_overwrite` will supersede the `dist_sig_parameters` 

328 in the `parameters` argument. Defaults to None. 

329 

330 Returns: 

331 Callable: A function that takes `y_true` and `y_pred`. 

332 Both need to be of the same shape. 

333 

334 """ 

335 if parameters is None: 

336 p = Parameters() 

337 else: 

338 p = parameters 

339 

340 if periodicity_overwrite is not None: 

341 periodicity = periodicity_overwrite 

342 else: 

343 periodicity = p.periodicity 

344 

345 if dist_dig_parameters_overwrite is not None: 

346 dist_sig_parameters = dist_dig_parameters_overwrite 

347 else: 

348 dist_sig_parameters = p.dist_sig_parameters 

349 

350 # @tf.autograph.experimental.do_not_convert 

351 def sigmoid_loss_func(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 

352 r_h = y_true 

353 r_l = y_pred 

354 if periodicity == float("inf"): 

355 dist_h = pairwise_dist(r_h) 

356 else: 

357 dist_h = pairwise_dist_periodic(r_h, periodicity) 

358 dist_l = pairwise_dist(r_l) 

359 

360 sig_h = sigmoid(*dist_sig_parameters[:3])(dist_h) 

361 sig_l = sigmoid(*dist_sig_parameters[3:])(dist_l) 

362 

363 cost = tf.reduce_mean(tf.square(sig_h - sig_l)) 

364 tf.debugging.assert_all_finite( 

365 cost, message=f"Sigmoid cost became infinite or NaN." 

366 ) 

367 return cost 

368 

369 return sigmoid_loss_func 

370 

371 

372def center_loss( 

373 model: tf.keras.models.Model, 

374 parameters: Optional[AnyParameters] = None, 

375 callback: Optional[tf.keras.callbacks.Callback] = None, 

376) -> Callable: 

377 """Encodermap center_loss 

378 

379 Use in custom training loops or in model.fit() training. 

380 

381 Args: 

382 model (tf.keras.Model): A model you want to use the loss function on. 

383 parameters (Optional[AnyParameters]): The parameters. If None is 

384 provided default values (check them with `print(em.Parameters.defaults_description())`) 

385 are used. Defaults to None. 

386 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback, 

387 that prevents a tensorboard write when `parameters.summary_step` is 

388 set to greater values. This saves disk-space, as costs are not needed 

389 to be logged every training step. 

390 

391 Note: 

392 If the model contains two layers. The first layer will be assumed to be 

393 the decoder. If the model contains more layers, one layer needs to 

394 be named 'latent' (case-insensitive). 

395 

396 Raises: 

397 Exception: When no bottleneck/latent layer can be found in the model. 

398 

399 Returns: 

400 Callable: A loss function. 

401 

402 """ 

403 # choose parameters 

404 if parameters is None: 

405 p = Parameters() 

406 else: 

407 p = parameters 

408 # check Layers 

409 if len(model.layers) == 2: 

410 # sequential API 

411 latent = model.encoder 

412 else: 

413 # functional API 

414 latent = model.encoder 

415 

416 if callback is None: 

417 write_bool = K.constant(False, "bool", name="log_bool") 

418 else: 

419 write_bool = callback.log_bool 

420 

421 # closure 

422 def center_loss_func( 

423 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None 

424 ) -> tf.Tensor: 

425 """y_true will not be used in this loss function. y_pred can be supplied, 

426 but if None will be taken from the latent layer. This loss function tries 

427 to center the points in the latent layer. 

428 

429 """ 

430 center_loss_func.name = "center_loss" 

431 y_pred = latent(y_true, training=True) 

432 # functional model gives a tuple 

433 if isinstance(y_true, tuple): 

434 y_true = tf.concat(y_true[:3], axis=1) 

435 if p.center_cost_scale is not None: 

436 center_cost = tf.reduce_mean(tf.square(y_pred)) 

437 center_cost *= p.center_cost_scale 

438 else: 

439 center_cost = 0.0 

440 tf.cond( 

441 write_bool, 

442 true_fn=lambda: _summary_cost("Center Cost", center_cost), 

443 false_fn=lambda: _do_nothing(), 

444 name="Cost", 

445 ) 

446 tf.debugging.assert_all_finite( 

447 center_cost, message=f"Center cost became infinite or NaN." 

448 ) 

449 return center_cost 

450 

451 return center_loss_func 

452 

453 

454def regularization_loss( 

455 model: tf.keras.models.Model, 

456 parameters: Optional[AnyParameters] = None, 

457 callback: Optional[tf.keras.callbacks.Callback] = None, 

458) -> Callable: 

459 """Regularization loss of arbitrary tf.keras.Model 

460 

461 Use in custom training loops or in model.fit() training. 

462 Loss is obtained as tf.math.add_n(model.losses) 

463 

464 Args: 

465 model (tf.keras.Model): A model you want to use the loss function on. 

466 parameters (Optional[AnyParameters]): The parameters. If None is 

467 provided default values (check them with `print(em.Parameters.defaults_description())`) 

468 are used. Defaults to None. 

469 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback, 

470 that prevents a tensorboard write when `parameters.summary_step` is 

471 set to greater values. This saves disk-space, as costs are not needed 

472 to be logged every training step. 

473 

474 Returns: 

475 Callable: A loss function. 

476 

477 """ 

478 if parameters is None: 

479 p = Parameters() 

480 else: 

481 p = parameters 

482 

483 if callback is None: 

484 write_bool = K.constant(False, "bool", name="log_bool") 

485 else: 

486 write_bool = callback.log_bool 

487 

488 def regularization_loss_func( 

489 y_true: Optional[tf.Tensor] = None, y_pred: Optional[tf.Tensor] = None 

490 ) -> tf.Tensor: 

491 """y_true and y_pred will not be considered here, because the 

492 regularization loss is accessed via model.losses. 

493 

494 """ 

495 regularization_loss.name = "regularization_loss" 

496 reg_loss = tf.math.add_n(model.losses) 

497 tf.cond( 

498 write_bool, 

499 true_fn=lambda: _summary_cost("Regularization Cost", reg_loss), 

500 false_fn=lambda: _do_nothing(), 

501 name="Cost", 

502 ) 

503 tf.debugging.assert_all_finite( 

504 reg_loss, message=f"Regularization cost became infinite or NaN." 

505 ) 

506 return reg_loss 

507 

508 return regularization_loss_func 

509 

510 

511def reconstruction_loss(model: tf.keras.models.Model) -> Callable: 

512 """Simple Autoencoder recosntruction loss. 

513 

514 Use in custom training loops or in model.fit training. 

515 

516 Args: 

517 model (tf.keras.Model): A model you want to use the loss function on. 

518 

519 Returns: 

520 Callable: A loss function to be used in custom training or model.fit. 

521 Function takes the following arguments: 

522 y_true (tf.Tensor): The true tensor. 

523 y_pred (tf.Tensor, optional): The output tensor. If not supplied 

524 the model will be called to get this tensor. Defaults to None. 

525 step (int): A step for tensorboard callbacks. Defaults to None. 

526 

527 Examples: 

528 >>> import tensorflow as tf 

529 >>> import encodermap as em 

530 >>> from encodermap import loss_functions 

531 >>> model = tf.keras.Model() 

532 >>> loss = loss_functions.reconstruction_loss(model) 

533 >>> x = tf.random.normal(shape=(10, 10)) 

534 >>> loss(x, x).numpy() 

535 0.0 

536 

537 """ 

538 

539 def reconstruction_loss_func( 

540 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None 

541 ) -> tf.Tensor: 

542 # if y_pred is None, this function is used in custom training 

543 # and should use model to get the output 

544 if y_pred is None: 

545 y_pred = model(y_true) 

546 # calculate error 

547 reconstruction_error = tf.reduce_mean(tf.square(tf.subtract(y_pred, y_true))) 

548 return reconstruction_error 

549 

550 return reconstruction_loss_func 

551 

552 

553def auto_loss( 

554 model: tf.keras.models.Model, 

555 parameters: Optional[AnyParameters] = None, 

556 callback: Optional[tf.keras.callbacks.Callback] = None, 

557) -> Callable: 

558 """Encodermap auto_loss. 

559 

560 Use in custom training loops or in model.fit() training. 

561 

562 Args: 

563 model (tf.keras.Model): A model you want to use the loss function on. 

564 parameters (Optional[AnyParameters]): The parameters. If None is 

565 provided default values (check them with `print(em.Parameters.defaults_description())`) 

566 are used. Defaults to None. 

567 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback, 

568 that prevents a tensorboard write when `parameters.summary_step` is 

569 set to greater values. This saves disk-space, as costs are not needed 

570 to be logged every training step. 

571 

572 Returns: 

573 Callable: A loss function. 

574 

575 """ 

576 if parameters is None: 

577 p = Parameters() 

578 else: 

579 p = parameters 

580 

581 if callback is None: 

582 write_bool = K.constant(False, "bool", name="log_bool") 

583 else: 

584 write_bool = callback.log_bool 

585 

586 def auto_loss_func( 

587 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None 

588 ) -> tf.Tensor: 

589 """y_true is complete model input, y_pred is complete model output. 

590 Because here it is not intended to unpack the output into dihedrals 

591 and angles, y_pred can be None and will be directly taken from the model. 

592 """ 

593 auto_loss_func.name = "auto_loss" 

594 

595 if y_pred is None: 

596 y_pred = model(y_true) 

597 if p.auto_cost_scale is not None: 

598 if p.auto_cost_variant == "mean_square": 

599 auto_cost = tf.reduce_mean( 

600 tf.square(periodic_distance(y_true, y_pred, p.periodicity)) 

601 ) 

602 elif p.auto_cost_variant == "mean_abs": 

603 auto_cost = tf.reduce_mean( 

604 tf.abs(periodic_distance(y_true, y_pred, p.periodicity)) 

605 ) 

606 elif p.auto_cost_variant == "mean_norm": 

607 auto_cost = tf.reduce_mean( 

608 tf.norm(periodic_distance(y_true, y_pred, p.periodicity), axis=1) 

609 ) 

610 else: 

611 raise ValueError( 

612 "auto_cost_variant {} not available".format(p.auto_cost_variant) 

613 ) 

614 auto_cost *= p.auto_cost_scale 

615 else: 

616 auto_cost = 0.0 

617 tf.cond( 

618 write_bool, 

619 true_fn=lambda: _summary_cost("Auto Cost", auto_cost), 

620 false_fn=lambda: _do_nothing(), 

621 name="Cost", 

622 ) 

623 tf.debugging.assert_all_finite( 

624 auto_cost, message=f"Auto cost became infinite or NaN." 

625 ) 

626 return auto_cost 

627 

628 return auto_loss_func 

629 

630 

631def dihedral_loss( 

632 model: tf.keras.models.Model, 

633 parameters: Optional[AnyParameters] = None, 

634 callback: Optional[tf.keras.callbacks.Callback] = None, 

635) -> Callable: 

636 """Encodermap dihedral loss. 

637 

638 Calculates distances between true and predicted dihedral angles. Respects 

639 periodicity in a [-a, a] interval if the provided parameters have a 

640 periodicity of 2 * a. 

641 

642 Note: 

643 The interval should be (-a, a], but due to floating point precision we can't make this 

644 distinction here. 

645 

646 Args: 

647 model (tf.keras.Model): A model you want to use the loss function on. 

648 parameters (Optional[AnyParameters]): The parameters. If None is 

649 provided default values (check them with `print(em.Parameters.defaults_description())`) 

650 are used. Defaults to None. 

651 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback, 

652 that prevents a tensorboard write when `parameters.summary_step` is 

653 set to greater values. This saves disk-space, as costs are not needed 

654 to be logged every training step. 

655 

656 Returns: 

657 Callable: A loss function. 

658 

659 """ 

660 if parameters is None: 

661 p = ADCParameters() 

662 else: 

663 p = parameters 

664 

665 if callback is None: 

666 write_bool = K.constant(False, "bool", name="log_bool") 

667 else: 

668 write_bool = callback.log_bool 

669 

670 # closure 

671 def dihedral_loss_func( 

672 y_pred: tf.Tensor, y_true: Optional[tf.Tensor] = None 

673 ) -> tf.Tensor: 

674 """y_pred should be model input dihedrals, y_true should be model output dihedrals.""" 

675 dihedral_loss_func.name = "dihedral_loss" 

676 if p.dihedral_cost_scale is not None: 

677 if p.dihedral_cost_variant == "mean_square": 

678 dihedral_cost = tf.reduce_mean( 

679 tf.square(periodic_distance(y_true, y_pred, p.periodicity)) 

680 ) 

681 elif p.dihedral_cost_variant == "mean_abs": 

682 dihedral_cost = tf.reduce_mean( 

683 tf.abs(periodic_distance(y_true, y_pred, p.periodicity)) 

684 ) 

685 elif p.dihedral_cost_variant == "mean_norm": 

686 dihedral_cost = tf.reduce_mean( 

687 tf.norm(periodic_distance(y_true, y_pred, p.periodicity), axis=1) 

688 ) 

689 else: 

690 raise ValueError( 

691 "dihedral_cost_variant {} not available".format(p.auto_cost_variant) 

692 ) 

693 dihedral_cost /= p.dihedral_cost_reference 

694 dihedral_cost *= p.dihedral_cost_scale 

695 else: 

696 dihedral_cost = 0.0 

697 tf.cond( 

698 write_bool, 

699 true_fn=lambda: _summary_cost("Dihedral Cost", dihedral_cost), 

700 false_fn=lambda: _do_nothing(), 

701 name="Cost", 

702 ) 

703 tf.debugging.assert_all_finite( 

704 dihedral_cost, message=f"Dihedral cost became infinite or NaN." 

705 ) 

706 return dihedral_cost 

707 

708 return dihedral_loss_func 

709 

710 

711def side_dihedral_loss( 

712 model: tf.keras.models.Model, 

713 parameters: Optional[AnyParameters] = None, 

714 callback: Optional[tf.keras.callbacks.Callback] = None, 

715) -> Callable: 

716 """Encodermap sidechain dihedral loss. 

717 

718 Calculates distances between true and predicted sidechain dihedral angles. 

719 Respects periodicity in a [-a, a] interval if the provided parameters have 

720 a periodicity of 2 * a. 

721 

722 Note: 

723 The interval should be (-a, a], but due to floating point precision we can't make this 

724 distinction here. 

725 

726 Args: 

727 model (tf.keras.Model): A model you want to use the loss function on. 

728 parameters (Optional[AnyParameters]): The parameters. If None is 

729 provided default values (check them with `print(em.Parameters.defaults_description())`) 

730 are used. Defaults to None. 

731 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback, 

732 that prevents a tensorboard write when `parameters.summary_step` is 

733 set to greater values. This saves disk-space, as costs are not needed 

734 to be logged every training step. 

735 

736 Returns: 

737 Callable: A loss function. 

738 

739 """ 

740 if parameters is None: 

741 p = ADCParameters() 

742 else: 

743 p = parameters 

744 

745 if callback is None: 

746 write_bool = K.constant(False, "bool", name="log_bool") 

747 else: 

748 write_bool = callback.log_bool 

749 

750 # closure 

751 def side_dihedral_loss_func( 

752 y_pred: tf.Tensor, y_true: Optional[tf.Tensor] = None 

753 ) -> tf.Tensor: 

754 """y_pred should be model input side dihedrals, y_true should be model output side dihedrals.""" 

755 side_dihedral_loss_func.name = "side_dihedral_loss" 

756 if p.side_dihedral_cost_scale is not None: 

757 if p.side_dihedral_cost_variant == "mean_square": 

758 side_dihedral_cost = tf.reduce_mean( 

759 tf.square(periodic_distance(y_true, y_pred, p.periodicity)) 

760 ) 

761 elif p.side_dihedral_cost_variant == "mean_abs": 

762 side_dihedral_cost = tf.reduce_mean( 

763 tf.abs(periodic_distance(y_true, y_pred, p.periodicity)) 

764 ) 

765 elif p.side_dihedral_cost_variant == "mean_norm": 

766 side_dihedral_cost = tf.reduce_mean( 

767 tf.norm(periodic_distance(y_true, y_pred, p.periodicity), axis=1) 

768 ) 

769 else: 

770 raise ValueError( 

771 "dihedral_cost_variant {} not available".format(p.auto_cost_variant) 

772 ) 

773 side_dihedral_cost /= p.side_dihedral_cost_reference 

774 side_dihedral_cost *= p.side_dihedral_cost_scale 

775 else: 

776 side_dihedral_cost = 0.0 

777 tf.cond( 

778 write_bool, 

779 true_fn=lambda: _summary_cost( 

780 "Sidechain Dihedral Cost", side_dihedral_cost 

781 ), 

782 false_fn=lambda: _do_nothing(), 

783 name="Cost", 

784 ) 

785 tf.debugging.assert_all_finite( 

786 side_dihedral_cost, message=f"Side dihedral cost became infinite or NaN." 

787 ) 

788 return side_dihedral_cost 

789 

790 return side_dihedral_loss_func 

791 

792 

793def angle_loss( 

794 model: tf.keras.models.Model, 

795 parameters: Optional[AnyParameters] = None, 

796 callback: Optional[tf.keras.callbacks.Callback] = None, 

797) -> Callable: 

798 """Encodermap angle loss. 

799 

800 Calculates distances between true and predicted angles. Respects periodicity 

801 in an [-a, a] interval if the provided parameters have a periodicity of 2 * a. 

802 

803 Note: 

804 The interval should be (-a, a], but due to floating point precision we can't make this 

805 distinction here. 

806 

807 Args: 

808 model (tf.keras.Model): A model you want to use the loss function on. 

809 parameters (Optional[AnyParameters]): The parameters. If None is 

810 provided default values (check them with `print(em.Parameters.defaults_description())`) 

811 are used. Defaults to None. 

812 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback, 

813 that prevents a tensorboard write when `parameters.summary_step` is 

814 set to greater values. This saves disk-space, as costs are not needed 

815 to be logged every training step. 

816 

817 Returns: 

818 Callable: A loss function. 

819 

820 """ 

821 if parameters is None: 

822 p = ADCParameters() 

823 else: 

824 p = parameters 

825 

826 if callback is None: 

827 write_bool = K.constant(False, "bool", name="log_bool") 

828 else: 

829 write_bool = callback.log_bool 

830 

831 # closure 

832 def angle_loss_func( 

833 y_pred: tf.Tensor, y_true: Optional[tf.Tensor] = None 

834 ) -> tf.Tensor: 

835 """y_true should be input angles. y_pred should be output angles (either from mean input angles or, when 

836 ADCParameters.use_backbone_angles == True, directly from model output).""" 

837 angle_loss_func.name = "angle_loss" 

838 if p.angle_cost_scale is not None: 

839 if p.angle_cost_variant == "mean_square": 

840 angle_cost = tf.reduce_mean( 

841 tf.square(periodic_distance(y_true, y_pred, p.periodicity)) 

842 ) 

843 elif p.angle_cost_variant == "mean_abs": 

844 angle_cost = tf.reduce_mean( 

845 tf.abs(periodic_distance(y_true, y_pred, p.periodicity)) 

846 ) 

847 elif p.angle_cost_variant == "mean_norm": 

848 angle_cost = tf.reduce_mean( 

849 tf.norm(periodic_distance(y_true, y_pred, p.periodicity), axis=1) 

850 ) 

851 else: 

852 raise ValueError( 

853 "angle_cost_variant {} not available".format(p.auto_cost_variant) 

854 ) 

855 angle_cost /= p.angle_cost_reference 

856 angle_cost *= p.angle_cost_scale 

857 else: 

858 angle_cost = 0.0 

859 tf.cond( 

860 write_bool, 

861 true_fn=lambda: _summary_cost("Angle Cost", angle_cost), 

862 false_fn=lambda: _do_nothing(), 

863 name="Cost", 

864 ) 

865 tf.debugging.assert_all_finite( 

866 angle_cost, message=f"Angle cost became infinite or NaN." 

867 ) 

868 return angle_cost 

869 

870 return angle_loss_func 

871 

872 

873def cartesian_distance_loss( 

874 model: tf.keras.models.Model, 

875 parameters: Optional[AnyParameters] = None, 

876 callback: Optional[tf.keras.callbacks.Callback] = None, 

877) -> Callable: 

878 """Encodermap cartesian distance loss. 

879 

880 Calculates sigmoid-weighted distances between pairwise cartesians and latent. 

881 Uses sketch-map's sigmoid function to transform the high-dimensional space 

882 of the input and the low-dimensional space of latent. 

883 

884 Note: 

885 Make sure to provide the pairwise cartesian distances. The output of 

886 the latent will be compared to the input. 

887 

888 Note: 

889 If the model contains two layers. The first layer will be assumed to be 

890 the decoder. If the model contains more layers, one layer needs to 

891 be named 'latent' (case-insensitive). 

892 

893 Args: 

894 model (tf.keras.Model): A model you want to use the loss function on. 

895 parameters (Optional[AnyParameters]): The parameters. If None is 

896 provided default values (check them with `print(em.Parameters.defaults_description())`) 

897 are used. Defaults to None. 

898 callback (Optional[tf.keras.callbacks.Callback]): A `write_bool` callback, 

899 that prevents a tensorboard write when `parameters.summary_step` is 

900 set to greater values. This saves disk-space, as costs are not needed 

901 to be logged every training step. 

902 

903 Returns: 

904 Callable: A loss function. 

905 

906 """ 

907 if parameters is None: 

908 p = ADCParameters() 

909 else: 

910 p = parameters 

911 

912 if callback is None: 

913 write_bool = K.constant(False, "bool", name="log_bool") 

914 else: 

915 write_bool = callback.log_bool 

916 

917 dist_loss = sigmoid_loss( 

918 p, 

919 periodicity_overwrite=float("inf"), 

920 dist_dig_parameters_overwrite=p.cartesian_dist_sig_parameters, 

921 ) 

922 

923 def cartesian_distance_loss_func(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 

924 """y_true can be whatever input you like, dihedrals, angles, pairwise dist, contact maps. That will be 

925 transformed with Sketch-map's sigmoid function, as will the output of the latent layer of the autoencoder. 

926 the difference of these two will result in a loss function.""" 

927 cartesian_distance_loss_func.name = "cartesian_distance_loss" 

928 if p.cartesian_distance_cost_scale is not None: 

929 dist_cost = dist_loss(y_true, y_pred) 

930 dist_cost *= p.cartesian_distance_cost_scale 

931 else: 

932 dist_cost = 0.0 

933 tf.cond( 

934 write_bool, 

935 true_fn=lambda: _summary_cost("Cartesian Distance Cost", dist_cost), 

936 false_fn=lambda: _do_nothing(), 

937 name="Cost", 

938 ) 

939 tf.debugging.assert_all_finite( 

940 dist_cost, message=f"Cartesian distance cost became infinite or NaN." 

941 ) 

942 return dist_cost 

943 

944 return cartesian_distance_loss_func 

945 

946 

947def cartesian_loss( 

948 model: tf.keras.models.Model, 

949 scale_callback: Optional[tf.keras.callbacks.Callback] = None, 

950 parameters: Optional["AnyParameters"] = None, 

951 log_callback: Optional[tf.keras.callbacks.Callback] = None, 

952 print_current_scale: bool = False, 

953) -> Callable: 

954 """Encodermap cartesian loss. 

955 

956 Calculates difference between input and output pairwise distances. 

957 Adjustments to this cost function via the `soft_start` parameter need to be made via a 

958 callback that re-compiles the model during training. For this, the 

959 `soft_start` parameters of the outer function will be used. 

960 It must be either 0 or 1, indexing the 1st or 2nd element of the 

961 `cartesian_cost_scale_soft_start` tuple. The callback should also be provided 

962 when `model.fit()` is executed. 

963 

964 Three cases are possible: 

965 * Case 1: step < cartesian_cost_scale_soft_start[0]: cost_scale = 0 

966 * Case 2: cartesian_cost_scale_soft_start[0] <= step <= cartesian_cost_scale_soft_start[1]: 

967 cost_scale = p.cartesian_cost_scale / (cartesian_cost_scale_soft_start[1] - cartesian_cost_scale_soft_start[0]) * step 

968 * Case 3: cartesian_cost_scale_soft_start[1] < step: cost_scale = p.cartesian_cost_scale 

969 

970 Note: 

971 Make sure to provide the pairwise cartesian distances. This function will be 

972 adjusted as training increases via a callback. See 

973 `encodermap.callbacks.callbacks.IncreaseCartesianCost` for more info. 

974 

975 Args: 

976 model (tf.keras.Model): The model to use the loss function on. 

977 scale_callback: Optional[encoodermap.callbacks.IncreaseCartesianCost]: 

978 parameters (Optional[AnyParameters]]): The parameters. If None is provided, 

979 default values (check them with 

980 `print(em.ADCParameters.defaults_description())`) are used. Defaults 

981 to None. 

982 soft_start (Union[int, None], optional): How to scale the cartesian loss. 

983 The `encodermap.parameters.ADCParameters` class contains a two-tuple of integers. These 

984 integers can be used to scale this loss function. If soft_start is 0, 

985 the first value of `ADCParameters.cartesian_cost_scale_soft_start` 

986 will be used. If it is 1, the second. If it is None, or both values 

987 of `ADCParameters.cartesian_cost_scale_soft_start` are None, the cost will 

988 not be scaled. Defaults to None. 

989 print_current_scale (bool, optional): Whether to print the current scale. 

990 Is used in testing. Defaults to False. 

991 

992 Raises: 

993 Exception: When no bottleneck/latent layer can be found in the model. 

994 Exception: When soft_start is greater than 1 and can't index the two-tuple. 

995 

996 Returns: 

997 Callable: A loss function. Can be used in either custom training or model.fit(). 

998 

999 """ 

1000 if parameters is None: 

1001 p = ADCParameters() 

1002 else: 

1003 p = parameters 

1004 

1005 if scale_callback is not None: 

1006 current_scale_callback = scale_callback.current_cartesian_cost_scale 

1007 else: 

1008 current_scale_callback = K.constant( 

1009 p.cartesian_cost_scale, dtype="float32", name="current_cartesian_cost_scale" 

1010 ) 

1011 

1012 if print_current_scale: 

1013 print(current_scale_callback) 

1014 

1015 if log_callback is None: 

1016 write_bool = K.constant(False, "bool", name="log_bool") 

1017 else: 

1018 write_bool = log_callback.log_bool 

1019 

1020 def cartesian_loss_func( 

1021 y_true: tf.Tensor, y_pred: Optional[tf.Tensor] = None 

1022 ) -> tf.Tensor: 

1023 """y_true should be pairwise distances of input cartesians, 

1024 y_pred should be pairwise distances of back-mapped output cartesians.""" 

1025 scale = current_scale_callback 

1026 if p.cartesian_cost_variant == "mean_square": 

1027 cartesian_cost = tf.reduce_mean(tf.square(y_true - y_pred)) 

1028 elif p.cartesian_cost_variant == "mean_abs": 

1029 cartesian_cost = tf.reduce_mean(tf.abs(y_true - y_pred)) 

1030 elif p.cartesian_cost_variant == "mean_norm": 

1031 cartesian_cost = tf.reduce_mean(tf.norm(y_true - y_pred, axis=1)) 

1032 else: 

1033 raise ValueError( 

1034 "cartesian_cost_variant {} not available".format( 

1035 p.dihedral_to_cartesian_cost_variant 

1036 ) 

1037 ) 

1038 cartesian_cost /= p.cartesian_cost_reference 

1039 tf.cond( 

1040 write_bool, 

1041 true_fn=lambda: _summary_cost( 

1042 "Cartesian Cost before scaling", cartesian_cost 

1043 ), 

1044 false_fn=lambda: _do_nothing(), 

1045 name="Cost", 

1046 ) 

1047 tf.cond( 

1048 write_bool, 

1049 true_fn=lambda: _summary_cost("Cartesian Cost current scaling", scale), 

1050 false_fn=lambda: _do_nothing(), 

1051 name="Cost", 

1052 ) 

1053 cartesian_cost *= scale 

1054 tf.cond( 

1055 write_bool, 

1056 true_fn=lambda: _summary_cost( 

1057 "Cartesian Cost after scaling", cartesian_cost 

1058 ), 

1059 false_fn=lambda: _do_nothing(), 

1060 name="Cost", 

1061 ) 

1062 tf.debugging.assert_all_finite( 

1063 cartesian_cost, message=f"Cartesian cost became infinite or NaN." 

1064 ) 

1065 return cartesian_cost 

1066 

1067 return cartesian_loss_func