Coverage for encodermap/plot/dashboard.py: 16%

381 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/plotting/dashboard.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""EncoderMap's dashboard. Explore and understand your MD data. 

23 

24""" 

25 

26 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32# Future Imports at the top 

33from __future__ import annotations 

34 

35# Standard Library Imports 

36import builtins 

37import json 

38import os 

39import random 

40from io import StringIO 

41from pathlib import Path 

42from typing import TYPE_CHECKING, Any, Final, Literal, Optional, Union 

43 

44# Third Party Imports 

45import dash 

46import dash_auth 

47import dash_bio as dashbio 

48import dash_bootstrap_components as dbc 

49import dash_mantine_components as dmc 

50import flask 

51import plotly.graph_objs as go 

52from dash import Dash, Input, Output, State, callback, ctx, dcc, html 

53from dash.exceptions import PreventUpdate 

54from dash_bootstrap_templates import ThemeChangerAIO, template_from_url 

55 

56# Encodermap imports 

57from encodermap import load_project 

58from encodermap.kondata import get_assign_from_file 

59from encodermap.plot.interactive_plotting import InteractivePlotting 

60from encodermap.plot.plotting import _plot_ball_and_stick 

61 

62 

63################################################################################ 

64# Typing 

65################################################################################ 

66 

67 

68if TYPE_CHECKING: 

69 # Third Party Imports 

70 import mdtraj as md 

71 

72 

73################################################################################ 

74# Globals 

75################################################################################ 

76 

77 

78################################################################################ 

79# Helpers 

80################################################################################ 

81 

82 

83def _redefine_open() -> None: 

84 """Redefines the `open()` builtin to trick MDTraj to use `ReplaceOpenWithStringIO`. 

85 

86 MDTraj can't save to a StringIO object, because it checks the str provided 

87 in the filename argument for the extension and decides the format based on that 

88 (even using the `save_pdb()` function checks the extension). As StringIO objects 

89 can't be used like a string this function together with `_redefine_os_path_exists` 

90 and `ReplaceOpenWithStringIO` are used to trick MDtraj. 

91 

92 """ 

93 orig_func = builtins.open 

94 

95 def new_func(*args, **kwargs): 

96 if str(args[0].lower()) == "stringio.pdb": 

97 return args[0] 

98 else: 

99 return orig_func(*args, **kwargs) 

100 

101 builtins.open = new_func 

102 

103 

104def _redefine_os_path_exists() -> None: 

105 """Redefines the os.path.exists() function to trick MDTRaj to use `ReplaceOpenWithStringIO`""" 

106 orig_func = os.path.exists 

107 

108 def new_func(path): 

109 if path.lower() == "stringio.pdb": 

110 return True 

111 else: 

112 return orig_func(path) 

113 

114 os.path.exists = new_func 

115 

116 

117class ReplaceOpenWithStringIO: 

118 """Tricks MDTraj to write the output into a StringIO object. 

119 

120 Inside a context-manager, this class will redefine the builtin `oprn()` 

121 function and overwrite the `os.path.exists()` function to use it as a 

122 sort-of str object. 

123 

124 """ 

125 

126 def __init__(self): 

127 self.stringio = StringIO() 

128 self._orig_open = builtins.open 

129 self._orig_path_exists = os.path.exists 

130 

131 def __enter__(self): 

132 _redefine_open() 

133 _redefine_os_path_exists() 

134 return self 

135 

136 def __exit__(self, type, value, traceback): 

137 builtins.open = self._orig_open 

138 os.path.exists = self._orig_path_exists 

139 self.stringio.seek(0) 

140 

141 def write(self, *args, **kwargs) -> None: 

142 """Write into the StringIO object.""" 

143 self.stringio.write(*args, **kwargs) 

144 

145 def read(self) -> str: 

146 """Read from the StringIO object.""" 

147 return self.stringio.read() 

148 

149 def lower(self) -> str: 

150 """Functions will think, this is a str's builtin `lower()` function.""" 

151 return "stringio.pdb" 

152 

153 def __str__(self): 

154 return "stringio.pdb" 

155 

156 def close(self): 

157 pass 

158 

159 

160def traj_to_pdb( 

161 traj: md.Trajectory, 

162) -> list[dict[str, Union[str, bool, dict[str, str]]]]: 

163 """Converts an MDTraj Trajectory into a dict, that can be understood by 

164 dashbio's NglMolViewer. 

165 

166 Args: 

167 traj (md.Trajectory): The MDTraj trajectory. 

168 

169 Returns: 

170 list[dict[str, Union[str, bool, dict[str, str]]]]: The json-type data for 

171 dashbio's NglMolViewer. 

172 

173 """ 

174 with ReplaceOpenWithStringIO() as r: 

175 traj.save_pdb(r) 

176 randname = f"{random.getrandbits(128):032x}" 

177 pdb_content = r.read() 

178 

179 # manually fix by adding peptide bonds 

180 lines = pdb_content.splitlines() 

181 # g = traj.top.to_bondgraph() 

182 # for node in g.nodes: 

183 # neighbors = " ".join([f"{n.index + 1:>4}" for n in g.neighbors(node)]) 

184 # lines.insert(-2, f"CONECT {node.index + 1:>4} {neighbors}") 

185 for chain in traj.top.chains: 

186 residues = [r for r in chain.residues] 

187 for r1, r2 in zip(residues[:-1], residues[1:]): 

188 a1 = {a.name: a.index + 1 for a in r1.atoms} 

189 a2 = {a.name: a.index + 1 for a in r2.atoms} 

190 lines.insert(-2, f"CONECT {a1['C']:>4} {a2['N']:>4}") 

191 pdb_content = "\n".join(lines) 

192 data_list = [ 

193 { 

194 "filename": f"{randname}.pdb", 

195 "ext": "pdb", 

196 "selectedValue": f"{randname[:4].upper()}", 

197 "chain": "ALL", 

198 "aaRange": "ALL", 

199 "chosen": {"atoms": "", "residues": ""}, 

200 "color": "sstruc", 

201 "config": {"type": "text/plain", "input": pdb_content}, 

202 "resetView": True, 

203 "uploaded": False, 

204 } 

205 ] 

206 return data_list 

207 

208 

209################################################################################ 

210# App 

211################################################################################ 

212 

213 

214class DebugPage: 

215 def __init__(self): 

216 self.display = html.Code(id="debug-display") 

217 dash.register_page("debug", layout=self.layout) 

218 

219 @property 

220 def layout(self): 

221 layout = html.Div( 

222 [ 

223 html.H1("Debug Page"), 

224 self.text_area, 

225 self.display, 

226 ] 

227 ) 

228 return layout 

229 

230 @property 

231 def text_area(self): 

232 text_input = html.Div( 

233 [ 

234 dbc.Alert( 

235 "Keep in mind, that this webpage is bad practice. It can " 

236 "execute arbitrary code. If you see this page on a deployed " 

237 "webpage, kill the server immediately. Otherwise use this " 

238 "page to access info about the running dash app. " 

239 "Enter code to evaluate. Accept with Enter-key...", 

240 color="danger", 

241 ), 

242 dbc.Textarea( 

243 placeholder='print("Hello World!")', 

244 rows=10, 

245 id="debug-textarea", 

246 ), 

247 ] 

248 ) 

249 return text_input 

250 

251 @staticmethod 

252 @callback( 

253 Output("debug-display", "children"), 

254 Input("debug-textarea", "n_submit"), 

255 State("debug-textarea", "value"), 

256 prevent_initial_call=True, 

257 ) 

258 def run_code(value, state): 

259 try: 

260 out = str(eval(state)) 

261 except Exception as e: 

262 out = str(e) 

263 return out 

264 

265 

266class LocalUploadTraj: 

267 _placeholder = """\ 

268 trajs = ["/path/to/traj1.xtc", "/path/to/traj2.xtc"] 

269 tops = ["/path/to/traj1.pdb", "/path/to/traj2.pdb"] 

270 common_str = ["traj1, "traj2"] 

271 """ 

272 

273 def __init__(self, main): 

274 self.main = main 

275 

276 # the display 

277 self.display = dbc.Container( 

278 [dbc.Card([], id="upload-display", class_name="align-items-center")], 

279 fluid=True, 

280 ) 

281 

282 # the upload area 

283 self.upload_card_body = dbc.CardBody( 

284 [ 

285 html.H4("Upload files"), 

286 dcc.Upload( 

287 id="upload-data-input", 

288 children=html.Div( 

289 ["Drag/Drop or ", html.A("Select Files")], 

290 className="h-100", 

291 ), 

292 style={ 

293 "width": "100%", 

294 "height": "60px", 

295 "lineHeight": "60px", 

296 "borderWidth": "1px", 

297 "borderStyle": "dashed", 

298 "borderRadius": "5px", 

299 "textAlign": "center", 

300 "margin": "10px", 

301 }, 

302 # Allow multiple files to be uploaded 

303 multiple=True, 

304 ), 

305 ] 

306 ) 

307 self.upload_card = dbc.Col( 

308 [ 

309 dbc.Card( 

310 self.upload_card_body, 

311 id="upload-data-input-card", 

312 class_name="h-100", 

313 ) 

314 ], 

315 width=3, 

316 class_name="col-sm-6 col-lg-3", 

317 ) 

318 

319 # the textarea 

320 self.text_area_card_body = dbc.CardBody( 

321 [ 

322 html.H4("Local files"), 

323 dbc.Textarea( 

324 placeholder=self._placeholder, 

325 rows=3, 

326 readonly=False, 

327 id="upload-paths-input", 

328 class_name="h-50", 

329 ), 

330 ], 

331 ) 

332 self.text_area_card = dbc.Col( 

333 [ 

334 dbc.Card( 

335 self.text_area_card_body, 

336 id="upload-paths-card", 

337 class_name="h-100", 

338 ), 

339 ], 

340 width=6, 

341 class_name="col-sm-6 col-lg-6", 

342 ) 

343 

344 # the project area 

345 self.input_card_body = dbc.CardBody( 

346 [ 

347 html.H4("Project"), 

348 dbc.Input( 

349 placeholder="linear_dimers", 

350 id="upload-project-input", 

351 class_name="h-50", 

352 ), 

353 ], 

354 ) 

355 self.input_card = dbc.Col( 

356 [ 

357 dbc.Card( 

358 self.input_card_body, 

359 id="upload-project-card", 

360 class_name="h-100", 

361 ), 

362 ], 

363 width=3, 

364 class_name="col-sm-6 col-lg-3", 

365 ) 

366 

367 # the complete container 

368 self.upload_container = dbc.Container( 

369 [ 

370 html.Div( 

371 [ 

372 self.upload_card, 

373 self.text_area_card, 

374 self.input_card, 

375 ], 

376 id="upload-hide", 

377 className="row", 

378 style={"display": "none", "height": "30%"}, 

379 ), 

380 html.Div(style={"margin-top": "20px"}), 

381 dbc.Row( 

382 [ 

383 dbc.Button( 

384 "Upload", 

385 id="upload-button", 

386 style={"width": "95%", "margin": "auto"}, 

387 ) 

388 ] 

389 ), 

390 html.Br(), 

391 dbc.Row( 

392 [ 

393 dbc.Button( 

394 "Linear Dimers Project", 

395 id="linear-dimers-button", 

396 style={"width": "95%", "margin": "auto"}, 

397 ) 

398 ] 

399 ), 

400 html.Br(), 

401 dbc.Row( 

402 [ 

403 dbc.Button( 

404 "1am7 Project", 

405 id="1am7-button", 

406 style={"width": "95%", "margin": "auto"}, 

407 ) 

408 ] 

409 ), 

410 html.Br(), 

411 dbc.Row( 

412 [ 

413 dbc.Button( 

414 "Reset", 

415 id="upload-reset-button", 

416 style={"width": "95%", "margin": "auto"}, 

417 ) 

418 ] 

419 ), 

420 ], 

421 fluid=True, 

422 style={"height": "75vh"}, 

423 ) 

424 

425 # define the layout 

426 self.layout = html.Div( 

427 [ 

428 self.display, 

429 self.main.store, 

430 html.Br(), 

431 dmc.LoadingOverlay( 

432 self.upload_container, 

433 ), 

434 ], 

435 style={"margin": "2%"}, 

436 ) 

437 

438 # define local callbacks 

439 self.main.app.callback( 

440 Output("upload-display", "children"), 

441 Output("upload-hide", "style"), 

442 Output("main-store", "data"), 

443 Input("upload-button", "n_clicks"), 

444 Input("linear-dimers-button", "n_clicks"), 

445 Input("1am7-button", "n_clicks"), 

446 Input("upload-reset-button", "n_clicks"), 

447 State("main-store", "data"), 

448 State("upload-data-input", "contents"), 

449 State("upload-data-input", "filename"), 

450 State("upload-data-input", "last_modified"), 

451 State("upload-paths-input", "value"), 

452 State("upload-project-input", "value"), 

453 )(self.load_trajs) 

454 

455 # register the pae 

456 dash.register_page("upload", layout=self.layout) 

457 

458 def load_trajs( 

459 self, 

460 upload_n_clicks, # Input 

461 reset_n_clicks, # Input 

462 linear_dimers_n_clicks, # Input 

463 n_click_1am7, # Input 

464 main_store, # State 

465 list_of_contents, # State 

466 list_of_names, # State 

467 list_of_dates, # State 

468 textarea_value, # State 

469 project_value, # State 

470 ): 

471 if main_store is None: 

472 main_store = {} 

473 triggered_id = ctx.triggered_id 

474 

475 # reset button pressed 

476 if triggered_id == "upload-reset-button": 

477 if hasattr(self.main, "trajs"): 

478 del self.main.trajs 

479 return ( 

480 dbc.CardBody(f"Session was reset. Choose MD data to upload."), 

481 {"height": "30%"}, 

482 {}, 

483 ) 

484 

485 if triggered_id == "linear-dimers-button": 

486 if hasattr(self.main, "trajs"): 

487 del self.main.trajs 

488 if main_store is None: 

489 main_store = {} 

490 main_store["traj_type"] = "project" 

491 main_store["traj"] = "linear_dimers" 

492 self.main.traj_page.parse_trajs(main_store) 

493 return ( 

494 dbc.CardBody(f"Data loaded. Press 'Reset' to reset the session."), 

495 {"display": "none"}, 

496 main_store, 

497 ) 

498 if triggered_id == "1am7-button": 

499 if hasattr(self.main, "trajs"): 

500 del self.main.trajs 

501 if main_store is None: 

502 main_store = {} 

503 main_store["traj_type"] = "project" 

504 main_store["traj"] = "1am7" 

505 self.main.traj_page.parse_trajs(main_store) 

506 return ( 

507 dbc.CardBody(f"Data loaded. Press 'Reset' to reset the session."), 

508 {"display": "none"}, 

509 main_store, 

510 ) 

511 

512 empty = [ 

513 dbc.CardBody(f"Data loaded. Press 'Reset' to reset the session."), 

514 {"display": "none"}, 

515 main_store, 

516 ] 

517 

518 if triggered_id is None: 

519 if isinstance(main_store, dict): 

520 if "traj_type" in main_store: 

521 return tuple(empty) 

522 return ( 

523 dbc.CardBody(f"Choose MD data to upload."), 

524 {"height": "30%"}, 

525 main_store, 

526 ) 

527 

528 # upload button pressed 

529 uploaded_any = ( 

530 list_of_contents is not None 

531 or textarea_value is not None 

532 or project_value is not None 

533 ) 

534 if triggered_id == "upload-button" and not uploaded_any: 

535 return ( 

536 dbc.CardBody( 

537 f"Place files in the upload window, or enter local files, " 

538 f"or a project name before pressing upload." 

539 ), 

540 {"height": "30%"}, 

541 main_store, 

542 ) 

543 

544 # here we transform data 

545 if list_of_contents is not None: 

546 main_store.update({"traj_type": "paths", "traj": list_of_contents}) 

547 # self.main.traj_page.parse_trajs(main_store) 

548 empty[0] = dbc.CardBody(f"Uploading files currently not implemented..") 

549 elif textarea_value is not None: 

550 main_store.update({"traj_type": "text", "traj": textarea_value}) 

551 self.main.traj_page.parse_trajs(main_store) 

552 empty[0] = dbc.CardBody( 

553 f"Loading trajectories. Go to the 'Traj' page to look at your data." 

554 ) 

555 elif project_value is not None: 

556 main_store.update({"traj_type": "project", "traj": project_value}) 

557 self.main.traj_page.parse_trajs(main_store) 

558 empty[0] = dbc.CardBody( 

559 f"Loading project '{project_value}'. Go to the 'Traj' page to look at your data." 

560 ) 

561 

562 empty[-1] = main_store 

563 return tuple(empty) 

564 

565 

566class TopPage: 

567 def __init__(self, main): 

568 self.main = main 

569 

570 # the display 

571 self.display = dbc.Container( 

572 [dbc.Card([], id="top-display", class_name="align-items-center")], 

573 fluid=True, 

574 ) 

575 

576 # the topology dropdown selector 

577 self.topology_container = dbc.Container( 

578 [ 

579 dbc.Row( 

580 [ 

581 dbc.Col( 

582 dcc.Dropdown( 

583 id="top-dynamic-dropdown", 

584 style={"width": "100%"}, 

585 persistence=True, 

586 persistence_type="session", 

587 ), 

588 width=3, 

589 class_name="col-sm-6 col-lg-3", 

590 ), 

591 dbc.Col( 

592 dbc.RadioItems( 

593 options=[ 

594 {"label": "Atoms", "value": 0}, 

595 {"label": "Bonds", "value": 1}, 

596 {"label": "Angles", "value": 2}, 

597 {"label": "Dihedrals", "value": 3}, 

598 ], 

599 inline=True, 

600 persistence=True, 

601 persistence_type="session", 

602 value=0, 

603 id="top-radioitems-input", 

604 ), 

605 width=5, 

606 class_name="col-sm-6 col-lg-4", 

607 ), 

608 dbc.Col( 

609 dbc.Row( 

610 [ 

611 html.P("Atom Subset"), 

612 dcc.RangeSlider( 

613 0, 

614 0, 

615 value=[0, 0], 

616 id="top-atom-indices-rangeslider", 

617 persistence=True, 

618 persistence_type="session", 

619 ), 

620 ], 

621 ), 

622 width=4, 

623 class_name="col-sm-6 col-lg-4", 

624 ), 

625 ], 

626 style={"width": "100%"}, 

627 ), 

628 html.Br(), 

629 dbc.Row( 

630 [ 

631 dbc.Col( 

632 [ 

633 dbc.Row( 

634 [ 

635 dmc.JsonInput( 

636 label="Custom Amino Acids:", 

637 value="{}", 

638 autosize=True, 

639 minRows=30, 

640 id="top-json-input", 

641 ), 

642 ], 

643 ), 

644 html.Br(), 

645 dbc.Button( 

646 "Load", 

647 id="top-json-load", 

648 style={"width": "95%", "margin": "auto"}, 

649 ), 

650 ], 

651 id="top-custom-aas", 

652 width=3, 

653 class_name="col-sm-6 col-lg-3", 

654 ), 

655 dbc.Col( 

656 [ 

657 html.Div( 

658 [], 

659 id="top-top-plot", 

660 style={"height": "800px"}, 

661 ), 

662 ], 

663 width=9, 

664 class_name="col-sm-6, col-lg-9", 

665 ), 

666 ], 

667 ), 

668 ], 

669 style={"display": "none", "height": "85%"}, 

670 id="top-hide", 

671 fluid=True, 

672 ) 

673 

674 # the dummy div, that is used to run the context when the page is accessed 

675 self.dummy_div = html.Div([], id="top-page-dummy-div") 

676 

677 # define the layout 

678 self.layout = html.Div( 

679 [ 

680 self.display, 

681 self.main.store, 

682 html.Br(), 

683 dmc.LoadingOverlay( 

684 self.topology_container, 

685 ), 

686 self.dummy_div, 

687 ], 

688 style={"margin": "2%"}, 

689 ) 

690 

691 # decorate the interactiveness of the page 

692 self.main.app.callback( 

693 Output("top-display", "children"), 

694 Output("top-hide", "style"), 

695 Output("top-dynamic-dropdown", "options"), 

696 State("top-json-input", "value"), 

697 State("top-dynamic-dropdown", "value"), 

698 Input("top-json-load", "n_clicks"), 

699 Input("main-store", "data"), 

700 Input("top-page-dummy-div", "children"), 

701 )(self.display_top) 

702 

703 # display the custom amino acids of the selected topology as json 

704 self.main.app.callback( 

705 Output("top-json-input", "value"), 

706 Output("top-atom-indices-rangeslider", "max"), 

707 Output("top-atom-indices-rangeslider", "value"), 

708 Input("top-dynamic-dropdown", "value"), 

709 )(self.display_custom_aas) 

710 

711 # display the topology 

712 self.main.app.callback( 

713 Output("top-top-plot", "children"), 

714 State("top-dynamic-dropdown", "value"), 

715 Input("top-radioitems-input", "value"), 

716 Input("top-page-dummy-div", "children"), 

717 Input("top-atom-indices-rangeslider", "value"), 

718 )(self.display_plot) 

719 

720 # register the page 

721 dash.register_page( 

722 "top", 

723 layout=self.layout, 

724 ) 

725 

726 def display_plot( 

727 self, 

728 top_value, 

729 radio_value, 

730 dummy, 

731 rangeslider_value, 

732 ) -> Any: 

733 if top_value is None: 

734 top_value = 0 

735 

736 if not hasattr(self.main, "trajs"): 

737 raise PreventUpdate 

738 

739 highlight: Literal["atoms", "bonds", "angles", "dihedrals"] = "atoms" 

740 if radio_value == 0: 

741 pass 

742 elif radio_value == 1: 

743 highlight = "bonds" 

744 elif radio_value == 2: 

745 highlight = "angles" 

746 elif radio_value == 3: 

747 highlight = "dihedrals" 

748 

749 if rangeslider_value != [0, 0]: 

750 atom_indices = list(range(rangeslider_value[0], rangeslider_value[1])) 

751 else: 

752 atom_indices = None 

753 

754 top = self.main.trajs.top[top_value] 

755 traj = self.main.trajs.trajs_by_top[top][0] 

756 

757 try: 

758 fig = _plot_ball_and_stick( 

759 traj, 

760 highlight=highlight, 

761 atom_indices=atom_indices, 

762 ) 

763 except Exception as e: 

764 raise Exception(f"{atom_indices=}") from e 

765 self.main._figures.append(fig) 

766 return dcc.Graph( 

767 figure=fig, 

768 ) 

769 

770 return f"Pressed {radio_value=} {top_value=}" 

771 

772 def display_custom_aas(self, top): 

773 if not hasattr(self.main, "trajs"): 

774 raise PreventUpdate 

775 if top is None: 

776 return "{}", 0, [0, 0] 

777 top = self.main.trajs.top[top] 

778 trajs = self.main.trajs.trajs_by_top[top] 

779 custom_aas = [t._custom_top for t in trajs] 

780 if len(custom_aas) > 1: 

781 if any([custom_aas[0] != c for c in custom_aas[1:]]): 

782 msg = f"The trajectories contain different `_custom_aas`. I am not able to display them:" 

783 for t in trajs: 

784 msg += f"\n{t.basename}:\n{t._custom_top.to_json()}" 

785 return msg 

786 return custom_aas[0].to_json(), top.n_atoms, [0, top.n_atoms] 

787 

788 def get_options(self): 

789 options = [] 

790 for i, top in enumerate(self.main.trajs.top): 

791 top_str = str(top).lstrip("<mdtraj.Topology with ").rstrip(">") 

792 options.append({"label": top_str, "value": i}) 

793 return options 

794 

795 def fill_dropdown(self, main_store, search_value): 

796 if main_store is None: 

797 raise PreventUpdate 

798 if "traj_type" not in main_store: 

799 raise PreventUpdate 

800 if not hasattr(self.main, "trajs"): 

801 raise PreventUpdate 

802 return self.get_options() 

803 

804 def display_top(self, json_values, top_value, n_clicks, main_store, dummy): 

805 blank_text = f"View and modify topologies on this page after you upload them." 

806 empty = (dbc.CardBody(blank_text), {"display": "none"}, {}) 

807 if main_store is None: 

808 return empty 

809 if "traj_type" not in main_store: 

810 return empty 

811 if not hasattr(self.main, "trajs"): 

812 self.main.traj_page.parse_trajs(main_store) 

813 

814 triggered_id = ctx.triggered_id 

815 if triggered_id == "top-json-load": 

816 if json_values == "{}": 

817 return ( 

818 dbc.CardBody( 

819 f"Provide custom amino-acids for the selected topology " 

820 f"{self.main.trajs.trajs_by_top[self.main.trajs.top[top_value]]} " 

821 f"to load." 

822 ), 

823 {"height": "85%"}, 

824 self.get_options(), 

825 ) 

826 else: 

827 try: 

828 data = json.loads(json_values) 

829 except json.decoder.JSONDecodeError as e: 

830 return ( 

831 dbc.CardBody( 

832 f"Couldn't parse your json. I got the error: {e}." 

833 ), 

834 {"height": "85%"}, 

835 self.get_options(), 

836 ) 

837 try: 

838 self.main.trajs.trajs_by_top[ 

839 self.main.trajs.top[top_value] 

840 ].load_custom_topology(data) 

841 except Exception as e: 

842 return ( 

843 dbc.CardBody(f"Couldn't load the custom topology: {e}"), 

844 {"height": "85%"}, 

845 self.get_options(), 

846 ) 

847 return ( 

848 dbc.CardBody(f"Custom topology loaded."), 

849 {"height": "85%"}, 

850 self.get_options(), 

851 ) 

852 

853 return ( 

854 dbc.CardBody(f"Topologies are available: {self.main.trajs.top=}"), 

855 {"height": "85%"}, 

856 self.get_options(), 

857 ) 

858 

859 

860class TrajPage: # pragma: no doccheck 

861 def __init__(self, main): 

862 self.main = main 

863 self.decorated = False 

864 self.main_div = html.Div([], id="traj-page-div") 

865 self.dummy_div = html.Div([], id="traj-page-dummy-div") 

866 dash.register_page( 

867 "traj", 

868 layout=html.Div( 

869 [ 

870 self.main.store, 

871 self.main_div, 

872 self.dummy_div, 

873 ] 

874 ), 

875 ) 

876 

877 self.main.app.callback( 

878 Output("traj-page-div", "children"), 

879 Input("main-store", "data"), 

880 Input("traj-page-dummy-div", "children"), 

881 )(self.display_trajs) 

882 

883 def display_trajs( 

884 self, 

885 main_store, 

886 dummy, 

887 ): 

888 if main_store is None: 

889 return html.P(f"Trajs will appear here after you upload them.") 

890 if "traj_type" not in main_store: 

891 return html.P(f"Trajs will appear here after you upload them.") 

892 

893 if not hasattr(self.main, "trajs"): 

894 self.parse_trajs(main_store) 

895 

896 return self.traj_loaded_layout(-1) 

897 

898 def display_traj(self, *args, **kwargs): 

899 triggered_id = ctx.triggered_id 

900 return html.Div(html.P(f"{args=} {kwargs=} {triggered_id=}")) 

901 

902 def traj_loaded_layout(self, traj_num: int = -1): 

903 if traj_num == -1: 

904 layout = html.Div( 

905 [ 

906 dbc.Row( 

907 [ 

908 dbc.DropdownMenu( 

909 label="Trajectory number", 

910 id="traj-page-dropdown", 

911 children=[ 

912 dbc.DropdownMenuItem( 

913 f"traj: {traj.traj_num} {traj.common_str}", 

914 id=f"traj-page-dropdown-item-{traj.traj_num}", 

915 ) 

916 for traj in self.main.trajs 

917 ], 

918 ) 

919 ] 

920 ), 

921 html.P(f"Here will trajs be displayed {self.main.trajs}"), 

922 html.Div("This is where the custom React component should go."), 

923 ] 

924 ) 

925 

926 # add the dynamic callback to display_traj 

927 if not self.decorated: 

928 self.main.app.callback( 

929 Output("traj-page-display", "children"), 

930 [ 

931 Input(f"traj-page-dropdown-item-{traj.traj_num}", "n_clicks") 

932 for traj in self.main.trajs 

933 ], 

934 prevent_initial_call=True, 

935 )(self.display_traj) 

936 self.decorated = True 

937 

938 return layout 

939 

940 def parse_trajs(self, data): 

941 if data["traj_type"] == "project": 

942 self.main.trajs, self.main.autoencoder = load_project( 

943 data["traj"], load_autoencoder=True 

944 ) 

945 else: 

946 raise NotImplementedError 

947 

948 

949class ProjectionPage(InteractivePlotting): 

950 def __init__(self, main): 

951 self.main = main 

952 self.scatter = None 

953 

954 # the display 

955 self.display = dbc.Container( 

956 [ 

957 dbc.Card( 

958 [html.P("Go to the 'Load' page to load data.")], 

959 id="projection-page-display", 

960 class_name="align-items-center", 

961 ), 

962 ], 

963 fluid=True, 

964 ) 

965 

966 # the topology dropdown selector 

967 self.figure_widget = go.FigureWidget() 

968 self.trace_widget = go.FigureWidget() 

969 self.projection_container = dbc.Container( 

970 [ 

971 dbc.Row(), 

972 dbc.Row( 

973 [ 

974 dbc.Col( 

975 dcc.Graph( 

976 figure=self.figure_widget, id="projection-page-plot" 

977 ), 

978 style={"height": "40vh"}, 

979 class_name="col-lg-6", 

980 ), 

981 dbc.Col( 

982 dcc.Graph( 

983 figure=self.trace_widget, id="projection-page-trace" 

984 ), 

985 style={"height": "40vh"}, 

986 class_name="col-lg-1", 

987 ), 

988 dbc.Col( 

989 html.Div([], id="ngl-container", style={"height": "40vh"}), 

990 style={"height": "40vh"}, 

991 class_name="col-lg-5", 

992 id="projection-page-view", 

993 ), 

994 ], 

995 style={"width": "100%", "height": "40%"}, 

996 ), 

997 ], 

998 style={"height": "85%"}, 

999 id="projection-page-hide", 

1000 fluid=True, 

1001 ) 

1002 

1003 # the dummy div, that is used to run the context when the page is accessed 

1004 self.dummy_div = html.Div([], id="projection-page-dummy-div") 

1005 

1006 # define the layout 

1007 self.layout = html.Div( 

1008 [ 

1009 self.display, 

1010 self.main.store, 

1011 html.Br(), 

1012 dmc.LoadingOverlay( 

1013 self.projection_container, 

1014 ), 

1015 self.dummy_div, 

1016 ], 

1017 style={"margin": "2%"}, 

1018 ) 

1019 

1020 dash.register_page( 

1021 "projection", 

1022 layout=self.layout, 

1023 ) 

1024 

1025 self.main.app.callback( 

1026 Output("projection-page-display", "children"), 

1027 Output("projection-page-plot", "figure"), 

1028 Output("ngl-container", "children"), 

1029 Input("projection-page-plot", "clickData"), 

1030 Input("projection-page-plot", "selectedData"), 

1031 Input("projection-page-plot", "relayoutData"), 

1032 State("main-store", "data"), 

1033 )(self.interact) 

1034 

1035 @property 

1036 def molStyles(self): 

1037 molstyles_dict = { 

1038 "representations": ["cartoon"], 

1039 "chosenAtomsColor": "white", 

1040 "chosenAtomsRadius": 1, 

1041 "molSpacingXaxis": 100, 

1042 } 

1043 return molstyles_dict 

1044 

1045 def interact( 

1046 self, click_on_plot, select_in_plot, relayoutdata, main_store 

1047 ) -> tuple[html.P, go.FigureWidget, Any]: 

1048 """Interactive elements: 

1049 

1050 * Click on Scatter Point 

1051 * Buttons: 

1052 * switch between cluster and scatter 

1053 * cluster 

1054 * generate 

1055 * slider 

1056 * dropdown 

1057 * Progress 

1058 

1059 """ 

1060 empty = [ 

1061 html.P("Go to the 'Load' page to load data."), 

1062 self.figure_widget, 

1063 [], 

1064 ] 

1065 triggered_id = ctx.triggered_id 

1066 if triggered_id == "projection-page-plot": 

1067 if click_on_plot is None: 

1068 raise PreventUpdate 

1069 index = [p["pointIndex"] for p in click_on_plot["points"]][0] 

1070 frame = self.main.trajs.get_single_frame(index) 

1071 marker = { 

1072 "color": [ 

1073 "#1f77b4" if i != index else "#ff7f0e" 

1074 for i in range(len(self.lowd)) 

1075 ], 

1076 "size": [7 if i != index else 20 for i in range(len(self.lowd))], 

1077 } 

1078 self.figure_widget.update_traces( 

1079 marker=marker, selector=({"name": "scatter"}) 

1080 ) 

1081 viewer = dashbio.NglMoleculeViewer( 

1082 data=traj_to_pdb(frame.traj), 

1083 molStyles=self.molStyles, 

1084 ) 

1085 empty[2] = [viewer] 

1086 print(f"{click_on_plot=} {empty[2]=}") 

1087 return tuple(empty) 

1088 

1089 if self.main.trajs is None: 

1090 if main_store is not None: 

1091 if main_store != {}: 

1092 self.main.traj_page.parse_trajs(main_store) 

1093 if self.main.autoencoder is None: 

1094 if main_store is not None: 

1095 if main_store != {}: 

1096 if main_store["traj_type"] == "project": 

1097 empty[0] = html.P( 

1098 f"The project {main_store['traj']} has no autoencoder associated to it." 

1099 ) 

1100 else: 

1101 empty[0] = html.P(f"The loaded trajs are not trained jet.") 

1102 

1103 if self.main.autoencoder is not None: 

1104 self.highd = self._highd_parser( 

1105 self.main.autoencoder, highd=None, trajs=self.main.trajs 

1106 ) 

1107 self.lowd = self._lowd_parser( 

1108 self.main.autoencoder, lowd=None, trajs=self.main.trajs 

1109 ) 

1110 empty[0] = html.P(f"Interact with the plot to view conformations.") 

1111 

1112 # parse scatter 

1113 if self.scatter is None and self.main.autoencoder is not None: 

1114 self.lowd_dim = self.lowd.shape[1] 

1115 self.scatter = go.Scattergl( 

1116 x=self.lowd[:, 0], 

1117 y=self.lowd[:, 1], 

1118 mode="markers", 

1119 name="scatter", 

1120 marker={ 

1121 "size": [10 for i in range(len(self.lowd))], 

1122 "color": ["#1f77b4" for i in range(len(self.lowd))], 

1123 }, 

1124 ) 

1125 

1126 self.figure_widget.add_trace(self.scatter) 

1127 

1128 print( 

1129 f"Interacting {triggered_id=} {click_on_plot=} {select_in_plot=} {relayoutdata=} {self.main.trajs=} {main_store=} {self.scatter=}" 

1130 ) 

1131 

1132 return tuple(empty) 

1133 

1134 

1135class Dashboard: 

1136 _encodermap_logo = "https://raw.githubusercontent.com/AG-Peter/encodermap/main/pic/logo_cube_300.png" 

1137 

1138 def __init__(self): 

1139 # create a dir for cache 

1140 if self.local: 

1141 self.cache_dir = Path("/tmp/encodermap_dash_cache") 

1142 self.cache_dir.mkdir(parents=True, exist_ok=True) 

1143 else: 

1144 raise NotImplementedError("No cachedir in non-local mode.") 

1145 

1146 # create the app and register the main page 

1147 self.server = flask.Flask(__name__) 

1148 self.store = dcc.Store(id="main-store", storage_type="session") 

1149 self.trajs = None 

1150 self.autoencoder = None 

1151 self.app = Dash( 

1152 server=self.server, 

1153 external_stylesheets=[ 

1154 dbc.themes.BOOTSTRAP, 

1155 dbc.icons.FONT_AWESOME, 

1156 ], 

1157 use_pages=True, 

1158 pages_folder="", 

1159 ) 

1160 vault_file = Path(__file__).resolve().parent.parent.parent / "act.vault" 

1161 if vault_file.is_file(): 

1162 username = get_assign_from_file(vault_file, "DASH_USER") 

1163 password = get_assign_from_file(vault_file, "DASH_PASSWORD") 

1164 auth = dash_auth.BasicAuth(self.app, {username: password}) 

1165 self.app.title = "EncoderMap Dashboard" 

1166 self.app._favicon = "favicon.ico" 

1167 

1168 # theme changer 

1169 self.theme_changer = ThemeChangerAIO( 

1170 aio_id="theme-change", 

1171 button_props={ 

1172 "color": "secondary", 

1173 "class_name": "me-1", 

1174 "outline": False, 

1175 "style": {"margin-top": "5px"}, 

1176 }, 

1177 radio_props={ 

1178 "persistence": True, 

1179 "persistence_type": "session", 

1180 }, 

1181 ) 

1182 

1183 # create other pages 

1184 if self.debug: 

1185 self.debug_page = DebugPage() 

1186 if self.local: 

1187 self.upload_traj_page = LocalUploadTraj(self) 

1188 else: 

1189 raise Exception("Write a non-local upload page") 

1190 self.traj_page = TrajPage(self) 

1191 self.top_page = TopPage(self) 

1192 self.projection_page = ProjectionPage(self) 

1193 

1194 # collect instance attributes 

1195 if self.local: 

1196 self._greeting = f"EncoderMap Dashboard for {os.getlogin()}" 

1197 else: 

1198 self._greeting = "EncoderMap Dashboard" 

1199 self._figures = [] 

1200 

1201 # init methods are divided into distinct methods to make them more legible 

1202 dash.register_page("home", path="/", layout=self.layout) 

1203 self.app.layout = self.app_layout 

1204 

1205 # decorate the class callbacks 

1206 callback( 

1207 Output("placeholder", "children"), 

1208 Input(ThemeChangerAIO.ids.radio("theme"), "value"), 

1209 prevent_initial_call=True, 

1210 )(self.update_theme) 

1211 

1212 @property 

1213 def app_layout(self): 

1214 return html.Div( 

1215 [ 

1216 self.navbar, 

1217 dash.page_container, 

1218 self.store, 

1219 self.placeholder, 

1220 ] 

1221 ) 

1222 

1223 @property 

1224 def layout(self): 

1225 layout = html.Div( 

1226 dbc.Container( 

1227 [ 

1228 html.H1("EncoderMap dashboard", className="display-3"), 

1229 html.P( 

1230 "Use the 'Upload' page to upload your MD data.", 

1231 className="lead", 

1232 ), 

1233 html.Hr(className="my-2"), 

1234 html.P( 

1235 [ 

1236 "Check out EncoderMap's GitHub page: ", 

1237 html.A( 

1238 "https://github.com/AG-Peter/encodermap", 

1239 href="https://github.com/AG-Peter/encodermap", 

1240 ), 

1241 ] 

1242 ), 

1243 html.P( 

1244 [ 

1245 "Read EncoderMap's documentation: ", 

1246 html.A( 

1247 "https://ag-peter.github.io/encodermap/", 

1248 href="https://ag-peter.github.io/encodermap/", 

1249 ), 

1250 ] 

1251 ), 

1252 html.P( 

1253 [ 

1254 "Give credit to the authors: ", 

1255 html.Ul( 

1256 [ 

1257 html.Li( 

1258 html.A( 

1259 "https://pubs.acs.org/doi/abs/10.1021/acs.jctc.8b00975", 

1260 href="https://pubs.acs.org/doi/abs/10.1021/acs.jctc.8b00975", 

1261 ), 

1262 ), 

1263 html.Li( 

1264 html.A( 

1265 "https://pubs.acs.org/doi/abs/10.1021/acs.jcim.9b00675", 

1266 href="https://pubs.acs.org/doi/abs/10.1021/acs.jcim.9b00675", 

1267 ), 

1268 ), 

1269 ] 

1270 ), 

1271 ] 

1272 ), 

1273 ], 

1274 fluid=True, 

1275 className="py-3", 

1276 ), 

1277 className="p-3 bg-light rounded-3", 

1278 ) 

1279 return layout 

1280 

1281 @property 

1282 def navbar(self): 

1283 navbar = dbc.Navbar( 

1284 dbc.Container( 

1285 [ 

1286 html.A( 

1287 # Use row and col to control vertical alignment of logo / brand 

1288 dbc.Row( 

1289 [ 

1290 dbc.Col( 

1291 html.Img(src=self._encodermap_logo, height="30px") 

1292 ), 

1293 dbc.Col( 

1294 dbc.NavbarBrand(self._greeting, className="ms-2") 

1295 ), 

1296 dbc.Col( 

1297 [ 

1298 dbc.Nav( 

1299 [ 

1300 dbc.NavItem( 

1301 dbc.NavLink( 

1302 page["name"], 

1303 href=page["relative_path"], 

1304 ) 

1305 ) 

1306 for page in dash.page_registry.values() 

1307 if page["name"] != "Home" 

1308 ], 

1309 navbar=True, 

1310 ), 

1311 ], 

1312 width={"size": "auto"}, 

1313 ), 

1314 ], 

1315 align="center", 

1316 className="g-0", 

1317 ), 

1318 href="/", 

1319 style={"textDecoration": "none"}, 

1320 ), 

1321 dbc.Row( 

1322 [ 

1323 dbc.Col( 

1324 [ 

1325 dbc.Nav( 

1326 [ 

1327 dbc.NavItem( 

1328 dbc.NavLink( 

1329 html.I( 

1330 className="fa-solid fa-heart", 

1331 style={"font-size": "1.5em"}, 

1332 ), 

1333 href="https://www.chemie.uni-konstanz.de/ag-peter/", 

1334 ) 

1335 ), 

1336 dbc.NavItem( 

1337 dbc.NavLink( 

1338 html.I( 

1339 className="fa-solid fa-book-open", 

1340 style={"font-size": "1.5em"}, 

1341 ), 

1342 href="https://ag-peter.github.io/encodermap/", 

1343 ) 

1344 ), 

1345 dbc.NavItem( 

1346 dbc.NavLink( 

1347 html.I( 

1348 className="fa-brands fa-square-github", 

1349 style={"font-size": "1.5em"}, 

1350 ), 

1351 href="https://github.com/AG-Peter/encodermap", 

1352 ) 

1353 ), 

1354 dbc.NavItem(self.theme_changer), 

1355 ] 

1356 ), 

1357 ] 

1358 ), 

1359 ], 

1360 align="center", 

1361 ), 

1362 ], 

1363 fluid=True, 

1364 ), 

1365 color="primary", 

1366 dark=True, 

1367 style={"justify-content": "left"}, 

1368 ) 

1369 return navbar 

1370 

1371 @property 

1372 def placeholder(self): 

1373 return html.P(id="placeholder") 

1374 

1375 def update_theme(self, theme): 

1376 for fig in self._figures: 

1377 template = template_from_url(theme) 

1378 fig.template = template 

1379 

1380 @property 

1381 def debug(self): 

1382 return os.getenv("ENCODERMAP_DASH_DEBUG", "False") == "True" 

1383 

1384 @property 

1385 def local(self): 

1386 return os.getenv("ENCODERMAP_DASH_RUN_LOCAL", "False") == "True" 

1387 

1388 def run(self, *args, **kwargs): 

1389 self.app.run(*args, host="0.0.0.0", debug=self.debug, **kwargs) 

1390 

1391 

1392################################################################################ 

1393# Main 

1394################################################################################ 

1395 

1396 

1397if __name__ == "__main__": 

1398 dashboard = Dashboard() 

1399 

1400 dashboard.run()