Coverage for encodermap/misc/saving_loading_models.py: 6%

355 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/misc/saving_loading_models.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22""" 

23Implementation of saving and loading models. 

24 

25 

26""" 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32# Future Imports at the top 

33from __future__ import annotations 

34 

35# Standard Library Imports 

36import copy 

37import os 

38import re 

39import shutil 

40import warnings 

41from collections.abc import Callable 

42from copy import deepcopy 

43from glob import glob 

44from pathlib import Path 

45from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload 

46 

47# Third Party Imports 

48import numpy as np 

49import tensorflow as tf 

50 

51# Encodermap imports 

52from encodermap.misc.misc import _datetime_windows_and_linux_compatible, run_path 

53from encodermap.parameters.parameters import ADCParameters, Parameters 

54 

55 

56################################################################################ 

57# Typing 

58################################################################################ 

59 

60 

61if TYPE_CHECKING: 

62 # Third Party Imports 

63 import numpy as np 

64 

65 # Encodermap imports 

66 from encodermap.autoencoder.autoencoder import AutoencoderClass 

67 from encodermap.trajinfo.info_all import TrajEnsemble 

68 

69 

70################################################################################ 

71# Globals 

72################################################################################ 

73 

74 

75__all__: list[str] = ["save_model", "load_model"] 

76 

77 

78################################################################################ 

79# Utils 

80################################################################################ 

81 

82 

83def _change_setting_inform_user( 

84 p: Union[Parameters, ADCParameters], 

85 setting: str, 

86 value: Any, 

87 parameters_file: Path, 

88 compat: bool = False, 

89) -> None: 

90 """Changes a setting in a parameter file and informs the user with a print message. 

91 

92 Args: 

93 p (Union[Parameters, ADCParameters]): An instance of the Parameters class. 

94 Either `Parameters`, or `ADCParameters`. 

95 setting (str): The setting to be changed. 

96 value (Any): The new value of the setting. 

97 parameters_file (Path): The file in which to change the setting. 

98 compat (bool): When loading old .model files and some parameters are 

99 inferred from this file. 

100 

101 """ 

102 curr = getattr(p, setting) 

103 if curr != value: 

104 setattr(p, setting, value) 

105 msg = ( 

106 f"In the parameters file {parameters_file}, the parameter '" 

107 f"{setting}' is set to '{curr}', but the architecture of the model " 

108 f"being loaded requires this parameter to be {value}. This parameter " 

109 f"file might not belong to the model you're trying to load." 

110 ) 

111 print(msg) 

112 if not parameters_file.is_file(): 

113 return 

114 

115 if not compat: 

116 msg += ( 

117 f"This tensorflow model was saved in the now deprecated .model " 

118 f"format. Since moving to the new .keras files, some changes have " 

119 f"been made to how parameters are saved. Some parameters can " 

120 f"be inferred from the old .model files by setting the `compat` " 

121 f"argument to True. This will create a backup of this parameters " 

122 f"file ({parameters_file}) and try to create a new one." 

123 ) 

124 raise Exception(msg) 

125 else: 

126 msg += ( 

127 f"I will backup the {parameters_file} and set appropriate values to" 

128 f"a new parameters.json." 

129 ) 

130 return 

131 

132 

133def _model_sort_key(model_name: str) -> int: 

134 """Returns numerical values based on whether `model_name` contains substrings. 

135 

136 Args: 

137 model_name (str): The filepath to the saved model. 

138 

139 Returns: 

140 int: Returns 0 for 'encoder', 1 for 'decoder', 2 for everything else. 

141 

142 """ 

143 x = Path(model_name).name 

144 return 0 if "encoder" in x else (1 if "decoder" in x else 2) 

145 

146 

147def _load_list_of_models( 

148 models: list[str], 

149 custom_objects: Optional[dict[str, Callable]] = None, 

150) -> list[tf.keras.Model]: 

151 """Load the models supplied in `models` using keras. 

152 

153 Args: 

154 models (list[str]): The paths of the models to be loaded 

155 

156 """ 

157 return [ 

158 tf.keras.models.load_model(x, custom_objects=custom_objects) for x in models 

159 ] 

160 

161 

162def filter_greater_than(step: int) -> Callable: 

163 """Returns a function that can used for filtering. 

164 

165 Examples: 

166 >>> from encodermap.misc.saving_loading_models import filter_greater_than 

167 >>> test = ["one_1.keras", "two_2.keras", "three_3.keras", "four_4.keras"] 

168 >>> list(filter(filter_greater_than(3), test)) 

169 ['three_3.keras', 'four_4.keras'] 

170 

171 Args: 

172 step (int): All files containing this step number or more will not 

173 be removed from the Sequence by the builtin filter function. 

174 

175 """ 

176 

177 def closure(path: Path) -> bool: 

178 r"""The closue of the `filter_greater_than` function. 

179 

180 Takes a `pathlib.Path` and extracts the last number using regexp (\d+). 

181 Returns True, if this number is equal or greater than `step`. 

182 

183 Args: 

184 path (Path): The path to use. 

185 

186 Returns: 

187 bool: Whether the last number in `path` is equal or greater than step. 

188 

189 """ 

190 current_step = int(re.findall(r"\d+", str(path))[-1]) 

191 return current_step >= step 

192 

193 return closure 

194 

195 

196################################################################################ 

197# Functions 

198################################################################################ 

199 

200 

201def save_model( 

202 model: tf.keras.Model, 

203 main_path: Union[str, Path], 

204 inp_class_name: Optional[str] = None, 

205 step: Optional[int] = None, 

206 print_message: bool = False, 

207) -> Path: 

208 """Saves a model in the portable .keras format. 

209 

210 Args: 

211 model (tf.keras.models.Model): The keras model to save. If the 

212 keras model has the attribute 'encoder_model' the encoder_model 

213 will be saved separately. The same with the attribute 'decoder_model'. 

214 main_path (Union[str, Path]): Which directory to save the model to. 

215 If step is None, the nae will be saved_model_{time}.keras, where time 

216 is a current ISO-8601 formatted string. 

217 step (Optional[int]): Can be None, in which case the model will bve saved 

218 using the current time. Otherwise, the step argument will be used like 

219 so: saved_model_{step}.keras Defaults to None. 

220 print_message (bool): Whether to print a message after saving the model 

221 Defaults to False. 

222 

223 Returns: 

224 Path: The path, where the model was saved. 

225 

226 """ 

227 main_path = Path(main_path) 

228 assert main_path.is_dir(), f"Please provide a directory as `main_path`." 

229 if step is None: 

230 time = _datetime_windows_and_linux_compatible() 

231 fname = main_path / f"saved_model_{time}.keras" 

232 encoder_name = main_path / f"saved_model_{time}_encoder.keras" 

233 decoder_name = main_path / f"saved_model_{time}_decoder.keras" 

234 else: 

235 fname = main_path / f"saved_model_{step}.keras" 

236 encoder_name = main_path / f"saved_model_{step}_encoder.keras" 

237 decoder_name = main_path / f"saved_model_{step}_decoder.keras" 

238 

239 if print_message: 

240 if inp_class_name is not None: 

241 print( 

242 f"Saving the model to {fname}. Use `em.{inp_class_name}.from_checkpoint('{main_path}')` " 

243 f"to load the most recent model, or `em.{inp_class_name}.from_checkpoint('{fname}')` " 

244 f"to load the model with specific weights.." 

245 ) 

246 else: 

247 print(f"Saving the model to {fname}.") 

248 model.save(fname) 

249 

250 if hasattr(model, "encoder_model"): 

251 if print_message: 

252 print( 

253 f"This model has a subclassed encoder, which can be loaded inde" 

254 f"pendently. Use `tf.keras.load_model('{encoder_name}')` to load " 

255 f"only this model." 

256 ) 

257 model.encoder_model.save(encoder_name) 

258 

259 if hasattr(model, "decoder_model"): 

260 if print_message: 

261 print( 

262 f"This model has a subclassed decoder, which can be loaded inde" 

263 f"pendently. Use `tf.keras.load_model('{decoder_name}')` to load " 

264 f"only this model." 

265 ) 

266 model.decoder_model.save(decoder_name) 

267 

268 return fname 

269 

270 

271def sort_model_files_with_timestr(path: Path) -> int: 

272 """Returns -1 for all files that have an ISO time in their filename and 

273 other numbers for files with numbers in their names. When a file has 

274 multiple number in its name '/path/to20/directory5/file_200.txt', the last 

275 number (200 in this case) will be used. 

276 

277 Returns: 

278 int: The sorting value. 

279 

280 """ 

281 m = re.match( 

282 r".*\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]" r"\d:[0-5]\d([+-][0-2]\d:[0-5]\d|Z).*", 

283 str(path), 

284 ) 

285 if m is not None: 

286 return -1 

287 number = int(re.findall(r"\d+", str(path))[-1]) 

288 return number 

289 

290 

291def _find_and_sort_files( 

292 path: Path, 

293 globstr: str, 

294 sort_criterion: Literal["creation_time", "number"] = "number", 

295) -> list[Path]: 

296 """Helper function to list possible files in `path`, using a `globstr` and 

297 a `sort_criterion`. 

298 

299 Args: 

300 path (Path): The directory to start with. 

301 globstr (str): The globstring to use. Example "*saved_model*". 

302 sort_criterion (Literal["creation_time", "number"]): Files can 

303 either be sorted by the creation time (`Path.stat().st_ctime`) or 

304 by "number" in which case the file 'saved_model_500.keras' will 

305 appear before the file 'saved_model_1000.keras'. 

306 

307 Returns: 

308 list[Path]: A list of paths with the applied sorting. 

309 

310 """ 

311 if sort_criterion == "creation_time": 

312 sortkey = lambda x: x.stat().st_ctime 

313 elif sort_criterion == "number": 

314 sortkey = sort_model_files_with_timestr 

315 else: 

316 raise ValueError( 

317 f"The argument `sort_criterion` has to be 'creation_time', or " 

318 f"'number', you supplied {sort_criterion=}." 

319 ) 

320 l = list( 

321 sorted( 

322 filter( 

323 lambda x: "encoder" not in x.stem and "decoder" not in x.stem, 

324 path.glob(globstr), 

325 ), 

326 key=sortkey, 

327 ) 

328 ) 

329 return l 

330 

331 

332@overload 

333def load_model( 333 ↛ exitline 333 didn't jump to the function exit

334 autoencoder: Union[None, "AutoencoderClass"], 

335 checkpoint_path: Union[str, Path], 

336 trajs: Optional[TrajEnsemble], 

337 sparse: bool, 

338 dataset: Optional[Union[tf.data.Dataset, np.ndarray]], 

339 print_message: bool, 

340 submodel: Literal[None], 

341 use_previous_model: bool, 

342 compat: bool, 

343) -> "AutoencoderClass": ... # pragma: no doccheck 

344 

345 

346@overload 

347def load_model( 347 ↛ exitline 347 didn't jump to the function exit

348 autoencoder: Union[None, "AutoencoderClass"], 

349 checkpoint_path: Union[str, Path], 

350 trajs: Optional[TrajEnsemble], 

351 sparse: bool, 

352 dataset: Optional[Union[tf.data.Dataset, np.ndarray]], 

353 print_message: bool, 

354 submodel: Literal["encoder", "decoder"], 

355 use_previous_model: bool, 

356 compat: bool, 

357) -> tf.keras.Model: ... # pragma: no doccheck 

358 

359 

360def load_model( 

361 autoencoder: Union[None, "AutoencoderClass"], 

362 checkpoint_path: Union[str, Path], 

363 trajs: Optional[TrajEnsemble] = None, 

364 sparse: bool = False, 

365 dataset: Optional[Union[tf.data.Dataset, np.ndarray]] = None, 

366 print_message: bool = False, 

367 submodel: Optional[Literal["encoder", "decoder"]] = None, 

368 use_previous_model: bool = False, 

369 compat: bool = False, 

370) -> Union["AutoencoderClass", tf.keras.Model]: 

371 """Reloads a model from a checkpoint path. 

372 

373 An implementation of saving the .keras files procuded by EncoderMap. 

374 The old legacy .model files can still be loaded by this function. Or use 

375 the `load_model_legacy` function directly. 

376 

377 Args: 

378 autoencoder (Union[None, "AutoencoderClass"]): Kept for 

379 legacy reasons. The old .model files had a list of "custom_objects" 

380 that was created by the autoencoder classes (`AutoEncoder`, 

381 `EncoderMap`. `AngleDihedralCartesianEncoderMap`) and needed to 

382 be supplied when reloading the models from disk. The new implementations 

383 use the `from_config` and `get_config` implementations of serializable 

384 keras objects and thus, the layers and cost functions can save their 

385 own state. Is only needed to load legacy models and can be None if a 

386 new .keras model is loaded. 

387 checkpoint_path (Union[str, Path]): Can be either the path to a .keras 

388 file or to a directory with multiple .keras files in which case, the 

389 most recent .keras file will be loaded. 

390 trajs (Optional[TrajEnsemble]): A `TrajEnsemble` class for when 

391 a `AngleDihedralCartesianEncoderMap` is reloaded. 

392 sparse (bool): This argument is also only needed to load legacy .model 

393 files. Defaults to False. 

394 dataset (Optional[Union[tf.data.Dataset, np.ndarray]]): A pass-through to 

395 the `dataset` argument of the autoencoder classes (`AutoEncoder`, 

396 `EncoderMap`. `AngleDihedralCartesianEncoderMap`) which all can take 

397 a tf.data.Dataset. Can be None, in which case, the data will be 

398 sourced differently (The `EncoderMap` class uses example data from 

399 a 4D hypercube, the `AngleDihedralCartesianEncoderMap` uses the 

400 data from the provided `trajs`.) 

401 print_message (bool): Whether to print some debug information. Defaults to False. 

402 submodel (Optional[Literal["encoder", "decoder"]]): Whether to only load 

403 a specific submodel. In order to use this argument, a file with 

404 the name *encoder.keras or *decoder.keras has to be in the 

405 in `checkpoint_path` specified directory. 

406 use_previous_model (bool): Whether to load a model from an intermediate 

407 checkpoint step. 

408 compat (bool): Whether to fix a parameters.json file that has been saved 

409 with the legacy .model file. 

410 

411 Returns: 

412 Union[tf.keras.models.Model, "AutoencoderClass"]: A tf.keras.models.Model 

413 when you specified submodel. And an appropriate "AutoencoderClass" 

414 otherwise. 

415 

416 

417 """ 

418 if "decoder.keras" in str(checkpoint_path) and submodel is None: 

419 raise Exception( 

420 f"The file you provided is just the decoder submodel of the complete " 

421 f"{autoencoder.__name__} class. Loading submodels, requires " 

422 f"you to explicitly set the argument `submodel='decoder'`. Note, " 

423 f"that loading submodels will return a `tf.keras.models.Model` instead " 

424 f"of an instance of {autoencoder.__name__}." 

425 ) 

426 if "encoder.keras" in str(checkpoint_path) and submodel is None: 

427 raise Exception( 

428 f"The file you provided is just the emcoder submodel of the complete " 

429 f"{autoencoder.__name__} class. Loading submodels, requires " 

430 f"you to explicitly set the argument `submodel='emcoder'`. Note, " 

431 f"that loading submodels will return a `tf.keras.models.Model` instead " 

432 f"of an instance of {autoencoder.__name__}." 

433 ) 

434 checkpoint_path = Path(checkpoint_path) 

435 if ".model" in str(checkpoint_path): 

436 print("Will use the legacy loader for old '*.model' file.") 

437 return load_model_legacy( 

438 autoencoder_class=autoencoder, 

439 checkpoint_path=str(checkpoint_path), 

440 trajs=trajs, 

441 sparse=sparse, 

442 dataset=dataset, 

443 compat=compat, 

444 ) 

445 if checkpoint_path.is_dir(): 

446 possible_models = _find_and_sort_files(checkpoint_path, "*saved_model*") 

447 try: 

448 newest_model = possible_models[-1] 

449 except IndexError as e: 

450 raise Exception( 

451 f"{checkpoint_path=} has no .keras files: {possible_models=}" 

452 ) from e 

453 if ".model" not in str(newest_model): 

454 if print_message: 

455 print( 

456 f"Found {len(possible_models)} in {checkpoint_path}. I will reload " 

457 f"{newest_model}, because this is the newest file." 

458 ) 

459 model = tf.keras.models.load_model(newest_model) 

460 checkpoint_path = newest_model 

461 else: 

462 possible_old_models = possible_models = _find_and_sort_files( 

463 checkpoint_path, "*.model" 

464 ) 

465 print("Will use the legacy loader for old '*.model' file.") 

466 return load_model_legacy( 

467 autoencoder_class=autoencoder, 

468 checkpoint_path=str(possible_old_models[-1]), 

469 trajs=trajs, 

470 sparse=sparse, 

471 dataset=dataset, 

472 compat=compat, 

473 ) 

474 else: 

475 if ".model" in str(checkpoint_path): 

476 return load_model_legacy( 

477 autoencoder_class=autoencoder, 

478 checkpoint_path=str(checkpoint_path), 

479 trajs=trajs, 

480 sparse=sparse, 

481 dataset=dataset, 

482 compat=compat, 

483 ) 

484 else: 

485 model = tf.keras.models.load_model(checkpoint_path) 

486 

487 # maybe load just encoder or decoder, if requested 

488 if submodel is not None: 

489 if submodel == "encoder": 

490 encoder_file = checkpoint_path.parent / checkpoint_path.name.replace( 

491 ".keras", "_encoder.keras" 

492 ) 

493 return tf.keras.models.load_model(encoder_file) 

494 elif submodel == "decoder": 

495 decoder_file = checkpoint_path.parent / checkpoint_path.name.replace( 

496 ".keras", "_decoder.keras" 

497 ) 

498 return tf.keras.models.load_model(decoder_file) 

499 else: 

500 raise ValueError( 

501 f"Argument `submodel` can only be either 'enocer' or 'decoder'. " 

502 f"You supplied: {submodel=}." 

503 ) 

504 

505 # load the params in the directory 

506 parameter_file = checkpoint_path.parent / "parameters.json" 

507 if not parameter_file.is_file() and autoencoder is not None: 

508 warnings.warn( 

509 f"There was no parameters.json file in the directory. {parameter_file.parent}. " 

510 f"I will load the model from the keras file, but I can't build a " 

511 f"{autoencoder} instance without the parameters." 

512 ) 

513 if parameter_file.is_file(): 

514 assert ( 

515 autoencoder is not None 

516 ), f"Please provide a class inheriting from `Autoencoder`." 

517 if "cartesian" in parameter_file.read_text(): 

518 p = ADCParameters.from_file(parameter_file) 

519 

520 # make sure parameters and current training step are the same 

521 current_step = re.findall(r"\d+", str(checkpoint_path)) 

522 backup_parameters = ( 

523 parameter_file.parent 

524 / f"parameters_at_{p.current_training_step}_{_datetime_windows_and_linux_compatible()}_{parameter_file.suffix}" 

525 ) 

526 if len(current_step) < 3 and len(current_step) >= 1: 

527 current_step = int(current_step[-1]) 

528 files_to_backup = list( 

529 filter( 

530 filter_greater_than(current_step), 

531 checkpoint_path.parent.glob("*.keras"), 

532 ) 

533 ) 

534 backup_files = [ 

535 f.parent 

536 / ( 

537 f.stem 

538 + f"_backup_from_{current_step}_{_datetime_windows_and_linux_compatible()}.keras" 

539 ) 

540 for f in files_to_backup 

541 ] 

542 if current_step != p.current_training_step: 

543 if not use_previous_model: 

544 raise Exception( 

545 f"The model was saved at step {current_step}, but the parameters " 

546 f"file has its current step at {p.current_training_step}. " 

547 f"It seems like you are reloading a model at an intermediate " 

548 f"step. If you set the `use_previous_model` flag to True, " 

549 f"I will backup the parameters file to {backup_parameters} and " 

550 f"set the new training step so, that you can use/retrain this " 

551 f"model." 

552 ) 

553 else: 

554 shutil.move(parameter_file, backup_parameters) 

555 for f1, f2 in zip(files_to_backup, backup_files): 

556 shutil.copyfile(f1, f2) 

557 p.current_training_step = current_step + 1 

558 p.n_steps = current_step + 1 

559 p.save() 

560 

561 # then load and return the autoencoder 

562 _using_hypercube = deepcopy(p.using_hypercube) 

563 out = autoencoder( 

564 trajs, 

565 parameters=p, 

566 read_only=False, 

567 dataset=dataset, 

568 model=model, 

569 ) 

570 out.p.using_hypercube = _using_hypercube 

571 return out 

572 else: 

573 p = Parameters.from_file(parameter_file) 

574 

575 # make sure parameters and current training step are the same 

576 current_step = re.findall(r"\d+", str(checkpoint_path)) 

577 backup_parameters = ( 

578 parameter_file.parent 

579 / f"parameters_at_{p.current_training_step}_{_datetime_windows_and_linux_compatible()}_{parameter_file.suffix}" 

580 ) 

581 if len(current_step) < 3 and len(current_step) >= 1: 

582 current_step = int(current_step[-1]) 

583 files_to_backup = list( 

584 filter( 

585 filter_greater_than(current_step), 

586 checkpoint_path.parent.glob("*.keras"), 

587 ) 

588 ) 

589 backup_files = [ 

590 f.parent 

591 / ( 

592 f.stem 

593 + f"_backup_from_{current_step}_{_datetime_windows_and_linux_compatible()}.keras" 

594 ) 

595 for f in files_to_backup 

596 ] 

597 if current_step != p.current_training_step: 

598 if not use_previous_model: 

599 raise Exception( 

600 f"The model was saved at step {current_step}, but the parameters " 

601 f"file has its current step at {p.current_training_step}. " 

602 f"It seems like you are reloading a model at an intermediate " 

603 f"step. If you set the `use_previous_model` flag to True, " 

604 f"I will backup the parameters file to {backup_parameters} and " 

605 f"set the new training step so, that you can use/retrain this " 

606 f"model." 

607 ) 

608 else: 

609 shutil.move(parameter_file, backup_parameters) 

610 for f1, f2 in zip(files_to_backup, backup_files): 

611 shutil.copyfile(f1, f2) 

612 p.current_training_step = current_step 

613 p.n_steps = current_step 

614 p.save() 

615 

616 # then load and return the autoencoder 

617 _using_hypercube = deepcopy(p.using_hypercube) 

618 out = autoencoder( 

619 parameters=p, 

620 train_data=dataset, 

621 read_only=False, 

622 model=model, 

623 ) 

624 out.p.using_hypercube = _using_hypercube 

625 return out 

626 return model 

627 

628 

629def load_model_legacy( 

630 autoencoder_class: Union[None, "AutoencoderClass"], 

631 checkpoint_path: Union[str, Path], 

632 trajs: Optional[TrajEnsemble] = None, 

633 sparse: bool = False, 

634 dataset: Optional[Union[tf.data.Dataset, np.ndarray]] = None, 

635 compat: bool = False, 

636) -> "AutoencoderClass": 

637 """Loads legacy .model files. 

638 

639 Note: 

640 The .model format has been deprecated. Please update your saved models 

641 to the .keras format. You can yse this function to rebuild a new 

642 model from the legacy .model files. 

643 

644 Args: 

645 autoencoder_class (Union[None, AutoencoderClass]): A class of the in 

646 EncoderMap implemented autoencoder classes. 

647 checkpoint_path (Union[str, Path]): The path to the file to load. 

648 trajs (Optional[TrajEnsemble]): When loading an AngleDihedralCartesianEncoderMap, 

649 the trajectories need to be supplied to verify the input/output shapes 

650 of the model. 

651 sparse (bool): Whether the model contains sparse inputs. 

652 dataset (Optional[Union[tf.data.Dataset, np.ndarray]): Either a tf.data.Dataset 

653 or a np.ndarray to infer the input shapre from. 

654 compat (bool): Whether 

655 

656 

657 """ 

658 # Local Folder Imports 

659 from ..autoencoder import AngleDihedralCartesianEncoderMap 

660 from ..models import gen_functional_model, gen_sequential_model 

661 

662 if "*" in str(checkpoint_path): 

663 checkpoint_path = list( 

664 sorted(map(Path, glob(str(checkpoint_path))), key=_model_sort_key) 

665 ) 

666 parameters_file = checkpoint_path[0].parent / "parameters.json" 

667 found = re.findall(r"\d+", str(checkpoint_path[0].name)) 

668 else: 

669 checkpoint_path = Path(checkpoint_path) 

670 parameters_file = checkpoint_path.parent / "parameters.json" 

671 found = re.findall(r"\d+", str(checkpoint_path.name)) 

672 read_only = False 

673 

674 if dataset is not None: 

675 d = dataset.take(1) 

676 if any([isinstance(i, tf.sparse.SparseTensor) for i in d]): 

677 sparse = True 

678 

679 try: 

680 step = int(found[-1]) 

681 except IndexError: 

682 step = None 

683 except ValueError as e: 

684 raise Exception(f"{checkpoint_path=} {found=}") from e 

685 

686 if autoencoder_class is AngleDihedralCartesianEncoderMap: 

687 p_class = ADCParameters 

688 else: 

689 p_class = Parameters 

690 

691 if not parameters_file.is_file(): 

692 parameters = p_class() 

693 print( 

694 f"Couldn't find the parameter's file at {parameters_file}. " 

695 f"Will use default {parameters.__class__.__name__} and will " 

696 f"infer architecture parameters from the model on disk. Weights " 

697 f"from the old model will be transferred to the new model. " 

698 f"From now on, you can save the model as a new .keras file." 

699 ) 

700 read_only = True 

701 else: 

702 parameters = p_class.from_file(parameters_file) 

703 print( 

704 "Weights from the old model will be transferred to the new " 

705 "model. From now on, you can save the model as a new .keras file." 

706 ) 

707 

708 # set the current training step 

709 if parameters.current_training_step == 0 and step is not None: 

710 parameters.current_training_step = step 

711 

712 # make assumptions on data based on input shape 

713 if autoencoder_class is AngleDihedralCartesianEncoderMap: 

714 old_model = tf.keras.models.load_model(checkpoint_path, compile=False) 

715 input_shape = old_model.input_shape 

716 encoder_input_shape = old_model.encoder_model.input_shape 

717 else: 

718 

719 class OldModel: 

720 pass 

721 

722 old_model = OldModel() 

723 assert len(checkpoint_path) == 2 

724 old_model.encoder = tf.keras.models.load_model( 

725 checkpoint_path[0], compile=False 

726 ) 

727 old_model.decoder = tf.keras.models.load_model( 

728 checkpoint_path[1], compile=False 

729 ) 

730 old_model.encoder_model = old_model.encoder 

731 old_model.decoder_model = old_model.decoder 

732 input_shape = old_model.encoder.input_shape[1] 

733 if dataset is not None: 

734 d = dataset.take(1) 

735 for a in d: 

736 break 

737 d = a 

738 if d[0].shape[1] != input_shape: 

739 if d[0].shape[1] * 2 == input_shape: 

740 _change_setting_inform_user( 

741 parameters, 

742 "periodicity", 

743 np.pi, 

744 parameters_file=parameters_file, 

745 compat=compat, 

746 ) 

747 else: 

748 raise Exception( 

749 f"The shape of the provided data {d[0].shape[1]} does not " 

750 f"match the input shape {input_shape} of the network. Are you sure, you " 

751 f"are loading the correct checkpoint?" 

752 ) 

753 else: 

754 _change_setting_inform_user( 

755 parameters, 

756 "periodicity", 

757 float("inf"), 

758 parameters_file=parameters_file, 

759 compat=compat, 

760 ) 

761 if not parameters_file.is_file(): 

762 print( 

763 f"I will create a model with an input shape of {input_shape}. For " 

764 f"periodic data (e.g., angles), this network will not work. If you " 

765 f"are reloading a network for periodic data please manually " 

766 f"create a parameters file at {parameters_file} with the appropriate " 

767 f"periodicity." 

768 ) 

769 

770 if autoencoder_class is AngleDihedralCartesianEncoderMap: 

771 if len(encoder_input_shape) == 3: 

772 _change_setting_inform_user( 

773 parameters, 

774 "use_sidechains", 

775 True, 

776 parameters_file=parameters_file, 

777 compat=compat, 

778 ) 

779 _change_setting_inform_user( 

780 parameters, 

781 "use_backbone_angles", 

782 True, 

783 parameters_file=parameters_file, 

784 compat=compat, 

785 ) 

786 elif len(encoder_input_shape) == 2: 

787 _change_setting_inform_user( 

788 parameters, 

789 "use_sidechains", 

790 False, 

791 parameters_file=parameters_file, 

792 compat=compat, 

793 ) 

794 _change_setting_inform_user( 

795 parameters, 

796 "use_backbone_angles", 

797 True, 

798 parameters_file=parameters_file, 

799 compat=compat, 

800 ) 

801 else: 

802 _change_setting_inform_user( 

803 parameters, 

804 "use_sidechains", 

805 False, 

806 parameters_file=parameters_file, 

807 compat=compat, 

808 ) 

809 _change_setting_inform_user( 

810 parameters, 

811 "use_backbone_angles", 

812 False, 

813 parameters_file=parameters_file, 

814 compat=compat, 

815 ) 

816 

817 # make some assumptions about the n_neurons and activation_function parameters 

818 n_neurons = [] 

819 activation_functions = [] 

820 for layer in old_model.encoder_model.layers: 

821 if hasattr(layer, "units"): 

822 n_neurons.append(layer.units) 

823 act = layer.activation.__name__ 

824 if act == "linear": 

825 activation_functions.append("") 

826 else: 

827 activation_functions.append(act) 

828 activation_functions = [activation_functions[-1]] + activation_functions 

829 

830 _change_setting_inform_user( 

831 parameters, 

832 "n_neurons", 

833 n_neurons, 

834 parameters_file=parameters_file, 

835 compat=compat, 

836 ) 

837 _change_setting_inform_user( 

838 parameters, 

839 "activation_functions", 

840 activation_functions, 

841 parameters_file=parameters_file, 

842 compat=compat, 

843 ) 

844 

845 if autoencoder_class is AngleDihedralCartesianEncoderMap: 

846 new_model = gen_functional_model( 

847 input_shapes=tuple([v[1:] for v in old_model.input_shape]), 

848 parameters=parameters, 

849 sparse=sparse, 

850 write_summary=not read_only, 

851 ) 

852 else: 

853 new_model = gen_sequential_model( 

854 input_shape=input_shape, 

855 parameters=parameters, 

856 sparse=sparse, 

857 ) 

858 try: 

859 new_model.encoder_model.set_weights(old_model.encoder.get_weights()) 

860 except AttributeError as e: 

861 new_model.encoder_model.set_weights(old_model.encoder_model.get_weights()) 

862 except Exception as e: 

863 raise Exception( 

864 f"{[i.shape for i in new_model.encoder_model.get_weights()]=}\n\n" 

865 f"{[i.shape for i in old_model.encoder_model.get_weights()]=}" 

866 ) from e 

867 

868 try: 

869 new_model.decoder_model.set_weights(old_model.decoder.get_weights()) 

870 except AttributeError as e: 

871 new_model.decoder_model.set_weights(old_model.decoder_model.get_weights()) 

872 except Exception as e: 

873 raise Exception( 

874 f"{[i.shape for i in new_model.decoder_model.get_weights()]=}\n\n" 

875 f"{[i.shape for i in old_model.decoder_model.get_weights()]=}" 

876 ) from e 

877 

878 if autoencoder_class is AngleDihedralCartesianEncoderMap: 

879 new_class = autoencoder_class( 

880 trajs=trajs, 

881 parameters=parameters, 

882 model=new_model, 

883 read_only=read_only, 

884 dataset=dataset, 

885 ) 

886 if not read_only: 

887 new_class.save() 

888 return new_class 

889 else: 

890 new_class = autoencoder_class( 

891 parameters=parameters, 

892 train_data=dataset, 

893 model=new_model, 

894 read_only=read_only, 

895 sparse=sparse, 

896 ) 

897 if not read_only: 

898 new_class.save() 

899 return new_class 

900 

901 

902def load_model_legacy_dep( 

903 autoencoder_class: AutoencoderClass, 

904 checkpoint_path: Union[str, Path], 

905 read_only: bool = True, 

906 overwrite_tensorboard_bool: bool = False, 

907 trajs: Optional[TrajEnsemble] = None, 

908 sparse: bool = False, 

909 dataset: Optional[tf.data.Dataset] = None, 

910) -> AutoencoderClass: # pragma: no doccheck 

911 """Reloads a tf.keras.Model from a checkpoint path. 

912 

913 

914 For this, an AutoencoderClass is necessary to provide the corresponding 

915 custom objects, such as loss functions. 

916 

917 

918 """ 

919 basedir = os.path.split(checkpoint_path)[0] 

920 

921 # remove wildcard 

922 if "*" in checkpoint_path: 

923 cp_path = checkpoint_path.replace("*", "") 

924 else: 

925 cp_path = checkpoint_path 

926 

927 if trajs is None and dataset is None: 

928 params = Parameters.from_file(basedir + "/parameters.json") 

929 _params = copy.deepcopy(params) 

930 if overwrite_tensorboard_bool: 

931 params.tensorboard = False 

932 directory = run_path("/".join(checkpoint_path.split("/")[:-1])) 

933 if directory != params.main_path: 

934 print( 

935 f"The saved model files have been moved from {params.main_path} " 

936 f"to {directory}. I will overwrite the 'main_path' attribute of " 

937 f"these parameters." 

938 ) 

939 params = deepcopy(params) 

940 params.main_path = run_path(directory) 

941 if not hasattr(params, "write_summary"): 

942 params.write_summary = params.tensorboard 

943 out = autoencoder_class(parameters=params, read_only=read_only) 

944 else: 

945 params = ADCParameters.from_file(basedir + "/parameters.json") 

946 _params = copy.deepcopy(params) 

947 if overwrite_tensorboard_bool: 

948 params.tensorboard = False 

949 if os.path.isfile(cp_path + "_current_step.txt"): 

950 with open(cp_path + "_current_step.txt", "r") as f: 

951 step = int(f.read()) 

952 elif read_only: 

953 step = 0 

954 else: 

955 print( 

956 "Cannot find cartesian loss step. Retraining of this model might " 

957 "lead to unexpected results." 

958 ) 

959 directory = run_path("/".join(checkpoint_path.split("/")[:-1])) 

960 if directory != params.main_path: 

961 print( 

962 f"The saved model files have been moved from {params.main_path} " 

963 f"to {directory}. I will overwrite the 'main_path' attribute of " 

964 f"these parameters." 

965 ) 

966 params = deepcopy(params) 

967 params.main_path = directory 

968 if not hasattr(params, "write_summary"): 

969 params.write_summary = params.tensorboard 

970 out = autoencoder_class( 

971 trajs, 

972 parameters=params, 

973 read_only=read_only, 

974 dataset=dataset, 

975 ) 

976 out.p = _params 

977 

978 # see if there are multiple models 

979 if "*" not in checkpoint_path: 

980 models = glob.glob(checkpoint_path + "*/") 

981 else: 

982 models = glob.glob(checkpoint_path + "/") 

983 

984 # three different ways of loading models 

985 if len(models) == 2: 

986 models.sort(key=_model_sort_key) 

987 custom_objects = {fn.__name__: fn for fn in out.loss} 

988 models = _load_list_of_models(models, custom_objects=custom_objects) 

989 n_inputs = models[0].inputs[0].shape[-1] 

990 if _params.periodicity < float("inf"): 

991 n_inputs = int(n_inputs / 2) 

992 model = SequentialModel(n_inputs, out.p, models) 

993 elif len(models) == 3: 

994 print("Loading a functional model can take up to a minute.") 

995 models.sort(key=_model_sort_key) 

996 encoder_model_name = models[0] 

997 custom_objects = {fn.__name__: fn for fn in out.loss} 

998 models = _load_list_of_models(models, custom_objects=custom_objects) 

999 model = models[2] 

1000 model.encoder_model = models[0] 

1001 model.decoder_model = models[1] 

1002 

1003 msg = None 

1004 if not _params.use_backbone_angles and not _params.use_sidechains: 

1005 if len(models[0].input_shape) != 2: 

1006 msg = ( 

1007 f"Reloading the models seemed to have failed. I expected the " 

1008 f"Encoder model to take an input of shape (None, Any), but the " 

1009 f"file at {encoder_model_name} takes an input shape of " 

1010 f"{models[0].input_shape}. This error can also be caused by bad " 

1011 f"filenames." 

1012 ) 

1013 elif _params.use_backbone_angles and not _params.use_sidechains: 

1014 if len(models[0].input_shape) != 2: 

1015 msg = ( 

1016 f"Reloading the models seemed to have failed. I expected the " 

1017 f"Encoder model to take an input of shape [(None, Any), (None, Any)] but the " 

1018 f"file at {encoder_model_name} takes an input shape of " 

1019 f"{models[0].input_shape}. This error can also be caused by bad " 

1020 f"filenames." 

1021 ) 

1022 else: 

1023 if len(models[0].input_shape) != 3: 

1024 msg = ( 

1025 f"Reloading the models seemed to have failed. I expected the " 

1026 f"Encoder model to take an input of shape [(None, Any), (None, Any), (None, Any)] but the " 

1027 f"file at {encoder_model_name} takes an input shape of " 

1028 f"{models[0].input_shape}. This error can also be caused by bad " 

1029 f"filenames." 

1030 ) 

1031 if msg is not None: 

1032 raise Exception(msg) 

1033 else: 

1034 print("Model is neither Sequential, nor functional. I try to reload it.") 

1035 custom_objects = {fn.__name__: fn for fn in out.loss} 

1036 model = tf.keras.models.load_model( 

1037 checkpoint_path, custom_objects=custom_objects 

1038 ) 

1039 if hasattr(model, "encoder_model") and not hasattr(model, "encode"): 

1040 print( 

1041 "The loaded model lost its `encode` function. I will try to rebuild it." 

1042 ) 

1043 

1044 models = [model.encoder_model, model.decoder_model] 

1045 n_inputs = models[0].inputs[0].shape[-1] 

1046 if _params.periodicity < float("inf"): 

1047 n_inputs = int(n_inputs / 2) 

1048 

1049 if sparse: 

1050 # Third Party Imports 

1051 from tensorflow.keras.layers import Dense, Input 

1052 

1053 shape = n_inputs 

1054 _input_layer = Input( 

1055 shape=(int(shape),), 

1056 sparse=True, 

1057 ) 

1058 x = Dense(shape)(_input_layer) 

1059 get_dense_model = tf.keras.Model( 

1060 inputs=_input_layer, 

1061 outputs=x, 

1062 ) 

1063 model.get_dense_model = get_dense_model 

1064 else: 

1065 get_dense_model = None 

1066 model = SequentialModel( 

1067 n_inputs, out.p, models, sparse=sparse, get_dense_model=get_dense_model 

1068 ) 

1069 

1070 out._model = model 

1071 if os.path.isfile(cp_path + "_step.txt"): 

1072 out.cartesian_loss_step = step + 1 

1073 return out