Coverage for encodermap/encodermap_tf1/moldata.py: 76%
154 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1import os
2import warnings
3from collections import OrderedDict
4from math import pi
6import MDAnalysis as md
7import numpy as np
8from MDAnalysis.analysis.align import AlignTraj
9from MDAnalysis.analysis.base import AnalysisBase
10from MDAnalysis.analysis.dihedrals import Dihedral
11from MDAnalysis.coordinates.memory import MemoryReader
12from MDAnalysis.lib.distances import calc_angles
13from tqdm import tqdm
15from .misc import create_dir
18class Positions(AnalysisBase):
19 def __init__(self, atomgroup, **kwargs):
20 super(Positions, self).__init__(atomgroup.universe.trajectory, **kwargs)
21 self._ag = atomgroup
23 def _prepare(self):
24 self.result = []
26 def _single_frame(self):
27 self.result.append(self._ag.positions)
29 def _conclude(self):
30 self.result = np.asarray(self.result)
33class Angles(AnalysisBase):
34 def __init__(self, atomgroups, **kwargs):
35 super(Angles, self).__init__(atomgroups[0].universe.trajectory, **kwargs)
36 self.atomgroups = atomgroups
38 if any([len(ag) != 3 for ag in atomgroups]): 38 ↛ 39line 38 didn't jump to line 39, because the condition on line 38 was never true
39 raise ValueError("All AtomGroups must contain 3 atoms")
41 self.ag1 = md.AtomGroup([ag[0] for ag in atomgroups])
42 self.ag2 = md.AtomGroup([ag[1] for ag in atomgroups])
43 self.ag3 = md.AtomGroup([ag[2] for ag in atomgroups])
45 def _prepare(self):
46 self.result = []
48 def _single_frame(self):
49 angle = calc_angles(
50 self.ag1.positions,
51 self.ag2.positions,
52 self.ag3.positions,
53 box=self.ag1.dimensions,
54 )
55 self.result.append(angle)
57 def _conclude(self):
58 self.result = np.asarray(self.result)
61class MolData:
62 """
63 MolData is designed to extract and hold conformational information from trajectories.
65 :ivar cartesians: numpy array of the trajectory atom coordinates
66 :ivar central_cartesians: cartesian coordinates of the central backbone atoms (N-CA-C-N-CA-C...)
67 :ivar dihedrals: all backbone dihederals (phi, psi, omega)
68 :ivar angles: all bond angles of the central backbone atoms
69 :ivar lengths: all bond lengths between neighbouring central atoms
70 :ivar sidedihedrals: all sidechain dihedrals
71 :ivar aminoaciddict: number of sidechain diheadrals
72 """
74 def __init__(
75 self,
76 atom_group,
77 cache_path="",
78 start=None,
79 stop=None,
80 step=None,
81 ):
82 """
83 :param atom_group: MDAnalysis atom group
84 :param cache_path: Allows to define a path where the calculated variables can be cached.
85 :param start: first frame to analyze
86 :param stop: last frame to analyze
87 :param step: step of the analyzes
88 """
89 self.universe = atom_group.universe
91 self.sorted_atoms = self.universe.atoms[
92 [atom.ix for atom in sorted(atom_group.atoms, key=self.sort_key)]
93 ]
95 self.central_atom_indices = [
96 i
97 for i, atom in enumerate(self.sorted_atoms)
98 if atom.name in ["N", "CA", "C"]
99 ]
100 self.central_atoms = self.sorted_atoms[self.central_atom_indices]
102 ######## Problems with ILE and PRO TRP
104 self.aminoaciddict = {
105 "ALA": 0,
106 "ARG": 5,
107 "ASN": 2,
108 "ASP": 2,
109 "CYS": 1,
110 "GLU": 3,
111 "GLN": 3,
112 "GLY": 0,
113 "HIS": 2,
114 "HID": 2,
115 "ILE": 1,
116 "LEU": 2,
117 "LYS": 4,
118 "MET": 3,
119 "PHE": 2,
120 "PRO": 0,
121 "SER": 1,
122 "THR": 1,
123 "TRP": 2,
124 "TYR": 2,
125 "VAL": 1,
126 "KAC": 4,
127 }
129 # Cartesians:
130 try:
131 self.cartesians = np.load(os.path.join(cache_path, "cartesians.npy"))
132 print("Loaded cartesians from {}".format(cache_path))
134 except FileNotFoundError:
135 print("Loading Cartesians...")
136 positions = Positions(self.sorted_atoms, verbose=True).run(
137 start=start, stop=stop, step=step
138 )
139 self.cartesians = positions.result.astype(np.float32)
141 if cache_path: 141 ↛ 142line 141 didn't jump to line 142, because the condition on line 141 was never true
142 np.save(
143 os.path.join(create_dir(cache_path), "cartesians.npy"),
144 self.cartesians,
145 )
147 self.central_cartesians = self.cartesians[:, self.central_atom_indices]
149 # Dihedrals:
150 try:
151 self.dihedrals = np.load(os.path.join(cache_path, "dihedrals.npy"))
152 print("Loaded dihedrals from {}".format(cache_path))
154 except FileNotFoundError:
155 print("Calculating dihedrals...")
156 dihedral_atoms = []
157 for i in OrderedDict.fromkeys(self.sorted_atoms.resnums):
158 phi_atoms = self.sorted_atoms.select_atoms(
159 "resnum {} and name C".format(i - 1)
160 ) + self.sorted_atoms.select_atoms(
161 "resnum {} and (name N or name CA or name C)".format(i)
162 )
163 if len(phi_atoms) == 4:
164 dihedral_atoms.append(phi_atoms.dihedral)
166 psi_atoms = self.sorted_atoms.select_atoms(
167 "resnum {} and (name N or name CA or name C)".format(i)
168 ) + self.sorted_atoms.select_atoms("resnum {} and name N".format(i + 1))
169 if len(psi_atoms) == 4:
170 dihedral_atoms.append(psi_atoms.dihedral)
172 omega_atoms = self.sorted_atoms.select_atoms(
173 "resnum {} and (name CA or name C)".format(i)
174 ) + self.sorted_atoms.select_atoms(
175 "resnum {} and (name N or name CA)".format(i + 1)
176 )
177 if len(psi_atoms) == 4:
178 dihedral_atoms.append(omega_atoms.dihedral)
180 dihedrals = Dihedral(dihedral_atoms, verbose=True).run(
181 start=start, stop=stop, step=step
182 )
183 self.dihedrals = dihedrals.angles.astype(np.float32)
184 self.dihedrals *= pi / 180
186 if cache_path: 186 ↛ 187line 186 didn't jump to line 187, because the condition on line 186 was never true
187 np.save(os.path.join(cache_path, "dihedrals.npy"), self.dihedrals)
189 # SideDihedrals
191 try:
192 self.sidedihedrals = np.load(os.path.join(cache_path, "sidedihedrals.npy"))
193 print("Loaded dihedrals from {}".format(cache_path))
195 except FileNotFoundError:
196 print("Calculating sidedihedrals...")
197 sidedihedral_atoms = []
199 for i in OrderedDict.fromkeys(self.sorted_atoms.resnums):
200 residue_atoms = self.sorted_atoms.select_atoms("resnum {}".format(i))
201 for n in range(
202 self.aminoaciddict[
203 self.universe.select_atoms(
204 "resnum {} and name CA".format(i)
205 ).resnames[0]
206 ]
207 ):
208 side_atoms = residue_atoms[n : int(n + 4)]
209 sidedihedral_atoms.append(side_atoms)
210 if sidedihedral_atoms == []:
211 self.sidedihedrals = np.nan
212 else:
213 warnings.showwarning(
214 "\033[1;37;40m This version of the MolData Class does not produce expected results for side-dihedrals.",
215 category=UserWarning,
216 filename="",
217 lineno=-1,
218 )
219 warnings.showwarning(
220 "\033[1;37;40m To make this class work the 'residue_atoms[n:int(n+4)]' needs to be reworked. It does not index the sidechains.",
221 category=UserWarning,
222 filename="",
223 lineno=-1,
224 )
225 sidedihedrals = Dihedral(sidedihedral_atoms, verbose=True).run(
226 start=start, stop=stop, step=step
227 )
228 self.sidedihedrals = sidedihedrals.angles.astype(np.float32)
229 self.sidedihedrals *= pi / 180
231 if cache_path: 231 ↛ 232line 231 didn't jump to line 232, because the condition on line 231 was never true
232 np.save(
233 os.path.join(cache_path, "sidedihedrals.npy"), self.sidedihedrals
234 )
236 # Angles:
237 try:
238 self.angles = np.load(os.path.join(cache_path, "angles.npy"))
239 print("Loaded angles from {}".format(cache_path))
241 except FileNotFoundError:
242 print("Calculating angles...")
243 angle_atoms = []
244 for i in range(len(self.central_atom_indices) - 2):
245 angle_atoms.append(
246 self.sorted_atoms[self.central_atom_indices[i : i + 3]]
247 )
249 angles = Angles(angle_atoms, verbose=True).run(
250 start=start, stop=stop, step=step
251 )
252 self.angles = angles.result.astype(np.float32)
254 if cache_path: 254 ↛ 255line 254 didn't jump to line 255, because the condition on line 254 was never true
255 np.save(os.path.join(create_dir(cache_path), "angles.npy"), self.angles)
257 # Lengths:
258 try:
259 self.lengths = np.load(os.path.join(cache_path, "lengths.npy"))
260 print("Loaded lengths from {}".format(cache_path))
262 except FileNotFoundError:
263 print("Calculating lengths...")
264 vecs = self.central_cartesians[:, :-1] - self.central_cartesians[:, 1:]
265 self.lengths = np.linalg.norm(vecs, axis=2)
266 if cache_path: 266 ↛ 267line 266 didn't jump to line 267, because the condition on line 266 was never true
267 np.save(
268 os.path.join(create_dir(cache_path), "lengths.npy"), self.lengths
269 )
271 assert self.lengths.shape[1] == self.central_cartesians.shape[1] - 1
272 assert self.angles.shape[1] == self.central_cartesians.shape[1] - 2
273 assert self.dihedrals.shape[1] == self.central_cartesians.shape[1] - 3
275 def __iadd__(self, other):
276 assert np.all(self.sorted_atoms.names == other.sorted_atoms.names)
277 self.cartesians = np.concatenate([self.cartesians, other.cartesians], axis=0)
278 self.central_cartesians = np.concatenate(
279 [self.central_cartesians, other.central_cartesians], axis=0
280 )
281 self.dihedrals = np.concatenate([self.dihedrals, other.dihedrals], axis=0)
282 self.sidedihedrals = np.concatenate(
283 [self.sidedihedrals, other.sidedihedrals], axis=0
284 )
285 self.angles = np.concatenate([self.angles, other.angles], axis=0)
286 self.lengths = np.concatenate([self.lengths, other.lengths], axis=0)
287 return self
289 @staticmethod
290 def sort_key(atom):
291 positions = {"N": 1, "CA": 2, "C": 5, "O": 6, "OXT": 7, "O1": 8, "O2": 9}
292 try:
293 result = positions[atom.name]
294 except KeyError:
295 result = 4
296 return atom.resnum, result
298 def write(
299 self,
300 path,
301 coordinates,
302 name="generated",
303 formats=("pdb", "xtc"),
304 only_central=False,
305 align_reference=None,
306 align_select="all",
307 ):
308 """
309 Writes a trajectory for the given coordinates.
311 :param path: directory where to save the trajectory
312 :param coordinates: numpy array of xyz coordinates (frames, atoms, xyz)
313 :param name: filename (without extension)
314 :param formats: specify which formats schould be used to write structure and trajectory. default: ("pdb", "xtc")
315 :param only_central: if True only central atom coordinates are expected (N-Ca-C...)
316 :param align_reference: Allows to allign the generated conformations according to some reference.
317 The reference should be given as MDAnalysis atomgroup
318 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60"
319 default is "all". Have a look at the MDAnalysis selection syntax for more details.
320 :return:
321 """
322 coordinates = np.array(coordinates)
323 if coordinates.ndim == 2:
324 coordinates = np.expand_dims(coordinates, 0)
325 if only_central:
326 output_universe = md.Merge(self.central_atoms)
327 self.sorted_atoms[self.central_atom_indices].write(
328 os.path.join(path, "{}.{}".format(name, formats[0]))
329 )
330 else:
331 output_universe = md.Merge(self.sorted_atoms)
332 self.sorted_atoms.write(
333 os.path.join(path, "{}.{}".format(name, formats[0]))
334 )
335 output_universe.load_new(coordinates, format=MemoryReader)
337 if align_reference is not None:
338 align_traj = AlignTraj(
339 output_universe, align_reference, align_select, in_memory=True
340 )
341 align_traj.run()
343 with md.Writer(os.path.join(path, "{}.{}".format(name, formats[1]))) as w:
344 for step in output_universe.trajectory:
345 w.write(output_universe.atoms)