Coverage for encodermap/loading/featurizer.py: 13%

667 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-15 22:46 +0200

1# -*- coding: utf-8 -*- 

2# encodermap/loading/featurizer.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""EncoderMap featurization follows the example of the now deprecated PyEMMA package. 

23 

24You can define your features in advance, inspect the expected output and then let 

25the computer do the number crunching afterwards. This can be done with either 

26PyEMMAs streamable featurization or **new** with dask and delayed on a dask-cluster 

27of your liking. Here are the basic concepts of EncoderMap's featurization. 

28 

29""" 

30 

31 

32################################################################################ 

33# Imports 

34################################################################################ 

35 

36 

37# Future Imports at the top 

38from __future__ import annotations 

39 

40# Standard Library Imports 

41import itertools 

42import numbers 

43import os 

44import re 

45import time 

46import warnings 

47from pathlib import Path 

48 

49# Third Party Imports 

50import numpy as np 

51from optional_imports import _optional_import 

52 

53# Encodermap imports 

54from encodermap.loading import features 

55from encodermap.loading.delayed import build_dask_xarray 

56 

57 

58################################################################################ 

59# Optional Imports 

60################################################################################ 

61 

62 

63xr = _optional_import("xarray") 

64md = _optional_import("mdtraj") 

65rich = _optional_import("rich") 

66Client = _optional_import("dask", "distributed.Client") 

67dask = _optional_import("dask") 

68Callback = _optional_import("dask", "callbacks.Callback") 

69dot_graph = _optional_import("dask", "dot.dot_graph") 

70progress = _optional_import("dask", "distributed.progress") 

71HDF5TrajectoryFile = _optional_import("mdtraj", "formats.HDF5TrajectoryFile") 

72_get_global_client = _optional_import("distributed", "client._get_global_client") 

73 

74 

75################################################################################ 

76# Typing 

77################################################################################ 

78 

79 

80# Standard Library Imports 

81from collections.abc import Iterable, Sequence 

82from typing import TYPE_CHECKING, Literal, Optional, Union 

83 

84 

85if TYPE_CHECKING: 

86 # Third Party Imports 

87 import dask 

88 import xarray as xr 

89 from dask import dot_graph 

90 from dask.callbacks import Callback 

91 from dask.distributed import Client, progress 

92 from distributed.client import _get_global_client 

93 

94 # Encodermap imports 

95 from encodermap.loading.features import AnyFeature 

96 from encodermap.trajinfo.info_all import TrajEnsemble 

97 from encodermap.trajinfo.info_single import SingleTraj 

98 

99 

100################################################################################ 

101# Import tqdm which can be either the jupyter one or the plain one 

102################################################################################ 

103 

104 

105def _is_notebook() -> bool: # pragma: no cover 

106 """Checks, whether code is currently executed in a notebook.""" 

107 try: 

108 # Third Party Imports 

109 from IPython import get_ipython 

110 

111 if "IPKernelApp" not in get_ipython().config: 

112 return False 

113 except ImportError: 

114 return False 

115 except AttributeError: 

116 return False 

117 return True 

118 

119 

120if _is_notebook(): # pragma: no cover 

121 # Third Party Imports 

122 from tqdm.notebook import tqdm 

123else: 

124 # Third Party Imports 

125 from tqdm import tqdm 

126 

127 

128################################################################################ 

129# Globals 

130################################################################################ 

131 

132 

133__all__: list[str] = ["Featurizer", "DaskFeaturizer"] 

134 

135 

136UNDERSCORE_MAPPING: dict[str, str] = { 

137 "central_dihedrals": "CentralDihedrals", 

138 "all_cartesians": "AllCartesians", 

139 "all_distances": "AllBondDistances", 

140 "central_cartesians": "CentralCartesians", 

141 "central_distances": "CentralBondDistances", 

142 "central_angles": "CentralAngles", 

143 "side_cartesians": "SideChainCartesians", 

144 "side_distances": "SideChainBondDistances", 

145 "side_angles": "SideChainAngles", 

146 "side_dihedrals": "SideChainDihedrals", 

147} 

148 

149_ADD_X_FUNCTION_NAMES: list[str] = [ 

150 "add_all", 

151 "add_selection", 

152 "add_distances_ca", 

153 "add_distances", 

154 "add_inverse_distances", 

155 "add_contacts", 

156 "add_residue_mindist", 

157 "add_group_COM", 

158 "add_residue_COM", 

159 "add_angles", 

160 "add_dihedrals", 

161 "add_minrmsd_to_ref", 

162 "add_backbone_torsions", 

163 "add_sidechain_torsions", 

164] 

165 

166 

167################################################################################ 

168# Utils 

169################################################################################ 

170 

171 

172class CoordsLoad: 

173 pass 

174 

175 

176def is_iterable_of_types(l, supertype): 

177 """Checks whether all elements of l are of type `supertype`.""" 

178 return is_iterable(l) and all( 

179 issubclass(t, supertype) for t, _ in itertools.groupby(l, type) 

180 ) 

181 

182 

183def is_iterable_of_int(l): 

184 """Checks if l is iterable and contains only integral types.""" 

185 return is_iterable_of_types(l, numbers.Integral) 

186 

187 

188def is_iterable(I): 

189 return isinstance(I, Iterable) 

190 

191 

192def _atoms_in_residues( 

193 traj: SingleTraj, 

194 residue_idxs: Sequence[int], 

195 subset_of_atom_idxs: Optional[np.ndarray] = None, 

196 fallback_to_full_residue: bool = True, 

197) -> list[np.ndarray]: 

198 """Returns a list of arrays containing the atom indices in each residue of `residue_idxs` 

199 

200 Args: 

201 traj (SingleTraj): A `SingleTraj` instance. 

202 residue_idxs (Sequence[int]): List or ndarray (ndim=1) of integers. 

203 subset_of_atom_idxs (Optional[np.ndarray]): Iterable of atom_idxs to which the 

204 selection has to be restricted. If None, all atoms considered. 

205 Defaults to None. 

206 fallback_to_full_residue (bool): It is possible that some 

207 residues don't yield any atoms with some subsets. Take all atoms in 

208 that case. If False, then [] is returned for that residue. 

209 Defaults to None. 

210 

211 Returns: 

212 list[np.ndarray]: The resulting list of arrays. 

213 

214 """ 

215 atoms_in_residues = [] 

216 if subset_of_atom_idxs is None: 

217 subset_of_atom_idxs = np.arange(traj.top.n_atoms) 

218 special_residues = [] 

219 for rr in traj.top.residues: 

220 if rr.index in residue_idxs: 

221 toappend = np.array( 

222 [aa.index for aa in rr.atoms if aa.index in subset_of_atom_idxs] 

223 ) 

224 if len(toappend) == 0: 

225 special_residues.append(rr) 

226 if fallback_to_full_residue: 

227 toappend = np.array([aa.index for aa in rr.atoms]) 

228 

229 atoms_in_residues.append(toappend) 

230 

231 # Any special cases? 

232 if len(special_residues) != 0: 

233 if fallback_to_full_residue: 

234 msg = "the full residue" 

235 else: 

236 msg = "emtpy lists" 

237 warnings.warn( 

238 f"These residues yielded no atoms in the subset and were returned as " 

239 f"{msg} {[rr for rr in special_residues[-2:]]}" 

240 ) 

241 

242 return atoms_in_residues 

243 

244 

245def combinations( 

246 seq: Iterable, 

247 k: int, 

248) -> np.ndarray: 

249 """Return j length subsequences of elements from the input iterable. 

250 

251 This version uses Numpy/Scipy and should be preferred over itertools. It avoids 

252 the creation of all intermediate Python objects. 

253 

254 Examples: 

255 >>> import numpy as np 

256 >>> from itertools import combinations as iter_comb 

257 >>> x = np.arange(3) 

258 >>> c1 = combinations(x, 2) 

259 >>> print(c1) 

260 [[0 1] 

261 [0 2] 

262 [1 2]] 

263 >>> c2 = np.array(tuple(iter_comb(x, 2))) 

264 >>> print(c2) 

265 [[0 1] 

266 [0 2] 

267 [1 2]] 

268 """ 

269 # Standard Library Imports 

270 from itertools import chain 

271 from itertools import combinations as _combinations 

272 

273 # Third Party Imports 

274 from scipy.special import comb 

275 

276 count = comb(len(seq), k, exact=True) 

277 res = np.fromiter(chain.from_iterable(_combinations(seq, k)), int, count=count * k) 

278 return res.reshape(-1, k) 

279 

280 

281def product(*arrays: np.ndarray) -> np.ndarray: 

282 """Generate a cartesian product of input arrays. 

283 

284 Args: 

285 arrays (np.ndarray): 1-D arrays to form the cartesian product of. 

286 

287 Returns: 

288 np.ndarray: 2-D array of shape (M, len(arrays)) containing cartesian 

289 products formed of input arrays. 

290 

291 """ 

292 arrays = [np.asarray(x) for x in arrays] 

293 shape = (len(x) for x in arrays) 

294 dtype = arrays[0].dtype 

295 

296 ix = np.indices(shape) 

297 ix = ix.reshape(len(arrays), -1).T 

298 

299 out = np.empty_like(ix, dtype=dtype) 

300 

301 for n, _ in enumerate(arrays): 

302 out[:, n] = arrays[n][ix[:, n]] 

303 

304 return out 

305 

306 

307def _parse_pairwise_input( 

308 indices1: Sequence[int], 

309 indices2: Sequence[int], 

310) -> np.ndarray: 

311 """For input of a pairwise type (distances, inverse distances, contacts) checks 

312 the type of input the user gave and formats it so that `DistanceFeature`, 

313 `InverseDistanceFeature`, and `ContactFeature` can work. 

314 

315 In case the input isn't already a list of distances, this function will: 

316 - sort the indices1 array 

317 - check for duplicates within the indices1 array 

318 - sort the indices2 array 

319 - check for duplicates within the indices2 array 

320 - check for duplicates between the indices1 and indices2 array 

321 - if indices2 is None, produce a list of pairs of indices in indices1, or 

322 - if indices2 is not None, produce a list of pairs of (i,j) where i comes from indices1, and j from indices2 

323 

324 """ 

325 

326 if is_iterable_of_int(indices1): 

327 # Eliminate duplicates and sort 

328 indices1 = np.unique(indices1) 

329 

330 # Intra-group distances 

331 if indices2 is None: 

332 atom_pairs = combinations(indices1, 2) 

333 

334 # Inter-group distances 

335 elif is_iterable_of_int(indices2): 

336 # Eliminate duplicates and sort 

337 indices2 = np.unique(indices2) 

338 

339 # Eliminate duplicates between indices1 and indices1 

340 uniqs = np.in1d(indices2, indices1, invert=True) 

341 indices2 = indices2[uniqs] 

342 atom_pairs = product(indices1, indices2) 

343 

344 else: 

345 atom_pairs = indices1 

346 

347 return atom_pairs 

348 

349 

350def pairs( 

351 sel: np.ndarray, 

352 excluded_neighbors: int = 0, 

353) -> np.ndarray: 

354 """Creates all pairs between indexes. Will exclude closest neighbors up to 

355 `excluded_neighbors` The self-pair (i,i) is always excluded. 

356 

357 Args: 

358 sel (np.ndarray): Array with selected atom indexes. 

359 excluded_neighbors (int): Number of neighbors that will be excluded 

360 when creating the pairs. Defaults to 0. 

361 

362 Returns: 

363 np.ndarray: A m x 2 array with all pair indexes between 

364 different atoms that are at least `excluded_neighbors` indexes 

365 apart, i.e. if i is the index of an atom, the pairs 

366 [i,i-2], [i,i-1], [i,i], [i,i+1], [i,i+2], will not be in `sel` 

367 (n=excluded_neighbors) if `excluded_neighbors` = 2. Moreover, 

368 the list is non-redundant,i.e. if [i,j] is in sel, then [j,i] is not. 

369 

370 """ 

371 assert isinstance(excluded_neighbors, int) 

372 

373 p = [] 

374 for i in range(len(sel)): 

375 for j in range(i + 1, len(sel)): 

376 # get ordered pair 

377 I = sel[i] 

378 J = sel[j] 

379 if I > J: 

380 I = sel[j] 

381 J = sel[i] 

382 # exclude 1 and 2 neighbors 

383 if J > I + excluded_neighbors: 

384 p.append([I, J]) 

385 return np.array(p) 

386 

387 

388class Track(Callback): 

389 def __init__( 

390 self, 

391 path: str = "/tmp.json/dasks", 

392 save_every: int = 1, 

393 ) -> None: 

394 self.path = path 

395 self.save_every = save_every 

396 self.n = 0 

397 os.makedirs(path, exist_ok=True) 

398 

399 def _plot( 

400 self, 

401 dsk, 

402 state, 

403 ) -> None: 

404 data = {} 

405 func = {} 

406 for key in state["released"]: 

407 data[key] = {"color": "blue"} 

408 for key in state["cache"]: 

409 data[key] = {"color": "red"} 

410 for key in state["finished"]: 

411 func[key] = {"color": "blue"} 

412 for key in state["running"]: 

413 func[key] = {"color": "red"} 

414 

415 filename = os.path.join(self.path, "part_{:0>4d}".format(self.n)) 

416 

417 dot_graph( 

418 dsk, 

419 filename=filename, 

420 format="png", 

421 data_attributes=data, 

422 function_attributes=func, 

423 ) 

424 

425 def _pretask( 

426 self, 

427 key, 

428 dsk, 

429 state, 

430 ) -> None: 

431 if self.n % self.save_every == 0: 

432 self._plot(dsk, state) 

433 self.n += 1 

434 

435 def _finish( 

436 self, 

437 dsk, 

438 state, 

439 errored, 

440 ) -> None: 

441 self._plot(dsk, state) 

442 self.n += 1 

443 

444 

445################################################################################ 

446# Classes 

447################################################################################ 

448 

449 

450class SingleTrajFeaturizer: 

451 def __init__(self, traj: SingleTraj, delayed: bool = False) -> None: 

452 self.traj = traj 

453 self.delayed = delayed 

454 self._n_custom_features = 0 

455 self._custom_feature_ids = [] 

456 self.active_features = [] 

457 

458 def add_list_of_feats( 

459 self, 

460 which: Union[Literal["all", "full"], Sequence[str]] = "all", 

461 deg: bool = False, 

462 omega: bool = True, 

463 check_aas: bool = True, 

464 periodic: bool = True, 

465 delayed: bool = False, 

466 ) -> None: 

467 """Adds features to the Featurizer to be loaded either in-memory. The 

468 argument `which` can be either 'all' or a list of the following strings: 

469 * 'AllCartesians': Cartesian coordinates of all atoms with 

470 shape (n_frames, n_atoms, 3). 

471 * 'AllBondDistances': Bond distances of all bonds recognized by 

472 mdtraj. Use top = md.Topology.from_openmm()if mdtraj does not 

473 recognize all bonds. 

474 * 'CentralCartesians': Cartesians of the N, C, CA atoms in the 

475 backbone with shape (n_frames, n_residues * 3, 3). 

476 * 'CentralBondDistances': The bond distances of the N, C, CA bonds 

477 with shape (n_frames, n_residues * 3 - 1). 

478 * 'CentralAngles': The angles between the backbone bonds with shape 

479 (n_frames, n_residues * 3 - 2). 

480 * 'CentralDihedrals': The dihedrals between the backbone atoms 

481 (omega, phi, psi). With shape (n_frames, n_residues * 3 - 3). 

482 * 'SideChainCartesians': Cartesians of the sidechain-atoms. 

483 Starting with CB, CG, ... 

484 * 'SideChainBondDistances': Bond distances between the 

485 sidechain atoms. starting with the CA-CG bond. 

486 * 'SideChainAngles': Angles between sidechain atoms. Starting with 

487 the C-CA-CB angle. 

488 * 'SideChainDihedrals': Dihedrals of the sidechains (chi1, chi2, chi3). 

489 If 'all' is provided for `which` the CentralCartesian, CentralDistances, 

490 CentralAngles, CentralDihedrals, SideChainDihedrals will be added. 

491 

492 Args: 

493 which (Union[str, list], optional). Either add 'all' features or 

494 a list of features. See Above for possible features. Defaults 

495 to 'all'. 

496 deg (bool): Whether the output should be formatted in degrees. 

497 omega (bool): Whether to include the omega angles of the backbone. 

498 check_aas (bool): Whether to check if all residues in top are known. Helps with custom topology to not skip unkonw custom/non-natural amino acids. 

499 

500 """ 

501 recognized_str = list(UNDERSCORE_MAPPING.keys()) + list( 

502 UNDERSCORE_MAPPING.values() 

503 ) 

504 if isinstance(which, str): 

505 if which == "full": 

506 which = [ 

507 "CentralCartesians", 

508 "CentralBondDistances", 

509 "CentralAngles", 

510 "CentralDihedrals", 

511 "SideChainDihedrals", 

512 "SideChainCartesians", 

513 "SideChainAngles", 

514 "AllCartesians", 

515 "SideChainBondDistances", 

516 ] 

517 elif which == "all": 

518 which = [ 

519 "CentralCartesians", 

520 "CentralBondDistances", 

521 "CentralAngles", 

522 "CentralDihedrals", 

523 "SideChainDihedrals", 

524 ] 

525 else: 

526 if which not in recognized_str: 

527 raise Exception( 

528 f"Recognized arguments to which are 'all' or any of the " 

529 f"following: {recognized_str}. The str you provided {which} " 

530 f"did not match any." 

531 ) 

532 which = [which] 

533 elif isinstance(which, (list, tuple)): 

534 assert all([isinstance(i, str) for i in which]), ( 

535 f"The argument `which` needs to be provided a sequence of str. " 

536 f"You have offending types in this argument." 

537 ) 

538 diff = set(which) - set(recognized_str) 

539 if diff: 

540 raise Exception( 

541 f"One or more of the str in `which` are not recognized. " 

542 f"The argument `which` needs to be a sequence containing any " 

543 f"of the following: {recognized_str}. The unrecognized str are: " 

544 f"{diff}." 

545 ) 

546 

547 # add the features 

548 for cf in which: 

549 if cf in UNDERSCORE_MAPPING: 

550 cf = UNDERSCORE_MAPPING[cf] 

551 feature = getattr(features, cf) 

552 if ( 

553 not feature._use_periodic 

554 and not feature._use_angle 

555 and not feature._use_omega 

556 ): 

557 feature = feature( 

558 self.traj, 

559 check_aas=True, 

560 delayed=delayed, 

561 ) 

562 elif ( 

563 feature._use_periodic 

564 and not feature._use_angle 

565 and not feature._use_omega 

566 ): 

567 feature = feature( 

568 self.traj, 

569 check_aas=True, 

570 periodic=periodic, 

571 delayed=delayed, 

572 ) 

573 elif ( 

574 feature._use_periodic and feature._use_angle and not feature._use_omega 

575 ): 

576 feature = feature( 

577 self.traj, 

578 deg=deg, 

579 check_aas=check_aas, 

580 periodic=periodic, 

581 delayed=delayed, 

582 ) 

583 elif feature._use_periodic and feature._use_angle and feature._use_omega: 

584 feature = feature( 

585 self.traj, 

586 deg=deg, 

587 omega=omega, 

588 check_aas=check_aas, 

589 periodic=periodic, 

590 delayed=delayed, 

591 ) 

592 else: 

593 raise Exception( 

594 f"Unknown combination of `_use_angle` and `_use_omega` in " 

595 f"class attributes of {feature=}" 

596 ) 

597 self._add_feature(feature) 

598 

599 def add_custom_feature(self, feature: AnyFeature) -> None: 

600 # Encodermap imports 

601 from encodermap.loading.features import CustomFeature 

602 

603 if not hasattr(feature, "name"): 

604 if not hasattr(feature, "id"): 

605 feature.id = self._n_custom_features 

606 self._custom_feature_ids.append(self._n_custom_features) 

607 self._n_custom_features += 1 

608 elif feature.id is None: 

609 feature.id = self._n_custom_features 

610 self._custom_feature_ids.append(self._n_custom_features) 

611 self._n_custom_features += 1 

612 else: 

613 assert feature.id not in self._custom_feature_ids, ( 

614 f"A CustomFeature with the id {feature.id} already exists. " 

615 f"Please change the id of your CustomFeature." 

616 ) 

617 assert isinstance(feature, CustomFeature) or issubclass( 

618 feature.__class__, CustomFeature 

619 ) 

620 feature.name = f"CustomFeature_{feature.id}" 

621 self._add_feature(feature) 

622 

623 def _add_feature(self, feature: AnyFeature) -> None: 

624 """Adds any feature to the list of current features. 

625 

626 Also checks whether the feature is already part of the active features. 

627 

628 """ 

629 assert feature.delayed == self.delayed, ( 

630 f"In-memory featurizer {self.__class__} unexpectedly got a delayed " 

631 f"feature {feature}. {feature.delayed=} {self.delayed=}" 

632 ) 

633 if feature.dimension == 0: 

634 warnings.warn( 

635 f"Given an empty feature (eg. due to an empty/ineffective " 

636 f"selection). Skipping it. Feature desc: {feature.describe()}" 

637 ) 

638 return 

639 if feature not in self.active_features: 

640 self.active_features.append(feature) 

641 else: 

642 warnings.warn( 

643 f"Tried to re-add the same feature {feature.__class__.__name__} to " 

644 f"{self.active_features=}" 

645 ) 

646 

647 def add_distances_ca( 

648 self, 

649 periodic: bool = True, 

650 excluded_neighbors: int = 2, 

651 delayed: bool = False, 

652 ) -> None: 

653 """Adds the distances between all Ca's to the feature list. 

654 

655 Args: 

656 periodic (bool): Use the minimum image convention when computing distances. 

657 excluded_neighbors (int): Number of exclusions when compiling the 

658 list of pairs. Two CA-atoms are considered neighbors if they 

659 belong to adjacent residues. Defaults to 2. 

660 

661 """ 

662 at_idxs_ca = self.select_Ca 

663 res_idxs_ca = [self.traj.top.atom(ca).residue.index for ca in at_idxs_ca] 

664 res_idxs_ca_pairs = pairs(res_idxs_ca, excluded_neighbors=excluded_neighbors) 

665 distance_indexes = [] 

666 for ri, rj in res_idxs_ca_pairs: 

667 distance_indexes.append( 

668 [ 

669 self.traj.top.residue(ri).atom("CA").index, 

670 self.traj.top.residue(rj).atom("CA").index, 

671 ] 

672 ) 

673 distance_indexes = np.array(distance_indexes) 

674 

675 self.add_distances(distance_indexes, periodic=periodic, delayed=delayed) 

676 

677 def add_distances( 

678 self, 

679 indices: Union[np.ndarray, Sequence[int]], 

680 periodic: bool = True, 

681 indices2: Optional[Sequence[int]] = None, 

682 delayed: bool = False, 

683 ) -> None: 

684 """Adds the distances between atoms to the feature list. 

685 

686 Args: 

687 indices (Union[np.ndarray, Iterable[Sequence[int]]]): Can be one of to types: 

688 A numpy array of shape (n, 2) with the pairs of atoms between 

689 which the distances shall be computed. Or a sequence of integers 

690 which are the indices (not pairs of indices) of the atoms between 

691 which the distances shall be computed. In this case, the arg `indices2` 

692 needs to be supplied. 

693 

694 periodic (bool): If periodic is True and the trajectory contains 

695 unitcell information, distances will be computed under the 

696 minimum image convention. Defaults to True. 

697 indices2 (Optional[Sequence[int]]): Only has effect if `indices` is 

698 a sequence of integers. Instead of the above behavior, only the 

699 distances between the atoms in indices` and `indices2` will be 

700 computed. 

701 

702 Note: 

703 When using the iterable of integers input, `indices` and `indices2` 

704 will be sorted numerically and made unique before converting them to 

705 a pairlist. Please look carefully at the output of `self.describe()` to 

706 see what features exactly have been added. 

707 

708 """ 

709 # Local Folder Imports 

710 from .features import DistanceFeature 

711 

712 atom_pairs = _parse_pairwise_input(indices, indices2) 

713 

714 atom_pairs = self._check_indices(atom_pairs) 

715 f = DistanceFeature(self.traj, atom_pairs, periodic=periodic, delayed=delayed) 

716 self._add_feature(f) 

717 

718 def add_backbone_torsions( 

719 self, 

720 selstr: Optional[int] = None, 

721 deg: bool = False, 

722 cossin: bool = False, 

723 periodic: bool = True, 

724 delayed: bool = False, 

725 ) -> None: 

726 """Adds all backbone phi/psi angles or the ones specified in `selstr` to the feature list. 

727 

728 Args: 

729 selstr (Optional[str]): If None, all phi/psi angles will be considered. 

730 Otherwise, can be a string specifying the atoms of specific 

731 backbone torsions (see example). 

732 deg (bool): Whether the output should be in degrees (True) or radians 

733 (False). Defaults to False. 

734 cossin (bool): Whether to return the angles (False) or tuples of their 

735 cos and sin values (True). Defaults to False. 

736 periodic (bool): Whether to observe the minimum image convention 

737 and respect proteins breaking over the periodic boundary 

738 condition as a whole (True). In this case, the trajectory container 

739 in `traj` needs to have unitcell information. Defaults to True. 

740 

741 Examples: 

742 >>> import encodermap as em 

743 >>> import numpy as np 

744 >>> from pprint import pprint 

745 >>> trajs = em.load_project("linear_dimers") 

746 >>> feat = em.Featurizer(trajs[0]) 

747 >>> feat.add_backbone_torsions("resname PRO") 

748 >>> pprint(feat.describe()) 

749 ['PHI 0 PRO 19', 

750 'PSI 0 PRO 19', 

751 'PHI 0 PRO 37', 

752 'PSI 0 PRO 37', 

753 'PHI 0 PRO 38', 

754 'PSI 0 PRO 38', 

755 'PHI 0 PRO 95', 

756 'PSI 0 PRO 95', 

757 'PHI 0 PRO 113', 

758 'PSI 0 PRO 113', 

759 'PHI 0 PRO 114', 

760 'PSI 0 PRO 114'] 

761 >>> ds = feat.get_output() 

762 >>> da = ds.BackboneTorsionFeature 

763 >>> phi_indices = da.coords["BACKBONETORSIONFEATURE"].str.contains("PHI") 

764 >>> angles = np.rad2deg(da.sel(BACKBONETORSIONFEATURE=phi_indices).values[0]) 

765 >>> np.min(angles) 

766 -103.39891 

767 >>> np.max(angles) 

768 -10.015779 

769 

770 """ 

771 # Local Folder Imports 

772 from .features import BackboneTorsionFeature 

773 

774 f = BackboneTorsionFeature( 

775 self.traj, 

776 selstr=selstr, 

777 deg=deg, 

778 cossin=cossin, 

779 periodic=periodic, 

780 delayed=delayed, 

781 ) 

782 self._add_feature(f) 

783 

784 def add_angles( 

785 self, 

786 indexes: np.ndarray, 

787 deg: bool = False, 

788 cossin: bool = False, 

789 periodic: bool = True, 

790 delayed: bool = False, 

791 ) -> None: 

792 """Adds the list of angles to the feature list. 

793 

794 Args: 

795 indexes (np.ndarray): An array with triplets of atom indices. 

796 deg (bool): Whether the output should be in degrees (True) or radians 

797 (False). Defaults to False. 

798 cossin (bool): Whether to return the angles (False) or tuples of their 

799 cos and sin values (True). Defaults to False. 

800 periodic (bool): Whether to observe the minimum image convention 

801 and respect proteins breaking over the periodic boundary 

802 condition as a whole (True). In this case, the trajectory container 

803 in `traj` needs to have unitcell information. Defaults to True. 

804 

805 """ 

806 # Local Folder Imports 

807 from .features import AngleFeature 

808 

809 indexes = self._check_indices(indexes, pair_n=3) 

810 f = AngleFeature( 

811 self.traj, 

812 indexes, 

813 deg=deg, 

814 cossin=cossin, 

815 periodic=periodic, 

816 delayed=delayed, 

817 ) 

818 self._add_feature(f) 

819 

820 def add_all( 

821 self, 

822 reference: Optional[md.Trajectory] = None, 

823 atom_indices: Optional[np.ndarray] = None, 

824 ref_atom_indices: Optional[np.ndarray] = None, 

825 delayed: bool = False, 

826 ) -> None: 

827 """Adds all atom coordinates to the feature list. 

828 The coordinates are flattened as follows: [x1, y1, z1, x2, y2, z2, ...] 

829 

830 Args: 

831 reference (Optional[md.Trajectory]). If different from None, all 

832 data is aligned using MDTraj's superpose. Defaults to None. 

833 atom_indices (Optional[np.ndarray]): The indices of atoms to superpose 

834 If None all atoms will be used. Defaults to None. 

835 ref_atom_indices (Optional[np.ndarray]): Use these atoms on the 

836 reference structure. If not supplied, the same atom indices 

837 will be used for this trajectory and the reference one. 

838 

839 """ 

840 self.add_selection( 

841 list(range(self.traj.n_atoms)), 

842 reference=reference, 

843 atom_indices=atom_indices, 

844 ref_atom_indices=ref_atom_indices, 

845 delayed=delayed, 

846 ) 

847 

848 def add_selection( 

849 self, 

850 indexes: np.ndarray, 

851 reference: Optional[np.ndarray] = None, 

852 atom_indices: Optional[np.ndarray] = None, 

853 ref_atom_indices: Optional[np.ndarray] = None, 

854 delayed: bool = False, 

855 ) -> None: 

856 """Adds the coordinates of the selected atom indexes to the feature list. 

857 The coordinates of the selection [1, 2, ...] are flattened as follows: [x1, y1, z1, x2, y2, z2, ...] 

858 

859 Args: 

860 indexes (np.ndarray): Array with selected atom indexes. 

861 reference (Optional[md.Trajectory]). If different from None, all 

862 data is aligned using MDTraj's superpose. Defaults to None. 

863 atom_indices (Optional[np.ndarray]): The indices of atoms to superpose 

864 If None, all atoms will be used. Defaults to None. 

865 ref_atom_indices (Optional[np.ndarray]): Use these atoms on the 

866 reference structure. If not supplied, the same atom indices 

867 will be used for this trajectory and the reference one. 

868 

869 """ 

870 # Local Folder Imports 

871 from .features import AlignFeature, SelectionFeature 

872 

873 if reference is None: 

874 f = SelectionFeature(self.traj, indexes, delayed=delayed) 

875 else: 

876 if not isinstance(reference, md.Trajectory): 

877 raise ValueError( 

878 "reference is not a mdtraj.Trajectory object, but {}".format( 

879 reference 

880 ) 

881 ) 

882 f = AlignFeature( 

883 reference=reference, 

884 indexes=indexes, 

885 atom_indices=atom_indices, 

886 ref_atom_indices=ref_atom_indices, 

887 delayed=delayed, 

888 ) 

889 self._add_feature(f) 

890 

891 def add_inverse_distances( 

892 self, 

893 indices: Union[np.ndarray, Sequence[int]], 

894 periodic: bool = True, 

895 indices2: Optional[Union[np.ndarray, Sequence[int]]] = None, 

896 delayed: bool = False, 

897 ) -> None: 

898 """Adds the inverse distances between atoms to the feature list. 

899 

900 Args: 

901 indices (Union[np.ndarray, Sequence[int]]): A array with shape (n, 2) 

902 giving the pairs of atoms between which the inverse distances 

903 shall be computed. Can also be a sequence of integers giving 

904 the first atoms in the distance calculations. In this case, 

905 `indices2` needs to be supplied. 

906 periodic (bool): Whether to observe the minimum image convention 

907 and respect proteins breaking over the periodic boundary 

908 condition as a whole (True). In this case, the trajectory container 

909 in `traj` needs to have unitcell information. Defaults to True. 

910 indices2 (Optional[Union[np.ndarray, Sequence[int]]]): If the argument 

911 `indices` is just a sequence of int (and not a (n, 2) np.ndarray), 

912 this argument needs to be provided. 

913 

914 Note: 

915 When using the *iterable of integers* input, `indices` and `indices2` 

916 will be sorted numerically and made unique before converting them to 

917 a pairlist. Please look carefully at the output of `describe()` to 

918 see what features exactly have been added. 

919 

920 """ 

921 # Local Folder Imports 

922 from .features import InverseDistanceFeature 

923 

924 atom_pairs = _parse_pairwise_input( 

925 indices, 

926 indices2, 

927 ) 

928 

929 atom_pairs = self._check_indices(atom_pairs) 

930 f = InverseDistanceFeature( 

931 self.traj, atom_pairs, periodic=periodic, delayed=delayed 

932 ) 

933 self._add_feature(f) 

934 

935 def add_contacts( 

936 self, 

937 indices: Union[np.ndarray, Sequence[int]], 

938 indices2: Optional[Union[np.ndarray, Sequence[int]]] = None, 

939 threshold: float = 0.3, 

940 periodic: bool = True, 

941 count_contacts: bool = False, 

942 delayed: bool = False, 

943 ) -> None: 

944 """Adds the contacts to the feature list. 

945 

946 Args: 

947 indices (Union[np.ndarray, Sequence[int]]): A array with shape (n, 2) 

948 giving the pairs of atoms between which the inverse distances 

949 shall be computed. Can also be a sequence of integers giving 

950 the first atoms in the distance calculations. In this case, 

951 `indices2` needs to be supplied. 

952 indices2 (Optional[Union[np.ndarray, Sequence[int]]]): If the argument 

953 `indices` is just a sequence of int (and not a (n, 2) np.ndarray), 

954 this argument needs to be provided. 

955 threshold (float): Distance below this (in nanometer) are considered 

956 as contacts. The output will contain 1.0 for these contacts. 

957 Above this threshold, the output will contain 0.0. Defaults to 0.2. 

958 periodic (bool): Whether to observe the minimum image convention 

959 and respect proteins breaking over the periodic boundary 

960 condition as a whole (True). In this case, the trajectory container 

961 in `traj` needs to have unitcell information. Defaults to True. 

962 count_contacts (bool): If set to true, this feature will return 

963 the number of formed contacts (and not feature values with 

964 either 1.0 or 0). The output of this feature will be of shape 

965 (Nt,1), and not (Nt, nr_of_contacts). Defaults to False. 

966 

967 Note: 

968 When using the *iterable of integers* input, `indices` and `indices2` 

969 will be sorted numerically and made unique before converting them 

970 to a pairlist. Please look carefully at the output of `describe()` 

971 to see what features exactly have been added. 

972 

973 """ 

974 # Local Folder Imports 

975 from .features import ContactFeature 

976 

977 atom_pairs = _parse_pairwise_input(indices, indices2) 

978 atom_pairs = self._check_indices(atom_pairs) 

979 f = ContactFeature( 

980 self.traj, atom_pairs, threshold, periodic, count_contacts, delayed=delayed 

981 ) 

982 self._add_feature(f) 

983 

984 def add_residue_mindist( 

985 self, 

986 residue_pairs: Union[Literal["all"], np.ndarray] = "all", 

987 scheme: Literal["ca", "closest", "closest-heavy"] = "closest-heavy", 

988 ignore_nonprotein: bool = True, 

989 threshold: Optional[float] = None, 

990 periodic: bool = True, 

991 count_contacts: bool = False, 

992 delayed: bool = False, 

993 ) -> None: 

994 """Adds the minimum distance between residues to the feature list. 

995 See below how the minimum distance can be defined. If the topology 

996 generated out of `traj` contains information on periodic boundary 

997 conditions, the minimum image convention will be used when computing 

998 distances. 

999 

1000 Args: 

1001 residue_pairs (Union[Literal["all"], np.ndarray]): Can be 'all', in 

1002 which case mindists will be calculated between all pairs of 

1003 residues excluding first and second neighbor. If a np.array 

1004 with shape (n ,2) is supplied, these residue indices (0-based) 

1005 will be used to compute the mindists. Defaults to 'all'. 

1006 scheme (Literal["ca", "closest", "closest-heavy"]): Within a residue, 

1007 determines the sub-group atoms that will be considered when 

1008 computing distances. Defaults to 'closest-heavy'. 

1009 ignore_nonprotein (bool): Whether to ignore residues that are not 

1010 of protein type (e.g. water molecules, post-translational modifications, 

1011 non-standard residues, etc.). Defaults to True. 

1012 threshold (float): Distances below this threshold (in nm) will 

1013 result in a feature 1.0, the distances above will result in 0.0. If 

1014 left to None, the numerical value will be returned. Defaults to None. 

1015 periodic (bool): Whether to observe the minimum image convention 

1016 and respect proteins breaking over the periodic boundary 

1017 condition as a whole (True). In this case, the trajectory container 

1018 in `traj` needs to have unitcell information. Defaults to True. 

1019 count_contacts (bool): If set to true, this feature will return 

1020 the number of formed contacts (and not feature values with 

1021 either 1.0 or 0). The output of this feature will be of shape 

1022 (Nt,1), and not (Nt, nr_of_contacts). Defaults to False. 

1023 

1024 Note: 

1025 Using `scheme` = 'closest' or 'closest-heavy' with 

1026 `residue pairs` = 'all' will compute nearly all interatomic distances, 

1027 for every frame, before extracting the closest pairs. This can be 

1028 very time-consuming. Those schemes are intended to be used with a 

1029 subset of residues chosen `residue_pairs`. 

1030 

1031 

1032 """ 

1033 # Local Folder Imports 

1034 from .features import ResidueMinDistanceFeature 

1035 

1036 if scheme != "ca" and isinstance(residue_pairs, str): 

1037 if residue_pairs == "all": 

1038 warnings.warn( 

1039 "Using all residue pairs with schemes like closest or " 

1040 "closest-heavy is very time consuming. Consider reducing " 

1041 "the residue pairs" 

1042 ) 

1043 

1044 f = ResidueMinDistanceFeature( 

1045 self.traj, 

1046 residue_pairs, 

1047 scheme, 

1048 ignore_nonprotein, 

1049 threshold, 

1050 periodic, 

1051 count_contacts=count_contacts, 

1052 delayed=delayed, 

1053 ) 

1054 

1055 self._add_feature(f) 

1056 

1057 def add_group_COM( 

1058 self, 

1059 group_definitions: Sequence[int], 

1060 ref_geom: Optional[md.Trajectory] = None, 

1061 image_molecules: bool = False, 

1062 mass_weighted: bool = True, 

1063 delayed: bool = False, 

1064 ) -> None: 

1065 """Adds the centers of mass (COM) in cartesian coordinates of a group or 

1066 groups of atoms. If these group definitions coincide directly with 

1067 residues, use `add_residue_COM` instead. No periodic boundaries are 

1068 taken into account. 

1069 

1070 Args: 

1071 group_definitions (Sequence[int]): List of the groups of atom indices 

1072 for which the COM will be computed. The atoms are zero-indexed. 

1073 ref_geom (Optional[md.Trajectory]): If a md.Trajectory is provided, 

1074 the coordinates of the provided traj will be centered using 

1075 this reference, before computing COM. If None is provided, 

1076 they won't be centered. Defaults to None. 

1077 image_molecules (bool): The method traj.image_molecules will be 

1078 called before computing averages. The method tries to correct 

1079 for molecules broken across periodic boundary conditions, but 

1080 can be time consuming. See 

1081 http://mdtraj.org/latest/api/generated/mdtraj.Trajectory.html#mdtraj.Trajectory.image_molecules 

1082 for more details. Defualts to False. 

1083 mass_weighted (bool): Set to False if you want the geometric center 

1084 and not the COM. Defaults to True. 

1085 

1086 """ 

1087 # Local Folder Imports 

1088 from .features import GroupCOMFeature 

1089 

1090 f = GroupCOMFeature( 

1091 self.traj, 

1092 group_definitions, 

1093 ref_geom=ref_geom, 

1094 image_molecules=image_molecules, 

1095 mass_weighted=mass_weighted, 

1096 delayed=delayed, 

1097 ) 

1098 self._add_feature(f) 

1099 

1100 def add_residue_COM( 

1101 self, 

1102 residue_indices: Sequence[int], 

1103 scheme: Literal["all", "backbone", "sidechain"] = "all", 

1104 ref_geom: Optional[md.Trajectory] = None, 

1105 image_molecules: bool = False, 

1106 mass_weighted: bool = True, 

1107 delayed: bool = False, 

1108 ) -> None: 

1109 """Adds a per-residue center of mass (COM) in cartesian coordinates. 

1110 No periodic boundaries are taken into account. 

1111 

1112 Args: 

1113 residue_indices (Sequence[int]): List of the residue indices 

1114 for which the COM will be computed. The atoms are zero-indexed. 

1115 scheme (Literal["all", "backbone", "sidechain"]): What atoms 

1116 contribute to the COM computation. If the scheme yields no 

1117 atoms for some residue, the selection falls back to 'all' for 

1118 that residue. Defaults to 'all'. 

1119 ref_geom (Optional[md.Trajectory]): If a md.Trajectory is provided, 

1120 the coordinates of the provided traj will be centered using 

1121 this reference, before computing COM. If None is provided, 

1122 they won't be centered. Defaults to None. 

1123 image_molecules (bool): The method traj.image_molecules will be 

1124 called before computing averages. The method tries to correct 

1125 for molecules broken across periodic boundary conditions, but 

1126 can be time consuming. See 

1127 http://mdtraj.org/latest/api/generated/mdtraj.Trajectory.html#mdtraj.Trajectory.image_molecules 

1128 for more details. Defualts to False. 

1129 mass_weighted (bool): Set to False if you want the geometric center 

1130 and not the COM. Defaults to True. 

1131 

1132 """ 

1133 

1134 # Local Folder Imports 

1135 from .features import ResidueCOMFeature 

1136 

1137 assert scheme in ["all", "backbone", "sidechain"] 

1138 

1139 residue_atoms = _atoms_in_residues( 

1140 self.traj, 

1141 residue_indices, 

1142 subset_of_atom_idxs=self.traj.top.select(scheme), 

1143 ) 

1144 

1145 f = ResidueCOMFeature( 

1146 self.traj, 

1147 np.asarray(residue_indices), 

1148 residue_atoms, 

1149 scheme, 

1150 ref_geom=ref_geom, 

1151 image_molecules=image_molecules, 

1152 mass_weighted=mass_weighted, 

1153 delayed=delayed, 

1154 ) 

1155 

1156 self._add_feature(f) 

1157 

1158 def add_dihedrals( 

1159 self, 

1160 indexes: np.ndarray, 

1161 deg: bool = False, 

1162 cossin: bool = False, 

1163 periodic: bool = True, 

1164 delayed: bool = False, 

1165 ) -> None: 

1166 """Adds the list of dihedrals to the feature list 

1167 

1168 Args: 

1169 indexes (np.ndarray): An array with quadruplets of atom indices. 

1170 deg (bool): Whether the output should be in degrees (True) or radians 

1171 (False). Defaults to False. 

1172 cossin (bool): Whether to return the angles (False) or tuples of their 

1173 cos and sin values (True). Defaults to False. 

1174 periodic (bool): Whether to observe the minimum image convention 

1175 and respect proteins breaking over the periodic boundary 

1176 condition as a whole (True). In this case, the trajectory container 

1177 in `traj` needs to have unitcell information. Defaults to True. 

1178 

1179 """ 

1180 # Local Folder Imports 

1181 from .features import DihedralFeature 

1182 

1183 indexes = self._check_indices(indexes, pair_n=4) 

1184 f = DihedralFeature( 

1185 self.traj, 

1186 indexes, 

1187 deg=deg, 

1188 cossin=cossin, 

1189 periodic=periodic, 

1190 delayed=delayed, 

1191 ) 

1192 self._add_feature(f) 

1193 

1194 def add_sidechain_torsions( 

1195 self, 

1196 selstr: Optional[str] = None, 

1197 deg: bool = False, 

1198 cossin: bool = False, 

1199 periodic: bool = True, 

1200 which: Union[ 

1201 Literal["all"], Sequence[Literal["chi1", "chi2", "chi3", "chi4", "chi5"]] 

1202 ] = "all", 

1203 delayed: bool = False, 

1204 ) -> None: 

1205 """Adds all side chain torsion angles or the ones specified in `selstr` 

1206 to the feature list. 

1207 

1208 Args: 

1209 

1210 selstr (Optional[str]): Selection string specifying the atom 

1211 selection used to specify a specific set of backbone angles. 

1212 If None (default), all chi angles found in the topology will be 

1213 computed. Defaults to None. 

1214 deg (bool): Whether the output should be in degrees (True) or radians 

1215 (False). Defaults to False. 

1216 cossin (bool): Whether to return the angles (False) or tuples of their 

1217 cos and sin values (True). Defaults to False. 

1218 periodic (bool): Whether to observe the minimum image convention 

1219 and respect proteins breaking over the periodic boundary 

1220 condition as a whole (True). In this case, the trajectory container 

1221 in `traj` needs to have unitcell information. Defaults to True. 

1222 which (Union[Literal["all"], Sequence[Literal["chi1", "chi2", "chi3", "chi4", "chi5"]]]): 

1223 Which angles to consider. Can be 'all' or any combination of 

1224 ('all', 'chi1', 'chi2', 'chi3', 'chi4', 'chi5'). 

1225 

1226 """ 

1227 # Local Folder Imports 

1228 from .features import SideChainTorsions 

1229 

1230 f = SideChainTorsions( 

1231 self.traj, 

1232 selstr=selstr, 

1233 deg=deg, 

1234 cossin=cossin, 

1235 periodic=periodic, 

1236 which=which, 

1237 delayed=delayed, 

1238 ) 

1239 self._add_feature(f) 

1240 

1241 def add_minrmsd_to_ref( 

1242 self, 

1243 ref: Union[md.Trajectory, SingleTraj], 

1244 ref_frame: int = 0, 

1245 atom_indices: Optional[np.ndarray] = None, 

1246 precentered: bool = False, 

1247 delayed: bool = False, 

1248 ) -> None: 

1249 """Adds the minimum root-mean-square-deviation (minrmsd) 

1250 with respect to a reference structure to the feature list. 

1251 

1252 Args: 

1253 ref (Union[md.Trajectory, SingleTraj]): Reference structure for 

1254 computing the minrmsd. 

1255 ref_frame (int): Reference frame of the filename specified in `ref`. 

1256 Defaults to 0. 

1257 atom_indices (Optional[np.ndarray]): Atoms that will be used for: 

1258 1. aligning the target and reference geometries. 

1259 2. computing rmsd after the alignment. 

1260 If left to None, all atoms of `ref` will be used. 

1261 precentered (bool): Use this boolean at your own risk to let 

1262 mdtraj know that the target conformations are already centered 

1263 at the origin, i.e., their (uniformly weighted) center of 

1264 mass lies at the origin. This will speed up the computation of 

1265 the rmsd. Defaults to False 

1266 """ 

1267 # Local Folder Imports 

1268 from .features import MinRmsdFeature 

1269 

1270 f = MinRmsdFeature( 

1271 self.traj, 

1272 ref, 

1273 ref_frame=ref_frame, 

1274 atom_indices=atom_indices, 

1275 precentered=precentered, 

1276 delayed=delayed, 

1277 ) 

1278 self._add_feature(f) 

1279 

1280 @property 

1281 def ndim(self) -> int: 

1282 return self.dimension() 

1283 

1284 @property 

1285 def features(self) -> list[AnyFeature]: 

1286 return self.active_features 

1287 

1288 @property 

1289 def select_Ca(self) -> np.ndarray: 

1290 return self.traj.top.select("name CA") 

1291 

1292 def _check_indices(self, pair_inds: np.ndarray, pair_n: int = 2) -> np.ndarray: 

1293 """Ensure pairs are valid (shapes, all atom indices available?, etc.)""" 

1294 

1295 pair_inds = np.array(pair_inds).astype(dtype=int, casting="safe") 

1296 

1297 if pair_inds.ndim != 2: 

1298 raise ValueError("pair indices has to be a matrix.") 

1299 

1300 if pair_inds.shape[1] != pair_n: 

1301 raise ValueError(f"pair indices shape has to be (x, {pair_n}).") 

1302 

1303 if pair_inds.max() > self.traj.top.n_atoms: 

1304 raise ValueError( 

1305 f"index out of bounds: {pair_inds.max()}. Maximum atom index " 

1306 f"available: {self.traj.top.n_atoms}" 

1307 ) 

1308 

1309 return pair_inds 

1310 

1311 def transform(self, p: Optional[tqdm] = None) -> np.ndarray: 

1312 """Calls the `transform()` methods of the accumulated features. 

1313 

1314 Args: 

1315 p (Optional[tqdm]): If an instance of tqdm is provided, a progress 

1316 it will be updated after every call of `feature.transform()`. 

1317 If None is provided, no progress bar will be displayed. 

1318 

1319 Returns: 

1320 np.ndarray: A numpy array with the features in `self.active_features`, 

1321 stacked along the feature dimension. 

1322 

1323 """ 

1324 # Encodermap imports 

1325 from encodermap.loading.features import CustomFeature 

1326 

1327 # if there are no features selected, return given trajectory 

1328 if not self.active_features: 

1329 warnings.warn( 

1330 "You have not selected any features. Add features and call " 

1331 "`transform` or `get_output` again.." 

1332 ) 

1333 return 

1334 

1335 # otherwise, build feature vector. 

1336 feature_vec = [] 

1337 

1338 for f in self.active_features: 

1339 # perform sanity checks for custom feature input 

1340 if isinstance(f, CustomFeature): 

1341 # NOTE: casting=safe raises in numpy>=1.9 

1342 vec = f.transform(self.traj).astype(np.float32, casting="safe") 

1343 if vec.shape[0] == 0: 

1344 vec = np.empty((0, f.dimension)) 

1345 

1346 if not isinstance(vec, np.ndarray): 

1347 raise ValueError( 

1348 f"Your custom feature {f.describe()} did not return a numpy.ndarray!" 

1349 ) 

1350 if not vec.ndim == 2: 

1351 raise ValueError( 

1352 f"Your custom feature {f.describe()} did not return a " 

1353 f"2d array. Shape was {vec.shape}" 

1354 ) 

1355 if not vec.shape[0] == self.traj.xyz.shape[0]: 

1356 raise ValueError( 

1357 f"Your custom feature {f.describe()} did not return as " 

1358 f"many frames, as it received. Input was {self.traj.xyz.shape[0]}, " 

1359 f"output was {vec.shape[0]}" 

1360 ) 

1361 else: 

1362 vec = f.transform().astype(np.float32) 

1363 feature_vec.append(vec) 

1364 

1365 if p is not None: 

1366 p.update() 

1367 

1368 if len(feature_vec) > 1: 

1369 res = np.hstack(feature_vec) 

1370 else: 

1371 res = feature_vec[0] 

1372 

1373 return res 

1374 

1375 def get_output(self, pbar: Optional[tqdm] = None) -> xr.Dataset: 

1376 # Encodermap imports 

1377 from encodermap.misc.xarray import unpack_data_and_feature 

1378 

1379 if pbar is None: 

1380 if self.traj.basename is None: 

1381 desc = f"Getting output of {len(self.active_features)} features" 

1382 else: 

1383 desc = ( 

1384 f"Getting output of {len(self.active_features)} features for " 

1385 f"{self.traj.basename}" 

1386 ) 

1387 pbar = tqdm( 

1388 total=len(self.active_features), 

1389 desc=desc, 

1390 ) 

1391 with pbar as p: 

1392 out = self.transform(p=p) 

1393 return unpack_data_and_feature(self, self.traj, out) 

1394 

1395 def describe(self) -> list[str]: 

1396 all_labels = [] 

1397 for f in self.active_features: 

1398 all_labels += f.describe() 

1399 return all_labels 

1400 

1401 def __repr__(self) -> str: 

1402 # Standard Library Imports 

1403 import pprint 

1404 

1405 feat_str = pprint.pformat(self.describe()[:10])[:-1] + ", ...]" 

1406 return f"EncoderMap Featurizer with features:\n{feat_str}" 

1407 

1408 def __len__(self) -> int: 

1409 return len(self.active_features) 

1410 

1411 def dimension(self) -> int: 

1412 return sum(f.dimension for f in self.active_features) 

1413 

1414 

1415class Featurizer: 

1416 """EncoderMap's featurization has drawn much inspiration from PyEMMA 

1417 (https://github.com/markovmodel/PyEMMA). 

1418 

1419 EncoderMap's Featurizer collects and computes collective variables (CVs). 

1420 CVs are data that are aligned with MD trajectories on the frame/time axis. 

1421 Trajectory data contains (besides the topology) an axis for atoms, and 

1422 an axis for cartesian coordinate (x, y, z), so that a trajectory can be 

1423 understood as an array with shape (n_frames, n_atoms, 3). A CV is an array 

1424 that is aligned with the frame/time and has its own feature axis. If the 

1425 trajectory in our example has 3 residues (MET, ALA, GLY), we can define 

1426 6 dihedral angles along the backbone of this peptide. These angles are: 

1427 

1428 * PSI1: Between MET1-N - MET1-CA - MET1-C - ALA2-N 

1429 * OMEGA1: Between MET1-CA - MET1-C - ALA2-N - ALA2-CA 

1430 * PHI1: Between MET1-C - ALA2-N - ALA2-CA - ALA2-C 

1431 * PSI2: Between ALA2-N - ALA2-CA - ALA2-C - GLY3-N 

1432 * OMEGA2: Between ALA2-CA - ALA2-C - GLY3-N - GLY3-CA 

1433 * PHI2: Between ALA2-C - GLY3-N - GLY3-CA - GLY3-C 

1434 

1435 Thus, the collective variable 'backbone-dihedrals' provides an array of 

1436 shape (n_frames, 6) and is aligned with the frame/time axis of the trajectory. 

1437 

1438 """ 

1439 

1440 def __new__(cls, traj: Union[SingleTraj, TrajEnsemble]): 

1441 # Encodermap imports 

1442 from encodermap.trajinfo.info_single import SingleTraj 

1443 

1444 if isinstance(traj, SingleTraj): 

1445 return SingleTrajFeaturizer(traj) 

1446 else: 

1447 return EnsembleFeaturizer(traj) 

1448 

1449 

1450class AddSingleFeatureMethodsToClass(type): 

1451 """Metaclass that programatically adds methods to the EnsembleFeaturizer.""" 

1452 

1453 def __new__(cls, name, bases, dct): # pragma: no doccheck 

1454 x = super().__new__(cls, name, bases, dct) 

1455 

1456 # iteratively add these functions 

1457 for add_X_function_name in _ADD_X_FUNCTION_NAMES: 

1458 # create a function with the corresponding add_X_function_name 

1459 # IMPORTANT: keep this as a keyword argument, to prevent 

1460 # python from late-binding 

1461 def add_X_func( 

1462 self, *args, add_x_name=add_X_function_name, **kwargs 

1463 ) -> None: 

1464 # iterate over the trajs in self.trajs 

1465 for top, trajs in self.trajs.trajs_by_top.items(): 

1466 # create a featurizer 

1467 if top not in self.feature_containers: 

1468 f = SingleTrajFeaturizer(trajs[0], delayed=self.delayed) 

1469 self.feature_containers[top] = f 

1470 else: 

1471 f = self.feature_containers[top] 

1472 # get the method defined by pyemma_function_name 

1473 func = getattr(f, add_x_name) 

1474 # call the method with *args and **kwargs, so that the 

1475 # correct feature is added 

1476 func(*args, **kwargs) 

1477 # this is the feature we are looking for. 

1478 feature = f.active_features[-1] 

1479 # add the feature 

1480 if top in self.active_features: 

1481 if feature in self.active_features[top]: 

1482 continue 

1483 else: 

1484 self.active_features[top].append(feature) 

1485 else: 

1486 self.active_features.setdefault(top, []).append(feature) 

1487 

1488 # also add the docstring :) 

1489 add_X_func.__doc__ = getattr( 

1490 SingleTrajFeaturizer, add_X_function_name 

1491 ).__doc__ 

1492 setattr(x, add_X_function_name, add_X_func) 

1493 return x 

1494 

1495 

1496class DaskFeaturizerMeta(type): 

1497 def __new__(cls, name, bases, dct): 

1498 x = super().__new__(cls, name, bases, dct) 

1499 

1500 for add_X_function_name in _ADD_X_FUNCTION_NAMES + ["add_list_of_feats"]: 

1501 

1502 def add_X_func(self, *args, add_x_name=add_X_function_name, **kwargs): 

1503 # call the parents featurizer class add function 

1504 assert self.feat.delayed, ( 

1505 f"Programmatically added `add_X_func` got a featurizer with a" 

1506 f"wrong `delayed` variable: {id(self.feat)=} {self.feat.delayed=}" 

1507 ) 

1508 getattr(self.feat, add_x_name)(*args, delayed=True, **kwargs) 

1509 

1510 add_X_func.__doc__ = getattr( 

1511 SingleTrajFeaturizer, 

1512 add_X_function_name, 

1513 ).__doc__ 

1514 setattr(x, add_X_function_name, add_X_func) 

1515 return x 

1516 

1517 

1518class EnsembleFeaturizer(metaclass=AddSingleFeatureMethodsToClass): 

1519 """The EnsembleFeaturizer is a container of multiple SinlgeTrajFeaturizer. 

1520 

1521 The `SingleTrajFeaturizer` are collected in a dict with the topologies 

1522 of the sub-ensembles as keys. 

1523 

1524 """ 

1525 

1526 def __init__(self, trajs: TrajEnsemble, delayed: bool = False) -> None: 

1527 """Instantiates the `EnsembleFeaturizer`. 

1528 

1529 Args: 

1530 trajs (TrajEnsmble): The `TrajEnsemble` to featurizer. 

1531 delayed (bool): Whether using dask to calculate features, or just do 

1532 a regular featurization. 

1533 

1534 """ 

1535 self.trajs = trajs 

1536 self.delayed = delayed 

1537 self.active_features = {} 

1538 self.feature_containers = {} 

1539 self.ensemble = False 

1540 self._n_custom_features = 0 

1541 self._custom_feature_ids = [] 

1542 

1543 def describe(self) -> dict[md.Topology, list[str]]: 

1544 """Returns the labels of the feature output. 

1545 

1546 Returns: 

1547 dict[md.Topology, list[str]]: A dict where the keys are the 

1548 topologies in the `TrajEnsemble` and the values are the 

1549 `describe()` outputs of the `SingleTrajFeaturizer` classes. 

1550 

1551 """ 

1552 out = {} 

1553 for top, container in self.feature_containers.items(): 

1554 out[top] = container.describe() 

1555 return out 

1556 

1557 def __len__(self) -> int: 

1558 lengths = [len(f) for f in self.feature_containers.values()] 

1559 assert all( 

1560 [lengths[0] == length for length in lengths] 

1561 ), f"This `{self.__class__.__name__}` has uneven features per topology." 

1562 if len(lengths) < 1: 

1563 return 0 

1564 return lengths[0] 

1565 

1566 def _add_feature( 

1567 self, f: AnyFeature, top: md.Topology, trajs: TrajEnsemble 

1568 ) -> None: 

1569 assert f.delayed == self.delayed, ( 

1570 f"In-memory featurizer {self.__class__} unexpectedly got a delayed " 

1571 f"feature {f}. {f.delayed=} {self.delayed=}" 

1572 ) 

1573 if top in self.feature_containers: 

1574 feat = self.feature_containers[top] 

1575 else: 

1576 feat = SingleTrajFeaturizer(trajs[0], delayed=self.delayed) 

1577 self.feature_containers[top] = feat 

1578 feat._add_feature(f) 

1579 self.active_features.setdefault(top, []).append(f) 

1580 

1581 def add_custom_feature(self, feature: AnyFeature) -> None: 

1582 # Encodermap imports 

1583 from encodermap.loading.features import CustomFeature 

1584 

1585 # decide on feature's id 

1586 if feature.__class__.__name__ == "CustomFeature": 

1587 if not hasattr(feature, "name"): 

1588 if not hasattr(feature, "id"): 

1589 feature.id = self._n_custom_features 

1590 self._custom_feature_ids.append(self._n_custom_features) 

1591 self._n_custom_features += 1 

1592 elif feature.id is None: 

1593 feature.id = self._n_custom_features 

1594 self._custom_feature_ids.append(self._n_custom_features) 

1595 self._n_custom_features += 1 

1596 else: 

1597 assert feature.id not in self._custom_feature_ids, ( 

1598 f"A CustomFeature with the id {feature.id} already exists. " 

1599 f"Please change the id of your CustomFeature." 

1600 ) 

1601 assert ( 

1602 isinstance(feature, CustomFeature) 

1603 or issubclass(feature.__class__, CustomFeature) 

1604 or hasattr(feature, "_is_custom") 

1605 ) 

1606 feature.name = f"CustomFeature_{feature.id}" 

1607 else: 

1608 try: 

1609 feature.name = feature.__class__.__name__ 

1610 except AttributeError: 

1611 pass 

1612 

1613 # add 

1614 for top, trajs in self.trajs.trajs_by_top.items(): 

1615 self._add_feature(feature, top, trajs) 

1616 

1617 def add_list_of_feats( 

1618 self, 

1619 which: Union[Literal["all"], Sequence[str]] = "all", 

1620 ensemble: bool = False, 

1621 deg: bool = False, 

1622 omega: bool = True, 

1623 check_aas: bool = True, 

1624 periodic: bool = True, 

1625 delayed: bool = False, 

1626 ) -> None: 

1627 """Adds features to the Featurizer to be loaded either in-memory. The 

1628 argument `which` can be either 'all' or a list of the following strings: 

1629 * 'AllCartesians': Cartesian coordinates of all atoms with 

1630 shape (n_frames, n_atoms, 3). 

1631 * 'AllBondDistances': Bond distances of all bonds recognized by 

1632 mdtraj. Use top = md.Topology.from_openmm()if mdtraj does not 

1633 recognize all bonds. 

1634 * 'CentralCartesians': Cartesians of the N, C, CA atoms in the 

1635 backbone with shape (n_frames, n_residues * 3, 3). 

1636 * 'CentralBondDistances': The bond distances of the N, C, CA bonds 

1637 with shape (n_frames, n_residues * 3 - 1). 

1638 * 'CentralAngles': The angles between the backbone bonds with shape 

1639 (n_frames, n_residues * 3 - 2). 

1640 * 'CentralDihedrals': The dihedrals between the backbone atoms 

1641 (omega, phi, psi). With shape (n_frames, n_residues * 3 - 3). 

1642 * 'SideChainCartesians': Cartesians of the sidechain-atoms. 

1643 Starting with CB, CG, ... 

1644 * 'SideChainBondDistances': Bond distances between the 

1645 sidechain atoms. starting with the CA-CG bond. 

1646 * 'SideChainAngles': Angles between sidechain atoms. Starting with 

1647 the C-CA-CB angle. 

1648 * 'SideChainDihedrals': Dihedrals of the sidechains (chi1, chi2, chi3). 

1649 If 'all' is provided for `which` the CentralCartesian, CentralDistances, 

1650 CentralAngles, CentralDihedrals, SideChainDihedrals will be added. 

1651 

1652 Args: 

1653 which (Union[str, list], optional). Either add 'all' features or 

1654 a list of features. See Above for possible features. Defaults 

1655 to 'all'. 

1656 ensemble (bool): Whether the trajs in this class belong to an ensemble. 

1657 This implies that they contain either the same topology or are 

1658 very similar (think wt, and mutant). Setting this option True will 

1659 try to match the CVs of the trajs onto the same dataset. 

1660 If a VAL residue has been replaced by LYS in the mutant, 

1661 the number of sidechain dihedrals will increase. The CVs of the 

1662 trajs with VAL will thus contain some NaN values. Defaults to False. 

1663 deg (bool): Whether the output should be formatted in degrees. 

1664 omega (bool): Whether to include the omega angles of the backbone. 

1665 check_aas (bool): Whether to check if all residues in top are known. Helps with custom topology to not skip unkonw custom/non-natural amino acids. 

1666 

1667 """ 

1668 self.ensemble = ensemble 

1669 recognized_str = list(UNDERSCORE_MAPPING.keys()) + list( 

1670 UNDERSCORE_MAPPING.values() 

1671 ) 

1672 for top, trajs in self.trajs.trajs_by_top.items(): 

1673 if isinstance(which, str): 

1674 if which == "full": 

1675 which = [ 

1676 "CentralCartesians", 

1677 "CentralBondDistances", 

1678 "CentralAngles", 

1679 "CentralDihedrals", 

1680 "SideChainDihedrals", 

1681 "SideChainCartesians", 

1682 "SideChainAngles", 

1683 "AllCartesians", 

1684 "SideChainBondDistances", 

1685 ] 

1686 if which == "all": 

1687 which = [ 

1688 "CentralCartesians", 

1689 "CentralBondDistances", 

1690 "CentralAngles", 

1691 "CentralDihedrals", 

1692 "SideChainDihedrals", 

1693 ] 

1694 else: 

1695 if which not in recognized_str: 

1696 raise Exception( 

1697 f"Recognized arguments to which are 'all' or any of the " 

1698 f"following: {recognized_str}. The str you provided {which} " 

1699 f"did not match any." 

1700 ) 

1701 which = [which] 

1702 elif isinstance(which, (list, tuple)): 

1703 assert all([isinstance(i, str) for i in which]), ( 

1704 f"The argument `which` needs to be provided a sequence of str. " 

1705 f"You have offending types in this argument." 

1706 ) 

1707 diff = set(which) - set(recognized_str) 

1708 if diff: 

1709 raise Exception( 

1710 f"One or more of the str in `which` are not recognized. " 

1711 f"The argument `which` needs to be a sequence containing any " 

1712 f"of the following: {recognized_str}. The unrecognized str are: " 

1713 f"{diff}." 

1714 ) 

1715 

1716 # add the features 

1717 # the _use_omega and _use_angle class attrs are added by the 

1718 # `FeatureMeta` in `features.py` by inspecting a `Feature` subclass' 

1719 # call signature 

1720 for cf in which: 

1721 if cf in UNDERSCORE_MAPPING: 

1722 cf = UNDERSCORE_MAPPING[cf] 

1723 feature = getattr(features, cf) 

1724 if ( 

1725 not feature._use_periodic 

1726 and not feature._use_angle 

1727 and not feature._use_omega 

1728 ): 

1729 feature = feature( 

1730 trajs[0], 

1731 check_aas=True, 

1732 generic_labels=ensemble, 

1733 delayed=delayed, 

1734 ) 

1735 elif ( 

1736 feature._use_periodic 

1737 and not feature._use_angle 

1738 and not feature._use_omega 

1739 ): 

1740 feature = feature( 

1741 trajs[0], 

1742 check_aas=True, 

1743 generic_labels=ensemble, 

1744 periodic=periodic, 

1745 delayed=delayed, 

1746 ) 

1747 elif ( 

1748 feature._use_periodic 

1749 and feature._use_angle 

1750 and not feature._use_omega 

1751 ): 

1752 feature = feature( 

1753 trajs[0], 

1754 deg=deg, 

1755 check_aas=check_aas, 

1756 generic_labels=ensemble, 

1757 periodic=periodic, 

1758 delayed=delayed, 

1759 ) 

1760 elif ( 

1761 feature._use_periodic and feature._use_angle and feature._use_omega 

1762 ): 

1763 feature = feature( 

1764 trajs[0], 

1765 deg=deg, 

1766 omega=omega, 

1767 check_aas=check_aas, 

1768 generic_labels=ensemble, 

1769 periodic=periodic, 

1770 delayed=delayed, 

1771 ) 

1772 else: 

1773 raise Exception( 

1774 f"Unknown combination of `_use_angle` and `_use_omega` in " 

1775 f"class attributes of {feature=}" 

1776 ) 

1777 if top in self.active_features: 

1778 if feature in self.active_features[top]: 

1779 warnings.warn( 

1780 f"Tried to re-add the same feature {feature.__class__.__name__} to " 

1781 f"{self.active_features=}" 

1782 ) 

1783 continue 

1784 else: 

1785 self.active_features[top].append(feature) 

1786 else: 

1787 self.active_features.setdefault(top, []).append(feature) 

1788 

1789 if top in self.feature_containers: 

1790 f = self.feature_containers[top] 

1791 else: 

1792 f = SingleTrajFeaturizer(trajs[0], delayed=self.delayed) 

1793 self.feature_containers[top] = f 

1794 f._add_feature(feature) 

1795 

1796 # after all is done, all tops should contain the same number of feats 

1797 no_of_feats = set([len(v) for v in self.active_features.values()]) 

1798 assert len(no_of_feats) == 1, ( 

1799 f"I was not able to add the same number of features to the respective " 

1800 f"topologies:\n{self.active_features=}\n{self.feature_containers=}" 

1801 ) 

1802 

1803 @property 

1804 def features(self) -> list[AnyFeature]: 

1805 feats = [] 

1806 for features in self.active_features.items(): 

1807 feats.extend(list(features)) 

1808 return feats 

1809 

1810 def transform( 

1811 self, 

1812 traj: Union[SingleTraj, md.Trajectory], 

1813 outer_p: Optional[Union[tqdm, rich.progress.Progress]] = None, 

1814 inner_p: Optional[Union[tqdm, rich.progress.Progress]] = None, 

1815 inner_p_id: Optional[int] = None, 

1816 ) -> np.ndarray: 

1817 """Applies the features to the trajectory data. 

1818 

1819 traj (Union[SingleTraj, md.Trajectory]): The trajectory which provides 

1820 the data. Make sure, that the topology of this traj matches the 

1821 topology used to initialize the features. 

1822 outer_p (Optional[Union[tqdm, rich.progress.Progress]]): An object 

1823 that supports `.update()` to advance a progress bar. The 

1824 `rich.progress.Progress` is special, as it needs additional code 

1825 to advance the multi-file progress bar dispolayed by the 

1826 `EnsembleFeaturzier`. The `outer_p` represents the overall 

1827 progress. 

1828 inner_p (Optional[Union[tqdm, rich.progress.Progress]]): Same as `outer_p`, 

1829 but the `inner_p` represents the progress per file. 

1830 inner_p_id (Optional[int]): The id of the `inner_p`, which needs to be 

1831 provided, if `outer_p` and `inner_p` are instances of 

1832 `rich.progress.Progress`. 

1833 

1834 """ 

1835 # Encodermap imports 

1836 from encodermap.loading.features import CustomFeature 

1837 

1838 feature_vec = [] 

1839 

1840 for f in self.active_features[traj.top]: 

1841 xyz = traj.xyz 

1842 if traj._have_unitcell: 

1843 unitcell_vectors = traj.unitcell_vectors 

1844 unitcell_info = np.hstack([traj.unitcell_lengths, traj.unitcell_angles]) 

1845 else: 

1846 unitcell_vectors = None 

1847 unitcell_info = None 

1848 if isinstance(f, CustomFeature): 

1849 vec = f.transform( 

1850 traj, 

1851 xyz, 

1852 unitcell_vectors, 

1853 unitcell_info, 

1854 ).astype(np.float32, casting="safe") 

1855 if vec.shape[0] == 0: 

1856 vec = np.empty((0, f.dimension)) 

1857 

1858 if not isinstance(vec, np.ndarray): 

1859 raise ValueError( 

1860 "Your custom feature %s did not return" 

1861 " a numpy.ndarray!" % str(f.describe()) 

1862 ) 

1863 if vec.ndim == 1: 

1864 vec = np.expand_dims(vec, -1) 

1865 if vec.ndim == 3: 

1866 vec = vec.reshape(xyz.shape[0], -1) 

1867 f.atom_feature = True 

1868 if not vec.shape[0] == traj.xyz.shape[0]: 

1869 raise ValueError( 

1870 "Your custom feature %s did not return" 

1871 " as many frames as it received!" 

1872 "Input was %i, output was %i" 

1873 % (str(f.describe()), traj.xyz.shape[0], vec.shape[0]) 

1874 ) 

1875 else: 

1876 vec = f.transform( 

1877 xyz, 

1878 unitcell_vectors, 

1879 unitcell_info, 

1880 ).astype(np.float32) 

1881 feature_vec.append(vec) 

1882 

1883 if outer_p is not None: 

1884 if isinstance(outer_p, rich.progress.Progress): 

1885 outer_p.update(0, advance=1) 

1886 else: 

1887 outer_p.update() 

1888 

1889 if inner_p is not None: 

1890 if isinstance(inner_p, rich.progress.Progress): 

1891 if inner_p_id is None: 

1892 inner_p_id = traj.traj_num + 1 

1893 outer_p.update(inner_p_id, advance=1) 

1894 else: 

1895 outer_p.update() 

1896 

1897 if len(feature_vec) > 1: 

1898 res = np.hstack(feature_vec) 

1899 else: 

1900 res = feature_vec[0] 

1901 

1902 # sleep half a second to let the progbars catch up 

1903 if outer_p is not None or inner_p is not None: 

1904 time.sleep(0.5) 

1905 

1906 return res 

1907 

1908 def n_features(self) -> int: 

1909 for i, (key, val) in enumerate(self.active_features.items()): 

1910 if i == 0: 

1911 length = len(val) 

1912 else: 

1913 _debug = [] 

1914 for i, (key, val) in enumerate(self.active_features.items()): 

1915 _debug.append(f"Top: {key}\nValue: {val}") 

1916 _debug = "\n\n".join(_debug) 

1917 assert length == len(val), ( 

1918 f"There are different number of features per topology in " 

1919 f"`self.active_features`. These features can't be transformed. " 

1920 f"Here are the features by topology:\n{_debug}" 

1921 ) 

1922 return length 

1923 

1924 def get_output( 

1925 self, 

1926 pbar: Optional[tqdm] = None, 

1927 ) -> xr.Dataset: 

1928 # Encodermap imports 

1929 from encodermap.misc.xarray import unpack_data_and_feature 

1930 

1931 if self.active_features == {}: 

1932 print(f"First add some features before calling `get_output()`.") 

1933 return 

1934 DSs = [] 

1935 n_features = self.n_features() 

1936 

1937 try: 

1938 # Third Party Imports 

1939 from rich.progress import Progress 

1940 

1941 _rich_installed = True 

1942 except ModuleNotFoundError: 

1943 _rich_installed = False 

1944 

1945 if pbar is None and _rich_installed: 

1946 with Progress() as progress: 

1947 tasks = [] 

1948 progress.add_task( 

1949 description=( 

1950 f"Getting output for an ensemble containing " 

1951 f"{self.trajs.n_trajs} trajs" 

1952 ), 

1953 total=n_features * self.trajs.n_trajs, 

1954 ) 

1955 for i, traj in enumerate(self.trajs): 

1956 desc = traj.basename 

1957 if traj.basename == "trajs": 

1958 desc = f"trajectory {traj.traj_num}" 

1959 tasks.append( 

1960 progress.add_task( 

1961 description=( 

1962 f"Getting output of {n_features} features for " 

1963 f"{desc}" 

1964 ), 

1965 total=n_features, 

1966 ) 

1967 ) 

1968 for i, traj in enumerate(self.trajs): 

1969 out = self.transform(traj, progress, progress, inner_p_id=i + 1) 

1970 ds = unpack_data_and_feature( 

1971 self.feature_containers[traj.top], traj, out 

1972 ) 

1973 DSs.append(ds) 

1974 else: 

1975 for i, traj in enumerate(self.trajs): 

1976 out = self.transform(traj, pbar, None) 

1977 ds = unpack_data_and_feature( 

1978 self.feature_containers[traj.top], traj, out 

1979 ) 

1980 DSs.append(ds) 

1981 return format_output(DSs) 

1982 

1983 

1984def format_output( 

1985 datasets: Sequence[xr.Dataset], 

1986) -> xr.Dataset: 

1987 """Concatenates multiple xr.Datasets and keeps coordinates in correct order. 

1988 

1989 Iterates over the labels in the coords that are not `traj`, `time` and picks 

1990 the one with the greatest dimension. These labels will be used as 

1991 the column names, the non-defined values are np.nan. 

1992 

1993 Args: 

1994 datasets (Sequence[xr.Dataset]): The datasets to combine. 

1995 

1996 Returns: 

1997 xr.Dataset: The output dataset. 

1998 

1999 """ 

2000 # Encodermap imports 

2001 from encodermap.trajinfo.trajinfo_utils import trajs_combine_attrs 

2002 

2003 # make sure that all traj-nums are unique 

2004 traj_nums = [ds.traj_num.values for ds in datasets] 

2005 assert all([i.size == 1 for i in traj_nums]) 

2006 traj_nums = np.array(traj_nums)[:, 0] 

2007 assert len(traj_nums) == len(np.unique(traj_nums)), ( 

2008 f"The sequence of datasets provided for arg `datasets` contains multiple " 

2009 f"traj_nums: {traj_nums=}" 

2010 ) 

2011 

2012 # create a large dataset 

2013 out = xr.concat( 

2014 datasets, 

2015 data_vars="all", 

2016 # compat="broadcast_equals", 

2017 # coords="all", 

2018 # join="outer", 

2019 dim="traj_num", 

2020 fill_value=np.nan, 

2021 combine_attrs=trajs_combine_attrs, 

2022 ) 

2023 

2024 # EncoderMap datasets 

2025 encodermap_dataarrays = list(UNDERSCORE_MAPPING.keys()) 

2026 

2027 all_labels = {} 

2028 for name, da in out.data_vars.items(): 

2029 if name not in encodermap_dataarrays: 

2030 continue 

2031 feature_axis = da.attrs["feature_axis"] 

2032 labels = da.coords[feature_axis].values 

2033 all_labels.setdefault(feature_axis, []).extend(labels) 

2034 

2035 # side dihedrals and central dihedrals need some special sorting 

2036 # the other generic labels can be sorted by their last int 

2037 for key, val in all_labels.items(): 

2038 all_labels[key] = np.unique(all_labels[key]) 

2039 

2040 if key == "SIDE_DIHEDRALS": 

2041 all_labels[key] = sorted( 

2042 all_labels[key], 

2043 key=lambda x: ( 

2044 int(re.findall(r"\d+", x)[-1]), 

2045 int(re.findall(r"\d+", x)[0]), 

2046 ), 

2047 ) 

2048 elif key == "ALLATOM": 

2049 all_labels[key] = sorted( 

2050 all_labels[key], 

2051 key=lambda x: ( 

2052 0 if x.endswith("c") else 1, 

2053 *map(int, re.findall(r"\d+", x)[::-1]), 

2054 ), 

2055 ) 

2056 elif key == "CENTRAL_DIHEDRALS": 

2057 all_labels[key] = sorted( 

2058 all_labels[key], 

2059 key=lambda x: ( 

2060 int(re.findall(r"\d+", x)[-1]), 

2061 1 if "PSI" in x else (2 if "OMEGA" in x else 3), 

2062 ), 

2063 ) 

2064 else: 

2065 all_labels[key] = sorted( 

2066 all_labels[key], key=lambda x: int(re.findall(r"\d+", x)[-1]) 

2067 ) 

2068 return out.reindex(all_labels, fill_value=np.nan) 

2069 

2070 

2071class DaskFeaturizer(metaclass=DaskFeaturizerMeta): 

2072 """Container for `SingleTrajFeaturizer` and `EnsembleFeaturizer` 

2073 that implements delayed transforms. 

2074 

2075 The DaskFeaturizer is similar to the other two featurizer classes and 

2076 mostly implements the same API. However, instead of computing the 

2077 transformations using in-memory computing, it prepares a `xarray.Dataset`, 

2078 which contains `dask.Arrays`. This dataset can be lazily and distributively 

2079 evaluated using dask.distributed clients and clusters. 

2080 

2081 """ 

2082 

2083 def __init__( 

2084 self, 

2085 trajs: Union[SingleTraj, TrajEnsemble], 

2086 n_workers: Union[str, int] = "cpu-2", 

2087 client: Optional[Client] = None, 

2088 ) -> None: 

2089 

2090 if not hasattr(trajs, "itertrajs"): 

2091 self.feat = SingleTrajFeaturizer(trajs, delayed=True) 

2092 else: 

2093 self.feat = EnsembleFeaturizer(trajs, delayed=True) 

2094 

2095 if n_workers == "cpu-2": 

2096 # Standard Library Imports 

2097 from multiprocessing import cpu_count 

2098 

2099 n_workers = cpu_count() - 2 

2100 

2101 if n_workers == "max": 

2102 # Standard Library Imports 

2103 from multiprocessing import cpu_count 

2104 

2105 n_workers = cpu_count() 

2106 

2107 dask.config.set(scheduler="processes") 

2108 

2109 if client is None: 

2110 self.client = _get_global_client() 

2111 else: 

2112 self.client = client 

2113 if self.client is None: 

2114 self.client = Client(n_workers=n_workers) 

2115 print( 

2116 f"Created dask scheduler. Access the dashboard via: " 

2117 f"{self.client.dashboard_link}" 

2118 ) 

2119 else: 

2120 print( 

2121 f"Using existing dask scheduler. Access the dashboard via: " 

2122 f"{self.client.dashboard_link}" 

2123 ) 

2124 

2125 def add_custom_feature(self, feature): 

2126 if not hasattr(feature, "delayed"): 

2127 feature.delayed = True 

2128 if not feature.delayed: 

2129 feature.delayed = True 

2130 self.feat.add_custom_feature(feature) 

2131 if hasattr(self, "dataset"): 

2132 warnings.warn( 

2133 f"The compute graph has already been built. I will rebuild the " 

2134 f"graph and add the feature as a transformer. Subsequent " 

2135 f"calls to `.get_output()` will include this feature." 

2136 ) 

2137 self.build_graph() 

2138 

2139 def build_graph( 

2140 self, 

2141 traj: Optional[SingleTraj] = None, 

2142 streamable: bool = False, 

2143 return_delayeds: bool = False, 

2144 ) -> None: 

2145 """Prepares the dask graph. 

2146 

2147 Args: 

2148 with_trajectories (Optional[bool]): Whether to also compute xyz. 

2149 This can be useful if you want to also save the trajectories to disk. 

2150 

2151 """ 

2152 if self.feat.active_features == {} or self.feat.active_features == []: 

2153 print(f"First add some features before calling `get_output()`.") 

2154 return 

2155 

2156 self.dataset, self.variables = build_dask_xarray( 

2157 self, 

2158 traj=traj, 

2159 streamable=streamable, 

2160 return_delayeds=return_delayeds, 

2161 ) 

2162 

2163 def to_netcdf( 

2164 self, 

2165 filename: Union[Path, str], 

2166 overwrite: bool = False, 

2167 with_trajectories: bool = False, 

2168 ) -> str: 

2169 """Saves the dask tasks to a NetCDF4 formatted HDF5 file. 

2170 

2171 Args: 

2172 filename (Union[str, list[str]]): The filename to be used. 

2173 overwrite (bool): Whether to overwrite the existing filename. 

2174 with_trajectories (bool): Also save the trajectory data. The output 

2175 file can be read with `encodermap.load(filename)` and rebuilds 

2176 the trajectories complete with traj_nums, common_str, custom_top, 

2177 and all CVs, that this featurizer calculates. 

2178 

2179 Returns: 

2180 str: Returns the filename of the created files. 

2181 

2182 """ 

2183 # Standard Library Imports 

2184 from pathlib import Path 

2185 

2186 filename = Path(filename) 

2187 if "dataset" in self.__dict__: 

2188 raise Exception(f"Graph already built.") 

2189 

2190 # allows multiple writes to netcdf4 files 

2191 def set_env(): 

2192 os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" 

2193 

2194 self.client.run(set_env) 

2195 

2196 if filename.is_file() and not overwrite: # pragma: nocover 

2197 raise Exception( 

2198 f"File {filename} already exists. Set `overwrite=True` to overwrite." 

2199 ) 

2200 if filename.is_file() and overwrite: 

2201 filename.unlink() 

2202 

2203 # build 

2204 self.build_graph(return_delayeds=with_trajectories) 

2205 

2206 if self.variables is not None: 

2207 # Third Party Imports 

2208 import h5py 

2209 from xarray import conventions 

2210 from xarray.backends.api import ( 

2211 _finalize_store, 

2212 _validate_attrs, 

2213 _validate_dataset_names, 

2214 ) 

2215 from xarray.backends.common import ArrayWriter 

2216 from xarray.backends.h5netcdf_ import H5NetCDFStore 

2217 

2218 # use xarrays's to_netcdf code and add saving of delayed coordinates, etc. 

2219 _validate_dataset_names(self.dataset) 

2220 _validate_attrs(self.dataset, invalid_netcdf=False) 

2221 store_open = H5NetCDFStore.open 

2222 have_chunks = any( 

2223 v.chunks is not None for v in self.dataset.variables.values() 

2224 ) 

2225 autoclose = have_chunks 

2226 store = store_open( 

2227 filename=filename, 

2228 mode="a", 

2229 format="NETCDF4", 

2230 autoclose=autoclose, 

2231 ) 

2232 writer = ArrayWriter() 

2233 try: 

2234 # create dicts of data to write to store 

2235 variables, attrs = conventions.encode_dataset_coordinates(self.dataset) 

2236 variables |= self.variables 

2237 store.store(variables, attrs, set(), writer, None) 

2238 store.close() 

2239 writes = writer.sync(compute=False) 

2240 finally: 

2241 store.close() 

2242 

2243 # this runs the computation and displays a progress 

2244 delayed = dask.delayed(_finalize_store)(writes, store) 

2245 delayed = delayed.persist() 

2246 progress(delayed) 

2247 delayed.compute() 

2248 

2249 # afterward, we remove the unwanted groups starting with md from the .h5 file 

2250 # they are artifacts of hijacking xarray's `to_netcdf` 

2251 # we also move all keys that are not part of the traj coords 

2252 md_keys = ["coordinates", "time", "cell_lengths", "cell_angles"] 

2253 with h5py.File(filename, "a") as f: 

2254 keys = list(f.keys()) 

2255 for key in filter(lambda k: k.startswith("md"), keys): 

2256 del f[key] 

2257 keys = list(f.keys()) 

2258 for key in keys: 

2259 if not any([m in key for m in md_keys]): 

2260 f.move(key, f"CVs/{key}") 

2261 

2262 # and add common_str, custom_top, etc. 

2263 self.feat.trajs.save(fname=filename, CVs=False, only_top=True) 

2264 else: 

2265 self.dataset.to_netcdf( 

2266 filename, 

2267 format="NETCDF4", 

2268 group="CVs", 

2269 engine="h5netcdf", 

2270 invalid_netcdf=False, 

2271 compute=True, 

2272 ) 

2273 return str(filename) 

2274 

2275 def get_output( 

2276 self, 

2277 make_trace: bool = False, 

2278 ) -> xr.Dataset: 

2279 """This function passes the trajs and the features of to dask to create a 

2280 delayed xarray out of that.""" 

2281 if "dataset" not in self.__dict__: 

2282 self.build_graph() 

2283 if not make_trace: 

2284 ds = self.dataset.compute() 

2285 if not ds: 

2286 raise Exception( 

2287 f"Computed dataset is empty. Maybe a computation failed in " 

2288 f"the dask-delayed dataset: {self.dataset}" 

2289 ) 

2290 # future = client.submit(future) 

2291 # out = self.client.compute(self.dataset) 

2292 # progress(out) 

2293 # return out.result() 

2294 else: 

2295 raise NotImplementedError("Currently not able to trace dask execution.") 

2296 # else: 

2297 # with tempfile.TemporaryDirectory() as tmpdir: 

2298 # tmpdir = Path(tmpdir) 

2299 # with Track(path=str(tmpdir)): 

2300 # out = self.client.compute(self.dataset) 

2301 # progress(out) 

2302 # return out.result() 

2303 # 

2304 # raise NotImplementedError( 

2305 # "gifsicle --delay 10 --loop=forever --colors 256 --scale=0.4 -O3 --merge dasks/part_*.png > output.gif" 

2306 # ) 

2307 return ds 

2308 

2309 @property 

2310 def feature_containers(self) -> dict[md.Topology, SingleTrajFeaturizer]: 

2311 return self.feat.feature_containers 

2312 

2313 @property 

2314 def active_features( 

2315 self, 

2316 ) -> Union[list[AnyFeature], dict[md.Topology, list[AnyFeature]]]: 

2317 return self.feat.active_features 

2318 

2319 def __len__(self): 

2320 return len(self.feat) 

2321 

2322 def transform( 

2323 self, 

2324 traj_or_trajs: Optional[Union[SingleTraj, TrajEnsemble]] = None, 

2325 *args, 

2326 **kwargs, 

2327 ) -> np.ndarray: 

2328 return self.feat.transform(traj_or_trajs, *args, **kwargs) 

2329 

2330 def describe(self) -> list[str]: 

2331 return self.feat.describe() 

2332 

2333 def dimension(self) -> int: 

2334 return self.feat.dimension 

2335 

2336 def visualize(self) -> None: 

2337 return dask.visualize(self.dataset)