Coverage for encodermap/encodermap_tf1/autoencoder.py: 12%

217 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# Third Party Imports 

2import tensorflow 

3 

4 

5try: 

6 # Third Party Imports 

7 from tensorflow.python.data.ops.readers import TFRecordDatasetV 

8except: 

9 from tensorflow.python.data.ops.readers import TFRecordDatasetV1 as TFRecordDataset 

10 

11# Standard Library Imports 

12import os 

13from math import pi 

14 

15# Third Party Imports 

16import numpy as np 

17import tensorflow.compat.v1 as tf 

18from tensorflow.python.client import timeline 

19from tensorflow.python.framework import ops as tf_ops 

20from tqdm import tqdm 

21 

22# Local Folder Imports 

23from .backmapping import dihedrals_to_cartesian_tf 

24from .misc import ( 

25 add_layer_summaries, 

26 distance_cost, 

27 pairwise_dist, 

28 periodic_distance, 

29 variable_summaries, 

30) 

31from .parameters import Parameters 

32 

33 

34class Autoencoder: 

35 def __init__( 

36 self, 

37 parameters, 

38 train_data=None, 

39 validation_data=None, 

40 checkpoint_path=None, 

41 n_inputs=None, 

42 read_only=False, 

43 seed=None, 

44 debug=False, 

45 ): 

46 """ 

47 :param parameters: Parameters object as defined in :class:`encodermap.encodermap_tf1.parameters.Parameters` 

48 

49 :param train_data: 2d numpy array where each row is treated as a training point 

50 

51 :param validation_data: A 2d numpy array. This data will only be used to calculate a validation error during 

52 training. It will not be used for training. 

53 

54 :param checkpoint_path: If a checkpoint path is given, values like neural network weights stored in this 

55 checkpoint will be restored. 

56 

57 :param n_inputs: If no train_data is given, for example when an already trained network is restored from a 

58 checkpoint, the number of of inputs needs to be given. This should be equal to the number of 

59 columns of the train_data the network was originally trained with. 

60 :param read_only: if True, no output is writen 

61 """ 

62 # Parameters: 

63 self.p = parameters 

64 self.n_inputs = n_inputs 

65 # if seed is provided, weights and biases are fixed to ensure reproducibility 

66 self.seed = seed 

67 tf.random.set_random_seed(self.seed) 

68 self.debug = debug 

69 if not read_only: 

70 self.p.save() 

71 print( 

72 "Output files are saved to {}".format(self.p.main_path), 

73 "as defined in 'main_path' in the parameters.", 

74 ) 

75 

76 # print("Disabling eager execution") 

77 # tf.compat.v1.disable_eager_execution() 

78 

79 self.train_data = train_data 

80 self.validation_data = validation_data 

81 

82 self._prepare_data() 

83 

84 self.graph = tf.Graph() 

85 with self.graph.as_default(): 

86 self.global_step = tf.train.create_global_step() 

87 

88 self._setup_data_iterator() 

89 

90 self._setup_network() 

91 

92 with tf.name_scope("cost"): 

93 self.cost = 0 

94 self._setup_cost() 

95 tf.summary.scalar("combined_cost", self.cost) 

96 

97 # Setup Optimizer: 

98 self.optimizer = tf.train.AdamOptimizer(self.p.learning_rate) 

99 self.gradients = self.optimizer.compute_gradients(self.cost) 

100 self.optimize = self.optimizer.apply_gradients( 

101 self.gradients, global_step=self.global_step 

102 ) 

103 

104 self.merged_summaries = tf.summary.merge_all() 

105 

106 # Setup Session 

107 if self.p.gpu_memory_fraction == 0: 

108 gpu_options = tf.GPUOptions(allow_growth=True) 

109 else: 

110 gpu_options = tf.GPUOptions( 

111 per_process_gpu_memory_fraction=self.p.gpu_memory_fraction 

112 ) 

113 self.sess = tf.Session( 

114 config=tf.ConfigProto(gpu_options=gpu_options), graph=self.graph 

115 ) 

116 self.sess.run(tf.global_variables_initializer()) 

117 self.sess.run( 

118 self.data_iterator.initializer, 

119 feed_dict={ 

120 p: d for p, d in zip(self.data_placeholders, self.train_data) 

121 }, 

122 ) 

123 if not read_only: 

124 self.train_writer = tf.summary.FileWriter( 

125 os.path.join(self.p.main_path, "train"), self.sess.graph 

126 ) 

127 if self.validation_data is not None: 

128 self.validation_writer = tf.summary.FileWriter( 

129 os.path.join(self.p.main_path, "validation"), self.sess.graph 

130 ) 

131 self.saver = tf.train.Saver(max_to_keep=100) 

132 

133 # load Checkpoint 

134 if checkpoint_path: 

135 self.saver.restore(self.sess, checkpoint_path) 

136 

137 def _prepare_data(self): 

138 if self.train_data is None: 

139 assert ( 

140 self.n_inputs is not None 

141 ), "If no train_data is given, n_inputs needs to be given" 

142 self.train_data = [np.zeros((3, self.n_inputs), dtype=np.float32)] 

143 elif isinstance(self.train_data, np.ndarray): 

144 self.train_data = [self.train_data.astype(np.float32)] 

145 elif isinstance(self.train_data, (list, tuple)): 

146 self.train_data = [dat.astype(np.float32) for dat in self.train_data] 

147 elif isinstance(self.train_data, TFRecordDataset): 

148 pass 

149 else: 

150 raise ValueError( 

151 "{} is not supported as input type for train_data".format( 

152 type(self.train_data) 

153 ) 

154 ) 

155 

156 if self.validation_data is not None: 

157 self.validation_data = self.validation_data.astype(np.float32) 

158 # Todo: allow lists of validation data 

159 

160 def _setup_network(self): 

161 self.inputs = self.data_iterator.get_next() 

162 self.main_inputs = self.inputs[0] 

163 self.main_inputs = tf.placeholder_with_default( 

164 self.main_inputs, self.main_inputs.shape 

165 ) 

166 self.regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant) 

167 encoded = self._encode(self.main_inputs) 

168 self.latent = tf.placeholder_with_default(encoded, encoded.shape) 

169 variable_summaries("latent", self.latent) 

170 self.generated = self._generate(self.latent) 

171 

172 def _setup_data_iterator(self): 

173 self.data_placeholders = tuple( 

174 tf.placeholder(dat.dtype, dat.shape) for dat in self.train_data 

175 ) 

176 self.data_set = tf.data.Dataset.from_tensor_slices(self.data_placeholders) 

177 self.data_set = self.data_set.shuffle(buffer_size=len(self.train_data[0])) 

178 self.data_set = self.data_set.repeat() 

179 self.data_set = self.data_set.batch(self.p.batch_size) 

180 if not tf.executing_eagerly(): 

181 self.data_iterator = self.data_set.make_initializable_iterator() 

182 else: 

183 self.data_iterator = self.data_set 

184 

185 def _encode(self, inputs): 

186 with tf.name_scope("encoder"): 

187 if self.p.periodicity < float("inf"): 

188 if self.p.periodicity != 2 * pi: 

189 inputs = inputs / self.p.periodicity * 2 * pi 

190 self.unit_circle_inputs = tf.concat([tf.sin(inputs), tf.cos(inputs)], 1) 

191 current_layer = self.unit_circle_inputs 

192 else: 

193 current_layer = inputs 

194 

195 assert ( 

196 len(self.p.n_neurons) == len(self.p.activation_functions) - 1 

197 ), "you need one activation function more then layers given in n_neurons" 

198 for i, (n_neurons, act_fun) in enumerate( 

199 zip(self.p.n_neurons, self.p.activation_functions[1:]) 

200 ): 

201 if act_fun: 

202 act_fun = getattr(tf.nn, act_fun) 

203 else: 

204 act_fun = None 

205 variable_summaries("activation{}".format(i), current_layer, debug=True) 

206 dense = tf.layers.Dense( 

207 n_neurons, 

208 activation=act_fun, 

209 kernel_initializer=tf.variance_scaling_initializer(), 

210 kernel_regularizer=self.regularizer, 

211 bias_initializer=tf.random_normal_initializer(0.1, 0.05), 

212 ) 

213 current_layer = dense(current_layer) 

214 add_layer_summaries(dense, debug=self.debug) 

215 return current_layer 

216 

217 def _generate(self, inputs): 

218 with tf.name_scope("generator"): 

219 current_layer = inputs 

220 if self.p.periodicity < float("inf"): 

221 n_neurons_with_inputs = [ 

222 self.main_inputs.shape[1] * 2 

223 ] + self.p.n_neurons 

224 else: 

225 n_neurons_with_inputs = [self.main_inputs.shape[1]] + self.p.n_neurons 

226 for n_neurons, act_fun in zip( 

227 n_neurons_with_inputs[-2::-1], self.p.activation_functions[-2::-1] 

228 ): 

229 if act_fun: 

230 act_fun = getattr(tf.nn, act_fun) 

231 else: 

232 act_fun = None 

233 current_layer = tf.layers.dense( 

234 current_layer, 

235 n_neurons, 

236 activation=act_fun, 

237 kernel_initializer=tf.variance_scaling_initializer(), 

238 kernel_regularizer=self.regularizer, 

239 bias_initializer=tf.random_normal_initializer(0.1, 0.05), 

240 ) 

241 if self.p.periodicity < float("inf"): 

242 split = self.main_inputs.shape[1] 

243 current_layer = tf.atan2( 

244 current_layer[:, :split], current_layer[:, split:] 

245 ) 

246 if self.p.periodicity != 2 * pi: 

247 current_layer = current_layer / (2 * pi) * self.p.periodicity 

248 return current_layer 

249 

250 def _setup_cost(self): 

251 self._auto_cost() 

252 self._center_cost() 

253 self._l2_reg_cost() 

254 

255 def _auto_cost(self): 

256 if self.p.auto_cost_scale is not None: 

257 if self.p.auto_cost_variant == "mean_square": 

258 auto_cost = tf.reduce_mean( 

259 tf.square( 

260 periodic_distance( 

261 self.main_inputs, self.generated, self.p.periodicity 

262 ) 

263 ) 

264 ) 

265 elif self.p.auto_cost_variant == "mean_abs": 

266 auto_cost = tf.reduce_mean( 

267 tf.abs( 

268 periodic_distance( 

269 self.main_inputs, self.generated, self.p.periodicity 

270 ) 

271 ) 

272 ) 

273 elif self.p.auto_cost_variant == "mean_norm": 

274 auto_cost = tf.reduce_mean( 

275 tf.norm( 

276 periodic_distance( 

277 self.main_inputs, self.generated, self.p.periodicity 

278 ), 

279 axis=1, 

280 ) 

281 ) 

282 else: 

283 raise ValueError( 

284 "auto_cost_variant {} not available".format( 

285 self.p.auto_cost_variant 

286 ) 

287 ) 

288 tf.summary.scalar("auto_cost", auto_cost) 

289 if self.p.auto_cost_scale != 0: 

290 self.cost += self.p.auto_cost_scale * auto_cost 

291 

292 def _l2_reg_cost(self): 

293 if self.p.l2_reg_constant is not None: 

294 reg_cost = tf.losses.get_regularization_loss() 

295 tf.summary.scalar("", reg_cost) 

296 if self.p.l2_reg_constant != 0: 

297 self.cost += reg_cost 

298 

299 def _center_cost(self): 

300 if self.p.center_cost_scale is not None: 

301 center_cost = tf.reduce_mean(tf.square(self.latent)) 

302 tf.summary.scalar("center_cost", center_cost) 

303 if self.p.center_cost_scale != 0: 

304 self.cost += self.p.center_cost_scale * center_cost 

305 

306 def encode(self, inputs): 

307 """ 

308 Projects high dimensional data to a low dimensional space using the encoder part of the autoencoder. 

309 

310 :param inputs: 2d numpy array with the same number of columns as the used train_data 

311 :return: 2d numpy array with the point projected the the low dimensional space. The number of columns is equal 

312 to the number of neurons in the bottleneck layer of the autoencoder. 

313 """ 

314 latents = [] 

315 batches = np.array_split(inputs, max(1, int(len(inputs) / 2048))) 

316 for batch in batches: 

317 latent = self.sess.run(self.latent, feed_dict={self.main_inputs: batch}) 

318 latents.append(latent) 

319 latents = np.concatenate(latents, axis=0) 

320 return latents 

321 

322 def generate(self, latent): 

323 """ 

324 Generates new high-dimensional points based on given low-dimensional points using the decoder part of the 

325 autoencoder. 

326 

327 :param latent: 2d numpy array containing points in the low-dimensional space. The number of columns must be 

328 equal to the number of neurons in the bottleneck layer of the autoencoder. 

329 :return: 2d numpy array containing points in the high-dimensional space. 

330 """ 

331 generateds = [] 

332 batches = np.array_split(latent, max(1, int(len(latent) / 2048))) 

333 for batch in batches: 

334 generated = self.sess.run(self.generated, feed_dict={self.latent: batch}) 

335 generateds.append(generated) 

336 generated = np.concatenate(generateds, axis=0) 

337 return generated 

338 

339 def _random_batch(self, data, batch_size=None): 

340 batch_size = batch_size or self.p.batch_size 

341 batch = data[np.random.choice(len(data), batch_size, replace=False)] 

342 return batch 

343 

344 def train(self): 

345 """ 

346 Train the autoencoder as specified in the parameters object. 

347 """ 

348 for i in tqdm(range(self.p.n_steps)): 

349 # if fixed seed we need the summaries at global_step == 1 

350 # to run unittests on them 

351 if self.seed is not None: 

352 # run session and get value summaries 

353 _, summary_values = self.sess.run( 

354 (self.optimize, self.merged_summaries) 

355 ) 

356 

357 # get the current global step 

358 x = tf.train.get_global_step(self.graph) 

359 x = x.eval(session=self.sess) 

360 

361 # read the current Summary 

362 values = {} 

363 values["global_step"] = x 

364 summary_proto = tf.Summary().FromString(summary_values) 

365 for entry in summary_proto.value: 

366 if "histogram" in entry.tag: 

367 values[entry.tag] = entry.histo 

368 else: 

369 values[entry.tag] = entry.simple_value 

370 return values 

371 

372 if (i + 1) % self.p.summary_step == 0: 

373 # _, summary_values = self.sess.run((self.optimize, self.merged_summaries)) 

374 _, summary_values = self.sess.run( 

375 (self.optimize, self.merged_summaries) 

376 ) 

377 self.train_writer.add_summary(summary_values, self._step()) 

378 if self.validation_data is not None: 

379 summary_values = self.sess.run( 

380 self.merged_summaries, 

381 feed_dict={ 

382 self.main_inputs: self._random_batch(self.validation_data) 

383 }, 

384 ) 

385 self.validation_writer.add_summary(summary_values, self._step()) 

386 else: 

387 self.sess.run(self.optimize) 

388 

389 if (self._step()) % self.p.checkpoint_step == 0: 

390 self.saver.save( 

391 self.sess, 

392 os.path.join( 

393 self.p.main_path, 

394 "checkpoints", 

395 "step{}.ckpt".format(self._step()), 

396 ), 

397 ) 

398 else: 

399 self.saver.save( 

400 self.sess, 

401 os.path.join( 

402 self.p.main_path, "checkpoints", "step{}.ckpt".format(self._step()) 

403 ), 

404 ) 

405 self.train_writer.flush() 

406 

407 def _step(self): 

408 return tf.train.global_step(self.sess, self.global_step) 

409 

410 def profile(self): 

411 options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 

412 run_metadata = tf.RunMetadata() 

413 for i in range(5): 

414 self.sess.run(self.optimize, options=options, run_metadata=run_metadata) 

415 

416 fetched_timeline = timeline.Timeline(run_metadata.step_stats) 

417 chrome_trace = fetched_timeline.generate_chrome_trace_format() 

418 with open( 

419 os.path.join(self.p.main_path, "timeline{}.json".format(i)), "w" 

420 ) as f: 

421 f.write(chrome_trace) 

422 

423 def close(self): 

424 """ 

425 Close tensorflow session to free resources. 

426 :return: 

427 """ 

428 if "sess" in self.__dict__.keys(): 

429 self.sess.close() 

430 try: 

431 tf_ops.dismantle_graph( 

432 self.graph 

433 ) # not implemented in older versions of tensorflow 

434 except AttributeError: 

435 pass 

436 

437 def __enter__(self): 

438 return self 

439 

440 def __exit__(self, exc_type, exc_val, exc_tb): 

441 self.close() 

442 

443 def __del__(self): 

444 self.close()