Coverage for encodermap/misc/clustering.py: 32%
163 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/misc/clustering.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Functions for building clusters."""
25################################################################################
26# Imports
27################################################################################
30import copy
31import warnings
33import matplotlib as mpl
34import numpy as np
36from .._optional_imports import _optional_import
37from .errors import BadError
39##############################################################################
40# Optional Imports
41##############################################################################
44md = _optional_import("mdtraj")
45ngl = _optional_import("nglview")
48################################################################################
49# Globals
50################################################################################
53__all__ = ["gen_dummy_traj", "get_cluster_frames"]
56################################################################################
57# Public Functions
58################################################################################
61def _get_joined_trajs(
62 trajs, cluster_no, shorten=True, max_frames=-1, col="cluster_membership"
63):
64 # where can be int or np.ndarray
65 if isinstance(cluster_no, (int, np.int_)): 65 ↛ 68line 65 didn't jump to line 68, because the condition on line 65 was never false
66 where = np.where(trajs.CVs[col] == cluster_no)[0]
67 else:
68 where = cluster_no
70 # stride the where check to make calculations faster
71 if max_frames != -1: 71 ↛ 72line 71 didn't jump to line 72, because the condition on line 71 was never true
72 idx = np.round(np.linspace(0, len(where) - 1, max_frames)).astype(int)
73 where = where[idx]
75 # if shorten change where, so that it contains only 10 structures
76 if shorten:
77 idx = np.round(np.linspace(0, len(where) - 1, 10)).astype(int)
78 where = where[idx]
79 assert len(where) == 10
81 # append to a list
82 joined_trajs = []
83 for i, point in enumerate(where):
84 try:
85 joined_trajs.append(trajs.get_single_frame(point).traj)
86 except IndexError:
87 print(point)
88 raise
89 return joined_trajs, where
92def get_cluster_frames(
93 trajs,
94 cluster_no,
95 align_string="name CA",
96 nglview=False,
97 stack_atoms=False,
98 shorten=False,
99 max_frames=-1,
100 superpose=True,
101 col="cluster_membership",
102 subunit="",
103 ball_and_stick=False,
104 cmap="viridis",
105):
106 if not isinstance(cluster_no, (int, np.int_)):
107 raise NotImplementedError()
109 joined_trajs, where = _get_joined_trajs(trajs, cluster_no, shorten, max_frames, col)
111 # preset nglview and only compute it if nglview = True
112 view = None
114 # if the trajs contain more atoms, we remove as much residues, until they have the same number of CA
115 if len(set([t.n_atoms for t in joined_trajs])) > 1: 115 ↛ 116line 115 didn't jump to line 116, because the condition on line 115 was never true
116 smallest_n_residues = min([t.n_residues for t in joined_trajs])
117 aligns = [
118 t.top.select(f"{align_string} and resid < {smallest_n_residues}")
119 for t in joined_trajs
120 ]
121 else:
122 aligns = [t.top.select(f"{align_string}") for t in joined_trajs]
124 # if superpose superpose the trajs
125 if superpose: 125 ↛ 143line 125 didn't jump to line 143, because the condition on line 125 was never false
126 if isinstance(superpose, bool): 126 ↛ 129line 126 didn't jump to line 129, because the condition on line 126 was never false
127 ref_frame = copy.deepcopy(joined_trajs[0])
128 else:
129 ref_frame = superpose
131 try:
132 for i, traj in enumerate(joined_trajs):
133 joined_trajs[i] = traj.superpose(
134 ref_frame,
135 atom_indices=aligns[i],
136 ref_atom_indices=aligns[0],
137 )
138 except AttributeError as e:
139 raise BadError(
140 f"You provided some wrong datatype or a misformatted string into the argument align_string. Here's the original error: {e}"
141 )
143 if subunit:
144 raise NotImplementedError()
146 # only stacking possible here
147 if not stack_atoms: 147 ↛ 148line 147 didn't jump to line 148, because the condition on line 147 was never true
148 raise Exception(
149 "Cannot build a time resolved traj from topologies with differing atom count."
150 )
151 if nglview:
152 if trajs.common_str: 152 ↛ 153line 152 didn't jump to line 153, because the condition on line 152 was never true
153 colors = mpl.cm.get_cmap(cmap).copy()(
154 np.linspace(0, 1, len(trajs.common_str))
155 )
156 else:
157 colors = mpl.cm.get_cmap(cmap).copy()(np.linspace(0, 1, len(joined_trajs)))
159 component = 0
160 for i, (frame, w) in enumerate(zip(joined_trajs, where)):
161 if trajs.common_str: 161 ↛ 162line 161 didn't jump to line 162, because the condition on line 161 was never true
162 c = colors[trajs.common_str.index(trajs.get_single_frame(w).common_str)]
163 else:
164 c = colors[i]
165 # c = '0x' + mpl.colors.rgb2hex(c)[1:7].upper()
166 c = mpl.colors.rgb2hex(c)
167 if i == 0:
168 view = ngl.show_mdtraj(frame)
169 else:
170 view.add_component(frame)
171 view.clear_representations(component=component)
172 if ball_and_stick: 172 ↛ 173line 172 didn't jump to line 173, because the condition on line 172 was never true
173 view.add_hyperball(selection="backbone", component=component, color=c)
174 else:
175 view.add_ribbon(selection="backbone", component=component, color=c)
176 component += 1
178 return view, joined_trajs
181def gen_dummy_traj(
182 trajs,
183 cluster_no,
184 align_string="name CA",
185 nglview=False,
186 stack_atoms=False,
187 shorten=False,
188 max_frames=-1,
189 superpose=True,
190 col="cluster_membership",
191 subunit="",
192 ref_align_string="name CA",
193 base_traj=None,
194):
195 """Makes a dummy traj from an encodermap trajectory which contains
196 trajectories with different topology.
198 This function takes an encodermap.TrajEnsemble object and returns mdtraj
199 trajectories for clustered data. This function can concatenate trajs even
200 if the topology of trajecotries in the TrajEnsemble class is different. The
201 topology of this dummy traj will be wrong, but the atomic positions are
202 correct.
204 This function constructs a traj of length cluster_membership.count(cluster_no)
205 with the topology of the first frame of this cluster
206 (trajs.get_single_frame(cluster_membership.index(cluster_no))) and changes the
207 atomic coordinates of this traj based on the other frames in this cluster.
209 Note:
210 If you have loaded the encodermap functions with the 'no_load'
211 backend a second call to this function with the same parameters will
212 be faster, because the trajectory frames have been loaded to memory.
214 Args:
215 trajs (encodermap.TrajEnsemble): The trajs which were clustered.
216 cluster_no (Union[int, int, np.ndarray, list]): The cluster_no of the cluster to make the dummy traj from.
217 Can be:
218 * int or int: The cluster will be found by using the trajs own cluster_membership in the trajs pandas dataframe.
219 * np.array or list: If list or np.array is provided multiple clusters are returned and colored according to clsuter_no.
220 align_string (str, optional): Use this mdtraj atom selection string to align the frames
221 of the dummy traj. Defaults to 'name CA'.
222 nglview (bool, optional): Whether to return a tuple of an nglview.view object and the traj
223 or not. Defaults to False.
224 stack_atoms (bool, optional): Whether to stack all frames into a single frame with
225 mutliple structures in it. This option is useful, if you want to
226 generate a picture of interpenetrating structures. Defaults to False.
227 shorten (bool, optional): Whether to return all structures or just a subset of
228 roughly ten structures. Defaults to False.
229 max_frames (int, optional): Only return so many frames. If set to -1 all frames will
230 be returned. Defaults to -1.
231 superpose (Union(bool, mdtraj.Trajectory), optional): Whether the frames of the returned traj should be superposed
232 to frame 0 of the traj. If an mdtraj Trajectory is provided this trajectory is used to superpose. Defaults to True.
233 subunit (str, optional): When you want to only visualize an ensemble of certain parts of your protein but keep some
234 part stationary (see `align_str`), you can provide a mdtraj selection string. This part of the
235 protein will only be rendered from the first frame. The other parts will be rendered as an ensemble
236 of structures (either along atom (`stack_atoms` = True) or time (`stack_atoms` = False)). Defaults to ''.
237 ref_align_string (str, optional): When the type of `superpose` is mdtraj.Trajectory with a different topology
238 than `trajs`, you can give a different align string into this argument. Defaults to 'name CA'.
239 base_traj (Union[None, mdtraj.Trajectory], optional): An mdtraj.Trajectory that will be set to the coordinates from
240 trajs, instead of trajs[0]. Normally, the first traj in `trajs` (trajs[0]) will be used as a base traj.
241 It will be extended into the time-direction until it has the desired number of frames (shorten=True; 10,
242 max_frames=N, N; etc.). If you don't want to use this traj but something else, you can feed this option
243 an mdtraj.Trajectory object. Defaults to None.
245 Returns:
246 tuple: A tuple containing:
248 view (nglview.view): The nglview.view object if nglview == True,
249 is None otherwise.
250 dummy_traj (mdtraj.Trajectory): The mdtraj trajectory with wrong
251 topology but correct atomic positions.
253 See also:
254 See the render_vmd function in this document
255 to render an image of the returned traj.
257 """
258 if isinstance(cluster_no, (int, np.int_)):
259 return _gen_dummy_traj_single(
260 trajs,
261 cluster_no,
262 align_string,
263 nglview,
264 stack_atoms,
265 shorten,
266 max_frames,
267 superpose,
268 col,
269 subunit,
270 ref_align_string,
271 base_traj,
272 )
273 elif isinstance(cluster_no, (list, np.ndarray)):
274 dummy_trajs, views = [], []
275 for i in cluster_no:
276 v, dt = _gen_dummy_traj_single(
277 trajs,
278 i,
279 align_string,
280 nglview,
281 stack_atoms,
282 shorten,
283 max_frames,
284 superpose,
285 col,
286 subunit,
287 ref_align_string,
288 base_traj,
289 )
290 dummy_trajs.append(dt)
291 views.append(v)
292 if stack_atoms:
293 dummy_traj = dummy_trajs[0]
294 view = views[0]
295 for frame in dummy_trajs[1:]:
296 dummy_traj = dummy_traj.stack(frame)
297 if not subunit:
298 view.add_trajectory(dummy_traj)
299 else:
300 raise NotImplementedError("Not yet Implemented")
301 else:
302 raise NotImplementedError(
303 "Joining along time axes with multiple cluster's doesn't seem to make sense. Make a proposal how to handle this!"
304 )
305 return views, dummy_trajs
306 else:
307 raise TypeError(
308 f"`cluster_no` must be int or list. You supplied {type(cluster_no)}."
309 )
312def _gen_dummy_traj_single(
313 trajs,
314 cluster_no,
315 align_string="name CA",
316 nglview=False,
317 stack_atoms=False,
318 shorten=False,
319 max_frames=-1,
320 superpose=True,
321 col="cluster_membership",
322 subunit="",
323 ref_align_string="name CA",
324 base_traj=None,
325):
326 """Called when only one cluster is needed."""
328 joined_trajs, where = _get_joined_trajs(trajs, cluster_no, shorten, max_frames, col)
330 # preset nglview and only compute it if nglview = True
331 view = None
333 # use traj[0] of the trajs list as the traj from which the topology will be used
334 # or use base_traj, if provided
335 if base_traj is None:
336 if isinstance(trajs[0].index, slice):
337 parent_traj = md.load(
338 trajs.locations[0], top=trajs.top_files[0], stride=trajs[0].index.step
339 )[: len(joined_trajs)]
340 else:
341 parent_traj = md.load(trajs.locations[0], top=trajs.top_files[0])[
342 : len(joined_trajs)
343 ]
344 else:
345 parent_traj = base_traj
347 # print some info
348 if align_string:
349 sel = parent_traj.top.select(align_string)
350 print(
351 f"Provided alignment string results in {len(sel)} atoms. First atom is {parent_traj.top.atom(sel[0])}. Last atom is {parent_traj.top.atom(sel[-1])}."
352 )
354 # join the correct number of trajs
355 # by use of the divmod method, the frames parent_traj traj will be
356 # appended for a certain amount, until the remainder of the division
357 # is met by that time, the parent traj will be sliced to fill the correct number of frames
358 try:
359 no_of_iters, rest = divmod(len(where), parent_traj.n_frames)
360 except Exception as e:
361 raise Exception(
362 f"Can not buid a dummy trajectory. Maybe you selected the wronmg cluster num. Here's the original Error: {e}"
363 )
364 for i in range(no_of_iters + 1):
365 if i == 0:
366 dummy_traj = copy.deepcopy(parent_traj)
367 elif i == no_of_iters:
368 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj)[:rest])
369 else:
370 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj))
372 # some checks
373 assert len(where) == dummy_traj.n_frames
374 assert len(where) == len(joined_trajs)
376 # change the xyz coordinates of dummy_traj according to the frames in joined trajs
377 for i, traj in enumerate(joined_trajs):
378 dummy_traj.xyz[i] = traj.xyz
380 # if superpose superpose the trajs
381 if superpose:
382 if isinstance(superpose, bool):
383 ref_frame = copy.deepcopy(dummy_traj[0])
384 else:
385 ref_frame = superpose
386 try:
387 _ = dummy_traj.superpose(
388 ref_frame,
389 atom_indices=dummy_traj.top.select(align_string),
390 ref_atom_indices=ref_frame.top.select(ref_align_string),
391 )
392 except AttributeError as e:
393 raise BadError(
394 f"You provided some wrong datatype or a misformatted string into the argument align_string. Here's the original error: {e}"
395 )
397 # if stack_atoms is true overwrite dummy_traj
398 if stack_atoms:
399 tmp_ = copy.deepcopy(dummy_traj)
400 for i, frame in enumerate(tmp_):
401 if i == 0:
402 dummy_traj = copy.deepcopy(frame)
403 else:
404 if subunit:
405 sel_all = frame.top.select("all")
406 sel_subunit = frame.top.select(subunit)
407 not_subunit = np.setdiff1d(sel_all, sel_subunit)
408 frame = frame.atom_slice(not_subunit)
409 dummy_traj = dummy_traj.stack(frame)
411 # make nglview
412 if nglview and not subunit:
413 view = ngl.show_mdtraj(dummy_traj)
414 if nglview and subunit:
415 warnings.simplefilter("ignore")
416 view = ngl.show_mdtraj(tmp_[0])
417 view.clear_representations()
418 view.add_representation("cartoon", color="red")
419 for frame in tmp_[1:]:
420 sel_all = frame.top.select("all")
421 sel_subunit = frame.top.select(subunit)
422 not_subunit = np.setdiff1d(sel_all, sel_subunit)
423 frame = frame.atom_slice(not_subunit)
424 try:
425 _ = view.add_trajectory(frame)
426 except KeyError as e:
427 # for some reason it doen't want some atoms. Let's remove them.
428 for a in frame.top.atoms:
429 if str(a) == str(e):
430 offender_index = a.index
431 break
432 without_offender = np.setdiff1d(
433 frame.top.select("all"), [offender_index]
434 )
435 frame = frame.atom_slice(without_offender)
436 _ = view.add_trajectory(frame)
438 # return
439 warnings.simplefilter("default")
440 return view, dummy_traj
443def rmsd_centroid_of_cluster(traj, parallel=True, atom_indices=None):
444 """Computes the rmsd centroid of a trajectory.
446 Args:
447 traj (mdtraj.Trajectory): The trajectory.
449 Returns:
450 tuple: A tuple containing:
452 index (int): The index of the centroid.
453 distances (np.ndarray): The RMSD distance matrix with shape
454 traj.n_frames x traj.n_frames
455 centroid (mdtraj.Trajectory): The traj of the centroid.
457 References:
458 Uses the algorithm found on http://mdtraj.org/latest/examples/centroids.html
460 """
461 if not np.any(atom_indices):
462 atom_indices = [a.index for a in traj.topology.atoms if a.element.symbol != "H"]
463 distances = np.empty((traj.n_frames, traj.n_frames))
464 for i in range(traj.n_frames):
465 distances[i] = md.rmsd(
466 traj, traj, i, atom_indices=atom_indices, parallel=parallel
467 )
468 beta = 1
469 index = np.exp(-beta * distances / distances.std()).sum(axis=1).argmax()
470 centroid = traj[index]
471 return index, distances, centroid