Coverage for encodermap/encodermap_tf1/angle_dihedral_cartesian_encodermap.py: 9%

171 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1from math import pi 

2 

3import numpy as np 

4import tensorflow.compat.v1 as tf 

5 

6from .autoencoder import Autoencoder 

7from .backmapping import ( 

8 chain_in_plane, 

9 dihedrals_to_cartesian_tf, 

10 guess_amide_H, 

11 guess_amide_O, 

12 merge_cartesians, 

13) 

14from .misc import distance_cost, pairwise_dist, periodic_distance, variable_summaries 

15from .moldata import MolData 

16from .parameters import ADCParameters 

17 

18 

19class AngleDihedralCartesianEncoderMap(Autoencoder): 

20 """ 

21 This EncoderMap variant is specially designed for protein conformations. 

22 During the training, the cartesian conformations of the backbone chain are reconstructed from backbone angles and 

23 dihedrals. 

24 This allows for a more sophisticated comparison of input conformations and generated conformations and improves 

25 the accuracy of generated conformations especially for large proteins. 

26 We achieve this with the cartesian_cost where we compare pairwise distances between atoms in cartesian coordinates 

27 in the input and generated conformations. 

28 """ 

29 

30 def __init__(self, *args, **kwargs): 

31 """ 

32 :param parameters: ADCParameters object as defined in :class:`encodermap.encodermap_tf1.parameters.ADCParameters` 

33 

34 :param train_data: the training data as a :class:`.MolData` object 

35 

36 :param validation_data: not yet supported 

37 

38 :param checkpoint_path: If a checkpoint path is given, values like neural network weights stored in this 

39 checkpoint will be restored. 

40 

41 :param read_only: if True, no output is writen 

42 """ 

43 print("calling_super") 

44 super(AngleDihedralCartesianEncoderMap, self).__init__(*args, **kwargs) 

45 print("super ended") 

46 assert isinstance(self.p, ADCParameters) 

47 assert isinstance(self.train_moldata, MolData) 

48 

49 def _prepare_data(self): 

50 self.train_moldata = self.train_data 

51 assert ( 

52 self.train_moldata.lengths.shape[1] 

53 == self.train_moldata.central_cartesians.shape[1] - 1 

54 ) 

55 assert ( 

56 self.train_moldata.angles.shape[1] 

57 == self.train_moldata.central_cartesians.shape[1] - 2 

58 ) 

59 assert ( 

60 self.train_moldata.dihedrals.shape[1] 

61 == self.train_moldata.central_cartesians.shape[1] - 3 

62 ) 

63 self.train_data = ( 

64 self.train_moldata.angles, 

65 self.train_moldata.dihedrals, 

66 self.train_moldata.central_cartesians, 

67 ) 

68 

69 if self.validation_data is not None: 

70 raise ValueError("validation data not supported yet") # Todo: add support 

71 

72 def _setup_network(self): 

73 self.inputs = self.data_iterator.get_next() 

74 print("self.inputs:", self.inputs) 

75 if self.p.use_backbone_angles: 

76 self.main_inputs = tf.concat([self.inputs[0], self.inputs[1]], axis=1) 

77 else: 

78 self.main_inputs = self.inputs[1] 

79 self.main_inputs = tf.placeholder_with_default( 

80 self.main_inputs, self.main_inputs.shape 

81 ) 

82 self.regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant) 

83 encoded = self._encode(self.main_inputs) 

84 self.latent = tf.placeholder_with_default(encoded, encoded.shape) 

85 variable_summaries("latent", self.latent) 

86 self.generated = self._generate(self.latent) 

87 

88 print("self.main_inputs:", self.main_inputs) 

89 print("self.latent:", self.latent) 

90 print("self.generated:", self.generated) 

91 

92 if self.p.use_backbone_angles: 

93 self.generated_angles = self.generated[:, : self.inputs[0].shape[1]] 

94 self.generated_dihedrals = self.generated[:, self.inputs[0].shape[1] :] 

95 else: 

96 self.generated_dihedrals = self.generated 

97 self.generated_angles = tf.tile( 

98 np.expand_dims(np.mean(self.train_moldata.angles, axis=0), axis=0), 

99 [tf.shape(self.generated_dihedrals)[0], 1], 

100 ) 

101 

102 print("self.generated_angles:", self.generated_angles) 

103 print("self.generated_dihedrals:", self.generated_dihedrals) 

104 

105 mean_lengths = np.expand_dims( 

106 np.mean(self.train_moldata.lengths, axis=0), axis=0 

107 ) 

108 print("mean_lengths.shape:", mean_lengths.shape) 

109 self.chain_in_plane = chain_in_plane(mean_lengths, self.generated_angles) 

110 print("self.chain_in_plane:", self.chain_in_plane) 

111 self.cartesian = dihedrals_to_cartesian_tf( 

112 self.generated_dihedrals + pi, self.chain_in_plane 

113 ) 

114 print("self.cartesian:", self.cartesian) 

115 # self.amide_H_cartesian = guess_amide_H(self.cartesian, self.train_moldata.central_atoms.names) 

116 # print('self.amide_H_cartesian:', self.amide_H_cartesian) 

117 # self.amide_O_cartesian = guess_amide_O(self.cartesian, self.train_moldata.central_atoms.names) 

118 # print('self.amide_O_cartesian:', self.amide_O_cartesian) 

119 # self.cartesian_with_guessed_atoms = merge_cartesians(self.cartesian, self.train_moldata.central_atoms.names, 

120 # self.amide_H_cartesian, self.amide_O_cartesian) 

121 # print('self.cartesian_with_guessed_atoms:', self.cartesian_with_guessed_atoms) 

122 self.input_cartesian_pairwise_dist = pairwise_dist( 

123 self.inputs[2][ 

124 :, 

125 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

126 ], 

127 flat=True, 

128 ) 

129 print("self.input_cartesian_pairwise_dist:", self.input_cartesian_pairwise_dist) 

130 self.gen_cartesian_pairwise_dist = pairwise_dist( 

131 self.cartesian[ 

132 :, 

133 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

134 ], 

135 flat=True, 

136 ) 

137 self.cost = 0 

138 print("cartesian_cost:", self._cartesian_cost().eval()) 

139 raise Exception("STOP") 

140 print("self.gen_cartesian_pairwise_dist:", self.gen_cartesian_pairwise_dist) 

141 self.clashes = tf.count_nonzero( 

142 pairwise_dist(self.cartesian, flat=True) < 1, axis=1, dtype=tf.float32 

143 ) 

144 print("self.clashes:", self.clashes) 

145 tf.summary.scalar("clashes", tf.reduce_mean(self.clashes)) 

146 

147 raise Exception("STOP") 

148 

149 def _setup_cost(self): 

150 self._dihedral_cost() 

151 self._angle_cost() 

152 self._cartesian_cost() 

153 

154 self._distance_cost() 

155 self._cartesian_distance_cost() 

156 self._center_cost() 

157 self._l2_reg_cost() 

158 

159 def _dihedral_cost(self): 

160 if self.p.dihedral_cost_scale is not None: 

161 if self.p.dihedral_cost_variant == "mean_square": 

162 dihedral_cost = tf.reduce_mean( 

163 tf.square( 

164 periodic_distance( 

165 self.inputs[1], self.generated_dihedrals, self.p.periodicity 

166 ) 

167 ) 

168 ) 

169 elif self.p.dihedral_cost_variant == "mean_abs": 

170 dihedral_cost = tf.reduce_mean( 

171 tf.abs( 

172 periodic_distance( 

173 self.inputs[1], self.generated_dihedrals, self.p.periodicity 

174 ) 

175 ) 

176 ) 

177 elif self.p.dihedral_cost_variant == "mean_norm": 

178 dihedral_cost = tf.reduce_mean( 

179 tf.norm( 

180 periodic_distance( 

181 self.inputs[1], self.generated_dihedrals, self.p.periodicity 

182 ), 

183 axis=1, 

184 ) 

185 ) 

186 else: 

187 raise ValueError( 

188 "dihedral_cost_variant {} not available".format( 

189 self.p.auto_cost_variant 

190 ) 

191 ) 

192 dihedral_cost /= self.p.dihedral_cost_reference 

193 tf.summary.scalar("dihedral_cost", dihedral_cost) 

194 if self.p.dihedral_cost_scale != 0: 

195 self.cost += self.p.dihedral_cost_scale * dihedral_cost 

196 

197 def _angle_cost(self): 

198 if self.p.angle_cost_scale is not None: 

199 if self.p.angle_cost_variant == "mean_square": 

200 angle_cost = tf.reduce_mean( 

201 tf.square( 

202 periodic_distance( 

203 self.inputs[0], self.generated_angles, self.p.periodicity 

204 ) 

205 ) 

206 ) 

207 elif self.p.angle_cost_variant == "mean_abs": 

208 angle_cost = tf.reduce_mean( 

209 tf.abs( 

210 periodic_distance( 

211 self.inputs[0], self.generated_angles, self.p.periodicity 

212 ) 

213 ) 

214 ) 

215 elif self.p.angle_cost_variant == "mean_norm": 

216 angle_cost = tf.reduce_mean( 

217 tf.norm( 

218 periodic_distance( 

219 self.inputs[0], self.generated_angles, self.p.periodicity 

220 ), 

221 axis=1, 

222 ) 

223 ) 

224 else: 

225 raise ValueError( 

226 "angle_cost_variant {} not available".format( 

227 self.p.auto_cost_variant 

228 ) 

229 ) 

230 angle_cost /= self.p.angle_cost_reference 

231 tf.summary.scalar("angle_cost", angle_cost) 

232 if self.p.angle_cost_scale != 0: 

233 self.cost += self.p.angle_cost_scale * angle_cost 

234 

235 def _distance_cost(self): 

236 if self.p.distance_cost_scale is not None: 

237 dist_cost = distance_cost( 

238 self.main_inputs, 

239 self.latent, 

240 *self.p.dist_sig_parameters, 

241 self.p.periodicity, 

242 ) 

243 tf.summary.scalar("sigmoid_loss", dist_cost) 

244 if self.p.distance_cost_scale != 0: 

245 self.cost += self.p.distance_cost_scale * dist_cost 

246 

247 def _cartesian_distance_cost(self): 

248 if self.p.cartesian_distance_cost_scale is not None: 

249 dist_cost = distance_cost( 

250 self.input_cartesian_pairwise_dist, 

251 self.latent, 

252 *self.p.cartesian_dist_sig_parameters, 

253 float("inf"), 

254 ) 

255 tf.summary.scalar("cartesian_distance_cost", dist_cost) 

256 if self.p.cartesian_distance_cost_scale != 0: 

257 self.cost += self.p.cartesian_distance_cost_scale * dist_cost 

258 

259 def _cartesian_cost(self): 

260 if self.p.cartesian_cost_scale is not None: 

261 if self.p.cartesian_cost_variant == "mean_square": 

262 cartesian_cost = tf.reduce_mean( 

263 tf.square( 

264 self.input_cartesian_pairwise_dist 

265 - self.gen_cartesian_pairwise_dist 

266 ) 

267 ) 

268 elif self.p.cartesian_cost_variant == "mean_abs": 

269 cartesian_cost = tf.reduce_mean( 

270 tf.abs( 

271 self.input_cartesian_pairwise_dist 

272 - self.gen_cartesian_pairwise_dist 

273 ) 

274 ) 

275 elif self.p.cartesian_cost_variant == "mean_norm": 

276 cartesian_cost = tf.reduce_mean( 

277 tf.norm( 

278 self.input_cartesian_pairwise_dist 

279 - self.gen_cartesian_pairwise_dist, 

280 axis=1, 

281 ) 

282 ) 

283 else: 

284 raise ValueError( 

285 "cartesian_cost_variant {} not available".format( 

286 self.p.dihedral_to_cartesian_cost_variant 

287 ) 

288 ) 

289 cartesian_cost /= self.p.cartesian_cost_reference 

290 tf.summary.scalar("cartesian_cost", cartesian_cost) 

291 if self.p.cartesian_cost_scale != 0: 

292 if self.p.cartesian_cost_scale_soft_start[0] is None: 

293 self.cost += self.p.cartesian_cost_scale * cartesian_cost 

294 else: 

295 a = self.p.cartesian_cost_scale_soft_start[0] 

296 b = self.p.cartesian_cost_scale_soft_start[1] 

297 cost_scale = tf.case( 

298 [ 

299 ( 

300 tf.less(self.global_step, a), 

301 lambda: tf.constant(0, tf.float32), 

302 ), # if global step is less than a 

303 ( 

304 tf.greater(self.global_step, b), 

305 lambda: tf.constant( 

306 self.p.cartesian_cost_scale, tf.float32 

307 ), 

308 ), 

309 ], 

310 default=lambda: self.p.cartesian_cost_scale 

311 / (b - a) 

312 * (tf.cast(self.global_step, tf.float32) - a), 

313 ) 

314 tf.summary.scalar("cartesian_cost_scale", cost_scale) 

315 self.cost += cost_scale * cartesian_cost 

316 

317 def generate(self, latent, quantity=None): 

318 if quantity is None: 

319 all_dihedrals = [] 

320 all_cartesians = [] 

321 all_angles = [] 

322 batches = np.array_split(latent, max(1, int(len(latent) / 2048))) 

323 for batch in batches: 

324 angles, dihedrals, cartesians = self.sess.run( 

325 ( 

326 self.generated_angles, 

327 self.generated_dihedrals, 

328 self.cartesian_with_guessed_atoms, 

329 ), 

330 feed_dict={self.latent: batch}, 

331 ) 

332 all_dihedrals.append(dihedrals) 

333 all_cartesians.append(cartesians) 

334 all_angles.append(angles) 

335 all_dihedrals = np.concatenate(all_dihedrals, axis=0) 

336 all_cartesians = np.concatenate(all_cartesians, axis=0) 

337 all_angles = np.concatenate(all_angles, axis=0) 

338 return all_angles, all_dihedrals, all_cartesians 

339 

340 else: 

341 results = [] 

342 batches = np.array_split(latent, max(1, int(len(latent) / 2048))) 

343 for batch in batches: 

344 results.append(self.sess.run(quantity, feed_dict={self.latent: batch})) 

345 return np.concatenate(results, axis=0) 

346 

347 

348class AngleDihedralCartesianEncoderMapDummy(AngleDihedralCartesianEncoderMap): 

349 def _setup_network(self): 

350 self.inputs = self.data_iterator.get_next() 

351 if self.p.use_backbone_angles: 

352 self.main_inputs = tf.concat([self.inputs[0], self.inputs[1]], axis=1) 

353 else: 

354 self.main_inputs = self.inputs[1] 

355 self.main_inputs = tf.placeholder_with_default( 

356 self.main_inputs, self.main_inputs.shape 

357 ) 

358 self.regularizer = tf.contrib.layers.l2_regularizer( 

359 scale=self.p.l2_reg_constant 

360 ) 

361 encoded = self._encode(self.main_inputs) 

362 self.latent = tf.placeholder_with_default(encoded, encoded.shape) 

363 variable_summaries("latent", self.latent) 

364 self.generated = self._generate(self.latent) 

365 

366 self.generated_dihedrals = tf.tile( 

367 np.expand_dims(np.mean(self.train_moldata.dihedrals, axis=0), axis=0), 

368 [tf.shape(self.main_inputs)[0], 1], 

369 ) 

370 self.generated_angles = tf.tile( 

371 np.expand_dims(np.mean(self.train_moldata.angles, axis=0), axis=0), 

372 [tf.shape(self.main_inputs)[0], 1], 

373 ) 

374 

375 mean_lengths = np.expand_dims( 

376 np.mean(self.train_moldata.lengths, axis=0), axis=0 

377 ) 

378 self.chain_in_plane = chain_in_plane(mean_lengths, self.generated_angles) 

379 self.cartesian = dihedrals_to_cartesian_tf( 

380 self.generated_dihedrals + pi, self.chain_in_plane 

381 ) 

382 

383 self.amide_H_cartesian = guess_amide_H( 

384 self.cartesian, self.train_moldata.central_atoms.names 

385 ) 

386 self.amide_O_cartesian = guess_amide_O( 

387 self.cartesian, self.train_moldata.central_atoms.names 

388 ) 

389 

390 self.cartesian_with_guessed_atoms = merge_cartesians( 

391 self.cartesian, 

392 self.train_moldata.central_atoms.names, 

393 self.amide_H_cartesian, 

394 self.amide_O_cartesian, 

395 ) 

396 

397 self.input_cartesian_pairwise_dist = pairwise_dist( 

398 self.inputs[2][ 

399 :, 

400 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

401 ], 

402 flat=True, 

403 ) 

404 

405 self.gen_cartesian_pairwise_dist = pairwise_dist( 

406 self.cartesian[ 

407 :, 

408 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

409 ], 

410 flat=True, 

411 ) 

412 

413 self.clashes = tf.count_nonzero( 

414 pairwise_dist(self.cartesian, flat=True) < 1, axis=1, dtype=tf.float32 

415 ) 

416 tf.summary.scalar("clashes", tf.reduce_mean(self.clashes))