Coverage for encodermap/loading/features.py: 67%

467 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/loading/features.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Patricia Schwarz 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Classes to be used as custom features with pyemma add_custom_feature 

23 

24ToDo: 

25 * Write tests 

26 * Put the describe_last_feats function into utils. 

27 * Add Nan feature. 

28 * Write Examples. 

29 

30""" 

31 

32############################################################################## 

33# Imports 

34############################################################################## 

35 

36 

37from __future__ import annotations 

38 

39import copy 

40import itertools 

41from typing import TYPE_CHECKING, Callable 

42 

43import numpy as np 

44 

45import encodermap 

46 

47from .._optional_imports import _optional_import 

48 

49############################################################################## 

50# Typing 

51############################################################################## 

52 

53 

54if TYPE_CHECKING: 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true

55 from mdtraj.core.residue_names import _AMINO_ACID_CODES 

56 from pyemma.coordinates.data.featurization.angles import ( 

57 AngleFeature, 

58 DihedralFeature, 

59 ) 

60 from pyemma.coordinates.data.featurization.distances import DistanceFeature 

61 from pyemma.coordinates.data.featurization.misc import ( 

62 CustomFeature, 

63 SelectionFeature, 

64 ) 

65 

66 from encodermap._typing import AnyFeature 

67 

68 

69############################################################################## 

70# Optional Imports 

71############################################################################## 

72 

73 

74CustomFeature: CustomFeature = _optional_import( 

75 "pyemma", "coordinates.data.featurization.misc.CustomFeature" 

76) 

77SelectionFeature: SelectionFeature = _optional_import( 

78 "pyemma", "coordinates.data.featurization.misc.SelectionFeature" 

79) 

80DihedralFeature: DihedralFeature = _optional_import( 

81 "pyemma", "coordinates.data.featurization.angles.DihedralFeature" 

82) 

83AngleFeature: AngleFeature = _optional_import( 

84 "pyemma", "coordinates.data.featurization.angles.AngleFeature" 

85) 

86DistanceFeature: DistanceFeature = _optional_import( 

87 "pyemma", "coordinates.data.featurization.distances.DistanceFeature" 

88) 

89_AMINO_ACID_CODES: dict = _optional_import( 

90 "mdtraj", "core.residue_names._AMINO_ACID_CODES" 

91) 

92indices_phi: Callable = _optional_import("mdtraj", "geometry.dihedral.indices_phi") 

93indices_psi: Callable = _optional_import("mdtraj", "geometry.dihedral.indices_psi") 

94indices_omega: Callable = _optional_import("mdtraj", "geometry.dihedral.indices_omega") 

95indices_chi1: Callable = _optional_import("mdtraj", "geometry.dihedral.indices_chi1") 

96indices_chi2: Callable = _optional_import("mdtraj", "geometry.dihedral.indices_chi2") 

97indices_chi3: Callable = _optional_import("mdtraj", "geometry.dihedral.indices_chi3") 

98indices_chi4: Callable = _optional_import("mdtraj", "geometry.dihedral.indices_chi4") 

99indices_chi5: Callable = _optional_import("mdtraj", "geometry.dihedral.indices_chi5") 

100 

101 

102############################################################################## 

103# Globals 

104############################################################################## 

105 

106__all__ = [ 

107 "AllCartesians", 

108 "AllBondDistances", 

109 "CentralCartesians", 

110 "CentralBondDistances", 

111 "CentralAngles", 

112 "CentralDihedrals", 

113 "SideChainCartesians", 

114 "SideChainBondDistances", 

115 "SideChainAngles", 

116 "SideChainDihedrals", 

117] 

118 

119############################################################################## 

120# Functions 

121############################################################################## 

122 

123 

124def describe_last_feats(feat: AnyFeature, n: int = 5) -> None: 

125 """Prints the description of the last `n` features. 

126 

127 Args: 

128 feat (encodermap.Featurizer): An instance of a featurizer. 

129 n (Optional[int]): The number of last features to decribe. Defaults to 5. 

130 

131 """ 

132 for i, lbl in enumerate(feat.describe()[-n:]): 

133 print(lbl) 

134 

135 

136def add_KAC_backbone_bonds(top): 

137 """Adds acetylated Lysine specific backbone bonds to mdtraj.Topology. 

138 

139 Args: 

140 top (mdtraj.Topology): The topology to be extended. 

141 

142 Returns: 

143 mdtraj.Topology: The new topology with added bonds. 

144 

145 Note: 

146 The bonds are currently not at the correct index, i.e. they are 

147 at the very end of top.bonds and not at the correct position. 

148 

149 """ 

150 # for index, bond in enumerate(top.bonds): 

151 # if any([a.residue.name == 'KAC' for a in bond]): 

152 # break 

153 resid_KAC = top.atom(top.select("resname KAC")[0]).residue.index 

154 # add C - N(KAC) 

155 bond = ( 

156 top.select(f"name C and resid {resid_KAC - 1}")[0], 

157 top.select(f"name N and resid {resid_KAC}")[0], 

158 ) 

159 bond = [top.atom(b) for b in bond] 

160 top.add_bond(*bond) 

161 # add N (KAC) - CA(KAC) 

162 bond = ( 

163 top.select(f"name N and resid {resid_KAC}")[0], 

164 top.select(f"name CA and resid {resid_KAC}")[0], 

165 ) 

166 bond = [top.atom(b) for b in bond] 

167 top.add_bond(*bond) 

168 # add CA (KAC) - C(KAC) 

169 bond = ( 

170 top.select(f"name CA and resid {resid_KAC}")[0], 

171 top.select(f"name C and resid {resid_KAC}")[0], 

172 ) 

173 bond = [top.atom(b) for b in bond] 

174 top.add_bond(*bond) 

175 return top 

176 

177 

178def add_KAC_sidechain_bonds(top): 

179 """Adds acetylated Lysine specific side chain bonds to mdtraj.Topology. Bonds between 

180 indented atoms are added: 

181 KAC11-N 102 

182 KAC11-H 103 

183 KAC11-CA 104 

184 KAC11-CB 105 

185 KAC11-CG 106 

186 KAC11-CD 107 

187 KAC11-CE 108 

188 KAC11-NZ 109 

189 KAC11-HZ 110 

190 KAC11-CH 111 

191 KAC11-OI2 112 

192 KAC11-CI1 113 

193 KAC11-C 114 

194 KAC11-O 115 

195 

196 Args: 

197 top (mdtraj.Topology): The topology to be extended. 

198 

199 Returns: 

200 mdtraj.Topology: The new topology with added bonds. 

201 

202 Note: 

203 The bonds are currently not at the correct index, i.e. they are 

204 at the very end of top.bonds and not at the correct position. 

205 """ 

206 # for r in top.residues: 

207 # if r.name == 'KAC': 

208 # for a in r.atoms: 

209 # print(a, a.index) 

210 # print(a.__dir__()) 

211 resid_KAC = top.atom(top.select("resname KAC")[0]).residue.index 

212 # add CA - CB(KAC) 

213 bond = ( 

214 top.select(f"name CA and resid {resid_KAC - 1}")[0], 

215 top.select(f"name CB and resid {resid_KAC}")[0], 

216 ) 

217 bond = [top.atom(b) for b in bond] 

218 top.add_bond(*bond) 

219 # add CB (KAC) - CG(KAC) 

220 bond = ( 

221 top.select(f"name CB and resid {resid_KAC}")[0], 

222 top.select(f"name CG and resid {resid_KAC}")[0], 

223 ) 

224 bond = [top.atom(b) for b in bond] 

225 top.add_bond(*bond) 

226 # add CG (KAC) - CE(KAC) 

227 bond = ( 

228 top.select(f"name CG and resid {resid_KAC}")[0], 

229 top.select(f"name CE and resid {resid_KAC}")[0], 

230 ) 

231 bond = [top.atom(b) for b in bond] 

232 top.add_bond(*bond) 

233 # add CE (KAC) - NZ(KAC) 

234 bond = ( 

235 top.select(f"name CE and resid {resid_KAC}")[0], 

236 top.select(f"name NZ and resid {resid_KAC}")[0], 

237 ) 

238 bond = [top.atom(b) for b in bond] 

239 top.add_bond(*bond) 

240 return top 

241 

242 

243############################################################################## 

244# Classes 

245############################################################################## 

246 

247 

248class CentralDihedrals(DihedralFeature): 

249 """Feature that collects all dihedrals in the backbone of a topology. 

250 

251 Attributes: 

252 top (mdtraj.Topology): Topology of this feature. 

253 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

254 

255 """ 

256 

257 __serialize_version = 0 

258 __serialize_fields = ("selstr", "_phi_inds", "_psi_inds", "_omega_inds") 

259 

260 def __init__( 

261 self, 

262 topology, 

263 selstr=None, 

264 deg=False, 

265 cossin=False, 

266 periodic=True, 

267 omega=True, 

268 generic_labels=False, 

269 ): 

270 """Instantiate this feature class. 

271 

272 Args: 

273 topology (mdtraj.Topology): A topology to build features from. 

274 selstr (Optional[str]): A string, that limits the selection of dihedral angles. 

275 Only dihedral angles which atoms are represented by the `selstr` argument 

276 are considered. This selection string follows MDTraj's atom selection 

277 language: https://mdtraj.org/1.9.3/atom_selection.html. Can also 

278 be None, in which case all backbone dihedrals (also omega) are 

279 considered. Defaults to None. 

280 deg (bool): Whether to return the result in degree (`deg=True`) or in 

281 radians (`deg=False`). Defaults to radions. 

282 cossin (bool): If True, each angle will be returned as a pair of 

283 (sin(x), cos(x)). This is useful, if you calculate the mean 

284 (e.g TICA/PCA, clustering) in that space. Defaults to False. 

285 periodic (bool): Whether to recognize periodic boundary conditions and 

286 work under the minimum image convention. Defaults to True. 

287 

288 """ 

289 self.top = topology 

290 self.selstr = selstr 

291 

292 indices = indices_psi(self.top) 

293 if not selstr: 293 ↛ 296line 293 didn't jump to line 296, because the condition on line 293 was never false

294 self._psi_inds = indices 

295 else: 

296 self._psi_inds = indices[ 

297 np.in1d(indices[:, 1], self.top.select(selstr), assume_unique=True) 

298 ] 

299 

300 self.omega = omega 

301 if self.omega: 301 ↛ 310line 301 didn't jump to line 310, because the condition on line 301 was never false

302 indices = indices_omega(self.top) 

303 if not selstr: 303 ↛ 306line 303 didn't jump to line 306, because the condition on line 303 was never false

304 self._omega_inds = indices 

305 else: 

306 self._omega_inds = indices[ 

307 np.in1d(indices[:, 1], self.top.select(selstr), assume_unique=True) 

308 ] 

309 

310 indices = indices_phi(self.top) 

311 if not selstr: 311 ↛ 314line 311 didn't jump to line 314, because the condition on line 311 was never false

312 self._phi_inds = indices 

313 else: 

314 self._phi_inds = indices[ 

315 np.in1d(indices[:, 1], self.top.select(selstr), assume_unique=True) 

316 ] 

317 

318 if self.omega: 318 ↛ 321line 318 didn't jump to line 321, because the condition on line 318 was never false

319 zipped = zip(self._psi_inds, self._omega_inds, self._phi_inds) 

320 else: 

321 zipped = zip(self._psi_inds, self._phi_inds) 

322 

323 # alternate phi, psi , omega pairs (phi_1, psi_1, omega_1..., phi_n, psi_n, omega_n) 

324 dih_indexes = np.array(list(psi_omega_phi for psi_omega_phi in zipped)).reshape( 

325 -1, 4 

326 ) 

327 

328 # set generic_labels for xarray 

329 if generic_labels: 

330 self.describe = self.generic_describe 

331 

332 super(CentralDihedrals, self).__init__( 

333 self.top, dih_indexes, deg=deg, cossin=cossin, periodic=periodic 

334 ) 

335 

336 @property 

337 def name(self): 

338 """str: The name of the class: "CentralDihedrals".""" 

339 return "CentralDihedrals" 

340 

341 @property 

342 def indexes(self): 

343 """np.ndarray: A (n_angles, 4) shaped numpy array giving the atom indices 

344 of the dihedral angles to be calculated.""" 

345 return self.angle_indexes.astype("int32") 

346 

347 def generic_describe(self): 

348 """Returns a list of generic labels, not containing residue names. 

349 These can be used to stack tops of different topology. 

350 

351 Returns: 

352 list[str]: A list of labels. 

353 

354 """ 

355 if self.cossin: 355 ↛ 356line 355 didn't jump to line 356, because the condition on line 355 was never true

356 sin_cos = ("COS(PSI %s)", "SIN(PSI %s)") 

357 labels_psi = [ 

358 ( 

359 sin_cos[0] % i, 

360 sin_cos[1] % i, 

361 ) 

362 for i in range(len(self._psi_inds)) 

363 ] 

364 if self.omega: 

365 sin_cos = ("COS(OMEGA %s)", "SIN(OMEGA %s)") 

366 labels_omega = [ 

367 ( 

368 sin_cos[0] % i, 

369 sin_cos[1] % i, 

370 ) 

371 for i in range(len(self._omega_inds)) 

372 ] 

373 sin_cos = ("COS(PHI %s)", "SIN(PHI %s)") 

374 labels_phi = [ 

375 ( 

376 sin_cos[0] % i, 

377 sin_cos[1] % i, 

378 ) 

379 for i in range(len(self._phi_inds)) 

380 ] 

381 # produce the same ordering as the given indices (phi_1, psi_1, ..., phi_n, psi_n) 

382 # or (cos(phi_1), sin(phi_1), cos(psi_1), sin(psi_1), ..., cos(phi_n), sin(phi_n), cos(psi_n), sin(psi_n)) 

383 if self.omega: 

384 zipped = zip(labels_psi, labels_omega, labels_phi) 

385 else: 

386 zip(labels_psi, labels_phi) 

387 

388 res = list( 

389 itertools.chain.from_iterable(itertools.chain.from_iterable(zipped)) 

390 ) 

391 else: 

392 labels_psi = [ 

393 f"CENTERDIH PSI %s" % i for i in range(len(self._psi_inds)) 

394 ] 

395 if self.omega: 395 ↛ 399line 395 didn't jump to line 399, because the condition on line 395 was never false

396 labels_omega = [ 

397 "CENTERDIH OMEGA %s" % i for i in range(len(self._omega_inds)) 

398 ] 

399 labels_phi = ["CENTERDIH PHI %s" % i for i in range(len(self._phi_inds))] 

400 if self.omega: 400 ↛ 403line 400 didn't jump to line 403, because the condition on line 400 was never false

401 zipped = zip(labels_psi, labels_omega, labels_phi) 

402 else: 

403 zipped = zip(labels_psi, labels_phi) 

404 res = list(itertools.chain.from_iterable(zipped)) 

405 return res 

406 

407 def describe(self): 

408 """Returns a list of labels, that can be used to unambiguously define 

409 atoms in the protein topology. 

410 

411 Returns: 

412 list[str]: A list of labels. This list has as many entries as atoms in `self.top`. 

413 

414 """ 

415 top = self.top 

416 getlbl = ( 

417 lambda at: f"RESID {at.residue.name}:{at.residue.resSeq:>4} CHAIN {at.residue.chain.index}" 

418 ) 

419 

420 if self.cossin: 420 ↛ 421line 420 didn't jump to line 421, because the condition on line 420 was never true

421 sin_cos = ("COS(PSI %s)", "SIN(PSI %s)") 

422 labels_psi = [ 

423 ( 

424 sin_cos[0] % getlbl(top.atom(ires[1])), 

425 sin_cos[1] % getlbl(top.atom(ires[1])), 

426 ) 

427 for ires in self._psi_inds 

428 ] 

429 if self.omega: 

430 sin_cos = ("COS(OMEGA %s)", "SIN(OMEGA %s)") 

431 labels_omega = [ 

432 ( 

433 sin_cos[0] % getlbl(top.atom(ires[1])), 

434 sin_cos[1] % getlbl(top.atom(ires[1])), 

435 ) 

436 for ires in self._omega_inds 

437 ] 

438 sin_cos = ("COS(PHI %s)", "SIN(PHI %s)") 

439 labels_phi = [ 

440 ( 

441 sin_cos[0] % getlbl(top.atom(ires[1])), 

442 sin_cos[1] % getlbl(top.atom(ires[1])), 

443 ) 

444 for ires in self._phi_inds 

445 ] 

446 # produce the same ordering as the given indices (phi_1, psi_1, ..., phi_n, psi_n) 

447 # or (cos(phi_1), sin(phi_1), cos(psi_1), sin(psi_1), ..., cos(phi_n), sin(phi_n), cos(psi_n), sin(psi_n)) 

448 if self.omega: 

449 zipped = zip(labels_psi, labels_omega, labels_phi) 

450 else: 

451 zip(labels_psi, labels_phi) 

452 

453 res = list( 

454 itertools.chain.from_iterable(itertools.chain.from_iterable(zipped)) 

455 ) 

456 else: 

457 labels_psi = [ 

458 f"CENTERDIH PSI " + getlbl(top.atom(ires[1])) 

459 for ires in self._psi_inds 

460 ] 

461 if self.omega: 461 ↛ 466line 461 didn't jump to line 466, because the condition on line 461 was never false

462 labels_omega = [ 

463 "CENTERDIH OMEGA " + getlbl(top.atom(ires[1])) 

464 for ires in self._omega_inds 

465 ] 

466 labels_phi = [ 

467 "CENTERDIH PHI " + getlbl(top.atom(ires[1])) 

468 for ires in self._phi_inds 

469 ] 

470 if self.omega: 470 ↛ 473line 470 didn't jump to line 473, because the condition on line 470 was never false

471 zipped = zip(labels_psi, labels_omega, labels_phi) 

472 else: 

473 zipped = zip(labels_psi, labels_phi) 

474 res = list(itertools.chain.from_iterable(zipped)) 

475 return res 

476 

477 @property 

478 def dask_transform(self): 

479 return "dihedral" 

480 

481 

482class SideChainDihedrals(DihedralFeature): 

483 """Feature that collects all dihedrals in the backbone of a topology. 

484 

485 Attributes: 

486 top (mdtraj.Topology): Topology of this feature. 

487 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

488 options (list[str]): A list of possible sidechain angles ['chi1' to 'chi5']. 

489 

490 """ 

491 

492 __serialize_version: int = 0 

493 __serialize_fields: tuple[str] = ("_prefix_label_lengths",) 

494 options: list[str] = ["chi1", "chi2", "chi3", "chi4", "chi5"] 

495 

496 def __init__( 

497 self, 

498 top, 

499 selstr=None, 

500 deg=False, 

501 cossin=False, 

502 periodic=True, 

503 generic_labels=False, 

504 ): 

505 which = self.options 

506 # get all dihedral index pairs 

507 from mdtraj.geometry import dihedral 

508 

509 indices_dict = {k: getattr(dihedral, "indices_%s" % k)(top) for k in which} 

510 if selstr: 510 ↛ 511line 510 didn't jump to line 511, because the condition on line 510 was never true

511 selection = top.select(selstr) 

512 truncated_indices_dict = {} 

513 for k, inds in indices_dict.items(): 

514 mask = np.in1d(inds[:, 1], selection, assume_unique=True) 

515 truncated_indices_dict[k] = inds[mask] 

516 indices_dict = truncated_indices_dict 

517 

518 valid = {k: indices_dict[k] for k in indices_dict if indices_dict[k].size > 0} 

519 if not valid: 

520 raise ValueError( 

521 "Could not determine any side chain dihedrals for your topology!" 

522 ) 

523 

524 # for key in indices_dict: 

525 # print(key, indices_dict[key]) 

526 # for proteins that don't have some chi angles we filter which 

527 which = list( 

528 filter( 

529 lambda x: True if len(indices_dict[x]) > 0 else False, 

530 indices_dict.keys(), 

531 ) 

532 ) 

533 

534 # change the sorting to be per-residue and not all chi1 and then all chi2 angles 

535 self.per_res_dict = {} 

536 for r in top.residues: 

537 arrs = [] 

538 bools = [] 

539 for k in which: 

540 if np.any(np.in1d(valid[k], np.array([a.index for a in r.atoms]))): 

541 where = np.where( 

542 np.in1d( 

543 valid[k].flatten(), np.array([a.index for a in r.atoms]) 

544 ) 

545 )[0] 

546 arr = valid[k].flatten()[where] 

547 bools.append(True) 

548 arrs.append(arr) 

549 else: 

550 bools.append(False) 

551 if any(bools): 

552 self.per_res_dict[str(r)] = np.vstack(arrs) 

553 

554 self._prefix_label_lengths = np.array( 

555 [len(indices_dict[k]) if k in which else 0 for k in self.options] 

556 ) 

557 indices = np.vstack([v for v in self.per_res_dict.values()]) 

558 

559 super(SideChainDihedrals, self).__init__( 

560 top=top, dih_indexes=indices, deg=deg, cossin=cossin, periodic=periodic 

561 ) 

562 

563 if generic_labels: 

564 self.describe = self.generic_describe 

565 

566 @property 

567 def name(self): 

568 """str: The name of the class: "SideChainDihedrals".""" 

569 return "SideChainDihedrals" 

570 

571 @property 

572 def indexes(self): 

573 """np.ndarray: A (n_angles, 4) shaped numpy array giving the atom indices 

574 of the dihedral angles to be calculated.""" 

575 return self.angle_indexes 

576 

577 def generic_describe(self): 

578 top = self.top 

579 getlbl = ( 

580 lambda at: f"RESID {at.residue.name}:{at.residue.resSeq:>4} CHAIN {at.residue.chain.index}" 

581 ) 

582 prefixes = [] 

583 for lengths, label in zip(self._prefix_label_lengths, self.options): 

584 if self.cossin: 584 ↛ 585line 584 didn't jump to line 585, because the condition on line 584 was never true

585 lengths *= 2 

586 prefixes.extend([label.upper()] * lengths) 

587 prefixes = [] 

588 for key, value in self.per_res_dict.items(): 

589 if self.cossin: 589 ↛ 590line 589 didn't jump to line 590, because the condition on line 589 was never true

590 prefixes.extend( 

591 [opt.upper() for opt in self.options[: value.shape[0]]] * 2 

592 ) 

593 else: 

594 prefixes.extend([opt.upper() for opt in self.options[: value.shape[0]]]) 

595 

596 if self.cossin: 596 ↛ 597line 596 didn't jump to line 597, because the condition on line 596 was never true

597 cossin = ("COS({dih} {res})", "SIN({dih} {res})") 

598 labels = [ 

599 s.format( 

600 dih=prefixes[j + len(cossin) * i], 

601 res=getlbl(self.top.atom(ires[1])), 

602 ) 

603 for i, ires in enumerate(self.angle_indexes) 

604 for j, s in enumerate(cossin) 

605 ] 

606 else: 

607 labels = [ 

608 "SIDECHDIH {dih} {res}".format( 

609 dih=prefixes[i], res=getlbl(self.top.atom(ires[1])) 

610 ) 

611 for i, ires in enumerate(self.angle_indexes) 

612 ] 

613 labels = list(map(lambda x: x[:14] + x[27:31], labels)) 

614 return labels 

615 

616 def describe(self): 

617 """Returns a list of labels, that can be used to unambiguously define 

618 atoms in the protein topology. 

619 

620 Returns: 

621 list[str]: A list of labels. This list has as many entries as atoms in `self.top`. 

622 

623 """ 

624 top = self.top 

625 getlbl = ( 

626 lambda at: f"RESID {at.residue.name}:{at.residue.resSeq:>4} CHAIN {at.residue.chain.index}" 

627 ) 

628 prefixes = [] 

629 for lengths, label in zip(self._prefix_label_lengths, self.options): 

630 if self.cossin: 630 ↛ 631line 630 didn't jump to line 631, because the condition on line 630 was never true

631 lengths *= 2 

632 prefixes.extend([label.upper()] * lengths) 

633 prefixes = [] 

634 for key, value in self.per_res_dict.items(): 

635 if self.cossin: 635 ↛ 636line 635 didn't jump to line 636, because the condition on line 635 was never true

636 prefixes.extend( 

637 [opt.upper() for opt in self.options[: value.shape[0]]] * 2 

638 ) 

639 else: 

640 prefixes.extend([opt.upper() for opt in self.options[: value.shape[0]]]) 

641 

642 if self.cossin: 642 ↛ 643line 642 didn't jump to line 643, because the condition on line 642 was never true

643 cossin = ("COS({dih} {res})", "SIN({dih} {res})") 

644 labels = [ 

645 s.format( 

646 dih=prefixes[j + len(cossin) * i], 

647 res=getlbl(self.top.atom(ires[1])), 

648 ) 

649 for i, ires in enumerate(self.angle_indexes) 

650 for j, s in enumerate(cossin) 

651 ] 

652 else: 

653 labels = [ 

654 "SIDECHDIH {dih} {res}".format( 

655 dih=prefixes[i], res=getlbl(self.top.atom(ires[1])) 

656 ) 

657 for i, ires in enumerate(self.angle_indexes) 

658 ] 

659 

660 return labels 

661 

662 

663class AllCartesians(SelectionFeature): 

664 """Feature that collects all cartesian position of all atoms in the trajectory. 

665 

666 Attributes: 

667 top (mdtraj.Topology): Topology of this feature. 

668 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

669 prefix_label (str): A prefix for the labels. In this case it is 'POSITION'. 

670 

671 """ 

672 

673 __serialize_version = 0 

674 __serialize_fields = ("indexes",) 

675 prefix_label = "POSITION " 

676 

677 def __init__(self, top): 

678 """Instantiate the AllCartesians class. 

679 

680 Args: 

681 top (mdtraj.Topology): A mdtraj topology. 

682 

683 """ 

684 self.top = top 

685 self.indexes = self.top.select("all") 

686 super().__init__(top, self.indexes) 

687 

688 @property 

689 def name(self): 

690 """str: The name of this class: 'AllCartesians'""" 

691 return "AllCartesians" 

692 

693 def describe(self): 

694 """Returns a list of labels, that can be used to unambiguously define 

695 atoms in the protein topology. 

696 

697 Returns: 

698 list[str]: A list of labels. This list has as many entries as atoms in `self.top`. 

699 

700 """ 

701 getlbl = ( 

702 lambda at: f"ATOM {at.name:>4}:{at.index:5} {at.residue.name}:{at.residue.resSeq:>4} CHAIN {at.residue.chain.index}" 

703 ) 

704 labels = [] 

705 for i in self.indexes: 

706 for pos in ["X", "Y", "Z"]: 

707 labels.append( 

708 f"{self.prefix_label} {pos} {getlbl(self.top.atom(i))}" 

709 ) 

710 return labels 

711 

712 

713class CentralCartesians(AllCartesians): 

714 """Feature that collects all cartesian position of the backbone atoms. 

715 

716 Attributes: 

717 top (mdtraj.Topology): Topology of this feature. 

718 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

719 prefix_label (str): A prefix for the labels. In this case it is 'CENTERPOS'. 

720 

721 """ 

722 

723 __serialize_version = 0 

724 __serialize_fields = ("indexes",) 

725 prefix_label = "CENTERPOS" 

726 

727 def __init__(self, top, generic_labels=False): 

728 self.top = top 

729 super().__init__(self.top) 

730 self.central_indexes = self.top.select("name CA or name C or name N") 

731 assert len(self.central_indexes) < len(self.indexes) 

732 self.indexes = self.central_indexes 

733 self.dimension = 3 * len(self.indexes) 

734 

735 if generic_labels: 

736 self.describe = self.generic_describe 

737 

738 def generic_describe(self): 

739 labels = [] 

740 for i in range(len(self.central_indexes)): 

741 for pos in ["X", "Y", "Z"]: 

742 labels.append(f"{self.prefix_label} {pos} {i}") 

743 return labels 

744 

745 def describe(self): 

746 """Returns a list of labels, that can be used to unambiguously define 

747 atoms in the protein topology. 

748 

749 Returns: 

750 list[str]: A list of labels. This list has as manyu entries as atoms in `self.top`. 

751 

752 """ 

753 getlbl = ( 

754 lambda at: f"ATOM {at.name:>4}:{at.index:5} {at.residue.name}:{at.residue.resSeq:>4} CHAIN {at.residue.chain.index}" 

755 ) 

756 labels = [] 

757 for i in self.central_indexes: 

758 for pos in ["X", "Y", "Z"]: 

759 labels.append( 

760 f"{self.prefix_label} {pos} {getlbl(self.top.atom(i))}" 

761 ) 

762 return labels 

763 

764 @property 

765 def name(self): 

766 """str: The name of the class: "CentralCartesians".""" 

767 return "CentralCartesians" 

768 

769 # def transform(self, traj): 

770 # newshape = (traj.xyz.shape[0], 3 * self.central_indexes.shape[0]) 

771 # return np.reshape(traj.xyz[:, self.central_indexes, :], newshape) 

772 

773 

774class SideChainCartesians(AllCartesians): 

775 """Feature that collects all cartesian position of all non-backbone atoms. 

776 

777 Attributes: 

778 top (mdtraj.Topology): Topology of this feature. 

779 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

780 prefix_label (str): A prefix for the labels. In this case it is 'SIDECHPOS'. 

781 

782 """ 

783 

784 __serialize_version = 0 

785 __serialize_fields = ("indexes",) 

786 prefix_label = "SIDECHPOS" 

787 

788 def __init__(self, top): 

789 self.top = top 

790 super().__init__(self.top) 

791 central_indexes = self.top.select("not backbone") 

792 assert len(central_indexes) < len(self.indexes) 

793 self.indexes = central_indexes 

794 self.dimension = 3 * len(self.indexes) 

795 

796 @property 

797 def name(self): 

798 """str: The name of the class: "SideChainCartesians".""" 

799 return "SideChainCartesians" 

800 

801 

802class AllBondDistances(DistanceFeature): 

803 """Feature that collects all bonds in a topology. 

804 

805 Attributes: 

806 top (mdtraj.Topology): Topology of this feature. 

807 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

808 prefix_label (str): A prefix for the labels. In this case it is 'DISTANCE'. 

809 

810 """ 

811 

812 __serialize_version = 0 

813 __serialize_fields = ("distance_indexes", "periodic") 

814 prefix_label = "DISTANCE " 

815 

816 def __init__(self, top, distance_indexes=None, periodic=True, check_aas=True): 

817 self.distance_indexes = distance_indexes 

818 if any([r.name not in _AMINO_ACID_CODES for r in top.residues]) and check_aas: 818 ↛ 819line 818 didn't jump to line 819, because the condition on line 818 was never true

819 raise Exception("Unkown amino acid in top.") 

820 if self.distance_indexes is None: 

821 self.top = top 

822 self.distance_indexes = np.vstack( 

823 [[b[0].index, b[1].index] for b in self.top.bonds] 

824 ) 

825 # print(self.distance_indexes, len(self.distance_indexes)) 

826 super().__init__(self.top, self.distance_indexes, periodic) 

827 else: 

828 super().__init__(self.top, self.distance_indexes, periodic) 

829 # print(self.distance_indexes, len(self.distance_indexes)) 

830 

831 def generic_describe(self): 

832 labels = [] 

833 for i in range(len(self.distance_indexes)): 

834 labels.append(f"{self.prefix_label}{i}") 

835 return labels 

836 

837 def describe(self): 

838 """Returns a list of labels, that can be used to unambiguously define 

839 atoms in the protein topology. 

840 

841 Returns: 

842 list[str]: A list of labels. This list has as many entries as atoms in `self.top`. 

843 

844 """ 

845 getlbl = ( 

846 lambda at: f"ATOM {at.name:>4}:{at.index:5} {at.residue.name}:{at.residue.resSeq:>4}" 

847 ) 

848 labels = [] 

849 for i, j in self.distance_indexes: 

850 i, j = self.top.atom(i), self.top.atom(j) 

851 labels.append( 

852 f"{self.prefix_label}{getlbl(i)} DIST {getlbl(j)} CHAIN {int(np.unique([a.residue.chain.index for a in [i, j]]))}" 

853 ) 

854 return labels 

855 

856 @property 

857 def name(self): 

858 """str: The name of the class: "AllBondDistances".""" 

859 return "AllBondDistances" 

860 

861 @property 

862 def indexes(self): 

863 """np.ndarray: A (n_angles, 2) shaped numpy array giving the atom indices 

864 of the distances to be calculated.""" 

865 return self.distance_indexes 

866 

867 

868class CentralBondDistances(AllBondDistances): 

869 """Feature that collects all bonds in the backbone of a topology. 

870 

871 Attributes: 

872 top (mdtraj.Topology): Topology of this feature. 

873 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

874 prefix_label (str): A prefix for the labels. In this case it is 'CENTERDISTANCE'. 

875 

876 """ 

877 

878 __serialize_version = 0 

879 __serialize_fields = ("distance_indexes", "periodic") 

880 prefix_label = "CENTERDISTANCE " 

881 

882 def __init__( 

883 self, 

884 top, 

885 distance_indexes=None, 

886 periodic=True, 

887 check_aas=True, 

888 generic_labels=False, 

889 ): 

890 self.top = copy.deepcopy(top) 

891 if any([r.name == "KAC" for r in top.residues]): 891 ↛ 892line 891 didn't jump to line 892, because the condition on line 891 was never true

892 self.top = add_KAC_backbone_bonds(self.top) 

893 check_aas = False 

894 select = self.top.select("name CA or name C or name N") 

895 

896 # temp_list = [] 

897 # for i in range(len(select)-1): 

898 # temp_list.append([select[i], select[i+1]]) 

899 # temp_array = np.array(temp_list) 

900 # print("this is the array of the selected atoms:", temp_array) 

901 # print(len(temp_array)) 

902 

903 if distance_indexes is None: 903 ↛ 906line 903 didn't jump to line 906, because the condition on line 903 was never false

904 distance_indexes = [] 

905 

906 for b in self.top.bonds: 

907 # print(b) 

908 if np.all([np.isin(x.index, select) for x in b]): 

909 distance_indexes.append([x.index for x in b]) 

910 distance_indexes = np.sort(distance_indexes, axis=0) 

911 

912 if generic_labels: 

913 self.describe = self.generic_describe 

914 

915 super().__init__( 

916 self.top, distance_indexes, periodic, check_aas=check_aas 

917 ) # distance_indexes 

918 

919 @property 

920 def name(self): 

921 """str: The name of the class: "CentralBondDistances".""" 

922 return "CentralBondDistances" 

923 

924 @property 

925 def indexes(self): 

926 """np.ndarray: A (n_angles, 2) shaped numpy array giving the atom indices 

927 of the distances to be calculated.""" 

928 return self.distance_indexes 

929 

930 

931class SideChainBondDistances(AllBondDistances): 

932 """Feature that collects all bonds not in the backbone of a topology. 

933 

934 Attributes: 

935 top (mdtraj.Topology): Topology of this feature. 

936 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

937 prefix_label (str): A prefix for the labels. In this case it is 'SIDECHDISTANCE'. 

938 

939 """ 

940 

941 __serialize_version = 0 

942 __serialize_fields = ("distance_indexes", "periodic") 

943 prefix_label = "SIDECHDISTANCE " 

944 

945 def __init__(self, top, periodic=True): 

946 self.top = top 

947 from mdtraj.geometry import dihedral 

948 

949 which = ["chi1", "chi2", "chi3", "chi4", "chi5"] 

950 indices_dict = {k: getattr(dihedral, "indices_%s" % k)(top) for k in which} 

951 flat_list = [ 

952 item 

953 for sublist in indices_dict.values() 

954 for item in sublist.flatten().tolist() 

955 ] 

956 atoms_in_sidechain_dihedrals = set(flat_list) 

957 

958 distance_indexes = [] 

959 for angle, indices in indices_dict.items(): 

960 for index in indices: 

961 if angle == "chi1": 

962 distance_indexes.append([index[1], index[2]]) 

963 distance_indexes.append([index[2], index[3]]) 

964 else: 

965 distance_indexes.append([index[2], index[3]]) 

966 distance_indexes = np.sort(distance_indexes, axis=0) 

967 super().__init__(self.top, distance_indexes, periodic) 

968 

969 @property 

970 def name(self): 

971 """str: The name of the class: "SideChainBondDistances".""" 

972 return "SideChainBondDistances" 

973 

974 @property 

975 def indexes(self): 

976 """np.ndarray: A (n_angles, 2) shaped numpy array giving the atom indices 

977 of the distances to be calculated.""" 

978 return self.distance_indexes 

979 

980 

981class CentralAngles(AngleFeature): 

982 """Feature that collects all angles in the backbone of a topology. 

983 

984 Attributes: 

985 top (mdtraj.Topology): Topology of this feature. 

986 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

987 prefix_label (str): A prefix for the labels. In this case it is 'CENTERANGLE'. 

988 

989 """ 

990 

991 __serialize_version = 0 

992 __serialize_fields = ("angle_indexes", "deg", "cossin", "periodic") 

993 prefix_label = "CENTERANGLE " 

994 

995 def __init__( 

996 self, top, deg=False, cossin=False, periodic=True, generic_labels=False 

997 ): 

998 self.top = copy.deepcopy(top) 

999 select = self.top.select("name CA or name C or name N") 

1000 # add 4 bonds in KAC 

1001 if any([r.name == "KAC" for r in top.residues]): 1001 ↛ 1002line 1001 didn't jump to line 1002, because the condition on line 1001 was never true

1002 self.top = add_KAC_backbone_bonds(self.top) 

1003 bonds = np.vstack([[x.index for x in b] for b in self.top.bonds]) 

1004 angle_indexes = [] 

1005 for a in select: 

1006 where = np.where(bonds == a) 

1007 possible_bonds = bonds[where[0], :] 

1008 where = np.isin(possible_bonds, select) 

1009 hits = np.count_nonzero(np.all(where, axis=1)) 

1010 if hits <= 1: 

1011 continue 

1012 elif hits == 2: 1012 ↛ 1018line 1012 didn't jump to line 1018, because the condition on line 1012 was never false

1013 where = np.all(where, axis=1) 

1014 these = np.unique( 

1015 [self.top.atom(i).index for i in possible_bonds[where, :].flatten()] 

1016 ) 

1017 angle_indexes.append(these) 

1018 elif hits == 3: 

1019 raise Exception( 

1020 f"Can't deal with these angles. One atom is part of three possible angles" 

1021 ) 

1022 elif hits == 4: 

1023 raise Exception( 

1024 f"Can't deal with these angles. One atom is part of four possible angles" 

1025 ) 

1026 else: 

1027 raise Exception( 

1028 f"Can't deal with these angles. One atom is part of three possible angles" 

1029 ) 

1030 

1031 angle_indexes = np.vstack(angle_indexes) 

1032 angle_indexes = np.unique(angle_indexes, axis=0) 

1033 if generic_labels: 

1034 self.describe = self.generic_describe 

1035 super().__init__(self.top, angle_indexes, deg, cossin, periodic) 

1036 

1037 def generic_describe(self): 

1038 labels = [] 

1039 for i in range(len(self.angle_indexes)): 

1040 labels.append(f"{self.prefix_label}{i}") 

1041 return labels 

1042 

1043 def describe(self): 

1044 """Returns a list of labels, that can be used to unambiguously define 

1045 atoms in the protein topology. 

1046 

1047 Returns: 

1048 list[str]: A list of labels. This list has as many entries as atoms in `self.top`. 

1049 

1050 """ 

1051 getlbl = ( 

1052 lambda at: f"ATOM {at.name:>4}:{at.index:5} {at.residue.name}:{at.residue.resSeq:>4}" 

1053 ) 

1054 labels = [] 

1055 for i, j, k in self.angle_indexes: 

1056 i, j, k = self.top.atom(i), self.top.atom(j), self.top.atom(k) 

1057 labels.append( 

1058 f"{self.prefix_label}{getlbl(i)} ANGLE {getlbl(j)} ANGLE {getlbl(k)} CHAIN {int(np.unique([a.residue.chain.index for a in [i, j, k]]))}" 

1059 ) 

1060 return labels 

1061 

1062 @property 

1063 def name(self): 

1064 """str: The name of the class: "CentralAngles".""" 

1065 return "CentralAngles" 

1066 

1067 @property 

1068 def indexes(self): 

1069 """np.ndarray: A (n_angles, 3) shaped numpy array giving the atom indices 

1070 of the angles to be calculated.""" 

1071 return self.angle_indexes 

1072 

1073 

1074class SideChainAngles(AngleFeature): 

1075 """Feature that collects all angles not in the backbone of a topology. 

1076 

1077 Attributes: 

1078 top (mdtraj.Topology): Topology of this feature. 

1079 indexes (np.ndarray): The numpy array returned from `top.select('all')`. 

1080 prefix_label (str): A prefix for the labels. In this case it is 'SIDECHANGLE'. 

1081 

1082 """ 

1083 

1084 __serialize_version = 0 

1085 __serialize_fields = ("angle_indexes", "deg", "cossin", "periodic") 

1086 prefix_label = "SIDECHANGLE " 

1087 

1088 def __init__(self, top, deg=False, cossin=False, periodic=True): 

1089 self.top = copy.deepcopy(top) 

1090 select = self.top.select( 

1091 "not backbone and (type C or type N or type S or type O) and not type H" 

1092 ) 

1093 # add 4 bonds in KAC 

1094 if any([r.name == "KAC" for r in top.residues]): 

1095 self.top = add_KAC_sidechain_bonds(self.top) 

1096 bonds = np.vstack([[x.index for x in b] for b in self.top.bonds]) 

1097 angle_indexes = [] 

1098 for a in select: 

1099 where = np.where(bonds == a) 

1100 possible_bonds = bonds[where[0], :] 

1101 where = np.isin(possible_bonds, select) 

1102 possible_bonds = possible_bonds[ 

1103 np.where(np.all(where, axis=1))[0], : 

1104 ] # remove atoms not in selection (like hydrogen) 

1105 where = where[ 

1106 np.where(np.all(where, axis=1))[0], : 

1107 ] # remove atoms not in selection (like hydrogen) 

1108 hits = np.count_nonzero(np.all(where, axis=1)) 

1109 if hits <= 1: 

1110 continue 

1111 elif hits == 2: 

1112 where_ax = np.all(where, axis=1) 

1113 angle_atoms = np.unique( 

1114 [ 

1115 self.top.atom(i).index 

1116 for i in possible_bonds[where_ax, :].flatten() 

1117 ] 

1118 ) 

1119 assert len(angle_atoms) == 3, print( 

1120 [(i, self.top.atom(i)) for i in angle_atoms] 

1121 ) 

1122 angle_indexes.append(angle_atoms) 

1123 elif hits == 3: 

1124 where_ax = np.vstack([where[:-1], [False, False]]) 

1125 where_ax = np.all(where_ax, axis=1) 

1126 angle_atoms = np.unique( 

1127 [ 

1128 self.top.atom(i).index 

1129 for i in possible_bonds[where_ax, :].flatten() 

1130 ] 

1131 ) 

1132 assert len(angle_atoms) == 3, print( 

1133 [(i, self.top.atom(i)) for i in angle_atoms] 

1134 ) 

1135 angle_indexes.append(angle_atoms) 

1136 where_ax = np.vstack([[False, False], where[1:]]) 

1137 where_ax = np.all(where_ax, axis=1) 

1138 angle_atoms = np.unique( 

1139 [ 

1140 self.top.atom(i).index 

1141 for i in possible_bonds[where_ax, :].flatten() 

1142 ] 

1143 ) 

1144 assert len(angle_atoms) == 3, print( 

1145 [(i, self.top.atom(i)) for i in angle_atoms] 

1146 ) 

1147 angle_indexes.append(angle_atoms) 

1148 elif hits == 4: 

1149 raise Exception( 

1150 f"Can't deal with these angles. One atom is part of four possible angles" 

1151 ) 

1152 else: 

1153 raise Exception( 

1154 f"Can't deal with these angles. One atom is part of three possible angles" 

1155 ) 

1156 angle_indexes = np.vstack(angle_indexes) 

1157 super().__init__(self.top, angle_indexes, deg, cossin, periodic) 

1158 

1159 def describe(self): 

1160 """Returns a list of labels, that can be used to unambiguously define 

1161 atoms in the protein topology. 

1162 

1163 Retruns: 

1164 list[str]: A list of labels. This list has as many entries as atoms in `self.top`. 

1165 

1166 """ 

1167 getlbl = ( 

1168 lambda at: f"ATOM {at.name:>4}:{at.index:5} {at.residue.name}:{at.residue.resSeq:>4}" 

1169 ) 

1170 labels = [] 

1171 for i, j, k in self.angle_indexes: 

1172 i, j, k = self.top.atom(i), self.top.atom(j), self.top.atom(k) 

1173 labels.append( 

1174 f"{self.prefix_label}{getlbl(i)} ANGLE {getlbl(j)} ANGLE {getlbl(k)} CHAIN {int(np.unique([a.residue.chain.index for a in [i, j, k]]))}" 

1175 ) 

1176 return labels 

1177 

1178 @property 

1179 def name(self): 

1180 """str: The name of the class: "SideChainAngles".""" 

1181 return "SideChainAngles" 

1182 

1183 @property 

1184 def indexes(self): 

1185 """np.ndarray: A (n_angles, 3) shaped numpy array giving the atom indices 

1186 of the angles to be calculated.""" 

1187 return self.angle_indexes