Coverage for encodermap/encodermap_tf1/moldata.py: 71%
172 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# Standard Library Imports
2import os
3import shutil
4import tempfile
5import warnings
6from collections import OrderedDict
7from math import pi
9# Third Party Imports
10import MDAnalysis as mda
11import mdtraj as md
12import numpy as np
13from MDAnalysis.analysis.align import AlignTraj
14from MDAnalysis.analysis.base import AnalysisBase
15from MDAnalysis.analysis.dihedrals import Dihedral
16from MDAnalysis.coordinates.memory import MemoryReader
17from MDAnalysis.lib.distances import calc_angles
18from mdtraj.geometry import dihedral as md_dihedral
19from tqdm import tqdm
21# Local Folder Imports
22from .misc import create_dir
25class Positions(AnalysisBase):
26 def __init__(self, atomgroup, **kwargs):
27 super(Positions, self).__init__(atomgroup.universe.trajectory, **kwargs)
28 self._ag = atomgroup
30 def _prepare(self):
31 self.result = []
33 def _single_frame(self):
34 self.result.append(self._ag.positions)
36 def _conclude(self):
37 self.result = np.asarray(self.result)
40class Angles(AnalysisBase):
41 def __init__(self, atomgroups, **kwargs):
42 super(Angles, self).__init__(atomgroups[0].universe.trajectory, **kwargs)
43 self.atomgroups = atomgroups
45 if any([len(ag) != 3 for ag in atomgroups]): 45 ↛ 46line 45 didn't jump to line 46, because the condition on line 45 was never true
46 raise ValueError("All AtomGroups must contain 3 atoms")
48 self.ag1 = mda.AtomGroup([ag[0] for ag in atomgroups])
49 self.ag2 = mda.AtomGroup([ag[1] for ag in atomgroups])
50 self.ag3 = mda.AtomGroup([ag[2] for ag in atomgroups])
52 def _prepare(self):
53 self.result = []
55 def _single_frame(self):
56 angle = calc_angles(
57 self.ag1.positions,
58 self.ag2.positions,
59 self.ag3.positions,
60 box=self.ag1.dimensions,
61 )
62 self.result.append(angle)
64 def _conclude(self):
65 self.result = np.asarray(self.result)
68class MolData:
69 """
70 MolData is designed to extract and hold conformational information from trajectories.
72 :ivar cartesians: numpy array of the trajectory atom coordinates
73 :ivar central_cartesians: cartesian coordinates of the central backbone atoms (N-CA-C-N-CA-C...)
74 :ivar dihedrals: all backbone dihederals (phi, psi, omega)
75 :ivar angles: all bond angles of the central backbone atoms
76 :ivar lengths: all bond lengths between neighbouring central atoms
77 :ivar sidedihedrals: all sidechain dihedrals
78 :ivar aminoaciddict: number of sidechain diheadrals
79 """
81 def __init__(
82 self,
83 atom_group,
84 cache_path="",
85 start=None,
86 stop=None,
87 step=None,
88 ):
89 """
90 :param atom_group: MDAnalysis atom group
91 :param cache_path: Allows to define a path where the calculated variables can be cached.
92 :param start: first frame to analyze
93 :param stop: last frame to analyze
94 :param step: step of the analyzes
95 """
96 self.universe = atom_group.universe
98 self.sorted_atoms = self.universe.atoms[
99 [atom.ix for atom in sorted(atom_group.atoms, key=self.sort_key)]
100 ]
102 self.central_atom_indices = [
103 i
104 for i, atom in enumerate(self.sorted_atoms)
105 if atom.name in ["N", "CA", "C"]
106 ]
107 self.central_atoms = self.sorted_atoms[self.central_atom_indices]
109 ######## Problems with ILE and PRO TRP
111 self.aminoaciddict = {
112 "ALA": 0,
113 "ARG": 5,
114 "ASN": 2,
115 "ASP": 2,
116 "CYS": 1,
117 "GLU": 3,
118 "GLN": 3,
119 "GLY": 0,
120 "HIS": 2,
121 "HID": 2,
122 "ILE": 1,
123 "LEU": 2,
124 "LYS": 4,
125 "MET": 3,
126 "PHE": 2,
127 "PRO": 0,
128 "SER": 1,
129 "THR": 1,
130 "TRP": 2,
131 "TYR": 2,
132 "VAL": 1,
133 "KAC": 4,
134 }
136 # Cartesians:
137 try:
138 self.cartesians = np.load(os.path.join(cache_path, "cartesians.npy"))
139 print("Loaded cartesians from {}".format(cache_path))
141 except FileNotFoundError:
142 print("Loading Cartesians...")
143 positions = Positions(self.sorted_atoms, verbose=True).run(
144 start=start, stop=stop, step=step
145 )
146 self.cartesians = positions.result.astype(np.float32)
148 if cache_path: 148 ↛ 149line 148 didn't jump to line 149, because the condition on line 148 was never true
149 np.save(
150 os.path.join(create_dir(cache_path), "cartesians.npy"),
151 self.cartesians,
152 )
154 self.central_cartesians = self.cartesians[:, self.central_atom_indices]
156 # Dihedrals:
157 try:
158 self.dihedrals = np.load(os.path.join(cache_path, "dihedrals.npy"))
159 print("Loaded dihedrals from {}".format(cache_path))
161 except FileNotFoundError:
162 print("Calculating dihedrals...")
163 dihedral_atoms = []
164 for i in OrderedDict.fromkeys(self.sorted_atoms.resnums):
165 phi_atoms = self.sorted_atoms.select_atoms(
166 "resnum {} and name C".format(i - 1)
167 ) + self.sorted_atoms.select_atoms(
168 "resnum {} and (name N or name CA or name C)".format(i)
169 )
170 if len(phi_atoms) == 4:
171 dihedral_atoms.append(phi_atoms.dihedral)
173 psi_atoms = self.sorted_atoms.select_atoms(
174 "resnum {} and (name N or name CA or name C)".format(i)
175 ) + self.sorted_atoms.select_atoms("resnum {} and name N".format(i + 1))
176 if len(psi_atoms) == 4:
177 dihedral_atoms.append(psi_atoms.dihedral)
179 omega_atoms = self.sorted_atoms.select_atoms(
180 "resnum {} and (name CA or name C)".format(i)
181 ) + self.sorted_atoms.select_atoms(
182 "resnum {} and (name N or name CA)".format(i + 1)
183 )
184 if len(psi_atoms) == 4:
185 dihedral_atoms.append(omega_atoms.dihedral)
187 dihedrals = Dihedral(dihedral_atoms, verbose=True).run(
188 start=start, stop=stop, step=step
189 )
190 self.dihedrals = dihedrals.angles.astype(np.float32)
191 self.dihedrals *= pi / 180
193 if cache_path: 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true
194 np.save(os.path.join(cache_path, "dihedrals.npy"), self.dihedrals)
196 # SideDihedrals
197 try:
198 self.sidedihedrals = np.load(os.path.join(cache_path, "sidedihedrals.npy"))
199 if np.all(np.isnan(self.sidedihedrals)):
200 raise FileNotFoundError
201 print("Loaded sidedihedrals from {}".format(cache_path))
203 except FileNotFoundError:
204 print("Calculating sidedihedrals...")
205 with tempfile.NamedTemporaryFile(suffix=".pdb") as fp:
206 with mda.Writer(fp.name) as PDB:
207 PDB.write(self.universe)
208 top = md.load_pdb(fp.name).top
210 sidedihedral_atoms_inds = []
211 sidedihedral_atoms = []
212 for i in range(1, 6):
213 indexer = getattr(md_dihedral, f"indices_chi{i}")
214 inds = indexer(top)
215 sidedihedral_atoms_inds.append(inds)
216 sidedihedral_atoms_inds = np.vstack(sidedihedral_atoms_inds)
218 for ind in sidedihedral_atoms_inds: 218 ↛ 219line 218 didn't jump to line 219, because the loop on line 218 never started
219 atoms = []
220 for i in ind:
221 atom = self.universe.select_atoms(f"index {i}")
222 assert len(atom) == 1
223 atoms.append(atom[0])
224 ag = atoms[0] + atoms[1] + atoms[2] + atoms[3]
225 sidedihedral_atoms.append(ag.dihedral)
226 assert all([len(ag.atoms) == 4 for ag in sidedihedral_atoms])
228 if sidedihedral_atoms == []: 228 ↛ 231line 228 didn't jump to line 231, because the condition on line 228 was never false
229 self.sidedihedrals = np.nan
230 else:
231 sidedihedrals = Dihedral(sidedihedral_atoms, verbose=True).run(
232 start=start, stop=stop, step=step
233 )
234 self.sidedihedrals = sidedihedrals.angles.astype(np.float32)
235 self.sidedihedrals *= pi / 180
237 if cache_path: 237 ↛ 238line 237 didn't jump to line 238, because the condition on line 237 was never true
238 np.save(
239 os.path.join(cache_path, "sidedihedrals.npy"), self.sidedihedrals
240 )
242 # Angles:
243 try:
244 self.angles = np.load(os.path.join(cache_path, "angles.npy"))
245 print("Loaded angles from {}".format(cache_path))
247 except FileNotFoundError:
248 print("Calculating angles...")
249 angle_atoms = []
250 for i in range(len(self.central_atom_indices) - 2):
251 angle_atoms.append(
252 self.sorted_atoms[self.central_atom_indices[i : i + 3]]
253 )
255 angles = Angles(angle_atoms, verbose=True).run(
256 start=start, stop=stop, step=step
257 )
258 self.angles = angles.result.astype(np.float32)
260 if cache_path: 260 ↛ 261line 260 didn't jump to line 261, because the condition on line 260 was never true
261 np.save(os.path.join(create_dir(cache_path), "angles.npy"), self.angles)
263 # Lengths:
264 try:
265 self.lengths = np.load(os.path.join(cache_path, "lengths.npy"))
266 print("Loaded lengths from {}".format(cache_path))
268 except FileNotFoundError:
269 print("Calculating lengths...")
270 vecs = self.central_cartesians[:, :-1] - self.central_cartesians[:, 1:]
271 self.lengths = np.linalg.norm(vecs, axis=2)
272 if cache_path: 272 ↛ 273line 272 didn't jump to line 273, because the condition on line 272 was never true
273 np.save(
274 os.path.join(create_dir(cache_path), "lengths.npy"), self.lengths
275 )
277 assert self.lengths.shape[1] == self.central_cartesians.shape[1] - 1
278 assert self.angles.shape[1] == self.central_cartesians.shape[1] - 2
279 assert self.dihedrals.shape[1] == self.central_cartesians.shape[1] - 3
281 def __iadd__(self, other):
282 assert np.all(self.sorted_atoms.names == other.sorted_atoms.names)
283 self.cartesians = np.concatenate([self.cartesians, other.cartesians], axis=0)
284 self.central_cartesians = np.concatenate(
285 [self.central_cartesians, other.central_cartesians], axis=0
286 )
287 self.dihedrals = np.concatenate([self.dihedrals, other.dihedrals], axis=0)
288 self.sidedihedrals = np.concatenate(
289 [self.sidedihedrals, other.sidedihedrals], axis=0
290 )
291 self.angles = np.concatenate([self.angles, other.angles], axis=0)
292 self.lengths = np.concatenate([self.lengths, other.lengths], axis=0)
293 return self
295 @staticmethod
296 def sort_key(atom):
297 positions = {"N": 1, "CA": 2, "C": 5, "O": 6, "OXT": 7, "O1": 8, "O2": 9}
298 try:
299 result = positions[atom.name]
300 except KeyError:
301 result = 4
302 return atom.resnum, result
304 def write(
305 self,
306 path,
307 coordinates,
308 name="generated",
309 formats=("pdb", "xtc"),
310 only_central=False,
311 align_reference=None,
312 align_select="all",
313 ):
314 """
315 Writes a trajectory for the given coordinates.
317 :param path: directory where to save the trajectory
318 :param coordinates: numpy array of xyz coordinates (frames, atoms, xyz)
319 :param name: filename (without extension)
320 :param formats: specify which formats schould be used to write structure and trajectory. default: ("pdb", "xtc")
321 :param only_central: if True only central atom coordinates are expected (N-Ca-C...)
322 :param align_reference: Allows to allign the generated conformations according to some reference.
323 The reference should be given as MDAnalysis atomgroup
324 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60"
325 default is "all". Have a look at the MDAnalysis selection syntax for more details.
326 :return:
327 """
328 coordinates = np.array(coordinates)
329 if coordinates.ndim == 2:
330 coordinates = np.expand_dims(coordinates, 0)
331 if only_central:
332 output_universe = mda.Merge(self.central_atoms)
333 self.sorted_atoms[self.central_atom_indices].write(
334 os.path.join(path, "{}.{}".format(name, formats[0]))
335 )
336 else:
337 output_universe = mda.Merge(self.sorted_atoms)
338 self.sorted_atoms.write(
339 os.path.join(path, "{}.{}".format(name, formats[0]))
340 )
341 output_universe.load_new(coordinates, format=MemoryReader)
343 if align_reference is not None:
344 align_traj = AlignTraj(
345 output_universe, align_reference, align_select, in_memory=True
346 )
347 align_traj.run()
349 with mda.Writer(os.path.join(path, "{}.{}".format(name, formats[1]))) as w:
350 for step in output_universe.trajectory:
351 w.write(output_universe.atoms)