Coverage for encodermap/misc/misc.py: 10%

243 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/misc/misc.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Miscellaneous functions.""" 

23 

24############################################################################## 

25# Imports 

26############################################################################## 

27 

28 

29# Future Imports at the top 

30from __future__ import annotations 

31 

32# Standard Library Imports 

33import os 

34from itertools import groupby 

35from pathlib import Path 

36from tempfile import NamedTemporaryFile 

37 

38# Third Party Imports 

39import numpy as np 

40import tensorflow as tf 

41from optional_imports import _optional_import 

42 

43 

44################################################################################ 

45# Optional Imports 

46################################################################################ 

47 

48 

49nx = _optional_import("networkx") 

50Image = _optional_import("PIL", "Image") 

51ImageDraw = _optional_import("PIL", "ImageDraw") 

52ImageFont = _optional_import("PIL", "ImageFont") 

53 

54 

55################################################################################ 

56# Typing 

57################################################################################ 

58 

59 

60# Standard Library Imports 

61from collections.abc import Iterable, Sequence 

62from typing import TYPE_CHECKING, Any, Optional, Union, overload 

63 

64 

65if TYPE_CHECKING: 

66 # Standard Library Imports 

67 from pathlib import Path 

68 

69 # Local Folder Imports 

70 from ..trajinfo.info_all import TrajEnsemble 

71 

72 

73################################################################################ 

74# Globals 

75################################################################################ 

76 

77 

78__all__: list[str] = [ 

79 "create_n_cube", 

80 "plot_model", 

81 "run_path", 

82 "get_full_common_str_and_ref", 

83] 

84 

85 

86FEATURE_NAMES = { 

87 "AllCartesians": "all_cartesians", 

88 "AllBondDistances": "all_distances", 

89 "CentralCartesians": "central_cartesians", 

90 "CentralBondDistances": "central_distances", 

91 "CentralAngles": "central_angles", 

92 "CentralDihedrals": "central_dihedrals", 

93 "SideChainCartesians": "side_cartesians", 

94 "SideChainBondDistances": "side_distances", 

95 "SideChainAngles": "side_angles", 

96 "SideChainDihedrals": "side_dihedrals", 

97} 

98 

99 

100_TOPOLOGY_EXTS = [ 

101 ".pdb", 

102 ".pdb.gz", 

103 ".h5", 

104 ".lh5", 

105 ".prmtop", 

106 ".parm7", 

107 ".prm7", 

108 ".psf", 

109 ".mol2", 

110 ".hoomdxml", 

111 ".gro", 

112 ".arc", 

113 ".hdf5", 

114 ".gsd", 

115] 

116 

117 

118################################################################################ 

119# Utilities 

120################################################################################ 

121 

122 

123def _is_notebook() -> bool: 

124 """Returns 'True', if we are currently in a notebook. 

125 

126 Returns: 

127 bool: True if we are in a notebook, False otherwise. 

128 

129 """ 

130 try: 

131 # Third Party Imports 

132 from IPython import get_ipython 

133 

134 if "IPKernelApp" not in get_ipython().config: # pragma: no cover 

135 return False 

136 except ImportError: 136 ↛ 137line 136 didn't jump to line 137, because the exception caught by line 136 didn't happen

137 return False 

138 except AttributeError: 

139 return False 

140 return True 

141 

142 

143def _can_be_feature(inp: Union[str, list[str]]) -> bool: 

144 """Function to decide whether the input can be interpreted by the Featurizer class. 

145 

146 Outputs True, if inp == 'all' or inp is a list of strings contained in FEATURE_NAMES. 

147 

148 Args: 

149 inp (Any): The input. 

150 

151 Returns: 

152 bool: True, if inp can be interpreted by featurizer. 

153 

154 Example: 

155 >>> from encodermap.misc.misc import _can_be_feature 

156 >>> _can_be_feature('all') 

157 True 

158 >>> _can_be_feature('no') 

159 False 

160 >>> _can_be_feature(['AllCartesians', 'central_dihedrals']) 

161 True 

162 

163 """ 

164 if isinstance(inp, str): 

165 if inp == "all": 

166 return True 

167 if isinstance(inp, list): 

168 if all([isinstance(i, str) for i in inp]): 

169 if all( 

170 [i in FEATURE_NAMES.keys() or i in FEATURE_NAMES.values() for i in inp] 

171 ): 

172 return True 

173 return False 

174 

175 

176def match_files( 

177 trajs: Union[list[str], list[Path]], 

178 tops: Union[list[str], list[Path]], 

179 common_str: list[str], 

180) -> tuple[list[str], list[str]]: 

181 """Matches trajectory and topology files with a list of common strings. 

182 

183 The idea behind this function is to use one topology file with 

184 multiple trajectory files, that all use the same topology. 

185 

186 Args: 

187 trajs (Union[list[str], list[Path]]): A list of str or list of Path objects. 

188 These are the trajectory files we want to match the subset of 

189 topology files with. 

190 tops (Union[list[str], list[Path]]): A list of str or list of Path objects. 

191 These are the topology files we will assign to the trajectory files 

192 using common substrings. 

193 common_str (list[str]): A list of common substrings. 

194 

195 Returns: 

196 tuple[list[str], list[str]]: A tuple containing a two list of str. 

197 The first list of str are the topology files matched to the 'trajs'. 

198 The second list of str are the common_str matched to the 'trajs'. 

199 Both lists have the same length as `trajs`. 

200 

201 """ 

202 tops_out = [] 

203 common_str_out = [] 

204 

205 trajs = list(map(str, trajs)) 

206 tops = list(map(str, tops)) 

207 

208 if all([Path(p).suffix == ".h5" for p in trajs]) and len(trajs) == len(tops) == len( 

209 common_str 

210 ): 

211 return tops, common_str 

212 

213 if len(trajs) == len(tops): 

214 try: 

215 first_cs = common_str[[cs in trajs[0] for cs in common_str].index(True)] 

216 except ValueError as e: 

217 raise Exception( 

218 f"Can't find any matching common_str in the input data:\n{trajs=}\n\n" 

219 f"{tops=}\n\n{common_str=}" 

220 ) from e 

221 if first_cs not in tops[0]: 

222 iterator = trajs 

223 else: 

224 iterator = zip(trajs, tops) 

225 else: 

226 iterator = trajs 

227 

228 for t in iterator: 

229 if isinstance(t, tuple): 

230 t, top = t 

231 if not any([cs in t for cs in common_str]): 

232 raise Exception( 

233 f"The traj file {t} does not match any of the common_str you provided." 

234 ) 

235 else: 

236 t_lcut = max([t.rfind(cs) for cs in common_str]) 

237 t_lcut = t[t_lcut:] 

238 cs = common_str[[cs in t_lcut for cs in common_str].index(True)] 

239 if t.split(".")[-1] == "h5": 

240 tops_out.append(t) 

241 assert cs in t, f"{cs=} {top=} {t=}" 

242 else: 

243 tops_out.append(top) 

244 assert cs in top and cs in t 

245 common_str_out.append(cs) 

246 else: 

247 if not any([cs in t for cs in common_str]): 

248 raise Exception( 

249 f"The traj file {t} does not match any of the common_str you provided." 

250 ) 

251 else: 

252 t_lcut = max([t.rfind(cs) for cs in common_str]) 

253 t_lcut = t[t_lcut:] 

254 cs = common_str[[cs in t_lcut for cs in common_str].index(True)] 

255 if t.split(".")[-1] == "h5": 

256 tops_out.append(trajs[[cs in r for r in trajs].index(True)]) 

257 else: 

258 tops_out.append(tops[[cs in r for r in tops].index(True)]) 

259 assert cs in t 

260 common_str_out.append(cs) 

261 return tops_out, common_str_out 

262 

263 

264def get_full_common_str_and_ref( 

265 trajs: Union[list[str], list[Path]], 

266 tops: Union[list[str], list[Path]], 

267 common_str: list[str], 

268) -> tuple[list[str], list[str], list[str]]: 

269 """Matches traj_files, top_files and common string and returns lists with the 

270 same length matching the provided common str. 

271 

272 Args: 

273 trajs (list[str]): A list of str pointing to trajectory files. 

274 tops (list[str]): A list of str pointing to topology files. 

275 common_str (list[str]): A list of strings that can be found in 

276 both trajs and tops (i.e. substrings). 

277 

278 Args: 

279 tuple: A tuple containing the following: 

280 - list[str]: A list of str with the traj file names. 

281 - list[str]: A list of str with the top file names. 

282 - list[str]: A list of str with the common_str's. 

283 All lists have the same length. 

284 

285 """ 

286 assert isinstance(common_str, list) 

287 if len(trajs) != len(tops) and common_str == [] and len(tops) != 1: 

288 raise Exception( 

289 "When providing a list of trajs and a list of refs with different " 

290 "length you must provide a list of common_str to match them." 

291 ) 

292 

293 # if the length of all objects is the same we just return them 

294 # but also we check, whether the common_str appears in the 

295 # trajs and tops to make sure its correclty ordered 

296 if len(trajs) == len(tops) == len(common_str): 

297 if not all([i is None for i in common_str]): 

298 for traj, top, cs in zip(trajs, tops, common_str): 

299 if cs not in str(traj) or cs not in str(top): 

300 return (trajs, *match_files(trajs, tops, common_str)) 

301 return trajs, tops, common_str 

302 

303 # if trajs and tops are the same length, they are expected to match 

304 elif len(trajs) == len(tops): 

305 if len(common_str) == 0: 

306 return trajs, tops, [None for i in trajs] 

307 

308 elif len(common_str) == 1: 

309 return trajs, tops, [common_str[0] for i in trajs] 

310 

311 else: 

312 tops_out, common_str_out = match_files(trajs, tops, common_str) 

313 return trajs, tops_out, common_str_out 

314 

315 # if only one topology is provided, we hope the user passed a correct one and fill everything else up 

316 elif len(trajs) > 1 and len(tops) == 1: 

317 tops_out = [tops[0] for t in trajs] 

318 

319 if common_str == []: 

320 common_str_out = [Path(traj).stem for traj in trajs] 

321 

322 elif len(common_str) != len(trajs): 

323 tops_out, common_str_out = match_files(trajs, tops_out, common_str) 

324 

325 elif len(common_str) == len(trajs): 

326 common_str_out = common_str 

327 

328 return trajs, tops_out, common_str_out 

329 

330 # in other cases, we need to do something similar 

331 else: 

332 if len(tops) > len(trajs): 

333 raise Exception( 

334 f"I was given more topologies {tops} than trajectories {trajs} . Something's not right." 

335 ) 

336 if len(common_str) > len(trajs): 

337 raise Exception( 

338 "I was given more common strings than trajectories. Something's not right." 

339 ) 

340 

341 if common_str == []: 

342 common_str_out = [Path(traj).stem for traj in trajs] 

343 

344 elif len(common_str) != len(trajs): 

345 tops_out, common_str_out = match_files(trajs, tops, common_str) 

346 

347 elif len(common_str) == len(trajs): 

348 common_str_out = common_str 

349 

350 tops_out, common_str_out = match_files(trajs, tops, common_str_out) 

351 return trajs, tops_out, common_str_out 

352 

353 

354def printTable( 

355 myDict: dict[str, dict[str, Any]], 

356 colList: Optional[list[str]] = None, 

357 sep: str = "\uFFFA", 

358) -> str: 

359 """Pretty print a list of dictionaries (myDict) as a dynamically sized table. 

360 If column names (colList) aren't specified, they will show in random order. 

361 sep: row separator. Ex: sep='\n' on Linux. Default: dummy to not split line. 

362 Author: Thierry Husson - Use it as you want but don't blame me. 

363 """ 

364 out = [] 

365 if not colList: 

366 colList = list(myDict[0].keys() if myDict else []) 

367 myList = [colList] # 1st row = header 

368 for item in myDict: 

369 myList.append([str(item[col] or "") for col in colList]) 

370 colSize = [max(map(len, (sep.join(col)).split(sep))) for col in zip(*myList)] 

371 formatStr = " | ".join(["{{:<{}}}".format(i) for i in colSize]) 

372 line = formatStr.replace(" | ", "-+-").format(*["-" * i for i in colSize]) 

373 item = myList.pop(0) 

374 lineDone = False 

375 while myList: 

376 if all(not i for i in item): 

377 item = myList.pop(0) 

378 if line and (sep != "\uFFFA" or not lineDone): 

379 out.append(line) 

380 lineDone = True 

381 row = [i.split(sep, 1) for i in item] 

382 out.append(formatStr.format(*[i[0] for i in row])) 

383 item = [i[1] if len(i) > 1 else "" for i in row] 

384 out = [" " + i for i in out] 

385 points = " \n".join(out) 

386 return points 

387 

388 

389############################################################################## 

390# Functions 

391############################################################################## 

392 

393 

394def _datetime_windows_and_linux_compatible() -> str: 

395 """Portable way to get `now` as either a linux or windows compatible string. 

396 

397 For linux systems, strings in this manner will be returned: 

398 2022-07-13T16:04:04+02:00 

399 

400 For Windows systems, strings in this manner will be returned: 

401 2022-07-13_16-04-46 

402 

403 """ 

404 # Standard Library Imports 

405 import datetime 

406 from sys import platform 

407 

408 if platform == "linux" or platform == "linux2" or platform == "darwin": 

409 return datetime.datetime.now().astimezone().replace(microsecond=0).isoformat() 

410 elif platform == "win32": 

411 return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 

412 

413 

414def all_equal(iterable: Iterable[Any]) -> bool: 

415 """Returns True, when all elements in a list are equal 

416 

417 Args: 

418 iterable (Iterable[Any]): An iterable of Any. 

419 

420 Returns: 

421 bool: Whether all elements in `iterable` are equal. 

422 

423 """ 

424 g = groupby(iterable) 

425 return next(g, True) and not next(g, False) 

426 

427 

428def _validate_uri(str_: Union[Path, str]) -> bool: 

429 """Checks whether the a str or Path is a valid uri. 

430 

431 Args: 

432 str_ (Union[Path, str]): The str or Path to test. 

433 

434 Returns: 

435 bool: True when `str_` is a valid URI, False otherwise. 

436 

437 """ 

438 # Standard Library Imports 

439 from urllib.parse import urlparse 

440 

441 try: 

442 result = urlparse(str_) 

443 return all([result.scheme, result.path]) 

444 except: 

445 return False 

446 

447 

448@overload 

449def _flatten_model( 449 ↛ exitline 449 didn't jump to the function exit

450 model_nested: tf.keras.Model, 

451 input_dim: Optional[Sequence[int]] = None, 

452 return_model: bool = True, 

453) -> tf.keras.Model: ... 

454 

455 

456@overload 

457def _flatten_model( 457 ↛ exitline 457 didn't jump to the function exit

458 model_nested: tf.keras.Model, 

459 input_dim: Optional[Sequence[int]] = None, 

460 return_model: bool = False, 

461) -> list[tf.keras.layers.Layer]: ... 

462 

463 

464def _flatten_model( 

465 model_nested: tf.keras.Model, 

466 input_dim: Optional[Sequence[int]] = None, 

467 return_model: bool = True, 

468) -> Union[tf.keras.Model, list[tf.keras.layers.Layer]]: 

469 """Flattens a nested tensorflow.keras.models.Model. 

470 

471 Can be useful if a model consists of two sequential models and needs to 

472 be flattened to be plotted. 

473 

474 Union[tf.keras.Model, list[tf.keras.layers.Layer]]: 

475 Either a tf.keras.Model or a list of layers. 

476 

477 """ 

478 layers_flat = [] 

479 for layer in model_nested.layers: 

480 try: 

481 layers_flat.extend(layer.layers) 

482 except AttributeError: 

483 layers_flat.append(layer) 

484 if return_model: 

485 model_flat = tf.keras.models.Sequential(layers_flat) 

486 model_flat.build(input_shape=(1, input_dim)) 

487 return model_flat 

488 else: 

489 return layers_flat 

490 

491 

492def plot_model( 

493 model: tf.keras.Model, 

494 input_dim: Optional[Sequence[int]] = None, 

495) -> Optional[Image]: 

496 """Plots keras model using tf.keras.utils.plot_model""" 

497 # Encodermap imports 

498 from encodermap.models.models import SequentialModel 

499 

500 if isinstance(model, SequentialModel) and input_dim is not None: 

501 model = _flatten_model(model, input_dim) 

502 with NamedTemporaryFile(suffix=".png") as f: 

503 try: 

504 _ = tf.keras.utils.plot_model( 

505 model, 

506 to_file=f.name, 

507 show_shapes=True, 

508 rankdir="TB", 

509 expand_nested=True, 

510 ) 

511 except: 

512 return 

513 return Image.open(f.name) 

514 with NamedTemporaryFile(suffix=".png") as f: 

515 try: 

516 _ = tf.keras.utils.plot_model( 

517 model, 

518 to_file=f.name, 

519 show_shapes=True, 

520 rankdir="TB", 

521 expand_nested=True, 

522 ) 

523 except: 

524 return 

525 img = Image.open(f.name) 

526 # img.putalpha(alpha=255) 

527 cube_img = Image.open(Path(__file__).resolve().parent / "logo_cube_300.png") 

528 img.paste(cube_img, (800, 500), cube_img) 

529 try: 

530 font = ImageFont.truetype( 

531 "/usr/share/fonts/truetype/dejavu/DejaVuMathTeXGyre.ttf", size=40 

532 ) 

533 ImageDraw.Draw(img).text( # Image 

534 (700, 800), # Coordinates 

535 "Made with EncoderMap", # Text 

536 (0, 0, 0), # Color 

537 font=font, 

538 ) 

539 except: 

540 pass 

541 return img 

542 

543 # plt.close("all") 

544 # plt.imshow(img) 

545 # plt.show() 

546 # if mpl.get_backend() == "module://ipykernel.pylab.backend_inline": 

547 # fig = plt.gcf() 

548 # fig.set_size_inches(fig.get_size_inches() * 4) 

549 # ax = plt.gca() 

550 # ax.axis("off") 

551 # plt.show() 

552 

553 

554def run_path(path: str) -> str: 

555 """Creates a directory at "path/run{i}" where the i is corresponding to the 

556 smallest not yet existing path. 

557 

558 Args: 

559 path (str): Path to the run folder. 

560 

561 Returns: 

562 str: The new output path. 

563 

564 Exampples: 

565 >>> import encodermap as em 

566 >>> import tempfile 

567 >>> import os 

568 >>> 

569 >>> def sort_key(inp: str) -> int: 

570 ... return int(inp[-1]) 

571 >>> 

572 >>> with tempfile.TemporaryDirectory() as td: 

573 ... # create some directories 

574 ... os.makedirs(os.path.join(td, "run0")) 

575 ... os.makedirs(os.path.join(td, "run1")) 

576 ... # em.misc.run_path will automatically advance the counter to 'run2' 

577 ... new_path = em.misc.run_path(td) 

578 ... print(new_path) 

579 ... print(sorted(os.listdir(td), key=sort_key)) # doctest: +ELLIPSIS 

580 /tmp/.../run2 

581 ['run0', 'run1', 'run2'] 

582 

583 """ 

584 i = 0 

585 while True: 

586 current_path = os.path.join(path, "run{}".format(i)) 

587 if not os.path.exists(current_path): 

588 os.makedirs(current_path) 

589 output_path = current_path 

590 break 

591 else: 

592 i += 1 

593 return output_path 

594 

595 

596def create_n_cube( 

597 n: int = 3, 

598 points_along_edge: int = 500, 

599 sigma: float = 0.05, 

600 same_colored_edges: int = 3, 

601 seed: Optional[int] = None, 

602) -> tuple[np.ndarray, np.ndarray]: 

603 """Creates points along the edges of an n-dimensional unit hyper-cube. 

604 

605 The cube is created using `networkx.hypercube_graph` and points are placed along 

606 the edges of the cube. By providing a `sigma` value, the points can be shifted 

607 by some Gaussian noise. 

608 

609 Args: 

610 n (int): The dimension of the Hypercube (can also take 1 or 2). 

611 Defaults to 3. 

612 points_along_edge (int): How many points should be placed along any edge. 

613 By increasing the number of dimensions, the number of edges 

614 increases, which also increases the total number of points. Defaults to 500. 

615 sigma (float): The sigma value for np.random.normal which 

616 introduces Gaussian noise to the positions of the points. Defaults to 0.05. 

617 same_color_edges (int): How many edges of the Hypercube should 

618 be colored with the same color. This can be used to later 

619 better visualize the edges of the cube. Defaults to 3. 

620 seed (Optional[int]): If an int is provided, this will be used as a seed 

621 for `np.random` and fix the random state. Defaults to None which produces 

622 random results every time this function is called. 

623 

624 Returns: 

625 tuple: A tuple containing the following: 

626 - np.ndarray: The coordinates of the points. 

627 - np.ndarray: Integers that can be used for coloration. 

628 

629 Example: 

630 >>> from encodermap.misc.misc import create_n_cube 

631 >>> # A sigma value of zero means no noise at all. 

632 >>> coords, colors = create_n_cube(2, sigma=0) 

633 >>> coords[0] 

634 array([0., 0.]) 

635 

636 """ 

637 if seed is not None: 

638 np.random.seed(seed=seed) 

639 # create networkx hypercube with given dimensions 

640 G = nx.hypercube_graph(n) 

641 

642 # vertices are not really needed 

643 vertices = np.array([n for n in G.nodes]) 

644 # get edges 

645 edges = np.array([e for e in G.edges]) 

646 

647 # fill this list with values 

648 coordinates = [] 

649 

650 # iterate over edges 

651 for i, edge in enumerate(edges): 

652 # some basic analytic geomerty 

653 A, B = edge 

654 AB = B - A 

655 # n points along edge 

656 lin = np.linspace(0, 1, points_along_edge) 

657 points = A + (AB[:, None] * lin).T 

658 if sigma: 

659 points += np.random.normal(scale=sigma, size=(len(points), n)) 

660 # add label for colors 

661 points = np.hstack([points, np.full((len(points), 1), i)]) 

662 coordinates.extend(points) 

663 

664 # make big numpy array 

665 coordinates = np.array(coordinates) 

666 

667 # color the specified number of same colored edges 

668 # choose a random edge 

669 found_edges = [] 

670 edge_pairs = [] 

671 

672 # iterate over the number of same colore edges 

673 for _ in range(same_colored_edges): 

674 for i, edge in enumerate(edges): 

675 if i in found_edges: 

676 continue 

677 found_edges.append(i) 

678 vertex = edge[0] 

679 where = np.where(np.all(edges[:, 0] == vertex, axis=1))[0] 

680 for j in where: 

681 new_edge = edges[j] 

682 if j not in found_edges: 

683 found_edges.append(j) 

684 break 

685 if i != j: 

686 edge_pairs.append([i, j]) 

687 break 

688 

689 # replace the corresponding indices 

690 for i, j in edge_pairs: 

691 new = coordinates[coordinates[:, -1] == i] 

692 new[:, -1] = np.full(points_along_edge, j) 

693 coordinates[coordinates[:, -1] == i] = new 

694 

695 return coordinates[:, :-1], coordinates[:, -1]