Coverage for encodermap/autoencoder/autoencoder.py: 8%

820 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-15 21:06 +0200

1# -*- coding: utf-8 -*- 

2# encodermap/autoencoder/autoencoder.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Forward facing Autoencoder classes. Contains four classes: 

23 

24* Autoencoder: Simple NN dense, fully connected AE architecture. Reg loss, auto loss and center loss 

25* EncoderMap: Uses the same architecture as `Autoencoder`, but adds another loss function. 

26* DihedralEncoderMap: Basically the same as `EncoderMap`, but rewrites the `generate` method to use 

27 an atomistic topology to rebuild a trajectory. 

28* AngleDihedralCartesianEncoderMap: Uses more loss functions and tries to learn a full all atom conformation. 

29 

30""" 

31 

32 

33################################################################################ 

34# Imports 

35################################################################################ 

36 

37 

38# Future Imports at the top 

39from __future__ import annotations 

40 

41# Standard Library Imports 

42import copy 

43import os 

44import warnings 

45from pathlib import Path 

46from tempfile import NamedTemporaryFile 

47 

48# Third Party Imports 

49import numpy as np 

50import tensorflow as tf 

51from optional_imports import _optional_import 

52from tqdm import tqdm 

53 

54# Encodermap imports 

55from encodermap.callbacks.callbacks import ( 

56 CheckpointSaver, 

57 ImageCallback, 

58 IncreaseCartesianCost, 

59 ProgressBar, 

60 TensorboardWriteBool, 

61) 

62from encodermap.callbacks.metrics import ADCClashMetric, ADCRMSDMetric 

63from encodermap.encodermap_tf1.backmapping import ( 

64 chain_in_plane, 

65 dihedrals_to_cartesian_tf, 

66) 

67from encodermap.loss_functions.loss_functions import ( 

68 angle_loss, 

69 auto_loss, 

70 cartesian_distance_loss, 

71 cartesian_loss, 

72 center_loss, 

73 dihedral_loss, 

74 distance_loss, 

75 reconstruction_loss, 

76 regularization_loss, 

77 side_dihedral_loss, 

78) 

79from encodermap.misc.backmapping import dihedral_backmapping, mdtraj_backmapping 

80from encodermap.misc.distances import pairwise_dist 

81from encodermap.misc.misc import create_n_cube, plot_model 

82from encodermap.misc.saving_loading_models import load_model, save_model 

83from encodermap.models.models import gen_functional_model, gen_sequential_model 

84from encodermap.parameters.parameters import ADCParameters, Parameters 

85from encodermap.trajinfo.info_all import TrajEnsemble 

86from encodermap.trajinfo.info_single import Capturing, SingleTraj 

87 

88 

89################################################################################ 

90# Optional Imports 

91################################################################################ 

92 

93 

94md = _optional_import("mdtraj") 

95mda = _optional_import("MDAnalysis") 

96 

97 

98################################################################################ 

99# Typing 

100################################################################################ 

101 

102 

103# Standard Library Imports 

104from collections.abc import Callable, Mapping 

105from typing import ( 

106 TYPE_CHECKING, 

107 Any, 

108 Literal, 

109 Optional, 

110 Sequence, 

111 Type, 

112 TypeVar, 

113 Union, 

114 overload, 

115) 

116 

117 

118AutoencoderType = TypeVar("AutoencoderType", bound="Parent") 

119EncoderMapType = TypeVar("EncoderMapType", bound="Parent") 

120DihedralEncoderMapType = TypeVar("DihedralEncoderMapType", bound="Parent") 

121AngleDihedralCartesianEncoderMapType = TypeVar( 

122 "AngleDihedralCartesianEncoderMapType", bound="Parent" 

123) 

124AutoencoderClass = Union[ 

125 AutoencoderType, 

126 EncoderMapType, 

127 DihedralEncoderMapType, 

128 AngleDihedralCartesianEncoderMapType, 

129] 

130 

131 

132if TYPE_CHECKING: 

133 # Third Party Imports 

134 from MDAnalysis import Universe 

135 from mdtraj import Topology, Trajectory 

136 

137 

138################################################################################ 

139# Globals 

140################################################################################ 

141 

142 

143__all__: list[str] = [ 

144 "Autoencoder", 

145 "EncoderMap", 

146 "AngleDihedralCartesianEncoderMap", 

147 "DihedralEncoderMap", 

148] 

149 

150 

151################################################################################ 

152# Utils 

153################################################################################ 

154 

155 

156def history_to_json_serializable(d: dict[str, Any]) -> dict[str, Any]: 

157 out = {} 

158 for k in list(d.keys()): 

159 if isinstance(k, Mapping): 

160 out |= history_to_json_serializable(d[k]) 

161 elif isinstance(d[k], np.ndarray): 

162 a = d[k] 

163 if a.ndim <= 1: 

164 out[k] = a.tolist() 

165 else: 

166 out[k] = [i.tolist() for i in a] 

167 elif isinstance(d[k], list) and all(isinstance(i, np.ndarray) for i in d[k]): 

168 out[k] = [i.tolist() for i in d[k]] 

169 else: 

170 out[k] = d[k] 

171 return out 

172 

173 

174def np_to_sparse_tensor(a: np.ndarray) -> tf.sparse.SparseTensor: 

175 """Converts a numpy array with nans to a SparseTensor. 

176 

177 Args: 

178 a (np.ndarray): The input array. 

179 

180 Returns: 

181 tf.sparse.SparseTensor: The corresponding SparseTensor. 

182 

183 """ 

184 orig_shape = a.shape 

185 indices = np.stack(np.where(~np.isnan(a))).T.astype("int64") 

186 dense_shape = a.shape 

187 a = a[~np.isnan(a)].flatten() 

188 if np.any(np.isnan(a)): 

189 raise Exception( 

190 f"NaN values in array with shape {orig_shape} could not be removed " 

191 f"by indexing with {indices=}. This will result in the SparseTensor " 

192 f"containing NaN values." 

193 ) 

194 return tf.sparse.SparseTensor(indices, a, dense_shape) 

195 

196 

197def _add_images_to_tensorboard( 

198 autoencoder: AutoencoderClass, 

199 data: Optional[Union[np.ndarray, Sequence[np.ndarray]]] = None, 

200 backend: Literal["matplotlib", "plotly"] = "matplotlib", 

201 image_step: Optional[int] = None, 

202 max_size: int = 10_000, 

203 mpl_scatter_kws: Optional[dict] = None, 

204 mpl_hist_kws: Optional[dict] = None, 

205 plotly_scatter_kws: Optional[dict] = None, 

206 plotly_hist_kws: Optional[dict] = None, 

207 additional_fns: Optional[Sequence[Callable]] = None, 

208 when: Literal["epoch", "batch"] = "epoch", 

209 save_to_disk: bool = False, 

210) -> None: 

211 """Adds images to Tensorboard using the data in data and the ids in ids. 

212 

213 Args: 

214 data (Optional[Union[np.ndarray, Sequence[np.ndarray]]): The input-data will 

215 be passed through the encoder part of the autoencoder. If None 

216 is provided, a set of 10_000 points from `self.train_data` will 

217 be taken. A list[np.ndarray] is needed for the functional API of the 

218 `AngleDihedralCartesianEncoderMap`, that takes a list of 

219 [angles, dihedrals, side_dihedrals]. Defaults to None. 

220 backend (Literal["matplotlib", "plotly"]: Which backend to use for 

221 plotting. Defaults to 'matplotlib'. 

222 mpl_scatter_kws (Optional[dict]): A dictionary, that `matplotlib.pyplot.scatter` 

223 takes as keyword args. If None is provided, the default dict 

224 is {"s": 20}. Defaults to None. 

225 mpl_hist_kws (Optional[dict]): A dictionary, that `matplotlib.pyplot.histogram` 

226 takes as keyword args. If None is provided, the default dict 

227 is {"bins": 50}. Defaults to None. 

228 plotly_scatter_kws (Optional[dict[str, Any]]): A dict with items that 

229 `plotly.express.scatter()` will accept. If None is provided, 

230 a dict with size 20 will be passed to 

231 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`, 

232 which sets an appropriate size of scatter points for the size of 

233 datasets encodermap is usually used for. 

234 plotly_hist_kws (Optional[dict[str, Any]]): A dict with items that 

235 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a 

236 dict with bins 50 will be passed to 

237 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`. 

238 You can choose a colormap here by providing `{'bins': 50, 'cmap': 

239 'plasma'}` for this argument. 

240 image_step (Optional[int]): The interval in which to plot 

241 images to tensorboard. If None is provided, the `image_step` 

242 will be the same as `Parameters.summary_step`. Defaults to None. 

243 max_size (int): The maximum size of the high-dimensional data, that is 

244 projected. Prevents excessively large-datasets from being projected 

245 at every `image_step`. Defaults to 10_000. 

246 additional_fns (Optional[Sequence[Callable]]): A list of functions 

247 that will accept the low-dimensional output of the `Autoencoder` 

248 latent/bottleneck layer and return a tf.Tensor that can be logged 

249 by `tf.summary.image()`. See the notebook 

250 'writing_custom_images_to_tensorboard.ipynb' in 

251 tutorials/notebooks_customization for more info. If None is 

252 provided, no additional functions will be used to plot to 

253 tensorboard. Defaults to None. 

254 when (Literal["epoch", "batch"]): When to log the images can be 

255 either 'batch', then the images will be logged after every step 

256 during training, or 'epoch', then only after every image_step 

257 epoch the images will be written. Defaults to 'epoch'. 

258 save_to_disk (bool): Whether to also write the images to disk. 

259 

260 """ 

261 if not autoencoder.p.tensorboard: 

262 warnings.warn( 

263 "Nothing is written to Tensorboard for this model. " 

264 "Please change parameters.tensorboard to True." 

265 ) 

266 return 

267 if any([isinstance(i, ImageCallback) for i in autoencoder.callbacks]): 

268 warnings.warn( 

269 f"This instance of {autoencoder.__class__.__name__} already has an " 

270 f"`ImageCallback`. While it's allowed to have multiple ImageCallbacks, " 

271 f"training performance suffers by adding more." 

272 ) 

273 if image_step is None: 

274 image_step = autoencoder.p.summary_step 

275 

276 if mpl_scatter_kws is None: 

277 mpl_scatter_kws = {"s": 20} 

278 if mpl_hist_kws is None: 

279 mpl_hist_kws = {"bins": 50} 

280 if plotly_scatter_kws is None: 

281 plotly_scatter_kws = {"size_max": 1, "opacity": 0.2} 

282 if plotly_hist_kws is None: 

283 plotly_hist_kws = {"bins": 50} 

284 if backend not in ["matplotlib", "plotly"]: 

285 raise Exception(f"Argument `backend` must be either 'plotly' or 'matplotlib'.") 

286 

287 autoencoder._log_images = True 

288 if data is None: 

289 if hasattr(autoencoder, "train_data"): 

290 data = autoencoder.train_data 

291 else: 

292 if hasattr(autoencoder, "trajs"): 

293 data = autoencoder.get_train_data_from_trajs( 

294 autoencoder.trajs, autoencoder.p, max_size=max_size 

295 )[1] 

296 else: 

297 if ( 

298 not autoencoder.p.use_backbone_angles 

299 and not autoencoder.p.use_sidechains 

300 ): 

301 data = [[]] 

302 elif ( 

303 autoencoder.p.use_backbone_angles 

304 and not autoencoder.p.use_sidechains 

305 ): 

306 data = [[], []] 

307 elif autoencoder.p.use_sidechains and autoencoder.p.use_backbone_angles: 

308 data = [[], [], []] 

309 else: 

310 raise Exception( 

311 "Only allowed combinations are:\n" 

312 " * No sidechains, no backbone angles\n" 

313 " * No sidechains, yes backbone angles\n" 

314 " * Yes Sidechains, yes backbone angles\n" 

315 f"Your parameters are: {autoencoder.p.use_sidechains=}. {autoencoder.p.use_backbone_angles=}" 

316 ) 

317 length = 0 

318 for d in autoencoder.dataset: 

319 if len(data) == 1: 

320 data[0].append(d[1]) 

321 elif len(data) == 2: 

322 data[0].append(d[0]) 

323 data[1].append(d[1]) 

324 elif len(data) == 3: 

325 data[0].append(d[0]) 

326 data[1].append(d[1]) 

327 data[2].append(d[-1]) 

328 length += autoencoder.p.batch_size 

329 if length > max_size: 

330 break 

331 for i, o in enumerate(data): 

332 if any([isinstance(d, tf.sparse.SparseTensor) for d in o]): 

333 o = [ 

334 tf.sparse.to_dense(_, default_value=np.nan).numpy() 

335 for _ in o 

336 ] 

337 o = np.concatenate(o) 

338 data[i] = np_to_sparse_tensor(o) 

339 else: 

340 data[i] = np.concatenate(o) 

341 else: 

342 max_size = -1 

343 

344 if isinstance(data, (np.ndarray, tf.sparse.SparseTensor)): 

345 if hasattr(autoencoder, "_tensorboard_data_req_shape"): 

346 assert np.array_equal( 

347 tf.shape(data).numpy()[1:], autoencoder._tensorboard_data_req_shape[1:] 

348 ), ( 

349 f"The provided `data` has the wrong shape. The provided data has " 

350 f"shape {tf.shape(data).numpy()}, whereas {autoencoder._tensorboard_data_req_shape} " 

351 f"was expected." 

352 ) 

353 else: 

354 for d in autoencoder.dataset: 

355 break 

356 if isinstance(data, (tuple, list)): 

357 assert data[0].shape[1:] == d[1].shape[1:], ( 

358 f"The provided `data` has the wrong shape. The provided data has " 

359 f"shape {data[0].shape[1:]}, whereas {d[1].shape[1:]} " 

360 f"was expected." 

361 ) 

362 else: 

363 assert data.shape[1:] == d[1].shape[1:], ( 

364 f"The provided `data` has the wrong shape. The provided data has " 

365 f"shape {data[0].shape[1:]}, whereas {d[1].shape[1:]} " 

366 f"was expected." 

367 ) 

368 if data.shape[0] > max_size and max_size >= 0: 

369 idx = np.unique( 

370 np.round(np.linspace(0, data.shape[0] - 1, max_size)).astype(int) 

371 ) 

372 if isinstance(data, tf.sparse.SparseTensor): 

373 data = tf.sparse.to_dense(data, default_value=np.nan).numpy()[idx] 

374 data = np_to_sparse_tensor(data) 

375 else: 

376 data = data[idx] 

377 if isinstance(data, np.ndarray): 

378 if np.any(np.isnan(data)): 

379 data = np_to_sparse_tensor(data) 

380 elif isinstance(data, (tuple, list)): 

381 for d in autoencoder.dataset: 

382 break 

383 if len(data) == 1: 

384 assert data[0].shape[1:] == d[1].shape[1:], ( 

385 f"The provided `data` has the wrong shape. The provided data has " 

386 f"shape {data[0].shape[1:]}, whereas {d[1].shape[1:]} " 

387 f"was expected." 

388 ) 

389 data = data[0] 

390 elif len(data) == 2: 

391 assert ( 

392 data[0].shape[1:] == d[0].shape[1:] 

393 and data[1].shape[1:] == d[1].shape[1:] 

394 ), ( 

395 f"The provided `data` has the wrong shape. The provided data has " 

396 f"shape {[_.shape[1:] for _ in data]}, whereas {[d[0].shape[1:], d[1].shape[1:]]} " 

397 f"was expected." 

398 ) 

399 elif len(data) == 3: 

400 assert ( 

401 data[0].shape[1:] == d[0].shape[1:] 

402 and data[1].shape[1:] == d[1].shape[1:] 

403 and data[2].shape[1:] == d[-1].shape[1:] 

404 ), ( 

405 f"The provided `data` has the wrong shape. The provided data has " 

406 f"shape {[_.shape[1:] for _ in data]}, whereas {[d[0].shape[1:], d[1].shape[1:], d[-1].shape[1:]]} " 

407 f"was expected." 

408 ) 

409 else: 

410 raise TypeError( 

411 f"Argument `data` should be of type None, np.ndarray, tuple, or " 

412 f"list, you provided {type(data)}." 

413 ) 

414 

415 # add the callback 

416 if save_to_disk: 

417 save_dir = Path(autoencoder.p.main_path) / "train_images" 

418 save_dir.mkdir(exist_ok=True) 

419 else: 

420 save_dir = None 

421 

422 autoencoder.callbacks.append( 

423 ImageCallback( 

424 parameters=autoencoder.p, 

425 highd_data=data, 

426 image_step=image_step, 

427 backend=backend, 

428 mpl_scatter_kws=mpl_scatter_kws, 

429 mpl_hist_kws=mpl_hist_kws, 

430 plotly_scatter_kws=plotly_scatter_kws, 

431 plotly_hist_kws=plotly_hist_kws, 

432 additional_fns=additional_fns, 

433 when=when, 

434 save_dir=save_dir, 

435 ) 

436 ) 

437 autoencoder.callbacks[-1].model = autoencoder.model 

438 if isinstance(data, (np.ndarray, tf.sparse.SparseTensor)): 

439 print( 

440 f"Logging images with {data.shape}-shaped data every " 

441 f"{image_step} epochs to Tensorboard at {autoencoder.p.main_path}" 

442 ) 

443 else: 

444 print( 

445 f"Logging images with {[i.shape for i in data]}-shaped data " 

446 f"every {image_step} epochs to Tensorboard at {autoencoder.p.main_path}" 

447 ) 

448 

449 

450def _print_save_message(autoencoder: AutoencoderClass) -> None: 

451 if autoencoder.p.main_path == Path(os.getcwd()): 

452 print( 

453 f"Output files are saved to {autoencoder.p.main_path}, which is the " 

454 f"current working trajectory." 

455 ) 

456 else: 

457 print( 

458 f"Output files are saved to {autoencoder.p.main_path} as defined " 

459 f"in 'main_path' in the parameters.", 

460 ) 

461 

462 

463def _get_model(autoencoder: AutoencoderClass) -> tf.keras.Model: 

464 """sets self.model according to `model_api` argument in self.parameters.""" 

465 model = autoencoder.p.model_api 

466 if model == "functional": 

467 assert isinstance(autoencoder, AngleDihedralCartesianEncoderMap) 

468 d = autoencoder.dataset.take(1) 

469 if any(isinstance(_, tf.SparseTensorSpec) for _ in d.element_spec): 

470 autoencoder.sparse = True 

471 if hasattr(autoencoder.p, "reconstruct_sidechains"): 

472 if autoencoder.p.reconstruct_sidechains: 

473 assert len(d.element_spec) == 7 

474 model = gen_functional_model( 

475 autoencoder.dataset, 

476 autoencoder.p, 

477 sparse=autoencoder.sparse, 

478 ) 

479 elif model == "sequential": 

480 assert ( 

481 isinstance(autoencoder, (Autoencoder, EncoderMap, DihedralEncoderMap)) 

482 or autoencoder.__class__.__name__ == "EncoderMap" 

483 ) 

484 if isinstance(autoencoder.train_data, tf.sparse.SparseTensor): 

485 autoencoder.sparse = True 

486 try: 

487 model = gen_sequential_model( 

488 autoencoder.train_data.shape[1], 

489 autoencoder.p, 

490 sparse=autoencoder.sparse, 

491 ) 

492 except AttributeError: 

493 if autoencoder.p.training == "custom": 

494 for d in autoencoder.train_data: 

495 break 

496 model = gen_sequential_model( 

497 d[0].get_shape().as_list()[1], 

498 autoencoder.p, 

499 sparse=autoencoder.sparse, 

500 ) 

501 elif autoencoder.p.training == "auto": 

502 for d, _ in autoencoder.train_data: 

503 break 

504 model = gen_sequential_model( 

505 d.get_shape().as_list()[1], 

506 autoencoder.p, 

507 sparse=autoencoder.sparse, 

508 ) 

509 else: 

510 raise Exception( 

511 f"Parameter `training` has to be one of 'custom', 'auto'. " 

512 f"You supplied '{autoencoder.p.training}'." 

513 ) 

514 

515 elif model == "custom": 

516 raise NotImplementedError("No custom API currently supported") 

517 else: 

518 raise ValueError( 

519 f"API argument needs to be one of `functional`, `sequential`, " 

520 f"`custom`. You provided '{model}'." 

521 ) 

522 assert not isinstance(model, str) 

523 return model 

524 

525 

526############################################################################## 

527# Function definition which allows self.p.tensorboard to be passed 

528############################################################################## 

529 

530 

531def function(f, tensorboard=False): 

532 """Compiles functions with `tensorflow.function` based on a `tensorboard` 

533 parameter. 

534 

535 

536 To understand the necessity of this function, we need to have a look at how 

537 tensorflow executes computations. There are two modes of execution: 

538 * eager mode: In eager mode, the computations are handled by python. 

539 The input types are python objects, and the output is a python object. 

540 This eager execution allows you to directly execute a calculation of 

541 two tensors (e.g. multiplication). 

542 * graph mode: In graph mode, computations are done inside tensorflow graphs, 

543 which are a collection of operations and tensors (i.e. data), that flow 

544 through the operations of the graph. These graphs, make tensorflow 

545 computations portable and significantly increase the performance of 

546 similar computations. 

547 Normally, you would accelerate a simple python function in tensorflow, 

548 by compiling it like so: 

549 

550 ```python 

551 import tensorflow as tf 

552 

553 @tf.function 

554 def multiply(a, b): 

555 return a * b 

556 

557 multiply(tf.constant(2), tf.constant(3)).numpy() 

558 # 6 

559 ``` 

560 

561 However, the basic paradigm of accelerating the computation interferes with 

562 `encodermap.Parameters` `tensorboard=True` argument, as it writes a lot of 

563 additional information to tensorboard. Thus, a compilation with tf.function 

564 does not make sense here. That's why EncoderMap's `function` decorator 

565 takes an additional argument: 

566 

567 """ 

568 

569 def wrapper(*args, **kwargs): 

570 tensorboard = kwargs.pop("tensorboard", False) 

571 """Wrapper of `encodermap.function`.""" 

572 if tensorboard: 

573 warnings.warn( 

574 "Running in tensorboard mode writes a lot of stuff to " 

575 "tensorboard. For speed boost deactivate tensorboard mode." 

576 ) 

577 result = f(*args, **kwargs) 

578 else: 

579 compiled = tf.function(f) 

580 result = compiled(*args, **kwargs) 

581 return result 

582 

583 return wrapper 

584 

585 

586############################################################################## 

587# Public Classes 

588############################################################################## 

589 

590 

591class Autoencoder: 

592 """Main Autoencoder class. Presents all high-level functions. 

593 

594 This is the main class for neural networks inside EncoderMap. The class 

595 prepares the data (batching and shuffling), creates a `tf.keras.Model` 

596 of layers specified by the attributes of the `encodermap.Parameters` class. 

597 Depending on what Parent/Child-Class is instantiated, a combination of 

598 various cost functions is set up. Callbacks to Tensorboard are also set up. 

599 

600 Attributes: 

601 train_data (np.ndarray): The numpy array of the train data passed at init. 

602 p (AnyParameters): An `encodermap.Parameters` class 

603 containing all info needed to set up the network. 

604 dataset (tensorflow.data.Dataset): The dataset that is actually used 

605 in training the keras model. The dataset is a batched, shuffled, 

606 infinitely-repeating dataset. 

607 read_only (bool): Variable telling the class whether it is allowed to 

608 write to disk (False) or not (True). 

609 metrics (list[Any]): A list of metrics passed to the model when it is compiled. 

610 callbacks (list[Any]): A list of tf.keras.callbacks.Callback subclasses 

611 changing the behavior of the model during training. 

612 Some standard callbacks are always present like: 

613 * encodermap.callbacks.callbacks.ProgressBar: 

614 A progress bar callback using tqdm giving the current 

615 progress of training and the current loss. 

616 * CheckPointSaver: 

617 A callback that saves the model every 

618 `parameters.checkpoint_step` steps into the main directory. 

619 This callback will only be used, when `read_only` is False. 

620 * TensorboardWriteBool: 

621 A callback that contains a boolean Tensor that will be 

622 True or False, depending on the current training step and 

623 the summary_step in the parameters class. The loss functions 

624 use this callback to decide whether they should write to 

625 Tensorboard. This callback will only be present when 

626 `read_only` is False and `parameters.tensorboard` is True. 

627 You can append your own callbacks to this list before executing 

628 `self.train()`. 

629 encoder (tf.keras.Model): The encoder submodel of `self.model`. 

630 decoder (tf.keras.Model): The decoder submodel of `self.model`. 

631 loss (Sequence[Callable]): A list of loss functions passed to the model 

632 when it is compiled. When the main `Autoencoder` class is used and 

633 `parameters.loss` is 'emap_cost', this list comprises center_cost, 

634 regularization_cost, auto_cost. When the `EncoderMap` sub-class is 

635 used and `parameters.loss` is 'emap_cost', distance_cost is added to 

636 the list. When `parameters.loss` is not 'emap_cost', the loss can either 

637 be a string ('mse'), or a function, that both are acceptable 

638 arguments for loss, when a keras model is compiled. 

639 

640 Methods: 

641 from_checkpoint: Rebuild the model from a checkpoint. 

642 add_images_to_tensorboard: Make tensorboard plot images. 

643 train: Starts the training of the tf.keras.models.Model. 

644 plot_network: Tries to plot the network. For this method to work 

645 graphviz, pydot and pydotplus need to be installed. 

646 encode: Takes high-dimensional data and sends it through the encoder. 

647 decode: Takes low-dimensional data and sends it through the encoder. 

648 generate: Same as `decode`. For AngleDihedralCartesianAutoencoder classes, 

649 this will build a protein strutcure. 

650 

651 Note: 

652 Performance of tensorflow is not only dependent on your system's 

653 hardware and how the data is presented to the network 

654 (for this check out https://www.tensorflow.org/guide/data_performance), 

655 but also how you compiled tensorflow. Normal tensorflow 

656 (pip install tensorflow) is build without CPU extensions to work on 

657 many CPUs. However, Tensorflow can greatly benefit from using CPU 

658 instructions like AVX2, AVX512 that bring a speed-up in linear algebra 

659 computations of 300%. By building tensorflow from source, 

660 you can activate these extensions. However, the speed-up of using 

661 tensorflow with a GPU dwarfs the CPU speed-up. To check whether a 

662 GPU is available run: `print(len(tf.config.list_physical_devices('GPU')))`. 

663 Refer to these pages to install tensorflow for the best performance: 

664 https://www.tensorflow.org/install/pip and 

665 https://www.tensorflow.org/install/gpu 

666 

667 Examples: 

668 >>> import encodermap as em 

669 >>> # without providing any data, default parameters and a 4D 

670 >>> # hypercube as input data will be used. 

671 >>> e_map = em.EncoderMap(read_only=True) 

672 >>> print(e_map.train_data.shape) 

673 (16000, 4) 

674 >>> print(e_map.dataset) # doctest: +SKIP 

675 <BatchDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float32, name=None), TensorSpec(shape=(None, 4), dtype=tf.float32, name=None))> 

676 >>> print(e_map.encode(e_map.train_data).shape) 

677 (16000, 2) 

678 

679 """ 

680 

681 def __init__( 

682 self, 

683 parameters=None, 

684 train_data: Optional[Union[np.ndarray, tf.data.Dataset]] = None, 

685 model: Optional[tf.keras.Model] = None, 

686 read_only: bool = False, 

687 sparse: bool = False, 

688 ) -> None: 

689 """Instantiate the Autoencoder class. 

690 

691 Args: 

692 parameters (Union[encodermap.Parameters, None], optional): The 

693 parameters to be used. If None is provided default values 

694 (check them with `print(em.Parameters.defaults_description()`)) 

695 are used. Defaults to None. 

696 train_data (Union[np.ndarray, tf.data.Dataset, None], optional): 

697 The train data. Can be one of the following: 

698 * None: If None is provided points on the edges of a 

699 4-dimensional hypercube will be used as train data. 

700 * np.ndarray: If a numpy array is provided, it will be 

701 transformed into a batched tf.data.Dataset by first 

702 making it an infinitely repeating dataset, shuffling 

703 it and the batching it with a batch size specified 

704 by parameters.batch_size. 

705 * tf.data.Dataset: If a dataset is provided it will be 

706 used without making any adjustments. Make sure, that the 

707 dataset uses `float32` as its type. 

708 Defaults to None. 

709 model (Union[tf.keras.models.Model, None], optional): Providing 

710 a keras model to this argument will make the Autoencoder/EncoderMap 

711 class use this model instead of the predefined ones. Make sure 

712 the model can accept EncoderMap's loss functions. If None is 

713 provided the model will be built using the specifications in 

714 parameters. Defaults to None. 

715 read_only (bool, optional): Whether the class is allowed to write 

716 to disk (False) or not (True). Defaults to False and will allow 

717 the class to write to disk. 

718 

719 """ 

720 # parameters 

721 if parameters is None: 

722 self.p = Parameters() 

723 else: 

724 self.p = parameters 

725 

726 if self.p.seed is not None: 

727 tf.random.set_seed(self.p.seed) 

728 self.read_only = read_only 

729 

730 if not self.read_only: 

731 self.p.write_summary = True 

732 self.p.save() 

733 self._print_save_message() 

734 

735 # check whether Tensorboard and Read-Only makes Sense 

736 if self.read_only and self.p.tensorboard: 

737 raise NotImplementedError 

738 

739 # clear old sessions 

740 tf.keras.backend.clear_session() 

741 self.sparse = sparse 

742 

743 # set up train_data 

744 self.set_train_data(train_data) 

745 

746 # create model based on user input 

747 if model is None: 

748 self.model = self._get_model() 

749 else: 

750 self.model = model 

751 

752 # setup callbacks for nice progress bars and saving every now and then 

753 self._setup_callbacks() 

754 

755 # create loss based on user input 

756 self.loss = self._get_loss() 

757 

758 # choose optimizer 

759 self.optimizer = tf.keras.optimizers.Adam( 

760 learning_rate=self.p.learning_rate, clipvalue=1.0 

761 ) 

762 

763 # compile model 

764 self.model.compile( 

765 optimizer=self.optimizer, 

766 loss=self.loss, 

767 metrics=self.metrics, 

768 unpacked_loss_fns={fn.__name__: fn for fn in self.loss}, 

769 ) 

770 

771 # do this if tensorboard is true. 

772 if self.p.tensorboard: 

773 self._log_images = False 

774 # get the output from model summary. 

775 with Capturing() as output: 

776 self.model.summary() 

777 with open(self.p.main_path + "/model_summary.txt", "w") as f: 

778 f.write("\n".join(output)) 

779 self.plot_network() 

780 print( 

781 f"Saved a text-summary of the model and an image in {self.p.main_path},", 

782 "as specified in 'main_path' in the parameters.", 

783 ) 

784 

785 # sets up the tb callback to plot the model 

786 self.tb_callback = tf.keras.callbacks.TensorBoard( 

787 self.p.main_path, write_graph=True 

788 ) 

789 self.tb_callback.set_model(self.model) 

790 

791 def _print_save_message(self) -> None: 

792 """Prints a save message to inform the user where the model is saved.""" 

793 _print_save_message(self) 

794 

795 def _print_missing_data(self) -> None: 

796 print( 

797 f"This model was reloaded from disk, but not yet provided with train " 

798 f"data. Use the `set_train_data()` method to provide the train " 

799 f"data and call `train()` again to train the model. Alternatively, " 

800 f"you could directly provide the train data, when reloading by " 

801 f"calling the `{self.__class__.__name__}.from_checkpoint()` " 

802 f"constructor with the `train_data` argument. Expected shape = " 

803 f"{self.model.encoder_model.input_shape[1]}, received shape = " 

804 f"{self._tensorboard_data_req_shape} {self._using_hypercube=} " 

805 f"{self.p.using_hypercube=} {self.dataset.element_spec=}" 

806 ) 

807 return 

808 

809 def set_train_data(self, data: Union[np.ndarray, tf.data.Dataset]) -> None: 

810 """Resets the train data for reloaded models.""" 

811 self._using_hypercube = False 

812 if data is None: 

813 self._using_hypercube = True 

814 self.p.using_hypercube = True 

815 self.train_data = create_n_cube(4, seed=self.p.seed)[0].astype("float32") 

816 self.p.periodicity = float("inf") 

817 elif isinstance(data, np.ndarray): 

818 if np.any(np.isnan(data)): 

819 self.sparse = True 

820 print("Input contains nans. Using sparse network.") 

821 self.train_data = np_to_sparse_tensor(data) 

822 else: 

823 self.train_data = data.astype("float32") 

824 elif isinstance(data, tf.data.Dataset): 

825 self.dataset = data 

826 try: 

827 _, __ = self.dataset.take(1) 

828 except ValueError: 

829 if self.p.training == "auto": 

830 if self.p.model_api == "custom": 

831 print( 

832 f"It seems like your dataset only yields tensors and not " 

833 f"tuples of tensors. TensorFlow is optimized for classification " 

834 f"tasks, where datasets yield tuples of (data, classes). EncoderMap," 

835 f"however is a regression task, but uses the same code as the " 

836 f"classification tasks. This requires the dataset provided " 

837 f"for a tensorflow model.fit() method to return tuples " 

838 f"of (data, classes). Your dataset does not do this. " 

839 f"I will transform your dataset using " 

840 f"the `tf.data.Dataset.zip()` function of `tf.data`. " 

841 f"This might break your custom model. You can " 

842 f"set the `training` parameter in the parameter class to " 

843 f"'custom' to not alter your dataset." 

844 ) 

845 self.dataset = tf.data.Dataset.zip((self.dataset, self.dataset)) 

846 _ = self.dataset.take(1) 

847 else: 

848 _ = self.dataset.take(1) 

849 self.train_data = _ 

850 else: 

851 raise TypeError( 

852 f"train_data must be `None`, `np.ndarray` or `tf.data.Dataset`. You supplied {type(data)}." 

853 ) 

854 

855 # check data and periodicity 

856 if not self.sparse and data is not None: 

857 if isinstance(data, np.ndarray): 

858 if np.any(data > self.p.periodicity): 

859 raise Exception( 

860 "There seems to be an error regarding the periodicity " 

861 f"of your data. The chosen periodicity is {self.p.periodicity}, " 

862 f"but there are datapoints outwards of this range: {data.max()}" 

863 ) 

864 

865 # prepare the data 

866 if isinstance(self.train_data, (np.ndarray, tf.sparse.SparseTensor)): 

867 if self.p.training == "auto": 

868 dataset = tf.data.Dataset.from_tensor_slices( 

869 (self.train_data, self.train_data) 

870 ) 

871 else: 

872 dataset = tf.data.Dataset.from_tensor_slices(self.train_data) 

873 dataset = dataset.shuffle( 

874 buffer_size=self.train_data.shape[0], reshuffle_each_iteration=True 

875 ) 

876 dataset = dataset.repeat() 

877 self.dataset = dataset.batch(self.p.batch_size) 

878 self._tensorboard_data_req_shape = tf.shape(self.train_data).numpy() 

879 else: 

880 d = self.train_data.element_spec[0] 

881 if isinstance(d, tuple): 

882 self._tensorboard_data_req_shape = d[0].shape.as_list() 

883 else: 

884 self._tensorboard_data_req_shape = d.shape.as_list() 

885 

886 if isinstance(data, np.ndarray): 

887 assert data.shape[1] == self._tensorboard_data_req_shape[1] 

888 

889 def _setup_callbacks(self) -> None: 

890 """Sets up a list with callbacks to be passed to self.model.fit()""" 

891 self.metrics = [] 

892 self.callbacks = [] 

893 self.callbacks.append(ProgressBar(parameters=self.p)) 

894 if not self.read_only: 

895 self.callbacks.append(CheckpointSaver(self.p)) 

896 if self.p.tensorboard: 

897 self.tensorboard_write_bool = TensorboardWriteBool(self.p) 

898 self.callbacks.append(self.tensorboard_write_bool) 

899 file_writer = tf.summary.create_file_writer(self.p.main_path + "/train") 

900 file_writer.set_as_default() 

901 tf.summary.text( 

902 name=f"Parameters Summary for {self.p.main_path}", 

903 data=self.p.parameters, 

904 step=0, 

905 ) 

906 else: 

907 self.tensorboard_write_bool = None 

908 

909 @classmethod 

910 def from_checkpoint( 

911 cls: Type[AutoencoderType], 

912 checkpoint_path: Union[str, Path], 

913 train_data: Optional[np.ndarray] = None, 

914 sparse: bool = False, 

915 use_previous_model: bool = False, 

916 compat: bool = False, 

917 ) -> AutoencoderType: 

918 """Reconstructs the class from a checkpoint. 

919 

920 Args: 

921 checkpoint_path (Union[str, Path]): The path to the checkpoint. Can 

922 be either a directory, in which case the most recently saved 

923 model will be loaded. Or a direct .keras file, in which case, this 

924 specific model will be loaded. 

925 train_data (Optional[np.ndarray]). When you want to retrain this model, you 

926 can provide the train data here. 

927 sparse (bool): Whether the reloaded model should be sparse. 

928 use_previous_model (bool): Set this flag to True, if you load a model 

929 from an in-between checkpoint step (e.g., to continue training with 

930 different parameters). If you have the files saved_model_0.keras, 

931 saved_model_500.keras and saved_model_1000.keras, setting this to 

932 True and loading the saved_model_500.keras will back up the 

933 saved_model_1000.keras. 

934 compat (bool): Whether to use compatibility mode when missing or wrong 

935 parameter files are present. In this special case, some assumptions 

936 about the network architecture are made from the model and the 

937 parameters in parameters.json overwritten accordingly (a backup 

938 will also be made). 

939 

940 Returns: 

941 Autoencoder: Encodermap `Autoencoder` class. 

942 

943 """ 

944 return load_model( 

945 cls, 

946 checkpoint_path, 

947 sparse=sparse, 

948 dataset=train_data, 

949 use_previous_model=use_previous_model, 

950 compat=compat, 

951 ) 

952 

953 def _get_model(self) -> tf.keras.Model: 

954 return _get_model(self) 

955 

956 @property 

957 def encoder(self) -> tf.keras.Model: 

958 """tf.keras.Model: Encoder part of the model.""" 

959 return self.model.encoder 

960 

961 @property 

962 def decoder(self) -> tf.keras.Model: 

963 """tf.keras.Model: Decoder part of the model.""" 

964 return self.model.decoder 

965 

966 def _get_loss(self): 

967 """sets self.loss according to `loss` in self.parameters.""" 

968 loss = self.p.loss 

969 if loss == "reconstruction_loss": 

970 loss = reconstruction_loss(self.model) 

971 elif loss == "emap_cost": 

972 self.auto_loss = auto_loss(self.model, self.p, self.tensorboard_write_bool) 

973 self.regularization_loss = regularization_loss( 

974 self.model, self.p, self.tensorboard_write_bool 

975 ) 

976 self.center_loss = center_loss( 

977 self.model, self.p, self.tensorboard_write_bool 

978 ) 

979 loss = [self.auto_loss, self.regularization_loss, self.center_loss] 

980 elif loss == "mse": 

981 loss = "mse" 

982 else: 

983 raise ValueError( 

984 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'." 

985 ) 

986 return loss 

987 

988 def train(self) -> Optional[dict[str, Any]]: 

989 """Starts the training of the model. 

990 

991 Returns: 

992 Union[tf.keras.callbacks.History, None]: If training succeeds, an 

993 instance of `tf.keras.callbacks.History` is returned. If not, 

994 None is returned. 

995 

996 """ 

997 if self.p.current_training_step >= self.p.n_steps: 

998 print( 

999 f"This {self.__class__.__name__} instance has already been trained " 

1000 f"for {self.p.current_training_step} steps. Increase the training " 

1001 f"steps by calling `{self.__class__.__name__}.p.n_steps += new_steps` " 

1002 f"and then call `{self.__class__.__name__}.train()` again." 

1003 ) 

1004 return 

1005 

1006 if self._using_hypercube == self.p.using_hypercube: 

1007 if self.p.periodicity == float("inf"): 

1008 if ( 

1009 self._tensorboard_data_req_shape[1] 

1010 != self.model.encoder_model.input_shape[1] 

1011 ): 

1012 return self._print_missing_data() 

1013 else: 

1014 if ( 

1015 self._tensorboard_data_req_shape[1] 

1016 != self.model.encoder_model.input_shape[1] // 2 

1017 ): 

1018 return self._print_missing_data() 

1019 else: 

1020 return self._print_missing_data() 

1021 

1022 if self.p.training == "custom" and self.p.batched: 

1023 raise NotImplementedError() 

1024 elif self.p.training == "custom" and not self.p.batched: 

1025 raise NotImplementedError() 

1026 elif self.p.training == "auto": 

1027 epochs = self.p.n_steps - self.p.current_training_step 

1028 try: 

1029 history = self.model.fit( 

1030 self.dataset, 

1031 batch_size=self.p.batch_size, 

1032 epochs=epochs, 

1033 steps_per_epoch=1, 

1034 verbose=0, 

1035 callbacks=self.callbacks, 

1036 ).history 

1037 except ValueError: 

1038 raise Exception( 

1039 f"{self.model.encoder_model.input_shape=} {self._tensorboard_data_req_shape=} " 

1040 f"{self.train_data.shape=} {self.dataset.element_spec=} {self.p.using_hypercube=} {self._using_hypercube=}" 

1041 ) 

1042 else: 

1043 raise ValueError( 

1044 f"training argument needs to be `auto` or `custom`. You provided '{self.training}'." 

1045 ) 

1046 self.p.current_training_step += self.p.n_steps - self.p.current_training_step 

1047 self.p.save() 

1048 self.save() 

1049 

1050 history_to_json_serializable(history) 

1051 return history 

1052 

1053 def add_loss(self, loss): 

1054 """Adds a new loss to the existing losses.""" 

1055 try: 

1056 self.loss.append(loss(self.model, self.p)) 

1057 except TypeError: 

1058 self.loss.append(loss(self.model)) 

1059 

1060 def add_callback(self, callback): 

1061 """Adds a new callback to the existing callbacks.""" 

1062 try: 

1063 self.callbacks.append(callback(self.p)) 

1064 except TypeError: 

1065 self.callbacks.append(callback) 

1066 

1067 def add_metric(self, metric): 

1068 """Adds a new metric to the existing metrics.""" 

1069 self.metrics.append(metric(self.p)) 

1070 

1071 def add_images_to_tensorboard(self, *args: Any, **kwargs: Any) -> None: 

1072 """Adds images of the latent space to tensorboard. 

1073 

1074 Args: 

1075 data (Optional[Union[np.ndarray, Sequence[np.ndarray]]): The input-data will 

1076 be passed through the encoder part of the autoencoder. If None 

1077 is provided, a set of 10_000 points from `self.train_data` will 

1078 be taken. A list[np.ndarray] is needed for the functional API of the 

1079 `AngleDihedralCartesianEncoderMap`, that takes a list of 

1080 [angles, dihedrals, side_dihedrals]. Defaults to None. 

1081 image_step (Optional[int]): The interval in which to plot 

1082 images to tensorboard. If None is provided, the `image_step` 

1083 will be the same as `Parameters.summary_step`. Defaults to None. 

1084 max_size (int): The maximum size of the high-dimensional data, that is 

1085 projected. Prevents excessively large-datasets from being projected 

1086 at every `image_step`. Defaults to 10_000. 

1087 scatter_kws (Optional[dict[str, Any]]): A dict with items that 

1088 `plotly.express.scatter()` will accept. If None is provided, 

1089 a dict with size 20 will be passed to 

1090 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`, 

1091 which sets an appropriate size of scatter points for the size of 

1092 datasets encodermap is usually used for. 

1093 hist_kws (Optional[dict[str, Any]]): A dict with items that 

1094 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a 

1095 dict with bins 50 will be passed to 

1096 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`. 

1097 You can choose a colormap here by providing `{'bins': 50, 'cmap': 

1098 'plasma'}` for this argument. 

1099 additional_fns (Optional[Sequence[Callable]]): A list of functions 

1100 that will accept the low-dimensional output of the `Autoencoder` 

1101 latent/bottleneck layer and return a tf.Tensor that can be logged 

1102 by `tf.summary.image()`. See the notebook 

1103 'writing_custom_images_to_tensorboard.ipynb' in 

1104 tutorials/notebooks_customization for more info. If None is 

1105 provided, no additional functions will be used to plot to 

1106 tensorboard. Defaults to None. 

1107 when (Literal["epoch", "batch"]): When to log the images can be 

1108 either 'batch', then the images will be logged after every step 

1109 during training, or 'epoch', then only after every image_step 

1110 epoch the images will be written. Defaults to 'epoch'. 

1111 save_to_disk (bool): Whether to also write the images to disk. 

1112 

1113 """ 

1114 _add_images_to_tensorboard(self, *args, **kwargs) 

1115 

1116 def plot_network(self) -> None: 

1117 """Tries to plot the network using pydot, pydotplus and graphviz. 

1118 Doesn't raise an exception if plotting is not possible. 

1119 

1120 Note: 

1121 Refer to this guide to install these programs: 

1122 https://stackoverflow.com/questions/47605558/importerror-failed-to-import-pydot-you-must-install-pydot-and-graphviz-for-py 

1123 

1124 """ 

1125 try: 

1126 out = plot_model(self.model, self.train_data.shape[1]) 

1127 if out is not None: 

1128 out.save(Path(self.p.main_path) / "network.png") 

1129 except: 

1130 pass 

1131 

1132 def encode(self, data: Optional[Sequence[np.ndarray]] = None) -> np.ndarray: 

1133 """Calls encoder part of `self.model`. 

1134 

1135 Args: 

1136 data (Optional[np.ndarray]): The data to be passed top the encoder part. 

1137 It can be either numpy ndarray or None. If None is provided, 

1138 a set of 10000 points from the provided train data will be taken. 

1139 Defaults to None. 

1140 

1141 Returns: 

1142 np.ndarray: The output from the bottleneck/latent layer. 

1143 

1144 """ 

1145 if data is None: 

1146 data = self.train_data 

1147 

1148 # check the shapes: 

1149 if not isinstance(data, (list, tuple)): 

1150 if data.shape[1] * 2 == self.model.encoder_model.input_shape[1]: 

1151 out = self.model.encoder(data) 

1152 elif data.shape[1] == self.model.encoder_model.input_shape[1]: 

1153 out = self.model.encoder_model(data) 

1154 else: 

1155 msg = ( 

1156 f"The shape of the provided data {data.shape=} does not " 

1157 f"match the expected shape {self.model.encoder_model.input_shape=}." 

1158 ) 

1159 if self.p.periodicity < float("inf"): 

1160 msg += f" Not even considering the periodicity of {self.p.periodicity}." 

1161 raise Exception(msg) 

1162 else: 

1163 for d, in_shape in zip(data, self.model.encoder_model.input_shape): 

1164 assert d.shape[1] == in_shape[1], ( 

1165 f"The shape of the provided data ({d.shape}) does not match " 

1166 f"the expected shape {in_shape}." 

1167 ) 

1168 out = self.model.encoder_model(data) 

1169 

1170 if isinstance(out, (list, tuple)): 

1171 out = [o.numpy() for o in out] 

1172 else: 

1173 out = out.numpy() 

1174 return out 

1175 

1176 def generate(self, data: np.ndarray) -> np.ndarray: 

1177 """Duplication of `self.decode`. 

1178 

1179 In `Autoencoder` and `EncoderMap` this method is equivalent to `decode()`. 

1180 In `AngleDihedralCartesianEncoderMap` this method will be overwritten 

1181 to produce output molecular conformations. 

1182 

1183 Args: 

1184 data (np.ndarray): The data to be passed to the decoder part of the 

1185 model. Make sure that the shape of the data matches the number 

1186 of neurons in the latent space. 

1187 

1188 Returns: 

1189 np.ndarray: Outputs from the decoder part. For 

1190 `AngleDihedralCartesianEncoderMap`, this will either be a 

1191 `mdtraj.Trajectory` or `MDAnalysis.Universe`. 

1192 """ 

1193 return self.model.decoder(data) 

1194 

1195 def decode(self, data: np.ndarray) -> Sequence[np.ndarray]: 

1196 """Calls the decoder part of the model. 

1197 

1198 `AngleDihedralCartesianAutoencoder` will, like the other two classes' 

1199 output a list of np.ndarray. 

1200 

1201 Args: 

1202 data (np.ndarray): The data to be passed to the decoder part of 

1203 the model. Make sure that the shape of the data matches the 

1204 number of neurons in the latent space. 

1205 

1206 Returns: 

1207 Union[list[np.ndarray], np.ndarray]: Outputs from the decoder part. 

1208 For `AngleDihedralCartesianEncoderMap`, this will be a list of 

1209 np.ndarray. 

1210 

1211 """ 

1212 out = self.model.decoder(data) 

1213 if isinstance(out, (list, tuple)): 

1214 out = [o.numpy() for o in out] 

1215 else: 

1216 out = out.numpy() 

1217 return out 

1218 

1219 def save(self, step: Optional[int] = None) -> None | Path: 

1220 """Saves the model to the current path defined in `parameters.main_path`. 

1221 

1222 Args: 

1223 step (Optional[int]): Does not save the model at the given 

1224 training step, but rather changes the string used for saving 

1225 the model from a datetime format to another. 

1226 

1227 Returns: 

1228 Union[None, Path]: When the model has been saved, the Path will 

1229 be returned. If the model could not be saved. None will be 

1230 returned. 

1231 

1232 """ 

1233 if not self.read_only: 

1234 return save_model( 

1235 self.model, 

1236 self.p.main_path, 

1237 inp_class_name=self.__class__.__name__, 

1238 step=step, 

1239 print_message=True, 

1240 ).parent 

1241 else: 

1242 print( 

1243 f"This {self.__class__.__name__} is set to read_only. Set " 

1244 f"`{self.__class__.__name__}.read_only=False` to save the " 

1245 f"current state of the model." 

1246 ) 

1247 

1248 def close(self) -> None: 

1249 """Clears the current keras backend and frees up resources.""" 

1250 # clear old sessions 

1251 tf.keras.backend.clear_session() 

1252 

1253 

1254class EncoderMap(Autoencoder): 

1255 """Complete copy of Autoencoder class but uses additional distance cost 

1256 scaled by the SketchMap sigmoid params""" 

1257 

1258 @classmethod 

1259 def from_checkpoint( 

1260 cls: Type[EncoderMapType], 

1261 checkpoint_path: Union[str, Path], 

1262 train_data: Optional[np.ndarray] = None, 

1263 sparse: bool = False, 

1264 use_previous_model: bool = False, 

1265 compat: bool = False, 

1266 ) -> EncoderMapType: 

1267 """Reconstructs the class from a checkpoint. 

1268 

1269 Args: 

1270 checkpoint_path (Union[str, Path]): The path to the checkpoint. Can 

1271 be either a directory, in which case the most recently saved 

1272 model will be loaded. Or a direct .keras file, in which case, this 

1273 specific model will be loaded. 

1274 train_data (Optional[np.ndarray]). When you want to retrain this model, you 

1275 can provide the train data here. 

1276 sparse (bool): Whether the reloaded model should be sparse. 

1277 use_previous_model (bool): Set this flag to True, if you load a model 

1278 from an in-between checkpoint step (e.g., to continue training with 

1279 different parameters). If you have the files saved_model_0.keras, 

1280 saved_model_500.keras and saved_model_1000.keras, setting this to 

1281 True and loading the saved_model_500.keras will back up the 

1282 saved_model_1000.keras. 

1283 compat (bool): Whether to use compatibility mode when missing or wrong 

1284 parameter files are present. In this special case, some assumptions 

1285 about the network architecture are made from the model and the 

1286 parameters in parameters.json overwritten accordingly (a backup 

1287 will also be made). 

1288 

1289 Returns: 

1290 EncoderMap: EncoderMap `EncoderMap` class. 

1291 

1292 """ 

1293 return load_model( 

1294 cls, 

1295 checkpoint_path, 

1296 sparse=sparse, 

1297 dataset=train_data, 

1298 use_previous_model=use_previous_model, 

1299 ) 

1300 

1301 def _get_loss(self): 

1302 loss = self.p.loss 

1303 if loss == "reconstruction_loss": 

1304 loss = reconstruction_loss(self.model) 

1305 elif loss == "emap_cost": 

1306 self.auto_loss = auto_loss(self.model, self.p, self.tensorboard_write_bool) 

1307 self.regularization_loss = regularization_loss( 

1308 self.model, self.p, self.tensorboard_write_bool 

1309 ) 

1310 self.center_loss = center_loss( 

1311 self.model, self.p, self.tensorboard_write_bool 

1312 ) 

1313 # this cost is new 

1314 self.distance_loss = distance_loss( 

1315 self.model, self.p, self.tensorboard_write_bool 

1316 ) 

1317 loss = [ 

1318 self.auto_loss, 

1319 self.regularization_loss, 

1320 self.center_loss, 

1321 self.distance_loss, 

1322 ] 

1323 elif loss == "mse": 

1324 loss = "mse" 

1325 else: 

1326 raise ValueError( 

1327 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'." 

1328 ) 

1329 return loss 

1330 

1331 

1332class DihedralEncoderMap(EncoderMap): 

1333 """Similar to the `EncoderMap` class, but overwrites the `generate` method. 

1334 

1335 Using this class, instead of tbe `EncoderMap` class, the `generate` method, 

1336 needs an additional argument: `top`, which should be a topology file. This 

1337 topology will be used as a base on which the dihedrals of the `decode` 

1338 method are applied. 

1339 

1340 """ 

1341 

1342 def generate( 

1343 self, 

1344 data: np.ndarray, 

1345 top: Union[Path, str], 

1346 ) -> Universe: 

1347 """Overwrites `EncoderMap`'s generate method and actually does 

1348 backmapping if a list of dihedrals is provided. 

1349 

1350 Args: 

1351 data (np.ndarray): The low-dimensional/latent/bottleneck data. 

1352 A ndim==2 numpy array with xy coordinates of points in latent space. 

1353 top (str): Topology file for this run of EncoderMap (can be .pdb, .gro, .. etc.). 

1354 

1355 Returns: 

1356 MDAnalysis.Universe: The topology with the provided backbone torsions. 

1357 

1358 Examples: 

1359 >>> # get some time-resolved pdb files 

1360 >>> import requests 

1361 >>> import numpy as np 

1362 >>> pdb_link = 'https://files.rcsb.org/view/1YUF.pdb' 

1363 >>> contents = requests.get(pdb_link).text 

1364 >>> print(contents.splitlines()[0]) # doctest: +SKIP 

1365 HEADER GROWTH FACTOR 01-APR-96 1YUF 

1366 >>> # fake a file with stringio 

1367 >>> from io import StringIO 

1368 >>> import MDAnalysis as mda 

1369 >>> import numpy as np 

1370 >>> file = StringIO(contents) 

1371 >>> # pass it to MDAnalysis 

1372 >>> u = mda.Universe(file, format='PDB') 

1373 >>> print(u) 

1374 <Universe with 720 atoms> 

1375 >>> # select the atomgroups 

1376 >>> ags = [*[res.psi_selection() for res in u.residues], 

1377 ... *[res.omega_selection() for res in u.residues], 

1378 ... *[res.phi_selection() for res in u.residues] 

1379 ... ] 

1380 >>> # filter Nones 

1381 >>> ags = list(filter(lambda x: False if x is None else True, ags)) 

1382 >>> print(ags[0][0]) # doctest: +SKIP 

1383 <Atom 3: C of type C of resname VAL, resid 1 and segid A and altLoc > 

1384 >>> # Run dihedral Angles 

1385 >>> from MDAnalysis.analysis.dihedrals import Dihedral 

1386 >>> R = np.deg2rad(Dihedral(ags).run().results.angles) 

1387 >>> print(R.shape) 

1388 (16, 147) 

1389 >>> # import EncoderMap and define parameters 

1390 >>> from encodermap.autoencoder import DihedralEncoderMap 

1391 >>> import encodermap as em 

1392 >>> parameters = em.Parameters( 

1393 ... dist_sig_parameters = (4.5, 12, 6, 1, 2, 6), 

1394 ... periodicity = 2*np.pi, 

1395 ... l2_reg_constant = 10.0, 

1396 ... summary_step = 5, 

1397 ... tensorboard = False, 

1398 ... ) 

1399 >>> e_map = DihedralEncoderMap(parameters, R, read_only=True) 

1400 >>> print(e_map.__class__.__name__) 

1401 DihedralEncoderMap 

1402 >>> # get some low-dimensional data 

1403 >>> lowd = np.random.random((100, 2)) 

1404 >>> # use the generate method to get a new MDAnalysis universe 

1405 >>> # but first remove the time resolution 

1406 >>> file = StringIO(contents.split('MODEL 2')[0]) 

1407 >>> new = e_map.generate(lowd, file) 

1408 >>> print(new.trajectory.coordinate_array.shape) 

1409 (100, 720, 3) 

1410 >>> # check whether frame 0 of u and new_u are different 

1411 >>> for ts in u.trajectory: 

1412 ... a1 = ts.positions 

1413 ... break 

1414 >>> print(np.array_equal(a1, new.trajectory.coordinate_array[0])) 

1415 False 

1416 

1417 """ 

1418 assert np.any(data) 

1419 dihedrals = self.decode(data) 

1420 assert np.any(dihedrals) 

1421 uni = dihedral_backmapping(top, dihedrals) 

1422 return uni 

1423 

1424 

1425class AngleDihedralCartesianEncoderMap: 

1426 """Different `__init__` method, than Autoencoder Class. Uses callbacks to tune-in cartesian cost. 

1427 

1428 Overwritten methods: `_set_up_callbacks` and `generate`. 

1429 

1430 Examples: 

1431 >>> import encodermap as em 

1432 >>> from pathlib import Path 

1433 >>> # Load two trajectories 

1434 >>> test_data = Path(em.__file__).parent.parent / "tests/data" 

1435 >>> test_data.is_dir() 

1436 True 

1437 >>> xtcs = [test_data / "1am7_corrected_part1.xtc", test_data / "1am7_corrected_part2.xtc"] 

1438 >>> tops = [test_data / "1am7_protein.pdb", test_data /"1am7_protein.pdb"] 

1439 >>> trajs = em.load(xtcs, tops) 

1440 >>> print(trajs) 

1441 encodermap.TrajEnsemble object. Current backend is no_load. Containing 2 trajectories. Not containing any CVs. 

1442 >>> # load CVs 

1443 >>> # This step can be omitted. The AngleDihedralCartesianEncoderMap class automatically loads CVs 

1444 >>> trajs.load_CVs('all') 

1445 >>> print(trajs.CVs['central_cartesians'].shape) 

1446 (51, 474, 3) 

1447 >>> print(trajs.CVs['central_dihedrals'].shape) 

1448 (51, 471) 

1449 >>> # create some parameters 

1450 >>> p = em.ADCParameters(periodicity=360, use_backbone_angles=True, use_sidechains=True, 

1451 ... cartesian_cost_scale_soft_start=(6, 12)) 

1452 >>> # Standard is functional model, as it offers more flexibility 

1453 >>> print(p.model_api) 

1454 functional 

1455 >>> print(p.distance_cost_scale) 

1456 None 

1457 >>> # Instantiate the class 

1458 >>> e_map = em.AngleDihedralCartesianEncoderMap(trajs, p, read_only=True) # doctest: +ELLIPSIS 

1459 Model... 

1460 >>> # dataset contains these inputs: 

1461 >>> # central_angles, central_dihedrals, central_cartesians, central_distances, sidechain_dihedrals 

1462 >>> print(e_map.dataset) # doctest: +SKIP 

1463 <BatchDataset element_spec=(TensorSpec(shape=(None, 472), dtype=tf.float32, name=None), TensorSpec(shape=(None, 471), dtype=tf.float32, name=None), TensorSpec(shape=(None, 474, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 473), dtype=tf.float32, name=None), TensorSpec(shape=(None, 316), dtype=tf.float32, name=None))> 

1464 >>> # output from the model contains the following data: 

1465 >>> # out_angles, out_dihedrals, back_cartesians, pairwise_distances of inp cartesians, pairwise of back-mapped cartesians, out_side_dihedrals 

1466 >>> for data in e_map.dataset.take(1): 

1467 ... pass 

1468 >>> out = e_map.model(data) 

1469 >>> print([i.shape for i in out]) # doctest: +SKIP 

1470 [TensorShape([256, 472]), TensorShape([256, 471]), TensorShape([256, 474, 3]), TensorShape([256, 112101]), TensorShape([256, 112101]), TensorShape([256, 316])] 

1471 >>> # get output of latent space by providing central_angles, central_dihedrals, sidehcain_dihedrals 

1472 >>> latent = e_map.encoder([data[0], data[1], data[-1]]) 

1473 >>> print(latent.shape) 

1474 (256, 2) 

1475 >>> # Rebuild central_angles, central_dihedrals and sidechain_angles from latent 

1476 >>> dih, ang, side_dih = e_map.decode(latent) 

1477 >>> print(dih.shape, ang.shape, side_dih.shape) 

1478 (256, 472) (256, 471) (256, 316) 

1479 

1480 """ 

1481 

1482 def __init__( 

1483 self, 

1484 trajs: Optional[TrajEnsemble] = None, 

1485 parameters: Optional[ADCParameters] = None, 

1486 model: Optional[tf.keras.Model] = None, 

1487 read_only: bool = False, 

1488 dataset: Optional[tf.data.Dataset] = None, 

1489 ensemble: bool = False, 

1490 use_dataset_when_possible: bool = True, 

1491 deterministic: bool = False, 

1492 ) -> None: 

1493 """Instantiate the `AngleDihedralCartesianEncoderMap` class. 

1494 

1495 Args: 

1496 trajs (Optional[TrajEnsemble]): The trajectories to be used as input. 

1497 If trajs contain no CVs, correct CVs will be loaded. can be None, 

1498 in which case the argument `dataset` should be provided. 

1499 Defaults to None. 

1500 parameters (Optional[em.ADCParameters]): The parameters for the 

1501 current run. Can be set to None and the default parameters will 

1502 be used. Defaults to None. 

1503 model (Optional[tf.keras.models.Model]): The keras model to use. You 

1504 can provide your own model with this argument. If set to None, 

1505 the model will be built to the specifications of parameters using 

1506 either the functional API. Defaults to None, 

1507 read_only (bool): Whether to write anything to disk 

1508 (False) or not (True). Defaults to False. 

1509 dataset (Optional[tf.data.Dataset]): The `dataset` argument takes 

1510 precedent over the `trajs` argument. If None, the dataset will 

1511 be constructed from the `trajs` argument (see 

1512 `em.trajinfo.TrajEnsemble.tf_dataset` for more info). Defaults 

1513 to None. 

1514 ensemble (bool): Whether to allow non-defined features when 

1515 featurizing the provided `trajs`. Only takes effect, when 

1516 the `trajs` don't already have the features (central_cartesians, 

1517 central_distances, central_angles, central_dihedrals, side_dihedrals) 

1518 loaded. Defaults to False. 

1519 use_dataset_when_possible (bool): Whether to use the `trajs` method 

1520 `tf_dataset()` to get a dataset for training or constructy a 

1521 dataset from the `trajs` CVs numpy arrays. For large datasets the 

1522 first method can be advantageous as not all data will end up in 

1523 memory and the dataset can be larger than the memory allows. For 

1524 small datasets the second method is faster, as all data is in 

1525 memory. Defaults to True. 

1526 

1527 """ 

1528 # parameters 

1529 if parameters is None: 

1530 self.p = ADCParameters() 

1531 else: 

1532 assert isinstance(parameters, ADCParameters), ( 

1533 f"Please provide an instance of `ADCParameters` for the argument " 

1534 f"'parameters' and not {type(parameters)}." 

1535 ) 

1536 self.p = parameters 

1537 

1538 # seed 

1539 if self.p.seed is not None: 

1540 tf.random.set_seed(self.p.seed) 

1541 

1542 # check some sidechain 

1543 if self.p.reconstruct_sidechains: 

1544 assert ( 

1545 self.p.use_sidechains 

1546 ), "If you want to reconstruct sidechains, you should also set `use_sidechains` to True." 

1547 

1548 # read_only 

1549 self.read_only = read_only 

1550 

1551 # save params and create dir 

1552 if not self.read_only: 

1553 self.p.write_summary = True 

1554 self.p.save() 

1555 self._print_save_message() 

1556 

1557 # check whether Tensorboard and read_only make sense 

1558 if self.read_only and self.p.tensorboard: 

1559 raise Exception("Can't use tensorboard, when `read_only` is set to True.") 

1560 

1561 # clear old sessions 

1562 tf.keras.backend.clear_session() 

1563 

1564 # get the CVs: 

1565 if trajs is not None: 

1566 if trajs.__class__.__name__ == "SingleTraj": 

1567 trajs = trajs._gen_ensemble() 

1568 self.trajs = trajs 

1569 

1570 # add the sidechain_info if sidechains need to be reconstructed 

1571 if self.p.reconstruct_sidechains: 

1572 self.p.sidechain_info = self.trajs.sidechain_info() 

1573 

1574 # decide on the dataset 

1575 if ( 

1576 all([traj._traj_file.suffix in [".h5", ".nc"] for traj in trajs]) 

1577 and trajs.CVs_in_file 

1578 and use_dataset_when_possible 

1579 ): 

1580 # if all CVs in a h5 file, we can load get batches from there 

1581 dataset = trajs.tf_dataset( 

1582 batch_size=self.p.batch_size, 

1583 sidechains=self.p.use_sidechains, 

1584 reconstruct_sidechains=self.p.reconstruct_sidechains, 

1585 deterministic=deterministic, 

1586 ) 

1587 self.inp_CV_data = trajs.CVs 

1588 

1589 else: 

1590 # if not, we need to load them 

1591 if not self.p.reconstruct_sidechains: 

1592 should_be = { 

1593 "central_angles", 

1594 "central_cartesians", 

1595 "central_dihedrals", 

1596 "central_distances", 

1597 "side_dihedrals", 

1598 } 

1599 else: 

1600 raise NotImplementedError( 

1601 f"Loading CVs with reconstruct_sidechains is currently not implemented." 

1602 ) 

1603 if dataset is None: 

1604 if not self.trajs.CVs: 

1605 missing = list(should_be - set(trajs.CVs.keys())) 

1606 if missing != []: 

1607 print("loading missing values: ", missing) 

1608 self.trajs.load_CVs(missing, ensemble=ensemble) 

1609 else: 

1610 if not should_be.issubset(set(self.trajs.CVs.keys())): 

1611 self.trajs.load_CVs(list(should_be), ensemble=ensemble) 

1612 

1613 if not should_be.issubset(set(self.trajs.CVs.keys())): 

1614 if not ensemble: 

1615 msg = ( 

1616 f" You can try to set `ensemble=True` to load " 

1617 f"these trajectories into an ensemble, which " 

1618 f"allows features with different feature length." 

1619 ) 

1620 else: 

1621 msg = "" 

1622 raise Exception( 

1623 f"Could not load CVs. Should be {should_be}, but " 

1624 f"currently only {set(trajs.CVs.keys())} are loaded.{msg}" 

1625 ) 

1626 

1627 # create dataset 

1628 if dataset is None: 

1629 ( 

1630 self.sparse, 

1631 self.train_data, 

1632 self.inp_CV_data, 

1633 ) = self.get_train_data_from_trajs(self.trajs, self.p) 

1634 if not self.p.reconstruct_sidechains: 

1635 dataset = tf.data.Dataset.from_tensor_slices( 

1636 ( 

1637 self.inp_CV_data["central_angles"], 

1638 self.inp_CV_data["central_dihedrals"], 

1639 self.inp_CV_data["central_cartesians"], 

1640 self.inp_CV_data["central_distances"], 

1641 self.inp_CV_data["side_dihedrals"], 

1642 ) 

1643 ) 

1644 dataset = dataset.shuffle( 

1645 buffer_size=self.inp_CV_data["central_cartesians"].shape[0], 

1646 reshuffle_each_iteration=True, 

1647 ) 

1648 else: 

1649 dataset = tf.data.Dataset.from_tensor_slices( 

1650 ( 

1651 self.inp_CV_data["central_angles"], 

1652 self.inp_CV_data["central_dihedrals"], 

1653 self.inp_CV_data["all_cartesians"], 

1654 self.inp_CV_data["central_distances"], 

1655 self.inp_CV_data["side_dihedrals"], 

1656 self.inp_CV_data["side_angles"], 

1657 self.inp_CV_data["side_dihedrals"], 

1658 self.inp_CV_data["side_distances_"], 

1659 ) 

1660 ) 

1661 dataset = dataset.shuffle( 

1662 buffer_size=self.inp_CV_data["all_cartesians"].shape[0], 

1663 reshuffle_each_iteration=True, 

1664 ) 

1665 dataset = dataset.repeat() 

1666 self.dataset = dataset.batch(self.p.batch_size) 

1667 else: 

1668 self.dataset = dataset 

1669 self.sparse = any( 

1670 [isinstance(t, tf.SparseTensorSpec) for t in self.dataset.element_spec] 

1671 ) 

1672 

1673 # create model based on user input 

1674 if model is None: 

1675 self.model = self._get_model() 

1676 else: 

1677 self.model = model 

1678 

1679 # setup callbacks 

1680 self._setup_callbacks() 

1681 

1682 # create loss based on user input 

1683 self.loss = self._get_loss() 

1684 

1685 # choose optimizer 

1686 self.optimizer = tf.keras.optimizers.Adam( 

1687 learning_rate=self.p.learning_rate, clipvalue=1.0 

1688 ) 

1689 

1690 # compile model 

1691 self.model.compile( 

1692 optimizer=self.optimizer, 

1693 loss=self.loss, 

1694 metrics=self.metrics, 

1695 unpacked_loss_fns={fn.__name__: fn for fn in self.loss}, 

1696 ) 

1697 

1698 # do this if tensorboard is true. 

1699 if self.p.tensorboard: 

1700 # print shapes 

1701 print("input shapes are:") 

1702 if hasattr(self, "inp_CV_data"): 

1703 print({k: v.shape for k, v in self.inp_CV_data.items()}) 

1704 else: 

1705 for d in self.dataset: 

1706 break 

1707 print([v.shape for v in d]) 

1708 

1709 # set _log_images False to fix the backend after training 

1710 self._log_images = False 

1711 # get the output from model summary. 

1712 with Capturing() as output: 

1713 self.model.summary() 

1714 with open(self.p.main_path + "/model_summary.txt", "w") as f: 

1715 f.write("\n".join(output)) 

1716 self.plot_network() 

1717 print( 

1718 f"Saved a text-summary of the model and an image in {self.p.main_path},", 

1719 "as specified in 'main_path' in the parameters.", 

1720 ) 

1721 

1722 # sets up the tb callback to plot the model 

1723 self.tb_callback = tf.keras.callbacks.TensorBoard( 

1724 self.p.main_path, write_graph=True 

1725 ) 

1726 self.tb_callback.set_model(self.model) 

1727 

1728 def _print_save_message(self) -> None: 

1729 """Prints a message, to inform user where model is saved.""" 

1730 _print_save_message(self) 

1731 

1732 def _get_model(self) -> tf.keras.Model: 

1733 """Constructs a model from an instance of this class.""" 

1734 return _get_model(self) 

1735 

1736 def add_loss(self, loss): 

1737 """Adds a new loss to the existing losses.""" 

1738 try: 

1739 self.loss.append(loss(self.model, self.p)) 

1740 except TypeError: 

1741 self.loss.append(loss(self.model)) 

1742 

1743 def add_callback(self, callback): 

1744 """Adds a new callback to the existing callbacks.""" 

1745 try: 

1746 self.callbacks.append(callback(self.p)) 

1747 except TypeError: 

1748 self.callbacks.append(callback) 

1749 

1750 def add_metric(self, metric): 

1751 """Adds a new metric to the existing metrics.""" 

1752 self.metrics.append(metric(self.p)) 

1753 

1754 def add_images_to_tensorboard(self, *args: Any, **kwargs: Any) -> None: 

1755 """Adds images of the latent space to tensorboard. 

1756 

1757 Args: 

1758 data (Optional[Union[np.ndarray, Sequence[np.ndarray]]): The input-data will 

1759 be passed through the encoder part of the autoencoder. If None 

1760 is provided, a set of 10_000 points from `self.train_data` will 

1761 be taken. A list[np.ndarray] is needed for the functional API of the 

1762 `AngleDihedralCartesianEncoderMap`, that takes a list of 

1763 [angles, dihedrals, side_dihedrals]. Defaults to None. 

1764 image_step (Optional[int]): The interval in which to plot 

1765 images to tensorboard. If None is provided, the `image_step` 

1766 will be the same as `Parameters.summary_step`. Defaults to None. 

1767 max_size (int): The maximum size of the high-dimensional data, that is 

1768 projected. Prevents excessively large-datasets from being projected 

1769 at every `image_step`. Defaults to 10_000. 

1770 scatter_kws (Optional[dict[str, Any]]): A dict with items that 

1771 `plotly.express.scatter()` will accept. If None is provided, 

1772 a dict with size 20 will be passed to 

1773 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`, 

1774 which sets an appropriate size of scatter points for the size of 

1775 datasets encodermap is usually used for. 

1776 hist_kws (Optional[dict[str, Any]]): A dict with items that 

1777 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a 

1778 dict with bins 50 will be passed to 

1779 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`. 

1780 You can choose a colormap here by providing `{'bins': 50, 'cmap': 

1781 'plasma'}` for this argument. 

1782 additional_fns (Optional[Sequence[Callable]]): A list of functions 

1783 that will accept the low-dimensional output of the `Autoencoder` 

1784 latent/bottleneck layer and return a tf.Tensor that can be logged 

1785 by `tf.summary.image()`. See the notebook 

1786 'writing_custom_images_to_tensorboard.ipynb' in 

1787 tutorials/notebooks_customization for more info. If None is 

1788 provided, no additional functions will be used to plot to 

1789 tensorboard. Defaults to None. 

1790 when (Literal["epoch", "batch"]): When to log the images can be 

1791 either 'batch', then the images will be logged after every step 

1792 during training, or 'epoch', then only after every image_step 

1793 epoch the images will be written. Defaults to 'epoch'. 

1794 save_to_disk (bool): Whether to also write the images to disk. 

1795 

1796 """ 

1797 _add_images_to_tensorboard(self, *args, **kwargs) 

1798 

1799 def train(self) -> Optional[dict[str, Any]]: 

1800 """Overwrites the parent class' `train()` method to implement references.""" 

1801 if all([v == 1 for k, v in self.p.__dict__.items() if "reference" in k]): 

1802 self.train_for_references() 

1803 else: 

1804 print("References are already provided. Skipping reference training.") 

1805 if self.p.current_training_step >= self.p.n_steps: 

1806 print( 

1807 f"This {self.__class__.__name__} instance has already been trained " 

1808 f"for {self.p.current_training_step} steps. Increase the training " 

1809 f"steps by calling `{self.__class__.__name__}.p.n_steps += new_steps` " 

1810 f"and then call `{self.__class__.__name__}.train()` again." 

1811 ) 

1812 return 

1813 

1814 if self.p.training == "custom" and self.p.batched: 

1815 raise NotImplementedError() 

1816 elif self.p.training == "custom" and not self.p.batched: 

1817 raise NotImplementedError() 

1818 elif self.p.training == "auto": 

1819 epochs = self.p.n_steps - self.p.current_training_step 

1820 history = self.model.fit( 

1821 self.dataset, 

1822 batch_size=self.p.batch_size, 

1823 epochs=epochs, 

1824 steps_per_epoch=1, 

1825 verbose=0, 

1826 callbacks=self.callbacks, 

1827 ).history 

1828 else: 

1829 raise ValueError( 

1830 f"training argument needs to be `auto` or `custom`. You provided '{self.training}'." 

1831 ) 

1832 self.p.current_training_step += self.p.n_steps - self.p.current_training_step 

1833 if not self.read_only: 

1834 self.p.save() 

1835 self.save() 

1836 

1837 history_to_json_serializable(history) 

1838 return history 

1839 

1840 def train_for_references(self, subsample: int = 100, maxiter: int = 500) -> None: 

1841 """Calculates the angle, dihedral, and cartesian costs to so-called 

1842 references, which can be used to bring these costs to a similar 

1843 magnitude. 

1844 

1845 """ 

1846 p = ADCParameters( 

1847 cartesian_cost_scale=1, 

1848 angle_cost_scale=1, 

1849 dihedral_cost_scale=1, 

1850 ) 

1851 if hasattr(self, "trajs"): 

1852 nsteps = min(maxiter, max(1, int(self.trajs.n_frames / self.p.batch_size))) 

1853 else: 

1854 return 

1855 # fmt: off 

1856 costs = { 

1857 "dihedral_cost": ["central_dihedrals", 1, dihedral_loss(self.model, p)], 

1858 "angle_cost": ["central_angles", 0, angle_loss(self.model, p)], 

1859 "cartesian_cost": ["central_cartesians", 2, cartesian_loss(self.model, parameters=p)], 

1860 } 

1861 # fmt: on 

1862 # Local Folder Imports 

1863 from ..models.models import ADCSparseFunctionalModel 

1864 

1865 if isinstance(self.model, ADCSparseFunctionalModel): 

1866 to_dense_models = { 

1867 "dihedral_cost": self.model.get_dense_model_central_dihedrals, 

1868 "angle_cost": self.model.get_dense_model_central_angles, 

1869 "cartesian_cost": self.model.get_dense_model_distances, 

1870 } 

1871 

1872 cost_references = {key: [] for key in costs.keys()} 

1873 for key, val in costs.items(): 

1874 if key in ["dihedral_cost", "angle_cost"]: 

1875 inp = self.trajs.CVs[val[0]] 

1876 if np.any(np.isnan(inp)): 

1877 inp = np_to_sparse_tensor(inp[::subsample]) 

1878 inp = to_dense_models[key](inp).numpy() 

1879 means = np.repeat( 

1880 np.expand_dims( 

1881 np.mean(inp, 0), 

1882 axis=0, 

1883 ), 

1884 repeats=self.p.batch_size, 

1885 axis=0, 

1886 ) 

1887 costs[key].append(means) 

1888 else: 

1889 inp = self.trajs.CVs["central_distances"] 

1890 if np.any(np.isnan(inp)): 

1891 inp = np_to_sparse_tensor(inp[::subsample]) 

1892 inp = to_dense_models[key](inp).numpy() 

1893 mean_lengths = np.expand_dims(np.mean(inp, axis=0), axis=0) 

1894 chain = chain_in_plane(mean_lengths, costs["angle_cost"][3]) 

1895 gen_cartesians = dihedrals_to_cartesian_tf( 

1896 costs["dihedral_cost"][3] + np.pi, chain 

1897 ) 

1898 pd = pairwise_dist( 

1899 gen_cartesians[ 

1900 :, 

1901 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

1902 ], 

1903 flat=True, 

1904 ) 

1905 costs[key].append(pd) 

1906 

1907 with tqdm( 

1908 desc="Calculating references", 

1909 total=nsteps, 

1910 position=0, 

1911 leave=True, 

1912 ) as pbar: 

1913 for i, data in zip(range(nsteps), self.dataset): 

1914 for d in data: 

1915 if not isinstance(d, tf.sparse.SparseTensor): 

1916 if np.any(np.isnan(d)) and not self.sparse: 

1917 raise Exception( 

1918 f"Received data containing nans from `self.dataset` ({d=})," 

1919 f"while `self.sparse` is set to True ({self.sparse=}). " 

1920 f"This training won't work as nans compromise the weights " 

1921 f"of the whole model. Try to explicitly set `sparse=True`, " 

1922 f"when instantiating the {self.__class__.__name__} class." 

1923 ) 

1924 for key, val in costs.items(): 

1925 if key in ["dihedral_cost", "angle_cost"]: 

1926 if isinstance(data[val[1]], tf.sparse.SparseTensor): 

1927 d = to_dense_models[key](data[val[1]]).numpy() 

1928 else: 

1929 d = data[val[1]] 

1930 cost_references[key].append(val[2](d, val[3]).numpy()) 

1931 if key == "cartesian_cost": 

1932 if isinstance(data[val[1]], tf.sparse.SparseTensor): 

1933 d = self.model.get_dense_model_cartesians( 

1934 data[val[1]] 

1935 ).numpy() 

1936 # un-flatten the cartesian coordinates 

1937 d = d.reshape(len(d), -1, 3) 

1938 else: 

1939 d = data[val[1]] 

1940 pd = pairwise_dist( 

1941 d[ 

1942 :, 

1943 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

1944 ], 

1945 flat=True, 

1946 ) 

1947 c = val[2](val[3], pd).numpy() 

1948 cost_references["cartesian_cost"].append(c) 

1949 pbar.update() 

1950 s = {k: np.mean(v) for k, v in cost_references.items()} 

1951 print(f"After {i} steps setting cost references: {s} to parameters.") 

1952 self.p.angle_cost_reference = float(np.mean(cost_references["angle_cost"])) 

1953 self.p.dihedral_cost_reference = float( 

1954 np.mean(cost_references["dihedral_cost"]) 

1955 ) 

1956 self.p.cartesian_cost_reference = float( 

1957 np.mean(cost_references["cartesian_cost"]) 

1958 ) 

1959 if not self.read_only: 

1960 self.p.save() 

1961 return cost_references 

1962 

1963 def save(self, step: Optional[int] = None) -> None | Path: 

1964 """Saves the model to the current path defined in `parameters.main_path`. 

1965 

1966 Args: 

1967 step (Optional[int]): Does not save the model at the given 

1968 training step, but rather changes the string used for saving 

1969 the model from a datetime format to another. 

1970 

1971 Returns: 

1972 Union[None, Path]: When the model has been saved, the Path will 

1973 be returned. If the model could not be saved. None will be 

1974 returned. 

1975 

1976 """ 

1977 if not self.read_only: 

1978 return save_model( 

1979 self.model, 

1980 self.p.main_path, 

1981 inp_class_name=self.__class__.__name__, 

1982 step=step, 

1983 print_message=True, 

1984 ).parent 

1985 else: 

1986 print( 

1987 f"This {self.__class__.__name__} is set to read_only. Set " 

1988 f"`{self.__class__.__name__}.read_only=False` to save the " 

1989 f"current state of the model." 

1990 ) 

1991 

1992 def close(self) -> None: 

1993 """Clears the current keras backend and frees up resources.""" 

1994 # clear old sessions 

1995 tf.keras.backend.clear_session() 

1996 

1997 def set_train_data(self, data: TrajEnsemble) -> None: 

1998 """Resets the train data for reloaded models.""" 

1999 ( 

2000 sparse, 

2001 self.train_data, 

2002 self.inp_CV_data, 

2003 ) = self.get_train_data_from_trajs(data, self.p) 

2004 self._using_hypercube = False 

2005 self.p.using_hypercube = False 

2006 if not self.sparse and sparse: 

2007 print( 

2008 f"The provided data contains nan's, but the model was trained " 

2009 f"on dense input data." 

2010 ) 

2011 return 

2012 if not self.parameters.reconstruct_sidechains: 

2013 data = [ 

2014 self.inp_CV_data["central_angles"], 

2015 self.inp_CV_data["central_dihedrals"], 

2016 self.inp_CV_data["central_cartesians"], 

2017 self.inp_CV_data["central_distances"], 

2018 self.inp_CV_data["side_dihedrals"], 

2019 ] 

2020 else: 

2021 data = [ 

2022 self.inp_CV_data["central_angles"], 

2023 self.inp_CV_data["central_dihedrals"], 

2024 self.inp_CV_data["all_cartesians"], 

2025 self.inp_CV_data["central_distances"], 

2026 self.inp_CV_data["side_dihedrals"], 

2027 self.inp_CV_data["side_angles"], 

2028 self.inp_CV_data["side_dihedrals"], 

2029 self.inp_CV_data["side_distances_"], 

2030 ] 

2031 dataset = tf.data.Dataset.from_tensor_slices(tuple(data)) 

2032 dataset = dataset.shuffle( 

2033 buffer_size=self.inp_CV_data["central_cartesians"].shape[0], 

2034 reshuffle_each_iteration=True, 

2035 ) 

2036 dataset = dataset.repeat() 

2037 self.dataset = dataset.batch(self.p.batch_size) 

2038 

2039 def plot_network(self) -> None: 

2040 """Tries to plot the network using pydot, pydotplus and graphviz. 

2041 Doesn't raise an exception if plotting is not possible. 

2042 

2043 Note: 

2044 Refer to this guide to install these programs: 

2045 https://stackoverflow.com/questions/47605558/importerror-failed-to-import-pydot-you-must-install-pydot-and-graphviz-for-py 

2046 

2047 """ 

2048 try: 

2049 out = plot_model(self.model, None) 

2050 if out is not None: 

2051 out.save(Path(self.p.main_path) / "network.png") 

2052 except: 

2053 pass 

2054 

2055 @staticmethod 

2056 def get_train_data_from_trajs( 

2057 trajs: Union[TrajEnsemble, SingleTraj], 

2058 p: ADCParameters, 

2059 attr: str = "CVs", 

2060 max_size: int = -1, 

2061 ) -> tuple[bool, list[np.ndarray], dict[str, np.ndarray]]: 

2062 """Builds train data from a `TrajEnsemble`. 

2063 

2064 Args: 

2065 trajs (TrajEnsemble): A `TrajEnsemble` instance. 

2066 p (encodermap.parameters.ADCParameters): An instance of `encodermap.parameters.ADCParameters`. 

2067 attr (str): Which attribute to get from `TrajEnsemble`. This defaults 

2068 to 'CVs', because 'CVs' is usually a dict containing the CV data. 

2069 However, you can build the train data from any dict in the `TrajEnsemble`. 

2070 max_size (int): When you only want a subset of the CV data. Set this 

2071 to the desired size. 

2072 

2073 Returns: 

2074 tuple: A tuple containing the following: 

2075 - bool: A bool that shows whether some 'CV' values are `np.nan` (True), 

2076 which will be used to decide whether the sparse training 

2077 will be used. 

2078 - list[np.ndarray]: An array of features fed into the autoencoder, 

2079 concatenated along the feature axis. The order of the 

2080 features is: central_angles, central_dihedral, (side_dihedrals 

2081 if p.use_sidechain_dihedrals is True). 

2082 - dict[str, np.ndarray]: The training data as a dict. Containing 

2083 all values in `trajs.CVs`. 

2084 

2085 """ 

2086 # Local Folder Imports 

2087 from ..misc.misc import FEATURE_NAMES 

2088 

2089 assert hasattr(trajs, attr), ( 

2090 f"Can't load train data from the attribute {attr}. " 

2091 f"{trajs.__class__.__name__} has no attribute '{attr}'" 

2092 ) 

2093 if not any([np.isnan(x).any() for x in getattr(trajs, attr).values()]): 

2094 inp_CV_data = { 

2095 key: val.astype("float32") for key, val in getattr(trajs, attr).items() 

2096 } 

2097 

2098 # squeeze, if xarray is provided 

2099 if all([hasattr(v, "values") for v in inp_CV_data.values()]): 

2100 inp_CV_data = {k: v.values.squeeze() for k, v in inp_CV_data.items()} 

2101 sparse = False 

2102 else: 

2103 sparse = True 

2104 

2105 # check whether the nans are correctly distributed 

2106 for k, v in trajs.CVs.items(): 

2107 if k not in list(FEATURE_NAMES.values()): 

2108 continue 

2109 if v.ndim == 3: 

2110 v = np.any(np.all(np.isnan(v), (1, 2))) 

2111 else: 

2112 v = np.any(np.all(np.isnan(v), 1)) 

2113 if v: 

2114 raise Exception( 

2115 f"Stacking of frames for CV `{k}` did not " 

2116 f"succeed. There are frames full of nans." 

2117 ) 

2118 

2119 # build the CV data 

2120 inp_CV_data = { 

2121 key: val.astype("float32") for key, val in getattr(trajs, attr).items() 

2122 } 

2123 

2124 # squeeze, if xarray is provided 

2125 if all([hasattr(v, "values") for v in inp_CV_data.values()]): 

2126 inp_CV_data = {k: v.values.squeeze() for k, v in inp_CV_data.items()} 

2127 

2128 # subsample if required 

2129 if max_size > -1: 

2130 for k, v in inp_CV_data.items(): 

2131 idx = np.unique( 

2132 np.round(np.linspace(0, v.shape[0] - 1, max_size)).astype(int) 

2133 ) 

2134 inp_CV_data[k] = v[idx] 

2135 

2136 for k, v in inp_CV_data.items(): 

2137 if np.any(np.isnan(v)): 

2138 values = v 

2139 if k == "central_cartesians": 

2140 values = values.reshape(len(values), -1) 

2141 sparse_tensor = np_to_sparse_tensor(values) 

2142 inp_CV_data[k] = sparse_tensor 

2143 

2144 if not p.reconstruct_sidechains: 

2145 if not p.use_backbone_angles and not p.use_sidechains: 

2146 train_data = inp_CV_data["central_dihedrals"] 

2147 elif p.use_backbone_angles and not p.use_sidechains: 

2148 train_data = [ 

2149 inp_CV_data["central_angles"], 

2150 inp_CV_data["central_dihedrals"], 

2151 ] 

2152 if p.model_api == "sequential" and not sparse: 

2153 train_data = np.hstack(train_data) 

2154 elif p.use_backbone_angles and p.use_sidechains: 

2155 train_data = [ 

2156 inp_CV_data["central_angles"], 

2157 inp_CV_data["central_dihedrals"], 

2158 inp_CV_data["side_dihedrals"], 

2159 ] 

2160 if p.model_api == "sequential" and not sparse: 

2161 train_data = np.hstack(train_data) 

2162 else: 

2163 raise Exception( 

2164 "Cannot train model with central dihedrals and side " 

2165 "dihedrals only. Backbone angles are required." 

2166 ) 

2167 # some checks for the length of the train data 

2168 if p.model_api == "functional": 

2169 if not p.use_backbone_angles and not p.use_sidechains: 

2170 pass 

2171 elif p.use_backbone_angles and not p.use_sidechains: 

2172 assert len(train_data) == 2 

2173 else: 

2174 assert len(train_data) == 3 

2175 else: 

2176 train_data = [ 

2177 inp_CV_data["central_angles"], 

2178 inp_CV_data["central_dihedrals"], 

2179 inp_CV_data["side_angles"], 

2180 inp_CV_data["side_dihedrals"], 

2181 ] 

2182 return sparse, train_data, inp_CV_data 

2183 

2184 @property 

2185 def encoder(self) -> tf.keras.Model: 

2186 """tf.keras.Model: The encoder Model.""" 

2187 return self.model.encoder_model 

2188 

2189 @property 

2190 def decoder(self) -> tf.keras.Model: 

2191 """tf.keras.Model: The decoder Model.""" 

2192 return self.model.decoder_model 

2193 

2194 @classmethod 

2195 def from_checkpoint( 

2196 cls: Type[AngleDihedralCartesianEncoderMapType], 

2197 trajs: Union[None, TrajEnsemble], 

2198 checkpoint_path: Union[Path, str], 

2199 dataset: Optional[tf.data.Dataset] = None, 

2200 use_previous_model: bool = False, 

2201 compat: bool = False, 

2202 ) -> AngleDihedralCartesianEncoderMapType: 

2203 """Reconstructs the model from a checkpoint. 

2204 

2205 Although the model can be loaded from disk without any form of data and 

2206 still yield the correct input and output shapes, it is required to either 

2207 provide `trajs` or `dataset` to double-check, that the correct model will 

2208 be reloaded. 

2209 

2210 This is also, whe the `sparse` argument is not needed, as sparcity of the 

2211 input data is a property of the `TrajEnsemble` provided. 

2212 

2213 Args: 

2214 trajs (Union[None, TrajEnsemble]): Either None (in which case, the 

2215 argument `dataset` is required), or an instance of `TrajEnsemble`, 

2216 which was used to instantiate the `AngleDihedralCartesianEncoderMap`, 

2217 before it was saved to disk. 

2218 checkpoint_path (Union[Path, str]): The path to the checkpoint. Can 

2219 either be the path to a .keras file or to a directory containing 

2220 .keras files, in which case the most recently created .keras 

2221 file will be used. 

2222 dataset (Optional[tf.data.Dataset]): If `trajs` is not provided, a 

2223 dataset is required to make sure the input shapes match the model, 

2224 that is stored on the disk. 

2225 use_previous_model (bool): Set this flag to True, if you load a model 

2226 from an in-between checkpoint step (e.g., to continue training with 

2227 different parameters). If you have the files saved_model_0.keras, 

2228 saved_model_500.keras and saved_model_1000.keras, setting this to 

2229 True and loading the saved_model_500.keras will back up the 

2230 saved_model_1000.keras. 

2231 compat (bool): Whether to use compatibility mode when missing or wrong 

2232 parameter files are present. In this special case, some assumptions 

2233 about the network architecture are made from the model and the 

2234 parameters in parameters.json overwritten accordingly (a backup 

2235 will also be made). 

2236 

2237 Returns: 

2238 AngleDihedralCartesianEncoderMapType: An instance of `AngleDihedralCartesianEncoderMap`. 

2239 

2240 """ 

2241 return load_model( 

2242 cls, 

2243 checkpoint_path, 

2244 trajs=trajs, 

2245 dataset=dataset, 

2246 use_previous_model=use_previous_model, 

2247 compat=compat, 

2248 ) 

2249 

2250 def _setup_callbacks(self) -> None: 

2251 """Overwrites the parent class' `_setup_callbacks` method. 

2252 

2253 Due to the 'soft start' of the cartesian cost, the `cartesian_increase_callback` 

2254 needs to be added to the list of callbacks. 

2255 

2256 """ 

2257 self.metrics = [] 

2258 self.callbacks = [] 

2259 self.callbacks.append(ProgressBar(parameters=self.p)) 

2260 if not self.read_only: 

2261 self.callbacks.append(CheckpointSaver(self.p)) 

2262 if self.p.tensorboard: 

2263 self.tensorboard_write_bool: Union[TensorboardWriteBool, None] = ( 

2264 TensorboardWriteBool(self.p) 

2265 ) 

2266 self.callbacks.append(self.tensorboard_write_bool) 

2267 file_writer = tf.summary.create_file_writer(self.p.main_path + "/train") 

2268 file_writer.set_as_default() 

2269 tf.summary.text( 

2270 name=f"Parameters Summary for {self.p.main_path}", 

2271 data=self.p.parameters, 

2272 step=0, 

2273 ) 

2274 else: 

2275 self.tensorboard_write_bool = None 

2276 if self.p.cartesian_cost_scale_soft_start != (None, None): 

2277 self.cartesian_increase_callback = IncreaseCartesianCost(self.p) 

2278 self.callbacks.append(self.cartesian_increase_callback) 

2279 if self.p.track_clashes: 

2280 self.metrics.append(ADCClashMetric(parameters=self.p, distance_unit="nm")) 

2281 if self.p.track_RMSD: 

2282 self.metrics.append(ADCRMSDMetric(parameters=self.p)) 

2283 

2284 def _get_loss(self): 

2285 loss = self.p.loss 

2286 if loss == "reconstruction_loss": 

2287 loss = reconstruction_loss(self.model) 

2288 elif loss == "emap_cost": 

2289 self.dihedral_loss = dihedral_loss( 

2290 self.model, self.p, self.tensorboard_write_bool 

2291 ) 

2292 self.angle_loss = angle_loss( 

2293 self.model, self.p, self.tensorboard_write_bool 

2294 ) 

2295 if self.p.cartesian_cost_scale_soft_start != (None, None): 

2296 self.cartesian_loss = cartesian_loss( 

2297 self.model, 

2298 self.cartesian_increase_callback, 

2299 self.p, 

2300 self.tensorboard_write_bool, 

2301 print_current_scale=False, 

2302 ) 

2303 else: 

2304 self.cartesian_loss = cartesian_loss( 

2305 self.model, None, self.p, self.tensorboard_write_bool 

2306 ) 

2307 

2308 self.distance_loss = distance_loss( 

2309 self.model, self.p, self.tensorboard_write_bool 

2310 ) 

2311 self.cartesian_distance_loss = cartesian_distance_loss( 

2312 self.model, self.p, self.tensorboard_write_bool 

2313 ) 

2314 self.center_loss = center_loss( 

2315 self.model, self.p, self.tensorboard_write_bool 

2316 ) 

2317 self.regularization_loss = regularization_loss( 

2318 self.model, self.p, self.tensorboard_write_bool 

2319 ) 

2320 loss = [ 

2321 self.dihedral_loss, 

2322 self.angle_loss, 

2323 self.cartesian_loss, 

2324 self.distance_loss, 

2325 self.cartesian_distance_loss, 

2326 self.center_loss, 

2327 self.regularization_loss, 

2328 ] 

2329 if self.p.use_sidechains: 

2330 self.side_dihedral_loss = side_dihedral_loss( 

2331 self.model, self.p, self.tensorboard_write_bool 

2332 ) 

2333 loss.append(self.side_dihedral_loss) 

2334 elif loss == "mse": 

2335 loss = "mse" 

2336 else: 

2337 raise ValueError( 

2338 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'." 

2339 ) 

2340 return loss 

2341 

2342 def encode( 

2343 self, 

2344 data: Optional[Union[TrajEnsemble, SingleTraj, Sequence[np.ndarray]]] = None, 

2345 ) -> np.ndarray: 

2346 """Runs the central_angles, central_dihedrals, (side_dihedrals) through the 

2347 autoencoder. Make sure that `data` has the correct shape. 

2348 

2349 Args: 

2350 data (Sequence[np.ndarray]): Provide a sequence of angles, and 

2351 central_dihedrals, if you used sidechain_dihedrals during training 

2352 append these to the end of the sequence. 

2353 

2354 Returns: 

2355 np.ndarray: The latent space representation of the provided `data`. 

2356 

2357 """ 

2358 if data is None: 

2359 if hasattr(self, "trajs"): 

2360 data = self.trajs 

2361 else: 

2362 for data in self.dataset: 

2363 break 

2364 if not self.p.use_sidechains and not self.p.use_backbone_angles: 

2365 data: np.ndarray = data[1] # type: ignore[no-redef] 

2366 elif self.p.use_backbone_angles and not self.p.use_sidechains: 

2367 data: list[np.ndarray] = [data[0], data[1]] # type: ignore[no-redef] 

2368 else: 

2369 data: list[np.ndarray] = [data[0], data[1], data[-1]] # type: ignore[no-redef] 

2370 

2371 if not hasattr(data, "trajs") and hasattr(data, "_CVs") and data is not None: 

2372 traj: SingleTraj = data # type: ignore[assignment] 

2373 _, data, __ = self.get_train_data_from_trajs(traj, self.p, attr="_CVs") 

2374 if isinstance(data, (tf.SparseTensor)): 

2375 shape = data.dense_shape[0] * data.dense_shape[1] 

2376 elif isinstance(data, np.ndarray): 

2377 shape = data.size 

2378 data = np.expand_dims(data, 0) 

2379 elif isinstance(data[0], (np.ndarray, tf.Tensor)): 

2380 shape = data[0].size 

2381 elif isinstance(data[0], tf.sparse.SparseTensor): 

2382 shape = data[0].values.shape[0] 

2383 else: 

2384 raise Exception(f"Unexpected datatype {data=}") 

2385 if shape > 100_000: 

2386 print( 

2387 f"Due to the size of the provided data {shape}, I " 

2388 f"need to chunk it, which takes longer. Sit back, grab a coffee..." 

2389 ) 

2390 indices = np.split( 

2391 np.arange(traj.n_frames), np.arange(100, traj.n_frames, 100) 

2392 ) 

2393 # single frame encoding does not work, because the frame axis is dropped 

2394 if len(indices[-1]) == 1: 

2395 indices = np.split( 

2396 np.arange(traj.n_frames), np.arange(100, traj.n_frames, 101) 

2397 ) 

2398 lowd = [] 

2399 for i, ind in enumerate(indices): 

2400 _, data, __ = self.get_train_data_from_trajs( 

2401 traj[ind], self.p, attr="_CVs" 

2402 ) 

2403 try: 

2404 lowd.append(self.model.encoder_model(data)) 

2405 except IndexError as e: 

2406 raise Exception(f"{i=} {ind=} {data=}") from e 

2407 return np.vstack(lowd) 

2408 else: 

2409 return self.encode(data) 

2410 elif hasattr(data, "trajs"): 

2411 lowd = [] 

2412 for traj in data.trajs: # type: ignore[union-attr] 

2413 lowd.append(self.encode(traj)) 

2414 return np.vstack(lowd) 

2415 elif isinstance(data, Sequence): 

2416 # Standard Library Imports 

2417 from functools import reduce 

2418 from operator import mul 

2419 

2420 size = reduce(mul, data[0].shape) 

2421 if size > 100_000: 

2422 indices = np.split( 

2423 np.arange(data[0].shape[0]), 

2424 np.arange(100, data[0].shape[0], 100), 

2425 ) 

2426 if len(indices[-1]) == 1: 

2427 indices = np.split( 

2428 np.arange(data[0].shape[0]), 

2429 np.arange(100, data[0].shape[0], 101), 

2430 ) 

2431 lowd = [] 

2432 for i, ind in enumerate(indices): 

2433 datum = [] 

2434 for d in data: 

2435 if isinstance(d, tf.sparse.SparseTensor): 

2436 d = tf.sparse.to_dense(d, default_value=np.nan).numpy()[ind] 

2437 datum.append(np_to_sparse_tensor(d)) 

2438 else: 

2439 datum.append(d[ind]) 

2440 lowd.append(self.model.encoder_model(datum)) 

2441 return np.vstack(lowd) 

2442 else: 

2443 return self.model.encoder_model(data).numpy() # type: ignore[no-any-return] 

2444 elif hasattr(data, "shape") or hasattr(data, "dense_shape"): 

2445 return self.model.encoder_model(data).numpy() # type: ignore[no-any-return] 

2446 else: 

2447 raise TypeError(f"Wrong type for argument `data`: {type(data)=} {data=}.") 

2448 

2449 def decode(self, data: np.ndarray) -> Sequence[np.ndarray]: 

2450 """Calls the decoder part of the model. 

2451 

2452 `AngleDihedralCartesianAutoencoder` will, like the other two classes' 

2453 output a list of np.ndarray. 

2454 

2455 Args: 

2456 data (np.ndarray): The data to be passed to the decoder part of 

2457 the model. Make sure that the shape of the data matches the 

2458 number of neurons in the latent space. 

2459 

2460 Returns: 

2461 Union[list[np.ndarray], np.ndarray]: Outputs from the decoder part. 

2462 For `AngleDihedralCartesianEncoderMap`, this will be a list of 

2463 np.ndarray. 

2464 

2465 """ 

2466 out = self.model.decoder(data) 

2467 if isinstance(out, (list, tuple)): 

2468 out = [o.numpy() for o in out] 

2469 else: 

2470 out = out.numpy() 

2471 return out # type: ignore[no-any-return] 

2472 

2473 @overload 

2474 def generate( 2474 ↛ exitline 2474 didn't jump to the function exit

2475 self, 

2476 points: np.ndarray, 

2477 top: Optional[Union[str, int, Topology]], 

2478 backend: Literal["mdtraj"], 

2479 progbar: Optional[Any], 

2480 ) -> Trajectory: ... 

2481 

2482 @overload 

2483 def generate( 2483 ↛ exitline 2483 didn't jump to the function exit

2484 self, 

2485 points: np.ndarray, 

2486 top: Optional[Union[str, int, Topology]], 

2487 backend: Literal["mdanalysis"], 

2488 progbar: Optional[Any], 

2489 ) -> Universe: ... 

2490 

2491 def generate( 

2492 self, 

2493 points: np.ndarray, 

2494 top: Optional[Union[str, int, Topology]] = None, 

2495 backend: Literal["mdtraj", "mdanalysis"] = "mdtraj", 

2496 progbar: Optional[Any] = None, 

2497 ) -> Union[Universe, Trajectory]: 

2498 """Overrides the parent class' `generate` method and builds a trajectory. 

2499 

2500 Instead of just providing data to `decode` using the decoder part of the 

2501 network, this method also takes a molecular topology as its `top` 

2502 argument. This topology is then used to rebuild a time-resolved 

2503 trajectory. 

2504 

2505 Args: 

2506 points (np.ndarray): The low-dimensional points from which the 

2507 trajectory should be rebuilt. 

2508 top (Optional[str, int, mdtraj.Topology]): The topology to be used for rebuilding the 

2509 trajectory. This should be a string pointing towards a <*.pdb, 

2510 *.gro, *.h5> file. Alternatively, None can be provided; in which 

2511 case, the internal topology (`self.top`) of this class is used. 

2512 Defaults to None. 

2513 backend (str): Defines what MD python package is to use, to build the 

2514 trajectory and also what type this method returns, needs to be 

2515 one of the following: 

2516 * "mdtraj" 

2517 * "mdanalysis" 

2518 

2519 Returns: 

2520 Union[mdtraj.Trajectory, MDAnalysis.universe]: The trajectory after 

2521 applying the decoded structural information. The type of this 

2522 depends on the chosen `backend` parameter. 

2523 

2524 """ 

2525 # get the output this can be done regardless 

2526 out = self.decode(points) 

2527 if isinstance(out, np.ndarray): 

2528 dihedrals = out 

2529 sidechain_dihedrals = None 

2530 elif ( 

2531 isinstance(out, (list, tuple)) 

2532 and len(out) == 2 

2533 and all([isinstance(a, np.ndarray) for a in out]) 

2534 ): 

2535 angles, dihedrals = out 

2536 sidechain_dihedrals = None 

2537 elif ( 

2538 isinstance(out, (list, tuple)) 

2539 and len(out) == 3 

2540 and all([isinstance(a, np.ndarray) for a in out]) 

2541 ): 

2542 angles, dihedrals, sidechain_dihedrals = out 

2543 else: 

2544 raise Exception( 

2545 f"Unexpected length of out detected: ({len(out)}=). Maybe also " 

2546 f"unexpected dtypes: ({[type(a) for a in out]=})." 

2547 ) 

2548 

2549 assert isinstance(self.trajs, TrajEnsemble), ( 

2550 f"`generate()` can only work, when the Autoencoder was built with a " 

2551 f"`TrajEnsemble` and not just a dataset. You can set the `TrajEnsemble` " 

2552 f"of this object with `AngleDihedralCartesianEncoderMap.trajs = TrajEnsemble`." 

2553 ) 

2554 

2555 if top is None: 

2556 if len(self.trajs.top) > 1: 

2557 print( 

2558 f"Please specify which topology you would like to use for generating " 

2559 f"conformations. You can either provide a `str` to a topology file " 

2560 f"(file extension .pdb, .h5, .gro) on disk, or a `int` specifying the " 

2561 f"one of the ensembles {len(self.trajs.top)} topologies " 

2562 f"(see `AngleDihedralCartesianEncoderMap.trajs.top` for available " 

2563 f"topologies). You can also directly supply a " 

2564 f"you can also specify a `mdtraj.Topology` object." 

2565 ) 

2566 return # type: ignore[return-value] 

2567 elif isinstance(top, int): 

2568 mdanalysis_traj = self.trajs[top][0].traj 

2569 elif isinstance(top, str) and top not in self.trajs.common_str: 

2570 mdanalysis_traj = md.load(top) 

2571 elif isinstance(top, str) and top in self.trajs.common_str: 

2572 mdanalysis_traj = self.trajs.trajs_by_common_str[top][0].traj 

2573 elif isinstance(top, md.Topology): 

2574 mdanalysis_traj = top 

2575 else: 

2576 raise ValueError( 

2577 f"Type of argument `top` must be int, str, md.Topology. You provided {type(top)}." 

2578 ) 

2579 

2580 # do the backmapping 

2581 if backend == "mdanalysis": 

2582 with NamedTemporaryFile(suffix=".pdb") as f: 

2583 mdanalysis_traj.save_pdb(f.name) 

2584 uni = dihedral_backmapping( 

2585 f.name, dihedrals, sidechains=sidechain_dihedrals 

2586 ) 

2587 return uni 

2588 elif backend == "mdtraj": 

2589 traj = mdtraj_backmapping( 

2590 top=top, 

2591 dihedrals=dihedrals, 

2592 sidechain_dihedrals=sidechain_dihedrals, 

2593 trajs=self.trajs, 

2594 progbar=progbar, 

2595 ) # type: ignore[call-overload] 

2596 return traj 

2597 else: 

2598 raise TypeError( 

2599 f"backend must be 'mdtraj' or 'mdanalysis', but you provided {backend}" 

2600 )