Coverage for encodermap/autoencoder/autoencoder.py: 40%

452 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/autoencoder/autoencoder.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Forward facing Autoencoder classes. Contains four classes: 

23 

24* Autoencoder: Simple NN dense, fully connected AE architecture. Reg loss, auto loss and center loss 

25* EncoderMap: Uses the same architecture as `Autoencoder`, but adds another loss function. 

26* DihedralEncoderMap: Basically the same as `EncoderMap`, but rewrites the `generate` method to use 

27 an atomistic topology to rebuild a trajectory. 

28* AngleDihedralCartesianEncoderMap: Uses more loss functions and tries to learn a full all atom conformation. 

29 

30""" 

31 

32 

33################################################################################ 

34# Imports 

35################################################################################ 

36 

37 

38from __future__ import annotations 

39 

40import typing 

41import warnings 

42from copy import deepcopy 

43from typing import Literal, Optional, Union 

44 

45import matplotlib 

46import numpy as np 

47import tensorflow as tf 

48import tensorflow.keras 

49 

50import encodermap 

51 

52from ..callbacks.callbacks import ( 

53 CheckpointSaver, 

54 ImageCallback, 

55 IncreaseCartesianCost, 

56 ProgressBar, 

57 TensorboardWriteBool, 

58) 

59from ..loss_functions.loss_functions import ( 

60 angle_loss, 

61 auto_loss, 

62 cartesian_distance_loss, 

63 cartesian_loss, 

64 center_loss, 

65 dihedral_loss, 

66 distance_loss, 

67 reconstruction_loss, 

68 regularization_loss, 

69 side_dihedral_loss, 

70) 

71from ..misc.backmapping import dihedral_backmapping, mdtraj_backmapping 

72from ..misc.misc import BadError, create_n_cube, plot_model 

73from ..misc.saving_loading_models import load_model, save_model 

74from ..models.models import gen_functional_model, gen_sequential_model 

75from ..parameters.parameters import ADCParameters, Parameters 

76from ..trajinfo.info_all import Capturing, TrajEnsemble 

77 

78################################################################################ 

79# Typing 

80################################################################################ 

81 

82 

83if typing.TYPE_CHECKING: 

84 import MDAnalysis as mda 

85 import mdtraj as md 

86 

87 

88################################################################################ 

89# Globals 

90################################################################################ 

91 

92 

93__all__ = [ 

94 "Autoencoder", 

95 "EncoderMap", 

96 "AngleDihedralCartesianEncoderMap", 

97 "DihedralEncoderMap", 

98] 

99 

100 

101############################################################################## 

102# Function definition which allows self.p.tensorboard to be passed 

103# @function(self.p.tensorboard) 

104# def train(self): 

105# # some training 

106############################################################################## 

107 

108 

109def function(f, tensorboard=False): 

110 """Compiles functions with `tensorflow.function` based on a `tensorboard` 

111 parameter. 

112 

113 

114 To understand the neccessity of this function, we need to have a look how 

115 tensorflow executes computations. There are two modes of execution: 

116 * eager mode: In eager mode, the computations are handles by python. 

117 The input types are python objects, and the output is a python object. 

118 This eager execution allows you to directly execute a calculation of 

119 two tensors (e.g. multiplication). 

120 * graph mode: In graph mode, computations are done inside tensorflow graphs, 

121 which are a collection of operations and tensors (i.e. data), that flow 

122 through the operations of the graph. These graphs, make tensorflow 

123 computations portable and significantly increase the performance of 

124 similar computations. 

125 Normally, you would accelerate a simple python function in tensorflow, 

126 by compiling it like so: 

127 

128 ```python 

129 import tensorflow as tf 

130 

131 @tf.function 

132 def multiply(a, b): 

133 return a * b 

134 

135 multiply(tf.constant(2), tf.constant(3)).numpy() 

136 # 6 

137 ``` 

138 

139 However, the basic paradigm of accelerating the computation interferes with 

140 `encodermap.Parameters` `tensorboard=True` argument, as it writes a lot of 

141 additional information to tensorboard. Thus, a compilation with tf.function 

142 does not make sense here. That's why encodermap's `function` decorator 

143 takes an additional argument: 

144 

145 """ 

146 

147 def wrapper(*args, **kwargs): 

148 tensorboard = kwargs.pop("tensorboard", False) 

149 """Wrapper of `encodermap.function`.""" 

150 if tensorboard: 

151 warnings.warn( 

152 "Running in tensorboard mode writes a lot of stuff to tensorboard. For speed boost deactivate tensorboard mode." 

153 ) 

154 result = f(*args, **kwargs) 

155 else: 

156 compiled = tf.function(f) 

157 result = compiled(*args, **kwargs) 

158 return result 

159 

160 return wrapper 

161 

162 

163############################################################################## 

164# Public Classes 

165############################################################################## 

166 

167 

168class Autoencoder: 

169 """Main Autoencoder class preparing data, setting up the neural network and implementing training. 

170 

171 This is the main class for neural networks inside EncoderMap. The class prepares the data 

172 (batching and shuffling), creates a `tf.keras.Model` of layers specified by the attributes of 

173 the `encodermap.Parameters` class. Depending on what Parent/Child-Class is instantiated 

174 a combination of cost functions is set up. Callbacks to Tensorboard are also set up. 

175 

176 Attributes: 

177 train_data (np.ndarray): The numpy array of the train data passed at init. 

178 p (encodermap.Parameters): An `encodermap.Parameters()` class containing all info needed to set 

179 up the network. 

180 dataset (tensorflow.data.Dataset): The dataset that is actually used in training the keras model. The dataset 

181 is a batched, shuffled, infinitely-repeating dataset. 

182 read_only (bool): Variable telling the class whether it is allowed to write to disk (False) or not (True). 

183 optimizer (tf.keras.optimizers.Adam): Instance of the Adam optimizer with learning rate specified by 

184 the Parameters class. 

185 metrics (list): A list of metrics passed to the model when it is compiled. 

186 callbacks (list): A list of tf.keras.callbacks.Callback Sub-classes changing the behavior of the model during 

187 training. Some standard callbacks are always present like: 

188 * encodermap.callbacks.callbacks.ProgressBar: 

189 A progress bar callback using tqdm giving the current progress of training and the 

190 current loss. 

191 * CheckPointSaver: 

192 A callback that saves the model every parameters.checkpoint_step steps into 

193 the main directory. This callback will only be used, when `read_only` is False. 

194 * TensorboardWriteBool: 

195 A callback that contains a boolean Tensor that will be True or False, 

196 depending on the current training step and the summary_step in the parameters class. The loss 

197 functions use this callback to decide whether they should write to Tensorboard. This callback 

198 will only be present, when `read_only` is False and `parameters.tensorboard` is True. 

199 You can append your own callbacks to this list before executing Autoencoder.train(). 

200 encoder (tf.keras.models.Model): The encoder (sub)model of `model`. 

201 decoder (tf.keras.models.Model): The decoder (sub)model of `model`. 

202 

203 Methods: 

204 from_checkpoint: Rebuild the model from a checkpoint. 

205 add_images_to_tensorboard: Make tensorboard plot images. 

206 train: Starts the training of the tf.keras.models.Model. 

207 plot_network: Tries to plot the network. For this method to work graphviz, pydot and pydotplus needs to be installed. 

208 encode: Takes high-dimensional data and sends it through the encoder. 

209 decode: Takes low-dimensional data and sends it through the encoder. 

210 generate: Same as decode. For AngleDihedralCartesianAutoencoder classes this will build a protein strutcure. 

211 

212 Note: 

213 Performance of tensorflow is not only dependant on your system's hardware and how the data is presented to 

214 the network (for this check out https://www.tensorflow.org/guide/data_performance), but also how you compiled 

215 tensorflow. Normal tensorflow (pip install tensorflow) is build without CPU extensions to work on many CPUs. 

216 However, Tensorflow can greatly benefit from using CPU instructions like AVX2, AVX512 that bring a speed-up 

217 in linear algebra computations of 300%. By building tensorflow from source you can activate these extensions. 

218 However, the CPU speed-up is dwarfed by the speed-up when you allow tensorflow to run on your GPU (grapohics 

219 card). To check whether a GPU is available run: 

220 `print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))`. 

221 Refer to these pages to install tensorflow for best performance: 

222 https://www.tensorflow.org/install/pip, https://www.tensorflow.org/install/gpu 

223 

224 Examples: 

225 >>> import encodermap as em 

226 >>> # without providing any data, default parameters and a 4D hypercube as input data will be used. 

227 >>> e_map = em.EncoderMap(read_only=True) 

228 >>> print(e_map.train_data.shape) 

229 (16000, 4) 

230 >>> print(e_map.dataset) 

231 <BatchDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float32, name=None), TensorSpec(shape=(None, 4), dtype=tf.float32, name=None))> 

232 >>> print(e_map.encode(e_map.train_data).shape) 

233 (16000, 2) 

234 

235 """ 

236 

237 def __init__( 

238 self, 

239 parameters=None, 

240 train_data: Optional[Union[np.ndarray, tf.Dataset]] = None, 

241 model=None, 

242 read_only=False, 

243 sparse=False, 

244 ): 

245 """Instantiate the Autoencoder class. 

246 

247 Args: 

248 parameters (Union[encodermap.Parameters, None], optional): The parameters to be used. If None is 

249 provided default values (check them with print(em.Parameters.defaults_description())) 

250 are used. Defaults to None. 

251 train_data (Union[np.ndarray, tf.data.Dataset, None], optional): The train data. Can be one of the following: 

252 * None: If None is provided points on the edges of a 4-dimensional hypercube will be used as train data. 

253 * np.ndarray: If a numpy array is provided, it will be transformed into a batched tf.data.Dataset by 

254 first making it an infinitely repeating dataset, shuffling it and the batching it with a batch 

255 size specified by parameters.batch_size. 

256 * tf.data.Dataset: If a dataset is provided it will be used without making any adjustments. Make 

257 sure, that the dataset uses `float32` as its type. 

258 Defaults to None. 

259 model (Union[tf.keras.models.Model, None], optional): Providing a keras model to this argument will make 

260 the Autoencoder/EncoderMap class use this model instead of the predefined ones. Make sure the model 

261 can accept EncoderMap's loss functions. If None is provided the model will be built using 

262 the specifications in parameters. Defaults to None. 

263 read_only (bool, optional): Whether the class is allowed to write to disk (False) or not (True). Defaults 

264 to False and will allow the class to write to disk. 

265 

266 Raises: 

267 BadError: When read_only is `True` and `parameters.tensorboard` is `True`, this Exception will be raised, 

268 because they are mutually exclusive. 

269 

270 """ 

271 # parameters 

272 if parameters is None: 

273 self.p = Parameters() 

274 else: 

275 self.p = parameters 

276 

277 if self.p.seed is not None: 277 ↛ 278line 277 didn't jump to line 278, because the condition on line 277 was never true

278 tf.random.set_seed(self.p.seed) 

279 self.read_only = read_only 

280 

281 if not self.read_only: 281 ↛ 282line 281 didn't jump to line 282, because the condition on line 281 was never true

282 self.p.save() 

283 print( 

284 "Output files are saved to {}".format(self.p.main_path), 

285 "as defined in 'main_path' in the parameters.", 

286 ) 

287 

288 # check whether Tensorboard and Read-Only makes Sense 

289 if self.read_only and self.p.tensorboard: 289 ↛ 290line 289 didn't jump to line 290, because the condition on line 289 was never true

290 raise BadError( 

291 "Setting tensorboard and read_only True is not possible. Tensorboard will always write to disk." 

292 " If you received this Error while loading a trained model, pass read_only=False as an argument" 

293 f" or set overwrite_tensorboard_bool True to overwrite the tensorboard parameter." 

294 ) 

295 

296 # clear old sessions 

297 tf.keras.backend.clear_session() 

298 self.sparse = sparse 

299 

300 # set up train_data 

301 if train_data is None: 

302 self.train_data = create_n_cube(4, seed=self.p.seed)[0].astype("float32") 

303 self.p.periodicity = float("inf") 

304 elif isinstance(train_data, np.ndarray): 304 ↛ 315line 304 didn't jump to line 315, because the condition on line 304 was never false

305 if np.any(np.isnan(train_data)): 305 ↛ 306line 305 didn't jump to line 306, because the condition on line 305 was never true

306 self.sparse = True 

307 print("Input contains nans. Using sparse network.") 

308 indices = np.stack(np.where(~np.isnan(train_data))).T.astype("int64") 

309 dense_shape = train_data.shape 

310 values = train_data[~np.isnan(train_data)].flatten().astype("float32") 

311 sparse_tensor = tf.sparse.SparseTensor(indices, values, dense_shape) 

312 self.train_data = sparse_tensor 

313 else: 

314 self.train_data = train_data.astype("float32") 

315 elif isinstance(train_data, tf.data.Dataset): 

316 self.dataset = train_data 

317 try: 

318 for _, __ in self.dataset: 

319 break 

320 except ValueError: 

321 if self.p.training == "auto": 

322 print( 

323 f"It seems like your dataset only yields tensors and not " 

324 f"tuples of tensors. Tensorlfow is optimized for classification " 

325 f"tasks, where datasets yield tuples of (data, classes). EncoderMap," 

326 f"however is a regression task, but uses the same code as the " 

327 f"classification tasks. I will transform your dataset using " 

328 f"the `tf.data.Dataset.zip()` function of `tf.data`. You can " 

329 f"set the `training` parameter in the parameter class to " 

330 f"'custom' to not alter your dataset." 

331 ) 

332 self.dataset = tf.data.Dataset.zip((self.dataset, self.dataset)) 

333 for _, __ in self.dataset: 

334 break 

335 else: 

336 for _ in self.dataset: 

337 break 

338 self.train_data = _ 

339 else: 

340 raise TypeError( 

341 f"train_data must be `None`, `np.ndarray` or `tf.data.Dataset`. You supplied {type(train_data)}." 

342 ) 

343 

344 # check data and periodicity 

345 if not self.sparse and not train_data is None: 

346 if np.any(train_data > self.p.periodicity): 346 ↛ 347line 346 didn't jump to line 347, because the condition on line 346 was never true

347 raise Exception( 

348 "There seems to be an error regarding the periodicity " 

349 f"of your data. The chosen periodicity is {self.p.periodicity}, " 

350 f"but there are datapoints outwards of this range: {train_data.max()}" 

351 ) 

352 

353 # prepare the data 

354 if isinstance(self.train_data, (np.ndarray, tf.sparse.SparseTensor)): 354 ↛ 383line 354 didn't jump to line 383, because the condition on line 354 was never false

355 if self.p.training == "auto": 355 ↛ 360line 355 didn't jump to line 360, because the condition on line 355 was never false

356 dataset = tf.data.Dataset.from_tensor_slices( 

357 (self.train_data, self.train_data) 

358 ) 

359 else: 

360 dataset = tf.data.Dataset.from_tensor_slices(self.train_data) 

361 dataset = dataset.shuffle( 

362 buffer_size=self.train_data.shape[0], reshuffle_each_iteration=True 

363 ) 

364 dataset = dataset.repeat() 

365 self.dataset = dataset.batch(self.p.batch_size) 

366 else: 

367 pass 

368 

369 # ToDo: Make training faster with Autotune, XLA (jit) compilation, DataRecords 

370 # self.dataset = self.dataset.prefetch(self.p.batch_size * 4) 

371 # self.dataset = self.dataset.interleave(num_parallel_calls=tf.data.experimental.AUTOTUNE) 

372 # num_batches = 20 

373 # num_threads = 8 

374 # if self.p.training == 'auto': 

375 # ds = tf.data.Dataset.from_tensors((self.train_data, self.train_data)).repeat(self.p.batch_size * num_batches) 

376 # else: 

377 # ds = tf.data.Dataset.from_tensors(self.train_data).repeat(self.p.batch_size * num_batches) 

378 # # ds = tf.data.Dataset.from_tensors(self.train_data).repeat(self.p.batch_size * num_batches) 

379 # ds = ds.batch(self.p.batch_size) 

380 # self.dataset = ds.interleave(lambda *args:tf.data.Dataset.from_tensor_slices(args), num_threads, 1, num_threads) 

381 

382 # create model based on user input 

383 if model is None: 383 ↛ 386line 383 didn't jump to line 386, because the condition on line 383 was never false

384 self.model = self.p.model_api 

385 else: 

386 self._model = model 

387 

388 # setup callbacks for nice progress bars and saving every now and then 

389 self._setup_callbacks() 

390 

391 # create loss based on user input 

392 self.loss = self.p.loss 

393 

394 # choose optimizer 

395 self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.p.learning_rate) 

396 

397 # compile model 

398 self.model.compile( 

399 optimizer=self.optimizer, loss=self.loss, metrics=self.metrics 

400 ) 

401 

402 # do this if tensorboard is true. 

403 if self.p.tensorboard: 403 ↛ 404line 403 didn't jump to line 404, because the condition on line 403 was never true

404 self._log_images = False 

405 # get the output from model summary. 

406 with Capturing() as output: 

407 self.model.summary() 

408 with open(self.p.main_path + "/model_summary.txt", "w") as f: 

409 f.write("\n".join(output)) 

410 tf.keras.utils.plot_model( 

411 self.model, 

412 to_file=self.p.main_path + "/model_summary.png", 

413 show_shapes=True, 

414 rankdir="LR", 

415 expand_nested=True, 

416 ) 

417 print( 

418 f"Saved a text-summary of the model and an image in {self.p.main_path},", 

419 "as specified in 'main_path' in the parameters.", 

420 ) 

421 

422 # sets up the tb callback to plot the model 

423 self.tb_callback = tf.keras.callbacks.TensorBoard( 

424 self.p.main_path, write_graph=True 

425 ) 

426 self.tb_callback.set_model(self.model) 

427 

428 def _setup_callbacks(self): 

429 """Sets up a list with callbacks to be passed to self.model.fit()""" 

430 self.metrics = [] 

431 self.callbacks = [] 

432 self.callbacks.append(ProgressBar(parameters=self.p)) 

433 if not self.read_only: 

434 self.callbacks.append(CheckpointSaver(self.p)) 

435 if self.p.tensorboard: 435 ↛ 436line 435 didn't jump to line 436, because the condition on line 435 was never true

436 self.tensorboard_write_bool = TensorboardWriteBool(self.p) 

437 self.callbacks.append(self.tensorboard_write_bool) 

438 file_writer = tf.summary.create_file_writer(self.p.main_path + "/train") 

439 file_writer.set_as_default() 

440 tf.summary.text( 

441 name=f"Parameters Summary for {self.p.main_path}", 

442 data=self.p.parameters, 

443 step=0, 

444 ) 

445 # callbacks.append(self.tb_callback) 

446 else: 

447 self.tensorboard_write_bool = None 

448 

449 @classmethod 

450 def from_checkpoint( 

451 cls, 

452 checkpoint_path, 

453 read_only=True, 

454 overwrite_tensorboard_bool=False, 

455 sparse=False, 

456 ): 

457 """Reconstructs the class from a checkpoint. 

458 

459 Args: 

460 Checkpoint path (str): The path to the checkpoint. Most models are saved in parts (encoder, decoder) 

461 and thus the provided path often needs a wildcard (*). The `save()` method of this class prints 

462 a string with which the model can be reloaded. 

463 read_only (bool, optional): Whether to reload the model in read_only mode (True) or allow the `Autoencoder` 

464 class to write to disk (False). This option might collide with the tensorboard Parameter in the 

465 respective parameters.json file in the maith_path. Defaults to True. 

466 overwrite_tensorboard_bool (bool, optional): Whether to overwrite the tensorboard Parameter while reloading 

467 the class. This can be set to True to set the tensorboard parameter False and allow read_only. 

468 Defaults to False. 

469 

470 Raises: 

471 BadError: When read_only is True, overwrite_tensorboard_bool is False and the reloaded parameters 

472 have tensorboard set to True. 

473 

474 Returns: 

475 Autoencoder: Encodermap `Autoencoder` class. 

476 

477 """ 

478 return load_model( 

479 cls, checkpoint_path, read_only, overwrite_tensorboard_bool, sparse=sparse 

480 ) 

481 

482 @property 

483 def model(self): 

484 """tf.keras.models.Model: The tf.keras.Model model used for training.""" 

485 return self._model 

486 

487 @model.setter 

488 def model(self, model): 

489 """sets self.model according to `model_api` argument in self.parameters.""" 

490 if model == "functional": 

491 for d in self.dataset: 491 ↛ 493line 491 didn't jump to line 493, because the loop on line 491 didn't complete

492 break 

493 if any([isinstance(_, tf.sparse.SparseTensor) for _ in d]): 493 ↛ 494line 493 didn't jump to line 494, because the condition on line 493 was never true

494 self.sparse = True 

495 self._model = gen_functional_model(self.dataset, self.p, sparse=self.sparse) 

496 elif model == "sequential": 496 ↛ 502line 496 didn't jump to line 502, because the condition on line 496 was never false

497 if isinstance(self.train_data, tf.sparse.SparseTensor): 497 ↛ 498line 497 didn't jump to line 498, because the condition on line 497 was never true

498 self.sparse = True 

499 self._model = gen_sequential_model( 

500 self.train_data.shape[1], self.p, sparse=self.sparse 

501 ) 

502 elif model == "custom": 

503 raise NotImplementedError("No custom API currently supported") 

504 else: 

505 raise ValueError( 

506 f"API argument needs to be one of `functional`, `sequential`, `custom`. You provided '{model}'." 

507 ) 

508 

509 @property 

510 def encoder(self): 

511 """tf.keras.models.Model: Encoder part of the model.""" 

512 return self._model.encoder_model 

513 

514 @property 

515 def decoder(self): 

516 """tf.keras.models.Model: Decoder part of the model.""" 

517 return self._model.decoder_model 

518 

519 @property 

520 def loss(self): 

521 """(Union[list, string, function]): A list of loss functions passed to the model when it is compiled. 

522 When the main Autoencoder class is used and parameters.loss is 'emap_cost' this list is comprised of 

523 center_cost, regularization_cost, auto_cost. When the EncoderMap sub-class is used and parameters.loss is 

524 'emap_cost' distance_cost is added to the list. When parameters.loss is not 'emap_cost', the loss can either 

525 be a string ('mse'), or a function, that both are acceptable arguments for loss, when a keras model 

526 is compiled. 

527 

528 """ 

529 return self._loss 

530 

531 @loss.setter 

532 def loss(self, loss): 

533 """sets self.loss according to `loss` in self.parameters.""" 

534 if loss == "reconstruction_loss": 

535 self._loss = reconstruction_loss(self.model) 

536 elif loss == "emap_cost": 

537 self.auto_loss = auto_loss(self.model, self.p, self.tensorboard_write_bool) 

538 self.regularization_loss = regularization_loss( 

539 self.model, self.p, self.tensorboard_write_bool 

540 ) 

541 self.center_loss = center_loss( 

542 self.model, self.p, self.tensorboard_write_bool 

543 ) 

544 self._loss = [self.auto_loss, self.regularization_loss, self.center_loss] 

545 elif loss == "mse": 

546 self._loss = "mse" 

547 else: 

548 raise ValueError( 

549 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'." 

550 ) 

551 

552 def train(self): 

553 """Starts the training of the model.""" 

554 if self.p.training == "custom" and self.p.batched: 

555 raise NotImplementedError() 

556 elif self.p.training == "custom" and not self.p.batched: 

557 raise NotImplementedError() 

558 elif self.p.training == "auto": 

559 if self.p.tensorboard and self._log_images: 

560 # get the old backend because the Tensorboard Images callback will set 'Agg' 

561 old_backend = matplotlib.get_backend() 

562 # start_time = time.perf_counter() 

563 self.history = self.model.fit( 

564 self.dataset, 

565 batch_size=self.p.batch_size, 

566 epochs=self.p.n_steps, 

567 steps_per_epoch=1, 

568 verbose=0, 

569 callbacks=self.callbacks, 

570 ) 

571 # print("Execution time:", time.perf_counter() - start_time) 

572 else: 

573 raise ValueError( 

574 f"training argument needs to be `auto` or `custom`. You provided '{self.training}'." 

575 ) 

576 self.save(step=self.p.n_steps) 

577 # reset the backend. 

578 if self.p.tensorboard and self._log_images: 

579 matplotlib.use(old_backend) 

580 

581 def add_images_to_tensorboard( 

582 self, 

583 data=None, 

584 image_step=None, 

585 scatter_kws={"s": 20}, 

586 hist_kws={"bins": 50}, 

587 additional_fns=None, 

588 when="epoch", 

589 ): 

590 """Adds images to Tensorboard using the data in data and the ids in ids. 

591 

592 Args: 

593 data (Union[np.ndarray, list, None], optional): The input-data will be passed through the encoder 

594 part of the autoencoder. If None is provided a set of 10000 points from the provided 

595 train data will be taken. A list is needed for the functional API of the ADCAutoencoder, that takes 

596 a list of [angles, dihedrals, side_dihedrals]. Defaults to None. 

597 image_step (Union[int, None], optional): The interval in which to plot images to tensorboard. 

598 If None is provided, the update step will be the same as parameters.summary_step. Defaults to None. 

599 scatter_kws (dict, optional): A dict with items that matplotlib.pyplot.scatter() will accept. Defaults to 

600 {'s': 20}, which sets an appropriate size of scatter points for the size of datasets encodermap is 

601 usually used for. 

602 hist_kws (dict, optional): A dict with items that matplotlib.pyplot.scatter() will accept. You can 

603 choose a colorbar here. Defaults to {'bins': 50} which sets an appropriate bin count for the 

604 size of datasets encodermap is usually used for. 

605 additional_fns (Union[list, None], optional): A list of functions that will accept the low-dimensional 

606 output of the autoencoder's latent/bottleneck layer and return a tf.Tensor that can be logged 

607 by `tf.summary.image()`. See the notebook 'writing_custom_images_to_tensorboard.ipynb' in 

608 tutorials/notebooks_customization for more info. If None is provided no additional functions will be 

609 used to plot to tensorboard. Defaults to None. 

610 when (str, optional): When to log the images can be either 'batch', then the images will be logged after 

611 every step during training, or 'epoch', then only after every image_step epoch the images will be 

612 written. Defaults to 'epoch'. 

613 

614 """ 

615 if not self.p.tensorboard: 

616 print( 

617 "Nothing is written to Tensorboard for this Model. Please change parameters.tensorboard to True." 

618 ) 

619 return 

620 if image_step is None: 

621 image_step = self.p.summary_step 

622 

623 self._log_images = True 

624 

625 # make a dataset for images 

626 if data is None: 

627 if isinstance(self.train_data, np.ndarray): 

628 data = self.train_data 

629 elif isinstance(self.train_data, list) or self.sparse: 

630 data = self.train_data 

631 else: 

632 data = list(self.dataset.take(int(10000 / self.p.batch_size))) 

633 data = np.stack(data)[:, 0, :].reshape(-1, self.train_data.shape[1]) 

634 else: 

635 if type(data) != type(self.train_data): 

636 raise Exception( 

637 f"Provided data has wrong type. Train data in this class is {type(self.train_data)}, provided data is {type(data)}" 

638 ) 

639 

640 self.callbacks.append( 

641 ImageCallback( 

642 data, 

643 image_step, 

644 scatter_kws=scatter_kws, 

645 hist_kws=hist_kws, 

646 additional_fns=additional_fns, 

647 when=when, 

648 ) 

649 ) 

650 if isinstance(data, (np.ndarray, tf.sparse.SparseTensor)): 

651 print( 

652 f"Logging images with {data.shape}-shaped data every {image_step} epochs to Tensorboard at {self.p.main_path}" 

653 ) 

654 else: 

655 print( 

656 f"Logging images with {[i.shape for i in data]}-shaped data every {image_step} epochs to Tensorboard at {self.p.main_path}" 

657 ) 

658 

659 def plot_network(self): 

660 """Tries to plot the network using pydot, pydotplus and graphviz. Doesn't raise an exception if plotting is 

661 not possible. 

662 

663 Note: 

664 Refer to this guide to install these programs: 

665 https://stackoverflow.com/questions/47605558/importerror-failed-to-import-pydot-you-must-install-pydot-and-graphviz-for-py 

666 

667 """ 

668 try: 

669 plot_model(self.model, self.train_data.shape[1]) 

670 except: 

671 pass 

672 

673 def encode(self, data=None): 

674 """Calls encoder part of model. 

675 

676 Args: 

677 data (Union[np.ndarray, None], optional): The data to be passed top the encoder part. 

678 Can be either numpy ndarray or None. If None is provided a set of 10000 points from the provided 

679 train data will be taken. Defaults to None. 

680 

681 Returns: 

682 np.ndarray: The output from the bottlenack/latent layer. 

683 

684 """ 

685 if data is None: 

686 data = self.train_data 

687 if hasattr(self.model, "encoder"): 687 ↛ 689line 687 didn't jump to line 689, because the condition on line 687 was never false

688 out = self.model.encoder(data) 

689 elif hasattr(self.model, "encoder_model"): 

690 out = self.model.encoder_model(data) 

691 if isinstance(out, list): 691 ↛ 692line 691 didn't jump to line 692, because the condition on line 691 was never true

692 out = [o.numpy() for o in out] 

693 else: 

694 out = out.numpy() 

695 return out 

696 

697 def generate(self, data): 

698 """Duplication of decode. 

699 

700 In Autoencoder and EncoderMap this method is equivalent to `decode()`. In AngleDihedralCartesianAutoencoder 

701 this method will be overwritten to produce output molecular conformations. 

702 

703 Args: 

704 data (np.ndarray): The data to be passed to the decoder part of the model. Make sure that the 

705 shape of the data matches the number of neurons in the latent space. 

706 

707 Returns: 

708 np.ndarray: Oue output from the decoder part. 

709 

710 """ 

711 return self.model.decoder(data) 

712 

713 def decode(self, data): 

714 """Calls the decoder part of the model. 

715 

716 AngleDihedralCartesianAutoencoder will, like the other two classes' output a tuple of data. 

717 

718 Args: 

719 data (np.ndarray): The data to be passed to the decoder part of the model. Make sure that the 

720 shape of the data matches the number of neurons in the latent space. 

721 

722 Returns: 

723 np.ndarray: Oue output from the decoder part. 

724 """ 

725 out = self.decoder(data) 

726 if isinstance(out, list): 

727 out = [o.numpy() for o in out] 

728 else: 

729 out = out.numpy() 

730 return out 

731 

732 def save(self, step=None): 

733 """Saves the model to the current path defined in `parameters.main_path`. 

734 

735 Args: 

736 step (Union[int, None], optional): Does not actually save the model at the given training step, but rather 

737 changes the string used for saving the model from an datetime format to another. 

738 

739 """ 

740 if not self.read_only: 

741 save_model(self.model, self.p.main_path, self.__class__.__name__, step=step) 

742 

743 def close(self): 

744 """Clears the current keras backend and frees up resources.""" 

745 # clear old sessions 

746 tf.keras.backend.clear_session() 

747 

748 

749class EncoderMap(Autoencoder): 

750 """Complete copy of Autoencoder class but uses additional distance cost 

751 scaled by the SketchMap sigmoid params""" 

752 

753 @classmethod 

754 def from_checkpoint( 

755 cls, 

756 checkpoint_path, 

757 read_only=True, 

758 overwrite_tensorboard_bool=False, 

759 sparse=False, 

760 ): 

761 """Reconstructs the model from a checkpoint.""" 

762 # Is this classmethod necessary? We need to make sure the class knows all losses. 

763 # And I don't know if the parent class calls the correct loss.setter 

764 return load_model( 

765 cls, checkpoint_path, read_only, overwrite_tensorboard_bool, sparse=sparse 

766 ) 

767 

768 @Autoencoder.loss.setter 

769 def loss(self, loss): 

770 if loss == "reconstruction_loss": 770 ↛ 771line 770 didn't jump to line 771, because the condition on line 770 was never true

771 self._loss = reconstruction_loss(self.model) 

772 elif loss == "emap_cost": 772 ↛ 790line 772 didn't jump to line 790, because the condition on line 772 was never false

773 self.auto_loss = auto_loss(self.model, self.p, self.tensorboard_write_bool) 

774 self.regularization_loss = regularization_loss( 

775 self.model, self.p, self.tensorboard_write_bool 

776 ) 

777 self.center_loss = center_loss( 

778 self.model, self.p, self.tensorboard_write_bool 

779 ) 

780 # this cost is new 

781 self.distance_loss = distance_loss( 

782 self.model, self.p, self.tensorboard_write_bool 

783 ) 

784 self._loss = [ 

785 self.auto_loss, 

786 self.regularization_loss, 

787 self.center_loss, 

788 self.distance_loss, 

789 ] 

790 elif loss == "mse": 

791 self._loss = "mse" 

792 else: 

793 raise ValueError( 

794 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'." 

795 ) 

796 

797 

798class DihedralEncoderMap(EncoderMap): 

799 """Similar to the `EncoderMap` class, but overwrites the `generate` method. 

800 

801 Using this class, instead of tbe `EncoderMap` class, the `generate` method, 

802 needs an additional argument: `top`, which should be a topology file. This 

803 topology will be used as a base on which the dihedrals of the `decode` 

804 method are applied. 

805 

806 """ 

807 

808 def generate(self, data: np.ndarray, top: str) -> MDAnalysis.Universe: 

809 """Overwrites `EncoderMap`'s generate method and actually does backmapping if a list of dihedrals is 

810 provided. 

811 

812 Args: 

813 data (np.ndarray): The low-dimensional/latent/bottleneck data. A ndim==2 numpy array with xy coordinates 

814 of points in latent space. 

815 top (str): Topology file for this run of EncoderMap (can be .pdb, .gro, .. etc.). 

816 

817 Returns: 

818 MDAnalysis.Universe: The topology with the provided backbone torsions. 

819 

820 Examples: 

821 >>> # get some time-resolved pdb files 

822 >>> import requests 

823 >>> import numpy as np 

824 >>> pdb_link = 'https://files.rcsb.org/view/1YUF.pdb' 

825 >>> contents = requests.get(pdb_link).text 

826 >>> print(contents.splitlines()[0]) # doctest: +SKIP 

827 HEADER GROWTH FACTOR 01-APR-96 1YUF 

828 >>> # fake a file with stringio 

829 >>> from io import StringIO 

830 >>> import MDAnalysis as mda 

831 >>> import numpy as np 

832 >>> file = StringIO(contents) 

833 >>> # pass it to MDAnalysis 

834 >>> u = mda.Universe(file, format='PDB') 

835 >>> print(u) 

836 <Universe with 720 atoms> 

837 >>> # select the atomgroups 

838 >>> ags = [*[res.psi_selection() for res in u.residues], 

839 ... *[res.omega_selection() for res in u.residues], 

840 ... *[res.phi_selection() for res in u.residues] 

841 ... ] 

842 >>> # filter Nones 

843 >>> ags = list(filter(lambda x: False if x is None else True, ags)) 

844 >>> print(ags[0][0]) # doctest: +SKIP 

845 <Atom 3: C of type C of resname VAL, resid 1 and segid A and altLoc > 

846 >>> # Run dihedral Angles 

847 >>> from MDAnalysis.analysis.dihedrals import Dihedral 

848 >>> R = np.deg2rad(Dihedral(ags).run().results.angles) 

849 >>> print(R.shape) 

850 (16, 147) 

851 >>> # import EncoderMap and define parameters 

852 >>> from encodermap.autoencoder import DihedralEncoderMap 

853 >>> import encodermap as em 

854 >>> parameters = em.Parameters( 

855 ... dist_sig_parameters = (4.5, 12, 6, 1, 2, 6), 

856 ... periodicity = 2*np.pi, 

857 ... l2_reg_constant = 10.0, 

858 ... summary_step = 5, 

859 ... tensorboard = False, 

860 ... ) 

861 >>> e_map = DihedralEncoderMap(parameters, R, read_only=True) 

862 >>> print(e_map.__class__.__name__) 

863 DihedralEncoderMap 

864 >>> # get some low-dimensional data 

865 >>> lowd = np.random.random((100, 2)) 

866 >>> # use the generate method to get a new MDAnalysis universe 

867 >>> # but first remove the time resolution 

868 >>> file = StringIO(contents.split('MODEL 2')[0]) 

869 >>> new = e_map.generate(lowd, file) 

870 >>> print(new.trajectory.coordinate_array.shape) 

871 (100, 720, 3) 

872 >>> # check whether frame 0 of u and new_u are different 

873 >>> for ts in u.trajectory: 

874 ... a1 = ts.positions 

875 ... break 

876 >>> print(np.array_equal(a1, new.trajectory.coordinate_array[0])) 

877 False 

878 

879 """ 

880 assert np.any(data) 

881 dihedrals = self.decode(data) 

882 assert np.any(dihedrals) 

883 uni = dihedral_backmapping(top, dihedrals) 

884 return uni 

885 

886 

887class AngleDihedralCartesianEncoderMap(Autoencoder): 

888 """Different `__init__` method, than Autoencoder Class. Uses callbacks to tune-in cartesian cost. 

889 

890 Overwritten methods: `_set_up_callbacks` and `generate`. 

891 

892 Examples: 

893 >>> import encodermap as em 

894 >>> # Load two trajectories 

895 >>> xtcs = ["tests/data/1am7_corrected_part1.xtc", "tests/data/1am7_corrected_part2.xtc"] 

896 >>> tops = ["tests/data/1am7_protein.pdb", "tests/data/1am7_protein.pdb"] 

897 >>> trajs = em.load(xtcs, tops) 

898 >>> print(trajs) 

899 encodermap.TrajEnsemble object. Current backend is no_load. Containing 2 trajs. Not containing any CVs. 

900 >>> # load CVs 

901 >>> # This step can be omitted. The AngleDihedralCartesianEncoderMap class automatically loads CVs 

902 >>> trajs.load_CVs('all') 

903 >>> print(trajs.CVs['central_cartesians'].shape) 

904 (51, 474, 3) 

905 >>> print(trajs.CVs['central_dihedrals'].shape) 

906 (51, 471) 

907 >>> # create some parameters 

908 >>> p = em.ADCParameters(periodicity=360, use_backbone_angles=True, use_sidechains=True, 

909 ... cartesian_cost_scale_soft_start=(6, 12)) 

910 >>> # Standard is functional model, as it offers more flexibility 

911 >>> print(p.model_api) 

912 functional 

913 >>> print(p.distance_cost_scale) 

914 None 

915 >>> # Instantiate the class 

916 >>> e_map = em.AngleDihedralCartesianEncoderMap(trajs, p, read_only=True) 

917 >>> # dataset contains these inputs: 

918 >>> # central_angles, central_dihedrals, central_cartesians, central_distances, sidechain_dihedrals 

919 >>> print(e_map.dataset) 

920 <BatchDataset element_spec=(TensorSpec(shape=(None, 472), dtype=tf.float32, name=None), TensorSpec(shape=(None, 471), dtype=tf.float32, name=None), TensorSpec(shape=(None, 474, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 473), dtype=tf.float32, name=None), TensorSpec(shape=(None, 316), dtype=tf.float32, name=None))> 

921 >>> # output from the model contains the following data: 

922 >>> # out_angles, out_dihedrals, back_cartesians, pairwise_distances of inp cartesians, pairwise of back-mapped cartesians, out_side_dihedrals 

923 >>> for data in e_map.dataset.take(1): 

924 ... pass 

925 >>> out = e_map.model(data) 

926 >>> print([i.shape for i in out]) 

927 [TensorShape([256, 472]), TensorShape([256, 471]), TensorShape([256, 474, 3]), TensorShape([256, 112101]), TensorShape([256, 112101]), TensorShape([256, 316])] 

928 >>> # get output of latent space by providing central_angles, central_dihedrals, sidehcain_dihedrals 

929 >>> latent = e_map.encoder([data[0], data[1], data[-1]]) 

930 >>> print(latent.shape) 

931 (256, 2) 

932 >>> # Rebuild central_angles, central_dihedrals and sidechain_angles from latent 

933 >>> dih, ang, side_dih = e_map.decode(latent) 

934 >>> print(dih.shape, ang.shape, side_dih.shape) 

935 (256, 472) (256, 471) (256, 316) 

936 

937 """ 

938 

939 def __init__( 

940 self, 

941 trajs: encodermap.TrajEnsemble, 

942 parameters: Optional[encodermap.ADCParameters] = None, 

943 model: Optional[tensorflow.keras.Model] = None, 

944 read_only: bool = False, 

945 cartesian_loss_step: int = 0, 

946 top: Optional[mdtraj.Topology] = None, 

947 ) -> None: 

948 """Instantiate the `AngleDihedralCartesianEncoderMap` class. 

949 

950 Args: 

951 trajs (em.TrajEnsemble): The trajectories to be used as input. If trajs contain no CVs, correct CVs will be loaded. 

952 parameters (Optional[em.ACDParameters]): The parameters for the current run. Can be set to None and the 

953 default parameters will be used. Defaults to None. 

954 model (Optional[tf.keras.models.Model]): The keras model to use. You can provide your own model 

955 with this argument. If set to None, the model will be built to the specifications of parameters using 

956 either the functional or sequential API. Defaults to None 

957 read_only (bool): Whether to write anything to disk (False) or not (True). Defaults to False. 

958 cartesian_loss_step (int, optional): For loading and re-training the model. The cartesian_distance_loss 

959 is tuned in step-wise. For this the start step of the training needs to be accounted for. If the 

960 scale of the cartesian loss should increase from epoch 6 to epoch 12 and the model is saved at 

961 epoch 9, this argument should also be set to 9, to continue training with the correct scaling 

962 factor. Defaults to 0. 

963 

964 """ 

965 # parameters 

966 if parameters is None: 966 ↛ 967line 966 didn't jump to line 967, because the condition on line 966 was never true

967 self.p = ADCParameters() 

968 else: 

969 self.p = parameters 

970 

971 # seed 

972 if self.p.seed is not None: 972 ↛ 973line 972 didn't jump to line 973, because the condition on line 972 was never true

973 tf.random.set_seed(self.p.seed) 

974 

975 # read_only 

976 self.read_only = read_only 

977 

978 # will be saved and overwritten when loading. 

979 self.cartesian_loss_step = cartesian_loss_step 

980 

981 # save params and create dir 

982 if not self.read_only: 

983 self.p.save() 

984 print( 

985 "Output files are saved to {}".format(self.p.main_path), 

986 "as defined in 'main_path' in the parameters.", 

987 ) 

988 

989 # check whether Tensorboard and Read-Only makes Sense 

990 if self.read_only and self.p.tensorboard: 990 ↛ 991line 990 didn't jump to line 991, because the condition on line 990 was never true

991 raise BadError( 

992 "Setting tensorboard and read_only True is not possible. Tensorboard will always write to disk." 

993 " If you received this Error while loading a trained model, pass read_only=False as an argument" 

994 f" or set overwrite_tensorboard_bool True to overwrite the tensorboard parameter." 

995 ) 

996 

997 # clear old sessions 

998 tf.keras.backend.clear_session() 

999 

1000 # get the CVs: 

1001 if isinstance(trajs, str): 1001 ↛ 1002line 1001 didn't jump to line 1002, because the condition on line 1001 was never true

1002 self.trajs = TrajEnsemble([trajs], [top]) 

1003 else: 

1004 self.trajs = trajs 

1005 

1006 # load missing values 

1007 should_be = set( 

1008 [ 

1009 "central_angles", 

1010 "central_cartesians", 

1011 "central_dihedrals", 

1012 "central_distances", 

1013 "side_dihedrals", 

1014 ] 

1015 ) 

1016 

1017 if self.trajs.CVs_in_file: 

1018 raise NotImplementedError( 

1019 "Write a tf.data.Dataset.from_generator function in enocdermap.data using the data from the netCDF files" 

1020 ) 

1021 elif self.trajs.CVs: 1021 ↛ 1027line 1021 didn't jump to line 1027, because the condition on line 1021 was never false

1022 missing = list(should_be - set(trajs.CVs.keys())) 

1023 if missing != []: 1023 ↛ 1024line 1023 didn't jump to line 1024, because the condition on line 1023 was never true

1024 print("loading missing values: ", missing) 

1025 self.trajs.load_CVs(missing, ensemble=False) 

1026 else: 

1027 self.trajs.load_CVs(list(should_be), ensemble=False) 

1028 

1029 if not should_be - set(self.trajs.CVs.keys()) == set(): 1029 ↛ 1030line 1029 didn't jump to line 1030, because the condition on line 1029 was never true

1030 raise BadError( 

1031 f"Could not load CVs. Should be {should_be}, but currenlty only {set(trajs.CVs.keys())} are loaded" 

1032 ) 

1033 

1034 # define inputs 

1035 self.sparse, self.train_data, self.inp_CV_data = self.get_train_data_from_trajs( 

1036 self.trajs, self.p 

1037 ) 

1038 

1039 # create dataset 

1040 dataset = tf.data.Dataset.from_tensor_slices( 

1041 ( 

1042 self.inp_CV_data["central_angles"], 

1043 self.inp_CV_data["central_dihedrals"], 

1044 self.inp_CV_data["central_cartesians"], 

1045 self.inp_CV_data["central_distances"], 

1046 self.inp_CV_data["side_dihedrals"], 

1047 ) 

1048 ) 

1049 dataset = dataset.shuffle( 

1050 buffer_size=self.inp_CV_data["central_cartesians"].shape[0], 

1051 reshuffle_each_iteration=True, 

1052 ) 

1053 dataset = dataset.repeat() 

1054 self.dataset = dataset.batch(self.p.batch_size) 

1055 

1056 # ToDo: Make training faster with Autotune, XLA (jit) compilation, DataRecords 

1057 # self.dataset = self.dataset.prefetch(self.p.batch_size * 4) 

1058 # self.dataset = self.dataset.interleave(num_parallel_calls=tf.data.experimental.AUTOTUNE) 

1059 # num_batches = 20 

1060 # num_threads = 8 

1061 # ds = tf.data.Dataset.from_tensors(self.train_data).repeat(self.p.batch_size * num_batches) 

1062 # ds = ds.batch(self.p.batch_size) 

1063 # self.dataset = ds.interleave(lambda *args:tf.data.Dataset.from_tensor_slices(args), num_threads, 1, num_threads) 

1064 

1065 # create model based on user input 

1066 if model is None: 1066 ↛ 1069line 1066 didn't jump to line 1069, because the condition on line 1066 was never false

1067 self.model = self.p.model_api 

1068 else: 

1069 self._model = model 

1070 

1071 # setup callbacks 

1072 self._setup_callbacks() 

1073 

1074 # create loss based on user input 

1075 self.loss = self.p.loss 

1076 

1077 # choose optimizer 

1078 self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.p.learning_rate) 

1079 

1080 # compile model 

1081 self.model.compile( 

1082 optimizer=self.optimizer, loss=self.loss, metrics=self.metrics 

1083 ) 

1084 

1085 # do this if tensorboard is true. 

1086 if self.p.tensorboard: 1086 ↛ 1088line 1086 didn't jump to line 1088, because the condition on line 1086 was never true

1087 # print shapes 

1088 print("input shapes are:") 

1089 print({k: v.shape for k, v in self.inp_CV_data.items()}) 

1090 # set _log_images to False to fix the backend after training 

1091 self._log_images = False 

1092 # get the output from model summary. 

1093 with Capturing() as output: 

1094 self.model.summary() 

1095 with open(self.p.main_path + "/model_summary.txt", "w") as f: 

1096 f.write("\n".join(output)) 

1097 try: 

1098 tf.keras.utils.plot_model( 

1099 self.model, 

1100 to_file=self.p.main_path + "/model_summary.png", 

1101 show_shapes=True, 

1102 rankdir="TB", 

1103 expand_nested=True, 

1104 ) 

1105 except Exception as e: 

1106 print(f"saving image gave error: {e}") 

1107 # todo: add image of cat 

1108 # from ..parameters import parameters as _p 

1109 # cat_image = os.path.split(os.path.split(os.path.split(_p.__file__)[0])[0])[0] + '/pic/priscilla-du-preez-8NXmaXg5xL0-unsplash.jpg' 

1110 # image = plt.imread(cat_image) 

1111 # plt.imshow(image) 

1112 # print(cat_image) 

1113 print( 

1114 f"Saved a text-summary of the model and an image in {self.p.main_path},", 

1115 "as specified in 'main_path' in the parameters.", 

1116 ) 

1117 

1118 # sets up the tb callback to plot the model 

1119 self.tb_callback = tf.keras.callbacks.TensorBoard( 

1120 self.p.main_path, write_graph=True 

1121 ) 

1122 self.tb_callback.set_model(self.model) 

1123 

1124 @staticmethod 

1125 def get_train_data_from_trajs(trajs, p, attr="CVs"): 

1126 if not any([np.isnan(x).any() for x in getattr(trajs, attr).values()]): 1126 ↛ 1132line 1126 didn't jump to line 1132, because the condition on line 1126 was never false

1127 inp_CV_data = { 

1128 key: val.astype("float32") for key, val in getattr(trajs, attr).items() 

1129 } 

1130 sparse = False 

1131 else: 

1132 sparse = True 

1133 print("Input contains nans. Using sparse network.") 

1134 inp_CV_data = { 

1135 key: val.astype("float32") for key, val in getattr(trajs, attr).items() 

1136 } 

1137 

1138 # squeeze, if xarray is provided 

1139 if all([hasattr(v, "values") for v in inp_CV_data.values()]): 

1140 inp_CV_data = {k: v.values.squeeze() for k, v in inp_CV_data.items()} 

1141 

1142 for k, v in inp_CV_data.items(): 

1143 if np.any(np.isnan(v)): 

1144 values = v 

1145 if k == "central_cartesians": 

1146 values = values.reshape(len(values), -1) 

1147 indices = np.stack(np.where(~np.isnan(values))).T.astype("int64") 

1148 dense_shape = values.shape 

1149 values = values[~np.isnan(values)].flatten() 

1150 sparse_tensor = tf.sparse.SparseTensor(indices, values, dense_shape) 

1151 inp_CV_data[k] = sparse_tensor 

1152 

1153 if not p.use_backbone_angles and not p.use_sidechains: 1153 ↛ 1154line 1153 didn't jump to line 1154, because the condition on line 1153 was never true

1154 train_data = inp_CV_data["central_dihedrals"] 

1155 elif p.use_backbone_angles and not p.use_sidechains: 1155 ↛ 1156line 1155 didn't jump to line 1156

1156 train_data = [ 

1157 inp_CV_data["central_angles"], 

1158 inp_CV_data["central_dihedrals"], 

1159 ] 

1160 if p.model_api == "sequential" and not sparse: 

1161 train_data = np.hstack(train_data) 

1162 elif p.use_backbone_angles and p.use_sidechains: 1162 ↛ 1171line 1162 didn't jump to line 1171, because the condition on line 1162 was never false

1163 train_data = [ 

1164 inp_CV_data["central_angles"], 

1165 inp_CV_data["central_dihedrals"], 

1166 inp_CV_data["side_dihedrals"], 

1167 ] 

1168 if p.model_api == "sequential" and not sparse: 1168 ↛ 1169line 1168 didn't jump to line 1169, because the condition on line 1168 was never true

1169 train_data = np.hstack(train_data) 

1170 else: 

1171 raise Exception( 

1172 "Cannot train model with central dihedrals and side dihedrals only. Backbone angles are required." 

1173 ) 

1174 

1175 # some checks for the length of the train data 

1176 if p.model_api == "functional": 1176 ↛ 1184line 1176 didn't jump to line 1184, because the condition on line 1176 was never false

1177 if not p.use_backbone_angles and not p.use_sidechains: 1177 ↛ 1178line 1177 didn't jump to line 1178, because the condition on line 1177 was never true

1178 assert isinstance(train_data, tf.sparse.SparseTensor) 

1179 elif p.use_backbone_angles and not p.use_sidechains: 1179 ↛ 1180line 1179 didn't jump to line 1180, because the condition on line 1179 was never true

1180 assert len(train_data) == 2 

1181 else: 

1182 assert len(train_data) == 3 

1183 

1184 return sparse, train_data, inp_CV_data 

1185 

1186 @classmethod 

1187 def from_checkpoint( 

1188 cls, trajs, checkpoint_path, read_only=True, overwrite_tensorboard_bool=False 

1189 ): 

1190 """Reconstructs the model from a checkpoint.""" 

1191 # Is this classmethod necessary? We need to make sure the class knows all losses. 

1192 # And I don't know if the parent class calls the correct loss.setter 

1193 return load_model( 

1194 cls, checkpoint_path, read_only, overwrite_tensorboard_bool, trajs 

1195 ) 

1196 

1197 def _setup_callbacks(self) -> None: 

1198 """Overwrites the parent class' `_setup_callbacks` method. 

1199 

1200 Due to the 'soft start' of the cartesian cost, the `cartesiand_increase_callback` 

1201 needs to be added to the list of callbacks. 

1202 

1203 """ 

1204 super(self.__class__, self)._setup_callbacks() 

1205 if self.p.cartesian_cost_scale_soft_start != (None, None): 

1206 self.cartesian_increase_callback = IncreaseCartesianCost( 

1207 self.p, start_step=self.cartesian_loss_step 

1208 ) 

1209 self.callbacks.append(self.cartesian_increase_callback) 

1210 

1211 def save(self, step: Optional[int] = None) -> None: 

1212 """Saves the model to the current path defined in `parameters.main_path`. 

1213 

1214 Args: 

1215 step (Optional[int]): Does not actually save the model at the given training step, but rather 

1216 changes the string used for saving the model from an datetime format to another. 

1217 

1218 """ 

1219 if not self.read_only: 

1220 save_model( 

1221 self.model, 

1222 self.p.main_path, 

1223 self.__class__.__name__, 

1224 step=step, 

1225 current_step=self.cartesian_loss_step, 

1226 ) 

1227 

1228 @Autoencoder.loss.setter 

1229 def loss(self, loss): 

1230 if loss == "reconstruction_loss": 1230 ↛ 1231line 1230 didn't jump to line 1231, because the condition on line 1230 was never true

1231 self._loss = reconstruction_loss(self.model) 

1232 elif loss == "emap_cost": 1232 ↛ 1277line 1232 didn't jump to line 1277, because the condition on line 1232 was never false

1233 self.dihedral_loss = dihedral_loss( 

1234 self.model, self.p, self.tensorboard_write_bool 

1235 ) 

1236 self.angle_loss = angle_loss( 

1237 self.model, self.p, self.tensorboard_write_bool 

1238 ) 

1239 if self.p.cartesian_cost_scale_soft_start != (None, None): 

1240 self.cartesian_loss = cartesian_loss( 

1241 self.model, 

1242 self.cartesian_increase_callback, 

1243 self.p, 

1244 self.tensorboard_write_bool, 

1245 ) 

1246 else: 

1247 self.cartesian_loss = cartesian_loss( 

1248 self.model, None, self.p, self.tensorboard_write_bool 

1249 ) 

1250 

1251 self.distance_loss = distance_loss( 

1252 self.model, self.p, self.tensorboard_write_bool 

1253 ) 

1254 self.cartesian_distance_loss = cartesian_distance_loss( 

1255 self.model, self.p, self.tensorboard_write_bool 

1256 ) 

1257 self.center_loss = center_loss( 

1258 self.model, self.p, self.tensorboard_write_bool 

1259 ) 

1260 self.regularization_loss = regularization_loss( 

1261 self.model, self.p, self.tensorboard_write_bool 

1262 ) 

1263 self._loss = [ 

1264 self.dihedral_loss, 

1265 self.angle_loss, 

1266 self.cartesian_loss, 

1267 self.distance_loss, 

1268 self.cartesian_distance_loss, 

1269 self.center_loss, 

1270 self.regularization_loss, 

1271 ] 

1272 if self.p.use_sidechains: 1272 ↛ exitline 1272 didn't return from function 'loss', because the condition on line 1272 was never false

1273 self.side_dihedral_loss = side_dihedral_loss( 

1274 self.model, self.p, self.tensorboard_write_bool 

1275 ) 

1276 self._loss.append(self.side_dihedral_loss) 

1277 elif loss == "mse": 

1278 self._loss = "mse" 

1279 else: 

1280 raise ValueError( 

1281 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'." 

1282 ) 

1283 

1284 def train(self) -> None: 

1285 """Overrides the parent class' `train` method. 

1286 

1287 After the training is finished, an additional file is written to disk, 

1288 which saves the current epoch. In the event that training will continue, 

1289 the current state of the soft-start cartesian cost is read from that file. 

1290 

1291 """ 

1292 super(self.__class__, self).train() 

1293 self.cartesian_loss_step += self.p.n_steps 

1294 fname = f"{self.p.main_path}/saved_model_{self.p.n_steps}.model" 

1295 with open(fname + "_current_step.txt", "w") as f: 

1296 f.write(str(self.cartesian_loss_step)) 

1297 

1298 def encode(self, data=None): 

1299 if hasattr(data, "_traj_file"): 1299 ↛ 1300line 1299 didn't jump to line 1300, because the condition on line 1299 was never true

1300 _, data, __ = self.get_train_data_from_trajs(data, self.p, attr="_CVs") 

1301 elif hasattr(data, "traj_files"): 1301 ↛ 1302line 1301 didn't jump to line 1302, because the condition on line 1301 was never true

1302 _, data, __ = self.get_train_data_from_trajs(data, self.p) 

1303 return super().encode(data) 

1304 

1305 def generate( 

1306 self, 

1307 points: np.ndarray, 

1308 top: Optional[str, int, mdtraj.Topology] = None, 

1309 backend: Literal["mdtraj", "mdanalysis"] = "mdtraj", 

1310 ) -> Union[MDAnalysis.Universe, mdtraj.Trajectory]: 

1311 """Overrides the parent class' `generate` method and builds a trajectory. 

1312 

1313 Instead of just providing data to `decode` using the decoder part of the 

1314 network, this method also takes a molecular topology as its `top` 

1315 argument. This topology is then used to rebuild a time-resolved 

1316 trajectory. 

1317 

1318 Args: 

1319 points (np.ndarray): The low-dimensional points from which the 

1320 trajectory should be rebuilt. 

1321 top (Optional[str, int, mdtraj.Topology]): The topology to be used for rebuilding the 

1322 trajectory. This should be a string pointing towards a <*.pdb, 

1323 *.gro, *.h5> file. Alternatively, None can be provided, in which 

1324 case, the internal topology (`self.top`) of this class is used. 

1325 Defaults to None. 

1326 backend (str): Defines what MD python package to use, to build the 

1327 trajectory and also what type this method returns, needs to be 

1328 one of the following: 

1329 * "mdtraj" 

1330 * "mdanalysis" 

1331 

1332 Returns: 

1333 Union[mdtraj.Trajectory, MDAnalysis.universe]: The trajectory after 

1334 applying the decoded structural information. The type of this 

1335 depends on the chosen `backend` parameter. 

1336 

1337 """ 

1338 # get the output this can be done regardless 

1339 out = self.decode(points) 

1340 

1341 if top is None: 

1342 top = self.trajs.top_files 

1343 if len(top) > 1: 

1344 print( 

1345 f"Please specify which topology you would like to use for generating " 

1346 f"conformations. You can either provide a `str` to a topology file " 

1347 f"(file extension .pdb, .h5, .gro) on disk, or a `int` specifying the " 

1348 f"`SingleTraj` object in this class' {self.trajs.n_trajs} trajs, or " 

1349 f"you can also specify a `mdtraj.Topology` object." 

1350 ) 

1351 return 

1352 else: 

1353 top = top[0] 

1354 trajs = self.trajs 

1355 if top not in self.trajs.top_files: 

1356 raise Exception( 

1357 "Provided topology was not used to train Encodermap." 

1358 ) 

1359 

1360 # get the output 

1361 if not self.p.use_backbone_angles and not self.p.use_sidechains: 

1362 dihedrals = self.decode(points) 

1363 elif self.p.use_backbone_angles and not self.p.use_sidechains: 

1364 splits = [trajs.CVs["central_angles"].shape[1]] 

1365 out = self.decode(points) 

1366 if isinstance(out, np.ndarray): 

1367 angles, dihedrals = np.split(out, splits, axis=1) 

1368 elif self.p.use_backbone_angles and self.p.use_sidechains: 

1369 splits = [ 

1370 trajs.CVs["central_angles"].shape[1], 

1371 trajs.CVs["central_angles"].shape[1] 

1372 + trajs.CVs["central_dihedrals"].shape[1], 

1373 ] 

1374 

1375 if isinstance(out, np.ndarray): 

1376 angles, dihedrals, sidechain_dihedrals = np.array_split( 

1377 out, splits, axis=1 

1378 ) 

1379 else: 

1380 angles, dihedrals, sidechain_dihedrals = out 

1381 

1382 # in this case we can just use any traj from self.trajs 

1383 traj = self.trajs 

1384 

1385 else: 

1386 if len(self.trajs.top_files) == 1: 

1387 trajs = self.trajs 

1388 if top not in self.trajs.top_files: 

1389 raise Exception( 

1390 "Provided topology was not used to train Encodermap." 

1391 ) 

1392 else: 

1393 if isinstance(top, str): 

1394 pass 

1395 elif isinstance(top, int): 

1396 top_ = self.trajs[top].traj[0] 

1397 top_.save_pdb("/tmp/tmp.pdb") 

1398 top = "/tmp/tmp.pdb" 

1399 elif isinstance(top, mdtraj.Topology): 

1400 top.save_pdb("/tmp/tmp.pdb") 

1401 top = "/tmp/tmp.pdb" 

1402 else: 

1403 raise TypeError( 

1404 f"Provided type for `top` must be `str`, `int`, or `mdtraj.Topology`, " 

1405 f"you provided {type(top)}." 

1406 ) 

1407 

1408 # align the topology with the trajs in self.trajs 

1409 from ..loading import features 

1410 from ..loading.featurizer import UNDERSOCRE_MAPPING 

1411 

1412 UNDERSOCRE_MAPPING = {v: k for k, v in UNDERSOCRE_MAPPING.items()} 

1413 labels = {} 

1414 feature_names = [ 

1415 "CentralCartesians", 

1416 "CentralBondDistances", 

1417 "CentralAngles", 

1418 "CentralDihedrals", 

1419 "SideChainDihedrals", 

1420 ] 

1421 

1422 for feature in feature_names: 

1423 feature = getattr(features, feature)(top_.top, generic_labels=True) 

1424 labels[UNDERSOCRE_MAPPING[feature.name]] = feature.describe() 

1425 

1426 return_values = [ 

1427 "central_dihedrals", 

1428 "central_angles", 

1429 "side_dihedrals", 

1430 ] 

1431 splits = {} 

1432 for i, k in enumerate(return_values): 

1433 split = np.isin( 

1434 self.trajs[0]._CVs.coords[k.upper()].values, labels[k] 

1435 ) 

1436 splits[k] = split 

1437 

1438 # split the output 

1439 if not self.p.use_backbone_angles and not self.p.use_sidechains: 

1440 dihedrals = out[:, splits["central_dihedrals"]] 

1441 elif self.p.use_backbone_angles and not self.p.use_sidechains: 

1442 dihedrals = out[1][:, splits["central_dihedrals"]] 

1443 angles = out[2][:, splits["central_angles"]] 

1444 elif self.p.use_backbone_angles and self.p.use_sidechains: 

1445 dihedrals = out[1][:, splits["central_dihedrals"]] 

1446 angles = out[0][:, splits["central_angles"]] 

1447 sidechain_dihedrals = out[2][:, splits["side_dihedrals"]] 

1448 

1449 # if the backend is mdanalysis we need to save the pdb 

1450 if backend == "mdanalysis": 

1451 top_.save_pdb("/tmp/tmp.pdb") 

1452 top = "/tmp/tmp.pdb" 

1453 else: 

1454 # in this case we need to use a traj, which topolgy matches top 

1455 for i, traj in self.trajs.itertrajs(): 

1456 if traj.top == top_.top: 

1457 break 

1458 else: 

1459 raise Exception( 

1460 "Could not find a trajectory in self.trajs, " 

1461 "that matches the topology provided as `top`." 

1462 ) 

1463 traj = deepcopy(traj) 

1464 

1465 # do the backmapping 

1466 if backend == "mdanalysis": 

1467 uni = dihedral_backmapping(top, dihedrals, sidechains=sidechain_dihedrals) 

1468 return uni 

1469 elif backend == "mdtraj": 

1470 traj = mdtraj_backmapping(top, dihedrals, sidechain_dihedrals, traj) 

1471 return traj 

1472 else: 

1473 raise TypeError( 

1474 f"backend must be 'mdtraj' or 'mdanalysis', but you provided {backend}" 

1475 )