Coverage for encodermap/misc/misc.py: 10%
243 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/misc/misc.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Miscellaneous functions."""
24##############################################################################
25# Imports
26##############################################################################
29# Future Imports at the top
30from __future__ import annotations
32# Standard Library Imports
33import os
34from itertools import groupby
35from pathlib import Path
36from tempfile import NamedTemporaryFile
38# Third Party Imports
39import numpy as np
40import tensorflow as tf
41from optional_imports import _optional_import
44################################################################################
45# Optional Imports
46################################################################################
49nx = _optional_import("networkx")
50Image = _optional_import("PIL", "Image")
51ImageDraw = _optional_import("PIL", "ImageDraw")
52ImageFont = _optional_import("PIL", "ImageFont")
55################################################################################
56# Typing
57################################################################################
60# Standard Library Imports
61from collections.abc import Iterable, Sequence
62from typing import TYPE_CHECKING, Any, Optional, Union, overload
65if TYPE_CHECKING:
66 # Standard Library Imports
67 from pathlib import Path
69 # Local Folder Imports
70 from ..trajinfo.info_all import TrajEnsemble
73################################################################################
74# Globals
75################################################################################
78__all__: list[str] = [
79 "create_n_cube",
80 "plot_model",
81 "run_path",
82 "get_full_common_str_and_ref",
83]
86FEATURE_NAMES = {
87 "AllCartesians": "all_cartesians",
88 "AllBondDistances": "all_distances",
89 "CentralCartesians": "central_cartesians",
90 "CentralBondDistances": "central_distances",
91 "CentralAngles": "central_angles",
92 "CentralDihedrals": "central_dihedrals",
93 "SideChainCartesians": "side_cartesians",
94 "SideChainBondDistances": "side_distances",
95 "SideChainAngles": "side_angles",
96 "SideChainDihedrals": "side_dihedrals",
97}
100_TOPOLOGY_EXTS = [
101 ".pdb",
102 ".pdb.gz",
103 ".h5",
104 ".lh5",
105 ".prmtop",
106 ".parm7",
107 ".prm7",
108 ".psf",
109 ".mol2",
110 ".hoomdxml",
111 ".gro",
112 ".arc",
113 ".hdf5",
114 ".gsd",
115]
118################################################################################
119# Utilities
120################################################################################
123def _is_notebook() -> bool:
124 """Returns 'True', if we are currently in a notebook.
126 Returns:
127 bool: True if we are in a notebook, False otherwise.
129 """
130 try:
131 # Third Party Imports
132 from IPython import get_ipython
134 if "IPKernelApp" not in get_ipython().config: # pragma: no cover
135 return False
136 except ImportError: 136 ↛ 137line 136 didn't jump to line 137, because the exception caught by line 136 didn't happen
137 return False
138 except AttributeError:
139 return False
140 return True
143def _can_be_feature(inp: Union[str, list[str]]) -> bool:
144 """Function to decide whether the input can be interpreted by the Featurizer class.
146 Outputs True, if inp == 'all' or inp is a list of strings contained in FEATURE_NAMES.
148 Args:
149 inp (Any): The input.
151 Returns:
152 bool: True, if inp can be interpreted by featurizer.
154 Example:
155 >>> from encodermap.misc.misc import _can_be_feature
156 >>> _can_be_feature('all')
157 True
158 >>> _can_be_feature('no')
159 False
160 >>> _can_be_feature(['AllCartesians', 'central_dihedrals'])
161 True
163 """
164 if isinstance(inp, str):
165 if inp == "all":
166 return True
167 if isinstance(inp, list):
168 if all([isinstance(i, str) for i in inp]):
169 if all(
170 [i in FEATURE_NAMES.keys() or i in FEATURE_NAMES.values() for i in inp]
171 ):
172 return True
173 return False
176def match_files(
177 trajs: Union[list[str], list[Path]],
178 tops: Union[list[str], list[Path]],
179 common_str: list[str],
180) -> tuple[list[str], list[str]]:
181 """Matches trajectory and topology files with a list of common strings.
183 The idea behind this function is to use one topology file with
184 multiple trajectory files, that all use the same topology.
186 Args:
187 trajs (Union[list[str], list[Path]]): A list of str or list of Path objects.
188 These are the trajectory files we want to match the subset of
189 topology files with.
190 tops (Union[list[str], list[Path]]): A list of str or list of Path objects.
191 These are the topology files we will assign to the trajectory files
192 using common substrings.
193 common_str (list[str]): A list of common substrings.
195 Returns:
196 tuple[list[str], list[str]]: A tuple containing a two list of str.
197 The first list of str are the topology files matched to the 'trajs'.
198 The second list of str are the common_str matched to the 'trajs'.
199 Both lists have the same length as `trajs`.
201 """
202 tops_out = []
203 common_str_out = []
205 trajs = list(map(str, trajs))
206 tops = list(map(str, tops))
208 if all([Path(p).suffix == ".h5" for p in trajs]) and len(trajs) == len(tops) == len(
209 common_str
210 ):
211 return tops, common_str
213 if len(trajs) == len(tops):
214 try:
215 first_cs = common_str[[cs in trajs[0] for cs in common_str].index(True)]
216 except ValueError as e:
217 raise Exception(
218 f"Can't find any matching common_str in the input data:\n{trajs=}\n\n"
219 f"{tops=}\n\n{common_str=}"
220 ) from e
221 if first_cs not in tops[0]:
222 iterator = trajs
223 else:
224 iterator = zip(trajs, tops)
225 else:
226 iterator = trajs
228 for t in iterator:
229 if isinstance(t, tuple):
230 t, top = t
231 if not any([cs in t for cs in common_str]):
232 raise Exception(
233 f"The traj file {t} does not match any of the common_str you provided."
234 )
235 else:
236 t_lcut = max([t.rfind(cs) for cs in common_str])
237 t_lcut = t[t_lcut:]
238 cs = common_str[[cs in t_lcut for cs in common_str].index(True)]
239 if t.split(".")[-1] == "h5":
240 tops_out.append(t)
241 assert cs in t, f"{cs=} {top=} {t=}"
242 else:
243 tops_out.append(top)
244 assert cs in top and cs in t
245 common_str_out.append(cs)
246 else:
247 if not any([cs in t for cs in common_str]):
248 raise Exception(
249 f"The traj file {t} does not match any of the common_str you provided."
250 )
251 else:
252 t_lcut = max([t.rfind(cs) for cs in common_str])
253 t_lcut = t[t_lcut:]
254 cs = common_str[[cs in t_lcut for cs in common_str].index(True)]
255 if t.split(".")[-1] == "h5":
256 tops_out.append(trajs[[cs in r for r in trajs].index(True)])
257 else:
258 tops_out.append(tops[[cs in r for r in tops].index(True)])
259 assert cs in t
260 common_str_out.append(cs)
261 return tops_out, common_str_out
264def get_full_common_str_and_ref(
265 trajs: Union[list[str], list[Path]],
266 tops: Union[list[str], list[Path]],
267 common_str: list[str],
268) -> tuple[list[str], list[str], list[str]]:
269 """Matches traj_files, top_files and common string and returns lists with the
270 same length matching the provided common str.
272 Args:
273 trajs (list[str]): A list of str pointing to trajectory files.
274 tops (list[str]): A list of str pointing to topology files.
275 common_str (list[str]): A list of strings that can be found in
276 both trajs and tops (i.e. substrings).
278 Args:
279 tuple: A tuple containing the following:
280 - list[str]: A list of str with the traj file names.
281 - list[str]: A list of str with the top file names.
282 - list[str]: A list of str with the common_str's.
283 All lists have the same length.
285 """
286 assert isinstance(common_str, list)
287 if len(trajs) != len(tops) and common_str == [] and len(tops) != 1:
288 raise Exception(
289 "When providing a list of trajs and a list of refs with different "
290 "length you must provide a list of common_str to match them."
291 )
293 # if the length of all objects is the same we just return them
294 # but also we check, whether the common_str appears in the
295 # trajs and tops to make sure its correclty ordered
296 if len(trajs) == len(tops) == len(common_str):
297 if not all([i is None for i in common_str]):
298 for traj, top, cs in zip(trajs, tops, common_str):
299 if cs not in str(traj) or cs not in str(top):
300 return (trajs, *match_files(trajs, tops, common_str))
301 return trajs, tops, common_str
303 # if trajs and tops are the same length, they are expected to match
304 elif len(trajs) == len(tops):
305 if len(common_str) == 0:
306 return trajs, tops, [None for i in trajs]
308 elif len(common_str) == 1:
309 return trajs, tops, [common_str[0] for i in trajs]
311 else:
312 tops_out, common_str_out = match_files(trajs, tops, common_str)
313 return trajs, tops_out, common_str_out
315 # if only one topology is provided, we hope the user passed a correct one and fill everything else up
316 elif len(trajs) > 1 and len(tops) == 1:
317 tops_out = [tops[0] for t in trajs]
319 if common_str == []:
320 common_str_out = [Path(traj).stem for traj in trajs]
322 elif len(common_str) != len(trajs):
323 tops_out, common_str_out = match_files(trajs, tops_out, common_str)
325 elif len(common_str) == len(trajs):
326 common_str_out = common_str
328 return trajs, tops_out, common_str_out
330 # in other cases, we need to do something similar
331 else:
332 if len(tops) > len(trajs):
333 raise Exception(
334 f"I was given more topologies {tops} than trajectories {trajs} . Something's not right."
335 )
336 if len(common_str) > len(trajs):
337 raise Exception(
338 "I was given more common strings than trajectories. Something's not right."
339 )
341 if common_str == []:
342 common_str_out = [Path(traj).stem for traj in trajs]
344 elif len(common_str) != len(trajs):
345 tops_out, common_str_out = match_files(trajs, tops, common_str)
347 elif len(common_str) == len(trajs):
348 common_str_out = common_str
350 tops_out, common_str_out = match_files(trajs, tops, common_str_out)
351 return trajs, tops_out, common_str_out
354def printTable(
355 myDict: dict[str, dict[str, Any]],
356 colList: Optional[list[str]] = None,
357 sep: str = "\uFFFA",
358) -> str:
359 """Pretty print a list of dictionaries (myDict) as a dynamically sized table.
360 If column names (colList) aren't specified, they will show in random order.
361 sep: row separator. Ex: sep='\n' on Linux. Default: dummy to not split line.
362 Author: Thierry Husson - Use it as you want but don't blame me.
363 """
364 out = []
365 if not colList:
366 colList = list(myDict[0].keys() if myDict else [])
367 myList = [colList] # 1st row = header
368 for item in myDict:
369 myList.append([str(item[col] or "") for col in colList])
370 colSize = [max(map(len, (sep.join(col)).split(sep))) for col in zip(*myList)]
371 formatStr = " | ".join(["{{:<{}}}".format(i) for i in colSize])
372 line = formatStr.replace(" | ", "-+-").format(*["-" * i for i in colSize])
373 item = myList.pop(0)
374 lineDone = False
375 while myList:
376 if all(not i for i in item):
377 item = myList.pop(0)
378 if line and (sep != "\uFFFA" or not lineDone):
379 out.append(line)
380 lineDone = True
381 row = [i.split(sep, 1) for i in item]
382 out.append(formatStr.format(*[i[0] for i in row]))
383 item = [i[1] if len(i) > 1 else "" for i in row]
384 out = [" " + i for i in out]
385 points = " \n".join(out)
386 return points
389##############################################################################
390# Functions
391##############################################################################
394def _datetime_windows_and_linux_compatible() -> str:
395 """Portable way to get `now` as either a linux or windows compatible string.
397 For linux systems, strings in this manner will be returned:
398 2022-07-13T16:04:04+02:00
400 For Windows systems, strings in this manner will be returned:
401 2022-07-13_16-04-46
403 """
404 # Standard Library Imports
405 import datetime
406 from sys import platform
408 if platform == "linux" or platform == "linux2" or platform == "darwin":
409 return datetime.datetime.now().astimezone().replace(microsecond=0).isoformat()
410 elif platform == "win32":
411 return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
414def all_equal(iterable: Iterable[Any]) -> bool:
415 """Returns True, when all elements in a list are equal
417 Args:
418 iterable (Iterable[Any]): An iterable of Any.
420 Returns:
421 bool: Whether all elements in `iterable` are equal.
423 """
424 g = groupby(iterable)
425 return next(g, True) and not next(g, False)
428def _validate_uri(str_: Union[Path, str]) -> bool:
429 """Checks whether the a str or Path is a valid uri.
431 Args:
432 str_ (Union[Path, str]): The str or Path to test.
434 Returns:
435 bool: True when `str_` is a valid URI, False otherwise.
437 """
438 # Standard Library Imports
439 from urllib.parse import urlparse
441 try:
442 result = urlparse(str_)
443 return all([result.scheme, result.path])
444 except:
445 return False
448@overload
449def _flatten_model( 449 ↛ exitline 449 didn't jump to the function exit
450 model_nested: tf.keras.Model,
451 input_dim: Optional[Sequence[int]] = None,
452 return_model: bool = True,
453) -> tf.keras.Model: ...
456@overload
457def _flatten_model( 457 ↛ exitline 457 didn't jump to the function exit
458 model_nested: tf.keras.Model,
459 input_dim: Optional[Sequence[int]] = None,
460 return_model: bool = False,
461) -> list[tf.keras.layers.Layer]: ...
464def _flatten_model(
465 model_nested: tf.keras.Model,
466 input_dim: Optional[Sequence[int]] = None,
467 return_model: bool = True,
468) -> Union[tf.keras.Model, list[tf.keras.layers.Layer]]:
469 """Flattens a nested tensorflow.keras.models.Model.
471 Can be useful if a model consists of two sequential models and needs to
472 be flattened to be plotted.
474 Union[tf.keras.Model, list[tf.keras.layers.Layer]]:
475 Either a tf.keras.Model or a list of layers.
477 """
478 layers_flat = []
479 for layer in model_nested.layers:
480 try:
481 layers_flat.extend(layer.layers)
482 except AttributeError:
483 layers_flat.append(layer)
484 if return_model:
485 model_flat = tf.keras.models.Sequential(layers_flat)
486 model_flat.build(input_shape=(1, input_dim))
487 return model_flat
488 else:
489 return layers_flat
492def plot_model(
493 model: tf.keras.Model,
494 input_dim: Optional[Sequence[int]] = None,
495) -> Optional[Image]:
496 """Plots keras model using tf.keras.utils.plot_model"""
497 # Encodermap imports
498 from encodermap.models.models import SequentialModel
500 if isinstance(model, SequentialModel) and input_dim is not None:
501 model = _flatten_model(model, input_dim)
502 with NamedTemporaryFile(suffix=".png") as f:
503 try:
504 _ = tf.keras.utils.plot_model(
505 model,
506 to_file=f.name,
507 show_shapes=True,
508 rankdir="TB",
509 expand_nested=True,
510 )
511 except:
512 return
513 return Image.open(f.name)
514 with NamedTemporaryFile(suffix=".png") as f:
515 try:
516 _ = tf.keras.utils.plot_model(
517 model,
518 to_file=f.name,
519 show_shapes=True,
520 rankdir="TB",
521 expand_nested=True,
522 )
523 except:
524 return
525 img = Image.open(f.name)
526 # img.putalpha(alpha=255)
527 cube_img = Image.open(Path(__file__).resolve().parent / "logo_cube_300.png")
528 img.paste(cube_img, (800, 500), cube_img)
529 try:
530 font = ImageFont.truetype(
531 "/usr/share/fonts/truetype/dejavu/DejaVuMathTeXGyre.ttf", size=40
532 )
533 ImageDraw.Draw(img).text( # Image
534 (700, 800), # Coordinates
535 "Made with EncoderMap", # Text
536 (0, 0, 0), # Color
537 font=font,
538 )
539 except:
540 pass
541 return img
543 # plt.close("all")
544 # plt.imshow(img)
545 # plt.show()
546 # if mpl.get_backend() == "module://ipykernel.pylab.backend_inline":
547 # fig = plt.gcf()
548 # fig.set_size_inches(fig.get_size_inches() * 4)
549 # ax = plt.gca()
550 # ax.axis("off")
551 # plt.show()
554def run_path(path: str) -> str:
555 """Creates a directory at "path/run{i}" where the i is corresponding to the
556 smallest not yet existing path.
558 Args:
559 path (str): Path to the run folder.
561 Returns:
562 str: The new output path.
564 Exampples:
565 >>> import encodermap as em
566 >>> import tempfile
567 >>> import os
568 >>>
569 >>> def sort_key(inp: str) -> int:
570 ... return int(inp[-1])
571 >>>
572 >>> with tempfile.TemporaryDirectory() as td:
573 ... # create some directories
574 ... os.makedirs(os.path.join(td, "run0"))
575 ... os.makedirs(os.path.join(td, "run1"))
576 ... # em.misc.run_path will automatically advance the counter to 'run2'
577 ... new_path = em.misc.run_path(td)
578 ... print(new_path)
579 ... print(sorted(os.listdir(td), key=sort_key)) # doctest: +ELLIPSIS
580 /tmp/.../run2
581 ['run0', 'run1', 'run2']
583 """
584 i = 0
585 while True:
586 current_path = os.path.join(path, "run{}".format(i))
587 if not os.path.exists(current_path):
588 os.makedirs(current_path)
589 output_path = current_path
590 break
591 else:
592 i += 1
593 return output_path
596def create_n_cube(
597 n: int = 3,
598 points_along_edge: int = 500,
599 sigma: float = 0.05,
600 same_colored_edges: int = 3,
601 seed: Optional[int] = None,
602) -> tuple[np.ndarray, np.ndarray]:
603 """Creates points along the edges of an n-dimensional unit hyper-cube.
605 The cube is created using `networkx.hypercube_graph` and points are placed along
606 the edges of the cube. By providing a `sigma` value, the points can be shifted
607 by some Gaussian noise.
609 Args:
610 n (int): The dimension of the Hypercube (can also take 1 or 2).
611 Defaults to 3.
612 points_along_edge (int): How many points should be placed along any edge.
613 By increasing the number of dimensions, the number of edges
614 increases, which also increases the total number of points. Defaults to 500.
615 sigma (float): The sigma value for np.random.normal which
616 introduces Gaussian noise to the positions of the points. Defaults to 0.05.
617 same_color_edges (int): How many edges of the Hypercube should
618 be colored with the same color. This can be used to later
619 better visualize the edges of the cube. Defaults to 3.
620 seed (Optional[int]): If an int is provided, this will be used as a seed
621 for `np.random` and fix the random state. Defaults to None which produces
622 random results every time this function is called.
624 Returns:
625 tuple: A tuple containing the following:
626 - np.ndarray: The coordinates of the points.
627 - np.ndarray: Integers that can be used for coloration.
629 Example:
630 >>> from encodermap.misc.misc import create_n_cube
631 >>> # A sigma value of zero means no noise at all.
632 >>> coords, colors = create_n_cube(2, sigma=0)
633 >>> coords[0]
634 array([0., 0.])
636 """
637 if seed is not None:
638 np.random.seed(seed=seed)
639 # create networkx hypercube with given dimensions
640 G = nx.hypercube_graph(n)
642 # vertices are not really needed
643 vertices = np.array([n for n in G.nodes])
644 # get edges
645 edges = np.array([e for e in G.edges])
647 # fill this list with values
648 coordinates = []
650 # iterate over edges
651 for i, edge in enumerate(edges):
652 # some basic analytic geomerty
653 A, B = edge
654 AB = B - A
655 # n points along edge
656 lin = np.linspace(0, 1, points_along_edge)
657 points = A + (AB[:, None] * lin).T
658 if sigma:
659 points += np.random.normal(scale=sigma, size=(len(points), n))
660 # add label for colors
661 points = np.hstack([points, np.full((len(points), 1), i)])
662 coordinates.extend(points)
664 # make big numpy array
665 coordinates = np.array(coordinates)
667 # color the specified number of same colored edges
668 # choose a random edge
669 found_edges = []
670 edge_pairs = []
672 # iterate over the number of same colore edges
673 for _ in range(same_colored_edges):
674 for i, edge in enumerate(edges):
675 if i in found_edges:
676 continue
677 found_edges.append(i)
678 vertex = edge[0]
679 where = np.where(np.all(edges[:, 0] == vertex, axis=1))[0]
680 for j in where:
681 new_edge = edges[j]
682 if j not in found_edges:
683 found_edges.append(j)
684 break
685 if i != j:
686 edge_pairs.append([i, j])
687 break
689 # replace the corresponding indices
690 for i, j in edge_pairs:
691 new = coordinates[coordinates[:, -1] == i]
692 new[:, -1] = np.full(points_along_edge, j)
693 coordinates[coordinates[:, -1] == i] = new
695 return coordinates[:, :-1], coordinates[:, -1]