Coverage for encodermap/parameters/parameters.py: 13%

233 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/parameters/parameters.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Parameter Classes for Encodermap. 

23 

24This module contains parameter classes which are used to hold information for 

25the encodermap autoencoder. Parameters can be set from keyword arguments, by 

26overwriting the class attributes or by reading them from .json, .yaml or ASCII files. 

27 

28Features: 

29 * Setting and saving Parameters with the Parameter class. 

30 * Loading parameters from disk and continue where you left off. 

31 * The Parameter and ADCParamter class contain already good default values. 

32 

33""" 

34 

35 

36################################################################################ 

37# Imports 

38################################################################################ 

39 

40 

41# Future Imports at the top 

42from __future__ import annotations 

43 

44# Standard Library Imports 

45import copy 

46import json 

47import os 

48from math import pi 

49from textwrap import wrap 

50 

51# Third Party Imports 

52from optional_imports import _optional_import 

53 

54# Encodermap imports 

55from encodermap.misc.misc import _datetime_windows_and_linux_compatible, printTable 

56 

57 

58################################################################################ 

59# Optional Imports 

60################################################################################ 

61 

62 

63yaml = _optional_import("yaml") 

64 

65 

66################################################################################ 

67# Typing 

68################################################################################ 

69 

70 

71# Standard Library Imports 

72from pathlib import Path 

73from typing import TYPE_CHECKING, Optional, TypeVar, Union 

74 

75 

76ParametersData = Union[ 

77 float, int, str, bool, list[int], list[str], list[float], tuple[int, None], None 

78] 

79ParametersDict = dict[str, ParametersData] 

80ParametersType = TypeVar("Parameters", bound="Parent") 

81ADCParametersType = TypeVar("Parameters", bound="Parent") 

82AnyParameters = Union[ParametersType, ADCParametersType] 

83 

84 

85################################################################################ 

86# Globals 

87################################################################################ 

88 

89 

90__all__: list[str] = ["Parameters", "ADCParameters"] 

91 

92 

93################################################################################ 

94# Functions 

95################################################################################ 

96 

97 

98def search_and_replace( 

99 file_path: Union[str, Path], 

100 search_pattern: str, 

101 replacement: str, 

102 out_path: Optional[str] = None, 

103 backup: bool = True, 

104) -> None: 

105 """Searches for a pattern in a text file and replaces it with the replacement 

106 

107 Args: 

108 file_path (str): File path of the file to replace the text pattern in. 

109 search_pattern (str): Pattern to search for. 

110 replacement (str): What to replace `search_pattern` with. 

111 out_path (str, optional): path where to write the output file. 

112 If no path is given, the original file will be replaced. Defaults to ''. 

113 backup (bool, optional): If backup is true, the original file is 

114 renamed to filename.bak before it is overwritten 

115 

116 Examples: 

117 >>> import tempfile 

118 >>> from encodermap.parameters.parameters import search_and_replace 

119 >>> from pathlib import Path 

120 >>> with tempfile.TemporaryDirectory() as td: 

121 ... td = Path(td) 

122 ... with open(td / "file.txt", mode="w+") as f: 

123 ... f.write("This is a Test file.") 

124 ... f.seek(0) 

125 ... print(f.read()) 

126 ... search_and_replace(td / "file.txt", "Test", "new Test") 

127 ... with open(td / "file.txt", mode="r") as f: 

128 ... print(f.read()) # doctest: +SKIP 

129 This is a Test file. 

130 This is a new Test file. 

131 

132 """ 

133 file_path = Path(file_path) 

134 with open(file_path, "r") as f: 

135 file_data = f.read() 

136 

137 file_data = file_data.replace(search_pattern, replacement) 

138 

139 if out_path is None: 

140 out_path = file_path 

141 if backup: 

142 backup_name = out_path.parent / (out_path.name + ".bak") 

143 out_path.rename(backup_name) 

144 

145 with open(out_path, "w") as file: 

146 file.write(file_data) 

147 

148 

149################################################################################ 

150# Classes 

151################################################################################ 

152 

153 

154class ParametersFramework: 

155 """Class to work with Parameters in the form of dict or attributes. 

156 

157 Parameters can be set via keyword args in init, set as 

158 instance attributes or read from disk. Can write parameters 

159 to disk in .yaml or .json format. 

160 

161 Attributes: 

162 main_path (str): The main path of the parameter class. 

163 defaults (dict): The defaults passed into the Parent Class by the child 

164 classes `Parameters` and `ADCParameters`. 

165 

166 Methods: 

167 save (): 

168 

169 

170 """ 

171 

172 _defaults = {} 

173 

174 def __init__(self, **kwargs: ParametersData) -> None: 

175 """Instantiate the ParametersFramework class. 

176 

177 This class is not meant to be used alone, but as a parent class for 

178 different parameters. 

179 

180 Args: 

181 defaults (dict): A dictionary of default values. 

182 **kwargs: Arbitrary keyword arguments. If these arguments are not 

183 keys of the `defaults` dictionary, they will be ignored. 

184 Otherwise, they will overwrite the keys in the defaults dict. 

185 

186 

187 """ 

188 self.main_path = os.getcwd() 

189 

190 # overwrite class defaults with user input **kwargs 

191 for key, value in kwargs.items(): 

192 if key not in self.defaults: 

193 if key == "n_epochs": 

194 raise Exception( 

195 "Parameter `n_epochs` and `n_steps_per_epoch` is " 

196 "deprecated. Use `n_steps` instead." 

197 ) 

198 if key == "main_path": 

199 setattr(self, key, value) 

200 continue 

201 print(f"Dropping unknown dict entry for {{'{key}': {value}}}") 

202 else: 

203 setattr(self, key, value) 

204 if len(kwargs["n_neurons"]) != len(kwargs["activation_functions"]) - 1: 

205 raise Exception( 

206 f"Length of `n_neurons` and `activation_functions` (-1) does not match: {self.n_neurons}, {self.activation_functions}" 

207 ) 

208 

209 def to_dict(self) -> ParametersDict: 

210 """Represents parameters as a dictionary. Can be undone by `.from_dict()`. 

211 

212 Returns: 

213 ParametersDict: The dict. 

214 

215 """ 

216 out = {} 

217 for k in self.defaults: 

218 out[k] = getattr(self, k) 

219 return out | {"main_path": self.main_path} 

220 

221 def save(self, path: Optional[Union[str, Path]] = None) -> str: 

222 """Save parameters in json format or yaml format. 

223 

224 Args: 

225 path (str, optional): Path where parameters should be saved. Possible extensions are '.json' and '.yaml'. 

226 If no path is given main_path/parameters.json is used. Defaults to ''. 

227 

228 Returns: 

229 str: The path where the parameters were saved. 

230 

231 """ 

232 if path is not None: 

233 path = str(path) 

234 else: 

235 path = os.path.join(self.main_path, f"parameters.json") 

236 fmt = "json" 

237 if os.path.isfile(path): 

238 filename, extension = os.path.splitext(path) 

239 time = _datetime_windows_and_linux_compatible() 

240 os.rename(path, filename + "_back_" + time + extension) 

241 fmt = path.split(".")[-1] 

242 if fmt not in ["json", "yaml"]: 

243 raise OSError( 

244 f"Unrecognized extension .{fmt} in path {path}. " 

245 f"Please provide either '.json' or '.yaml'" 

246 ) 

247 with open(path, "w") as f: 

248 if fmt == "json": 

249 json.dump(self.__dict__, f, indent=4, sort_keys=True) 

250 else: 

251 yaml.dump(self.__dict__, f, default_flow_style=True) 

252 return path 

253 

254 @property 

255 def parameters(self) -> str: 

256 """str: A string that contains tabulated parameter values.""" 

257 doc_p = Parameters.__doc__.split("Attributes:")[1].split("Examples:")[0] 

258 doc_p = ( 

259 "\n".join(map(lambda x: x.lstrip(" "), doc_p.splitlines())) 

260 .lstrip("\n") 

261 .rstrip("\n\n") 

262 .splitlines() 

263 ) 

264 doc = ADCParameters.__doc__.split("Attributes:")[1].split("Examples:")[0] 

265 doc = ( 

266 "\n".join(map(lambda x: x.lstrip(" "), doc.splitlines())) 

267 .lstrip("\n") 

268 .rstrip("\n\n") 

269 .splitlines() 

270 ) 

271 doc = doc_p + doc 

272 descr_dict = {} 

273 key = doc[0].split("):")[0].split()[0] 

274 descr = doc[0].split("): ")[-1] 

275 for line in doc[1:]: 

276 if "):" not in line: 

277 descr = descr + " " + line 

278 else: 

279 descr_dict[key] = descr 

280 key = line.split("):")[0].split()[0] 

281 descr = line.split("): ")[-1] 

282 else: 

283 descr_dict[key] = descr 

284 out = [] 

285 for key, value in self.__dict__.items(): 

286 if key in self.defaults: 

287 try: 

288 out.append( 

289 { 

290 "Parameter": key, 

291 "Value": value, 

292 "Description": "\n".join(wrap(descr_dict[key], width=50)), 

293 } 

294 ) 

295 except KeyError as e: 

296 raise Exception( 

297 f"There is no documentation about the parameter {key} in " 

298 f"this class' docstring. Please fix." 

299 ) from e 

300 return printTable(out, sep="\n") 

301 

302 @classmethod 

303 def from_dict( 

304 cls, params: ParametersDict 

305 ) -> Union[ParametersType, ADCParametersType]: 

306 """Constructs a parameters class from a dictionary of values.""" 

307 return cls(**params) 

308 

309 @classmethod 

310 def from_file( 

311 cls, path: Union[str, Path] 

312 ) -> Union[ParametersType, ADCParametersType]: 

313 """Alternative constructor for ParameterFramework classes. 

314 

315 Reads a file and sets the attributes based on that. 

316 

317 Args: 

318 path (str): Path to the parameters.json or parameters.yaml file 

319 

320 Returns: 

321 ParametersFramework: A new ParametersFramework class. 

322 

323 """ 

324 path = Path(path) 

325 with open(path, "r") as f: 

326 if path.suffix == ".json": 

327 params = json.load(f) 

328 elif path.suffix == ".yaml": 

329 params = yaml.load(f, Loader=yaml.FullLoader) 

330 else: 

331 raise ValueError( 

332 f"The extension of the provided file should be `.json`, or " 

333 f"`.yaml`. You provided {path.split('.')[1]}" 

334 ) 

335 

336 if "n_epochs" in params: 

337 print( 

338 "Detected old definition `n_epochs` and `n_steps_per_epoch`. " 

339 "I will change that to `n_steps` = `n_epochs` * `n_steps_per_epoch`." 

340 ) 

341 params["n_steps"] = params["n_epochs"] * params["n_steps_per_epoch"] 

342 

343 # also check soft start 

344 if "cartesian_cost_scale_soft_start" in params: 

345 if params["cartesian_cost_scale_soft_start"] != (None, None) and params[ 

346 "cartesian_cost_scale_soft_start" 

347 ] != [None, None]: 

348 a, b = params["cartesian_cost_scale_soft_start"] 

349 a *= params["n_steps_per_epoch"] 

350 b *= params["n_steps_per_epoch"] 

351 params["cartesian_cost_scale_soft_start"] = (a, b) 

352 

353 # fix the summary step and checkpoint_step 

354 params["summary_step"] *= params["n_steps_per_epoch"] 

355 params["checkpoint_step"] *= params["n_steps_per_epoch"] 

356 

357 del params["n_epochs"] 

358 del params["n_steps_per_epoch"] 

359 

360 if Path(params["main_path"]).parent != path.parent: 

361 print( 

362 "Seems like the parameter file was moved to another directory. " 

363 "Parameter file is updated ..." 

364 ) 

365 params["main_path"] = str(path.parent) 

366 

367 newclass = cls(**params) 

368 newclass.save() 

369 return newclass 

370 

371 @classmethod 

372 def load(cls, path: Union[str, Path]) -> Union[ParametersType, ADCParametersType]: 

373 """Loads the parameters saved in a .json or .yaml file into a new Parameter object. 

374 

375 Args: 

376 path (str): Path to the parameters.json or parameters.yaml file 

377 

378 Returns: 

379 ParametersFramework: A new ParametersFramework class. 

380 

381 """ 

382 path = Path(path) 

383 with open(path, "r") as f: 

384 if path.suffix == ".json": 

385 params = json.load(f) 

386 elif path.suffix == ".yaml": 

387 params = yaml.load(f, Loader=yaml.FullLoader) 

388 else: 

389 raise ValueError( 

390 f"The extension of the provided file should be `.json`, or " 

391 f"`.yaml`. You provided {path.suffix}" 

392 ) 

393 

394 if "n_epochs" in params: 

395 print( 

396 "Detected old definition `n_epochs` and `n_steps_per_epoch`. " 

397 "I will change that to `n_steps` = `n_epochs` * `n_steps_per_epoch`." 

398 ) 

399 params["n_steps"] = params["n_epochs"] * params["n_steps_per_epoch"] 

400 del params["n_epochs"] 

401 del params["n_steps_per_epoch"] 

402 

403 # check whether the parameters file has been moved and update it accordingly.def from 

404 if params["main_path"] != os.path.dirname(path): 

405 print( 

406 "s=Seems like the parameter file was moved to another directory. " 

407 "Parameter file is updated ..." 

408 ) 

409 search_and_replace(path, params["main_path"], os.path.dirname(path)) 

410 with open(path, "r") as file: 

411 if path.split(".")[1] == "json": 

412 params = json.load(f) 

413 elif path.split(".")[1] == "yaml": 

414 params = yaml.load(f, Loader=yaml.FullLoader) 

415 else: 

416 raise ValueError( 

417 f"The extension of the provided file should be `.json`, or `.yaml`. You provided {path.split('.')[1]}" 

418 ) 

419 

420 return cls(**params) 

421 

422 def update(self, **kwargs: ParametersData) -> None: 

423 """Updates the values of `self`. 

424 

425 Args: 

426 **kwargs: Arbitrary keyword arguments. If these arguments are not 

427 keys of the `self.defaults` dictionary, they will be ignored. 

428 Otherwise, they will overwrite the keys in the defaults dict. 

429 

430 """ 

431 for key, value in kwargs.items(): 

432 if key not in self.__dict__.keys(): 

433 print(f"Dropping unknown dict entry for {{'{key}': {value}}}") 

434 else: 

435 setattr(self, key, value) 

436 

437 @property 

438 def defaults(self) -> ParametersDict: 

439 return copy.deepcopy(self._defaults) 

440 

441 def _setattr(self, dictionary: ParametersDict) -> None: 

442 """Updates the values of `self.` 

443 

444 Args: 

445 dictionary (dict): 

446 

447 """ 

448 if "cartesian_cost_scale_soft_start" in dictionary: 

449 if dictionary["cartesian_cost_scale_soft_start"] is not None or dictionary[ 

450 "cartesian_cost_scale_soft_start" 

451 ] != (None, None): 

452 if len(dictionary["cartesian_cost_scale_soft_start"]) != 2: 

453 raise Exception( 

454 "Parameter cartesian_cost_scale_soft_start only takes a tuple of len 2." 

455 ) 

456 for key, value in dictionary.items(): 

457 setattr(self, key, value) 

458 

459 def __setitem__(self, key: str, value: ParametersData) -> None: 

460 """Implements the setitem method. Values can be set like so: 

461 

462 Examples: 

463 >>> from encodermap import Parameters 

464 >>> p = Parameters() 

465 >>> p["center_cost_scale"] = 2.5 

466 >>> p["center_cost_scale"] 

467 2.5 

468 

469 """ 

470 if key == "cartesian_cost_scale_soft_start": 

471 if value is not None or value != (None, None): 

472 if len(value) != 2: 

473 raise Exception( 

474 "Parameter cartesian_cost_scale_soft_start only takes a tuple of len 2." 

475 ) 

476 setattr(self, key, value) 

477 

478 def __getitem__(self, item: str) -> ParametersData: 

479 """Implements the getitem method. Get items with instance[key].""" 

480 return getattr(self, item) 

481 

482 def _string_summary(self) -> str: 

483 """Creates a short summary of a parameter class. Additionally, adds info about non-standard values.""" 

484 check_defaults = self.defaults 

485 diff_keys = list( 

486 filter( 

487 lambda x: not self.__dict__[x] == check_defaults[x], 

488 check_defaults.keys(), 

489 ) 

490 ) 

491 s = f"{self.__class__.__name__} class with 'main_path' at {self.main_path}." 

492 for d in diff_keys: 

493 s += f"\nNon-standard value of {d}: {self.__dict__[d]} (standard is {self.defaults[d]})" 

494 if diff_keys == []: 

495 s += " All parameters are set to default values." 

496 return s 

497 

498 def __str__(self) -> str: 

499 return self._string_summary() 

500 

501 def __repr__(self) -> str: 

502 return f"<{self._string_summary()} Object at 0x{id(self):02x}>" 

503 

504 

505class Parameters(ParametersFramework): 

506 """Class to hold Parameters for the Autoencoder 

507 

508 Parameters can be set via keyword args while instantiating the class, set as 

509 instance attributes or read from disk. This class can write parameters 

510 to disk in .yaml or .json format. 

511 

512 Attributes: 

513 defaults (dict): Classvariable dict that holds the defaults 

514 even when the current values might have changed. 

515 main_path (str): Defines a main path where the parameters and other things might be stored. 

516 n_neurons (list of int): List containing number of neurons for each layer up to the bottleneck layer. 

517 For example [128, 128, 2] stands for an autoencoder with the following architecture 

518 {i, 128, 128, 2, 128, 128, i} where i is the number of dimensions of the input data. 

519 These are Input/Output Layers that are not trained. 

520 activation_functions (list of str): List of activation function names as implemented in TensorFlow. 

521 For example: "relu", "tanh", "sigmoid" or "" to use no activation function. 

522 The encoder part of the network takes the activation functions 

523 from the list starting with the second element. The decoder part of 

524 the network takes the activation functions in reversed order starting with 

525 the second element form the back. For example ["", "relu", "tanh", ""] would 

526 result in a autoencoder with {"relu", "tanh", "", "tanh", "relu", ""} as 

527 sequence of activation functions. 

528 periodicity (float): Defines the distance between periodic walls for the inputs. 

529 For example 2pi for angular values in radians. 

530 All periodic data processed by EncoderMap must be wrapped to one periodic window. 

531 E.g. data with 2pi periodicity may contain values from -pi to pi or from 0 to 2pi. 

532 Set the periodicity to float("inf") for non-periodic inputs. 

533 learning_rate (float): Learning rate used by the optimizer. 

534 n_steps (int): Number of training steps. 

535 batch_size (int): Number of training points used in each training step 

536 summary_step (int): A summary for TensorBoard is writen every summary_step steps. 

537 checkpoint_step (int): A checkpoint is writen every checkpoint_step steps. 

538 dist_sig_parameters (tuple of floats): Parameters for the sigmoid 

539 functions applied to the high- and low-dimensional distances 

540 in the following order (sig_h, a_h, b_h, sig_l, a_l, b_l) 

541 distance_cost_scale (int): Adjusts how much the distance based metric is weighted in the cost function. 

542 auto_cost_scale (int): Adjusts how much the autoencoding cost is weighted in the cost function. 

543 auto_cost_variant (str): defines how the auto cost is calculated. Must be one of: 

544 * `mean_square` 

545 * `mean_abs` 

546 * `mean_norm` 

547 center_cost_scale (float): Adjusts how much the centering cost is weighted in the cost function. 

548 l2_reg_constant (float): Adjusts how much the L2 regularisation is weighted in the cost function. 

549 gpu_memory_fraction (float): Specifies the fraction of gpu memory blocked. 

550 If set to 0, memory is allocated as needed. 

551 analysis_path (str): A path that can be used to store analysis 

552 id (str): Can be any name for the run. Might be useful for example for 

553 specific analysis for different data sets. 

554 model_api (str): A string defining the API to be used to build the keras model. 

555 Defaults to `sequntial`. Possible strings are: 

556 * `functional` will use keras' functional API. 

557 * `sequential` will define a keras Model, containing two other models with the Sequential API. 

558 These two models are encoder and decoder. 

559 * `custom` will create a custom Model where even the layers are custom. 

560 loss (str): A string defining the loss function. 

561 Defaults to `emap_cost`. Possible losses are: 

562 * `reconstruction_loss` will try to train output == input 

563 * `mse`: Returns a mean squared error loss. 

564 * `emap_cost` is the EncoderMap loss function. Depending on the class `Autoencoder`, 

565 `Encodermap, `ADCAutoencoder`, different contributions are used for a combined loss. 

566 Autoencoder uses atuo_cost, reg_cost, center_cost. 

567 EncoderMap class adds sigmoid_loss. 

568 batched (bool): Whether the dataset is batched or not. 

569 training (str): A string defining what kind of training is performed when autoencoder.train() is callsed. 

570 * `auto` does a regular model.compile() and model.fit() procedure. 

571 * `custom` uses gradient tape and calculates losses and gradients manually. 

572 tensorboard (bool): Whether to print tensorboard information. Defaults to False. 

573 seed (Union[int, None]): Fixes the state of all operations using random numbers. Defaults to None. 

574 current_training_step (int): The current training step. Aids in reloading of models. 

575 write_summary (bool): If True writes a summar.txt of the models into main_path 

576 if `tensorboard` is True, summaries will also be written. 

577 trainable_dense_to_sparse (bool): When using different topologies to train 

578 the AngleDihedralCartesianEncoderMap, some inputs might be sparse, 

579 which means, they have missing values. Creating a dense input is done 

580 by first passing these sparse tensors through `tf.keras.layers.Dense` 

581 layers. These layers have trainable weights, and if this parameter 

582 is True, these weights will be changed by the optimizer. 

583 using_hypercube (bool): This parameter is not meant to be set by the user. 

584 It allows us to print better error messages when re-loading and 

585 re-training a model. It contains a boolean whether a model has been 

586 trained on the hypercube example data. If your data is 4-dimensional 

587 and you reload a model and forget to prvide your data, the model 

588 will happily train with the hypercube (and not your) data. This variable 

589 implements a check. 

590 

591 Examples: 

592 >>> import encodermap as em 

593 >>> import tempfile 

594 >>> from pathlib import Path 

595 ... 

596 >>> with tempfile.TemporaryDirectory() as td: 

597 ... td = Path(td) 

598 ... p = em.Parameters() 

599 ... print(p.auto_cost_variant) 

600 ... savepath = p.save(td / "parameters.json") 

601 ... print(savepath) 

602 ... new_params = em.Parameters.from_file(td / "parameters.json") 

603 ... print(new_params.main_path) # doctest: +SKIP 

604 mean_abs 

605 /tmp...parameters.json 

606 seems like the parameter file was moved to another directory. Parameter file is updated ... 

607 /home... 

608 

609 """ 

610 

611 _defaults = dict( 

612 n_neurons=[128, 128, 2], 

613 activation_functions=["", "tanh", "tanh", ""], 

614 periodicity=2 * pi, 

615 learning_rate=0.001, 

616 n_steps=1000, 

617 batch_size=256, 

618 summary_step=10, 

619 checkpoint_step=5000, 

620 dist_sig_parameters=(4.5, 12, 6, 1, 2, 6), 

621 distance_cost_scale=500, 

622 auto_cost_scale=1, 

623 auto_cost_variant="mean_abs", 

624 center_cost_scale=0.0001, 

625 l2_reg_constant=0.001, 

626 gpu_memory_fraction=0, 

627 analysis_path="", 

628 id="", 

629 model_api="sequential", 

630 loss="emap_cost", 

631 training="auto", 

632 batched=True, 

633 tensorboard=False, 

634 seed=None, 

635 current_training_step=0, 

636 write_summary=False, 

637 trainable_dense_to_sparse=False, 

638 using_hypercube=False, 

639 ) 

640 

641 def __init__(self, **kwargs: ParametersData) -> None: 

642 """Instantiate the Parameters class 

643 

644 Takes a dict as input and overwrites the class defaults. The dict is directly 

645 stored as an attribute and can be accessed via instance attributes. 

646 

647 Args: 

648 **kwargs (dcit): Dict containing values. If unknown keys are passed they will be dropped. 

649 

650 """ 

651 # set class variable defaults to be instance variable 

652 if "defaults" in kwargs: 

653 kwargs.pop("defaults", None) 

654 super().__init__(**{**self.defaults, **kwargs}) 

655 

656 @classmethod 

657 def defaults_description(cls) -> str: 

658 """str: A string that contains tabulated default parameter values.""" 

659 doc = cls.__doc__.split("Attributes:")[1].split("Examples:")[0] 

660 doc = ( 

661 "\n".join(map(lambda x: x.lstrip(" "), doc.splitlines())) 

662 .lstrip("\n") 

663 .rstrip("\n\n") 

664 .splitlines() 

665 ) 

666 descr_dict = {} 

667 key = doc[0].split("):")[0].split()[0] 

668 descr = doc[0].split("): ")[-1] 

669 for line in doc[1:]: 

670 if "):" not in line: 

671 descr = descr + " " + line 

672 else: 

673 descr_dict[key] = descr 

674 key = line.split("):")[0].split()[0] 

675 descr = line.split("): ")[-1] 

676 else: 

677 descr_dict[key] = descr 

678 

679 out = [] 

680 for key, value in cls.defaults.items(): 

681 out.append( 

682 { 

683 "Parameter": key, 

684 "Default Value": value, 

685 "Description": "\n".join(wrap(descr_dict[key], width=50)), 

686 } 

687 ) 

688 return printTable(out, sep="\n") 

689 

690 

691class ADCParameters(ParametersFramework): 

692 """This is the parameter object for the AngleDihedralCartesianEncoder. 

693 It holds all the parameters that the Parameters object includes, plus the following attributes: 

694 

695 Attributes: 

696 track_clashes (bool): Whether to track the number of clashes during 

697 training. The average number of clashes is the average number of 

698 distances in the reconstructed cartesian coordinates with a distance 

699 smaller than 1 (nm). Defaults to False. 

700 track_RMSD (bool): Whether to track the RMSD of the input and reconstructed 

701 cartesians during training. The RMSDs are computed along the batch 

702 by minimizing the .. math:: 

703 \\text{RMSD}(\\mathbf{x}, \\mathbf{x}^{\\text{ref}}) = \\min_{\\mathsf{R}, \\mathbf{t}} % 

704 \\sqrt{\\frac{1}{N} \\sum_{i=1}^{N} \\left[ % 

705 (\\mathsf{R}\\cdot\\mathbf{x}_{i}(t) + \\mathbf{t}) - \\mathbf{x}_{i}^{\\text{ref}} \\right]^{2}} 

706 This results in n RMSD values, where n is the size of the batch. 

707 A mean RMSD of this batch and the values for this batch will be logged 

708 to tensorboard. 

709 cartesian_pwd_start (int): Index of the first atom to use for the pairwise 

710 distance calculation. 

711 cartesian_pwd_stop (int): Index of the last atom to use for the pairwise 

712 distance calculation. 

713 cartesian_pwd_step (int): Step for the calculation of paiwise 

714 distances. E.g. for a chain of atoms N-C_a-C-N-C_a-C... 

715 cartesian_pwd_start=1 and cartesian_pwd_step=3 will result 

716 in using all C-alpha atoms for the pairwise distance calculation. 

717 use_backbone_angles (bool): Allows to define whether backbone bond 

718 angles should be learned (True) or if instead mean 

719 values should be used to generate conformations (False). 

720 use_sidechains (bool): Whether sidechain dihedrals should be passed 

721 through the autoencoder. 

722 angle_cost_scale (int): Adjusts how much the angle cost is weighted 

723 in the cost function. 

724 angle_cost_variant (str): Defines how the angle cost is calculated. 

725 Must be one of: 

726 * "mean_square" 

727 * "mean_abs" 

728 * "mean_norm". 

729 angle_cost_reference (int): Can be used to normalize the angle cost with 

730 the cost of same reference model (dummy). 

731 dihedral_cost_scale (int): Adjusts how much the dihedral cost is weighted 

732 in the cost function. 

733 dihedral_cost_variant (str): Defines how the dihedral cost is calculated. 

734 Must be one of: 

735 * "mean_square" 

736 * "mean_abs" 

737 * "mean_norm". 

738 dihedral_cost_reference (int): Can be used to normalize the dihedral 

739 cost with the cost of same reference model (dummy). 

740 side_dihedral_cost_scale (int): Adjusts how much the side dihedral cost 

741 is weighted in the cost function. 

742 side_dihedral_cost_variant (str): Defines how the side dihedral cost 

743 is calculated. Must be one of: 

744 * "mean_square" 

745 * "mean_abs" 

746 * "mean_norm". 

747 side_dihedral_cost_reference (int): Can be used to normalize the side 

748 dihedral cost with the cost of same reference model (dummy). 

749 cartesian_cost_scale (int): Adjusts how much the cartesian cost is 

750 weighted in the cost function. 

751 cartesian_cost_scale_soft_start (tuple): Allows to slowly turn on the 

752 cartesian cost. Must be a tuple with 

753 (start, end) or (None, None) If begin and end are given, 

754 cartesian_cost_scale will be increased linearly in the 

755 given range. 

756 cartesian_cost_variant (str): Defines how the cartesian cost is calculated. 

757 Must be one of: 

758 * "mean_square" 

759 * "mean_abs" 

760 * "mean_norm". 

761 cartesian_cost_reference (int): Can be used to normalize the cartesian 

762 cost with the cost of same reference model (dummy). 

763 cartesian_dist_sig_parameters (tuple of floats): Parameters for the 

764 sigmoid functions applied to the high- and low-dimensional 

765 distances in the following order (sig_h, a_h, b_h, sig_l, a_l, b_l). 

766 cartesian_distance_cost_scale (int): Adjusts how much the cartesian 

767 distance cost is weighted in the cost function. 

768 multimer_training (Any): Experimental feature. 

769 multimer_topology_classes (Any): Experimental feature. 

770 multimer_connection_bridges (Any): Experimental feature. 

771 multimer_lengths (Any): Experimental feature. 

772 reconstruct_sidechains (bool): Whether to also reconstruct sidechains. 

773 

774 Examples: 

775 >>> import encodermap as em 

776 >>> import tempfile 

777 >>> from pathlib import Path 

778 ... 

779 >>> with tempfile.TemporaryDirectory() as td: 

780 ... td = Path(td) 

781 ... p = em.Parameters() 

782 ... print(p.auto_cost_variant) 

783 ... savepath = p.save(td / "parameters.json") 

784 ... print(savepath) 

785 ... new_params = em.Parameters.from_file(td / "parameters.json") 

786 ... print(new_params.main_path) # doctest: +SKIP 

787 mean_abs 

788 /tmp...parameters.json 

789 seems like the parameter file was moved to another directory. Parameter file is updated ... 

790 /home... 

791 

792 """ 

793 

794 _defaults = dict( 

795 Parameters._defaults, 

796 **dict( 

797 track_clashes=False, 

798 track_RMSD=False, 

799 model_api="functional", # overwrite the main class. Functional allows multiple in and outputs. 

800 cartesian_pwd_start=None, 

801 cartesian_pwd_stop=None, 

802 cartesian_pwd_step=None, 

803 use_backbone_angles=False, 

804 use_sidechains=False, 

805 angle_cost_scale=0, 

806 angle_cost_variant="mean_abs", 

807 angle_cost_reference=1, 

808 dihedral_cost_scale=1, 

809 dihedral_cost_variant="mean_abs", 

810 dihedral_cost_reference=1, 

811 side_dihedral_cost_scale=0.5, 

812 side_dihedral_cost_variant="mean_abs", 

813 side_dihedral_cost_reference=1, 

814 cartesian_cost_scale=1, 

815 cartesian_cost_scale_soft_start=(None, None), # begin, end 

816 cartesian_cost_variant="mean_abs", 

817 cartesian_cost_reference=1, 

818 cartesian_dist_sig_parameters=Parameters._defaults["dist_sig_parameters"], 

819 cartesian_distance_cost_scale=1, 

820 auto_cost_scale=None, 

821 distance_cost_scale=None, 

822 multimer_training=None, 

823 multimer_topology_classes=None, 

824 multimer_connection_bridges=None, 

825 multimer_lengths=None, 

826 reconstruct_sidechains=False, 

827 ), 

828 ) 

829 

830 def __init__(self, **kwargs: ParametersData) -> None: 

831 """Instantiate the ADCParameters class 

832 

833 Takes a dict as input and overwrites the class defaults. The dict is directly 

834 stored as an attribute and can be accessed via instance attributes. 

835 

836 Args: 

837 **kwargs (dict): Dict containing values. If unknown values are passed they will be dropped. 

838 

839 """ 

840 if "cartesian_cost_scale_soft_start" in kwargs: 

841 if kwargs["cartesian_cost_scale_soft_start"] is not None or kwargs[ 

842 "cartesian_cost_scale_soft_start" 

843 ] != (None, None): 

844 if len(kwargs["cartesian_cost_scale_soft_start"]) != 2: 

845 raise Exception( 

846 "Parameter cartesian_cost_scale_soft_start only takes a tuple of len 2." 

847 ) 

848 # set class variable defaults to be instance variable 

849 if "defaults" in kwargs: 

850 kwargs.pop("defaults", None) 

851 super().__init__(**{**self.defaults, **kwargs}) 

852 

853 @classmethod 

854 def defaults_description(cls) -> str: 

855 """str: A string that contains tabulated default parameter values.""" 

856 doc_p = Parameters.__doc__.split("Attributes:")[1].split("Examples:")[0] 

857 doc_p = ( 

858 "\n".join(map(lambda x: x.lstrip(" "), doc_p.splitlines())) 

859 .lstrip("\n") 

860 .rstrip("\n\n") 

861 .splitlines() 

862 ) 

863 doc = cls.__doc__.split("Attributes:")[1].split("Examples:")[0] 

864 doc = ( 

865 "\n".join(map(lambda x: x.lstrip(" "), doc.splitlines())) 

866 .lstrip("\n") 

867 .rstrip("\n\n") 

868 .splitlines() 

869 ) 

870 doc = doc_p + doc 

871 descr_dict = {} 

872 key = doc[0].split("):")[0].split()[0] 

873 descr = doc[0].split("): ")[-1] 

874 for line in doc: 

875 if "):" not in line: 

876 descr = descr + " " + line 

877 else: 

878 descr_dict[key] = descr 

879 key = line.split("):")[0].split()[0] 

880 descr = line.split("): ")[-1] 

881 else: 

882 descr_dict[key] = descr 

883 

884 out = [] 

885 for key, value in cls.defaults.items(): 

886 out.append( 

887 { 

888 "Parameter": key, 

889 "Default Value": value, 

890 "Description": "\n".join(wrap(descr_dict[key], width=50)), 

891 } 

892 ) 

893 return printTable(out, sep="\n")