Coverage for encodermap/plot/interactive_plotting.py: 13%

210 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/plot/interactive_plotting.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""ToDo: 

23 * Check if vmd rendering works 

24 * Animate the path and the torsions inside the Interactive Plotting class. 

25 * ToolTip when hovering over buttons. 

26 * Path Tool. 

27 * Allow a path that selects closest points for points on path. Maybe do both for any given bezier/path path. 

28 * Allow autoencoder to be None. Catch the exception that tries to get autoencoder.trajs. 

29 * Superpose option with mdtraj not working 

30 * Keep the lasso select. 

31 * Movie along paths. 

32 

33""" 

34 

35############################################################################## 

36# Imports 

37############################################################################## 

38 

39import copy 

40import os 

41from itertools import groupby 

42 

43from matplotlib.path import Path 

44 

45from .._optional_imports import _optional_import 

46from ..autoencoder.autoencoder import AngleDihedralCartesianEncoderMap 

47from ..misc.clustering import gen_dummy_traj, get_cluster_frames 

48from ..misc.misc import _datetime_windows_and_linux_compatible, all_equal 

49from ..plot.plotting import plot_cluster 

50from ..trajinfo.info_all import TrajEnsemble 

51from ..trajinfo.info_single import SingleTraj 

52from .utils import * 

53from .utils import _unpack_cluster_info, _unpack_path_info 

54 

55############################################################################## 

56# Optional Imports 

57############################################################################## 

58 

59 

60sns = _optional_import("seaborn") 

61 

62############################################################################## 

63# Classes 

64############################################################################## 

65 

66 

67class InteractivePlotting: 

68 """Class to open up an interactive plotting window. 

69 

70 Contains sub-classes to handle user-clickable menus and selectors. 

71 

72 Attributes: 

73 trajs (encodermap.TrajEnsemble): The trajs passed into this class. 

74 fig (matplotlib.figure): The figure plotted onto. If ax is passed when 

75 this class is instantiated the parent figure will be fetched with 

76 self.fig = self.ax.get_figure() 

77 ax (matplotlib.axes): The axes where the lowd data of the trajs 

78 is plotted on. 

79 menu_ax (matplotlib.axes): The axes where the normal menu is plotted on. 

80 status_menu_ax (matplotlib.axes): The axes on which the status menu is plotted on. 

81 pts (matplotlib.collections.Collection): The points which are plotted. Based on some 

82 other class variables the color of this collection is adjusted. 

83 statusmenu (encodermap.plot.utils.StatusMenu): The menu containing the 

84 status buttons. 

85 menu (encodermap.plot.utils.Menu): The menu containing the remaining buttons. 

86 tool (encodermap.plot.utils.SelectFromCollection): The currentlty active 

87 tool used to select points. This can be lasso, polygon, etc... 

88 mode (str): Current mode of the statusmenu. 

89 

90 Examples: 

91 >>> sess = ep.InteractivePlotting(trajs) 

92 

93 """ 

94 

95 def __init__( 

96 self, 

97 autoencoder, 

98 trajs=None, 

99 data=None, 

100 ax=None, 

101 align_string="name CA", 

102 top=None, 

103 hist=False, 

104 scatter_kws={"s": 5}, 

105 ball_and_stick=False, 

106 top_index=0, 

107 ): 

108 """Instantiate the InteractivePlotting class. 

109 

110 Args: 

111 trajs (encodermap.TrajEnsemble): The trajs of which the lowd info 

112 should be plotted. 

113 ax (matplotlib.axes, optional): On what axes to plot. If no axes is provided 

114 a new figure and axes will be created defaults to None. 

115 

116 """ 

117 # the align string for the cluster dummy method 

118 self.align_string = align_string 

119 self.top = top 

120 self.hist = hist 

121 self.autoencoder = autoencoder 

122 self.ball_and_stick = ball_and_stick 

123 self.top_index = top_index 

124 

125 # scatter kws 

126 self.scatter_kws = {**{"s": 80, "alpha": 0.5}, **scatter_kws} 

127 

128 # close all plots 

129 plt.close("all") 

130 

131 # decide on fate of data 

132 if data is None: 

133 if hasattr(trajs, "lowd"): 

134 print("Using the attribute `lowd` of provided `trajs`") 

135 data = trajs.lowd 

136 elif isinstance(trajs, (TrajEnsemble, SingleTraj)) and ( 

137 isinstance(autoencoder, AngleDihedralCartesianEncoderMap) 

138 or autoencoder.__class__.__name__ == "AngleDihedralCartesianEncoderMap" 

139 ): 

140 print( 

141 "Using the provided `autoencoder` and `trajs` to create a projection." 

142 ) 

143 data = autoencoder.encode(trajs) 

144 elif isinstance(data, np.ndarray) and hasattr(autoencoder, "encode"): 

145 print("Using the `encode` method of `autoencoder` with provided data.") 

146 if np.any(np.isnan(data)): 

147 import tensorflow as tf 

148 

149 indices = np.stack(np.where(~np.isnan(data))).T.astype("int64") 

150 dense_shape = data.shape 

151 values = data[~np.isnan(data)].flatten().astype("float32") 

152 data = tf.sparse.SparseTensor(indices, values, dense_shape) 

153 data = autoencoder.encode(data) 

154 elif hasattr(autoencoder, "encode"): 

155 print("Using the `train_data` attribute of `autoencoder`.") 

156 data = autoencoder.encode() 

157 else: 

158 print("Mocking data with np.random") 

159 np.random.seed(19680801) 

160 data = np.random.rand(100, 2) 

161 if data.shape[1] != 2: 

162 print("Using provided `data` to call encoder.") 

163 data = autoencoder.encode(data) 

164 self.data = data 

165 

166 # see what traj has been provided 

167 if trajs is None: 

168 self.trajs = autoencoder.trajs 

169 else: 

170 if isinstance(trajs, str): 

171 self.trajs = SingleTraj(trajs, self.top, traj_num=0)._gen_ensemble() 

172 elif isinstance(trajs, list): 

173 self.trajs = TrajEnsemble(trajs, self.top) 

174 else: 

175 self.trajs = trajs 

176 

177 if isinstance(trajs, SingleTraj): 

178 if "lowd" not in self.trajs.CVs: 

179 selfself.trajs.load_CV(self.data, attr_name="lowd") 

180 else: 

181 if "lowd" not in self.trajs.CVs: 

182 self.trajs.load_CVs(self.data, attr_name="lowd") 

183 

184 # decide what function to use to build clusters 

185 # Decided against gen_dummy traj as get_cluster_frames works better with jinja2 

186 self.cluster_building_fn = get_cluster_frames 

187 

188 # create fig and ax 

189 if ax is None: 

190 # create fig and ax 

191 subplot_kw = dict(xlim=(0, 1), ylim=(0, 1), autoscale_on=True) 

192 self.fig, self.ax = plt.subplots( 

193 1, 1, figsize=(10, 8) 

194 ) # subplot_kw=subplot_kw) 

195 else: 

196 self.ax = ax 

197 self.fig = self.ax.get_figure() 

198 

199 # add the axes to create the menus on 

200 self.fig.subplots_adjust(left=0.3) 

201 self.menu_ax = plt.axes([0.05, 0.1, 0.15, 0.35], facecolor="lightblue") 

202 self.status_menu_ax = plt.axes( 

203 [0.05, 0.49, 0.15, 0.35], facecolor="lightyellow" 

204 ) 

205 

206 # remove everything in these axes 

207 self.menu_ax.axis("off") 

208 self.status_menu_ax.axis("off") 

209 self.tool = DummyTool() 

210 

211 # plot 

212 self.pts = self.ax.scatter(self.data[:, 0], self.data[:, 1], **self.scatter_kws) 

213 

214 # hist 

215 if self.hist: 

216 self.ax.hist2d(*data.T, bins=400, norm=mpl.colors.LogNorm()) 

217 

218 # Check whether mouse enters drawing area 

219 # Upon entering drawing area tools are initialized based on current mode 

220 # Leave event currently serves no purpose 

221 self.cid_ax_enter = self.fig.canvas.mpl_connect( 

222 "axes_enter_event", self.on_enter_ax 

223 ) 

224 self.cid_ax_leave = self.fig.canvas.mpl_connect( 

225 "axes_leave_event", self.on_leave_ax 

226 ) 

227 

228 # chech button presses and compare them with the status of the menuitems 

229 self.cid_on_click = self.fig.canvas.mpl_connect( 

230 "button_release_event", self.on_click 

231 ) 

232 

233 # Instantiate Menu 

234 self.statusmenu = StatusMenu(self.status_menu_ax) 

235 self.menu = Menu(self.menu_ax) 

236 

237 # Show 

238 plt.show() 

239 

240 def on_click(self, event): 

241 """Decides whether the release event happened in the drawing area or the menu. 

242 

243 Args: 

244 event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect(). 

245 

246 """ 

247 if event.inaxes == self.ax: 

248 self.on_click_tool(event) 

249 else: 

250 self.on_click_menu(event) 

251 

252 def on_enter_ax(self, event): 

253 """Chosses the tool to use when self.ax is entered, based on current mode. 

254 

255 Args: 

256 event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect(). 

257 

258 """ 

259 # print('Axis is entered') 

260 if event.inaxes is self.ax and self.mode == "Idle": 

261 # reset point coloration 

262 self.pts.set_color("C0") 

263 self.tool.disconnect() 

264 if event.inaxes is self.ax and self.mode != "Idle": 

265 # statusmenu 

266 for key, item in self.statusmenu.menuitems.items(): 

267 if self.mode == key: 

268 method = getattr(self, key.lower()) 

269 method() 

270 

271 def on_leave_ax(self, event): 

272 """Disconnect the current tool.""" 

273 pass 

274 

275 def on_click_tool(self, event): 

276 """Left here for convenience if some tools need a button release event.""" 

277 pass 

278 

279 def on_click_menu(self, event): 

280 """Chooses the function to call based on what MenuItem was clicked. 

281 

282 Args: 

283 event (matplotlib.backend_bases.Event): The event provided by figure.canvas.connect(). 

284 

285 """ 

286 for key, item in self.menu.menuitems.items(): 

287 if item.check_select(event): 

288 method = getattr(self, key.lower().replace(" ", "_")) 

289 method() 

290 

291 def reset(self): 

292 """Called when 'Reset' is pressed.""" 

293 if "user_selected_points" in self.trajs.CVs: 

294 self.trajs._CVs.drop(labels="user_selected_points") 

295 self.__init__( 

296 self.trajs, 

297 self.autoencoder, 

298 self.data, 

299 None, 

300 self.align_string, 

301 self.top, 

302 self.hist, 

303 self.scatter_kws, 

304 self.ball_and_stick, 

305 ) 

306 

307 def write(self): 

308 """Called when 'Write' is pressed.""" 

309 if self.mode == "Idle": 

310 return 

311 time = _datetime_windows_and_linux_compatible() 

312 if self.mode == "Bezier" or self.mode == "Path": 

313 os.makedirs( 

314 f"{self.autoencoder.p.main_path}/generated_paths/", exist_ok=True 

315 ) 

316 fname = ( 

317 f"{self.autoencoder.p.main_path}/generated_paths/generated_{time}.pdb" 

318 ) 

319 with mda.Writer(fname) as w: 

320 for step in self.uni.trajectory: 

321 w.write(self.uni.atoms) 

322 self.ax.set_title( 

323 f"Generated Path with {len(generated)} points saved at {fname}" 

324 ) 

325 else: 

326 if "user_selected_points" not in self.trajs.CVs: 

327 self.ax.set_title("First set the points before writing them to disk.") 

328 return 

329 max_, fname = _unpack_cluster_info( 

330 self.trajs, 

331 self.autoencoder.p.main_path, 

332 self.tool, 

333 self.dummy_traj, 

334 self.align_string, 

335 ) 

336 self.ax.set_title(f"Cluster {max_} saved at {fname}") 

337 

338 def set_points(self): 

339 """Called when 'Set Points' is pressed.""" 

340 if self.mode == "Idle": 

341 return 

342 if self.mode != "Idle": 

343 if "tool" not in self.__dict__.keys(): 

344 self.ax.set_title(f"Tool {self.mode} not yet implemented.") 

345 return 

346 else: 

347 indices = self.accept() 

348 if self.mode == "Bezier" or self.mode == "Path": 

349 if np.unique(self.path_points, axis=0).shape[0] != 200: 

350 self.ax.set_title( 

351 f"Tool {self.mode} returned not the requested unique points." 

352 ) 

353 return 

354 self.dummy_traj = self.autoencoder.generate( 

355 self.path_points, backend="mdanalysis", top=self.top_index 

356 ) 

357 self.view = ngl.show_mdanalysis(self.dummy_traj) 

358 if self.ball_and_stick: 

359 self.view.clear_representations() 

360 self.view.add_ball_and_stick() 

361 

362 self.ax.set_title( 

363 f"Generated Path with {len(self.dummy_traj.trajectory)} points is accessible as InteractivePlotting.view." 

364 ) 

365 return 

366 

367 if indices is not None and self.mode != "Bezier" and self.mode != "Path": 

368 self.ax.set_title( 

369 f"Currently working on rendering the cluster. I'll let you know, when I'm finished." 

370 ) 

371 indices = np.asarray(indices) 

372 

373 # update user defined clustering 

374 col = "user_selected_points" 

375 if col not in self.trajs.CVs: 

376 _ = np.full(self.trajs.n_frames, -1) 

377 try: 

378 _[indices] = 0 

379 except IndexError as e: 

380 print(indices) 

381 raise SystemExit from e 

382 self.trajs.load_CVs(_, col) 

383 else: 

384 _ = self.trajs.CVs[col] 

385 max_ = _.max() 

386 _[indices] = max_ + 1 

387 self.trajs.load_CVs(_, col) 

388 

389 # change coloration of self.pts 

390 color_palette = sns.color_palette("Paired", self.trajs.CVs[col].max() + 1) 

391 cluster_colors = [ 

392 (*color_palette[x], 1) if x >= 0 else (0.5, 0.5, 0.5, 0.01) 

393 for x in self.trajs.CVs[col] 

394 ] 

395 self.pts.set_color(cluster_colors) 

396 

397 max_ = np.max(self.trajs.CVs[col]) 

398 self.view, self.dummy_traj = self.cluster_building_fn( 

399 self.trajs, 

400 max_, 

401 nglview=True, 

402 shorten=True, 

403 stack_atoms=True, 

404 col=col, 

405 align_string=self.align_string, 

406 ball_and_stick=self.ball_and_stick, 

407 ) 

408 if self.ball_and_stick: 

409 for i in range(len(self.dummy_traj)): 

410 self.view.clear_representations(component=i) 

411 self.view.add_ball_and_stick(component=i) 

412 self.ax.set_title( 

413 f"Cluster {max_} is accessible as InteractivePlotting.view." 

414 ) 

415 

416 def render_move(self): 

417 pass 

418 

419 def lasso(self): 

420 self.tool = SelectFromCollection(self.ax, self.pts) 

421 

422 def rectangle(self): 

423 self.tool = SelectFromCollection(self.ax, self.pts, selector=RectangleSelector) 

424 

425 def ellipse(self): 

426 print("Ellipse not yet implemented") 

427 

428 def polygon(self): 

429 textstr = "\n".join( 

430 ( 

431 "Select points in the figure by enclosing them within a polygon.", 

432 # Press the 'esc' key to start a new polygon. 

433 "Try holding the 'shift' key to move all of the vertices.", 

434 "Try holding the 'ctrl' key to move a single vertex.", 

435 ) 

436 ) 

437 

438 # these are matplotlib.patch.Patch properties 

439 props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) 

440 

441 # place a text box in upper left in axes coords 

442 self.manual_text = self.ax.text( 

443 0.05, 

444 0.95, 

445 textstr, 

446 transform=self.ax.transAxes, 

447 fontsize=6, 

448 verticalalignment="top", 

449 bbox=props, 

450 ) 

451 self.tool = SelectFromCollection(self.ax, self.pts, selector=PolygonSelector) 

452 

453 def path(self): 

454 pass 

455 

456 def bezier(self): 

457 line = Line2D([], [], ls="--", c="#666666", marker="x", mew=2, mec="#204a87") 

458 self.ax.add_line(line) 

459 self.tool = BezierBuilder(line, self.ax) 

460 

461 def accept(self): 

462 if "manual_text" in self.__dict__.keys(): 

463 self.manual_text.set_visible(False) 

464 del self.manual_text 

465 if self.mode == "Bezier": 

466 self.path_points = copy.deepcopy(self.tool.ind) 

467 selected_indices = self.tool.ind 

468 self.tool.disconnect() 

469 return selected_indices 

470 

471 @property 

472 def cluster_zoomed(self): 

473 col = "user_selected_points" 

474 if not col in self.trajs.df.keys(): 

475 return 

476 max_ = np.max(self.trajs.df[col]) 

477 _ = plot_cluster_zoomed(self.trajs, max_, col=col) 

478 return _ 

479 

480 @property 

481 def mode(self): 

482 return self.statusmenu.status