Coverage for encodermap/loading/featurizer.py: 53%

395 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/loading/featurizer.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Classes to be used as custom features with pyemma add_custom_feature 

23 

24ToDo: 

25 * Write Docstrings. 

26 * Write Examples. 

27 * Sidechain angles, distances not working correctly. 

28""" 

29 

30 

31################################################################################ 

32# Imports 

33################################################################################ 

34 

35 

36from __future__ import annotations 

37 

38import numpy as np 

39import pandas as pd 

40 

41from .._optional_imports import _optional_import 

42from ..loading import features 

43from ..misc.misc import FEATURE_NAMES, _validate_uri 

44from ..misc.xarray import get_indices_by_feature_dim, unpack_data_and_feature 

45from ..trajinfo.info_all import TrajEnsemble 

46from ..trajinfo.info_single import SingleTraj 

47 

48################################################################################ 

49# Optional Imports 

50################################################################################ 

51 

52 

53featurizer = _optional_import("pyemma", "coordinates.featurizer") 

54source = _optional_import("pyemma", "coordinates.source") 

55xr = _optional_import("xarray") 

56CHI1_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI1_ATOMS") 

57CHI2_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI2_ATOMS") 

58CHI3_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI3_ATOMS") 

59CHI4_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI4_ATOMS") 

60CHI5_ATOMS = _optional_import("mdtraj", "geometry.dihedral.CHI5_ATOMS") 

61Client = _optional_import("dask", "distributed.Client") 

62 

63 

64################################################################################ 

65# Typing 

66################################################################################ 

67 

68 

69from typing import TYPE_CHECKING 

70 

71if TYPE_CHECKING: 

72 import xarray as xr 

73 from dask.distributed import Client 

74 from mdtraj.geometry.dihedral import ( 

75 CHI1_ATOMS, 

76 CHI2_ATOMS, 

77 CHI3_ATOMS, 

78 CHI4_ATOMS, 

79 CHI5_ATOMS, 

80 ) 

81 from pyemma.coordinates import featurizer, source 

82 from pyemma.coordinates.data._base.datasource import DataSource 

83 from pyemma.coordinates.data.featurization._base import Feature 

84 

85 

86################################################################################ 

87# Globals 

88################################################################################ 

89 

90 

91__all__ = ["Featurizer"] 

92 

93 

94UNDERSOCRE_MAPPING = { 

95 "central_dihedrals": "CentralDihedrals", 

96 "all_cartesians": "AllCartesians", 

97 "all_distances": "AllBondDistances", 

98 "central_cartesians": "CentralCartesians", 

99 "central_distances": "CentralBondDistances", 

100 "central_angles": "CentralAngles", 

101 "side_cartesians": "SideChainCartesians", 

102 "side_distances": "SideChainBondDistances", 

103 "side_angles": "SideChainAngles", 

104 "side_dihedrals": "SideChainDihedrals", 

105} 

106 

107 

108############################################################################## 

109# Utils 

110############################################################################## 

111 

112############################################################################## 

113# Classes 

114############################################################################## 

115 

116 

117class Featurizer(type): 

118 def __new__(cls, trajs, in_memory=True): 

119 if in_memory: 119 ↛ 122line 119 didn't jump to line 122, because the condition on line 119 was never false

120 cls = PyEMMAFeaturizer 

121 else: 

122 from .dask_featurizer import DaskFeaturizer 

123 

124 cls = DaskFeaturizer 

125 return cls(trajs) 

126 

127 

128def format_output( 

129 inps: list[DataSource], 

130 feats: list[Feature], 

131 trajs: list[TrajEnsemble], 

132) -> tuple[list[np.ndarray], list[PyEMMAFeaturizer], list[TrajEnsemble]]: 

133 """Formats the output of multiple topologies. 

134 

135 Iterates over the features in `feats` and looks for the feature 

136 with the greatest dimension, i.e. the longest returned describe. This 

137 feature yields the column names, the non-defined values are np.nan 

138 

139 Args: 

140 inps (list[DataSource]): The list of inputs, that 

141 return the values of the feats, when `get_output()` is called. 

142 feats (list[encodermap.loading.Featurizer]: These featurizers collect the 

143 features and will be used to determine the highest length of feats. 

144 trajs (list[encodermap.trajinfo.TrajEnsemble]): List of trajs with 

145 identical topologies. 

146 

147 Returns: 

148 tuple[list[np.ndarray], list[Featurizer], list[TrajEnsembe]: The 

149 data, that `TrajEnsemble` can work with. 

150 

151 """ 

152 

153 class Featurizer_out: 

154 pass 

155 

156 # append to this 

157 all_out = [] 

158 

159 feat_out = Featurizer_out() 

160 feat_out.features = [] 

161 max_feat_lengths = {} 

162 labels = {} 

163 for feat in feats: 

164 for i, f in enumerate(feat.feat.active_features): 

165 name = f.__class__.__name__ 

166 

167 if name not in max_feat_lengths: 

168 max_feat_lengths[name] = 0 

169 feat_out.features.append( 

170 EmptyFeature(name, len(f.describe()), f.describe(), f.indexes) 

171 ) 

172 

173 if name == "SideChainDihedrals": 

174 if name not in labels: 

175 labels[name] = [] 

176 labels[name].extend(f.describe()) 

177 else: 

178 if max_feat_lengths[name] < len(f.describe()): 

179 max_feat_lengths[name] = len(f.describe()) 

180 labels[name] = f.describe() 

181 feat_out.features[i] = EmptyFeature( 

182 name, len(f.describe()), f.describe(), f.indexes 

183 ) 

184 

185 # rejig the sidechain labels 

186 side_key = "SideChainDihedrals" 

187 if side_key in labels: 187 ↛ 203line 187 didn't jump to line 203, because the condition on line 187 was never false

188 labels[side_key] = np.unique(labels[side_key]) 

189 labels[side_key] = sorted( 

190 labels[side_key], key=lambda x: (int(x[-3:]), int(x[13])) 

191 ) 

192 index_of_sidechain_dihedral_features = [ 

193 f.name == side_key for f in feat_out.features 

194 ].index(True) 

195 new_empty_feat = EmptyFeature( 

196 side_key, 

197 len(labels[side_key]), 

198 labels[side_key], 

199 None, 

200 ) 

201 feat_out.features[index_of_sidechain_dihedral_features] = new_empty_feat 

202 

203 for (k, v), f in zip(labels.items(), feat_out.features): 

204 if not len(v) == len(f.describe()) == f._dim: 204 ↛ 205line 204 didn't jump to line 205, because the condition on line 204 was never true

205 raise Exception( 

206 f"Could not consolidate the features of the {f.name} " 

207 f"feature. The `labels` dict, which dictates the size " 

208 f"of the resulting array with np.nan's defines a shape " 

209 f"of {len(v)}, but the feature defines a shape of {len(f.describe())} " 

210 f"(or `f._dim = {f._dim}`). The labels dict gives these labels:\n\n{v}" 

211 f"\n\n, the feature labels gives these labels:\n\n{f.describe()}." 

212 ) 

213 

214 # flatten the labels. These will be the columns for a pandas dataframe. 

215 # At the start the dataframe will be full of np.nan. 

216 # The values of inp.get_output() will then be used in conjunction with 

217 # The labels of the features to fill this dataframe partially 

218 flat_labels = [item for sublist in labels.values() for item in sublist] 

219 if not len(flat_labels) == sum([f._dim for f in feat_out.features]): 219 ↛ 220line 219 didn't jump to line 220, because the condition on line 219 was never true

220 raise Exception( 

221 f"The length of the generic CV labels ({len(flat_labels)} " 

222 f"does not match the length of the labels of the generic features " 

223 f"({[f._dim for f in feat_out.features]})." 

224 ) 

225 

226 # iterate over the sorted trajs, inps, and feats 

227 for inp, feat, sub_trajs in zip(inps, feats, trajs): 

228 # make a flat list for this specific feature space 

229 describe_this_feature = [] 

230 for f in feat.feat.active_features: 

231 # make sure generic labels are used 

232 if f.describe.__func__.__name__ != "generic_describe": 232 ↛ 233line 232 didn't jump to line 233, because the condition on line 232 was never true

233 raise Exception( 

234 f"It seems like this feature: {f.__class__} does not return generic " 

235 f"feature names but topology-specifc ones (generic: 'SIDECHDIH CHI1 1', " 

236 f"topology specific: 'SIDECHDIH CHI1 ASP1'). Normally, encodermap's " 

237 f"features can be instantiated with a `generic_labels=True` flag to " 

238 f"overwrite the features `describe()` method with a `generic_describe()` " 

239 f"method. This changes the `.__func__.__name__` of the `describe()` method " 

240 f"to 'generic_describe'. However the func name for this feature is " 

241 f"{f.describe.__func__.__name__}." 

242 ) 

243 describe_this_feature.extend(f.describe()) 

244 # use the output to fill a pandas dataframe with all labels 

245 out = np.vstack(inp.get_output()) 

246 for o, traj in zip(out, sub_trajs): 

247 df = pd.DataFrame(np.nan, index=range(len(out)), columns=flat_labels) 

248 df = df.assign(**{k: v for k, v in zip(describe_this_feature, out.T)}) 

249 all_out.append((df.to_numpy(), feat_out, traj)) 

250 

251 # make sure the sapes of all df matches 

252 shapes = [o[0].shape[1] for o in all_out] 

253 if not len(list(set(shapes))) == 1: 253 ↛ 254line 253 didn't jump to line 254, because the condition on line 253 was never true

254 raise Exception( 

255 f"Alignment was not possible. Some values exhibit different shapes: " 

256 f"{list(set(shapes))}. All shapes:\n\n{[o[0].shape[1] for o in all_out]}" 

257 ) 

258 return all_out 

259 

260 

261class PyEMMAFeaturizer: 

262 def __init__(self, trajs): 

263 self.trajs = trajs 

264 self._copy_docstrings_from_pyemma() 

265 

266 def _copy_docstrings_from_pyemma(self): 

267 if isinstance(self.feat, list): 

268 feat_ = self.feat[0] 

269 else: 

270 feat_ = self.feat 

271 self.add_all.__func__.__doc__ = feat_.add_all.__doc__ 

272 self.add_selection.__func__.__doc__ = feat_.add_selection.__doc__ 

273 self.add_distances.__func__.__doc__ = feat_.add_distances.__doc__ 

274 self.add_distances_ca.__func__.__doc__ = feat_.add_distances_ca.__doc__ 

275 self.add_inverse_distances.__func__.__doc__ = ( 

276 feat_.add_inverse_distances.__doc__ 

277 ) 

278 self.add_contacts.__func__.__doc__ = feat_.add_contacts.__doc__ 

279 self.add_residue_mindist.__func__.__doc__ = feat_.add_residue_mindist.__doc__ 

280 self.add_group_COM.__func__.__doc__ = feat_.add_group_COM.__doc__ 

281 self.add_residue_COM.__func__.__doc__ = feat_.add_residue_COM.__doc__ 

282 self.add_group_mindist.__func__.__doc__ = feat_.add_group_mindist.__doc__ 

283 self.add_angles.__func__.__doc__ = feat_.add_angles.__doc__ 

284 self.add_dihedrals.__func__.__doc__ = feat_.add_dihedrals.__doc__ 

285 self.add_backbone_torsions.__func__.__doc__ = ( 

286 feat_.add_backbone_torsions.__doc__ 

287 ) 

288 self.add_chi1_torsions.__func__.__doc__ = feat_.add_chi1_torsions.__doc__ 

289 self.add_sidechain_torsions.__func__.__doc__ = ( 

290 feat_.add_sidechain_torsions.__doc__ 

291 ) 

292 self.add_minrmsd_to_ref.__func__.__doc__ = feat_.add_minrmsd_to_ref.__doc__ 

293 

294 def get_output(self) -> xr.Dataset: 

295 if self.mode == "single_top": 

296 if len(self.feat.active_features) == 0: 296 ↛ 297line 296 didn't jump to line 297, because the condition on line 296 was never true

297 print("No features loaded. No output will be returned") 

298 return 

299 

300 if self.mode == "multiple_top": 

301 if len(self.feat[0].features) == 0: 301 ↛ 302line 301 didn't jump to line 302, because the condition on line 301 was never true

302 print("No features loaded. No output will be returned") 

303 return 

304 

305 if self.mode == "single_top": 

306 datasets = [] 

307 out = self.inp.get_output() 

308 for traj, out in zip(self.trajs, out): 

309 datasets.append(unpack_data_and_feature(self, traj, out)) 

310 

311 if len(datasets) == 1: 

312 assert datasets[0].coords["traj_num"] == np.array( 

313 [self.trajs[0].traj_num] 

314 ) 

315 return datasets[0] 

316 else: 

317 out = xr.combine_nested(datasets, concat_dim="traj_num") 

318 if ( 318 ↛ 322line 318 didn't jump to line 322

319 len(out.coords["traj_num"]) != len(self.trajs) 

320 and len(out.coords["traj_num"]) != self.trajs.n_trajs 

321 ): 

322 raise Exception( 

323 f"The combineNnested xarray method returned " 

324 f"a bad dataset, which has {out.coords['traj_num']} " 

325 f"trajectories, but the featurizer has {self.trajs} " 

326 f"trajectories." 

327 ) 

328 # out = xr.concat(datasets, dim='traj_num') 

329 else: 

330 out = format_output(self.inp, self.feat, self.sorted_trajs) 

331 datasets = [unpack_data_and_feature(o[1], o[2], o[0]) for o in out] 

332 out = xr.concat(datasets, dim="traj_num") 

333 

334 return out 

335 

336 def add_list_of_feats(self, which="all"): 

337 """Adds features to the Featurizer to be loaded either in-memory or out-of-memory. 

338 `which` can be either 'all' or a list of the following strings. 'all' will add all of these features: 

339 * 'AllCartesians': Cartesian coordinates of all atoms with shape (n_frames, n_atoms, 3). 

340 * 'AllBondDistances': Bond distances of all bonds recognized by mdtraj. Use top = md.Topology.from_openmm() 

341 if mdtraj does not recognize all bonds. 

342 * 'CentralCartesians': Cartesians of the N, C, CA atoms in the backbone with shape (n_frames, n_residues * 3, 3). 

343 * 'CentralBondDistances': The bond distances of the N, C, CA bonds with shape (n_frames, n_residues * 3 - 1). 

344 * 'CentralAngles': The angles between the backbone bonds with shape (n_frames, n_residues * 3 - 2). 

345 * 'CentralDihedrals': The dihedrals between the backbone atoms (omega, phi, psi). With shape (n_frames, 

346 n_residues * 3 - 3). 

347 * 'SideChainCartesians': Cartesians of the sidechain-atoms. Starting with CB, CG, ... 

348 * 'SideChainBondDistances': Bond distances between the sidechain atoms. starting with the CA-CG bond. 

349 * 'SideChainAngles': Angles between sidechain atoms. Starting with the C-CA-CB angle. 

350 * 'SideChainDihedrals': Dihedrals of the sidechains (chi1, chi2, chi3). 

351 

352 Args: 

353 which (Union[str, list], optional). Either add 'all' features or a list of features. See Above for 

354 possible features. Defaults to 'all'. 

355 

356 """ 

357 if isinstance(which, str): 

358 if which == "all": 

359 which = [ 

360 "CentralCartesians", 

361 "CentralBondDistances", 

362 "CentralAngles", 

363 "CentralDihedrals", 

364 "SideChainDihedrals", 

365 ] 

366 if not isinstance(which, list): 

367 which = [which] 

368 if self.mode == "single_top": 

369 for cf in which: 

370 if cf in UNDERSOCRE_MAPPING: 

371 cf = UNDERSOCRE_MAPPING[cf] 

372 feature = getattr(features, cf)(self.top) 

373 self.feat.add_custom_feature(feature) 

374 else: 

375 for cf in which: 

376 if cf in UNDERSOCRE_MAPPING: 376 ↛ 377line 376 didn't jump to line 377, because the condition on line 376 was never true

377 cf = UNDERSOCRE_MAPPING[cf] 

378 for top, feat in zip(self.top, self.feat): 

379 feature = getattr(features, cf)(top, generic_labels=True) 

380 feat.add_custom_feature(feature) 

381 

382 def add_all(self, *args, **kwargs): 

383 if self.mode == "multiple_top": 383 ↛ 384line 383 didn't jump to line 384, because the condition on line 383 was never true

384 raise Exception( 

385 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

386 ) 

387 self.feat.add_all(*args, **kwargs) 

388 

389 def add_selection(self, *args, **kwargs): 

390 if self.mode == "multiple_top": 390 ↛ 391line 390 didn't jump to line 391, because the condition on line 390 was never true

391 raise Exception( 

392 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

393 ) 

394 self.feat.add_selection(*args, **kwargs) 

395 

396 def add_distances(self, *args, **kwargs): 

397 if self.mode == "multiple_top": 397 ↛ 398line 397 didn't jump to line 398, because the condition on line 397 was never true

398 raise Exception( 

399 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

400 ) 

401 self.feat.add_distances(*args, **kwargs) 

402 

403 def add_distances_ca(self, *args, **kwargs): 

404 if self.mode == "multiple_top": 404 ↛ 405line 404 didn't jump to line 405, because the condition on line 404 was never true

405 raise Exception( 

406 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

407 ) 

408 self.feat.add_distances_ca(*args, **kwargs) 

409 

410 def add_inverse_distances(self, *args, **kwargs): 

411 if self.mode == "multiple_top": 411 ↛ 412line 411 didn't jump to line 412, because the condition on line 411 was never true

412 raise Exception( 

413 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

414 ) 

415 self.feat.add_inverse_distances(*args, **kwargs) 

416 

417 def add_contacts(self, *args, **kwargs): 

418 if self.mode == "multiple_top": 418 ↛ 419line 418 didn't jump to line 419, because the condition on line 418 was never true

419 raise Exception( 

420 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

421 ) 

422 self.feat.add_contacts(*args, **kwargs) 

423 

424 def add_residue_mindist(self, *args, **kwargs): 

425 if self.mode == "multiple_top": 425 ↛ 426line 425 didn't jump to line 426, because the condition on line 425 was never true

426 raise Exception( 

427 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

428 ) 

429 self.feat.add_residue_mindist(*args, **kwargs) 

430 

431 def add_group_COM(self, *args, **kwargs): 

432 if self.mode == "multiple_top": 432 ↛ 433line 432 didn't jump to line 433, because the condition on line 432 was never true

433 raise Exception( 

434 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

435 ) 

436 self.feat.add_group_COM(*args, **kwargs) 

437 

438 def add_residue_COM(self, *args, **kwargs): 

439 if self.mode == "multiple_top": 439 ↛ 440line 439 didn't jump to line 440, because the condition on line 439 was never true

440 raise Exception( 

441 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

442 ) 

443 self.feat.add_residue_COM(*args, **kwargs) 

444 

445 def add_group_mindist(self, *args, **kwargs): 

446 if self.mode == "multiple_top": 

447 raise Exception( 

448 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

449 ) 

450 self.feat.add_group_mindist(*args, **kwargs) 

451 

452 def add_angles(self, *args, **kwargs): 

453 if self.mode == "multiple_top": 453 ↛ 454line 453 didn't jump to line 454, because the condition on line 453 was never true

454 raise Exception( 

455 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

456 ) 

457 self.feat.add_angles(*args, **kwargs) 

458 

459 def add_dihedrals(self, *args, **kwargs): 

460 if self.mode == "multiple_top": 460 ↛ 461line 460 didn't jump to line 461, because the condition on line 460 was never true

461 raise Exception( 

462 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

463 ) 

464 self.feat.add_dihedrals(*args, **kwargs) 

465 

466 def add_backbone_torsions(self, *args, **kwargs): 

467 if self.mode == "multiple_top": 467 ↛ 468line 467 didn't jump to line 468, because the condition on line 467 was never true

468 raise Exception( 

469 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

470 ) 

471 self.feat.add_backbone_torsions(*args, **kwargs) 

472 

473 def add_chi1_torsions(self, *args, **kwargs): 

474 if self.mode == "multiple_top": 474 ↛ 475line 474 didn't jump to line 475, because the condition on line 474 was never true

475 raise Exception( 

476 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

477 ) 

478 self.feat.add_sidechain_torsions(which=["chi1"], *args, **kwargs) 

479 

480 def add_sidechain_torsions(self, *args, **kwargs): 

481 if self.mode == "multiple_top": 

482 raise Exception( 

483 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

484 ) 

485 self.feat.add_sidechain_torsions(*args, **kwargs) 

486 

487 def add_minrmsd_to_ref(self, *args, **kwargs): 

488 if self.mode == "multiple_top": 488 ↛ 489line 488 didn't jump to line 489, because the condition on line 488 was never true

489 raise Exception( 

490 "Using PyEMMA's `add_x` functions is not possible when TrajEnsemble contains multiple topologies." 

491 ) 

492 self.feat.add_minrmsd_to_ref(*args, **kwargs) 

493 

494 def add_custom_feature(self, feature): 

495 self.feat.add_custom_feature(feature) 

496 

497 @property 

498 def features(self): 

499 if self.mode == "single_top": 499 ↛ 502line 499 didn't jump to line 502, because the condition on line 499 was never false

500 return self.feat.active_features 

501 else: 

502 return [f.features for f in self.feat] 

503 

504 @property 

505 def sorted_info_single(self): 

506 if self.mode == "single_top": 

507 raise Exception( 

508 "Attribute is only accessible, when working with mutliple topologies." 

509 ) 

510 out = [] 

511 for info_all in self.sorted_trajs: 

512 for traj in info_all: 

513 out.append(traj) 

514 return out 

515 

516 @property 

517 def sorted_featurizers(self): 

518 if self.mode == "single_top": 

519 raise Exception( 

520 "Attribute is only accessible, when working with mutliple topologies." 

521 ) 

522 out = [] 

523 for feat, info_all in zip(self.feat, self.sorted_trajs): 

524 out.extend([feat for i in range(info_all.n_trajs)]) 

525 return out 

526 

527 @property 

528 def trajs(self): 

529 return self._trajs 

530 

531 def describe(self): 

532 return self.feat.describe() 

533 

534 @trajs.setter 

535 def trajs(self, trajs): 

536 if isinstance(trajs, SingleTraj) or trajs.__class__.__name__ == "SingleTraj": 

537 self._trajs = trajs._gen_ensemble() 

538 self.top = trajs.top 

539 self.feat = featurizer(self.top) 

540 if _validate_uri(trajs.traj_file): 540 ↛ 541line 540 didn't jump to line 541, because the condition on line 540 was never true

541 self.inp = source([trajs.xyz], features=self.feat) 

542 else: 

543 try: 

544 self.inp = source([trajs.traj_file], features=self.feat) 

545 except Exception: 

546 print(trajs.traj_file) 

547 print(trajs.top_file) 

548 raise 

549 self.mode = "single_top" 

550 elif ( 550 ↛ 582line 550 didn't jump to line 582

551 isinstance(trajs, TrajEnsemble) 

552 or trajs.__class__.__name__ == "TrajEnsemble" 

553 ): 

554 if len(trajs.top) > 1: 

555 self._trajs = trajs 

556 # self.top = Topologies(trajs.top) 

557 self.top = trajs.top 

558 self.sorted_trajs = [] 

559 for top in trajs.top: 

560 matching_trajs = list( 

561 filter(lambda x: True if x.top == top else False, trajs) 

562 ) 

563 self.sorted_trajs.append(TrajEnsemble(matching_trajs)) 

564 self.feat = [Featurizer(t) for t in self.sorted_trajs] 

565 self.inp = [ 

566 source([t.traj_file for t in t_subset], features=feat.feat) 

567 for t_subset, feat in zip(self.sorted_trajs, self.feat) 

568 ] 

569 self.mode = "multiple_top" 

570 else: 

571 self._trajs = trajs 

572 self.top = trajs.top[0] 

573 self.feat = featurizer(self.top) 

574 if all([_validate_uri(traj.traj_file) for traj in trajs]): 574 ↛ 575line 574 didn't jump to line 575, because the condition on line 574 was never true

575 self.inp = source(trajs.xtc, features=self.feat) 

576 else: 

577 self.inp = source( 

578 [traj.traj_file for traj in trajs], features=self.feat 

579 ) 

580 self.mode = "single_top" 

581 else: 

582 raise TypeError( 

583 f"trajs must be {SingleTraj.__class__} or {TrajEnsemble.__class__}, you provided {trajs.__class__}" 

584 ) 

585 

586 def __len__(self): 

587 if self.mode == "single_top": 587 ↛ 590line 587 didn't jump to line 590, because the condition on line 587 was never false

588 return len(self.feat.active_features) 

589 else: 

590 return len([f.features for f in self.feat]) 

591 

592 def __str__(self): 

593 if self.mode == "single_top": 

594 return self.feat.__str__() 

595 else: 

596 return ", ".join([f.__str__() for f in self.feat]) 

597 

598 def __repr__(self): 

599 if self.mode == "single_top": 

600 return self.feat.__repr__() 

601 else: 

602 return ", ".join([f.__repr__() for f in self.feat]) 

603 

604 

605class EmptyFeature: 

606 """Class to fill with attributes to be read by encodermap.xarray. 

607 

608 This class will be used in multiple_top mode, where the attributes 

609 _dim, describe and name will be overwritten with correct values to 

610 build features that contain NaN values. 

611 

612 """ 

613 

614 def __init__(self, name, _dim, description, indexes): 

615 """Initialize the Empty feature. 

616 

617 Args: 

618 name (str): The name of the feature. 

619 _dim (int): The feature length of the feature shape=(n_frames, ferature). 

620 description (list of str): The description for every feature. 

621 

622 """ 

623 self.name = name 

624 self._dim = _dim 

625 self.description = description 

626 self.indexes = indexes 

627 

628 def describe(self): 

629 return self.description 

630 

631 

632class Topologies: 

633 def __init__(self, tops, alignments=None): 

634 self.tops = tops 

635 if alignments is None: 

636 alignments = [ 

637 "side_dihedrals", 

638 "central_cartesians", 

639 "central_distances", 

640 "central_angles", 

641 "central_dihedrals", 

642 ] 

643 self.alignments = {k: {} for k in alignments} 

644 self.compare_tops() 

645 allowed_strings = list( 

646 filter( 

647 lambda x: True if "side" in x else False, 

648 (k for k in UNDERSOCRE_MAPPING.keys()), 

649 ) 

650 ) 

651 if not all([i in allowed_strings for i in alignments]): 

652 raise Exception( 

653 f"Invalid alignment string in `alignments`. Allowed strings are {allowed_strings}" 

654 ) 

655 

656 def compare_tops(self): 

657 if not all([t.n_residues == self.tops[0].n_residues for t in self.tops]): 

658 raise Exception( 

659 "Using Different Topologies currenlty only works if all contain the same number of residues." 

660 ) 

661 generators = [t.residues for t in self.tops] 

662 sidechains = [t.select("sidechain") for t in self.tops] 

663 all_bonds = [ 

664 list(map(lambda x: (x[0].index, x[1].index), t.bonds)) for t in self.tops 

665 ] 

666 

667 # iterate over residues of the sequences 

668 n_res_max = max([t.n_residues for t in self.tops]) 

669 for i in range(n_res_max): 

670 # get some info 

671 residues = [next(g) for g in generators] 

672 all_atoms = [[a.name for a in r.atoms] for r in residues] 

673 atoms = [ 

674 list( 

675 filter( 

676 lambda x: True 

677 if x.index in sel and "H" not in x.name and "OXT" not in x.name 

678 else False, 

679 r.atoms, 

680 ) 

681 ) 

682 for r, sel in zip(residues, sidechains) 

683 ] 

684 atoms_indices = [[a.index for a in atoms_] for atoms_ in atoms] 

685 bonds = [ 

686 list( 

687 filter( 

688 lambda bond: True if any([b in ai for b in bond]) else False, ab 

689 ) 

690 ) 

691 for ai, ab in zip(atoms_indices, all_bonds) 

692 ] 

693 

694 # reduce the integers of atoms_indices and bonds, so that N is 0. That way, we can compare them, even, when 

695 # two amino aicds in the chains are different 

696 N_indices = [ 

697 list(filter(lambda x: True if x.name == "N" else False, r.atoms))[ 

698 0 

699 ].index 

700 for r in residues 

701 ] 

702 

703 # align to respective N 

704 atoms_indices = [ 

705 [x - N for x in y] for y, N in zip(atoms_indices, N_indices) 

706 ] 

707 bonds = [ 

708 [(x[0] - N, x[1] - N) for x in y] for y, N in zip(bonds, N_indices) 

709 ] 

710 

711 chi1 = [any(set(l).issubset(set(a)) for l in CHI1_ATOMS) for a in all_atoms] 

712 chi2 = [any(set(l).issubset(set(a)) for l in CHI2_ATOMS) for a in all_atoms] 

713 chi3 = [any(set(l).issubset(set(a)) for l in CHI3_ATOMS) for a in all_atoms] 

714 chi4 = [any(set(l).issubset(set(a)) for l in CHI4_ATOMS) for a in all_atoms] 

715 chi5 = [any(set(l).issubset(set(a)) for l in CHI5_ATOMS) for a in all_atoms] 

716 chi = np.array([chi1, chi2, chi3, chi4, chi5]) 

717 

718 self.alignments["side_dihedrals"][f"residue_{i}"] = chi 

719 

720 if "side_cartesians" in self.alignments: 

721 raise NotImplementedError( 

722 "Cartesians between different topologies can currently not be aligned." 

723 ) 

724 

725 if "side_distances" in self.alignments: 

726 raise NotImplementedError( 

727 "Distances between different topologies can currently not be aligned." 

728 ) 

729 

730 if "side_angles" in self.alignments: 

731 raise NotImplementedError( 

732 "Angles between different topologies can currently not be aligned." 

733 ) 

734 

735 self.drop_double_false() 

736 

737 def drop_double_false(self): 

738 """Drops features that None of the topologies have. 

739 

740 For example: Asp and Glu. Asp has a chi1 and chi2 torsion. Glu has chi1, chi2 and chi3. Both 

741 don't have chi4 or chi5. In self.compare_tops these dihedrals are still considered. In this 

742 method they will be removed. 

743 

744 """ 

745 for alignment, value in self.alignments.items(): 

746 for residue, array in value.items(): 

747 where = np.where(np.any(array, axis=1))[0] 

748 self.alignments[alignment][residue] = array[where] 

749 

750 def get_max_length(self, alignment): 

751 """Maximum length that a feature should have given a certain axis. 

752 

753 Args: 

754 alignment (str): The key for `self.alignments`. 

755 

756 """ 

757 alignment_dict = self.alignments[alignment] 

758 stacked = np.vstack([v for v in alignment_dict.values()]) 

759 counts = np.count_nonzero(stacked, axis=0) # Flase is 0 

760 return np.max(counts) 

761 

762 def format_output(self, inputs, feats, sorted_trajs): 

763 """Formats the output of an em.Featurizer object using the alignment info. 

764 

765 Args: 

766 inputs (list): List of pyemma.coordinates.data.feature_reader.FeatureReader objects. 

767 feats (list): List of encodermap.Featurizer objetcs. 

768 sorted_trajs (list): List of em.TrajEnsemble objects sorted in the same way as `self.tops`. 

769 

770 """ 

771 out = [] 

772 for i, (inp, top, feat, trajs) in enumerate( 

773 zip(inputs, self.tops, feats, sorted_trajs) 

774 ): 

775 value_dict = {} 

776 for traj_ind, (data, traj) in enumerate(zip(inp.get_output(), trajs)): 

777 if any( 

778 [isinstance(o, EmptyFeature) for o in feat.feat.active_features] 

779 ): 

780 from ..misc.xarray import add_one_by_one 

781 

782 ffunc = lambda x: True if "NaN" not in x else False 

783 indices = [0] + add_one_by_one( 

784 [len(list(filter(ffunc, f.describe()))) for f in feat.features] 

785 ) 

786 else: 

787 indices = get_indices_by_feature_dim(feat, traj, data.shape) 

788 

789 # divide the values returned by PyEMMA 

790 for f, ind in zip(feat.features, indices): 

791 try: 

792 name = FEATURE_NAMES[f.name] 

793 except KeyError: 

794 name = f.__class__.__name__ 

795 f.name = name 

796 except AttributeError: 

797 name = f.__class__.__name__ 

798 f.name = name 

799 if traj_ind == 0: 

800 value_dict[name] = [] 

801 value_dict[name].append(data[:, ind]) 

802 

803 # stack along the frame axis, just like pyemma would 

804 value_dict = {k: np.vstack(v) for k, v in value_dict.items()} 

805 

806 # put nans in all features specified by alignment 

807 for alignment, alignment_dict in self.alignments.items(): 

808 if alignment not in value_dict: 

809 continue 

810 max_length = self.get_max_length(alignment) 

811 new_values = np.full( 

812 shape=(value_dict[alignment].shape[0], max_length), 

813 fill_value=np.nan, 

814 ) 

815 where = np.vstack([v for v in alignment_dict.values()])[:, i] 

816 new_values[:, where] = value_dict[alignment] 

817 value_dict[alignment] = new_values 

818 

819 # find the index of the feature in feat.feat.active_features 

820 names = np.array( 

821 [f.__class__.__name__ for f in feat.feat.active_features] 

822 ) 

823 index = np.where([n in FEATURE_NAMES for n in names])[0] 

824 index = index[ 

825 np.where([FEATURE_NAMES[n] == alignment for n in names[index]]) 

826 ] 

827 

828 # get the old description and change it around 

829 assert len(index) == 1 

830 index = index[0] 

831 if not isinstance(feat.feat.active_features[index], EmptyFeature): 

832 old_desc = np.array( 

833 [i for i in feat.feat.active_features[index].describe()] 

834 ) 

835 new_desc = np.array( 

836 [ 

837 f"NaN due to ensemble with other topologies {i}" 

838 for i in range(max_length) 

839 ] 

840 ) 

841 new_desc[where] = old_desc 

842 new_desc = new_desc.tolist() 

843 

844 # get the old indexes and add the NaNs 

845 old_indexes = feat.feat.active_features[index].indexes 

846 new_indexes = np.full( 

847 shape=(max_length, old_indexes.shape[1]), fill_value=np.nan 

848 ) 

849 new_indexes[where] = old_indexes 

850 

851 # create empty feature 

852 new_class = EmptyFeature( 

853 alignment, max_length, new_desc, new_indexes 

854 ) 

855 feat.feat.active_features[index] = new_class 

856 

857 new_values = np.hstack([v for v in value_dict.values()]) 

858 out.append([new_values, feat, trajs]) 

859 return out 

860 

861 def format_output_generator(self, inputs, feats, sorted_trajs): 

862 raise NotImplementedError("Will be implemented once tfRecords are implemented.") 

863 

864 def __iter__(self): 

865 self._index = 0 

866 return self 

867 

868 def __next__(self): 

869 if self._index >= len(self.tops): 

870 raise StopIteration 

871 else: 

872 self._index += 1 

873 return self.tops[self._index - 1]