Coverage for encodermap/encodermap_tf1/moldata.py: 76%

154 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1import os 

2import warnings 

3from collections import OrderedDict 

4from math import pi 

5 

6import MDAnalysis as md 

7import numpy as np 

8from MDAnalysis.analysis.align import AlignTraj 

9from MDAnalysis.analysis.base import AnalysisBase 

10from MDAnalysis.analysis.dihedrals import Dihedral 

11from MDAnalysis.coordinates.memory import MemoryReader 

12from MDAnalysis.lib.distances import calc_angles 

13from tqdm import tqdm 

14 

15from .misc import create_dir 

16 

17 

18class Positions(AnalysisBase): 

19 def __init__(self, atomgroup, **kwargs): 

20 super(Positions, self).__init__(atomgroup.universe.trajectory, **kwargs) 

21 self._ag = atomgroup 

22 

23 def _prepare(self): 

24 self.result = [] 

25 

26 def _single_frame(self): 

27 self.result.append(self._ag.positions) 

28 

29 def _conclude(self): 

30 self.result = np.asarray(self.result) 

31 

32 

33class Angles(AnalysisBase): 

34 def __init__(self, atomgroups, **kwargs): 

35 super(Angles, self).__init__(atomgroups[0].universe.trajectory, **kwargs) 

36 self.atomgroups = atomgroups 

37 

38 if any([len(ag) != 3 for ag in atomgroups]): 38 ↛ 39line 38 didn't jump to line 39, because the condition on line 38 was never true

39 raise ValueError("All AtomGroups must contain 3 atoms") 

40 

41 self.ag1 = md.AtomGroup([ag[0] for ag in atomgroups]) 

42 self.ag2 = md.AtomGroup([ag[1] for ag in atomgroups]) 

43 self.ag3 = md.AtomGroup([ag[2] for ag in atomgroups]) 

44 

45 def _prepare(self): 

46 self.result = [] 

47 

48 def _single_frame(self): 

49 angle = calc_angles( 

50 self.ag1.positions, 

51 self.ag2.positions, 

52 self.ag3.positions, 

53 box=self.ag1.dimensions, 

54 ) 

55 self.result.append(angle) 

56 

57 def _conclude(self): 

58 self.result = np.asarray(self.result) 

59 

60 

61class MolData: 

62 """ 

63 MolData is designed to extract and hold conformational information from trajectories. 

64 

65 :ivar cartesians: numpy array of the trajectory atom coordinates 

66 :ivar central_cartesians: cartesian coordinates of the central backbone atoms (N-CA-C-N-CA-C...) 

67 :ivar dihedrals: all backbone dihederals (phi, psi, omega) 

68 :ivar angles: all bond angles of the central backbone atoms 

69 :ivar lengths: all bond lengths between neighbouring central atoms 

70 :ivar sidedihedrals: all sidechain dihedrals 

71 :ivar aminoaciddict: number of sidechain diheadrals 

72 """ 

73 

74 def __init__( 

75 self, 

76 atom_group, 

77 cache_path="", 

78 start=None, 

79 stop=None, 

80 step=None, 

81 ): 

82 """ 

83 :param atom_group: MDAnalysis atom group 

84 :param cache_path: Allows to define a path where the calculated variables can be cached. 

85 :param start: first frame to analyze 

86 :param stop: last frame to analyze 

87 :param step: step of the analyzes 

88 """ 

89 self.universe = atom_group.universe 

90 

91 self.sorted_atoms = self.universe.atoms[ 

92 [atom.ix for atom in sorted(atom_group.atoms, key=self.sort_key)] 

93 ] 

94 

95 self.central_atom_indices = [ 

96 i 

97 for i, atom in enumerate(self.sorted_atoms) 

98 if atom.name in ["N", "CA", "C"] 

99 ] 

100 self.central_atoms = self.sorted_atoms[self.central_atom_indices] 

101 

102 ######## Problems with ILE and PRO TRP 

103 

104 self.aminoaciddict = { 

105 "ALA": 0, 

106 "ARG": 5, 

107 "ASN": 2, 

108 "ASP": 2, 

109 "CYS": 1, 

110 "GLU": 3, 

111 "GLN": 3, 

112 "GLY": 0, 

113 "HIS": 2, 

114 "HID": 2, 

115 "ILE": 1, 

116 "LEU": 2, 

117 "LYS": 4, 

118 "MET": 3, 

119 "PHE": 2, 

120 "PRO": 0, 

121 "SER": 1, 

122 "THR": 1, 

123 "TRP": 2, 

124 "TYR": 2, 

125 "VAL": 1, 

126 "KAC": 4, 

127 } 

128 

129 # Cartesians: 

130 try: 

131 self.cartesians = np.load(os.path.join(cache_path, "cartesians.npy")) 

132 print("Loaded cartesians from {}".format(cache_path)) 

133 

134 except FileNotFoundError: 

135 print("Loading Cartesians...") 

136 positions = Positions(self.sorted_atoms, verbose=True).run( 

137 start=start, stop=stop, step=step 

138 ) 

139 self.cartesians = positions.result.astype(np.float32) 

140 

141 if cache_path: 141 ↛ 142line 141 didn't jump to line 142, because the condition on line 141 was never true

142 np.save( 

143 os.path.join(create_dir(cache_path), "cartesians.npy"), 

144 self.cartesians, 

145 ) 

146 

147 self.central_cartesians = self.cartesians[:, self.central_atom_indices] 

148 

149 # Dihedrals: 

150 try: 

151 self.dihedrals = np.load(os.path.join(cache_path, "dihedrals.npy")) 

152 print("Loaded dihedrals from {}".format(cache_path)) 

153 

154 except FileNotFoundError: 

155 print("Calculating dihedrals...") 

156 dihedral_atoms = [] 

157 for i in OrderedDict.fromkeys(self.sorted_atoms.resnums): 

158 phi_atoms = self.sorted_atoms.select_atoms( 

159 "resnum {} and name C".format(i - 1) 

160 ) + self.sorted_atoms.select_atoms( 

161 "resnum {} and (name N or name CA or name C)".format(i) 

162 ) 

163 if len(phi_atoms) == 4: 

164 dihedral_atoms.append(phi_atoms.dihedral) 

165 

166 psi_atoms = self.sorted_atoms.select_atoms( 

167 "resnum {} and (name N or name CA or name C)".format(i) 

168 ) + self.sorted_atoms.select_atoms("resnum {} and name N".format(i + 1)) 

169 if len(psi_atoms) == 4: 

170 dihedral_atoms.append(psi_atoms.dihedral) 

171 

172 omega_atoms = self.sorted_atoms.select_atoms( 

173 "resnum {} and (name CA or name C)".format(i) 

174 ) + self.sorted_atoms.select_atoms( 

175 "resnum {} and (name N or name CA)".format(i + 1) 

176 ) 

177 if len(psi_atoms) == 4: 

178 dihedral_atoms.append(omega_atoms.dihedral) 

179 

180 dihedrals = Dihedral(dihedral_atoms, verbose=True).run( 

181 start=start, stop=stop, step=step 

182 ) 

183 self.dihedrals = dihedrals.angles.astype(np.float32) 

184 self.dihedrals *= pi / 180 

185 

186 if cache_path: 186 ↛ 187line 186 didn't jump to line 187, because the condition on line 186 was never true

187 np.save(os.path.join(cache_path, "dihedrals.npy"), self.dihedrals) 

188 

189 # SideDihedrals 

190 

191 try: 

192 self.sidedihedrals = np.load(os.path.join(cache_path, "sidedihedrals.npy")) 

193 print("Loaded dihedrals from {}".format(cache_path)) 

194 

195 except FileNotFoundError: 

196 print("Calculating sidedihedrals...") 

197 sidedihedral_atoms = [] 

198 

199 for i in OrderedDict.fromkeys(self.sorted_atoms.resnums): 

200 residue_atoms = self.sorted_atoms.select_atoms("resnum {}".format(i)) 

201 for n in range( 

202 self.aminoaciddict[ 

203 self.universe.select_atoms( 

204 "resnum {} and name CA".format(i) 

205 ).resnames[0] 

206 ] 

207 ): 

208 side_atoms = residue_atoms[n : int(n + 4)] 

209 sidedihedral_atoms.append(side_atoms) 

210 if sidedihedral_atoms == []: 

211 self.sidedihedrals = np.nan 

212 else: 

213 warnings.showwarning( 

214 "\033[1;37;40m This version of the MolData Class does not produce expected results for side-dihedrals.", 

215 category=UserWarning, 

216 filename="", 

217 lineno=-1, 

218 ) 

219 warnings.showwarning( 

220 "\033[1;37;40m To make this class work the 'residue_atoms[n:int(n+4)]' needs to be reworked. It does not index the sidechains.", 

221 category=UserWarning, 

222 filename="", 

223 lineno=-1, 

224 ) 

225 sidedihedrals = Dihedral(sidedihedral_atoms, verbose=True).run( 

226 start=start, stop=stop, step=step 

227 ) 

228 self.sidedihedrals = sidedihedrals.angles.astype(np.float32) 

229 self.sidedihedrals *= pi / 180 

230 

231 if cache_path: 231 ↛ 232line 231 didn't jump to line 232, because the condition on line 231 was never true

232 np.save( 

233 os.path.join(cache_path, "sidedihedrals.npy"), self.sidedihedrals 

234 ) 

235 

236 # Angles: 

237 try: 

238 self.angles = np.load(os.path.join(cache_path, "angles.npy")) 

239 print("Loaded angles from {}".format(cache_path)) 

240 

241 except FileNotFoundError: 

242 print("Calculating angles...") 

243 angle_atoms = [] 

244 for i in range(len(self.central_atom_indices) - 2): 

245 angle_atoms.append( 

246 self.sorted_atoms[self.central_atom_indices[i : i + 3]] 

247 ) 

248 

249 angles = Angles(angle_atoms, verbose=True).run( 

250 start=start, stop=stop, step=step 

251 ) 

252 self.angles = angles.result.astype(np.float32) 

253 

254 if cache_path: 254 ↛ 255line 254 didn't jump to line 255, because the condition on line 254 was never true

255 np.save(os.path.join(create_dir(cache_path), "angles.npy"), self.angles) 

256 

257 # Lengths: 

258 try: 

259 self.lengths = np.load(os.path.join(cache_path, "lengths.npy")) 

260 print("Loaded lengths from {}".format(cache_path)) 

261 

262 except FileNotFoundError: 

263 print("Calculating lengths...") 

264 vecs = self.central_cartesians[:, :-1] - self.central_cartesians[:, 1:] 

265 self.lengths = np.linalg.norm(vecs, axis=2) 

266 if cache_path: 266 ↛ 267line 266 didn't jump to line 267, because the condition on line 266 was never true

267 np.save( 

268 os.path.join(create_dir(cache_path), "lengths.npy"), self.lengths 

269 ) 

270 

271 assert self.lengths.shape[1] == self.central_cartesians.shape[1] - 1 

272 assert self.angles.shape[1] == self.central_cartesians.shape[1] - 2 

273 assert self.dihedrals.shape[1] == self.central_cartesians.shape[1] - 3 

274 

275 def __iadd__(self, other): 

276 assert np.all(self.sorted_atoms.names == other.sorted_atoms.names) 

277 self.cartesians = np.concatenate([self.cartesians, other.cartesians], axis=0) 

278 self.central_cartesians = np.concatenate( 

279 [self.central_cartesians, other.central_cartesians], axis=0 

280 ) 

281 self.dihedrals = np.concatenate([self.dihedrals, other.dihedrals], axis=0) 

282 self.sidedihedrals = np.concatenate( 

283 [self.sidedihedrals, other.sidedihedrals], axis=0 

284 ) 

285 self.angles = np.concatenate([self.angles, other.angles], axis=0) 

286 self.lengths = np.concatenate([self.lengths, other.lengths], axis=0) 

287 return self 

288 

289 @staticmethod 

290 def sort_key(atom): 

291 positions = {"N": 1, "CA": 2, "C": 5, "O": 6, "OXT": 7, "O1": 8, "O2": 9} 

292 try: 

293 result = positions[atom.name] 

294 except KeyError: 

295 result = 4 

296 return atom.resnum, result 

297 

298 def write( 

299 self, 

300 path, 

301 coordinates, 

302 name="generated", 

303 formats=("pdb", "xtc"), 

304 only_central=False, 

305 align_reference=None, 

306 align_select="all", 

307 ): 

308 """ 

309 Writes a trajectory for the given coordinates. 

310 

311 :param path: directory where to save the trajectory 

312 :param coordinates: numpy array of xyz coordinates (frames, atoms, xyz) 

313 :param name: filename (without extension) 

314 :param formats: specify which formats schould be used to write structure and trajectory. default: ("pdb", "xtc") 

315 :param only_central: if True only central atom coordinates are expected (N-Ca-C...) 

316 :param align_reference: Allows to allign the generated conformations according to some reference. 

317 The reference should be given as MDAnalysis atomgroup 

318 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60" 

319 default is "all". Have a look at the MDAnalysis selection syntax for more details. 

320 :return: 

321 """ 

322 coordinates = np.array(coordinates) 

323 if coordinates.ndim == 2: 

324 coordinates = np.expand_dims(coordinates, 0) 

325 if only_central: 

326 output_universe = md.Merge(self.central_atoms) 

327 self.sorted_atoms[self.central_atom_indices].write( 

328 os.path.join(path, "{}.{}".format(name, formats[0])) 

329 ) 

330 else: 

331 output_universe = md.Merge(self.sorted_atoms) 

332 self.sorted_atoms.write( 

333 os.path.join(path, "{}.{}".format(name, formats[0])) 

334 ) 

335 output_universe.load_new(coordinates, format=MemoryReader) 

336 

337 if align_reference is not None: 

338 align_traj = AlignTraj( 

339 output_universe, align_reference, align_select, in_memory=True 

340 ) 

341 align_traj.run() 

342 

343 with md.Writer(os.path.join(path, "{}.{}".format(name, formats[1]))) as w: 

344 for step in output_universe.trajectory: 

345 w.write(output_universe.atoms)