Coverage for encodermap/trajinfo/info_single.py: 11%

815 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/trajinfo/info_single.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Classes to work with ensembles of trajectories. 

23 

24The statistics of a protein can be better described by an ensemble of proteins, 

25rather than a single long trajectory. Treating a protein in such a way opens great 

26possibilities and changes the way one can treat molecular dynamics data. 

27Trajectory ensembles allow: 

28 * Faster convergence via adaptive sampling. 

29 * Better anomaly detection of unique structural states. 

30 

31This subpackage contains two classes which are containers of trajectory data. 

32The SingleTraj trajectory contains information about a single trajectory. 

33The TrajEnsemble class contains information about multiple trajectories. This adds 

34a new dimension to MD data. The time and atom dimension are already established. 

35Two frames can be appended along the time axis to get a trajectory with multiple 

36frames. If they are appended along the atom axis, the new frame contains the 

37atoms of these two. The trajectory works in a similar fashion. Adding two trajectories 

38along the trajectory axis returns a trajectory ensemble, represented as a TrajEnsemble 

39class in this package. 

40 

41See also: 

42 http://statisticalbiophysicsblog.org/?p=92 

43 

44""" 

45 

46################################################################################ 

47# Imports 

48################################################################################ 

49 

50 

51# Future Imports at the top 

52from __future__ import annotations 

53 

54# Standard Library Imports 

55import errno 

56import os 

57import re 

58import sys 

59import warnings 

60from copy import deepcopy 

61from io import StringIO 

62from pathlib import Path 

63 

64# Third Party Imports 

65import numpy as np 

66import tables 

67from optional_imports import _optional_import 

68 

69# Encodermap imports 

70from encodermap._typing import CanBeIndex, CustomAAsDict 

71from encodermap.misc.misc import _TOPOLOGY_EXTS 

72from encodermap.misc.xarray_save_wrong_hdf5 import save_netcdf_alongside_mdtraj 

73from encodermap.trajinfo.load_traj import _load_pdb_from_uri, _load_traj 

74 

75 

76################################################################################ 

77# Optional Imports 

78################################################################################ 

79 

80 

81md = _optional_import("mdtraj") 

82mda = _optional_import("MDAnalysis") 

83h5 = _optional_import("h5py") 

84xr = _optional_import("xarray") 

85h5py = _optional_import("h5py") 

86pd = _optional_import("pandas") 

87 

88 

89################################################################################ 

90# Typing 

91################################################################################ 

92 

93 

94# Standard Library Imports 

95from collections.abc import Callable, Iterable, Sequence 

96from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload 

97 

98 

99if TYPE_CHECKING: # pragma: no cover 

100 # Third Party Imports 

101 import h5py as h5 

102 import MDAnalysis as mda 

103 import mdtraj as md 

104 import xarray as xr 

105 

106 # Encodermap imports 

107 from encodermap.trajinfo.info_all import TrajEnsemble 

108 from encodermap.trajinfo.trajinfo_utils import CustomTopology, SingleTrajFeatureType 

109 

110 

111################################################################################ 

112# Globals 

113################################################################################ 

114 

115 

116__all__: list[str] = ["SingleTraj"] 

117 

118 

119################################################################################ 

120# Utilities 

121################################################################################ 

122 

123 

124class MixedUpInputs(Exception): 

125 """For when the user provides trajectories as topologies and vice versa.""" 

126 

127 pass 

128 

129 

130class Capturing(list): 

131 """Class to capture print statements from function calls. 

132 

133 Examples: 

134 >>> # write a function 

135 >>> def my_func(arg='argument'): 

136 ... print(arg) 

137 ... return('fin') 

138 >>> # use capturing context manager 

139 >>> with Capturing() as output: 

140 ... my_func('new_argument') 

141 >>> print(output) 

142 ['new_argument', "'fin'"] 

143 

144 """ 

145 

146 def __enter__(self): 

147 self._stdout = sys.stdout 

148 sys.stdout = self._stringio = StringIO() 

149 return self 

150 

151 def __exit__(self, *args): 

152 self.extend(self._stringio.getvalue().splitlines()) 

153 del self._stringio # free up some memory 

154 sys.stdout = self._stdout 

155 

156 

157def _hash_numpy_array(x: np.ndarray) -> int: 

158 hash_value = hash(x.shape) 

159 hash_value ^= hash(x.strides) 

160 hash_value ^= hash(x.data.tobytes()) 

161 return hash_value 

162 

163 

164################################################################################ 

165# Classes 

166################################################################################ 

167 

168 

169class SingleTrajFsel: 

170 def __init__(self, other): 

171 self.other = other 

172 

173 def __getitem__(self, item: CanBeIndex) -> SingleTraj: 

174 if self.other.traj_num is None: 

175 if isinstance(item, (int, np.int64)): 

176 idx = np.where(self.other.id == item)[0] 

177 elif isinstance(item, (list, np.ndarray)): 

178 idx = np.where(np.in1d(self.other.id, np.asarray(item)))[0] 

179 elif isinstance(item, slice): 

180 raise NotImplementedError("Currently can't index frames with slice.") 

181 else: 

182 raise ValueError( 

183 f"The `fsel[]` method of `SingleTraj` takes {CanBeIndex} types, " 

184 f"but {type(item)} was provided." 

185 ) 

186 else: 

187 if isinstance(item, (int, np.int64)): 

188 idx = np.where(self.other.id[:, 1] == item)[0] 

189 elif isinstance(item, (list, np.ndarray)): 

190 idx = np.where(np.in1d(self.other.id[:, 1], np.asarray(item)))[0] 

191 elif isinstance(item, slice): 

192 raise NotImplementedError("Currently can't index frames with slice.") 

193 else: 

194 raise ValueError( 

195 f"The `fsel[]` method of `SingleTraj` takes {CanBeIndex} types, " 

196 f"but {type(item)} was provided." 

197 ) 

198 if len(idx) == 0: 

199 raise ValueError( 

200 f"No frames with frame index {item} in trajectory {self.other} " 

201 f"with frames: {self.other._frames}" 

202 ) 

203 return self.other[idx] 

204 

205 

206class SingleTraj: 

207 """This class contains the info about a single trajectory. 

208 

209 This class contains many of the attributes and methods of 

210 `mdtraj.Trajectory <https://www.mdtraj.org/1.9.8.dev0/api/generated/mdtraj.Trajectory.html>`_. 

211 It is meant to be used as a standalone single trajetcory or in an ensemble 

212 defined in the :obj:`encodermap.trajinfo.info_all.TrajEnsemble` class. 

213 Other than the standard 

214 `mdtraj.Trajectory <https://www.mdtraj.org/1.9.8.dev0/api/generated/mdtraj.Trajectory.html>`_ 

215 this class loads the MD data only when needed. The location of the file(s) 

216 and other attributes like indices (int, list[int], numpy.ndarray, slice) 

217 are stored until the traj is accessed via the 

218 :attr:`SingleTraj.traj` attribute. The returned traj is a 

219 `mdtraj.Trajectory <https://www.mdtraj.org/1.9.8.dev0/api/generated/mdtraj.Trajectory.html>`_ 

220 with the correct number of frames in the correct sequence. 

221 

222 Besides MD data, this class keeps track of your collective variables. Oftentimes 

223 the raw xyz data of a trajectory is not needed for understanding the conformation 

224 and suitable CVs are selected to represent a protein via internal coordinates 

225 (torsions, pairwise distances, etc.). This class keeps tack of your CVs. 

226 Whether you call them ``'highd'`` or ``'torsions'``, this class keeps 

227 track of everything and returns the values when you need them. 

228 

229 :obj:`SingleTraj` supports fancy indexing, so you can extract one or more frames 

230 from a Trajectory as a separate trajectory. For example, to form a 

231 trajectory with every other frame, you can slice with ``traj[::2]``. 

232 

233 Note: 

234 SingleTraj uses the nanometer, degree & picosecond unit system. 

235 

236 Attributes: 

237 backend (str): Current state of loading. If ``backend == 'no_load'`` xyz 

238 data will be loaded from disk, if accessed. If ``backend == 'mdtraj'``, 

239 the data is already in RAM. 

240 common_str (str): Substring of :attr:`traj_file` and :attr:`top_file`. 

241 Used to group multiple trajectory and topology files. 

242 If ``traj files=['protein1_traj1.xtc', 'protein1_traj2.xtc']`` 

243 have the same topolgy stored in a file called ``'protein1.pdb'``, you 

244 can load them with ``common_str='protein1'`` together with more ``.xtc`` 

245 and ``.pdb`` files and these two ``.xtc`` files will use the correct 

246 ``.pdb`` file. 

247 index (Sequence[Union[None, int, list, numpy.ndarray, slice]]): A sequence of 

248 fancy slices of the trajectory. When file is loaded from disk, the 

249 fancy indexes will be applied one after the other. 

250 traj_num (int): Integer to identify a SingleTraj class in a TrajEnsemble class. 

251 traj_file (str): Trajectory file used to create this class. 

252 top_file (str): Topology file used to create this class. If a .h5 trajectory 

253 was used traj_file and top_file are identical. If a 

254 ``mdtraj.Trajectory`` was used to create SingleTraj, these strings are empty. 

255 

256 Examples: 

257 Load a pdb file with 14 frames from rcsb.org 

258 

259 >>> import encodermap as em 

260 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb") 

261 >>> traj # doctest: +ELLIPSIS 

262 <encodermap.SingleTraj object... 

263 >>> traj.n_frames 

264 14 

265 

266 Providing common_str sets this attribute. 

267 

268 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb", common_str="1GHC") 

269 >>> traj.common_str 

270 '1GHC' 

271 

272 Indexing using integers returns a SingleTraj with only one frame. 

273 

274 >>> frame = traj[5] 

275 >>> frame.n_frames 

276 1 

277 

278 Indexing can also use lists of integers. 

279 

280 >>> subset = traj[[0, 1, 5]] 

281 >>> subset.n_frames 

282 3 

283 

284 Further indexing this ``subset`` uses the current trajectory 'as is'. Although 

285 frame 0, 1, and 5 have been extracted from ``traj``, we get frame 5 from 

286 ``subset`` by indexing with 2. 

287 

288 >>> frame = subset[2] 

289 >>> frame.id 

290 array([5]) 

291 

292 Indexing using the original frame indices from the file is done using the 

293 ``fsel[]`` accessor. 

294 

295 >>> frame = subset.fsel[5] 

296 >>> frame.id 

297 array([5]) 

298 

299 Advanced slicing 

300 

301 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb")[-1:7:-2] 

302 >>> [frame.id[0] for frame in traj] 

303 [13, 11, 9] 

304 

305 The ``traj_num`` argument is mainly used in :obj:`encodermap.TrajEnsemble`, 

306 but can be provided manually. 

307 

308 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb", traj_num=2) 

309 >>> traj.traj_num 

310 2 

311 

312 The argument ``basename_fn`` should be a callable, that takes a string 

313 and returns a string. 

314 

315 >>> from pathlib import Path 

316 >>> def my_basename_fn(filename): 

317 ... stem = str(Path(filename).stem) 

318 ... return "custom_" + stem 

319 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb", basename_fn=my_basename_fn) 

320 >>> traj.basename 

321 'custom_1GHC' 

322 

323 Build a trajectory ensemble from multiple :obj:`SingleTraj` objects. 

324 

325 >>> traj1 = em.SingleTraj("https://files.rcsb.org/view/1YUG.pdb") 

326 >>> traj2 = em.SingleTraj("https://files.rcsb.org/view/1YUF.pdb") 

327 >>> trajs = traj1 + traj2 

328 >>> print(trajs.n_trajs, trajs.n_frames, [traj.n_frames for traj in trajs]) 

329 2 31 [15, 16] 

330 

331 """ 

332 

333 _mdtraj_attr = [ 

334 "n_frames", 

335 "n_atoms", 

336 "n_chains", 

337 "n_residues", 

338 "openmm_boxes", 

339 "openmm_positions", 

340 "time", 

341 "timestep", 

342 "xyz", 

343 "unitcell_vectors", 

344 "unitcell_lengths", 

345 "unitcell_angles", 

346 "_check_valid_unitcell", 

347 "_distance_unit", 

348 "_have_unitcell", 

349 "_rmsd_traces", 

350 "_savers", 

351 "_string_summary_basic", 

352 "_time", 

353 "_time_default_to_arange", 

354 "_topology", 

355 "_unitcell_angles", 

356 "_unitcell_lengths", 

357 "_xyz", 

358 ] 

359 

360 def __init__( 

361 self, 

362 traj: Union[str, Path, md.Trajectory], 

363 top: Optional[Union[str, Path]] = None, 

364 common_str: str = "", 

365 backend: Literal["no_load", "mdtraj"] = "no_load", 

366 index: Optional[Union[int, list[int], np.ndarray, slice]] = None, 

367 traj_num: Optional[int] = None, 

368 basename_fn: Optional[Callable[[str], str]] = None, 

369 custom_top: Optional[CustomAAsDict] = None, 

370 ) -> None: 

371 """Initialize the SingleTraj object with location and reference pdb file. 

372 

373 Args: 

374 traj (Union[str, mdtraj.Trajectory]): The trajectory. This argument 

375 can either be the filename of a trajectory file 

376 (``.xtc, .dcd, .h5, .trr``) or an instance of 

377 ``mdtraj.Trajectory``. 

378 top (Union[str, mdtraj.Topology], optional): The path to the 

379 topology file. Defaults to :obj:`None`. If a ``mdtraj.Trajectory`` 

380 or a ``.h5`` file is provided in :obj:`traj`, this argument will 

381 not be used and the topology from the corresponding ``traj`` 

382 argument will be used. 

383 common_str (str, optional): A string to group traj of similar 

384 topology. If multiple :obj:`SingleTraj` are grouped in one 

385 :obj:`encodermap.trajinfo.info_all.TrajEnsemble`, the 

386 ``common_str`` is used to group them together. Defaults to '' 

387 which means this instance of :obj:`SingleTraj` won't have a common 

388 string. 

389 backend (Literal['no_load', 'mdtraj'], optional): Choose the 

390 backend to load trajectories. 

391 - 'mdtraj' uses mdtraj, which loads all trajectories into RAM. 

392 - 'no_load' creates an empty trajectory object. 

393 Defaults to 'no_load' 

394 index (Optional[Union[int, list[int], numpy.ndarray, slice]]): An 

395 integer or a Sequence of ``int``. If an integer is provided, 

396 only the frame at this position will be loaded once the internal 

397 ``mdtraj.Trajectory`` is accessed. If an array or list is provided, 

398 the corresponding frames will be used. Indices always slice the 

399 trajectory as is, meaning they don't index the original frames 

400 of the trajectory on disk (see Example section). These indices 

401 can have duplicates: ``[0, 1, 1, 2, 0, 1]``. A slice object 

402 can also be provided. Supports fancy slicing (``traj[1:50:3]``). 

403 If :obj:`None` is provided, the traj is loaded fully. 

404 Defaults to :obj:`None`. 

405 traj_num (Union[int, None], optional): If working with multiple 

406 trajs, this is the easiest unique identifier. If multiple 

407 :obj:`SingleTraj` are instantiated by 

408 :obj:`encodermap.trajinfo.info_all.TrajEnsemble` the ``traj_num`` 

409 is used as a unique identifier per traj. Defaults to :obj:`None`. 

410 basename_fn (Optional[Callable[[str], str]]): A function to apply to 

411 ``traj`` to give it another identifier. If all your trajs are called 

412 ``'traj.xtc'`` and only the directory they're in gives them a 

413 unique identifier, you can provide a function into this 

414 argument to split the path. The function has to take a :obj:`str` and 

415 return :obj:`str`. If None is provided, the basename 

416 is extracted like so: ``lambda x: x.split('/')[0].split('.')[-1]``. 

417 Defaults to None, in which case the filename without 

418 extension will be used. 

419 custom_top: Optional[:obj:`encodermap._typing.CustomAAsDict`]: An 

420 instance of the :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology` 

421 class or a dictionary that can be made into such. 

422 

423 """ 

424 # Imports 

425 # Encodermap imports 

426 from encodermap.misc.xarray import construct_xarray_from_numpy 

427 from encodermap.trajinfo.trajinfo_utils import CustomTopology 

428 

429 # defaults 

430 self.__traj = traj 

431 self.backend = backend 

432 self.common_str = common_str 

433 self.index = index if isinstance(index, tuple) else (index,) 

434 self.traj_num = traj_num 

435 self._loaded_once = False if backend == "no_load" else True 

436 self._orig_frames = np.array([]) 

437 self._CVs = xr.Dataset() 

438 

439 # custom topology to load dihedral angles 

440 self._custom_top = custom_top 

441 if self._custom_top is not None: 

442 if isinstance(self._custom_top, dict): 

443 self._custom_top = CustomTopology.from_dict(self._custom_top, traj=self) 

444 else: 

445 self._custom_top = CustomTopology(traj=self) 

446 

447 # _atom indices are for delayed atom-slicing 

448 self._atom_indices = None 

449 

450 # decide the basename 

451 if basename_fn is None: 

452 basename_fn = lambda x: os.path.basename(x).split(".")[0] 

453 self.basename_fn = basename_fn 

454 

455 # save the filename 

456 if isinstance(traj, str): 

457 if self._validate_uri(traj): 

458 self._traj_file = traj 

459 else: 

460 self._traj_file = Path(traj) 

461 elif isinstance(traj, Path): 

462 self._traj_file = traj 

463 elif isinstance(traj, md.Trajectory): 

464 self._traj_file = Path("") 

465 self._top_file = Path("") 

466 self.backend = "mdtraj" 

467 self.trajectory = traj 

468 self._loaded_once = True 

469 self._orig_frames = np.arange(traj.n_frames) 

470 else: 

471 raise ValueError( 

472 f"Argument `traj` takes either str, Path, or " 

473 f"mdtraj.Trajectory. You supplied: {type(traj)}." 

474 ) 

475 

476 if top is not None: 

477 if isinstance(top, md.Topology): 

478 if custom_top is not None: 

479 raise Exception( 

480 f"Providing an MDTraj Topology as the `top` argument interferes " 

481 f"with the argument `custom_topology`. Use one or the other. " 

482 ) 

483 self._top_file = Path("") 

484 else: 

485 if self._validate_uri(top): 

486 self._top_file = top 

487 else: 

488 self._top_file = Path(top) 

489 if isinstance(self._traj_file, Path): 

490 if ( 

491 self._traj_file.suffix in _TOPOLOGY_EXTS 

492 and self._traj_file != self._top_file 

493 ): 

494 raise MixedUpInputs( 

495 f"You probably mixed up the input. Normally you " 

496 f"want to instantiate with `SingleTraj(traj, top)`." 

497 f"Based on the files and the extensions you provided " 

498 f"(traj={self._traj_file.name} and top=" 

499 f"{self._top_file.name}), you want to change the " 

500 f"order of the arguments, or use keyword arguments." 

501 ) 

502 else: 

503 if isinstance(self._traj_file, Path): 

504 if self._traj_file.suffix in _TOPOLOGY_EXTS: 

505 if self._validate_uri(traj): 

506 self._top_file = traj 

507 else: 

508 self._top_file = self._traj_file 

509 else: 

510 self._top_file = self._traj_file 

511 

512 if self.backend == "no_load": 

513 self.trajectory = False 

514 self.topology = False 

515 else: 

516 if isinstance(self._traj_file, str) and self._validate_uri(self._traj_file): 

517 traj = md.load_pdb(str(self.traj_file)) 

518 elif self._traj_file != Path(""): 

519 try: 

520 traj = md.load(str(self._traj_file), top=str(self._top_file)) 

521 except tables.NoSuchNodeError as e: 

522 if self.traj_num is None: 

523 raise 

524 # Local Folder Imports 

525 from .info_all import HDF5GroupWrite 

526 

527 with HDF5GroupWrite(self.top_file) as h5file: 

528 traj = h5file.read_traj(self.traj_num) 

529 self.trajectory = traj 

530 self.topology = False 

531 self._loaded_once = True 

532 self.topology = self._custom_top.top 

533 self._orig_frames = np.arange(traj.n_frames) 

534 

535 # check the h5file. Maybe it contains CVs and topology 

536 CVs_in_file = False 

537 custom_top_in_file = False 

538 common_str_in_file = False 

539 if isinstance(self._traj_file, Path): 

540 if self._traj_file.suffix == ".h5": 

541 with h5.File(self.traj_file, "r") as file: 

542 h5_attrs = list(file.attrs.keys()) 

543 if "CVs" in file.keys(): 

544 CVs_in_file = True 

545 if any(["custom_top" in s for s in h5_attrs]): 

546 custom_top_in_file = True 

547 if any(["common_str" in s for s in h5_attrs]): 

548 common_str_in_file = True 

549 

550 # if group 'CVs' is in h5store, but no data 

551 if CVs_in_file: 

552 test = xr.open_dataset( 

553 self.traj_file, 

554 group="CVs", 

555 engine="h5netcdf", 

556 backend_kwargs={"phony_dims": "access"}, 

557 ) 

558 if len(test.data_vars) == 0: 

559 CVs_in_file = False 

560 

561 # inform the use about a grandfathered traj_num 

562 if custom_top_in_file: 

563 h5_top_attrs = list(filter(lambda x: x.startswith("_custom_top"), h5_attrs)) 

564 finds = [re.findall(r"\d+", s) for s in h5_top_attrs] 

565 if not all([f == [] for f in finds]): 

566 h5_top_traj_nums = [int(f[-1]) for f in finds] 

567 if self.traj_num not in h5_top_traj_nums: 

568 if len(h5_top_attrs) == 1: 

569 h5_top_traj_num = h5_top_traj_nums[0] 

570 with h5py.File(self.traj_file, "a") as file: 

571 file.attrs["_custom_top"] = file.attrs[h5_top_attrs[0]] 

572 del file.attrs[h5_top_attrs[0]] 

573 print( 

574 f"The file {self.traj_file} was previously saved with the `traj_num`=" 

575 f"{h5_top_traj_num}, but this instance uses `traj_num`={self.traj_num}. " 

576 f"I have overwritten the `traj_num` for the custom_topology in this file." 

577 ) 

578 elif len(h5_top_attrs) > 1: 

579 raise Exception( 

580 f"The h5File contains multiple '_custom_top' attrs: {h5_top_attrs} " 

581 f"of which none matches this Traj's `traj_num`: {self.traj_num}." 

582 ) 

583 

584 # inform the use about a grandfathered traj_num 

585 if common_str_in_file: 

586 common_str_key = "common_str" 

587 h5_common_str_traj_num = None 

588 h5_common_str_attrs = list( 

589 filter(lambda x: x.startswith("common_str"), h5_attrs) 

590 ) 

591 finds = [re.findall(r"\d+", s) for s in h5_common_str_attrs] 

592 if not all([f == [] for f in finds]): 

593 h5_common_str_traj_nums = [int(f[-1]) for f in finds] 

594 if self.traj_num not in h5_common_str_traj_nums: 

595 if len(h5_common_str_attrs) == 1: 

596 h5_common_str_traj_num = h5_common_str_traj_nums[0] 

597 with h5py.File(self.traj_file, "a") as file: 

598 file.attrs["common_str"] = file.attrs[ 

599 h5_common_str_attrs[0] 

600 ] 

601 del file.attrs[h5_common_str_attrs[0]] 

602 print( 

603 f"The file {self.traj_file} was previously saved with the `traj_num`=" 

604 f"{h5_common_str_traj_num}, but this instance uses `traj_num`={self.traj_num}. " 

605 f"I have overwritten the `traj_num` for the custom_topology in this file." 

606 ) 

607 elif len(h5_common_str_attrs) > 1: 

608 raise Exception( 

609 f"The h5File contains multiple '_custom_top' attrs: {h5_common_str_attrs} " 

610 f"of which none matches this Traj's `traj_num`: {self.traj_num}." 

611 ) 

612 else: 

613 h5_common_str_traj_num = self.traj_num 

614 if h5_common_str_traj_num is not None: 

615 common_str_key = f"common_str_{h5_common_str_traj_num}" 

616 

617 # maybe load CVs from h5 file 

618 if CVs_in_file: 

619 try: 

620 ds = xr.open_dataset( 

621 self.traj_file, 

622 group="CVs", 

623 engine="h5netcdf", 

624 backend_kwargs={"phony_dims": "access"}, 

625 ) 

626 if ds.sizes["traj_num"] > 1: 

627 assert self.traj_num in ds.coords["traj_num"], ( 

628 f"This trajectory with {self.traj_num=} is not in " 

629 f"the dataset with traj_nums: {ds.coords['traj_num']}." 

630 ) 

631 ds = ds.sel(traj_num=self.traj_num) 

632 ds = ds.expand_dims("traj_num") 

633 if str(ds.coords["traj_name"].values) != self.basename: 

634 ds.coords["traj_name"] = [self.basename] 

635 ds = ds.dropna("frame_num", how="all") 

636 self._CVs = ds 

637 # bad formatted h5 file 

638 except OSError: 

639 DAs = { 

640 k: construct_xarray_from_numpy(self, i[()], k) 

641 for k, i in file["CVs"].items() 

642 } 

643 DS = xr.Dataset(DAs) 

644 self._CVs.update(DS) 

645 # other exceptions probably due to formatting 

646 except Exception as e: 

647 raise Exception( 

648 f"The formatting of the data in the file " 

649 f"{self.traj_file} is off. Xarray could " 

650 f"not load the group 'CVs' and failed with {e}. " 

651 f"Some debug: {CVs_in_file=} and {file.keys()=}." 

652 ) from e 

653 

654 # get the original frame indices from the dataset 

655 # this is the only case where we want to overwrite 

656 # this variable 

657 if not self._loaded_once: 

658 self._loaded_once = True 

659 self._orig_frames = self._CVs["frame_num"].values 

660 

661 # iteratively apply index 

662 index = self._orig_frames 

663 for ind in self.index: 

664 if ind is not None: 

665 index = index[ind] 

666 

667 # set the _CVs accordingly 

668 self._CVs = self._CVs.loc[{"frame_num": index}] 

669 

670 # maybe load _custom_top_from file 

671 if custom_top_in_file: 

672 ctop = CustomTopology.from_hdf5_file( 

673 self.traj_file, 

674 traj=self, 

675 ) 

676 if custom_top is not None: 

677 if self._custom_top != ctop: 

678 print( 

679 f"The file {self.traj_file} contained a custom_topology: " 

680 f"{ctop}. The custom topology you provided {self._custom_top} " 

681 f"is different from that in the file. The provided one will " 

682 f"take precedent over the one in the file. " 

683 ) 

684 else: 

685 self._custom_top = ctop 

686 

687 # maybe load common_str form file 

688 if common_str_in_file: 

689 with h5py.File(self.traj_file, "r") as f: 

690 common_str_file = f.attrs[common_str_key] 

691 

692 if self.common_str == "": 

693 self.common_str = common_str_file 

694 else: 

695 if common_str_file != self.common_str: 

696 warnings.warn( 

697 f"The file {self.traj_file} contains a " 

698 f"`common_str`={common_str_file}, which differs from the " 

699 f"`common_str` provided in __init__ ({self.common_str}). " 

700 f"I will use the one from " 

701 f"__init__ and overwrite the one in the file." 

702 ) 

703 try: 

704 with h5py.File(self.traj_file, "a") as f: 

705 f.attrs[common_str_key] = self.common_str 

706 except OSError as e: 

707 warnings.warn( 

708 f"I could not write the common str in the file, because " 

709 f"it is already opened by another python object: {str(e)}" 

710 ) 

711 

712 @classmethod 

713 def from_pdb_id(cls, pdb_id: str, traj_num: Optional[int] = None) -> SingleTraj: 

714 """Alternate constructor for the TrajEnsemble class. 

715 

716 Builds an SingleTraj class from a pdb-id. 

717 

718 Args: 

719 pdb_id (str): The 4-letter pdb id. 

720 

721 Returns: 

722 SingleTraj: An SingleTraj class. 

723 

724 """ 

725 url = f"https://files.rcsb.org/view/{pdb_id.upper()}.pdb" 

726 return cls(url, common_str=pdb_id.upper(), traj_num=traj_num) 

727 

728 @property 

729 def featurizer(self): 

730 # Local Folder Imports 

731 from ..loading.featurizer import SingleTrajFeaturizer 

732 

733 if not hasattr(self, "_featurizer"): 

734 self._featurizer = SingleTrajFeaturizer(self) 

735 return self._featurizer 

736 

737 @property 

738 def indices_chi1(self) -> np.ndarray: 

739 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the 

740 atoms that take part in this dihedral angle. This index is 0-based.""" 

741 return self._custom_top.indices_chi1() 

742 

743 @property 

744 def indices_chi2(self) -> np.ndarray: 

745 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the 

746 atoms that take part in this dihedral angle. This index is 0-based.""" 

747 return self._custom_top.indices_chi2() 

748 

749 @property 

750 def indices_chi3(self) -> np.ndarray: 

751 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the 

752 atoms that take part in this dihedral angle. This index is 0-based.""" 

753 return self._custom_top.indices_chi3() 

754 

755 @property 

756 def indices_chi4(self) -> np.ndarray: 

757 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the 

758 atoms that take part in this dihedral angle. This index is 0-based.""" 

759 return self._custom_top.indices_chi4() 

760 

761 @property 

762 def indices_chi5(self) -> np.ndarray: 

763 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the 

764 atoms that take part in this dihedral angle. This index is 0-based.""" 

765 return self._custom_top.indices_chi5() 

766 

767 @property 

768 def indices_phi(self) -> np.ndarray: 

769 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the 

770 atoms that take part in this dihedral angle. This index is 0-based.""" 

771 return self._custom_top.indices_phi() 

772 

773 @property 

774 def indices_omega(self) -> np.ndarray: 

775 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the 

776 atoms that take part in this dihedral angle. This index is 0-based.""" 

777 return self._custom_top.indices_omega() 

778 

779 @property 

780 def indices_psi(self) -> np.ndarray: 

781 """numpy.ndarray: A numpy array with shape (n_dihedrals, 4) indexing the 

782 atoms that take part in this dihedral angle. This index is 0-based.""" 

783 return self._custom_top.indices_psi() 

784 

785 @property 

786 def _original_frame_indices(self) -> np.ndarray: 

787 """numpy.ndarray: If trajectory has not been loaded, it is loaded and the 

788 frames of the trajectory file on disk are put into a `np.arange()`. If 

789 the trajectory is sliced in weird ways, this array tracks the original frames. 

790 """ 

791 if self._loaded_once: 

792 return self._orig_frames 

793 else: 

794 self.load_traj() 

795 return self._orig_frames 

796 

797 @property 

798 def _frames(self) -> np.ndarray: 

799 """numpy.ndarray: Applies self.index over self._orig_frames.""" 

800 frames = self._orig_frames.copy() 

801 for ind in self.index: 

802 if ind is not None: 

803 frames = frames[ind] 

804 return np.asarray(frames) 

805 

806 def _trace(self, CV: Sequence[str]) -> np.ndarray: 

807 """Creates a low-dimensional representation of the loaded CV data by 

808 stacking all arguments in `CV` along a single axis. 

809 

810 If this `SingleTraj` has 100 frames and a CV with shape (100, 50, 3) with 

811 the name 'cartesians', then `traj._trace` will return a numpy.ndarray of shape 

812 (100, 150). 

813 

814 Args: 

815 CV (Sequence[str]): The CVs to combine in the trace. 

816 

817 Returns: 

818 numpy.ndarray: The trace. 

819 

820 """ 

821 out = [] 

822 for i in CV: 

823 v = self.CVs[i] 

824 out.append(v.reshape(v.shape[0], -1)) 

825 return np.concatenate(out) 

826 

827 @property 

828 def traj_file(self) -> str: 

829 """str: The traj file as a string (rather than a `pathlib.Path`).""" 

830 return str(self._traj_file) 

831 

832 @property 

833 def top_file(self) -> str: 

834 """str: The topology file as a string (rather than a `pathlib.Path`).""" 

835 return str(self._top_file) 

836 

837 @property 

838 def traj(self) -> md.Trajectory: 

839 """mdtraj.Trajectory: This attribute always returns an mdtraj.Trajectory. 

840 if `backend` is 'no_load', the trajectory will be loaded into memory and returned. 

841 

842 """ 

843 if self.backend == "no_load": 

844 self.load_traj() 

845 out = self.trajectory 

846 self.unload() 

847 return out 

848 else: 

849 return self.trajectory 

850 

851 @property 

852 def _traj(self): 

853 """Needs to be here to complete setter. 

854 Not returning anything, because setter is also not returning anything.""" 

855 pass 

856 

857 @_traj.setter 

858 def _traj(self, traj_file): 

859 """Sets the traj and trajectory attributes. Can be provided str or 

860 mdtraj.Trajectory and sets the attributes based on the chosen backend.""" 

861 if self.topology: 

862 reinject_top = deepcopy(self.topology) 

863 else: 

864 reinject_top = False 

865 self.trajectory, _ = _load_traj( 

866 *self.index, 

867 traj_file=traj_file, 

868 top_file=self._top_file, 

869 traj_num=self.traj_num, 

870 atom_indices=self._atom_indices, 

871 ) 

872 if not self._loaded_once: 

873 self._loaded_once = True 

874 self._orig_frames = _ 

875 if reinject_top: 

876 self.trajectory.top = reinject_top 

877 self.topology = reinject_top 

878 

879 @property 

880 def basename(self) -> str: 

881 """str: Basename is the filename without path and without extension. If `basename_fn` is not None, it will be 

882 applied to `traj_file`.""" 

883 if self.traj_file: 

884 if str(self.traj_file) == ".": 

885 return None 

886 return self.basename_fn(self.traj_file) 

887 

888 @property 

889 def extension(self) -> str: 

890 """str: Extension is the file extension of the trajectory file (self.traj_file).""" 

891 if isinstance(self._traj_file, Path): 

892 return self._traj_file.suffix 

893 else: 

894 return "." + self._traj_file.split(".")[-1] 

895 

896 @property 

897 def id(self) -> np.ndarray: 

898 """numpy.ndarray: id is an array of unique identifiers which identify the frames in 

899 this SingleTraj object when multiple Trajectories are considered. 

900 

901 If the traj was initialized from an TrajEnsemble class, the traj gets a unique 

902 identifier (traj_num) which will also be put into the id array, so that id 

903 can have two shapes ((n_frames, ), (n_frames, 2)) This corresponds to 

904 self.id.ndim = 1 and self.id.ndim = 2. In the latter case self.id[:,1] are the 

905 frames and self.id[:,0] is an array full of traj_num. 

906 

907 """ 

908 values = self._original_frame_indices 

909 if isinstance(values, (int, np.integer)): 

910 if self.traj_num is None: 

911 return np.array([values]) 

912 else: 

913 return np.array([[self.traj_num, values]]) 

914 else: 

915 for i, ind in enumerate(self.index): 

916 if ind is not None: 

917 values = values[ind] 

918 

919 # if reduced all the way to single frame 

920 if isinstance(values, (int, np.integer)): 

921 if self.traj_num is None: 

922 return np.array([values]) 

923 else: 

924 return np.array([[self.traj_num, values]]) 

925 

926 # else 

927 if self.traj_num is None: 

928 return values 

929 else: 

930 return np.array([np.full(len(values), self.traj_num), values]).T 

931 

932 @property 

933 def n_frames(self) -> int: 

934 """int: Number of frames in traj. 

935 

936 Loads the traj into memory if not in HDF5 file format. Be aware. 

937 

938 """ 

939 if any([isinstance(ind, (int, np.integer)) for ind in self.index]): 

940 self._loaded_once = True 

941 ind = [i for i in self.index if isinstance(i, (int, np.integer))][0] 

942 self._orig_frames = ind 

943 return 1 

944 elif self._traj_file.suffix == ".h5": 

945 with h5.File(self.traj_file, "r") as file: 

946 if self.index == (None,): 

947 if ( 

948 "coordinates" not in list(file.keys()) 

949 and self.traj_num is not None 

950 ): 

951 n_frames = np.arange( 

952 file[f"coordinates_{self.traj_num}"].shape[0] 

953 ) 

954 else: 

955 n_frames = np.arange(file["coordinates"].shape[0]) 

956 else: 

957 for i, ind in enumerate(self.index): 

958 if i == 0: 

959 if ( 

960 "coordinates" not in list(file.keys()) 

961 and self.traj_num is not None 

962 ): 

963 n_frames = np.arange( 

964 file[f"coordinates_{self.traj_num}"].shape[0] 

965 ) 

966 else: 

967 n_frames = np.arange(file["coordinates"].shape[0]) 

968 if ind is not None: 

969 n_frames = n_frames[ind] 

970 else: 

971 if ind is not None: 

972 n_frames = n_frames[ind] 

973 if not self._loaded_once: 

974 if ( 

975 "coordinates" not in list(file.keys()) 

976 and self.traj_num is not None 

977 ): 

978 self._orig_frames = np.arange( 

979 file[f"coordinates_{self.traj_num}"].shape[0] 

980 ) 

981 else: 

982 self._orig_frames = np.arange(file["coordinates"].shape[0]) 

983 

984 # return single int or length of array 

985 if isinstance(n_frames, (int, np.integer)): 

986 return n_frames 

987 else: 

988 return len(n_frames) 

989 elif self._traj_file.suffix == ".xtc": 

990 with mda.coordinates.XTC.XTCReader(self.traj_file) as reader: 

991 if self.index == (None,): 

992 n_frames = np.arange(reader.n_frames) 

993 else: 

994 for i, ind in enumerate(self.index): 

995 if i == 0: 

996 n_frames = np.arange(reader.n_frames)[ind] 

997 else: 

998 n_frames = n_frames[ind] 

999 if ind is None: 

1000 n_frames = n_frames[0] 

1001 if not self._loaded_once: 

1002 self._loaded_once = True 

1003 self._orig_frames = np.arange(reader.n_frames) 

1004 if isinstance(n_frames, (int, np.integer)): 

1005 return n_frames 

1006 else: 

1007 return len(n_frames) 

1008 else: 

1009 self.load_traj() 

1010 return self.traj.n_frames 

1011 

1012 @property 

1013 def _n_frames_base_h5_file(self) -> int: 

1014 """int: Can be used to get n_frames without loading an HDF5 into memory.""" 

1015 if self.extension == ".h5": 

1016 with h5.File(self.traj_file, "r") as file: 

1017 return file["coordinates"].shape[0] 

1018 else: 

1019 return -1 

1020 

1021 @property 

1022 def CVs_in_file(self) -> bool: 

1023 """bool: Is True, if `traj_file` has exyension .h5 and contains CVs.""" 

1024 if self.extension == ".h5": 

1025 with h5.File(self.traj_file, "r") as file: 

1026 if "CVs" in file.keys(): 

1027 return True 

1028 return False 

1029 

1030 @property 

1031 def n_atoms(self) -> int: 

1032 """int: Number of atoms in traj. 

1033 

1034 Loads the traj into memory if not in HDF5 file format. Be aware. 

1035 

1036 """ 

1037 if self._atom_indices is not None: 

1038 return len(self._atom_indices) 

1039 if self.extension == ".h5": 

1040 with h5.File(self.traj_file, "r") as file: 

1041 try: 

1042 h5_coordinates_shape = file["coordinates"].shape[1] 

1043 except KeyError as e: 

1044 if self.traj_num is not None: 

1045 h5_coordinates_shape = file[ 

1046 f"coordinates_{self.traj_num}" 

1047 ].shape[1] 

1048 else: 

1049 raise Exception( 

1050 f"Neither key 'coordinates', nor 'coordinates_{self.traj_num}' " 

1051 f"exists in h5file {self.traj_file} with keys {file.keys()}." 

1052 ) from e 

1053 return h5_coordinates_shape 

1054 else: 

1055 return self.top.n_atoms 

1056 

1057 @property 

1058 def n_residues(self) -> int: 

1059 """int: Number of residues in traj.""" 

1060 return self.top.n_residues 

1061 

1062 @property 

1063 def n_chains(self) -> int: 

1064 """int: Number of chains in traj.""" 

1065 return self.top.n_chains 

1066 

1067 @property 

1068 def top(self) -> md.Topology: 

1069 """mdtraj.Topology: The structure of a Topology object is similar to that of a PDB file. 

1070 

1071 It consists. of a set of Chains (often but not always corresponding to 

1072 polymer chains). Each Chain contains a set of Residues, and each Residue 

1073 contains a set of Atoms. In addition, the Topology stores a list of which 

1074 atom pairs are bonded to each other. 

1075 Atom and residue names should follow the PDB 3.0 nomenclature for all 

1076 molecules for which one exists 

1077 

1078 Attributes: 

1079 chains (generator): Iterate over chains. 

1080 residues (generator): Iterate over residues. 

1081 atoms (generator): Iterate over atoms. 

1082 bonds (generator): Iterate over bonds. 

1083 

1084 """ 

1085 return self._get_top() 

1086 

1087 def _get_raw_top(self) -> md.Topology: 

1088 """Reads different files and loads md.Topology from them. 

1089 

1090 This topology will *NOT* be corrected with `CustomTopology`. 

1091 

1092 Returns: 

1093 mdtraj.Topology: The raw topology. 

1094 

1095 """ 

1096 # Third Party Imports 

1097 import tables 

1098 

1099 # Local Folder Imports 

1100 from .info_all import HDF5GroupWrite 

1101 

1102 if self.top_file: 

1103 if self.top_file != ".": 

1104 if self._validate_uri(self.top_file): 

1105 assert self.top_file.endswith(".pdb") 

1106 top = _load_pdb_from_uri(self.top_file) 

1107 elif not os.path.isfile(self.top_file): 

1108 raise FileNotFoundError( 

1109 errno.ENOENT, os.strerror(errno.ENOENT), self.top_file 

1110 ) 

1111 if self.backend == "no_load" and not self.extension == ".h5" and self.traj_file: 

1112 if self._validate_uri(self.top_file): 

1113 top = _load_pdb_from_uri(self.top_file) 

1114 else: 

1115 top = md.load_topology(self.top_file) 

1116 if self.extension == ".h5": 

1117 try: 

1118 top = md.load_topology(self.top_file) 

1119 except tables.NoSuchNodeError as e: 

1120 if self.traj_num is None: 

1121 raise e 

1122 with HDF5GroupWrite(self.top_file) as h5file: 

1123 top = h5file.read_traj(self.traj_num).top 

1124 if self.backend == "no_load" and "top" not in locals(): 

1125 try: 

1126 top = md.load_topology(self.top_file) 

1127 except tables.NoSuchNodeError as e: 

1128 if self.traj_num is None: 

1129 raise e 

1130 with HDF5GroupWrite(self.top_file) as h5file: 

1131 top = h5file.read_traj(self.traj_num).top 

1132 if self.backend == "mdtraj": 

1133 top = self.traj.top 

1134 else: 

1135 if self._validate_uri(self.top_file): 

1136 top = _load_pdb_from_uri(self.top_file) 

1137 else: 

1138 try: 

1139 top = md.load_topology(self.top_file) 

1140 except tables.NoSuchNodeError as e: 

1141 if self.traj_num is None: 

1142 raise e 

1143 with HDF5GroupWrite(self.top_file) as h5file: 

1144 top = h5file.read_traj(self.traj_num).top 

1145 return top 

1146 

1147 def _get_top(self) -> md.Topology: 

1148 """Reads different files and loads md.Topology from them. 

1149 

1150 Returns: 

1151 mdtraj.Topology: The structure of a Topology object is similar to that of a PDB file. 

1152 

1153 """ 

1154 if self.topology: 

1155 top = self.topology 

1156 else: 

1157 top = self._custom_top.top 

1158 return top 

1159 

1160 def copy(self) -> SingleTraj: 

1161 """Returns a copy of ``self``.""" 

1162 return deepcopy(self) 

1163 

1164 def del_CVs(self) -> None: 

1165 """Resets the ``_CVs`` attribute to an empty :obj:`xarray.Dataset`.""" 

1166 del self._CVs 

1167 self._CVs = xr.Dataset() 

1168 

1169 def _calc_CV(self) -> dict[str, np.ndarray]: 

1170 """Returns the current CVs as a dictionary.""" 

1171 if self._CVs: 

1172 out = {} 

1173 for key, val in self._CVs.data_vars.items(): 

1174 if "feature_indices" in key: 

1175 if key.startswith("MinRmsdFeature"): 

1176 continue 

1177 if "cartesian" in key: 

1178 assert val.shape[0] == 1, ( 

1179 f"The substring 'feature_indices' is special and can " 

1180 f"only contain a (1, n_frames) or (1, n_frames, 4) arrays. " 

1181 f"Your value of {key=} has the shape: {val.shape=} " 

1182 f"If you have manually " 

1183 f"loaded a feature with this substring, use a different " 

1184 f"one. These CVs contain integer indices and not values." 

1185 ) 

1186 else: 

1187 assert val.shape[-1] <= 4, ( 

1188 f"The substring 'feature_indices' is special and can " 

1189 f"only contain a (1, n_frames) or (1, n_frames, <=4) arrays. " 

1190 f"Your value of {key=} has the shape: {val.shape=} " 

1191 f"If you have manually " 

1192 f"loaded a feature with this substring, use a different " 

1193 f"one. These CVs contain integer indices and not values." 

1194 ) 

1195 continue 

1196 axis_name = ( 

1197 "feature_axis" 

1198 if "feature_axis" in val.attrs 

1199 else "feature_axes" if "feature_axes" in val.attrs else None 

1200 ) 

1201 if key == "central_angles_indices": 

1202 raise Exception( 

1203 f"{val.shape=}. {axis_name=} {val.attrs[axis_name]=}" 

1204 ) 

1205 if np.any(np.isnan(val)): 

1206 if axis_name is not None: 

1207 val = val.dropna(val.attrs[axis_name]) 

1208 else: 

1209 val = val.dropna(key.upper()) 

1210 try: 

1211 out[key] = val.values.squeeze(0) 

1212 except ValueError as e: 

1213 raise Exception(f"{key=} {val=}") from e 

1214 return out 

1215 else: 

1216 return {} 

1217 

1218 @property 

1219 def CVs(self) -> dict[str, np.ndarray]: 

1220 """dict[str, numpy.ndarray]: Returns a simple dict from the more complicated self._CVs xarray Dataset. 

1221 

1222 If self._CVs is empty and self.traj_file is a HDF5 (.h5) file, the contents 

1223 of the HDF5 will be checked, whether CVs have been stored there. 

1224 If not and empty dict will be returned. 

1225 

1226 """ 

1227 return self._calc_CV() 

1228 

1229 def _validate_uri(self, uri: Union[Path, str]) -> bool: 

1230 """Checks whether `uri` is a valid uri.""" 

1231 # Encodermap imports 

1232 from encodermap.misc.misc import _validate_uri 

1233 

1234 return _validate_uri(str(uri)) 

1235 

1236 def load_traj( 

1237 self, 

1238 new_backend: Literal["no_load", "mdtraj"] = "mdtraj", 

1239 ) -> None: 

1240 """Loads the trajectory, with a new specified backend. 

1241 

1242 After this is called the instance variable self.trajectory 

1243 will contain a mdtraj Trajectory object. 

1244 

1245 Args: 

1246 new_backend (str, optional): Can either be: 

1247 - ``'mdtraj'`` to load the trajectory using mdtraj. 

1248 - ``'no_load'`` to not load the traj (unload). 

1249 Defaults to ``'mdtraj'``. 

1250 

1251 """ 

1252 if self.backend == new_backend: 

1253 return 

1254 if self.backend == "mdtraj" and new_backend == "no_load": 

1255 self.unload() 

1256 self.topology = False 

1257 if self.backend == "no_load" and new_backend == "mdtraj": 

1258 self.backend = new_backend 

1259 # call the setter again 

1260 try: 

1261 self._traj = self.traj_file 

1262 except Exception: 

1263 self.backend = "no_load" 

1264 raise 

1265 self.topology = self.top 

1266 

1267 def select( 

1268 self, 

1269 sel_str: str = "all", 

1270 ) -> np.ndarray: 

1271 """Execute a selection against the topology. 

1272 

1273 Args: 

1274 sel_str (str, optional): What to select. Defaults to 'all'. 

1275 

1276 See also: 

1277 https://mdtraj.org/1.9.4/atom_selection.html 

1278 

1279 Examples: 

1280 >>> import encodermap as em 

1281 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb") 

1282 >>> select = traj.top.select("name CA and resSeq 1") 

1283 >>> select 

1284 array([1]) 

1285 

1286 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb") 

1287 >>> select = traj.top.select("name CA and resSeq 1") 

1288 >>> traj.top.atom(select[0]) 

1289 MET1-CA 

1290 

1291 """ 

1292 return self.top.select(sel_str) 

1293 

1294 def unload( 

1295 self, 

1296 CVs: bool = False, 

1297 ) -> None: 

1298 """Clears up RAM by deleting the trajectory info and the CV data. 

1299 

1300 If ``CVs`` is set to True the loaded CVs will also be deleted. 

1301 

1302 Args: 

1303 CVs (bool, optional): Whether to also delete CVs, defaults to False. 

1304 

1305 """ 

1306 if self.backend == "no_load": 

1307 return 

1308 self.backend = "no_load" 

1309 for key in self._mdtraj_attr: 

1310 try: 

1311 del self.__dict__[key] 

1312 except KeyError: 

1313 pass 

1314 if CVs: 

1315 self._CVs = xr.Dataset() 

1316 self.trajectory, self.topology = False, False 

1317 

1318 def _gen_ensemble(self) -> TrajEnsemble: 

1319 """Creates a :obj:`encodermap.trajinfo.info_all.TrajEnsemble` instance 

1320 with this :obj:`SingleTraj` in it. 

1321 

1322 This method is needed to add two SingleTraj objects 

1323 along the 'trajectory' axis with the method add_new_traj. 

1324 This method is also called by the __getitem__ method of the TrajEnsemble class. 

1325 

1326 """ 

1327 # Encodermap imports 

1328 from encodermap.trajinfo.info_all import TrajEnsemble 

1329 

1330 info_all = TrajEnsemble( 

1331 trajs=[self], 

1332 backend=self.backend, 

1333 common_str=[], 

1334 basename_fn=self.basename_fn, 

1335 ) 

1336 info_all.trajs[0]._CVs = self._CVs 

1337 info_all.trajs[0].common_str = self.common_str 

1338 info_all.common_str = [self.common_str] 

1339 info_all.trajs[0].traj_num = self.traj_num 

1340 info_all.trajs[0].index = self.index 

1341 info_all.trajs[0]._custom_top = self._custom_top 

1342 info_all.trajs[0].topology = self._custom_top.top 

1343 return info_all 

1344 

1345 def _add_along_traj(self, y: SingleTraj) -> TrajEnsemble: 

1346 """Puts self and y into a :obj:`encodermap.trajinfo.info_all.TrajEnsemble`. 

1347 

1348 This way the trajectories are not appended along the time 

1349 axis but rather along the 'trajectory' axis. 

1350 

1351 Args: 

1352 y (SingleTraj): The other :obj:`SingleTraj` trajectory. 

1353 

1354 """ 

1355 if self.traj_num == y.traj_num: 

1356 if y.traj_num is not None: 

1357 y.traj_num += 1 

1358 else: 

1359 y.traj_num = 1 

1360 class_1 = self._gen_ensemble() 

1361 class_2 = y._gen_ensemble() 

1362 new_class = class_1 + class_2 

1363 return new_class 

1364 

1365 def get_single_frame(self, key: int) -> SingleTraj: 

1366 """Returns a single frame from the trajectory. 

1367 

1368 Args: 

1369 key (Union[int, np.int]): Index of the frame. 

1370 

1371 Examples: 

1372 

1373 Import EncoderMap and load :obj:`SingleTraj`. 

1374 

1375 >>> import encodermap as em 

1376 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb") 

1377 >>> traj.n_frames 

1378 14 

1379 

1380 Load the same traj and give it a ``traj_num`` for recognition in a 

1381 set of multiple trajectories. 

1382 

1383 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb", traj_num=5) 

1384 >>> frame = traj.get_single_frame(2) 

1385 >>> frame.id 

1386 array([[5, 2]]) 

1387 

1388 """ 

1389 return self.__getitem__(key) 

1390 

1391 def show_traj(self, gui: bool = True) -> Any: 

1392 """Returns an nglview view object. 

1393 

1394 See Also: 

1395 https://nglviewer.org/nglview/latest/ 

1396 

1397 Returns: 

1398 view (nglview.widget): The nglview widget object. 

1399 

1400 """ 

1401 # Third Party Imports 

1402 import nglview 

1403 

1404 view = nglview.show_mdtraj(self.traj, gui=gui) 

1405 return view 

1406 

1407 def dash_summary(self) -> pd.DataFrame: 

1408 """Returns a :obj:`pandas.DataFrame` with useful information about this instance. 

1409 

1410 Returns: 

1411 pd.DataFrame: The dataframe. 

1412 

1413 """ 

1414 dt = self.traj.time 

1415 dt = np.unique(dt[1:] - dt[:-1]) 

1416 if len(dt) == 1: 

1417 dt = dt[0] 

1418 elif len(dt) == 0: 

1419 dt = "single frame" 

1420 if self.index == (None,): 

1421 index = "[::]" 

1422 else: 

1423 index = self.index[1:] 

1424 if len(index) == 1: 

1425 index = index[0] 

1426 df = pd.DataFrame( 

1427 { 

1428 "field": [ 

1429 "n_frames", 

1430 "n_atoms", 

1431 "dt (ps)", 

1432 "traj_file", 

1433 "top_file", 

1434 "index", 

1435 "common_str", 

1436 ], 

1437 "value": [ 

1438 self.n_frames, 

1439 self.n_atoms, 

1440 dt, 

1441 self.traj_file, 

1442 self.top_file, 

1443 index, 

1444 self.common_str, 

1445 ], 

1446 } 

1447 ) 

1448 return df.astype(str) 

1449 

1450 def load_custom_topology( 

1451 self, 

1452 custom_top: Optional[Union[CustomTopology, CustomAAsDict]] = None, 

1453 ) -> None: 

1454 """Loads a custom_topology from a `CustomTopology` class or a dict. 

1455 

1456 See Also: 

1457 :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology` 

1458 

1459 Args: 

1460 custom_top: Optional[Union[CustomTopology, CustomAAsDict]]: An 

1461 instance of :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology` 

1462 or a dictionary that can be made into such. 

1463 

1464 """ 

1465 # Encodermap imports 

1466 from encodermap.trajinfo.trajinfo_utils import CustomTopology 

1467 

1468 if isinstance(custom_top, CustomTopology): 

1469 self._custom_top = custom_top 

1470 else: 

1471 self._custom_top = CustomTopology.from_dict(custom_top, traj=self) 

1472 # overwrite the old topology 

1473 self.topology = self._custom_top.top 

1474 

1475 def load_CV( 

1476 self, 

1477 data: SingleTrajFeatureType, 

1478 attr_name: Optional[str] = None, 

1479 cols: Optional[list[int]] = None, 

1480 deg: Optional[bool] = None, 

1481 periodic: bool = True, 

1482 labels: Optional[list[str]] = None, 

1483 override: bool = False, 

1484 ) -> None: 

1485 """Load CVs into traj. Many options are possible. Provide xarray, 

1486 numpy array, em.loading.feature, em.featurizer, and even string! 

1487 

1488 This method loads CVs into the SingleTraj instance. Many ways of doing so 

1489 are available: 

1490 - :obj:`numpy.ndarray`: The easiest way. Provide a np array and a name for 

1491 the array, and the data will be saved as an instance variable, 

1492 accesible via `SingleTraj.name`. 

1493 - :obj:`xarray.DataArray`: You can load a multidimensional xarray as 

1494 data into the class. Please refer to xarrays own documentation 

1495 if you want to create one yourself. 

1496 - :obj:`xarray.Dataset`: You can add another dataset to the existing _CVs. 

1497 - :obj:`encodermap.loading.features.Feature`: If you provide one of the 

1498 features from ``encodermap.loading.features`` the resulting 

1499 features will be loaded and also be placed under the set name. 

1500 - :obj:`encodermap.loading.featurizer.Featurizer`: If you provide a 

1501 full featurizer, the data will be generated and be accessible 

1502 as an attribute. 

1503 - str: If a string is provided, the data will be loaded from a 

1504 ``.txt``, ``.npy``, or NetCDF / HDF5 ``.nc`` file. 

1505 

1506 Args: 

1507 data (Union[str, numpy.ndarray, xr.DataArray, 

1508 em.loading.features.Feature, em.loading.featurizer.Featurizer]): 

1509 The CV to load. Either as :obj:`numpy.ndarray`, 

1510 :obj:`xarray.DataArray`, EncoderMap feature, 

1511 or EncoderMap Featurizer. 

1512 attr_name (Optional[str]): The name under which the CV 

1513 should be found in the class. Is needed, if a raw numpy array 

1514 is passed, otherwise the name will be generated from the filename 

1515 (if ``data == str``), the DataArray.name (if ``data == xarray.DataArray``), 

1516 or the feature name. 

1517 cols (Optional[list]): A list specifying the columns 

1518 to use it for the high-dimensional data. If your highD data contains 

1519 (x,y,z,...)-errors or has an enumeration column at ``col=0`` 

1520 this can be used to remove this unwanted data. 

1521 deg (Optional[bool]): Whether the provided data is in radians (False) 

1522 or degree (True). It can also be None for non-angular data. 

1523 labels (Optional[Union[list, str]]): If you want to label 

1524 the data you provided, pass a list of str. If set to None, 

1525 the features in this dimension will be labeled as 

1526 ``[f"{attr_name.upper()} FEATURE {i}" for i in range(self.n_frames)]``. 

1527 If a str is provided, the features will be labeled as 

1528 ``[f"{attr_name.upper()} {label.upper()} {i}" for i in range(self.n_frames)]``. 

1529 If a list of str is provided, it needs to have the same length 

1530 as the traj has frames. Defaults to None. 

1531 override (bool): Whether to overwrite existing CVs. The method will also 

1532 print a message which CVs have been overwritten. 

1533 

1534 Examples: 

1535 

1536 Import EncoderMap and load an example Trajectory. 

1537 

1538 >>> import encodermap as em 

1539 >>> traj = em.SingleTraj("https://files.rcsb.org/view/1GHC.pdb") 

1540 

1541 Load the central dihedrals using ``data='central_dihedrals'` as shortcut. 

1542 

1543 >>> traj.load_CV("central_dihedrals") 

1544 >>> traj.central_dihedrals.shape 

1545 (14, 222) 

1546 

1547 

1548 >>> traj._CVs['central_dihedrals']['CENTRAL_DIHEDRALS'].values[:2] 

1549 ['CENTERDIH PSI RESID MET: 1 CHAIN 0' 

1550 'CENTERDIH OMEGA RESID MET: 1 CHAIN 0'] 

1551 

1552 Slicing the :obj:`SingleTraj` keeps all CVs in order. 

1553 

1554 >>> import numpy as np 

1555 >>> from pathlib import Path 

1556 >>> traj1 = em.SingleTraj( 

1557 ... Path(em.__file__).parent.parent / "tests/data/1am7_corrected.xtc", 

1558 ... Path(em.__file__).parent.parent / "tests/data/1am7_protein.pdb", 

1559 ... ) 

1560 >>> traj1.load_CV(traj1.xyz[..., -1], 'z_coordinate') 

1561 ... 

1562 >>> for i, frame in enumerate(traj1): 

1563 ... print(np.array_equal(frame.z_coordinate[0], frame.xyz[0, :, -1])) 

1564 ... if i == 3: 

1565 ... break 

1566 True 

1567 True 

1568 True 

1569 True 

1570 

1571 Raises: 

1572 FileNotFoundError: When the file given by `data` does not exist. 

1573 IOError: When the provided filename does not have .txt, .npy or .nc extension. 

1574 TypeError: When `data` does not match the specified input types. 

1575 Exception: When a numpy array has been passed as `data` and no `attr_name` has been provided. 

1576 Exception: When the provided `attr_name` is str, but cannot be a python identifier. 

1577 

1578 """ 

1579 # Local Folder Imports 

1580 from .trajinfo_utils import load_CVs_singletraj, trajs_combine_attrs 

1581 

1582 if isinstance(attr_name, str): 

1583 if "feature_indices" in attr_name: 

1584 raise Exception( 

1585 f"The substring 'feature_indices' is a protected attribute. " 

1586 f"Your attribute can't contain this substring." 

1587 ) 

1588 

1589 new_CVs = load_CVs_singletraj( 

1590 data, 

1591 self, 

1592 attr_name, 

1593 cols, 

1594 deg=deg, 

1595 periodic=periodic, 

1596 labels=labels, 

1597 ) 

1598 if self._CVs: 

1599 assert ( 

1600 len(new_CVs.coords["traj_num"]) == 1 

1601 ), f"something bad happened: {self._CVs=}" 

1602 if len(new_CVs.coords["traj_num"]) > 1: 

1603 raise Exception( 

1604 f"The provided feature resulted in a dataset with " 

1605 f"{new_CVs.sizes['traj_num']} trajectories. A `SingleTraj` " 

1606 f"class can't accept such a feature." 

1607 ) 

1608 if self.traj_num is not None: 

1609 assert new_CVs.coords["traj_num"] == np.array([self.traj_num]), ( 

1610 data, 

1611 self.traj_num, 

1612 new_CVs.coords["traj_num"], 

1613 ) 

1614 

1615 # check the sizes 

1616 len_CVs = new_CVs.coords["frame_num"].shape[0] 

1617 if self._CVs: 

1618 n_frames = self._CVs.coords["frame_num"].shape[0] 

1619 else: 

1620 n_frames = self.n_frames 

1621 if n_frames != len_CVs: 

1622 raise Exception( 

1623 f"Loading the requested {data} CVs is not possible, as " 

1624 f"they are not aligned with the number of frames in the " 

1625 f"trajectory. The CVs have {len_CVs} frames, the trajectory " 

1626 f"{self._traj_file} has {n_frames} frames." 

1627 ) 

1628 

1629 # check the keys and whether they get overwritten 

1630 if hasattr(new_CVs, "keys"): 

1631 new_keys = new_CVs.keys() 

1632 else: 

1633 new_keys = set([new_CVs.name]) 

1634 if override: 

1635 if overwritten_keys := self._CVs.keys() & new_keys: 

1636 warnings.warn( 

1637 f"Overwriting the following CVs with new values: " 

1638 f"{overwritten_keys}." 

1639 ) 

1640 self._CVs = xr.merge( 

1641 [new_CVs, self._CVs], 

1642 combine_attrs=trajs_combine_attrs, 

1643 compat="override", 

1644 join="left", 

1645 ) 

1646 else: 

1647 try: 

1648 CVs = xr.merge([self._CVs, new_CVs], combine_attrs=trajs_combine_attrs) 

1649 assert len(CVs.coords["traj_num"]) == 1, ( 

1650 f"Can't merge\n\n{self._CVs=}\n\nand\n\n{new_CVs=}\n\n, " 

1651 f"because they would stack along the traj axis." 

1652 ) 

1653 self._CVs = CVs 

1654 except xr.core.merge.MergeError as e: 

1655 msg = ( 

1656 f"Could not add the CV `{attr_name}` to the CVs of the traj " 

1657 f"likely due to it being already in the CVs " 

1658 f"({list(self.CVs.keys())}). Set `override` to True to " 

1659 f"overwrite these CVs. In case you are faced with " 

1660 f"conflicting values on 'traj_name', here they are:\n\n" 

1661 f"{self._CVs.coords['traj_name']=}\n\n{new_CVs.coords['traj_name']=}" 

1662 ) 

1663 raise Exception(msg) from e 

1664 

1665 def _common_str_to_hdf_file(self, fname: Union[Path, str]) -> None: 

1666 """Saves this traj's common_str as an attribute to a hdf5 file.""" 

1667 common_str = self.common_str 

1668 if common_str is None: 

1669 common_str = "" 

1670 with h5py.File(fname, "a") as f: 

1671 f.attrs[f"common_str_{self.traj_num}"] = common_str 

1672 

1673 def save_CV_as_numpy( 

1674 self, 

1675 attr_name: str, 

1676 fname: Optional[str] = None, 

1677 overwrite: bool = False, 

1678 ) -> None: 

1679 """Saves a specified collective variable of this traj as a ``.npy`` file. 

1680 

1681 This got its own method for parallelization purposes. 

1682 

1683 Args: 

1684 attr_name (str): Name of the CV to save. 

1685 fname (str, optional): Can be either 

1686 overwrite (bool, opt): Whether to overwrite the file. Defaults to False. 

1687 

1688 Raises: 

1689 IOError: When the file already exists and overwrite is set to False. 

1690 

1691 """ 

1692 if fname is None: # pragma: no cover 

1693 fname = f"{self.basename}_{attr_name}.npy" 

1694 if os.path.isdir(fname): 

1695 fname = os.path.join(fname, f"{self.basename}_{attr_name}.npy") 

1696 if os.path.isfile(fname) and not overwrite: 

1697 raise IOError(f"{fname} already exists. Set overwrite=True to overwrite.") 

1698 np.save(fname, self.CVs[attr_name]) 

1699 

1700 def sidechain_info(self) -> dict[int, dict[int, Sequence[int]]]: 

1701 """Indices used for the AngleDihedralCartesianEncoderMap class to 

1702 allow training with multiple different sidechains. 

1703 

1704 Returns: 

1705 dict[str, Sequence[int]]: The indices. The key '-1' is used for 

1706 the hypothetical convex hull of all feature spaces (the output of 

1707 the tensorflow model). The other keys match the common_str of the 

1708 trajs. 

1709 

1710 Raises: 

1711 Exception: When the common_strings and topologies are not 

1712 aligned. An exception is raised. Aligned means that all trajs 

1713 with the same common_str should possess the same topology. 

1714 

1715 """ 

1716 # make sure no clustal w has not been loaded 

1717 if hasattr(self, "clustal_w"): 

1718 raise NotImplementedError( 

1719 f"This is currently not supported for TrajEsnembles with " 

1720 f"clustal_w alignments." 

1721 ) 

1722 if not "side_dihedrals" in self._CVs: 

1723 self.load_CV("side_dihedrals") 

1724 

1725 max_residues = self.top.n_residues 

1726 out = {} 

1727 

1728 for residx in range(1, max_residues + 1): 

1729 labels = [ 

1730 l 

1731 for l in self._CVs.side_dihedrals.coords["SIDE_DIHEDRALS"].values 

1732 if int(re.findall(r"\d+", l)[-2]) == residx 

1733 ] 

1734 first_frame = self._CVs.coords["frame_num"].min() 

1735 length = ( 

1736 self._CVs.side_dihedrals.sel( 

1737 frame_num=first_frame, SIDE_DIHEDRALS=labels 

1738 ) 

1739 .dropna(dim="SIDE_DIHEDRALS") 

1740 .sizes["SIDE_DIHEDRALS"] 

1741 ) 

1742 out[residx] = length 

1743 return out 

1744 

1745 def atom_slice( 

1746 self, 

1747 atom_indices: np.ndarray, 

1748 invert: bool = False, 

1749 ) -> None: 

1750 """Deletes atoms from this :obj:`SingleTraj` instance. 

1751 

1752 Args: 

1753 atom_indices (Union[list, numpy.ndarray]): The indices of the 

1754 atoms to keep. 

1755 invert (bool): If False, it is assumed, that the atoms in ``atom_indices`` 

1756 are the ones to be kept. If True, the atoms in ``atom_indices`` 

1757 are the ones to be removed. 

1758 

1759 """ 

1760 atom_indices = np.asarray(atom_indices) 

1761 if invert: 

1762 atom_indices = np.array( 

1763 [a.index for a in self.top.atoms if a.index not in atom_indices] 

1764 ) 

1765 self._atom_indices = atom_indices 

1766 if self._CVs: 

1767 warnings.warn( 

1768 "Dropping CVs from trajectory. Slicing CVs with this method is " 

1769 "currently not possible. Raise an issue if you want to have this " 

1770 "feature added." 

1771 ) 

1772 self._CVs = xr.Dataset() 

1773 self._custom_top._parsed = False 

1774 self.topology = self._custom_top.top.subset(atom_indices) 

1775 if self.traj_file != ".": 

1776 self._traj = self.traj_file 

1777 

1778 def join(self, other: Union[SingleTraj, md.Trajectory]) -> md.Trajectory: 

1779 """Join two trajectories together along the time/frame axis. 

1780 

1781 Note: 

1782 Returns a ``mdtraj.Trajectory`` and thus loses CVs, filenames, etc. 

1783 

1784 """ 

1785 if isinstance(other, md.Trajectory): 

1786 return self.traj.join(other) 

1787 return self.traj.join(other.traj) 

1788 

1789 def stack(self, other: SingleTraj) -> md.Trajectory: 

1790 """Stack two trajectories along the atom axis 

1791 

1792 Note: 

1793 Returns a m``dtraj.Trajectory`` and thus loses CVs, filenames, etc. 

1794 

1795 """ 

1796 if isinstance(other, md.core.trajectory.Trajectory): 

1797 return self.traj.stack(other) 

1798 return self.traj.stack(other.traj) 

1799 

1800 def superpose( 

1801 self, 

1802 reference: Union[md.Trajectory, SingleTraj], 

1803 frame: int = 0, 

1804 atom_indices: Optional[np.ndarray] = None, 

1805 ref_atom_indices: Optional[np.ndarray] = None, 

1806 parallel: bool = True, 

1807 inherit_CVs: bool = False, 

1808 ) -> SingleTraj: 

1809 """Superpose each conformation in this trajectory upon a reference 

1810 

1811 Args: 

1812 reference (Union[mdtraj.Trajectory, SingleTraj]): The reference 

1813 frame to align to. If the reference has multiple frames and 

1814 you want to use a specific frame as reference, use the 

1815 ``frame`` argument also. 

1816 frame (int, optional): Align to this frame in reference. 

1817 Default is 0. 

1818 atom_indices (Union[np.array, None], optional): Indices in 

1819 self, used to calculate RMS values. Defaults to None which 

1820 means all atoms will be used. 

1821 ref_atom_indices (Union[np.array, None], optional): Indices 

1822 in reference, used to calculate RMS values. Defaults to None 

1823 which means all atoms will be used. 

1824 parallel (bool, optional): Use OpenMP to run the superposition 

1825 in parallel over multiple cores. 

1826 inherit_CVs (bool, optional): Whether to also inherit the CVs. 

1827 This feature is currently not implemented. It would require 

1828 additional code in all Feature classes discerning intrinsic 

1829 (distance, angle, cluster_membership, etc.) or an extrinsic 

1830 feature (absolute coordinate, COG position, etc.). Then this 

1831 extrinsic/intrinsic boolean flag also needs to accompany the 

1832 xarray Datasets, so that the intrinsic CVs can be inherited, 

1833 and the extrinsic can be dropped with a corresponding message. 

1834 

1835 Returns: 

1836 SingleTraj: A new trajectory with atoms aligned. 

1837 

1838 """ 

1839 if isinstance(reference, md.core.trajectory.Trajectory): 

1840 new = SingleTraj( 

1841 self.traj.superpose( 

1842 reference, frame, atom_indices, ref_atom_indices, parallel 

1843 ) 

1844 ) 

1845 new = SingleTraj( 

1846 self.traj.superpose( 

1847 reference.traj, frame, atom_indices, ref_atom_indices, parallel 

1848 ) 

1849 ) 

1850 if self.traj_file: 

1851 new._traj_file = self._traj_file 

1852 if self.top_file: 

1853 new._top_file = self._top_file 

1854 if self._CVs and inherit_CVs: 

1855 raise NotImplementedError("CV inheritance not implemented yet.") 

1856 return new 

1857 

1858 def save( 

1859 self, 

1860 fname: str, 

1861 CVs: Union[Literal["all"], list[str]] = "all", 

1862 overwrite: bool = False, 

1863 ) -> None: 

1864 """Save the trajectory as HDF5 file format to disk. 

1865 

1866 Args: 

1867 fname (str): The filename. 

1868 CVs (Union[List, 'all'], optional): Either provide a list of strings 

1869 of the CVs you would like to save to disk, or set to 'all' to save 

1870 all CVs. Defaults to []. 

1871 overwrite (bool, optional): Whether force overwrite an existing file. 

1872 Defaults to False. 

1873 

1874 Raises: 

1875 IOError: When the file already exists and overwrite is False. 

1876 

1877 """ 

1878 # check and drop inhomogeneous attributes 

1879 offending_keys = [] 

1880 if self._CVs: 

1881 for da in self._CVs.data_vars.values(): 

1882 for key, val in da.attrs.items(): 

1883 if isinstance(val, list): 

1884 offending_keys.append(key) 

1885 for key in offending_keys: 

1886 for da in self._CVs.data_vars.values(): 

1887 if key in da.attrs: 

1888 del da.attrs[key] 

1889 if key in self._CVs.attrs: 

1890 del self._CVs.attrs[key] 

1891 # raise exception if file already exists 

1892 if os.path.isfile(fname) and not overwrite: 

1893 raise IOError(f"{fname} already exists. Set overwrite=True to overwrite.") 

1894 else: 

1895 self.traj.save_hdf5(fname, force_overwrite=overwrite) 

1896 

1897 # save the topology as a nested dict 

1898 if self._custom_top: 

1899 self._custom_top.traj = self 

1900 self._custom_top.to_hdf_file(fname=fname) 

1901 

1902 if self.common_str: 

1903 with h5.File(fname, "a") as file: 

1904 if self.traj_num is None: 

1905 file.attrs["common_str"] = self.common_str 

1906 else: 

1907 file.attrs[f"common_str_{self.traj_num}"] = self.common_str 

1908 

1909 # save the CVs with xarray 

1910 if self._CVs and CVs == "all": 

1911 save_netcdf_alongside_mdtraj(fname, self._CVs) 

1912 return 

1913 if self._CVs and isinstance(CVs, list): 

1914 with h5.File(fname, "a") as file: 

1915 if "CVs" in list(file.keys()): 

1916 grp = file["CVs"] 

1917 else: 

1918 grp = file.create_group("CVs") 

1919 for key in CVs: 

1920 value = self._CVs[key] 

1921 assert self.n_frames == value.shape[1] 

1922 grp.create_dataset(name=key, data=value) 

1923 

1924 @overload 

1925 def iterframes( 1925 ↛ exitline 1925 didn't jump to the function exit

1926 self, 

1927 with_traj_num: bool = False, 

1928 ) -> Iterable[tuple[int, SingleTraj]]: ... 

1929 

1930 @overload 

1931 def iterframes( 1931 ↛ exitline 1931 didn't jump to the function exit

1932 self, 

1933 with_traj_num: bool = True, 

1934 ) -> Iterable[tuple[int, int, SingleTraj]]: ... 

1935 

1936 def iterframes( 

1937 self, 

1938 with_traj_num: bool = False, 

1939 ) -> Iterable[tuple[Any, ...]]: 

1940 """Iterator over the frames in this class. 

1941 

1942 Args: 

1943 with_traj_num (bool): Whether to return a three-tuple of traj_num, 

1944 frame_num, frame (True) or just traj_num, frame (False). 

1945 

1946 Yields: 

1947 tuple: A tuple containing the following: 

1948 - int: The traj_num. 

1949 - int: The frame_num. 

1950 - encodermap.SingleTraj: An SingleTraj object. 

1951 

1952 Examples: 

1953 

1954 Import EncoderMap and create :obj:`SingleTraj` instance. 

1955 

1956 >>> import encodermap as em 

1957 >>> traj = em.SingleTraj('https://files.rcsb.org/view/1YUG.pdb') 

1958 >>> traj.n_frames 

1959 15 

1960 

1961 Slicing the trajectory every 5th frame 

1962 

1963 >>> traj = traj[::5] 

1964 >>> traj.n_frames 

1965 3 

1966 

1967 Using the ``iterframes()`` iterator. 

1968 

1969 >>> for frame_num, frame in traj.iterframes(): 

1970 ... print(frame_num, frame.n_frames) 

1971 0 1 

1972 5 1 

1973 10 1 

1974 

1975 """ 

1976 if self.id.ndim == 2: 

1977 a = self.id[:, 1] 

1978 else: 

1979 a = self.id 

1980 for i, frame in zip(a, self): 

1981 if with_traj_num: 

1982 yield self.traj_num, i, frame 

1983 else: 

1984 yield i, frame 

1985 

1986 def __copy__(self): 

1987 cls = self.__class__ 

1988 result = cls.__new__(cls) 

1989 result.__dict__.update(self.__dict__) 

1990 return result 

1991 

1992 def __deepcopy__(self, memo): 

1993 # Standard Library Imports 

1994 from copy import deepcopy 

1995 

1996 cls = self.__class__ 

1997 result = cls.__new__(cls) 

1998 memo[id(self)] = result 

1999 for k, v in self.__dict__.items(): 

2000 setattr(result, k, deepcopy(v, memo)) 

2001 return result 

2002 

2003 def __hash__(self) -> int: 

2004 hash_value = hash(self.top) 

2005 # combine with hashes of arrays 

2006 hash_value ^= _hash_numpy_array(self._xyz) 

2007 hash_value ^= _hash_numpy_array(self.time) 

2008 if self._unitcell_lengths is not None: 

2009 hash_value ^= _hash_numpy_array(self._unitcell_lengths) 

2010 if self._unitcell_angles is not None: 

2011 hash_value ^= _hash_numpy_array(self._unitcell_angles) 

2012 return hash_value 

2013 

2014 def __eq__(self, other: SingleTraj) -> bool: 

2015 """Two SingleTraj objects are the same, when the trajectories are the same, 

2016 the files are the same, and the loaded CVs are the same.""" 

2017 if self._traj_file != other._traj_file: 

2018 return False 

2019 if self.__hash__() != other.__hash__(): 

2020 return False 

2021 if not self._CVs.equals(other._CVs): 

2022 return False 

2023 return True 

2024 

2025 def __reversed__(self) -> SingleTraj: 

2026 """Reverses the frame order of the traj. Same as traj[::-1]""" 

2027 return self[::-1] 

2028 

2029 def __enter__(self): 

2030 """Enters context manager. Inside the context manager, the traj stays loaded.""" 

2031 self.load_traj() 

2032 return self 

2033 

2034 def __exit__(self, type, value, traceback): 

2035 """Exits the context manager and deletes unwanted variables.""" 

2036 self.unload() 

2037 

2038 def __iter__(self): 

2039 """Iterate over frames in this class. Returns the correct 

2040 CVs along with the frame of the trajectory.""" 

2041 self._index = 0 

2042 if len(self) == 0 and self.index is None: 

2043 self.load_traj() 

2044 return self 

2045 

2046 def __next__(self): 

2047 if len(self.id) == 1: 

2048 return self 

2049 if self._index >= self.n_frames: 

2050 raise StopIteration 

2051 else: 

2052 self._index += 1 

2053 return self[self._index - 1] 

2054 

2055 def __getitem__(self, key: CanBeIndex) -> SingleTraj: 

2056 """This method returns another trajectory as a SingleTraj class. 

2057 

2058 Args: 

2059 key (Union[int, list[int], numpy.ndarray, slice]): Indexing the trajectory 

2060 can be done by int (returns a traj with 1 frame), lists of int or 

2061 numpy.ndarray (returns a new traj with len(traj) == len(key)), or 

2062 slice ([::3]), which returns a new traj with the correct number of 

2063 frames. 

2064 

2065 Returns: 

2066 Info_Single: A SingleTraj object with this frame in it. 

2067 

2068 """ 

2069 if not isinstance(key, (int, np.int_, list, np.ndarray, slice)): 

2070 raise TypeError( 

2071 f"Indexing of `SingleTraj` requires the index to " 

2072 f"be one of the following types: (int, " 

2073 f"list, np.ndarray, slice), you provided {type(key)}." 

2074 ) 

2075 

2076 if any([isinstance(i, (int, np.integer)) for i in self.index]) and key != 0: 

2077 raise IndexError( 

2078 f"SingleTraj index {key} out of range for traj with only 1 frame." 

2079 ) 

2080 

2081 if isinstance(key, (int, np.integer)): 

2082 if key > self.n_frames: 

2083 raise IndexError( 

2084 f"Index {key} out of range for traj with " 

2085 f"{self.n_frames} frames." 

2086 ) 

2087 if isinstance(key, (list, np.ndarray)): 

2088 if any([k > self.n_frames for k in key]): 

2089 raise IndexError( 

2090 f"At least one index in {key} out of range for " 

2091 f"traj with {self.n_frames} frames. Normally frames are " 

2092 f"selected by current integer index. If you are trying to " 

2093 f"access frames by their number as it is in the file {self.traj_file}, " 

2094 f"you can use the `fsel[]` locator of this class:\n\n" 

2095 f"traj = em.load('traj_file.xtc', 'top_file.xtc')\n" 

2096 f"traj.fsel[{key}]." 

2097 ) 

2098 

2099 # append the index to the list of "transformations" 

2100 new_index = (*self.index, key) 

2101 

2102 # build a new traj from that 

2103 if self.backend == "no_load": 

2104 traj_out = SingleTraj( 

2105 self.traj_file, 

2106 self.top_file, 

2107 backend=self.backend, 

2108 common_str=self.common_str, 

2109 index=new_index, 

2110 traj_num=self.traj_num, 

2111 basename_fn=self.basename_fn, 

2112 ) 

2113 else: 

2114 traj_out = SingleTraj( 

2115 self.trajectory[key], 

2116 self.top_file, 

2117 backend=self.backend, 

2118 common_str=self.common_str, 

2119 index=new_index, 

2120 traj_num=self.traj_num, 

2121 basename_fn=self.basename_fn, 

2122 ) 

2123 traj_out._traj_file = self._traj_file 

2124 traj_out._top_file = self._top_file 

2125 assert traj_out._traj_file == self._traj_file 

2126 

2127 # the original_frames 

2128 if self._atom_indices is not None: 

2129 traj_out.atom_slice(self._atom_indices) 

2130 assert self.n_atoms == traj_out.n_atoms 

2131 assert self.n_residues == traj_out.n_residues 

2132 traj_out._orig_frames = self._orig_frames 

2133 traj_out._loaded_once = self._loaded_once 

2134 

2135 # last the CVs 

2136 if self._CVs: 

2137 traj_out._CVs = self._CVs.isel(frame_num=key) 

2138 if "frame_num" not in traj_out._CVs.dims: 

2139 traj_out._CVs = traj_out._CVs.expand_dims( 

2140 { 

2141 "frame_num": [key], 

2142 }, 

2143 ) 

2144 traj_out._CVs.assign_coords(time=("frame_num", traj_out.time)) 

2145 

2146 return traj_out 

2147 

2148 @property 

2149 def fsel(self): 

2150 return SingleTrajFsel(self) 

2151 

2152 def __add__(self, y: SingleTraj) -> TrajEnsemble: 

2153 """Addition of two SingleTraj classes yields TrajEnsemble class. 

2154 

2155 Args: 

2156 y (encodermap.SingleTraj): The other traj, that will be added. 

2157 

2158 Returns: 

2159 encodermap.TrajEnsemble: The new trajs. 

2160 

2161 """ 

2162 return self._add_along_traj(y) 

2163 

2164 def __getattr__(self, attr): 

2165 """What to do when attributes cannot be obtained in a normal way?. 

2166 

2167 This method allows access to the `self.CVs` dictionary's values as 

2168 instance variables. Furthermore, if a mdtraj variable is called, 

2169 the traj is loaded, and the correct variable is returned. 

2170 

2171 """ 

2172 if attr in self._mdtraj_attr: 

2173 self.load_traj() 

2174 return getattr(self.traj, attr) 

2175 elif attr in self._CVs: 

2176 val = self._CVs[attr] # [index] 

2177 axis_name = ( 

2178 "feature_axis" 

2179 if "feature_axis" in val.attrs 

2180 else "feature_axes" if "feature_axes" in val.attrs else None 

2181 ) 

2182 if np.any(np.isnan(val)): 

2183 if axis_name is not None: 

2184 if "indices" in val.name: 

2185 assert val.dropna("ATOM_NO").size > 0, ( 

2186 f"Can't `dropna` from DataArray {val=}. This would " 

2187 f"result in an empty array. Here are the nan indices: " 

2188 f"{np.where(np.isnan(val.values))[0]=}" 

2189 ) 

2190 val = val.dropna("ATOM_NO") 

2191 else: 

2192 val = val.dropna(val.attrs[axis_name]) 

2193 else: 

2194 val = val.dropna(attr.upper()) 

2195 val_squeezed = val.values.squeeze(0) 

2196 assert val_squeezed.size > 0, ( 

2197 f"The value of {attr} is unexpectedly empty {val.size=} after dropping nans " 

2198 f"and dropping nans from he feature axis '{val.attrs[axis_name]}' and " 

2199 f"squeezing out the trajectory axis. This is the original " 

2200 f"DataArray:\n{self._CVs[attr]}\n\nAfter the dropna, the value was:\n{val}" 

2201 ) 

2202 return val_squeezed 

2203 elif attr == "traj": # pragma: no cover 

2204 self.__getattribute__(attr) 

2205 elif attr == "id": # pragma: no cover 

2206 self.__getattribute__(attr) 

2207 elif attr == "top": 

2208 return self._get_top() 

2209 else: 

2210 raise AttributeError(f"'SingleTraj' object has no attribute '{attr}'") 

2211 

2212 def _string_summary(self) -> str: # pragma: no cover 

2213 """Returns a summary about the current instance. 

2214 

2215 Number of frames, index, loaded CVs. 

2216 

2217 """ 

2218 s = f"encodermap.SingleTraj object." 

2219 if self.backend == "no_load": 

2220 s += " Currently not in memory." 

2221 else: 

2222 s += " Data currently in memory." 

2223 if self.basename: 

2224 s += f" Basename is '{self.basename}'." 

2225 if self.index != (None,): 

2226 with np.printoptions(threshold=1, edgeitems=1): 

2227 s += f" At indices {self.index[1:]}." 

2228 if self._CVs: 

2229 for key, value in self._CVs.items(): 

2230 if "feature_indices" in key: 

2231 continue 

2232 shape = value.shape 

2233 if not shape: 

2234 shape = 1 

2235 s += f" CV {key} with shape {shape} loaded." 

2236 else: 

2237 s += " Not containing any CVs." 

2238 if "n_atoms" in self.__dict__.keys(): 

2239 s += f" Containing {self.n_atoms} atoms." 

2240 if "n_frames" in self.__dict__.keys(): 

2241 s += f" Containing {self.n_frames} frames." 

2242 if self.common_str: 

2243 s += f" Common string is '{self.common_str}'." 

2244 if self._custom_top: 

2245 s += ( 

2246 f" Custom topology with residues: " 

2247 f"{[r.name for r in self._custom_top.residues]}." 

2248 ) 

2249 return s 

2250 

2251 def __len__(self): 

2252 return self.n_frames 

2253 

2254 def __str__(self): 

2255 return self._string_summary() 

2256 

2257 def __repr__(self): 

2258 return f"<{self._string_summary()} Object at 0x{id(self):02x}>"