Coverage for encodermap/models/models.py: 43%

457 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/models/models.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""ToDo: 

23 * Add some nice images to the plot_model of the functional model. 

24 

25""" 

26 

27############################################################################## 

28# Imports 

29############################################################################## 

30 

31 

32import warnings 

33from math import pi 

34 

35import numpy as np 

36import tensorflow as tf 

37import tensorflow.keras.backend as K 

38from tensorflow.keras.layers import Concatenate, Dense, Input, Lambda 

39 

40from ..encodermap_tf1.backmapping import chain_in_plane, dihedrals_to_cartesian_tf 

41from ..misc import pairwise_dist 

42from ..misc.summaries import add_layer_summaries 

43from ..parameters.parameters import ADCParameters, Parameters 

44from ..trajinfo.info_all import Capturing 

45from .layers import ( 

46 BackMapLayer, 

47 MeanAngles, 

48 PairwiseDistances, 

49 PeriodicInput, 

50 PeriodicOutput, 

51 Sparse, 

52 SparseReshape, 

53) 

54 

55############################################################################## 

56# Globals 

57############################################################################## 

58 

59 

60__all__ = ["gen_sequential_model", "gen_functional_model"] 

61 

62 

63############################################################################## 

64# Public Functions 

65############################################################################## 

66 

67 

68def gen_sequential_model(input_shape, parameters=None, sparse=False): 

69 """Returns a tf.keras Model build with the specified input shape and the parameters in the Parameters class. 

70 

71 Args: 

72 input_shape (int): The input shape of the returned model. In most cases that is data.shape[1] of your data. 

73 parameters (Union[encodermap.Parameters, encodermap.ADCParameters, None], optional): The parameters to 

74 use on the returned model. If None is provided the default parameters in encodermap.Parameters.defaults 

75 is used. You can look at the defaults with print(em.Parameters.defaults_description()). Defaults to None. 

76 

77 Returns: 

78 em.SequentialModel: A subclass of tf.keras.Model build with specified parameters. 

79 

80 """ 

81 if parameters is None: 81 ↛ 82line 81 didn't jump to line 82, because the condition on line 81 was never true

82 parameters = Parameters() 

83 else: 

84 if isinstance(parameters, Parameters): 84 ↛ 86line 84 didn't jump to line 86, because the condition on line 84 was never false

85 return SequentialModel(input_shape, parameters, sparse=sparse) 

86 elif isinstance(parameters, ADCParameters): 

87 return ADCSequentialModel(input_shape, parameters) 

88 else: 

89 p = parameters 

90 raise TypeError( 

91 f"parameters need to be ecodermap.Parameters or encodermap.ACDParameters. You supplied {type(p)}" 

92 ) 

93 

94 

95class Sparse(tf.keras.layers.Dense): 

96 def call(self, inputs): 

97 outputs = tf.sparse.sparse_dense_matmul(inputs, self.kernel) 

98 if self.use_bias: 

99 outputs = tf.nn.bias_add(outputs, self.bias) 

100 return outputs 

101 

102 

103class SparseModel(tf.keras.Model): 

104 def __init__(self, name, input_shape): 

105 super(SparseModel, self).__init__(name) 

106 self._sparse_layer = Sparse(input_shape) 

107 inputs = tf.keras.layers.Input( 

108 shape=(input_shape,), sparse=True, name="sparse_tensor" 

109 ) 

110 self._set_inputs(inputs) 

111 

112 def call(self, sparse_tensor): 

113 return self._sparse_layer(sparse_tensor) 

114 

115 

116def gen_functional_model( 

117 input_dataset, parameters=None, reload_layers=None, sparse=False 

118): 

119 """Builds a model to specification of parameters using the functional API. 

120 

121 The functional API is much more flexible than the sequential API, in that models with multiple inputs and outputs 

122 can be defined. Custom-layers and sub-models can be intermixed. In EncoderMap's case the functional API is used to 

123 build the AngleDihedralCartesianAutoencoder, which takes input data in form of a tf.data.Dataset with: 

124 * backbone_angles (angles between C, CA, N - atoms in the backbone). 

125 * backbone_torsions (dihedral angles in the backbone, commonly known as omega, phi, psi). 

126 * cartesian_coordinates (coordinates of the C, CA, N backbone atoms. This data has ndim 3, the other have ndim 2). 

127 * backbone_distances (distances between the C, CA, N backbone atoms). 

128 * sidechain_torsions (dihedral angles in the sidechain, commonly known as chi1, chi2, chi3, chi4, chi5). 

129 Packing and unpacking that data in the correct manner is important. Make sure to double check whether you are using 

130 angles or dihedrals. A simple print of the shape can be enough. 

131 

132 In the functional model all operations are tf.keras.layers, meaning that the projection onto a unit_circle that 

133 the `SequentialModel` does in its `call()` method needs to be a layer. The FunctionalModel consist of 5 main parts: 

134 * Angle Inputs: The provided dataset is unpacked and the periodic data of the angles is projected onto 

135 a unit-circle. If the angles are in gradians, they will also be normalized into a [-pi, pi) interval. 

136 * Autoencoder: The trainable part of the network consists of the Autoencoder part build to the specifications 

137 in the provided parameters. Here, Dense layers are stacked. Only the angles and torsions are fed into the 

138 Autoencoder. The Distances and Cartesians are used later. 

139 * Angle Outputs: The angles are recalculated from their unit-circle inputs. 

140 * Back-Mapping. The backmapping layer takes backbone_angles and backbone_dihedrals, backbone_distances to 

141 calculate new cartesian coordinates. 

142 * Pairwise Distances: The pairwise distances of the input cartesians and the back-mapped cartesians are calculated. 

143 

144 Args: 

145 input_dataset (tf.data.Dataset): The dataset with the data in the order given in the explanation. 

146 parameters (Union[em.ADCParameters, None], optional): The parameters to be used to build the network. 

147 If None is provided the default parameters in encodermap.ADCParameters.defaults 

148 is used. You can look at the defaults with print(em.ADCParameters.defaults_description()). Defaults to None. 

149 reload_layers (Union[None, list], optional): List of layers that will be reloaded when reloading the model from 

150 disk. Defaults to None, when a new model should be built. 

151 

152 Raises: 

153 AssertionError: AssertionErrors will be raised when the input data is not formatted correctly. 

154 This means, if len(cartesians) != len(distances) - 1, or len(cartesians) != len(angles) - 2. 

155 This can also mean, the input dataset is not packed correctly. Please keep the order specified above. 

156 This can also mean, that the provided protein is not linear (branched, circular, ...). 

157 

158 Returns: 

159 em.FunctionalModel: A subclass of tf.keras.Model build with specified parameters. 

160 

161 """ 

162 if parameters is None: 162 ↛ 163line 162 didn't jump to line 163, because the condition on line 162 was never true

163 p = Parameters() 

164 else: 

165 p = parameters 

166 

167 if isinstance(p, Parameters): 167 ↛ 168line 167 didn't jump to line 168, because the condition on line 167 was never true

168 raise Exception( 

169 "Functional Model is currently reserved for the ADCAutoencoder," 

170 "because of the way the data is packed and unpacked." 

171 ) 

172 

173 # unpack the shapes of the input 

174 for i, d in enumerate(input_dataset): 174 ↛ 179line 174 didn't jump to line 179, because the loop on line 174 didn't complete

175 angles, dihedrals, cartesians, distances, side_dihedrals = d 

176 break 

177 

178 # These assertions need to be changed for all proteins that are not-linear. 

179 if isinstance(cartesians, tf.sparse.SparseTensor): 179 ↛ 180line 179 didn't jump to line 180, because the condition on line 179 was never true

180 assert distances.shape[1] == cartesians.shape[1] // 3 - 1 

181 else: 

182 assert distances.shape[1] == cartesians.shape[1] - 1, print( 

183 distances.shape, cartesians.shape 

184 ) 

185 assert angles.shape[1] == distances.shape[1] - 1, print( 

186 angles.shape, cartesians.shape 

187 ) 

188 assert dihedrals.shape[1] == distances.shape[1] - 2, print( 

189 dihedrals.shape, cartesians.shape 

190 ) 

191 

192 if reload_layers is not None: 192 ↛ 193line 192 didn't jump to line 193, because the condition on line 192 was never true

193 raise Exception("currently not reloadable.") 

194 

195 # define regularizer 

196 regularizer = tf.keras.regularizers.l2(p.l2_reg_constant) 

197 

198 # central cartesians 

199 if not sparse or not isinstance(dihedrals, tf.sparse.SparseTensor): 199 ↛ 207line 199 didn't jump to line 207, because the condition on line 199 was never false

200 inp_dihedrals = Input( 

201 shape=(dihedrals.shape[1],), 

202 name="input_dihedrals", 

203 ) 

204 x = PeriodicInput(p, "dihedrals")(inp_dihedrals) 

205 get_dense_model_central_dihedrals = None 

206 else: 

207 inp_dihedrals = Input( 

208 shape=(dihedrals.shape[1],), 

209 name="input_dihedrals", 

210 sparse=True, 

211 ) 

212 x = Dense(dihedrals.shape[1])(inp_dihedrals) 

213 get_dense_model_central_dihedrals = tf.keras.Model( 

214 inputs=inp_dihedrals, 

215 outputs=x, 

216 ) 

217 x = PeriodicInput(p, "dihedrals")(x) 

218 

219 # backbone angles 

220 if p.use_backbone_angles: 220 ↛ 242line 220 didn't jump to line 242, because the condition on line 220 was never false

221 if not sparse or not isinstance(angles, tf.sparse.SparseTensor): 221 ↛ 229line 221 didn't jump to line 229, because the condition on line 221 was never false

222 inp_angles = Input( 

223 shape=(angles.shape[1],), 

224 name="input_angles", 

225 ) 

226 y = PeriodicInput(p, "angles")(inp_angles) 

227 get_dense_model_central_angles = None 

228 else: 

229 inp_angles = Input( 

230 shape=(angles.shape[1],), 

231 name="input_angles", 

232 sparse=True, 

233 ) 

234 y = Dense(angles.shape[1])(inp_angles) 

235 get_dense_model_central_angles = tf.keras.Model( 

236 inputs=inp_angles, 

237 outputs=y, 

238 ) 

239 y = PeriodicInput(p, "angles")(y) 

240 

241 # sidechains 

242 get_dense_model_side_dihedrals = None 

243 if p.use_sidechains: 243 ↛ 265line 243 didn't jump to line 265, because the condition on line 243 was never false

244 if not sparse or not isinstance(side_dihedrals, tf.sparse.SparseTensor): 244 ↛ 250line 244 didn't jump to line 250, because the condition on line 244 was never false

245 inp_side_dihedrals = Input( 

246 shape=(side_dihedrals.shape[1],), name="input_side_dihedrals" 

247 ) 

248 z = PeriodicInput(p, "side_dihedrals")(inp_side_dihedrals) 

249 else: 

250 inp_side_dihedrals = Input( 

251 shape=(side_dihedrals.shape[1],), 

252 name="input_side_dihedrals", 

253 sparse=True, 

254 ) 

255 z = Dense(side_dihedrals.shape[1])(inp_side_dihedrals) 

256 # z = SparseModel(name="Sparse_Model", input_shape=side_dihedrals.shape[1])(inp_side_dihedrals) 

257 get_dense_model_side_dihedrals = tf.keras.Model( 

258 inputs=inp_side_dihedrals, 

259 outputs=z, 

260 ) 

261 z = PeriodicInput(p, "side_dihedrals")(z) 

262 

263 # these inputs will be passed through and will be used for backmapping 

264 # and RMSD metrics 

265 if not sparse or not isinstance(cartesians, tf.sparse.SparseTensor): 265 ↛ 280line 265 didn't jump to line 280, because the condition on line 265 was never false

266 inp_cartesians = Input( 

267 shape=( 

268 cartesians.shape[1], 

269 3, 

270 ), 

271 name="input_cartesians", 

272 ) 

273 inp_distances = Input( 

274 shape=(distances.shape[1],), 

275 name="input_distances", 

276 ) 

277 get_dense_model_cartesians = None 

278 get_dense_model_distances = None 

279 else: 

280 inp_cartesians = Input( 

281 shape=(cartesians.shape[1],), 

282 name="input_cartesians", 

283 sparse=True, 

284 ) 

285 dc = Dense(cartesians.shape[1])(inp_cartesians) 

286 get_dense_model_cartesians = tf.keras.Model( 

287 inputs=inp_cartesians, 

288 outputs=dc, 

289 ) 

290 dc = tf.keras.layers.Reshape( 

291 target_shape=( 

292 cartesians.shape[1] // 3, 

293 3, 

294 ), 

295 input_shape=(cartesians.shape[1],), 

296 )(dc) 

297 inp_distances = Input( 

298 shape=(distances.shape[1],), 

299 name="input_distances", 

300 sparse=True, 

301 ) 

302 dd = Dense(distances.shape[1])(inp_distances) 

303 get_dense_model_distances = tf.keras.Model( 

304 inputs=inp_distances, 

305 outputs=dd, 

306 ) 

307 

308 # stack the three datasources going through the network 

309 if not p.use_backbone_angles and not p.use_sidechains: 309 ↛ 310line 309 didn't jump to line 310, because the condition on line 309 was never true

310 splits = None 

311 elif p.use_backbone_angles and not p.use_sidechains: 311 ↛ 312line 311 didn't jump to line 312, because the condition on line 311 was never true

312 splits = [x.shape[1], y.shape[1]] 

313 x = Concatenate(axis=1, name="Main_Inputs")([x, y]) 

314 elif p.use_backbone_angles and p.use_sidechains: 314 ↛ 320line 314 didn't jump to line 320, because the condition on line 314 was never false

315 # ToDo: make sure the splits work here. There seems to be different outputs from tf and np. 

316 splits = [x.shape[1], y.shape[1], z.shape[1]] 

317 x = Concatenate(axis=1, name="Main_Inputs")([x, y, z]) 

318 

319 # save the out_shape now and use it for an output layer 

320 out_shape = x.shape[1] 

321 

322 # rename empty string in parameters to None 

323 activation_functions = list( 

324 map(lambda x: x if x != "" else None, p.activation_functions) 

325 ) 

326 

327 # define how layers are stacked 

328 layer_data = list( 

329 zip( 

330 p.n_neurons + p.n_neurons[-2::-1], 

331 activation_functions[1:] + activation_functions[-1::-1], 

332 ) 

333 ) 

334 # add a layer that reshapes the output 

335 layer_data.append([out_shape, "tanh"]) 

336 

337 # decide layer names 

338 names = [] 

339 for i, (n_neurons, act_fun) in enumerate(layer_data): 

340 if i < len(p.n_neurons) - 1: 

341 name = f"Encoder_{i}" 

342 elif i > len(p.n_neurons) - 1: 

343 ind = i - len(p.n_neurons) 

344 name = f"Decoder_{ind}" 

345 else: 

346 name = "Latent" 

347 names.append(name) 

348 layer_data = list((*i, j) for i, j in zip(layer_data, names)) 

349 

350 # unpack layer data into encoder and decoder 

351 neurons = [i[0] for i in layer_data] 

352 bottleneck_index = neurons.index(min(neurons)) + 1 

353 encoder_layers = layer_data[:bottleneck_index] 

354 decoder_layers = layer_data[bottleneck_index:] 

355 

356 # enocder layers 

357 for n_neurons, act_fun, name in encoder_layers: 

358 layer = Dense( 

359 units=n_neurons, 

360 activation=act_fun, 

361 name=name, 

362 kernel_initializer=tf.initializers.VarianceScaling(), 

363 kernel_regularizer=regularizer, 

364 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5), 

365 ) 

366 x = layer(x) 

367 

368 # encoder model 

369 if p.use_backbone_angles and p.use_sidechains: 369 ↛ 375line 369 didn't jump to line 375, because the condition on line 369 was never false

370 encoder = tf.keras.Model( 

371 inputs=[inp_angles, inp_dihedrals, inp_side_dihedrals], 

372 outputs=[x], 

373 name="Encoder", 

374 ) 

375 elif p.use_backbone_angles and not p.use_sidechains: 

376 encoder = tf.keras.Model( 

377 inputs=[inp_angles, inp_dihedrals], 

378 outputs=[x], 

379 name="Encoder", 

380 ) 

381 else: 

382 encoder = tf.keras.Model( 

383 inputs=[inp_dihedrals, inp_side_dihedrals], 

384 outputs=[x], 

385 name="Encoder", 

386 ) 

387 if p.tensorboard: 387 ↛ 388line 387 didn't jump to line 388, because the condition on line 387 was never true

388 with Capturing() as output: 

389 encoder.summary() 

390 with open(p.main_path + "/encoder_summary.txt", "w") as f: 

391 f.write("\n".join(output)) 

392 

393 # decoder input 

394 decoder_input = Input(shape=(encoder_layers[-1][0],), name="Decoder_Input") 

395 x = decoder_input 

396 

397 # decoder layers 

398 for i, (n_neurons, act_fun, name) in enumerate(decoder_layers): 

399 layer = Dense( 

400 units=n_neurons, 

401 activation=act_fun, 

402 name=name, 

403 kernel_initializer=tf.initializers.VarianceScaling(), 

404 kernel_regularizer=regularizer, 

405 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5), 

406 ) 

407 x = layer(x) 

408 

409 # split output accordingly 

410 if splits is None: 410 ↛ 411line 410 didn't jump to line 411, because the condition on line 410 was never true

411 out_dihedrals = x 

412 out_angles = MeanAngles(p, "Mean_Angles", out_dihedrals.shape[0])(inp_angles) 

413 decoder_input = [decoder_input, inp_angles] 

414 decoder_output = [out_angles, out_dihedrals] 

415 elif len(splits) == 2: 415 ↛ 416line 415 didn't jump to line 416, because the condition on line 415 was never true

416 out_angles, out_dihedrals = Lambda( 

417 lambda x: tf.split(x, splits, 1), name="Split_Output" 

418 )(x) 

419 out_angles = PeriodicOutput(p, "Angles")(out_angles) 

420 out_dihedrals = PeriodicOutput(p, "Dihedrals")(out_dihedrals) 

421 decoder_input = [decoder_input] 

422 decoder_output = [out_angles, out_dihedrals] 

423 elif len(splits) == 3: 423 ↛ 434line 423 didn't jump to line 434, because the condition on line 423 was never false

424 out_dihedrals, out_angles, out_side_dihedrals = Lambda( 

425 lambda x: tf.split(x, splits, 1), name="Split_Output" 

426 )(x) 

427 out_angles = PeriodicOutput(p, "Angles")(out_angles) 

428 out_dihedrals = PeriodicOutput(p, "Dihedrals")(out_dihedrals) 

429 out_side_dihedrals = PeriodicOutput(p, "Side_Dihedrals")(out_side_dihedrals) 

430 decoder_input = [decoder_input] 

431 decoder_output = [out_angles, out_dihedrals, out_side_dihedrals] 

432 

433 # decoder model before backmapping 

434 decoder = tf.keras.Model( 

435 inputs=decoder_input, name="Decoder", outputs=decoder_output 

436 ) 

437 if p.tensorboard: 437 ↛ 438line 437 didn't jump to line 438, because the condition on line 437 was never true

438 with Capturing() as output: 

439 decoder.summary() 

440 with open(p.main_path + "/decoder_summary.txt", "w") as f: 

441 f.write("\n".join(output)) 

442 

443 # backmap input 

444 back_inp_angles = Input(shape=(out_angles.shape[1],), name="Back_Angles_Input") 

445 back_inp_dihedrals = Input( 

446 shape=(out_dihedrals.shape[1],), name="Back_Dihedrals_Input" 

447 ) 

448 if p.use_sidechains: 448 ↛ 454line 448 didn't jump to line 454, because the condition on line 448 was never false

449 pass_sidedihedrals = Input( 

450 shape=(out_side_dihedrals.shape[1],), name="Side_Dihedrals_Pass_Through" 

451 ) 

452 

453 # backmapping. The hardest part 

454 if not sparse or not isinstance(distances, tf.sparse.SparseTensor): 454 ↛ 459line 454 didn't jump to line 459, because the condition on line 454 was never false

455 back_cartesians = BackMapLayer()( 

456 (inp_distances, back_inp_angles, back_inp_dihedrals) 

457 ) 

458 else: 

459 back_cartesians = BackMapLayer()((dd, back_inp_angles, back_inp_dihedrals)) 

460 

461 # pairwise distances is the last part 

462 if not sparse or not isinstance(cartesians, tf.sparse.SparseTensor): 462 ↛ 465line 462 didn't jump to line 465, because the condition on line 462 was never false

463 inp_pair = PairwiseDistances(p, "Input")(inp_cartesians) 

464 else: 

465 inp_pair = PairwiseDistances(p, "Input")(dc) 

466 out_pair = PairwiseDistances(p, "Backmapped")(back_cartesians) 

467 

468 # backmap_model 

469 if p.use_sidechains: 469 ↛ 489line 469 didn't jump to line 489, because the condition on line 469 was never false

470 backmap_model = tf.keras.Model( 

471 name="Backmapping", 

472 inputs=[ 

473 back_inp_angles, 

474 back_inp_dihedrals, 

475 pass_sidedihedrals, 

476 inp_distances, 

477 inp_cartesians, 

478 ], 

479 outputs=[ 

480 back_inp_angles, 

481 back_inp_dihedrals, 

482 back_cartesians, 

483 inp_pair, 

484 out_pair, 

485 pass_sidedihedrals, 

486 ], 

487 ) 

488 else: 

489 backmap_model = tf.keras.Model( 

490 name="Backmapping", 

491 inputs=[ 

492 back_inp_angles, 

493 back_inp_dihedrals, 

494 inp_distances, 

495 inp_cartesians, 

496 ], 

497 outputs=[ 

498 back_inp_angles, 

499 back_inp_dihedrals, 

500 back_cartesians, 

501 inp_pair, 

502 out_pair, 

503 ], 

504 ) 

505 if p.tensorboard: 505 ↛ 506line 505 didn't jump to line 506, because the condition on line 505 was never true

506 with Capturing() as output: 

507 backmap_model.summary() 

508 with open(p.main_path + "/backmap_summary.txt", "w") as f: 

509 f.write("\n".join(output)) 

510 

511 # call all the models hierarchically to rebuild a complete model 

512 if p.use_sidechains: 512 ↛ 528line 512 didn't jump to line 528

513 main_inputs = [ 

514 inp_angles, 

515 inp_dihedrals, 

516 inp_cartesians, 

517 inp_distances, 

518 inp_side_dihedrals, 

519 ] 

520 main_outputs = backmap_model( 

521 ( 

522 *decoder(encoder((inp_angles, inp_dihedrals, inp_side_dihedrals))), 

523 inp_distances, 

524 inp_cartesians, 

525 ) 

526 ) 

527 else: 

528 main_inputs = [ 

529 inp_angles, 

530 inp_dihedrals, 

531 inp_cartesians, 

532 inp_distances, 

533 ] 

534 main_outputs = backmap_model( 

535 ( 

536 *decoder(encoder((inp_angles, inp_dihedrals))), 

537 inp_distances, 

538 inp_cartesians, 

539 ) 

540 ) 

541 

542 # full_model = tf.keras.Model(inputs=[inp_angles, inp_dihedrals, inp_cartesians, inp_distances, inp_side_dihedrals], 

543 # outputs=main_outputs, name="Full_Model") 

544 

545 # pass input and outputs to FunctionalModel 

546 # In FunctionalModel train_step is overwritten. Train_step should unpack the data and assign the inputs/outputs to the 

547 # differtent loss functions. 

548 if not sparse: 548 ↛ 557line 548 didn't jump to line 557, because the condition on line 548 was never false

549 model = FunctionalModel( 

550 parameters=p, 

551 inputs=main_inputs, 

552 outputs=main_outputs, 

553 encoder=encoder, 

554 decoder=decoder, 

555 ) 

556 else: 

557 model = SparseFunctionalModel( 

558 parameters=p, 

559 inputs=main_inputs, 

560 outputs=main_outputs, 

561 encoder=encoder, 

562 decoder=decoder, 

563 get_dense_model_central_dihedrals=get_dense_model_central_dihedrals, 

564 get_dense_model_central_angles=get_dense_model_central_angles, 

565 get_dense_model_side_dihedrals=get_dense_model_side_dihedrals, 

566 get_dense_model_cartesians=get_dense_model_cartesians, 

567 get_dense_model_distances=get_dense_model_distances, 

568 ) 

569 

570 return model 

571 

572 

573############################################################################## 

574# Public Classes 

575############################################################################## 

576 

577 

578class FunctionalModel(tf.keras.Model): 

579 def __init__(self, parameters, inputs, outputs, encoder, decoder): 

580 super().__init__(inputs=inputs, outputs=outputs) 

581 self.p = parameters 

582 self.encoder_model = encoder 

583 self.decoder_model = decoder 

584 

585 # train counter 

586 self._train_counter = K.variable(0, "int64", name="train_counter") 

587 

588 def encoder(self, x, training=False): 

589 return self.encoder_model(x, training=training) 

590 

591 def decoder(self, x, training=False): 

592 return self.decoder_model(x, training=training) 

593 

594 def compile(self, *args, **kwargs): 

595 super().compile(*args, **kwargs) 

596 self.unpacked_loss_fns = {fn.__name__: fn for fn in self.compiled_loss._losses} 

597 

598 def get_loss(self, inp): 

599 # unpack the inputs 

600 if self.p.use_sidechains: 

601 ( 

602 inp_angles, 

603 inp_dihedrals, 

604 inp_cartesians, 

605 inp_distances, 

606 inp_side_dihedrals, 

607 ) = inp 

608 else: 

609 ( 

610 inp_angles, 

611 inp_dihedrals, 

612 inp_cartesians, 

613 inp_distances, 

614 ) = inp 

615 # unpack the outputs 

616 out = self(inp, training=True) 

617 if self.p.use_sidechains: 

618 ( 

619 out_angles, 

620 out_dihedrals, 

621 back_cartesians, 

622 inp_pair, 

623 out_pair, 

624 out_side_dihedrals, 

625 ) = out 

626 else: 

627 ( 

628 out_angles, 

629 out_dihedrals, 

630 back_cartesians, 

631 inp_pair, 

632 out_pair, 

633 ) = out 

634 

635 # define latent for cartesian_distance_loss 

636 if self.p.use_sidechains: 

637 latent = self.encoder_model( 

638 (inp_angles, inp_dihedrals, inp_side_dihedrals), 

639 training=True, 

640 ) 

641 else: 

642 latent = self.encoder_model( 

643 (inp_angles, inp_dihedrals), 

644 training=True, 

645 ) 

646 

647 with tf.name_scope("Cost"): 

648 loss = 0.0 

649 # dihedral loss 

650 loss += self.unpacked_loss_fns["dihedral_loss_func"]( 

651 inp_dihedrals, out_dihedrals 

652 ) 

653 # angle loss 

654 loss += self.unpacked_loss_fns["angle_loss_func"](inp_angles, out_angles) 

655 # cartesian loss 

656 loss += self.unpacked_loss_fns["cartesian_loss_func"](inp_pair, out_pair) 

657 # distance loss 

658 loss += self.unpacked_loss_fns["distance_loss_func"]( 

659 (inp_angles, inp_dihedrals, inp_side_dihedrals) 

660 ) 

661 # cartesian distance cost 

662 loss += self.unpacked_loss_fns["cartesian_distance_loss_func"]( 

663 inp_pair, latent 

664 ) 

665 # center loss 

666 loss += self.unpacked_loss_fns["center_loss_func"]( 

667 (inp_angles, inp_dihedrals, inp_side_dihedrals) 

668 ) 

669 # reg loss 

670 loss += self.unpacked_loss_fns["regularization_loss_func"]() 

671 # side dihedral loss 

672 if self.p.use_sidechains: 

673 loss += self.unpacked_loss_fns["side_dihedral_loss_func"]( 

674 inp_side_dihedrals, out_side_dihedrals 

675 ) 

676 tf.summary.scalar("Combined Cost", loss) 

677 

678 # autoloss 

679 from encodermap.loss_functions.loss_functions import periodic_distance 

680 

681 angle_auto_loss = 2 * tf.reduce_mean( 

682 tf.square(periodic_distance(inp_angles, out_angles, 2 * np.pi)) 

683 ) 

684 dihedral_auto_loss = 2 * tf.reduce_mean( 

685 tf.square(periodic_distance(inp_dihedrals, out_dihedrals, 2 * np.pi)) 

686 ) 

687 side_dihedral_auto_loss = 2 * tf.reduce_mean( 

688 tf.square( 

689 periodic_distance(inp_side_dihedrals, out_side_dihedrals, 2 * np.pi) 

690 ) 

691 ) 

692 return loss 

693 

694 def train_step(self, data): 

695 # Data will always contain all inputs, depending on p, the model will return different tuples 

696 # but the input will be the same, only when the encoder is called with teh train data, and the 

697 # decoder is called does packing and unpacking matter. 

698 # gradient tape to calculate loss for backmapping 

699 

700 with tf.GradientTape() as tape: 

701 tf.summary.experimental.set_step(self._train_counter) 

702 loss = self.get_loss(data) 

703 loggable_encoder_layers = [ 

704 l for l in self.encoder_model.layers if l.__class__.__name__ == "Dense" 

705 ] 

706 loggable_decoder_layers = [ 

707 l for l in self.decoder_model.layers if l.__class__.__name__ == "Dense" 

708 ] 

709 for l in loggable_encoder_layers + loggable_decoder_layers: 

710 add_layer_summaries(l, step=self._train_counter) 

711 

712 # optimization happens here 

713 # Compute Gradients 

714 # trainable_vars = self.trainable_variables 

715 trainable_vars = ( 

716 self.encoder_model.trainable_variables 

717 + self.decoder_model.trainable_variables 

718 ) 

719 # maybe self.encoder_model.trainable_vars + self.decoder_model.trainable_vars 

720 gradients = tape.gradient(loss, trainable_vars) 

721 # Update weights 

722 self.optimizer.apply_gradients(zip(gradients, trainable_vars)) 

723 # Update metrics 

724 # self.compiled_metrics.update_state(data, self(data)) 

725 # Return a dict mapping metric names to current value 

726 # Add loss to the dict so the ProgressBar callback can pick it up 

727 # return {**{m.name: m.result() for m in self.metrics}, **{'loss': loss}} 

728 # udpate train counter because tensorflow seems to have deprecated it 

729 self._train_counter.assign_add(1) 

730 return {"loss": loss} 

731 

732 

733class SparseFunctionalModel(FunctionalModel): 

734 def __init__( 

735 self, 

736 parameters, 

737 inputs, 

738 outputs, 

739 encoder, 

740 decoder, 

741 get_dense_model_central_dihedrals, 

742 get_dense_model_central_angles, 

743 get_dense_model_side_dihedrals, 

744 get_dense_model_cartesians, 

745 get_dense_model_distances, 

746 ): 

747 super().__init__(parameters, inputs, outputs, encoder, decoder) 

748 self.get_dense_model_central_dihedrals = get_dense_model_central_dihedrals 

749 self.get_dense_model_central_angles = get_dense_model_central_angles 

750 self.get_dense_model_side_dihedrals = get_dense_model_side_dihedrals 

751 self.get_dense_model_cartesians = get_dense_model_cartesians 

752 self.get_dense_model_distances = get_dense_model_distances 

753 

754 def get_loss(self, inp): 

755 # unpack the inputs 

756 ( 

757 sparse_inp_angles, 

758 sparse_inp_dihedrals, 

759 sparse_inp_cartesians, 

760 sparse_inp_distances, 

761 sparse_side_dihedrals, 

762 ) = inp 

763 

764 # make the side dihedrals dense 

765 if isinstance(sparse_inp_angles, tf.sparse.SparseTensor): 

766 inp_angles = self.get_dense_model_central_angles( 

767 sparse_inp_angles, training=True 

768 ) 

769 else: 

770 inp_angles = sparse_inp_angles 

771 if isinstance(sparse_inp_dihedrals, tf.sparse.SparseTensor): 

772 inp_dihedrals = self.get_dense_model_central_dihedrals( 

773 sparse_inp_dihedrals, training=True 

774 ) 

775 else: 

776 inp_dihedrals = sparse_inp_dihedrals 

777 if isinstance(sparse_inp_cartesians, tf.sparse.SparseTensor): 

778 inp_cartesians = self.get_dense_model_cartesians( 

779 sparse_inp_cartesians, training=True 

780 ) 

781 else: 

782 inp_cartesians = sparse_inp_cartesians 

783 if isinstance(sparse_inp_distances, tf.sparse.SparseTensor): 

784 inp_distances = self.get_dense_model_distances( 

785 sparse_inp_distances, training=True 

786 ) 

787 else: 

788 inp_distances = sparse_inp_distances 

789 

790 if self.p.use_sidechains: 

791 if isinstance(sparse_side_dihedrals, tf.sparse.SparseTensor): 

792 inp_side_dihedrals = self.get_dense_model_side_dihedrals( 

793 sparse_side_dihedrals, training=True 

794 ) 

795 else: 

796 inp_side_dihedrals = sparse_side_dihedrals 

797 

798 # make them into an Iterable again 

799 if self.p.use_sidechains: 

800 data = ( 

801 inp_angles, 

802 inp_dihedrals, 

803 inp_cartesians, 

804 inp_distances, 

805 inp_side_dihedrals, 

806 ) 

807 else: 

808 data = ( 

809 inp_angles, 

810 inp_dihedrals, 

811 inp_cartesians, 

812 inp_distances, 

813 ) 

814 

815 # call the loss 

816 return super().get_loss(data) 

817 

818 

819class SequentialModel(tf.keras.Model): 

820 def __init__( 

821 self, 

822 input_dim, 

823 parameters=None, 

824 reload_layers=None, 

825 sparse=False, 

826 get_dense_model=None, 

827 ): 

828 if parameters is None: 828 ↛ 829line 828 didn't jump to line 829, because the condition on line 828 was never true

829 self.p = Parameters() 

830 else: 

831 self.p = parameters 

832 super().__init__() 

833 self.sparse = sparse 

834 self.input_dim = input_dim 

835 self.get_dense_model = get_dense_model 

836 

837 # tensors for using tf.cond inside self.train_step() 

838 self.update_step = tf.constant(self.p.summary_step, dtype="int64") 

839 self.debug_tensor = tf.constant(self.p.tensorboard, dtype="bool") 

840 

841 # periodicity doubles the inputs and outputs 

842 if self.p.periodicity < float("inf"): 

843 self.input_dim *= 2 

844 

845 # if layers is provided we are rebuilding a trained model 

846 if reload_layers is not None: 846 ↛ 847line 846 didn't jump to line 847, because the condition on line 846 was never true

847 if self.p.periodicity < float("inf"): 

848 print( 

849 f"rebuilding Model with input_dim = {int(self.input_dim/2)} and periodicity = {self.p.periodicity}" 

850 ) 

851 else: 

852 print(f"rebuilding Model with input_dim = {self.input_dim}") 

853 if len(reload_layers) != 2: 

854 raise Exception( 

855 "currently only works with 2 layers. Encoder and Decoder." 

856 ) 

857 assert all([isinstance(i, tf.keras.Sequential) for i in reload_layers]) 

858 # input_layer = tf.keras.layers.InputLayer(input_shape=(self.input_dim,), dtype='float32') 

859 self.encoder_model = reload_layers[0] 

860 self.decoder_model = reload_layers[1] 

861 self.build(input_shape=(1, self.input_dim)) 

862 return 

863 

864 # define regularizer 

865 regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant) 

866 

867 # rename empty string in parameters to None 

868 activation_functions = list( 

869 map(lambda x: x if x != "" else None, self.p.activation_functions) 

870 ) 

871 

872 # define how layers are stacked 

873 layer_data = list( 

874 zip( 

875 self.p.n_neurons + self.p.n_neurons[-2::-1], 

876 activation_functions[1:] + activation_functions[-1::-1], 

877 ) 

878 ) 

879 # add a layer that reshapes the output 

880 layer_data.append([self.input_dim, "tanh"]) 

881 

882 # decide layer names 

883 names = [] 

884 for i, (n_neurons, act_fun) in enumerate(layer_data): 

885 if i < len(self.p.n_neurons) - 1: 

886 name = f"Encoder_{i}" 

887 elif i > len(self.p.n_neurons) - 1: 

888 ind = i - len(self.p.n_neurons) 

889 name = f"Decoder_{ind}" 

890 else: 

891 name = "Latent" 

892 names.append(name) 

893 layer_data = list((*i, j) for i, j in zip(layer_data, names)) 

894 

895 # define encoder and decoder layers 

896 neurons = [i[0] for i in layer_data] 

897 bottleneck_index = neurons.index(min(neurons)) + 1 

898 self.encoder_layers = layer_data[:bottleneck_index] 

899 self.decoder_layers = layer_data[bottleneck_index:] 

900 

901 # input 

902 # Instead of using InputLayer use Dense with kwarg input_shape 

903 # allows model to be reloaded better <- weird english... reloaded better 

904 if self.sparse: 904 ↛ 905line 904 didn't jump to line 905, because the condition on line 904 was never true

905 shape = self.input_dim 

906 if self.p.periodicity < float("inf"): 

907 shape /= 2 

908 _input_layer = Input( 

909 shape=(int(shape),), 

910 sparse=True, 

911 ) 

912 x = Dense(shape)(_input_layer) 

913 self.get_dense_model = tf.keras.Model( 

914 inputs=_input_layer, 

915 outputs=x, 

916 ) 

917 

918 input_layer = tf.keras.layers.Dense( 

919 input_shape=(self.input_dim,), 

920 units=self.encoder_layers[0][0], 

921 activation=self.encoder_layers[0][1], 

922 name=self.encoder_layers[0][2], 

923 kernel_initializer=tf.initializers.VarianceScaling(), 

924 kernel_regularizer=regularizer, 

925 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5), 

926 ) 

927 

928 # what model to use for the encoder 

929 

930 # output 

931 # output_layer = tf.keras.layers.Dense( 

932 # self.input_dim, 

933 # name="Output", 

934 # activation=None, 

935 # kernel_initializer=tf.initializers.VarianceScaling(), 

936 # kernel_regularizer=regularizer, 

937 # bias_initializer=tf.initializers.RandomNormal(0.1, 0.05), 

938 # trainable=True) 

939 

940 # encoder 

941 self.encoder_model = tf.keras.Sequential( 

942 [input_layer] 

943 + [ 

944 tf.keras.layers.Dense( 

945 n_neurons, 

946 activation=act_fun, 

947 name=name, 

948 kernel_initializer=tf.initializers.VarianceScaling(), 

949 kernel_regularizer=regularizer, 

950 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5), 

951 ) 

952 for n_neurons, act_fun, name in self.encoder_layers[1:] 

953 ], 

954 name="Encoder", 

955 ) 

956 

957 # decoder 

958 self.decoder_model = tf.keras.Sequential( 

959 [ 

960 tf.keras.layers.Dense( 

961 n_neurons, 

962 activation=act_fun, 

963 name=name, 

964 kernel_initializer=tf.initializers.VarianceScaling(), 

965 kernel_regularizer=regularizer, 

966 bias_initializer=tf.initializers.RandomNormal(0.1, 0.5), 

967 ) 

968 for n_neurons, act_fun, name in self.decoder_layers 

969 ], 

970 name="Decoder", 

971 ) 

972 

973 # build 

974 self.build(input_shape=(1, self.input_dim)) 

975 

976 # train counter 

977 self._train_counter = K.variable(0, "int64", name="train_counter") 

978 

979 def build(self, *args, **kwargs): 

980 input_shape = kwargs["input_shape"] 

981 # Because build calls self.call and self.call calls self.encode 

982 # the input dim needs to be halved here 

983 if self.p.periodicity < float("inf"): 

984 input_shape = (*input_shape[:-1], int(input_shape[-1] / 2)) 

985 try: 

986 super().build(*args, **{**kwargs, **dict(input_shape=input_shape)}) 

987 except Exception: 

988 if self.p.periodicity < float("inf"): 

989 print( 

990 "Exception is raised because of periodicity. In general " 

991 "you don't need to call the build method from outside of " 

992 "this class, because it is called at the end of __init__." 

993 ) 

994 raise 

995 

996 def compile(self, *args, **kwargs): 

997 super().compile(*args, **kwargs) 

998 try: 

999 self.unpacked_loss_fns = { 

1000 fn.__name__: fn for fn in self.compiled_loss._losses 

1001 } 

1002 except AttributeError: 

1003 for i in dir(self): 

1004 if "loss" in i: 

1005 print(i) 

1006 print(self._callable_losses) 

1007 raise 

1008 

1009 def encoder(self, x, training=False): 

1010 if self.sparse: 1010 ↛ 1011line 1010 didn't jump to line 1011, because the condition on line 1010 was never true

1011 x = self.get_dense_model(x) 

1012 if self.p.periodicity < float("inf"): 

1013 if self.p.periodicity != 2 * pi: 1013 ↛ 1014line 1013 didn't jump to line 1014, because the condition on line 1013 was never true

1014 x = x / self.p.periodicity * 2 * pi 

1015 x = tf.concat([tf.sin(x), tf.cos(x)], 1) 

1016 return self.encoder_model(x, training=training) 

1017 

1018 def decoder(self, x, training=False): 

1019 x = self.decoder_model(x, training=training) 

1020 if self.p.periodicity < float("inf"): 

1021 x = tf.atan2(*tf.split(x, 2, 1)) 

1022 if self.p.periodicity != 2 * pi: 1022 ↛ 1023line 1022 didn't jump to line 1023, because the condition on line 1022 was never true

1023 x = x / (2 * pi) * self.p.periodicity 

1024 return x 

1025 

1026 def call(self, x, training=False): 

1027 # encode and decode 

1028 encoded = self.encoder(x, training=training) 

1029 decoded = self.decoder(encoded, training=training) 

1030 return decoded 

1031 

1032 def train_step(self, data): 

1033 """Overwrites the normal train_step. What is different? 

1034 

1035 Not much. Even the provided data is expected to be a tuple of (data, classes) (x, y) in classification tasks. 

1036 The data is unpacked and y is discarded, because the Autoencoder Model is a regression task. 

1037 

1038 Args: 

1039 data (tuple): The (x, y) data of this train step. 

1040 

1041 """ 

1042 x, _ = data 

1043 if self.sparse: 

1044 x = self.get_dense_model(x) 

1045 

1046 with tf.GradientTape() as tape: 

1047 tf.summary.experimental.set_step(self._train_counter) 

1048 with tf.name_scope("Cost"): 

1049 loss = 0.0 

1050 for l in self.compiled_loss._losses: 

1051 loss += l(x, self(x, training=True)) 

1052 tf.summary.scalar("Combined Cost", loss) 

1053 for l in self.encoder_model.layers + self.decoder_model.layers: 

1054 add_layer_summaries(l, step=self._train_counter) 

1055 

1056 # Compute Gradients 

1057 trainable_vars = self.trainable_variables 

1058 gradients = tape.gradient(loss, trainable_vars) 

1059 # Update weights 

1060 self.optimizer.apply_gradients(zip(gradients, trainable_vars)) 

1061 # Update metrics 

1062 self.compiled_metrics.update_state(x, self(x)) 

1063 # udpate train counter because tensorflow seems to have deprecated it 

1064 self._train_counter.assign_add(1) 

1065 # Return a dict mapping metric names to current value 

1066 return {**{m.name: m.result() for m in self.metrics}, **{"loss": loss}} 

1067 

1068 

1069class ADCSequentialModel(SequentialModel): 

1070 def __init__(self, input_dim, parameters=None, reload_layers=None, sparse=False): 

1071 warnings.warn("check split") 

1072 if parameters is None: 

1073 self.p = ADCParameters() 

1074 else: 

1075 self.p = parameters 

1076 self.multiples = tf.TensorShape((self.p.batch_size, 1)) 

1077 super(ADCSequentialModel, self).__init__(input_dim, self.p, reload_layers) 

1078 

1079 def call(self, x, training=False): 

1080 # encode and decode 

1081 if isinstance(x, tf.Tensor): 

1082 pass 

1083 elif isinstance(x, tuple): 

1084 ( 

1085 inp_angles, 

1086 inp_dihedrals, 

1087 inp_cartesians, 

1088 inp_distances, 

1089 inp_side_dihedrals, 

1090 ) = x 

1091 if not self.p.use_backbone_angles and not self.p.use_sidechains: 

1092 x = inp_dihedrals 

1093 elif self.p.use_backbone_angles and not self.p.use_sidechains: 

1094 x = tf.concat([inp_angles, inp_dihedrals], 1) 

1095 elif self.p.use_backbone_angles and self.p.use_sidechains: 

1096 x = tf.concat([inp_angles, inp_dihedrals, inp_side_dihedrals], 1) 

1097 encoded = self.encoder(x, training=training) 

1098 decoded = self.decoder(encoded, training=training) 

1099 return decoded 

1100 

1101 def call_and_map_back( 

1102 self, x, distances, angles, dihedrals, cartesians, splits, side_dihedrals=None 

1103 ): 

1104 # latent = self.encoder(x, training=False) 

1105 out = self(x, training=True) 

1106 latent = self.encoder(x, training=True) 

1107 

1108 # unpack out 

1109 if splits is None: 

1110 out_dihedrals = out 

1111 out_angles = tf.tile( 

1112 tf.expand_dims(tf.reduce_mean(angles, 0), 0), multiples=self.multiples 

1113 ) 

1114 elif len(splits) == 2: 

1115 out_angles, out_dihedrals = tf.split(out, splits, 1) 

1116 elif len(splits) == 3: 

1117 out_angles, out_dihedrals, out_side_dihedrals = tf.split(out, splits, 1) 

1118 

1119 # do back-mapping 

1120 back_mean_lengths = tf.expand_dims(tf.reduce_mean(distances, 0), 0) 

1121 back_chain_in_plane = chain_in_plane(back_mean_lengths, out_angles) 

1122 back_cartesians = dihedrals_to_cartesian_tf( 

1123 out_dihedrals + pi, back_chain_in_plane 

1124 ) 

1125 

1126 # get pairwise distances of CA atoms 

1127 inp_pair = pairwise_dist( 

1128 cartesians[ 

1129 :, 

1130 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

1131 ], 

1132 flat=True, 

1133 ) 

1134 out_pair = pairwise_dist( 

1135 back_cartesians[ 

1136 :, 

1137 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

1138 ], 

1139 flat=True, 

1140 ) 

1141 

1142 with tf.name_scope("Cost"): 

1143 loss = 0.0 

1144 # dihedral loss 

1145 loss += self.unpacked_loss_fns["dihedral_loss_func"]( 

1146 dihedrals, out_dihedrals 

1147 ) 

1148 # angle loss 

1149 loss += self.unpacked_loss_fns["angle_loss_func"](angles, out_angles) 

1150 # cartesian loss 

1151 loss += self.unpacked_loss_fns["cartesian_loss_func"](inp_pair, out_pair) 

1152 # distance loss 

1153 loss += self.unpacked_loss_fns["distance_loss_func"](x) 

1154 # cartesian distance cost 

1155 loss += self.unpacked_loss_fns["cartesian_distance_loss_func"]( 

1156 inp_pair, latent 

1157 ) 

1158 # center loss 

1159 loss += self.unpacked_loss_fns["center_loss_func"](x) 

1160 # reg loss 

1161 loss += self.unpacked_loss_fns["regularization_loss_func"]() 

1162 if self.p.use_sidechains: 

1163 loss += self.unpacked_loss_fns["side_dihedral_loss_func"]( 

1164 side_dihedrals, out_side_dihedrals 

1165 ) 

1166 tf.summary.scalar("Combined Cost", loss) 

1167 return loss 

1168 

1169 def train_step(self, data): 

1170 # unpack the data 

1171 ( 

1172 inp_angles, 

1173 inp_dihedrals, 

1174 inp_cartesians, 

1175 inp_distances, 

1176 inp_side_dihedrals, 

1177 ) = data 

1178 if not self.p.use_backbone_angles and not self.p.use_sidechains: 

1179 main_inputs = inp_dihedrals 

1180 splits = None 

1181 elif self.p.use_backbone_angles and not self.p.use_sidechains: 

1182 main_inputs = tf.concat([inp_angles, inp_dihedrals], 1) 

1183 splits = [inp_angles.shape[1], inp_dihedrals.shape[1]] 

1184 elif self.p.use_backbone_angles and self.p.use_sidechains: 

1185 # ToDo: make sure the splits work here. There seems to be different outputs from tf and np. 

1186 main_inputs = tf.concat([inp_angles, inp_dihedrals, inp_side_dihedrals], 1) 

1187 splits = [ 

1188 inp_angles.shape[1], 

1189 inp_dihedrals.shape[1], 

1190 inp_side_dihedrals.shape[1], 

1191 ] 

1192 

1193 # gradient tape to calculate loss for backmapping 

1194 with tf.GradientTape() as tape: 

1195 tf.summary.experimental.set_step(self._train_counter) 

1196 if self.p.use_sidechains: 

1197 loss = self.call_and_map_back( 

1198 main_inputs, 

1199 inp_distances, 

1200 inp_angles, 

1201 inp_dihedrals, 

1202 inp_cartesians, 

1203 splits, 

1204 inp_side_dihedrals, 

1205 ) 

1206 else: 

1207 loss = self.call_and_map_back( 

1208 main_inputs, 

1209 inp_distances, 

1210 inp_angles, 

1211 inp_dihedrals, 

1212 inp_cartesians, 

1213 splits, 

1214 ) 

1215 for l in self.encoder_model.layers + self.decoder_model.layers: 

1216 add_layer_summaries(l, step=self._train_counter) 

1217 

1218 # optimization happens here 

1219 # Compute Gradients 

1220 trainable_vars = self.trainable_variables 

1221 gradients = tape.gradient(loss, trainable_vars) 

1222 # Update weights 

1223 self.optimizer.apply_gradients(zip(gradients, trainable_vars)) 

1224 # Update metrics 

1225 self.compiled_metrics.update_state(data, self(data)) 

1226 # Return a dict mapping metric names to current value 

1227 # Add loss to the dict so the ProgressBar callback can pick it up 

1228 return {**{m.name: m.result() for m in self.metrics}, **{"loss": loss}}