Coverage for encodermap/autoencoder/autoencoder.py: 40%
452 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/autoencoder/autoencoder.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Forward facing Autoencoder classes. Contains four classes:
24* Autoencoder: Simple NN dense, fully connected AE architecture. Reg loss, auto loss and center loss
25* EncoderMap: Uses the same architecture as `Autoencoder`, but adds another loss function.
26* DihedralEncoderMap: Basically the same as `EncoderMap`, but rewrites the `generate` method to use
27 an atomistic topology to rebuild a trajectory.
28* AngleDihedralCartesianEncoderMap: Uses more loss functions and tries to learn a full all atom conformation.
30"""
33################################################################################
34# Imports
35################################################################################
38from __future__ import annotations
40import typing
41import warnings
42from copy import deepcopy
43from typing import Literal, Optional, Union
45import matplotlib
46import numpy as np
47import tensorflow as tf
48import tensorflow.keras
50import encodermap
52from ..callbacks.callbacks import (
53 CheckpointSaver,
54 ImageCallback,
55 IncreaseCartesianCost,
56 ProgressBar,
57 TensorboardWriteBool,
58)
59from ..loss_functions.loss_functions import (
60 angle_loss,
61 auto_loss,
62 cartesian_distance_loss,
63 cartesian_loss,
64 center_loss,
65 dihedral_loss,
66 distance_loss,
67 reconstruction_loss,
68 regularization_loss,
69 side_dihedral_loss,
70)
71from ..misc.backmapping import dihedral_backmapping, mdtraj_backmapping
72from ..misc.misc import BadError, create_n_cube, plot_model
73from ..misc.saving_loading_models import load_model, save_model
74from ..models.models import gen_functional_model, gen_sequential_model
75from ..parameters.parameters import ADCParameters, Parameters
76from ..trajinfo.info_all import Capturing, TrajEnsemble
78################################################################################
79# Typing
80################################################################################
83if typing.TYPE_CHECKING:
84 import MDAnalysis as mda
85 import mdtraj as md
88################################################################################
89# Globals
90################################################################################
93__all__ = [
94 "Autoencoder",
95 "EncoderMap",
96 "AngleDihedralCartesianEncoderMap",
97 "DihedralEncoderMap",
98]
101##############################################################################
102# Function definition which allows self.p.tensorboard to be passed
103# @function(self.p.tensorboard)
104# def train(self):
105# # some training
106##############################################################################
109def function(f, tensorboard=False):
110 """Compiles functions with `tensorflow.function` based on a `tensorboard`
111 parameter.
114 To understand the neccessity of this function, we need to have a look how
115 tensorflow executes computations. There are two modes of execution:
116 * eager mode: In eager mode, the computations are handles by python.
117 The input types are python objects, and the output is a python object.
118 This eager execution allows you to directly execute a calculation of
119 two tensors (e.g. multiplication).
120 * graph mode: In graph mode, computations are done inside tensorflow graphs,
121 which are a collection of operations and tensors (i.e. data), that flow
122 through the operations of the graph. These graphs, make tensorflow
123 computations portable and significantly increase the performance of
124 similar computations.
125 Normally, you would accelerate a simple python function in tensorflow,
126 by compiling it like so:
128 ```python
129 import tensorflow as tf
131 @tf.function
132 def multiply(a, b):
133 return a * b
135 multiply(tf.constant(2), tf.constant(3)).numpy()
136 # 6
137 ```
139 However, the basic paradigm of accelerating the computation interferes with
140 `encodermap.Parameters` `tensorboard=True` argument, as it writes a lot of
141 additional information to tensorboard. Thus, a compilation with tf.function
142 does not make sense here. That's why encodermap's `function` decorator
143 takes an additional argument:
145 """
147 def wrapper(*args, **kwargs):
148 tensorboard = kwargs.pop("tensorboard", False)
149 """Wrapper of `encodermap.function`."""
150 if tensorboard:
151 warnings.warn(
152 "Running in tensorboard mode writes a lot of stuff to tensorboard. For speed boost deactivate tensorboard mode."
153 )
154 result = f(*args, **kwargs)
155 else:
156 compiled = tf.function(f)
157 result = compiled(*args, **kwargs)
158 return result
160 return wrapper
163##############################################################################
164# Public Classes
165##############################################################################
168class Autoencoder:
169 """Main Autoencoder class preparing data, setting up the neural network and implementing training.
171 This is the main class for neural networks inside EncoderMap. The class prepares the data
172 (batching and shuffling), creates a `tf.keras.Model` of layers specified by the attributes of
173 the `encodermap.Parameters` class. Depending on what Parent/Child-Class is instantiated
174 a combination of cost functions is set up. Callbacks to Tensorboard are also set up.
176 Attributes:
177 train_data (np.ndarray): The numpy array of the train data passed at init.
178 p (encodermap.Parameters): An `encodermap.Parameters()` class containing all info needed to set
179 up the network.
180 dataset (tensorflow.data.Dataset): The dataset that is actually used in training the keras model. The dataset
181 is a batched, shuffled, infinitely-repeating dataset.
182 read_only (bool): Variable telling the class whether it is allowed to write to disk (False) or not (True).
183 optimizer (tf.keras.optimizers.Adam): Instance of the Adam optimizer with learning rate specified by
184 the Parameters class.
185 metrics (list): A list of metrics passed to the model when it is compiled.
186 callbacks (list): A list of tf.keras.callbacks.Callback Sub-classes changing the behavior of the model during
187 training. Some standard callbacks are always present like:
188 * encodermap.callbacks.callbacks.ProgressBar:
189 A progress bar callback using tqdm giving the current progress of training and the
190 current loss.
191 * CheckPointSaver:
192 A callback that saves the model every parameters.checkpoint_step steps into
193 the main directory. This callback will only be used, when `read_only` is False.
194 * TensorboardWriteBool:
195 A callback that contains a boolean Tensor that will be True or False,
196 depending on the current training step and the summary_step in the parameters class. The loss
197 functions use this callback to decide whether they should write to Tensorboard. This callback
198 will only be present, when `read_only` is False and `parameters.tensorboard` is True.
199 You can append your own callbacks to this list before executing Autoencoder.train().
200 encoder (tf.keras.models.Model): The encoder (sub)model of `model`.
201 decoder (tf.keras.models.Model): The decoder (sub)model of `model`.
203 Methods:
204 from_checkpoint: Rebuild the model from a checkpoint.
205 add_images_to_tensorboard: Make tensorboard plot images.
206 train: Starts the training of the tf.keras.models.Model.
207 plot_network: Tries to plot the network. For this method to work graphviz, pydot and pydotplus needs to be installed.
208 encode: Takes high-dimensional data and sends it through the encoder.
209 decode: Takes low-dimensional data and sends it through the encoder.
210 generate: Same as decode. For AngleDihedralCartesianAutoencoder classes this will build a protein strutcure.
212 Note:
213 Performance of tensorflow is not only dependant on your system's hardware and how the data is presented to
214 the network (for this check out https://www.tensorflow.org/guide/data_performance), but also how you compiled
215 tensorflow. Normal tensorflow (pip install tensorflow) is build without CPU extensions to work on many CPUs.
216 However, Tensorflow can greatly benefit from using CPU instructions like AVX2, AVX512 that bring a speed-up
217 in linear algebra computations of 300%. By building tensorflow from source you can activate these extensions.
218 However, the CPU speed-up is dwarfed by the speed-up when you allow tensorflow to run on your GPU (grapohics
219 card). To check whether a GPU is available run:
220 `print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))`.
221 Refer to these pages to install tensorflow for best performance:
222 https://www.tensorflow.org/install/pip, https://www.tensorflow.org/install/gpu
224 Examples:
225 >>> import encodermap as em
226 >>> # without providing any data, default parameters and a 4D hypercube as input data will be used.
227 >>> e_map = em.EncoderMap(read_only=True)
228 >>> print(e_map.train_data.shape)
229 (16000, 4)
230 >>> print(e_map.dataset)
231 <BatchDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float32, name=None), TensorSpec(shape=(None, 4), dtype=tf.float32, name=None))>
232 >>> print(e_map.encode(e_map.train_data).shape)
233 (16000, 2)
235 """
237 def __init__(
238 self,
239 parameters=None,
240 train_data: Optional[Union[np.ndarray, tf.Dataset]] = None,
241 model=None,
242 read_only=False,
243 sparse=False,
244 ):
245 """Instantiate the Autoencoder class.
247 Args:
248 parameters (Union[encodermap.Parameters, None], optional): The parameters to be used. If None is
249 provided default values (check them with print(em.Parameters.defaults_description()))
250 are used. Defaults to None.
251 train_data (Union[np.ndarray, tf.data.Dataset, None], optional): The train data. Can be one of the following:
252 * None: If None is provided points on the edges of a 4-dimensional hypercube will be used as train data.
253 * np.ndarray: If a numpy array is provided, it will be transformed into a batched tf.data.Dataset by
254 first making it an infinitely repeating dataset, shuffling it and the batching it with a batch
255 size specified by parameters.batch_size.
256 * tf.data.Dataset: If a dataset is provided it will be used without making any adjustments. Make
257 sure, that the dataset uses `float32` as its type.
258 Defaults to None.
259 model (Union[tf.keras.models.Model, None], optional): Providing a keras model to this argument will make
260 the Autoencoder/EncoderMap class use this model instead of the predefined ones. Make sure the model
261 can accept EncoderMap's loss functions. If None is provided the model will be built using
262 the specifications in parameters. Defaults to None.
263 read_only (bool, optional): Whether the class is allowed to write to disk (False) or not (True). Defaults
264 to False and will allow the class to write to disk.
266 Raises:
267 BadError: When read_only is `True` and `parameters.tensorboard` is `True`, this Exception will be raised,
268 because they are mutually exclusive.
270 """
271 # parameters
272 if parameters is None:
273 self.p = Parameters()
274 else:
275 self.p = parameters
277 if self.p.seed is not None: 277 ↛ 278line 277 didn't jump to line 278, because the condition on line 277 was never true
278 tf.random.set_seed(self.p.seed)
279 self.read_only = read_only
281 if not self.read_only: 281 ↛ 282line 281 didn't jump to line 282, because the condition on line 281 was never true
282 self.p.save()
283 print(
284 "Output files are saved to {}".format(self.p.main_path),
285 "as defined in 'main_path' in the parameters.",
286 )
288 # check whether Tensorboard and Read-Only makes Sense
289 if self.read_only and self.p.tensorboard: 289 ↛ 290line 289 didn't jump to line 290, because the condition on line 289 was never true
290 raise BadError(
291 "Setting tensorboard and read_only True is not possible. Tensorboard will always write to disk."
292 " If you received this Error while loading a trained model, pass read_only=False as an argument"
293 f" or set overwrite_tensorboard_bool True to overwrite the tensorboard parameter."
294 )
296 # clear old sessions
297 tf.keras.backend.clear_session()
298 self.sparse = sparse
300 # set up train_data
301 if train_data is None:
302 self.train_data = create_n_cube(4, seed=self.p.seed)[0].astype("float32")
303 self.p.periodicity = float("inf")
304 elif isinstance(train_data, np.ndarray): 304 ↛ 315line 304 didn't jump to line 315, because the condition on line 304 was never false
305 if np.any(np.isnan(train_data)): 305 ↛ 306line 305 didn't jump to line 306, because the condition on line 305 was never true
306 self.sparse = True
307 print("Input contains nans. Using sparse network.")
308 indices = np.stack(np.where(~np.isnan(train_data))).T.astype("int64")
309 dense_shape = train_data.shape
310 values = train_data[~np.isnan(train_data)].flatten().astype("float32")
311 sparse_tensor = tf.sparse.SparseTensor(indices, values, dense_shape)
312 self.train_data = sparse_tensor
313 else:
314 self.train_data = train_data.astype("float32")
315 elif isinstance(train_data, tf.data.Dataset):
316 self.dataset = train_data
317 try:
318 for _, __ in self.dataset:
319 break
320 except ValueError:
321 if self.p.training == "auto":
322 print(
323 f"It seems like your dataset only yields tensors and not "
324 f"tuples of tensors. Tensorlfow is optimized for classification "
325 f"tasks, where datasets yield tuples of (data, classes). EncoderMap,"
326 f"however is a regression task, but uses the same code as the "
327 f"classification tasks. I will transform your dataset using "
328 f"the `tf.data.Dataset.zip()` function of `tf.data`. You can "
329 f"set the `training` parameter in the parameter class to "
330 f"'custom' to not alter your dataset."
331 )
332 self.dataset = tf.data.Dataset.zip((self.dataset, self.dataset))
333 for _, __ in self.dataset:
334 break
335 else:
336 for _ in self.dataset:
337 break
338 self.train_data = _
339 else:
340 raise TypeError(
341 f"train_data must be `None`, `np.ndarray` or `tf.data.Dataset`. You supplied {type(train_data)}."
342 )
344 # check data and periodicity
345 if not self.sparse and not train_data is None:
346 if np.any(train_data > self.p.periodicity): 346 ↛ 347line 346 didn't jump to line 347, because the condition on line 346 was never true
347 raise Exception(
348 "There seems to be an error regarding the periodicity "
349 f"of your data. The chosen periodicity is {self.p.periodicity}, "
350 f"but there are datapoints outwards of this range: {train_data.max()}"
351 )
353 # prepare the data
354 if isinstance(self.train_data, (np.ndarray, tf.sparse.SparseTensor)): 354 ↛ 383line 354 didn't jump to line 383, because the condition on line 354 was never false
355 if self.p.training == "auto": 355 ↛ 360line 355 didn't jump to line 360, because the condition on line 355 was never false
356 dataset = tf.data.Dataset.from_tensor_slices(
357 (self.train_data, self.train_data)
358 )
359 else:
360 dataset = tf.data.Dataset.from_tensor_slices(self.train_data)
361 dataset = dataset.shuffle(
362 buffer_size=self.train_data.shape[0], reshuffle_each_iteration=True
363 )
364 dataset = dataset.repeat()
365 self.dataset = dataset.batch(self.p.batch_size)
366 else:
367 pass
369 # ToDo: Make training faster with Autotune, XLA (jit) compilation, DataRecords
370 # self.dataset = self.dataset.prefetch(self.p.batch_size * 4)
371 # self.dataset = self.dataset.interleave(num_parallel_calls=tf.data.experimental.AUTOTUNE)
372 # num_batches = 20
373 # num_threads = 8
374 # if self.p.training == 'auto':
375 # ds = tf.data.Dataset.from_tensors((self.train_data, self.train_data)).repeat(self.p.batch_size * num_batches)
376 # else:
377 # ds = tf.data.Dataset.from_tensors(self.train_data).repeat(self.p.batch_size * num_batches)
378 # # ds = tf.data.Dataset.from_tensors(self.train_data).repeat(self.p.batch_size * num_batches)
379 # ds = ds.batch(self.p.batch_size)
380 # self.dataset = ds.interleave(lambda *args:tf.data.Dataset.from_tensor_slices(args), num_threads, 1, num_threads)
382 # create model based on user input
383 if model is None: 383 ↛ 386line 383 didn't jump to line 386, because the condition on line 383 was never false
384 self.model = self.p.model_api
385 else:
386 self._model = model
388 # setup callbacks for nice progress bars and saving every now and then
389 self._setup_callbacks()
391 # create loss based on user input
392 self.loss = self.p.loss
394 # choose optimizer
395 self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.p.learning_rate)
397 # compile model
398 self.model.compile(
399 optimizer=self.optimizer, loss=self.loss, metrics=self.metrics
400 )
402 # do this if tensorboard is true.
403 if self.p.tensorboard: 403 ↛ 404line 403 didn't jump to line 404, because the condition on line 403 was never true
404 self._log_images = False
405 # get the output from model summary.
406 with Capturing() as output:
407 self.model.summary()
408 with open(self.p.main_path + "/model_summary.txt", "w") as f:
409 f.write("\n".join(output))
410 tf.keras.utils.plot_model(
411 self.model,
412 to_file=self.p.main_path + "/model_summary.png",
413 show_shapes=True,
414 rankdir="LR",
415 expand_nested=True,
416 )
417 print(
418 f"Saved a text-summary of the model and an image in {self.p.main_path},",
419 "as specified in 'main_path' in the parameters.",
420 )
422 # sets up the tb callback to plot the model
423 self.tb_callback = tf.keras.callbacks.TensorBoard(
424 self.p.main_path, write_graph=True
425 )
426 self.tb_callback.set_model(self.model)
428 def _setup_callbacks(self):
429 """Sets up a list with callbacks to be passed to self.model.fit()"""
430 self.metrics = []
431 self.callbacks = []
432 self.callbacks.append(ProgressBar(parameters=self.p))
433 if not self.read_only:
434 self.callbacks.append(CheckpointSaver(self.p))
435 if self.p.tensorboard: 435 ↛ 436line 435 didn't jump to line 436, because the condition on line 435 was never true
436 self.tensorboard_write_bool = TensorboardWriteBool(self.p)
437 self.callbacks.append(self.tensorboard_write_bool)
438 file_writer = tf.summary.create_file_writer(self.p.main_path + "/train")
439 file_writer.set_as_default()
440 tf.summary.text(
441 name=f"Parameters Summary for {self.p.main_path}",
442 data=self.p.parameters,
443 step=0,
444 )
445 # callbacks.append(self.tb_callback)
446 else:
447 self.tensorboard_write_bool = None
449 @classmethod
450 def from_checkpoint(
451 cls,
452 checkpoint_path,
453 read_only=True,
454 overwrite_tensorboard_bool=False,
455 sparse=False,
456 ):
457 """Reconstructs the class from a checkpoint.
459 Args:
460 Checkpoint path (str): The path to the checkpoint. Most models are saved in parts (encoder, decoder)
461 and thus the provided path often needs a wildcard (*). The `save()` method of this class prints
462 a string with which the model can be reloaded.
463 read_only (bool, optional): Whether to reload the model in read_only mode (True) or allow the `Autoencoder`
464 class to write to disk (False). This option might collide with the tensorboard Parameter in the
465 respective parameters.json file in the maith_path. Defaults to True.
466 overwrite_tensorboard_bool (bool, optional): Whether to overwrite the tensorboard Parameter while reloading
467 the class. This can be set to True to set the tensorboard parameter False and allow read_only.
468 Defaults to False.
470 Raises:
471 BadError: When read_only is True, overwrite_tensorboard_bool is False and the reloaded parameters
472 have tensorboard set to True.
474 Returns:
475 Autoencoder: Encodermap `Autoencoder` class.
477 """
478 return load_model(
479 cls, checkpoint_path, read_only, overwrite_tensorboard_bool, sparse=sparse
480 )
482 @property
483 def model(self):
484 """tf.keras.models.Model: The tf.keras.Model model used for training."""
485 return self._model
487 @model.setter
488 def model(self, model):
489 """sets self.model according to `model_api` argument in self.parameters."""
490 if model == "functional":
491 for d in self.dataset: 491 ↛ 493line 491 didn't jump to line 493, because the loop on line 491 didn't complete
492 break
493 if any([isinstance(_, tf.sparse.SparseTensor) for _ in d]): 493 ↛ 494line 493 didn't jump to line 494, because the condition on line 493 was never true
494 self.sparse = True
495 self._model = gen_functional_model(self.dataset, self.p, sparse=self.sparse)
496 elif model == "sequential": 496 ↛ 502line 496 didn't jump to line 502, because the condition on line 496 was never false
497 if isinstance(self.train_data, tf.sparse.SparseTensor): 497 ↛ 498line 497 didn't jump to line 498, because the condition on line 497 was never true
498 self.sparse = True
499 self._model = gen_sequential_model(
500 self.train_data.shape[1], self.p, sparse=self.sparse
501 )
502 elif model == "custom":
503 raise NotImplementedError("No custom API currently supported")
504 else:
505 raise ValueError(
506 f"API argument needs to be one of `functional`, `sequential`, `custom`. You provided '{model}'."
507 )
509 @property
510 def encoder(self):
511 """tf.keras.models.Model: Encoder part of the model."""
512 return self._model.encoder_model
514 @property
515 def decoder(self):
516 """tf.keras.models.Model: Decoder part of the model."""
517 return self._model.decoder_model
519 @property
520 def loss(self):
521 """(Union[list, string, function]): A list of loss functions passed to the model when it is compiled.
522 When the main Autoencoder class is used and parameters.loss is 'emap_cost' this list is comprised of
523 center_cost, regularization_cost, auto_cost. When the EncoderMap sub-class is used and parameters.loss is
524 'emap_cost' distance_cost is added to the list. When parameters.loss is not 'emap_cost', the loss can either
525 be a string ('mse'), or a function, that both are acceptable arguments for loss, when a keras model
526 is compiled.
528 """
529 return self._loss
531 @loss.setter
532 def loss(self, loss):
533 """sets self.loss according to `loss` in self.parameters."""
534 if loss == "reconstruction_loss":
535 self._loss = reconstruction_loss(self.model)
536 elif loss == "emap_cost":
537 self.auto_loss = auto_loss(self.model, self.p, self.tensorboard_write_bool)
538 self.regularization_loss = regularization_loss(
539 self.model, self.p, self.tensorboard_write_bool
540 )
541 self.center_loss = center_loss(
542 self.model, self.p, self.tensorboard_write_bool
543 )
544 self._loss = [self.auto_loss, self.regularization_loss, self.center_loss]
545 elif loss == "mse":
546 self._loss = "mse"
547 else:
548 raise ValueError(
549 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'."
550 )
552 def train(self):
553 """Starts the training of the model."""
554 if self.p.training == "custom" and self.p.batched:
555 raise NotImplementedError()
556 elif self.p.training == "custom" and not self.p.batched:
557 raise NotImplementedError()
558 elif self.p.training == "auto":
559 if self.p.tensorboard and self._log_images:
560 # get the old backend because the Tensorboard Images callback will set 'Agg'
561 old_backend = matplotlib.get_backend()
562 # start_time = time.perf_counter()
563 self.history = self.model.fit(
564 self.dataset,
565 batch_size=self.p.batch_size,
566 epochs=self.p.n_steps,
567 steps_per_epoch=1,
568 verbose=0,
569 callbacks=self.callbacks,
570 )
571 # print("Execution time:", time.perf_counter() - start_time)
572 else:
573 raise ValueError(
574 f"training argument needs to be `auto` or `custom`. You provided '{self.training}'."
575 )
576 self.save(step=self.p.n_steps)
577 # reset the backend.
578 if self.p.tensorboard and self._log_images:
579 matplotlib.use(old_backend)
581 def add_images_to_tensorboard(
582 self,
583 data=None,
584 image_step=None,
585 scatter_kws={"s": 20},
586 hist_kws={"bins": 50},
587 additional_fns=None,
588 when="epoch",
589 ):
590 """Adds images to Tensorboard using the data in data and the ids in ids.
592 Args:
593 data (Union[np.ndarray, list, None], optional): The input-data will be passed through the encoder
594 part of the autoencoder. If None is provided a set of 10000 points from the provided
595 train data will be taken. A list is needed for the functional API of the ADCAutoencoder, that takes
596 a list of [angles, dihedrals, side_dihedrals]. Defaults to None.
597 image_step (Union[int, None], optional): The interval in which to plot images to tensorboard.
598 If None is provided, the update step will be the same as parameters.summary_step. Defaults to None.
599 scatter_kws (dict, optional): A dict with items that matplotlib.pyplot.scatter() will accept. Defaults to
600 {'s': 20}, which sets an appropriate size of scatter points for the size of datasets encodermap is
601 usually used for.
602 hist_kws (dict, optional): A dict with items that matplotlib.pyplot.scatter() will accept. You can
603 choose a colorbar here. Defaults to {'bins': 50} which sets an appropriate bin count for the
604 size of datasets encodermap is usually used for.
605 additional_fns (Union[list, None], optional): A list of functions that will accept the low-dimensional
606 output of the autoencoder's latent/bottleneck layer and return a tf.Tensor that can be logged
607 by `tf.summary.image()`. See the notebook 'writing_custom_images_to_tensorboard.ipynb' in
608 tutorials/notebooks_customization for more info. If None is provided no additional functions will be
609 used to plot to tensorboard. Defaults to None.
610 when (str, optional): When to log the images can be either 'batch', then the images will be logged after
611 every step during training, or 'epoch', then only after every image_step epoch the images will be
612 written. Defaults to 'epoch'.
614 """
615 if not self.p.tensorboard:
616 print(
617 "Nothing is written to Tensorboard for this Model. Please change parameters.tensorboard to True."
618 )
619 return
620 if image_step is None:
621 image_step = self.p.summary_step
623 self._log_images = True
625 # make a dataset for images
626 if data is None:
627 if isinstance(self.train_data, np.ndarray):
628 data = self.train_data
629 elif isinstance(self.train_data, list) or self.sparse:
630 data = self.train_data
631 else:
632 data = list(self.dataset.take(int(10000 / self.p.batch_size)))
633 data = np.stack(data)[:, 0, :].reshape(-1, self.train_data.shape[1])
634 else:
635 if type(data) != type(self.train_data):
636 raise Exception(
637 f"Provided data has wrong type. Train data in this class is {type(self.train_data)}, provided data is {type(data)}"
638 )
640 self.callbacks.append(
641 ImageCallback(
642 data,
643 image_step,
644 scatter_kws=scatter_kws,
645 hist_kws=hist_kws,
646 additional_fns=additional_fns,
647 when=when,
648 )
649 )
650 if isinstance(data, (np.ndarray, tf.sparse.SparseTensor)):
651 print(
652 f"Logging images with {data.shape}-shaped data every {image_step} epochs to Tensorboard at {self.p.main_path}"
653 )
654 else:
655 print(
656 f"Logging images with {[i.shape for i in data]}-shaped data every {image_step} epochs to Tensorboard at {self.p.main_path}"
657 )
659 def plot_network(self):
660 """Tries to plot the network using pydot, pydotplus and graphviz. Doesn't raise an exception if plotting is
661 not possible.
663 Note:
664 Refer to this guide to install these programs:
665 https://stackoverflow.com/questions/47605558/importerror-failed-to-import-pydot-you-must-install-pydot-and-graphviz-for-py
667 """
668 try:
669 plot_model(self.model, self.train_data.shape[1])
670 except:
671 pass
673 def encode(self, data=None):
674 """Calls encoder part of model.
676 Args:
677 data (Union[np.ndarray, None], optional): The data to be passed top the encoder part.
678 Can be either numpy ndarray or None. If None is provided a set of 10000 points from the provided
679 train data will be taken. Defaults to None.
681 Returns:
682 np.ndarray: The output from the bottlenack/latent layer.
684 """
685 if data is None:
686 data = self.train_data
687 if hasattr(self.model, "encoder"): 687 ↛ 689line 687 didn't jump to line 689, because the condition on line 687 was never false
688 out = self.model.encoder(data)
689 elif hasattr(self.model, "encoder_model"):
690 out = self.model.encoder_model(data)
691 if isinstance(out, list): 691 ↛ 692line 691 didn't jump to line 692, because the condition on line 691 was never true
692 out = [o.numpy() for o in out]
693 else:
694 out = out.numpy()
695 return out
697 def generate(self, data):
698 """Duplication of decode.
700 In Autoencoder and EncoderMap this method is equivalent to `decode()`. In AngleDihedralCartesianAutoencoder
701 this method will be overwritten to produce output molecular conformations.
703 Args:
704 data (np.ndarray): The data to be passed to the decoder part of the model. Make sure that the
705 shape of the data matches the number of neurons in the latent space.
707 Returns:
708 np.ndarray: Oue output from the decoder part.
710 """
711 return self.model.decoder(data)
713 def decode(self, data):
714 """Calls the decoder part of the model.
716 AngleDihedralCartesianAutoencoder will, like the other two classes' output a tuple of data.
718 Args:
719 data (np.ndarray): The data to be passed to the decoder part of the model. Make sure that the
720 shape of the data matches the number of neurons in the latent space.
722 Returns:
723 np.ndarray: Oue output from the decoder part.
724 """
725 out = self.decoder(data)
726 if isinstance(out, list):
727 out = [o.numpy() for o in out]
728 else:
729 out = out.numpy()
730 return out
732 def save(self, step=None):
733 """Saves the model to the current path defined in `parameters.main_path`.
735 Args:
736 step (Union[int, None], optional): Does not actually save the model at the given training step, but rather
737 changes the string used for saving the model from an datetime format to another.
739 """
740 if not self.read_only:
741 save_model(self.model, self.p.main_path, self.__class__.__name__, step=step)
743 def close(self):
744 """Clears the current keras backend and frees up resources."""
745 # clear old sessions
746 tf.keras.backend.clear_session()
749class EncoderMap(Autoencoder):
750 """Complete copy of Autoencoder class but uses additional distance cost
751 scaled by the SketchMap sigmoid params"""
753 @classmethod
754 def from_checkpoint(
755 cls,
756 checkpoint_path,
757 read_only=True,
758 overwrite_tensorboard_bool=False,
759 sparse=False,
760 ):
761 """Reconstructs the model from a checkpoint."""
762 # Is this classmethod necessary? We need to make sure the class knows all losses.
763 # And I don't know if the parent class calls the correct loss.setter
764 return load_model(
765 cls, checkpoint_path, read_only, overwrite_tensorboard_bool, sparse=sparse
766 )
768 @Autoencoder.loss.setter
769 def loss(self, loss):
770 if loss == "reconstruction_loss": 770 ↛ 771line 770 didn't jump to line 771, because the condition on line 770 was never true
771 self._loss = reconstruction_loss(self.model)
772 elif loss == "emap_cost": 772 ↛ 790line 772 didn't jump to line 790, because the condition on line 772 was never false
773 self.auto_loss = auto_loss(self.model, self.p, self.tensorboard_write_bool)
774 self.regularization_loss = regularization_loss(
775 self.model, self.p, self.tensorboard_write_bool
776 )
777 self.center_loss = center_loss(
778 self.model, self.p, self.tensorboard_write_bool
779 )
780 # this cost is new
781 self.distance_loss = distance_loss(
782 self.model, self.p, self.tensorboard_write_bool
783 )
784 self._loss = [
785 self.auto_loss,
786 self.regularization_loss,
787 self.center_loss,
788 self.distance_loss,
789 ]
790 elif loss == "mse":
791 self._loss = "mse"
792 else:
793 raise ValueError(
794 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'."
795 )
798class DihedralEncoderMap(EncoderMap):
799 """Similar to the `EncoderMap` class, but overwrites the `generate` method.
801 Using this class, instead of tbe `EncoderMap` class, the `generate` method,
802 needs an additional argument: `top`, which should be a topology file. This
803 topology will be used as a base on which the dihedrals of the `decode`
804 method are applied.
806 """
808 def generate(self, data: np.ndarray, top: str) -> MDAnalysis.Universe:
809 """Overwrites `EncoderMap`'s generate method and actually does backmapping if a list of dihedrals is
810 provided.
812 Args:
813 data (np.ndarray): The low-dimensional/latent/bottleneck data. A ndim==2 numpy array with xy coordinates
814 of points in latent space.
815 top (str): Topology file for this run of EncoderMap (can be .pdb, .gro, .. etc.).
817 Returns:
818 MDAnalysis.Universe: The topology with the provided backbone torsions.
820 Examples:
821 >>> # get some time-resolved pdb files
822 >>> import requests
823 >>> import numpy as np
824 >>> pdb_link = 'https://files.rcsb.org/view/1YUF.pdb'
825 >>> contents = requests.get(pdb_link).text
826 >>> print(contents.splitlines()[0]) # doctest: +SKIP
827 HEADER GROWTH FACTOR 01-APR-96 1YUF
828 >>> # fake a file with stringio
829 >>> from io import StringIO
830 >>> import MDAnalysis as mda
831 >>> import numpy as np
832 >>> file = StringIO(contents)
833 >>> # pass it to MDAnalysis
834 >>> u = mda.Universe(file, format='PDB')
835 >>> print(u)
836 <Universe with 720 atoms>
837 >>> # select the atomgroups
838 >>> ags = [*[res.psi_selection() for res in u.residues],
839 ... *[res.omega_selection() for res in u.residues],
840 ... *[res.phi_selection() for res in u.residues]
841 ... ]
842 >>> # filter Nones
843 >>> ags = list(filter(lambda x: False if x is None else True, ags))
844 >>> print(ags[0][0]) # doctest: +SKIP
845 <Atom 3: C of type C of resname VAL, resid 1 and segid A and altLoc >
846 >>> # Run dihedral Angles
847 >>> from MDAnalysis.analysis.dihedrals import Dihedral
848 >>> R = np.deg2rad(Dihedral(ags).run().results.angles)
849 >>> print(R.shape)
850 (16, 147)
851 >>> # import EncoderMap and define parameters
852 >>> from encodermap.autoencoder import DihedralEncoderMap
853 >>> import encodermap as em
854 >>> parameters = em.Parameters(
855 ... dist_sig_parameters = (4.5, 12, 6, 1, 2, 6),
856 ... periodicity = 2*np.pi,
857 ... l2_reg_constant = 10.0,
858 ... summary_step = 5,
859 ... tensorboard = False,
860 ... )
861 >>> e_map = DihedralEncoderMap(parameters, R, read_only=True)
862 >>> print(e_map.__class__.__name__)
863 DihedralEncoderMap
864 >>> # get some low-dimensional data
865 >>> lowd = np.random.random((100, 2))
866 >>> # use the generate method to get a new MDAnalysis universe
867 >>> # but first remove the time resolution
868 >>> file = StringIO(contents.split('MODEL 2')[0])
869 >>> new = e_map.generate(lowd, file)
870 >>> print(new.trajectory.coordinate_array.shape)
871 (100, 720, 3)
872 >>> # check whether frame 0 of u and new_u are different
873 >>> for ts in u.trajectory:
874 ... a1 = ts.positions
875 ... break
876 >>> print(np.array_equal(a1, new.trajectory.coordinate_array[0]))
877 False
879 """
880 assert np.any(data)
881 dihedrals = self.decode(data)
882 assert np.any(dihedrals)
883 uni = dihedral_backmapping(top, dihedrals)
884 return uni
887class AngleDihedralCartesianEncoderMap(Autoencoder):
888 """Different `__init__` method, than Autoencoder Class. Uses callbacks to tune-in cartesian cost.
890 Overwritten methods: `_set_up_callbacks` and `generate`.
892 Examples:
893 >>> import encodermap as em
894 >>> # Load two trajectories
895 >>> xtcs = ["tests/data/1am7_corrected_part1.xtc", "tests/data/1am7_corrected_part2.xtc"]
896 >>> tops = ["tests/data/1am7_protein.pdb", "tests/data/1am7_protein.pdb"]
897 >>> trajs = em.load(xtcs, tops)
898 >>> print(trajs)
899 encodermap.TrajEnsemble object. Current backend is no_load. Containing 2 trajs. Not containing any CVs.
900 >>> # load CVs
901 >>> # This step can be omitted. The AngleDihedralCartesianEncoderMap class automatically loads CVs
902 >>> trajs.load_CVs('all')
903 >>> print(trajs.CVs['central_cartesians'].shape)
904 (51, 474, 3)
905 >>> print(trajs.CVs['central_dihedrals'].shape)
906 (51, 471)
907 >>> # create some parameters
908 >>> p = em.ADCParameters(periodicity=360, use_backbone_angles=True, use_sidechains=True,
909 ... cartesian_cost_scale_soft_start=(6, 12))
910 >>> # Standard is functional model, as it offers more flexibility
911 >>> print(p.model_api)
912 functional
913 >>> print(p.distance_cost_scale)
914 None
915 >>> # Instantiate the class
916 >>> e_map = em.AngleDihedralCartesianEncoderMap(trajs, p, read_only=True)
917 >>> # dataset contains these inputs:
918 >>> # central_angles, central_dihedrals, central_cartesians, central_distances, sidechain_dihedrals
919 >>> print(e_map.dataset)
920 <BatchDataset element_spec=(TensorSpec(shape=(None, 472), dtype=tf.float32, name=None), TensorSpec(shape=(None, 471), dtype=tf.float32, name=None), TensorSpec(shape=(None, 474, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 473), dtype=tf.float32, name=None), TensorSpec(shape=(None, 316), dtype=tf.float32, name=None))>
921 >>> # output from the model contains the following data:
922 >>> # out_angles, out_dihedrals, back_cartesians, pairwise_distances of inp cartesians, pairwise of back-mapped cartesians, out_side_dihedrals
923 >>> for data in e_map.dataset.take(1):
924 ... pass
925 >>> out = e_map.model(data)
926 >>> print([i.shape for i in out])
927 [TensorShape([256, 472]), TensorShape([256, 471]), TensorShape([256, 474, 3]), TensorShape([256, 112101]), TensorShape([256, 112101]), TensorShape([256, 316])]
928 >>> # get output of latent space by providing central_angles, central_dihedrals, sidehcain_dihedrals
929 >>> latent = e_map.encoder([data[0], data[1], data[-1]])
930 >>> print(latent.shape)
931 (256, 2)
932 >>> # Rebuild central_angles, central_dihedrals and sidechain_angles from latent
933 >>> dih, ang, side_dih = e_map.decode(latent)
934 >>> print(dih.shape, ang.shape, side_dih.shape)
935 (256, 472) (256, 471) (256, 316)
937 """
939 def __init__(
940 self,
941 trajs: encodermap.TrajEnsemble,
942 parameters: Optional[encodermap.ADCParameters] = None,
943 model: Optional[tensorflow.keras.Model] = None,
944 read_only: bool = False,
945 cartesian_loss_step: int = 0,
946 top: Optional[mdtraj.Topology] = None,
947 ) -> None:
948 """Instantiate the `AngleDihedralCartesianEncoderMap` class.
950 Args:
951 trajs (em.TrajEnsemble): The trajectories to be used as input. If trajs contain no CVs, correct CVs will be loaded.
952 parameters (Optional[em.ACDParameters]): The parameters for the current run. Can be set to None and the
953 default parameters will be used. Defaults to None.
954 model (Optional[tf.keras.models.Model]): The keras model to use. You can provide your own model
955 with this argument. If set to None, the model will be built to the specifications of parameters using
956 either the functional or sequential API. Defaults to None
957 read_only (bool): Whether to write anything to disk (False) or not (True). Defaults to False.
958 cartesian_loss_step (int, optional): For loading and re-training the model. The cartesian_distance_loss
959 is tuned in step-wise. For this the start step of the training needs to be accounted for. If the
960 scale of the cartesian loss should increase from epoch 6 to epoch 12 and the model is saved at
961 epoch 9, this argument should also be set to 9, to continue training with the correct scaling
962 factor. Defaults to 0.
964 """
965 # parameters
966 if parameters is None: 966 ↛ 967line 966 didn't jump to line 967, because the condition on line 966 was never true
967 self.p = ADCParameters()
968 else:
969 self.p = parameters
971 # seed
972 if self.p.seed is not None: 972 ↛ 973line 972 didn't jump to line 973, because the condition on line 972 was never true
973 tf.random.set_seed(self.p.seed)
975 # read_only
976 self.read_only = read_only
978 # will be saved and overwritten when loading.
979 self.cartesian_loss_step = cartesian_loss_step
981 # save params and create dir
982 if not self.read_only:
983 self.p.save()
984 print(
985 "Output files are saved to {}".format(self.p.main_path),
986 "as defined in 'main_path' in the parameters.",
987 )
989 # check whether Tensorboard and Read-Only makes Sense
990 if self.read_only and self.p.tensorboard: 990 ↛ 991line 990 didn't jump to line 991, because the condition on line 990 was never true
991 raise BadError(
992 "Setting tensorboard and read_only True is not possible. Tensorboard will always write to disk."
993 " If you received this Error while loading a trained model, pass read_only=False as an argument"
994 f" or set overwrite_tensorboard_bool True to overwrite the tensorboard parameter."
995 )
997 # clear old sessions
998 tf.keras.backend.clear_session()
1000 # get the CVs:
1001 if isinstance(trajs, str): 1001 ↛ 1002line 1001 didn't jump to line 1002, because the condition on line 1001 was never true
1002 self.trajs = TrajEnsemble([trajs], [top])
1003 else:
1004 self.trajs = trajs
1006 # load missing values
1007 should_be = set(
1008 [
1009 "central_angles",
1010 "central_cartesians",
1011 "central_dihedrals",
1012 "central_distances",
1013 "side_dihedrals",
1014 ]
1015 )
1017 if self.trajs.CVs_in_file:
1018 raise NotImplementedError(
1019 "Write a tf.data.Dataset.from_generator function in enocdermap.data using the data from the netCDF files"
1020 )
1021 elif self.trajs.CVs: 1021 ↛ 1027line 1021 didn't jump to line 1027, because the condition on line 1021 was never false
1022 missing = list(should_be - set(trajs.CVs.keys()))
1023 if missing != []: 1023 ↛ 1024line 1023 didn't jump to line 1024, because the condition on line 1023 was never true
1024 print("loading missing values: ", missing)
1025 self.trajs.load_CVs(missing, ensemble=False)
1026 else:
1027 self.trajs.load_CVs(list(should_be), ensemble=False)
1029 if not should_be - set(self.trajs.CVs.keys()) == set(): 1029 ↛ 1030line 1029 didn't jump to line 1030, because the condition on line 1029 was never true
1030 raise BadError(
1031 f"Could not load CVs. Should be {should_be}, but currenlty only {set(trajs.CVs.keys())} are loaded"
1032 )
1034 # define inputs
1035 self.sparse, self.train_data, self.inp_CV_data = self.get_train_data_from_trajs(
1036 self.trajs, self.p
1037 )
1039 # create dataset
1040 dataset = tf.data.Dataset.from_tensor_slices(
1041 (
1042 self.inp_CV_data["central_angles"],
1043 self.inp_CV_data["central_dihedrals"],
1044 self.inp_CV_data["central_cartesians"],
1045 self.inp_CV_data["central_distances"],
1046 self.inp_CV_data["side_dihedrals"],
1047 )
1048 )
1049 dataset = dataset.shuffle(
1050 buffer_size=self.inp_CV_data["central_cartesians"].shape[0],
1051 reshuffle_each_iteration=True,
1052 )
1053 dataset = dataset.repeat()
1054 self.dataset = dataset.batch(self.p.batch_size)
1056 # ToDo: Make training faster with Autotune, XLA (jit) compilation, DataRecords
1057 # self.dataset = self.dataset.prefetch(self.p.batch_size * 4)
1058 # self.dataset = self.dataset.interleave(num_parallel_calls=tf.data.experimental.AUTOTUNE)
1059 # num_batches = 20
1060 # num_threads = 8
1061 # ds = tf.data.Dataset.from_tensors(self.train_data).repeat(self.p.batch_size * num_batches)
1062 # ds = ds.batch(self.p.batch_size)
1063 # self.dataset = ds.interleave(lambda *args:tf.data.Dataset.from_tensor_slices(args), num_threads, 1, num_threads)
1065 # create model based on user input
1066 if model is None: 1066 ↛ 1069line 1066 didn't jump to line 1069, because the condition on line 1066 was never false
1067 self.model = self.p.model_api
1068 else:
1069 self._model = model
1071 # setup callbacks
1072 self._setup_callbacks()
1074 # create loss based on user input
1075 self.loss = self.p.loss
1077 # choose optimizer
1078 self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.p.learning_rate)
1080 # compile model
1081 self.model.compile(
1082 optimizer=self.optimizer, loss=self.loss, metrics=self.metrics
1083 )
1085 # do this if tensorboard is true.
1086 if self.p.tensorboard: 1086 ↛ 1088line 1086 didn't jump to line 1088, because the condition on line 1086 was never true
1087 # print shapes
1088 print("input shapes are:")
1089 print({k: v.shape for k, v in self.inp_CV_data.items()})
1090 # set _log_images to False to fix the backend after training
1091 self._log_images = False
1092 # get the output from model summary.
1093 with Capturing() as output:
1094 self.model.summary()
1095 with open(self.p.main_path + "/model_summary.txt", "w") as f:
1096 f.write("\n".join(output))
1097 try:
1098 tf.keras.utils.plot_model(
1099 self.model,
1100 to_file=self.p.main_path + "/model_summary.png",
1101 show_shapes=True,
1102 rankdir="TB",
1103 expand_nested=True,
1104 )
1105 except Exception as e:
1106 print(f"saving image gave error: {e}")
1107 # todo: add image of cat
1108 # from ..parameters import parameters as _p
1109 # cat_image = os.path.split(os.path.split(os.path.split(_p.__file__)[0])[0])[0] + '/pic/priscilla-du-preez-8NXmaXg5xL0-unsplash.jpg'
1110 # image = plt.imread(cat_image)
1111 # plt.imshow(image)
1112 # print(cat_image)
1113 print(
1114 f"Saved a text-summary of the model and an image in {self.p.main_path},",
1115 "as specified in 'main_path' in the parameters.",
1116 )
1118 # sets up the tb callback to plot the model
1119 self.tb_callback = tf.keras.callbacks.TensorBoard(
1120 self.p.main_path, write_graph=True
1121 )
1122 self.tb_callback.set_model(self.model)
1124 @staticmethod
1125 def get_train_data_from_trajs(trajs, p, attr="CVs"):
1126 if not any([np.isnan(x).any() for x in getattr(trajs, attr).values()]): 1126 ↛ 1132line 1126 didn't jump to line 1132, because the condition on line 1126 was never false
1127 inp_CV_data = {
1128 key: val.astype("float32") for key, val in getattr(trajs, attr).items()
1129 }
1130 sparse = False
1131 else:
1132 sparse = True
1133 print("Input contains nans. Using sparse network.")
1134 inp_CV_data = {
1135 key: val.astype("float32") for key, val in getattr(trajs, attr).items()
1136 }
1138 # squeeze, if xarray is provided
1139 if all([hasattr(v, "values") for v in inp_CV_data.values()]):
1140 inp_CV_data = {k: v.values.squeeze() for k, v in inp_CV_data.items()}
1142 for k, v in inp_CV_data.items():
1143 if np.any(np.isnan(v)):
1144 values = v
1145 if k == "central_cartesians":
1146 values = values.reshape(len(values), -1)
1147 indices = np.stack(np.where(~np.isnan(values))).T.astype("int64")
1148 dense_shape = values.shape
1149 values = values[~np.isnan(values)].flatten()
1150 sparse_tensor = tf.sparse.SparseTensor(indices, values, dense_shape)
1151 inp_CV_data[k] = sparse_tensor
1153 if not p.use_backbone_angles and not p.use_sidechains: 1153 ↛ 1154line 1153 didn't jump to line 1154, because the condition on line 1153 was never true
1154 train_data = inp_CV_data["central_dihedrals"]
1155 elif p.use_backbone_angles and not p.use_sidechains: 1155 ↛ 1156line 1155 didn't jump to line 1156
1156 train_data = [
1157 inp_CV_data["central_angles"],
1158 inp_CV_data["central_dihedrals"],
1159 ]
1160 if p.model_api == "sequential" and not sparse:
1161 train_data = np.hstack(train_data)
1162 elif p.use_backbone_angles and p.use_sidechains: 1162 ↛ 1171line 1162 didn't jump to line 1171, because the condition on line 1162 was never false
1163 train_data = [
1164 inp_CV_data["central_angles"],
1165 inp_CV_data["central_dihedrals"],
1166 inp_CV_data["side_dihedrals"],
1167 ]
1168 if p.model_api == "sequential" and not sparse: 1168 ↛ 1169line 1168 didn't jump to line 1169, because the condition on line 1168 was never true
1169 train_data = np.hstack(train_data)
1170 else:
1171 raise Exception(
1172 "Cannot train model with central dihedrals and side dihedrals only. Backbone angles are required."
1173 )
1175 # some checks for the length of the train data
1176 if p.model_api == "functional": 1176 ↛ 1184line 1176 didn't jump to line 1184, because the condition on line 1176 was never false
1177 if not p.use_backbone_angles and not p.use_sidechains: 1177 ↛ 1178line 1177 didn't jump to line 1178, because the condition on line 1177 was never true
1178 assert isinstance(train_data, tf.sparse.SparseTensor)
1179 elif p.use_backbone_angles and not p.use_sidechains: 1179 ↛ 1180line 1179 didn't jump to line 1180, because the condition on line 1179 was never true
1180 assert len(train_data) == 2
1181 else:
1182 assert len(train_data) == 3
1184 return sparse, train_data, inp_CV_data
1186 @classmethod
1187 def from_checkpoint(
1188 cls, trajs, checkpoint_path, read_only=True, overwrite_tensorboard_bool=False
1189 ):
1190 """Reconstructs the model from a checkpoint."""
1191 # Is this classmethod necessary? We need to make sure the class knows all losses.
1192 # And I don't know if the parent class calls the correct loss.setter
1193 return load_model(
1194 cls, checkpoint_path, read_only, overwrite_tensorboard_bool, trajs
1195 )
1197 def _setup_callbacks(self) -> None:
1198 """Overwrites the parent class' `_setup_callbacks` method.
1200 Due to the 'soft start' of the cartesian cost, the `cartesiand_increase_callback`
1201 needs to be added to the list of callbacks.
1203 """
1204 super(self.__class__, self)._setup_callbacks()
1205 if self.p.cartesian_cost_scale_soft_start != (None, None):
1206 self.cartesian_increase_callback = IncreaseCartesianCost(
1207 self.p, start_step=self.cartesian_loss_step
1208 )
1209 self.callbacks.append(self.cartesian_increase_callback)
1211 def save(self, step: Optional[int] = None) -> None:
1212 """Saves the model to the current path defined in `parameters.main_path`.
1214 Args:
1215 step (Optional[int]): Does not actually save the model at the given training step, but rather
1216 changes the string used for saving the model from an datetime format to another.
1218 """
1219 if not self.read_only:
1220 save_model(
1221 self.model,
1222 self.p.main_path,
1223 self.__class__.__name__,
1224 step=step,
1225 current_step=self.cartesian_loss_step,
1226 )
1228 @Autoencoder.loss.setter
1229 def loss(self, loss):
1230 if loss == "reconstruction_loss": 1230 ↛ 1231line 1230 didn't jump to line 1231, because the condition on line 1230 was never true
1231 self._loss = reconstruction_loss(self.model)
1232 elif loss == "emap_cost": 1232 ↛ 1277line 1232 didn't jump to line 1277, because the condition on line 1232 was never false
1233 self.dihedral_loss = dihedral_loss(
1234 self.model, self.p, self.tensorboard_write_bool
1235 )
1236 self.angle_loss = angle_loss(
1237 self.model, self.p, self.tensorboard_write_bool
1238 )
1239 if self.p.cartesian_cost_scale_soft_start != (None, None):
1240 self.cartesian_loss = cartesian_loss(
1241 self.model,
1242 self.cartesian_increase_callback,
1243 self.p,
1244 self.tensorboard_write_bool,
1245 )
1246 else:
1247 self.cartesian_loss = cartesian_loss(
1248 self.model, None, self.p, self.tensorboard_write_bool
1249 )
1251 self.distance_loss = distance_loss(
1252 self.model, self.p, self.tensorboard_write_bool
1253 )
1254 self.cartesian_distance_loss = cartesian_distance_loss(
1255 self.model, self.p, self.tensorboard_write_bool
1256 )
1257 self.center_loss = center_loss(
1258 self.model, self.p, self.tensorboard_write_bool
1259 )
1260 self.regularization_loss = regularization_loss(
1261 self.model, self.p, self.tensorboard_write_bool
1262 )
1263 self._loss = [
1264 self.dihedral_loss,
1265 self.angle_loss,
1266 self.cartesian_loss,
1267 self.distance_loss,
1268 self.cartesian_distance_loss,
1269 self.center_loss,
1270 self.regularization_loss,
1271 ]
1272 if self.p.use_sidechains: 1272 ↛ exitline 1272 didn't return from function 'loss', because the condition on line 1272 was never false
1273 self.side_dihedral_loss = side_dihedral_loss(
1274 self.model, self.p, self.tensorboard_write_bool
1275 )
1276 self._loss.append(self.side_dihedral_loss)
1277 elif loss == "mse":
1278 self._loss = "mse"
1279 else:
1280 raise ValueError(
1281 f"loss argument needs to be `reconstruction_loss`, `mse` or `emap_cost`. You provided '{loss}'."
1282 )
1284 def train(self) -> None:
1285 """Overrides the parent class' `train` method.
1287 After the training is finished, an additional file is written to disk,
1288 which saves the current epoch. In the event that training will continue,
1289 the current state of the soft-start cartesian cost is read from that file.
1291 """
1292 super(self.__class__, self).train()
1293 self.cartesian_loss_step += self.p.n_steps
1294 fname = f"{self.p.main_path}/saved_model_{self.p.n_steps}.model"
1295 with open(fname + "_current_step.txt", "w") as f:
1296 f.write(str(self.cartesian_loss_step))
1298 def encode(self, data=None):
1299 if hasattr(data, "_traj_file"): 1299 ↛ 1300line 1299 didn't jump to line 1300, because the condition on line 1299 was never true
1300 _, data, __ = self.get_train_data_from_trajs(data, self.p, attr="_CVs")
1301 elif hasattr(data, "traj_files"): 1301 ↛ 1302line 1301 didn't jump to line 1302, because the condition on line 1301 was never true
1302 _, data, __ = self.get_train_data_from_trajs(data, self.p)
1303 return super().encode(data)
1305 def generate(
1306 self,
1307 points: np.ndarray,
1308 top: Optional[str, int, mdtraj.Topology] = None,
1309 backend: Literal["mdtraj", "mdanalysis"] = "mdtraj",
1310 ) -> Union[MDAnalysis.Universe, mdtraj.Trajectory]:
1311 """Overrides the parent class' `generate` method and builds a trajectory.
1313 Instead of just providing data to `decode` using the decoder part of the
1314 network, this method also takes a molecular topology as its `top`
1315 argument. This topology is then used to rebuild a time-resolved
1316 trajectory.
1318 Args:
1319 points (np.ndarray): The low-dimensional points from which the
1320 trajectory should be rebuilt.
1321 top (Optional[str, int, mdtraj.Topology]): The topology to be used for rebuilding the
1322 trajectory. This should be a string pointing towards a <*.pdb,
1323 *.gro, *.h5> file. Alternatively, None can be provided, in which
1324 case, the internal topology (`self.top`) of this class is used.
1325 Defaults to None.
1326 backend (str): Defines what MD python package to use, to build the
1327 trajectory and also what type this method returns, needs to be
1328 one of the following:
1329 * "mdtraj"
1330 * "mdanalysis"
1332 Returns:
1333 Union[mdtraj.Trajectory, MDAnalysis.universe]: The trajectory after
1334 applying the decoded structural information. The type of this
1335 depends on the chosen `backend` parameter.
1337 """
1338 # get the output this can be done regardless
1339 out = self.decode(points)
1341 if top is None:
1342 top = self.trajs.top_files
1343 if len(top) > 1:
1344 print(
1345 f"Please specify which topology you would like to use for generating "
1346 f"conformations. You can either provide a `str` to a topology file "
1347 f"(file extension .pdb, .h5, .gro) on disk, or a `int` specifying the "
1348 f"`SingleTraj` object in this class' {self.trajs.n_trajs} trajs, or "
1349 f"you can also specify a `mdtraj.Topology` object."
1350 )
1351 return
1352 else:
1353 top = top[0]
1354 trajs = self.trajs
1355 if top not in self.trajs.top_files:
1356 raise Exception(
1357 "Provided topology was not used to train Encodermap."
1358 )
1360 # get the output
1361 if not self.p.use_backbone_angles and not self.p.use_sidechains:
1362 dihedrals = self.decode(points)
1363 elif self.p.use_backbone_angles and not self.p.use_sidechains:
1364 splits = [trajs.CVs["central_angles"].shape[1]]
1365 out = self.decode(points)
1366 if isinstance(out, np.ndarray):
1367 angles, dihedrals = np.split(out, splits, axis=1)
1368 elif self.p.use_backbone_angles and self.p.use_sidechains:
1369 splits = [
1370 trajs.CVs["central_angles"].shape[1],
1371 trajs.CVs["central_angles"].shape[1]
1372 + trajs.CVs["central_dihedrals"].shape[1],
1373 ]
1375 if isinstance(out, np.ndarray):
1376 angles, dihedrals, sidechain_dihedrals = np.array_split(
1377 out, splits, axis=1
1378 )
1379 else:
1380 angles, dihedrals, sidechain_dihedrals = out
1382 # in this case we can just use any traj from self.trajs
1383 traj = self.trajs
1385 else:
1386 if len(self.trajs.top_files) == 1:
1387 trajs = self.trajs
1388 if top not in self.trajs.top_files:
1389 raise Exception(
1390 "Provided topology was not used to train Encodermap."
1391 )
1392 else:
1393 if isinstance(top, str):
1394 pass
1395 elif isinstance(top, int):
1396 top_ = self.trajs[top].traj[0]
1397 top_.save_pdb("/tmp/tmp.pdb")
1398 top = "/tmp/tmp.pdb"
1399 elif isinstance(top, mdtraj.Topology):
1400 top.save_pdb("/tmp/tmp.pdb")
1401 top = "/tmp/tmp.pdb"
1402 else:
1403 raise TypeError(
1404 f"Provided type for `top` must be `str`, `int`, or `mdtraj.Topology`, "
1405 f"you provided {type(top)}."
1406 )
1408 # align the topology with the trajs in self.trajs
1409 from ..loading import features
1410 from ..loading.featurizer import UNDERSOCRE_MAPPING
1412 UNDERSOCRE_MAPPING = {v: k for k, v in UNDERSOCRE_MAPPING.items()}
1413 labels = {}
1414 feature_names = [
1415 "CentralCartesians",
1416 "CentralBondDistances",
1417 "CentralAngles",
1418 "CentralDihedrals",
1419 "SideChainDihedrals",
1420 ]
1422 for feature in feature_names:
1423 feature = getattr(features, feature)(top_.top, generic_labels=True)
1424 labels[UNDERSOCRE_MAPPING[feature.name]] = feature.describe()
1426 return_values = [
1427 "central_dihedrals",
1428 "central_angles",
1429 "side_dihedrals",
1430 ]
1431 splits = {}
1432 for i, k in enumerate(return_values):
1433 split = np.isin(
1434 self.trajs[0]._CVs.coords[k.upper()].values, labels[k]
1435 )
1436 splits[k] = split
1438 # split the output
1439 if not self.p.use_backbone_angles and not self.p.use_sidechains:
1440 dihedrals = out[:, splits["central_dihedrals"]]
1441 elif self.p.use_backbone_angles and not self.p.use_sidechains:
1442 dihedrals = out[1][:, splits["central_dihedrals"]]
1443 angles = out[2][:, splits["central_angles"]]
1444 elif self.p.use_backbone_angles and self.p.use_sidechains:
1445 dihedrals = out[1][:, splits["central_dihedrals"]]
1446 angles = out[0][:, splits["central_angles"]]
1447 sidechain_dihedrals = out[2][:, splits["side_dihedrals"]]
1449 # if the backend is mdanalysis we need to save the pdb
1450 if backend == "mdanalysis":
1451 top_.save_pdb("/tmp/tmp.pdb")
1452 top = "/tmp/tmp.pdb"
1453 else:
1454 # in this case we need to use a traj, which topolgy matches top
1455 for i, traj in self.trajs.itertrajs():
1456 if traj.top == top_.top:
1457 break
1458 else:
1459 raise Exception(
1460 "Could not find a trajectory in self.trajs, "
1461 "that matches the topology provided as `top`."
1462 )
1463 traj = deepcopy(traj)
1465 # do the backmapping
1466 if backend == "mdanalysis":
1467 uni = dihedral_backmapping(top, dihedrals, sidechains=sidechain_dihedrals)
1468 return uni
1469 elif backend == "mdtraj":
1470 traj = mdtraj_backmapping(top, dihedrals, sidechain_dihedrals, traj)
1471 return traj
1472 else:
1473 raise TypeError(
1474 f"backend must be 'mdtraj' or 'mdanalysis', but you provided {backend}"
1475 )