Coverage for encodermap/trajinfo/info_all.py: 8%

1225 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-15 21:06 +0200

1# -*- coding: utf-8 -*- 

2# encodermap/trajinfo/info_all.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Classes to work with ensembles of trajectories. 

23 

24The statistics of a protein can be better described by an ensemble of proteins, 

25rather than a single long trajectory. Treating a protein in such a way opens great 

26possibilities and changes the way one can treat molecular dynamics data. 

27Trajectory ensembles allow: 

28 * Faster convergence via adaptive sampling. 

29 * Better grasp of equilibrium and off-equilibrium dynamics. 

30 

31 

32This subpackage contains two classes which are containers of trajectory data. 

33The SingleTraj trajectory contains information about a single trajectory. 

34The TrajEnsemble class contains information about multiple trajectories. This adds 

35a new dimension to MD data. The time and atom dimension are already established. 

36Two frames can be appended along the time axis to get a trajectory with multiple 

37frames. If they are appended along the atom axis, the new frame contains the 

38atoms of these two. The trajectory works in a similar fashion. Adding two trajectories 

39along the trajectory axis returns a trajectory ensemble, represented as a `TrajEnsemble` 

40class in this package. 

41 

42See also: 

43 http://statisticalbiophysicsblog.org/?p=92 

44 

45""" 

46 

47################################################################################ 

48# Imports 

49################################################################################ 

50 

51 

52# Future Imports at the top 

53from __future__ import annotations 

54 

55# Standard Library Imports 

56import contextlib 

57import copy 

58import json 

59import operator 

60import os 

61import re 

62import warnings 

63from collections.abc import Callable, Iterator, KeysView, Sequence 

64from copy import deepcopy 

65from pathlib import Path 

66from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload 

67 

68# Third Party Imports 

69import numpy as np 

70import tables 

71from optional_imports import _optional_import 

72from tqdm import tqdm as normal_tqdm_ 

73from tqdm.notebook import tqdm as notebook_tqdm_ 

74 

75# Encodermap imports 

76from encodermap._typing import CanBeIndex, CustomAAsDict 

77from encodermap.misc.misc import ( 

78 _TOPOLOGY_EXTS, 

79 _is_notebook, 

80 get_full_common_str_and_ref, 

81) 

82from encodermap.misc.xarray_save_wrong_hdf5 import save_netcdf_alongside_mdtraj 

83from encodermap.trajinfo.trajinfo_utils import CustomTopology, TrajEnsembleFeatureType 

84 

85 

86################################################################################ 

87# Typing 

88################################################################################ 

89 

90 

91if TYPE_CHECKING: # pragma: no cover 

92 # Third Party Imports 

93 import mdtraj as md 

94 import pandas as pd 

95 import tensorflow as tf 

96 import xarray as xr 

97 

98 # Encodermap imports 

99 from encodermap.trajinfo.info_single import SingleTraj 

100 

101 

102string_types = (str,) 

103 

104 

105################################################################################ 

106# Optional Imports 

107################################################################################ 

108 

109 

110md = _optional_import("mdtraj") 

111pd = _optional_import("pandas") 

112xr = _optional_import("xarray") 

113HDF5TrajectoryFile = _optional_import("mdtraj", "formats.hdf5") 

114h5py = _optional_import("h5py") 

115 

116################################################################################ 

117# Globals 

118################################################################################ 

119 

120 

121__all__: list[str] = ["TrajEnsemble"] 

122 

123 

124################################################################################ 

125# Utils 

126################################################################################ 

127 

128 

129class notebook_tqdm(notebook_tqdm_): 

130 def __init__(self, *args, **kwargs): 

131 kwargs.pop("function", None) 

132 super().__init__(*args, **kwargs) 

133 

134 def reset(self, total=None, **kwargs): 

135 self.total = total 

136 self.refresh() 

137 

138 def update(self, n=1, **kwargs): 

139 kwargs.pop("function", None) 

140 super().update(n=n) 

141 

142 

143class normal_tqdm(normal_tqdm_): 

144 def __init__(self, *args, **kwargs): 

145 self._calls = {} 

146 function = kwargs.pop("function") 

147 super().__init__(*args, **kwargs) 

148 if function not in self._calls: 

149 self._calls[function] = { 

150 "update_calls": 0, 

151 "total": self.total, 

152 } 

153 self.print = os.getenv("ENCODERMAP_PRINT_PROG_UPDATES", "False") == "True" 

154 if self.print: 

155 print("INSTANTIATION") 

156 self.debug_print() 

157 

158 def debug_print(self): 

159 print(f"Progbar {id(self)}") 

160 for function, data in self._calls.items(): 

161 print( 

162 f"{function:<15} total: {data['total']:>3} n: {data['update_calls']:>3}" 

163 ) 

164 print("\n") 

165 

166 def update(self, n=1, **kwargs): 

167 function = kwargs.pop("function", None) 

168 if function is not None: 

169 if function not in self._calls: 

170 self._calls[function] = { 

171 "update_calls": 0, 

172 "total": 0, 

173 } 

174 if self.print: 

175 print(f"BEFORE UPDATE ({function})") 

176 self.debug_print() 

177 super().update(n) 

178 if function is not None: 

179 self._calls[function]["update_calls"] += 1 

180 if self.print and function is not None: 

181 print(f"AFTER UPDATE ({function})") 

182 self.debug_print() 

183 

184 def reset(self, total=None, **kwargs): 

185 assert total > self.total 

186 function = kwargs.pop("function", None) 

187 if function is not None: 

188 if function not in self._calls: 

189 self._calls[function] = { 

190 "update_calls": 0, 

191 "total": total - self.total, 

192 } 

193 else: 

194 self._calls[function]["total"] += total - self.total 

195 if self.print: 

196 print(f"BEFORE RESET ({function})") 

197 self.debug_print() 

198 self.total = total 

199 self.refresh() 

200 if self.print and function is not None: 

201 print(f"AFTER RESET ({function})") 

202 self.debug_print() 

203 

204 

205@contextlib.contextmanager 

206def temp_seed(seed: int) -> Iterator[None]: 

207 """Temporarily set a numpy seed in a context manager. 

208 

209 Args: 

210 seed (int): The seed. 

211 

212 Examples: 

213 >>> from encodermap.trajinfo.info_all import temp_seed 

214 >>> import numpy as np 

215 >>> with temp_seed(123456789): 

216 ... print(np.random.randint(low=0, high=10, size=(5, ))) 

217 [8 2 9 7 4] 

218 

219 """ 

220 state = np.random.get_state() 

221 np.random.seed(seed) 

222 try: 

223 yield 

224 finally: 

225 np.random.set_state(state) 

226 

227 

228def _check_mode(m, modes): 

229 if m not in modes: 

230 raise ValueError( 

231 "This operation is only available when a file " 'is open in mode="%s".' % m 

232 ) 

233 

234 

235################################################################################ 

236# Classes 

237################################################################################ 

238 

239 

240class HDF5GroupWrite(md.formats.HDF5TrajectoryFile): 

241 def keys(self) -> KeysView: 

242 if not self._open: 

243 raise Exception(f"Can't view keys of closed HDF5 file.") 

244 nodes = [n.name for n in self._handle.list_nodes("/")] 

245 return KeysView(nodes) 

246 

247 def _initialize_headers( 

248 self, 

249 group_id: str, 

250 n_atoms: int, 

251 set_coordinates: bool, 

252 set_time: bool, 

253 set_cell: bool, 

254 set_velocities: bool, 

255 set_kineticEnergy: bool, 

256 set_potentialEnergy: bool, 

257 set_temperature: bool, 

258 set_alchemicalLambda: bool, 

259 ) -> None: 

260 # Local Folder Imports 

261 from .._version import get_versions 

262 

263 version = get_versions()["version"] 

264 self._n_atoms = n_atoms 

265 

266 self._handle.root._v_attrs.conventions = "Pande" 

267 self._handle.root._v_attrs.conventionVersion = "1.1" 

268 self._handle.root._v_attrs.program = "MDTraj" 

269 self._handle.root._v_attrs.programVersion = version 

270 self._handle.root._v_attrs.title = "title" 

271 

272 # if the client has not the title attribute themselves, we'll 

273 # set it to MDTraj as a default option. 

274 if not hasattr(self._handle.root._v_attrs, "application"): 

275 self._handle.root._v_attrs.application = "MDTraj" 

276 

277 # create arrays that store frame level informat 

278 if set_coordinates: 

279 self._create_earray( 

280 where="/", 

281 name=f"coordinates_{group_id}", 

282 atom=self.tables.Float32Atom(), 

283 shape=(0, self._n_atoms, 3), 

284 ) 

285 getattr(self._handle.root, f"coordinates_{group_id}").attrs[ 

286 "units" 

287 ] = "nanometers" 

288 

289 if set_time: 

290 self._create_earray( 

291 where="/", 

292 name=f"time_{group_id}", 

293 atom=self.tables.Float32Atom(), 

294 shape=(0,), 

295 ) 

296 getattr(self._handle.root, f"time_{group_id}").attrs[ 

297 "units" 

298 ] = "picoseconds" 

299 

300 if set_cell: 

301 self._create_earray( 

302 where="/", 

303 name=f"cell_lengths_{group_id}", 

304 atom=self.tables.Float32Atom(), 

305 shape=(0, 3), 

306 ) 

307 self._create_earray( 

308 where="/", 

309 name=f"cell_angles_{group_id}", 

310 atom=self.tables.Float32Atom(), 

311 shape=(0, 3), 

312 ) 

313 getattr(self._handle.root, f"cell_lengths_{group_id}").attrs[ 

314 "units" 

315 ] = "nanometers" 

316 getattr(self._handle.root, f"cell_angles_{group_id}").attrs[ 

317 "units" 

318 ] = "degrees" 

319 

320 if set_velocities: 

321 self._create_earray( 

322 where="/", 

323 name=f"velocities_{group_id}", 

324 atom=self.tables.Float32Atom(), 

325 shape=(0, self._n_atoms, 3), 

326 ) 

327 getattr(self._handle.root, f"velocities_{group_id}").attrs[ 

328 "units" 

329 ] = "nanometers/picosecond" 

330 

331 if set_kineticEnergy: 

332 self._create_earray( 

333 where="/", 

334 name=f"kineticEnergy_{group_id}", 

335 atom=self.tables.Float32Atom(), 

336 shape=(0,), 

337 ) 

338 getattr(self._handle.root, f"kineticEnergy_{group_id}").attrs[ 

339 "units" 

340 ] = "kilojoules_per_mole" 

341 

342 if set_potentialEnergy: 

343 self._create_earray( 

344 where="/", 

345 name=f"potentialEnergy_{group_id}", 

346 atom=self.tables.Float32Atom(), 

347 shape=(0,), 

348 ) 

349 getattr(self._handle.root, f"potentialEnergy_{group_id}").attrs[ 

350 "units" 

351 ] = "kilojoules_per_mole" 

352 

353 if set_temperature: 

354 self._create_earray( 

355 where="/", 

356 name=f"temperature_{group_id}", 

357 atom=self.tables.Float32Atom(), 

358 shape=(0,), 

359 ) 

360 getattr(self._handle.root, f"temperature_{group_id}").attrs[ 

361 "units" 

362 ] = "kelvin" 

363 

364 if set_alchemicalLambda: 

365 self._create_earray( 

366 where="/", 

367 name=f"lambda_{group_id}", 

368 atom=self.tables.Float32Atom(), 

369 shape=(0,), 

370 ) 

371 self._get_node("/", name=f"lambda_{group_id}").attrs["units"] = ( 

372 "dimensionless" 

373 ) 

374 

375 def write_into_group( 

376 self, 

377 group_id: str, 

378 coordinates: np.ndarray, 

379 time: np.ndarray, 

380 cell_lengths: np.ndarray, 

381 cell_angles: np.ndarray, 

382 topology: md.Topology, 

383 ) -> None: 

384 # Third Party Imports 

385 from mdtraj.utils import ensure_type, in_units_of 

386 

387 _check_mode(self.mode, ("w", "a")) 

388 

389 if self.mode == "a": 

390 try: 

391 self._frame_index = len( 

392 getattr(self._handle.root, f"coordinates_{group_id}") 

393 ) 

394 self._needs_initialization = False 

395 except self.tables.NoSuchNodeError: 

396 self._frame_index = 0 

397 self._needs_initialization = True 

398 

399 # these must be either both present or both absent. since 

400 # we're going to throw an error if one is present w/o the other, 

401 # lets do it now. 

402 if cell_lengths is None and cell_angles is not None: 

403 raise ValueError("cell_lengths were given, but no cell_angles") 

404 if cell_lengths is not None and cell_angles is None: 

405 raise ValueError("cell_angles were given, but no cell_lengths") 

406 

407 # if the input arrays are openmm.unit.Quantities, convert them 

408 # into md units. Note that this acts as a no-op if the user doesn't 

409 # have openmm.unit installed (e.g. they didn't install OpenMM) 

410 coordinates = in_units_of(coordinates, None, "nanometers") 

411 time = in_units_of(time, None, "picoseconds") 

412 cell_lengths = in_units_of(cell_lengths, None, "nanometers") 

413 cell_angles = in_units_of(cell_angles, None, "degrees") 

414 

415 # do typechecking and shapechecking on the arrays 

416 # this ensure_type method has a lot of options, but basically it lets 

417 # us validate most aspects of the array. Also, we can upconvert 

418 # on defficent ndim, which means that if the user sends in a single 

419 # frame of data (i.e. coordinates is shape=(n_atoms, 3)), we can 

420 # realize that. obviously the default mode is that they want to 

421 # write multiple frames at a time, so the coordinate shape is 

422 # (n_frames, n_atoms, 3) 

423 coordinates = ensure_type( 

424 coordinates, 

425 dtype=np.float32, 

426 ndim=3, 

427 name="coordinates", 

428 shape=(None, None, 3), 

429 can_be_none=False, 

430 warn_on_cast=False, 

431 add_newaxis_on_deficient_ndim=True, 

432 ) 

433 ( 

434 n_frames, 

435 n_atoms, 

436 ) = coordinates.shape[0:2] 

437 time = ensure_type( 

438 time, 

439 dtype=np.float32, 

440 ndim=1, 

441 name="time", 

442 shape=(n_frames,), 

443 can_be_none=True, 

444 warn_on_cast=False, 

445 add_newaxis_on_deficient_ndim=True, 

446 ) 

447 cell_lengths = ensure_type( 

448 cell_lengths, 

449 dtype=np.float32, 

450 ndim=2, 

451 name="cell_lengths", 

452 shape=(n_frames, 3), 

453 can_be_none=True, 

454 warn_on_cast=False, 

455 add_newaxis_on_deficient_ndim=True, 

456 ) 

457 cell_angles = ensure_type( 

458 cell_angles, 

459 dtype=np.float32, 

460 ndim=2, 

461 name="cell_angles", 

462 shape=(n_frames, 3), 

463 can_be_none=True, 

464 warn_on_cast=False, 

465 add_newaxis_on_deficient_ndim=True, 

466 ) 

467 

468 # if this is our first call to write(), we need to create the headers 

469 # and the arrays in the underlying HDF5 file 

470 if self._needs_initialization: 

471 self._initialize_headers( 

472 group_id=group_id, 

473 n_atoms=n_atoms, 

474 set_coordinates=True, 

475 set_time=(time is not None), 

476 set_cell=(cell_lengths is not None or cell_angles is not None), 

477 set_velocities=False, 

478 set_kineticEnergy=False, 

479 set_potentialEnergy=False, 

480 set_temperature=False, 

481 set_alchemicalLambda=False, 

482 ) 

483 self._needs_initialization = False 

484 

485 # we need to check that that the entries that the user is trying 

486 # to save are actually fields in OUR file 

487 

488 try: 

489 # try to get the nodes for all the fields that we have 

490 # which are not None 

491 names = [ 

492 f"coordinates_{group_id}", 

493 f"time_{group_id}", 

494 f"cell_angles_{group_id}", 

495 f"cell_lengths_{group_id}", 

496 ] 

497 for name in names: 

498 contents = locals()[name.replace(f"_{group_id}", "")] 

499 if contents is not None: 

500 self._get_node(where="/", name=name).append(contents) 

501 if contents is None: 

502 # for each attribute that they're not saving, we want 

503 # to make sure the file doesn't explect it 

504 try: 

505 self._get_node(where="/", name=name) 

506 raise AssertionError() 

507 except self.tables.NoSuchNodeError: 

508 pass 

509 except self.tables.NoSuchNodeError: 

510 raise ValueError( 

511 "The file that you're trying to save to doesn't " 

512 "contain the field %s. You can always save a new trajectory " 

513 "and have it contain this information, but I don't allow 'ragged' " 

514 "arrays. If one frame is going to have %s information, then I expect " 

515 "all of them to. So I can't save it for just these frames. Sorry " 

516 "about that :)" % (name, name) 

517 ) 

518 except AssertionError: 

519 raise ValueError( 

520 "The file that you're saving to expects each frame " 

521 "to contain %s information, but you did not supply it." 

522 "I don't allow 'ragged' arrays. If one frame is going " 

523 "to have %s information, then I expect all of them to. " % (name, name) 

524 ) 

525 

526 self._frame_index += n_frames 

527 self.flush() 

528 self.write_topology(group_id, topology) 

529 

530 def write_topology( 

531 self, 

532 group_id: str, 

533 topology_object: md.Topology, 

534 ) -> None: 

535 """Writes the topology into the group_id. 

536 

537 Args: 

538 group_id (str): The name of the group. Normally 'topology' is 

539 used for single traj HDF5 files. Can also be 'topology_<traj_num>', 

540 where <traj_num> is the traj_num of a trajectory. 

541 topology_object (md.Topology): The topology to put into the group. 

542 

543 """ 

544 _check_mode(self.mode, ("w", "a")) 

545 

546 try: 

547 node = self._handle.get_node("/", name=f"topology_{group_id}") 

548 except tables.NoSuchNodeError: 

549 pass 

550 else: 

551 if self.mode != "a": 

552 raise Exception( 

553 f"File already exists and has trajectory information. " 

554 f"Set `overwrite` to True to overwrite." 

555 ) 

556 self._handle.remove_node("/", name=f"topology_{group_id}") 

557 

558 # we want to be able to handle the openmm Topology object 

559 # here too, so if it's not an mdtraj topology we'll just guess 

560 # that it's probably an openmm topology and convert 

561 if not isinstance(topology_object, md.Topology): 

562 topology_object = md.Topology.from_openmm(topology_object) 

563 

564 try: 

565 topology_dict = {"chains": [], "bonds": []} 

566 

567 for chain in topology_object.chains: 

568 chain_dict = {"residues": [], "index": int(chain.index)} 

569 for residue in chain.residues: 

570 residue_dict = { 

571 "index": int(residue.index), 

572 "name": str(residue.name), 

573 "atoms": [], 

574 "resSeq": int(residue.resSeq), 

575 "segmentID": str(residue.segment_id), 

576 } 

577 

578 for atom in residue.atoms: 

579 try: 

580 element_symbol_string = str(atom.element.symbol) 

581 except AttributeError: 

582 element_symbol_string = "" 

583 

584 residue_dict["atoms"].append( 

585 { 

586 "index": int(atom.index), 

587 "name": str(atom.name), 

588 "element": element_symbol_string, 

589 } 

590 ) 

591 chain_dict["residues"].append(residue_dict) 

592 topology_dict["chains"].append(chain_dict) 

593 

594 for atom1, atom2 in topology_object.bonds: 

595 topology_dict["bonds"].append([int(atom1.index), int(atom2.index)]) 

596 

597 except AttributeError as e: 

598 raise AttributeError( 

599 "topology_object fails to implement the" 

600 "chains() -> residue() -> atoms() and bond() protocol. " 

601 "Specifically, we encountered the following %s" % e 

602 ) 

603 

604 # actually set the tables 

605 try: 

606 self._remove_node(where="/", name="topology") 

607 except self.tables.NoSuchNodeError: 

608 pass 

609 

610 data = json.dumps(topology_dict) 

611 if not isinstance(data, bytes): 

612 data = data.encode("ascii") 

613 

614 if self.tables.__version__ >= "3.0.0": 

615 self._handle.create_array( 

616 where="/", name=f"topology_{group_id}", obj=[data] 

617 ) 

618 else: 

619 self._handle.createArray( 

620 where="/", name=f"topology_{group_id}", object=[data] 

621 ) 

622 

623 def read(self, traj_num: int): 

624 # Third Party Imports 

625 from mdtraj.utils import in_units_of 

626 

627 def get_field(name, slice, out_units, can_be_none=True): 

628 try: 

629 node = self._get_node(where="/", name=name + f"_{traj_num}") 

630 data = node.__getitem__(slice) 

631 in_units = node.attrs.units 

632 if not isinstance(in_units, string_types): 

633 in_units = in_units.decode() 

634 data = in_units_of(data, in_units, out_units) 

635 return data 

636 except self.tables.NoSuchNodeError: 

637 if can_be_none: 

638 return None 

639 raise 

640 

641 out = { 

642 "coordinates": get_field( 

643 "coordinates", 

644 (slice(None), slice(None), slice(None)), 

645 out_units="nanometers", 

646 can_be_none=False, 

647 ), 

648 "time": get_field( 

649 "time", slice(None), out_units="picoseconds", can_be_none=False 

650 ), 

651 "cell_lengths": get_field( 

652 "cell_lengths", 

653 (slice(None), slice(None)), 

654 out_units="nanometers", 

655 can_be_none=False, 

656 ), 

657 "cell_angles": get_field( 

658 "cell_angles", 

659 (slice(None), slice(None)), 

660 out_units="degrees", 

661 can_be_none=False, 

662 ), 

663 } 

664 return out 

665 

666 def read_topology( 

667 self, 

668 group_id: str, 

669 ) -> md.Topology: 

670 # Third Party Imports 

671 import mdtraj.core.element as elem 

672 

673 try: 

674 raw = self._get_node("/", name=group_id)[0] 

675 if not isinstance(raw, string_types): 

676 raw = raw.decode() 

677 topology_dict = json.loads(raw) 

678 except self.tables.NoSuchNodeError: 

679 return None 

680 

681 topology = md.Topology() 

682 

683 for chain_dict in sorted( 

684 topology_dict["chains"], key=operator.itemgetter("index") 

685 ): 

686 chain = topology.add_chain() 

687 for residue_dict in sorted( 

688 chain_dict["residues"], key=operator.itemgetter("index") 

689 ): 

690 try: 

691 resSeq = residue_dict["resSeq"] 

692 except KeyError: 

693 resSeq = None 

694 warnings.warn( 

695 "No resSeq information found in HDF file, defaulting to zero-based indices" 

696 ) 

697 try: 

698 segment_id = residue_dict["segmentID"] 

699 except KeyError: 

700 segment_id = "" 

701 residue = topology.add_residue( 

702 residue_dict["name"], chain, resSeq=resSeq, segment_id=segment_id 

703 ) 

704 for atom_dict in sorted( 

705 residue_dict["atoms"], key=operator.itemgetter("index") 

706 ): 

707 try: 

708 element = elem.get_by_symbol(atom_dict["element"]) 

709 except KeyError: 

710 element = elem.virtual 

711 topology.add_atom(atom_dict["name"], element, residue) 

712 

713 atoms = list(topology.atoms) 

714 for index1, index2 in topology_dict["bonds"]: 

715 topology.add_bond(atoms[index1], atoms[index2]) 

716 

717 return topology 

718 

719 def read_trajs(self) -> Sequence[md.Trajectory]: 

720 # Third Party Imports 

721 from mdtraj.core.trajectory import Trajectory 

722 

723 nodes = [n.name for n in self._handle.list_nodes("/") if n.name != "CVs"] 

724 traj_nums = [] 

725 trajs = {} 

726 for node in nodes: 

727 traj_nums.extend(re.findall(r"\d+", node)) 

728 traj_nums = list(sorted(map(int, set(traj_nums)))) 

729 for traj_num in traj_nums: 

730 topology = self.read_topology(f"topology_{traj_num}") 

731 data = self.read(traj_num) 

732 trajs[traj_num] = Trajectory( 

733 xyz=data["coordinates"], 

734 topology=topology, 

735 time=data["time"], 

736 unitcell_lengths=data["cell_lengths"], 

737 unitcell_angles=data["cell_angles"], 

738 ) 

739 return trajs 

740 

741 def read_traj(self, traj_num: int) -> md.Trajectory: 

742 # Third Party Imports 

743 from mdtraj.core.trajectory import Trajectory 

744 

745 topology = self.read_topology(f"topology_{traj_num}") 

746 data = self.read(traj_num) 

747 traj = Trajectory( 

748 xyz=data["coordinates"], 

749 topology=topology, 

750 time=data["time"], 

751 unitcell_lengths=data["cell_lengths"], 

752 unitcell_angles=data["cell_angles"], 

753 ) 

754 return traj 

755 

756 

757class TrajEnsembleTsel: 

758 def __init__(self, other): 

759 self.other = other 

760 

761 def __getitem__(self, item: CanBeIndex) -> Union[TrajEnsemble, SingleTraj]: 

762 items = np.array(list(self.other.trajs_by_traj_num.keys())) 

763 if isinstance(item, (int, np.int64)): 

764 if item not in items: 

765 raise ValueError( 

766 f"No trajectories with traj_num {item} in TrajEnsemble {self.other} " 

767 f"with trajectories: {items}" 

768 ) 

769 return self.other.trajs_by_traj_num[item] 

770 elif isinstance(item, (list, np.ndarray)): 

771 if np.asarray(item).ndim == 1: 

772 idx = np.where(np.in1d(items, np.asarray(item)))[0] 

773 else: 

774 return self.other._pyemma_indexing_tsel(item) 

775 elif isinstance(item, slice): 

776 raise NotImplementedError("Currently can't index trajs with slice.") 

777 else: 

778 raise ValueError( 

779 f"The `tsel[]` method of `TrajEnsmeble` takes {CanBeIndex} types, " 

780 f"but {type(item)} was provided." 

781 ) 

782 if len(idx) == 0: 

783 raise ValueError( 

784 f"No trajs with traj_nums {item} in TrajEnsmble {self.other} " 

785 f"with trajectories: {items}" 

786 ) 

787 return self.other[idx] 

788 

789 

790class TrajEnsemble: 

791 """A fancy list of single trajectories. Topologies can be different across trajs. 

792 

793 Check out http://statisticalbiophysicsblog.org/?p=92 for why trajectory ensembles are awesome. 

794 

795 This class is a fancy list of :obj:`encodermap.trajinfo.info_single.SingleTraj``. 

796 Trajectories can have different topologies and will be grouped by 

797 the ``common_str`` argument. Each trajectory has its own unique ``traj_num``, 

798 which identifies it in the ensemble - even when the ensemble is sliced or 

799 subsampled. 

800 

801 Examples: 

802 >>> import encodermap as em 

803 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG") 

804 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF") 

805 

806 Addition of two :obj:`encodermap.trajinfo.info_single.SingleTraj` also creates an ensemble. 

807 

808 >>> trajs = traj1 + traj2 

809 >>> trajs # doctest: +ELLIPSIS 

810 <encodermap.TrajEnsemble object. Current backend is no_load. Containing 2 trajectories. Common str is ['1YUG', '1YUF']. Not containing any CVs...> 

811 

812 Indexing a :obj:`TrajEnsemble` returns a :obj:`encodermap.trajinfo.info_single.SingleTraj` 

813 based on its 0-based index. Think of the :obj:`TrajEnsmeble` as a list of 

814 :obj:`encodermap.trajinfo.info_single.SingleTraj`. But trajectories can also 

815 have ``traj_nums``, which do not have to adhere to ``[0, 1, 2, ...]``. This 

816 is similar to how a :obj:`pandas.DataFrame` offers indexing via ``.loc[]`` 

817 and ``.iloc[]`` (https://pandas.pydata.org/docs/user_guide/indexing.html#different-choices-for-indexing). 

818 For indexing trajs based on their ``traj_num``, you can use the ``.tsel[]`` 

819 accessor of the :obj:`TrajEnsmeble` 

820 

821 Examples: 

822 >>> import encodermap as em 

823 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG") 

824 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF") 

825 

826 Addition of two `SingleTraj` also creates an ensemble. 

827 

828 >>> trajs = traj1 + traj2 

829 >>> trajs.traj_nums 

830 [0, 1] 

831 

832 Change the ``traj_num`` of ``traj2`` 

833 

834 >>> trajs[1].traj_num = 4 

835 >>> trajs.traj_nums 

836 [0, 4] 

837 >>> trajs[1] # doctest: +ELLIPSIS 

838 <encodermap.SingleTraj object. Currently not in memory. Basename is '1YUF'. Not containing any CVs. Common string is '1YUF'. Object at ...> 

839 >>> trajs.tsel[4] # doctest: +ELLIPSIS 

840 <encodermap.SingleTraj object. Currently not in memory. Basename is '1YUF'. Not containing any CVs. Common string is '1YUF'. Object at ...> 

841 

842 :obj:`TrajEnsemble` supports fancy indexing. You can slice to your liking 

843 (``trajs[::5]`` returns a :obj:`TrajEnsemble` object that only consideres 

844 every fifth frame). Besides indexing by slices and integers, you can pass a 

845 2-dimensional :obj:`numpy.ndarray`. ``np.array([[0, 5], [1, 10], [5, 20]])`` 

846 will return a :obj:`TrajEnsemble` object with frame 5 of trajectory 0, frame 

847 10 of trajectory 1 and frame 20 of trajectory 5. 

848 

849 Examples: 

850 >>> import encodermap as em 

851 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG") 

852 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF") 

853 >>> trajs = traj1 + traj2 

854 >>> sel = trajs[[[0, 0], [0, 1], [0, 2], [1, 10]]] 

855 >>> sel # doctest: +ELLIPSIS 

856 <encodermap.TrajEnsemble object. Current backend is no_load. Containing 4 frames and 2 trajectories. Common str is...> 

857 

858 

859 The :obj:`TrajEnsemble` class also is an iterator to iterate over trajectores. 

860 Besides plain iteration, the :obj:`TrajEnsmeble` also offers alternate iterators. 

861 The ``itertrajs()`` iterator returns a two-tuple of ``traj_num`` and ``traj``. 

862 The ``iterframes()`` iterator returns a three-tuple of ``traj_num``, 

863 ``frame_num``, and ``traj``. 

864 

865 Examples: 

866 >>> import encodermap as em 

867 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG") 

868 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF") 

869 >>> trajs = traj1 + traj2 

870 >>> trajs[1].traj_num = 4 

871 >>> for traj_num, traj in trajs.itertrajs(): 

872 ... print(traj_num, traj.n_frames) 

873 0 15 

874 4 16 

875 >>> for traj_num, frame_num ,traj in trajs.subsample(10).iterframes(): 

876 ... print(traj_num, frame_num, traj.n_frames) 

877 0 0 1 

878 0 10 1 

879 4 0 1 

880 4 10 1 

881 

882 The :obj:`TrajEnsemble` has multiple alternative constructors. The 

883 :meth:`with_overwrite_trajnums` constructor fixes inhomogeneous sequences of 

884 :obj:`encodermap.trajinfo.info_single.SingleTraj` and :obj:`TrajEnsemble`. 

885 

886 Examples: 

887 >>> import encodermap as em 

888 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG", traj_num=0) 

889 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF", traj_num=0) 

890 >>> trajs = em.TrajEnsemble([traj1, traj2]) # doctest: +IGNORE_EXCEPTION_DETAIL, +ELLIPSIS, +NORMALIZE_WHITESPACE 

891 Traceback (most recent call last): 

892 ... 

893 Exception: The `traj_num` attributes of the provided 2 `SingleTraj`s is not unique, the `traj_num` 0 occurs 2 times. This can happen, if you use `SingleTraj`s, that are already part of a `TrajEnsemble`. To create copies of the `SingleTraj`s and overwrite their `traj_num`s, use the `with_overwrite_trajnums()` constructor. 

894 >>> trajs = em.TrajEnsemble.with_overwrite_trajnums(traj1, traj2) 

895 >>> trajs # doctest: +ELLIPSIS 

896 <encodermap.TrajEnsemble...> 

897 

898 The :meth:`from_dataset` constructor can be used to load an ensemble from 

899 an ``.h5`` file 

900 

901 Examples: 

902 >>> import encodermap as em 

903 >>> from tempfile import TemporaryDirectory 

904 >>> traj1 = em.SingleTraj.from_pdb_id("1YUG") 

905 >>> traj2 = em.SingleTraj.from_pdb_id("1YUF") 

906 >>> trajs = em.TrajEnsemble([traj1, traj2]) 

907 >>> with TemporaryDirectory() as td: 

908 ... trajs.save(td + "/trajs.h5") 

909 ... new = em.TrajEnsemble.from_dataset(td + "/trajs.h5") 

910 ... print(new) # doctest: +ELLIPSIS 

911 encodermap.TrajEnsemble object. Current backend is no_load. Containing 2 trajectories. Common str is...Not containing any CVs. 

912 

913 Attributes: 

914 CVs (dict[str, np.ndarray]): The collective variables of the ``SingleTraj`` 

915 classes. Only CVs with matching names in all ``SingleTraj`` classes 

916 are returned. The data is stacked along a hypothetical time axis 

917 along the trajs. 

918 _CVs (xarray.Dataset): The same data as in CVs but with labels. 

919 Additionally, the xarray is not stacked along the time axis. 

920 It contains an extra dimension for trajectories. 

921 n_trajs (int): Number of individual trajectories in this class. 

922 n_frames (int): Number of frames, sum over all trajectories. 

923 locations (list[str]): A list with the locations of the trajectories. 

924 top (list[mdtraj.Topology]): A list with the reference pdb for each trajecotry. 

925 basenames (list[str]): A list with the names of the trajecotries. 

926 The leading path and the file extension is omitted. 

927 name_arr (np.ndarray): An array with ``len(name_arr) == n_frames``. 

928 This array keeps track of each frame in this object by identifying each 

929 frame with a filename. This can be useful, when frames are mixed inside 

930 a :obj:`TrajEnsemble` class. 

931 

932 """ 

933 

934 def __init__( 

935 self, 

936 trajs: Union[ 

937 Sequence[str], 

938 Sequence[Path], 

939 Sequence[md.Trajectory], 

940 Sequence[SingleTraj], 

941 ], 

942 tops: Union[None, Sequence[str], Sequence[Path]] = None, 

943 backend: Literal["mdtraj", "no_load"] = "no_load", 

944 common_str: Optional[Sequence[str]] = None, 

945 basename_fn: Optional[Callable[[str], str]] = None, 

946 traj_nums: Optional[Sequence[int]] = None, 

947 custom_top: Optional[CustomAAsDict] = None, 

948 ) -> None: 

949 """Instantiate the :obj:`TrajEnsmeble` class with two lists of files. 

950 

951 Args: 

952 trajs (Union[Sequence[str], Sequence[md.Trajectory], 

953 Sequence[SingleTraj], Sequence[Path]]): List of strings with 

954 paths to trajectories. Can also be a list of md.Trajectory or 

955 em.SingleTraj. 

956 tops (Optional[list[str]]): List of strings with paths to reference pdbs. 

957 backend (str, optional): Choose the backend to load trajectories: 

958 - 'mdtraj' uses mdtraj, which loads all trajectories into RAM. 

959 - 'no_load' creates an empty trajectory object. 

960 Defaults to 'no_load', which makes the instantiation of large 

961 ensembles fast and RAM efficient. 

962 common_str (list[str], optional): If you want to include trajectories with 

963 different topology. The common string is used to pair traj-files 

964 (``.xtc, .dcd, .lammpstrj, ...``) with their topology 

965 (``.pdb, .gro, ...``). The common-string should be a substring 

966 of matching traj and topology files. 

967 basename_fn (Union[None, Callable[[str], str], optional): A function 

968 to apply to the trajectory file path string to return the basename 

969 of the trajectory. If None is provided, the filename without 

970 extension will be used. When all files are named the same and 

971 the folder they're in defines the name of the trajectory, you 

972 can supply ``lambda x: split('/')[-2]`` as this argument. 

973 Defaults to None. 

974 custom_top: Optional[CustomAAsDict]: An instance of the 

975 :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology` or a 

976 dictionary that can be made into such. 

977 

978 """ 

979 # defaults 

980 # Local Folder Imports 

981 from .info_single import SingleTraj 

982 

983 # check if h5file might be a complete dataset 

984 if isinstance(trajs, (str, Path)): 

985 if Path(trajs).suffix == ".h5": 

986 return TrajEnsemble.from_dataset(trajs, basename_fn=basename_fn) 

987 

988 if tops == []: 

989 raise Exception( 

990 f"Your list of topology files is empty: {tops=}. Pass None, if " 

991 f"your trajectories are all .pdb/.gro files." 

992 ) 

993 

994 self.backend = backend 

995 

996 # custom topology to load dihedral angles 

997 self._custom_top = custom_top 

998 

999 # set the trajnums 

1000 if traj_nums is not None: 

1001 # Standard Library Imports 

1002 from copy import deepcopy 

1003 

1004 if not len(traj_nums) == len(trajs): 

1005 raise Exception( 

1006 f"Uneven length of `traj_nums` ({len(traj_nums)} " 

1007 f"and `trajs` ({len(trajs)}) provided." 

1008 ) 

1009 if all([isinstance(t, SingleTraj) for t in trajs]): 

1010 trajs_ = [] 

1011 for n, t in zip(traj_nums, trajs): 

1012 t = deepcopy(t) 

1013 t.traj_num = n 

1014 trajs_.append(t) 

1015 trajs = trajs_ 

1016 

1017 # make sure, that traj_nums are not duplicated 

1018 elif all([isinstance(t, SingleTraj) for t in trajs]) and isinstance( 

1019 trajs, (list, tuple) 

1020 ): 

1021 if any([t.traj_num is None for t in trajs]) and any( 

1022 [isinstance(t.traj_num, int) for t in trajs] 

1023 ): 

1024 raise Exception( 

1025 f"The `SingleTraj`s you provided have bad `traj_num`s " 

1026 f"one has `None`, the others have int: {[t.traj_num for t in trajs]}" 

1027 ) 

1028 if not all([(i.traj_num is None) for i in trajs]): 

1029 uniques, counts = np.unique( 

1030 np.asarray([t.traj_num for t in trajs]), return_counts=True 

1031 ) 

1032 if np.any(counts > 1): 

1033 ex_num = uniques[np.argmax(counts)] 

1034 raise Exception( 

1035 f"The `traj_num` attributes of the provided {len(trajs)} `SingleTraj`s are " 

1036 f"not unique, the `traj_num` {ex_num} occurs {np.max(counts)} times. " 

1037 f"This can happen, if you use `SingleTraj`s, that are already part of " 

1038 f"a `TrajEnsemble`. To create copies of the `SingleTraj`s and over" 

1039 f"write their `traj_num`s, use the `with_overwrite_trajnums()` constructor." 

1040 ) 

1041 trajs = list(sorted(trajs, key=lambda x: x.traj_num)) 

1042 

1043 # basename function 

1044 if basename_fn is None: 

1045 basename_fn = lambda x: os.path.basename(x).split(".")[0] 

1046 self.basename_fn = basename_fn 

1047 

1048 # common string 

1049 if common_str is None: 

1050 common_str = [] 

1051 if isinstance(common_str, str): 

1052 self.common_str = [common_str] 

1053 else: 

1054 self.common_str = common_str 

1055 

1056 # loading with setters 

1057 if tops is None: 

1058 tops = [] 

1059 self._top_files = tops 

1060 if all([isinstance(traj, str) for traj in trajs]): 

1061 if self._top_files == [] and all( 

1062 ["." + top.split(".")[-1] in _TOPOLOGY_EXTS for top in trajs] 

1063 ): 

1064 self._top_files = trajs 

1065 if all([isinstance(traj, Path) for traj in trajs]) and not tops: 

1066 self._top_files = [str(t) for t in trajs] 

1067 if isinstance(tops, str): 

1068 self._top_files = [tops] 

1069 self.traj_files = trajs 

1070 

1071 # set the traj nums afterwards 

1072 if traj_nums is not None: 

1073 for i, traj in zip(traj_nums, self.trajs): 

1074 traj.traj_num = i 

1075 

1076 @classmethod 

1077 def with_overwrite_trajnums( 

1078 cls, 

1079 *trajs: Union[TrajEnsemble, "SingleTraj"], 

1080 ) -> TrajEnsemble: 

1081 """Creates a :obj:`TrajEnsemble` by copying the provided 

1082 :obj:`encodermap.trajinfo.info_single.SingleTraj` instances and 

1083 changing their ``traj_num`` attribute to adhere to ``[0, 1, 2, ...]``. 

1084 

1085 Args: 

1086 trajs (Sequence[SingleTraj]): The sequence of trajs. 

1087 

1088 Returns: 

1089 TrajEnsemble: A :obj:`TrajEnsemble` instance. 

1090 

1091 """ 

1092 # Standard Library Imports 

1093 from copy import deepcopy 

1094 

1095 # Local Folder Imports 

1096 from .info_single import SingleTraj 

1097 

1098 new_trajs = [] 

1099 i = 0 

1100 for t in trajs: 

1101 if isinstance(t, SingleTraj): 

1102 t = deepcopy(t) 

1103 t.traj_num = i 

1104 new_trajs.append(t) 

1105 i += 1 

1106 elif isinstance(t, TrajEnsemble): 

1107 for traj in t: 

1108 traj = deepcopy(traj) 

1109 traj.traj_num = i 

1110 new_trajs.append(traj) 

1111 i += 1 

1112 else: 

1113 raise TypeError( 

1114 f"Classmethod `with_overwrite_trajnums` can only accept `SingleTraj` " 

1115 f"and `TrajEnsemble`, but {t} is not an instance of either." 

1116 ) 

1117 return cls(new_trajs) 

1118 

1119 @classmethod 

1120 def from_textfile( 

1121 cls, 

1122 fname: Union[str, Path], 

1123 basename_fn: Optional[Callable[[str], str]] = None, 

1124 ) -> TrajEnsemble: 

1125 """Creates a :obj:`TrajEnsemble` object from a textfile. 

1126 

1127 The textfile needs to be space-separated with two or three columns: 

1128 - Column 1: 

1129 The trajectory file. 

1130 - Column 2: 

1131 The corresponding topology file (If you are using ``.h5`` trajs, 

1132 column 1 and 2 will be identical, but column 2 needs to be there 

1133 nonetheless). 

1134 - Column 3: 

1135 The common string of the trajectory. This column can be left 

1136 out, which will result in an :obj:`TrajEnsemble` without common 

1137 strings. 

1138 

1139 Args: 

1140 fname (Union[str, Path]): File to be read. 

1141 basename_fn (Union[None, Callable[[str], str]], optional): A function 

1142 to apply to the ``traj_file`` string to return the basename of 

1143 the trajectory. If None is provided, the filename without 

1144 extension will be used. When all files are named the same and 

1145 the folder they're in defines the name of the trajectory, you 

1146 can supply ``lambda x: split('/')[-2]`` as this argument. 

1147 Defaults to None. 

1148 

1149 Returns: 

1150 TrajEnsemble: A :obj:`TrajEnsemble` instance. 

1151 

1152 """ 

1153 # Local Folder Imports 

1154 from ..trajinfo import info_single 

1155 

1156 traj_files = [] 

1157 top_files = [] 

1158 common_str = [] 

1159 traj_nums = [] 

1160 

1161 with open(fname, "r") as f: 

1162 for row in f: 

1163 traj_files.append(row.split()[0]) 

1164 top_files.append(row.split()[1]) 

1165 try: 

1166 traj_nums.append(int(row.split()[2])) 

1167 except ValueError: 

1168 traj_nums.append(None) 

1169 try: 

1170 common_str.append(row.split()[3]) 

1171 except IndexError: 

1172 common_str.append("") 

1173 

1174 trajs = [] 

1175 for i, (traj_file, top_file, cs, traj_num) in enumerate( 

1176 zip(traj_files, top_files, common_str, traj_nums) 

1177 ): 

1178 trajs.append(info_single.SingleTraj(traj_file, top_file, cs, traj_num)) 

1179 

1180 return cls( 

1181 trajs, common_str=np.unique(common_str).tolist(), basename_fn=basename_fn 

1182 ) 

1183 

1184 @classmethod 

1185 def from_dataset( 

1186 cls, 

1187 fname: Union[str, Path], 

1188 basename_fn: Optional[Callable[[str], str]] = None, 

1189 ) -> TrajEnsemble: 

1190 # Local Folder Imports 

1191 from .info_single import SingleTraj 

1192 

1193 traj_nums = [] 

1194 with h5py.File(fname) as h5file: 

1195 for key in h5file.keys(): 

1196 if key == "CVs": 

1197 continue 

1198 traj_nums.extend(re.findall(r"\d+", key)) 

1199 traj_nums = list(sorted(map(int, set(traj_nums)))) 

1200 

1201 trajs = [] 

1202 for traj_num in traj_nums: 

1203 trajs.append( 

1204 SingleTraj( 

1205 traj=fname, 

1206 top=fname, 

1207 traj_num=traj_num, 

1208 basename_fn=basename_fn, 

1209 ) 

1210 ) 

1211 common_str = list(set([t.common_str for t in trajs])) 

1212 newclass = cls(trajs=trajs, common_str=common_str, basename_fn=basename_fn) 

1213 return newclass 

1214 

1215 def load_custom_topology( 

1216 self, 

1217 custom_top: Optional[Union[CustomTopology, CustomAAsDict]] = None, 

1218 ) -> None: 

1219 """Loads a custom_topology from a `CustomTopology` class or a dict. 

1220 

1221 See Also: 

1222 `CustomTopology` 

1223 

1224 Args: 

1225 custom_top: Optional[Union[CustomTopology, CustomAAsDict]]: An instance of the 

1226 `CustomTopology` class or a dictionary that can be made into such. 

1227 

1228 """ 

1229 for traj in self.trajs: 

1230 traj.load_custom_topology(custom_top) 

1231 

1232 @property 

1233 def tsel(self): 

1234 return TrajEnsembleTsel(self) 

1235 

1236 def del_featurizer(self) -> None: 

1237 """Deletes the current instance of ``self.featurizer``.""" 

1238 if hasattr(self, "_featurizer"): 

1239 del self._featurizer 

1240 

1241 @property 

1242 def featurizer(self): 

1243 # Local Folder Imports 

1244 from ..loading.featurizer import Featurizer 

1245 

1246 if not hasattr(self, "_featurizer"): 

1247 self._featurizer = Featurizer(self) 

1248 return self._featurizer 

1249 

1250 @property 

1251 def traj_files(self) -> list[str]: 

1252 """list: A list of the traj_files of the individual SingleTraj classes.""" 

1253 return self._traj_files 

1254 

1255 @property 

1256 def top_files(self) -> list[str]: 

1257 """list: Returns minimal set of topology files. 

1258 

1259 If yoy want a list of top files with the same 

1260 length as self.trajs use self._top_files and 

1261 self._traj_files. 

1262 

1263 """ 

1264 return list(dict.fromkeys(self._top_files)) 

1265 

1266 @traj_files.setter 

1267 def traj_files(self, trajs): 

1268 # Local Folder Imports 

1269 from ..trajinfo import info_single 

1270 

1271 traj_nums = np.arange(len(trajs)) 

1272 # fill this lists 

1273 self.trajs = [] 

1274 

1275 if all([isinstance(traj, Path) for traj in trajs]): 

1276 trajs = [str(traj) for traj in trajs] 

1277 

1278 if all([isinstance(i, md.Trajectory) for i in trajs]): 

1279 self.backend = "mdtraj" 

1280 self.trajs = [ 

1281 info_single.SingleTraj( 

1282 traj, 

1283 traj_num=i, 

1284 basename_fn=self.basename_fn, 

1285 custom_top=self._custom_top, 

1286 ) 

1287 for i, traj in zip(traj_nums, trajs) 

1288 ] 

1289 elif all([i.__class__.__name__ == "SingleTraj" for i in trajs]): 

1290 self.trajs = trajs 

1291 self._top_files = [traj.top_file for traj in self.trajs] 

1292 self._traj_files = [traj.traj_file for traj in self.trajs] 

1293 # check backends and common str 

1294 if ( 

1295 not all([traj.backend == "no_load" for traj in trajs]) 

1296 or self.backend == "mdtraj" 

1297 ): 

1298 (traj.load_traj() for traj in trajs) 

1299 for i, traj in zip(traj_nums, trajs): 

1300 if traj.traj_num is None: 

1301 traj.traj_num = i 

1302 if traj._CVs: 

1303 traj._CVs = traj._CVs.assign_coords(traj_num=[i]) 

1304 elif all([isinstance(i, str) for i in trajs]) and self.top_files: 

1305 # find common_str matches in top_files and traj_files 

1306 ( 

1307 self._traj_files, 

1308 self._top_files, 

1309 self._common_str, 

1310 ) = get_full_common_str_and_ref(trajs, self._top_files, self.common_str) 

1311 for i, t, top, cs in zip( 

1312 traj_nums, self._traj_files, self._top_files, self._common_str 

1313 ): 

1314 t = info_single.SingleTraj( 

1315 traj=t, 

1316 top=top, 

1317 backend=self.backend, 

1318 common_str=cs, 

1319 traj_num=i, 

1320 basename_fn=self.basename_fn, 

1321 custom_top=self._custom_top, 

1322 ) 

1323 self.trajs.append(t) 

1324 elif all([isinstance(i, str) for i in trajs]) and not self.top_files: 

1325 for i, traj_file in zip(traj_nums, trajs): 

1326 self.trajs.append( 

1327 info_single.SingleTraj( 

1328 traj=traj_file, 

1329 basename_fn=self.basename_fn, 

1330 traj_num=i, 

1331 custom_top=self._custom_top, 

1332 ) 

1333 ) 

1334 else: 

1335 raise TypeError( 

1336 "The objects in the list are not of the correct type or inconsistent. " 

1337 f"You provided {[c.__class__.__name__ for c in trajs]}. " 

1338 "Please provide a list of `str`, list of `mdtraj.Trajectory` or list of `SingleTraj`." 

1339 ) 

1340 

1341 @property 

1342 def top(self) -> list[md.Topology]: 

1343 """list: Returns a minimal set of mdtraj.Topologies. 

1344 

1345 If all trajectories share the same topology a list 

1346 with len 1 will be returned. 

1347 

1348 """ 

1349 out = [] 

1350 for traj in self.trajs: 

1351 try: 

1352 if traj.top not in out: 

1353 out.append(traj.top) 

1354 except IOError as e: 

1355 if "no such file" in str(e).lower(): 

1356 raise e 

1357 raise Exception( 

1358 f"I have no idea how this IOError can occur. {self.trajs=}." 

1359 ) from e 

1360 return out 

1361 

1362 @property 

1363 def trajs_by_top(self) -> dict[md.Topology, TrajEnsemble]: 

1364 """dict[md.Topology, TrajEnsemble]: Returns the trajs in `self` ordered by top. 

1365 

1366 If all trajectories share the same topology, a dict with 

1367 one key will be returned. 

1368 

1369 """ 

1370 out = {} 

1371 for traj in self.trajs: 

1372 out.setdefault(traj.top, []).append(traj) 

1373 out = {k: TrajEnsemble(v) for k, v in out.items()} 

1374 for v in out.values(): 

1375 v.common_str = list(set([t.common_str for t in v])) 

1376 return out 

1377 

1378 @property 

1379 def trajs_by_common_str(self) -> dict[Union[None, str], TrajEnsemble]: 

1380 """dict[str, TrajEnsemble]: Returns the trajs in `self` ordered by top. 

1381 

1382 If all trajectories share the same common_str, a dict with 

1383 one key will be returned. As the common_str can be None, None can also 

1384 occur as a key in this dict. 

1385 

1386 """ 

1387 out = {} 

1388 for traj in self.trajs: 

1389 out.setdefault(traj.common_str, []).append(traj) 

1390 out = {k: TrajEnsemble(v) for k, v in out.items()} 

1391 return out 

1392 

1393 def sidechain_info(self) -> dict[int, dict[int, Sequence[int]]]: 

1394 """Indices used for the AngleDihedralCartesianEncoderMap class to 

1395 allow training with multiple different sidechains. 

1396 

1397 Returns: 

1398 dict[str, Sequence[int]]: The indices. The key '-1' is used for 

1399 the hypothetical convex hull of all feature spaces (the output of 

1400 the tensorflow model). The other keys match the common_str of the 

1401 trajs. 

1402 

1403 Raises: 

1404 Exception: When the common_strings and topologies are not 

1405 aligned. An exception is raised. Aligned means that all trajs 

1406 with the same common_str should possess the same topology. 

1407 

1408 """ 

1409 # make sure no clustal w has not been loaded 

1410 if any([hasattr(t, "clustal_w") for t in self.trajs]): 

1411 raise NotImplementedError( 

1412 f"This is currently not supported for TrajEsnembles with " 

1413 f"clustal_w alignments." 

1414 ) 

1415 else: 

1416 max_residues = max([t.n_residues for t in self]) 

1417 

1418 # make sure CVs are loaded and contain the appropriate values 

1419 should_be = { 

1420 "central_cartesians", 

1421 "central_dihedrals", 

1422 "central_distances", 

1423 "central_angles", 

1424 "side_dihedrals", 

1425 "side_cartesians", 

1426 "side_distances", 

1427 "side_angles", 

1428 } 

1429 diff = should_be - set(self._CVs.data_vars.keys()) 

1430 if len(diff) > 0: 

1431 raise Exception( 

1432 f"The TrajEnsemble misses these CVs to calculate the sidechain_info: " 

1433 f"{list(diff)}. Please load them with `trajs.load_CVs({list(diff)})`." 

1434 ) 

1435 

1436 # make sure we are using an ensemble with generic indices 

1437 forbidden_names = set( 

1438 [residue.name for traj in self for residue in traj.top.residues] 

1439 ) 

1440 if ( 

1441 len( 

1442 ( 

1443 offending := [ 

1444 label 

1445 for label in self._CVs.central_distances.coords[ 

1446 "CENTRAL_DISTANCES" 

1447 ].values 

1448 if any(f in label for f in forbidden_names) 

1449 ] 

1450 ) 

1451 ) 

1452 > 0 

1453 ): 

1454 raise Exception( 

1455 f"The CVs in this TrajEnsemble were not loaded with the `ensemble=True` " 

1456 f"keyword in `trajs.load_CVs()`. Finding the sidechain_info in " 

1457 f"such a set of CVs is not possible. The offending labels {offending} " 

1458 f"contain residue names, which should not occur if CVs were loaded " 

1459 f"with `ensemble=True`." 

1460 ) 

1461 

1462 # the key -1 is the feature hull, telling tensorflow 

1463 # how to create the branched chain of backbone and sidechains 

1464 # i.e. how many sidechains there are per residue max 

1465 # the other keys correspond to which sidechain atoms are non nan and 

1466 # can be used when calculating the distance matrices per different atom 

1467 out = {-1: {}} 

1468 

1469 # the feature hull 

1470 max_sidechains = self._CVs.side_dihedrals.coords["SIDE_DIHEDRALS"].values 

1471 for residx in range(1, max_residues + 1): 

1472 labels = [l for l in max_sidechains if l.endswith(" " + str(residx))] 

1473 out[-1][residx] = len(labels) 

1474 

1475 # for every traj 

1476 for traj in self.trajs: 

1477 assert ( 

1478 traj.traj_num not in out 

1479 ), f"This traj has the same `traj_num` as another traj." 

1480 out[traj.traj_num] = {} 

1481 for residx in range(1, max_residues + 1): 

1482 labels = [ 

1483 l 

1484 for l in traj._CVs.side_dihedrals.coords["SIDE_DIHEDRALS"].values 

1485 if l.endswith(" " + str(residx)) 

1486 ] 

1487 first_frame = traj._CVs.coords["frame_num"].min() 

1488 length = ( 

1489 traj._CVs.side_dihedrals.sel( 

1490 frame_num=first_frame, SIDE_DIHEDRALS=labels 

1491 ) 

1492 .dropna(dim="SIDE_DIHEDRALS") 

1493 .sizes["SIDE_DIHEDRALS"] 

1494 ) 

1495 out[traj.traj_num][residx] = length 

1496 return out 

1497 

1498 @property 

1499 def trajs_by_traj_num(self) -> dict[int, SingleTraj]: 

1500 out = {} 

1501 for traj in self: 

1502 out[traj.traj_num] = traj 

1503 return out 

1504 

1505 @property 

1506 def id(self) -> np.ndarray: 

1507 """np.ndarray: Duplication of self.index_arr""" 

1508 return self.index_arr 

1509 

1510 @property 

1511 def n_residues(self) -> int: 

1512 """list: List of number of residues of the SingleTraj classes""" 

1513 return [traj.n_residues for traj in self.trajs] 

1514 

1515 @property 

1516 def basenames(self) -> list[str]: 

1517 """list: List of the basenames in the Info single classes.""" 

1518 return [traj.basename for traj in self.trajs] 

1519 

1520 @property 

1521 def traj_nums(self) -> list[int]: 

1522 """list: Number of info single classes in self.""" 

1523 return [traj.traj_num for traj in self.trajs] 

1524 

1525 @property 

1526 def n_trajs(self) -> int: 

1527 """int: Number of trajectories in this ensemble.""" 

1528 return len(self.trajs) 

1529 

1530 def to_alignment_query(self) -> str: 

1531 """A string, that cen be put into sequence alignment software.""" 

1532 for cs, trajs in self.trajs_by_common_str.items(): 

1533 assert len(trajs.top) == 1, ( 

1534 f"Can't provide a query for a `TrajEnsemble`, where a common string " 

1535 f"has multiple topologies. In this case, the common string '{cs}' " 

1536 f"has these topologies: {trajs.top}. When you are unhappy with how " 

1537 f"EncoderMap automatically applies common strings to trajectory " 

1538 f"filenames, keep in mind that you can always generate a `TrajEnsemble` " 

1539 f"from multiple `TrajEnsembles`. You can do: \n" 

1540 f"trajs1 = em.load(\n" 

1541 f" traj_files1,\n" 

1542 f" top_files1,\n" 

1543 f" common_str=['trajs1'],\n" 

1544 f")\n" 

1545 f"trajs2 = em.load(\n" 

1546 f" traj_files2,\n" 

1547 f" top_files2,\n" 

1548 f" common_str=['trajs2'],\n" 

1549 f")\n" 

1550 f"trajs = trajs1 + trajs2\n" 

1551 f"to force a `TrajEnsemble` to adhere to your common strings." 

1552 ) 

1553 out = "" 

1554 for cs, trajs in self.trajs_by_common_str.items(): 

1555 seq = trajs.top[0].to_fasta() 

1556 for j, s in enumerate(seq): 

1557 add = f">{cs}n{s}\n" 

1558 return out 

1559 

1560 def parse_clustal_w_alignment(self, aln: str) -> None: 

1561 """Parse an alignment in ClustalW format and add the info to the trajectories. 

1562 

1563 Args: 

1564 aln (str): The alignment in ClustalW format. 

1565 

1566 """ 

1567 self.clustal_w = aln 

1568 # remove empty lines 

1569 aln = "\n".join( 

1570 list(filter(lambda x: not re.match(r"^\s*$", x), aln.splitlines())) 

1571 ) 

1572 

1573 # every three lines represent one trajectory 

1574 lines = aln.splitlines() 

1575 n_lines = len(lines) 

1576 assert ( 

1577 n_lines % (self.n_trajs + 1) == 0 

1578 or n_lines % (len(self.common_str) + 1) == 0 

1579 ), ( 

1580 f"The CLUSTAL W aln string, that you provided has the wrong number of " 

1581 f"lines. I've expected to receive a multiple of {self.n_trajs + 1} (" 

1582 f"which is the number of trajs ({self.n_trajs}) plus one for the score-" 

1583 f"characters ' ', '.', ':', '*'), but the number of provided lines was " 

1584 f"{n_lines}." 

1585 ) 

1586 

1587 if "|" in aln: 

1588 for i, (_, sub_trajs) in enumerate(self.trajs_by_top.items()): 

1589 for cs, trajs in sub_trajs.trajs_by_common_str.items(): 

1590 for traj in trajs: 

1591 for j, chain in enumerate(traj.top.chains): 

1592 search = ( 

1593 f"{cs}|TrajNum_{traj.traj_num}Topology_{i}Chain_{j}" 

1594 ) 

1595 data = "" 

1596 for line in lines: 

1597 if line.startswith(search): 

1598 data += line.split()[-1] 

1599 assert ( 

1600 test := len(data.replace("-", "")) 

1601 ) == traj.n_residues, ( 

1602 f"The CLUSTAL W sequence {data} with {test} one-letter " 

1603 f"residues has not the same number of residues as trajectory " 

1604 f"{traj}, which has {traj.n_residues}" 

1605 ) 

1606 traj.clustal_w = data 

1607 else: 

1608 for cs, trajs in self.trajs_by_common_str.items(): 

1609 search = cs 

1610 data = "" 

1611 for line in lines: 

1612 if line.startswith(search): 

1613 data += line.split()[-1] 

1614 assert (test := len(data.replace("-", ""))) == trajs[0].n_residues, ( 

1615 f"The CLUSTAL W sequence {data} with {test} one-letter " 

1616 f"residues has not the same number of residues as trajectory " 

1617 f"{trajs[0]}, which has {trajs[0].n_residues}" 

1618 ) 

1619 for traj in trajs: 

1620 traj.clustal_w = data 

1621 

1622 def del_CVs(self, CVs: Optional[Sequence[str]] = None) -> None: 

1623 """Deletes all CVs in all trajs. Does not affect the files.""" 

1624 if CVs is None: 

1625 for traj in self.trajs: 

1626 traj.del_CVs() 

1627 else: 

1628 if not isinstance(CVs, (list, tuple)): 

1629 CVs = [CVs] 

1630 remove = deepcopy(CVs) 

1631 for CV in CVs: 

1632 remove.append(f"{CV}_feature_indices") 

1633 for traj in self.trajs: 

1634 traj._CVs = traj._CVs.drop_vars(remove, errors="ignore") 

1635 

1636 @property 

1637 def _CVs(self) -> xr.Dataset: 

1638 """xarray.Dataset: Returns x-array Dataset of matching CVs. stacked 

1639 along the trajectory-axis.""" 

1640 # Local Folder Imports 

1641 from .trajinfo_utils import trajs_combine_attrs 

1642 

1643 ds = xr.combine_nested( 

1644 [traj._CVs for traj in self.trajs], 

1645 concat_dim="traj_num", 

1646 compat="broadcast_equals", 

1647 fill_value=np.nan, 

1648 coords="all", 

1649 join="outer", 

1650 combine_attrs=trajs_combine_attrs, 

1651 ) 

1652 

1653 # if ensemble we don't need to reorder labels. That was already done 

1654 # by the Featurizer. We know if we have an ensemble if the trajs 

1655 # have features with nans (except the "feature_indices" dataarrays 

1656 if any( 

1657 [ 

1658 np.any(np.isnan(v.values)) 

1659 for traj in self 

1660 for n, v in traj._CVs.data_vars.items() 

1661 if "feature_indices" not in n 

1662 ] 

1663 ): 

1664 return ds 

1665 

1666 # sort the combined arrays 

1667 new_label_order = {} 

1668 non_indices_data_vars = [k for k in ds.keys() if "feature_indices" not in k] 

1669 for k in non_indices_data_vars: 

1670 if (ind_k := f"{k}_feature_indices") not in ds: 

1671 continue 

1672 argsort = [] 

1673 da = ds[ind_k] 

1674 

1675 # ResidueMinDistanceFeature 

1676 if "RES_NO" in da.coords: 

1677 continue 

1678 

1679 try: 

1680 feature_axis_name = da.attrs["feature_axis"] 

1681 except KeyError as e: 

1682 raise Exception(f"There is no feature_axis attribute in {da=}") from e 

1683 labels = da.coords[feature_axis_name].values 

1684 if "ATOM_NO" not in da.coords and "frame_num" not in da.coords: 

1685 iterable = da.values.T 

1686 elif "ATOM_NO" not in da.coords and "frame_num" in da.coords: 

1687 iterable = ( 

1688 da.stack({"frame": ("traj_num", "frame_num")}) 

1689 .transpose("frame", ...) 

1690 .dropna("frame", how="all") 

1691 ) 

1692 iterable = iterable.transpose(feature_axis_name, "frame").values 

1693 elif "ATOM_NO" in da.coords and "frame_num" not in da.coords: 

1694 iterable = da.transpose(feature_axis_name, "traj_num", "ATOM_NO").values 

1695 else: 

1696 iterable = ( 

1697 da.stack({"frame": ("traj_num", "frame_num")}) 

1698 .transpose("frame", ...) 

1699 .dropna("frame", how="all") 

1700 ) 

1701 iterable = iterable.transpose(feature_axis_name, "frame", "ATOM_NO") 

1702 for i, (row, label) in enumerate(zip(iterable, labels)): 

1703 if ( 

1704 "ATOM_NO" not in da.coords and "RES_NO" not in da.coords 

1705 ): # selection feature 

1706 row = row[~np.isnan(row)] 

1707 uniques, counts = np.unique(row, return_counts=True) 

1708 ind = uniques[np.argmax(counts)] 

1709 argsort.append(ind) 

1710 else: 

1711 row = np.nan_to_num(row + 1, 0).sum(1) 

1712 row = row[np.nonzero(row)] 

1713 uniques, counts = np.unique(row, return_counts=True) 

1714 if "side" in feature_axis_name: 

1715 raise Exception(f"{new_label_order=}") 

1716 try: 

1717 ind = uniques[np.argmax(counts)] 

1718 except ValueError as e: 

1719 if "attempt to get argmax of an empty" in str(e): 

1720 raise Exception( 

1721 f"Can't order the data_var {k}, as the {row=} " 

1722 f"has become empty ({row=})." 

1723 ) 

1724 raise e 

1725 argsort.append(ind) 

1726 labels = labels[np.argsort(argsort)] 

1727 new_label_order[feature_axis_name] = labels 

1728 

1729 ds = ds.reindex(new_label_order, fill_value=np.nan) 

1730 return ds 

1731 

1732 def _calc_CV(self) -> dict[str, np.ndarray]: 

1733 if ( 

1734 not all([traj.CVs for traj in self.trajs]) 

1735 or [traj.CVs for traj in self.trajs] == [] 

1736 ): 

1737 return {} 

1738 else: 

1739 CVs = {} 

1740 matching_keys = list( 

1741 set.intersection(*[set(traj.CVs.keys()) for traj in self.trajs]) 

1742 ) 

1743 dropping_keys = set(matching_keys).difference( 

1744 *[set(traj.CVs.keys()) for traj in self.trajs] 

1745 ) 

1746 if dropping_keys: 

1747 print( 

1748 f"The CVs {dropping_keys} will not be in the `CVs` dictionary, " 

1749 f"as they are only present in some, but not all of the {len(self.trajs)} " 

1750 f"trajectories. You can access them with " 

1751 f"`TrajEnsemble([t for t in trajs if any([cv in {dropping_keys} for cv in t.CVs.keys()])])`" 

1752 ) 

1753 if matching_keys != []: 

1754 for key in matching_keys: 

1755 data = [] 

1756 for traj in self.trajs: 

1757 data.append(traj._CVs[key].values) 

1758 # check if all shapes are the same 

1759 shapes = [d.shape[2:] for d in data] 

1760 if not len(set(shapes)) == 1: 

1761 print( 

1762 f"I am not returning the CVs for the feature '{key}' " 

1763 f"because some trajectories have different " 

1764 f"shapes for these CVs. The shapes are {set(shapes)}. " 

1765 f"If you want to access these CVs, use the `_CVs` " 

1766 f"xarray dataset of `TrajEnsemble` instead." 

1767 ) 

1768 continue 

1769 if np.all( 

1770 [ 

1771 any([isinstance(ind, int) for ind in traj.index]) 

1772 for traj in self.trajs 

1773 ] 

1774 ): 

1775 data = np.vstack([d for d in data]) 

1776 if data.ndim <= 3: 

1777 data = data.reshape(-1, data.shape[-1]) 

1778 else: 

1779 data = data.reshape(-1, *data.shape[-2:]) 

1780 else: 

1781 try: 

1782 data = np.concatenate( 

1783 [d.squeeze(axis=0) for d in data], axis=0 

1784 ) 

1785 except ValueError as e: 

1786 if "zero-dimensional" in str(e): 

1787 data = np.hstack([d.squeeze(axis=0) for d in data]) 

1788 if "all the input arrays must have the same" in str(e): 

1789 err_shapes = "\n".join( 

1790 [ 

1791 f"Traj: {self.trajs[i].traj_num:<3} CV '{key}' shape: {d.shape}" 

1792 for i, d in enumerate(data) 

1793 ] 

1794 ) 

1795 raise Exception( 

1796 f"Can't concatenate the data of the CV '{key}'." 

1797 f"The shapes of this CV for the individual " 

1798 f"trajectories are:\n" 

1799 f"{err_shapes}" 

1800 ) 

1801 else: 

1802 raise e 

1803 if data.shape[-1] != 1: 

1804 CVs[key] = data 

1805 else: 

1806 CVs[key] = data.squeeze(-1) 

1807 return CVs 

1808 

1809 @property 

1810 def CVs(self) -> dict[str, np.ndarray]: 

1811 """dict: Returns dict of CVs in SingleTraj classes. Only CVs with the same names 

1812 in all SingleTraj classes are loaded. 

1813 

1814 """ 

1815 return self._calc_CV() 

1816 

1817 @property 

1818 def locations(self) -> list[str]: 

1819 """list: Duplication of self.traj_files but using the trajs own traj_file attribute. 

1820 Ensures that traj files are always returned independent of the current load state. 

1821 """ 

1822 return [traj.traj_file for traj in self.trajs] 

1823 

1824 @property 

1825 def index_arr(self) -> np.ndarray: 

1826 """np.ndarray: Returns np.ndarray with ndim = 2. Clearly assigning every 

1827 loaded frame an identifier of traj_num (self.index_arr[:,0]) and 

1828 frame_num (self.index_arr[:,1]). Can be used to create an unspecified 

1829 subset of frames and can be useful when used with clustering. 

1830 

1831 """ 

1832 # can also be made to use the SingleTraj.index_arr attribute, 

1833 # but doing it this way the traj is loaded. 

1834 # which might slow down thing significantly 

1835 return np.vstack([traj.id for traj in self.trajs]) 

1836 

1837 @property 

1838 def name_arr(self) -> np.ndarray: 

1839 """np.ndarray: Trajectory names with the same length as self.n_frames.""" 

1840 name_arr = [] 

1841 if not np.all([traj.n_frames for traj in self.trajs]): 

1842 return np.array(name_arr) 

1843 else: 

1844 for x, traj in enumerate(self.trajs): 

1845 names = [traj.basename for i in range(traj.n_frames)] 

1846 name_arr.extend(names) 

1847 return np.array(name_arr) 

1848 

1849 @property 

1850 def n_frames(self) -> int: 

1851 """int: Sum of the loaded frames.""" 

1852 return sum([traj.n_frames for traj in self.trajs]) 

1853 

1854 @property 

1855 def frames(self) -> list[int]: 

1856 """list: Frames of individual trajectories.""" 

1857 return [traj.n_frames for traj in self.trajs] 

1858 

1859 @property 

1860 def CVs_in_file(self) -> bool: 

1861 """bool: Is true, if CVs can be loaded from file. Can be used to build a 

1862 data generator from.""" 

1863 return all([traj.CVs_in_file for traj in self.trajs]) 

1864 

1865 def _traj_joined(self, progbar: Optional[Any] = None) -> md.Trajectory: 

1866 # use traj[0] of the trajs list as the traj from which the topology will be used 

1867 parent_traj = self.trajs[0].traj 

1868 

1869 # join the correct number of trajs 

1870 # by use of the `divmod` method, the frames parent_traj traj will be 

1871 # appended for a certain amount, until the remainder of the division 

1872 # is met by that time, the parent traj will be sliced to fill the correct number of frames 

1873 no_of_iters, rest = divmod(self.n_frames, parent_traj.n_frames) 

1874 

1875 total = self.n_frames + no_of_iters + 1 

1876 if progbar is None: 

1877 if _is_notebook(): 

1878 progbar = notebook_tqdm( 

1879 total=total, 

1880 leave=False, 

1881 position=0, 

1882 desc="Clustering...", 

1883 ) 

1884 else: 

1885 progbar = normal_tqdm( 

1886 total=total, 

1887 leave=False, 

1888 position=0, 

1889 desc="Clustering...", 

1890 function="_traj_joined", 

1891 ) 

1892 else: 

1893 if not isinstance(progbar, bool): 

1894 progbar.reset(progbar.total + total, function="_traj_joined") 

1895 else: 

1896 progbar = None 

1897 

1898 for i in range(no_of_iters + 1): 

1899 if i == 0: 

1900 dummy_traj = copy.deepcopy(parent_traj) 

1901 elif i == no_of_iters: 

1902 if rest != 0: 

1903 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj)[:rest]) 

1904 else: 

1905 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj)) 

1906 if progbar is not None: 

1907 progbar.update(function="_traj_joined") 

1908 

1909 # some checks 

1910 assert self.n_frames == dummy_traj.n_frames 

1911 # assert self.n_frames == len(self.trajs), f"{self.n_frames=}, {len(self.trajs)=}" 

1912 

1913 # change the xyz coordinates of dummy_traj according to the frames in joined trajs 

1914 for i, (_, __, traj) in enumerate(self.iterframes()): 

1915 try: 

1916 dummy_traj.xyz[i] = traj.xyz 

1917 except ValueError as e: 

1918 if "broadcast" not in str(e): 

1919 raise e 

1920 warnings.warn( 

1921 f"This`TrajEnsemble` has {len(self.top)} unique topologies. " 

1922 f"I will use the topology with {self.top[0].n_atoms} for joining " 

1923 f"and discard atoms in the other trajectories." 

1924 ) 

1925 dummy_traj.xyz[i] = traj.xyz[0, : dummy_traj.n_atoms] 

1926 if progbar is not None: 

1927 progbar.update(function="_traj_joined") 

1928 

1929 return dummy_traj 

1930 

1931 @property 

1932 def traj_joined(self) -> md.Trajectory: 

1933 """mdtraj.Trajectory: Returns a mdtraj Trajectory with every frame of 

1934 this class appended along the time axis. 

1935 

1936 Can also work if different topologies (with the same number of atoms) are loaded. 

1937 In that case, the first frame in self will be used as topology parent and the remaining frames' 

1938 xyz coordinates are used to position the parents' atoms accordingly. 

1939 

1940 Examples: 

1941 >>> import encodermap as em 

1942 >>> trajs = em.load_project("pASP_pGLU") 

1943 >>> subsample = trajs[0][:20] + trajs[1][:20] 

1944 >>> subsample.split_into_frames().traj_joined # doctest: +ELLIPSIS 

1945 <mdtraj.Trajectory with 40 frames, 69 atoms, 6 residues, and unitcells at ...> 

1946 

1947 """ 

1948 return self._traj_joined() 

1949 

1950 @property 

1951 def xyz(self) -> np.ndarray: 

1952 """np.ndarray: xyz coordinates of all atoms stacked along the traj-time axis. 

1953 

1954 Only works if all trajs share the same topology. 

1955 

1956 """ 

1957 if len(self.top) == 1: 

1958 xyz = np.vstack([traj.xyz for traj in self.trajs]) 

1959 return xyz 

1960 else: 

1961 try: 

1962 xyz = np.vstack([traj.xyz for traj in self.trajs]) 

1963 return xyz 

1964 except Exception as e: 

1965 msg = ( 

1966 "Non consistent topologies don't allow to return a " 

1967 "common xyz. This could be achived by implementing a " 

1968 "high-dimensional masked numpy array with nans at " 

1969 "non-defined positions." 

1970 ) 

1971 e2 = NotImplementedError(msg) 

1972 raise e2 from e 

1973 

1974 def copy(self): 

1975 return deepcopy(self) 

1976 

1977 def split_into_frames(self, inplace: bool = False) -> None: 

1978 """Splits self into separate frames. 

1979 

1980 Args: 

1981 inplace (bool): Whether to do the split inplace or not. 

1982 Defaults to False and thus, returns a new `TrajEnsemble` class. 

1983 

1984 """ 

1985 frames = [] 

1986 for traj_num, frame_num, frame in self.iterframes(): 

1987 frames.append(frame) 

1988 out = TrajEnsemble.with_overwrite_trajnums(*frames) 

1989 assert out.trajs != [] 

1990 if inplace: 

1991 self = out 

1992 else: 

1993 return out 

1994 

1995 def save_CVs(self, path: Union[str, Path]) -> None: 

1996 """Saves the CVs to a NETCDF file using xarray.""" 

1997 self._CVs.to_netcdf(path, format="NETCDF4", engine="h5netcdf") 

1998 

1999 def cluster( 

2000 self, 

2001 cluster_id: int, 

2002 col: str = "cluster_membership", 

2003 memberships: Optional[np.ndarray] = None, 

2004 n_points: int = -1, 

2005 overwrite: bool = True, 

2006 ) -> TrajEnsemble: 

2007 """Clusters this :obj:`TrajEnsemble` based on the provided 

2008 ``cluster_id`` and ``col``. 

2009 

2010 With 'clustering' we mean to extract a subset given a certain membership. 

2011 Take two trajectories with 3 frames each as an ensemble. Let's say we 

2012 calculate the end-to-end distance of the trajectories and use it as 

2013 a collective variable of the system. The values are 

2014 ``[0.8, 1.3, 1.2, 1.9, 0.2, 1.3]``. Based on these values, we define a 

2015 boolean CV (using 0 as False and 1 as True) which says whether the 

2016 end-to-end distance is smaller or grather than 1.0. We give this CV the 

2017 name ``'end_to_end_binary'`` and the values are ``[0, 1, 1, 1, 0, 1]``. 

2018 We can use this CV to 'cluster' the :obj:`TrajEnsemble` via: 

2019 - ``cluster = trajs.cluster(cluster_id=0, col='end_to_end_binary')``: 

2020 This gives a :obj:`TrajEnsemble` with 2 frames. 

2021 - ``cluster = trajs.cluster(cluster_id=0, col='end_to_end_binary')``: 

2022 This gives a :obj:`TrajEnsemble` with 4 frames. 

2023 Sometimes, you want to save this a cluster in a format that can be rendered 

2024 by graphical programs (``.xtc, .pdb``), you can use either the :meth:`join` or 

2025 :meth:`stack` method of the resulting :obj:``TrajEnsemble` to get a 

2026 `mdtraj.Trajectory`, which is either stacked along the atom axis or 

2027 joined along the time axis. 

2028 

2029 Note: 

2030 If the resulting :obj:`TrajEnsemble` has inhomogeneous topologies, the 

2031 :meth:`join` method will return a dict[md.Topology, md.Trajectory] 

2032 instead. This dict can be used to save multiple (``.xtc, .pdb``) files 

2033 and visualize your cluster in external programs. 

2034 

2035 The ``col`` parameter takes any CV name, that is per-frame and integer. 

2036 

2037 Args: 

2038 cluster_id (int): The cluster id to use. Needs to be an integer, 

2039 that is present in the ``col`` parameter. 

2040 col (str): Which 'column' of the collective variables to use. 

2041 Needs to be a key, that can be found in ``trajs.CVs.keys()``. 

2042 memberships (Optional[np.ndarray]): If a :obj:`numpy.ndarray` is 

2043 provided here, the memberships from this array will be used. 

2044 In this case, the ``col`` argument will be unused. 

2045 n_points (int): How many points the resulting cluster should contain. 

2046 Subsamples the points in ``col == cluster_id`` evenly and without 

2047 repeat. If set to -1, all points will be used. 

2048 overwrite (bool): When the ``memberships`` argument is used, but the 

2049 :obj:`TrajEnsemble` already has a CV under the name specified by 

2050 ``col``, you can set this to True to overwrite this column. Can 

2051 be helpful, when you iteratively conduct multiple clusterings. 

2052 

2053 Examples: 

2054 

2055 Import EncoderMap and NumPy. 

2056 

2057 >>> import encodermap as em 

2058 >>> import numpy as np 

2059 

2060 Load an example project. 

2061 

2062 >>> trajs = em.load_project("pASP_pGLU", load_autoencoder=False) 

2063 

2064 Create an array full of ``-1``'s. These are the 'outliers'. 

2065 

2066 >>> cluster_membership = np.ones(shape=(trajs.n_frames, )) * -1 

2067 

2068 Select the first 5 frames of every traj to be in cluster 0. 

2069 

2070 >>> cluster_membership[trajs.id[:, 1] < 5] = 0 

2071 

2072 Select all frames between 50 and 55 to be cluster 1. 

2073 

2074 >>> cluster_membership[(50 <= trajs.id[:, 1]) & (trajs.id[:, 1] <= 55)] = 1 

2075 >>> np.unique(cluster_membership) 

2076 array([-1., 0., 1.]) 

2077 

2078 Load this array as a CV called ``'clu_mem'``. 

2079 

2080 >>> trajs.load_CVs(cluster_membership, attr_name='clu_mem') 

2081 

2082 Extract all of cluster 0 with ``n_points=-1``. 

2083 

2084 >>> clu0 = trajs.cluster(0, "clu_mem") 

2085 >>> clu0.n_frames 

2086 35 

2087 

2088 Extract an evenly spaced subset of cluster 1 with 10 total points. 

2089 

2090 >>> clu1 = trajs.cluster(1, "clu_mem", n_points=10) 

2091 >>> clu1.n_frames 

2092 10 

2093 

2094 Cclusters with inhomogeneous topologies can be stacked along the atom axis. 

2095 

2096 >>> [t.n_atoms for t in trajs] 

2097 [69, 83, 103, 91, 80, 63, 73] 

2098 >>> stacked = clu1.stack() 

2099 >>> stacked.n_atoms 

2100 795 

2101 

2102 But joining the trajectories returns a ``dict[top, traj]`` if the 

2103 topologies are inhomogeneous. 

2104 

2105 >>> joined = clu1.join() 

2106 >>> type(joined) 

2107 <class 'dict'> 

2108 

2109 """ 

2110 if memberships is not None: 

2111 if not overwrite: 

2112 assert col not in self._CVs, ( 

2113 f"Can't load {memberships} as new CVs. " 

2114 f"The CV {col} containing cluster memberships already exists. " 

2115 f"Choose a different name for the argument `col`." 

2116 ) 

2117 self.load_CVs(memberships, col, override=overwrite) 

2118 

2119 assert ( 

2120 col in self._CVs 

2121 ), f"To use the CV '{col}' for clustering, add it to the CVs with `load_CVs`." 

2122 

2123 # find the index 

2124 index_ = (self.CVs[col] == cluster_id).squeeze() 

2125 index = self.index_arr[index_] 

2126 frame_index = np.arange(self.n_frames)[index_] 

2127 assert index.size > 0, ( 

2128 f"The `cluster_id` {cluster_id} is not present in the `col` {col}: " 

2129 f"{np.unique(self.CVs[col])=}" 

2130 ) 

2131 if n_points > 0: 

2132 ind = np.unique( 

2133 np.round(np.linspace(0, len(index) - 1, n_points)).astype(int) 

2134 ) 

2135 index = index[ind] 

2136 frame_index = frame_index[ind] 

2137 try: 

2138 out = self[index] 

2139 except IndexError as e: 

2140 out = self._return_frames_by_index(frame_index) 

2141 if hasattr(self, "clustal_w"): 

2142 out.parse_clustal_w_alignment(self.clustal_w) 

2143 return out 

2144 

2145 def join( 

2146 self, 

2147 align_string: str = "name CA", 

2148 superpose: bool = True, 

2149 ref_align_string: str = "name CA", 

2150 base_traj: Optional[md.Trajectory] = None, 

2151 progbar: Optional[Any] = None, 

2152 dict_keys: Literal["top", "cs"] = "top", 

2153 ) -> dict[Union[md.Topology, str], md.Trajectory]: 

2154 if len(self.top) > 1 and superpose: 

2155 assert align_string == ref_align_string == "name CA", ( 

2156 f"Aligning different topologies only possible, when the `align" 

2157 f"_string` and `ref_align_string` both are 'name CA'." 

2158 ) 

2159 if progbar is None: 

2160 if _is_notebook(): 

2161 progbar = notebook_tqdm( 

2162 total=len(self.top), 

2163 leave=False, 

2164 position=0, 

2165 desc="Joining...", 

2166 ) 

2167 else: 

2168 progbar = normal_tqdm( 

2169 total=len(self.top), 

2170 leave=False, 

2171 position=0, 

2172 desc="Joining...", 

2173 function="join", 

2174 ) 

2175 else: 

2176 if not isinstance(progbar, bool): 

2177 progbar.reset(progbar.total + len(self.top), function="join") 

2178 else: 

2179 progbar = None 

2180 

2181 all_trajs = [] 

2182 out_by_top = {} 

2183 if dict_keys == "top": 

2184 iterator = enumerate(self.trajs_by_top.items()) 

2185 elif dict_keys == "cs": 

2186 iterator = enumerate(self.trajs_by_common_str.items()) 

2187 else: 

2188 raise TypeError( 

2189 f"The argument `dict_keys` '{dict_keys}' is not recognized." 

2190 ) 

2191 for i, (top, traj) in iterator: 

2192 traj = traj._traj_joined(progbar=progbar) 

2193 if superpose: 

2194 if base_traj is not None: 

2195 CAs_traj = traj.top.select(align_string) 

2196 CAs_ref = base_traj.top.select(ref_align_string) 

2197 if hasattr(self, "clustal_w"): 

2198 new_CAs_traj = [] 

2199 new_CAs_ref = [] 

2200 i_t = 0 

2201 i_r = 0 

2202 for aln_t, aln_r in zip( 

2203 self.trajs_by_top[traj.top][0].clustal_w, 

2204 self.trajs_by_top[base_traj.top][0].clustal_w, 

2205 ): 

2206 if aln_t == aln_r == "-": 

2207 pass 

2208 elif aln_t != "-" and aln_r == "-": 

2209 i_t += 1 

2210 elif aln_t == "-" and aln_r != "-": 

2211 i_r += 1 

2212 else: 

2213 new_CAs_traj.append(CAs_traj[i_t]) 

2214 new_CAs_ref.append(CAs_ref[i_r]) 

2215 i_t += 1 

2216 i_r += 1 

2217 new_CAs_traj = np.array(new_CAs_traj) 

2218 new_CAs_ref = np.array(new_CAs_ref) 

2219 CAs_traj = new_CAs_traj.copy() 

2220 CAs_ref = new_CAs_ref.copy() 

2221 else: 

2222 length = min(len(CAs_traj), len(CAs_ref)) 

2223 CAs_traj = CAs_traj[:length] 

2224 CAs_ref = CAs_ref[:length] 

2225 traj = traj.superpose( 

2226 base_traj, 

2227 atom_indices=CAs_traj, 

2228 ref_atom_indices=CAs_ref, 

2229 ) 

2230 else: 

2231 traj = traj.superpose( 

2232 traj, 

2233 atom_indices=traj.top.select(align_string), 

2234 ) 

2235 if i == 0: 

2236 base_traj = traj[0] 

2237 all_trajs.append(traj) 

2238 out_by_top[top] = traj 

2239 if progbar is not None: 

2240 progbar.update(function="join") 

2241 

2242 # return 

2243 return out_by_top 

2244 

2245 def stack( 

2246 self, 

2247 align_string: str = "name CA", 

2248 superpose: bool = True, 

2249 ref_align_string: str = "name CA", 

2250 base_traj: Optional[md.Trajectory] = None, 

2251 progbar: Optional[Any] = None, 

2252 ) -> md.Trajectory: 

2253 if progbar is None: 

2254 if _is_notebook(): 

2255 progbar = notebook_tqdm( 

2256 total=self.n_frames, 

2257 leave=False, 

2258 position=0, 

2259 desc="Stacking...", 

2260 ) 

2261 else: 

2262 progbar = normal_tqdm( 

2263 total=self.n_frames, 

2264 leave=False, 

2265 position=0, 

2266 desc="Stacking...", 

2267 function="stack", 

2268 ) 

2269 else: 

2270 if not isinstance(progbar, bool): 

2271 progbar.reset(progbar.total + self.n_frames, function="stack") 

2272 else: 

2273 progbar = None 

2274 

2275 all_trajs = self.join( 

2276 align_string, 

2277 superpose, 

2278 ref_align_string, 

2279 base_traj, 

2280 progbar=progbar, 

2281 ) 

2282 atoms = 0 

2283 # stack 

2284 for i, traj in enumerate(all_trajs.values()): 

2285 for j, frame in enumerate(traj): 

2286 atoms += frame.n_atoms 

2287 if i == 0 and j == 0: 

2288 stacked = deepcopy(frame) 

2289 else: 

2290 stacked = stacked.stack(frame) 

2291 if progbar is not None: 

2292 progbar.update(function="stack") 

2293 assert stacked.n_atoms == atoms 

2294 return stacked 

2295 

2296 def _trace(self, CV: Sequence[str]) -> np.ndarray: 

2297 """Creates a low-dimensional represnetation of the loaded CV data by 

2298 stacking all arguments in `CV` along a single axis. 

2299 

2300 If this `TrajEnsemble` has 10 trajectories with 100 frames each 

2301 and a CV with shape (100, 50, 3) in each of them with the name 'cartesians' 

2302 then `trajs._trace` will return a np.ndarray of shape 

2303 (1000, 150). 

2304 

2305 Args: 

2306 CV (Sequence[str]): The CVs to combine in the trace. 

2307 

2308 Returns: 

2309 np.ndarray: The trace. 

2310 

2311 """ 

2312 out = [] 

2313 for i in CV: 

2314 v = self.CVs[i] 

2315 out.append(v.reshape(v.shape[0], -1)) 

2316 return np.concatenate(out) 

2317 

2318 def to_dataframe(self, CV: Union[str, Sequence[str]]) -> pd.DataFrame: 

2319 # frame nums can be inhomogeneous 

2320 frame_num = [] 

2321 for traj in self: 

2322 if (_ := traj.id).ndim == 2: 

2323 frame_num.extend(_[:, 1]) 

2324 else: 

2325 frame_num.extend(_) 

2326 time = [] 

2327 for traj in self: 

2328 time.extend(traj.time) 

2329 

2330 # the CV data can be directly extracted from xarray 

2331 if isinstance(CV, str): 

2332 data = [ 

2333 self._CVs[CV] 

2334 .stack({"frame": ("traj_num", "frame_num")}) 

2335 .transpose("frame", ...) 

2336 .dropna("frame", how="all") 

2337 .to_pandas() 

2338 ] 

2339 else: 

2340 data = [] 

2341 for cv in CV: 

2342 df = ( 

2343 self._CVs[cv] 

2344 .stack({"frame": ("traj_num", "frame_num")}) 

2345 .transpose("frame", ...) 

2346 .dropna("frame", how="all") 

2347 .to_pandas() 

2348 ) 

2349 if len(df.columns) == 1: 

2350 df = df.rename(columns={0: cv.upper()}) 

2351 data.append(df) 

2352 

2353 df = pd.DataFrame( 

2354 { 

2355 "traj_file": [ 

2356 traj.traj_file for traj in self for i in range(traj.n_frames) 

2357 ], 

2358 "top_file": [ 

2359 traj.top_file for traj in self for i in range(traj.n_frames) 

2360 ], 

2361 "traj_num": [ 

2362 traj.traj_num for traj in self for i in range(traj.n_frames) 

2363 ], 

2364 "frame_num": frame_num, 

2365 "time": time, 

2366 }, 

2367 ) 

2368 df = df.set_index(["traj_num", "frame_num"]) 

2369 return pd.concat([df, *data], axis=1) 

2370 

2371 def dash_summary(self) -> pd.DataFrame: 

2372 """A :obj:`pandas.DataFrame` that summarizes this ensemble. 

2373 

2374 Returns: 

2375 pd.DataFrame: The DataFrame. 

2376 

2377 """ 

2378 if self.n_trajs == 1: 

2379 return self.trajs[0].dash_summary() 

2380 else: 

2381 # atoms 

2382 n_atoms = np.unique([t.n_atoms for t in self]) 

2383 if len(n_atoms) == 1: 

2384 n_atoms = n_atoms[0] 

2385 

2386 # dt 

2387 dt = [] 

2388 for t in self: 

2389 dt.extend(np.unique(t.time[1:] - t.time[:-1])) 

2390 dt = np.unique(np.asarray(dt)) 

2391 if len(dt) == 1: 

2392 dt = dt[0] 

2393 

2394 # traj_files 

2395 traj_files = [t.basename for t in self] 

2396 

2397 # topologies 

2398 multiple_tops = len(self.top) != 1 

2399 df = pd.DataFrame( 

2400 { 

2401 "field": [ 

2402 "n_trajs", 

2403 "n_frames", 

2404 "n_atoms", 

2405 "dt (ps)", 

2406 "trajs", 

2407 "multiple tops", 

2408 "common_str", 

2409 ], 

2410 "value": [ 

2411 self.n_trajs, 

2412 self.n_frames, 

2413 n_atoms, 

2414 dt, 

2415 traj_files, 

2416 multiple_tops, 

2417 list(set(self.common_str)), 

2418 ], 

2419 } 

2420 ) 

2421 return df.astype(str) 

2422 

2423 def load_CVs( 

2424 self, 

2425 data: TrajEnsembleFeatureType = None, 

2426 attr_name: Optional[str] = None, 

2427 cols: Optional[list[int]] = None, 

2428 deg: Optional[bool] = None, 

2429 periodic: bool = True, 

2430 labels: Optional[list[str]] = None, 

2431 directory: Optional[Union[str, Path]] = None, 

2432 ensemble: bool = False, 

2433 override: bool = False, 

2434 custom_aas: Optional[CustomAAsDict] = None, 

2435 alignment: Optional[str] = None, 

2436 ) -> None: 

2437 """Loads CVs in various ways. The easiest way is to provide a single 

2438 :obj:`numpy.ndarray` and a name for that array. 

2439 

2440 Besides np.ndarray, files (``.txt and .npy``) can be loaded. Features 

2441 or Featurizers can be provided. A :obj:`xarray.Dataset` can be provided. 

2442 A str can be provided which either is the name of one of EncoderMap's 

2443 features (`encodermap.features`) or the string can be 'all', 

2444 which loads all features required for EncoderMap's 

2445 :obj:`encodermap.autoencoder.autoencoder`AngleDihedralCartesianEncoderMap`. 

2446 

2447 Args: 

2448 data (Optional[TrajEnsembleFeatureType]): The CV to 

2449 load. When a :obj:`numpy.ndarray` is provided, it needs to have 

2450 a shape matching ``n_frames`` and the data will be distributed 

2451 to the trajs, When a list of files is provided, ``len(data)`` 

2452 (the files) needs to match ``n_trajs``. The first file will be 

2453 loaded by the first traj (based on the traj's ``traj_num``) and 

2454 so on. If a list of :obj:`numpy.ndarray` is provided, the first 

2455 array will be assigned to the first traj (based on the traj's 

2456 ``traj_num``). If None is provided, the argument ``directory`` 

2457 will be used to construct a str using this expression 

2458 ``fname = directory + traj.basename + '_' + attr_name``. If 

2459 there are ``.txt`` or ``.npy`` files matching that string in 

2460 the ``directory``, the CVs will be loaded from these files to 

2461 the corresponding trajs. Defaults to None. 

2462 attr_name (Optional[str]): The name under which the CV should 

2463 be found in the class. Choose whatever you like. ``'highd'``, 

2464 ``'lowd'``, ``'dists'``, etc. The CV can then be accessed via 

2465 dot-notation: ``trajs.attr_name``. Defaults to None, in which 

2466 case, the argument ``data`` should point to existing files. 

2467 The ``attr_name`` will be extracted from these files. 

2468 cols (Optional[list[int]]): A list of integers indexing the columns 

2469 of the data to be loaded. This is useful if a file contains 

2470 columns which are not features (i.e. an indexer or the error of 

2471 the features. eg:: 

2472 

2473 id f1 f2 f1_err f2_err 

2474 0 1.0 2.0 0.1 0.1 

2475 1 2.5 1.2 0.11 0.52 

2476 

2477 In that case, you would want to supply ``cols=[1, 2]`` to the 

2478 ``cols`` argument. If None is provided all columns are loaded. 

2479 Defaults to None. 

2480 deg (Optional[bool]): Whether to return angular CVs using degrees. 

2481 If None or False, CVs will be in radian. Defaults to None. 

2482 periodic (bool): Whether to use the minimum image convention to 

2483 calculate distances/angles/dihedrals. This is generally recommended, 

2484 when you don't clean up your trajectories and the proteins break 

2485 over the periodic boundary conditions. However, when the protein is 

2486 large, the distance between one site and another might be shorter 

2487 through the periodic boundary. This can lead to wrong results 

2488 in your distance calculations. 

2489 labels (list[str]): A list containing the labels for the dimensions of 

2490 the data. If you provide a :obj:`numpy.ndarray` with shape 

2491 ``(n_trajs, n_frames, n_feat)``, this list needs to be of 

2492 ``len(n_feat)``. An exception will be raised otherwise. If None is 

2493 privided, the labels will be automatically generated. Defaults to None. 

2494 directory (Optional[str]): If this argument is provided, the 

2495 directory will be searched for ``.txt`` or ``.npy`` files which 

2496 have the same names as the trajectories have basenames. The 

2497 CVs will then be loaded from these files. 

2498 ensemble (bool): Whether the trajs in this class belong to an ensemble. 

2499 This implies that they contain either the same topology or are 

2500 very similar (think wt, and mutant). Setting this option True will 

2501 try to match the CVs of the trajs onto the same dataset. 

2502 If a VAL residue has been replaced by LYS in the mutant, 

2503 the number of sidechain dihedrals will increase. The CVs of the 

2504 trajs with VAL will thus contain some NaN values. Defaults to False. 

2505 override (bool): Whether to override CVs with the same name as ``attr_name``. 

2506 custom_aas (Optional[CustomAAsDict]): You can provide non-standard 

2507 residue definitions in this argument. See 

2508 :obj:`encodermap.trajinfo.trajinfo_utils.CustomTopology` for 

2509 information how to use the custom_aas argument. If set to None 

2510 (default), only standard residue names are assumed. 

2511 alignment (Optional[str]): If your proteins have similar but different 

2512 sequences, you can provide a CLUSTAL W alignment as this argument 

2513 and the featurization will align the features accordingly. 

2514 

2515 Raises: 

2516 TypeError: When wrong Type has been provided for data. 

2517 

2518 """ 

2519 # Local Folder Imports 

2520 from .trajinfo_utils import load_CVs_ensembletraj 

2521 

2522 if data is None: 

2523 data = self.featurizer 

2524 

2525 if custom_aas is not None: 

2526 self.load_custom_topology(custom_aas) 

2527 

2528 if alignment is not None: 

2529 if ensemble: 

2530 self.parse_clustal_w_alignment(alignment) 

2531 else: 

2532 print( 

2533 "Providing a CLUSTAL W alignment for featurization of ensembles " 

2534 "of protein families, makes only sense, when `ensemble` is also " 

2535 "set to True. This makes EncoderMap align the features based " 

2536 "on their sequence alignment." 

2537 ) 

2538 

2539 # if some trajs are missing time 

2540 b, c = np.unique( 

2541 np.asarray([t.backend for t in self.trajs]), return_counts=True 

2542 ) 

2543 if len(b) > 1: 

2544 for traj in self.trajs: 

2545 traj.load_traj() 

2546 

2547 load_CVs_ensembletraj( 

2548 self, 

2549 data, 

2550 attr_name, 

2551 cols, 

2552 deg, 

2553 periodic, 

2554 labels, 

2555 directory, 

2556 ensemble, 

2557 override, 

2558 ) 

2559 

2560 def save( 

2561 self, 

2562 fname: Union[str, Path], 

2563 CVs: Union[Literal["all"], list[str], Literal[False]] = "all", 

2564 overwrite: bool = False, 

2565 only_top: bool = False, 

2566 ) -> None: 

2567 """Saves this TrajEnsemble into a single ``.h5`` file. 

2568 

2569 Args: 

2570 fname (Union[str, Path]): Where to save the file. 

2571 CVs (Union[Literal["all"], list[str], Literal[False]]): Which CVs 

2572 to alos store in the file. If set to ``'all'``, all CVs will 

2573 be saved. Otherwise, a list[str] can be provided to only save 

2574 specific CVs. Can also be set to False, no CVs are stored in the 

2575 file. 

2576 overwrite (bool): If the file exists, it is overwritten. 

2577 only_top (bool): Only writes the trajectorie's topologies into the file. 

2578 

2579 Raises: 

2580 IOError: If file already exists and overwrite is not True. 

2581 

2582 """ 

2583 # Third Party Imports 

2584 from mdtraj.utils import in_units_of 

2585 

2586 if any([hasattr(traj, "clustal_w") for traj in self]): 

2587 warnings.warn( 

2588 "Can't currently save a `TrajEnsemble` with a clustal w alignment" 

2589 ) 

2590 return 

2591 

2592 fname = Path(fname) 

2593 assert ( 

2594 fname.suffix == ".h5" 

2595 ), "We recommend the .h5 file extension for these files." 

2596 if fname.is_file() and not overwrite and not only_top: 

2597 raise IOError( 

2598 f"File {fname} already exists. Set `overwrite` to True to overwrite." 

2599 ) 

2600 if fname.is_file() and overwrite: 

2601 fname.unlink() 

2602 

2603 for i, traj in self.itertrajs(): 

2604 with HDF5GroupWrite(fname, "a", force_overwrite=overwrite) as f: 

2605 if not only_top: 

2606 f.write_into_group( 

2607 group_id=str(i), 

2608 coordinates=in_units_of( 

2609 traj.xyz, md.Trajectory._distance_unit, f.distance_unit 

2610 ), 

2611 time=traj.time, 

2612 cell_lengths=in_units_of( 

2613 traj.unitcell_lengths, 

2614 md.Trajectory._distance_unit, 

2615 f.distance_unit, 

2616 ), 

2617 cell_angles=traj.unitcell_angles, 

2618 topology=traj.top, 

2619 ) 

2620 else: 

2621 f.write_topology(traj.traj_num, traj.top) 

2622 for i, traj in self.itertrajs(): 

2623 traj._custom_top.traj = traj 

2624 traj._custom_top.to_hdf_file(fname) 

2625 traj._common_str_to_hdf_file(fname) 

2626 if CVs == "all": 

2627 save_netcdf_alongside_mdtraj(fname, self._CVs) 

2628 return 

2629 if self._CVs and CVs: 

2630 with h5py.File(fname, "a") as file: 

2631 if "CVs" in list(file.keys()): 

2632 grp = file["CVs"] 

2633 else: 

2634 grp = file.create_group("CVs") 

2635 for key in CVs: 

2636 value = self._CVs[key] 

2637 assert self.n_frames == value.shape[1] 

2638 grp.create_dataset(name=key, data=value) 

2639 

2640 def _return_trajs_by_index(self, index: Sequence[int]) -> TrajEnsemble: 

2641 """Creates a :obj:`TrajEnsemble` object with the trajs specified by ``index``. 

2642 

2643 This is a sub-method of the ``trajs[]`` indexer. 

2644 

2645 """ 

2646 new_common_str = [] 

2647 for i, traj in enumerate(self.trajs): 

2648 if i not in index: 

2649 continue 

2650 new_common_str.append(traj.common_str) 

2651 new_common_str = list(set(new_common_str)) 

2652 for i, ind in enumerate(index): 

2653 if i == 0: 

2654 trajs_subset = self.trajs[ind]._gen_ensemble() 

2655 else: 

2656 new_traj = self.trajs[ind]._gen_ensemble() 

2657 trajs_subset += new_traj 

2658 trajs_subset.common_str = new_common_str 

2659 trajs_subset.basename_fn = self.basename_fn 

2660 return trajs_subset 

2661 

2662 def _return_frames_by_index(self, index: Sequence[int]) -> TrajEnsemble: 

2663 """Creates a :obj:`TrajEnsemble` object with the frames specified by ``index``.""" 

2664 new_common_str = [] 

2665 frames = [] 

2666 for traj_num, frame_num, frame in self.iterframes(): 

2667 if frame_num not in index: 

2668 continue 

2669 frames.append(frame) 

2670 new_common_str.append(frame.common_str) 

2671 assert len(frames) > 0, ( 

2672 f"Can't return frames with index {index} from trajectory with " 

2673 f"frames {self.id[:, -1]}." 

2674 ) 

2675 new_common_str = list(set(new_common_str)) 

2676 for i, frame in enumerate(frames): 

2677 if i == 0: 

2678 trajs_subset = frame._gen_ensemble() 

2679 else: 

2680 new_traj = frame._gen_ensemble() 

2681 trajs_subset = trajs_subset + new_traj 

2682 trajs_subset.common_str = new_common_str 

2683 trajs_subset.basename_fn = self.basename_fn 

2684 return trajs_subset 

2685 

2686 def _pyemma_indexing_tsel(self, key: np.ndarray) -> TrajEnsemble: 

2687 """Returns a new :obj:`TrajEnsemble` by giving the indices of traj and frame.""" 

2688 if key.ndim == 1: 

2689 key = key.reshape(len(key), 1).T 

2690 trajs = [] 

2691 for i, num in enumerate(np.unique(key[:, 0])): 

2692 frames = key[key[:, 0] == num, 1] 

2693 trajs.append(self.tsel[num].fsel[frames]) 

2694 return TrajEnsemble( 

2695 trajs, basename_fn=self.basename_fn, common_str=self.common_str 

2696 ) 

2697 

2698 def _pyemma_indexing_no_tsel(self, key: np.ndarray) -> TrajEnsemble: 

2699 """Returns a new :obj:`TrajEnsemble` by giving the indices of traj and frame.""" 

2700 if key.ndim == 1: 

2701 key = key.reshape(len(key), 1).T 

2702 trajs = [] 

2703 for i, num in enumerate(np.unique(key[:, 0])): 

2704 assert num < self.n_trajs, ( 

2705 f"Can't identify trajectory with number {num} in an ensemble " 

2706 f"with {self.n_trajs} trajectories." 

2707 ) 

2708 frames = key[key[:, 0] == num, 1] 

2709 trajs.append(self.trajs[num][frames]) 

2710 return TrajEnsemble( 

2711 trajs, basename_fn=self.basename_fn, common_str=self.common_str 

2712 ) 

2713 

2714 def subsample( 

2715 self, 

2716 stride: Optional[int] = None, 

2717 total: Optional[int] = None, 

2718 ) -> Optional[TrajEnsemble]: 

2719 """Returns a subset of this :obj:`TrajEnsemble` given the provided stride 

2720 or total. 

2721 

2722 This is a faster alternative than using the ``trajs[trajs.index_arr[::1000]]`` 

2723 when HDF5 trajs are used, because the slicing information is saved in the 

2724 respective :obj:`encodermap.trajinfo.info_single.SingleTraj` 

2725 and loading of single frames is faster in HDF5 formatted trajs. 

2726 

2727 Args: 

2728 stride (Optional[int]): Return a frame ever stride frames. 

2729 total (Optional[int]): Return a total of evenly sampled frames. 

2730 

2731 Returns: 

2732 TrajEnsemble: A trajectory ensemble. 

2733 

2734 Note: 

2735 The result from ``subsample(1000)` `is different from 

2736 ``trajs[trajs.index_arr[::1000]]``. With subsample every trajectory 

2737 is sub-sampled independently. Consider a :obj:`TrajEnsemble` with two 

2738 :obj:`encodermap.trajinfo.info_single.SingleTraj` trajectories with 

2739 18 frames each. ``subsampled = trajs.subsample(5)`` would return a 

2740 :obj:`TrajEnsemble` with two trajs with 3 frames each 

2741 (``subsampled.n_frames == 6``). Whereas, 

2742 ``subsampled = trajs[trajs.index_arr[::5]]`` would return a 

2743 :obj:`TrajEnsemble` with 7 SingleTrajs with 1 frame each 

2744 (``subsampled.n_frames == 7``). Because the time and frame numbers 

2745 are saved all the time, this should not be too much of a problem. 

2746 

2747 """ 

2748 if stride is None and total is not None: 

2749 idx = self.id[ 

2750 np.unique( 

2751 np.round(np.linspace(0, self.n_frames - 1, total)).astype(int) 

2752 ) 

2753 ] 

2754 return self[idx] 

2755 elif total is None and stride is not None: 

2756 trajs = [] 

2757 for i, traj in enumerate(self.trajs): 

2758 _ = traj[slice(None, None, stride)] 

2759 trajs.append(_) 

2760 return TrajEnsemble( 

2761 trajs, common_str=self.common_str, basename_fn=self.basename_fn 

2762 ) 

2763 else: 

2764 print("Provide either stride or total.") 

2765 

2766 def get_single_frame(self, key: int) -> "SingleTraj": 

2767 """Returns a single frame from all loaded trajectories. 

2768 

2769 Consider a :obj:`TrajEnsemble` class with two trajectories. One has 10 

2770 frames, the other 5 (``trajs.n_frames`` is 15). Calling 

2771 ``trajs.get_single_frame(12)`` is equal to calling ``trajs[1][1]``. 

2772 Calling ``trajs.get_single_frame(16)`` will error, and 

2773 ``trajs.get_single_frame(1)`` is the same as ``trajs[0][1]``. 

2774 

2775 Args: 

2776 key (int): The frame to return. 

2777 

2778 Returns: 

2779 encodermap.trajinfo.info_single.SingleTraj: The frame. 

2780 

2781 """ 

2782 # some input checks 

2783 if self.n_frames == 0: 

2784 raise Exception( 

2785 "Indexing a no_load backend does not work. I need some " 

2786 "information about the frames in each trajectory. Please " 

2787 "load either highd or lowd." 

2788 ) 

2789 if key >= self.n_frames: 

2790 raise IndexError( 

2791 "index {} is out of bounds for trajectory with {} frames".format( 

2792 key, self.n_frames 

2793 ) 

2794 ) 

2795 if not isinstance(key, (int, np.int32, np.int64)): 

2796 raise IndexError( 

2797 "if you want a single frame, please provide an integer. " 

2798 "If you want multiple frames use ep.TrajEnsemble[]" 

2799 ) 

2800 

2801 if len(self.trajs) == 1: 

2802 return self.trajs[0][key] 

2803 else: 

2804 num, frame = np.hstack( 

2805 [ 

2806 np.array([np.full(t.n_frames, t.traj_num), np.arange(t.n_frames)]) 

2807 for t in self.trajs 

2808 ] 

2809 ).T[key] 

2810 traj_nums = np.unique(self.id[:, 0]) 

2811 if not np.array_equal(traj_nums, np.arange(len(traj_nums) + 1)): 

2812 traj_out = self.tsel[num][frame] 

2813 else: 

2814 traj_out = self.trajs[num][frame] 

2815 return traj_out 

2816 

2817 def unload(self) -> None: 

2818 """Unloads all trajs in self.""" 

2819 [traj.unload() for traj in self] 

2820 self.backend = "no_load" 

2821 

2822 def load_trajs(self) -> None: 

2823 """Loads all trajs in self.""" 

2824 [traj.load_traj() for traj in self] 

2825 self.backend = "mdtraj" 

2826 

2827 @overload 

2828 def batch_iterator( 2828 ↛ exitline 2828 didn't jump to the function exit

2829 self, 

2830 batch_size: int, 

2831 replace: bool = False, 

2832 CV_names: tuple[str] = ("",), 

2833 deterministic: bool = True, 

2834 yield_index: bool = True, 

2835 start: int = 1, 

2836 ) -> Iterator[ 

2837 tuple[ 

2838 np.ndarray, 

2839 np.ndarray, 

2840 ] 

2841 ]: ... 

2842 

2843 @overload 

2844 def batch_iterator( 2844 ↛ exitline 2844 didn't jump to the function exit

2845 self, 

2846 batch_size: int, 

2847 replace: bool = False, 

2848 CV_names: tuple[str] = ("",), 

2849 deterministic: bool = True, 

2850 yield_index: bool = False, 

2851 start: int = 1, 

2852 ) -> Iterator[np.ndarray]: ... 

2853 

2854 @overload 

2855 def batch_iterator( 2855 ↛ exitline 2855 didn't jump to the function exit

2856 self, 

2857 batch_size: int, 

2858 replace: bool = False, 

2859 CV_names: Optional[Sequence[str]] = None, 

2860 deterministic: bool = True, 

2861 yield_index: bool = True, 

2862 start: int = 1, 

2863 ) -> Iterator[ 

2864 tuple[ 

2865 np.ndarray, 

2866 tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], 

2867 ] 

2868 ]: ... 

2869 

2870 @overload 

2871 def batch_iterator( 2871 ↛ exitline 2871 didn't jump to the function exit

2872 self, 

2873 batch_size: int, 

2874 replace: bool = False, 

2875 CV_names: Optional[Sequence[str]] = None, 

2876 deterministic: bool = True, 

2877 yield_index: bool = False, 

2878 start: int = 1, 

2879 ) -> Iterator[ 

2880 tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] 

2881 ]: ... 

2882 

2883 def batch_iterator( 

2884 self, 

2885 batch_size: int, 

2886 replace: bool = False, 

2887 CV_names: Optional[Sequence[str]] = None, 

2888 deterministic: bool = False, 

2889 yield_index: bool = False, 

2890 start: int = 1, 

2891 ) -> Iterator[Any]: 

2892 """Lazy batched iterator of CV data. 

2893 

2894 This iterator extracts batches of CV data from the ensemble. If the 

2895 ensemble is a large HDF5 datset, this provides the ability to use all 

2896 data without loading it all into memory. 

2897 

2898 Examples: 

2899 

2900 Import EncoderMap and load some example trajectories. 

2901 

2902 >>> import encodermap as em 

2903 >>> trajs = em.TrajEnsemble( 

2904 ... [ 

2905 ... 'https://files.rcsb.org/view/1YUG.pdb', 

2906 ... 'https://files.rcsb.org/view/1YUF.pdb' 

2907 ... ] 

2908 ... ) 

2909 

2910 This iterator will yield new samples forever. The batch is a tuple 

2911 of :obj:`numpy.ndarray`. 

2912 

2913 >>> for batch in trajs.batch_iterator(batch_size=2): 

2914 ... print([b.shape for b in batch]) 

2915 ... break 

2916 [(2, 148), (2, 147), (2, 150, 3), (2, 149), (2, 82)] 

2917 

2918 Use it with Python's builtin ``next()`` function. The ``deterministic`` 

2919 flag returns deterministic batches. The ``yield_index`` flag also 

2920 provides the index of the extracted batch. In this example, both batches 

2921 are extracted from the 1YUG trajectory (``traj_num==0``). 

2922 

2923 >>> iterator = trajs.batch_iterator(deterministic=True, batch_size=2, yield_index=True) 

2924 >>> index, batch = next(iterator) 

2925 >>> index 

2926 [[0 5] 

2927 [0 8]] 

2928 >>> index, batch = next(iterator) 

2929 >>> index 

2930 [[ 0 3] 

2931 [ 0 10]] 

2932 

2933 If a single string is requested for ``CV_names``, the batch, will 

2934 be a sinlge :obj:`numpy.ndarray`, rather than a tuple thereof. 

2935 

2936 >>> iterator = trajs.batch_iterator(batch_size=2, CV_names=["central_dihedrals"]) 

2937 >>> batch = next(iterator) 

2938 >>> batch.shape 

2939 (2, 147) 

2940 

2941 Args: 

2942 batch_size (int): The size of the batch. 

2943 replace (bool): Whether inside a single batch a sample can occur 

2944 more than once. Set to False (default) to only allow unique 

2945 samples in a batch. 

2946 CV_names (Sequence[str]): The names of the CVs to be used in the 

2947 iterator. If a list/tuple with a single string is provided, the 

2948 batch will be a :obj:`numpy.ndarray`, rather than a tuple 

2949 thereof. 

2950 deterministic (bbol): Whether the samples should be deterministic. 

2951 yield_index (bool): Whether to also yield the index of the extracted 

2952 samples. 

2953 start (int): A start ineteger, which can be used together with 

2954 ``deterministic=True`` to get different deterministic datasets. 

2955 

2956 Returns: 

2957 Iterator[Any]: Different iterators based on chosen arguments. 

2958 

2959 """ 

2960 # Encodermap imports 

2961 from encodermap.autoencoder.autoencoder import np_to_sparse_tensor 

2962 

2963 # the standard CV_names 

2964 if CV_names is None: 

2965 CV_names = [ 

2966 "central_angles", 

2967 "central_dihedrals", 

2968 "central_cartesians", 

2969 "central_distances", 

2970 "side_dihedrals", 

2971 ] 

2972 

2973 if self.CVs_in_file and all([t.index == (None,) for t in self.trajs]): 

2974 ds = xr.open_dataset( 

2975 self.trajs[0].traj_file, engine="h5netcdf", group="CVs" 

2976 ) 

2977 ds = ds[CV_names] 

2978 total_len = ( 

2979 ds.stack({"frame": ("traj_num", "frame_num")}) 

2980 .transpose("frame", ...) 

2981 .dropna("frame", how="all") 

2982 .sizes["frame"] 

2983 ) 

2984 else: 

2985 ds = self._CVs[CV_names] 

2986 total_len = ( 

2987 ds.stack({"frame": ("traj_num", "frame_num")}) 

2988 .transpose("frame", ...) 

2989 .dropna("frame", how="all") 

2990 .sizes["frame"] 

2991 ) 

2992 traj_nums_and_frames = self.id.copy() 

2993 if self.CVs_in_file: 

2994 assert len(traj_nums_and_frames) == total_len, ( 

2995 f"The CVs of the trajs are not aligned with the frames. The " 

2996 f"CVs stacked along the traj/frame axis have a shape of {total_len}, " 

2997 f"while the id array of the trajs has a shape of {self.id.shape}. " 

2998 f"The frames of the trajs are reported as {self.n_frames}. The " 

2999 f"CV data was extracted from the .h5 file {self.trajs[0].traj_file}, " 

3000 f"by stacking the traj/frame axis into a combined axis and dropping " 

3001 f"the frames full of NaNs for the CVs {CV_names}." 

3002 ) 

3003 else: 

3004 assert len(traj_nums_and_frames) == total_len, ( 

3005 f"The CVs of the trajs are not aligned with the frames. The " 

3006 f"CVs stacked along the traj/frame axis have a shape of {total_len}, " 

3007 f"while the id array of the trajs has a shape of {self.id.shape}. " 

3008 f"The frames of the trajs are reported as {self.n_frames}. The CV " 

3009 f"data was obtained from combining the CVs {CV_names} of the trajectories " 

3010 f"in this ensemble along a traj axis." 

3011 ) 

3012 

3013 # detect sparse or sidechain only sparse 

3014 sparse = set() 

3015 sub_ds = ds.sel(frame_num=0) 

3016 for o in CV_names: 

3017 datum = sub_ds[o].values 

3018 if np.isnan(datum).any(1).any(None): 

3019 sparse.add(o) 

3020 

3021 # start the loop 

3022 # i is the counter for the sample 

3023 # j is the counter for how many tries were needed to select either 

3024 # unique indices (if replace is False) or select indices where not 

3025 i = start 

3026 while True: 

3027 index = [] 

3028 out = [[] for o in CV_names] 

3029 j = 0 

3030 while len(index) < batch_size: 

3031 if j > 100 * batch_size: 

3032 raise Exception( 

3033 f"Can't find unique indices after 100 iterations. " 

3034 f"Current index is {index=}." 

3035 ) 

3036 if deterministic: 

3037 np.random.seed(i + j) 

3038 idx = tuple(traj_nums_and_frames[np.random.randint(0, total_len, 1)[0]]) 

3039 if idx in index and not replace: 

3040 j += 1 

3041 continue 

3042 data = ds.sel(traj_num=idx[0], frame_num=idx[1]) 

3043 # check if any values in CV_names are all nans 

3044 # this can happen for ensembles with different length trajectories 

3045 # we append to `out_`, because if not all nans for this frame 

3046 # we can append `out_` to `out` 

3047 out_ = [] 

3048 for k, o in enumerate(CV_names): 

3049 v = data[o].values 

3050 if "dist" in o: 

3051 assert np.all(np.nan_to_num(v, copy=True, nan=1.0) > 0.0), ( 

3052 f"Distances for the selection traj_num={idx[0]} frame={idx[1]} " 

3053 f"contained a 0. This will result in problems with the cartesian " 

3054 f"cost." 

3055 ) 

3056 if np.all(np.isnan(v)): 

3057 # if all nans break 

3058 out_ = [] 

3059 break 

3060 else: 

3061 out_.append(v) 

3062 # and continue 

3063 if out_ == []: 

3064 j += 1 

3065 continue 

3066 # if not, we can append 

3067 index.append(idx) 

3068 for k, o in enumerate(out_): 

3069 out[k].append(o) 

3070 j += 1 

3071 # stack 

3072 out = [np.stack(o, 0) for o in out] 

3073 index = np.array(index) 

3074 

3075 # make sparse tensors 

3076 for i, o in enumerate(CV_names): 

3077 if o in sparse: 

3078 if out[i].ndim > 2: 

3079 out[i] = out[i].reshape(batch_size, -1) 

3080 out[i] = np_to_sparse_tensor(out[i]) 

3081 

3082 i += 1 

3083 # and yield 

3084 if len(CV_names) > 1: 

3085 out = tuple(out) 

3086 else: 

3087 out = out[0] 

3088 if yield_index: 

3089 yield index, out 

3090 else: 

3091 yield out 

3092 

3093 def tf_dataset( 

3094 self, 

3095 batch_size: int, 

3096 replace: bool = False, 

3097 sidechains: bool = False, 

3098 reconstruct_sidechains: bool = False, 

3099 CV_names: Optional[list[str]] = None, 

3100 deterministic: bool = False, 

3101 prefetch: bool = True, 

3102 start: int = 1, 

3103 ) -> tf.data.Dataset: 

3104 # Third Party Imports 

3105 import tensorflow as tf 

3106 

3107 gen = lambda: self.batch_iterator( 

3108 batch_size, replace, CV_names, deterministic, start=start 

3109 ) 

3110 if CV_names is None and not sidechains and not reconstruct_sidechains: 

3111 CV_names = [ 

3112 "central_angles", 

3113 "central_dihedrals", 

3114 "central_cartesians", 

3115 "central_distances", 

3116 ] 

3117 elif CV_names is None and sidechains and not reconstruct_sidechains: 

3118 CV_names = [ 

3119 "central_angles", 

3120 "central_dihedrals", 

3121 "central_cartesians", 

3122 "central_distances", 

3123 "side_dihedrals", 

3124 ] 

3125 elif CV_names is None and reconstruct_sidechains: 

3126 CV_names = [ 

3127 "central_angles", 

3128 "central_dihedrals", 

3129 "all_cartesians", 

3130 "central_distances", 

3131 "side_angles", 

3132 "side_dihedrals", 

3133 "side_distances", 

3134 ] 

3135 for o in CV_names: 

3136 assert o in self._CVs, f"The CV '{o}' is not loaded in this ensemble." 

3137 

3138 # define the TensorSpecs 

3139 sample = next( 

3140 self.batch_iterator( 

3141 batch_size=batch_size, 

3142 replace=replace, 

3143 CV_names=CV_names, 

3144 deterministic=deterministic, 

3145 ) 

3146 ) 

3147 if isinstance(sample, tuple): 

3148 tensor_specs = [] 

3149 for o, s in enumerate(sample): 

3150 if isinstance(s, tf.sparse.SparseTensor): 

3151 tensor_specs.append( 

3152 tf.SparseTensorSpec(shape=s.dense_shape, dtype="float32") 

3153 ) 

3154 else: 

3155 tensor_specs.append(tf.TensorSpec(shape=s.shape, dtype="float32")) 

3156 tensor_specs = tuple(tensor_specs) 

3157 else: 

3158 tensor_specs = tf.TensorSpec(shape=sample.shape, dtype="float32") 

3159 dataset = tf.data.Dataset.from_generator(gen, output_signature=tensor_specs) 

3160 if prefetch: 

3161 dataset = dataset.prefetch(batch_size * 4) 

3162 if deterministic: 

3163 options = tf.data.Options() 

3164 options.deterministic = True 

3165 dataset = dataset.with_options(options) 

3166 assert dataset.options().deterministic 

3167 return dataset 

3168 

3169 def itertrajs(self) -> Iterator[tuple[int, "SingleTraj"]]: 

3170 """Generator over the SingleTraj classes. 

3171 

3172 Yields: 

3173 tuple: A tuple containing the following: 

3174 - int: A loop-counter integer. Is identical with traj.traj_num. 

3175 - encodermap.SingleTraj: An SingleTraj object. 

3176 

3177 Examples: 

3178 >>> import encodermap as em 

3179 >>> trajs = em.TrajEnsemble( 

3180 ... [ 

3181 ... 'https://files.rcsb.org/view/1YUG.pdb', 

3182 ... 'https://files.rcsb.org/view/1YUF.pdb' 

3183 ... ] 

3184 ... ) 

3185 >>> for i, traj in trajs.itertrajs(): 

3186 ... print(traj.basename) 

3187 1YUG 

3188 1YUF 

3189 

3190 """ 

3191 for traj in self: 

3192 yield traj.traj_num, traj 

3193 

3194 def iterframes(self) -> Iterator[tuple[int, int, "SingleTraj"]]: 

3195 """Generator over the frames in this instance. 

3196 

3197 Yields: 

3198 tuple: A tuple containing the following: 

3199 - int: The traj_num 

3200 - int: The frame_num 

3201 - encodermap.SingleTraj: An SingleTraj object. 

3202 

3203 Examples: 

3204 

3205 Import EncoderMap and load an example :obj:`TrajEnsemble`. 

3206 

3207 

3208 >>> import encodermap as em 

3209 >>> trajs = em.TrajEnsemble( 

3210 ... [ 

3211 ... 'https://files.rcsb.org/view/1YUG.pdb', 

3212 ... 'https://files.rcsb.org/view/1YUF.pdb', 

3213 ... ], 

3214 ... ) 

3215 >>> print(trajs.n_frames) 

3216 31 

3217 

3218 Subsample every tenth frame. 

3219 

3220 >>> trajs = trajs.subsample(10) 

3221 >>> trajs.n_frames 

3222 4 

3223 

3224 Call the :meth:`iterframes` method. 

3225 

3226 >>> for traj_num, frame_num, frame in trajs.iterframes(): 

3227 ... print(traj_num, frame_num, frame.n_frames) 

3228 0 0 1 

3229 0 10 1 

3230 1 0 1 

3231 1 10 1 

3232 

3233 """ 

3234 for traj in self: 

3235 yield from traj.iterframes(with_traj_num=True) 

3236 

3237 def __copy__(self): 

3238 cls = self.__class__ 

3239 result = cls.__new__(cls) 

3240 result.__dict__.update(self.__dict__) 

3241 return result 

3242 

3243 def __deepcopy__(self, memo): 

3244 # Standard Library Imports 

3245 from copy import deepcopy 

3246 

3247 cls = self.__class__ 

3248 result = cls.__new__(cls) 

3249 memo[id(self)] = result 

3250 for k, v in self.__dict__.items(): 

3251 setattr(result, k, deepcopy(v, memo)) 

3252 return result 

3253 

3254 def __getitem__(self, key: CanBeIndex) -> TrajEnsemble: 

3255 if isinstance(key, (int, np.int32, np.int64)): 

3256 return self.trajs[key] 

3257 elif isinstance(key, list) and not isinstance(key[0], list): 

3258 new_class = self._return_trajs_by_index(key) 

3259 return new_class 

3260 elif isinstance(key, np.ndarray): 

3261 if key.ndim == 1: 

3262 new_class = self._return_trajs_by_index(key) 

3263 return new_class 

3264 elif key.ndim == 2: 

3265 new_class = self._pyemma_indexing_no_tsel(key) 

3266 return new_class 

3267 else: 

3268 raise IndexError( 

3269 f"Passing a key with more than 2 dimensions makes no sense. " 

3270 f"One dim for trajs, one for frames. Your key has " 

3271 f"{key.ndim} dimensions." 

3272 ) 

3273 elif isinstance(key, slice): 

3274 start, stop, step = key.indices(self.n_trajs) 

3275 list_ = list(range(start, stop, step)) 

3276 new_class = self[list_] 

3277 return new_class 

3278 elif isinstance(key, list) and all(isinstance(k, list) for k in key): 

3279 return self[np.asarray(key)] 

3280 raise IndexError(f"Invalid argument for slicing: {key=}") 

3281 

3282 def __reversed__(self): 

3283 raise NotImplementedError() 

3284 

3285 def __eq__(self, other): 

3286 # check if traj_files and ids are the same 

3287 if len(self) != len(other): 

3288 return False 

3289 else: 

3290 # Standard Library Imports 

3291 import functools 

3292 

3293 same_strings = functools.reduce( 

3294 lambda x, y: x and y, 

3295 map( 

3296 lambda a, b: a == b, 

3297 [traj.traj_file for traj in self.trajs], 

3298 [traj2.traj_file for traj2 in other.trajs], 

3299 ), 

3300 True, 

3301 ) 

3302 same_ids = all( 

3303 [ 

3304 np.array_equal(traj1.id, traj2.id) 

3305 for traj1, traj2 in zip(self.trajs, other.trajs) 

3306 ] 

3307 ) 

3308 same_CVs = self._CVs.equals(other._CVs) 

3309 return same_strings and same_ids and same_CVs 

3310 

3311 def __iter__(self): 

3312 self._index = 0 

3313 return self 

3314 

3315 def __next__(self): 

3316 if self._index >= self.n_trajs: 

3317 raise StopIteration 

3318 else: 

3319 self._index += 1 

3320 return self.trajs[self._index - 1] 

3321 

3322 def __radd__(self, y): 

3323 """Reverse addition to make sum() work.""" 

3324 if isinstance(y, int): 

3325 return self 

3326 return self.__add__(y) 

3327 

3328 def __add__(self, y): 

3329 """Addition of two TrajEnsemble objects returns new TrajEnsemble with 

3330 trajectories joined along the traj axis. 

3331 

3332 """ 

3333 # assert the other contains trajs 

3334 if len(y.trajs) == 0: # pragma: nocover 

3335 raise Exception( 

3336 f"The `TrajEnsemble` {y} does not contain any trajs and can't " 

3337 f"be used in addition." 

3338 ) 

3339 # decide on the new backend 

3340 if self.backend != y.backend: 

3341 print("Mismatch between the backends. Using 'mdtraj'.") 

3342 y.load_trajs() 

3343 self.load_trajs() 

3344 

3345 if not set(self.traj_nums).isdisjoint(set(y.traj_nums)): 

3346 raise Exception( 

3347 f"Can't add two `TrajEnsemble` with overlapping traj_nums: " 

3348 f"left side: {self.traj_nums}\n" 

3349 f"right side: {y.traj_nums}" 

3350 ) 

3351 

3352 # build a common_str_ array with the correct number of entries 

3353 # use this to create a new class 

3354 # if there are no references in self or y. One of them was created from mdtraj.Trajectories 

3355 if not any([self._top_files + y._top_files]): 

3356 new_class = self.__class__(self.trajs + y.trajs, backend=self.backend) 

3357 else: 

3358 common_str_ = ( 

3359 get_full_common_str_and_ref( 

3360 self.traj_files, self._top_files, self.common_str 

3361 )[2] 

3362 + get_full_common_str_and_ref(y.traj_files, y._top_files, y.common_str)[ 

3363 2 

3364 ] 

3365 ) 

3366 common_str_ = list(dict.fromkeys(common_str_)) 

3367 new_class = self.__class__( 

3368 self.traj_files + y.traj_files, 

3369 self._top_files + y._top_files, 

3370 traj_nums=self.traj_nums + y.traj_nums, 

3371 backend=self.backend, 

3372 common_str=common_str_, 

3373 ) 

3374 # put the trajs directly in the new class. This way the frames of the SingleTraj classes are preserved 

3375 new_class.trajs = self.trajs + y.trajs 

3376 

3377 return new_class 

3378 

3379 def __getattr__(self, attr: str): 

3380 if attr in self.CVs: 

3381 return self.CVs[attr] 

3382 else: 

3383 return self.__getattribute__(attr) 

3384 

3385 def _string_summary(self) -> str: 

3386 if all([i.trajectory for i in self.trajs]): 

3387 s = ( 

3388 f"encodermap.TrajEnsemble object. Current backend is " 

3389 f"{self.backend}. Containing {self.n_frames} frames and " 

3390 f"{self.n_trajs} trajectories." 

3391 ) 

3392 else: 

3393 s = ( 

3394 f"encodermap.TrajEnsemble object. Current backend is " 

3395 f"{self.backend}. Containing {self.n_trajs} trajectories." 

3396 ) 

3397 if self.common_str: 

3398 s += f" Common str is {self.common_str}." 

3399 if self.CVs: 

3400 for key, value in self.CVs.items(): 

3401 s += f" CV {key} with shape {value.shape} loaded." 

3402 else: 

3403 s += " Not containing any CVs." 

3404 return s 

3405 

3406 def __len__(self) -> int: 

3407 return self.n_frames 

3408 

3409 def __str__(self) -> str: 

3410 return self._string_summary() 

3411 

3412 def __repr__(self) -> str: 

3413 return f"<{self._string_summary()} Object at 0x{id(self):02x}>"