Coverage for encodermap/loss_functions/loss_classes.py: 36%

105 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/loss_functions/loss_classes.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Losses for encodermap. 

23 

24All losses in EncoderMap inherit from `tf.keras.losses.Loss` and thus can be 

25easily paired with other models. 

26 

27""" 

28################################################################################ 

29# Imports 

30################################################################################ 

31 

32 

33# Future Imports at the top 

34from __future__ import annotations 

35 

36# Standard Library Imports 

37import functools 

38import inspect 

39import os 

40from collections.abc import Sequence 

41from typing import Any, Optional, Union 

42 

43# Third Party Imports 

44import tensorflow as tf 

45import tensorflow.keras.backend as K 

46 

47# Encodermap imports 

48from encodermap.misc.distances import ( 

49 pairwise_dist, 

50 pairwise_dist_periodic, 

51 periodic_distance, 

52 sigmoid, 

53) 

54from encodermap.parameters.parameters import ADCParameters, Parameters 

55 

56 

57################################################################################ 

58# Typing 

59################################################################################ 

60 

61 

62################################################################################ 

63# Globals 

64################################################################################ 

65 

66 

67__all__: list[str] = ["DihedralLoss", "AngleLoss"] 

68 

69 

70################################################################################ 

71# Utils 

72################################################################################ 

73 

74 

75def testing(cls_or_func): 

76 if inspect.isclass(cls_or_func): 

77 orig_init = cls_or_func.__init__ 

78 

79 @functools.wraps(cls_or_func) 

80 def __init__(self, *args, **kwargs): 

81 if os.getenv("ENCODERMAP_TESTING", "False") != "True": 

82 raise Exception( 

83 f"You are instantiating a em.testing class ({cls_or_func.__name__}). " 

84 f"These classes are actively developed and not stable. If you " 

85 f"know what you are doing, set the environment variable " 

86 f"'EM_TESTING' to 'True'." 

87 ) 

88 return orig_init(self, *args, **kwargs) 

89 

90 cls_or_func.__init__ = __init__ 

91 return cls_or_func 

92 else: 

93 

94 @functools.wraps(cls_or_func) 

95 def newfunc(*args, **kwargs): 

96 if os.getenv("ENCODERMAP_TESTING", "False") != "True": 

97 raise Exception( 

98 f"You are calling an em.testing function: ({cls_or_func.__name__}). " 

99 f"These functions are actively developed and not stable. If you " 

100 f"know what you are doing, set the environment variable " 

101 f"'EM_TESTING' to 'True'." 

102 ) 

103 return cls_or_func(*args, **kwargs) 

104 

105 return newfunc 

106 

107 

108def _do_nothing(*args) -> None: 

109 """This function does nothing. One of the functions provided to tf.cond.""" 

110 pass 

111 

112 

113def _summary_cost( 

114 name: str, 

115 cost: tf.Tensor, 

116) -> None: 

117 """This functions logs a scalar to a name. One of the functions provided to tf.cond. 

118 

119 Args: 

120 name (str): The name to log the scalar as. 

121 cost (tf.Tensor): The scalar tensor to log. 

122 

123 """ 

124 tf.summary.scalar(name, cost) 

125 

126 

127################################################################################ 

128# PublicClasses 

129################################################################################ 

130 

131 

132@testing 

133@tf.keras.utils.register_keras_serializable() 

134class EncoderMapBaseLoss(tf.keras.losses.Loss): 

135 """EncoderMap's base loss. Serializes parameters and `self._train_counter`. 

136 

137 It Can be subclassed to implement custom loss functions that have access to 

138 EncoderMap's parameter classes. 

139 

140 """ 

141 

142 def __init__( 

143 self, 

144 parameters: Optional[Union[Parameters, ADCParameters]] = None, 

145 write_bool_cb: Optional[tf.keras.callbacks.Callback] = None, 

146 **kwargs, 

147 ) -> None: 

148 """Instantiate the Loss class. 

149 

150 Most subclassed losses, don't need to overwrite this `__init__()`. 

151 

152 Args: 

153 parameters (Optional[Union[encodermap.parameters.Parameters, encodermap.parameters.ADCParameters]]): The parameters 

154 this class will use to decide hwo to compute losses. 

155 

156 """ 

157 if parameters is None: 

158 self.p = Parameters() 

159 else: 

160 self.p = parameters 

161 

162 if write_bool_cb is None: 

163 self.cb = None 

164 self.write_bool = K.constant(False, "bool", name="log_bool") 

165 else: 

166 self.cb = write_bool_cb 

167 self.write_bool = write_bool_cb.log_bool 

168 

169 super().__init__() 

170 

171 def call(self, loss_name, current_loss) -> None: 

172 """Use super().call(loss_name, current_loss) to log the current loss to 

173 tensorboard and advance the train counter. 

174 

175 Args: 

176 loss_name (str): The name of the loss, as it should appear in Tensorboard. 

177 current_loss (float): The current value of the loss. 

178 

179 Returns: 

180 float: The current loss. 

181 

182 """ 

183 tf.cond( 

184 self.write_bool, 

185 true_fn=lambda: _summary_cost(loss_name, current_loss), 

186 false_fn=lambda: _do_nothing(), 

187 name="Cost", 

188 ) 

189 return current_loss 

190 

191 @classmethod 

192 def from_config(cls, config): 

193 p = config.pop("p") 

194 if "cartesian_pwd_start" in p: 

195 p = ADCParameters(**p) 

196 else: 

197 p = Parameters(**p) 

198 write_bool_cb = tf.keras.saving.deserialize_keras_object( 

199 config.pop("write_bool_cb") 

200 ) 

201 return cls(parameters=p, write_bool_cb=write_bool_cb, **config) 

202 

203 def get_config(self) -> dict[Any, Any]: 

204 config = super().get_config().copy() 

205 config.update( 

206 { 

207 "p": self.p.to_dict(), 

208 "write_bool_cb": tf.keras.saving.serialize_keras_object(self.cb), 

209 } 

210 ) 

211 return config 

212 

213 

214@testing 

215@tf.keras.utils.register_keras_serializable() 

216class ADCBaseLoss(EncoderMapBaseLoss): 

217 """Base class for all Losses of the `AngleDihedralCartesianEncoderMap`. 

218 

219 Replaces the default `Parameters()` with `ADCParameters()`. 

220 

221 """ 

222 

223 def __init__( 

224 self, 

225 parameters: Optional[ADCParameters] = None, 

226 write_bool_cb: Optional[tf.keras.callbacks.Callback] = None, 

227 **kwargs, 

228 ) -> None: 

229 super().__init__(parameters, write_bool_cb) 

230 if parameters is None: 

231 self.p = ADCParameters() 

232 

233 

234@testing 

235@tf.keras.utils.register_keras_serializable() 

236class DihedralLoss(ADCBaseLoss): 

237 """EncoderMap's `DihedralLoss` for `AngleDihedralCartesianEncoderMap`. 

238 

239 Uses the periodicity in `self.p` to compare the distances of input and 

240 output dihedrals. The `inp_dihedrals` are a tensor of size 

241 (batch_size, n_dihedrals), the `out_dihedrals` are a tensor of size 

242 (batch_size, n_dihedrals). The distances between two dihedrals are 

243 calculated with `d = tf.abs(inp_dihedrals - out_dihedrals)`. This array 

244 has the shape (batch_size, n_dihedrals). Because angles lie in a periodic 

245 space with periodicity (-pi, pi] this array needs to be adjusted to account 

246 for this with `tf.min(d, periodicity - d)`. The resulting array of shape 

247 (batch_size, n_dihedrals) will now be transformed based on the 

248 `dihedral_cost_variant`, which can be 'mean_square', 'mean_abs', or 'mean_norm'. 

249 

250 """ 

251 

252 name = "DihedralLoss" 

253 

254 def call(self, y_true: Sequence[tf.Tensor], y_pred: Sequence[tf.Tensor]) -> float: 

255 inp_dihedrals = y_pred[1] 

256 out_dihedrals = y_true[1] 

257 

258 if self.p.dihedral_cost_scale is not None: 

259 if self.p.dihedral_cost_variant == "mean_square": 

260 dihedral_cost = tf.reduce_mean( 

261 tf.square( 

262 periodic_distance( 

263 inp_dihedrals, out_dihedrals, self.p.periodicity 

264 ) 

265 ) 

266 ) 

267 elif self.p.dihedral_cost_variant == "mean_abs": 

268 dihedral_cost = tf.reduce_mean( 

269 tf.abs( 

270 periodic_distance( 

271 inp_dihedrals, out_dihedrals, self.p.periodicity 

272 ) 

273 ) 

274 ) 

275 elif self.p.dihedral_cost_variant == "mean_norm": 

276 dihedral_cost = tf.reduce_mean( 

277 tf.norm( 

278 periodic_distance( 

279 inp_dihedrals, out_dihedrals, self.p.periodicity 

280 ), 

281 axis=1, 

282 ) 

283 ) 

284 else: 

285 raise ValueError( 

286 f"dihedral_cost_variant {self.p.dihedral_cost_variant} not available" 

287 ) 

288 dihedral_cost /= self.p.dihedral_cost_reference 

289 if self.p.dihedral_cost_scale != 0: 

290 dihedral_cost *= self.p.dihedral_cost_scale 

291 else: 

292 dihedral_cost = 0 

293 

294 return super().call("Dihedral Cost", dihedral_cost) 

295 

296 

297@testing 

298@tf.keras.utils.register_keras_serializable() 

299class AngleLoss(ADCBaseLoss): 

300 """EncoderMap's `AngleLoss` for `AngleDihedralCartesianEncoderMap`. 

301 

302 Uses the periodicity in `self.p` to compare the distances of input and 

303 output angles. The `inp_angles` are a tensor of size 

304 (batch_size, n_angles), the `out_angles` are a tensor of size 

305 (batch_size, n_angles). The distances between two angles are 

306 calculated with `d = tf.abs(inp_angles - out_angles)`. This array 

307 has the shape (batch_size, n_angles). Because angles lie in a periodic 

308 space with periodicity (-pi, pi] this array needs to be adjusted to account 

309 for this with `tf.min(d, periodicity - d)`. The resulting array of shape 

310 (batch_size, n_angles) will now be transformed based on the 

311 `angle_cost_variant`, which can be 'mean_square', 'mean_abs', or 'mean_norm'. 

312 

313 """ 

314 

315 def call(self, y_true: Sequence[tf.Tensor], y_pred: Sequence[tf.Tensor]) -> float: 

316 inp_angles = y_pred[0] 

317 out_angles = y_true[0] 

318 

319 if self.p.angle_cost_scale is not None: 

320 if self.p.angle_cost_variant == "mean_square": 

321 angle_cost = tf.reduce_mean( 

322 tf.square( 

323 periodic_distance(inp_angles, out_angles, self.p.periodicity) 

324 ) 

325 ) 

326 elif self.p.angle_cost_variant == "mean_abs": 

327 angle_cost = tf.reduce_mean( 

328 tf.abs( 

329 periodic_distance(inp_angles, out_angles, self.p.periodicity) 

330 ) 

331 ) 

332 elif self.p.angle_cost_variant == "mean_norm": 

333 angle_cost = tf.reduce_mean( 

334 tf.norm( 

335 periodic_distance(inp_angles, out_angles, self.p.periodicity), 

336 axis=1, 

337 ) 

338 ) 

339 else: 

340 raise ValueError( 

341 f"angle_cost_variant {self.p.angle_cost_variant} not available" 

342 ) 

343 angle_cost /= self.p.angle_cost_reference 

344 if self.p.angle_cost_scale != 0: 

345 angle_cost *= self.p.angle_cost_scale 

346 else: 

347 angle_cost = 0 

348 

349 return super().call("Angle Cost", angle_cost)