Coverage for encodermap/misc/clustering.py: 32%

163 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/misc/clustering.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Functions for building clusters.""" 

23 

24 

25################################################################################ 

26# Imports 

27################################################################################ 

28 

29 

30import copy 

31import warnings 

32 

33import matplotlib as mpl 

34import numpy as np 

35 

36from .._optional_imports import _optional_import 

37from .errors import BadError 

38 

39############################################################################## 

40# Optional Imports 

41############################################################################## 

42 

43 

44md = _optional_import("mdtraj") 

45ngl = _optional_import("nglview") 

46 

47 

48################################################################################ 

49# Globals 

50################################################################################ 

51 

52 

53__all__ = ["gen_dummy_traj", "get_cluster_frames"] 

54 

55 

56################################################################################ 

57# Public Functions 

58################################################################################ 

59 

60 

61def _get_joined_trajs( 

62 trajs, cluster_no, shorten=True, max_frames=-1, col="cluster_membership" 

63): 

64 # where can be int or np.ndarray 

65 if isinstance(cluster_no, (int, np.int_)): 65 ↛ 68line 65 didn't jump to line 68, because the condition on line 65 was never false

66 where = np.where(trajs.CVs[col] == cluster_no)[0] 

67 else: 

68 where = cluster_no 

69 

70 # stride the where check to make calculations faster 

71 if max_frames != -1: 71 ↛ 72line 71 didn't jump to line 72, because the condition on line 71 was never true

72 idx = np.round(np.linspace(0, len(where) - 1, max_frames)).astype(int) 

73 where = where[idx] 

74 

75 # if shorten change where, so that it contains only 10 structures 

76 if shorten: 

77 idx = np.round(np.linspace(0, len(where) - 1, 10)).astype(int) 

78 where = where[idx] 

79 assert len(where) == 10 

80 

81 # append to a list 

82 joined_trajs = [] 

83 for i, point in enumerate(where): 

84 try: 

85 joined_trajs.append(trajs.get_single_frame(point).traj) 

86 except IndexError: 

87 print(point) 

88 raise 

89 return joined_trajs, where 

90 

91 

92def get_cluster_frames( 

93 trajs, 

94 cluster_no, 

95 align_string="name CA", 

96 nglview=False, 

97 stack_atoms=False, 

98 shorten=False, 

99 max_frames=-1, 

100 superpose=True, 

101 col="cluster_membership", 

102 subunit="", 

103 ball_and_stick=False, 

104 cmap="viridis", 

105): 

106 if not isinstance(cluster_no, (int, np.int_)): 

107 raise NotImplementedError() 

108 

109 joined_trajs, where = _get_joined_trajs(trajs, cluster_no, shorten, max_frames, col) 

110 

111 # preset nglview and only compute it if nglview = True 

112 view = None 

113 

114 # if the trajs contain more atoms, we remove as much residues, until they have the same number of CA 

115 if len(set([t.n_atoms for t in joined_trajs])) > 1: 115 ↛ 116line 115 didn't jump to line 116, because the condition on line 115 was never true

116 smallest_n_residues = min([t.n_residues for t in joined_trajs]) 

117 aligns = [ 

118 t.top.select(f"{align_string} and resid < {smallest_n_residues}") 

119 for t in joined_trajs 

120 ] 

121 else: 

122 aligns = [t.top.select(f"{align_string}") for t in joined_trajs] 

123 

124 # if superpose superpose the trajs 

125 if superpose: 125 ↛ 143line 125 didn't jump to line 143, because the condition on line 125 was never false

126 if isinstance(superpose, bool): 126 ↛ 129line 126 didn't jump to line 129, because the condition on line 126 was never false

127 ref_frame = copy.deepcopy(joined_trajs[0]) 

128 else: 

129 ref_frame = superpose 

130 

131 try: 

132 for i, traj in enumerate(joined_trajs): 

133 joined_trajs[i] = traj.superpose( 

134 ref_frame, 

135 atom_indices=aligns[i], 

136 ref_atom_indices=aligns[0], 

137 ) 

138 except AttributeError as e: 

139 raise BadError( 

140 f"You provided some wrong datatype or a misformatted string into the argument align_string. Here's the original error: {e}" 

141 ) 

142 

143 if subunit: 

144 raise NotImplementedError() 

145 

146 # only stacking possible here 

147 if not stack_atoms: 147 ↛ 148line 147 didn't jump to line 148, because the condition on line 147 was never true

148 raise Exception( 

149 "Cannot build a time resolved traj from topologies with differing atom count." 

150 ) 

151 if nglview: 

152 if trajs.common_str: 152 ↛ 153line 152 didn't jump to line 153, because the condition on line 152 was never true

153 colors = mpl.cm.get_cmap(cmap).copy()( 

154 np.linspace(0, 1, len(trajs.common_str)) 

155 ) 

156 else: 

157 colors = mpl.cm.get_cmap(cmap).copy()(np.linspace(0, 1, len(joined_trajs))) 

158 

159 component = 0 

160 for i, (frame, w) in enumerate(zip(joined_trajs, where)): 

161 if trajs.common_str: 161 ↛ 162line 161 didn't jump to line 162, because the condition on line 161 was never true

162 c = colors[trajs.common_str.index(trajs.get_single_frame(w).common_str)] 

163 else: 

164 c = colors[i] 

165 # c = '0x' + mpl.colors.rgb2hex(c)[1:7].upper() 

166 c = mpl.colors.rgb2hex(c) 

167 if i == 0: 

168 view = ngl.show_mdtraj(frame) 

169 else: 

170 view.add_component(frame) 

171 view.clear_representations(component=component) 

172 if ball_and_stick: 172 ↛ 173line 172 didn't jump to line 173, because the condition on line 172 was never true

173 view.add_hyperball(selection="backbone", component=component, color=c) 

174 else: 

175 view.add_ribbon(selection="backbone", component=component, color=c) 

176 component += 1 

177 

178 return view, joined_trajs 

179 

180 

181def gen_dummy_traj( 

182 trajs, 

183 cluster_no, 

184 align_string="name CA", 

185 nglview=False, 

186 stack_atoms=False, 

187 shorten=False, 

188 max_frames=-1, 

189 superpose=True, 

190 col="cluster_membership", 

191 subunit="", 

192 ref_align_string="name CA", 

193 base_traj=None, 

194): 

195 """Makes a dummy traj from an encodermap trajectory which contains 

196 trajectories with different topology. 

197 

198 This function takes an encodermap.TrajEnsemble object and returns mdtraj 

199 trajectories for clustered data. This function can concatenate trajs even 

200 if the topology of trajecotries in the TrajEnsemble class is different. The 

201 topology of this dummy traj will be wrong, but the atomic positions are 

202 correct. 

203 

204 This function constructs a traj of length cluster_membership.count(cluster_no) 

205 with the topology of the first frame of this cluster 

206 (trajs.get_single_frame(cluster_membership.index(cluster_no))) and changes the 

207 atomic coordinates of this traj based on the other frames in this cluster. 

208 

209 Note: 

210 If you have loaded the encodermap functions with the 'no_load' 

211 backend a second call to this function with the same parameters will 

212 be faster, because the trajectory frames have been loaded to memory. 

213 

214 Args: 

215 trajs (encodermap.TrajEnsemble): The trajs which were clustered. 

216 cluster_no (Union[int, int, np.ndarray, list]): The cluster_no of the cluster to make the dummy traj from. 

217 Can be: 

218 * int or int: The cluster will be found by using the trajs own cluster_membership in the trajs pandas dataframe. 

219 * np.array or list: If list or np.array is provided multiple clusters are returned and colored according to clsuter_no. 

220 align_string (str, optional): Use this mdtraj atom selection string to align the frames 

221 of the dummy traj. Defaults to 'name CA'. 

222 nglview (bool, optional): Whether to return a tuple of an nglview.view object and the traj 

223 or not. Defaults to False. 

224 stack_atoms (bool, optional): Whether to stack all frames into a single frame with 

225 mutliple structures in it. This option is useful, if you want to 

226 generate a picture of interpenetrating structures. Defaults to False. 

227 shorten (bool, optional): Whether to return all structures or just a subset of 

228 roughly ten structures. Defaults to False. 

229 max_frames (int, optional): Only return so many frames. If set to -1 all frames will 

230 be returned. Defaults to -1. 

231 superpose (Union(bool, mdtraj.Trajectory), optional): Whether the frames of the returned traj should be superposed 

232 to frame 0 of the traj. If an mdtraj Trajectory is provided this trajectory is used to superpose. Defaults to True. 

233 subunit (str, optional): When you want to only visualize an ensemble of certain parts of your protein but keep some 

234 part stationary (see `align_str`), you can provide a mdtraj selection string. This part of the 

235 protein will only be rendered from the first frame. The other parts will be rendered as an ensemble 

236 of structures (either along atom (`stack_atoms` = True) or time (`stack_atoms` = False)). Defaults to ''. 

237 ref_align_string (str, optional): When the type of `superpose` is mdtraj.Trajectory with a different topology 

238 than `trajs`, you can give a different align string into this argument. Defaults to 'name CA'. 

239 base_traj (Union[None, mdtraj.Trajectory], optional): An mdtraj.Trajectory that will be set to the coordinates from 

240 trajs, instead of trajs[0]. Normally, the first traj in `trajs` (trajs[0]) will be used as a base traj. 

241 It will be extended into the time-direction until it has the desired number of frames (shorten=True; 10, 

242 max_frames=N, N; etc.). If you don't want to use this traj but something else, you can feed this option 

243 an mdtraj.Trajectory object. Defaults to None. 

244 

245 Returns: 

246 tuple: A tuple containing: 

247 

248 view (nglview.view): The nglview.view object if nglview == True, 

249 is None otherwise. 

250 dummy_traj (mdtraj.Trajectory): The mdtraj trajectory with wrong 

251 topology but correct atomic positions. 

252 

253 See also: 

254 See the render_vmd function in this document 

255 to render an image of the returned traj. 

256 

257 """ 

258 if isinstance(cluster_no, (int, np.int_)): 

259 return _gen_dummy_traj_single( 

260 trajs, 

261 cluster_no, 

262 align_string, 

263 nglview, 

264 stack_atoms, 

265 shorten, 

266 max_frames, 

267 superpose, 

268 col, 

269 subunit, 

270 ref_align_string, 

271 base_traj, 

272 ) 

273 elif isinstance(cluster_no, (list, np.ndarray)): 

274 dummy_trajs, views = [], [] 

275 for i in cluster_no: 

276 v, dt = _gen_dummy_traj_single( 

277 trajs, 

278 i, 

279 align_string, 

280 nglview, 

281 stack_atoms, 

282 shorten, 

283 max_frames, 

284 superpose, 

285 col, 

286 subunit, 

287 ref_align_string, 

288 base_traj, 

289 ) 

290 dummy_trajs.append(dt) 

291 views.append(v) 

292 if stack_atoms: 

293 dummy_traj = dummy_trajs[0] 

294 view = views[0] 

295 for frame in dummy_trajs[1:]: 

296 dummy_traj = dummy_traj.stack(frame) 

297 if not subunit: 

298 view.add_trajectory(dummy_traj) 

299 else: 

300 raise NotImplementedError("Not yet Implemented") 

301 else: 

302 raise NotImplementedError( 

303 "Joining along time axes with multiple cluster's doesn't seem to make sense. Make a proposal how to handle this!" 

304 ) 

305 return views, dummy_trajs 

306 else: 

307 raise TypeError( 

308 f"`cluster_no` must be int or list. You supplied {type(cluster_no)}." 

309 ) 

310 

311 

312def _gen_dummy_traj_single( 

313 trajs, 

314 cluster_no, 

315 align_string="name CA", 

316 nglview=False, 

317 stack_atoms=False, 

318 shorten=False, 

319 max_frames=-1, 

320 superpose=True, 

321 col="cluster_membership", 

322 subunit="", 

323 ref_align_string="name CA", 

324 base_traj=None, 

325): 

326 """Called when only one cluster is needed.""" 

327 

328 joined_trajs, where = _get_joined_trajs(trajs, cluster_no, shorten, max_frames, col) 

329 

330 # preset nglview and only compute it if nglview = True 

331 view = None 

332 

333 # use traj[0] of the trajs list as the traj from which the topology will be used 

334 # or use base_traj, if provided 

335 if base_traj is None: 

336 if isinstance(trajs[0].index, slice): 

337 parent_traj = md.load( 

338 trajs.locations[0], top=trajs.top_files[0], stride=trajs[0].index.step 

339 )[: len(joined_trajs)] 

340 else: 

341 parent_traj = md.load(trajs.locations[0], top=trajs.top_files[0])[ 

342 : len(joined_trajs) 

343 ] 

344 else: 

345 parent_traj = base_traj 

346 

347 # print some info 

348 if align_string: 

349 sel = parent_traj.top.select(align_string) 

350 print( 

351 f"Provided alignment string results in {len(sel)} atoms. First atom is {parent_traj.top.atom(sel[0])}. Last atom is {parent_traj.top.atom(sel[-1])}." 

352 ) 

353 

354 # join the correct number of trajs 

355 # by use of the divmod method, the frames parent_traj traj will be 

356 # appended for a certain amount, until the remainder of the division 

357 # is met by that time, the parent traj will be sliced to fill the correct number of frames 

358 try: 

359 no_of_iters, rest = divmod(len(where), parent_traj.n_frames) 

360 except Exception as e: 

361 raise Exception( 

362 f"Can not buid a dummy trajectory. Maybe you selected the wronmg cluster num. Here's the original Error: {e}" 

363 ) 

364 for i in range(no_of_iters + 1): 

365 if i == 0: 

366 dummy_traj = copy.deepcopy(parent_traj) 

367 elif i == no_of_iters: 

368 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj)[:rest]) 

369 else: 

370 dummy_traj = dummy_traj.join(copy.deepcopy(parent_traj)) 

371 

372 # some checks 

373 assert len(where) == dummy_traj.n_frames 

374 assert len(where) == len(joined_trajs) 

375 

376 # change the xyz coordinates of dummy_traj according to the frames in joined trajs 

377 for i, traj in enumerate(joined_trajs): 

378 dummy_traj.xyz[i] = traj.xyz 

379 

380 # if superpose superpose the trajs 

381 if superpose: 

382 if isinstance(superpose, bool): 

383 ref_frame = copy.deepcopy(dummy_traj[0]) 

384 else: 

385 ref_frame = superpose 

386 try: 

387 _ = dummy_traj.superpose( 

388 ref_frame, 

389 atom_indices=dummy_traj.top.select(align_string), 

390 ref_atom_indices=ref_frame.top.select(ref_align_string), 

391 ) 

392 except AttributeError as e: 

393 raise BadError( 

394 f"You provided some wrong datatype or a misformatted string into the argument align_string. Here's the original error: {e}" 

395 ) 

396 

397 # if stack_atoms is true overwrite dummy_traj 

398 if stack_atoms: 

399 tmp_ = copy.deepcopy(dummy_traj) 

400 for i, frame in enumerate(tmp_): 

401 if i == 0: 

402 dummy_traj = copy.deepcopy(frame) 

403 else: 

404 if subunit: 

405 sel_all = frame.top.select("all") 

406 sel_subunit = frame.top.select(subunit) 

407 not_subunit = np.setdiff1d(sel_all, sel_subunit) 

408 frame = frame.atom_slice(not_subunit) 

409 dummy_traj = dummy_traj.stack(frame) 

410 

411 # make nglview 

412 if nglview and not subunit: 

413 view = ngl.show_mdtraj(dummy_traj) 

414 if nglview and subunit: 

415 warnings.simplefilter("ignore") 

416 view = ngl.show_mdtraj(tmp_[0]) 

417 view.clear_representations() 

418 view.add_representation("cartoon", color="red") 

419 for frame in tmp_[1:]: 

420 sel_all = frame.top.select("all") 

421 sel_subunit = frame.top.select(subunit) 

422 not_subunit = np.setdiff1d(sel_all, sel_subunit) 

423 frame = frame.atom_slice(not_subunit) 

424 try: 

425 _ = view.add_trajectory(frame) 

426 except KeyError as e: 

427 # for some reason it doen't want some atoms. Let's remove them. 

428 for a in frame.top.atoms: 

429 if str(a) == str(e): 

430 offender_index = a.index 

431 break 

432 without_offender = np.setdiff1d( 

433 frame.top.select("all"), [offender_index] 

434 ) 

435 frame = frame.atom_slice(without_offender) 

436 _ = view.add_trajectory(frame) 

437 

438 # return 

439 warnings.simplefilter("default") 

440 return view, dummy_traj 

441 

442 

443def rmsd_centroid_of_cluster(traj, parallel=True, atom_indices=None): 

444 """Computes the rmsd centroid of a trajectory. 

445 

446 Args: 

447 traj (mdtraj.Trajectory): The trajectory. 

448 

449 Returns: 

450 tuple: A tuple containing: 

451 

452 index (int): The index of the centroid. 

453 distances (np.ndarray): The RMSD distance matrix with shape 

454 traj.n_frames x traj.n_frames 

455 centroid (mdtraj.Trajectory): The traj of the centroid. 

456 

457 References: 

458 Uses the algorithm found on http://mdtraj.org/latest/examples/centroids.html 

459 

460 """ 

461 if not np.any(atom_indices): 

462 atom_indices = [a.index for a in traj.topology.atoms if a.element.symbol != "H"] 

463 distances = np.empty((traj.n_frames, traj.n_frames)) 

464 for i in range(traj.n_frames): 

465 distances[i] = md.rmsd( 

466 traj, traj, i, atom_indices=atom_indices, parallel=parallel 

467 ) 

468 beta = 1 

469 index = np.exp(-beta * distances / distances.std()).sum(axis=1).argmax() 

470 centroid = traj[index] 

471 return index, distances, centroid