Coverage for encodermap/encodermap_tf1/autoencoder.py: 12%

217 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1import tensorflow 

2 

3try: 

4 from tensorflow.python.data.ops.readers import TFRecordDatasetV 

5except: 

6 from tensorflow.python.data.ops.readers import TFRecordDatasetV1 as TFRecordDataset 

7 

8import os 

9from math import pi 

10 

11import numpy as np 

12import tensorflow.compat.v1 as tf 

13from tensorflow.python.client import timeline 

14from tensorflow.python.framework import ops as tf_ops 

15from tqdm import tqdm 

16 

17from .backmapping import dihedrals_to_cartesian_tf 

18from .misc import ( 

19 add_layer_summaries, 

20 distance_cost, 

21 pairwise_dist, 

22 periodic_distance, 

23 variable_summaries, 

24) 

25from .parameters import Parameters 

26 

27 

28class Autoencoder: 

29 def __init__( 

30 self, 

31 parameters, 

32 train_data=None, 

33 validation_data=None, 

34 checkpoint_path=None, 

35 n_inputs=None, 

36 read_only=False, 

37 seed=None, 

38 debug=False, 

39 ): 

40 """ 

41 :param parameters: Parameters object as defined in :class:`encodermap.encodermap_tf1.parameters.Parameters` 

42 

43 :param train_data: 2d numpy array where each row is treated as a training point 

44 

45 :param validation_data: A 2d numpy array. This data will only be used to calculate a validation error during 

46 training. It will not be used for training. 

47 

48 :param checkpoint_path: If a checkpoint path is given, values like neural network weights stored in this 

49 checkpoint will be restored. 

50 

51 :param n_inputs: If no train_data is given, for example when an already trained network is restored from a 

52 checkpoint, the number of of inputs needs to be given. This should be equal to the number of 

53 columns of the train_data the network was originally trained with. 

54 :param read_only: if True, no output is writen 

55 """ 

56 # Parameters: 

57 self.p = parameters 

58 self.n_inputs = n_inputs 

59 # if seed is provided, weights and biases are fixed to ensure reproducibility 

60 self.seed = seed 

61 tf.random.set_random_seed(self.seed) 

62 self.debug = debug 

63 if not read_only: 

64 self.p.save() 

65 print( 

66 "Output files are saved to {}".format(self.p.main_path), 

67 "as defined in 'main_path' in the parameters.", 

68 ) 

69 

70 print("Disabling eager execution") 

71 tf.compat.v1.disable_eager_execution() 

72 

73 self.train_data = train_data 

74 self.validation_data = validation_data 

75 

76 self._prepare_data() 

77 

78 self.graph = tf.Graph() 

79 with self.graph.as_default(): 

80 self.global_step = tf.train.create_global_step() 

81 

82 self._setup_data_iterator() 

83 

84 self._setup_network() 

85 

86 with tf.name_scope("cost"): 

87 self.cost = 0 

88 self._setup_cost() 

89 tf.summary.scalar("combined_cost", self.cost) 

90 

91 # Setup Optimizer: 

92 self.optimizer = tf.train.AdamOptimizer(self.p.learning_rate) 

93 gradients = self.optimizer.compute_gradients(self.cost) 

94 self.optimize = self.optimizer.apply_gradients( 

95 gradients, global_step=self.global_step 

96 ) 

97 

98 self.merged_summaries = tf.summary.merge_all() 

99 

100 # Setup Session 

101 if self.p.gpu_memory_fraction == 0: 

102 gpu_options = tf.GPUOptions(allow_growth=True) 

103 else: 

104 gpu_options = tf.GPUOptions( 

105 per_process_gpu_memory_fraction=self.p.gpu_memory_fraction 

106 ) 

107 self.sess = tf.Session( 

108 config=tf.ConfigProto(gpu_options=gpu_options), graph=self.graph 

109 ) 

110 self.sess.run(tf.global_variables_initializer()) 

111 self.sess.run( 

112 self.data_iterator.initializer, 

113 feed_dict={ 

114 p: d for p, d in zip(self.data_placeholders, self.train_data) 

115 }, 

116 ) 

117 if not read_only: 

118 self.train_writer = tf.summary.FileWriter( 

119 os.path.join(self.p.main_path, "train"), self.sess.graph 

120 ) 

121 if self.validation_data is not None: 

122 self.validation_writer = tf.summary.FileWriter( 

123 os.path.join(self.p.main_path, "validation"), self.sess.graph 

124 ) 

125 self.saver = tf.train.Saver(max_to_keep=100) 

126 

127 # load Checkpoint 

128 if checkpoint_path: 

129 self.saver.restore(self.sess, checkpoint_path) 

130 

131 def _prepare_data(self): 

132 if self.train_data is None: 

133 assert ( 

134 self.n_inputs is not None 

135 ), "If no train_data is given, n_inputs needs to be given" 

136 self.train_data = [np.zeros((3, self.n_inputs), dtype=np.float32)] 

137 elif isinstance(self.train_data, np.ndarray): 

138 self.train_data = [self.train_data.astype(np.float32)] 

139 elif isinstance(self.train_data, (list, tuple)): 

140 self.train_data = [dat.astype(np.float32) for dat in self.train_data] 

141 elif isinstance(self.train_data, TFRecordDataset): 

142 pass 

143 else: 

144 raise ValueError( 

145 "{} is not supported as input type for train_data".format( 

146 type(train_data) 

147 ) 

148 ) 

149 

150 if self.validation_data is not None: 

151 self.validation_data = self.validation_data.astype(np.float32) 

152 # Todo: allow lists of validation data 

153 

154 def _setup_network(self): 

155 self.inputs = self.data_iterator.get_next() 

156 self.main_inputs = self.inputs[0] 

157 self.main_inputs = tf.placeholder_with_default( 

158 self.main_inputs, self.main_inputs.shape 

159 ) 

160 self.regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant) 

161 encoded = self._encode(self.main_inputs) 

162 self.latent = tf.placeholder_with_default(encoded, encoded.shape) 

163 variable_summaries("latent", self.latent) 

164 self.generated = self._generate(self.latent) 

165 

166 def _setup_data_iterator(self): 

167 self.data_placeholders = tuple( 

168 tf.placeholder(dat.dtype, dat.shape) for dat in self.train_data 

169 ) 

170 self.data_set = tf.data.Dataset.from_tensor_slices(self.data_placeholders) 

171 self.data_set = self.data_set.shuffle(buffer_size=len(self.train_data[0])) 

172 self.data_set = self.data_set.repeat() 

173 self.data_set = self.data_set.batch(self.p.batch_size) 

174 self.data_iterator = self.data_set.make_initializable_iterator() 

175 

176 def _encode(self, inputs): 

177 with tf.name_scope("encoder"): 

178 if self.p.periodicity < float("inf"): 

179 if self.p.periodicity != 2 * pi: 

180 inputs = inputs / self.p.periodicity * 2 * pi 

181 self.unit_circle_inputs = tf.concat([tf.sin(inputs), tf.cos(inputs)], 1) 

182 current_layer = self.unit_circle_inputs 

183 else: 

184 current_layer = inputs 

185 

186 assert ( 

187 len(self.p.n_neurons) == len(self.p.activation_functions) - 1 

188 ), "you need one activation function more then layers given in n_neurons" 

189 for i, (n_neurons, act_fun) in enumerate( 

190 zip(self.p.n_neurons, self.p.activation_functions[1:]) 

191 ): 

192 if act_fun: 

193 act_fun = getattr(tf.nn, act_fun) 

194 else: 

195 act_fun = None 

196 variable_summaries("activation{}".format(i), current_layer, debug=True) 

197 dense = tf.layers.Dense( 

198 n_neurons, 

199 activation=act_fun, 

200 kernel_initializer=tf.variance_scaling_initializer(), 

201 kernel_regularizer=self.regularizer, 

202 bias_initializer=tf.random_normal_initializer(0.1, 0.05), 

203 ) 

204 current_layer = dense(current_layer) 

205 add_layer_summaries(dense, debug=self.debug) 

206 return current_layer 

207 

208 def _generate(self, inputs): 

209 with tf.name_scope("generator"): 

210 current_layer = inputs 

211 if self.p.periodicity < float("inf"): 

212 n_neurons_with_inputs = [ 

213 self.main_inputs.shape[1] * 2 

214 ] + self.p.n_neurons 

215 else: 

216 n_neurons_with_inputs = [self.main_inputs.shape[1]] + self.p.n_neurons 

217 for n_neurons, act_fun in zip( 

218 n_neurons_with_inputs[-2::-1], self.p.activation_functions[-2::-1] 

219 ): 

220 if act_fun: 

221 act_fun = getattr(tf.nn, act_fun) 

222 else: 

223 act_fun = None 

224 current_layer = tf.layers.dense( 

225 current_layer, 

226 n_neurons, 

227 activation=act_fun, 

228 kernel_initializer=tf.variance_scaling_initializer(), 

229 kernel_regularizer=self.regularizer, 

230 bias_initializer=tf.random_normal_initializer(0.1, 0.05), 

231 ) 

232 if self.p.periodicity < float("inf"): 

233 split = self.main_inputs.shape[1] 

234 current_layer = tf.atan2( 

235 current_layer[:, :split], current_layer[:, split:] 

236 ) 

237 if self.p.periodicity != 2 * pi: 

238 current_layer = current_layer / (2 * pi) * self.p.periodicity 

239 return current_layer 

240 

241 def _setup_cost(self): 

242 self._auto_cost() 

243 self._center_cost() 

244 self._l2_reg_cost() 

245 

246 def _auto_cost(self): 

247 if self.p.auto_cost_scale is not None: 

248 if self.p.auto_cost_variant == "mean_square": 

249 auto_cost = tf.reduce_mean( 

250 tf.square( 

251 periodic_distance( 

252 self.main_inputs, self.generated, self.p.periodicity 

253 ) 

254 ) 

255 ) 

256 elif self.p.auto_cost_variant == "mean_abs": 

257 auto_cost = tf.reduce_mean( 

258 tf.abs( 

259 periodic_distance( 

260 self.main_inputs, self.generated, self.p.periodicity 

261 ) 

262 ) 

263 ) 

264 elif self.p.auto_cost_variant == "mean_norm": 

265 auto_cost = tf.reduce_mean( 

266 tf.norm( 

267 periodic_distance( 

268 self.main_inputs, self.generated, self.p.periodicity 

269 ), 

270 axis=1, 

271 ) 

272 ) 

273 else: 

274 raise ValueError( 

275 "auto_cost_variant {} not available".format( 

276 self.p.auto_cost_variant 

277 ) 

278 ) 

279 tf.summary.scalar("auto_cost", auto_cost) 

280 if self.p.auto_cost_scale != 0: 

281 self.cost += self.p.auto_cost_scale * auto_cost 

282 

283 def _l2_reg_cost(self): 

284 if self.p.l2_reg_constant is not None: 

285 reg_cost = tf.losses.get_regularization_loss() 

286 tf.summary.scalar("reg_cost", reg_cost) 

287 if self.p.l2_reg_constant != 0: 

288 self.cost += reg_cost 

289 

290 def _center_cost(self): 

291 if self.p.center_cost_scale is not None: 

292 center_cost = tf.reduce_mean(tf.square(self.latent)) 

293 tf.summary.scalar("center_cost", center_cost) 

294 if self.p.center_cost_scale != 0: 

295 self.cost += self.p.center_cost_scale * center_cost 

296 

297 def encode(self, inputs): 

298 """ 

299 Projects high dimensional data to a low dimensional space using the encoder part of the autoencoder. 

300 

301 :param inputs: 2d numpy array with the same number of columns as the used train_data 

302 :return: 2d numpy array with the point projected the the low dimensional space. The number of columns is equal 

303 to the number of neurons in the bottleneck layer of the autoencoder. 

304 """ 

305 latents = [] 

306 batches = np.array_split(inputs, max(1, int(len(inputs) / 2048))) 

307 for batch in batches: 

308 latent = self.sess.run(self.latent, feed_dict={self.main_inputs: batch}) 

309 latents.append(latent) 

310 latents = np.concatenate(latents, axis=0) 

311 return latents 

312 

313 def generate(self, latent): 

314 """ 

315 Generates new high-dimensional points based on given low-dimensional points using the decoder part of the 

316 autoencoder. 

317 

318 :param latent: 2d numpy array containing points in the low-dimensional space. The number of columns must be 

319 equal to the number of neurons in the bottleneck layer of the autoencoder. 

320 :return: 2d numpy array containing points in the high-dimensional space. 

321 """ 

322 generateds = [] 

323 batches = np.array_split(latent, max(1, int(len(latent) / 2048))) 

324 for batch in batches: 

325 generated = self.sess.run(self.generated, feed_dict={self.latent: batch}) 

326 generateds.append(generated) 

327 generated = np.concatenate(generateds, axis=0) 

328 return generated 

329 

330 def _random_batch(self, data, batch_size=None): 

331 batch_size = batch_size or self.p.batch_size 

332 batch = data[np.random.choice(len(data), batch_size, replace=False)] 

333 return batch 

334 

335 def train(self): 

336 """ 

337 Train the autoencoder as specified in the parameters object. 

338 """ 

339 for i in tqdm(range(self.p.n_steps)): 

340 # if fixed seed we need the summmaries at global_step == 1 

341 # to run unittests on them 

342 if self.seed is not None: 

343 # run session and get value summaries 

344 _, summary_values = self.sess.run( 

345 (self.optimize, self.merged_summaries) 

346 ) 

347 

348 # get the current global step 

349 x = tf.train.get_global_step(self.graph) 

350 x = x.eval(session=self.sess) 

351 

352 # read the current Summary 

353 values = {} 

354 values["global_step"] = x 

355 summary_proto = tf.Summary().FromString(summary_values) 

356 for entry in summary_proto.value: 

357 if "histogram" in entry.tag: 

358 values[entry.tag] = entry.histo 

359 else: 

360 values[entry.tag] = entry.simple_value 

361 return values 

362 

363 if (i + 1) % self.p.summary_step == 0: 

364 # _, summary_values = self.sess.run((self.optimize, self.merged_summaries)) 

365 _, summary_values = self.sess.run( 

366 (self.optimize, self.merged_summaries) 

367 ) 

368 self.train_writer.add_summary(summary_values, self._step()) 

369 if self.validation_data is not None: 

370 summary_values = self.sess.run( 

371 self.merged_summaries, 

372 feed_dict={ 

373 self.main_inputs: self._random_batch(self.validation_data) 

374 }, 

375 ) 

376 self.validation_writer.add_summary(summary_values, self._step()) 

377 else: 

378 self.sess.run(self.optimize) 

379 

380 if (self._step()) % self.p.checkpoint_step == 0: 

381 self.saver.save( 

382 self.sess, 

383 os.path.join( 

384 self.p.main_path, 

385 "checkpoints", 

386 "step{}.ckpt".format(self._step()), 

387 ), 

388 ) 

389 else: 

390 self.saver.save( 

391 self.sess, 

392 os.path.join( 

393 self.p.main_path, "checkpoints", "step{}.ckpt".format(self._step()) 

394 ), 

395 ) 

396 self.train_writer.flush() 

397 

398 def _step(self): 

399 return tf.train.global_step(self.sess, self.global_step) 

400 

401 def profile(self): 

402 options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 

403 run_metadata = tf.RunMetadata() 

404 for i in range(5): 

405 self.sess.run(self.optimize, options=options, run_metadata=run_metadata) 

406 

407 fetched_timeline = timeline.Timeline(run_metadata.step_stats) 

408 chrome_trace = fetched_timeline.generate_chrome_trace_format() 

409 with open( 

410 os.path.join(self.p.main_path, "timeline{}.json".format(i)), "w" 

411 ) as f: 

412 f.write(chrome_trace) 

413 

414 def close(self): 

415 """ 

416 Close tensorflow session to free resources. 

417 :return: 

418 """ 

419 if "sess" in self.__dict__.keys(): 

420 self.sess.close() 

421 try: 

422 tf_ops.dismantle_graph( 

423 self.graph 

424 ) # not implemented in older versions of tensorflow 

425 except AttributeError: 

426 pass 

427 

428 def __enter__(self): 

429 return self 

430 

431 def __exit__(self, exc_type, exc_val, exc_tb): 

432 self.close() 

433 

434 def __del__(self): 

435 self.close()