Coverage for encodermap/trajinfo/trajinfo_utils.py: 82%
227 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/trajinfo/trajinfo_utils.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Util functions for the `TrajEnsemble` and `SingleTraj` classes.
24"""
27################################################################################
28# Imports
29################################################################################
32from __future__ import annotations
34import warnings
35from pathlib import Path
37import numpy as np
39from .._optional_imports import _optional_import
40from ..loading.featurizer import PyEMMAFeaturizer as Featurizer
41from ..misc.misc import FEATURE_NAMES
42from ..misc.xarray import construct_xarray_from_numpy
44##############################################################################
45# Optional Imports
46##############################################################################
49xr = _optional_import("xarray")
50Feature = _optional_import("pyemma", "coordinates.data.featurization._base.Feature")
51md = _optional_import("mdtraj")
54################################################################################
55# Typing
56################################################################################
59from typing import TYPE_CHECKING, Literal, Optional, Union
61if TYPE_CHECKING: 61 ↛ 62line 61 didn't jump to line 62, because the condition on line 61 was never true
62 import mdtraj as md
63 import xarray as xr
64 from pyemma.coordinates.data.featurization._base import Feature
66 from ..loading.featurizer import Featurizer
67 from .info_single import SingleTraj, TrajEnsemble
69 SingleTrajFeatureType = Union[
70 str, Path, np.ndarray, Feature, xr.Dataset, xr.DataArray, Featurizer
71 ]
72 TrajEnsembleFeatureType = Union[
73 list[str], list[Path], list[np.ndarray], xr.Dataset, Featurizer, Literal["all"]
74 ]
75 Index = Optional[
76 Union[tuple[int, list, np.ndarray, slice]], int, list, np.ndarray, slice
77 ]
80################################################################################
81# Globals
82################################################################################
85CAN_BE_FEATURE_NAME = list(FEATURE_NAMES.keys()) + list(FEATURE_NAMES.values())
86__all__ = ["load_CVs_singletraj", "load_CVs_ensembletraj"]
89################################################################################
90# Utils
91################################################################################
94def np_to_xr(
95 data: np.ndarray,
96 traj: SingleTraj,
97 attr_name: Optional[str] = None,
98 labels: Optional[list[str]] = None,
99 filename: Optional[Union[str, Path]] = None,
100) -> xr.DataArray:
101 if attr_name is None:
102 if filename is None:
103 msg = f"Please also provide an `attr_name` under which to save the CV."
104 raise Exception(msg)
105 attr_name = Path(filename).stem
107 if labels is not None: 107 ↛ 108line 107 didn't jump to line 108, because the condition on line 107 was never true
108 if isinstance(labels, str):
109 labels = [
110 f"{attr_name.upper()} {labels.upper()} {i}"
111 for i in range(data.shape[1])
112 ]
113 elif (
114 not all([isinstance(l, str) for l in labels])
115 and len(labels) == data.shape[1]
116 ):
117 pass
118 else:
119 raise Exception(
120 f"'labels' is either not a list of str or does not have the "
121 f"same number of datapoints as self.n_frames={data.n_frames}: {labels=}"
122 )
123 data = np.expand_dims(data, axis=0)
124 if np.any(np.isnan(data)):
125 # if some nans are found along frame remove them
126 if data.ndim == 2:
127 data = data[:, ~np.isnan(data).any(axis=0)]
128 if data.ndim == 3:
129 data = data[:, ~np.isnan(data).any(axis=2)[0]]
130 if data.ndim == 4:
131 data = data[:, ~np.isnan(data).any(axis=2)[0].any(axis=1)]
132 da = construct_xarray_from_numpy(traj, data, attr_name, labels, check_n_frames=True)
133 return da
136def load_CV_from_string_or_path(
137 file_or_feature: str,
138 traj: SingleTraj,
139 attr_name: Optional[str] = None,
140 cols: Optional[Union[int, list[int]]] = None,
141 labels: Optional[list[str]] = None,
142) -> xr.Dataset:
143 """Loads CV data from a string. That string can either identify a features,
144 or point to a file.
146 Args:
147 file_or_feature (str): The file or feature to load. If 'all' is
148 provided, all "standard" features are loaded. But a feature name
149 like 'sidechain_angle' can alsop be provided. If a file with
150 the .txt or .npy extension is provided, the data in that file is used.
151 traj (SingleTraj): The trajectory, that is used to load the features.
152 attr_name (Union[None, str], optional): The name under which the CV should be found in the class.
153 Is needed, if a raw numpy array is passed, otherwise the name will be generated from the filename
154 (if data == str), the DataArray.name (if data == xarray.DataArray), or the feature name.
155 cols (Union[list, None], optional): A list specifying the columns to use for the highD data.
156 If your highD data contains (x,y,z,...)-errors or has an enumeration
157 column at col=0 this can be used to remove this unwanted data.
158 labels (Union[list, str, None], optional): If you want to label the data you provided pass a list of str.
159 If set to None, the features in this dimension will be labelled as
160 [f"{attr_name.upper()} FEATURE {i}" for i in range(self.n_frames)]. If a str is provided, the features
161 will be labelled as [f"{attr_name.upper()} {label.upper()} {i}" for i in range(self.n_frames)]. If a list of str
162 is provided it needs to have the same length as the traj has frames. Defaults to None.
164 Returns:
165 xr.Dataset: An xarray dataset.
167 """
168 if str(file_or_feature) == "all" or str(file_or_feature) in CAN_BE_FEATURE_NAME:
169 feat = Featurizer(traj)
170 if file_or_feature == "all":
171 feat.add_list_of_feats("all")
172 else:
173 feat.add_list_of_feats([file_or_feature])
174 out = feat.get_output()
175 if traj.traj_num is not None:
176 assert out.coords["traj_num"] == np.array([traj.traj_num]), print(
177 traj.traj_num,
178 out.coords["traj_num"].values,
179 feat.trajs.trajs[0].traj_num,
180 )
181 return out
182 elif (f := Path(file_or_feature)).exists():
183 if f.suffix == ".txt":
184 data = np.loadtxt(f, usecols=cols)
185 elif f.suffix == ".npy":
186 data = np.load(f)
187 if cols is not None: 187 ↛ 188line 187 didn't jump to line 188, because the condition on line 187 was never true
188 data = data[:, cols]
189 elif f.suffix in [".nc", ".h5"]: 189 ↛ 210line 189 didn't jump to line 210, because the condition on line 189 was never false
190 data = xr.open_dataset(f)
191 if len(data.data_vars.keys()) != 1:
192 if attr_name is not None: 192 ↛ 203line 192 didn't jump to line 203, because the condition on line 192 was never false
193 raise Exception(
194 f"The dataset in {f} has "
195 f"{len(data.data_vars.keys())} dataarrays, "
196 f"but only one `attr_name`: '{attr_name}' "
197 f"was requested. The names of the dataarrays "
198 f"are: {data.data_vars.keys()}. I can't over"
199 f"ride them all with one `attr_name`. Set "
200 f"`attr_name` to None to load the data with "
201 f"their respective names"
202 )
203 return data
204 else:
205 if attr_name is not None: 205 ↛ 208line 205 didn't jump to line 208, because the condition on line 205 was never false
206 d = list(data.data_vars.values())[0]
207 d.name = attr_name
208 return d
209 else:
210 raise Exception(
211 f"Currently only .txt, .npy, .nc, and .h5 files can "
212 f"be loaded. Your file {f} does not have the "
213 f"correct extension."
214 )
215 else:
216 raise Exception(
217 f"If features are loaded via a string, the string needs "
218 f"to be 'all', a features name ('central_dihedrals') or "
219 f'an existing file. Your string "{file_or_feature}"'
220 f"is none of those"
221 )
223 return np_to_xr(data, traj, attr_name, labels, file_or_feature)
226def load_CVs_singletraj(
227 data: SingleTrajFeatureType,
228 traj: SingleTraj,
229 attr_name: Optional[str] = None,
230 cols: Optional[list[int]] = None,
231 labels: Optional[list[str]] = None,
232) -> xr.Dataset:
233 if isinstance(attr_name, str):
234 if not attr_name.isidentifier(): 234 ↛ 235line 234 didn't jump to line 235, because the condition on line 234 was never true
235 raise Exception(
236 f"Provided string for `attr_name` can not be a "
237 f"python identifier. Choose another attribute name."
238 )
239 # load a string
240 if isinstance(data, (str, Path)):
241 CVs = load_CV_from_string_or_path(str(data), traj, attr_name, cols, labels)
243 # load a list of strings from standard features
244 elif isinstance(data, list) and all([isinstance(_, str) for _ in data]):
245 feat = Featurizer(traj)
246 feat.add_list_of_feats(data)
247 return feat.get_output()
249 # if the data is a numpy array
250 elif isinstance(data, (list, np.ndarray)):
251 CVs = np_to_xr(np.asarray(data), traj, attr_name, labels).to_dataset()
253 # xarray objects are simply returned
254 elif isinstance(data, xr.Dataset):
255 return data
257 elif isinstance(data, xr.DataArray):
258 return data.to_dataset()
260 # if this is a feature
261 elif issubclass(data.__class__, Feature):
262 feat = Featurizer(traj)
263 feat.add_custom_feature(data)
264 return feat.get_output()
266 # if an instance of featurizer is provided
267 elif isinstance(data, Featurizer): 267 ↛ 301line 267 didn't jump to line 301, because the condition on line 267 was never false
268 if isinstance(attr_name, str):
269 if len(data) != 1:
270 raise TypeError(
271 f"Provided Featurizer contains {len(data)} "
272 f"features and `attr_name` is of type `str`. "
273 f"Please provide a list of str."
274 )
275 attr_name = [attr_name]
276 if isinstance(attr_name, list):
277 if len(attr_name) != len(data): 277 ↛ 278line 277 didn't jump to line 278, because the condition on line 277 was never true
278 raise IndexError(
279 f"Provided Featurizer contains {len(data)} "
280 f"features and `attr_name` contains "
281 f"{len(attr_name)} elements. Please make sure "
282 f"they contain the same amount of items."
283 )
284 out = data.get_output()
285 if attr_name is not None:
286 if isinstance(attr_name, str): 286 ↛ 287line 286 didn't jump to line 287, because the condition on line 286 was never true
287 attr_name = [attr_name]
288 _renaming = {}
289 for f, v in zip(data.features, attr_name):
290 _feature = False
291 if hasattr(f, "name"): 291 ↛ 295line 291 didn't jump to line 295, because the condition on line 291 was never false
292 if f.name in FEATURE_NAMES:
293 k = FEATURE_NAMES[f.name]
294 _feature = True
295 if not _feature:
296 k = f.__class__.__name__
297 _renaming[k] = v
298 out = out.rename_vars(_renaming)
299 return out
300 else:
301 raise TypeError(
302 f"`data` must be str, np.ndarray, list, xr.DataArray, xr.Dataset, "
303 f"em.Featurizer or em.features.Feature. You supplied "
304 f"{type(data)}."
305 )
307 return CVs
310def load_CVs_ensembletraj(
311 trajs: TrajEnsemble,
312 data: TrajEnsembleFeatureType,
313 attr_name: Optional[list[str]] = None,
314 cols: Optional[list[int]] = None,
315 labels: Optional[list[str]] = None,
316 directory: Optional[Union[Path, str]] = None,
317 ensemble: bool = False,
318) -> None:
319 if isinstance(data, (str, Path)) and not ensemble:
320 path_data = Path(data)
321 npy_files = [
322 (t._traj_file.parent if directory is None else Path(directory))
323 / (t.basename + f"_{data}.npy")
324 for t in trajs
325 ]
326 txt_files = [
327 (t._traj_file.parent if directory is None else Path(directory))
328 / (t.basename + f"_{data}.txt")
329 for t in trajs
330 ]
331 raw_files = [
332 (t._traj_file.parent if directory is None else Path(directory))
333 / (t.basename + f"_{data}")
334 for t in trajs
335 ]
336 if str(data) == "all":
337 [t.load_CV("all") for t in trajs]
338 return
339 if path_data.is_dir():
340 return load_CVs_from_dir(trajs, data, attr_name=attr_name, cols=cols)
341 elif data in CAN_BE_FEATURE_NAME: 341 ↛ 342line 341 didn't jump to line 342, because the condition on line 341 was never true
342 [t.load_CV(data, attr_name, cols, labels) for t in trajs]
343 return
344 elif path_data.is_file() and (
345 path_data.suffix == ".h5" or path_data.suffix == ".nc"
346 ):
347 ds = xr.open_dataset(path_data)
348 if diff := set([t.traj_num for t in trajs]) - set(ds["traj_num"].values): 348 ↛ 349line 348 didn't jump to line 349, because the condition on line 348 was never true
349 raise Exception(
350 f"The dataset you try to load and the TrajEnsemble "
351 f"have different number of trajectories: {diff}."
352 )
353 for t, (traj_num, sub_ds) in zip(trajs, ds.groupby("traj_num")):
354 assert t.traj_num == traj_num
355 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num)
356 sub_ds = sub_ds.expand_dims("traj_num")
357 assert sub_ds.coords["traj_num"] == np.array([t.traj_num])
358 t.load_CV(sub_ds)
359 return
360 elif all([f.is_file() for f in npy_files]):
361 [
362 t.load_CV(f, attr_name=data, cols=cols, labels=labels)
363 for t, f in zip(trajs, npy_files)
364 ]
365 return
366 elif all([f.is_file() for f in txt_files]): 366 ↛ 372line 366 didn't jump to line 372, because the condition on line 366 was never false
367 [
368 t.load_CV(f, attr_name=data, cols=cols, labels=labels)
369 for t, f in zip(trajs, txt_files)
370 ]
371 return
372 elif all([f.is_file() for f in raw_files]):
373 [
374 t.load_CV(f, attr_name=data, cols=cols, labels=labels)
375 for t, f in zip(trajs, raw_files)
376 ]
377 return
378 else:
379 msg = (
380 f"If `data` is provided a single string, the string needs to "
381 f"be either a feature ({CAN_BE_FEATURE_NAME}), a .h5/.nc file "
382 f"({file}), or a list of npy/txt files ({npy_files}, "
383 f"{txt_files}). The provided `data` fits none of "
384 f"these possibilities."
385 )
386 raise ValueError(msg)
388 elif isinstance(data, list) and not ensemble:
389 if all([i in CAN_BE_FEATURE_NAME for i in data]):
390 [t.load_CV(data, attr_name, cols, labels) for t in trajs]
391 return
392 elif all([isinstance(i, (list, np.ndarray)) for i in data]):
393 [t.load_CV(d, attr_name, cols, labels) for t, d in zip(trajs, data)]
394 return
395 elif all([Path(f).is_file() for f in data]):
396 suffix = set([Path(f).suffix for f in data])
397 if len(suffix) != 1: 397 ↛ 398line 397 didn't jump to line 398, because the condition on line 397 was never true
398 raise Exception(
399 "Please provide a list with consistent file "
400 f"extensions and not a mish-mash, like: {suffix}"
401 )
402 suffix = suffix.pop()
403 if suffix == ".npy":
404 [
405 t.load_CV(np.load(d), attr_name, cols, labels)
406 for t, d in zip(trajs, data)
407 ]
408 else:
409 [
410 t.load_CV(np.genfromtxt(d), attr_name, cols, labels)
411 for t, d in zip(trajs, data)
412 ]
413 return
414 else:
415 msg = (
416 f"If `data` is provided as a list, the list needs to contain "
417 f"strings that can be features ({CAN_BE_FEATURE_NAME}), or "
418 f"some combination of lists and numpy arrays."
419 )
420 raise ValueError(msg)
422 elif isinstance(data, np.ndarray):
423 if len(data) != trajs.n_trajs and len(data) != trajs.n_frames:
424 raise ValueError(
425 f"The provided numpy array is misshaped. It needs "
426 f"to be of shape (n_trajs={trajs.n_trajs}, "
427 f"n_frames={np.unique([t.n_frames for t in trajs])[0]}, "
428 f"X, (Y)), but is {data.shape}."
429 )
430 if len(data) == trajs.n_frames:
431 data = [data[t.id[:, 1]] for t in trajs]
432 [t.load_CV(d, attr_name, cols, labels) for t, d in zip(trajs, data)]
433 for t in trajs:
434 for v in t._CVs.values():
435 assert v.shape[0] == 1, print(t.basename, v)
436 return
438 elif isinstance(data, Featurizer):
439 ds = data.get_output()
440 for t, (traj_num, sub_ds) in zip(trajs, ds.groupby("traj_num")):
441 assert t.traj_num == traj_num
442 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num)
443 sub_ds = sub_ds.expand_dims("traj_num")
444 t.load_CV(sub_ds)
445 return
447 elif isinstance(data, xr.Dataset): 447 ↛ 448line 447 didn't jump to line 448, because the condition on line 447 was never true
448 for t, (traj_num, sub_ds) in zip(trajs, data.groupby("traj_num")):
449 assert t.traj_num == traj_num
450 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num)
451 sub_ds = sub_ds.expand_dims("traj_num")
452 t.load_CV(sub_ds)
453 return
455 if ensemble: 455 ↛ 459line 455 didn't jump to line 459, because the condition on line 455 was never false
456 return load_CVs_ensemble(trajs, data)
458 else:
459 raise TypeError(
460 f"`data` must be str, np.ndarray, list, xr.Dataset"
461 f"em.Featurizer or. You supplied {type(data)}."
462 )
465def load_CVs_ensemble(
466 trajs: TrajEnsemble,
467 data: Union[str, list[str]],
468) -> None:
469 if isinstance(data, str): 469 ↛ 472line 469 didn't jump to line 472, because the condition on line 469 was never false
470 if data != "all": 470 ↛ 471line 470 didn't jump to line 471, because the condition on line 470 was never true
471 data = [data]
472 feat = Featurizer(trajs)
473 feat.add_list_of_feats(data)
474 for t, (traj_num, sub_ds) in zip(trajs, feat.get_output().groupby("traj_num")):
475 assert t.traj_num == traj_num
476 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num)
477 sub_ds = sub_ds.expand_dims("traj_num")
478 if t._CVs: 478 ↛ 479line 478 didn't jump to line 479, because the condition on line 478 was never true
479 warnings.warn(
480 "Using ensemble=True will drop old CV entries from "
481 "trajs, because the ferature length increases."
482 )
483 t._CVs = sub_ds
486def load_CVs_from_dir(
487 trajs: TrajEnsemble,
488 data: Path,
489 attr_name: Optional[str] = None,
490 cols: Optional[list[int]] = None,
491) -> None:
492 files = map(str, data.glob("*"))
493 files = list(
494 filter(
495 lambda x: True if any([traj.basename in x for traj in trajs]) else False,
496 files,
497 )
498 )
499 key = {"npy": 1, "txt": 2}
500 files = sorted(
501 files,
502 key=lambda x: key[x.split(".")[-1]] if x.split(".")[-1] in key else 3,
503 )[: trajs.n_trajs]
504 files = sorted(
505 files,
506 key=lambda x: [traj.basename in x for traj in trajs].index(True),
507 )
508 for traj, f in zip(trajs, files):
509 if traj.basename not in f: 509 ↛ 510line 509 didn't jump to line 510, because the condition on line 509 was never true
510 raise Exception(f"File {f} does not contain substring of traj {traj}.")
511 traj.load_CV(f, attr_name=attr_name, cols=cols)
512 return