Coverage for encodermap/trajinfo/trajinfo_utils.py: 10%

841 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/trajinfo/trajinfo_utils.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Util functions for the `TrajEnsemble` and `SingleTraj` classes. 

23 

24""" 

25 

26 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32# Future Imports at the top 

33from __future__ import annotations 

34 

35# Standard Library Imports 

36import functools 

37import re 

38import warnings 

39from copy import deepcopy 

40from dataclasses import dataclass, field 

41from functools import reduce 

42from operator import xor 

43from pathlib import Path 

44 

45# Third Party Imports 

46import numpy as np 

47from optional_imports import _optional_import 

48 

49# Encodermap imports 

50from encodermap._typing import CustomAAsDict, DihedralOrBondDict 

51from encodermap.loading.features import AnyFeature, Feature 

52from encodermap.loading.featurizer import ( 

53 DaskFeaturizer, 

54 EnsembleFeaturizer, 

55 SingleTrajFeaturizer, 

56) 

57from encodermap.misc.misc import FEATURE_NAMES 

58from encodermap.misc.xarray import construct_xarray_from_numpy 

59from encodermap.trajinfo.info_single import SingleTraj 

60 

61 

62############################################################################## 

63# Optional Imports 

64############################################################################## 

65 

66 

67xr = _optional_import("xarray") 

68md = _optional_import("mdtraj") 

69_atom_sequence = _optional_import("mdtraj", "geometry.dihedral._atom_sequence") 

70parse_offsets = _optional_import("mdtraj", "geometry.dihedral.parse_offsets") 

71_construct_atom_dict = _optional_import( 

72 "mdtraj", "geometry.dihedral._construct_atom_dict" 

73) 

74_strip_offsets = _optional_import("mdtraj", "geometry.dihedral._strip_offsets") 

75h5py = _optional_import("h5py") 

76yaml = _optional_import("yaml") 

77 

78 

79################################################################################ 

80# Typing 

81################################################################################ 

82 

83 

84# Standard Library Imports 

85from collections.abc import Generator, Sequence 

86from typing import TYPE_CHECKING, Any, Literal, Optional, Union 

87 

88 

89SingleTrajFeatureType = Union[ 

90 str, 

91 Path, 

92 np.ndarray, 

93 Feature, 

94 xr.Dataset, 

95 xr.DataArray, 

96 SingleTrajFeaturizer, 

97 DaskFeaturizer, 

98 Literal["all"], 

99 Literal["full"], 

100 AnyFeature, 

101 None, 

102] 

103TrajEnsembleFeatureType = Union[ 

104 Sequence[str], 

105 Sequence[Path], 

106 Sequence[np.ndarray], 

107 xr.Dataset, 

108 EnsembleFeaturizer, 

109 DaskFeaturizer, 

110 Literal["all"], 

111 Literal["full"], 

112 AnyFeature, 

113 None, 

114] 

115 

116 

117if TYPE_CHECKING: # pragma: no cover 

118 # Third Party Imports 

119 import mdtraj as md 

120 import xarray as xr 

121 

122 # Encodermap imports 

123 from encodermap.trajinfo.info_all import TrajEnsemble 

124 

125 

126################################################################################ 

127# Globals 

128################################################################################ 

129 

130 

131CAN_BE_FEATURE_NAME = list(FEATURE_NAMES.keys()) + list(FEATURE_NAMES.values()) 

132 

133 

134# fmt: off 

135_AMINO_ACID_CODES = {'ACE': None, 'NME': None, '00C': 'C', '01W': 'X', '02K': 

136'A', '02L': 'N', '03Y': 'C', '07O': 'C', '08P': 'C', '0A0': 'D', '0A1': 'Y', 

137'0A2': 'K', '0A8': 'C', '0AA': 'V', '0AB': 'V', '0AC': 'G', '0AF': 'W', '0AG': 

138'L', '0AH': 'S', '0AK': 'D', '0BN': 'F', '0CS': 'A', '0E5': 'T', '0EA': 'Y', 

139'0FL': 'A', '0NC': 'A', '0WZ': 'Y', '0Y8': 'P', '143': 'C', '193': 'X', '1OP': 

140'Y', '1PA': 'F', '1PI': 'A', '1TQ': 'W', '1TY': 'Y', '1X6': 'S', '200': 'F', 

141'23F': 'F', '23S': 'X', '26B': 'T', '2AD': 'X', '2AG': 'A', '2AO': 'X', '2AS': 

142'X', '2CO': 'C', '2DO': 'X', '2FM': 'M', '2HF': 'H', '2KK': 'K', '2KP': 'K', 

143'2LU': 'L', '2ML': 'L', '2MR': 'R', '2MT': 'P', '2OR': 'R', '2PI': 'X', '2QZ': 

144'T', '2R3': 'Y', '2SI': 'X', '2TL': 'T', '2TY': 'Y', '2VA': 'V', '2XA': 'C', 

145'32S': 'X', '32T': 'X', '33X': 'A', '3AH': 'H', '3AR': 'X', '3CF': 'F', '3GA': 

146'A', '3MD': 'D', '3NF': 'Y', '3QN': 'K', '3TY': 'X', '3XH': 'G', '4BF': 'Y', 

147'4CF': 'F', '4CY': 'M', '4DP': 'W', '4FB': 'P', '4FW': 'W', '4HT': 'W', '4IN': 

148'W', '4MM': 'X', '4PH': 'F', '4U7': 'A', '56A': 'H', '5AB': 'A', '5CS': 'C', 

149'5CW': 'W', '5HP': 'E', '6CL': 'K', '6CW': 'W', '6GL': 'A', '6HN': 'K', '7JA': 

150'I', '9NE': 'E', '9NF': 'F', '9NR': 'R', '9NV': 'V', 'A5N': 'N', 'A66': 'X', 

151'AA3': 'A', 'AA4': 'A', 'AAR': 'R', 'AB7': 'X', 'ABA': 'A', 'ACB': 'D', 'ACL': 

152'R', 'ADD': 'X', 'AEA': 'X', 'AEI': 'D', 'AFA': 'N', 'AGM': 'R', 'AGT': 'C', 

153'AHB': 'N', 'AHH': 'X', 'AHO': 'A', 'AHP': 'A', 'AHS': 'X', 'AHT': 'X', 'AIB': 

154'A', 'AKL': 'D', 'AKZ': 'D', 'ALA': 'A', 'ALC': 'A', 'ALM': 'A', 'ALN': 'A', 

155'ALO': 'T', 'ALS': 'A', 'ALT': 'A', 'ALV': 'A', 'ALY': 'K', 'AN8': 'A', 'APE': 

156'X', 'APH': 'A', 'API': 'K', 'APK': 'K', 'APM': 'X', 'APP': 'X', 'AR2': 'R', 

157'AR4': 'E', 'AR7': 'R', 'ARG': 'R', 'ARM': 'R', 'ARO': 'R', 'ARV': 'X', 'AS2': 

158'D', 'AS9': 'X', 'ASA': 'D', 'ASB': 'D', 'ASI': 'D', 'ASK': 'D', 'ASL': 'D', 

159'ASM': 'X', 'ASN': 'N', 'ASP': 'D', 'ASQ': 'D', 'ASX': 'B', 'AVN': 'X', 'AYA': 

160'A', 'AZK': 'K', 'AZS': 'S', 'AZY': 'Y', 'B1F': 'F', 'B2A': 'A', 'B2F': 'F', 

161'B2I': 'I', 'B2V': 'V', 'B3A': 'A', 'B3D': 'D', 'B3E': 'E', 'B3K': 'K', 'B3L': 

162'X', 'B3M': 'X', 'B3Q': 'X', 'B3S': 'S', 'B3T': 'X', 'B3U': 'H', 'B3X': 'N', 

163'B3Y': 'Y', 'BB6': 'C', 'BB7': 'C', 'BB8': 'F', 'BB9': 'C', 'BBC': 'C', 'BCS': 

164'C', 'BE2': 'X', 'BFD': 'D', 'BG1': 'S', 'BH2': 'D', 'BHD': 'D', 'BIF': 'F', 

165'BIL': 'X', 'BIU': 'I', 'BJH': 'X', 'BL2': 'L', 'BLE': 'L', 'BLY': 'K', 'BMT': 

166'T', 'BNN': 'F', 'BNO': 'X', 'BOR': 'R', 'BPE': 'C', 'BSE': 'S', 'BTA': 'L', 

167'BTC': 'C', 'BTR': 'W', 'BUC': 'C', 'BUG': 'V', 'C1X': 'K', 'C22': 'A', 'C3Y': 

168'C', 'C4R': 'C', 'C5C': 'C', 'C66': 'X', 'C6C': 'C', 'CAF': 'C', 'CAL': 'X', 

169'CAS': 'C', 'CAV': 'X', 'CAY': 'C', 'CCL': 'K', 'CCS': 'C', 'CDE': 'X', 'CDV': 

170'X', 'CEA': 'C', 'CGA': 'E', 'CGU': 'E', 'CHF': 'X', 'CHG': 'X', 'CHP': 'G', 

171'CHS': 'X', 'CIR': 'R', 'CLE': 'L', 'CLG': 'K', 'CLH': 'K', 'CME': 'C', 'CMH': 

172'C', 'CML': 'C', 'CMT': 'C', 'CPC': 'X', 'CPI': 'X', 'CR5': 'G', 'CS0': 'C', 

173'CS1': 'C', 'CS3': 'C', 'CS4': 'C', 'CSA': 'C', 'CSB': 'C', 'CSD': 'C', 'CSE': 

174'C', 'CSJ': 'C', 'CSO': 'C', 'CSP': 'C', 'CSR': 'C', 'CSS': 'C', 'CSU': 'C', 

175'CSW': 'C', 'CSX': 'C', 'CSZ': 'C', 'CTE': 'W', 'CTH': 'T', 'CUC': 'X', 'CWR': 

176'S', 'CXM': 'M', 'CY0': 'C', 'CY1': 'C', 'CY3': 'C', 'CY4': 'C', 'CYA': 'C', 

177'CYD': 'C', 'CYF': 'C', 'CYG': 'C', 'CYJ': 'K', 'CYM': 'C', 'CYQ': 'C', 'CYR': 

178'C', 'CYS': 'C', 'CZ2': 'C', 'CZZ': 'C', 'D11': 'T', 'D3P': 'G', 'D4P': 'X', 

179'DA2': 'X', 'DAB': 'A', 'DAH': 'F', 'DAL': 'A', 'DAR': 'R', 'DAS': 'D', 'DBB': 

180'T', 'DBS': 'S', 'DBU': 'T', 'DBY': 'Y', 'DBZ': 'A', 'DC2': 'C', 'DCL': 'X', 

181'DCY': 'C', 'DDE': 'H', 'DFI': 'X', 'DFO': 'X', 'DGH': 'G', 'DGL': 'E', 'DGN': 

182'Q', 'DHA': 'S', 'DHI': 'H', 'DHL': 'X', 'DHN': 'V', 'DHP': 'X', 'DHV': 'V', 

183'DI7': 'Y', 'DIL': 'I', 'DIR': 'R', 'DIV': 'V', 'DLE': 'L', 'DLS': 'K', 'DLY': 

184'K', 'DM0': 'K', 'DMH': 'N', 'DMK': 'D', 'DMT': 'X', 'DNE': 'L', 'DNL': 'K', 

185'DNP': 'A', 'DNS': 'K', 'DOA': 'X', 'DOH': 'D', 'DON': 'L', 'DPL': 'P', 'DPN': 

186'F', 'DPP': 'A', 'DPQ': 'Y', 'DPR': 'P', 'DSE': 'S', 'DSG': 'N', 'DSN': 'S', 

187'DSP': 'D', 'DTH': 'T', 'DTR': 'W', 'DTY': 'Y', 'DVA': 'V', 'DYS': 'C', 'ECC': 

188'Q', 'EFC': 'C', 'EHP': 'F', 'ESB': 'Y', 'ESC': 'M', 'EXY': 'L', 'EYS': 'X', 

189'F2F': 'F', 'FAK': 'K', 'FB5': 'A', 'FB6': 'A', 'FCL': 'F', 'FGA': 'E', 'FGL': 

190'G', 'FGP': 'S', 'FH7': 'K', 'FHL': 'K', 'FHO': 'K', 'FLA': 'A', 'FLE': 'L', 

191'FLT': 'Y', 'FME': 'M', 'FOE': 'C', 'FP9': 'P', 'FRD': 'X', 'FT6': 'W', 'FTR': 

192'W', 'FTY': 'Y', 'FVA': 'V', 'FZN': 'K', 'GAU': 'E', 'GCM': 'X', 'GFT': 'S', 

193'GGL': 'E', 'GHG': 'Q', 'GHP': 'G', 'GL3': 'G', 'GLH': 'Q', 'GLJ': 'E', 'GLK': 

194'E', 'GLM': 'X', 'GLN': 'Q', 'GLQ': 'E', 'GLU': 'E', 'GLX': 'Z', 'GLY': 'G', 

195'GLZ': 'G', 'GMA': 'E', 'GND': 'X', 'GPL': 'K', 'GSC': 'G', 'GSU': 'E', 'GT9': 

196'C', 'GVL': 'S', 'H14': 'F', 'H5M': 'P', 'HAC': 'A', 'HAR': 'R', 'HBN': 'H', 

197'HCS': 'X', 'HFA': 'X', 'HGL': 'X', 'HHI': 'H', 'HIA': 'H', 'HIC': 'H', 'HIP': 

198'H', 'HIQ': 'H', 'HIS': 'H', 'HL2': 'L', 'HLU': 'L', 'HMR': 'R', 'HPC': 'F', 

199'HPE': 'F', 'HPH': 'F', 'HPQ': 'F', 'HQA': 'A', 'HRG': 'R', 'HRP': 'W', 'HS8': 

200'H', 'HS9': 'H', 'HSE': 'S', 'HSL': 'S', 'HSO': 'H', 'HTI': 'C', 'HTN': 'N', 

201'HTR': 'W', 'HV5': 'A', 'HVA': 'V', 'HY3': 'P', 'HYP': 'P', 'HZP': 'P', 'I2M': 

202'I', 'I58': 'K', 'IAM': 'A', 'IAR': 'R', 'IAS': 'D', 'IEL': 'K', 'IGL': 'G', 

203'IIL': 'I', 'ILE': 'I', 'ILG': 'E', 'ILX': 'I', 'IML': 'I', 'IOY': 'F', 'IPG': 

204'G', 'IT1': 'K', 'IYR': 'Y', 'IYT': 'T', 'IZO': 'M', 'JJJ': 'C', 'JJK': 'C', 

205'JJL': 'C', 'K1R': 'C', 'KCX': 'K', 'KGC': 'K', 'KNB': 'A', 'KOR': 'M', 'KPI': 

206'K', 'KST': 'K', 'KYN': 'W', 'KYQ': 'K', 'L2A': 'X', 'LA2': 'K', 'LAA': 'D', 

207'LAL': 'A', 'LBY': 'K', 'LCK': 'K', 'LCX': 'K', 'LCZ': 'X', 'LDH': 'K', 'LED': 

208'L', 'LEF': 'L', 'LEH': 'L', 'LEI': 'V', 'LEM': 'L', 'LEN': 'L', 'LET': 'K', 

209'LEU': 'L', 'LEX': 'L', 'LHC': 'X', 'LLP': 'K', 'LLY': 'K', 'LME': 'E', 'LMF': 

210'K', 'LMQ': 'Q', 'LP6': 'K', 'LPD': 'P', 'LPG': 'G', 'LPL': 'X', 'LPS': 'S', 

211'LSO': 'K', 'LTA': 'X', 'LTR': 'W', 'LVG': 'G', 'LVN': 'V', 'LYF': 'K', 'LYK': 

212'K', 'LYM': 'K', 'LYN': 'K', 'LYR': 'K', 'LYS': 'K', 'LYX': 'K', 'LYZ': 'K', 

213'M0H': 'C', 'M2L': 'K', 'M2S': 'M', 'M30': 'G', 'M3L': 'K', 'MA': 'A', 'MAA': 

214'A', 'MAI': 'R', 'MBQ': 'Y', 'MC1': 'S', 'MCG': 'X', 'MCL': 'K', 'MCS': 'C', 

215'MD3': 'C', 'MD6': 'G', 'MDF': 'Y', 'MDH': 'X', 'MEA': 'F', 'MED': 'M', 'MEG': 

216'E', 'MEN': 'N', 'MEQ': 'Q', 'MET': 'M', 'MEU': 'G', 'MF3': 'X', 'MGG': 'R', 

217'MGN': 'Q', 'MGY': 'G', 'MHL': 'L', 'MHO': 'M', 'MHS': 'H', 'MIS': 'S', 'MK8': 

218'L', 'ML3': 'K', 'MLE': 'L', 'MLL': 'L', 'MLY': 'K', 'MLZ': 'K', 'MME': 'M', 

219'MMO': 'R', 'MND': 'N', 'MNL': 'L', 'MNV': 'V', 'MOD': 'X', 'MP8': 'P', 'MPH': 

220'X', 'MPJ': 'X', 'MPQ': 'G', 'MSA': 'G', 'MSE': 'M', 'MSL': 'M', 'MSO': 'M', 

221'MSP': 'X', 'MT2': 'M', 'MTY': 'Y', 'MVA': 'V', 'N10': 'S', 'N2C': 'X', 'N7P': 

222'P', 'N80': 'P', 'N8P': 'P', 'NA8': 'A', 'NAL': 'A', 'NAM': 'A', 'NB8': 'N', 

223'NBQ': 'Y', 'NC1': 'S', 'NCB': 'A', 'NCY': 'X', 'NDF': 'F', 'NEM': 'H', 'NEP': 

224'H', 'NFA': 'F', 'NHL': 'E', 'NIY': 'Y', 'NLE': 'L', 'NLN': 'L', 'NLO': 'L', 

225'NLP': 'L', 'NLQ': 'Q', 'NMC': 'G', 'NMM': 'R', 'NNH': 'R', 'NPH': 'C', 'NPI': 

226'A', 'NSK': 'X', 'NTR': 'Y', 'NTY': 'Y', 'NVA': 'V', 'NYS': 'C', 'NZH': 'H', 

227'O12': 'X', 'OAR': 'R', 'OAS': 'S', 'OBF': 'X', 'OBS': 'K', 'OCS': 'C', 'OCY': 

228'C', 'OHI': 'H', 'OHS': 'D', 'OIC': 'X', 'OLE': 'X', 'OLT': 'T', 'OLZ': 'S', 

229'OMT': 'M', 'ONH': 'A', 'ONL': 'X', 'OPR': 'R', 'ORN': 'A', 'ORQ': 'R', 'OSE': 

230'S', 'OTB': 'X', 'OTH': 'T', 'OXX': 'D', 'P1L': 'C', 'P2Y': 'P', 'PAQ': 'Y', 

231'PAS': 'D', 'PAT': 'W', 'PAU': 'A', 'PBB': 'C', 'PBF': 'F', 'PCA': 'E', 'PCC': 

232'P', 'PCE': 'X', 'PCS': 'F', 'PDL': 'X', 'PEC': 'C', 'PF5': 'F', 'PFF': 'F', 

233'PFX': 'X', 'PG1': 'S', 'PG9': 'G', 'PGL': 'X', 'PGY': 'G', 'PH6': 'P', 'PHA': 

234'F', 'PHD': 'D', 'PHE': 'F', 'PHI': 'F', 'PHL': 'F', 'PHM': 'F', 'PIV': 'X', 

235'PLE': 'L', 'PM3': 'F', 'POM': 'P', 'PPN': 'F', 'PR3': 'C', 'PR9': 'P', 'PRO': 

236'P', 'PRS': 'P', 'PSA': 'F', 'PSH': 'H', 'PTA': 'X', 'PTH': 'Y', 'PTM': 'Y', 

237'PTR': 'Y', 'PVH': 'H', 'PVL': 'X', 'PYA': 'A', 'PYL': 'O', 'PYX': 'C', 'QCS': 

238'C', 'QMM': 'Q', 'QPA': 'C', 'QPH': 'F', 'R1A': 'C', 'R4K': 'W', 'RE0': 'W', 

239'RE3': 'W', 'RON': 'X', 'RVX': 'S', 'RZ4': 'S', 'S1H': 'S', 'S2C': 'C', 'S2D': 

240'A', 'S2P': 'A', 'SAC': 'S', 'SAH': 'C', 'SAR': 'G', 'SBL': 'S', 'SCH': 'C', 

241'SCS': 'C', 'SCY': 'C', 'SD2': 'X', 'SDP': 'S', 'SE7': 'A', 'SEB': 'S', 'SEC': 

242'U', 'SEG': 'A', 'SEL': 'S', 'SEM': 'S', 'SEN': 'S', 'SEP': 'S', 'SER': 'S', 

243'SET': 'S', 'SGB': 'S', 'SHC': 'C', 'SHP': 'G', 'SHR': 'K', 'SIB': 'C', 'SLR': 

244'P', 'SLZ': 'K', 'SMC': 'C', 'SME': 'M', 'SMF': 'F', 'SNC': 'C', 'SNN': 'N', 

245'SOC': 'C', 'SOY': 'S', 'SRZ': 'S', 'STY': 'Y', 'SUB': 'X', 'SUN': 'S', 'SVA': 

246'S', 'SVV': 'S', 'SVW': 'S', 'SVX': 'S', 'SVY': 'S', 'SVZ': 'S', 'SYS': 'C', 

247'T11': 'F', 'T66': 'X', 'TA4': 'X', 'TAV': 'D', 'TBG': 'V', 'TBM': 'T', 'TCQ': 

248'Y', 'TCR': 'W', 'TDD': 'L', 'TFQ': 'F', 'TH6': 'T', 'THC': 'T', 'THO': 'X', 

249'THR': 'T', 'THZ': 'R', 'TIH': 'A', 'TMB': 'T', 'TMD': 'T', 'TNB': 'C', 'TNR': 

250'S', 'TOQ': 'W', 'TPH': 'X', 'TPL': 'W', 'TPO': 'T', 'TPQ': 'Y', 'TQI': 'W', 

251'TQQ': 'W', 'TRF': 'W', 'TRG': 'K', 'TRN': 'W', 'TRO': 'W', 'TRP': 'W', 'TRQ': 

252'W', 'TRW': 'W', 'TRX': 'W', 'TRY': 'W', 'TST': 'X', 'TTQ': 'W', 'TTS': 'Y', 

253'TXY': 'Y', 'TY1': 'Y', 'TY2': 'Y', 'TY3': 'Y', 'TY5': 'Y', 'TYB': 'Y', 'TYI': 

254'Y', 'TYJ': 'Y', 'TYN': 'Y', 'TYO': 'Y', 'TYQ': 'Y', 'TYR': 'Y', 'TYS': 'Y', 

255'TYT': 'Y', 'TYW': 'Y', 'TYX': 'X', 'TYY': 'Y', 'TZB': 'X', 'TZO': 'X', 'UMA': 

256'A', 'UN1': 'X', 'UN2': 'X', 'UNK': 'X', 'VAD': 'V', 'VAF': 'V', 'VAL': 'V', 

257'VB1': 'K', 'VDL': 'X', 'VLL': 'X', 'VLM': 'X', 'VMS': 'X', 'VOL': 'X', 'WLU': 

258'L', 'WPA': 'F', 'WRP': 'W', 'WVL': 'V', 'X2W': 'E', 'XCN': 'C', 'XCP': 'X', 

259'XDT': 'T', 'XPL': 'O', 'XPR': 'P', 'XSN': 'N', 'XX1': 'K', 'YCM': 'C', 'YOF': 

260'Y', 'YTH': 'T', 'Z01': 'A', 'ZAL': 'A', 'ZCL': 'F', 'ZFB': 'X', 'ZU0': 'T', 

261'ZZJ': 'A'} 

262# fmt: on 

263 

264 

265PHI_ATOMS = ["-C", "N", "CA", "C"] 

266PSI_ATOMS = ["N", "CA", "C", "+N"] 

267OMEGA_ATOMS = ["CA", "C", "+N", "+CA"] 

268CHI1_ATOMS = [ 

269 ["N", "CA", "CB", "CG"], 

270 ["N", "CA", "CB", "CG1"], 

271 ["N", "CA", "CB", "SG"], 

272 ["N", "CA", "CB", "OG"], 

273 ["N", "CA", "CB", "OG1"], 

274] 

275CHI2_ATOMS = [ 

276 ["CA", "CB", "CG", "CD"], 

277 ["CA", "CB", "CG", "CD1"], 

278 ["CA", "CB", "CG1", "CD1"], 

279 ["CA", "CB", "CG", "OD1"], 

280 ["CA", "CB", "CG", "ND1"], 

281 ["CA", "CB", "CG", "SD"], 

282] 

283CHI3_ATOMS = [ 

284 ["CB", "CG", "CD", "NE"], 

285 ["CB", "CG", "CD", "CE"], 

286 ["CB", "CG", "CD", "OE1"], 

287 ["CB", "CG", "SD", "CE"], 

288] 

289CHI4_ATOMS = [["CG", "CD", "NE", "CZ"], ["CG", "CD", "CE", "NZ"]] 

290CHI5_ATOMS = [["CD", "NE", "CZ", "NH1"]] 

291 

292 

293__all__: list[str] = ["load_CVs_singletraj", "load_CVs_ensembletraj", "CustomTopology"] 

294 

295 

296################################################################################ 

297# Classes 

298################################################################################ 

299 

300 

301@dataclass 

302class Bond: 

303 """Dataclass, that contains information of an atomic bond. 

304 

305 Attributes: 

306 resname (str): The name of the residue, this bond belongs to. Although 

307 bonds belong to residues, they can also have `atom1` or `atom2` 

308 belonging to a different residue. 

309 type (Literal["add", "delete", "optional", "optional_delete"]): Defines 

310 what should be done with this bond. 'add', adds it to the topology and 

311 raises an Exception if the bond was already present. 'optional' does 

312 the same as 'add', but without raising an Exception. 'delete' deletes 

313 this bond from the topology. An Exception is raised, if this bond 

314 wasn't even in the topology to begin with. 'optional_delete' deletes 

315 bonds, but doesn't raise an Exception. 

316 atom1 (Union[str, int]): The name of the first atom. Can be 'CA', 'N', or 

317 whatever (not limited to proteins). If it is int it can be any other 

318 atom of the topology (also belonging to a different residue). 

319 atom2 (Union[str, int]): The name of the second atom. Can be 'CA', 'N', or 

320 whatever (not limited to proteins). If it is int it can be any other 

321 atom of the topology (also belonging to a different residue). 

322 

323 """ 

324 

325 resname: str 

326 type: Literal["add", "delete", "optional", "optional_delete"] 

327 atom1: Union[str, int] 

328 atom2: Union[str, int] 

329 

330 def __hash__(self) -> int: 

331 seq = [ 

332 self.resname, 

333 self.type, 

334 self.atom1, 

335 self.atom2, 

336 ] 

337 return reduce(xor, map(hash, seq)) 

338 

339 

340@dataclass 

341class Dihedral: 

342 """Dataclass that stores information about a dihedral of 4 atoms. 

343 

344 Attributes: 

345 resname (str): The name of the residue, this bond belongs to. Although 

346 bonds belong to residues, they can also have `atom1` or `atom2` 

347 belonging to a different residue. 

348 type (Literal["OMEGA", "PHI", "PSI", "CHI1", "CHI2", "CHI3", "CHI4", "CHI5"]): 

349 Defines what type of dihedral this dihedral is. Mainly used to 

350 discern different these types of dihedrals. 

351 atom1 (Union[str, int]): The name of the first atom. Can be 'CA', 'N', or 

352 whatever (not limited to proteins). If it is `int` it can be any other 

353 atom of the topology (also belonging to a different residue). 

354 atom2 (Union[str, int]): The name of the second atom. Can be 'CA', 'N', or 

355 whatever (not limited to proteins). If it is `int` it can be any other 

356 atom of the topology (also belonging to a different residue). 

357 atom3 (Union[str, int]): The name of the third atom. Can be 'CA', 'N', or 

358 whatever (not limited to proteins). If it is `int` it can be any other 

359 atom of the topology (also belonging to a different residue). 

360 atom4 (Union[str, int]): The name of the fourth atom. Can be 'CA', 'N', or 

361 whatever (not limited to proteins). If it is `int` it can be any other 

362 atom of the topology (also belonging to a different residue). 

363 delete (bool): Whether this dihedral has to be deleted or not. If 

364 `delete` is set to True, this dihedral won't produce output. 

365 

366 """ 

367 

368 resname: str 

369 type: Literal["OMEGA", "PHI", "PSI", "CHI1", "CHI2", "CHI3", "CHI4", "CHI5"] 

370 atom1: Union[int, str, None] = None 

371 atom2: Union[int, str, None] = None 

372 atom3: Union[int, str, None] = None 

373 atom4: Union[int, str, None] = None 

374 delete: bool = False 

375 

376 def __hash__(self) -> int: 

377 seq = [ 

378 self.resname, 

379 self.type, 

380 self.atom1, 

381 self.atom2, 

382 self.atom3, 

383 self.atom4, 

384 self.delete, 

385 ] 

386 return reduce(xor, map(hash, seq)) 

387 

388 @property 

389 def new_atoms_def(self) -> list[str]: 

390 """list[str]: A list of str, that describes the dihedral's atoms.""" 

391 if not self.delete: 

392 atoms = [self.atom1, self.atom2, self.atom3, self.atom4] 

393 assert all( 

394 [isinstance(a, str) for a in atoms] 

395 ), f"Can only add to dihedral definitions if all atoms are str. {self.delete=}" 

396 return atoms 

397 else: 

398 return [] 

399 

400 

401@dataclass 

402class NewResidue: 

403 """Dataclass that stores information about a new (nonstandard) residue. 

404 

405 Attributes: 

406 name (str): The 3-letter code name of the new residue. 

407 idx (Union[None, int]): The 0-based unique index of the residue. The 

408 `idx` index is always unique (i.e., if multiple chains are present, 

409 this residue can only appear in one chain). 

410 resSeq (Union[None, int]): The 1-based non-unique index of the residue. 

411 resSeqs can appear multiple times, but in separate chains. Each 

412 residue chain can have a MET1 residue. Either resSeq or idx must 

413 be defined. Not both can be None. 

414 one_letter_code (str): The one letter code of this new resiude. 

415 Can be set to a known one letter code, so that this new residue 

416 mimics that one letter code residue's behavior. Can also be '' 

417 (empty string), if you don't want to bother with this definition. 

418 ignore (bool): Whether to ignore the features of this residue. 

419 bonds (list[Bond]): A list of `Bond` instances. 

420 dihedrals (list[Dihedral]): A list of `Dihedral` instances. 

421 common_str (Optional[str]): The common_str of the (sub)set of 

422 `SingleTraj`s that this new dihedral should apply to. Only 

423 applies to `SingleTraj`s with the same `common_str`. Can be 

424 None and thus applies to all trajs in the `TrajEnsmeble`. 

425 

426 """ 

427 

428 name: str 

429 idx: Union[None, int] = None 

430 resSeq: Union[None, int] = None 

431 one_letter_code: str = "" 

432 topology: Optional[md.Topology] = None 

433 ignore: bool = False 

434 bonds: list[Bond] = field(default_factory=list) 

435 dihedrals: list[Dihedral] = field(default_factory=list) 

436 common_str: Optional[str] = None 

437 

438 def __hash__(self) -> int: 

439 seq = [ 

440 self.name, 

441 self.idx, 

442 self.resSeq, 

443 self.one_letter_code, 

444 self.topology, 

445 self.ignore, 

446 ] 

447 out_hash = reduce(xor, map(hash, seq)) 

448 if self.bonds: 

449 out_hash ^= reduce(xor, [hash(b) for b in self.bonds]) 

450 else: 

451 out_hash ^= hash(None) 

452 if self.dihedrals: 

453 out_hash ^= reduce(xor, [hash(d) for d in self.dihedrals]) 

454 else: 

455 out_hash ^= hash(None) 

456 return out_hash 

457 

458 def parse_bonds_and_dihedrals( 

459 self, 

460 bonds_and_dihedrals: DihedralOrBondDict, 

461 ) -> None: 

462 """Parses a dict of bonds and dihedrals. The format of this can be derived 

463 from the format of the `CustomTopology` input dict. 

464 

465 Args: 

466 bonds_and_dihedrals (DihedralOrBondDict): A dict defining bonds 

467 and dihedrals of this newResidue. 

468 

469 """ 

470 if self.bonds or self.dihedrals: 

471 raise Exception( 

472 f"The method `parse_bonds_and_dihedrals` works on empty `NewResidue` " 

473 f"instances. If you want to add bonds or dihedrals use the " 

474 f"`add_bond` or `add_dihedral` methods." 

475 ) 

476 for bond_or_dihe_type, atoms_or_bonds in bonds_and_dihedrals.items(): 

477 if "bonds" in bond_or_dihe_type: 

478 bond_type = bond_or_dihe_type.rstrip("bonds").rstrip("_") 

479 if bond_type == "": 

480 bond_type = "add" 

481 for bond in atoms_or_bonds: 

482 bond = Bond(self.name, bond_type, *bond) 

483 self.bonds.append(bond) 

484 else: 

485 if isinstance(atoms_or_bonds, str) or atoms_or_bonds is None: 

486 if bond_or_dihe_type.startswith("not_"): 

487 bond_or_dihe_type = bond_or_dihe_type.replace("not_", "") 

488 dihe_name = bond_or_dihe_type 

489 atoms_or_bonds = [] 

490 delete = True 

491 else: 

492 dihe_name = bond_or_dihe_type 

493 delete = False 

494 

495 dihedral = Dihedral( 

496 self.name, dihe_name, *atoms_or_bonds, delete=delete 

497 ) 

498 

499 self.dihedrals.append(dihedral) 

500 

501 def get_dihedral_by_type(self, type: str) -> Dihedral: 

502 for d in self.dihedrals: 

503 if d.type == type: 

504 return d 

505 

506 def add_bond(self, bond: Bond) -> None: 

507 assert isinstance(bond, Bond) 

508 self.bonds.append(bond) 

509 

510 def add_dihedral(self, dihedral: Dihedral) -> None: 

511 assert isinstance(dihedral, Dihedral) 

512 self.dihedrals.append(dihedral) 

513 

514 def as_amino_acid_dict_entry(self) -> dict[str, Union[str, None]]: 

515 one_letter_code = None if self.ignore else self.one_letter_code 

516 return {self.name: one_letter_code} 

517 

518 def _str_summary(self): 

519 name = self.name 

520 one_letter_code = self.one_letter_code 

521 topology = self.topology 

522 ignore = self.ignore 

523 common_str = self.common_str 

524 out = f"{name=} {one_letter_code=} {topology=} {ignore=} {common_str=}" 

525 if self.bonds: 

526 out += "\nBonds:" 

527 for b in self.bonds: 

528 out += f"\n{b}" 

529 if self.dihedrals: 

530 out += "\nDihedrals:" 

531 for bd in self.dihedrals: 

532 out += f"\n{bd}" 

533 return out 

534 

535 def __str__(self): 

536 return self._str_summary() 

537 

538 

539def _delete_bond( 

540 top: md.Topology, 

541 bond: tuple["Atom", "Atom"], 

542) -> md.Topology: 

543 """Deletes a bond from a MDTraj topology. 

544 

545 MDTraj's topology has an easy implementation of `md.Topology.add_bond`. 

546 However, it is lacking the same functionality for removing bonds. This function 

547 adds it by creating a new topology from a dataframe specifying the atoms 

548 (which does not change) and a new np.array specifying the bonds. 

549 

550 Args: 

551 top (md.Topology): The topology. 

552 bond (tuple[Atom, Atom]): A tuple of two mdtraj.core.topology.Atom objects. 

553 

554 Returns: 

555 md.Topology: The new topology. 

556 

557 """ 

558 atoms, bonds = top.to_dataframe() 

559 # fmt: off 

560 match1 = np.where((bonds[:, 0] == bond[0].index) & (bonds[:, 1] == bond[1].index))[0] 

561 match2 = np.where((bonds[:, 1] == bond[0].index) & (bonds[:, 0] == bond[1].index))[0] 

562 # fmt: on 

563 if ( 

564 (match1.size == 0 and match2.size == 0) 

565 or (match1.size > 1 and match2.size == 0) 

566 or (match1.size == 0 and match2.size > 1) 

567 ): 

568 raise Exception(f"Could not identify bond {bond} in topology.") 

569 elif match1.size == 1 and match2.size != 1: 

570 match = match1[0] 

571 elif match1.size != 1 and match2.size == 1: 

572 match = match2[0] 

573 else: 

574 raise Exception 

575 

576 index = np.ones(len(bonds)).astype(bool) 

577 index[match] = False 

578 new_bonds = bonds[index] 

579 new_top = md.Topology.from_dataframe(atoms, new_bonds) 

580 return new_top 

581 

582 

583class CustomTopology: 

584 """Adds custom topology elements to a topology parsed by MDTraj. 

585 

586 Postpones parsing the custom AAs until requested. 

587 

588 The custom_aminoacids dictionary follows these styleguides: 

589 * The keys can be str or tuple[str, str] 

590 * If a key is str, it needs to be a 3-letter code (MET, ALA, GLY, ...) 

591 * If a key is a tuple[str, str], the first str of the tuple is a common_str 

592 (see the docstring for `encodermap.TrajEnsemble` to learn about common_str. 

593 This common_str can be used to apply custom topologies to an ensemble 

594 based on their common_str. For example:: 

595 

596 {("CSR_mutant", "CSR"): ...} 

597 

598 * A key can also affect only a single residue (not all resides called "CSR"). 

599 For that, the 3-letter code of the residue needs to be postponed with 

600 a dash and the 1-based indexed resSeq of the residue:: 

601 

602 {"CSR-2": ...} 

603 

604 * The value to a key can be None, which means this residue will not be 

605 used for building a topology. Because EncoderMap raises Exceptions, 

606 when it encounters unknown residues (to make sure, you don't forget to 

607 featurize some important residues), it will also raise Exceptions when 

608 the topology contains unknown solvents/solutes. If you run a simulation 

609 in water/methanol mixtures with the residue names SOL and MOH, EncoderMap 

610 will raise an Exception upon encountering MOH, so your custom topology 

611 should contain 1{"MOH": None}` to include MOH. 

612 * The value of a key can also be a tuple[str, Union[dict, None]]. In this 

613 case, the first string should be the one-letter code of the residue or 

614 the residue most closely representing this residue. If you use 

615 phosphotyrosine (PTR) in your simulations and want to use it as a 

616 standard tyrosine residue, the custom topology should contain 

617 `{"PTR": ("Y", None)}` 

618 * If your residue is completely novel you need to define all possible 

619 bonds, backbone and sidechain dihedrals yourself. For that, you want 

620 to provide a tuple[str, dict[str, Union[listr[str], list[int]]] type. 

621 This second level dict allows for the following keys: 

622 * bonds: For bonds between atoms. This key can contain a list[tuple[str, str]], 

623 which defines bonds in this residue. This dict defines a bond 

624 between N and CA in phosphothreonine. 

625 {"PTR": ("Y", { 

626 "bonds": [ 

627 ("N", "CA"), 

628 ], 

629 }} 

630 These strings can cotain + and - signs to denote bonds to 

631 previous or following residues. To connect the residues MET1 to 

632 TPO2 to ALA3, you want to have this dict: 

633 {"TPO": ("T", { 

634 "bonds": [ 

635 ("-C", "N"), # bond to MET1-C 

636 ("N", "CA"), 

637 ... 

638 ("C", "+N"), # bond to ALA2-N 

639 ], 

640 }} 

641 For exotic bonds, one of the strings can also be int to 

642 connect to any 0-based indexed atom in your topology. You can connect 

643 the residues CYS2 and CYS20 wit a sulfide bride like so: 

644 {"CYS-2": ("C", { 

645 "bonds": [ 

646 ("S", 321), # connect to CYS20, the 321 is a placeholder 

647 ], 

648 }, 

649 "CYS-20": ("C", { 

650 "bonds": [ 

651 (20, "S"), # connect to CYS2 

652 ], 

653 }, 

654 } 

655 * optional_bonds: This key accepts the same list[tuple] as 'bonds'. 

656 However, bonds will raise an Exception if a bond already exists. 

657 The above example with a disulfide bridge between CYS2 and CYS20 

658 will thus raise an exception. A better example is: 

659 {"CYS-2": ("C", { 

660 "optional_bonds": [ 

661 ("S", 321), # connect to CYS20, the 321 is a placeholder 

662 ], 

663 }, 

664 "CYS-20": ("C", { 

665 "optional_bonds": [ 

666 (20, "S"), # connect to CYS2 

667 ], 

668 }, 

669 } 

670 * delete_bonds: This key accepts the same list[tuple] as 'bonds', 

671 but will remove bonds. If a bond was marked for deletion, but 

672 does not exist in your topology, an Exception will be raised. 

673 To delete bonds, without raising an Exception, use: 

674 * optional_delete_bonds: This will delete bonds, if they are present 

675 and won't raise an Exception if no bond is present. 

676 * PHI, PSI, OMEGA: These keys define the backbone torsions of this 

677 residue. You can just provide a list[str] for these keys. But 

678 the str can contain + and - to use atoms in previous or following 

679 residues. Example: 

680 { 

681 "CYS-2": ( 

682 "C", 

683 { 

684 "PHI": ["-C", "N", "CA", "C"], 

685 "PSI": ["N", "CA", "C", "+N"], 

686 "OMEGA": ["CA", "C", "+N", "+CA"], 

687 }, 

688 ), 

689 } 

690 * not-PSI, not_OMEGA, not_PHI: Same as 'PHI', 'PSI", 'OMEGA', but 

691 will remove these dihedrals from consideration. The vales of 

692 these keys do not matter. Example: 

693 { 

694 "CYS-2": ( 

695 "C", 

696 { 

697 "PHI": ["-C", "N", "CA", "C"], 

698 "not_PSI": [], # value for not_* keys does not matter 

699 "not_OMEGA": [], # it just makes EncoderMap skip these dihedrals. 

700 }, 

701 ), 

702 } 

703 * CHI1, ..., CHI5: Finally, these keys define the atoms considered for 

704 the sidechain angles. If you want to add extra sidechain dihedrals 

705 for phosphothreonine, you can do: 

706 { 

707 "TPO": ( 

708 "T", 

709 { 

710 "CHI2": ["CA", "CB", "OG1", "P"], # include phosphorus in sidechain angles 

711 "CHI3": ["CB", "OG1", "P", "OXT"], # include the terminal axygen in sidechain angles 

712 }, 

713 ) 

714 } 

715 

716 Examples: 

717 >>> # Aminoacids taken from https://www.swisssidechain.ch/ 

718 >>> # The provided .pdb file has only strange and unnatural aminoacids. 

719 >>> # Its sequence is: 

720 >>> # TPO - PTR - ORN - OAS - 2AG - CSR 

721 >>> # TPO: phosphothreonine 

722 >>> # PTR: phosphotyrosine 

723 >>> # ORN: ornithine 

724 >>> # OAS: o-acetylserine 

725 >>> # 2AG: 2-allyl-glycine 

726 >>> # CSR: selenocysteine 

727 >>> # However, someone mis-named the 2AG residue to ALL 

728 >>> # Let's fix that with EncoderMap's CustomTopology 

729 >>> import encodermap as em 

730 >>> from pathlib import Path 

731 ... 

732 >>> traj = em.load(Path(em.__file__).resolve().parent.parent / "tests/data/unnatural_aminoacids.pdb") 

733 ... 

734 >>> custom_aas = { 

735 ... "ALL": ("A", None), # makes EncoderMap treat 2-allyl-glycine as alanine 

736 ... "OAS": ( 

737 ... "S", # OAS is 2-acetylserine 

738 ... { 

739 ... "CHI2": ["CA", "CB", "OG", "CD"], # this is a non-standard chi2 angle 

740 ... "CHI3": ["CB", "OG", "CD", "CE"], # this is a non-standard chi3 angle 

741 ... }, 

742 ... ), 

743 ... "CSR": ( # CSR is selenocysteine 

744 ... "S", 

745 ... { 

746 ... "bonds": [ # we can manually define bonds for selenocysteine like so: 

747 ... ("-C", "N"), # bond between previous carbon and nitrogen CSR 

748 ... ("N", "CA"), 

749 ... ("N", "H1"), 

750 ... ("CA", "C"), 

751 ... ("CA", "HA"), # this topology includes hydrogens 

752 ... ("C", "O"), 

753 ... ("C", "OXT"), # As the C-terminal residue, we don't need to put ("C", "+N") here 

754 ... ("CA", "CB"), 

755 ... ("CB", "HB1"), 

756 ... ("CB", "HB2"), 

757 ... ("CB", "SE"), 

758 ... ("SE", "HE"), 

759 ... ], 

760 ... "CHI1": ["N", "CA", "CB", "SE"], # this is a non-standard chi1 angle 

761 ... }, 

762 ... ), 

763 ... "TPO": ( # TPO is phosphothreonine 

764 ... "T", 

765 ... { 

766 ... "CHI2": ["CA", "CB", "OG1", "P"], # a non-standard chi2 angle 

767 ... "CHI3": ["CB", "OG1", "P", "OXT"], # a non-standard chi3 angle 

768 ... }, 

769 ... ), 

770 ... } 

771 ... 

772 >>> # loading this will raise an Exception, because the bonds in CSR already exist 

773 >>> traj.load_custom_topology(custom_aas) # doctest: +IGNORE_EXCEPTION_DETAIL, +ELLIPSIS, +NORMALIZE_WHITESPACE 

774 Traceback (most recent call last): 

775 ... 

776 Exception: Bond between ALL5-C and CSR6-N already exists. Consider using the key 'optional_bonds' to not raise an Exception on already existing bonds. 

777 >>> # If we rename the "bonds" section in "CSR" to "optional_bonds" it will work 

778 >>> custom_aas["CSR"][1]["optional_bonds"] = custom_aas["CSR"][1].pop("bonds") 

779 >>> traj.load_custom_topology(custom_aas) 

780 >>> sidechains = em.features.SideChainDihedrals(traj).describe() 

781 >>> "SIDECHDIH CHI2 RESID OAS: 4 CHAIN 0" in sidechains 

782 True 

783 >>> "SIDECHDIH CHI3 RESID OAS: 4 CHAIN 0" in sidechains 

784 True 

785 >>> "SIDECHDIH CHI1 RESID CSR: 6 CHAIN 0" in sidechains 

786 True 

787 >>> "SIDECHDIH CHI2 RESID TPO: 1 CHAIN 0" in sidechains 

788 True 

789 >>> "SIDECHDIH CHI3 RESID TPO: 1 CHAIN 0" in sidechains 

790 True 

791 

792 """ 

793 

794 def __init__( 

795 self, 

796 *new_residues: NewResidue, 

797 traj: Optional[SingleTraj] = None, 

798 ) -> None: 

799 """Instantiate the CustomTopology. 

800 

801 Args: 

802 *residues (NewResidue): An arbitrary amount of instances of `NewResidue`. 

803 traj (Optional[SingleTraj]): An instance of `SingleTraj` can be 

804 provided to allow fixing that traj's topology. 

805 

806 """ 

807 self.residues = set([*new_residues]) 

808 self.traj = traj 

809 self._parsed = False 

810 self.amino_acid_codes = _AMINO_ACID_CODES 

811 

812 @property 

813 def top(self) -> md.Topology: 

814 """md.Topology: The fixed topology.""" 

815 if not self._parsed: 

816 top = self.add_bonds() 

817 self.add_amino_acid_codes() 

818 self._parsed = True 

819 self._top = top 

820 return self._top 

821 

822 @property 

823 def new_residues(self) -> list[NewResidue]: 

824 """list[NewResidue]: A list of all new residues.""" 

825 return list(self.residues) 

826 

827 def add_new_residue(self, new_residue: NewResidue) -> None: 

828 """Adds an instance of `NewResidue` to the reisdues of this `CustomTopology`. 

829 

830 Args: 

831 new_residue (NewResidue): An instance of `NewResidue`. 

832 

833 """ 

834 self.residues += new_residue 

835 

836 def __hash__(self) -> int: 

837 if self.residues: 

838 return reduce(xor, [hash(r) for r in list(self.residues)]) 

839 else: 

840 return hash(None) 

841 

842 def __add__(self, other: CustomTopology) -> CustomTopology: 

843 return CustomTopology(*(self.residues | other.residues)) 

844 

845 def __eq__(self, other: CustomTopology) -> bool: 

846 return self.residues == other.residues 

847 

848 def add_bonds(self) -> md.Topology: 

849 """Adds and deletes bonds specified in the custom topology. 

850 

851 Returns: 

852 md.Topology: The new topology. 

853 

854 """ 

855 # Encodermap imports 

856 from encodermap.misc.misc import _validate_uri 

857 

858 top = self.traj._get_raw_top() 

859 for residue in self.residues: 

860 # search for this residue in the protein 

861 if residue.idx is not None: 

862 top_residue = top.residue(residue.idx) 

863 assert top_residue.name == residue.name, ( 

864 f"There is no residue with the name {residue.name} " 

865 f"and the index {residue.idx} in the topology." 

866 f"Residue at index {residue.idx} has the name {top_residue.name}." 

867 ) 

868 top_residues = [top_residue] 

869 elif residue.resSeq is not None: 

870 top_residues = [r for r in top.residues if r.resSeq == residue.resSeq] 

871 else: 

872 top_residues = [r for r in top.residues if r.name == residue.name] 

873 

874 # add the bonds of this residue 

875 for b in residue.bonds: 

876 current_bonds = [(a1, a2) for a1, a2 in top.bonds] 

877 a1 = b.atom1 

878 a2 = b.atom2 

879 action = b.type 

880 assert isinstance(a1, (str, int)) 

881 assert isinstance(a2, (str, int)) 

882 

883 _a1 = deepcopy(a1) 

884 _a2 = deepcopy(a2) 

885 

886 # iterate over the found residues 

887 for r in top_residues: 

888 index = r.index 

889 

890 # find atom 1 by str 

891 if isinstance(_a1, str): 

892 if _a1.startswith("-"): 

893 a1 = _a1.lstrip("-") 

894 a1_r = top.residue(index - 1) 

895 a1 = [a for a in a1_r.atoms if a.name == a1] 

896 elif _a1.startswith("+"): 

897 a1 = _a1.lstrip("+") 

898 a1_r = top.residue(index + 1) 

899 a1 = [a for a in a1_r.atoms if a.name == a1] 

900 else: 

901 a1 = [a for a in r.atoms if a.name == _a1] 

902 if len(a1) == 0: 

903 if "optional" in action: 

904 continue 

905 raise Exception( 

906 f"Atom {_a1} not part of residue {r}: {a1=}, {b=}" 

907 ) 

908 elif len(a1) > 1: 

909 raise Exception( 

910 f"Multiple atoms with same name in residue {r}: {a1=}" 

911 ) 

912 a1 = a1[0] 

913 

914 # find by int 

915 elif isinstance(_a1, int): 

916 a1 = top.atom(_a1) 

917 else: 

918 raise Exception( 

919 f"Wrong type: {type(_a1)=}. Needs to be str or int." 

920 ) 

921 

922 # find atom 2 by str 

923 if isinstance(_a2, str): 

924 if _a2.startswith("-"): 

925 a2 = _a2.lstrip("-") 

926 a2_r = top.residue(index - 1) 

927 a2 = [a for a in a2_r.atoms if a.name == a2] 

928 elif _a2.startswith("+"): 

929 a2 = _a2.lstrip("+") 

930 try: 

931 a2_r = top.residue(index + 1) 

932 except IndexError: 

933 continue 

934 a2 = [a for a in a2_r.atoms if a.name == a2] 

935 else: 

936 a2 = [a for a in r.atoms if a.name == _a2] 

937 assert isinstance(a2, list) 

938 if len(a2) == 0: 

939 if "optional" in action: 

940 continue 

941 raise Exception( 

942 f"Atom {_a2} not part of residue {r}: {a2=}, {b=}" 

943 ) 

944 elif len(a2) > 1: 

945 raise Exception( 

946 f"Multiple atoms with same name in residue {r}: {a2=}" 

947 ) 

948 a2 = a2[0] 

949 

950 # find atom 2 by int 

951 elif isinstance(_a2, int): 

952 a2 = top.atom(_a2) 

953 else: 

954 raise Exception( 

955 f"Wrong type: {type(_a2)=}. Needs to be str or int." 

956 ) 

957 

958 # decide what to do with this bond 

959 if action == "add" or action == "optional": 

960 if a1.residue.chain.index != a2.residue.chain.index: 

961 self.combine_chains( 

962 a1.residue.chain.index, a2.residue.chain.index 

963 ) 

964 if (a1, a2) not in current_bonds and ( 

965 a2, 

966 a1, 

967 ) not in current_bonds: 

968 top.add_bond(a1, a2) 

969 assert (a1, a2) in [(n1, n2) for n1, n2 in top.bonds] or ( 

970 a2, 

971 a1, 

972 ) in [(n1, n2) for n1, n2 in top.bonds] 

973 else: 

974 if action != "optional": 

975 raise Exception( 

976 f"Bond between {a1} and {a2} already exists. " 

977 f"Consider using the key 'optional_bonds' to not " 

978 f"raise an Exception on already existing bonds." 

979 ) 

980 elif action == "delete" or action == "optional_delete": 

981 if (a1, a2) in current_bonds or (a2, a1) in current_bonds: 

982 top = _delete_bond(top, (a1, a2)) 

983 else: 

984 if action == "delete": 

985 raise Exception( 

986 f"Bond between {a1} and {a2} was not present in topology. " 

987 f"Consider using the key 'optional_delete_bonds' to not " 

988 f"raise an Exception on bonds that don't exist in the " 

989 f"first place." 

990 ) 

991 else: 

992 raise Exception( 

993 f"Bond action must be 'add', 'optional', 'delete', or " 

994 f"'optional_delete'. I got: {action}." 

995 ) 

996 return top 

997 

998 def combine_chains(self, chain_id1: int, chain_id2: int) -> None: 

999 """Function to combine two chains into one. 

1000 

1001 Args: 

1002 chain_id1 (int): The 0-based index of chain 1. 

1003 chain_id2 (int): The 0-based index of chain 2. 

1004 

1005 """ 

1006 raise NotImplementedError( 

1007 f"Currently not able to make a new bond across two different chains. " 

1008 f"Should be easy though, just make the atoms to bonds and pandas dataframe " 

1009 f"and manually combine and renumber chains. But don't have the time to do it now." 

1010 ) 

1011 

1012 def _atom_dict( 

1013 self, top: Optional[md.Topology] = None 

1014 ) -> dict[int, dict[int, dict[str, int]]]: 

1015 """A dictionary to lookup indices by atom name, residue_id and chain index. 

1016 

1017 The dictionary is nested as such:: 

1018 atom_dict[chain_index][residue_index][atom_name] = atom_index 

1019 

1020 """ 

1021 if top is None: 

1022 return _construct_atom_dict(self.top) 

1023 else: 

1024 return _construct_atom_dict(top) 

1025 

1026 def get_single_residue_atom_ids( 

1027 self, 

1028 atom_names: list[str], 

1029 r: NewResidue, 

1030 key_error_ok: bool = False, 

1031 ) -> np.ndarray: 

1032 """Gives the 0-based atom ids of a single residue. 

1033 

1034 Args: 

1035 atom_names (list[str]): The names of the atoms. ie. ['N' ,'CA', 'C', '+N'] 

1036 r (NewResidue): An instance of `NewResidue`. 

1037 key_error_ok (bool): Whether a key error when querying `self._atom_dict` 

1038 raises an error or returns an empty np.ndarray. 

1039 

1040 Returns: 

1041 np.ndarray: An integer array with the ids of the requested atoms. 

1042 

1043 """ 

1044 new_defs = [] 

1045 offsets = parse_offsets(atom_names) 

1046 atom_names = _strip_offsets(atom_names) 

1047 for chain in self.top.chains: 

1048 cid = chain.index 

1049 for residue in chain.residues: 

1050 rid = residue.index 

1051 if residue.name == r.name: 

1052 if r.resSeq is not None: 

1053 if residue.resSeq != r.resSeq: 

1054 continue 

1055 if r.idx is not None: 

1056 if rid != r.idx: 

1057 continue 

1058 try: 

1059 new_def = [ 

1060 self._atom_dict()[cid][rid + offset][atom] 

1061 for atom, offset in zip(atom_names, offsets) 

1062 ] 

1063 except KeyError as e: 

1064 if key_error_ok: 

1065 return np.array([]) 

1066 else: 

1067 raise e 

1068 new_defs.append(new_def) 

1069 return np.asarray(new_defs) 

1070 

1071 def backbone_sequence( 

1072 self, 

1073 atom_names: list[str], 

1074 type: Literal["PHI", "PSI", "OMEGA"], 

1075 ) -> np.ndarray: 

1076 """Searches for a sequence along the backbone. 

1077 

1078 Args: 

1079 atom_names (list[str]): The names of the atoms. Can use +/- to 

1080 mark atoms in previous or following residue. 

1081 type (Literal["PHI", "PSI", "OMEGA"]): The type of the dihedral 

1082 sequence. 

1083 

1084 Returns: 

1085 np.ndarray: The integer indices of the requested atoms. 

1086 

1087 """ 

1088 top = self.top.copy() 

1089 if hasattr(self, "traj"): 

1090 if self.traj._atom_indices is not None: 

1091 top = top.subset(self.traj._atom_indices) 

1092 default_indices, default_sequence = _atom_sequence(top, atom_names) 

1093 for r in self.residues: 

1094 if r.ignore: 

1095 continue 

1096 d = r.get_dihedral_by_type(type) 

1097 if d is None: 

1098 msg = ( 

1099 f"Your custom topology for residue name={r.name} resSeq={r.resSeq} " 

1100 f"index={r.idx} does not define atoms for the dihedral {type}. " 

1101 f"If this dihedral consists of standard atom names, it " 

1102 f"will be considered for dihedral calculations. If this " 

1103 f"dihedral should not be present in your custom topology you " 

1104 f"need to explicitly delete it by adding " 

1105 f"'{type.upper()}_ATOMS': 'delete' to your custom_topology. " 

1106 f"If you want this dihedral to be present in your topology, " 

1107 f"you can ignore this warning." 

1108 ) 

1109 warnings.warn(msg, stacklevel=2) 

1110 continue 

1111 if not d.delete: 

1112 if d.new_atoms_def != atom_names: 

1113 new_defs = self.get_single_residue_atom_ids(d.new_atoms_def, r) 

1114 raise NotImplementedError( 

1115 f"Add this backbone angle with non-standard atoms to the indices" 

1116 ) 

1117 else: 

1118 new_defs = self.get_single_residue_atom_ids(d.new_atoms_def, r) 

1119 assert all( 

1120 [ 

1121 np.any( 

1122 np.all( 

1123 new_def == default_sequence, 

1124 axis=1, 

1125 ) 

1126 ) 

1127 for new_def in new_defs 

1128 ] 

1129 ) 

1130 if d.delete: 

1131 if type == "PSI": 

1132 atoms = ["N", "CA", "C", "+N"] 

1133 elif type == "PHI": 

1134 atoms = ["-C", "N", "CA", "C"] 

1135 elif type == "OMEGA": 

1136 atoms = ["CA", "C", "+N", "+CA"] 

1137 else: 

1138 raise Exception( 

1139 f"The dihedral angle type {type} is not recognized." 

1140 ) 

1141 delete_defs = self.get_single_residue_atom_ids(atoms, r) 

1142 delete_idx = np.all(default_sequence == delete_defs, axis=1) 

1143 if not delete_idx.any(): 

1144 warnings.warn( 

1145 f"Your custom topology requested the dihedral {d.type} " 

1146 f"of the residue name={r.name} index={r.idx} resSeq={r.resSeq} " 

1147 f"to be deleted. This dihedral was not found when searching " 

1148 f"for standard dihedral names. It is not present in the " 

1149 f"custom topology and it would also not be present if you " 

1150 f"haven't specified it to be deleted. If this is unexpected " 

1151 f"to you, you need to reevaluate your topology." 

1152 ) 

1153 delete_idx = np.where(delete_idx)[0] 

1154 assert len(delete_idx) == 1 

1155 delete_idx = delete_idx[0] 

1156 default_sequence = np.delete(default_sequence, delete_idx, 0) 

1157 return default_sequence 

1158 

1159 def sidechain_sequence( 

1160 self, 

1161 atom_names: list[str], 

1162 type: Literal["CHI1", "CHI2", "CHI3", "CHI4", "CHI5"], 

1163 top: Optional[md.Topology] = None, 

1164 ) -> np.ndarray: 

1165 """Searches for a sequence along the sidechains. 

1166 

1167 Args: 

1168 atom_names (list[str]): The names of the atoms. Can use +/- to 

1169 mark atoms in previous or following residue. 

1170 type (Literal["CHI1", "CHI2", "CHI3", "CHI4", "CHI5"]: The type of the dihedral 

1171 sequence. 

1172 top (Optional[md.Topology]): Can be used to overwrite the toplogy in 

1173 self.traj. 

1174 

1175 Returns: 

1176 np.ndarray: The integer indices of the requested atoms. 

1177 

1178 """ 

1179 delete_defs = None 

1180 if top is None: 

1181 top = self.top.copy() 

1182 if hasattr(self, "traj"): 

1183 if self.traj._atom_indices is not None: 

1184 top = top.subset(self.traj._atom_indices) 

1185 for r in self.residues: 

1186 d = r.get_dihedral_by_type(type) 

1187 if d is None: 

1188 continue 

1189 if "OT" in d.new_atoms_def: 

1190 atoms_def = d.new_atoms_def.copy() 

1191 atoms_def[atoms_def.index("OT")] = "OXT" 

1192 else: 

1193 atoms_def = d.new_atoms_def.copy() 

1194 if atoms_def not in atom_names: 

1195 atom_names.append(atoms_def) 

1196 if d.delete: 

1197 if delete_defs is None: 

1198 delete_defs = [] 

1199 if (atoms_def := d.new_atoms_def) == []: 

1200 atoms_def = globals()[f"{d.type}_ATOMS"] 

1201 delete = [ 

1202 self.get_single_residue_atom_ids(atoms, r, key_error_ok=True) 

1203 for atoms in atoms_def 

1204 ] 

1205 delete = list(filter(lambda x: x.size > 0, delete)) 

1206 delete = np.vstack(delete) 

1207 delete_defs.append(delete) 

1208 if delete_defs is not None: 

1209 delete_defs = np.vstack(delete_defs) 

1210 return self._indices_chi(atom_names, top, delete=delete_defs) 

1211 

1212 def _indices_chi( 

1213 self, 

1214 chi_atoms: Sequence[list[str]], 

1215 top: Optional[md.Topology] = None, 

1216 delete: Optional[Sequence[np.ndarray]] = None, 

1217 ) -> np.ndarray: 

1218 chi_atoms = list(filter(lambda x: x != [], chi_atoms)) 

1219 if top is None: 

1220 top = self.top.copy() 

1221 rids, indices = zip( 

1222 *(self._atom_sequence(atoms, top=top) for atoms in chi_atoms) 

1223 ) 

1224 id_sort = np.argsort(np.concatenate(rids)) 

1225 if not any(x.size for x in indices): 

1226 return np.empty(shape=(0, 4), dtype=int) 

1227 indices = np.vstack([x for x in indices if x.size])[id_sort] 

1228 if delete is not None: 

1229 delete_ids = [] 

1230 for row in delete: 

1231 delete_ids.append(np.where((indices == row).all(1))) 

1232 delete_ids = np.array(delete_ids) 

1233 indices = np.delete(indices, delete_ids, 0) 

1234 return indices 

1235 

1236 def _atom_sequence( 

1237 self, 

1238 atom_names, 

1239 residue_offsets=None, 

1240 top: Optional[md.Topology] = None, 

1241 ) -> tuple[np.ndarray, np.ndarray]: 

1242 if top is None: 

1243 top = self.top.copy() 

1244 if residue_offsets is None: 

1245 residue_offsets = parse_offsets(atom_names) 

1246 atom_names = _strip_offsets(atom_names) 

1247 atom_dict = self._atom_dict(top) 

1248 

1249 atom_indices = [] 

1250 found_residue_ids = [] 

1251 atoms_and_offsets = list(zip(atom_names, residue_offsets)) 

1252 for chain in top.chains: 

1253 cid = chain.index 

1254 for residue in chain.residues: 

1255 rid = residue.index 

1256 # Check that desired residue_IDs are in dict 

1257 if all([rid + offset in atom_dict[cid] for offset in residue_offsets]): 

1258 # Check that we find all atom names in dict 

1259 if all( 

1260 [ 

1261 atom in atom_dict[cid][rid + offset] 

1262 for atom, offset in atoms_and_offsets 

1263 ] 

1264 ): 

1265 # Lookup desired atom indices and add to list 

1266 atom_indices.append( 

1267 [ 

1268 atom_dict[cid][rid + offset][atom] 

1269 for atom, offset in atoms_and_offsets 

1270 ] 

1271 ) 

1272 found_residue_ids.append(rid) 

1273 

1274 atom_indices = np.array(atom_indices) 

1275 found_residue_ids = np.array(found_residue_ids) 

1276 

1277 if len(atom_indices) == 0: 

1278 atom_indices = np.empty(shape=(0, 4), dtype=int) 

1279 

1280 return found_residue_ids, atom_indices 

1281 

1282 def atom_sequence( 

1283 self, 

1284 type: Literal["PHI", "PSI", "OMEGA", "CHI1", "CHI2", "CHI3", "CHI4", "CHI5"], 

1285 ) -> tuple[np.ndarray, np.ndarray]: 

1286 """Returns either backbone or sidechain indices in a useful order. 

1287 

1288 Args: 

1289 type (Literal["OMEGA", "PHI", "PSI", "CHI1", "CHI2", "CHI3", "CHI4", "CHI5"]): 

1290 The angle, that is looked for. 

1291 

1292 Returns: 

1293 tuple[np.ndarray, np.ndarray]: A tuple containing two numpy arrays: 

1294 

1295 

1296 """ 

1297 atom_names = globals()[f"{type}_ATOMS"] 

1298 if isinstance(atom_names[0], str): 

1299 return self.backbone_sequence(atom_names, type) 

1300 else: 

1301 return self.sidechain_sequence(atom_names, type) 

1302 

1303 def sidechain_indices_by_residue( 

1304 self, 

1305 ) -> Generator[md.core.topology.Residue, np.ndarray]: 

1306 for residue in self.top.residues: 

1307 indices = [] 

1308 for i in range(1, 6): 

1309 atoms = np.array([a.index for a in residue.atoms]) 

1310 indices.extend( 

1311 self.sidechain_sequence( 

1312 atom_names=globals()[f"CHI{i}_ATOMS"], 

1313 type=f"CHI{i}", 

1314 top=self.top.subset(atoms), 

1315 ) 

1316 ) 

1317 indices = np.sort(np.unique(indices)) + atoms.min() 

1318 yield residue, indices 

1319 

1320 def indices_chi1(self) -> np.ndarray: 

1321 """Returns the requested indices as a (n_dihedrals, 4)-shaped numpy array.""" 

1322 return self.atom_sequence(type="CHI1") 

1323 

1324 def indices_chi2(self) -> np.ndarray: 

1325 """Returns the requested indices as a (n_dihedrals, 4)-shaped numpy array.""" 

1326 return self.atom_sequence(type="CHI2") 

1327 

1328 def indices_chi3(self) -> np.ndarray: 

1329 """Returns the requested indices as a (n_dihedrals, 4)-shaped numpy array.""" 

1330 return self.atom_sequence(type="CHI3") 

1331 

1332 def indices_chi4(self) -> np.ndarray: 

1333 """Returns the requested indices as a (n_dihedrals, 4)-shaped numpy array.""" 

1334 return self.atom_sequence(type="CHI4") 

1335 

1336 def indices_chi5(self) -> np.ndarray: 

1337 """Returns the requested indices as a (n_dihedrals, 4)-shaped numpy array.""" 

1338 return self.atom_sequence(type="CHI5") 

1339 

1340 def indices_psi(self) -> np.ndarray: 

1341 """Returns the requested indices as a (n_dihedrals, 4)-shaped numpy array.""" 

1342 return self.atom_sequence(type="PSI") 

1343 

1344 def indices_phi(self) -> np.ndarray: 

1345 """Returns the requested indices as a (n_dihedrals, 4)-shaped numpy array.""" 

1346 return self.atom_sequence(type="PHI") 

1347 

1348 def indices_omega(self) -> np.ndarray: 

1349 """Returns the requested indices as a (n_dihedrals, 4)-shaped numpy array.""" 

1350 return self.atom_sequence(type="OMEGA") 

1351 

1352 def add_amino_acid_codes(self) -> None: 

1353 self.amino_acid_codes |= {r.name: r.one_letter_code for r in self.new_residues} 

1354 

1355 def _str_summary(self): 

1356 out = [] 

1357 for r in self.new_residues: 

1358 out.append(str(r)) 

1359 if len(out) > 0: 

1360 return "\n\n".join(out) 

1361 return "CustomTopology without any custom residues." 

1362 

1363 def __str__(self): 

1364 return self._str_summary() 

1365 

1366 def __bool__(self): 

1367 return bool(self.residues) 

1368 

1369 def to_json(self) -> str: 

1370 # Standard Library Imports 

1371 import json 

1372 

1373 return json.dumps(self.to_dict()) 

1374 

1375 def to_hdf_file( 

1376 self, 

1377 fname: Union[Path, str], 

1378 ) -> None: 

1379 if hasattr(self, "traj"): 

1380 key = self.traj.traj_num 

1381 else: 

1382 key = None 

1383 if key is None: 

1384 key = "_custom_top" 

1385 else: 

1386 key = f"_custom_top_{key}" 

1387 with h5py.File(fname, mode="a") as file: 

1388 file.attrs[key] = str(self.to_dict()) 

1389 

1390 def to_dict(self) -> CustomAAsDict: 

1391 out = {} 

1392 for r in self.new_residues: 

1393 key = r.name 

1394 if r.resSeq is not None: 

1395 assert r.idx is None, f"Can't have resSeq and idx be not None." 

1396 key = f"{key}{r.resSeq}" 

1397 if r.idx is not None: 

1398 assert r.resSeq is None, f"Can't have resSeq and idx be not None." 

1399 key = f"{key}-{r.idx}" 

1400 if r.common_str: 

1401 key = (r.common_str, key) 

1402 

1403 if r.ignore: 

1404 out[key] = None 

1405 continue 

1406 

1407 one_letter_code = r.one_letter_code 

1408 def_dict = {} 

1409 for b in r.bonds: 

1410 if b.type == "add": 

1411 btype = "bonds" 

1412 else: 

1413 btype = f"{b.type}_bonds" 

1414 def_dict.setdefault(btype, []).append((b.atom1, b.atom2)) 

1415 for d in r.dihedrals: 

1416 if not d.delete: 

1417 def_dict[d.type] = [d.atom1, d.atom2, d.atom3, d.atom4] 

1418 else: 

1419 def_dict["not_" + d.type] = None 

1420 out[key] = (one_letter_code, def_dict) 

1421 return out 

1422 

1423 def to_yaml(self) -> None: 

1424 data = self.to_dict() 

1425 return yaml.dump(data) 

1426 

1427 @classmethod 

1428 def from_hdf5_file( 

1429 cls, 

1430 fname: Union[Path, str], 

1431 traj: Optional[SingleTraj] = None, 

1432 ): 

1433 # Standard Library Imports 

1434 import ast 

1435 

1436 if traj is None: 

1437 key = None 

1438 else: 

1439 key = traj.traj_num 

1440 

1441 if key is None: 

1442 key = "_custom_top" 

1443 else: 

1444 key = f"_custom_top_{key}" 

1445 with h5py.File(fname, mode="r") as file: 

1446 dic = ast.literal_eval(file.attrs[key]) 

1447 return cls.from_dict(dic, traj=traj) 

1448 

1449 @classmethod 

1450 def from_yaml(cls, path: Union[str, Path], traj: Optional["SingleTraj"] = None): 

1451 with open(path) as f: 

1452 data = yaml.load(f, Loader=yaml.FullLoader) 

1453 return cls.from_dict(data, traj) 

1454 

1455 @classmethod 

1456 def from_json(cls, json_str: str, traj: Optional["SingleTraj"] = None): 

1457 """The same as `from_dict`, but using a json str.""" 

1458 # Standard Library Imports 

1459 import json 

1460 

1461 return cls.from_dict(json.loads(json_str), traj) 

1462 

1463 @classmethod 

1464 def from_dict(cls, custom_aas: CustomAAsDict, traj: Optional[SingleTraj] = None): 

1465 """Instantiate the class from a dictionary. 

1466 

1467 Args: 

1468 custom_aas (CustomAAsDict): 

1469 Custom AAs defined by a dict with the following properties: 

1470 The keys are the residue names encountered in this traj. The 

1471 values to the keys can be one of three types: 

1472 * None: if a key: None pair is supplied, this just adds the 

1473 residue to the recognized residues. Nothing will be done 

1474 with it. 

1475 * str: If a key: str pair is supplied, it is expected that the 

1476 string matches one of the one-letter amino-acid codes. 

1477 If your new residue is based on Lysine and you named it 

1478 LYQ, you need to supply: {"LYQ": "K"} 

1479 * tuple[str, dict]: If your residue has nonstandard side-chain 

1480 angles (i.e. due to phosphorylation), you can supply a tuple 

1481 of the one-letter amino-acid code and a dict which defines 

1482 the sidechain angles like so: 

1483 {"THR": ("T", {"CHI2": ["CA", "CB", "CG", "P"]})} 

1484 In this example, the standard amino acid threonine was 

1485 phosphorylated. The chi2 angle was added. 

1486 If you want to add custom bonds you can add the "bond" key 

1487 to the dict and give it either atom names or atom indices of 

1488 other atoms like so: 

1489 {"LYQ": ("K", {"bonds": [("N", "CA"), ("N", "H"), ...], "CHI1": ["N", "CA", "CB", "CG"]}). 

1490 * tuple[str, str, dict]: In this case, the first string should 

1491 be the name of the amino-acid, the second string should 

1492 be a common_str, that is in `self.common_str`. That way, 

1493 the different topologies in this `TrajEnsemble` can dynamically 

1494 use different custom_aas. 

1495 

1496 """ 

1497 new_residues = [] 

1498 for resname, value in custom_aas.items(): 

1499 if isinstance(resname, tuple): 

1500 common_str = resname[0] 

1501 resname = resname[1] 

1502 else: 

1503 common_str = None 

1504 

1505 if common_str is not None and traj.common_str != "": 

1506 if common_str != traj.common_str: 

1507 continue 

1508 

1509 if value is not None: 

1510 assert ( 

1511 len(value) == 2 

1512 ), f"The custom_aas dict needs a tuple[str, dict] as its values." 

1513 

1514 if "-" in resname: 

1515 idx = int(re.findall(r"\d+", resname)[-1]) 

1516 resSeq = None 

1517 resname = resname.replace(f"-{idx}", "") 

1518 elif any(re.findall(r"\d+", resname)) and "-" not in resname: 

1519 idx = None 

1520 resSeq = int(re.findall(r"\d+", resname)[-1]) 

1521 resname = resname.replace(str(resSeq), "") 

1522 else: 

1523 idx = None 

1524 resSeq = None 

1525 

1526 if value is None: 

1527 residue = NewResidue( 

1528 name=resname, 

1529 common_str=common_str, 

1530 resSeq=resSeq, 

1531 idx=idx, 

1532 ignore=True, 

1533 ) 

1534 new_residues.append(residue) 

1535 continue 

1536 

1537 if value[1] is None: 

1538 residue = NewResidue( 

1539 name=resname, 

1540 common_str=common_str, 

1541 resSeq=resSeq, 

1542 idx=idx, 

1543 ignore=True, 

1544 ) 

1545 new_residues.append(residue) 

1546 continue 

1547 

1548 one_letter_code = value[0] 

1549 value = value[1] 

1550 residue = NewResidue( 

1551 name=resname, 

1552 one_letter_code=one_letter_code, 

1553 idx=idx, 

1554 resSeq=resSeq, 

1555 ) 

1556 residue.parse_bonds_and_dihedrals(value) 

1557 new_residues.append(residue) 

1558 return cls(*new_residues, traj=traj) 

1559 

1560 

1561################################################################################ 

1562# Utils 

1563################################################################################ 

1564 

1565 

1566def flatten(container): 

1567 for i in container: 

1568 if isinstance(i, (list, tuple)): 

1569 for j in flatten(i): 

1570 yield j 

1571 else: 

1572 yield i 

1573 

1574 

1575def trajs_combine_attrs( 

1576 args: Sequence[dict[str, Union[str, Any]]], 

1577 context: Optional[xr.Context] = None, # noqa: U100 

1578) -> dict[str, Any]: 

1579 """Used for combining attributes and checking, whether CVs stay in the same unit system. 

1580 

1581 Args: 

1582 args (Sequence[dict[str, Any]]): A sequence of dicts to combine. 

1583 context (Optional[xr.Context]): An xarray.Context object. 

1584 Currently not used in the function, but xarray passes it nonetheless 

1585 

1586 Returns: 

1587 dict[str, Any]: The combined dict. 

1588 

1589 """ 

1590 args = list(filter(lambda x: x != {}, args)) 

1591 if len(args) == 1: 

1592 return args[0] 

1593 

1594 concat = { 

1595 "full_path": "full_paths", 

1596 "topology_file": "topology_files", 

1597 "feature_axis": "feature_axes", 

1598 } 

1599 _inv_concat = {v: k for k, v in concat.items()} 

1600 out = {} 

1601 

1602 for arg in args: 

1603 for k, v in arg.items(): 

1604 if k in concat: 

1605 k = concat[k] 

1606 if isinstance(v, list): 

1607 out.setdefault(k, []).extend(v) 

1608 else: 

1609 out.setdefault(k, []).append(v) 

1610 

1611 for k in out.keys(): 

1612 out[k] = list(set(out[k])) 

1613 

1614 if "angle_units" in out: 

1615 assert len(out["angle_units"]) == 1, ( 

1616 f"Can't combine datasets with inhomogeneous angle types. The datasets " 

1617 f"you tried to combine had the angle types {out['angle_units']}." 

1618 ) 

1619 

1620 for k, v in out.copy().items(): 

1621 if len(v) == 1 and k in concat: 

1622 out[concat[k]] = out.pop(k)[0] 

1623 elif len(v) == 1 and k in _inv_concat: 

1624 out[_inv_concat[k]] = out.pop(k)[0] 

1625 elif all([v[0] == i for i in v[1:]]): 

1626 out[k] = v[0] 

1627 

1628 if "feature_axis" in out: 

1629 if len(out["feature_axis"]) > 1 and isinstance(out["feature_axis"], list): 

1630 raise Exception( 

1631 f"Could not combine xarray Dataset attributes:\n" 

1632 f"{out['feature_axis']}\n\n{args=}" 

1633 ) 

1634 

1635 return out 

1636 

1637 

1638def np_to_xr( 

1639 data: np.ndarray, 

1640 traj: SingleTraj, 

1641 attr_name: Optional[str] = None, 

1642 deg: Optional[bool] = None, 

1643 labels: Optional[list[str]] = None, 

1644 filename: Optional[Union[str, Path]] = None, 

1645) -> xr.DataArray: 

1646 """Converts a numpy.ndarray to a xarray.DataArray. 

1647 

1648 Can use some additional labels and attributes to customize the DataArray. 

1649 

1650 Args: 

1651 data (np.ndarray): The data to put into the xarray.DataArray. It is 

1652 assumed that this array is of shape (n_frames, n_features), where 

1653 n_frames is the number of frames in `traj` and n_features can be 

1654 any positive integer. 

1655 traj (SingleTraj): An instance of `SingleTraj`. 

1656 attr_name (Optional[str]): The name of the feature, that will be used 

1657 to identify this feature (e.g. 'dihedral_angles', 'my_distance'). 

1658 Can be completely custom. If None is provided, the feature will be 

1659 called 'FEATURE_{i}', where i is a 0-based index of unnamed features. 

1660 Defaults to None. 

1661 deg (Optional[bool]): When True, the input is assumed to use degree. 

1662 When False, the input is assumed in radians. This can be important 

1663 if you want to combine features (that are not allowed for angle 

1664 features with different units). If None, the input is assumed to be 

1665 not angular (distances, absolute positions). Defaults to None. 

1666 labels (Optional[list[str]]): A list of str, which contain labels for the 

1667 feature. If provided needs to be of `len(labels) == data.shape[1]`. 

1668 If None is provided, the labels will be '... FEATURE 0', '... FEATURE 1', 

1669 ..., '... FEATURE {n_frames}'. 

1670 filename (Optional[Union[str, Path]]): If the data is loaded from a file, 

1671 and `attr_name` and `labels` are both None, then they will use 

1672 the filename. 

1673 

1674 Returns: 

1675 xr.DataArray: The DataArray. 

1676 

1677 """ 

1678 if attr_name is None: 

1679 if filename is None: 

1680 msg = f"Please also provide an `attr_name` under which to save the CV." 

1681 raise Exception(msg) 

1682 attr_name = Path(filename).stem 

1683 

1684 if labels is not None: 

1685 if isinstance(labels, str): 

1686 labels = [ 

1687 f"{attr_name.upper()} {labels.upper()} {i}" 

1688 for i in range(data.shape[1]) 

1689 ] 

1690 elif all([isinstance(l, str) for l in labels]) and len(labels) == data.shape[1]: 

1691 pass 

1692 else: 

1693 raise Exception( 

1694 f"'labels' is either not a list of str or does not have the " 

1695 f"same number of datapoints as {data.shape=}, {labels[:5]=} " 

1696 f"{len(labels)=}." 

1697 ) 

1698 data = np.expand_dims(data, axis=0) 

1699 if np.any(np.isnan(data)): 

1700 # if some nans are found along frame remove them 

1701 if data.ndim == 2: 

1702 index = np.isnan(data).all(axis=0) 

1703 if np.any(index): 

1704 print( 

1705 f"The 2D `np.ndarray` provided for the trajectory {traj} has " 

1706 f"some frames ({np.count_nonzero(index)} of {len(index)} " 

1707 f"frames in axis 0) that are full of nans. These are " 

1708 f"automatically dropped from the array." 

1709 ) 

1710 data = data[:, ~index] 

1711 if data.ndim == 3: 

1712 idx = np.isnan(data).all(axis=2)[0] 

1713 if np.any(idx): 

1714 print( 

1715 f"The 3D `np.ndarray` provided for the trajectory {traj} has " 

1716 f"some frames ({np.count_nonzero(idx)} of {len(idx)} " 

1717 f"frames in axis 1) that are full of nans. These are " 

1718 f"automatically dropped from the array. For 3D arrays, axis " 

1719 f"0 represents the traj axis, which should always have " 

1720 f"length 1 (test: {data.shape[0]=})" 

1721 ) 

1722 data = data[:, ~idx] 

1723 if data.ndim == 4: 

1724 idx = np.isnan(data).any(axis=2)[0].any(axis=1) 

1725 if np.any(idx): 

1726 print( 

1727 f"The 4D `np.ndarray` provided for the trajectory {traj} has " 

1728 f"some frames ({np.count_nonzero(idx)} of {len(idx)} " 

1729 f"frames in axis 1) that are full of nans. These are " 

1730 f"automatically dropped from the array. For 4D arrays, axis " 

1731 f"0 represents the traj axis, which should always have " 

1732 f"length 1 (test: {data.shape[0]=}) and axis 3 represents the " 

1733 f"cartesian coordinate axis (x, y, z) so this axis should " 

1734 f"always have length 3 (test: {data.shape[3]=})." 

1735 ) 

1736 data = data[:, ~idx] 

1737 da = construct_xarray_from_numpy( 

1738 traj, data, attr_name, deg, labels, check_n_frames=True 

1739 ) 

1740 assert len(da.dims) >= 3, f"{da=}" 

1741 return da 

1742 

1743 

1744def load_CV_from_string_or_path( 

1745 file_or_feature: str, 

1746 traj: SingleTraj, 

1747 attr_name: Optional[str] = None, 

1748 cols: Optional[Union[int, list[int]]] = None, 

1749 deg: Optional[bool] = None, 

1750 labels: Optional[Union[list[str], str]] = None, 

1751) -> xr.Dataset: 

1752 """Loads CV data from a string. That string can either identify a features, 

1753 or point to a file. 

1754 

1755 Args: 

1756 file_or_feature (str): The file or feature to load. If 'all' is 

1757 provided, all "standard" features are loaded. But a feature name 

1758 like 'sidechain_angle' can alsop be provided. If a file with 

1759 the .txt or .npy extension is provided, the data in that file is used. 

1760 traj (SingleTraj): The trajectory, that is used to load the features. 

1761 attr_name (Union[None, str], optional): The name under which the CV 

1762 should be found in the class. Is needed, if a raw numpy array is 

1763 passed, otherwise the name will be generated from the filename 

1764 (if data == str), the DataArray.name (if data == xarray.DataArray), 

1765 or the feature name. 

1766 cols (Union[list, None], optional): A list specifying the columns to 

1767 use for the high-dimensional data. If your highD data contains 

1768 (x,y,z,...)-errors or has an enumeration column at col=0 this can 

1769 be used to remove this unwanted data. 

1770 deg (bool): Whether the provided data is in radians (False) 

1771 or degree (True). Can also be None for non-angular data. 

1772 labels (Union[list[str], str, None], optional): If you want to label the data 

1773 you provided pass a list of str. If set to None, the features in this 

1774 dimension will be labeled as 

1775 `[f"{attr_name.upper()} FEATURE {i}" for i in range(self.n_frames)]`. 

1776 If a str is provided, the features will be labeled as 

1777 `[f"{attr_name.upper()} {label.upper()} {i}" for i in range(self.n_frames)]`. 

1778 If a list[str] is provided, it needs to have the same length as 

1779 the traj has frames. Defaults to None. 

1780 

1781 Returns: 

1782 xr.Dataset: An xarray dataset. 

1783 

1784 """ 

1785 if ( 

1786 str(file_or_feature) == "all" 

1787 or str(file_or_feature) == "full" 

1788 or str(file_or_feature) in CAN_BE_FEATURE_NAME 

1789 ): 

1790 # feat = Featurizer(traj) 

1791 if file_or_feature == "all": 

1792 traj.featurizer.add_list_of_feats(which="all", deg=deg) 

1793 elif file_or_feature == "full": 

1794 traj.featurizer.add_list_of_feats(which="full", deg=deg) 

1795 else: 

1796 traj.featurizer.add_list_of_feats(which=[file_or_feature], deg=deg) 

1797 out = traj.featurizer.get_output() 

1798 if traj.traj_num is not None: 

1799 assert out.coords["traj_num"] == np.array([traj.traj_num]), print( 

1800 traj.traj_num, 

1801 out.coords["traj_num"].values, 

1802 traj.traj_num, 

1803 ) 

1804 return out 

1805 elif (f := Path(file_or_feature)).exists(): 

1806 if f.suffix == ".txt": 

1807 data = np.loadtxt(f, usecols=cols) 

1808 elif f.suffix == ".npy": 

1809 data = np.load(f) 

1810 if cols is not None: 

1811 data = data[:, cols] 

1812 elif f.suffix in [".nc", ".h5"]: 

1813 data = xr.open_dataset(f) 

1814 if len(data.data_vars.keys()) != 1: 

1815 if attr_name is not None: 

1816 raise Exception( 

1817 f"The dataset in {f} has " 

1818 f"{len(data.data_vars.keys())} DataArrays, " 

1819 f"but only one `attr_name`: '{attr_name}' " 

1820 f"was requested. The names of the DataArrays " 

1821 f"are: {data.data_vars.keys()}. I can't over" 

1822 f"ride them all with one `attr_name`. Set " 

1823 f"`attr_name` to None to load the data with " 

1824 f"their respective names" 

1825 ) 

1826 return data 

1827 else: 

1828 if attr_name is not None: 

1829 d = list(data.data_vars.values())[0] 

1830 d.name = attr_name 

1831 return d 

1832 else: 

1833 raise Exception( 

1834 f"Currently only .txt, .npy, .nc, and .h5 files can " 

1835 f"be loaded. Your file {f} does not have the " 

1836 f"correct extension." 

1837 ) 

1838 else: 

1839 raise Exception( 

1840 f"If features are loaded via a string, the string needs " 

1841 f"to be 'all', a feature name ('central_dihedrals'), or " 

1842 f'an existing file. Your string "{file_or_feature}"' 

1843 f"is none of those." 

1844 ) 

1845 

1846 da = np_to_xr(data, traj, attr_name, deg, labels, file_or_feature) 

1847 assert len(da.dims) == 3 

1848 return da 

1849 

1850 

1851def load_CVs_singletraj( 

1852 data: SingleTrajFeatureType, 

1853 traj: SingleTraj, 

1854 attr_name: Optional[str] = None, 

1855 cols: Optional[list[int]] = None, 

1856 deg: Optional[bool] = None, 

1857 periodic: bool = True, 

1858 labels: Optional[list[str]] = None, 

1859) -> xr.Dataset: 

1860 # Local Folder Imports 

1861 from ..loading.features import Feature 

1862 from ..loading.featurizer import ( 

1863 DaskFeaturizer, 

1864 EnsembleFeaturizer, 

1865 SingleTrajFeaturizer, 

1866 ) 

1867 

1868 if isinstance(attr_name, str): 

1869 if not attr_name.isidentifier(): 

1870 raise Exception( 

1871 f"Provided string for `attr_name` can not be a " 

1872 f"python identifier. Choose another attribute name." 

1873 ) 

1874 # load a string 

1875 if isinstance(data, (str, Path)): 

1876 CVs = load_CV_from_string_or_path(str(data), traj, attr_name, cols, deg, labels) 

1877 

1878 # load a list of strings from standard features 

1879 elif isinstance(data, list) and all([isinstance(_, str) for _ in data]): 

1880 # feat = Featurizer(traj) 

1881 traj.featurizer.add_list_of_feats(data, deg=deg, periodic=periodic) 

1882 out = traj.featurizer.get_output() 

1883 out.coords["traj_num"] = [traj.traj_num] 

1884 return out 

1885 

1886 # if the data is a numpy array 

1887 elif isinstance(data, (list, np.ndarray)): 

1888 assert not isinstance(labels, bool) 

1889 CVs = np_to_xr(np.asarray(data), traj, attr_name, deg, labels).to_dataset( 

1890 promote_attrs=True 

1891 ) 

1892 

1893 # xarray objects are simply returned 

1894 elif isinstance(data, xr.Dataset): 

1895 return data 

1896 

1897 elif isinstance(data, xr.DataArray): 

1898 return data.to_dataset(promote_attrs=True) 

1899 

1900 # if this is a feature 

1901 elif issubclass(data.__class__, Feature): 

1902 traj.featurizer.add_custom_feature(data) 

1903 return traj.featurizer.get_output() 

1904 

1905 # if an instance of Featurizer is provided 

1906 elif isinstance(data, (DaskFeaturizer, SingleTrajFeaturizer, EnsembleFeaturizer)): 

1907 if isinstance(attr_name, str): 

1908 if len(data) != 1: 

1909 raise TypeError( 

1910 f"Provided Featurizer contains {len(data)} " 

1911 f"features and `attr_name` is of type `str`. " 

1912 f"Please provide a list of str." 

1913 ) 

1914 attr_name = [attr_name] 

1915 if isinstance(attr_name, list): 

1916 if len(attr_name) != len(data): 

1917 raise IndexError( 

1918 f"Provided Featurizer contains {len(data)} " 

1919 f"features and `attr_name` contains " 

1920 f"{len(attr_name)} elements. Please make sure " 

1921 f"they contain the same amount of items." 

1922 ) 

1923 out = data.get_output() 

1924 assert out.sizes["traj_num"] == 1 

1925 if attr_name is not None: 

1926 if isinstance(attr_name, str): 

1927 attr_name = [attr_name] 

1928 _renaming = {} 

1929 for f, v in zip(data.features, attr_name): 

1930 _feature = False 

1931 if hasattr(f, "name"): 

1932 if f.name in FEATURE_NAMES: 

1933 k = FEATURE_NAMES[f.name] 

1934 _feature = True 

1935 if not _feature: 

1936 k = f.__class__.__name__ 

1937 _renaming[k] = v 

1938 out = out.rename_vars(_renaming) 

1939 return out 

1940 else: 

1941 raise TypeError( 

1942 f"`data` must be str, np.ndarray, list, xr.DataArray, xr.Dataset, " 

1943 f"em.Featurizer or em.features.Feature. You supplied " 

1944 f"{type(data)}." 

1945 ) 

1946 

1947 return CVs 

1948 

1949 

1950def load_CVs_ensembletraj( 

1951 trajs: TrajEnsemble, 

1952 data: TrajEnsembleFeatureType, 

1953 attr_name: Optional[list[str]] = None, 

1954 cols: Optional[list[int]] = None, 

1955 deg: Optional[bool] = None, 

1956 periodic: bool = True, 

1957 labels: Optional[list[str]] = None, 

1958 directory: Optional[Union[Path, str]] = None, 

1959 ensemble: bool = False, 

1960 override: bool = False, 

1961) -> None: 

1962 """Loads CVs for a trajectory ensemble. 

1963 

1964 CVs can be loaded from a multitude of sources. The argument `data` can be: 

1965 * np.ndarray: Use a numpy array as a feature. 

1966 * str | Path: You can point to .txt or .npy files and load the features 

1967 from these files. In this case, the `cols` argument can be 

1968 used to only use a subset of columns in these files. 

1969 You can also point to a single directory in which case the basename 

1970 of the trajectories will be used to look for .npy and .txt files. 

1971 * str: Some strings like "central_dihedrals" are recognized out-of-the-box. 

1972 You can also provide "all" to load all dihedrals used in an 

1973 `encodermap.AngleDihedralCartesianEncoderMap`. 

1974 * Feature: You can provide an `encodermap.loading.features` Feature. The 

1975 CVs will be loaded by creating a featurizer, adding this feature, and 

1976 obtaining the output. 

1977 * Featurizer: You can also directly provide a featurizer with multiple 

1978 features. 

1979 * xr.DataArray: You can also provide a xarray.DataArray, which will be 

1980 appended to the existing CVs. 

1981 * xr.Dataset: If you provide a xarray.Dataset, you will overwrite all 

1982 currently loaded CVs. 

1983 

1984 Args: 

1985 trajs (TrajEnsemble): The trajectory ensemble to load the data for. 

1986 data (Union[str, list, np.ndarray, 'all', xr.Dataset]): The CV to 

1987 load. When a numpy array is provided, it needs to have a shape 

1988 matching `n_frames`. The data is distributed to the trajs. 

1989 When a list of files is provided, `len(data)` needs to match 

1990 `n_trajs`. The first file will be loaded by the first traj 

1991 (based on the traj's `traj_num`) and so on. If a list of 

1992 `np.ndarray` is provided, the first array will be assigned to 

1993 the first traj (based on the traj's `traj_num`). If None is provided, 

1994 the argument `directory` will be used to construct a str like: 

1995 fname = directory + traj.basename + '_' + attr_name. If there are 

1996 .txt or .npy files matching that string in the `directory`, 

1997 the CVs will be loaded from these files to the corresponding 

1998 trajs. Defaults to None. 

1999 attr_name (Optional[str]): The name under which the CV should 

2000 be found in the class. Choose whatever you like. `highd`, `lowd`, 

2001 `dists`, etc. The CV can then be accessed via dot-notation: 

2002 `trajs.attr_name`. Defaults to None, in which case, the argument 

2003 `data` should point to existing files and the `attr_name` will 

2004 be extracted from these files. 

2005 cols (Optional[list[int]]): A list of integers indexing the columns 

2006 of the data to be loaded. This is useful if a file contains 

2007 columns which are not features (i.e. an indexer or the error of 

2008 the features. eg:: 

2009 

2010 id f1 f2 f1_err f2_err 

2011 0 1.0 2.0 0.1 0.1 

2012 1 2.5 1.2 0.11 0.52 

2013 

2014 In that case, you would want to supply `cols=[1, 2]` to the `cols` 

2015 argument. If None all columns are loaded. Defaults to None. 

2016 deg (Optional[bool]): Whether to return angular CVs using degrees. 

2017 If None or False, CVs will be in radian. Defaults to None. 

2018 labels (list): A list containing the labels for the dimensions of 

2019 the data. If you provide a `np.ndarray` with shape (n_trajs, 

2020 n_frames, n_feat), this list needs to be of len(n_feat) 

2021 Defaults to None. 

2022 directory (Optional[str]): If this argument is provided, the 

2023 directory will be searched for ``.txt`` or ``.npy`` files which 

2024 have the same names as the trajectories have basenames. The 

2025 CVs will then be loaded from these files. 

2026 ensemble (bool): Whether the trajs in this class belong to an ensemble. 

2027 This implies that they contain either the same topology or are 

2028 very similar (think wt, and mutant). Setting this option True will 

2029 try to match the CVs of the trajs onto the same dataset. 

2030 If a VAL residue has been replaced by LYS in the mutant, 

2031 the number of sidechain dihedrals will increase. The CVs of the 

2032 trajs with VAL will thus contain some NaN values. Defaults to False. 

2033 override (bool): Whether to override CVs with the same name as `attr_name`. 

2034 

2035 """ 

2036 # Local Folder Imports 

2037 from ..loading.features import Feature 

2038 from ..loading.featurizer import DaskFeaturizer, EnsembleFeaturizer 

2039 

2040 if isinstance(data, (str, Path)) and not ensemble: 

2041 # all EncoderMap features for ML training 

2042 if str(data) == "all": 

2043 [ 

2044 t.load_CV("all", deg=deg, periodic=periodic, override=override) 

2045 for t in trajs 

2046 ] 

2047 return 

2048 if str(data) == "full": 

2049 [ 

2050 t.load_CV("full", deg=deg, periodic=periodic, override=override) 

2051 for t in trajs 

2052 ] 

2053 return 

2054 path_data = Path(data) 

2055 if not all([t.basename is None for t in trajs]): 

2056 npy_files = [ 

2057 (t._traj_file.parent if directory is None else Path(directory)) 

2058 / (t.basename + f"_{data}.npy") 

2059 for t in trajs 

2060 ] 

2061 txt_files = [ 

2062 (t._traj_file.parent if directory is None else Path(directory)) 

2063 / (t.basename + f"_{data}.txt") 

2064 for t in trajs 

2065 ] 

2066 raw_files = [ 

2067 (t._traj_file.parent if directory is None else Path(directory)) 

2068 / (t.basename + f"_{data}") 

2069 for t in trajs 

2070 ] 

2071 # a directory containing files with names identical to trajs 

2072 if path_data.is_dir(): 

2073 return load_CVs_from_dir( 

2074 trajs, data, attr_name=attr_name, deg=deg, cols=cols 

2075 ) 

2076 # maybe just a single feature 

2077 elif data in CAN_BE_FEATURE_NAME: 

2078 [ 

2079 t.load_CV( 

2080 data, 

2081 attr_name, 

2082 cols, 

2083 deg=deg, 

2084 periodic=periodic, 

2085 labels=labels, 

2086 override=override, 

2087 ) 

2088 for t in trajs 

2089 ] 

2090 return 

2091 # a h5 or nc file 

2092 elif path_data.is_file() and ( 

2093 path_data.suffix == ".h5" or path_data.suffix == ".nc" 

2094 ): 

2095 ds = xr.open_dataset(path_data) 

2096 if diff := set([t.traj_num for t in trajs]) - set(ds["traj_num"].values): 

2097 raise Exception( 

2098 f"The dataset you try to load and the TrajEnsemble " 

2099 f"have different number of trajectories: {diff}." 

2100 ) 

2101 for t, (traj_num, sub_ds) in zip( 

2102 trajs, ds.groupby("traj_num", squeeze=False) 

2103 ): 

2104 assert t.traj_num == traj_num, f"{t.traj_num=}, {traj_num=}" 

2105 if "traj_num" in sub_ds.coords: 

2106 assert sub_ds.coords["traj_num"] == t.traj_num 

2107 else: 

2108 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num) 

2109 sub_ds = sub_ds.expand_dims("traj_num") 

2110 assert sub_ds.coords["traj_num"] == np.array([t.traj_num]) 

2111 t.load_CV(sub_ds) 

2112 return 

2113 # all numpy files 

2114 elif all([f.is_file() for f in npy_files]): 

2115 [ 

2116 t.load_CV( 

2117 f, 

2118 attr_name=data, 

2119 cols=cols, 

2120 deg=deg, 

2121 labels=labels, 

2122 override=override, 

2123 ) 

2124 for t, f in zip(trajs, npy_files) 

2125 ] 

2126 return 

2127 # all txt files 

2128 elif all([f.is_file() for f in txt_files]): 

2129 [ 

2130 t.load_CV( 

2131 f, 

2132 attr_name=data, 

2133 cols=cols, 

2134 deg=deg, 

2135 labels=labels, 

2136 override=override, 

2137 ) 

2138 for t, f in zip(trajs, txt_files) 

2139 ] 

2140 return 

2141 # all raw files without suffix 

2142 elif all([f.is_file() for f in raw_files]): 

2143 [ 

2144 t.load_CV( 

2145 f, 

2146 attr_name=data, 

2147 cols=cols, 

2148 deg=deg, 

2149 labels=labels, 

2150 override=override, 

2151 ) 

2152 for t, f in zip(trajs, raw_files) 

2153 ] 

2154 return 

2155 # raise ValueError 

2156 else: 

2157 msg = ( 

2158 f"If `data` is provided a single string, the string needs to " 

2159 f"be either a feature ({CAN_BE_FEATURE_NAME}), a .h5/.nc file " 

2160 f"or a npy file. The provided `data`={data} " 

2161 f"fits none of these possibilities." 

2162 ) 

2163 raise ValueError(msg) 

2164 

2165 elif isinstance(data, list) and not ensemble: 

2166 if all([isinstance(i, (list, np.ndarray)) for i in data]): 

2167 [ 

2168 t.load_CV( 

2169 d, 

2170 attr_name, 

2171 cols, 

2172 deg=deg, 

2173 periodic=periodic, 

2174 labels=labels, 

2175 override=override, 

2176 ) 

2177 for t, d in zip(trajs, data) 

2178 ] 

2179 return 

2180 elif all([i in CAN_BE_FEATURE_NAME for i in data]): 

2181 [ 

2182 t.load_CV( 

2183 data=data, 

2184 attr_name=attr_name, 

2185 cols=cols, 

2186 deg=deg, 

2187 periodic=periodic, 

2188 labels=labels, 

2189 override=override, 

2190 ) 

2191 for t in trajs 

2192 ] 

2193 return 

2194 elif all([Path(f).is_file() for f in data]): 

2195 suffix = set([Path(f).suffix for f in data]) 

2196 if len(suffix) != 1: 

2197 raise Exception( 

2198 "Please provide a list with consistent file " 

2199 f"extensions and not a mish-mash, like: {suffix}" 

2200 ) 

2201 suffix = suffix.pop() 

2202 if suffix == ".npy": 

2203 [ 

2204 t.load_CV( 

2205 data=np.load(d), 

2206 attr_name=attr_name, 

2207 cols=cols, 

2208 deg=deg, 

2209 periodic=periodic, 

2210 labels=labels, 

2211 override=override, 

2212 ) 

2213 for t, d in zip(trajs, data) 

2214 ] 

2215 else: 

2216 [ 

2217 t.load_CV( 

2218 data=np.genfromtxt(d), 

2219 attr_name=attr_name, 

2220 cols=cols, 

2221 deg=deg, 

2222 periodic=periodic, 

2223 labels=labels, 

2224 override=override, 

2225 ) 

2226 for t, d in zip(trajs, data) 

2227 ] 

2228 return 

2229 else: 

2230 if not all([isinstance(d, str) for d in data]): 

2231 msg = ( 

2232 f"If `data` is provided as a list, the list needs to contain " 

2233 f"strings that can be features ({CAN_BE_FEATURE_NAME}), or " 

2234 f"some combination of lists and numpy arrays." 

2235 ) 

2236 else: 

2237 wrong = [d for d in data if d not in CAN_BE_FEATURE_NAME] 

2238 msg = ( 

2239 f"The list of str you supplied, did contain some str, that is " 

2240 f"not recognized as a feature: {wrong}." 

2241 ) 

2242 raise ValueError(msg) 

2243 

2244 elif isinstance(data, np.ndarray): 

2245 if len(data) != trajs.n_trajs and len(data) != trajs.n_frames: 

2246 raise ValueError( 

2247 f"The provided numpy array is misshaped. It needs " 

2248 f"to be of shape (n_trajs={trajs.n_trajs}, " 

2249 f"n_frames={np.unique([t.n_frames for t in trajs])[0]}, " 

2250 f"X, (Y)), but is {data.shape}." 

2251 ) 

2252 if len(data) == trajs.n_frames: 

2253 data = [ 

2254 data[np.where(trajs.index_arr[:, 0] == t.traj_num)[0]] for t in trajs 

2255 ] 

2256 assert len(data) == trajs.n_trajs 

2257 for d, t in zip(data, trajs): 

2258 t.load_CV( 

2259 data=d, 

2260 attr_name=attr_name, 

2261 cols=cols, 

2262 deg=deg, 

2263 periodic=periodic, 

2264 labels=labels, 

2265 override=override, 

2266 ) 

2267 for t in trajs: 

2268 for v in t._CVs.values(): 

2269 assert v.shape[0] == 1, f"{t.basename=}, {v=}" 

2270 return 

2271 

2272 elif issubclass(data.__class__, Feature): 

2273 for t in trajs: 

2274 t.load_CV( 

2275 data, 

2276 attr_name, 

2277 cols, 

2278 deg=deg, 

2279 periodic=periodic, 

2280 labels=labels, 

2281 override=override, 

2282 ) 

2283 return 

2284 

2285 elif (isinstance(data, (EnsembleFeaturizer, DaskFeaturizer))) or ( 

2286 data.__class__.__name__ in ["EnsembleFeaturizer", "DaskFeaturizer"] 

2287 ): 

2288 ds = data.get_output() 

2289 assert (ds_traj_nums := list(ds.coords["traj_num"])) == ( 

2290 traj_traj_nums := [t.traj_num for t in trajs] 

2291 ), ( 

2292 f"The dataset provided by '{data}' does not match the trajectories in " 

2293 f"'{trajs}'. The dataset defines trajectories with these traj_nums: " 

2294 f"{ds_traj_nums}, while the TrajEnsemble has these traj_nums: {traj_traj_nums}. " 

2295 f"If you are using the DaskFeaturizer, make sure to not call its " 

2296 f"`transform()` method with any arguments, as this will alter the " 

2297 f"compute graph and affect the output of `get_output()`. To fix " 

2298 f"this for `DaskFeaturizers`, you can run `del feat.dataset`." 

2299 ) 

2300 for t, (traj_num, sub_ds) in zip(trajs, ds.groupby("traj_num", squeeze=False)): 

2301 assert t.traj_num == traj_num, f"{t.traj_num=}, {traj_num=}" 

2302 if "traj_num" in sub_ds.coords: 

2303 assert sub_ds.coords["traj_num"] == t.traj_num 

2304 else: 

2305 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num) 

2306 sub_ds = sub_ds.expand_dims("traj_num") 

2307 

2308 # remove frames of full nans, which can happen for weird overlaps 

2309 # Standard Library Imports 

2310 from functools import reduce 

2311 

2312 to_reduce = [] 

2313 for da in sub_ds.data_vars.values(): 

2314 if da.name.endswith("feature_indices"): 

2315 continue 

2316 a = ( 

2317 (~np.isnan(da).any(dim=da.attrs["feature_axis"])) 

2318 .squeeze("traj_num") 

2319 .values 

2320 ) 

2321 if a.ndim == 2: 

2322 a = np.all(a, axis=1) 

2323 to_reduce.append(a) 

2324 

2325 if len(to_reduce) == 1: 

2326 idx = to_reduce[0] 

2327 else: 

2328 idx = reduce(np.bitwise_and, to_reduce) 

2329 sub_ds = sub_ds.sel(frame_num=idx) 

2330 t.load_CV(sub_ds) 

2331 return 

2332 

2333 elif isinstance(data, xr.Dataset): 

2334 for i, (t, (traj_num, sub_ds)) in enumerate( 

2335 zip(trajs, data.groupby("traj_num", squeeze=False)) 

2336 ): 

2337 assert t.traj_num == traj_num, f"{t.traj_num=}, {traj_num=}" 

2338 if "traj_num" in sub_ds.coords: 

2339 assert sub_ds.coords["traj_num"] == t.traj_num 

2340 else: 

2341 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num) 

2342 sub_ds = sub_ds.expand_dims("traj_num") 

2343 sub_ds = sub_ds.dropna("frame_num", how="all") 

2344 t.load_CV(sub_ds) 

2345 return 

2346 

2347 if ensemble: 

2348 return load_CVs_ensemble(trajs, data, periodic=periodic) 

2349 

2350 else: 

2351 raise TypeError( 

2352 f"`data` must be str, np.ndarray, list, xr.Dataset, or " 

2353 f"em.Featurizer. You supplied {type(data)=} {data.__class__.__name__=}." 

2354 ) 

2355 

2356 

2357def load_CVs_ensemble( 

2358 trajs: TrajEnsemble, 

2359 data: Union[str, list[str], Literal["all", "full"]], 

2360 periodic: bool = True, 

2361) -> None: 

2362 """Loads CVs for a trajectory ensemble. This time with generic feature names 

2363 so different topologies are aligned and can be treated separately. Loading 

2364 CVs with ensemble=True will always delete existing CVs. 

2365 

2366 Args: 

2367 trajs (TrajEnsemble): The trajectory ensemble to load the data for. 

2368 data (Union[str, list[str], Literal["all']): The CV to 

2369 load. When a numpy array is provided, it needs to have a shape 

2370 matching `n_frames`. The data is distributed to the trajs. 

2371 When a list of files is provided, `len(data)` needs to match 

2372 `n_trajs`. The first file will be loaded by the first traj 

2373 (based on the traj's `traj_num`) and so on. If a list of 

2374 `np.ndarray` is provided, the first array will be assigned to 

2375 the first traj (based on the traj's `traj_num`). If None is provided, 

2376 the argument `directory` will be used to construct a str like: 

2377 fname = directory + traj.basename + '_' + attr_name. If there are 

2378 .txt or .npy files matching that string in the `directory`, 

2379 the CVs will be loaded from these files to the corresponding 

2380 trajs. Defaults to None. 

2381 periodic (bool): Whether distance, angle, dihedral calculations should 

2382 obey the minimum image convention. 

2383 

2384 """ 

2385 if isinstance(data, str): 

2386 if data != "all": 

2387 data = [data] 

2388 trajs.featurizer.add_list_of_feats(data, ensemble=True, periodic=periodic) 

2389 deg_units = [] 

2390 for f in trajs.featurizer.features: 

2391 if hasattr(f, "deg"): 

2392 deg_units.append(f.deg) 

2393 assert all( 

2394 [not d for d in deg_units] 

2395 ), "Loading an ensemble only possible if all degree units are radian." 

2396 output = trajs.featurizer.get_output() 

2397 for t, (traj_num, sub_ds) in zip(trajs, output.groupby("traj_num", squeeze=False)): 

2398 assert t.traj_num == traj_num, f"{t.traj_num=}, {traj_num=}" 

2399 try: 

2400 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num) 

2401 sub_ds = sub_ds.expand_dims("traj_num") 

2402 except ValueError as e: 

2403 if "already exists as a scalar" not in str(e): 

2404 raise e 

2405 if t.id.ndim == 2: 

2406 frames = t.id[:, 1] 

2407 else: 

2408 frames = t.id 

2409 sub_ds = sub_ds.sel({"frame_num": frames}) 

2410 if t._CVs: 

2411 warnings.warn( 

2412 "Using ensemble=True will drop old CV entries from " 

2413 "trajs, because the feature length increases." 

2414 ) 

2415 t._CVs = sub_ds 

2416 

2417 

2418def load_CVs_from_dir( 

2419 trajs: TrajEnsemble, 

2420 data: Path, 

2421 attr_name: Optional[str] = None, 

2422 cols: Optional[list[int]] = None, 

2423 deg: Optional[bool] = None, 

2424) -> None: 

2425 files = map(str, data.glob("*")) 

2426 files = list( 

2427 filter( 

2428 lambda x: True if any([traj.basename in x for traj in trajs]) else False, 

2429 files, 

2430 ) 

2431 ) 

2432 key = {"npy": 1, "txt": 2} 

2433 files = sorted( 

2434 files, 

2435 key=lambda x: key[x.split(".")[-1]] if x.split(".")[-1] in key else 3, 

2436 )[: trajs.n_trajs] 

2437 files = sorted( 

2438 files, 

2439 key=lambda x: [traj.basename in x for traj in trajs].index(True), 

2440 ) 

2441 assert ( 

2442 len(files) == trajs.n_trajs 

2443 ), f"Couldn't find the correct number of files:\n{files=}\nfor trajs:\n{trajs=}" 

2444 for traj, f in zip(trajs, files): 

2445 if traj.basename not in f: 

2446 raise Exception(f"File {f} does not contain substring of traj {traj}.") 

2447 traj.load_CV(f, attr_name=attr_name, cols=cols, deg=deg)