Coverage for encodermap/misc/saving_loading_models.py: 6%
355 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/misc/saving_loading_models.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""
23Implementation of saving and loading models.
26"""
27################################################################################
28# Imports
29################################################################################
32# Future Imports at the top
33from __future__ import annotations
35# Standard Library Imports
36import copy
37import os
38import re
39import shutil
40import warnings
41from collections.abc import Callable
42from copy import deepcopy
43from glob import glob
44from pathlib import Path
45from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload
47# Third Party Imports
48import numpy as np
49import tensorflow as tf
51# Encodermap imports
52from encodermap.misc.misc import _datetime_windows_and_linux_compatible, run_path
53from encodermap.parameters.parameters import ADCParameters, Parameters
56################################################################################
57# Typing
58################################################################################
61if TYPE_CHECKING:
62 # Third Party Imports
63 import numpy as np
65 # Encodermap imports
66 from encodermap.autoencoder.autoencoder import AutoencoderClass
67 from encodermap.trajinfo.info_all import TrajEnsemble
70################################################################################
71# Globals
72################################################################################
75__all__: list[str] = ["save_model", "load_model"]
78################################################################################
79# Utils
80################################################################################
83def _change_setting_inform_user(
84 p: Union[Parameters, ADCParameters],
85 setting: str,
86 value: Any,
87 parameters_file: Path,
88 compat: bool = False,
89) -> None:
90 """Changes a setting in a parameter file and informs the user with a print message.
92 Args:
93 p (Union[Parameters, ADCParameters]): An instance of the Parameters class.
94 Either `Parameters`, or `ADCParameters`.
95 setting (str): The setting to be changed.
96 value (Any): The new value of the setting.
97 parameters_file (Path): The file in which to change the setting.
98 compat (bool): When loading old .model files and some parameters are
99 inferred from this file.
101 """
102 curr = getattr(p, setting)
103 if curr != value:
104 setattr(p, setting, value)
105 msg = (
106 f"In the parameters file {parameters_file}, the parameter '"
107 f"{setting}' is set to '{curr}', but the architecture of the model "
108 f"being loaded requires this parameter to be {value}. This parameter "
109 f"file might not belong to the model you're trying to load."
110 )
111 print(msg)
112 if not parameters_file.is_file():
113 return
115 if not compat:
116 msg += (
117 f"This tensorflow model was saved in the now deprecated .model "
118 f"format. Since moving to the new .keras files, some changes have "
119 f"been made to how parameters are saved. Some parameters can "
120 f"be inferred from the old .model files by setting the `compat` "
121 f"argument to True. This will create a backup of this parameters "
122 f"file ({parameters_file}) and try to create a new one."
123 )
124 raise Exception(msg)
125 else:
126 msg += (
127 f"I will backup the {parameters_file} and set appropriate values to"
128 f"a new parameters.json."
129 )
130 return
133def _model_sort_key(model_name: str) -> int:
134 """Returns numerical values based on whether `model_name` contains substrings.
136 Args:
137 model_name (str): The filepath to the saved model.
139 Returns:
140 int: Returns 0 for 'encoder', 1 for 'decoder', 2 for everything else.
142 """
143 x = Path(model_name).name
144 return 0 if "encoder" in x else (1 if "decoder" in x else 2)
147def _load_list_of_models(
148 models: list[str],
149 custom_objects: Optional[dict[str, Callable]] = None,
150) -> list[tf.keras.Model]:
151 """Load the models supplied in `models` using keras.
153 Args:
154 models (list[str]): The paths of the models to be loaded
156 """
157 return [
158 tf.keras.models.load_model(x, custom_objects=custom_objects) for x in models
159 ]
162def filter_greater_than(step: int) -> Callable:
163 """Returns a function that can used for filtering.
165 Examples:
166 >>> from encodermap.misc.saving_loading_models import filter_greater_than
167 >>> test = ["one_1.keras", "two_2.keras", "three_3.keras", "four_4.keras"]
168 >>> list(filter(filter_greater_than(3), test))
169 ['three_3.keras', 'four_4.keras']
171 Args:
172 step (int): All files containing this step number or more will not
173 be removed from the Sequence by the builtin filter function.
175 """
177 def closure(path: Path) -> bool:
178 r"""The closue of the `filter_greater_than` function.
180 Takes a `pathlib.Path` and extracts the last number using regexp (\d+).
181 Returns True, if this number is equal or greater than `step`.
183 Args:
184 path (Path): The path to use.
186 Returns:
187 bool: Whether the last number in `path` is equal or greater than step.
189 """
190 current_step = int(re.findall(r"\d+", str(path))[-1])
191 return current_step >= step
193 return closure
196################################################################################
197# Functions
198################################################################################
201def save_model(
202 model: tf.keras.Model,
203 main_path: Union[str, Path],
204 inp_class_name: Optional[str] = None,
205 step: Optional[int] = None,
206 print_message: bool = False,
207) -> Path:
208 """Saves a model in the portable .keras format.
210 Args:
211 model (tf.keras.models.Model): The keras model to save. If the
212 keras model has the attribute 'encoder_model' the encoder_model
213 will be saved separately. The same with the attribute 'decoder_model'.
214 main_path (Union[str, Path]): Which directory to save the model to.
215 If step is None, the nae will be saved_model_{time}.keras, where time
216 is a current ISO-8601 formatted string.
217 step (Optional[int]): Can be None, in which case the model will bve saved
218 using the current time. Otherwise, the step argument will be used like
219 so: saved_model_{step}.keras Defaults to None.
220 print_message (bool): Whether to print a message after saving the model
221 Defaults to False.
223 Returns:
224 Path: The path, where the model was saved.
226 """
227 main_path = Path(main_path)
228 assert main_path.is_dir(), f"Please provide a directory as `main_path`."
229 if step is None:
230 time = _datetime_windows_and_linux_compatible()
231 fname = main_path / f"saved_model_{time}.keras"
232 encoder_name = main_path / f"saved_model_{time}_encoder.keras"
233 decoder_name = main_path / f"saved_model_{time}_decoder.keras"
234 else:
235 fname = main_path / f"saved_model_{step}.keras"
236 encoder_name = main_path / f"saved_model_{step}_encoder.keras"
237 decoder_name = main_path / f"saved_model_{step}_decoder.keras"
239 if print_message:
240 if inp_class_name is not None:
241 print(
242 f"Saving the model to {fname}. Use `em.{inp_class_name}.from_checkpoint('{main_path}')` "
243 f"to load the most recent model, or `em.{inp_class_name}.from_checkpoint('{fname}')` "
244 f"to load the model with specific weights.."
245 )
246 else:
247 print(f"Saving the model to {fname}.")
248 model.save(fname)
250 if hasattr(model, "encoder_model"):
251 if print_message:
252 print(
253 f"This model has a subclassed encoder, which can be loaded inde"
254 f"pendently. Use `tf.keras.load_model('{encoder_name}')` to load "
255 f"only this model."
256 )
257 model.encoder_model.save(encoder_name)
259 if hasattr(model, "decoder_model"):
260 if print_message:
261 print(
262 f"This model has a subclassed decoder, which can be loaded inde"
263 f"pendently. Use `tf.keras.load_model('{decoder_name}')` to load "
264 f"only this model."
265 )
266 model.decoder_model.save(decoder_name)
268 return fname
271def sort_model_files_with_timestr(path: Path) -> int:
272 """Returns -1 for all files that have an ISO time in their filename and
273 other numbers for files with numbers in their names. When a file has
274 multiple number in its name '/path/to20/directory5/file_200.txt', the last
275 number (200 in this case) will be used.
277 Returns:
278 int: The sorting value.
280 """
281 m = re.match(
282 r".*\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]" r"\d:[0-5]\d([+-][0-2]\d:[0-5]\d|Z).*",
283 str(path),
284 )
285 if m is not None:
286 return -1
287 number = int(re.findall(r"\d+", str(path))[-1])
288 return number
291def _find_and_sort_files(
292 path: Path,
293 globstr: str,
294 sort_criterion: Literal["creation_time", "number"] = "number",
295) -> list[Path]:
296 """Helper function to list possible files in `path`, using a `globstr` and
297 a `sort_criterion`.
299 Args:
300 path (Path): The directory to start with.
301 globstr (str): The globstring to use. Example "*saved_model*".
302 sort_criterion (Literal["creation_time", "number"]): Files can
303 either be sorted by the creation time (`Path.stat().st_ctime`) or
304 by "number" in which case the file 'saved_model_500.keras' will
305 appear before the file 'saved_model_1000.keras'.
307 Returns:
308 list[Path]: A list of paths with the applied sorting.
310 """
311 if sort_criterion == "creation_time":
312 sortkey = lambda x: x.stat().st_ctime
313 elif sort_criterion == "number":
314 sortkey = sort_model_files_with_timestr
315 else:
316 raise ValueError(
317 f"The argument `sort_criterion` has to be 'creation_time', or "
318 f"'number', you supplied {sort_criterion=}."
319 )
320 l = list(
321 sorted(
322 filter(
323 lambda x: "encoder" not in x.stem and "decoder" not in x.stem,
324 path.glob(globstr),
325 ),
326 key=sortkey,
327 )
328 )
329 return l
332@overload
333def load_model( 333 ↛ exitline 333 didn't jump to the function exit
334 autoencoder: Union[None, "AutoencoderClass"],
335 checkpoint_path: Union[str, Path],
336 trajs: Optional[TrajEnsemble],
337 sparse: bool,
338 dataset: Optional[Union[tf.data.Dataset, np.ndarray]],
339 print_message: bool,
340 submodel: Literal[None],
341 use_previous_model: bool,
342 compat: bool,
343) -> "AutoencoderClass": ... # pragma: no doccheck
346@overload
347def load_model( 347 ↛ exitline 347 didn't jump to the function exit
348 autoencoder: Union[None, "AutoencoderClass"],
349 checkpoint_path: Union[str, Path],
350 trajs: Optional[TrajEnsemble],
351 sparse: bool,
352 dataset: Optional[Union[tf.data.Dataset, np.ndarray]],
353 print_message: bool,
354 submodel: Literal["encoder", "decoder"],
355 use_previous_model: bool,
356 compat: bool,
357) -> tf.keras.Model: ... # pragma: no doccheck
360def load_model(
361 autoencoder: Union[None, "AutoencoderClass"],
362 checkpoint_path: Union[str, Path],
363 trajs: Optional[TrajEnsemble] = None,
364 sparse: bool = False,
365 dataset: Optional[Union[tf.data.Dataset, np.ndarray]] = None,
366 print_message: bool = False,
367 submodel: Optional[Literal["encoder", "decoder"]] = None,
368 use_previous_model: bool = False,
369 compat: bool = False,
370) -> Union["AutoencoderClass", tf.keras.Model]:
371 """Reloads a model from a checkpoint path.
373 An implementation of saving the .keras files procuded by EncoderMap.
374 The old legacy .model files can still be loaded by this function. Or use
375 the `load_model_legacy` function directly.
377 Args:
378 autoencoder (Union[None, "AutoencoderClass"]): Kept for
379 legacy reasons. The old .model files had a list of "custom_objects"
380 that was created by the autoencoder classes (`AutoEncoder`,
381 `EncoderMap`. `AngleDihedralCartesianEncoderMap`) and needed to
382 be supplied when reloading the models from disk. The new implementations
383 use the `from_config` and `get_config` implementations of serializable
384 keras objects and thus, the layers and cost functions can save their
385 own state. Is only needed to load legacy models and can be None if a
386 new .keras model is loaded.
387 checkpoint_path (Union[str, Path]): Can be either the path to a .keras
388 file or to a directory with multiple .keras files in which case, the
389 most recent .keras file will be loaded.
390 trajs (Optional[TrajEnsemble]): A `TrajEnsemble` class for when
391 a `AngleDihedralCartesianEncoderMap` is reloaded.
392 sparse (bool): This argument is also only needed to load legacy .model
393 files. Defaults to False.
394 dataset (Optional[Union[tf.data.Dataset, np.ndarray]]): A pass-through to
395 the `dataset` argument of the autoencoder classes (`AutoEncoder`,
396 `EncoderMap`. `AngleDihedralCartesianEncoderMap`) which all can take
397 a tf.data.Dataset. Can be None, in which case, the data will be
398 sourced differently (The `EncoderMap` class uses example data from
399 a 4D hypercube, the `AngleDihedralCartesianEncoderMap` uses the
400 data from the provided `trajs`.)
401 print_message (bool): Whether to print some debug information. Defaults to False.
402 submodel (Optional[Literal["encoder", "decoder"]]): Whether to only load
403 a specific submodel. In order to use this argument, a file with
404 the name *encoder.keras or *decoder.keras has to be in the
405 in `checkpoint_path` specified directory.
406 use_previous_model (bool): Whether to load a model from an intermediate
407 checkpoint step.
408 compat (bool): Whether to fix a parameters.json file that has been saved
409 with the legacy .model file.
411 Returns:
412 Union[tf.keras.models.Model, "AutoencoderClass"]: A tf.keras.models.Model
413 when you specified submodel. And an appropriate "AutoencoderClass"
414 otherwise.
417 """
418 if "decoder.keras" in str(checkpoint_path) and submodel is None:
419 raise Exception(
420 f"The file you provided is just the decoder submodel of the complete "
421 f"{autoencoder.__name__} class. Loading submodels, requires "
422 f"you to explicitly set the argument `submodel='decoder'`. Note, "
423 f"that loading submodels will return a `tf.keras.models.Model` instead "
424 f"of an instance of {autoencoder.__name__}."
425 )
426 if "encoder.keras" in str(checkpoint_path) and submodel is None:
427 raise Exception(
428 f"The file you provided is just the emcoder submodel of the complete "
429 f"{autoencoder.__name__} class. Loading submodels, requires "
430 f"you to explicitly set the argument `submodel='emcoder'`. Note, "
431 f"that loading submodels will return a `tf.keras.models.Model` instead "
432 f"of an instance of {autoencoder.__name__}."
433 )
434 checkpoint_path = Path(checkpoint_path)
435 if ".model" in str(checkpoint_path):
436 print("Will use the legacy loader for old '*.model' file.")
437 return load_model_legacy(
438 autoencoder_class=autoencoder,
439 checkpoint_path=str(checkpoint_path),
440 trajs=trajs,
441 sparse=sparse,
442 dataset=dataset,
443 compat=compat,
444 )
445 if checkpoint_path.is_dir():
446 possible_models = _find_and_sort_files(checkpoint_path, "*saved_model*")
447 try:
448 newest_model = possible_models[-1]
449 except IndexError as e:
450 raise Exception(
451 f"{checkpoint_path=} has no .keras files: {possible_models=}"
452 ) from e
453 if ".model" not in str(newest_model):
454 if print_message:
455 print(
456 f"Found {len(possible_models)} in {checkpoint_path}. I will reload "
457 f"{newest_model}, because this is the newest file."
458 )
459 model = tf.keras.models.load_model(newest_model)
460 checkpoint_path = newest_model
461 else:
462 possible_old_models = possible_models = _find_and_sort_files(
463 checkpoint_path, "*.model"
464 )
465 print("Will use the legacy loader for old '*.model' file.")
466 return load_model_legacy(
467 autoencoder_class=autoencoder,
468 checkpoint_path=str(possible_old_models[-1]),
469 trajs=trajs,
470 sparse=sparse,
471 dataset=dataset,
472 compat=compat,
473 )
474 else:
475 if ".model" in str(checkpoint_path):
476 return load_model_legacy(
477 autoencoder_class=autoencoder,
478 checkpoint_path=str(checkpoint_path),
479 trajs=trajs,
480 sparse=sparse,
481 dataset=dataset,
482 compat=compat,
483 )
484 else:
485 model = tf.keras.models.load_model(checkpoint_path)
487 # maybe load just encoder or decoder, if requested
488 if submodel is not None:
489 if submodel == "encoder":
490 encoder_file = checkpoint_path.parent / checkpoint_path.name.replace(
491 ".keras", "_encoder.keras"
492 )
493 return tf.keras.models.load_model(encoder_file)
494 elif submodel == "decoder":
495 decoder_file = checkpoint_path.parent / checkpoint_path.name.replace(
496 ".keras", "_decoder.keras"
497 )
498 return tf.keras.models.load_model(decoder_file)
499 else:
500 raise ValueError(
501 f"Argument `submodel` can only be either 'enocer' or 'decoder'. "
502 f"You supplied: {submodel=}."
503 )
505 # load the params in the directory
506 parameter_file = checkpoint_path.parent / "parameters.json"
507 if not parameter_file.is_file() and autoencoder is not None:
508 warnings.warn(
509 f"There was no parameters.json file in the directory. {parameter_file.parent}. "
510 f"I will load the model from the keras file, but I can't build a "
511 f"{autoencoder} instance without the parameters."
512 )
513 if parameter_file.is_file():
514 assert (
515 autoencoder is not None
516 ), f"Please provide a class inheriting from `Autoencoder`."
517 if "cartesian" in parameter_file.read_text():
518 p = ADCParameters.from_file(parameter_file)
520 # make sure parameters and current training step are the same
521 current_step = re.findall(r"\d+", str(checkpoint_path))
522 backup_parameters = (
523 parameter_file.parent
524 / f"parameters_at_{p.current_training_step}_{_datetime_windows_and_linux_compatible()}_{parameter_file.suffix}"
525 )
526 if len(current_step) < 3 and len(current_step) >= 1:
527 current_step = int(current_step[-1])
528 files_to_backup = list(
529 filter(
530 filter_greater_than(current_step),
531 checkpoint_path.parent.glob("*.keras"),
532 )
533 )
534 backup_files = [
535 f.parent
536 / (
537 f.stem
538 + f"_backup_from_{current_step}_{_datetime_windows_and_linux_compatible()}.keras"
539 )
540 for f in files_to_backup
541 ]
542 if current_step != p.current_training_step:
543 if not use_previous_model:
544 raise Exception(
545 f"The model was saved at step {current_step}, but the parameters "
546 f"file has its current step at {p.current_training_step}. "
547 f"It seems like you are reloading a model at an intermediate "
548 f"step. If you set the `use_previous_model` flag to True, "
549 f"I will backup the parameters file to {backup_parameters} and "
550 f"set the new training step so, that you can use/retrain this "
551 f"model."
552 )
553 else:
554 shutil.move(parameter_file, backup_parameters)
555 for f1, f2 in zip(files_to_backup, backup_files):
556 shutil.copyfile(f1, f2)
557 p.current_training_step = current_step + 1
558 p.n_steps = current_step + 1
559 p.save()
561 # then load and return the autoencoder
562 _using_hypercube = deepcopy(p.using_hypercube)
563 out = autoencoder(
564 trajs,
565 parameters=p,
566 read_only=False,
567 dataset=dataset,
568 model=model,
569 )
570 out.p.using_hypercube = _using_hypercube
571 return out
572 else:
573 p = Parameters.from_file(parameter_file)
575 # make sure parameters and current training step are the same
576 current_step = re.findall(r"\d+", str(checkpoint_path))
577 backup_parameters = (
578 parameter_file.parent
579 / f"parameters_at_{p.current_training_step}_{_datetime_windows_and_linux_compatible()}_{parameter_file.suffix}"
580 )
581 if len(current_step) < 3 and len(current_step) >= 1:
582 current_step = int(current_step[-1])
583 files_to_backup = list(
584 filter(
585 filter_greater_than(current_step),
586 checkpoint_path.parent.glob("*.keras"),
587 )
588 )
589 backup_files = [
590 f.parent
591 / (
592 f.stem
593 + f"_backup_from_{current_step}_{_datetime_windows_and_linux_compatible()}.keras"
594 )
595 for f in files_to_backup
596 ]
597 if current_step != p.current_training_step:
598 if not use_previous_model:
599 raise Exception(
600 f"The model was saved at step {current_step}, but the parameters "
601 f"file has its current step at {p.current_training_step}. "
602 f"It seems like you are reloading a model at an intermediate "
603 f"step. If you set the `use_previous_model` flag to True, "
604 f"I will backup the parameters file to {backup_parameters} and "
605 f"set the new training step so, that you can use/retrain this "
606 f"model."
607 )
608 else:
609 shutil.move(parameter_file, backup_parameters)
610 for f1, f2 in zip(files_to_backup, backup_files):
611 shutil.copyfile(f1, f2)
612 p.current_training_step = current_step
613 p.n_steps = current_step
614 p.save()
616 # then load and return the autoencoder
617 _using_hypercube = deepcopy(p.using_hypercube)
618 out = autoencoder(
619 parameters=p,
620 train_data=dataset,
621 read_only=False,
622 model=model,
623 )
624 out.p.using_hypercube = _using_hypercube
625 return out
626 return model
629def load_model_legacy(
630 autoencoder_class: Union[None, "AutoencoderClass"],
631 checkpoint_path: Union[str, Path],
632 trajs: Optional[TrajEnsemble] = None,
633 sparse: bool = False,
634 dataset: Optional[Union[tf.data.Dataset, np.ndarray]] = None,
635 compat: bool = False,
636) -> "AutoencoderClass":
637 """Loads legacy .model files.
639 Note:
640 The .model format has been deprecated. Please update your saved models
641 to the .keras format. You can yse this function to rebuild a new
642 model from the legacy .model files.
644 Args:
645 autoencoder_class (Union[None, AutoencoderClass]): A class of the in
646 EncoderMap implemented autoencoder classes.
647 checkpoint_path (Union[str, Path]): The path to the file to load.
648 trajs (Optional[TrajEnsemble]): When loading an AngleDihedralCartesianEncoderMap,
649 the trajectories need to be supplied to verify the input/output shapes
650 of the model.
651 sparse (bool): Whether the model contains sparse inputs.
652 dataset (Optional[Union[tf.data.Dataset, np.ndarray]): Either a tf.data.Dataset
653 or a np.ndarray to infer the input shapre from.
654 compat (bool): Whether
657 """
658 # Local Folder Imports
659 from ..autoencoder import AngleDihedralCartesianEncoderMap
660 from ..models import gen_functional_model, gen_sequential_model
662 if "*" in str(checkpoint_path):
663 checkpoint_path = list(
664 sorted(map(Path, glob(str(checkpoint_path))), key=_model_sort_key)
665 )
666 parameters_file = checkpoint_path[0].parent / "parameters.json"
667 found = re.findall(r"\d+", str(checkpoint_path[0].name))
668 else:
669 checkpoint_path = Path(checkpoint_path)
670 parameters_file = checkpoint_path.parent / "parameters.json"
671 found = re.findall(r"\d+", str(checkpoint_path.name))
672 read_only = False
674 if dataset is not None:
675 d = dataset.take(1)
676 if any([isinstance(i, tf.sparse.SparseTensor) for i in d]):
677 sparse = True
679 try:
680 step = int(found[-1])
681 except IndexError:
682 step = None
683 except ValueError as e:
684 raise Exception(f"{checkpoint_path=} {found=}") from e
686 if autoencoder_class is AngleDihedralCartesianEncoderMap:
687 p_class = ADCParameters
688 else:
689 p_class = Parameters
691 if not parameters_file.is_file():
692 parameters = p_class()
693 print(
694 f"Couldn't find the parameter's file at {parameters_file}. "
695 f"Will use default {parameters.__class__.__name__} and will "
696 f"infer architecture parameters from the model on disk. Weights "
697 f"from the old model will be transferred to the new model. "
698 f"From now on, you can save the model as a new .keras file."
699 )
700 read_only = True
701 else:
702 parameters = p_class.from_file(parameters_file)
703 print(
704 "Weights from the old model will be transferred to the new "
705 "model. From now on, you can save the model as a new .keras file."
706 )
708 # set the current training step
709 if parameters.current_training_step == 0 and step is not None:
710 parameters.current_training_step = step
712 # make assumptions on data based on input shape
713 if autoencoder_class is AngleDihedralCartesianEncoderMap:
714 old_model = tf.keras.models.load_model(checkpoint_path, compile=False)
715 input_shape = old_model.input_shape
716 encoder_input_shape = old_model.encoder_model.input_shape
717 else:
719 class OldModel:
720 pass
722 old_model = OldModel()
723 assert len(checkpoint_path) == 2
724 old_model.encoder = tf.keras.models.load_model(
725 checkpoint_path[0], compile=False
726 )
727 old_model.decoder = tf.keras.models.load_model(
728 checkpoint_path[1], compile=False
729 )
730 old_model.encoder_model = old_model.encoder
731 old_model.decoder_model = old_model.decoder
732 input_shape = old_model.encoder.input_shape[1]
733 if dataset is not None:
734 d = dataset.take(1)
735 for a in d:
736 break
737 d = a
738 if d[0].shape[1] != input_shape:
739 if d[0].shape[1] * 2 == input_shape:
740 _change_setting_inform_user(
741 parameters,
742 "periodicity",
743 np.pi,
744 parameters_file=parameters_file,
745 compat=compat,
746 )
747 else:
748 raise Exception(
749 f"The shape of the provided data {d[0].shape[1]} does not "
750 f"match the input shape {input_shape} of the network. Are you sure, you "
751 f"are loading the correct checkpoint?"
752 )
753 else:
754 _change_setting_inform_user(
755 parameters,
756 "periodicity",
757 float("inf"),
758 parameters_file=parameters_file,
759 compat=compat,
760 )
761 if not parameters_file.is_file():
762 print(
763 f"I will create a model with an input shape of {input_shape}. For "
764 f"periodic data (e.g., angles), this network will not work. If you "
765 f"are reloading a network for periodic data please manually "
766 f"create a parameters file at {parameters_file} with the appropriate "
767 f"periodicity."
768 )
770 if autoencoder_class is AngleDihedralCartesianEncoderMap:
771 if len(encoder_input_shape) == 3:
772 _change_setting_inform_user(
773 parameters,
774 "use_sidechains",
775 True,
776 parameters_file=parameters_file,
777 compat=compat,
778 )
779 _change_setting_inform_user(
780 parameters,
781 "use_backbone_angles",
782 True,
783 parameters_file=parameters_file,
784 compat=compat,
785 )
786 elif len(encoder_input_shape) == 2:
787 _change_setting_inform_user(
788 parameters,
789 "use_sidechains",
790 False,
791 parameters_file=parameters_file,
792 compat=compat,
793 )
794 _change_setting_inform_user(
795 parameters,
796 "use_backbone_angles",
797 True,
798 parameters_file=parameters_file,
799 compat=compat,
800 )
801 else:
802 _change_setting_inform_user(
803 parameters,
804 "use_sidechains",
805 False,
806 parameters_file=parameters_file,
807 compat=compat,
808 )
809 _change_setting_inform_user(
810 parameters,
811 "use_backbone_angles",
812 False,
813 parameters_file=parameters_file,
814 compat=compat,
815 )
817 # make some assumptions about the n_neurons and activation_function parameters
818 n_neurons = []
819 activation_functions = []
820 for layer in old_model.encoder_model.layers:
821 if hasattr(layer, "units"):
822 n_neurons.append(layer.units)
823 act = layer.activation.__name__
824 if act == "linear":
825 activation_functions.append("")
826 else:
827 activation_functions.append(act)
828 activation_functions = [activation_functions[-1]] + activation_functions
830 _change_setting_inform_user(
831 parameters,
832 "n_neurons",
833 n_neurons,
834 parameters_file=parameters_file,
835 compat=compat,
836 )
837 _change_setting_inform_user(
838 parameters,
839 "activation_functions",
840 activation_functions,
841 parameters_file=parameters_file,
842 compat=compat,
843 )
845 if autoencoder_class is AngleDihedralCartesianEncoderMap:
846 new_model = gen_functional_model(
847 input_shapes=tuple([v[1:] for v in old_model.input_shape]),
848 parameters=parameters,
849 sparse=sparse,
850 write_summary=not read_only,
851 )
852 else:
853 new_model = gen_sequential_model(
854 input_shape=input_shape,
855 parameters=parameters,
856 sparse=sparse,
857 )
858 try:
859 new_model.encoder_model.set_weights(old_model.encoder.get_weights())
860 except AttributeError as e:
861 new_model.encoder_model.set_weights(old_model.encoder_model.get_weights())
862 except Exception as e:
863 raise Exception(
864 f"{[i.shape for i in new_model.encoder_model.get_weights()]=}\n\n"
865 f"{[i.shape for i in old_model.encoder_model.get_weights()]=}"
866 ) from e
868 try:
869 new_model.decoder_model.set_weights(old_model.decoder.get_weights())
870 except AttributeError as e:
871 new_model.decoder_model.set_weights(old_model.decoder_model.get_weights())
872 except Exception as e:
873 raise Exception(
874 f"{[i.shape for i in new_model.decoder_model.get_weights()]=}\n\n"
875 f"{[i.shape for i in old_model.decoder_model.get_weights()]=}"
876 ) from e
878 if autoencoder_class is AngleDihedralCartesianEncoderMap:
879 new_class = autoencoder_class(
880 trajs=trajs,
881 parameters=parameters,
882 model=new_model,
883 read_only=read_only,
884 dataset=dataset,
885 )
886 if not read_only:
887 new_class.save()
888 return new_class
889 else:
890 new_class = autoencoder_class(
891 parameters=parameters,
892 train_data=dataset,
893 model=new_model,
894 read_only=read_only,
895 sparse=sparse,
896 )
897 if not read_only:
898 new_class.save()
899 return new_class
902def load_model_legacy_dep(
903 autoencoder_class: AutoencoderClass,
904 checkpoint_path: Union[str, Path],
905 read_only: bool = True,
906 overwrite_tensorboard_bool: bool = False,
907 trajs: Optional[TrajEnsemble] = None,
908 sparse: bool = False,
909 dataset: Optional[tf.data.Dataset] = None,
910) -> AutoencoderClass: # pragma: no doccheck
911 """Reloads a tf.keras.Model from a checkpoint path.
914 For this, an AutoencoderClass is necessary to provide the corresponding
915 custom objects, such as loss functions.
918 """
919 basedir = os.path.split(checkpoint_path)[0]
921 # remove wildcard
922 if "*" in checkpoint_path:
923 cp_path = checkpoint_path.replace("*", "")
924 else:
925 cp_path = checkpoint_path
927 if trajs is None and dataset is None:
928 params = Parameters.from_file(basedir + "/parameters.json")
929 _params = copy.deepcopy(params)
930 if overwrite_tensorboard_bool:
931 params.tensorboard = False
932 directory = run_path("/".join(checkpoint_path.split("/")[:-1]))
933 if directory != params.main_path:
934 print(
935 f"The saved model files have been moved from {params.main_path} "
936 f"to {directory}. I will overwrite the 'main_path' attribute of "
937 f"these parameters."
938 )
939 params = deepcopy(params)
940 params.main_path = run_path(directory)
941 if not hasattr(params, "write_summary"):
942 params.write_summary = params.tensorboard
943 out = autoencoder_class(parameters=params, read_only=read_only)
944 else:
945 params = ADCParameters.from_file(basedir + "/parameters.json")
946 _params = copy.deepcopy(params)
947 if overwrite_tensorboard_bool:
948 params.tensorboard = False
949 if os.path.isfile(cp_path + "_current_step.txt"):
950 with open(cp_path + "_current_step.txt", "r") as f:
951 step = int(f.read())
952 elif read_only:
953 step = 0
954 else:
955 print(
956 "Cannot find cartesian loss step. Retraining of this model might "
957 "lead to unexpected results."
958 )
959 directory = run_path("/".join(checkpoint_path.split("/")[:-1]))
960 if directory != params.main_path:
961 print(
962 f"The saved model files have been moved from {params.main_path} "
963 f"to {directory}. I will overwrite the 'main_path' attribute of "
964 f"these parameters."
965 )
966 params = deepcopy(params)
967 params.main_path = directory
968 if not hasattr(params, "write_summary"):
969 params.write_summary = params.tensorboard
970 out = autoencoder_class(
971 trajs,
972 parameters=params,
973 read_only=read_only,
974 dataset=dataset,
975 )
976 out.p = _params
978 # see if there are multiple models
979 if "*" not in checkpoint_path:
980 models = glob.glob(checkpoint_path + "*/")
981 else:
982 models = glob.glob(checkpoint_path + "/")
984 # three different ways of loading models
985 if len(models) == 2:
986 models.sort(key=_model_sort_key)
987 custom_objects = {fn.__name__: fn for fn in out.loss}
988 models = _load_list_of_models(models, custom_objects=custom_objects)
989 n_inputs = models[0].inputs[0].shape[-1]
990 if _params.periodicity < float("inf"):
991 n_inputs = int(n_inputs / 2)
992 model = SequentialModel(n_inputs, out.p, models)
993 elif len(models) == 3:
994 print("Loading a functional model can take up to a minute.")
995 models.sort(key=_model_sort_key)
996 encoder_model_name = models[0]
997 custom_objects = {fn.__name__: fn for fn in out.loss}
998 models = _load_list_of_models(models, custom_objects=custom_objects)
999 model = models[2]
1000 model.encoder_model = models[0]
1001 model.decoder_model = models[1]
1003 msg = None
1004 if not _params.use_backbone_angles and not _params.use_sidechains:
1005 if len(models[0].input_shape) != 2:
1006 msg = (
1007 f"Reloading the models seemed to have failed. I expected the "
1008 f"Encoder model to take an input of shape (None, Any), but the "
1009 f"file at {encoder_model_name} takes an input shape of "
1010 f"{models[0].input_shape}. This error can also be caused by bad "
1011 f"filenames."
1012 )
1013 elif _params.use_backbone_angles and not _params.use_sidechains:
1014 if len(models[0].input_shape) != 2:
1015 msg = (
1016 f"Reloading the models seemed to have failed. I expected the "
1017 f"Encoder model to take an input of shape [(None, Any), (None, Any)] but the "
1018 f"file at {encoder_model_name} takes an input shape of "
1019 f"{models[0].input_shape}. This error can also be caused by bad "
1020 f"filenames."
1021 )
1022 else:
1023 if len(models[0].input_shape) != 3:
1024 msg = (
1025 f"Reloading the models seemed to have failed. I expected the "
1026 f"Encoder model to take an input of shape [(None, Any), (None, Any), (None, Any)] but the "
1027 f"file at {encoder_model_name} takes an input shape of "
1028 f"{models[0].input_shape}. This error can also be caused by bad "
1029 f"filenames."
1030 )
1031 if msg is not None:
1032 raise Exception(msg)
1033 else:
1034 print("Model is neither Sequential, nor functional. I try to reload it.")
1035 custom_objects = {fn.__name__: fn for fn in out.loss}
1036 model = tf.keras.models.load_model(
1037 checkpoint_path, custom_objects=custom_objects
1038 )
1039 if hasattr(model, "encoder_model") and not hasattr(model, "encode"):
1040 print(
1041 "The loaded model lost its `encode` function. I will try to rebuild it."
1042 )
1044 models = [model.encoder_model, model.decoder_model]
1045 n_inputs = models[0].inputs[0].shape[-1]
1046 if _params.periodicity < float("inf"):
1047 n_inputs = int(n_inputs / 2)
1049 if sparse:
1050 # Third Party Imports
1051 from tensorflow.keras.layers import Dense, Input
1053 shape = n_inputs
1054 _input_layer = Input(
1055 shape=(int(shape),),
1056 sparse=True,
1057 )
1058 x = Dense(shape)(_input_layer)
1059 get_dense_model = tf.keras.Model(
1060 inputs=_input_layer,
1061 outputs=x,
1062 )
1063 model.get_dense_model = get_dense_model
1064 else:
1065 get_dense_model = None
1066 model = SequentialModel(
1067 n_inputs, out.p, models, sparse=sparse, get_dense_model=get_dense_model
1068 )
1070 out._model = model
1071 if os.path.isfile(cp_path + "_step.txt"):
1072 out.cartesian_loss_step = step + 1
1073 return out