Coverage for encodermap/plot/dashboard.py: 16%
381 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/plotting/dashboard.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""EncoderMap's dashboard. Explore and understand your MD data.
24"""
27################################################################################
28# Imports
29################################################################################
32# Future Imports at the top
33from __future__ import annotations
35# Standard Library Imports
36import builtins
37import json
38import os
39import random
40from io import StringIO
41from pathlib import Path
42from typing import TYPE_CHECKING, Any, Final, Literal, Optional, Union
44# Third Party Imports
45import dash
46import dash_auth
47import dash_bio as dashbio
48import dash_bootstrap_components as dbc
49import dash_mantine_components as dmc
50import flask
51import plotly.graph_objs as go
52from dash import Dash, Input, Output, State, callback, ctx, dcc, html
53from dash.exceptions import PreventUpdate
54from dash_bootstrap_templates import ThemeChangerAIO, template_from_url
56# Encodermap imports
57from encodermap import load_project
58from encodermap.kondata import get_assign_from_file
59from encodermap.plot.interactive_plotting import InteractivePlotting
60from encodermap.plot.plotting import _plot_ball_and_stick
63################################################################################
64# Typing
65################################################################################
68if TYPE_CHECKING:
69 # Third Party Imports
70 import mdtraj as md
73################################################################################
74# Globals
75################################################################################
78################################################################################
79# Helpers
80################################################################################
83def _redefine_open() -> None:
84 """Redefines the `open()` builtin to trick MDTraj to use `ReplaceOpenWithStringIO`.
86 MDTraj can't save to a StringIO object, because it checks the str provided
87 in the filename argument for the extension and decides the format based on that
88 (even using the `save_pdb()` function checks the extension). As StringIO objects
89 can't be used like a string this function together with `_redefine_os_path_exists`
90 and `ReplaceOpenWithStringIO` are used to trick MDtraj.
92 """
93 orig_func = builtins.open
95 def new_func(*args, **kwargs):
96 if str(args[0].lower()) == "stringio.pdb":
97 return args[0]
98 else:
99 return orig_func(*args, **kwargs)
101 builtins.open = new_func
104def _redefine_os_path_exists() -> None:
105 """Redefines the os.path.exists() function to trick MDTRaj to use `ReplaceOpenWithStringIO`"""
106 orig_func = os.path.exists
108 def new_func(path):
109 if path.lower() == "stringio.pdb":
110 return True
111 else:
112 return orig_func(path)
114 os.path.exists = new_func
117class ReplaceOpenWithStringIO:
118 """Tricks MDTraj to write the output into a StringIO object.
120 Inside a context-manager, this class will redefine the builtin `oprn()`
121 function and overwrite the `os.path.exists()` function to use it as a
122 sort-of str object.
124 """
126 def __init__(self):
127 self.stringio = StringIO()
128 self._orig_open = builtins.open
129 self._orig_path_exists = os.path.exists
131 def __enter__(self):
132 _redefine_open()
133 _redefine_os_path_exists()
134 return self
136 def __exit__(self, type, value, traceback):
137 builtins.open = self._orig_open
138 os.path.exists = self._orig_path_exists
139 self.stringio.seek(0)
141 def write(self, *args, **kwargs) -> None:
142 """Write into the StringIO object."""
143 self.stringio.write(*args, **kwargs)
145 def read(self) -> str:
146 """Read from the StringIO object."""
147 return self.stringio.read()
149 def lower(self) -> str:
150 """Functions will think, this is a str's builtin `lower()` function."""
151 return "stringio.pdb"
153 def __str__(self):
154 return "stringio.pdb"
156 def close(self):
157 pass
160def traj_to_pdb(
161 traj: md.Trajectory,
162) -> list[dict[str, Union[str, bool, dict[str, str]]]]:
163 """Converts an MDTraj Trajectory into a dict, that can be understood by
164 dashbio's NglMolViewer.
166 Args:
167 traj (md.Trajectory): The MDTraj trajectory.
169 Returns:
170 list[dict[str, Union[str, bool, dict[str, str]]]]: The json-type data for
171 dashbio's NglMolViewer.
173 """
174 with ReplaceOpenWithStringIO() as r:
175 traj.save_pdb(r)
176 randname = f"{random.getrandbits(128):032x}"
177 pdb_content = r.read()
179 # manually fix by adding peptide bonds
180 lines = pdb_content.splitlines()
181 # g = traj.top.to_bondgraph()
182 # for node in g.nodes:
183 # neighbors = " ".join([f"{n.index + 1:>4}" for n in g.neighbors(node)])
184 # lines.insert(-2, f"CONECT {node.index + 1:>4} {neighbors}")
185 for chain in traj.top.chains:
186 residues = [r for r in chain.residues]
187 for r1, r2 in zip(residues[:-1], residues[1:]):
188 a1 = {a.name: a.index + 1 for a in r1.atoms}
189 a2 = {a.name: a.index + 1 for a in r2.atoms}
190 lines.insert(-2, f"CONECT {a1['C']:>4} {a2['N']:>4}")
191 pdb_content = "\n".join(lines)
192 data_list = [
193 {
194 "filename": f"{randname}.pdb",
195 "ext": "pdb",
196 "selectedValue": f"{randname[:4].upper()}",
197 "chain": "ALL",
198 "aaRange": "ALL",
199 "chosen": {"atoms": "", "residues": ""},
200 "color": "sstruc",
201 "config": {"type": "text/plain", "input": pdb_content},
202 "resetView": True,
203 "uploaded": False,
204 }
205 ]
206 return data_list
209################################################################################
210# App
211################################################################################
214class DebugPage:
215 def __init__(self):
216 self.display = html.Code(id="debug-display")
217 dash.register_page("debug", layout=self.layout)
219 @property
220 def layout(self):
221 layout = html.Div(
222 [
223 html.H1("Debug Page"),
224 self.text_area,
225 self.display,
226 ]
227 )
228 return layout
230 @property
231 def text_area(self):
232 text_input = html.Div(
233 [
234 dbc.Alert(
235 "Keep in mind, that this webpage is bad practice. It can "
236 "execute arbitrary code. If you see this page on a deployed "
237 "webpage, kill the server immediately. Otherwise use this "
238 "page to access info about the running dash app. "
239 "Enter code to evaluate. Accept with Enter-key...",
240 color="danger",
241 ),
242 dbc.Textarea(
243 placeholder='print("Hello World!")',
244 rows=10,
245 id="debug-textarea",
246 ),
247 ]
248 )
249 return text_input
251 @staticmethod
252 @callback(
253 Output("debug-display", "children"),
254 Input("debug-textarea", "n_submit"),
255 State("debug-textarea", "value"),
256 prevent_initial_call=True,
257 )
258 def run_code(value, state):
259 try:
260 out = str(eval(state))
261 except Exception as e:
262 out = str(e)
263 return out
266class LocalUploadTraj:
267 _placeholder = """\
268 trajs = ["/path/to/traj1.xtc", "/path/to/traj2.xtc"]
269 tops = ["/path/to/traj1.pdb", "/path/to/traj2.pdb"]
270 common_str = ["traj1, "traj2"]
271 """
273 def __init__(self, main):
274 self.main = main
276 # the display
277 self.display = dbc.Container(
278 [dbc.Card([], id="upload-display", class_name="align-items-center")],
279 fluid=True,
280 )
282 # the upload area
283 self.upload_card_body = dbc.CardBody(
284 [
285 html.H4("Upload files"),
286 dcc.Upload(
287 id="upload-data-input",
288 children=html.Div(
289 ["Drag/Drop or ", html.A("Select Files")],
290 className="h-100",
291 ),
292 style={
293 "width": "100%",
294 "height": "60px",
295 "lineHeight": "60px",
296 "borderWidth": "1px",
297 "borderStyle": "dashed",
298 "borderRadius": "5px",
299 "textAlign": "center",
300 "margin": "10px",
301 },
302 # Allow multiple files to be uploaded
303 multiple=True,
304 ),
305 ]
306 )
307 self.upload_card = dbc.Col(
308 [
309 dbc.Card(
310 self.upload_card_body,
311 id="upload-data-input-card",
312 class_name="h-100",
313 )
314 ],
315 width=3,
316 class_name="col-sm-6 col-lg-3",
317 )
319 # the textarea
320 self.text_area_card_body = dbc.CardBody(
321 [
322 html.H4("Local files"),
323 dbc.Textarea(
324 placeholder=self._placeholder,
325 rows=3,
326 readonly=False,
327 id="upload-paths-input",
328 class_name="h-50",
329 ),
330 ],
331 )
332 self.text_area_card = dbc.Col(
333 [
334 dbc.Card(
335 self.text_area_card_body,
336 id="upload-paths-card",
337 class_name="h-100",
338 ),
339 ],
340 width=6,
341 class_name="col-sm-6 col-lg-6",
342 )
344 # the project area
345 self.input_card_body = dbc.CardBody(
346 [
347 html.H4("Project"),
348 dbc.Input(
349 placeholder="linear_dimers",
350 id="upload-project-input",
351 class_name="h-50",
352 ),
353 ],
354 )
355 self.input_card = dbc.Col(
356 [
357 dbc.Card(
358 self.input_card_body,
359 id="upload-project-card",
360 class_name="h-100",
361 ),
362 ],
363 width=3,
364 class_name="col-sm-6 col-lg-3",
365 )
367 # the complete container
368 self.upload_container = dbc.Container(
369 [
370 html.Div(
371 [
372 self.upload_card,
373 self.text_area_card,
374 self.input_card,
375 ],
376 id="upload-hide",
377 className="row",
378 style={"display": "none", "height": "30%"},
379 ),
380 html.Div(style={"margin-top": "20px"}),
381 dbc.Row(
382 [
383 dbc.Button(
384 "Upload",
385 id="upload-button",
386 style={"width": "95%", "margin": "auto"},
387 )
388 ]
389 ),
390 html.Br(),
391 dbc.Row(
392 [
393 dbc.Button(
394 "Linear Dimers Project",
395 id="linear-dimers-button",
396 style={"width": "95%", "margin": "auto"},
397 )
398 ]
399 ),
400 html.Br(),
401 dbc.Row(
402 [
403 dbc.Button(
404 "1am7 Project",
405 id="1am7-button",
406 style={"width": "95%", "margin": "auto"},
407 )
408 ]
409 ),
410 html.Br(),
411 dbc.Row(
412 [
413 dbc.Button(
414 "Reset",
415 id="upload-reset-button",
416 style={"width": "95%", "margin": "auto"},
417 )
418 ]
419 ),
420 ],
421 fluid=True,
422 style={"height": "75vh"},
423 )
425 # define the layout
426 self.layout = html.Div(
427 [
428 self.display,
429 self.main.store,
430 html.Br(),
431 dmc.LoadingOverlay(
432 self.upload_container,
433 ),
434 ],
435 style={"margin": "2%"},
436 )
438 # define local callbacks
439 self.main.app.callback(
440 Output("upload-display", "children"),
441 Output("upload-hide", "style"),
442 Output("main-store", "data"),
443 Input("upload-button", "n_clicks"),
444 Input("linear-dimers-button", "n_clicks"),
445 Input("1am7-button", "n_clicks"),
446 Input("upload-reset-button", "n_clicks"),
447 State("main-store", "data"),
448 State("upload-data-input", "contents"),
449 State("upload-data-input", "filename"),
450 State("upload-data-input", "last_modified"),
451 State("upload-paths-input", "value"),
452 State("upload-project-input", "value"),
453 )(self.load_trajs)
455 # register the pae
456 dash.register_page("upload", layout=self.layout)
458 def load_trajs(
459 self,
460 upload_n_clicks, # Input
461 reset_n_clicks, # Input
462 linear_dimers_n_clicks, # Input
463 n_click_1am7, # Input
464 main_store, # State
465 list_of_contents, # State
466 list_of_names, # State
467 list_of_dates, # State
468 textarea_value, # State
469 project_value, # State
470 ):
471 if main_store is None:
472 main_store = {}
473 triggered_id = ctx.triggered_id
475 # reset button pressed
476 if triggered_id == "upload-reset-button":
477 if hasattr(self.main, "trajs"):
478 del self.main.trajs
479 return (
480 dbc.CardBody(f"Session was reset. Choose MD data to upload."),
481 {"height": "30%"},
482 {},
483 )
485 if triggered_id == "linear-dimers-button":
486 if hasattr(self.main, "trajs"):
487 del self.main.trajs
488 if main_store is None:
489 main_store = {}
490 main_store["traj_type"] = "project"
491 main_store["traj"] = "linear_dimers"
492 self.main.traj_page.parse_trajs(main_store)
493 return (
494 dbc.CardBody(f"Data loaded. Press 'Reset' to reset the session."),
495 {"display": "none"},
496 main_store,
497 )
498 if triggered_id == "1am7-button":
499 if hasattr(self.main, "trajs"):
500 del self.main.trajs
501 if main_store is None:
502 main_store = {}
503 main_store["traj_type"] = "project"
504 main_store["traj"] = "1am7"
505 self.main.traj_page.parse_trajs(main_store)
506 return (
507 dbc.CardBody(f"Data loaded. Press 'Reset' to reset the session."),
508 {"display": "none"},
509 main_store,
510 )
512 empty = [
513 dbc.CardBody(f"Data loaded. Press 'Reset' to reset the session."),
514 {"display": "none"},
515 main_store,
516 ]
518 if triggered_id is None:
519 if isinstance(main_store, dict):
520 if "traj_type" in main_store:
521 return tuple(empty)
522 return (
523 dbc.CardBody(f"Choose MD data to upload."),
524 {"height": "30%"},
525 main_store,
526 )
528 # upload button pressed
529 uploaded_any = (
530 list_of_contents is not None
531 or textarea_value is not None
532 or project_value is not None
533 )
534 if triggered_id == "upload-button" and not uploaded_any:
535 return (
536 dbc.CardBody(
537 f"Place files in the upload window, or enter local files, "
538 f"or a project name before pressing upload."
539 ),
540 {"height": "30%"},
541 main_store,
542 )
544 # here we transform data
545 if list_of_contents is not None:
546 main_store.update({"traj_type": "paths", "traj": list_of_contents})
547 # self.main.traj_page.parse_trajs(main_store)
548 empty[0] = dbc.CardBody(f"Uploading files currently not implemented..")
549 elif textarea_value is not None:
550 main_store.update({"traj_type": "text", "traj": textarea_value})
551 self.main.traj_page.parse_trajs(main_store)
552 empty[0] = dbc.CardBody(
553 f"Loading trajectories. Go to the 'Traj' page to look at your data."
554 )
555 elif project_value is not None:
556 main_store.update({"traj_type": "project", "traj": project_value})
557 self.main.traj_page.parse_trajs(main_store)
558 empty[0] = dbc.CardBody(
559 f"Loading project '{project_value}'. Go to the 'Traj' page to look at your data."
560 )
562 empty[-1] = main_store
563 return tuple(empty)
566class TopPage:
567 def __init__(self, main):
568 self.main = main
570 # the display
571 self.display = dbc.Container(
572 [dbc.Card([], id="top-display", class_name="align-items-center")],
573 fluid=True,
574 )
576 # the topology dropdown selector
577 self.topology_container = dbc.Container(
578 [
579 dbc.Row(
580 [
581 dbc.Col(
582 dcc.Dropdown(
583 id="top-dynamic-dropdown",
584 style={"width": "100%"},
585 persistence=True,
586 persistence_type="session",
587 ),
588 width=3,
589 class_name="col-sm-6 col-lg-3",
590 ),
591 dbc.Col(
592 dbc.RadioItems(
593 options=[
594 {"label": "Atoms", "value": 0},
595 {"label": "Bonds", "value": 1},
596 {"label": "Angles", "value": 2},
597 {"label": "Dihedrals", "value": 3},
598 ],
599 inline=True,
600 persistence=True,
601 persistence_type="session",
602 value=0,
603 id="top-radioitems-input",
604 ),
605 width=5,
606 class_name="col-sm-6 col-lg-4",
607 ),
608 dbc.Col(
609 dbc.Row(
610 [
611 html.P("Atom Subset"),
612 dcc.RangeSlider(
613 0,
614 0,
615 value=[0, 0],
616 id="top-atom-indices-rangeslider",
617 persistence=True,
618 persistence_type="session",
619 ),
620 ],
621 ),
622 width=4,
623 class_name="col-sm-6 col-lg-4",
624 ),
625 ],
626 style={"width": "100%"},
627 ),
628 html.Br(),
629 dbc.Row(
630 [
631 dbc.Col(
632 [
633 dbc.Row(
634 [
635 dmc.JsonInput(
636 label="Custom Amino Acids:",
637 value="{}",
638 autosize=True,
639 minRows=30,
640 id="top-json-input",
641 ),
642 ],
643 ),
644 html.Br(),
645 dbc.Button(
646 "Load",
647 id="top-json-load",
648 style={"width": "95%", "margin": "auto"},
649 ),
650 ],
651 id="top-custom-aas",
652 width=3,
653 class_name="col-sm-6 col-lg-3",
654 ),
655 dbc.Col(
656 [
657 html.Div(
658 [],
659 id="top-top-plot",
660 style={"height": "800px"},
661 ),
662 ],
663 width=9,
664 class_name="col-sm-6, col-lg-9",
665 ),
666 ],
667 ),
668 ],
669 style={"display": "none", "height": "85%"},
670 id="top-hide",
671 fluid=True,
672 )
674 # the dummy div, that is used to run the context when the page is accessed
675 self.dummy_div = html.Div([], id="top-page-dummy-div")
677 # define the layout
678 self.layout = html.Div(
679 [
680 self.display,
681 self.main.store,
682 html.Br(),
683 dmc.LoadingOverlay(
684 self.topology_container,
685 ),
686 self.dummy_div,
687 ],
688 style={"margin": "2%"},
689 )
691 # decorate the interactiveness of the page
692 self.main.app.callback(
693 Output("top-display", "children"),
694 Output("top-hide", "style"),
695 Output("top-dynamic-dropdown", "options"),
696 State("top-json-input", "value"),
697 State("top-dynamic-dropdown", "value"),
698 Input("top-json-load", "n_clicks"),
699 Input("main-store", "data"),
700 Input("top-page-dummy-div", "children"),
701 )(self.display_top)
703 # display the custom amino acids of the selected topology as json
704 self.main.app.callback(
705 Output("top-json-input", "value"),
706 Output("top-atom-indices-rangeslider", "max"),
707 Output("top-atom-indices-rangeslider", "value"),
708 Input("top-dynamic-dropdown", "value"),
709 )(self.display_custom_aas)
711 # display the topology
712 self.main.app.callback(
713 Output("top-top-plot", "children"),
714 State("top-dynamic-dropdown", "value"),
715 Input("top-radioitems-input", "value"),
716 Input("top-page-dummy-div", "children"),
717 Input("top-atom-indices-rangeslider", "value"),
718 )(self.display_plot)
720 # register the page
721 dash.register_page(
722 "top",
723 layout=self.layout,
724 )
726 def display_plot(
727 self,
728 top_value,
729 radio_value,
730 dummy,
731 rangeslider_value,
732 ) -> Any:
733 if top_value is None:
734 top_value = 0
736 if not hasattr(self.main, "trajs"):
737 raise PreventUpdate
739 highlight: Literal["atoms", "bonds", "angles", "dihedrals"] = "atoms"
740 if radio_value == 0:
741 pass
742 elif radio_value == 1:
743 highlight = "bonds"
744 elif radio_value == 2:
745 highlight = "angles"
746 elif radio_value == 3:
747 highlight = "dihedrals"
749 if rangeslider_value != [0, 0]:
750 atom_indices = list(range(rangeslider_value[0], rangeslider_value[1]))
751 else:
752 atom_indices = None
754 top = self.main.trajs.top[top_value]
755 traj = self.main.trajs.trajs_by_top[top][0]
757 try:
758 fig = _plot_ball_and_stick(
759 traj,
760 highlight=highlight,
761 atom_indices=atom_indices,
762 )
763 except Exception as e:
764 raise Exception(f"{atom_indices=}") from e
765 self.main._figures.append(fig)
766 return dcc.Graph(
767 figure=fig,
768 )
770 return f"Pressed {radio_value=} {top_value=}"
772 def display_custom_aas(self, top):
773 if not hasattr(self.main, "trajs"):
774 raise PreventUpdate
775 if top is None:
776 return "{}", 0, [0, 0]
777 top = self.main.trajs.top[top]
778 trajs = self.main.trajs.trajs_by_top[top]
779 custom_aas = [t._custom_top for t in trajs]
780 if len(custom_aas) > 1:
781 if any([custom_aas[0] != c for c in custom_aas[1:]]):
782 msg = f"The trajectories contain different `_custom_aas`. I am not able to display them:"
783 for t in trajs:
784 msg += f"\n{t.basename}:\n{t._custom_top.to_json()}"
785 return msg
786 return custom_aas[0].to_json(), top.n_atoms, [0, top.n_atoms]
788 def get_options(self):
789 options = []
790 for i, top in enumerate(self.main.trajs.top):
791 top_str = str(top).lstrip("<mdtraj.Topology with ").rstrip(">")
792 options.append({"label": top_str, "value": i})
793 return options
795 def fill_dropdown(self, main_store, search_value):
796 if main_store is None:
797 raise PreventUpdate
798 if "traj_type" not in main_store:
799 raise PreventUpdate
800 if not hasattr(self.main, "trajs"):
801 raise PreventUpdate
802 return self.get_options()
804 def display_top(self, json_values, top_value, n_clicks, main_store, dummy):
805 blank_text = f"View and modify topologies on this page after you upload them."
806 empty = (dbc.CardBody(blank_text), {"display": "none"}, {})
807 if main_store is None:
808 return empty
809 if "traj_type" not in main_store:
810 return empty
811 if not hasattr(self.main, "trajs"):
812 self.main.traj_page.parse_trajs(main_store)
814 triggered_id = ctx.triggered_id
815 if triggered_id == "top-json-load":
816 if json_values == "{}":
817 return (
818 dbc.CardBody(
819 f"Provide custom amino-acids for the selected topology "
820 f"{self.main.trajs.trajs_by_top[self.main.trajs.top[top_value]]} "
821 f"to load."
822 ),
823 {"height": "85%"},
824 self.get_options(),
825 )
826 else:
827 try:
828 data = json.loads(json_values)
829 except json.decoder.JSONDecodeError as e:
830 return (
831 dbc.CardBody(
832 f"Couldn't parse your json. I got the error: {e}."
833 ),
834 {"height": "85%"},
835 self.get_options(),
836 )
837 try:
838 self.main.trajs.trajs_by_top[
839 self.main.trajs.top[top_value]
840 ].load_custom_topology(data)
841 except Exception as e:
842 return (
843 dbc.CardBody(f"Couldn't load the custom topology: {e}"),
844 {"height": "85%"},
845 self.get_options(),
846 )
847 return (
848 dbc.CardBody(f"Custom topology loaded."),
849 {"height": "85%"},
850 self.get_options(),
851 )
853 return (
854 dbc.CardBody(f"Topologies are available: {self.main.trajs.top=}"),
855 {"height": "85%"},
856 self.get_options(),
857 )
860class TrajPage: # pragma: no doccheck
861 def __init__(self, main):
862 self.main = main
863 self.decorated = False
864 self.main_div = html.Div([], id="traj-page-div")
865 self.dummy_div = html.Div([], id="traj-page-dummy-div")
866 dash.register_page(
867 "traj",
868 layout=html.Div(
869 [
870 self.main.store,
871 self.main_div,
872 self.dummy_div,
873 ]
874 ),
875 )
877 self.main.app.callback(
878 Output("traj-page-div", "children"),
879 Input("main-store", "data"),
880 Input("traj-page-dummy-div", "children"),
881 )(self.display_trajs)
883 def display_trajs(
884 self,
885 main_store,
886 dummy,
887 ):
888 if main_store is None:
889 return html.P(f"Trajs will appear here after you upload them.")
890 if "traj_type" not in main_store:
891 return html.P(f"Trajs will appear here after you upload them.")
893 if not hasattr(self.main, "trajs"):
894 self.parse_trajs(main_store)
896 return self.traj_loaded_layout(-1)
898 def display_traj(self, *args, **kwargs):
899 triggered_id = ctx.triggered_id
900 return html.Div(html.P(f"{args=} {kwargs=} {triggered_id=}"))
902 def traj_loaded_layout(self, traj_num: int = -1):
903 if traj_num == -1:
904 layout = html.Div(
905 [
906 dbc.Row(
907 [
908 dbc.DropdownMenu(
909 label="Trajectory number",
910 id="traj-page-dropdown",
911 children=[
912 dbc.DropdownMenuItem(
913 f"traj: {traj.traj_num} {traj.common_str}",
914 id=f"traj-page-dropdown-item-{traj.traj_num}",
915 )
916 for traj in self.main.trajs
917 ],
918 )
919 ]
920 ),
921 html.P(f"Here will trajs be displayed {self.main.trajs}"),
922 html.Div("This is where the custom React component should go."),
923 ]
924 )
926 # add the dynamic callback to display_traj
927 if not self.decorated:
928 self.main.app.callback(
929 Output("traj-page-display", "children"),
930 [
931 Input(f"traj-page-dropdown-item-{traj.traj_num}", "n_clicks")
932 for traj in self.main.trajs
933 ],
934 prevent_initial_call=True,
935 )(self.display_traj)
936 self.decorated = True
938 return layout
940 def parse_trajs(self, data):
941 if data["traj_type"] == "project":
942 self.main.trajs, self.main.autoencoder = load_project(
943 data["traj"], load_autoencoder=True
944 )
945 else:
946 raise NotImplementedError
949class ProjectionPage(InteractivePlotting):
950 def __init__(self, main):
951 self.main = main
952 self.scatter = None
954 # the display
955 self.display = dbc.Container(
956 [
957 dbc.Card(
958 [html.P("Go to the 'Load' page to load data.")],
959 id="projection-page-display",
960 class_name="align-items-center",
961 ),
962 ],
963 fluid=True,
964 )
966 # the topology dropdown selector
967 self.figure_widget = go.FigureWidget()
968 self.trace_widget = go.FigureWidget()
969 self.projection_container = dbc.Container(
970 [
971 dbc.Row(),
972 dbc.Row(
973 [
974 dbc.Col(
975 dcc.Graph(
976 figure=self.figure_widget, id="projection-page-plot"
977 ),
978 style={"height": "40vh"},
979 class_name="col-lg-6",
980 ),
981 dbc.Col(
982 dcc.Graph(
983 figure=self.trace_widget, id="projection-page-trace"
984 ),
985 style={"height": "40vh"},
986 class_name="col-lg-1",
987 ),
988 dbc.Col(
989 html.Div([], id="ngl-container", style={"height": "40vh"}),
990 style={"height": "40vh"},
991 class_name="col-lg-5",
992 id="projection-page-view",
993 ),
994 ],
995 style={"width": "100%", "height": "40%"},
996 ),
997 ],
998 style={"height": "85%"},
999 id="projection-page-hide",
1000 fluid=True,
1001 )
1003 # the dummy div, that is used to run the context when the page is accessed
1004 self.dummy_div = html.Div([], id="projection-page-dummy-div")
1006 # define the layout
1007 self.layout = html.Div(
1008 [
1009 self.display,
1010 self.main.store,
1011 html.Br(),
1012 dmc.LoadingOverlay(
1013 self.projection_container,
1014 ),
1015 self.dummy_div,
1016 ],
1017 style={"margin": "2%"},
1018 )
1020 dash.register_page(
1021 "projection",
1022 layout=self.layout,
1023 )
1025 self.main.app.callback(
1026 Output("projection-page-display", "children"),
1027 Output("projection-page-plot", "figure"),
1028 Output("ngl-container", "children"),
1029 Input("projection-page-plot", "clickData"),
1030 Input("projection-page-plot", "selectedData"),
1031 Input("projection-page-plot", "relayoutData"),
1032 State("main-store", "data"),
1033 )(self.interact)
1035 @property
1036 def molStyles(self):
1037 molstyles_dict = {
1038 "representations": ["cartoon"],
1039 "chosenAtomsColor": "white",
1040 "chosenAtomsRadius": 1,
1041 "molSpacingXaxis": 100,
1042 }
1043 return molstyles_dict
1045 def interact(
1046 self, click_on_plot, select_in_plot, relayoutdata, main_store
1047 ) -> tuple[html.P, go.FigureWidget, Any]:
1048 """Interactive elements:
1050 * Click on Scatter Point
1051 * Buttons:
1052 * switch between cluster and scatter
1053 * cluster
1054 * generate
1055 * slider
1056 * dropdown
1057 * Progress
1059 """
1060 empty = [
1061 html.P("Go to the 'Load' page to load data."),
1062 self.figure_widget,
1063 [],
1064 ]
1065 triggered_id = ctx.triggered_id
1066 if triggered_id == "projection-page-plot":
1067 if click_on_plot is None:
1068 raise PreventUpdate
1069 index = [p["pointIndex"] for p in click_on_plot["points"]][0]
1070 frame = self.main.trajs.get_single_frame(index)
1071 marker = {
1072 "color": [
1073 "#1f77b4" if i != index else "#ff7f0e"
1074 for i in range(len(self.lowd))
1075 ],
1076 "size": [7 if i != index else 20 for i in range(len(self.lowd))],
1077 }
1078 self.figure_widget.update_traces(
1079 marker=marker, selector=({"name": "scatter"})
1080 )
1081 viewer = dashbio.NglMoleculeViewer(
1082 data=traj_to_pdb(frame.traj),
1083 molStyles=self.molStyles,
1084 )
1085 empty[2] = [viewer]
1086 print(f"{click_on_plot=} {empty[2]=}")
1087 return tuple(empty)
1089 if self.main.trajs is None:
1090 if main_store is not None:
1091 if main_store != {}:
1092 self.main.traj_page.parse_trajs(main_store)
1093 if self.main.autoencoder is None:
1094 if main_store is not None:
1095 if main_store != {}:
1096 if main_store["traj_type"] == "project":
1097 empty[0] = html.P(
1098 f"The project {main_store['traj']} has no autoencoder associated to it."
1099 )
1100 else:
1101 empty[0] = html.P(f"The loaded trajs are not trained jet.")
1103 if self.main.autoencoder is not None:
1104 self.highd = self._highd_parser(
1105 self.main.autoencoder, highd=None, trajs=self.main.trajs
1106 )
1107 self.lowd = self._lowd_parser(
1108 self.main.autoencoder, lowd=None, trajs=self.main.trajs
1109 )
1110 empty[0] = html.P(f"Interact with the plot to view conformations.")
1112 # parse scatter
1113 if self.scatter is None and self.main.autoencoder is not None:
1114 self.lowd_dim = self.lowd.shape[1]
1115 self.scatter = go.Scattergl(
1116 x=self.lowd[:, 0],
1117 y=self.lowd[:, 1],
1118 mode="markers",
1119 name="scatter",
1120 marker={
1121 "size": [10 for i in range(len(self.lowd))],
1122 "color": ["#1f77b4" for i in range(len(self.lowd))],
1123 },
1124 )
1126 self.figure_widget.add_trace(self.scatter)
1128 print(
1129 f"Interacting {triggered_id=} {click_on_plot=} {select_in_plot=} {relayoutdata=} {self.main.trajs=} {main_store=} {self.scatter=}"
1130 )
1132 return tuple(empty)
1135class Dashboard:
1136 _encodermap_logo = "https://raw.githubusercontent.com/AG-Peter/encodermap/main/pic/logo_cube_300.png"
1138 def __init__(self):
1139 # create a dir for cache
1140 if self.local:
1141 self.cache_dir = Path("/tmp/encodermap_dash_cache")
1142 self.cache_dir.mkdir(parents=True, exist_ok=True)
1143 else:
1144 raise NotImplementedError("No cachedir in non-local mode.")
1146 # create the app and register the main page
1147 self.server = flask.Flask(__name__)
1148 self.store = dcc.Store(id="main-store", storage_type="session")
1149 self.trajs = None
1150 self.autoencoder = None
1151 self.app = Dash(
1152 server=self.server,
1153 external_stylesheets=[
1154 dbc.themes.BOOTSTRAP,
1155 dbc.icons.FONT_AWESOME,
1156 ],
1157 use_pages=True,
1158 pages_folder="",
1159 )
1160 vault_file = Path(__file__).resolve().parent.parent.parent / "act.vault"
1161 if vault_file.is_file():
1162 username = get_assign_from_file(vault_file, "DASH_USER")
1163 password = get_assign_from_file(vault_file, "DASH_PASSWORD")
1164 auth = dash_auth.BasicAuth(self.app, {username: password})
1165 self.app.title = "EncoderMap Dashboard"
1166 self.app._favicon = "favicon.ico"
1168 # theme changer
1169 self.theme_changer = ThemeChangerAIO(
1170 aio_id="theme-change",
1171 button_props={
1172 "color": "secondary",
1173 "class_name": "me-1",
1174 "outline": False,
1175 "style": {"margin-top": "5px"},
1176 },
1177 radio_props={
1178 "persistence": True,
1179 "persistence_type": "session",
1180 },
1181 )
1183 # create other pages
1184 if self.debug:
1185 self.debug_page = DebugPage()
1186 if self.local:
1187 self.upload_traj_page = LocalUploadTraj(self)
1188 else:
1189 raise Exception("Write a non-local upload page")
1190 self.traj_page = TrajPage(self)
1191 self.top_page = TopPage(self)
1192 self.projection_page = ProjectionPage(self)
1194 # collect instance attributes
1195 if self.local:
1196 self._greeting = f"EncoderMap Dashboard for {os.getlogin()}"
1197 else:
1198 self._greeting = "EncoderMap Dashboard"
1199 self._figures = []
1201 # init methods are divided into distinct methods to make them more legible
1202 dash.register_page("home", path="/", layout=self.layout)
1203 self.app.layout = self.app_layout
1205 # decorate the class callbacks
1206 callback(
1207 Output("placeholder", "children"),
1208 Input(ThemeChangerAIO.ids.radio("theme"), "value"),
1209 prevent_initial_call=True,
1210 )(self.update_theme)
1212 @property
1213 def app_layout(self):
1214 return html.Div(
1215 [
1216 self.navbar,
1217 dash.page_container,
1218 self.store,
1219 self.placeholder,
1220 ]
1221 )
1223 @property
1224 def layout(self):
1225 layout = html.Div(
1226 dbc.Container(
1227 [
1228 html.H1("EncoderMap dashboard", className="display-3"),
1229 html.P(
1230 "Use the 'Upload' page to upload your MD data.",
1231 className="lead",
1232 ),
1233 html.Hr(className="my-2"),
1234 html.P(
1235 [
1236 "Check out EncoderMap's GitHub page: ",
1237 html.A(
1238 "https://github.com/AG-Peter/encodermap",
1239 href="https://github.com/AG-Peter/encodermap",
1240 ),
1241 ]
1242 ),
1243 html.P(
1244 [
1245 "Read EncoderMap's documentation: ",
1246 html.A(
1247 "https://ag-peter.github.io/encodermap/",
1248 href="https://ag-peter.github.io/encodermap/",
1249 ),
1250 ]
1251 ),
1252 html.P(
1253 [
1254 "Give credit to the authors: ",
1255 html.Ul(
1256 [
1257 html.Li(
1258 html.A(
1259 "https://pubs.acs.org/doi/abs/10.1021/acs.jctc.8b00975",
1260 href="https://pubs.acs.org/doi/abs/10.1021/acs.jctc.8b00975",
1261 ),
1262 ),
1263 html.Li(
1264 html.A(
1265 "https://pubs.acs.org/doi/abs/10.1021/acs.jcim.9b00675",
1266 href="https://pubs.acs.org/doi/abs/10.1021/acs.jcim.9b00675",
1267 ),
1268 ),
1269 ]
1270 ),
1271 ]
1272 ),
1273 ],
1274 fluid=True,
1275 className="py-3",
1276 ),
1277 className="p-3 bg-light rounded-3",
1278 )
1279 return layout
1281 @property
1282 def navbar(self):
1283 navbar = dbc.Navbar(
1284 dbc.Container(
1285 [
1286 html.A(
1287 # Use row and col to control vertical alignment of logo / brand
1288 dbc.Row(
1289 [
1290 dbc.Col(
1291 html.Img(src=self._encodermap_logo, height="30px")
1292 ),
1293 dbc.Col(
1294 dbc.NavbarBrand(self._greeting, className="ms-2")
1295 ),
1296 dbc.Col(
1297 [
1298 dbc.Nav(
1299 [
1300 dbc.NavItem(
1301 dbc.NavLink(
1302 page["name"],
1303 href=page["relative_path"],
1304 )
1305 )
1306 for page in dash.page_registry.values()
1307 if page["name"] != "Home"
1308 ],
1309 navbar=True,
1310 ),
1311 ],
1312 width={"size": "auto"},
1313 ),
1314 ],
1315 align="center",
1316 className="g-0",
1317 ),
1318 href="/",
1319 style={"textDecoration": "none"},
1320 ),
1321 dbc.Row(
1322 [
1323 dbc.Col(
1324 [
1325 dbc.Nav(
1326 [
1327 dbc.NavItem(
1328 dbc.NavLink(
1329 html.I(
1330 className="fa-solid fa-heart",
1331 style={"font-size": "1.5em"},
1332 ),
1333 href="https://www.chemie.uni-konstanz.de/ag-peter/",
1334 )
1335 ),
1336 dbc.NavItem(
1337 dbc.NavLink(
1338 html.I(
1339 className="fa-solid fa-book-open",
1340 style={"font-size": "1.5em"},
1341 ),
1342 href="https://ag-peter.github.io/encodermap/",
1343 )
1344 ),
1345 dbc.NavItem(
1346 dbc.NavLink(
1347 html.I(
1348 className="fa-brands fa-square-github",
1349 style={"font-size": "1.5em"},
1350 ),
1351 href="https://github.com/AG-Peter/encodermap",
1352 )
1353 ),
1354 dbc.NavItem(self.theme_changer),
1355 ]
1356 ),
1357 ]
1358 ),
1359 ],
1360 align="center",
1361 ),
1362 ],
1363 fluid=True,
1364 ),
1365 color="primary",
1366 dark=True,
1367 style={"justify-content": "left"},
1368 )
1369 return navbar
1371 @property
1372 def placeholder(self):
1373 return html.P(id="placeholder")
1375 def update_theme(self, theme):
1376 for fig in self._figures:
1377 template = template_from_url(theme)
1378 fig.template = template
1380 @property
1381 def debug(self):
1382 return os.getenv("ENCODERMAP_DASH_DEBUG", "False") == "True"
1384 @property
1385 def local(self):
1386 return os.getenv("ENCODERMAP_DASH_RUN_LOCAL", "False") == "True"
1388 def run(self, *args, **kwargs):
1389 self.app.run(*args, host="0.0.0.0", debug=self.debug, **kwargs)
1392################################################################################
1393# Main
1394################################################################################
1397if __name__ == "__main__":
1398 dashboard = Dashboard()
1400 dashboard.run()