Coverage for encodermap/misc/misc.py: 31%
226 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/misc/misc.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Miscellaneous functions."""
24##############################################################################
25# Imports
26##############################################################################
29import os
30from itertools import groupby
31from pathlib import Path
33import matplotlib as mpl
34import matplotlib.pyplot as plt
35import numpy as np
36import tensorflow as tf
38from .._optional_imports import _optional_import
39from .errors import BadError
41################################################################################
42# Optional Imports
43################################################################################
46nx = _optional_import("networkx")
49################################################################################
50# Globals
51################################################################################
54__all__ = ["create_n_cube", "plot_model", "run_path", "get_full_common_str_and_ref"]
57FEATURE_NAMES = {
58 "AllCartesians": "all_cartesians",
59 "AllBondDistances": "all_distances",
60 "CentralCartesians": "central_cartesians",
61 "CentralBondDistances": "central_distances",
62 "CentralAngles": "central_angles",
63 "CentralDihedrals": "central_dihedrals",
64 "SideChainCartesians": "side_cartesians",
65 "SideChainBondDistances": "side_distances",
66 "SideChainAngles": "side_angles",
67 "SideChainDihedrals": "side_dihedrals",
68}
71_TOPOLOGY_EXTS = [
72 ".pdb",
73 ".pdb.gz",
74 ".h5",
75 ".lh5",
76 ".prmtop",
77 ".parm7",
78 ".prm7",
79 ".psf",
80 ".mol2",
81 ".hoomdxml",
82 ".gro",
83 ".arc",
84 ".hdf5",
85 ".gsd",
86]
89################################################################################
90# Utilities
91################################################################################
94def scale_projs(trajs, boundary, cols=None, debug=False):
95 """Scales the projections and moves outliers to their closest points.
97 Makes sure to not place a new point where there already is a point with a while loop."""
98 import glob
99 import os
101 import numpy as np
103 for traj in trajs:
104 data = traj.lowd
105 data_scaled = []
107 outside = []
108 inside = []
109 dist_min = 99999999.9
110 comp_min = [0, 0]
112 for line in data:
113 if (abs(line[0]) > boundary) or (abs(line[1]) > boundary):
114 # if outside boundary find closest point
115 outside.append(line)
116 else:
117 inside.append(line)
118 data_scaled.append(line)
120 for line in outside:
121 for comp in inside:
122 dist = np.linalg.norm(line[:3] - comp[:3])
123 if dist < dist_min:
124 dist_min = dist
125 comp_min = comp
126 if debug:
127 print("scaling outlier point at " + str(line[0]) + ", " + str(line[1]))
128 # if cols is provided, only scale these points
129 addition = 0.01 * np.random.rand(trajs.dim)
130 new_point = comp_min
131 while np.any(np.isin(new_point, data_scaled)):
132 if cols is None:
133 new_point = [comp_min[i] + addition[i] for i in range(len(line))]
134 else:
135 new_point = [comp_min[i] + addition[i] for i in range(len(cols))]
136 addition += 0.01 * np.random.rand(trajs.dim)
137 data_scaled.append(new_point)
139 if not len(data) == len(data_scaled):
140 raise Exception("This method did not work")
142 traj.lowd = np.vstack(data_scaled)
143 try:
144 this = traj.lowd.shape[1]
145 except IndexError:
146 print(traj.basename)
147 print(traj.lowd)
148 print(data_scaled)
149 raise
152def _can_be_feature(inp):
153 """Function to decide whether the input can be interpreted by the Featurizer class.
155 Outputs True, if inp == 'all' or inp is a list of strings contained in FEATURE_NAMES.
157 Args:
158 inp (Any): The input.
160 Returns:
161 bool: True, if inp can be interpreted by featurizer.
163 Example:
164 >>> from encodermap.misc.misc import _can_be_feature
165 >>> _can_be_feature('all')
166 True
167 >>> _can_be_feature('no')
168 False
169 >>> _can_be_feature(['AllCartesians', 'central_dihedrals'])
170 True
172 """
173 if isinstance(inp, str):
174 if inp == "all":
175 return True
176 if isinstance(inp, list):
177 if all([isinstance(i, str) for i in inp]):
178 if all(
179 [i in FEATURE_NAMES.keys() or i in FEATURE_NAMES.values() for i in inp]
180 ):
181 return True
182 return False
185def match_files(trajs, tops, common_str):
186 tops_out = []
187 common_str_out = []
189 trajs = list(map(str, trajs))
190 tops = list(map(str, tops))
192 for t in trajs:
193 if not any([cs in t for cs in common_str]):
194 raise BadError(
195 f"The traj file {t} does not match any of the common_str you provided."
196 )
197 else:
198 t_lcut = max([t.rfind(cs) for cs in common_str])
199 t_lcut = t[t_lcut:]
200 cs = common_str[[cs in t_lcut for cs in common_str].index(True)]
201 if t.split(".")[-1] == "h5":
202 tops_out.append(trajs[[cs in r for r in trajs].index(True)])
203 else:
204 tops_out.append(tops[[cs in r for r in tops].index(True)])
205 common_str_out.append(cs)
206 return tops_out, common_str_out
209def get_full_common_str_and_ref(trajs, tops, common_str):
210 """Matches traj_files, top_files and common string and returns lists with the
211 same length matching the provided common str.
213 Args:
214 trajs (list[str]): A list of str pointing to trajectory files.
215 tops (list[str]): A list of str pointing to topology files.
216 common_str (list[str]): A list of strings that can be found in
217 both trajs and tops (i.e. substrings).
219 """
220 if len(trajs) != len(tops) and common_str == [] and len(tops) != 1:
221 raise BadError(
222 "When providing a list of trajs and a list of refs with different length you must provide a list of common_str to match them."
223 )
225 # if the length of all objects is the same we just return them
226 if len(trajs) == len(tops) == len(common_str):
227 return trajs, tops, common_str
229 # if trajs and tops is the same length they are expected to match
230 elif len(trajs) == len(tops):
231 return trajs, tops, [None for i in trajs]
233 # if only one topology is provided, we hope the user passed a correct one and fill everything else up
234 elif len(trajs) > 1 and len(tops) == 1: 234 ↛ 250line 234 didn't jump to line 250, because the condition on line 234 was never false
235 tops_out = [tops[0] for t in trajs]
237 if common_str == []: 237 ↛ 240line 237 didn't jump to line 240, because the condition on line 237 was never false
238 common_str_out = [Path(traj).stem for traj in trajs]
240 elif len(common_str) != len(trajs):
241 tops_out, common_str_out = match_files(trajs, tops_out, common_str)
243 elif len(common_str) == len(trajs):
244 common_str_out = common_str
246 return trajs, tops_out, common_str_out
248 # in the other cases we need to do something similar
249 else:
250 if len(tops) > len(trajs):
251 raise Exception(
252 f"I was given more topologies {tops} than trajectories {trajs} . Something's not right."
253 )
254 if len(common_str) > len(trajs):
255 raise Exception(
256 "I was given more common strings than trajectories. Something's not right."
257 )
259 if common_str == []:
260 common_str_out = [Path(traj).stem for traj in trajs]
262 elif len(common_str) != len(trajs):
263 tops_out, common_str_out = match_files(trajs, tops, common_str)
265 elif len(common_str) == len(trajs):
266 common_str_out = common_str
268 tops_out, common_str_out = match_files(trajs, tops, common_str_out)
269 return trajs, tops_out, common_str_out
272def printTable(myDict, colList=None, sep="\uFFFA"):
273 """Pretty print a list of dictionaries (myDict) as a dynamically sized table.
274 If column names (colList) aren't specified, they will show in random order.
275 sep: row separator. Ex: sep='\n' on Linux. Default: dummy to not split line.
276 Author: Thierry Husson - Use it as you want but don't blame me.
277 """
278 out = []
279 if not colList:
280 colList = list(myDict[0].keys() if myDict else [])
281 myList = [colList] # 1st row = header
282 for item in myDict:
283 myList.append([str(item[col] or "") for col in colList])
284 colSize = [max(map(len, (sep.join(col)).split(sep))) for col in zip(*myList)]
285 formatStr = " | ".join(["{{:<{}}}".format(i) for i in colSize])
286 line = formatStr.replace(" | ", "-+-").format(*["-" * i for i in colSize])
287 item = myList.pop(0)
288 lineDone = False
289 while myList:
290 if all(not i for i in item):
291 item = myList.pop(0)
292 if line and (sep != "\uFFFA" or not lineDone):
293 out.append(line)
294 lineDone = True
295 row = [i.split(sep, 1) for i in item]
296 out.append(formatStr.format(*[i[0] for i in row]))
297 item = [i[1] if len(i) > 1 else "" for i in row]
298 out = [" " + i for i in out]
299 points = " \n".join(out)
300 return points
303##############################################################################
304# Functions
305##############################################################################
308def _datetime_windows_and_linux_compatible():
309 """Portable way to get `now` as either a linux or windows compatible string.
311 For linux systems strings in this manner will be returned:
312 2022-07-13T16:04:04+02:00
314 For windows systems strings in this manner will be returned:
315 2022-07-13_16-04-46
317 """
318 import datetime
319 from sys import platform
321 if platform == "linux" or platform == "linux2" or platform == "darwin": 321 ↛ 323line 321 didn't jump to line 323, because the condition on line 321 was never false
322 return datetime.datetime.now().astimezone().replace(microsecond=0).isoformat()
323 elif platform == "win32":
324 return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
327def all_equal(iterable):
328 """Returns True, when all elements in List are equal"""
329 g = groupby(iterable)
330 return next(g, True) and not next(g, False)
333def _validate_uri(str_):
334 """Checks whether the str_ is a valid uri."""
335 from urllib.parse import urlparse
337 try:
338 result = urlparse(str_)
339 return all([result.scheme, result.path])
340 except:
341 return False
344def _flatten_model(model_nested, input_dim=None, return_model=True):
345 """Flattens a nested tensorflow.keras.models.Model.
347 Can be useful if a model consists of two sequential models and needs to
348 be flattened to be plotted.
350 """
351 layers_flat = []
352 for layer in model_nested.layers:
353 try:
354 layers_flat.extend(layer.layers)
355 except AttributeError:
356 layers_flat.append(layer)
357 if return_model:
358 model_flat = tf.keras.models.Sequential(layers_flat)
359 model_flat.build(input_shape=(1, input_dim))
360 return model_flat
361 else:
362 return layers_flat
365def plot_model(model, input_dim):
366 """Plots keras model using tf.keras.utils.plot_model"""
367 model = _flatten_model(model, input_dim)
368 try:
369 _ = tf.keras.utils.plot_model(
370 model, to_file="tmp.png", show_shapes=True, rankdir="LR", expand_nested=True
371 )
372 plt.show()
373 except:
374 pass
375 img = plt.imread("tmp.png")
376 os.remove("tmp.png")
377 plt.close("all")
378 plt.imshow(img)
379 if mpl.get_backend() == "module://ipykernel.pylab.backend_inline":
380 fig = plt.gcf()
381 fig.set_size_inches(fig.get_size_inches() * 4)
382 ax = plt.gca()
383 ax.axis("off")
384 plt.show()
387def run_path(path):
388 """Creates a directory at "path/run{i}" where the i is corresponding to the smallest not yet existing path.
390 Args:
391 path (str): Path to the run folder.
393 Returns:
394 str: The new output path.
396 Exampples:
397 >>> import os
398 >>> import encodermap as em
399 >>> os.makedirs('run1/')
400 >>> em.misc.run_path('run1/')
401 'run2/'
402 >>> os.listdir()
403 ['run1/', 'run2/']
405 """
406 i = 0
407 while True:
408 current_path = os.path.join(path, "run{}".format(i))
409 if not os.path.exists(current_path):
410 os.makedirs(current_path)
411 output_path = current_path
412 break
413 else:
414 i += 1
415 return output_path
418def create_n_cube(
419 n=3, points_along_edge=500, sigma=0.05, same_colored_edges=3, seed=None
420):
421 """Creates points along the edges of an n-dimensional unit hyper-cube.
423 The cube is created using networkx.hypercube_graph and points are placed along
424 the edges of the cube. By providing a sigma value the points can be shifted
425 by some Gaussian noise.
427 Args:
428 n (int, optional): The dimension of the Hypercube (can also take 1 or 2).
429 Defaults to 3.
430 points_along_edge (int, optional): How many points should be placed along any edge.
431 By increasing the number of dimensions, the number of edges
432 increases, which also increases the total number of points. Defaults to 500.
433 sigma (float, optional): The sigma value for np.random.normal which
434 introduces Gaussian noise to the positions of the points. Defaults to 0.05.
435 same_color_edges (int, optional): How many edges of the Hypercube should
436 be colored with the same color. This can be used to later
437 better visualize the edges of the cube. Defaults to 3.
438 seed (int, optional): If an int is provided this will be used as a seed
439 for np.random and fix the random state. Defaults to None which produces
440 random results every time this function is called.
442 Returns:
443 tuple: A tuple containing the following:
444 coordinates (np.ndarray): The coordinates of the points.
445 colors (np.ndarray): Integers that can be used for coloration.
447 Example:
448 >>> # A sigma value of zero means no noise at all.
449 >>> coords, colors = create_n_cube(2, sigma=0)
450 >>> coords[0]
451 [0., 1.]
453 """
454 if seed is not None: 454 ↛ 455line 454 didn't jump to line 455, because the condition on line 454 was never true
455 np.random.seed(seed=seed)
456 # create networkx hypercube with given dimensions
457 G = nx.hypercube_graph(n)
459 # vertices is not really needed
460 vertices = np.array([n for n in G.nodes])
461 # get edges
462 edges = np.array([e for e in G.edges])
464 # fill this list with values
465 coordinates = []
467 # iterate over edges
468 for i, edge in enumerate(edges):
469 # some basic analytic geomerty
470 A, B = edge
471 AB = B - A
472 # n points along edge
473 lin = np.linspace(0, 1, points_along_edge)
474 points = A + (AB[:, None] * lin).T
475 if sigma: 475 ↛ 478line 475 didn't jump to line 478, because the condition on line 475 was never false
476 points += np.random.normal(scale=sigma, size=(len(points), n))
477 # add label for colors
478 points = np.hstack([points, np.full((len(points), 1), i)])
479 coordinates.extend(points)
481 # make big numpy array
482 coordinates = np.array(coordinates)
484 # color the specified number of same colored edges
485 # choose a random edge
486 found_edges = []
487 edge_pairs = []
489 # iterate over the number of same colore edges
490 for _ in range(same_colored_edges):
491 for i, edge in enumerate(edges): 491 ↛ 490line 491 didn't jump to line 490, because the loop on line 491 didn't complete
492 if i in found_edges:
493 continue
494 found_edges.append(i)
495 vertex = edge[0]
496 where = np.where(np.all(edges[:, 0] == vertex, axis=1))[0]
497 for j in where: 497 ↛ 502line 497 didn't jump to line 502, because the loop on line 497 didn't complete
498 new_edge = edges[j]
499 if j not in found_edges:
500 found_edges.append(j)
501 break
502 if i != j: 502 ↛ 491line 502 didn't jump to line 491, because the condition on line 502 was never false
503 edge_pairs.append([i, j])
504 break
506 # replace the corresponding indices
507 for i, j in edge_pairs:
508 new = coordinates[coordinates[:, -1] == i]
509 new[:, 3] = np.full(points_along_edge, j)
510 coordinates[coordinates[:, -1] == i] = new
512 return coordinates[:, :-1], coordinates[:, -1]