Coverage for encodermap/loading/delayed.py: 10%
243 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/loading/delayed.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Functions to use with the DaskFeaturizer class.
24"""
27################################################################################
28# Imports
29################################################################################
32# Future Imports at the top
33from __future__ import annotations
35# Standard Library Imports
36from pathlib import Path
38# Third Party Imports
39import numpy as np
40from MDAnalysis.coordinates.XTC import XTCReader
41from optional_imports import _optional_import
44################################################################################
45# Optional Imports
46################################################################################
49dask = _optional_import("dask")
50da = _optional_import("dask", "array")
51dd = _optional_import("dask", "dataframe")
52box_vectors_to_lengths_and_angles = _optional_import(
53 "mdtraj", "utils.unitcell.box_vectors_to_lengths_and_angles"
54)
55_dist_mic = _optional_import("mdtraj", "geometry._geometry._dist_mic")
56_dist = _optional_import("mdtraj", "geometry._geometry._dist")
57_dihedral_mic = _optional_import("mdtraj", "geometry._geometry._dihedral_mic")
58_dihedral = _optional_import("mdtraj", "geometry._geometry._dihedral")
59_angle_mic = _optional_import("mdtraj", "geometry._geometry._angle_mic")
60_angle = _optional_import("mdtraj", "geometry._geometry._angle")
61jit = _optional_import("numba", "jit")
62prange = _optional_import("numba", "prange")
63xr = _optional_import("xarray")
64md = _optional_import("mdtraj")
65h5py = _optional_import("h5py")
68################################################################################
69# Typing
70################################################################################
73# Standard Library Imports
74from typing import TYPE_CHECKING, Literal, Optional, Union, overload
77if TYPE_CHECKING:
78 # Third Party Imports
79 from dask.delayed import Delayed
81 # Encodermap imports
82 from encodermap.loading.featurizer import DaskFeaturizer
83 from encodermap.trajinfo.info_single import SingleTraj
86################################################################################
87# Utils
88################################################################################
91# @jit(parallel=True, nopython=True)
92def calc_bravais_box(box_info: np.ndarray) -> np.ndarray:
93 """Calculates the Bravais vectors from lengths and angles (in degrees).
95 Note:
96 This code is adapted from gyroid, which is licensed under the BSD
97 http://pythonhosted.org/gyroid/_modules/gyroid/unitcell.html
99 Args:
100 box_info (np.ndarray): The box info, where the columns are ordered as
101 follows: a, b, c, alpha, beta. gamma in degree.
103 Returns:
104 np.ndarray: The bravais vectors as a shape (n_frames, 3, 3) array.
106 """
107 a_length, b_length, c_length = box_info[:, :3].T
108 alpha, beta, gamma = box_info[:, 3:].T
110 alpha = alpha * np.pi / 180
111 beta = beta * np.pi / 180
112 gamma = gamma * np.pi / 180
114 a = np.zeros((3, len(a_length)), dtype="float32")
115 a[0] = a_length
116 b = np.zeros((3, len(b_length)), dtype="float32")
117 b[0] = b_length * np.cos(gamma)
118 b[1] = b_length * np.sin(gamma)
119 cx = c_length * np.cos(beta)
120 cy = c_length * (np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma)
121 cz = np.sqrt(c_length * c_length - cx * cx - cy * cy)
122 c = np.empty((3, len(c_length)), dtype="float32")
123 c[0] = cx
124 c[1] = cy
125 c[2] = cz
127 if not a.shape == b.shape == c.shape:
128 raise TypeError("Shape is messed up.")
130 # Make sure that all vector components that are _almost_ 0 are set exactly
131 # to 0
132 tol = 1e-6
133 for i in range(a.shape[0]):
134 for j in range(a.shape[1]):
135 if a[i, j] > -tol and a[i, j] < tol:
136 a[i, j] = 0.0
137 for i in range(b.shape[0]):
138 for j in range(b.shape[1]):
139 if b[i, j] > -tol and b[i, j] < tol:
140 b[i, j] = 0.0
141 for i in range(c.shape[0]):
142 for j in range(c.shape[1]):
143 if c[i, j] > -tol and c[i, j] < tol:
144 c[i, j] = 0.0
146 unitcell_vectors = np.ascontiguousarray(
147 np.swapaxes(np.dstack((a.T, b.T, c.T)), 1, 2)
148 )
150 return unitcell_vectors
153@dask.delayed(nout=4)
154def _load_xyz(traj, frame_indices):
155 """Distances in nm. Angles in degree."""
156 positions = np.empty(
157 shape=(len(frame_indices), traj.n_atoms, 3), dtype="float32", order="C"
158 )
159 time = np.empty(shape=(len(frame_indices)), dtype="float32", order="C")
160 unitcell_info = np.empty(shape=(len(frame_indices), 6), dtype="float32", order="C")
161 for i, ts in enumerate(traj[frame_indices]):
162 positions[i] = ts.positions
163 time[i] = ts.time
164 unitcell_info[i] = ts._unitcell
165 positions /= 10 # for some heretical reason, MDAnalysis uses angstrom
166 unitcell_info[:, :3] /= 10
167 unitcell_vectors = calc_bravais_box(unitcell_info)
169 return positions, time, unitcell_vectors, unitcell_info
172def load_xyz(
173 traj_file: str,
174 frame_indices: np.ndarray,
175 traj_num: Optional[int] = None,
176) -> tuple[da.array, da.array, da.array, da.array]:
177 if Path(traj_file).suffix == ".h5":
178 return load_xyz_from_h5(traj_file, frame_indices, traj_num)
179 if Path(traj_file).suffix != ".xtc":
180 raise Exception(
181 f"Currently only .xtc and .h5 trajectory files are supported. "
182 f"But adding more formats is easy. Raise an issue, if you want "
183 f"to have them added."
184 )
185 traj = XTCReader(traj_file)
186 n_atoms = traj.n_atoms
187 n_frames = len(frame_indices)
188 p, t, uv, ui = _load_xyz(traj, frame_indices)
189 p = da.from_delayed(
190 p,
191 shape=(n_frames, n_atoms, 3),
192 dtype="float32",
193 )
194 t = da.from_delayed(
195 t,
196 shape=(n_frames,),
197 dtype="float32",
198 )
199 uv = da.from_delayed(
200 uv,
201 shape=(n_frames, 3, 3),
202 dtype="float32",
203 )
204 ui = da.from_delayed(
205 ui,
206 shape=(n_frames, 6),
207 dtype="float32",
208 )
209 return p, t, uv, ui
212@dask.delayed(nout=4)
213def _load_xyz_from_h5(
214 traj_file: str,
215 frame_indices: np.ndarray,
216 traj_num: int,
217) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
218 """Distances in nm. Angles in degree.
220 Args:
221 traj_file (str): The file to load.
222 frame_indices (np.ndarray): An int array giving the positions to load.
223 traj_num (int): Which traj num the output should be put to.
225 Returns:
226 tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray: A four-tuple of np
227 arrays. The order of these arrays is:
228 * positions: Shape (len(frame_indices), 3): The xyz coordinates in nm.
229 * time: shape (len(frame_indices), ): The time in ps.
230 * unitcell_vectors: Shape (len(frame_indices), 3, 3): The unitcell vectors.
231 * unitcell_info: Shape (len(frame_indices), 6), where [:, :3] are
232 the unitcell lengths in nm and [:, 3:] are the unitcell angles
233 in degree.
235 """
236 keys = ["coordinates", "time", "cell_lengths", "cell_angles"]
237 if traj_num is not None:
238 keys_with_num = [f"{k}_{traj_num}" for k in keys]
239 else:
240 keys_with_num = keys
241 data = {}
242 with h5py.File(traj_file, "r") as f:
243 for k, out in zip(keys_with_num, keys):
244 if k not in f and out not in f:
245 data[out] = None
246 elif k in f:
247 data[out] = f[k][frame_indices]
248 elif out in f:
249 data[out] = f[out][frame_indices]
250 unitcell_info = np.empty(shape=(len(frame_indices), 6), dtype="float32", order="C")
251 unitcell_info[:, :3] = data["cell_lengths"]
252 unitcell_info[:, 3:] = data["cell_angles"]
253 unitcell_vectors = calc_bravais_box(unitcell_info)
254 return data["coordinates"], data["time"], unitcell_vectors, unitcell_info
257def load_xyz_from_h5(
258 traj_file: str,
259 frame_indices: np.ndarray,
260 traj_num: Optional[int] = None,
261) -> tuple[da.array, da.array, da.array, da.array]:
262 """Loads xyz coordinates and unitcell info from a block in a .h5 file.
264 Standard MDTraj h5 keys are:
265 ['cell_angles', 'cell_lengths', 'coordinates', 'time', 'topology']
267 Args:
268 traj_file (str): The file to load.
269 frame_indices (np.ndarray): An int array giving the positions to load.
270 traj_num (int): Which traj num the output should be put to.
272 Returns:
273 tuple[da.array, da.array, da.array, da.array]: A four-tuple of dask
274 arrays that contain dask delayeds. The order of these arrays is:
275 * positions: Shape (len(frame_indices), 3): The xyz coordinates in nm.
276 * time: shape (len(frame_indices), ): The time in ps.
277 * unitcell_vectors: Shape (len(frame_indices), 3, 3): The unitcell vectors.
278 * unitcell_info: Shape (len(frame_indices), 6), where [:, :3] are
279 the unitcell lengths in nm and [:, 3:] are the unitcell angles
280 in degree.
282 """
283 # Encodermap imports
284 from encodermap.trajinfo.info_all import HDF5GroupWrite
286 n_frames = len(frame_indices)
287 with HDF5GroupWrite(traj_file) as f:
288 if "topology" not in f.keys() and traj_num is not None:
289 top = f.read_topology(f"topology_{traj_num}")
290 else:
291 top = f.read_topology("topology")
292 n_atoms = top.n_atoms
293 p, t, uv, ui = _load_xyz_from_h5(traj_file, frame_indices, traj_num)
294 p = da.from_delayed(p, shape=(n_frames, n_atoms, 3), dtype="float32")
295 t = da.from_delayed(t, shape=(n_frames,), dtype="float32")
296 uv = da.from_delayed(uv, shape=(n_frames, 3, 3), dtype="float32")
297 ui = da.from_delayed(ui, shape=(n_frames, 6), dtype="float32")
298 return p, t, uv, ui
301################################################################################
302# Dask graph creation
303################################################################################
306@overload
307def build_dask_xarray( 307 ↛ exitline 307 didn't jump to the function exit
308 featurizer: DaskFeaturizer,
309 traj: Optional[SingleTraj],
310 streamable: bool,
311 return_delayeds: Literal[True],
312) -> tuple[xr.Dataset, dict[str, xr.Variable]]: ...
315@overload
316def build_dask_xarray( 316 ↛ exitline 316 didn't jump to the function exit
317 featurizer: DaskFeaturizer,
318 traj: Optional[SingleTraj],
319 streamable: bool,
320 return_delayeds: Literal[False],
321) -> tuple[xr.Dataset, None]: ...
324def build_dask_xarray(
325 featurizer: DaskFeaturizer,
326 traj: Optional[SingleTraj] = None,
327 streamable: bool = False,
328 return_delayeds: bool = False,
329) -> tuple[xr.Dataset, Union[None, dict[str, xr.Variable]]]:
330 """Builds a large dask xarray, which will be distributively evaluated.
332 This class takes a `DaskFeaturizer` class, which contains a list of features.
333 Every feature in this list contains enough information for the delayed functions
334 to calculate the requested quantities when provided the xyz coordinates of the
335 atoms, the unitcell vectors, and the unitcell infos as a Bravais matrix.
337 Args:
338 featurizer (DaskFeaturizer): An instance of the DaskFeaturizer.
339 return_coordinates (bool): Whether to add this information:
340 all_xyz, all_time, all_cell_lengths, all_cell_angles
341 to the returned values. Defaults to False.
342 streamable (bool): Whether to divide the calculations into one-frame
343 blocks, which can then only be calculated when requested.
345 Returns:
346 Union[xr.Dataset, tuple[xr.Dataset, list[dask.delayed]]:
347 When `return_coordinates` is False, only a xr.Dataset is returned.
348 Otherwise, a tuple with a xr.Dataset and a sequence of dask.Delayed
349 objects is returned.
352 """
353 # Imports
354 # Encodermap imports
355 from encodermap.loading.features import CustomFeature
356 from encodermap.misc.xarray import (
357 FEATURE_NAMES,
358 make_dataarray,
359 make_frame_CV_dataarray,
360 make_position_dataarray,
361 )
362 from encodermap.trajinfo.trajinfo_utils import trajs_combine_attrs
364 # definitions
365 coordinates = {
366 "coordinates": ["md_frame", "md_atom", "md_cart"],
367 "time": ["md_frame"],
368 "cell_lengths": ["md_frame", "md_length"],
369 "cell_angles": ["md_frame", "md_angle"],
370 }
372 # pre-define blocks from the trajectories
373 n_blocks = 10
375 # append delayeds here:
376 if return_delayeds:
377 delayeds = {}
378 else:
379 delayeds = None
381 assert len(featurizer.feat.active_features) > 0
383 # collect the Datasets in this list
384 DSs = []
386 # if the dask featurizer contains an `EnsembleFeaturizer`, we can use `itertrajs()`
387 if traj is None:
388 if hasattr(featurizer.feat, "trajs"):
389 iterable = featurizer.feat.trajs.itertrajs()
390 else:
391 iterable = enumerate([featurizer.feat.traj])
392 else:
393 iterable = enumerate([traj])
395 # iter over trajs or just the one
396 for i, traj in iterable:
397 n_frames = len(traj.id)
398 if not streamable:
399 n_frames_per_block = n_frames // n_blocks
400 blocks = [
401 np.arange(i * n_frames_per_block, (i + 1) * n_frames_per_block)
402 for i in range(n_blocks - 1)
403 ]
404 blocks.append(np.arange((n_blocks - 1) * n_frames_per_block, n_frames))
405 # remove empty blocks
406 blocks = list(filter(lambda x: x.size > 0, blocks))
407 else:
408 n_frames_per_block = 1
409 blocks = [[i] for i in range(n_frames)]
411 # collect multiple DataArrays here
412 DAs = {}
413 indexes = {}
415 if delayeds is not None:
416 xyz_traj = []
417 time_traj = []
418 lengths_traj = []
419 angles_traj = []
421 # distribute the loading to multiple workers
422 for j, block in enumerate(blocks):
423 # get the actual frame indices if the traj was sliced
424 if traj.id.ndim == 2:
425 frame_indices = traj.id[block, 1]
426 else:
427 frame_indices = traj.id[block]
429 assert len(frame_indices) > 0, f"{frame_indices=}"
430 xyz, time, unitcell_vector, unitcell_info = load_xyz(
431 traj.traj_file, frame_indices, traj.traj_num
432 )
434 if delayeds is not None:
435 unitcell_lengths = unitcell_info[:, :3]
436 unitcell_angles = unitcell_info[:, 3:]
437 xyz_traj.append(xyz)
438 time_traj.append(time)
439 lengths_traj.append(unitcell_lengths)
440 angles_traj.append(unitcell_angles)
442 # iterate over the features and let them use the traj information
443 if hasattr(featurizer.feat, "trajs"):
444 features = featurizer.feat.active_features[traj.top]
445 else:
446 features = featurizer.feat.active_features
448 for k, feat in enumerate(features):
449 # the name of the feature will be used for traceability
450 if not isinstance(feat, CustomFeature) or not issubclass(
451 feat.__class__, CustomFeature
452 ):
453 assert hasattr(feat, "dask_indices") and hasattr(
454 feat, "dask_transform"
455 ), (
456 f"For `feature.transform()` to be acceptable as delayed, "
457 f"the feature needs to implement the `dask_indices` property "
458 f"and `dask_transform` staticmethod. The feature {feat} has "
459 f"this these methods and attributes "
460 f"{[a for a in feat.__dir__() if not a.startswith('_')]}"
461 )
462 assert feat.delayed, (
463 f"The feature {feat} was not altered to return a delayed "
464 f"transform. Please read up in `encodermap.DaskFeaturizer` how "
465 f"to make features work with dask delayed."
466 )
468 # decide on the name
469 try:
470 name = FEATURE_NAMES[feat.name]
471 except (KeyError, AttributeError):
472 if hasattr(feat, "name"):
473 if isinstance(feat.name, str):
474 name = feat.name
475 if "mdtraj.trajectory" in feat.name.lower():
476 feat.name = feat.__class__.__name__
477 name = feat.__class__.__name__
478 else:
479 name = feat.__class__.__name__
480 feat.name = name
481 else:
482 name = feat.__class__.__name__
483 if name == "CustomFeature":
484 name = feat.describe()[0].split()[0]
485 feat.name = name
487 # the feature length is given by the describe() of the feature
488 if callable(feat.describe()):
489 feat_length = len([i for i in feat.describe()(traj.top)])
490 else:
491 feat_length = len(feat.describe())
493 # dynamically populate kwargs with feature settings
494 kwargs = {"indexes": getattr(feat, feat.dask_indices)}
495 if feat._use_periodic:
496 kwargs["periodic"] = feat.periodic
497 if feat._use_angle:
498 kwargs["deg"] = feat.deg
499 kwargs["cossin"] = feat.cossin
500 # if feat._use_omega:
501 # kwargs["omega"] = feat.omega
502 if hasattr(feat, "_nonstandard_transform_args"):
503 for k in feat._nonstandard_transform_args:
504 if not hasattr(feat, k):
505 kwargs[k] = None
506 else:
507 kwargs[k] = getattr(feat, k)
508 a = da.from_delayed(
509 feat.dask_transform(
510 **kwargs,
511 xyz=xyz,
512 unitcell_vectors=unitcell_vector,
513 unitcell_info=unitcell_info,
514 ),
515 shape=(len(frame_indices), feat_length),
516 dtype="float32",
517 )
519 if hasattr(feat, "deg"):
520 deg = feat.deg
521 else:
522 deg = None
524 if (
525 feat.name
526 in ["AllCartesians", "CentralCartesians", "SideChainCartesians"]
527 or feat.atom_feature
528 ):
529 a = da.reshape(a, (len(frame_indices), -1, 3))
530 a = da.expand_dims(a, axis=0)
531 if hasattr(featurizer, "indices_by_top"):
532 feat.indexes = featurizer.indices_by_top[traj.top][feat.name]
533 dataarray, ind_dataarray = make_position_dataarray(
534 feat.describe(),
535 traj[block],
536 name,
537 a,
538 deg=deg,
539 feat=feat,
540 )
541 else:
542 a = da.expand_dims(a, axis=0)
543 if feat._dim == 1:
544 dataarray, ind_dataarray = make_frame_CV_dataarray(
545 feat.describe(),
546 traj[block],
547 name,
548 a,
549 deg=deg,
550 feat=feat,
551 )
552 else:
553 if hasattr(featurizer, "indices_by_top"):
554 feat.indexes = featurizer.indices_by_top[traj.top][
555 feat.name
556 ]
557 dataarray, ind_dataarray = make_dataarray(
558 feat.describe(),
559 traj[block],
560 name,
561 a,
562 deg=deg,
563 feat=feat,
564 )
566 assert dataarray.size > 0, (
567 f"Dataarray created for feature {feat} provided with "
568 f"traj {traj} at frame indices {block} did not contain "
569 f"any data."
570 )
572 # append the DataArray to the DAs dictionary
573 DAs.setdefault(name, []).append(dataarray)
574 if ind_dataarray is not None:
575 indexes.setdefault(name + "_feature_indices", []).append(
576 ind_dataarray
577 )
578 else:
579 indexes[name + "_feature_indices"] = [None]
581 # after every traj, we combine the datasets
582 for key, value in DAs.items():
583 DAs[key] = xr.concat(
584 DAs[key],
585 "frame_num",
586 combine_attrs=trajs_combine_attrs,
587 )
588 # we only need any component from the indexes but make sure that
589 # they are homogeneous. Every block of a traj should return
590 # the same index array, as they don't depend on frame data
591 if indexes[key + "_feature_indices"][0] is None:
592 assert all(
593 [i is None for i in indexes[key + "_feature_indices"][1:]]
594 ), (
595 f"Got an inhomogeneous result for indexes for feature {feat=} "
596 f"at {frame_indices=} {indexes=}"
597 )
598 indexes[key + "_feature_indices"] = indexes.pop(key + "_feature_indices")[0]
600 # combine data per traj
601 DAs_and_indexes = DAs | indexes
602 DAs_and_indexes = {k: v for k, v in DAs_and_indexes.items() if v is not None}
603 try:
604 ds = xr.Dataset(
605 DAs_and_indexes,
606 attrs=trajs_combine_attrs(
607 [v.attrs if v.size > 0 else {} for v in DAs_and_indexes.values()]
608 ),
609 )
610 except xr.core.merge.MergeError as e:
611 raise Exception(f"{indexes=}") from e
612 DSs.append(ds)
614 # and add to the delayeds if needed
615 if delayeds is not None:
616 for (coord, dims), data, unit in zip(
617 coordinates.items(),
618 [xyz_traj, time_traj, lengths_traj, angles_traj],
619 ["nanometers", "picoseconds", "nanometers", "degrees"],
620 ):
621 name = f"{coord}_{traj.traj_num}"
622 delayeds[name] = xr.Variable(
623 dims=[f"{d}_{traj.traj_num}" for d in dims],
624 data=da.concatenate(data),
625 attrs={"units": unit.encode("utf-8")},
626 )
628 # make a large dataset out of this
629 ds = xr.concat(
630 DSs,
631 data_vars="all",
632 # compat="broadcast_equals",
633 # coords="all",
634 # join="outer",
635 dim="traj_num",
636 fill_value=np.nan,
637 combine_attrs=trajs_combine_attrs,
638 )
639 assert ds, (
640 f"Concatenation of chunked datasets yielded empty dataset.\n"
641 f"{DSs=}\n\n{DAs_and_indexes=}"
642 )
644 return ds, delayeds