Coverage for encodermap/misc/saving_loading_models.py: 10%
121 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/misc/saving_loading_models.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
23"""
24ToDo:
25 * This is in a desperate need of rework.
27"""
28from __future__ import annotations
30import typing
31from typing import Callable, Optional, Union
33if typing.TYPE_CHECKING: 33 ↛ 34line 33 didn't jump to line 34, because the condition on line 33 was never true
34 from .._typing import AutoencoderClass
36import copy
37import glob
38import os
39from pathlib import Path
41import tensorflow as tf
43from ..misc.misc import _datetime_windows_and_linux_compatible
44from ..models.models import ADCSequentialModel, FunctionalModel, SequentialModel
45from ..parameters import ADCParameters, Parameters
46from .errors import BadError
49def save_model(model, main_path, inp_class_name, step=None, current_step=None):
50 if step is None:
51 time = _datetime_windows_and_linux_compatible()
52 fname = f"{main_path}/saved_model_{time}.model"
53 else:
54 fname = f"{main_path}/saved_model_{step}.model"
55 if len(model.layers) == 2:
56 names = ["_" + l.name.lower() for l in model.layers]
57 for i, submodel in enumerate(model.layers):
58 tf.keras.models.save_model(submodel, fname + names[i])
59 if step is None:
60 print(f"Saved current state of model.")
61 print(
62 f"Use em.{inp_class_name}.from_checkpoint('{fname}*') to reload the current state of the two submodels."
63 )
64 elif issubclass(model.__class__, FunctionalModel):
65 if step is None:
66 print("Saving the ACD Functional Model can take up to a minute.")
67 tf.keras.models.save_model(model, fname)
68 tf.keras.models.save_model(model.encoder_model, fname + "_encoder")
69 tf.keras.models.save_model(model.decoder_model, fname + "_decoder")
70 if step is None:
71 print(
72 f"Saved current state of functional model at the end of step "
73 f"{current_step}. Use em.{inp_class_name}.from_checkpoint('{fname}*') "
74 f"to reload the current state."
75 )
76 else:
77 print(
78 f"Current model is not a subclass of Union[SequentialModel, "
79 f"ACDSequentialModel, FunctionalModel]. I will try to save it at "
80 f"{fname}, but can't guarantee that you can reload it."
81 )
82 tf.keras.models.save_model(model, fname)
84 if current_step is not None:
85 with open(fname + "_current_step.txt", "w") as f:
86 f.write(str(current_step))
89def model_sort_key(model_name: str) -> int:
90 """Returns numerical values baed on whether `model_name` contains substrings.
92 Args:
93 model_name (str): The filepath to the saved model.
95 Returns:
96 int: Returns 0 for 'encoder', 1 for 'decoder', 2 for everything else.
98 """
99 x = Path(model_name).name
100 return 0 if "encoder" in x else (1 if "decoder" in x else 2)
103def load_list_of_models(
104 models: list[str],
105 custom_objects: Optional[dict[str, Callable]] = None,
106) -> list[tf.keras.Model]:
107 """Load the models supplied in `models` using keras.
109 Args:
110 models (list[str]): The paths of the models to be loaded
112 """
113 return [
114 tf.keras.models.load_model(x, custom_objects=custom_objects) for x in models
115 ]
118def load_model(
119 autoencoder_class: AutoencoderClass,
120 checkpoint_path: str,
121 read_only: bool = True,
122 overwrite_tensorboard_bool: bool = False,
123 trajs: Optional[TrajEnsemble] = None,
124 sparse: bool = False,
125) -> AutoencoderClass:
126 """Reloads a tf.keras.Model from a checkpoint path.
129 For this, an AutoencoderClass is necessary, to provide the corresponding
130 custom objects, such as loss functions.
133 """
134 basedir = os.path.split(checkpoint_path)[0]
136 # remove wildcard
137 if "*" in checkpoint_path:
138 cp_path = checkpoint_path.replace("*", "")
139 else:
140 cp_path = checkpoint_path
142 if trajs is None:
143 params = Parameters.from_file(basedir + "/parameters.json")
144 _params = copy.deepcopy(params)
145 if overwrite_tensorboard_bool:
146 params.tensorboard = False
147 out = autoencoder_class(parameters=params, read_only=read_only)
148 else:
149 params = ADCParameters.from_file(basedir + "/parameters.json")
150 _params = copy.deepcopy(params)
151 if overwrite_tensorboard_bool:
152 params.tensorboard = False
153 if os.path.isfile(cp_path + "_current_step.txt"):
154 with open(cp_path + "_current_step.txt", "r") as f:
155 step = int(f.read())
156 elif read_only:
157 step = 0
158 else:
159 raise BadError(
160 "Cannot find cartesian loss step. Model will not be trainable without knowing the step the model was saved at."
161 )
162 out = autoencoder_class(
163 trajs, parameters=params, read_only=read_only, cartesian_loss_step=step
164 )
165 out.p = _params
167 # see if there are multiple models
168 if "*" not in checkpoint_path:
169 models = glob.glob(checkpoint_path + "*/")
170 else:
171 models = glob.glob(checkpoint_path + "/")
173 # three different ways of loading models
174 if len(models) == 2:
175 models.sort(key=model_sort_key)
176 custom_objects = {fn.__name__: fn for fn in out.loss}
177 models = load_list_of_models(models, custom_objects=custom_objects)
178 n_inputs = models[0].inputs[0].shape[-1]
179 if _params.periodicity < float("inf"):
180 n_inputs = int(n_inputs / 2)
181 model = SequentialModel(n_inputs, out.p, models)
182 elif len(models) == 3:
183 print("Loading a functional model can take up to a minute.")
184 models.sort(key=model_sort_key)
185 encoder_model_name = models[0]
186 custom_objects = {fn.__name__: fn for fn in out.loss}
187 models = load_list_of_models(models, custom_objects=custom_objects)
188 model = models[2]
189 model.encoder_model = models[0]
190 model.decoder_model = models[1]
192 msg = None
193 if not _params.use_backbone_angles and not _params.use_sidechains:
194 if len(models[0].input_shape) != 2:
195 msg = (
196 f"Reloading the models seemed to have failed. I expected the "
197 f"Encoder model to take an input of shape (None, Any), but the "
198 f"file at {encoder_model_name} takes an input shape of "
199 f"{models[0].input_shape}. This error can also be caused by bad "
200 f"filenames."
201 )
202 elif _params.use_backbone_angles and not _params.use_sidechains:
203 if len(models[0].input_shape) != 2:
204 msg = (
205 f"Reloading the models seemed to have failed. I expected the "
206 f"Encoder model to take an input of shape [(None, Any), (None, Any)] but the "
207 f"file at {encoder_model_name} takes an input shape of "
208 f"{models[0].input_shape}. This error can also be caused by bad "
209 f"filenames."
210 )
211 else:
212 if len(models[0].input_shape) != 3:
213 msg = (
214 f"Reloading the models seemed to have failed. I expected the "
215 f"Encoder model to take an input of shape [(None, Any), (None, Any), (None, Any)] but the "
216 f"file at {encoder_model_name} takes an input shape of "
217 f"{models[0].input_shape}. This error can also be caused by bad "
218 f"filenames."
219 )
220 if msg is not None:
221 raise Exception(msg)
222 else:
223 print("Model is neither Sequential, nor functional. I try to reload it.")
224 custom_objects = {fn.__name__: fn for fn in out.loss}
225 model = tf.keras.models.load_model(
226 checkpoint_path, custom_objects=custom_objects
227 )
228 if hasattr(model, "encoder_model") and not hasattr(model, "encode"):
229 print(
230 "The loaded model lost its `encode` function. I will try to rebuild it."
231 )
233 models = [model.encoder_model, model.decoder_model]
234 n_inputs = models[0].inputs[0].shape[-1]
235 if _params.periodicity < float("inf"):
236 n_inputs = int(n_inputs / 2)
238 if sparse:
239 from tensorflow.keras.layers import Dense, Input
241 shape = n_inputs
242 _input_layer = Input(
243 shape=(int(shape),),
244 sparse=True,
245 )
246 x = Dense(shape)(_input_layer)
247 get_dense_model = tf.keras.Model(
248 inputs=_input_layer,
249 outputs=x,
250 )
251 model.get_dense_model = get_dense_model
252 else:
253 get_dense_model = None
254 model = SequentialModel(
255 n_inputs, out.p, models, sparse=sparse, get_dense_model=get_dense_model
256 )
258 out._model = model
259 if os.path.isfile(cp_path + "_step.txt"):
260 out.cartesian_loss_step = step + 1
261 return out