Coverage for encodermap/misc/clustering.py: 8%
86 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/misc/clustering.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Functions for building clusters."""
25################################################################################
26# Imports
27################################################################################
30# Future Imports at the top
31from __future__ import annotations
33# Standard Library Imports
34import copy
36# Third Party Imports
37import numpy as np
38from optional_imports import _optional_import
41################################################################################
42# Optional Imports
43################################################################################
46md = _optional_import("mdtraj")
47ngl = _optional_import("nglview")
50################################################################################
51# Typing
52################################################################################
55# Standard Library Imports
56from typing import TYPE_CHECKING, Optional, TypedDict
58# Third Party Imports
59from typing_extensions import NotRequired
62if TYPE_CHECKING:
63 # Third Party Imports
64 from mdtraj import Topology, Trajectory
65 from pandas import DataFrame, Series
67 # Local Folder Imports
68 from ..trajinfo import SingleTraj, TrajEnsemble
70 class ClusterDict(TypedDict):
71 """The output of the `cluster_to_dict()` function."""
73 stacked: Trajectory
74 joined_per_top: dict[Topology, Trajectory]
75 ensemble: TrajEnsemble
76 series: Series
77 joined: NotRequired[Trajectory]
80################################################################################
81# Globals
82################################################################################
85__all__: list[str] = []
88################################################################################
89# Public Functions
90################################################################################
93def rmsd_centroid_of_cluster(
94 traj: Trajectory,
95 parallel: bool = True,
96 atom_indices: Optional[np.ndarray] = None,
97) -> tuple[int, np.ndarray, Trajectory]:
98 """Computes the rmsd centroid of a trajectory.
100 Args:
101 traj (mdtraj.Trajectory): The trajectory.
102 parallel (bool): Use OpenMP to calculate each of the RMSDs in
103 parallel over multiple cores.
104 atom_indices
106 Returns:
107 tuple: A tuple containing:
108 - index (int): The index of the centroid.
109 - distances (np.ndarray): The RMSD distance matrix with shape
110 traj.n_frames x traj.n_frames
111 - centroid (mdtraj.Trajectory): The traj of the centroid.
113 References:
114 Uses the algorithm found on http://mdtraj.org/latest/examples/centroids.html
116 """
117 if not np.any(atom_indices):
118 atom_indices = [a.index for a in traj.topology.atoms if a.element.symbol != "H"]
119 distances = np.empty((traj.n_frames, traj.n_frames))
120 for i in range(traj.n_frames):
121 distances[i] = md.rmsd(
122 traj, traj, i, atom_indices=atom_indices, parallel=parallel
123 )
124 beta = 1
125 index = np.exp(-beta * distances / distances.std()).sum(axis=1).argmax()
126 centroid = traj[index]
127 return index, distances, centroid
130def cluster_to_dict(
131 trajs: TrajEnsemble,
132 align_string: str = "name CA",
133 ref_align_string: str = "name CA",
134 base_traj: Optional[Trajectory] = None,
135) -> ClusterDict:
136 """Creates a dictionary with joined/stacked trajectory frames.
138 Examples:
139 >>> import encodermap as em
140 >>> import numpy as np
141 >>> trajs = em.TrajEnsemble(
142 ... [
143 ... em.SingleTraj.from_pdb_id("1YUG"),
144 ... em.SingleTraj.from_pdb_id("1YUF"),
145 ... ]
146 ... )
147 >>> cluster_membership = np.full(shape=(trajs.n_frames, ), fill_value=-1, dtype=int)
148 >>> cluster_membership[::2] = 0
149 >>> trajs.load_CVs(cluster_membership, "cluster_membership")
150 >>> cluster = trajs.cluster(cluster_id=0)
151 >>> cluster_dict = em.misc.clustering.cluster_to_dict(cluster)
153 Only, when the trajectories have all the same number of atoms, then
154 we can join them all along the timne axis. And only then does the
155 `'joined'` key appear in the output.
157 >>> trajs.top # doctest: +ELLIPSIS
158 [<mdtraj.Topology with 1 chains, 50 residues, 720 atoms, 734 bonds at...>]
159 >>> cluster_dict.keys()
160 dict_keys(['stacked', 'joined_per_top', 'ensemble', 'series', 'joined'])
162 The `'stacked'` key contains all frames stacked along the atom axis.
164 >>> cluster_dict["stacked"] # doctest: +ELLIPSIS
165 <mdtraj.Trajectory with 1 frames, 11520 atoms, 800 residues, without unitcells at...>
167 Args:
168 trajs (TrajEnsemble): A TrajEnsemble of a cluster. See the documentation
169 of :func:`encodermap.TrajEnsmeble.cluster` for more info.
170 align_string (str): The align string of the parent traj.
171 ref_align_string (str): The align string for the reference.
172 base_traj (Trajectory, optional): A parent trajectory for when all
173 trajs in `trajs` have the same number of atoms. In that case,
174 the atomic coordinates can be applied to the base_traj.
176 Returns:
177 ClusterDict: A dictionary with joined/stacked MDTraj trajectories.
179 """
180 ds = trajs._CVs
181 y = None
182 if len(ds) == 1:
183 col = list(ds.keys())[0]
184 cluster_id = ds[col].values
185 cluster_id = cluster_id[~np.isnan(cluster_id)]
186 cluster_id = np.unique(cluster_id.astype(int))
187 assert (
188 len(cluster_id) == 1
189 ), f"The CV '{col}' has ambiguous cluster_ids: {cluster_id}."
190 cluster_id = cluster_id[0]
191 else:
192 for name, data_var in ds.items():
193 x = data_var.values
194 x = x[~np.isnan(x)]
195 x = np.mod(x, 1)
196 if np.all(x == 0):
197 col = name
198 y = x.copy()
199 if np.all(x == 0) and len(np.unique(x)) == 1:
200 col = name
201 cluster_id = np.unique(x)[0]
202 break
203 else:
204 if y is None:
205 raise Exception(
206 f"Could not find a CV with integer values that defines a cluster "
207 f"membership. Make sure to `trajs.load_CVs()` a numpy array with "
208 f"cluster memberships."
209 )
210 else:
211 raise Exception(f"The CV '{col}' has ambiguous cluster_ids: {y}.")
213 series: DataFrame = (
214 ds.stack({"frame": ("traj_num", "frame_num")})
215 .transpose("frame", ...)
216 .dropna("frame", how="all")[col]
217 .to_pandas()
218 )
219 assert len(keys := list(series.keys())) == 1
220 series = series[keys[0]]
222 # if frames have the same xyz, we can join them
223 joined = None
224 joined_per_top = {}
225 all_trajs: list[Trajectory] = []
226 if all([t.n_atoms == trajs[0].n_atoms for t in trajs]):
227 # superpose all
228 for i, traj in enumerate(trajs):
229 if traj.top in joined_per_top:
230 ref = joined_per_top[traj.top].get_single_frame(0)
231 else:
232 ref = traj.get_single_frame(0)
233 superposed = traj.superpose(
234 reference=ref,
235 frame=0,
236 atom_indices=traj.top.select(align_string),
237 ref_atom_indices=ref.top.select(ref_align_string),
238 )
239 if traj.top in joined_per_top:
240 joined_per_top[traj.top] += superposed
241 else:
242 joined_per_top[traj.top] = superposed
243 all_trajs.append(superposed)
245 parent_traj = base_traj
246 if parent_traj is None:
247 parent_traj = all_trajs[0].traj
249 # divmod
250 try:
251 no_of_iters, rest = divmod(
252 sum([t.n_frames for t in all_trajs]), parent_traj.n_frames
253 )
254 except Exception as e:
255 raise Exception(
256 f"Can not build a dummy trajectory. Maybe you selected the "
257 f"wrong cluster num. Here's the original Error: {e}"
258 )
259 for i in range(no_of_iters + 1):
260 if i == 0:
261 dummy_traj = copy.deepcopy(parent_traj)
262 elif i == no_of_iters:
263 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj)[:rest])
264 else:
265 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj))
267 # set the xyz
268 i = 0
269 for traj in all_trajs:
270 for frame in traj:
271 dummy_traj[0].xyz = frame.xyz
273 joined = dummy_traj
275 # stack
276 for i, traj in enumerate(trajs):
277 for j, frame in enumerate(traj):
278 if i == 0 and j == 0:
279 stacked = copy.deepcopy(frame.traj)
280 else:
281 stacked = stacked.stack(frame.traj)
283 out: ClusterDict = {
284 "stacked": stacked,
285 "joined_per_top": joined_per_top,
286 "ensemble": trajs,
287 "series": series,
288 }
289 if joined is not None:
290 out["joined"] = joined
292 return out