Coverage for encodermap/callbacks/metrics.py: 18%

206 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-15 21:06 +0200

1# -*- coding: utf-8 -*- 

2# encodermap/callbacks/metrics.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Metrics are meta-variables that can be computed to observe the training but 

23are not directly linked to loss/cost and gradients. 

24 

25""" 

26 

27 

28################################################################################ 

29# Imports 

30################################################################################ 

31 

32 

33# Future Imports at the top 

34from __future__ import annotations 

35 

36# Standard Library Imports 

37from typing import TYPE_CHECKING, Any, Literal, Optional, Union 

38 

39# Third Party Imports 

40import numpy as np 

41import tensorflow as tf 

42import tensorflow.keras.backend as K 

43 

44# Encodermap imports 

45from encodermap.misc.distances import pairwise_dist 

46 

47 

48################################################################################ 

49# Typing 

50################################################################################ 

51 

52 

53if TYPE_CHECKING: 

54 # Encodermap imports 

55 from encodermap import ADCParameters, Parameters 

56 

57 

58################################################################################ 

59# Metrics 

60################################################################################ 

61 

62 

63WEIGHTS: list[float] = [14.0067, 24.305, 24.305] 

64 

65 

66################################################################################ 

67# Metrics 

68################################################################################ 

69 

70 

71def kabsch_weighted( 

72 P: np.ndarray, Q: np.ndarray, W: Optional[np.ndarray] = None 

73) -> tuple[np.ndarray, np.ndarray, float]: 

74 """Taken from Jimmy C. Kromann's RMSD (https://github.com/charnley/rmsd) 

75 Using the Kabsch algorithm with two sets of paired point P and Q. 

76 Each vector set is represented as an NxD matrix, where D is the 

77 dimension of the space. 

78 An optional vector of weights W may be provided. 

79 

80 Note that this algorithm does not require that P and Q have already 

81 been overlayed by a centroid translation. 

82 

83 The function returns the rotation matrix U, translation vector V, 

84 and RMS deviation between Q and P', where P' is: 

85 

86 P' = P * U + V 

87 

88 For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm 

89 

90 Args: 

91 P (np.ndarray): Points A with shape (n_points, 3). 

92 Q (np.ndarray): Points B with shape (n_points, 3). 

93 W (np.ndarray): Weights with shape (n_points, ). 

94 

95 Returns: 

96 float: The RMSD value in the same units, as the input points. 

97 

98 """ 

99 # Computation of the weighted covariance matrix 

100 CMP = np.zeros(3).astype("float32") 

101 CMQ = np.zeros(3).astype("float32") 

102 C = np.zeros((3, 3)).astype("float32") 

103 if W is None: 

104 W = np.ones(len(P)).astype("float32") / len(P) 

105 W = np.array([W, W, W]).T.astype("float32") 

106 # NOTE UNUSED psq = 0.0 

107 # NOTE UNUSED qsq = 0.0 

108 iw = 3.0 / W.sum() 

109 n = len(P) 

110 for i in range(3): 

111 for j in range(n): 

112 for k in range(3): 

113 C[i, k] += P[j, i] * Q[j, k] * W[j, i] 

114 CMP = (P * W).sum(axis=0) 

115 CMQ = (Q * W).sum(axis=0) 

116 PSQ = (P * P * W).sum() - (CMP * CMP).sum() * iw 

117 QSQ = (Q * Q * W).sum() - (CMQ * CMQ).sum() * iw 

118 C = (C - np.outer(CMP, CMQ) * iw) * iw 

119 

120 # Computation of the optimal rotation matrix 

121 # This can be done using singular value decomposition (SVD) 

122 # Getting the sign of the det(V)*(W) to decide 

123 # whether we need to correct our rotation matrix to ensure a 

124 # right-handed coordinate system. 

125 # And finally calculating the optimal rotation matrix U 

126 # see http://en.wikipedia.org/wiki/Kabsch_algorithm 

127 V, S, W = np.linalg.svd(C) 

128 d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0 

129 

130 if d: 

131 S[-1] = -S[-1] 

132 V[:, -1] = -V[:, -1] 

133 # Create Rotation matrix U, translation vector V, and calculate RMSD: 

134 U = np.dot(V, W) 

135 msd = (PSQ + QSQ) * iw - 2.0 * S.sum() 

136 if msd < 0.0: 

137 msd = 0.0 

138 rmsd_ = np.sqrt(msd) 

139 V = np.zeros(3).astype("float32") 

140 for i in range(3): 

141 t = (U[i, :] * CMQ).sum() 

142 V[i] = CMP[i] - t 

143 V = V * iw 

144 return rmsd_.astype("float32") 

145 

146 

147def rmsd(a, b, translate): 

148 weights = np.tile(WEIGHTS, a.shape[0] // 3).astype("float32") 

149 if translate: 

150 a -= np.tile(np.expand_dims(np.mean(a, axis=-1), -1), (1, 3)) 

151 b -= np.tile(np.expand_dims(np.mean(b, axis=-1), -1), (1, 3)) 

152 return kabsch_weighted(a, b, weights) 

153 

154 

155def rmsd_numpy(a: np.ndarray, b: np.ndarray, translate: bool = True) -> np.ndarray: 

156 """Implements Kabsch-Umeyama algorithm. 

157 

158 References: 

159 @article{kabsch1976solution, 

160 title={A solution for the best rotation to relate two sets of vectors}, 

161 author={Kabsch, Wolfgang}, 

162 journal={Acta Crystallographica Section A: Crystal Physics, Diffraction, Theoretical and General Crystallography}, 

163 volume={32}, 

164 number={5}, 

165 pages={922--923}, 

166 year={1976}, 

167 publisher={International Union of Crystallography} 

168 } 

169 

170 """ 

171 result = [rmsd(i, j, translate=translate) for i, j in zip(a, b)] 

172 return np.asarray(result) 

173 

174 

175def kabsch_tf(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: 

176 # weights repeat N-CA-C 

177 weights = tf.tile(WEIGHTS, [tf.shape(a)[0] // 3]) 

178 

179 # center coordinates 

180 a -= tf.tile(tf.expand_dims(tf.reduce_mean(a, axis=-1), -1), (1, 3)) 

181 b -= tf.tile(tf.expand_dims(tf.reduce_mean(b, axis=-1), -1), (1, 3)) 

182 

183 # predefine multipliers 

184 S_mul = tf.convert_to_tensor( 

185 [1, 1, -1], 

186 dtype=tf.float32, 

187 ) 

188 

189 # Computation of the weighted covariance matrix 

190 C = tf.zeros((3, 3), dtype=tf.float32) 

191 W = tf.tile(tf.expand_dims(weights, -1), [1, 3]) 

192 iw = 3.0 / tf.reduce_sum(W) 

193 n = tf.shape(a)[0] 

194 for i in range(3): 

195 for j in range(n): 

196 for k in range(3): 

197 updates = ( 

198 C[i, k] 

199 + tf.gather_nd(a, [j, i]) * tf.gather_nd(b, [j, k]) * W[j, i] 

200 ) 

201 C = tf.tensor_scatter_nd_update( 

202 tensor=C, 

203 indices=[[i, k]], 

204 updates=[updates], 

205 ) 

206 CMP = tf.reduce_sum(a * W, axis=0) 

207 CMQ = tf.reduce_sum(b * W, axis=0) 

208 PSQ = tf.reduce_sum(tf.square(a) * W) - tf.reduce_sum(tf.square(CMP)) * iw 

209 QSQ = tf.reduce_sum(tf.square(b) * W) - tf.reduce_sum(tf.square(CMQ)) * iw 

210 C = (C - tf.tensordot(CMP, CMQ, axes=0) * iw) * iw 

211 

212 # Computation of the optimal rotation matrix 

213 # This can be done using singular value decomposition (SVD) 

214 # Getting the sign of the det(V)*(W) to decide 

215 # whether we need to correct our rotation matrix to ensure a 

216 # right-handed coordinate system. 

217 # And finally calculating the optimal rotation matrix U 

218 # see http://en.wikipedia.org/wiki/Kabsch_algorithm 

219 S, V, W = tf.linalg.svd(C, full_matrices=False) 

220 d = (tf.linalg.det(V) * tf.linalg.det(W)) < 0.0 

221 if d: 

222 S *= S_mul 

223 msd = (PSQ + QSQ) * iw - 2.0 * tf.reduce_sum(S) 

224 if msd < 0.0: 

225 msd = 0.0 

226 return tf.sqrt(msd) 

227 

228 

229def rmsd_tf( 

230 a: tf.Tensor, 

231 b: tf.Tensor, 

232 p: ADCParameters, 

233) -> tf.Tensor: 

234 out = [] 

235 for i in range(p.batch_size): 

236 result = kabsch_tf( 

237 tf.gather_nd(a, [i]), 

238 tf.gather_nd(b, [i]), 

239 ) 

240 out.append(result) 

241 return tf.stack(out) 

242 

243 

244################################################################################ 

245# Metric Classes 

246################################################################################ 

247 

248 

249@tf.keras.utils.register_keras_serializable() 

250class EncoderMapBaseMetric(tf.keras.metrics.Metric): 

251 """Base class for metrics in EncoderMap. 

252 

253 Metrics are used to judge the performance of ML models. They are similar 

254 to loss functions as they can (but don't have to) be computed at every 

255 iteration of the training. Oftentimes, metrics implement more expensive 

256 calculations. Metrics are also automatically added to a model's training 

257 history, accessible via `history = emap.train()`. 

258 

259 Examples: 

260 

261 In this example, the `update` method always returns zero. 

262 

263 >>> import encodermap as em 

264 >>> import numpy as np 

265 ... 

266 >>> class MyMetric(em.callbacks.EncoderMapBaseMetric): 

267 ... def update(self, y_true, y_pred, sample_weight=None): 

268 ... return 0.0 

269 ... 

270 >>> emap = em.EncoderMap() # doctest: +ELLIPSIS 

271 Output... 

272 >>> emap.add_metric(MyMetric) 

273 >>> history = emap.train() # doctest: +ELLIPSIS 

274 Saving... 

275 >>> np.mean(history["MyMetric Metric"]) 

276 0.0 

277 >>> len(history["MyMetric Metric"]) == emap.p.n_steps 

278 True 

279 

280 """ 

281 

282 custom_update_state = True 

283 

284 def __init__( 

285 self, 

286 parameters: Optional[ADCParameters], 

287 name: str | None = None, 

288 current_training_step: Optional[int] = None, 

289 **kwargs, 

290 ) -> None: 

291 if name is None: 

292 name = f"{self.__class__.__name__} Metric" 

293 super(EncoderMapBaseMetric, self).__init__(name=name, **kwargs) 

294 if parameters is None: 

295 self.p = Parameters() 

296 else: 

297 self.p = parameters 

298 if current_training_step is None: 

299 self._my_training_step = K.variable( 

300 self.p.current_training_step, "int64", name="train_counter" 

301 ) 

302 else: 

303 if parameters is not None: 

304 if current_training_step != parameters.current_training_step: 

305 raise Exception( 

306 f"Instantiation of {self.__class__.__name__} got different " 

307 f"values for current training steps. In parameters, the " 

308 f"training step is {parameters.current_training_step}, in " 

309 f"the arguments, I got {current_training_step}" 

310 ) 

311 self._my_training_step = K.variable( 

312 current_training_step, "int64", name="train_counter" 

313 ) 

314 self.custom_metric_scalar = self.add_weight( 

315 name=f"custom_metric_vector_{self.__class__.__name__}", 

316 initializer="zeros", 

317 dtype=tf.float32, 

318 ) 

319 if not self.custom_update_state and not hasattr(self, "update"): 

320 raise Exception( 

321 f"Please implement an `update` method, that returns a scalar, when" 

322 f"sublcassing this metric." 

323 ) 

324 

325 @classmethod 

326 def from_config( 

327 cls, 

328 config: dict[str, Any], 

329 custom_objects: Optional[dict[Any, Any]] = None, 

330 ): 

331 """Reconstructs this keras serializable from a dict. 

332 

333 Args: 

334 config (dict[str, Any]): A dictionary. 

335 custom_objects (Optional[dict[str, Any]]): Not needed here, but see 

336 https://keras.io/guides/serialization_and_saving/ for yourself. 

337 

338 """ 

339 parameters = config.pop("parameters") 

340 parameters = Parameters(**parameters) 

341 return cls( 

342 name=config.pop("name"), 

343 parameters=parameters, 

344 current_training_step=None, 

345 **config, 

346 ) 

347 

348 def get_config(self) -> dict[str, Any]: 

349 """Serializes this keras serializable. 

350 

351 Returns: 

352 dict[str, Any]: A dict with the serializable objects. 

353 

354 """ 

355 config = super().get_config().copy() 

356 config["parameters"] = self.p.to_dict() 

357 return config 

358 

359 def update_state( 

360 self, 

361 y_true: tf.Tensor, 

362 y_pred: tf.Tensor, 

363 sample_weight=None, 

364 ) -> None: 

365 """Updates this metric.""" 

366 self.custom_metric_scalar.assign(self.update(y_true, y_pred)) 

367 self._my_training_step.assign_add(1) 

368 

369 def result(self): 

370 return self.custom_metric_scalar 

371 

372 

373@tf.keras.utils.register_keras_serializable() 

374class AngleDihedralCartesianEncoderMapBaseMetric(tf.keras.metrics.Metric): 

375 custom_update_state = False 

376 

377 def __init__( 

378 self, 

379 parameters: Optional[ADCParameters], 

380 name: str | None = None, 

381 current_training_step: Optional[int] = None, 

382 **kwargs, 

383 ) -> None: 

384 # Encodermap imports 

385 from encodermap import ADCParameters 

386 

387 if name is None: 

388 name = f"{self.__class__.__name__} Metric" 

389 super(AngleDihedralCartesianEncoderMapBaseMetric, self).__init__( 

390 name=name, **kwargs 

391 ) 

392 if parameters is None: 

393 self.p = ADCParameters() 

394 else: 

395 self.p = parameters 

396 if current_training_step is None: 

397 self._my_training_step = K.variable( 

398 self.p.current_training_step, "int64", name="train_counter" 

399 ) 

400 else: 

401 if parameters is not None: 

402 if current_training_step != parameters.current_training_step: 

403 raise Exception( 

404 f"Instantiation of {self.__class__.__name__} got different " 

405 f"values for current training steps. In parameters, the " 

406 f"training step is {parameters.current_training_step}, in " 

407 f"the arguments, I got {current_training_step}" 

408 ) 

409 self._my_training_step = K.variable( 

410 current_training_step, "int64", name="train_counter" 

411 ) 

412 if not self.custom_update_state and not hasattr(self, "update"): 

413 raise Exception( 

414 f"Please implement an `update` method, that returns a scalar, when" 

415 f"sublcassing this metric." 

416 ) 

417 # self.custom_metric_vector = self.add_weight( 

418 # name=f"custom_metric_vector_{self.__class__.__name__}", 

419 # shape=self.p.batch_size, 

420 # initializer="zeros", 

421 # ) 

422 # self.custom_metric_scalar = self.add_weight( 

423 # name=f"custom_metric_vector_{self.__class__.__name__}", 

424 # initializer="zeros", 

425 # dtype=tf.float32, 

426 # ) 

427 

428 @classmethod 

429 def from_config( 

430 cls, 

431 config: dict[str, Any], 

432 custom_objects: Optional[dict[Any, Any]] = None, 

433 ): 

434 """Reconstructs this keras serializable from a dict. 

435 

436 Args: 

437 config (dict[str, Any]): A dictionary. 

438 custom_objects (Optional[dict[str, Any]]): Not needed here, but see 

439 https://keras.io/guides/serialization_and_saving/ for yourself. 

440 

441 """ 

442 # Encodermap imports 

443 from encodermap import ADCParameters 

444 

445 parameters = config.pop("parameters") 

446 parameters = ADCParameters(**parameters) 

447 return cls( 

448 name=config.pop("name"), 

449 parameters=parameters, 

450 current_training_step=None, 

451 **config, 

452 ) 

453 

454 def get_config(self) -> dict[str, Any]: 

455 """Serializes this keras serializable. 

456 

457 Returns: 

458 dict[str, Any]: A dict with the serializable objects. 

459 

460 """ 

461 config = super().get_config().copy() 

462 config["parameters"] = self.p.to_dict() 

463 return config 

464 

465 

466class OmegaAngleBaseMetric(AngleDihedralCartesianEncoderMapBaseMetric): 

467 custom_update_state = True 

468 

469 

470class SidechainVsBackboneFrequencyBaseMetric( 

471 AngleDihedralCartesianEncoderMapBaseMetric 

472): 

473 custom_update_state = True 

474 

475 

476class ADCClashMetric(AngleDihedralCartesianEncoderMapBaseMetric): 

477 custom_update_state = True 

478 """Metric that computes clashes between atoms in the reconstructed backbone. 

479 

480 Please choose the correct distance unit. 

481 

482 """ 

483 

484 def __init__( 

485 self, 

486 distance_unit: Literal["nm", "ang"], 

487 name: str = "ADCClashMetric", 

488 parameters: Optional[ADCParameters] = None, 

489 **kwargs, 

490 ): 

491 super().__init__(name=name, parameters=parameters, **kwargs) 

492 self.clashes = self.add_weight( 

493 name="clashes", 

494 shape=(), 

495 initializer="zeros", 

496 dtype=tf.int64, 

497 ) 

498 self.distance_unit = distance_unit 

499 if self.distance_unit == "nm": 

500 self.clash_distance = 0.1 

501 elif self.distance_unit == "ang": 

502 self.clash_distance = 1 

503 else: 

504 raise Exception( 

505 f"Argument `distance_unit` must be either 'nm' or 'ang'. You " 

506 f"provided {self.distance_unit=}." 

507 ) 

508 

509 @classmethod 

510 def from_config( 

511 cls, 

512 config: dict[str, Any], 

513 custom_objects: Optional[dict[Any, Any]] = None, 

514 ): 

515 """Reconstructs this keras serializable from a dict. 

516 

517 Args: 

518 config (dict[str, Any]): A dictionary. 

519 custom_objects (Optional[dict[str, Any]]): Not needed here, but see 

520 https://keras.io/guides/serialization_and_saving/ for yourself. 

521 

522 """ 

523 # Encodermap imports 

524 from encodermap import ADCParameters 

525 

526 parameters = config.pop("parameters") 

527 distance_unit = config.pop("distance_unit") 

528 parameters = ADCParameters(**parameters) 

529 return cls( 

530 name=config.pop("name"), 

531 parameters=parameters, 

532 distance_unit=distance_unit, 

533 current_training_step=None, 

534 **config, 

535 ) 

536 

537 def get_config(self) -> dict[str, Any]: 

538 """Serializes this keras serializable. 

539 

540 Returns: 

541 dict[str, Any]: A dict with the serializable objects. 

542 

543 """ 

544 config = super().get_config().copy() 

545 config["parameters"] = self.p.to_dict() 

546 config["distance_unit"] = self.distance_unit 

547 return config 

548 

549 def update_state( 

550 self, 

551 y_true: tf.Tensor, 

552 y_pred: tf.Tensor, 

553 sample_weight=None, 

554 ) -> None: 

555 """Updates this metric. 

556 

557 y_true (tf.Tensor): The train goal. 

558 y_pred (tf.Tensor): Current model output. 

559 

560 """ 

561 if isinstance(y_pred, (list, tuple)): 

562 y_pred = y_pred[2] 

563 clashes = tf.math.count_nonzero( 

564 pairwise_dist(y_pred, flat=True) < self.clash_distance, 

565 axis=1, 

566 dtype=tf.int64, 

567 ) 

568 with tf.name_scope("Metrics"): 

569 tf.summary.scalar( 

570 "Mean Clashes", tf.reduce_mean(clashes), step=self._my_training_step 

571 ) 

572 tf.summary.histogram("Batch Clashes", clashes, step=self._my_training_step) 

573 self.clashes.assign(tf.reduce_mean(clashes)) 

574 self._my_training_step.assign_add(1) 

575 

576 def result(self): 

577 return self.clashes 

578 

579 

580class ADCRMSDMetric(AngleDihedralCartesianEncoderMapBaseMetric): 

581 custom_update_state = True 

582 

583 def __init__( 

584 self, 

585 name: str = "ADCRMSDMetric", 

586 parameters: Optional[Union[Parameters, ADCParameters]] = None, 

587 **kwargs: Any, 

588 ) -> None: 

589 """Instantiate the RMSD metric. The RMSD of the output will be computed 

590 on the CA atoms of input vs output. 

591 

592 Note: 

593 Output is in nm. 

594 

595 Args: 

596 name (str): Name of the metric. Defaults to 'ADCRMSDMetric'. 

597 parameters (Optional[Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]]): An instance 

598 of a parameter class, which is used to define which atoms are CA 

599 atoms. Defaults to None. 

600 

601 """ 

602 # Encodermap imports 

603 from encodermap import ADCParameters 

604 

605 super(ADCRMSDMetric, self).__init__( 

606 name=name, 

607 parameters=parameters, 

608 **kwargs, 

609 ) 

610 if parameters is None: 

611 parameters = ADCParameters() 

612 self.p = parameters 

613 self.rmsd = self.add_weight( 

614 name="rmsd", 

615 shape=self.p.batch_size, 

616 initializer="zeros", 

617 ) 

618 

619 def update_state(self, y_true, y_pred, sample_weight=None): 

620 if isinstance(y_pred, (list, tuple)): 

621 y_pred = y_pred[2] 

622 if isinstance(y_true, (list, tuple)): 

623 y_true = y_true[2] 

624 rmsd = rmsd_tf(y_true, y_pred, p=self.p) 

625 self.rmsd.assign(rmsd) 

626 with tf.name_scope("Metrics"): 

627 tf.summary.scalar("Mean RMSDs", tf.reduce_mean(self.rmsd)) 

628 tf.summary.histogram("Batch RMSDs", self.rmsd) 

629 

630 def result(self): 

631 return self.rmsd