Coverage for encodermap/encodermap_tf1/plot.py: 14%

230 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1import datetime 

2import os 

3import subprocess 

4from itertools import cycle 

5 

6import matplotlib.pyplot as plt 

7import MDAnalysis as md 

8import numpy as np 

9from matplotlib.lines import Line2D 

10from matplotlib.path import Path 

11from matplotlib.widgets import Lasso 

12from scipy.interpolate import interp1d 

13 

14from .backmapping import dihedral_backmapping 

15from .misc import create_dir, periodic_distance_np, sigmoid 

16 

17 

18class ManualPath(object): 

19 """ 

20 ManualPath is a tool to manually select a path in a matplotlib graph. 

21 It supports two modes: "interpolated line", and "free draw". 

22 Press "m" to switch modes. 

23 

24 In interpolated line mode click in the graph to add an additional way point. 

25 Press "delete" to remove the last way point. 

26 Press "d" to remove all way points. 

27 Press "enter" once you have finished your path selection. 

28 

29 In free draw mode press and hold the left mouse button while you draw a path. 

30 

31 Once the path selection is completed, the use_points method is called with the points on the selected path. 

32 You can overwrite the use_points method to do what ever you want with the points on the path. 

33 """ 

34 

35 def __init__(self, axe, n_points=200): 

36 """ 

37 

38 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots() 

39 :param n_points: Number of points distributed on the selected path. 

40 """ 

41 self.axe = axe 

42 self.canvas = axe.figure.canvas 

43 self.fig = axe.figure 

44 self.lasso = None 

45 self.n_interpolation_points = n_points 

46 

47 self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox) 

48 

49 control_line = Line2D( 

50 [], [], ls="--", c="#666666", marker="x", mew=2, mec="#204a87" 

51 ) 

52 self.control_line = self.axe.add_line(control_line) 

53 self.x_control = list(self.control_line.get_xdata()) 

54 self.y_control = list(self.control_line.get_ydata()) 

55 

56 interpolated_line = Line2D([], [], c=self.control_line.get_markeredgecolor()) 

57 self.interpolated_line = self.axe.add_line(interpolated_line) 

58 

59 self.modes = cycle([self._add_point_interp, self._free_draw]) 

60 self.mode = next(self.modes) 

61 self.key_cid = self.canvas.mpl_connect("key_press_event", self._on_key) 

62 self.click_cid = self.canvas.mpl_connect("button_press_event", self._on_click) 

63 self.draw_cid = self.canvas.mpl_connect("draw_event", self._grab_background) 

64 

65 def _on_click(self, event): 

66 # Ignore clicks outside axes 

67 if event.inaxes != self.axe: 

68 return 

69 if self.canvas.widgetlock.locked(): 

70 return 

71 if event.inaxes is None: 

72 return 

73 if event.button == 1: 

74 self.mode(event) 

75 

76 def _on_key(self, event): 

77 if event.key == "m": 

78 self.mode = next(self.modes) 

79 return 

80 if event.key == "enter": 

81 points = np.array(self.interpolated_line.get_data()).T 

82 self._reset_lines() 

83 self.use_points(points) 

84 self._grab_background() 

85 

86 return 

87 if event.key == "d": 

88 self._reset_lines() 

89 return 

90 if event.key == "delete": 

91 del self.x_control[-1] 

92 del self.y_control[-1] 

93 self._update_interp() 

94 

95 def _free_draw(self, event): 

96 self.lasso = Lasso( 

97 event.inaxes, (event.xdata, event.ydata), self._free_draw_callback 

98 ) 

99 # acquire a lock on the widget drawing 

100 self.canvas.widgetlock(self.lasso) 

101 

102 def _free_draw_callback(self, verts): 

103 points = np.array(verts) 

104 self.use_points(points) 

105 self.canvas.draw_idle() 

106 self.canvas.widgetlock.release(self.lasso) 

107 del self.lasso 

108 

109 def _add_point_interp(self, event): 

110 self.x_control.append(event.xdata) 

111 self.y_control.append(event.ydata) 

112 self._update_interp() 

113 

114 def _update_interp(self): 

115 self.control_line.set_data(self.x_control, self.y_control) 

116 x_i, y_i = self._interpolate(self.x_control, self.y_control) 

117 x_i, y_i = self._interpolate( 

118 x_i, y_i 

119 ) # second iteration makes points more evenly spaced 

120 self.interpolated_line.set_data(x_i, y_i) 

121 self._update_lines() 

122 

123 def _interpolate(self, x, y): 

124 cumulative_distances = [0] 

125 for i in range(1, len(x)): 

126 dist = ((x[i] - x[i - 1]) ** 2 + (y[i] - y[i - 1]) ** 2) ** 0.5 

127 cumulative_distances.append(cumulative_distances[-1] + dist) 

128 interp_i = np.linspace( 

129 0, max(cumulative_distances), self.n_interpolation_points 

130 ) 

131 try: 

132 x_i = interp1d(cumulative_distances, x, kind="cubic")(interp_i) 

133 y_i = interp1d(cumulative_distances, y, kind="cubic")(interp_i) 

134 except ValueError: 

135 try: 

136 x_i = interp1d(cumulative_distances, x, kind="linear")(interp_i) 

137 y_i = interp1d(cumulative_distances, y, kind="linear")(interp_i) 

138 except ValueError: 

139 x_i = [] 

140 y_i = [] 

141 return x_i, y_i 

142 

143 def use_points(self, points): 

144 """ 

145 Overwrite this method to use the selected points in any way you like. 

146 

147 For Example: 

148 

149 >>> class MyManualPath(ManualPath): 

150 >>> def use_points(self, points): 

151 >>> print(points) 

152 

153 :param points: numpy array with points from the manual path selection 

154 :return: None 

155 """ 

156 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".") 

157 raise NotImplementedError("has to be implemented in subclass") 

158 

159 def _grab_background(self, event=None): 

160 """ 

161 When the figure is resized, hide the points, draw everything, 

162 and update the background. 

163 """ 

164 # Thanks to: https://stackoverflow.com/questions/29277080/efficient-matplotlib-redrawing 

165 self.canvas.mpl_disconnect(self.draw_cid) 

166 self.interpolated_line.set_visible(False) 

167 self.control_line.set_visible(False) 

168 self.canvas.draw() 

169 

170 # With most backends (e.g. TkAgg), we could grab (and refresh, in 

171 # self.blit) self.ax.bbox instead of self.fig.bbox, but Qt4Agg, and 

172 # some others, requires us to update the _full_ canvas, instead. 

173 self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox) 

174 

175 self.interpolated_line.set_visible(True) 

176 self.control_line.set_visible(True) 

177 self.canvas.draw() 

178 self.draw_cid = self.canvas.mpl_connect("draw_event", self._grab_background) 

179 

180 def _update_lines(self): 

181 """ 

182 Efficiently update the figure, without needing to redraw the 

183 "background" artists. 

184 """ 

185 self.fig.canvas.restore_region(self.background) 

186 self.axe.draw_artist(self.interpolated_line) 

187 self.axe.draw_artist(self.control_line) 

188 self.canvas.draw() # Todo: canvas.draw() works in jupyter notebook but canvas.update() outside of notebook 

189 self.canvas.flush_events() 

190 

191 def _reset_lines(self): 

192 self.interpolated_line.set_data([[], []]) 

193 self.control_line.set_data([[], []]) 

194 self._update_lines() 

195 

196 self.x_control = [] 

197 self.y_control = [] 

198 

199 

200class PathGenerateDihedrals(ManualPath): 

201 """ 

202 This class inherits from :py:class:`encodermap.plot.ManualPath`. 

203 The points from a manually selected path are fed into the decoder part of a given autoencoder. 

204 The output of the autoencoder is used as phi psi dihedral angles to reconstruct protein conformations 

205 based on the protein structure given with pdb_path. 

206 Three output files are written for each selected path: 

207 points.npy, generated.npy and generated.pdb which contain: 

208 the points on the selected path, the generated output of 

209 the autoencoder, and the generated protein conformations respectively. 

210 Keep in mind that backbone dihedrals are not sufficient to describe a protein conformation completely. 

211 Usually the backbone is reconstructed well but all side chains are messed up. 

212 """ 

213 

214 def __init__(self, axe, autoencoder, pdb_path, save_path=None, n_points=200): 

215 """ 

216 

217 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots() 

218 :param autoencoder: :py:class:`encodermap.autoencoder.Autoencoder` which was trained on protein dihedral 

219 angles. The dihedrals have to be order starting from the amino end. 

220 First all phi angles then all psi angles. 

221 :param pdb_path: Path to a protein data bank (pdb) file of the protein 

222 :param save_path: Path where outputs should be written 

223 :param n_points: Number of points distributed on the selected path. 

224 """ 

225 super(PathGenerateDihedrals, self).__init__(axe, n_points=n_points) 

226 

227 self.autoencoder = autoencoder 

228 self.pdb_path = pdb_path 

229 

230 if save_path: 

231 self.save_path = save_path 

232 else: 

233 self.save_path = autoencoder.p.main_path 

234 

235 def use_points(self, points): 

236 current_save_path = create_dir( 

237 os.path.join( 

238 self.save_path, 

239 "generated_paths", 

240 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 

241 ) 

242 ) 

243 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".") 

244 np.save(os.path.join(current_save_path, "points"), points) 

245 generated = self.autoencoder.generate(points) 

246 np.save(os.path.join(current_save_path, "generated.npy"), generated) 

247 

248 universe = dihedral_backmapping(self.pdb_path, generated) 

249 output_pdb_path = os.path.join(current_save_path, "generated.pdb") 

250 with md.Writer(output_pdb_path) as w: 

251 for step in universe.trajectory: 

252 w.write(universe.atoms) 

253 

254 

255class PathGenerateCartesians(ManualPath): 

256 """ 

257 This class inherits from :class:`encodermap.plot.ManualPath`. 

258 It is used to select paths in a 2d map and to generate conformations for these paths with a 

259 AngleDihedralCartesianEncoder. 

260 """ 

261 

262 def __init__( 

263 self, 

264 axe, 

265 autoencoder, 

266 mol_data, 

267 save_path=None, 

268 n_points=200, 

269 vmd_path="", 

270 align_reference=None, 

271 align_select="all", 

272 ): 

273 """ 

274 

275 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots() 

276 :param autoencoder: :class:`.AngleDihedralCartesianEncoder` 

277 :param mol_data: :class:`.MolData` 

278 :param save_path: Path where outputs should be written 

279 :param n_points: Number of points distributed on the selected path. 

280 :param vmd_path: If a path to vmd is given, the generated conformations will be directly opened in vmd. 

281 :param align_reference: Allows to allign the generated conformations according to some reference. 

282 The reference should be given as MDAnalysis atomgroup 

283 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60" 

284 default is "all". Have a look at the MDAnalysis selection syntax for more details. 

285 

286 """ 

287 super().__init__(axe, n_points=n_points) 

288 

289 self.autoencoder = autoencoder 

290 self.mol_data = mol_data 

291 self.vmd_path = vmd_path 

292 

293 self.align_reference = align_reference 

294 self.align_select = align_select 

295 

296 if save_path: 

297 self.save_path = save_path 

298 else: 

299 self.save_path = autoencoder.p.main_path 

300 

301 def use_points(self, points): 

302 current_save_path = create_dir( 

303 os.path.join( 

304 self.save_path, 

305 "generated_paths", 

306 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 

307 ) 

308 ) 

309 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".") 

310 np.save(os.path.join(current_save_path, "points"), points) 

311 try: 

312 dihedrals, cartesians = self.autoencoder.generate(points) 

313 except ValueError: 

314 angles, dihedrals, cartesians = self.autoencoder.generate(points) 

315 np.save(os.path.join(current_save_path, "generated_dihedrals.npy"), dihedrals) 

316 np.save(os.path.join(current_save_path, "generated_cartesians.npy"), cartesians) 

317 

318 self.mol_data.write( 

319 current_save_path, 

320 cartesians, 

321 only_central=False, 

322 align_reference=self.align_reference, 

323 align_select=self.align_select, 

324 ) 

325 if self.vmd_path: 

326 cmd = "{} {} {}".format(self.vmd_path, "generated.pdb", "generated.xtc") 

327 print(cmd) 

328 process = subprocess.Popen( 

329 cmd, 

330 cwd=current_save_path, 

331 shell=True, 

332 stdin=subprocess.PIPE, 

333 stdout=subprocess.PIPE, 

334 ) 

335 process.stdin.write(b"animate delete beg 0 end 0 skip 0 0\n") 

336 process.stdin.flush() 

337 

338 

339class PathSelect(ManualPath): 

340 """ 

341 This class inherits from :class:`encodermap.plot.ManualPath`. 

342 It is used to select areas in a 2d map and to write all conformations in these areas to separate trajectories. 

343 """ 

344 

345 def __init__( 

346 self, 

347 axe, 

348 projected, 

349 mol_data, 

350 save_path, 

351 n_points=200, 

352 vmd_path="", 

353 align_reference=None, 

354 align_select="all", 

355 ): 

356 """ 

357 

358 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots() 

359 :param projected: points in the map (must be the same number of points as conformations in mol_data) 

360 :param mol_data: :class:`.MolData` 

361 :param save_path: Path where outputs should be written 

362 :param n_points: Number of points distributed on the selected path. 

363 :param vmd_path: If a path to vmd is given, the generated conformations will be directly opened in vmd. 

364 :param align_reference: Allows to allign the generated conformations according to some reference. 

365 The reference should be given as MDAnalysis atomgroup 

366 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60" 

367 default is "all". Have a look at the MDAnalysis selection syntax for more details. 

368 

369 """ 

370 super().__init__(axe, n_points=n_points) 

371 

372 self.mol_data = mol_data 

373 self.vmd_path = vmd_path 

374 self.projected = projected 

375 assert len(projected) == len(mol_data.dihedrals) 

376 

377 self.align_reference = align_reference 

378 self.align_select = align_select 

379 

380 self.save_path = save_path 

381 

382 def use_points(self, points): 

383 current_save_path = create_dir( 

384 os.path.join( 

385 self.save_path, 

386 "selected", 

387 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 

388 ) 

389 ) 

390 indices = np.nonzero(Path(points).contains_points(self.projected))[0] 

391 self.axe.scatter(self.projected[indices, 0], self.projected[indices, 1]) 

392 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".") 

393 np.save(os.path.join(current_save_path, "points"), self.projected[indices, :]) 

394 np.save(os.path.join(current_save_path, "indices"), indices) 

395 

396 self.mol_data.write( 

397 current_save_path, 

398 self.mol_data.cartesians[indices], 

399 only_central=False, 

400 align_reference=self.align_reference, 

401 align_select=self.align_select, 

402 ) 

403 if self.vmd_path: 

404 cmd = "{} {} {}".format(self.vmd_path, "generated.pdb", "generated.xtc") 

405 print(cmd) 

406 process = subprocess.Popen( 

407 cmd, 

408 cwd=current_save_path, 

409 shell=True, 

410 stdin=subprocess.PIPE, 

411 stdout=subprocess.PIPE, 

412 ) 

413 process.stdin.write(b"animate delete beg 0 end 0 skip 0 0\n") 

414 process.stdin.flush() 

415 

416 

417def distance_histogram( 

418 data, periodicity, sigmoid_parameters, axes=None, low_d_max=5, bins="auto" 

419): 

420 """ 

421 Plots the histogram of all pairwise distances in the data. 

422 It also shows the sigmoid function and its normalized derivative. 

423 

424 :param data: each row should contain a point in a number_of _columns dimensional space. 

425 :param periodicity: Periodicity of the data. use float("inf") for non periodic data 

426 :param sigmoid_parameters: tuple (sigma, a, b) 

427 :param axes: Array like structure with two matplotlib axe objects ore None. If None a new figure is generated. 

428 :param low_d_max: upper limit for plotting the low_d sigmoid 

429 :param bins: number of bins for histogram 

430 :return: matplotlib axe objects 

431 """ 

432 vecs = periodic_distance_np( 

433 np.expand_dims(data, axis=1), np.expand_dims(data, axis=0), periodicity 

434 ) 

435 dists = np.linalg.norm(vecs, axis=2) 

436 while True: 

437 try: 

438 dists = np.linalg.norm(dists, axis=2) 

439 except np.AxisError: 

440 break 

441 dists = dists.reshape(-1) 

442 

443 if axes is None: 

444 fig, axes = plt.subplots(2) 

445 axe2 = axes[0].twinx() 

446 counts, edges, patches = axe2.hist( 

447 dists, bins=bins, density=True, edgecolor="black" 

448 ) 

449 x = np.linspace(0, max(dists), 1000) 

450 

451 y = sigmoid(x, *sigmoid_parameters[:3]) 

452 edges_sig = sigmoid(edges, *sigmoid_parameters[:3]) 

453 dy = np.diff(y) 

454 dy_norm = dy / max(dy) 

455 axes[0].plot(x, y, color="C1", label="sigmoid") 

456 axes[0].plot(x[:-1], dy_norm, color="C2", label="diff sigmoid") 

457 

458 axes[0].legend() 

459 axes[0].set_xlabel("distance") 

460 axes[0].set_ylim((0, 1)) 

461 axes[0].set_zorder(axe2.get_zorder() + 1) 

462 axes[0].patch.set_visible(False) 

463 axes[0].set_title("high-d") 

464 

465 x = np.linspace(0, low_d_max, 1000) 

466 y = sigmoid(x, *sigmoid_parameters[3:]) 

467 dy = np.diff(y) 

468 dy_norm = dy / max(dy) 

469 idx = np.argmin( 

470 np.abs(np.expand_dims(edges_sig, axis=1) - np.expand_dims(y, axis=0)), axis=1 

471 ) 

472 edges_x = x[idx] 

473 

474 axes[1].plot(x, y, color="C1", label="sigmoid") 

475 

476 axes[1].legend() 

477 axes[1].set_xlabel("distance") 

478 axes[1].set_ylim((0, 1)) 

479 axes[1].set_title("low-d") 

480 for i in range(len(edges)): 

481 if edges_x[i] != edges_x[-1]: 

482 axes[1].annotate( 

483 "", 

484 xy=(edges[i], 0), 

485 xytext=(edges_x[i], 0), 

486 xycoords=axes[0].transData, 

487 textcoords=axes[1].transData, 

488 arrowprops=dict(facecolor="black", arrowstyle="-", clip_on=False), 

489 ) 

490 axes[0].figure.tight_layout() 

491 return axes[0], axe2, axes[1]