Coverage for encodermap/misc/clustering.py: 8%

86 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/misc/clustering.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Functions for building clusters.""" 

23 

24 

25################################################################################ 

26# Imports 

27################################################################################ 

28 

29 

30# Future Imports at the top 

31from __future__ import annotations 

32 

33# Standard Library Imports 

34import copy 

35 

36# Third Party Imports 

37import numpy as np 

38from optional_imports import _optional_import 

39 

40 

41################################################################################ 

42# Optional Imports 

43################################################################################ 

44 

45 

46md = _optional_import("mdtraj") 

47ngl = _optional_import("nglview") 

48 

49 

50################################################################################ 

51# Typing 

52################################################################################ 

53 

54 

55# Standard Library Imports 

56from typing import TYPE_CHECKING, Optional, TypedDict 

57 

58# Third Party Imports 

59from typing_extensions import NotRequired 

60 

61 

62if TYPE_CHECKING: 

63 # Third Party Imports 

64 from mdtraj import Topology, Trajectory 

65 from pandas import DataFrame, Series 

66 

67 # Local Folder Imports 

68 from ..trajinfo import SingleTraj, TrajEnsemble 

69 

70 class ClusterDict(TypedDict): 

71 """The output of the `cluster_to_dict()` function.""" 

72 

73 stacked: Trajectory 

74 joined_per_top: dict[Topology, Trajectory] 

75 ensemble: TrajEnsemble 

76 series: Series 

77 joined: NotRequired[Trajectory] 

78 

79 

80################################################################################ 

81# Globals 

82################################################################################ 

83 

84 

85__all__: list[str] = [] 

86 

87 

88################################################################################ 

89# Public Functions 

90################################################################################ 

91 

92 

93def rmsd_centroid_of_cluster( 

94 traj: Trajectory, 

95 parallel: bool = True, 

96 atom_indices: Optional[np.ndarray] = None, 

97) -> tuple[int, np.ndarray, Trajectory]: 

98 """Computes the rmsd centroid of a trajectory. 

99 

100 Args: 

101 traj (mdtraj.Trajectory): The trajectory. 

102 parallel (bool): Use OpenMP to calculate each of the RMSDs in 

103 parallel over multiple cores. 

104 atom_indices 

105 

106 Returns: 

107 tuple: A tuple containing: 

108 - index (int): The index of the centroid. 

109 - distances (np.ndarray): The RMSD distance matrix with shape 

110 traj.n_frames x traj.n_frames 

111 - centroid (mdtraj.Trajectory): The traj of the centroid. 

112 

113 References: 

114 Uses the algorithm found on http://mdtraj.org/latest/examples/centroids.html 

115 

116 """ 

117 if not np.any(atom_indices): 

118 atom_indices = [a.index for a in traj.topology.atoms if a.element.symbol != "H"] 

119 distances = np.empty((traj.n_frames, traj.n_frames)) 

120 for i in range(traj.n_frames): 

121 distances[i] = md.rmsd( 

122 traj, traj, i, atom_indices=atom_indices, parallel=parallel 

123 ) 

124 beta = 1 

125 index = np.exp(-beta * distances / distances.std()).sum(axis=1).argmax() 

126 centroid = traj[index] 

127 return index, distances, centroid 

128 

129 

130def cluster_to_dict( 

131 trajs: TrajEnsemble, 

132 align_string: str = "name CA", 

133 ref_align_string: str = "name CA", 

134 base_traj: Optional[Trajectory] = None, 

135) -> ClusterDict: 

136 """Creates a dictionary with joined/stacked trajectory frames. 

137 

138 Examples: 

139 >>> import encodermap as em 

140 >>> import numpy as np 

141 >>> trajs = em.TrajEnsemble( 

142 ... [ 

143 ... em.SingleTraj.from_pdb_id("1YUG"), 

144 ... em.SingleTraj.from_pdb_id("1YUF"), 

145 ... ] 

146 ... ) 

147 >>> cluster_membership = np.full(shape=(trajs.n_frames, ), fill_value=-1, dtype=int) 

148 >>> cluster_membership[::2] = 0 

149 >>> trajs.load_CVs(cluster_membership, "cluster_membership") 

150 >>> cluster = trajs.cluster(cluster_id=0) 

151 >>> cluster_dict = em.misc.clustering.cluster_to_dict(cluster) 

152 

153 Only, when the trajectories have all the same number of atoms, then 

154 we can join them all along the timne axis. And only then does the 

155 `'joined'` key appear in the output. 

156 

157 >>> trajs.top # doctest: +ELLIPSIS 

158 [<mdtraj.Topology with 1 chains, 50 residues, 720 atoms, 734 bonds at...>] 

159 >>> cluster_dict.keys() 

160 dict_keys(['stacked', 'joined_per_top', 'ensemble', 'series', 'joined']) 

161 

162 The `'stacked'` key contains all frames stacked along the atom axis. 

163 

164 >>> cluster_dict["stacked"] # doctest: +ELLIPSIS 

165 <mdtraj.Trajectory with 1 frames, 11520 atoms, 800 residues, without unitcells at...> 

166 

167 Args: 

168 trajs (TrajEnsemble): A TrajEnsemble of a cluster. See the documentation 

169 of :func:`encodermap.TrajEnsmeble.cluster` for more info. 

170 align_string (str): The align string of the parent traj. 

171 ref_align_string (str): The align string for the reference. 

172 base_traj (Trajectory, optional): A parent trajectory for when all 

173 trajs in `trajs` have the same number of atoms. In that case, 

174 the atomic coordinates can be applied to the base_traj. 

175 

176 Returns: 

177 ClusterDict: A dictionary with joined/stacked MDTraj trajectories. 

178 

179 """ 

180 ds = trajs._CVs 

181 y = None 

182 if len(ds) == 1: 

183 col = list(ds.keys())[0] 

184 cluster_id = ds[col].values 

185 cluster_id = cluster_id[~np.isnan(cluster_id)] 

186 cluster_id = np.unique(cluster_id.astype(int)) 

187 assert ( 

188 len(cluster_id) == 1 

189 ), f"The CV '{col}' has ambiguous cluster_ids: {cluster_id}." 

190 cluster_id = cluster_id[0] 

191 else: 

192 for name, data_var in ds.items(): 

193 x = data_var.values 

194 x = x[~np.isnan(x)] 

195 x = np.mod(x, 1) 

196 if np.all(x == 0): 

197 col = name 

198 y = x.copy() 

199 if np.all(x == 0) and len(np.unique(x)) == 1: 

200 col = name 

201 cluster_id = np.unique(x)[0] 

202 break 

203 else: 

204 if y is None: 

205 raise Exception( 

206 f"Could not find a CV with integer values that defines a cluster " 

207 f"membership. Make sure to `trajs.load_CVs()` a numpy array with " 

208 f"cluster memberships." 

209 ) 

210 else: 

211 raise Exception(f"The CV '{col}' has ambiguous cluster_ids: {y}.") 

212 

213 series: DataFrame = ( 

214 ds.stack({"frame": ("traj_num", "frame_num")}) 

215 .transpose("frame", ...) 

216 .dropna("frame", how="all")[col] 

217 .to_pandas() 

218 ) 

219 assert len(keys := list(series.keys())) == 1 

220 series = series[keys[0]] 

221 

222 # if frames have the same xyz, we can join them 

223 joined = None 

224 joined_per_top = {} 

225 all_trajs: list[Trajectory] = [] 

226 if all([t.n_atoms == trajs[0].n_atoms for t in trajs]): 

227 # superpose all 

228 for i, traj in enumerate(trajs): 

229 if traj.top in joined_per_top: 

230 ref = joined_per_top[traj.top].get_single_frame(0) 

231 else: 

232 ref = traj.get_single_frame(0) 

233 superposed = traj.superpose( 

234 reference=ref, 

235 frame=0, 

236 atom_indices=traj.top.select(align_string), 

237 ref_atom_indices=ref.top.select(ref_align_string), 

238 ) 

239 if traj.top in joined_per_top: 

240 joined_per_top[traj.top] += superposed 

241 else: 

242 joined_per_top[traj.top] = superposed 

243 all_trajs.append(superposed) 

244 

245 parent_traj = base_traj 

246 if parent_traj is None: 

247 parent_traj = all_trajs[0].traj 

248 

249 # divmod 

250 try: 

251 no_of_iters, rest = divmod( 

252 sum([t.n_frames for t in all_trajs]), parent_traj.n_frames 

253 ) 

254 except Exception as e: 

255 raise Exception( 

256 f"Can not build a dummy trajectory. Maybe you selected the " 

257 f"wrong cluster num. Here's the original Error: {e}" 

258 ) 

259 for i in range(no_of_iters + 1): 

260 if i == 0: 

261 dummy_traj = copy.deepcopy(parent_traj) 

262 elif i == no_of_iters: 

263 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj)[:rest]) 

264 else: 

265 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj)) 

266 

267 # set the xyz 

268 i = 0 

269 for traj in all_trajs: 

270 for frame in traj: 

271 dummy_traj[0].xyz = frame.xyz 

272 

273 joined = dummy_traj 

274 

275 # stack 

276 for i, traj in enumerate(trajs): 

277 for j, frame in enumerate(traj): 

278 if i == 0 and j == 0: 

279 stacked = copy.deepcopy(frame.traj) 

280 else: 

281 stacked = stacked.stack(frame.traj) 

282 

283 out: ClusterDict = { 

284 "stacked": stacked, 

285 "joined_per_top": joined_per_top, 

286 "ensemble": trajs, 

287 "series": series, 

288 } 

289 if joined is not None: 

290 out["joined"] = joined 

291 

292 return out