Coverage for encodermap/encodermap_tf1/plot.py: 13%

231 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# Standard Library Imports 

2import datetime 

3import os 

4import subprocess 

5from itertools import cycle 

6 

7# Third Party Imports 

8import matplotlib.pyplot as plt 

9import MDAnalysis as md 

10import numpy as np 

11from matplotlib.lines import Line2D 

12from matplotlib.path import Path 

13from matplotlib.widgets import Lasso 

14from scipy.interpolate import interp1d 

15 

16# Local Folder Imports 

17from .backmapping import dihedral_backmapping 

18from .misc import create_dir, periodic_distance_np, sigmoid 

19 

20 

21class ManualPath(object): 

22 """ 

23 ManualPath is a tool to manually select a path in a matplotlib graph. 

24 It supports two modes: "interpolated line", and "free draw". 

25 Press "m" to switch modes. 

26 

27 In interpolated line mode click in the graph to add an additional way point. 

28 Press "delete" to remove the last way point. 

29 Press "d" to remove all way points. 

30 Press "enter" once you have finished your path selection. 

31 

32 In free draw mode press and hold the left mouse button while you draw a path. 

33 

34 Once the path selection is completed, the use_points method is called with the points on the selected path. 

35 You can overwrite the use_points method to do what ever you want with the points on the path. 

36 """ 

37 

38 def __init__(self, axe, n_points=200): 

39 """ 

40 

41 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots() 

42 :param n_points: Number of points distributed on the selected path. 

43 """ 

44 self.axe = axe 

45 self.canvas = axe.figure.canvas 

46 self.fig = axe.figure 

47 self.lasso = None 

48 self.n_interpolation_points = n_points 

49 

50 self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox) 

51 

52 control_line = Line2D( 

53 [], [], ls="--", c="#666666", marker="x", mew=2, mec="#204a87" 

54 ) 

55 self.control_line = self.axe.add_line(control_line) 

56 self.x_control = list(self.control_line.get_xdata()) 

57 self.y_control = list(self.control_line.get_ydata()) 

58 

59 interpolated_line = Line2D([], [], c=self.control_line.get_markeredgecolor()) 

60 self.interpolated_line = self.axe.add_line(interpolated_line) 

61 

62 self.modes = cycle([self._add_point_interp, self._free_draw]) 

63 self.mode = next(self.modes) 

64 self.key_cid = self.canvas.mpl_connect("key_press_event", self._on_key) 

65 self.click_cid = self.canvas.mpl_connect("button_press_event", self._on_click) 

66 self.draw_cid = self.canvas.mpl_connect("draw_event", self._grab_background) 

67 

68 def _on_click(self, event): 

69 # Ignore clicks outside axes 

70 if event.inaxes != self.axe: 

71 return 

72 if self.canvas.widgetlock.locked(): 

73 return 

74 if event.inaxes is None: 

75 return 

76 if event.button == 1: 

77 self.mode(event) 

78 

79 def _on_key(self, event): 

80 if event.key == "m": 

81 self.mode = next(self.modes) 

82 return 

83 if event.key == "enter": 

84 points = np.array(self.interpolated_line.get_data()).T 

85 self._reset_lines() 

86 self.use_points(points) 

87 self._grab_background() 

88 

89 return 

90 if event.key == "d": 

91 self._reset_lines() 

92 return 

93 if event.key == "delete": 

94 del self.x_control[-1] 

95 del self.y_control[-1] 

96 self._update_interp() 

97 

98 def _free_draw(self, event): 

99 self.lasso = Lasso( 

100 event.inaxes, (event.xdata, event.ydata), self._free_draw_callback 

101 ) 

102 # acquire a lock on the widget drawing 

103 self.canvas.widgetlock(self.lasso) 

104 

105 def _free_draw_callback(self, verts): 

106 points = np.array(verts) 

107 self.use_points(points) 

108 self.canvas.draw_idle() 

109 self.canvas.widgetlock.release(self.lasso) 

110 del self.lasso 

111 

112 def _add_point_interp(self, event): 

113 self.x_control.append(event.xdata) 

114 self.y_control.append(event.ydata) 

115 self._update_interp() 

116 

117 def _update_interp(self): 

118 self.control_line.set_data(self.x_control, self.y_control) 

119 x_i, y_i = self._interpolate(self.x_control, self.y_control) 

120 x_i, y_i = self._interpolate( 

121 x_i, y_i 

122 ) # second iteration makes points more evenly spaced 

123 self.interpolated_line.set_data(x_i, y_i) 

124 self._update_lines() 

125 

126 def _interpolate(self, x, y): 

127 cumulative_distances = [0] 

128 for i in range(1, len(x)): 

129 dist = ((x[i] - x[i - 1]) ** 2 + (y[i] - y[i - 1]) ** 2) ** 0.5 

130 cumulative_distances.append(cumulative_distances[-1] + dist) 

131 interp_i = np.linspace( 

132 0, max(cumulative_distances), self.n_interpolation_points 

133 ) 

134 try: 

135 x_i = interp1d(cumulative_distances, x, kind="cubic")(interp_i) 

136 y_i = interp1d(cumulative_distances, y, kind="cubic")(interp_i) 

137 except ValueError: 

138 try: 

139 x_i = interp1d(cumulative_distances, x, kind="linear")(interp_i) 

140 y_i = interp1d(cumulative_distances, y, kind="linear")(interp_i) 

141 except ValueError: 

142 x_i = [] 

143 y_i = [] 

144 return x_i, y_i 

145 

146 def use_points(self, points): 

147 """ 

148 Overwrite this method to use the selected points in any way you like. 

149 

150 For Example: 

151 

152 >>> class MyManualPath(ManualPath): 

153 >>> def use_points(self, points): 

154 >>> print(points) 

155 

156 :param points: numpy array with points from the manual path selection 

157 :return: None 

158 """ 

159 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".") 

160 raise NotImplementedError("has to be implemented in subclass") 

161 

162 def _grab_background(self, event=None): 

163 """ 

164 When the figure is resized, hide the points, draw everything, 

165 and update the background. 

166 """ 

167 # Thanks to: https://stackoverflow.com/questions/29277080/efficient-matplotlib-redrawing 

168 self.canvas.mpl_disconnect(self.draw_cid) 

169 self.interpolated_line.set_visible(False) 

170 self.control_line.set_visible(False) 

171 self.canvas.draw() 

172 

173 # With most backends (e.g. TkAgg), we could grab (and refresh, in 

174 # self.blit) self.ax.bbox instead of self.fig.bbox, but Qt4Agg, and 

175 # some others, requires us to update the _full_ canvas, instead. 

176 self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox) 

177 

178 self.interpolated_line.set_visible(True) 

179 self.control_line.set_visible(True) 

180 self.canvas.draw() 

181 self.draw_cid = self.canvas.mpl_connect("draw_event", self._grab_background) 

182 

183 def _update_lines(self): 

184 """ 

185 Efficiently update the figure, without needing to redraw the 

186 "background" artists. 

187 """ 

188 self.fig.canvas.restore_region(self.background) 

189 self.axe.draw_artist(self.interpolated_line) 

190 self.axe.draw_artist(self.control_line) 

191 self.canvas.draw() # Todo: canvas.draw() works in jupyter notebook but canvas.update() outside of notebook 

192 self.canvas.flush_events() 

193 

194 def _reset_lines(self): 

195 self.interpolated_line.set_data([[], []]) 

196 self.control_line.set_data([[], []]) 

197 self._update_lines() 

198 

199 self.x_control = [] 

200 self.y_control = [] 

201 

202 

203class PathGenerateDihedrals(ManualPath): 

204 """ 

205 This class inherits from :py:class:`encodermap.plot.ManualPath`. 

206 The points from a manually selected path are fed into the decoder part of a given autoencoder. 

207 The output of the autoencoder is used as phi psi dihedral angles to reconstruct protein conformations 

208 based on the protein structure given with pdb_path. 

209 Three output files are written for each selected path: 

210 points.npy, generated.npy and generated.pdb which contain: 

211 the points on the selected path, the generated output of 

212 the autoencoder, and the generated protein conformations respectively. 

213 Keep in mind that backbone dihedrals are not sufficient to describe a protein conformation completely. 

214 Usually the backbone is reconstructed well but all side chains are messed up. 

215 """ 

216 

217 def __init__(self, axe, autoencoder, pdb_path, save_path=None, n_points=200): 

218 """ 

219 

220 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots() 

221 :param autoencoder: :py:class:`encodermap.autoencoder.Autoencoder` which was trained on protein dihedral 

222 angles. The dihedrals have to be order starting from the amino end. 

223 First all phi angles then all psi angles. 

224 :param pdb_path: Path to a protein data bank (pdb) file of the protein 

225 :param save_path: Path where outputs should be written 

226 :param n_points: Number of points distributed on the selected path. 

227 """ 

228 super(PathGenerateDihedrals, self).__init__(axe, n_points=n_points) 

229 

230 self.autoencoder = autoencoder 

231 self.pdb_path = pdb_path 

232 

233 if save_path: 

234 self.save_path = save_path 

235 else: 

236 self.save_path = autoencoder.p.main_path 

237 

238 def use_points(self, points): 

239 current_save_path = create_dir( 

240 os.path.join( 

241 self.save_path, 

242 "generated_paths", 

243 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 

244 ) 

245 ) 

246 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".") 

247 np.save(os.path.join(current_save_path, "points"), points) 

248 generated = self.autoencoder.generate(points) 

249 np.save(os.path.join(current_save_path, "generated.npy"), generated) 

250 

251 universe = dihedral_backmapping(self.pdb_path, generated) 

252 output_pdb_path = os.path.join(current_save_path, "generated.pdb") 

253 with md.Writer(output_pdb_path) as w: 

254 for step in universe.trajectory: 

255 w.write(universe.atoms) 

256 

257 

258class PathGenerateCartesians(ManualPath): 

259 """ 

260 This class inherits from :class:`encodermap.plot.ManualPath`. 

261 It is used to select paths in a 2d map and to generate conformations for these paths with a 

262 AngleDihedralCartesianEncoder. 

263 """ 

264 

265 def __init__( 

266 self, 

267 axe, 

268 autoencoder, 

269 mol_data, 

270 save_path=None, 

271 n_points=200, 

272 vmd_path="", 

273 align_reference=None, 

274 align_select="all", 

275 ): 

276 """ 

277 

278 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots() 

279 :param autoencoder: :class:`.AngleDihedralCartesianEncoder` 

280 :param mol_data: :class:`.MolData` 

281 :param save_path: Path where outputs should be written 

282 :param n_points: Number of points distributed on the selected path. 

283 :param vmd_path: If a path to vmd is given, the generated conformations will be directly opened in vmd. 

284 :param align_reference: Allows to allign the generated conformations according to some reference. 

285 The reference should be given as MDAnalysis atomgroup 

286 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60" 

287 default is "all". Have a look at the MDAnalysis selection syntax for more details. 

288 

289 """ 

290 super().__init__(axe, n_points=n_points) 

291 

292 self.autoencoder = autoencoder 

293 self.mol_data = mol_data 

294 self.vmd_path = vmd_path 

295 

296 self.align_reference = align_reference 

297 self.align_select = align_select 

298 

299 if save_path: 

300 self.save_path = save_path 

301 else: 

302 self.save_path = autoencoder.p.main_path 

303 

304 def use_points(self, points): 

305 current_save_path = create_dir( 

306 os.path.join( 

307 self.save_path, 

308 "generated_paths", 

309 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 

310 ) 

311 ) 

312 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".") 

313 np.save(os.path.join(current_save_path, "points"), points) 

314 try: 

315 dihedrals, cartesians = self.autoencoder.generate(points) 

316 except ValueError: 

317 angles, dihedrals, cartesians = self.autoencoder.generate(points) 

318 np.save(os.path.join(current_save_path, "generated_dihedrals.npy"), dihedrals) 

319 np.save(os.path.join(current_save_path, "generated_cartesians.npy"), cartesians) 

320 

321 self.mol_data.write( 

322 current_save_path, 

323 cartesians, 

324 only_central=False, 

325 align_reference=self.align_reference, 

326 align_select=self.align_select, 

327 ) 

328 if self.vmd_path: 

329 cmd = "{} {} {}".format(self.vmd_path, "generated.pdb", "generated.xtc") 

330 print(cmd) 

331 process = subprocess.Popen( 

332 cmd, 

333 cwd=current_save_path, 

334 shell=True, 

335 stdin=subprocess.PIPE, 

336 stdout=subprocess.PIPE, 

337 ) 

338 process.stdin.write(b"animate delete beg 0 end 0 skip 0 0\n") 

339 process.stdin.flush() 

340 

341 

342class PathSelect(ManualPath): 

343 """ 

344 This class inherits from :class:`encodermap.plot.ManualPath`. 

345 It is used to select areas in a 2d map and to write all conformations in these areas to separate trajectories. 

346 """ 

347 

348 def __init__( 

349 self, 

350 axe, 

351 projected, 

352 mol_data, 

353 save_path, 

354 n_points=200, 

355 vmd_path="", 

356 align_reference=None, 

357 align_select="all", 

358 ): 

359 """ 

360 

361 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots() 

362 :param projected: points in the map (must be the same number of points as conformations in mol_data) 

363 :param mol_data: :class:`.MolData` 

364 :param save_path: Path where outputs should be written 

365 :param n_points: Number of points distributed on the selected path. 

366 :param vmd_path: If a path to vmd is given, the generated conformations will be directly opened in vmd. 

367 :param align_reference: Allows to allign the generated conformations according to some reference. 

368 The reference should be given as MDAnalysis atomgroup 

369 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60" 

370 default is "all". Have a look at the MDAnalysis selection syntax for more details. 

371 

372 """ 

373 super().__init__(axe, n_points=n_points) 

374 

375 self.mol_data = mol_data 

376 self.vmd_path = vmd_path 

377 self.projected = projected 

378 assert len(projected) == len(mol_data.dihedrals) 

379 

380 self.align_reference = align_reference 

381 self.align_select = align_select 

382 

383 self.save_path = save_path 

384 

385 def use_points(self, points): 

386 current_save_path = create_dir( 

387 os.path.join( 

388 self.save_path, 

389 "selected", 

390 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 

391 ) 

392 ) 

393 indices = np.nonzero(Path(points).contains_points(self.projected))[0] 

394 self.axe.scatter(self.projected[indices, 0], self.projected[indices, 1]) 

395 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".") 

396 np.save(os.path.join(current_save_path, "points"), self.projected[indices, :]) 

397 np.save(os.path.join(current_save_path, "indices"), indices) 

398 

399 self.mol_data.write( 

400 current_save_path, 

401 self.mol_data.cartesians[indices], 

402 only_central=False, 

403 align_reference=self.align_reference, 

404 align_select=self.align_select, 

405 ) 

406 if self.vmd_path: 

407 cmd = "{} {} {}".format(self.vmd_path, "generated.pdb", "generated.xtc") 

408 print(cmd) 

409 process = subprocess.Popen( 

410 cmd, 

411 cwd=current_save_path, 

412 shell=True, 

413 stdin=subprocess.PIPE, 

414 stdout=subprocess.PIPE, 

415 ) 

416 process.stdin.write(b"animate delete beg 0 end 0 skip 0 0\n") 

417 process.stdin.flush() 

418 

419 

420def distance_histogram( 

421 data, periodicity, sigmoid_parameters, axes=None, low_d_max=5, bins="auto" 

422): 

423 """ 

424 Plots the histogram of all pairwise distances in the data. 

425 It also shows the sigmoid function and its normalized derivative. 

426 

427 :param data: each row should contain a point in a number_of _columns dimensional space. 

428 :param periodicity: Periodicity of the data. use float("inf") for non periodic data 

429 :param sigmoid_parameters: tuple (sigma, a, b) 

430 :param axes: Array like structure with two matplotlib axe objects ore None. If None a new figure is generated. 

431 :param low_d_max: upper limit for plotting the low_d sigmoid 

432 :param bins: number of bins for histogram 

433 :return: matplotlib axe objects 

434 """ 

435 vecs = periodic_distance_np( 

436 np.expand_dims(data, axis=1), np.expand_dims(data, axis=0), periodicity 

437 ) 

438 dists = np.linalg.norm(vecs, axis=2) 

439 while True: 

440 try: 

441 dists = np.linalg.norm(dists, axis=2) 

442 except np.AxisError: 

443 break 

444 dists = dists.reshape(-1) 

445 

446 if axes is None: 

447 fig, axes = plt.subplots(2) 

448 axe2 = axes[0].twinx() 

449 counts, edges, patches = axe2.hist( 

450 dists, bins=bins, density=True, edgecolor="black" 

451 ) 

452 x = np.linspace(0, max(dists), 1000) 

453 

454 y = sigmoid(x, *sigmoid_parameters[:3]) 

455 edges_sig = sigmoid(edges, *sigmoid_parameters[:3]) 

456 dy = np.diff(y) 

457 dy_norm = dy / max(dy) 

458 axes[0].plot(x, y, color="C1", label="sigmoid") 

459 axes[0].plot(x[:-1], dy_norm, color="C2", label="diff sigmoid") 

460 

461 axes[0].legend() 

462 axes[0].set_xlabel("distance") 

463 axes[0].set_ylim((0, 1)) 

464 axes[0].set_zorder(axe2.get_zorder() + 1) 

465 axes[0].patch.set_visible(False) 

466 axes[0].set_title("high-d") 

467 

468 x = np.linspace(0, low_d_max, 1000) 

469 y = sigmoid(x, *sigmoid_parameters[3:]) 

470 dy = np.diff(y) 

471 dy_norm = dy / max(dy) 

472 idx = np.argmin( 

473 np.abs(np.expand_dims(edges_sig, axis=1) - np.expand_dims(y, axis=0)), axis=1 

474 ) 

475 edges_x = x[idx] 

476 

477 axes[1].plot(x, y, color="C1", label="sigmoid") 

478 

479 axes[1].legend() 

480 axes[1].set_xlabel("distance") 

481 axes[1].set_ylim((0, 1)) 

482 axes[1].set_title("low-d") 

483 for i in range(len(edges)): 

484 if edges_x[i] != edges_x[-1]: 

485 axes[1].annotate( 

486 "", 

487 xy=(edges[i], 0), 

488 xytext=(edges_x[i], 0), 

489 xycoords=axes[0].transData, 

490 textcoords=axes[1].transData, 

491 arrowprops=dict(facecolor="black", arrowstyle="-", clip_on=False), 

492 ) 

493 axes[0].figure.tight_layout() 

494 return axes[0], axe2, axes[1]