Coverage for encodermap/encodermap_tf1/plot.py: 14%
230 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1import datetime
2import os
3import subprocess
4from itertools import cycle
6import matplotlib.pyplot as plt
7import MDAnalysis as md
8import numpy as np
9from matplotlib.lines import Line2D
10from matplotlib.path import Path
11from matplotlib.widgets import Lasso
12from scipy.interpolate import interp1d
14from .backmapping import dihedral_backmapping
15from .misc import create_dir, periodic_distance_np, sigmoid
18class ManualPath(object):
19 """
20 ManualPath is a tool to manually select a path in a matplotlib graph.
21 It supports two modes: "interpolated line", and "free draw".
22 Press "m" to switch modes.
24 In interpolated line mode click in the graph to add an additional way point.
25 Press "delete" to remove the last way point.
26 Press "d" to remove all way points.
27 Press "enter" once you have finished your path selection.
29 In free draw mode press and hold the left mouse button while you draw a path.
31 Once the path selection is completed, the use_points method is called with the points on the selected path.
32 You can overwrite the use_points method to do what ever you want with the points on the path.
33 """
35 def __init__(self, axe, n_points=200):
36 """
38 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots()
39 :param n_points: Number of points distributed on the selected path.
40 """
41 self.axe = axe
42 self.canvas = axe.figure.canvas
43 self.fig = axe.figure
44 self.lasso = None
45 self.n_interpolation_points = n_points
47 self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox)
49 control_line = Line2D(
50 [], [], ls="--", c="#666666", marker="x", mew=2, mec="#204a87"
51 )
52 self.control_line = self.axe.add_line(control_line)
53 self.x_control = list(self.control_line.get_xdata())
54 self.y_control = list(self.control_line.get_ydata())
56 interpolated_line = Line2D([], [], c=self.control_line.get_markeredgecolor())
57 self.interpolated_line = self.axe.add_line(interpolated_line)
59 self.modes = cycle([self._add_point_interp, self._free_draw])
60 self.mode = next(self.modes)
61 self.key_cid = self.canvas.mpl_connect("key_press_event", self._on_key)
62 self.click_cid = self.canvas.mpl_connect("button_press_event", self._on_click)
63 self.draw_cid = self.canvas.mpl_connect("draw_event", self._grab_background)
65 def _on_click(self, event):
66 # Ignore clicks outside axes
67 if event.inaxes != self.axe:
68 return
69 if self.canvas.widgetlock.locked():
70 return
71 if event.inaxes is None:
72 return
73 if event.button == 1:
74 self.mode(event)
76 def _on_key(self, event):
77 if event.key == "m":
78 self.mode = next(self.modes)
79 return
80 if event.key == "enter":
81 points = np.array(self.interpolated_line.get_data()).T
82 self._reset_lines()
83 self.use_points(points)
84 self._grab_background()
86 return
87 if event.key == "d":
88 self._reset_lines()
89 return
90 if event.key == "delete":
91 del self.x_control[-1]
92 del self.y_control[-1]
93 self._update_interp()
95 def _free_draw(self, event):
96 self.lasso = Lasso(
97 event.inaxes, (event.xdata, event.ydata), self._free_draw_callback
98 )
99 # acquire a lock on the widget drawing
100 self.canvas.widgetlock(self.lasso)
102 def _free_draw_callback(self, verts):
103 points = np.array(verts)
104 self.use_points(points)
105 self.canvas.draw_idle()
106 self.canvas.widgetlock.release(self.lasso)
107 del self.lasso
109 def _add_point_interp(self, event):
110 self.x_control.append(event.xdata)
111 self.y_control.append(event.ydata)
112 self._update_interp()
114 def _update_interp(self):
115 self.control_line.set_data(self.x_control, self.y_control)
116 x_i, y_i = self._interpolate(self.x_control, self.y_control)
117 x_i, y_i = self._interpolate(
118 x_i, y_i
119 ) # second iteration makes points more evenly spaced
120 self.interpolated_line.set_data(x_i, y_i)
121 self._update_lines()
123 def _interpolate(self, x, y):
124 cumulative_distances = [0]
125 for i in range(1, len(x)):
126 dist = ((x[i] - x[i - 1]) ** 2 + (y[i] - y[i - 1]) ** 2) ** 0.5
127 cumulative_distances.append(cumulative_distances[-1] + dist)
128 interp_i = np.linspace(
129 0, max(cumulative_distances), self.n_interpolation_points
130 )
131 try:
132 x_i = interp1d(cumulative_distances, x, kind="cubic")(interp_i)
133 y_i = interp1d(cumulative_distances, y, kind="cubic")(interp_i)
134 except ValueError:
135 try:
136 x_i = interp1d(cumulative_distances, x, kind="linear")(interp_i)
137 y_i = interp1d(cumulative_distances, y, kind="linear")(interp_i)
138 except ValueError:
139 x_i = []
140 y_i = []
141 return x_i, y_i
143 def use_points(self, points):
144 """
145 Overwrite this method to use the selected points in any way you like.
147 For Example:
149 >>> class MyManualPath(ManualPath):
150 >>> def use_points(self, points):
151 >>> print(points)
153 :param points: numpy array with points from the manual path selection
154 :return: None
155 """
156 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".")
157 raise NotImplementedError("has to be implemented in subclass")
159 def _grab_background(self, event=None):
160 """
161 When the figure is resized, hide the points, draw everything,
162 and update the background.
163 """
164 # Thanks to: https://stackoverflow.com/questions/29277080/efficient-matplotlib-redrawing
165 self.canvas.mpl_disconnect(self.draw_cid)
166 self.interpolated_line.set_visible(False)
167 self.control_line.set_visible(False)
168 self.canvas.draw()
170 # With most backends (e.g. TkAgg), we could grab (and refresh, in
171 # self.blit) self.ax.bbox instead of self.fig.bbox, but Qt4Agg, and
172 # some others, requires us to update the _full_ canvas, instead.
173 self.background = self.fig.canvas.copy_from_bbox(self.fig.bbox)
175 self.interpolated_line.set_visible(True)
176 self.control_line.set_visible(True)
177 self.canvas.draw()
178 self.draw_cid = self.canvas.mpl_connect("draw_event", self._grab_background)
180 def _update_lines(self):
181 """
182 Efficiently update the figure, without needing to redraw the
183 "background" artists.
184 """
185 self.fig.canvas.restore_region(self.background)
186 self.axe.draw_artist(self.interpolated_line)
187 self.axe.draw_artist(self.control_line)
188 self.canvas.draw() # Todo: canvas.draw() works in jupyter notebook but canvas.update() outside of notebook
189 self.canvas.flush_events()
191 def _reset_lines(self):
192 self.interpolated_line.set_data([[], []])
193 self.control_line.set_data([[], []])
194 self._update_lines()
196 self.x_control = []
197 self.y_control = []
200class PathGenerateDihedrals(ManualPath):
201 """
202 This class inherits from :py:class:`encodermap.plot.ManualPath`.
203 The points from a manually selected path are fed into the decoder part of a given autoencoder.
204 The output of the autoencoder is used as phi psi dihedral angles to reconstruct protein conformations
205 based on the protein structure given with pdb_path.
206 Three output files are written for each selected path:
207 points.npy, generated.npy and generated.pdb which contain:
208 the points on the selected path, the generated output of
209 the autoencoder, and the generated protein conformations respectively.
210 Keep in mind that backbone dihedrals are not sufficient to describe a protein conformation completely.
211 Usually the backbone is reconstructed well but all side chains are messed up.
212 """
214 def __init__(self, axe, autoencoder, pdb_path, save_path=None, n_points=200):
215 """
217 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots()
218 :param autoencoder: :py:class:`encodermap.autoencoder.Autoencoder` which was trained on protein dihedral
219 angles. The dihedrals have to be order starting from the amino end.
220 First all phi angles then all psi angles.
221 :param pdb_path: Path to a protein data bank (pdb) file of the protein
222 :param save_path: Path where outputs should be written
223 :param n_points: Number of points distributed on the selected path.
224 """
225 super(PathGenerateDihedrals, self).__init__(axe, n_points=n_points)
227 self.autoencoder = autoencoder
228 self.pdb_path = pdb_path
230 if save_path:
231 self.save_path = save_path
232 else:
233 self.save_path = autoencoder.p.main_path
235 def use_points(self, points):
236 current_save_path = create_dir(
237 os.path.join(
238 self.save_path,
239 "generated_paths",
240 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
241 )
242 )
243 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".")
244 np.save(os.path.join(current_save_path, "points"), points)
245 generated = self.autoencoder.generate(points)
246 np.save(os.path.join(current_save_path, "generated.npy"), generated)
248 universe = dihedral_backmapping(self.pdb_path, generated)
249 output_pdb_path = os.path.join(current_save_path, "generated.pdb")
250 with md.Writer(output_pdb_path) as w:
251 for step in universe.trajectory:
252 w.write(universe.atoms)
255class PathGenerateCartesians(ManualPath):
256 """
257 This class inherits from :class:`encodermap.plot.ManualPath`.
258 It is used to select paths in a 2d map and to generate conformations for these paths with a
259 AngleDihedralCartesianEncoder.
260 """
262 def __init__(
263 self,
264 axe,
265 autoencoder,
266 mol_data,
267 save_path=None,
268 n_points=200,
269 vmd_path="",
270 align_reference=None,
271 align_select="all",
272 ):
273 """
275 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots()
276 :param autoencoder: :class:`.AngleDihedralCartesianEncoder`
277 :param mol_data: :class:`.MolData`
278 :param save_path: Path where outputs should be written
279 :param n_points: Number of points distributed on the selected path.
280 :param vmd_path: If a path to vmd is given, the generated conformations will be directly opened in vmd.
281 :param align_reference: Allows to allign the generated conformations according to some reference.
282 The reference should be given as MDAnalysis atomgroup
283 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60"
284 default is "all". Have a look at the MDAnalysis selection syntax for more details.
286 """
287 super().__init__(axe, n_points=n_points)
289 self.autoencoder = autoencoder
290 self.mol_data = mol_data
291 self.vmd_path = vmd_path
293 self.align_reference = align_reference
294 self.align_select = align_select
296 if save_path:
297 self.save_path = save_path
298 else:
299 self.save_path = autoencoder.p.main_path
301 def use_points(self, points):
302 current_save_path = create_dir(
303 os.path.join(
304 self.save_path,
305 "generated_paths",
306 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
307 )
308 )
309 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".")
310 np.save(os.path.join(current_save_path, "points"), points)
311 try:
312 dihedrals, cartesians = self.autoencoder.generate(points)
313 except ValueError:
314 angles, dihedrals, cartesians = self.autoencoder.generate(points)
315 np.save(os.path.join(current_save_path, "generated_dihedrals.npy"), dihedrals)
316 np.save(os.path.join(current_save_path, "generated_cartesians.npy"), cartesians)
318 self.mol_data.write(
319 current_save_path,
320 cartesians,
321 only_central=False,
322 align_reference=self.align_reference,
323 align_select=self.align_select,
324 )
325 if self.vmd_path:
326 cmd = "{} {} {}".format(self.vmd_path, "generated.pdb", "generated.xtc")
327 print(cmd)
328 process = subprocess.Popen(
329 cmd,
330 cwd=current_save_path,
331 shell=True,
332 stdin=subprocess.PIPE,
333 stdout=subprocess.PIPE,
334 )
335 process.stdin.write(b"animate delete beg 0 end 0 skip 0 0\n")
336 process.stdin.flush()
339class PathSelect(ManualPath):
340 """
341 This class inherits from :class:`encodermap.plot.ManualPath`.
342 It is used to select areas in a 2d map and to write all conformations in these areas to separate trajectories.
343 """
345 def __init__(
346 self,
347 axe,
348 projected,
349 mol_data,
350 save_path,
351 n_points=200,
352 vmd_path="",
353 align_reference=None,
354 align_select="all",
355 ):
356 """
358 :param axe: matplotlib axe object for example from: fig, axe = plt.subplots()
359 :param projected: points in the map (must be the same number of points as conformations in mol_data)
360 :param mol_data: :class:`.MolData`
361 :param save_path: Path where outputs should be written
362 :param n_points: Number of points distributed on the selected path.
363 :param vmd_path: If a path to vmd is given, the generated conformations will be directly opened in vmd.
364 :param align_reference: Allows to allign the generated conformations according to some reference.
365 The reference should be given as MDAnalysis atomgroup
366 :param align_select: Allows to select which atoms should be used for the alignment. e.g. "resid 5:60"
367 default is "all". Have a look at the MDAnalysis selection syntax for more details.
369 """
370 super().__init__(axe, n_points=n_points)
372 self.mol_data = mol_data
373 self.vmd_path = vmd_path
374 self.projected = projected
375 assert len(projected) == len(mol_data.dihedrals)
377 self.align_reference = align_reference
378 self.align_select = align_select
380 self.save_path = save_path
382 def use_points(self, points):
383 current_save_path = create_dir(
384 os.path.join(
385 self.save_path,
386 "selected",
387 datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
388 )
389 )
390 indices = np.nonzero(Path(points).contains_points(self.projected))[0]
391 self.axe.scatter(self.projected[indices, 0], self.projected[indices, 1])
392 self.axe.plot(points[:, 0], points[:, 1], linestyle="", marker=".")
393 np.save(os.path.join(current_save_path, "points"), self.projected[indices, :])
394 np.save(os.path.join(current_save_path, "indices"), indices)
396 self.mol_data.write(
397 current_save_path,
398 self.mol_data.cartesians[indices],
399 only_central=False,
400 align_reference=self.align_reference,
401 align_select=self.align_select,
402 )
403 if self.vmd_path:
404 cmd = "{} {} {}".format(self.vmd_path, "generated.pdb", "generated.xtc")
405 print(cmd)
406 process = subprocess.Popen(
407 cmd,
408 cwd=current_save_path,
409 shell=True,
410 stdin=subprocess.PIPE,
411 stdout=subprocess.PIPE,
412 )
413 process.stdin.write(b"animate delete beg 0 end 0 skip 0 0\n")
414 process.stdin.flush()
417def distance_histogram(
418 data, periodicity, sigmoid_parameters, axes=None, low_d_max=5, bins="auto"
419):
420 """
421 Plots the histogram of all pairwise distances in the data.
422 It also shows the sigmoid function and its normalized derivative.
424 :param data: each row should contain a point in a number_of _columns dimensional space.
425 :param periodicity: Periodicity of the data. use float("inf") for non periodic data
426 :param sigmoid_parameters: tuple (sigma, a, b)
427 :param axes: Array like structure with two matplotlib axe objects ore None. If None a new figure is generated.
428 :param low_d_max: upper limit for plotting the low_d sigmoid
429 :param bins: number of bins for histogram
430 :return: matplotlib axe objects
431 """
432 vecs = periodic_distance_np(
433 np.expand_dims(data, axis=1), np.expand_dims(data, axis=0), periodicity
434 )
435 dists = np.linalg.norm(vecs, axis=2)
436 while True:
437 try:
438 dists = np.linalg.norm(dists, axis=2)
439 except np.AxisError:
440 break
441 dists = dists.reshape(-1)
443 if axes is None:
444 fig, axes = plt.subplots(2)
445 axe2 = axes[0].twinx()
446 counts, edges, patches = axe2.hist(
447 dists, bins=bins, density=True, edgecolor="black"
448 )
449 x = np.linspace(0, max(dists), 1000)
451 y = sigmoid(x, *sigmoid_parameters[:3])
452 edges_sig = sigmoid(edges, *sigmoid_parameters[:3])
453 dy = np.diff(y)
454 dy_norm = dy / max(dy)
455 axes[0].plot(x, y, color="C1", label="sigmoid")
456 axes[0].plot(x[:-1], dy_norm, color="C2", label="diff sigmoid")
458 axes[0].legend()
459 axes[0].set_xlabel("distance")
460 axes[0].set_ylim((0, 1))
461 axes[0].set_zorder(axe2.get_zorder() + 1)
462 axes[0].patch.set_visible(False)
463 axes[0].set_title("high-d")
465 x = np.linspace(0, low_d_max, 1000)
466 y = sigmoid(x, *sigmoid_parameters[3:])
467 dy = np.diff(y)
468 dy_norm = dy / max(dy)
469 idx = np.argmin(
470 np.abs(np.expand_dims(edges_sig, axis=1) - np.expand_dims(y, axis=0)), axis=1
471 )
472 edges_x = x[idx]
474 axes[1].plot(x, y, color="C1", label="sigmoid")
476 axes[1].legend()
477 axes[1].set_xlabel("distance")
478 axes[1].set_ylim((0, 1))
479 axes[1].set_title("low-d")
480 for i in range(len(edges)):
481 if edges_x[i] != edges_x[-1]:
482 axes[1].annotate(
483 "",
484 xy=(edges[i], 0),
485 xytext=(edges_x[i], 0),
486 xycoords=axes[0].transData,
487 textcoords=axes[1].transData,
488 arrowprops=dict(facecolor="black", arrowstyle="-", clip_on=False),
489 )
490 axes[0].figure.tight_layout()
491 return axes[0], axe2, axes[1]