Coverage for encodermap/encodermap_tf1/autoencoder.py: 12%
217 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# Third Party Imports
2import tensorflow
5try:
6 # Third Party Imports
7 from tensorflow.python.data.ops.readers import TFRecordDatasetV
8except:
9 from tensorflow.python.data.ops.readers import TFRecordDatasetV1 as TFRecordDataset
11# Standard Library Imports
12import os
13from math import pi
15# Third Party Imports
16import numpy as np
17import tensorflow.compat.v1 as tf
18from tensorflow.python.client import timeline
19from tensorflow.python.framework import ops as tf_ops
20from tqdm import tqdm
22# Local Folder Imports
23from .backmapping import dihedrals_to_cartesian_tf
24from .misc import (
25 add_layer_summaries,
26 distance_cost,
27 pairwise_dist,
28 periodic_distance,
29 variable_summaries,
30)
31from .parameters import Parameters
34class Autoencoder:
35 def __init__(
36 self,
37 parameters,
38 train_data=None,
39 validation_data=None,
40 checkpoint_path=None,
41 n_inputs=None,
42 read_only=False,
43 seed=None,
44 debug=False,
45 ):
46 """
47 :param parameters: Parameters object as defined in :class:`encodermap.encodermap_tf1.parameters.Parameters`
49 :param train_data: 2d numpy array where each row is treated as a training point
51 :param validation_data: A 2d numpy array. This data will only be used to calculate a validation error during
52 training. It will not be used for training.
54 :param checkpoint_path: If a checkpoint path is given, values like neural network weights stored in this
55 checkpoint will be restored.
57 :param n_inputs: If no train_data is given, for example when an already trained network is restored from a
58 checkpoint, the number of of inputs needs to be given. This should be equal to the number of
59 columns of the train_data the network was originally trained with.
60 :param read_only: if True, no output is writen
61 """
62 # Parameters:
63 self.p = parameters
64 self.n_inputs = n_inputs
65 # if seed is provided, weights and biases are fixed to ensure reproducibility
66 self.seed = seed
67 tf.random.set_random_seed(self.seed)
68 self.debug = debug
69 if not read_only:
70 self.p.save()
71 print(
72 "Output files are saved to {}".format(self.p.main_path),
73 "as defined in 'main_path' in the parameters.",
74 )
76 # print("Disabling eager execution")
77 # tf.compat.v1.disable_eager_execution()
79 self.train_data = train_data
80 self.validation_data = validation_data
82 self._prepare_data()
84 self.graph = tf.Graph()
85 with self.graph.as_default():
86 self.global_step = tf.train.create_global_step()
88 self._setup_data_iterator()
90 self._setup_network()
92 with tf.name_scope("cost"):
93 self.cost = 0
94 self._setup_cost()
95 tf.summary.scalar("combined_cost", self.cost)
97 # Setup Optimizer:
98 self.optimizer = tf.train.AdamOptimizer(self.p.learning_rate)
99 self.gradients = self.optimizer.compute_gradients(self.cost)
100 self.optimize = self.optimizer.apply_gradients(
101 self.gradients, global_step=self.global_step
102 )
104 self.merged_summaries = tf.summary.merge_all()
106 # Setup Session
107 if self.p.gpu_memory_fraction == 0:
108 gpu_options = tf.GPUOptions(allow_growth=True)
109 else:
110 gpu_options = tf.GPUOptions(
111 per_process_gpu_memory_fraction=self.p.gpu_memory_fraction
112 )
113 self.sess = tf.Session(
114 config=tf.ConfigProto(gpu_options=gpu_options), graph=self.graph
115 )
116 self.sess.run(tf.global_variables_initializer())
117 self.sess.run(
118 self.data_iterator.initializer,
119 feed_dict={
120 p: d for p, d in zip(self.data_placeholders, self.train_data)
121 },
122 )
123 if not read_only:
124 self.train_writer = tf.summary.FileWriter(
125 os.path.join(self.p.main_path, "train"), self.sess.graph
126 )
127 if self.validation_data is not None:
128 self.validation_writer = tf.summary.FileWriter(
129 os.path.join(self.p.main_path, "validation"), self.sess.graph
130 )
131 self.saver = tf.train.Saver(max_to_keep=100)
133 # load Checkpoint
134 if checkpoint_path:
135 self.saver.restore(self.sess, checkpoint_path)
137 def _prepare_data(self):
138 if self.train_data is None:
139 assert (
140 self.n_inputs is not None
141 ), "If no train_data is given, n_inputs needs to be given"
142 self.train_data = [np.zeros((3, self.n_inputs), dtype=np.float32)]
143 elif isinstance(self.train_data, np.ndarray):
144 self.train_data = [self.train_data.astype(np.float32)]
145 elif isinstance(self.train_data, (list, tuple)):
146 self.train_data = [dat.astype(np.float32) for dat in self.train_data]
147 elif isinstance(self.train_data, TFRecordDataset):
148 pass
149 else:
150 raise ValueError(
151 "{} is not supported as input type for train_data".format(
152 type(self.train_data)
153 )
154 )
156 if self.validation_data is not None:
157 self.validation_data = self.validation_data.astype(np.float32)
158 # Todo: allow lists of validation data
160 def _setup_network(self):
161 self.inputs = self.data_iterator.get_next()
162 self.main_inputs = self.inputs[0]
163 self.main_inputs = tf.placeholder_with_default(
164 self.main_inputs, self.main_inputs.shape
165 )
166 self.regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant)
167 encoded = self._encode(self.main_inputs)
168 self.latent = tf.placeholder_with_default(encoded, encoded.shape)
169 variable_summaries("latent", self.latent)
170 self.generated = self._generate(self.latent)
172 def _setup_data_iterator(self):
173 self.data_placeholders = tuple(
174 tf.placeholder(dat.dtype, dat.shape) for dat in self.train_data
175 )
176 self.data_set = tf.data.Dataset.from_tensor_slices(self.data_placeholders)
177 self.data_set = self.data_set.shuffle(buffer_size=len(self.train_data[0]))
178 self.data_set = self.data_set.repeat()
179 self.data_set = self.data_set.batch(self.p.batch_size)
180 if not tf.executing_eagerly():
181 self.data_iterator = self.data_set.make_initializable_iterator()
182 else:
183 self.data_iterator = self.data_set
185 def _encode(self, inputs):
186 with tf.name_scope("encoder"):
187 if self.p.periodicity < float("inf"):
188 if self.p.periodicity != 2 * pi:
189 inputs = inputs / self.p.periodicity * 2 * pi
190 self.unit_circle_inputs = tf.concat([tf.sin(inputs), tf.cos(inputs)], 1)
191 current_layer = self.unit_circle_inputs
192 else:
193 current_layer = inputs
195 assert (
196 len(self.p.n_neurons) == len(self.p.activation_functions) - 1
197 ), "you need one activation function more then layers given in n_neurons"
198 for i, (n_neurons, act_fun) in enumerate(
199 zip(self.p.n_neurons, self.p.activation_functions[1:])
200 ):
201 if act_fun:
202 act_fun = getattr(tf.nn, act_fun)
203 else:
204 act_fun = None
205 variable_summaries("activation{}".format(i), current_layer, debug=True)
206 dense = tf.layers.Dense(
207 n_neurons,
208 activation=act_fun,
209 kernel_initializer=tf.variance_scaling_initializer(),
210 kernel_regularizer=self.regularizer,
211 bias_initializer=tf.random_normal_initializer(0.1, 0.05),
212 )
213 current_layer = dense(current_layer)
214 add_layer_summaries(dense, debug=self.debug)
215 return current_layer
217 def _generate(self, inputs):
218 with tf.name_scope("generator"):
219 current_layer = inputs
220 if self.p.periodicity < float("inf"):
221 n_neurons_with_inputs = [
222 self.main_inputs.shape[1] * 2
223 ] + self.p.n_neurons
224 else:
225 n_neurons_with_inputs = [self.main_inputs.shape[1]] + self.p.n_neurons
226 for n_neurons, act_fun in zip(
227 n_neurons_with_inputs[-2::-1], self.p.activation_functions[-2::-1]
228 ):
229 if act_fun:
230 act_fun = getattr(tf.nn, act_fun)
231 else:
232 act_fun = None
233 current_layer = tf.layers.dense(
234 current_layer,
235 n_neurons,
236 activation=act_fun,
237 kernel_initializer=tf.variance_scaling_initializer(),
238 kernel_regularizer=self.regularizer,
239 bias_initializer=tf.random_normal_initializer(0.1, 0.05),
240 )
241 if self.p.periodicity < float("inf"):
242 split = self.main_inputs.shape[1]
243 current_layer = tf.atan2(
244 current_layer[:, :split], current_layer[:, split:]
245 )
246 if self.p.periodicity != 2 * pi:
247 current_layer = current_layer / (2 * pi) * self.p.periodicity
248 return current_layer
250 def _setup_cost(self):
251 self._auto_cost()
252 self._center_cost()
253 self._l2_reg_cost()
255 def _auto_cost(self):
256 if self.p.auto_cost_scale is not None:
257 if self.p.auto_cost_variant == "mean_square":
258 auto_cost = tf.reduce_mean(
259 tf.square(
260 periodic_distance(
261 self.main_inputs, self.generated, self.p.periodicity
262 )
263 )
264 )
265 elif self.p.auto_cost_variant == "mean_abs":
266 auto_cost = tf.reduce_mean(
267 tf.abs(
268 periodic_distance(
269 self.main_inputs, self.generated, self.p.periodicity
270 )
271 )
272 )
273 elif self.p.auto_cost_variant == "mean_norm":
274 auto_cost = tf.reduce_mean(
275 tf.norm(
276 periodic_distance(
277 self.main_inputs, self.generated, self.p.periodicity
278 ),
279 axis=1,
280 )
281 )
282 else:
283 raise ValueError(
284 "auto_cost_variant {} not available".format(
285 self.p.auto_cost_variant
286 )
287 )
288 tf.summary.scalar("auto_cost", auto_cost)
289 if self.p.auto_cost_scale != 0:
290 self.cost += self.p.auto_cost_scale * auto_cost
292 def _l2_reg_cost(self):
293 if self.p.l2_reg_constant is not None:
294 reg_cost = tf.losses.get_regularization_loss()
295 tf.summary.scalar("", reg_cost)
296 if self.p.l2_reg_constant != 0:
297 self.cost += reg_cost
299 def _center_cost(self):
300 if self.p.center_cost_scale is not None:
301 center_cost = tf.reduce_mean(tf.square(self.latent))
302 tf.summary.scalar("center_cost", center_cost)
303 if self.p.center_cost_scale != 0:
304 self.cost += self.p.center_cost_scale * center_cost
306 def encode(self, inputs):
307 """
308 Projects high dimensional data to a low dimensional space using the encoder part of the autoencoder.
310 :param inputs: 2d numpy array with the same number of columns as the used train_data
311 :return: 2d numpy array with the point projected the the low dimensional space. The number of columns is equal
312 to the number of neurons in the bottleneck layer of the autoencoder.
313 """
314 latents = []
315 batches = np.array_split(inputs, max(1, int(len(inputs) / 2048)))
316 for batch in batches:
317 latent = self.sess.run(self.latent, feed_dict={self.main_inputs: batch})
318 latents.append(latent)
319 latents = np.concatenate(latents, axis=0)
320 return latents
322 def generate(self, latent):
323 """
324 Generates new high-dimensional points based on given low-dimensional points using the decoder part of the
325 autoencoder.
327 :param latent: 2d numpy array containing points in the low-dimensional space. The number of columns must be
328 equal to the number of neurons in the bottleneck layer of the autoencoder.
329 :return: 2d numpy array containing points in the high-dimensional space.
330 """
331 generateds = []
332 batches = np.array_split(latent, max(1, int(len(latent) / 2048)))
333 for batch in batches:
334 generated = self.sess.run(self.generated, feed_dict={self.latent: batch})
335 generateds.append(generated)
336 generated = np.concatenate(generateds, axis=0)
337 return generated
339 def _random_batch(self, data, batch_size=None):
340 batch_size = batch_size or self.p.batch_size
341 batch = data[np.random.choice(len(data), batch_size, replace=False)]
342 return batch
344 def train(self):
345 """
346 Train the autoencoder as specified in the parameters object.
347 """
348 for i in tqdm(range(self.p.n_steps)):
349 # if fixed seed we need the summaries at global_step == 1
350 # to run unittests on them
351 if self.seed is not None:
352 # run session and get value summaries
353 _, summary_values = self.sess.run(
354 (self.optimize, self.merged_summaries)
355 )
357 # get the current global step
358 x = tf.train.get_global_step(self.graph)
359 x = x.eval(session=self.sess)
361 # read the current Summary
362 values = {}
363 values["global_step"] = x
364 summary_proto = tf.Summary().FromString(summary_values)
365 for entry in summary_proto.value:
366 if "histogram" in entry.tag:
367 values[entry.tag] = entry.histo
368 else:
369 values[entry.tag] = entry.simple_value
370 return values
372 if (i + 1) % self.p.summary_step == 0:
373 # _, summary_values = self.sess.run((self.optimize, self.merged_summaries))
374 _, summary_values = self.sess.run(
375 (self.optimize, self.merged_summaries)
376 )
377 self.train_writer.add_summary(summary_values, self._step())
378 if self.validation_data is not None:
379 summary_values = self.sess.run(
380 self.merged_summaries,
381 feed_dict={
382 self.main_inputs: self._random_batch(self.validation_data)
383 },
384 )
385 self.validation_writer.add_summary(summary_values, self._step())
386 else:
387 self.sess.run(self.optimize)
389 if (self._step()) % self.p.checkpoint_step == 0:
390 self.saver.save(
391 self.sess,
392 os.path.join(
393 self.p.main_path,
394 "checkpoints",
395 "step{}.ckpt".format(self._step()),
396 ),
397 )
398 else:
399 self.saver.save(
400 self.sess,
401 os.path.join(
402 self.p.main_path, "checkpoints", "step{}.ckpt".format(self._step())
403 ),
404 )
405 self.train_writer.flush()
407 def _step(self):
408 return tf.train.global_step(self.sess, self.global_step)
410 def profile(self):
411 options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
412 run_metadata = tf.RunMetadata()
413 for i in range(5):
414 self.sess.run(self.optimize, options=options, run_metadata=run_metadata)
416 fetched_timeline = timeline.Timeline(run_metadata.step_stats)
417 chrome_trace = fetched_timeline.generate_chrome_trace_format()
418 with open(
419 os.path.join(self.p.main_path, "timeline{}.json".format(i)), "w"
420 ) as f:
421 f.write(chrome_trace)
423 def close(self):
424 """
425 Close tensorflow session to free resources.
426 :return:
427 """
428 if "sess" in self.__dict__.keys():
429 self.sess.close()
430 try:
431 tf_ops.dismantle_graph(
432 self.graph
433 ) # not implemented in older versions of tensorflow
434 except AttributeError:
435 pass
437 def __enter__(self):
438 return self
440 def __exit__(self, exc_type, exc_val, exc_tb):
441 self.close()
443 def __del__(self):
444 self.close()