Coverage for encodermap/plot/interactive_plotting.py: 13%
210 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/plot/interactive_plotting.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""ToDo:
23 * Check if vmd rendering works
24 * Animate the path and the torsions inside the Interactive Plotting class.
25 * ToolTip when hovering over buttons.
26 * Path Tool.
27 * Allow a path that selects closest points for points on path. Maybe do both for any given bezier/path path.
28 * Allow autoencoder to be None. Catch the exception that tries to get autoencoder.trajs.
29 * Superpose option with mdtraj not working
30 * Keep the lasso select.
31 * Movie along paths.
33"""
35##############################################################################
36# Imports
37##############################################################################
39import copy
40import os
41from itertools import groupby
43from matplotlib.path import Path
45from .._optional_imports import _optional_import
46from ..autoencoder.autoencoder import AngleDihedralCartesianEncoderMap
47from ..misc.clustering import gen_dummy_traj, get_cluster_frames
48from ..misc.misc import _datetime_windows_and_linux_compatible, all_equal
49from ..plot.plotting import plot_cluster
50from ..trajinfo.info_all import TrajEnsemble
51from ..trajinfo.info_single import SingleTraj
52from .utils import *
53from .utils import _unpack_cluster_info, _unpack_path_info
55##############################################################################
56# Optional Imports
57##############################################################################
60sns = _optional_import("seaborn")
62##############################################################################
63# Classes
64##############################################################################
67class InteractivePlotting:
68 """Class to open up an interactive plotting window.
70 Contains sub-classes to handle user-clickable menus and selectors.
72 Attributes:
73 trajs (encodermap.TrajEnsemble): The trajs passed into this class.
74 fig (matplotlib.figure): The figure plotted onto. If ax is passed when
75 this class is instantiated the parent figure will be fetched with
76 self.fig = self.ax.get_figure()
77 ax (matplotlib.axes): The axes where the lowd data of the trajs
78 is plotted on.
79 menu_ax (matplotlib.axes): The axes where the normal menu is plotted on.
80 status_menu_ax (matplotlib.axes): The axes on which the status menu is plotted on.
81 pts (matplotlib.collections.Collection): The points which are plotted. Based on some
82 other class variables the color of this collection is adjusted.
83 statusmenu (encodermap.plot.utils.StatusMenu): The menu containing the
84 status buttons.
85 menu (encodermap.plot.utils.Menu): The menu containing the remaining buttons.
86 tool (encodermap.plot.utils.SelectFromCollection): The currentlty active
87 tool used to select points. This can be lasso, polygon, etc...
88 mode (str): Current mode of the statusmenu.
90 Examples:
91 >>> sess = ep.InteractivePlotting(trajs)
93 """
95 def __init__(
96 self,
97 autoencoder,
98 trajs=None,
99 data=None,
100 ax=None,
101 align_string="name CA",
102 top=None,
103 hist=False,
104 scatter_kws={"s": 5},
105 ball_and_stick=False,
106 top_index=0,
107 ):
108 """Instantiate the InteractivePlotting class.
110 Args:
111 trajs (encodermap.TrajEnsemble): The trajs of which the lowd info
112 should be plotted.
113 ax (matplotlib.axes, optional): On what axes to plot. If no axes is provided
114 a new figure and axes will be created defaults to None.
116 """
117 # the align string for the cluster dummy method
118 self.align_string = align_string
119 self.top = top
120 self.hist = hist
121 self.autoencoder = autoencoder
122 self.ball_and_stick = ball_and_stick
123 self.top_index = top_index
125 # scatter kws
126 self.scatter_kws = {**{"s": 80, "alpha": 0.5}, **scatter_kws}
128 # close all plots
129 plt.close("all")
131 # decide on fate of data
132 if data is None:
133 if hasattr(trajs, "lowd"):
134 print("Using the attribute `lowd` of provided `trajs`")
135 data = trajs.lowd
136 elif isinstance(trajs, (TrajEnsemble, SingleTraj)) and (
137 isinstance(autoencoder, AngleDihedralCartesianEncoderMap)
138 or autoencoder.__class__.__name__ == "AngleDihedralCartesianEncoderMap"
139 ):
140 print(
141 "Using the provided `autoencoder` and `trajs` to create a projection."
142 )
143 data = autoencoder.encode(trajs)
144 elif isinstance(data, np.ndarray) and hasattr(autoencoder, "encode"):
145 print("Using the `encode` method of `autoencoder` with provided data.")
146 if np.any(np.isnan(data)):
147 import tensorflow as tf
149 indices = np.stack(np.where(~np.isnan(data))).T.astype("int64")
150 dense_shape = data.shape
151 values = data[~np.isnan(data)].flatten().astype("float32")
152 data = tf.sparse.SparseTensor(indices, values, dense_shape)
153 data = autoencoder.encode(data)
154 elif hasattr(autoencoder, "encode"):
155 print("Using the `train_data` attribute of `autoencoder`.")
156 data = autoencoder.encode()
157 else:
158 print("Mocking data with np.random")
159 np.random.seed(19680801)
160 data = np.random.rand(100, 2)
161 if data.shape[1] != 2:
162 print("Using provided `data` to call encoder.")
163 data = autoencoder.encode(data)
164 self.data = data
166 # see what traj has been provided
167 if trajs is None:
168 self.trajs = autoencoder.trajs
169 else:
170 if isinstance(trajs, str):
171 self.trajs = SingleTraj(trajs, self.top, traj_num=0)._gen_ensemble()
172 elif isinstance(trajs, list):
173 self.trajs = TrajEnsemble(trajs, self.top)
174 else:
175 self.trajs = trajs
177 if isinstance(trajs, SingleTraj):
178 if "lowd" not in self.trajs.CVs:
179 selfself.trajs.load_CV(self.data, attr_name="lowd")
180 else:
181 if "lowd" not in self.trajs.CVs:
182 self.trajs.load_CVs(self.data, attr_name="lowd")
184 # decide what function to use to build clusters
185 # Decided against gen_dummy traj as get_cluster_frames works better with jinja2
186 self.cluster_building_fn = get_cluster_frames
188 # create fig and ax
189 if ax is None:
190 # create fig and ax
191 subplot_kw = dict(xlim=(0, 1), ylim=(0, 1), autoscale_on=True)
192 self.fig, self.ax = plt.subplots(
193 1, 1, figsize=(10, 8)
194 ) # subplot_kw=subplot_kw)
195 else:
196 self.ax = ax
197 self.fig = self.ax.get_figure()
199 # add the axes to create the menus on
200 self.fig.subplots_adjust(left=0.3)
201 self.menu_ax = plt.axes([0.05, 0.1, 0.15, 0.35], facecolor="lightblue")
202 self.status_menu_ax = plt.axes(
203 [0.05, 0.49, 0.15, 0.35], facecolor="lightyellow"
204 )
206 # remove everything in these axes
207 self.menu_ax.axis("off")
208 self.status_menu_ax.axis("off")
209 self.tool = DummyTool()
211 # plot
212 self.pts = self.ax.scatter(self.data[:, 0], self.data[:, 1], **self.scatter_kws)
214 # hist
215 if self.hist:
216 self.ax.hist2d(*data.T, bins=400, norm=mpl.colors.LogNorm())
218 # Check whether mouse enters drawing area
219 # Upon entering drawing area tools are initialized based on current mode
220 # Leave event currently serves no purpose
221 self.cid_ax_enter = self.fig.canvas.mpl_connect(
222 "axes_enter_event", self.on_enter_ax
223 )
224 self.cid_ax_leave = self.fig.canvas.mpl_connect(
225 "axes_leave_event", self.on_leave_ax
226 )
228 # chech button presses and compare them with the status of the menuitems
229 self.cid_on_click = self.fig.canvas.mpl_connect(
230 "button_release_event", self.on_click
231 )
233 # Instantiate Menu
234 self.statusmenu = StatusMenu(self.status_menu_ax)
235 self.menu = Menu(self.menu_ax)
237 # Show
238 plt.show()
240 def on_click(self, event):
241 """Decides whether the release event happened in the drawing area or the menu.
243 Args:
244 event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect().
246 """
247 if event.inaxes == self.ax:
248 self.on_click_tool(event)
249 else:
250 self.on_click_menu(event)
252 def on_enter_ax(self, event):
253 """Chosses the tool to use when self.ax is entered, based on current mode.
255 Args:
256 event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect().
258 """
259 # print('Axis is entered')
260 if event.inaxes is self.ax and self.mode == "Idle":
261 # reset point coloration
262 self.pts.set_color("C0")
263 self.tool.disconnect()
264 if event.inaxes is self.ax and self.mode != "Idle":
265 # statusmenu
266 for key, item in self.statusmenu.menuitems.items():
267 if self.mode == key:
268 method = getattr(self, key.lower())
269 method()
271 def on_leave_ax(self, event):
272 """Disconnect the current tool."""
273 pass
275 def on_click_tool(self, event):
276 """Left here for convenience if some tools need a button release event."""
277 pass
279 def on_click_menu(self, event):
280 """Chooses the function to call based on what MenuItem was clicked.
282 Args:
283 event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect().
285 """
286 for key, item in self.menu.menuitems.items():
287 if item.check_select(event):
288 method = getattr(self, key.lower().replace(" ", "_"))
289 method()
291 def reset(self):
292 """Called when 'Reset' is pressed."""
293 if "user_selected_points" in self.trajs.CVs:
294 self.trajs._CVs.drop(labels="user_selected_points")
295 self.__init__(
296 self.trajs,
297 self.autoencoder,
298 self.data,
299 None,
300 self.align_string,
301 self.top,
302 self.hist,
303 self.scatter_kws,
304 self.ball_and_stick,
305 )
307 def write(self):
308 """Called when 'Write' is pressed."""
309 if self.mode == "Idle":
310 return
311 time = _datetime_windows_and_linux_compatible()
312 if self.mode == "Bezier" or self.mode == "Path":
313 os.makedirs(
314 f"{self.autoencoder.p.main_path}/generated_paths/", exist_ok=True
315 )
316 fname = (
317 f"{self.autoencoder.p.main_path}/generated_paths/generated_{time}.pdb"
318 )
319 with mda.Writer(fname) as w:
320 for step in self.uni.trajectory:
321 w.write(self.uni.atoms)
322 self.ax.set_title(
323 f"Generated Path with {len(generated)} points saved at {fname}"
324 )
325 else:
326 if "user_selected_points" not in self.trajs.CVs:
327 self.ax.set_title("First set the points before writing them to disk.")
328 return
329 max_, fname = _unpack_cluster_info(
330 self.trajs,
331 self.autoencoder.p.main_path,
332 self.tool,
333 self.dummy_traj,
334 self.align_string,
335 )
336 self.ax.set_title(f"Cluster {max_} saved at {fname}")
338 def set_points(self):
339 """Called when 'Set Points' is pressed."""
340 if self.mode == "Idle":
341 return
342 if self.mode != "Idle":
343 if "tool" not in self.__dict__.keys():
344 self.ax.set_title(f"Tool {self.mode} not yet implemented.")
345 return
346 else:
347 indices = self.accept()
348 if self.mode == "Bezier" or self.mode == "Path":
349 if np.unique(self.path_points, axis=0).shape[0] != 200:
350 self.ax.set_title(
351 f"Tool {self.mode} returned not the requested unique points."
352 )
353 return
354 self.dummy_traj = self.autoencoder.generate(
355 self.path_points, backend="mdanalysis", top=self.top_index
356 )
357 self.view = ngl.show_mdanalysis(self.dummy_traj)
358 if self.ball_and_stick:
359 self.view.clear_representations()
360 self.view.add_ball_and_stick()
362 self.ax.set_title(
363 f"Generated Path with {len(self.dummy_traj.trajectory)} points is accessible as InteractivePlotting.view."
364 )
365 return
367 if indices is not None and self.mode != "Bezier" and self.mode != "Path":
368 self.ax.set_title(
369 f"Currently working on rendering the cluster. I'll let you know, when I'm finished."
370 )
371 indices = np.asarray(indices)
373 # update user defined clustering
374 col = "user_selected_points"
375 if col not in self.trajs.CVs:
376 _ = np.full(self.trajs.n_frames, -1)
377 try:
378 _[indices] = 0
379 except IndexError as e:
380 print(indices)
381 raise SystemExit from e
382 self.trajs.load_CVs(_, col)
383 else:
384 _ = self.trajs.CVs[col]
385 max_ = _.max()
386 _[indices] = max_ + 1
387 self.trajs.load_CVs(_, col)
389 # change coloration of self.pts
390 color_palette = sns.color_palette("Paired", self.trajs.CVs[col].max() + 1)
391 cluster_colors = [
392 (*color_palette[x], 1) if x >= 0 else (0.5, 0.5, 0.5, 0.01)
393 for x in self.trajs.CVs[col]
394 ]
395 self.pts.set_color(cluster_colors)
397 max_ = np.max(self.trajs.CVs[col])
398 self.view, self.dummy_traj = self.cluster_building_fn(
399 self.trajs,
400 max_,
401 nglview=True,
402 shorten=True,
403 stack_atoms=True,
404 col=col,
405 align_string=self.align_string,
406 ball_and_stick=self.ball_and_stick,
407 )
408 if self.ball_and_stick:
409 for i in range(len(self.dummy_traj)):
410 self.view.clear_representations(component=i)
411 self.view.add_ball_and_stick(component=i)
412 self.ax.set_title(
413 f"Cluster {max_} is accessible as InteractivePlotting.view."
414 )
416 def render_move(self):
417 pass
419 def lasso(self):
420 self.tool = SelectFromCollection(self.ax, self.pts)
422 def rectangle(self):
423 self.tool = SelectFromCollection(self.ax, self.pts, selector=RectangleSelector)
425 def ellipse(self):
426 print("Ellipse not yet implemented")
428 def polygon(self):
429 textstr = "\n".join(
430 (
431 "Select points in the figure by enclosing them within a polygon.",
432 # Press the 'esc' key to start a new polygon.
433 "Try holding the 'shift' key to move all of the vertices.",
434 "Try holding the 'ctrl' key to move a single vertex.",
435 )
436 )
438 # these are matplotlib.patch.Patch properties
439 props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)
441 # place a text box in upper left in axes coords
442 self.manual_text = self.ax.text(
443 0.05,
444 0.95,
445 textstr,
446 transform=self.ax.transAxes,
447 fontsize=6,
448 verticalalignment="top",
449 bbox=props,
450 )
451 self.tool = SelectFromCollection(self.ax, self.pts, selector=PolygonSelector)
453 def path(self):
454 pass
456 def bezier(self):
457 line = Line2D([], [], ls="--", c="#666666", marker="x", mew=2, mec="#204a87")
458 self.ax.add_line(line)
459 self.tool = BezierBuilder(line, self.ax)
461 def accept(self):
462 if "manual_text" in self.__dict__.keys():
463 self.manual_text.set_visible(False)
464 del self.manual_text
465 if self.mode == "Bezier":
466 self.path_points = copy.deepcopy(self.tool.ind)
467 selected_indices = self.tool.ind
468 self.tool.disconnect()
469 return selected_indices
471 @property
472 def cluster_zoomed(self):
473 col = "user_selected_points"
474 if not col in self.trajs.df.keys():
475 return
476 max_ = np.max(self.trajs.df[col])
477 _ = plot_cluster_zoomed(self.trajs, max_, col=col)
478 return _
480 @property
481 def mode(self):
482 return self.statusmenu.status