Coverage for encodermap/encodermap_tf1/plot.py: 13%
231 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# Standard Library Imports
2import datetime
3import os
4import subprocess
5from itertools import cycle
7# Third Party Imports
8import matplotlib.pyplot as plt
9import MDAnalysis as md
10import numpy as np
11from matplotlib.lines import Line2D
12from matplotlib.path import Path
13from matplotlib.widgets import Lasso
14from scipy.interpolate import interp1d
16# Local Folder Imports
17from .backmapping import dihedral_backmapping
18from .misc import create_dir, periodic_distance_np, sigmoid
21class ManualPath(object):
22 """
23 ManualPath is a tool to manually select a path in a matplotlib graph.
24 It supports two modes: "interpolated line", and "free draw".
25 Press "m" to switch modes.
27 In interpolated line mode click in the graph to add an additional way point.
28 Press "delete" to remove the last way point.
29 Press "d" to remove all way points.
30 Press "enter" once you have finished your path selection.
32 In free draw mode press and hold the left mouse button while you draw a path.
34 Once the path selection is completed, the use_points method is called with the points on the selected path.
35 You can overwrite the use_points method to do what ever you want with the points on the path.
36 """
38 def __init__(self, axe, n_points=200):
39 """
41 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots()
42 :param n_points: Number of points distributed on the selected path.
43 """
44 self.axe = axe
45 self.canvas = axe.figure.canvas
46 self.fig = axe.figure
47 self.lasso = None
48 self.n_interpolation_points = n_points
50 self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox)
52 control_line = Line2D(
53 [], [], ls="--", c="#666666", marker="x", mew=2, mec="#204a87"
54 )
55 self.control_line = self.axe.add_line(control_line)
56 self.x_control = list(self.control_line.get_xdata())
57 self.y_control = list(self.control_line.get_ydata())
59 interpolated_line = Line2D([], [], c=self.control_line.get_markeredgecolor())
60 self.interpolated_line = self.axe.add_line(interpolated_line)
62 self.modes = cycle([self._add_point_interp, self._free_draw])
63 self.mode = next(self.modes)
64 self.key_cid = self.canvas.mpl_connect("key_press_event", self._on_key)
65 self.click_cid = self.canvas.mpl_connect("button_press_event", self._on_click)
66 self.draw_cid = self.canvas.mpl_connect("draw_event", self._grab_background)
68 def _on_click(self, event):
69 # Ignore clicks outside axes
70 if event.inaxes != self.axe:
71 return
72 if self.canvas.widgetlock.locked():
73 return
74 if event.inaxes is None:
75 return
76 if event.button == 1:
77 self.mode(event)
79 def _on_key(self, event):
80 if event.key == "m":
81 self.mode = next(self.modes)
82 return
83 if event.key == "enter":
84 points = np.array(self.interpolated_line.get_data()).T
85 self._reset_lines()
86 self.use_points(points)
87 self._grab_background()
89 return
90 if event.key == "d":
91 self._reset_lines()
92 return
93 if event.key == "delete":
94 del self.x_control[-1]
95 del self.y_control[-1]
96 self._update_interp()
98 def _free_draw(self, event):
99 self.lasso = Lasso(
100 event.inaxes, (event.xdata, event.ydata), self._free_draw_callback
101 )
102 # acquire a lock on the widget drawing
103 self.canvas.widgetlock(self.lasso)
105 def _free_draw_callback(self, verts):
106 points = np.array(verts)
107 self.use_points(points)
108 self.canvas.draw_idle()
109 self.canvas.widgetlock.release(self.lasso)
110 del self.lasso
112 def _add_point_interp(self, event):
113 self.x_control.append(event.xdata)
114 self.y_control.append(event.ydata)
115 self._update_interp()
117 def _update_interp(self):
118 self.control_line.set_data(self.x_control, self.y_control)
119 x_i, y_i = self._interpolate(self.x_control, self.y_control)
120 x_i, y_i = self._interpolate(
121 x_i, y_i
122 ) # second iteration makes points more evenly spaced
123 self.interpolated_line.set_data(x_i, y_i)
124 self._update_lines()
126 def _interpolate(self, x, y):
127 cumulative_distances = [0]
128 for i in range(1, len(x)):
129 dist = ((x[i] - x[i - 1]) ** 2 + (y[i] - y[i - 1]) ** 2) ** 0.5
130 cumulative_distances.append(cumulative_distances[-1] + dist)
131 interp_i = np.linspace(
132 0, max(cumulative_distances), self.n_interpolation_points
133 )
134 try:
135 x_i = interp1d(cumulative_distances, x, kind="cubic")(interp_i)
136 y_i = interp1d(cumulative_distances, y, kind="cubic")(interp_i)
137 except ValueError:
138 try:
139 x_i = interp1d(cumulative_distances, x, kind="linear")(interp_i)
140 y_i = interp1d(cumulative_distances, y, kind="linear")(interp_i)
141 except ValueError:
142 x_i = []
143 y_i = []
144 return x_i, y_i
146 def use_points(self, points):
147 """
148 Overwrite this method to use the selected points in any way you like.
150 For Example:
152 >>> class MyManualPath(ManualPath):
153 >>> def use_points(self, points):
154 >>> print(points)
156 :param points: numpy array with points from the manual path selection
157 :return: None
158 """
159 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".")
160 raise NotImplementedError("has to be implemented in subclass")
162 def _grab_background(self, event=None):
163 """
164 When the figure is resized, hide the points, draw everything,
165 and update the background.
166 """
167 # Thanks to: https://stackoverflow.com/questions/29277080/efficient-matplotlib-redrawing
168 self.canvas.mpl_disconnect(self.draw_cid)
169 self.interpolated_line.set_visible(False)
170 self.control_line.set_visible(False)
171 self.canvas.draw()
173 # With most backends (e.g. TkAgg), we could grab (and refresh, in
174 # self.blit) self.ax.bbox instead of self.fig.bbox, but Qt4Agg, and
175 # some others, requires us to update the _full_ canvas, instead.
176 self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox)
178 self.interpolated_line.set_visible(True)
179 self.control_line.set_visible(True)
180 self.canvas.draw()
181 self.draw_cid = self.canvas.mpl_connect("draw_event", self._grab_background)
183 def _update_lines(self):
184 """
185 Efficiently update the figure, without needing to redraw the
186 "background" artists.
187 """
188 self.fig.canvas.restore_region(self.background)
189 self.axe.draw_artist(self.interpolated_line)
190 self.axe.draw_artist(self.control_line)
191 self.canvas.draw() # Todo: canvas.draw() works in jupyter notebook but canvas.update() outside of notebook
192 self.canvas.flush_events()
194 def _reset_lines(self):
195 self.interpolated_line.set_data([[], []])
196 self.control_line.set_data([[], []])
197 self._update_lines()
199 self.x_control = []
200 self.y_control = []
203class PathGenerateDihedrals(ManualPath):
204 """
205 This class inherits from :py:class:`encodermap.plot.ManualPath`.
206 The points from a manually selected path are fed into the decoder part of a given autoencoder.
207 The output of the autoencoder is used as phi psi dihedral angles to reconstruct protein conformations
208 based on the protein structure given with pdb_path.
209 Three output files are written for each selected path:
210 points.npy, generated.npy and generated.pdb which contain:
211 the points on the selected path, the generated output of
212 the autoencoder, and the generated protein conformations respectively.
213 Keep in mind that backbone dihedrals are not sufficient to describe a protein conformation completely.
214 Usually the backbone is reconstructed well but all side chains are messed up.
215 """
217 def __init__(self, axe, autoencoder, pdb_path, save_path=None, n_points=200):
218 """
220 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots()
221 :param autoencoder: :py:class:`encodermap.autoencoder.Autoencoder` which was trained on protein dihedral
222 angles. The dihedrals have to be order starting from the amino end.
223 First all phi angles then all psi angles.
224 :param pdb_path: Path to a protein data bank (pdb) file of the protein
225 :param save_path: Path where outputs should be written
226 :param n_points: Number of points distributed on the selected path.
227 """
228 super(PathGenerateDihedrals, self).__init__(axe, n_points=n_points)
230 self.autoencoder = autoencoder
231 self.pdb_path = pdb_path
233 if save_path:
234 self.save_path = save_path
235 else:
236 self.save_path = autoencoder.p.main_path
238 def use_points(self, points):
239 current_save_path = create_dir(
240 os.path.join(
241 self.save_path,
242 "generated_paths",
243 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
244 )
245 )
246 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".")
247 np.save(os.path.join(current_save_path, "points"), points)
248 generated = self.autoencoder.generate(points)
249 np.save(os.path.join(current_save_path, "generated.npy"), generated)
251 universe = dihedral_backmapping(self.pdb_path, generated)
252 output_pdb_path = os.path.join(current_save_path, "generated.pdb")
253 with md.Writer(output_pdb_path) as w:
254 for step in universe.trajectory:
255 w.write(universe.atoms)
258class PathGenerateCartesians(ManualPath):
259 """
260 This class inherits from :class:`encodermap.plot.ManualPath`.
261 It is used to select paths in a 2d map and to generate conformations for these paths with a
262 AngleDihedralCartesianEncoder.
263 """
265 def __init__(
266 self,
267 axe,
268 autoencoder,
269 mol_data,
270 save_path=None,
271 n_points=200,
272 vmd_path="",
273 align_reference=None,
274 align_select="all",
275 ):
276 """
278 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots()
279 :param autoencoder: :class:`.AngleDihedralCartesianEncoder`
280 :param mol_data: :class:`.MolData`
281 :param save_path: Path where outputs should be written
282 :param n_points: Number of points distributed on the selected path.
283 :param vmd_path: If a path to vmd is given, the generated conformations will be directly opened in vmd.
284 :param align_reference: Allows to allign the generated conformations according to some reference.
285 The reference should be given as MDAnalysis atomgroup
286 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60"
287 default is "all". Have a look at the MDAnalysis selection syntax for more details.
289 """
290 super().__init__(axe, n_points=n_points)
292 self.autoencoder = autoencoder
293 self.mol_data = mol_data
294 self.vmd_path = vmd_path
296 self.align_reference = align_reference
297 self.align_select = align_select
299 if save_path:
300 self.save_path = save_path
301 else:
302 self.save_path = autoencoder.p.main_path
304 def use_points(self, points):
305 current_save_path = create_dir(
306 os.path.join(
307 self.save_path,
308 "generated_paths",
309 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
310 )
311 )
312 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".")
313 np.save(os.path.join(current_save_path, "points"), points)
314 try:
315 dihedrals, cartesians = self.autoencoder.generate(points)
316 except ValueError:
317 angles, dihedrals, cartesians = self.autoencoder.generate(points)
318 np.save(os.path.join(current_save_path, "generated_dihedrals.npy"), dihedrals)
319 np.save(os.path.join(current_save_path, "generated_cartesians.npy"), cartesians)
321 self.mol_data.write(
322 current_save_path,
323 cartesians,
324 only_central=False,
325 align_reference=self.align_reference,
326 align_select=self.align_select,
327 )
328 if self.vmd_path:
329 cmd = "{} {} {}".format(self.vmd_path, "generated.pdb", "generated.xtc")
330 print(cmd)
331 process = subprocess.Popen(
332 cmd,
333 cwd=current_save_path,
334 shell=True,
335 stdin=subprocess.PIPE,
336 stdout=subprocess.PIPE,
337 )
338 process.stdin.write(b"animate delete beg 0 end 0 skip 0 0\n")
339 process.stdin.flush()
342class PathSelect(ManualPath):
343 """
344 This class inherits from :class:`encodermap.plot.ManualPath`.
345 It is used to select areas in a 2d map and to write all conformations in these areas to separate trajectories.
346 """
348 def __init__(
349 self,
350 axe,
351 projected,
352 mol_data,
353 save_path,
354 n_points=200,
355 vmd_path="",
356 align_reference=None,
357 align_select="all",
358 ):
359 """
361 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots()
362 :param projected: points in the map (must be the same number of points as conformations in mol_data)
363 :param mol_data: :class:`.MolData`
364 :param save_path: Path where outputs should be written
365 :param n_points: Number of points distributed on the selected path.
366 :param vmd_path: If a path to vmd is given, the generated conformations will be directly opened in vmd.
367 :param align_reference: Allows to allign the generated conformations according to some reference.
368 The reference should be given as MDAnalysis atomgroup
369 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60"
370 default is "all". Have a look at the MDAnalysis selection syntax for more details.
372 """
373 super().__init__(axe, n_points=n_points)
375 self.mol_data = mol_data
376 self.vmd_path = vmd_path
377 self.projected = projected
378 assert len(projected) == len(mol_data.dihedrals)
380 self.align_reference = align_reference
381 self.align_select = align_select
383 self.save_path = save_path
385 def use_points(self, points):
386 current_save_path = create_dir(
387 os.path.join(
388 self.save_path,
389 "selected",
390 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
391 )
392 )
393 indices = np.nonzero(Path(points).contains_points(self.projected))[0]
394 self.axe.scatter(self.projected[indices, 0], self.projected[indices, 1])
395 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".")
396 np.save(os.path.join(current_save_path, "points"), self.projected[indices, :])
397 np.save(os.path.join(current_save_path, "indices"), indices)
399 self.mol_data.write(
400 current_save_path,
401 self.mol_data.cartesians[indices],
402 only_central=False,
403 align_reference=self.align_reference,
404 align_select=self.align_select,
405 )
406 if self.vmd_path:
407 cmd = "{} {} {}".format(self.vmd_path, "generated.pdb", "generated.xtc")
408 print(cmd)
409 process = subprocess.Popen(
410 cmd,
411 cwd=current_save_path,
412 shell=True,
413 stdin=subprocess.PIPE,
414 stdout=subprocess.PIPE,
415 )
416 process.stdin.write(b"animate delete beg 0 end 0 skip 0 0\n")
417 process.stdin.flush()
420def distance_histogram(
421 data, periodicity, sigmoid_parameters, axes=None, low_d_max=5, bins="auto"
422):
423 """
424 Plots the histogram of all pairwise distances in the data.
425 It also shows the sigmoid function and its normalized derivative.
427 :param data: each row should contain a point in a number_of _columns dimensional space.
428 :param periodicity: Periodicity of the data. use float("inf") for non periodic data
429 :param sigmoid_parameters: tuple (sigma, a, b)
430 :param axes: Array like structure with two matplotlib axe objects ore None. If None a new figure is generated.
431 :param low_d_max: upper limit for plotting the low_d sigmoid
432 :param bins: number of bins for histogram
433 :return: matplotlib axe objects
434 """
435 vecs = periodic_distance_np(
436 np.expand_dims(data, axis=1), np.expand_dims(data, axis=0), periodicity
437 )
438 dists = np.linalg.norm(vecs, axis=2)
439 while True:
440 try:
441 dists = np.linalg.norm(dists, axis=2)
442 except np.AxisError:
443 break
444 dists = dists.reshape(-1)
446 if axes is None:
447 fig, axes = plt.subplots(2)
448 axe2 = axes[0].twinx()
449 counts, edges, patches = axe2.hist(
450 dists, bins=bins, density=True, edgecolor="black"
451 )
452 x = np.linspace(0, max(dists), 1000)
454 y = sigmoid(x, *sigmoid_parameters[:3])
455 edges_sig = sigmoid(edges, *sigmoid_parameters[:3])
456 dy = np.diff(y)
457 dy_norm = dy / max(dy)
458 axes[0].plot(x, y, color="C1", label="sigmoid")
459 axes[0].plot(x[:-1], dy_norm, color="C2", label="diff sigmoid")
461 axes[0].legend()
462 axes[0].set_xlabel("distance")
463 axes[0].set_ylim((0, 1))
464 axes[0].set_zorder(axe2.get_zorder() + 1)
465 axes[0].patch.set_visible(False)
466 axes[0].set_title("high-d")
468 x = np.linspace(0, low_d_max, 1000)
469 y = sigmoid(x, *sigmoid_parameters[3:])
470 dy = np.diff(y)
471 dy_norm = dy / max(dy)
472 idx = np.argmin(
473 np.abs(np.expand_dims(edges_sig, axis=1) - np.expand_dims(y, axis=0)), axis=1
474 )
475 edges_x = x[idx]
477 axes[1].plot(x, y, color="C1", label="sigmoid")
479 axes[1].legend()
480 axes[1].set_xlabel("distance")
481 axes[1].set_ylim((0, 1))
482 axes[1].set_title("low-d")
483 for i in range(len(edges)):
484 if edges_x[i] != edges_x[-1]:
485 axes[1].annotate(
486 "",
487 xy=(edges[i], 0),
488 xytext=(edges_x[i], 0),
489 xycoords=axes[0].transData,
490 textcoords=axes[1].transData,
491 arrowprops=dict(facecolor="black", arrowstyle="-", clip_on=False),
492 )
493 axes[0].figure.tight_layout()
494 return axes[0], axe2, axes[1]