Coverage for encodermap/encodermap_tf1/moldata.py: 71%

172 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# Standard Library Imports 

2import os 

3import shutil 

4import tempfile 

5import warnings 

6from collections import OrderedDict 

7from math import pi 

8 

9# Third Party Imports 

10import MDAnalysis as mda 

11import mdtraj as md 

12import numpy as np 

13from MDAnalysis.analysis.align import AlignTraj 

14from MDAnalysis.analysis.base import AnalysisBase 

15from MDAnalysis.analysis.dihedrals import Dihedral 

16from MDAnalysis.coordinates.memory import MemoryReader 

17from MDAnalysis.lib.distances import calc_angles 

18from mdtraj.geometry import dihedral as md_dihedral 

19from tqdm import tqdm 

20 

21# Local Folder Imports 

22from .misc import create_dir 

23 

24 

25class Positions(AnalysisBase): 

26 def __init__(self, atomgroup, **kwargs): 

27 super(Positions, self).__init__(atomgroup.universe.trajectory, **kwargs) 

28 self._ag = atomgroup 

29 

30 def _prepare(self): 

31 self.result = [] 

32 

33 def _single_frame(self): 

34 self.result.append(self._ag.positions) 

35 

36 def _conclude(self): 

37 self.result = np.asarray(self.result) 

38 

39 

40class Angles(AnalysisBase): 

41 def __init__(self, atomgroups, **kwargs): 

42 super(Angles, self).__init__(atomgroups[0].universe.trajectory, **kwargs) 

43 self.atomgroups = atomgroups 

44 

45 if any([len(ag) != 3 for ag in atomgroups]): 45 ↛ 46line 45 didn't jump to line 46, because the condition on line 45 was never true

46 raise ValueError("All AtomGroups must contain 3 atoms") 

47 

48 self.ag1 = mda.AtomGroup([ag[0] for ag in atomgroups]) 

49 self.ag2 = mda.AtomGroup([ag[1] for ag in atomgroups]) 

50 self.ag3 = mda.AtomGroup([ag[2] for ag in atomgroups]) 

51 

52 def _prepare(self): 

53 self.result = [] 

54 

55 def _single_frame(self): 

56 angle = calc_angles( 

57 self.ag1.positions, 

58 self.ag2.positions, 

59 self.ag3.positions, 

60 box=self.ag1.dimensions, 

61 ) 

62 self.result.append(angle) 

63 

64 def _conclude(self): 

65 self.result = np.asarray(self.result) 

66 

67 

68class MolData: 

69 """ 

70 MolData is designed to extract and hold conformational information from trajectories. 

71 

72 :ivar cartesians: numpy array of the trajectory atom coordinates 

73 :ivar central_cartesians: cartesian coordinates of the central backbone atoms (N-CA-C-N-CA-C...) 

74 :ivar dihedrals: all backbone dihederals (phi, psi, omega) 

75 :ivar angles: all bond angles of the central backbone atoms 

76 :ivar lengths: all bond lengths between neighbouring central atoms 

77 :ivar sidedihedrals: all sidechain dihedrals 

78 :ivar aminoaciddict: number of sidechain diheadrals 

79 """ 

80 

81 def __init__( 

82 self, 

83 atom_group, 

84 cache_path="", 

85 start=None, 

86 stop=None, 

87 step=None, 

88 ): 

89 """ 

90 :param atom_group: MDAnalysis atom group 

91 :param cache_path: Allows to define a path where the calculated variables can be cached. 

92 :param start: first frame to analyze 

93 :param stop: last frame to analyze 

94 :param step: step of the analyzes 

95 """ 

96 self.universe = atom_group.universe 

97 

98 self.sorted_atoms = self.universe.atoms[ 

99 [atom.ix for atom in sorted(atom_group.atoms, key=self.sort_key)] 

100 ] 

101 

102 self.central_atom_indices = [ 

103 i 

104 for i, atom in enumerate(self.sorted_atoms) 

105 if atom.name in ["N", "CA", "C"] 

106 ] 

107 self.central_atoms = self.sorted_atoms[self.central_atom_indices] 

108 

109 ######## Problems with ILE and PRO TRP 

110 

111 self.aminoaciddict = { 

112 "ALA": 0, 

113 "ARG": 5, 

114 "ASN": 2, 

115 "ASP": 2, 

116 "CYS": 1, 

117 "GLU": 3, 

118 "GLN": 3, 

119 "GLY": 0, 

120 "HIS": 2, 

121 "HID": 2, 

122 "ILE": 1, 

123 "LEU": 2, 

124 "LYS": 4, 

125 "MET": 3, 

126 "PHE": 2, 

127 "PRO": 0, 

128 "SER": 1, 

129 "THR": 1, 

130 "TRP": 2, 

131 "TYR": 2, 

132 "VAL": 1, 

133 "KAC": 4, 

134 } 

135 

136 # Cartesians: 

137 try: 

138 self.cartesians = np.load(os.path.join(cache_path, "cartesians.npy")) 

139 print("Loaded cartesians from {}".format(cache_path)) 

140 

141 except FileNotFoundError: 

142 print("Loading Cartesians...") 

143 positions = Positions(self.sorted_atoms, verbose=True).run( 

144 start=start, stop=stop, step=step 

145 ) 

146 self.cartesians = positions.result.astype(np.float32) 

147 

148 if cache_path: 148 ↛ 149line 148 didn't jump to line 149, because the condition on line 148 was never true

149 np.save( 

150 os.path.join(create_dir(cache_path), "cartesians.npy"), 

151 self.cartesians, 

152 ) 

153 

154 self.central_cartesians = self.cartesians[:, self.central_atom_indices] 

155 

156 # Dihedrals: 

157 try: 

158 self.dihedrals = np.load(os.path.join(cache_path, "dihedrals.npy")) 

159 print("Loaded dihedrals from {}".format(cache_path)) 

160 

161 except FileNotFoundError: 

162 print("Calculating dihedrals...") 

163 dihedral_atoms = [] 

164 for i in OrderedDict.fromkeys(self.sorted_atoms.resnums): 

165 phi_atoms = self.sorted_atoms.select_atoms( 

166 "resnum {} and name C".format(i - 1) 

167 ) + self.sorted_atoms.select_atoms( 

168 "resnum {} and (name N or name CA or name C)".format(i) 

169 ) 

170 if len(phi_atoms) == 4: 

171 dihedral_atoms.append(phi_atoms.dihedral) 

172 

173 psi_atoms = self.sorted_atoms.select_atoms( 

174 "resnum {} and (name N or name CA or name C)".format(i) 

175 ) + self.sorted_atoms.select_atoms("resnum {} and name N".format(i + 1)) 

176 if len(psi_atoms) == 4: 

177 dihedral_atoms.append(psi_atoms.dihedral) 

178 

179 omega_atoms = self.sorted_atoms.select_atoms( 

180 "resnum {} and (name CA or name C)".format(i) 

181 ) + self.sorted_atoms.select_atoms( 

182 "resnum {} and (name N or name CA)".format(i + 1) 

183 ) 

184 if len(psi_atoms) == 4: 

185 dihedral_atoms.append(omega_atoms.dihedral) 

186 

187 dihedrals = Dihedral(dihedral_atoms, verbose=True).run( 

188 start=start, stop=stop, step=step 

189 ) 

190 self.dihedrals = dihedrals.angles.astype(np.float32) 

191 self.dihedrals *= pi / 180 

192 

193 if cache_path: 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true

194 np.save(os.path.join(cache_path, "dihedrals.npy"), self.dihedrals) 

195 

196 # SideDihedrals 

197 try: 

198 self.sidedihedrals = np.load(os.path.join(cache_path, "sidedihedrals.npy")) 

199 if np.all(np.isnan(self.sidedihedrals)): 

200 raise FileNotFoundError 

201 print("Loaded sidedihedrals from {}".format(cache_path)) 

202 

203 except FileNotFoundError: 

204 print("Calculating sidedihedrals...") 

205 with tempfile.NamedTemporaryFile(suffix=".pdb") as fp: 

206 with mda.Writer(fp.name) as PDB: 

207 PDB.write(self.universe) 

208 top = md.load_pdb(fp.name).top 

209 

210 sidedihedral_atoms_inds = [] 

211 sidedihedral_atoms = [] 

212 for i in range(1, 6): 

213 indexer = getattr(md_dihedral, f"indices_chi{i}") 

214 inds = indexer(top) 

215 sidedihedral_atoms_inds.append(inds) 

216 sidedihedral_atoms_inds = np.vstack(sidedihedral_atoms_inds) 

217 

218 for ind in sidedihedral_atoms_inds: 218 ↛ 219line 218 didn't jump to line 219, because the loop on line 218 never started

219 atoms = [] 

220 for i in ind: 

221 atom = self.universe.select_atoms(f"index {i}") 

222 assert len(atom) == 1 

223 atoms.append(atom[0]) 

224 ag = atoms[0] + atoms[1] + atoms[2] + atoms[3] 

225 sidedihedral_atoms.append(ag.dihedral) 

226 assert all([len(ag.atoms) == 4 for ag in sidedihedral_atoms]) 

227 

228 if sidedihedral_atoms == []: 228 ↛ 231line 228 didn't jump to line 231, because the condition on line 228 was never false

229 self.sidedihedrals = np.nan 

230 else: 

231 sidedihedrals = Dihedral(sidedihedral_atoms, verbose=True).run( 

232 start=start, stop=stop, step=step 

233 ) 

234 self.sidedihedrals = sidedihedrals.angles.astype(np.float32) 

235 self.sidedihedrals *= pi / 180 

236 

237 if cache_path: 237 ↛ 238line 237 didn't jump to line 238, because the condition on line 237 was never true

238 np.save( 

239 os.path.join(cache_path, "sidedihedrals.npy"), self.sidedihedrals 

240 ) 

241 

242 # Angles: 

243 try: 

244 self.angles = np.load(os.path.join(cache_path, "angles.npy")) 

245 print("Loaded angles from {}".format(cache_path)) 

246 

247 except FileNotFoundError: 

248 print("Calculating angles...") 

249 angle_atoms = [] 

250 for i in range(len(self.central_atom_indices) - 2): 

251 angle_atoms.append( 

252 self.sorted_atoms[self.central_atom_indices[i : i + 3]] 

253 ) 

254 

255 angles = Angles(angle_atoms, verbose=True).run( 

256 start=start, stop=stop, step=step 

257 ) 

258 self.angles = angles.result.astype(np.float32) 

259 

260 if cache_path: 260 ↛ 261line 260 didn't jump to line 261, because the condition on line 260 was never true

261 np.save(os.path.join(create_dir(cache_path), "angles.npy"), self.angles) 

262 

263 # Lengths: 

264 try: 

265 self.lengths = np.load(os.path.join(cache_path, "lengths.npy")) 

266 print("Loaded lengths from {}".format(cache_path)) 

267 

268 except FileNotFoundError: 

269 print("Calculating lengths...") 

270 vecs = self.central_cartesians[:, :-1] - self.central_cartesians[:, 1:] 

271 self.lengths = np.linalg.norm(vecs, axis=2) 

272 if cache_path: 272 ↛ 273line 272 didn't jump to line 273, because the condition on line 272 was never true

273 np.save( 

274 os.path.join(create_dir(cache_path), "lengths.npy"), self.lengths 

275 ) 

276 

277 assert self.lengths.shape[1] == self.central_cartesians.shape[1] - 1 

278 assert self.angles.shape[1] == self.central_cartesians.shape[1] - 2 

279 assert self.dihedrals.shape[1] == self.central_cartesians.shape[1] - 3 

280 

281 def __iadd__(self, other): 

282 assert np.all(self.sorted_atoms.names == other.sorted_atoms.names) 

283 self.cartesians = np.concatenate([self.cartesians, other.cartesians], axis=0) 

284 self.central_cartesians = np.concatenate( 

285 [self.central_cartesians, other.central_cartesians], axis=0 

286 ) 

287 self.dihedrals = np.concatenate([self.dihedrals, other.dihedrals], axis=0) 

288 self.sidedihedrals = np.concatenate( 

289 [self.sidedihedrals, other.sidedihedrals], axis=0 

290 ) 

291 self.angles = np.concatenate([self.angles, other.angles], axis=0) 

292 self.lengths = np.concatenate([self.lengths, other.lengths], axis=0) 

293 return self 

294 

295 @staticmethod 

296 def sort_key(atom): 

297 positions = {"N": 1, "CA": 2, "C": 5, "O": 6, "OXT": 7, "O1": 8, "O2": 9} 

298 try: 

299 result = positions[atom.name] 

300 except KeyError: 

301 result = 4 

302 return atom.resnum, result 

303 

304 def write( 

305 self, 

306 path, 

307 coordinates, 

308 name="generated", 

309 formats=("pdb", "xtc"), 

310 only_central=False, 

311 align_reference=None, 

312 align_select="all", 

313 ): 

314 """ 

315 Writes a trajectory for the given coordinates. 

316 

317 :param path: directory where to save the trajectory 

318 :param coordinates: numpy array of xyz coordinates (frames, atoms, xyz) 

319 :param name: filename (without extension) 

320 :param formats: specify which formats schould be used to write structure and trajectory. default: ("pdb", "xtc") 

321 :param only_central: if True only central atom coordinates are expected (N-Ca-C...) 

322 :param align_reference: Allows to allign the generated conformations according to some reference. 

323 The reference should be given as MDAnalysis atomgroup 

324 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60" 

325 default is "all". Have a look at the MDAnalysis selection syntax for more details. 

326 :return: 

327 """ 

328 coordinates = np.array(coordinates) 

329 if coordinates.ndim == 2: 

330 coordinates = np.expand_dims(coordinates, 0) 

331 if only_central: 

332 output_universe = mda.Merge(self.central_atoms) 

333 self.sorted_atoms[self.central_atom_indices].write( 

334 os.path.join(path, "{}.{}".format(name, formats[0])) 

335 ) 

336 else: 

337 output_universe = mda.Merge(self.sorted_atoms) 

338 self.sorted_atoms.write( 

339 os.path.join(path, "{}.{}".format(name, formats[0])) 

340 ) 

341 output_universe.load_new(coordinates, format=MemoryReader) 

342 

343 if align_reference is not None: 

344 align_traj = AlignTraj( 

345 output_universe, align_reference, align_select, in_memory=True 

346 ) 

347 align_traj.run() 

348 

349 with mda.Writer(os.path.join(path, "{}.{}".format(name, formats[1]))) as w: 

350 for step in output_universe.trajectory: 

351 w.write(output_universe.atoms)