Coverage for encodermap/misc/saving_loading_models.py: 10%

121 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/misc/saving_loading_models.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22 

23""" 

24ToDo: 

25 * This is in a desperate need of rework. 

26 

27""" 

28from __future__ import annotations 

29 

30import typing 

31from typing import Callable, Optional, Union 

32 

33if typing.TYPE_CHECKING: 33 ↛ 34line 33 didn't jump to line 34, because the condition on line 33 was never true

34 from .._typing import AutoencoderClass 

35 

36import copy 

37import glob 

38import os 

39from pathlib import Path 

40 

41import tensorflow as tf 

42 

43from ..misc.misc import _datetime_windows_and_linux_compatible 

44from ..models.models import ADCSequentialModel, FunctionalModel, SequentialModel 

45from ..parameters import ADCParameters, Parameters 

46from .errors import BadError 

47 

48 

49def save_model(model, main_path, inp_class_name, step=None, current_step=None): 

50 if step is None: 

51 time = _datetime_windows_and_linux_compatible() 

52 fname = f"{main_path}/saved_model_{time}.model" 

53 else: 

54 fname = f"{main_path}/saved_model_{step}.model" 

55 if len(model.layers) == 2: 

56 names = ["_" + l.name.lower() for l in model.layers] 

57 for i, submodel in enumerate(model.layers): 

58 tf.keras.models.save_model(submodel, fname + names[i]) 

59 if step is None: 

60 print(f"Saved current state of model.") 

61 print( 

62 f"Use em.{inp_class_name}.from_checkpoint('{fname}*') to reload the current state of the two submodels." 

63 ) 

64 elif issubclass(model.__class__, FunctionalModel): 

65 if step is None: 

66 print("Saving the ACD Functional Model can take up to a minute.") 

67 tf.keras.models.save_model(model, fname) 

68 tf.keras.models.save_model(model.encoder_model, fname + "_encoder") 

69 tf.keras.models.save_model(model.decoder_model, fname + "_decoder") 

70 if step is None: 

71 print( 

72 f"Saved current state of functional model at the end of step " 

73 f"{current_step}. Use em.{inp_class_name}.from_checkpoint('{fname}*') " 

74 f"to reload the current state." 

75 ) 

76 else: 

77 print( 

78 f"Current model is not a subclass of Union[SequentialModel, " 

79 f"ACDSequentialModel, FunctionalModel]. I will try to save it at " 

80 f"{fname}, but can't guarantee that you can reload it." 

81 ) 

82 tf.keras.models.save_model(model, fname) 

83 

84 if current_step is not None: 

85 with open(fname + "_current_step.txt", "w") as f: 

86 f.write(str(current_step)) 

87 

88 

89def model_sort_key(model_name: str) -> int: 

90 """Returns numerical values baed on whether `model_name` contains substrings. 

91 

92 Args: 

93 model_name (str): The filepath to the saved model. 

94 

95 Returns: 

96 int: Returns 0 for 'encoder', 1 for 'decoder', 2 for everything else. 

97 

98 """ 

99 x = Path(model_name).name 

100 return 0 if "encoder" in x else (1 if "decoder" in x else 2) 

101 

102 

103def load_list_of_models( 

104 models: list[str], 

105 custom_objects: Optional[dict[str, Callable]] = None, 

106) -> list[tf.keras.Model]: 

107 """Load the models supplied in `models` using keras. 

108 

109 Args: 

110 models (list[str]): The paths of the models to be loaded 

111 

112 """ 

113 return [ 

114 tf.keras.models.load_model(x, custom_objects=custom_objects) for x in models 

115 ] 

116 

117 

118def load_model( 

119 autoencoder_class: AutoencoderClass, 

120 checkpoint_path: str, 

121 read_only: bool = True, 

122 overwrite_tensorboard_bool: bool = False, 

123 trajs: Optional[TrajEnsemble] = None, 

124 sparse: bool = False, 

125) -> AutoencoderClass: 

126 """Reloads a tf.keras.Model from a checkpoint path. 

127 

128 

129 For this, an AutoencoderClass is necessary, to provide the corresponding 

130 custom objects, such as loss functions. 

131 

132 

133 """ 

134 basedir = os.path.split(checkpoint_path)[0] 

135 

136 # remove wildcard 

137 if "*" in checkpoint_path: 

138 cp_path = checkpoint_path.replace("*", "") 

139 else: 

140 cp_path = checkpoint_path 

141 

142 if trajs is None: 

143 params = Parameters.from_file(basedir + "/parameters.json") 

144 _params = copy.deepcopy(params) 

145 if overwrite_tensorboard_bool: 

146 params.tensorboard = False 

147 out = autoencoder_class(parameters=params, read_only=read_only) 

148 else: 

149 params = ADCParameters.from_file(basedir + "/parameters.json") 

150 _params = copy.deepcopy(params) 

151 if overwrite_tensorboard_bool: 

152 params.tensorboard = False 

153 if os.path.isfile(cp_path + "_current_step.txt"): 

154 with open(cp_path + "_current_step.txt", "r") as f: 

155 step = int(f.read()) 

156 elif read_only: 

157 step = 0 

158 else: 

159 raise BadError( 

160 "Cannot find cartesian loss step. Model will not be trainable without knowing the step the model was saved at." 

161 ) 

162 out = autoencoder_class( 

163 trajs, parameters=params, read_only=read_only, cartesian_loss_step=step 

164 ) 

165 out.p = _params 

166 

167 # see if there are multiple models 

168 if "*" not in checkpoint_path: 

169 models = glob.glob(checkpoint_path + "*/") 

170 else: 

171 models = glob.glob(checkpoint_path + "/") 

172 

173 # three different ways of loading models 

174 if len(models) == 2: 

175 models.sort(key=model_sort_key) 

176 custom_objects = {fn.__name__: fn for fn in out.loss} 

177 models = load_list_of_models(models, custom_objects=custom_objects) 

178 n_inputs = models[0].inputs[0].shape[-1] 

179 if _params.periodicity < float("inf"): 

180 n_inputs = int(n_inputs / 2) 

181 model = SequentialModel(n_inputs, out.p, models) 

182 elif len(models) == 3: 

183 print("Loading a functional model can take up to a minute.") 

184 models.sort(key=model_sort_key) 

185 encoder_model_name = models[0] 

186 custom_objects = {fn.__name__: fn for fn in out.loss} 

187 models = load_list_of_models(models, custom_objects=custom_objects) 

188 model = models[2] 

189 model.encoder_model = models[0] 

190 model.decoder_model = models[1] 

191 

192 msg = None 

193 if not _params.use_backbone_angles and not _params.use_sidechains: 

194 if len(models[0].input_shape) != 2: 

195 msg = ( 

196 f"Reloading the models seemed to have failed. I expected the " 

197 f"Encoder model to take an input of shape (None, Any), but the " 

198 f"file at {encoder_model_name} takes an input shape of " 

199 f"{models[0].input_shape}. This error can also be caused by bad " 

200 f"filenames." 

201 ) 

202 elif _params.use_backbone_angles and not _params.use_sidechains: 

203 if len(models[0].input_shape) != 2: 

204 msg = ( 

205 f"Reloading the models seemed to have failed. I expected the " 

206 f"Encoder model to take an input of shape [(None, Any), (None, Any)] but the " 

207 f"file at {encoder_model_name} takes an input shape of " 

208 f"{models[0].input_shape}. This error can also be caused by bad " 

209 f"filenames." 

210 ) 

211 else: 

212 if len(models[0].input_shape) != 3: 

213 msg = ( 

214 f"Reloading the models seemed to have failed. I expected the " 

215 f"Encoder model to take an input of shape [(None, Any), (None, Any), (None, Any)] but the " 

216 f"file at {encoder_model_name} takes an input shape of " 

217 f"{models[0].input_shape}. This error can also be caused by bad " 

218 f"filenames." 

219 ) 

220 if msg is not None: 

221 raise Exception(msg) 

222 else: 

223 print("Model is neither Sequential, nor functional. I try to reload it.") 

224 custom_objects = {fn.__name__: fn for fn in out.loss} 

225 model = tf.keras.models.load_model( 

226 checkpoint_path, custom_objects=custom_objects 

227 ) 

228 if hasattr(model, "encoder_model") and not hasattr(model, "encode"): 

229 print( 

230 "The loaded model lost its `encode` function. I will try to rebuild it." 

231 ) 

232 

233 models = [model.encoder_model, model.decoder_model] 

234 n_inputs = models[0].inputs[0].shape[-1] 

235 if _params.periodicity < float("inf"): 

236 n_inputs = int(n_inputs / 2) 

237 

238 if sparse: 

239 from tensorflow.keras.layers import Dense, Input 

240 

241 shape = n_inputs 

242 _input_layer = Input( 

243 shape=(int(shape),), 

244 sparse=True, 

245 ) 

246 x = Dense(shape)(_input_layer) 

247 get_dense_model = tf.keras.Model( 

248 inputs=_input_layer, 

249 outputs=x, 

250 ) 

251 model.get_dense_model = get_dense_model 

252 else: 

253 get_dense_model = None 

254 model = SequentialModel( 

255 n_inputs, out.p, models, sparse=sparse, get_dense_model=get_dense_model 

256 ) 

257 

258 out._model = model 

259 if os.path.isfile(cp_path + "_step.txt"): 

260 out.cartesian_loss_step = step + 1 

261 return out