Coverage for encodermap/autoencoder/autoencoder.py: 8%
820 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 21:06 +0200
1# -*- coding: utf-8 -*-
2# encodermap/autoencoder/autoencoder.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Forward facing Autoencoder classes. Contains four classes:
24* Autoencoder: Simple NN dense, fully connected AE architecture. Reg loss, auto loss and center loss
25* EncoderMap: Uses the same architecture as `Autoencoder`, but adds another loss function.
26* DihedralEncoderMap: Basically the same as `EncoderMap`, but rewrites the `generate` method to use
27 an atomistic topology to rebuild a trajectory.
28* AngleDihedralCartesianEncoderMap: Uses more loss functions and tries to learn a full all atom conformation.
30"""
33################################################################################
34# Imports
35################################################################################
38# Future Imports at the top
39from __future__ import annotations
41# Standard Library Imports
42import copy
43import os
44import warnings
45from pathlib import Path
46from tempfile import NamedTemporaryFile
48# Third Party Imports
49import numpy as np
50import tensorflow as tf
51from optional_imports import _optional_import
52from tqdm import tqdm
54# Encodermap imports
55from encodermap.callbacks.callbacks import (
56 CheckpointSaver,
57 ImageCallback,
58 IncreaseCartesianCost,
59 ProgressBar,
60 TensorboardWriteBool,
61)
62from encodermap.callbacks.metrics import ADCClashMetric, ADCRMSDMetric
63from encodermap.encodermap_tf1.backmapping import (
64 chain_in_plane,
65 dihedrals_to_cartesian_tf,
66)
67from encodermap.loss_functions.loss_functions import (
68 angle_loss,
69 auto_loss,
70 cartesian_distance_loss,
71 cartesian_loss,
72 center_loss,
73 dihedral_loss,
74 distance_loss,
75 reconstruction_loss,
76 regularization_loss,
77 side_dihedral_loss,
78)
79from encodermap.misc.backmapping import dihedral_backmapping, mdtraj_backmapping
80from encodermap.misc.distances import pairwise_dist
81from encodermap.misc.misc import create_n_cube, plot_model
82from encodermap.misc.saving_loading_models import load_model, save_model
83from encodermap.models.models import gen_functional_model, gen_sequential_model
84from encodermap.parameters.parameters import ADCParameters, Parameters
85from encodermap.trajinfo.info_all import TrajEnsemble
86from encodermap.trajinfo.info_single import Capturing, SingleTraj
89################################################################################
90# Optional Imports
91################################################################################
94md = _optional_import("mdtraj")
95mda = _optional_import("MDAnalysis")
98################################################################################
99# Typing
100################################################################################
103# Standard Library Imports
104from collections.abc import Callable, Mapping
105from typing import (
106 TYPE_CHECKING,
107 Any,
108 Literal,
109 Optional,
110 Sequence,
111 Type,
112 TypeVar,
113 Union,
114 overload,
115)
118AutoencoderType = TypeVar("AutoencoderType", bound="Parent")
119EncoderMapType = TypeVar("EncoderMapType", bound="Parent")
120DihedralEncoderMapType = TypeVar("DihedralEncoderMapType", bound="Parent")
121AngleDihedralCartesianEncoderMapType = TypeVar(
122 "AngleDihedralCartesianEncoderMapType", bound="Parent"
123)
124AutoencoderClass = Union[
125 AutoencoderType,
126 EncoderMapType,
127 DihedralEncoderMapType,
128 AngleDihedralCartesianEncoderMapType,
129]
132if TYPE_CHECKING:
133 # Third Party Imports
134 from MDAnalysis import Universe
135 from mdtraj import Topology, Trajectory
138################################################################################
139# Globals
140################################################################################
143__all__: list[str] = [
144 "Autoencoder",
145 "EncoderMap",
146 "AngleDihedralCartesianEncoderMap",
147 "DihedralEncoderMap",
148]
151################################################################################
152# Utils
153################################################################################
156def history_to_json_serializable(d: dict[str, Any]) -> dict[str, Any]:
157 out = {}
158 for k in list(d.keys()):
159 if isinstance(k, Mapping):
160 out |= history_to_json_serializable(d[k])
161 elif isinstance(d[k], np.ndarray):
162 a = d[k]
163 if a.ndim <= 1:
164 out[k] = a.tolist()
165 else:
166 out[k] = [i.tolist() for i in a]
167 elif isinstance(d[k], list) and all(isinstance(i, np.ndarray) for i in d[k]):
168 out[k] = [i.tolist() for i in d[k]]
169 else:
170 out[k] = d[k]
171 return out
174def np_to_sparse_tensor(a: np.ndarray) -> tf.sparse.SparseTensor:
175 """Converts a numpy array with nans to a SparseTensor.
177 Args:
178 a (np.ndarray): The input array.
180 Returns:
181 tf.sparse.SparseTensor: The corresponding SparseTensor.
183 """
184 orig_shape = a.shape
185 indices = np.stack(np.where(~np.isnan(a))).T.astype("int64")
186 dense_shape = a.shape
187 a = a[~np.isnan(a)].flatten()
188 if np.any(np.isnan(a)):
189 raise Exception(
190 f"NaN values in array with shape {orig_shape} could not be removed "
191 f"by indexing with {indices=}. This will result in the SparseTensor "
192 f"containing NaN values."
193 )
194 return tf.sparse.SparseTensor(indices, a, dense_shape)
197def _add_images_to_tensorboard(
198 autoencoder: AutoencoderClass,
199 data: Optional[Union[np.ndarray, Sequence[np.ndarray]]] = None,
200 backend: Literal["matplotlib", "plotly"] = "matplotlib",
201 image_step: Optional[int] = None,
202 max_size: int = 10_000,
203 mpl_scatter_kws: Optional[dict] = None,
204 mpl_hist_kws: Optional[dict] = None,
205 plotly_scatter_kws: Optional[dict] = None,
206 plotly_hist_kws: Optional[dict] = None,
207 additional_fns: Optional[Sequence[Callable]] = None,
208 when: Literal["epoch", "batch"] = "epoch",
209 save_to_disk: bool = False,
210) -> None:
211 """Adds images to Tensorboard using the data in data and the ids in ids.
213 Args:
214 data (Optional[Union[np.ndarray, Sequence[np.ndarray]]): The input-data will
215 be passed through the encoder part of the autoencoder. If None
216 is provided, a set of 10_000 points from `self.train_data` will
217 be taken. A list[np.ndarray] is needed for the functional API of the
218 `AngleDihedralCartesianEncoderMap`, that takes a list of
219 [angles, dihedrals, side_dihedrals]. Defaults to None.
220 backend (Literal["matplotlib", "plotly"]: Which backend to use for
221 plotting. Defaults to 'matplotlib'.
222 mpl_scatter_kws (Optional[dict]): A dictionary, that `matplotlib.pyplot.scatter`
223 takes as keyword args. If None is provided, the default dict
224 is {"s": 20}. Defaults to None.
225 mpl_hist_kws (Optional[dict]): A dictionary, that `matplotlib.pyplot.histogram`
226 takes as keyword args. If None is provided, the default dict
227 is {"bins": 50}. Defaults to None.
228 plotly_scatter_kws (Optional[dict[str, Any]]): A dict with items that
229 `plotly.express.scatter()` will accept. If None is provided,
230 a dict with size 20 will be passed to
231 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`,
232 which sets an appropriate size of scatter points for the size of
233 datasets encodermap is usually used for.
234 plotly_hist_kws (Optional[dict[str, Any]]): A dict with items that
235 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a
236 dict with bins 50 will be passed to
237 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`.
238 You can choose a colormap here by providing `{'bins': 50, 'cmap':
239 'plasma'}` for this argument.
240 image_step (Optional[int]): The interval in which to plot
241 images to tensorboard. If None is provided, the `image_step`
242 will be the same as `Parameters.summary_step`. Defaults to None.
243 max_size (int): The maximum size of the high-dimensional data, that is
244 projected. Prevents excessively large-datasets from being projected
245 at every `image_step`. Defaults to 10_000.
246 additional_fns (Optional[Sequence[Callable]]): A list of functions
247 that will accept the low-dimensional output of the `Autoencoder`
248 latent/bottleneck layer and return a tf.Tensor that can be logged
249 by `tf.summary.image()`. See the notebook
250 'writing_custom_images_to_tensorboard.ipynb' in
251 tutorials/notebooks_customization for more info. If None is
252 provided, no additional functions will be used to plot to
253 tensorboard. Defaults to None.
254 when (Literal["epoch", "batch"]): When to log the images can be
255 either 'batch', then the images will be logged after every step
256 during training, or 'epoch', then only after every image_step
257 epoch the images will be written. Defaults to 'epoch'.
258 save_to_disk (bool): Whether to also write the images to disk.
260 """
261 if not autoencoder.p.tensorboard:
262 warnings.warn(
263 "Nothing is written to Tensorboard for this model. "
264 "Please change parameters.tensorboard to True."
265 )
266 return
267 if any([isinstance(i, ImageCallback) for i in autoencoder.callbacks]):
268 warnings.warn(
269 f"This instance of {autoencoder.__class__.__name__} already has an "
270 f"`ImageCallback`. While it's allowed to have multiple ImageCallbacks, "
271 f"training performance suffers by adding more."
272 )
273 if image_step is None:
274 image_step = autoencoder.p.summary_step
276 if mpl_scatter_kws is None:
277 mpl_scatter_kws = {"s": 20}
278 if mpl_hist_kws is None:
279 mpl_hist_kws = {"bins": 50}
280 if plotly_scatter_kws is None:
281 plotly_scatter_kws = {"size_max": 1, "opacity": 0.2}
282 if plotly_hist_kws is None:
283 plotly_hist_kws = {"bins": 50}
284 if backend not in ["matplotlib", "plotly"]:
285 raise Exception(f"Argument `backend` must be either 'plotly' or 'matplotlib'.")
287 autoencoder._log_images = True
288 if data is None:
289 if hasattr(autoencoder, "train_data"):
290 data = autoencoder.train_data
291 else:
292 if hasattr(autoencoder, "trajs"):
293 data = autoencoder.get_train_data_from_trajs(
294 autoencoder.trajs, autoencoder.p, max_size=max_size
295 )[1]
296 else:
297 if (
298 not autoencoder.p.use_backbone_angles
299 and not autoencoder.p.use_sidechains
300 ):
301 data = [[]]
302 elif (
303 autoencoder.p.use_backbone_angles
304 and not autoencoder.p.use_sidechains
305 ):
306 data = [[], []]
307 elif autoencoder.p.use_sidechains and autoencoder.p.use_backbone_angles:
308 data = [[], [], []]
309 else:
310 raise Exception(
311 "Only allowed combinations are:\n"
312 " * No sidechains, no backbone angles\n"
313 " * No sidechains, yes backbone angles\n"
314 " * Yes Sidechains, yes backbone angles\n"
315 f"Your parameters are: {autoencoder.p.use_sidechains=}. {autoencoder.p.use_backbone_angles=}"
316 )
317 length = 0
318 for d in autoencoder.dataset:
319 if len(data) == 1:
320 data[0].append(d[1])
321 elif len(data) == 2:
322 data[0].append(d[0])
323 data[1].append(d[1])
324 elif len(data) == 3:
325 data[0].append(d[0])
326 data[1].append(d[1])
327 data[2].append(d[-1])
328 length += autoencoder.p.batch_size
329 if length > max_size:
330 break
331 for i, o in enumerate(data):
332 if any([isinstance(d, tf.sparse.SparseTensor) for d in o]):
333 o = [
334 tf.sparse.to_dense(_, default_value=np.nan).numpy()
335 for _ in o
336 ]
337 o = np.concatenate(o)
338 data[i] = np_to_sparse_tensor(o)
339 else:
340 data[i] = np.concatenate(o)
341 else:
342 max_size = -1
344 if isinstance(data, (np.ndarray, tf.sparse.SparseTensor)):
345 if hasattr(autoencoder, "_tensorboard_data_req_shape"):
346 assert np.array_equal(
347 tf.shape(data).numpy()[1:], autoencoder._tensorboard_data_req_shape[1:]
348 ), (
349 f"The provided `data` has the wrong shape. The provided data has "
350 f"shape {tf.shape(data).numpy()}, whereas {autoencoder._tensorboard_data_req_shape} "
351 f"was expected."
352 )
353 else:
354 for d in autoencoder.dataset:
355 break
356 if isinstance(data, (tuple, list)):
357 assert data[0].shape[1:] == d[1].shape[1:], (
358 f"The provided `data` has the wrong shape. The provided data has "
359 f"shape {data[0].shape[1:]}, whereas {d[1].shape[1:]} "
360 f"was expected."
361 )
362 else:
363 assert data.shape[1:] == d[1].shape[1:], (
364 f"The provided `data` has the wrong shape. The provided data has "
365 f"shape {data[0].shape[1:]}, whereas {d[1].shape[1:]} "
366 f"was expected."
367 )
368 if data.shape[0] > max_size and max_size >= 0:
369 idx = np.unique(
370 np.round(np.linspace(0, data.shape[0] - 1, max_size)).astype(int)
371 )
372 if isinstance(data, tf.sparse.SparseTensor):
373 data = tf.sparse.to_dense(data, default_value=np.nan).numpy()[idx]
374 data = np_to_sparse_tensor(data)
375 else:
376 data = data[idx]
377 if isinstance(data, np.ndarray):
378 if np.any(np.isnan(data)):
379 data = np_to_sparse_tensor(data)
380 elif isinstance(data, (tuple, list)):
381 for d in autoencoder.dataset:
382 break
383 if len(data) == 1:
384 assert data[0].shape[1:] == d[1].shape[1:], (
385 f"The provided `data` has the wrong shape. The provided data has "
386 f"shape {data[0].shape[1:]}, whereas {d[1].shape[1:]} "
387 f"was expected."
388 )
389 data = data[0]
390 elif len(data) == 2:
391 assert (
392 data[0].shape[1:] == d[0].shape[1:]
393 and data[1].shape[1:] == d[1].shape[1:]
394 ), (
395 f"The provided `data` has the wrong shape. The provided data has "
396 f"shape {[_.shape[1:] for _ in data]}, whereas {[d[0].shape[1:], d[1].shape[1:]]} "
397 f"was expected."
398 )
399 elif len(data) == 3:
400 assert (
401 data[0].shape[1:] == d[0].shape[1:]
402 and data[1].shape[1:] == d[1].shape[1:]
403 and data[2].shape[1:] == d[-1].shape[1:]
404 ), (
405 f"The provided `data` has the wrong shape. The provided data has "
406 f"shape {[_.shape[1:] for _ in data]}, whereas {[d[0].shape[1:], d[1].shape[1:], d[-1].shape[1:]]} "
407 f"was expected."
408 )
409 else:
410 raise TypeError(
411 f"Argument `data` should be of type None, np.ndarray, tuple, or "
412 f"list, you provided {type(data)}."
413 )
415 # add the callback
416 if save_to_disk:
417 save_dir = Path(autoencoder.p.main_path) / "train_images"
418 save_dir.mkdir(exist_ok=True)
419 else:
420 save_dir = None
422 autoencoder.callbacks.append(
423 ImageCallback(
424 parameters=autoencoder.p,
425 highd_data=data,
426 image_step=image_step,
427 backend=backend,
428 mpl_scatter_kws=mpl_scatter_kws,
429 mpl_hist_kws=mpl_hist_kws,
430 plotly_scatter_kws=plotly_scatter_kws,
431 plotly_hist_kws=plotly_hist_kws,
432 additional_fns=additional_fns,
433 when=when,
434 save_dir=save_dir,
435 )
436 )
437 autoencoder.callbacks[-1].model = autoencoder.model
438 if isinstance(data, (np.ndarray, tf.sparse.SparseTensor)):
439 print(
440 f"Logging images with {data.shape}-shaped data every "
441 f"{image_step} epochs to Tensorboard at {autoencoder.p.main_path}"
442 )
443 else:
444 print(
445 f"Logging images with {[i.shape for i in data]}-shaped data "
446 f"every {image_step} epochs to Tensorboard at {autoencoder.p.main_path}"
447 )
450def _print_save_message(autoencoder: AutoencoderClass) -> None:
451 if autoencoder.p.main_path == Path(os.getcwd()):
452 print(
453 f"Output files are saved to {autoencoder.p.main_path}, which is the "
454 f"current working trajectory."
455 )
456 else:
457 print(
458 f"Output files are saved to {autoencoder.p.main_path} as defined "
459 f"in 'main_path' in the parameters.",
460 )
463def _get_model(autoencoder: AutoencoderClass) -> tf.keras.Model:
464 """sets self.model according to `model_api` argument in self.parameters."""
465 model = autoencoder.p.model_api
466 if model == "functional":
467 assert isinstance(autoencoder, AngleDihedralCartesianEncoderMap)
468 d = autoencoder.dataset.take(1)
469 if any(isinstance(_, tf.SparseTensorSpec) for _ in d.element_spec):
470 autoencoder.sparse = True
471 if hasattr(autoencoder.p, "reconstruct_sidechains"):
472 if autoencoder.p.reconstruct_sidechains:
473 assert len(d.element_spec) == 7
474 model = gen_functional_model(
475 autoencoder.dataset,
476 autoencoder.p,
477 sparse=autoencoder.sparse,
478 )
479 elif model == "sequential":
480 assert (
481 isinstance(autoencoder, (Autoencoder, EncoderMap, DihedralEncoderMap))
482 or autoencoder.__class__.__name__ == "EncoderMap"
483 )
484 if isinstance(autoencoder.train_data, tf.sparse.SparseTensor):
485 autoencoder.sparse = True
486 try:
487 model = gen_sequential_model(
488 autoencoder.train_data.shape[1],
489 autoencoder.p,
490 sparse=autoencoder.sparse,
491 )
492 except AttributeError:
493 if autoencoder.p.training == "custom":
494 for d in autoencoder.train_data:
495 break
496 model = gen_sequential_model(
497 d[0].get_shape().as_list()[1],
498 autoencoder.p,
499 sparse=autoencoder.sparse,
500 )
501 elif autoencoder.p.training == "auto":
502 for d, _ in autoencoder.train_data:
503 break
504 model = gen_sequential_model(
505 d.get_shape().as_list()[1],
506 autoencoder.p,
507 sparse=autoencoder.sparse,
508 )
509 else:
510 raise Exception(
511 f"Parameter `training` has to be one of 'custom', 'auto'. "
512 f"You supplied '{autoencoder.p.training}'."
513 )
515 elif model == "custom":
516 raise NotImplementedError("No custom API currently supported")
517 else:
518 raise ValueError(
519 f"API argument needs to be one of `functional`, `sequential`, "
520 f"`custom`. You provided '{model}'."
521 )
522 assert not isinstance(model, str)
523 return model
526##############################################################################
527# Function definition which allows self.p.tensorboard to be passed
528##############################################################################
531def function(f, tensorboard=False):
532 """Compiles functions with `tensorflow.function` based on a `tensorboard`
533 parameter.
536 To understand the necessity of this function, we need to have a look at how
537 tensorflow executes computations. There are two modes of execution:
538 * eager mode: In eager mode, the computations are handled by python.
539 The input types are python objects, and the output is a python object.
540 This eager execution allows you to directly execute a calculation of
541 two tensors (e.g. multiplication).
542 * graph mode: In graph mode, computations are done inside tensorflow graphs,
543 which are a collection of operations and tensors (i.e. data), that flow
544 through the operations of the graph. These graphs, make tensorflow
545 computations portable and significantly increase the performance of
546 similar computations.
547 Normally, you would accelerate a simple python function in tensorflow,
548 by compiling it like so:
550 ```python
551 import tensorflow as tf
553 @tf.function
554 def multiply(a, b):
555 return a * b
557 multiply(tf.constant(2), tf.constant(3)).numpy()
558 # 6
559 ```
561 However, the basic paradigm of accelerating the computation interferes with
562 `encodermap.Parameters` `tensorboard=True` argument, as it writes a lot of
563 additional information to tensorboard. Thus, a compilation with tf.function
564 does not make sense here. That's why EncoderMap's `function` decorator
565 takes an additional argument:
567 """
569 def wrapper(*args, **kwargs):
570 tensorboard = kwargs.pop("tensorboard", False)
571 """Wrapper of `encodermap.function`."""
572 if tensorboard:
573 warnings.warn(
574 "Running in tensorboard mode writes a lot of stuff to "
575 "tensorboard. For speed boost deactivate tensorboard mode."
576 )
577 result = f(*args, **kwargs)
578 else:
579 compiled = tf.function(f)
580 result = compiled(*args, **kwargs)
581 return result
583 return wrapper
586##############################################################################
587# Public Classes
588##############################################################################
591class Autoencoder:
592 """Main Autoencoder class. Presents all high-level functions.
594 This is the main class for neural networks inside EncoderMap. The class
595 prepares the data (batching and shuffling), creates a `tf.keras.Model`
596 of layers specified by the attributes of the `encodermap.Parameters` class.
597 Depending on what Parent/Child-Class is instantiated, a combination of
598 various cost functions is set up. Callbacks to Tensorboard are also set up.
600 Attributes:
601 train_data (np.ndarray): The numpy array of the train data passed at init.
602 p (AnyParameters): An `encodermap.Parameters` class
603 containing all info needed to set up the network.
604 dataset (tensorflow.data.Dataset): The dataset that is actually used
605 in training the keras model. The dataset is a batched, shuffled,
606 infinitely-repeating dataset.
607 read_only (bool): Variable telling the class whether it is allowed to
608 write to disk (False) or not (True).
609 metrics (list[Any]): A list of metrics passed to the model when it is compiled.
610 callbacks (list[Any]): A list of tf.keras.callbacks.Callback subclasses
611 changing the behavior of the model during training.
612 Some standard callbacks are always present like:
613 * encodermap.callbacks.callbacks.ProgressBar:
614 A progress bar callback using tqdm giving the current
615 progress of training and the current loss.
616 * CheckPointSaver:
617 A callback that saves the model every
618 `parameters.checkpoint_step` steps into the main directory.
619 This callback will only be used, when `read_only` is False.
620 * TensorboardWriteBool:
621 A callback that contains a boolean Tensor that will be
622 True or False, depending on the current training step and
623 the summary_step in the parameters class. The loss functions
624 use this callback to decide whether they should write to
625 Tensorboard. This callback will only be present when
626 `read_only` is False and `parameters.tensorboard` is True.
627 You can append your own callbacks to this list before executing
628 `self.train()`.
629 encoder (tf.keras.Model): The encoder submodel of `self.model`.
630 decoder (tf.keras.Model): The decoder submodel of `self.model`.
631 loss (Sequence[Callable]): A list of loss functions passed to the model
632 when it is compiled. When the main `Autoencoder` class is used and
633 `parameters.loss` is 'emap_cost', this list comprises center_cost,
634 regularization_cost, auto_cost. When the `EncoderMap` sub-class is
635 used and `parameters.loss` is 'emap_cost', distance_cost is added to
636 the list. When `parameters.loss` is not 'emap_cost', the loss can either
637 be a string ('mse'), or a function, that both are acceptable
638 arguments for loss, when a keras model is compiled.
640 Methods:
641 from_checkpoint: Rebuild the model from a checkpoint.
642 add_images_to_tensorboard: Make tensorboard plot images.
643 train: Starts the training of the tf.keras.models.Model.
644 plot_network: Tries to plot the network. For this method to work
645 graphviz, pydot and pydotplus need to be installed.
646 encode: Takes high-dimensional data and sends it through the encoder.
647 decode: Takes low-dimensional data and sends it through the encoder.
648 generate: Same as `decode`. For AngleDihedralCartesianAutoencoder classes,
649 this will build a protein strutcure.
651 Note:
652 Performance of tensorflow is not only dependent on your system's
653 hardware and how the data is presented to the network
654 (for this check out https://www.tensorflow.org/guide/data_performance),
655 but also how you compiled tensorflow. Normal tensorflow
656 (pip install tensorflow) is build without CPU extensions to work on
657 many CPUs. However, Tensorflow can greatly benefit from using CPU
658 instructions like AVX2, AVX512 that bring a speed-up in linear algebra
659 computations of 300%. By building tensorflow from source,
660 you can activate these extensions. However, the speed-up of using
661 tensorflow with a GPU dwarfs the CPU speed-up. To check whether a
662 GPU is available run: `print(len(tf.config.list_physical_devices('GPU')))`.
663 Refer to these pages to install tensorflow for the best performance:
664 https://www.tensorflow.org/install/pip and
665 https://www.tensorflow.org/install/gpu
667 Examples:
668 >>> import encodermap as em
669 >>> # without providing any data, default parameters and a 4D
670 >>> # hypercube as input data will be used.
671 >>> e_map = em.EncoderMap(read_only=True)
672 >>> print(e_map.train_data.shape)
673 (16000, 4)
674 >>> print(e_map.dataset) # doctest: +SKIP
675 <BatchDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float32, name=None), TensorSpec(shape=(None, 4), dtype=tf.float32, name=None))>
676 >>> print(e_map.encode(e_map.train_data).shape)
677 (16000, 2)
679 """
681 def __init__(
682 self,
683 parameters=None,
684 train_data: Optional[Union[np.ndarray, tf.data.Dataset]] = None,
685 model: Optional[tf.keras.Model] = None,
686 read_only: bool = False,
687 sparse: bool = False,
688 ) -> None:
689 """Instantiate the Autoencoder class.
691 Args:
692 parameters (Union[encodermap.Parameters, None], optional): The
693 parameters to be used. If None is provided default values
694 (check them with `print(em.Parameters.defaults_description()`))
695 are used. Defaults to None.
696 train_data (Union[np.ndarray, tf.data.Dataset, None], optional):
697 The train data. Can be one of the following:
698 * None: If None is provided points on the edges of a
699 4-dimensional hypercube will be used as train data.
700 * np.ndarray: If a numpy array is provided, it will be
701 transformed into a batched tf.data.Dataset by first
702 making it an infinitely repeating dataset, shuffling
703 it and the batching it with a batch size specified
704 by parameters.batch_size.
705 * tf.data.Dataset: If a dataset is provided it will be
706 used without making any adjustments. Make sure, that the
707 dataset uses `float32` as its type.
708 Defaults to None.
709 model (Union[tf.keras.models.Model, None], optional): Providing
710 a keras model to this argument will make the Autoencoder/EncoderMap
711 class use this model instead of the predefined ones. Make sure
712 the model can accept EncoderMap's loss functions. If None is
713 provided the model will be built using the specifications in
714 parameters. Defaults to None.
715 read_only (bool, optional): Whether the class is allowed to write
716 to disk (False) or not (True). Defaults to False and will allow
717 the class to write to disk.
719 """
720 # parameters
721 if parameters is None:
722 self.p = Parameters()
723 else:
724 self.p = parameters
726 if self.p.seed is not None:
727 tf.random.set_seed(self.p.seed)
728 self.read_only = read_only
730 if not self.read_only:
731 self.p.write_summary = True
732 self.p.save()
733 self._print_save_message()
735 # check whether Tensorboard and Read-Only makes Sense
736 if self.read_only and self.p.tensorboard:
737 raise NotImplementedError
739 # clear old sessions
740 tf.keras.backend.clear_session()
741 self.sparse = sparse
743 # set up train_data
744 self.set_train_data(train_data)
746 # create model based on user input
747 if model is None:
748 self.model = self._get_model()
749 else:
750 self.model = model
752 # setup callbacks for nice progress bars and saving every now and then
753 self._setup_callbacks()
755 # create loss based on user input
756 self.loss = self._get_loss()
758 # choose optimizer
759 self.optimizer = tf.keras.optimizers.Adam(
760 learning_rate=self.p.learning_rate, clipvalue=1.0
761 )
763 # compile model
764 self.model.compile(
765 optimizer=self.optimizer,
766 loss=self.loss,
767 metrics=self.metrics,
768 unpacked_loss_fns={fn.__name__: fn for fn in self.loss},
769 )
771 # do this if tensorboard is true.
772 if self.p.tensorboard:
773 self._log_images = False
774 # get the output from model summary.
775 with Capturing() as output:
776 self.model.summary()
777 with open(self.p.main_path + "/model_summary.txt", "w") as f:
778 f.write("\n".join(output))
779 self.plot_network()
780 print(
781 f"Saved a text-summary of the model and an image in {self.p.main_path},",
782 "as specified in 'main_path' in the parameters.",
783 )
785 # sets up the tb callback to plot the model
786 self.tb_callback = tf.keras.callbacks.TensorBoard(
787 self.p.main_path, write_graph=True
788 )
789 self.tb_callback.set_model(self.model)
791 def _print_save_message(self) -> None:
792 """Prints a save message to inform the user where the model is saved."""
793 _print_save_message(self)
795 def _print_missing_data(self) -> None:
796 print(
797 f"This model was reloaded from disk, but not yet provided with train "
798 f"data. Use the `set_train_data()` method to provide the train "
799 f"data and call `train()` again to train the model. Alternatively, "
800 f"you could directly provide the train data, when reloading by "
801 f"calling the `{self.__class__.__name__}.from_checkpoint()` "
802 f"constructor with the `train_data` argument. Expected shape = "
803 f"{self.model.encoder_model.input_shape[1]}, received shape = "
804 f"{self._tensorboard_data_req_shape} {self._using_hypercube=} "
805 f"{self.p.using_hypercube=} {self.dataset.element_spec=}"
806 )
807 return
809 def set_train_data(self, data: Union[np.ndarray, tf.data.Dataset]) -> None:
810 """Resets the train data for reloaded models."""
811 self._using_hypercube = False
812 if data is None:
813 self._using_hypercube = True
814 self.p.using_hypercube = True
815 self.train_data = create_n_cube(4, seed=self.p.seed)[0].astype("float32")
816 self.p.periodicity = float("inf")
817 elif isinstance(data, np.ndarray):
818 if np.any(np.isnan(data)):
819 self.sparse = True
820 print("Input contains nans. Using sparse network.")
821 self.train_data = np_to_sparse_tensor(data)
822 else:
823 self.train_data = data.astype("float32")
824 elif isinstance(data, tf.data.Dataset):
825 self.dataset = data
826 try:
827 _, __ = self.dataset.take(1)
828 except ValueError:
829 if self.p.training == "auto":
830 if self.p.model_api == "custom":
831 print(
832 f"It seems like your dataset only yields tensors and not "
833 f"tuples of tensors. TensorFlow is optimized for classification "
834 f"tasks, where datasets yield tuples of (data, classes). EncoderMap,"
835 f"however is a regression task, but uses the same code as the "
836 f"classification tasks. This requires the dataset provided "
837 f"for a tensorflow model.fit() method to return tuples "
838 f"of (data, classes). Your dataset does not do this. "
839 f"I will transform your dataset using "
840 f"the `tf.data.Dataset.zip()` function of `tf.data`. "
841 f"This might break your custom model. You can "
842 f"set the `training` parameter in the parameter class to "
843 f"'custom' to not alter your dataset."
844 )
845 self.dataset = tf.data.Dataset.zip((self.dataset, self.dataset))
846 _ = self.dataset.take(1)
847 else:
848 _ = self.dataset.take(1)
849 self.train_data = _
850 else:
851 raise TypeError(
852 f"train_data must be `None`, `np.ndarray` or `tf.data.Dataset`. You supplied {type(data)}."
853 )
855 # check data and periodicity
856 if not self.sparse and data is not None:
857 if isinstance(data, np.ndarray):
858 if np.any(data > self.p.periodicity):
859 raise Exception(
860 "There seems to be an error regarding the periodicity "
861 f"of your data. The chosen periodicity is {self.p.periodicity}, "
862 f"but there are datapoints outwards of this range: {data.max()}"
863 )
865 # prepare the data
866 if isinstance(self.train_data, (np.ndarray, tf.sparse.SparseTensor)):
867 if self.p.training == "auto":
868 dataset = tf.data.Dataset.from_tensor_slices(
869 (self.train_data, self.train_data)
870 )
871 else:
872 dataset = tf.data.Dataset.from_tensor_slices(self.train_data)
873 dataset = dataset.shuffle(
874 buffer_size=self.train_data.shape[0], reshuffle_each_iteration=True
875 )
876 dataset = dataset.repeat()
877 self.dataset = dataset.batch(self.p.batch_size)
878 self._tensorboard_data_req_shape = tf.shape(self.train_data).numpy()
879 else:
880 d = self.train_data.element_spec[0]
881 if isinstance(d, tuple):
882 self._tensorboard_data_req_shape = d[0].shape.as_list()
883 else:
884 self._tensorboard_data_req_shape = d.shape.as_list()
886 if isinstance(data, np.ndarray):
887 assert data.shape[1] == self._tensorboard_data_req_shape[1]
889 def _setup_callbacks(self) -> None:
890 """Sets up a list with callbacks to be passed to self.model.fit()"""
891 self.metrics = []
892 self.callbacks = []
893 self.callbacks.append(ProgressBar(parameters=self.p))
894 if not self.read_only:
895 self.callbacks.append(CheckpointSaver(self.p))
896 if self.p.tensorboard:
897 self.tensorboard_write_bool = TensorboardWriteBool(self.p)
898 self.callbacks.append(self.tensorboard_write_bool)
899 file_writer = tf.summary.create_file_writer(self.p.main_path + "/train")
900 file_writer.set_as_default()
901 tf.summary.text(
902 name=f"Parameters Summary for {self.p.main_path}",
903 data=self.p.parameters,
904 step=0,
905 )
906 else:
907 self.tensorboard_write_bool = None
909 @classmethod
910 def from_checkpoint(
911 cls: Type[AutoencoderType],
912 checkpoint_path: Union[str, Path],
913 train_data: Optional[np.ndarray] = None,
914 sparse: bool = False,
915 use_previous_model: bool = False,
916 compat: bool = False,
917 ) -> AutoencoderType:
918 """Reconstructs the class from a checkpoint.
920 Args:
921 checkpoint_path (Union[str, Path]): The path to the checkpoint. Can
922 be either a directory, in which case the most recently saved
923 model will be loaded. Or a direct .keras file, in which case, this
924 specific model will be loaded.
925 train_data (Optional[np.ndarray]). When you want to retrain this model, you
926 can provide the train data here.
927 sparse (bool): Whether the reloaded model should be sparse.
928 use_previous_model (bool): Set this flag to True, if you load a model
929 from an in-between checkpoint step (e.g., to continue training with
930 different parameters). If you have the files saved_model_0.keras,
931 saved_model_500.keras and saved_model_1000.keras, setting this to
932 True and loading the saved_model_500.keras will back up the
933 saved_model_1000.keras.
934 compat (bool): Whether to use compatibility mode when missing or wrong
935 parameter files are present. In this special case, some assumptions
936 about the network architecture are made from the model and the
937 parameters in parameters.json overwritten accordingly (a backup
938 will also be made).
940 Returns:
941 Autoencoder: Encodermap `Autoencoder` class.
943 """
944 return load_model(
945 cls,
946 checkpoint_path,
947 sparse=sparse,
948 dataset=train_data,
949 use_previous_model=use_previous_model,
950 compat=compat,
951 )
953 def _get_model(self) -> tf.keras.Model:
954 return _get_model(self)
956 @property
957 def encoder(self) -> tf.keras.Model:
958 """tf.keras.Model: Encoder part of the model."""
959 return self.model.encoder
961 @property
962 def decoder(self) -> tf.keras.Model:
963 """tf.keras.Model: Decoder part of the model."""
964 return self.model.decoder
966 def _get_loss(self):
967 """sets self.loss according to `loss` in self.parameters."""
968 loss = self.p.loss
969 if loss == "reconstruction_loss":
970 loss = reconstruction_loss(self.model)
971 elif loss == "emap_cost":
972 self.auto_loss = auto_loss(self.model, self.p, self.tensorboard_write_bool)
973 self.regularization_loss = regularization_loss(
974 self.model, self.p, self.tensorboard_write_bool
975 )
976 self.center_loss = center_loss(
977 self.model, self.p, self.tensorboard_write_bool
978 )
979 loss = [self.auto_loss, self.regularization_loss, self.center_loss]
980 elif loss == "mse":
981 loss = "mse"
982 else:
983 raise ValueError(
984 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'."
985 )
986 return loss
988 def train(self) -> Optional[dict[str, Any]]:
989 """Starts the training of the model.
991 Returns:
992 Union[tf.keras.callbacks.History, None]: If training succeeds, an
993 instance of `tf.keras.callbacks.History` is returned. If not,
994 None is returned.
996 """
997 if self.p.current_training_step >= self.p.n_steps:
998 print(
999 f"This {self.__class__.__name__} instance has already been trained "
1000 f"for {self.p.current_training_step} steps. Increase the training "
1001 f"steps by calling `{self.__class__.__name__}.p.n_steps += new_steps` "
1002 f"and then call `{self.__class__.__name__}.train()` again."
1003 )
1004 return
1006 if self._using_hypercube == self.p.using_hypercube:
1007 if self.p.periodicity == float("inf"):
1008 if (
1009 self._tensorboard_data_req_shape[1]
1010 != self.model.encoder_model.input_shape[1]
1011 ):
1012 return self._print_missing_data()
1013 else:
1014 if (
1015 self._tensorboard_data_req_shape[1]
1016 != self.model.encoder_model.input_shape[1] // 2
1017 ):
1018 return self._print_missing_data()
1019 else:
1020 return self._print_missing_data()
1022 if self.p.training == "custom" and self.p.batched:
1023 raise NotImplementedError()
1024 elif self.p.training == "custom" and not self.p.batched:
1025 raise NotImplementedError()
1026 elif self.p.training == "auto":
1027 epochs = self.p.n_steps - self.p.current_training_step
1028 try:
1029 history = self.model.fit(
1030 self.dataset,
1031 batch_size=self.p.batch_size,
1032 epochs=epochs,
1033 steps_per_epoch=1,
1034 verbose=0,
1035 callbacks=self.callbacks,
1036 ).history
1037 except ValueError:
1038 raise Exception(
1039 f"{self.model.encoder_model.input_shape=} {self._tensorboard_data_req_shape=} "
1040 f"{self.train_data.shape=} {self.dataset.element_spec=} {self.p.using_hypercube=} {self._using_hypercube=}"
1041 )
1042 else:
1043 raise ValueError(
1044 f"training argument needs to be `auto` or `custom`. You provided '{self.training}'."
1045 )
1046 self.p.current_training_step += self.p.n_steps - self.p.current_training_step
1047 self.p.save()
1048 self.save()
1050 history_to_json_serializable(history)
1051 return history
1053 def add_loss(self, loss):
1054 """Adds a new loss to the existing losses."""
1055 try:
1056 self.loss.append(loss(self.model, self.p))
1057 except TypeError:
1058 self.loss.append(loss(self.model))
1060 def add_callback(self, callback):
1061 """Adds a new callback to the existing callbacks."""
1062 try:
1063 self.callbacks.append(callback(self.p))
1064 except TypeError:
1065 self.callbacks.append(callback)
1067 def add_metric(self, metric):
1068 """Adds a new metric to the existing metrics."""
1069 self.metrics.append(metric(self.p))
1071 def add_images_to_tensorboard(self, *args: Any, **kwargs: Any) -> None:
1072 """Adds images of the latent space to tensorboard.
1074 Args:
1075 data (Optional[Union[np.ndarray, Sequence[np.ndarray]]): The input-data will
1076 be passed through the encoder part of the autoencoder. If None
1077 is provided, a set of 10_000 points from `self.train_data` will
1078 be taken. A list[np.ndarray] is needed for the functional API of the
1079 `AngleDihedralCartesianEncoderMap`, that takes a list of
1080 [angles, dihedrals, side_dihedrals]. Defaults to None.
1081 image_step (Optional[int]): The interval in which to plot
1082 images to tensorboard. If None is provided, the `image_step`
1083 will be the same as `Parameters.summary_step`. Defaults to None.
1084 max_size (int): The maximum size of the high-dimensional data, that is
1085 projected. Prevents excessively large-datasets from being projected
1086 at every `image_step`. Defaults to 10_000.
1087 scatter_kws (Optional[dict[str, Any]]): A dict with items that
1088 `plotly.express.scatter()` will accept. If None is provided,
1089 a dict with size 20 will be passed to
1090 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`,
1091 which sets an appropriate size of scatter points for the size of
1092 datasets encodermap is usually used for.
1093 hist_kws (Optional[dict[str, Any]]): A dict with items that
1094 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a
1095 dict with bins 50 will be passed to
1096 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`.
1097 You can choose a colormap here by providing `{'bins': 50, 'cmap':
1098 'plasma'}` for this argument.
1099 additional_fns (Optional[Sequence[Callable]]): A list of functions
1100 that will accept the low-dimensional output of the `Autoencoder`
1101 latent/bottleneck layer and return a tf.Tensor that can be logged
1102 by `tf.summary.image()`. See the notebook
1103 'writing_custom_images_to_tensorboard.ipynb' in
1104 tutorials/notebooks_customization for more info. If None is
1105 provided, no additional functions will be used to plot to
1106 tensorboard. Defaults to None.
1107 when (Literal["epoch", "batch"]): When to log the images can be
1108 either 'batch', then the images will be logged after every step
1109 during training, or 'epoch', then only after every image_step
1110 epoch the images will be written. Defaults to 'epoch'.
1111 save_to_disk (bool): Whether to also write the images to disk.
1113 """
1114 _add_images_to_tensorboard(self, *args, **kwargs)
1116 def plot_network(self) -> None:
1117 """Tries to plot the network using pydot, pydotplus and graphviz.
1118 Doesn't raise an exception if plotting is not possible.
1120 Note:
1121 Refer to this guide to install these programs:
1122 https://stackoverflow.com/questions/47605558/importerror-failed-to-import-pydot-you-must-install-pydot-and-graphviz-for-py
1124 """
1125 try:
1126 out = plot_model(self.model, self.train_data.shape[1])
1127 if out is not None:
1128 out.save(Path(self.p.main_path) / "network.png")
1129 except:
1130 pass
1132 def encode(self, data: Optional[Sequence[np.ndarray]] = None) -> np.ndarray:
1133 """Calls encoder part of `self.model`.
1135 Args:
1136 data (Optional[np.ndarray]): The data to be passed top the encoder part.
1137 It can be either numpy ndarray or None. If None is provided,
1138 a set of 10000 points from the provided train data will be taken.
1139 Defaults to None.
1141 Returns:
1142 np.ndarray: The output from the bottleneck/latent layer.
1144 """
1145 if data is None:
1146 data = self.train_data
1148 # check the shapes:
1149 if not isinstance(data, (list, tuple)):
1150 if data.shape[1] * 2 == self.model.encoder_model.input_shape[1]:
1151 out = self.model.encoder(data)
1152 elif data.shape[1] == self.model.encoder_model.input_shape[1]:
1153 out = self.model.encoder_model(data)
1154 else:
1155 msg = (
1156 f"The shape of the provided data {data.shape=} does not "
1157 f"match the expected shape {self.model.encoder_model.input_shape=}."
1158 )
1159 if self.p.periodicity < float("inf"):
1160 msg += f" Not even considering the periodicity of {self.p.periodicity}."
1161 raise Exception(msg)
1162 else:
1163 for d, in_shape in zip(data, self.model.encoder_model.input_shape):
1164 assert d.shape[1] == in_shape[1], (
1165 f"The shape of the provided data ({d.shape}) does not match "
1166 f"the expected shape {in_shape}."
1167 )
1168 out = self.model.encoder_model(data)
1170 if isinstance(out, (list, tuple)):
1171 out = [o.numpy() for o in out]
1172 else:
1173 out = out.numpy()
1174 return out
1176 def generate(self, data: np.ndarray) -> np.ndarray:
1177 """Duplication of `self.decode`.
1179 In `Autoencoder` and `EncoderMap` this method is equivalent to `decode()`.
1180 In `AngleDihedralCartesianEncoderMap` this method will be overwritten
1181 to produce output molecular conformations.
1183 Args:
1184 data (np.ndarray): The data to be passed to the decoder part of the
1185 model. Make sure that the shape of the data matches the number
1186 of neurons in the latent space.
1188 Returns:
1189 np.ndarray: Outputs from the decoder part. For
1190 `AngleDihedralCartesianEncoderMap`, this will either be a
1191 `mdtraj.Trajectory` or `MDAnalysis.Universe`.
1192 """
1193 return self.model.decoder(data)
1195 def decode(self, data: np.ndarray) -> Sequence[np.ndarray]:
1196 """Calls the decoder part of the model.
1198 `AngleDihedralCartesianAutoencoder` will, like the other two classes'
1199 output a list of np.ndarray.
1201 Args:
1202 data (np.ndarray): The data to be passed to the decoder part of
1203 the model. Make sure that the shape of the data matches the
1204 number of neurons in the latent space.
1206 Returns:
1207 Union[list[np.ndarray], np.ndarray]: Outputs from the decoder part.
1208 For `AngleDihedralCartesianEncoderMap`, this will be a list of
1209 np.ndarray.
1211 """
1212 out = self.model.decoder(data)
1213 if isinstance(out, (list, tuple)):
1214 out = [o.numpy() for o in out]
1215 else:
1216 out = out.numpy()
1217 return out
1219 def save(self, step: Optional[int] = None) -> None | Path:
1220 """Saves the model to the current path defined in `parameters.main_path`.
1222 Args:
1223 step (Optional[int]): Does not save the model at the given
1224 training step, but rather changes the string used for saving
1225 the model from a datetime format to another.
1227 Returns:
1228 Union[None, Path]: When the model has been saved, the Path will
1229 be returned. If the model could not be saved. None will be
1230 returned.
1232 """
1233 if not self.read_only:
1234 return save_model(
1235 self.model,
1236 self.p.main_path,
1237 inp_class_name=self.__class__.__name__,
1238 step=step,
1239 print_message=True,
1240 ).parent
1241 else:
1242 print(
1243 f"This {self.__class__.__name__} is set to read_only. Set "
1244 f"`{self.__class__.__name__}.read_only=False` to save the "
1245 f"current state of the model."
1246 )
1248 def close(self) -> None:
1249 """Clears the current keras backend and frees up resources."""
1250 # clear old sessions
1251 tf.keras.backend.clear_session()
1254class EncoderMap(Autoencoder):
1255 """Complete copy of Autoencoder class but uses additional distance cost
1256 scaled by the SketchMap sigmoid params"""
1258 @classmethod
1259 def from_checkpoint(
1260 cls: Type[EncoderMapType],
1261 checkpoint_path: Union[str, Path],
1262 train_data: Optional[np.ndarray] = None,
1263 sparse: bool = False,
1264 use_previous_model: bool = False,
1265 compat: bool = False,
1266 ) -> EncoderMapType:
1267 """Reconstructs the class from a checkpoint.
1269 Args:
1270 checkpoint_path (Union[str, Path]): The path to the checkpoint. Can
1271 be either a directory, in which case the most recently saved
1272 model will be loaded. Or a direct .keras file, in which case, this
1273 specific model will be loaded.
1274 train_data (Optional[np.ndarray]). When you want to retrain this model, you
1275 can provide the train data here.
1276 sparse (bool): Whether the reloaded model should be sparse.
1277 use_previous_model (bool): Set this flag to True, if you load a model
1278 from an in-between checkpoint step (e.g., to continue training with
1279 different parameters). If you have the files saved_model_0.keras,
1280 saved_model_500.keras and saved_model_1000.keras, setting this to
1281 True and loading the saved_model_500.keras will back up the
1282 saved_model_1000.keras.
1283 compat (bool): Whether to use compatibility mode when missing or wrong
1284 parameter files are present. In this special case, some assumptions
1285 about the network architecture are made from the model and the
1286 parameters in parameters.json overwritten accordingly (a backup
1287 will also be made).
1289 Returns:
1290 EncoderMap: EncoderMap `EncoderMap` class.
1292 """
1293 return load_model(
1294 cls,
1295 checkpoint_path,
1296 sparse=sparse,
1297 dataset=train_data,
1298 use_previous_model=use_previous_model,
1299 )
1301 def _get_loss(self):
1302 loss = self.p.loss
1303 if loss == "reconstruction_loss":
1304 loss = reconstruction_loss(self.model)
1305 elif loss == "emap_cost":
1306 self.auto_loss = auto_loss(self.model, self.p, self.tensorboard_write_bool)
1307 self.regularization_loss = regularization_loss(
1308 self.model, self.p, self.tensorboard_write_bool
1309 )
1310 self.center_loss = center_loss(
1311 self.model, self.p, self.tensorboard_write_bool
1312 )
1313 # this cost is new
1314 self.distance_loss = distance_loss(
1315 self.model, self.p, self.tensorboard_write_bool
1316 )
1317 loss = [
1318 self.auto_loss,
1319 self.regularization_loss,
1320 self.center_loss,
1321 self.distance_loss,
1322 ]
1323 elif loss == "mse":
1324 loss = "mse"
1325 else:
1326 raise ValueError(
1327 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'."
1328 )
1329 return loss
1332class DihedralEncoderMap(EncoderMap):
1333 """Similar to the `EncoderMap` class, but overwrites the `generate` method.
1335 Using this class, instead of tbe `EncoderMap` class, the `generate` method,
1336 needs an additional argument: `top`, which should be a topology file. This
1337 topology will be used as a base on which the dihedrals of the `decode`
1338 method are applied.
1340 """
1342 def generate(
1343 self,
1344 data: np.ndarray,
1345 top: Union[Path, str],
1346 ) -> Universe:
1347 """Overwrites `EncoderMap`'s generate method and actually does
1348 backmapping if a list of dihedrals is provided.
1350 Args:
1351 data (np.ndarray): The low-dimensional/latent/bottleneck data.
1352 A ndim==2 numpy array with xy coordinates of points in latent space.
1353 top (str): Topology file for this run of EncoderMap (can be .pdb, .gro, .. etc.).
1355 Returns:
1356 MDAnalysis.Universe: The topology with the provided backbone torsions.
1358 Examples:
1359 >>> # get some time-resolved pdb files
1360 >>> import requests
1361 >>> import numpy as np
1362 >>> pdb_link = 'https://files.rcsb.org/view/1YUF.pdb'
1363 >>> contents = requests.get(pdb_link).text
1364 >>> print(contents.splitlines()[0]) # doctest: +SKIP
1365 HEADER GROWTH FACTOR 01-APR-96 1YUF
1366 >>> # fake a file with stringio
1367 >>> from io import StringIO
1368 >>> import MDAnalysis as mda
1369 >>> import numpy as np
1370 >>> file = StringIO(contents)
1371 >>> # pass it to MDAnalysis
1372 >>> u = mda.Universe(file, format='PDB')
1373 >>> print(u)
1374 <Universe with 720 atoms>
1375 >>> # select the atomgroups
1376 >>> ags = [*[res.psi_selection() for res in u.residues],
1377 ... *[res.omega_selection() for res in u.residues],
1378 ... *[res.phi_selection() for res in u.residues]
1379 ... ]
1380 >>> # filter Nones
1381 >>> ags = list(filter(lambda x: False if x is None else True, ags))
1382 >>> print(ags[0][0]) # doctest: +SKIP
1383 <Atom 3: C of type C of resname VAL, resid 1 and segid A and altLoc >
1384 >>> # Run dihedral Angles
1385 >>> from MDAnalysis.analysis.dihedrals import Dihedral
1386 >>> R = np.deg2rad(Dihedral(ags).run().results.angles)
1387 >>> print(R.shape)
1388 (16, 147)
1389 >>> # import EncoderMap and define parameters
1390 >>> from encodermap.autoencoder import DihedralEncoderMap
1391 >>> import encodermap as em
1392 >>> parameters = em.Parameters(
1393 ... dist_sig_parameters = (4.5, 12, 6, 1, 2, 6),
1394 ... periodicity = 2*np.pi,
1395 ... l2_reg_constant = 10.0,
1396 ... summary_step = 5,
1397 ... tensorboard = False,
1398 ... )
1399 >>> e_map = DihedralEncoderMap(parameters, R, read_only=True)
1400 >>> print(e_map.__class__.__name__)
1401 DihedralEncoderMap
1402 >>> # get some low-dimensional data
1403 >>> lowd = np.random.random((100, 2))
1404 >>> # use the generate method to get a new MDAnalysis universe
1405 >>> # but first remove the time resolution
1406 >>> file = StringIO(contents.split('MODEL 2')[0])
1407 >>> new = e_map.generate(lowd, file)
1408 >>> print(new.trajectory.coordinate_array.shape)
1409 (100, 720, 3)
1410 >>> # check whether frame 0 of u and new_u are different
1411 >>> for ts in u.trajectory:
1412 ... a1 = ts.positions
1413 ... break
1414 >>> print(np.array_equal(a1, new.trajectory.coordinate_array[0]))
1415 False
1417 """
1418 assert np.any(data)
1419 dihedrals = self.decode(data)
1420 assert np.any(dihedrals)
1421 uni = dihedral_backmapping(top, dihedrals)
1422 return uni
1425class AngleDihedralCartesianEncoderMap:
1426 """Different `__init__` method, than Autoencoder Class. Uses callbacks to tune-in cartesian cost.
1428 Overwritten methods: `_set_up_callbacks` and `generate`.
1430 Examples:
1431 >>> import encodermap as em
1432 >>> from pathlib import Path
1433 >>> # Load two trajectories
1434 >>> test_data = Path(em.__file__).parent.parent / "tests/data"
1435 >>> test_data.is_dir()
1436 True
1437 >>> xtcs = [test_data / "1am7_corrected_part1.xtc", test_data / "1am7_corrected_part2.xtc"]
1438 >>> tops = [test_data / "1am7_protein.pdb", test_data /"1am7_protein.pdb"]
1439 >>> trajs = em.load(xtcs, tops)
1440 >>> print(trajs)
1441 encodermap.TrajEnsemble object. Current backend is no_load. Containing 2 trajectories. Not containing any CVs.
1442 >>> # load CVs
1443 >>> # This step can be omitted. The AngleDihedralCartesianEncoderMap class automatically loads CVs
1444 >>> trajs.load_CVs('all')
1445 >>> print(trajs.CVs['central_cartesians'].shape)
1446 (51, 474, 3)
1447 >>> print(trajs.CVs['central_dihedrals'].shape)
1448 (51, 471)
1449 >>> # create some parameters
1450 >>> p = em.ADCParameters(periodicity=360, use_backbone_angles=True, use_sidechains=True,
1451 ... cartesian_cost_scale_soft_start=(6, 12))
1452 >>> # Standard is functional model, as it offers more flexibility
1453 >>> print(p.model_api)
1454 functional
1455 >>> print(p.distance_cost_scale)
1456 None
1457 >>> # Instantiate the class
1458 >>> e_map = em.AngleDihedralCartesianEncoderMap(trajs, p, read_only=True) # doctest: +ELLIPSIS
1459 Model...
1460 >>> # dataset contains these inputs:
1461 >>> # central_angles, central_dihedrals, central_cartesians, central_distances, sidechain_dihedrals
1462 >>> print(e_map.dataset) # doctest: +SKIP
1463 <BatchDataset element_spec=(TensorSpec(shape=(None, 472), dtype=tf.float32, name=None), TensorSpec(shape=(None, 471), dtype=tf.float32, name=None), TensorSpec(shape=(None, 474, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 473), dtype=tf.float32, name=None), TensorSpec(shape=(None, 316), dtype=tf.float32, name=None))>
1464 >>> # output from the model contains the following data:
1465 >>> # out_angles, out_dihedrals, back_cartesians, pairwise_distances of inp cartesians, pairwise of back-mapped cartesians, out_side_dihedrals
1466 >>> for data in e_map.dataset.take(1):
1467 ... pass
1468 >>> out = e_map.model(data)
1469 >>> print([i.shape for i in out]) # doctest: +SKIP
1470 [TensorShape([256, 472]), TensorShape([256, 471]), TensorShape([256, 474, 3]), TensorShape([256, 112101]), TensorShape([256, 112101]), TensorShape([256, 316])]
1471 >>> # get output of latent space by providing central_angles, central_dihedrals, sidehcain_dihedrals
1472 >>> latent = e_map.encoder([data[0], data[1], data[-1]])
1473 >>> print(latent.shape)
1474 (256, 2)
1475 >>> # Rebuild central_angles, central_dihedrals and sidechain_angles from latent
1476 >>> dih, ang, side_dih = e_map.decode(latent)
1477 >>> print(dih.shape, ang.shape, side_dih.shape)
1478 (256, 472) (256, 471) (256, 316)
1480 """
1482 def __init__(
1483 self,
1484 trajs: Optional[TrajEnsemble] = None,
1485 parameters: Optional[ADCParameters] = None,
1486 model: Optional[tf.keras.Model] = None,
1487 read_only: bool = False,
1488 dataset: Optional[tf.data.Dataset] = None,
1489 ensemble: bool = False,
1490 use_dataset_when_possible: bool = True,
1491 deterministic: bool = False,
1492 ) -> None:
1493 """Instantiate the `AngleDihedralCartesianEncoderMap` class.
1495 Args:
1496 trajs (Optional[TrajEnsemble]): The trajectories to be used as input.
1497 If trajs contain no CVs, correct CVs will be loaded. can be None,
1498 in which case the argument `dataset` should be provided.
1499 Defaults to None.
1500 parameters (Optional[em.ADCParameters]): The parameters for the
1501 current run. Can be set to None and the default parameters will
1502 be used. Defaults to None.
1503 model (Optional[tf.keras.models.Model]): The keras model to use. You
1504 can provide your own model with this argument. If set to None,
1505 the model will be built to the specifications of parameters using
1506 either the functional API. Defaults to None,
1507 read_only (bool): Whether to write anything to disk
1508 (False) or not (True). Defaults to False.
1509 dataset (Optional[tf.data.Dataset]): The `dataset` argument takes
1510 precedent over the `trajs` argument. If None, the dataset will
1511 be constructed from the `trajs` argument (see
1512 `em.trajinfo.TrajEnsemble.tf_dataset` for more info). Defaults
1513 to None.
1514 ensemble (bool): Whether to allow non-defined features when
1515 featurizing the provided `trajs`. Only takes effect, when
1516 the `trajs` don't already have the features (central_cartesians,
1517 central_distances, central_angles, central_dihedrals, side_dihedrals)
1518 loaded. Defaults to False.
1519 use_dataset_when_possible (bool): Whether to use the `trajs` method
1520 `tf_dataset()` to get a dataset for training or constructy a
1521 dataset from the `trajs` CVs numpy arrays. For large datasets the
1522 first method can be advantageous as not all data will end up in
1523 memory and the dataset can be larger than the memory allows. For
1524 small datasets the second method is faster, as all data is in
1525 memory. Defaults to True.
1527 """
1528 # parameters
1529 if parameters is None:
1530 self.p = ADCParameters()
1531 else:
1532 assert isinstance(parameters, ADCParameters), (
1533 f"Please provide an instance of `ADCParameters` for the argument "
1534 f"'parameters' and not {type(parameters)}."
1535 )
1536 self.p = parameters
1538 # seed
1539 if self.p.seed is not None:
1540 tf.random.set_seed(self.p.seed)
1542 # check some sidechain
1543 if self.p.reconstruct_sidechains:
1544 assert (
1545 self.p.use_sidechains
1546 ), "If you want to reconstruct sidechains, you should also set `use_sidechains` to True."
1548 # read_only
1549 self.read_only = read_only
1551 # save params and create dir
1552 if not self.read_only:
1553 self.p.write_summary = True
1554 self.p.save()
1555 self._print_save_message()
1557 # check whether Tensorboard and read_only make sense
1558 if self.read_only and self.p.tensorboard:
1559 raise Exception("Can't use tensorboard, when `read_only` is set to True.")
1561 # clear old sessions
1562 tf.keras.backend.clear_session()
1564 # get the CVs:
1565 if trajs is not None:
1566 if trajs.__class__.__name__ == "SingleTraj":
1567 trajs = trajs._gen_ensemble()
1568 self.trajs = trajs
1570 # add the sidechain_info if sidechains need to be reconstructed
1571 if self.p.reconstruct_sidechains:
1572 self.p.sidechain_info = self.trajs.sidechain_info()
1574 # decide on the dataset
1575 if (
1576 all([traj._traj_file.suffix in [".h5", ".nc"] for traj in trajs])
1577 and trajs.CVs_in_file
1578 and use_dataset_when_possible
1579 ):
1580 # if all CVs in a h5 file, we can load get batches from there
1581 dataset = trajs.tf_dataset(
1582 batch_size=self.p.batch_size,
1583 sidechains=self.p.use_sidechains,
1584 reconstruct_sidechains=self.p.reconstruct_sidechains,
1585 deterministic=deterministic,
1586 )
1587 self.inp_CV_data = trajs.CVs
1589 else:
1590 # if not, we need to load them
1591 if not self.p.reconstruct_sidechains:
1592 should_be = {
1593 "central_angles",
1594 "central_cartesians",
1595 "central_dihedrals",
1596 "central_distances",
1597 "side_dihedrals",
1598 }
1599 else:
1600 raise NotImplementedError(
1601 f"Loading CVs with reconstruct_sidechains is currently not implemented."
1602 )
1603 if dataset is None:
1604 if not self.trajs.CVs:
1605 missing = list(should_be - set(trajs.CVs.keys()))
1606 if missing != []:
1607 print("loading missing values: ", missing)
1608 self.trajs.load_CVs(missing, ensemble=ensemble)
1609 else:
1610 if not should_be.issubset(set(self.trajs.CVs.keys())):
1611 self.trajs.load_CVs(list(should_be), ensemble=ensemble)
1613 if not should_be.issubset(set(self.trajs.CVs.keys())):
1614 if not ensemble:
1615 msg = (
1616 f" You can try to set `ensemble=True` to load "
1617 f"these trajectories into an ensemble, which "
1618 f"allows features with different feature length."
1619 )
1620 else:
1621 msg = ""
1622 raise Exception(
1623 f"Could not load CVs. Should be {should_be}, but "
1624 f"currently only {set(trajs.CVs.keys())} are loaded.{msg}"
1625 )
1627 # create dataset
1628 if dataset is None:
1629 (
1630 self.sparse,
1631 self.train_data,
1632 self.inp_CV_data,
1633 ) = self.get_train_data_from_trajs(self.trajs, self.p)
1634 if not self.p.reconstruct_sidechains:
1635 dataset = tf.data.Dataset.from_tensor_slices(
1636 (
1637 self.inp_CV_data["central_angles"],
1638 self.inp_CV_data["central_dihedrals"],
1639 self.inp_CV_data["central_cartesians"],
1640 self.inp_CV_data["central_distances"],
1641 self.inp_CV_data["side_dihedrals"],
1642 )
1643 )
1644 dataset = dataset.shuffle(
1645 buffer_size=self.inp_CV_data["central_cartesians"].shape[0],
1646 reshuffle_each_iteration=True,
1647 )
1648 else:
1649 dataset = tf.data.Dataset.from_tensor_slices(
1650 (
1651 self.inp_CV_data["central_angles"],
1652 self.inp_CV_data["central_dihedrals"],
1653 self.inp_CV_data["all_cartesians"],
1654 self.inp_CV_data["central_distances"],
1655 self.inp_CV_data["side_dihedrals"],
1656 self.inp_CV_data["side_angles"],
1657 self.inp_CV_data["side_dihedrals"],
1658 self.inp_CV_data["side_distances_"],
1659 )
1660 )
1661 dataset = dataset.shuffle(
1662 buffer_size=self.inp_CV_data["all_cartesians"].shape[0],
1663 reshuffle_each_iteration=True,
1664 )
1665 dataset = dataset.repeat()
1666 self.dataset = dataset.batch(self.p.batch_size)
1667 else:
1668 self.dataset = dataset
1669 self.sparse = any(
1670 [isinstance(t, tf.SparseTensorSpec) for t in self.dataset.element_spec]
1671 )
1673 # create model based on user input
1674 if model is None:
1675 self.model = self._get_model()
1676 else:
1677 self.model = model
1679 # setup callbacks
1680 self._setup_callbacks()
1682 # create loss based on user input
1683 self.loss = self._get_loss()
1685 # choose optimizer
1686 self.optimizer = tf.keras.optimizers.Adam(
1687 learning_rate=self.p.learning_rate, clipvalue=1.0
1688 )
1690 # compile model
1691 self.model.compile(
1692 optimizer=self.optimizer,
1693 loss=self.loss,
1694 metrics=self.metrics,
1695 unpacked_loss_fns={fn.__name__: fn for fn in self.loss},
1696 )
1698 # do this if tensorboard is true.
1699 if self.p.tensorboard:
1700 # print shapes
1701 print("input shapes are:")
1702 if hasattr(self, "inp_CV_data"):
1703 print({k: v.shape for k, v in self.inp_CV_data.items()})
1704 else:
1705 for d in self.dataset:
1706 break
1707 print([v.shape for v in d])
1709 # set _log_images False to fix the backend after training
1710 self._log_images = False
1711 # get the output from model summary.
1712 with Capturing() as output:
1713 self.model.summary()
1714 with open(self.p.main_path + "/model_summary.txt", "w") as f:
1715 f.write("\n".join(output))
1716 self.plot_network()
1717 print(
1718 f"Saved a text-summary of the model and an image in {self.p.main_path},",
1719 "as specified in 'main_path' in the parameters.",
1720 )
1722 # sets up the tb callback to plot the model
1723 self.tb_callback = tf.keras.callbacks.TensorBoard(
1724 self.p.main_path, write_graph=True
1725 )
1726 self.tb_callback.set_model(self.model)
1728 def _print_save_message(self) -> None:
1729 """Prints a message, to inform user where model is saved."""
1730 _print_save_message(self)
1732 def _get_model(self) -> tf.keras.Model:
1733 """Constructs a model from an instance of this class."""
1734 return _get_model(self)
1736 def add_loss(self, loss):
1737 """Adds a new loss to the existing losses."""
1738 try:
1739 self.loss.append(loss(self.model, self.p))
1740 except TypeError:
1741 self.loss.append(loss(self.model))
1743 def add_callback(self, callback):
1744 """Adds a new callback to the existing callbacks."""
1745 try:
1746 self.callbacks.append(callback(self.p))
1747 except TypeError:
1748 self.callbacks.append(callback)
1750 def add_metric(self, metric):
1751 """Adds a new metric to the existing metrics."""
1752 self.metrics.append(metric(self.p))
1754 def add_images_to_tensorboard(self, *args: Any, **kwargs: Any) -> None:
1755 """Adds images of the latent space to tensorboard.
1757 Args:
1758 data (Optional[Union[np.ndarray, Sequence[np.ndarray]]): The input-data will
1759 be passed through the encoder part of the autoencoder. If None
1760 is provided, a set of 10_000 points from `self.train_data` will
1761 be taken. A list[np.ndarray] is needed for the functional API of the
1762 `AngleDihedralCartesianEncoderMap`, that takes a list of
1763 [angles, dihedrals, side_dihedrals]. Defaults to None.
1764 image_step (Optional[int]): The interval in which to plot
1765 images to tensorboard. If None is provided, the `image_step`
1766 will be the same as `Parameters.summary_step`. Defaults to None.
1767 max_size (int): The maximum size of the high-dimensional data, that is
1768 projected. Prevents excessively large-datasets from being projected
1769 at every `image_step`. Defaults to 10_000.
1770 scatter_kws (Optional[dict[str, Any]]): A dict with items that
1771 `plotly.express.scatter()` will accept. If None is provided,
1772 a dict with size 20 will be passed to
1773 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`,
1774 which sets an appropriate size of scatter points for the size of
1775 datasets encodermap is usually used for.
1776 hist_kws (Optional[dict[str, Any]]): A dict with items that
1777 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a
1778 dict with bins 50 will be passed to
1779 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`.
1780 You can choose a colormap here by providing `{'bins': 50, 'cmap':
1781 'plasma'}` for this argument.
1782 additional_fns (Optional[Sequence[Callable]]): A list of functions
1783 that will accept the low-dimensional output of the `Autoencoder`
1784 latent/bottleneck layer and return a tf.Tensor that can be logged
1785 by `tf.summary.image()`. See the notebook
1786 'writing_custom_images_to_tensorboard.ipynb' in
1787 tutorials/notebooks_customization for more info. If None is
1788 provided, no additional functions will be used to plot to
1789 tensorboard. Defaults to None.
1790 when (Literal["epoch", "batch"]): When to log the images can be
1791 either 'batch', then the images will be logged after every step
1792 during training, or 'epoch', then only after every image_step
1793 epoch the images will be written. Defaults to 'epoch'.
1794 save_to_disk (bool): Whether to also write the images to disk.
1796 """
1797 _add_images_to_tensorboard(self, *args, **kwargs)
1799 def train(self) -> Optional[dict[str, Any]]:
1800 """Overwrites the parent class' `train()` method to implement references."""
1801 if all([v == 1 for k, v in self.p.__dict__.items() if "reference" in k]):
1802 self.train_for_references()
1803 else:
1804 print("References are already provided. Skipping reference training.")
1805 if self.p.current_training_step >= self.p.n_steps:
1806 print(
1807 f"This {self.__class__.__name__} instance has already been trained "
1808 f"for {self.p.current_training_step} steps. Increase the training "
1809 f"steps by calling `{self.__class__.__name__}.p.n_steps += new_steps` "
1810 f"and then call `{self.__class__.__name__}.train()` again."
1811 )
1812 return
1814 if self.p.training == "custom" and self.p.batched:
1815 raise NotImplementedError()
1816 elif self.p.training == "custom" and not self.p.batched:
1817 raise NotImplementedError()
1818 elif self.p.training == "auto":
1819 epochs = self.p.n_steps - self.p.current_training_step
1820 history = self.model.fit(
1821 self.dataset,
1822 batch_size=self.p.batch_size,
1823 epochs=epochs,
1824 steps_per_epoch=1,
1825 verbose=0,
1826 callbacks=self.callbacks,
1827 ).history
1828 else:
1829 raise ValueError(
1830 f"training argument needs to be `auto` or `custom`. You provided '{self.training}'."
1831 )
1832 self.p.current_training_step += self.p.n_steps - self.p.current_training_step
1833 if not self.read_only:
1834 self.p.save()
1835 self.save()
1837 history_to_json_serializable(history)
1838 return history
1840 def train_for_references(self, subsample: int = 100, maxiter: int = 500) -> None:
1841 """Calculates the angle, dihedral, and cartesian costs to so-called
1842 references, which can be used to bring these costs to a similar
1843 magnitude.
1845 """
1846 p = ADCParameters(
1847 cartesian_cost_scale=1,
1848 angle_cost_scale=1,
1849 dihedral_cost_scale=1,
1850 )
1851 if hasattr(self, "trajs"):
1852 nsteps = min(maxiter, max(1, int(self.trajs.n_frames / self.p.batch_size)))
1853 else:
1854 return
1855 # fmt: off
1856 costs = {
1857 "dihedral_cost": ["central_dihedrals", 1, dihedral_loss(self.model, p)],
1858 "angle_cost": ["central_angles", 0, angle_loss(self.model, p)],
1859 "cartesian_cost": ["central_cartesians", 2, cartesian_loss(self.model, parameters=p)],
1860 }
1861 # fmt: on
1862 # Local Folder Imports
1863 from ..models.models import ADCSparseFunctionalModel
1865 if isinstance(self.model, ADCSparseFunctionalModel):
1866 to_dense_models = {
1867 "dihedral_cost": self.model.get_dense_model_central_dihedrals,
1868 "angle_cost": self.model.get_dense_model_central_angles,
1869 "cartesian_cost": self.model.get_dense_model_distances,
1870 }
1872 cost_references = {key: [] for key in costs.keys()}
1873 for key, val in costs.items():
1874 if key in ["dihedral_cost", "angle_cost"]:
1875 inp = self.trajs.CVs[val[0]]
1876 if np.any(np.isnan(inp)):
1877 inp = np_to_sparse_tensor(inp[::subsample])
1878 inp = to_dense_models[key](inp).numpy()
1879 means = np.repeat(
1880 np.expand_dims(
1881 np.mean(inp, 0),
1882 axis=0,
1883 ),
1884 repeats=self.p.batch_size,
1885 axis=0,
1886 )
1887 costs[key].append(means)
1888 else:
1889 inp = self.trajs.CVs["central_distances"]
1890 if np.any(np.isnan(inp)):
1891 inp = np_to_sparse_tensor(inp[::subsample])
1892 inp = to_dense_models[key](inp).numpy()
1893 mean_lengths = np.expand_dims(np.mean(inp, axis=0), axis=0)
1894 chain = chain_in_plane(mean_lengths, costs["angle_cost"][3])
1895 gen_cartesians = dihedrals_to_cartesian_tf(
1896 costs["dihedral_cost"][3] + np.pi, chain
1897 )
1898 pd = pairwise_dist(
1899 gen_cartesians[
1900 :,
1901 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
1902 ],
1903 flat=True,
1904 )
1905 costs[key].append(pd)
1907 with tqdm(
1908 desc="Calculating references",
1909 total=nsteps,
1910 position=0,
1911 leave=True,
1912 ) as pbar:
1913 for i, data in zip(range(nsteps), self.dataset):
1914 for d in data:
1915 if not isinstance(d, tf.sparse.SparseTensor):
1916 if np.any(np.isnan(d)) and not self.sparse:
1917 raise Exception(
1918 f"Received data containing nans from `self.dataset` ({d=}),"
1919 f"while `self.sparse` is set to True ({self.sparse=}). "
1920 f"This training won't work as nans compromise the weights "
1921 f"of the whole model. Try to explicitly set `sparse=True`, "
1922 f"when instantiating the {self.__class__.__name__} class."
1923 )
1924 for key, val in costs.items():
1925 if key in ["dihedral_cost", "angle_cost"]:
1926 if isinstance(data[val[1]], tf.sparse.SparseTensor):
1927 d = to_dense_models[key](data[val[1]]).numpy()
1928 else:
1929 d = data[val[1]]
1930 cost_references[key].append(val[2](d, val[3]).numpy())
1931 if key == "cartesian_cost":
1932 if isinstance(data[val[1]], tf.sparse.SparseTensor):
1933 d = self.model.get_dense_model_cartesians(
1934 data[val[1]]
1935 ).numpy()
1936 # un-flatten the cartesian coordinates
1937 d = d.reshape(len(d), -1, 3)
1938 else:
1939 d = data[val[1]]
1940 pd = pairwise_dist(
1941 d[
1942 :,
1943 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
1944 ],
1945 flat=True,
1946 )
1947 c = val[2](val[3], pd).numpy()
1948 cost_references["cartesian_cost"].append(c)
1949 pbar.update()
1950 s = {k: np.mean(v) for k, v in cost_references.items()}
1951 print(f"After {i} steps setting cost references: {s} to parameters.")
1952 self.p.angle_cost_reference = float(np.mean(cost_references["angle_cost"]))
1953 self.p.dihedral_cost_reference = float(
1954 np.mean(cost_references["dihedral_cost"])
1955 )
1956 self.p.cartesian_cost_reference = float(
1957 np.mean(cost_references["cartesian_cost"])
1958 )
1959 if not self.read_only:
1960 self.p.save()
1961 return cost_references
1963 def save(self, step: Optional[int] = None) -> None | Path:
1964 """Saves the model to the current path defined in `parameters.main_path`.
1966 Args:
1967 step (Optional[int]): Does not save the model at the given
1968 training step, but rather changes the string used for saving
1969 the model from a datetime format to another.
1971 Returns:
1972 Union[None, Path]: When the model has been saved, the Path will
1973 be returned. If the model could not be saved. None will be
1974 returned.
1976 """
1977 if not self.read_only:
1978 return save_model(
1979 self.model,
1980 self.p.main_path,
1981 inp_class_name=self.__class__.__name__,
1982 step=step,
1983 print_message=True,
1984 ).parent
1985 else:
1986 print(
1987 f"This {self.__class__.__name__} is set to read_only. Set "
1988 f"`{self.__class__.__name__}.read_only=False` to save the "
1989 f"current state of the model."
1990 )
1992 def close(self) -> None:
1993 """Clears the current keras backend and frees up resources."""
1994 # clear old sessions
1995 tf.keras.backend.clear_session()
1997 def set_train_data(self, data: TrajEnsemble) -> None:
1998 """Resets the train data for reloaded models."""
1999 (
2000 sparse,
2001 self.train_data,
2002 self.inp_CV_data,
2003 ) = self.get_train_data_from_trajs(data, self.p)
2004 self._using_hypercube = False
2005 self.p.using_hypercube = False
2006 if not self.sparse and sparse:
2007 print(
2008 f"The provided data contains nan's, but the model was trained "
2009 f"on dense input data."
2010 )
2011 return
2012 if not self.parameters.reconstruct_sidechains:
2013 data = [
2014 self.inp_CV_data["central_angles"],
2015 self.inp_CV_data["central_dihedrals"],
2016 self.inp_CV_data["central_cartesians"],
2017 self.inp_CV_data["central_distances"],
2018 self.inp_CV_data["side_dihedrals"],
2019 ]
2020 else:
2021 data = [
2022 self.inp_CV_data["central_angles"],
2023 self.inp_CV_data["central_dihedrals"],
2024 self.inp_CV_data["all_cartesians"],
2025 self.inp_CV_data["central_distances"],
2026 self.inp_CV_data["side_dihedrals"],
2027 self.inp_CV_data["side_angles"],
2028 self.inp_CV_data["side_dihedrals"],
2029 self.inp_CV_data["side_distances_"],
2030 ]
2031 dataset = tf.data.Dataset.from_tensor_slices(tuple(data))
2032 dataset = dataset.shuffle(
2033 buffer_size=self.inp_CV_data["central_cartesians"].shape[0],
2034 reshuffle_each_iteration=True,
2035 )
2036 dataset = dataset.repeat()
2037 self.dataset = dataset.batch(self.p.batch_size)
2039 def plot_network(self) -> None:
2040 """Tries to plot the network using pydot, pydotplus and graphviz.
2041 Doesn't raise an exception if plotting is not possible.
2043 Note:
2044 Refer to this guide to install these programs:
2045 https://stackoverflow.com/questions/47605558/importerror-failed-to-import-pydot-you-must-install-pydot-and-graphviz-for-py
2047 """
2048 try:
2049 out = plot_model(self.model, None)
2050 if out is not None:
2051 out.save(Path(self.p.main_path) / "network.png")
2052 except:
2053 pass
2055 @staticmethod
2056 def get_train_data_from_trajs(
2057 trajs: Union[TrajEnsemble, SingleTraj],
2058 p: ADCParameters,
2059 attr: str = "CVs",
2060 max_size: int = -1,
2061 ) -> tuple[bool, list[np.ndarray], dict[str, np.ndarray]]:
2062 """Builds train data from a `TrajEnsemble`.
2064 Args:
2065 trajs (TrajEnsemble): A `TrajEnsemble` instance.
2066 p (encodermap.parameters.ADCParameters): An instance of `encodermap.parameters.ADCParameters`.
2067 attr (str): Which attribute to get from `TrajEnsemble`. This defaults
2068 to 'CVs', because 'CVs' is usually a dict containing the CV data.
2069 However, you can build the train data from any dict in the `TrajEnsemble`.
2070 max_size (int): When you only want a subset of the CV data. Set this
2071 to the desired size.
2073 Returns:
2074 tuple: A tuple containing the following:
2075 - bool: A bool that shows whether some 'CV' values are `np.nan` (True),
2076 which will be used to decide whether the sparse training
2077 will be used.
2078 - list[np.ndarray]: An array of features fed into the autoencoder,
2079 concatenated along the feature axis. The order of the
2080 features is: central_angles, central_dihedral, (side_dihedrals
2081 if p.use_sidechain_dihedrals is True).
2082 - dict[str, np.ndarray]: The training data as a dict. Containing
2083 all values in `trajs.CVs`.
2085 """
2086 # Local Folder Imports
2087 from ..misc.misc import FEATURE_NAMES
2089 assert hasattr(trajs, attr), (
2090 f"Can't load train data from the attribute {attr}. "
2091 f"{trajs.__class__.__name__} has no attribute '{attr}'"
2092 )
2093 if not any([np.isnan(x).any() for x in getattr(trajs, attr).values()]):
2094 inp_CV_data = {
2095 key: val.astype("float32") for key, val in getattr(trajs, attr).items()
2096 }
2098 # squeeze, if xarray is provided
2099 if all([hasattr(v, "values") for v in inp_CV_data.values()]):
2100 inp_CV_data = {k: v.values.squeeze() for k, v in inp_CV_data.items()}
2101 sparse = False
2102 else:
2103 sparse = True
2105 # check whether the nans are correctly distributed
2106 for k, v in trajs.CVs.items():
2107 if k not in list(FEATURE_NAMES.values()):
2108 continue
2109 if v.ndim == 3:
2110 v = np.any(np.all(np.isnan(v), (1, 2)))
2111 else:
2112 v = np.any(np.all(np.isnan(v), 1))
2113 if v:
2114 raise Exception(
2115 f"Stacking of frames for CV `{k}` did not "
2116 f"succeed. There are frames full of nans."
2117 )
2119 # build the CV data
2120 inp_CV_data = {
2121 key: val.astype("float32") for key, val in getattr(trajs, attr).items()
2122 }
2124 # squeeze, if xarray is provided
2125 if all([hasattr(v, "values") for v in inp_CV_data.values()]):
2126 inp_CV_data = {k: v.values.squeeze() for k, v in inp_CV_data.items()}
2128 # subsample if required
2129 if max_size > -1:
2130 for k, v in inp_CV_data.items():
2131 idx = np.unique(
2132 np.round(np.linspace(0, v.shape[0] - 1, max_size)).astype(int)
2133 )
2134 inp_CV_data[k] = v[idx]
2136 for k, v in inp_CV_data.items():
2137 if np.any(np.isnan(v)):
2138 values = v
2139 if k == "central_cartesians":
2140 values = values.reshape(len(values), -1)
2141 sparse_tensor = np_to_sparse_tensor(values)
2142 inp_CV_data[k] = sparse_tensor
2144 if not p.reconstruct_sidechains:
2145 if not p.use_backbone_angles and not p.use_sidechains:
2146 train_data = inp_CV_data["central_dihedrals"]
2147 elif p.use_backbone_angles and not p.use_sidechains:
2148 train_data = [
2149 inp_CV_data["central_angles"],
2150 inp_CV_data["central_dihedrals"],
2151 ]
2152 if p.model_api == "sequential" and not sparse:
2153 train_data = np.hstack(train_data)
2154 elif p.use_backbone_angles and p.use_sidechains:
2155 train_data = [
2156 inp_CV_data["central_angles"],
2157 inp_CV_data["central_dihedrals"],
2158 inp_CV_data["side_dihedrals"],
2159 ]
2160 if p.model_api == "sequential" and not sparse:
2161 train_data = np.hstack(train_data)
2162 else:
2163 raise Exception(
2164 "Cannot train model with central dihedrals and side "
2165 "dihedrals only. Backbone angles are required."
2166 )
2167 # some checks for the length of the train data
2168 if p.model_api == "functional":
2169 if not p.use_backbone_angles and not p.use_sidechains:
2170 pass
2171 elif p.use_backbone_angles and not p.use_sidechains:
2172 assert len(train_data) == 2
2173 else:
2174 assert len(train_data) == 3
2175 else:
2176 train_data = [
2177 inp_CV_data["central_angles"],
2178 inp_CV_data["central_dihedrals"],
2179 inp_CV_data["side_angles"],
2180 inp_CV_data["side_dihedrals"],
2181 ]
2182 return sparse, train_data, inp_CV_data
2184 @property
2185 def encoder(self) -> tf.keras.Model:
2186 """tf.keras.Model: The encoder Model."""
2187 return self.model.encoder_model
2189 @property
2190 def decoder(self) -> tf.keras.Model:
2191 """tf.keras.Model: The decoder Model."""
2192 return self.model.decoder_model
2194 @classmethod
2195 def from_checkpoint(
2196 cls: Type[AngleDihedralCartesianEncoderMapType],
2197 trajs: Union[None, TrajEnsemble],
2198 checkpoint_path: Union[Path, str],
2199 dataset: Optional[tf.data.Dataset] = None,
2200 use_previous_model: bool = False,
2201 compat: bool = False,
2202 ) -> AngleDihedralCartesianEncoderMapType:
2203 """Reconstructs the model from a checkpoint.
2205 Although the model can be loaded from disk without any form of data and
2206 still yield the correct input and output shapes, it is required to either
2207 provide `trajs` or `dataset` to double-check, that the correct model will
2208 be reloaded.
2210 This is also, whe the `sparse` argument is not needed, as sparcity of the
2211 input data is a property of the `TrajEnsemble` provided.
2213 Args:
2214 trajs (Union[None, TrajEnsemble]): Either None (in which case, the
2215 argument `dataset` is required), or an instance of `TrajEnsemble`,
2216 which was used to instantiate the `AngleDihedralCartesianEncoderMap`,
2217 before it was saved to disk.
2218 checkpoint_path (Union[Path, str]): The path to the checkpoint. Can
2219 either be the path to a .keras file or to a directory containing
2220 .keras files, in which case the most recently created .keras
2221 file will be used.
2222 dataset (Optional[tf.data.Dataset]): If `trajs` is not provided, a
2223 dataset is required to make sure the input shapes match the model,
2224 that is stored on the disk.
2225 use_previous_model (bool): Set this flag to True, if you load a model
2226 from an in-between checkpoint step (e.g., to continue training with
2227 different parameters). If you have the files saved_model_0.keras,
2228 saved_model_500.keras and saved_model_1000.keras, setting this to
2229 True and loading the saved_model_500.keras will back up the
2230 saved_model_1000.keras.
2231 compat (bool): Whether to use compatibility mode when missing or wrong
2232 parameter files are present. In this special case, some assumptions
2233 about the network architecture are made from the model and the
2234 parameters in parameters.json overwritten accordingly (a backup
2235 will also be made).
2237 Returns:
2238 AngleDihedralCartesianEncoderMapType: An instance of `AngleDihedralCartesianEncoderMap`.
2240 """
2241 return load_model(
2242 cls,
2243 checkpoint_path,
2244 trajs=trajs,
2245 dataset=dataset,
2246 use_previous_model=use_previous_model,
2247 compat=compat,
2248 )
2250 def _setup_callbacks(self) -> None:
2251 """Overwrites the parent class' `_setup_callbacks` method.
2253 Due to the 'soft start' of the cartesian cost, the `cartesian_increase_callback`
2254 needs to be added to the list of callbacks.
2256 """
2257 self.metrics = []
2258 self.callbacks = []
2259 self.callbacks.append(ProgressBar(parameters=self.p))
2260 if not self.read_only:
2261 self.callbacks.append(CheckpointSaver(self.p))
2262 if self.p.tensorboard:
2263 self.tensorboard_write_bool: Union[TensorboardWriteBool, None] = (
2264 TensorboardWriteBool(self.p)
2265 )
2266 self.callbacks.append(self.tensorboard_write_bool)
2267 file_writer = tf.summary.create_file_writer(self.p.main_path + "/train")
2268 file_writer.set_as_default()
2269 tf.summary.text(
2270 name=f"Parameters Summary for {self.p.main_path}",
2271 data=self.p.parameters,
2272 step=0,
2273 )
2274 else:
2275 self.tensorboard_write_bool = None
2276 if self.p.cartesian_cost_scale_soft_start != (None, None):
2277 self.cartesian_increase_callback = IncreaseCartesianCost(self.p)
2278 self.callbacks.append(self.cartesian_increase_callback)
2279 if self.p.track_clashes:
2280 self.metrics.append(ADCClashMetric(parameters=self.p, distance_unit="nm"))
2281 if self.p.track_RMSD:
2282 self.metrics.append(ADCRMSDMetric(parameters=self.p))
2284 def _get_loss(self):
2285 loss = self.p.loss
2286 if loss == "reconstruction_loss":
2287 loss = reconstruction_loss(self.model)
2288 elif loss == "emap_cost":
2289 self.dihedral_loss = dihedral_loss(
2290 self.model, self.p, self.tensorboard_write_bool
2291 )
2292 self.angle_loss = angle_loss(
2293 self.model, self.p, self.tensorboard_write_bool
2294 )
2295 if self.p.cartesian_cost_scale_soft_start != (None, None):
2296 self.cartesian_loss = cartesian_loss(
2297 self.model,
2298 self.cartesian_increase_callback,
2299 self.p,
2300 self.tensorboard_write_bool,
2301 print_current_scale=False,
2302 )
2303 else:
2304 self.cartesian_loss = cartesian_loss(
2305 self.model, None, self.p, self.tensorboard_write_bool
2306 )
2308 self.distance_loss = distance_loss(
2309 self.model, self.p, self.tensorboard_write_bool
2310 )
2311 self.cartesian_distance_loss = cartesian_distance_loss(
2312 self.model, self.p, self.tensorboard_write_bool
2313 )
2314 self.center_loss = center_loss(
2315 self.model, self.p, self.tensorboard_write_bool
2316 )
2317 self.regularization_loss = regularization_loss(
2318 self.model, self.p, self.tensorboard_write_bool
2319 )
2320 loss = [
2321 self.dihedral_loss,
2322 self.angle_loss,
2323 self.cartesian_loss,
2324 self.distance_loss,
2325 self.cartesian_distance_loss,
2326 self.center_loss,
2327 self.regularization_loss,
2328 ]
2329 if self.p.use_sidechains:
2330 self.side_dihedral_loss = side_dihedral_loss(
2331 self.model, self.p, self.tensorboard_write_bool
2332 )
2333 loss.append(self.side_dihedral_loss)
2334 elif loss == "mse":
2335 loss = "mse"
2336 else:
2337 raise ValueError(
2338 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'."
2339 )
2340 return loss
2342 def encode(
2343 self,
2344 data: Optional[Union[TrajEnsemble, SingleTraj, Sequence[np.ndarray]]] = None,
2345 ) -> np.ndarray:
2346 """Runs the central_angles, central_dihedrals, (side_dihedrals) through the
2347 autoencoder. Make sure that `data` has the correct shape.
2349 Args:
2350 data (Sequence[np.ndarray]): Provide a sequence of angles, and
2351 central_dihedrals, if you used sidechain_dihedrals during training
2352 append these to the end of the sequence.
2354 Returns:
2355 np.ndarray: The latent space representation of the provided `data`.
2357 """
2358 if data is None:
2359 if hasattr(self, "trajs"):
2360 data = self.trajs
2361 else:
2362 for data in self.dataset:
2363 break
2364 if not self.p.use_sidechains and not self.p.use_backbone_angles:
2365 data: np.ndarray = data[1] # type: ignore[no-redef]
2366 elif self.p.use_backbone_angles and not self.p.use_sidechains:
2367 data: list[np.ndarray] = [data[0], data[1]] # type: ignore[no-redef]
2368 else:
2369 data: list[np.ndarray] = [data[0], data[1], data[-1]] # type: ignore[no-redef]
2371 if not hasattr(data, "trajs") and hasattr(data, "_CVs") and data is not None:
2372 traj: SingleTraj = data # type: ignore[assignment]
2373 _, data, __ = self.get_train_data_from_trajs(traj, self.p, attr="_CVs")
2374 if isinstance(data, (tf.SparseTensor)):
2375 shape = data.dense_shape[0] * data.dense_shape[1]
2376 elif isinstance(data, np.ndarray):
2377 shape = data.size
2378 data = np.expand_dims(data, 0)
2379 elif isinstance(data[0], (np.ndarray, tf.Tensor)):
2380 shape = data[0].size
2381 elif isinstance(data[0], tf.sparse.SparseTensor):
2382 shape = data[0].values.shape[0]
2383 else:
2384 raise Exception(f"Unexpected datatype {data=}")
2385 if shape > 100_000:
2386 print(
2387 f"Due to the size of the provided data {shape}, I "
2388 f"need to chunk it, which takes longer. Sit back, grab a coffee..."
2389 )
2390 indices = np.split(
2391 np.arange(traj.n_frames), np.arange(100, traj.n_frames, 100)
2392 )
2393 # single frame encoding does not work, because the frame axis is dropped
2394 if len(indices[-1]) == 1:
2395 indices = np.split(
2396 np.arange(traj.n_frames), np.arange(100, traj.n_frames, 101)
2397 )
2398 lowd = []
2399 for i, ind in enumerate(indices):
2400 _, data, __ = self.get_train_data_from_trajs(
2401 traj[ind], self.p, attr="_CVs"
2402 )
2403 try:
2404 lowd.append(self.model.encoder_model(data))
2405 except IndexError as e:
2406 raise Exception(f"{i=} {ind=} {data=}") from e
2407 return np.vstack(lowd)
2408 else:
2409 return self.encode(data)
2410 elif hasattr(data, "trajs"):
2411 lowd = []
2412 for traj in data.trajs: # type: ignore[union-attr]
2413 lowd.append(self.encode(traj))
2414 return np.vstack(lowd)
2415 elif isinstance(data, Sequence):
2416 # Standard Library Imports
2417 from functools import reduce
2418 from operator import mul
2420 size = reduce(mul, data[0].shape)
2421 if size > 100_000:
2422 indices = np.split(
2423 np.arange(data[0].shape[0]),
2424 np.arange(100, data[0].shape[0], 100),
2425 )
2426 if len(indices[-1]) == 1:
2427 indices = np.split(
2428 np.arange(data[0].shape[0]),
2429 np.arange(100, data[0].shape[0], 101),
2430 )
2431 lowd = []
2432 for i, ind in enumerate(indices):
2433 datum = []
2434 for d in data:
2435 if isinstance(d, tf.sparse.SparseTensor):
2436 d = tf.sparse.to_dense(d, default_value=np.nan).numpy()[ind]
2437 datum.append(np_to_sparse_tensor(d))
2438 else:
2439 datum.append(d[ind])
2440 lowd.append(self.model.encoder_model(datum))
2441 return np.vstack(lowd)
2442 else:
2443 return self.model.encoder_model(data).numpy() # type: ignore[no-any-return]
2444 elif hasattr(data, "shape") or hasattr(data, "dense_shape"):
2445 return self.model.encoder_model(data).numpy() # type: ignore[no-any-return]
2446 else:
2447 raise TypeError(f"Wrong type for argument `data`: {type(data)=} {data=}.")
2449 def decode(self, data: np.ndarray) -> Sequence[np.ndarray]:
2450 """Calls the decoder part of the model.
2452 `AngleDihedralCartesianAutoencoder` will, like the other two classes'
2453 output a list of np.ndarray.
2455 Args:
2456 data (np.ndarray): The data to be passed to the decoder part of
2457 the model. Make sure that the shape of the data matches the
2458 number of neurons in the latent space.
2460 Returns:
2461 Union[list[np.ndarray], np.ndarray]: Outputs from the decoder part.
2462 For `AngleDihedralCartesianEncoderMap`, this will be a list of
2463 np.ndarray.
2465 """
2466 out = self.model.decoder(data)
2467 if isinstance(out, (list, tuple)):
2468 out = [o.numpy() for o in out]
2469 else:
2470 out = out.numpy()
2471 return out # type: ignore[no-any-return]
2473 @overload
2474 def generate( 2474 ↛ exitline 2474 didn't jump to the function exit
2475 self,
2476 points: np.ndarray,
2477 top: Optional[Union[str, int, Topology]],
2478 backend: Literal["mdtraj"],
2479 progbar: Optional[Any],
2480 ) -> Trajectory: ...
2482 @overload
2483 def generate( 2483 ↛ exitline 2483 didn't jump to the function exit
2484 self,
2485 points: np.ndarray,
2486 top: Optional[Union[str, int, Topology]],
2487 backend: Literal["mdanalysis"],
2488 progbar: Optional[Any],
2489 ) -> Universe: ...
2491 def generate(
2492 self,
2493 points: np.ndarray,
2494 top: Optional[Union[str, int, Topology]] = None,
2495 backend: Literal["mdtraj", "mdanalysis"] = "mdtraj",
2496 progbar: Optional[Any] = None,
2497 ) -> Union[Universe, Trajectory]:
2498 """Overrides the parent class' `generate` method and builds a trajectory.
2500 Instead of just providing data to `decode` using the decoder part of the
2501 network, this method also takes a molecular topology as its `top`
2502 argument. This topology is then used to rebuild a time-resolved
2503 trajectory.
2505 Args:
2506 points (np.ndarray): The low-dimensional points from which the
2507 trajectory should be rebuilt.
2508 top (Optional[str, int, mdtraj.Topology]): The topology to be used for rebuilding the
2509 trajectory. This should be a string pointing towards a <*.pdb,
2510 *.gro, *.h5> file. Alternatively, None can be provided; in which
2511 case, the internal topology (`self.top`) of this class is used.
2512 Defaults to None.
2513 backend (str): Defines what MD python package is to use, to build the
2514 trajectory and also what type this method returns, needs to be
2515 one of the following:
2516 * "mdtraj"
2517 * "mdanalysis"
2519 Returns:
2520 Union[mdtraj.Trajectory, MDAnalysis.universe]: The trajectory after
2521 applying the decoded structural information. The type of this
2522 depends on the chosen `backend` parameter.
2524 """
2525 # get the output this can be done regardless
2526 out = self.decode(points)
2527 if isinstance(out, np.ndarray):
2528 dihedrals = out
2529 sidechain_dihedrals = None
2530 elif (
2531 isinstance(out, (list, tuple))
2532 and len(out) == 2
2533 and all([isinstance(a, np.ndarray) for a in out])
2534 ):
2535 angles, dihedrals = out
2536 sidechain_dihedrals = None
2537 elif (
2538 isinstance(out, (list, tuple))
2539 and len(out) == 3
2540 and all([isinstance(a, np.ndarray) for a in out])
2541 ):
2542 angles, dihedrals, sidechain_dihedrals = out
2543 else:
2544 raise Exception(
2545 f"Unexpected length of out detected: ({len(out)}=). Maybe also "
2546 f"unexpected dtypes: ({[type(a) for a in out]=})."
2547 )
2549 assert isinstance(self.trajs, TrajEnsemble), (
2550 f"`generate()` can only work, when the Autoencoder was built with a "
2551 f"`TrajEnsemble` and not just a dataset. You can set the `TrajEnsemble` "
2552 f"of this object with `AngleDihedralCartesianEncoderMap.trajs = TrajEnsemble`."
2553 )
2555 if top is None:
2556 if len(self.trajs.top) > 1:
2557 print(
2558 f"Please specify which topology you would like to use for generating "
2559 f"conformations. You can either provide a `str` to a topology file "
2560 f"(file extension .pdb, .h5, .gro) on disk, or a `int` specifying the "
2561 f"one of the ensembles {len(self.trajs.top)} topologies "
2562 f"(see `AngleDihedralCartesianEncoderMap.trajs.top` for available "
2563 f"topologies). You can also directly supply a "
2564 f"you can also specify a `mdtraj.Topology` object."
2565 )
2566 return # type: ignore[return-value]
2567 elif isinstance(top, int):
2568 mdanalysis_traj = self.trajs[top][0].traj
2569 elif isinstance(top, str) and top not in self.trajs.common_str:
2570 mdanalysis_traj = md.load(top)
2571 elif isinstance(top, str) and top in self.trajs.common_str:
2572 mdanalysis_traj = self.trajs.trajs_by_common_str[top][0].traj
2573 elif isinstance(top, md.Topology):
2574 mdanalysis_traj = top
2575 else:
2576 raise ValueError(
2577 f"Type of argument `top` must be int, str, md.Topology. You provided {type(top)}."
2578 )
2580 # do the backmapping
2581 if backend == "mdanalysis":
2582 with NamedTemporaryFile(suffix=".pdb") as f:
2583 mdanalysis_traj.save_pdb(f.name)
2584 uni = dihedral_backmapping(
2585 f.name, dihedrals, sidechains=sidechain_dihedrals
2586 )
2587 return uni
2588 elif backend == "mdtraj":
2589 traj = mdtraj_backmapping(
2590 top=top,
2591 dihedrals=dihedrals,
2592 sidechain_dihedrals=sidechain_dihedrals,
2593 trajs=self.trajs,
2594 progbar=progbar,
2595 ) # type: ignore[call-overload]
2596 return traj
2597 else:
2598 raise TypeError(
2599 f"backend must be 'mdtraj' or 'mdanalysis', but you provided {backend}"
2600 )