Coverage for encodermap/encodermap_tf1/angle_dihedral_cartesian_encodermap.py: 10%

156 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# Standard Library Imports 

2from math import pi 

3 

4# Third Party Imports 

5import numpy as np 

6import tensorflow.compat.v1 as tf 

7 

8# Local Folder Imports 

9from .autoencoder import Autoencoder 

10from .backmapping import ( 

11 chain_in_plane, 

12 dihedrals_to_cartesian_tf, 

13 guess_amide_H, 

14 guess_amide_O, 

15 merge_cartesians, 

16) 

17from .misc import distance_cost, pairwise_dist, periodic_distance, variable_summaries 

18from .moldata import MolData 

19from .parameters import ADCParameters 

20 

21 

22class AngleDihedralCartesianEncoderMap(Autoencoder): 

23 """ 

24 This EncoderMap variant is specially designed for protein conformations. 

25 During the training, the cartesian conformations of the backbone chain are reconstructed from backbone angles and 

26 dihedrals. 

27 This allows for a more sophisticated comparison of input conformations and generated conformations and improves 

28 the accuracy of generated conformations especially for large proteins. 

29 We achieve this with the cartesian_cost where we compare pairwise distances between atoms in cartesian coordinates 

30 in the input and generated conformations. 

31 """ 

32 

33 def __init__(self, *args, **kwargs): 

34 """ 

35 :param parameters: ADCParameters object as defined in :class:`.ADCParameters` 

36 

37 :param train_data: the training data as a :class:`.MolData` object 

38 

39 :param validation_data: not yet supported 

40 

41 :param checkpoint_path: If a checkpoint path is given, values like neural network weights stored in this 

42 checkpoint will be restored. 

43 

44 :param read_only: if True, no output is writen 

45 """ 

46 super().__init__(*args, **kwargs) 

47 assert isinstance(self.p, ADCParameters) 

48 assert isinstance(self.train_moldata, MolData) 

49 

50 def _prepare_data(self): 

51 self.train_moldata = self.train_data 

52 assert ( 

53 self.train_moldata.lengths.shape[1] 

54 == self.train_moldata.central_cartesians.shape[1] - 1 

55 ) 

56 assert ( 

57 self.train_moldata.angles.shape[1] 

58 == self.train_moldata.central_cartesians.shape[1] - 2 

59 ) 

60 assert ( 

61 self.train_moldata.dihedrals.shape[1] 

62 == self.train_moldata.central_cartesians.shape[1] - 3 

63 ) 

64 self.train_data = ( 

65 self.train_moldata.angles, 

66 self.train_moldata.dihedrals, 

67 self.train_moldata.central_cartesians, 

68 ) 

69 

70 if self.validation_data is not None: 

71 raise ValueError("validation data not supported yet") # Todo: add support 

72 

73 def _setup_network(self): 

74 self.inputs = self.data_iterator.get_next() 

75 if self.p.use_backbone_angles: 

76 self.main_inputs = tf.concat([self.inputs[0], self.inputs[1]], axis=1) 

77 else: 

78 self.main_inputs = self.inputs[1] 

79 self.main_inputs = tf.placeholder_with_default( 

80 self.main_inputs, self.main_inputs.shape 

81 ) 

82 self.regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant) 

83 encoded = self._encode(self.main_inputs) 

84 self.latent = tf.placeholder_with_default(encoded, encoded.shape) 

85 variable_summaries("latent", self.latent) 

86 self.generated = self._generate(self.latent) 

87 

88 if self.p.use_backbone_angles: 

89 self.generated_angles = self.generated[:, : self.inputs[0].shape[1]] 

90 self.generated_dihedrals = self.generated[:, self.inputs[0].shape[1] :] 

91 else: 

92 self.generated_dihedrals = self.generated 

93 self.generated_angles = tf.tile( 

94 np.expand_dims(np.mean(self.train_moldata.angles, axis=0), axis=0), 

95 [tf.shape(self.generated_dihedrals)[0], 1], 

96 ) 

97 

98 mean_lengths = np.expand_dims( 

99 np.mean(self.train_moldata.lengths, axis=0), axis=0 

100 ) 

101 self.chain_in_plane = chain_in_plane(mean_lengths, self.generated_angles) 

102 self.cartesian = dihedrals_to_cartesian_tf( 

103 self.generated_dihedrals + pi, self.chain_in_plane 

104 ) 

105 

106 self.amide_H_cartesian = guess_amide_H( 

107 self.cartesian, self.train_moldata.central_atoms.names 

108 ) 

109 self.amide_O_cartesian = guess_amide_O( 

110 self.cartesian, self.train_moldata.central_atoms.names 

111 ) 

112 

113 self.cartesian_with_guessed_atoms = merge_cartesians( 

114 self.cartesian, 

115 self.train_moldata.central_atoms.names, 

116 self.amide_H_cartesian, 

117 self.amide_O_cartesian, 

118 ) 

119 

120 self.input_cartesian_pairwise_dist = pairwise_dist( 

121 self.inputs[2][ 

122 :, 

123 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

124 ], 

125 flat=True, 

126 ) 

127 

128 self.gen_cartesian_pairwise_dist = pairwise_dist( 

129 self.cartesian[ 

130 :, 

131 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

132 ], 

133 flat=True, 

134 ) 

135 

136 self.clashes = tf.count_nonzero( 

137 pairwise_dist(self.cartesian, flat=True) < 1, axis=1, dtype=tf.float32 

138 ) 

139 tf.summary.scalar("clashes", tf.reduce_mean(self.clashes)) 

140 

141 def _setup_cost(self): 

142 self._dihedral_cost() 

143 self._angle_cost() 

144 self._cartesian_cost() 

145 

146 self._distance_cost() 

147 self._cartesian_distance_cost() 

148 self._center_cost() 

149 self._l2_reg_cost() 

150 

151 def _dihedral_cost(self): 

152 if self.p.dihedral_cost_scale is not None: 

153 if self.p.dihedral_cost_variant == "mean_square": 

154 dihedral_cost = tf.reduce_mean( 

155 tf.square( 

156 periodic_distance( 

157 self.inputs[1], self.generated_dihedrals, self.p.periodicity 

158 ) 

159 ) 

160 ) 

161 elif self.p.dihedral_cost_variant == "mean_abs": 

162 dihedral_cost = tf.reduce_mean( 

163 tf.abs( 

164 periodic_distance( 

165 self.inputs[1], self.generated_dihedrals, self.p.periodicity 

166 ) 

167 ) 

168 ) 

169 elif self.p.dihedral_cost_variant == "mean_norm": 

170 dihedral_cost = tf.reduce_mean( 

171 tf.norm( 

172 periodic_distance( 

173 self.inputs[1], self.generated_dihedrals, self.p.periodicity 

174 ), 

175 axis=1, 

176 ) 

177 ) 

178 else: 

179 raise ValueError( 

180 "dihedral_cost_variant {} not available".format( 

181 self.p.auto_cost_variant 

182 ) 

183 ) 

184 dihedral_cost /= self.p.dihedral_cost_reference 

185 tf.summary.scalar("dihedral_cost", dihedral_cost) 

186 if self.p.dihedral_cost_scale != 0: 

187 self.cost += self.p.dihedral_cost_scale * dihedral_cost 

188 

189 def _angle_cost(self): 

190 if self.p.angle_cost_scale is not None: 

191 if self.p.angle_cost_variant == "mean_square": 

192 angle_cost = tf.reduce_mean( 

193 tf.square( 

194 periodic_distance( 

195 self.inputs[0], self.generated_angles, self.p.periodicity 

196 ) 

197 ) 

198 ) 

199 elif self.p.angle_cost_variant == "mean_abs": 

200 angle_cost = tf.reduce_mean( 

201 tf.abs( 

202 periodic_distance( 

203 self.inputs[0], self.generated_angles, self.p.periodicity 

204 ) 

205 ) 

206 ) 

207 elif self.p.angle_cost_variant == "mean_norm": 

208 angle_cost = tf.reduce_mean( 

209 tf.norm( 

210 periodic_distance( 

211 self.inputs[0], self.generated_angles, self.p.periodicity 

212 ), 

213 axis=1, 

214 ) 

215 ) 

216 else: 

217 raise ValueError( 

218 "angle_cost_variant {} not available".format( 

219 self.p.auto_cost_variant 

220 ) 

221 ) 

222 angle_cost /= self.p.angle_cost_reference 

223 tf.summary.scalar("angle_cost", angle_cost) 

224 if self.p.angle_cost_scale != 0: 

225 self.cost += self.p.angle_cost_scale * angle_cost 

226 

227 def _distance_cost(self): 

228 if self.p.distance_cost_scale is not None: 

229 dist_cost = distance_cost( 

230 self.main_inputs, 

231 self.latent, 

232 *self.p.dist_sig_parameters, 

233 self.p.periodicity, 

234 ) 

235 tf.summary.scalar("distance_cost", dist_cost) 

236 if self.p.distance_cost_scale != 0: 

237 self.cost += self.p.distance_cost_scale * dist_cost 

238 

239 def _cartesian_distance_cost(self): 

240 if self.p.cartesian_distance_cost_scale is not None: 

241 dist_cost = distance_cost( 

242 self.input_cartesian_pairwise_dist, 

243 self.latent, 

244 *self.p.cartesian_dist_sig_parameters, 

245 float("inf"), 

246 ) 

247 tf.summary.scalar("cartesian_distance_cost", dist_cost) 

248 if self.p.cartesian_distance_cost_scale != 0: 

249 self.cost += self.p.cartesian_distance_cost_scale * dist_cost 

250 

251 def _cartesian_cost(self): 

252 if self.p.cartesian_cost_scale is not None: 

253 if self.p.cartesian_cost_variant == "mean_square": 

254 cartesian_cost = tf.reduce_mean( 

255 tf.square( 

256 self.input_cartesian_pairwise_dist 

257 - self.gen_cartesian_pairwise_dist 

258 ) 

259 ) 

260 elif self.p.cartesian_cost_variant == "mean_abs": 

261 cartesian_cost = tf.reduce_mean( 

262 tf.abs( 

263 self.input_cartesian_pairwise_dist 

264 - self.gen_cartesian_pairwise_dist 

265 ) 

266 ) 

267 elif self.p.cartesian_cost_variant == "mean_norm": 

268 cartesian_cost = tf.reduce_mean( 

269 tf.norm( 

270 self.input_cartesian_pairwise_dist 

271 - self.gen_cartesian_pairwise_dist, 

272 axis=1, 

273 ) 

274 ) 

275 else: 

276 raise ValueError( 

277 "cartesian_cost_variant {} not available".format( 

278 self.p.dihedral_to_cartesian_cost_variant 

279 ) 

280 ) 

281 cartesian_cost /= self.p.cartesian_cost_reference 

282 tf.summary.scalar("cartesian_cost", cartesian_cost) 

283 if self.p.cartesian_cost_scale != 0: 

284 if self.p.cartesian_cost_scale_soft_start[0] is None: 

285 self.cost += self.p.cartesian_cost_scale * cartesian_cost 

286 else: 

287 a = self.p.cartesian_cost_scale_soft_start[0] 

288 b = self.p.cartesian_cost_scale_soft_start[1] 

289 cost_scale = tf.case( 

290 [ 

291 ( 

292 tf.less(self.global_step, a), 

293 lambda: tf.constant(0, tf.float32), 

294 ), 

295 ( 

296 tf.greater(self.global_step, b), 

297 lambda: tf.constant( 

298 self.p.cartesian_cost_scale, tf.float32 

299 ), 

300 ), 

301 ], 

302 default=lambda: self.p.cartesian_cost_scale 

303 / (b - a) 

304 * (tf.cast(self.global_step, tf.float32) - a), 

305 ) 

306 tf.summary.scalar("cartesian_cost_scale", cost_scale) 

307 self.cost += cost_scale * cartesian_cost 

308 

309 def generate(self, latent, quantity=None): 

310 if quantity is None: 

311 all_dihedrals = [] 

312 all_cartesians = [] 

313 all_angles = [] 

314 batches = np.array_split(latent, max(1, int(len(latent) / 2048))) 

315 for batch in batches: 

316 angles, dihedrals, cartesians = self.sess.run( 

317 ( 

318 self.generated_angles, 

319 self.generated_dihedrals, 

320 self.cartesian_with_guessed_atoms, 

321 ), 

322 feed_dict={self.latent: batch}, 

323 ) 

324 all_dihedrals.append(dihedrals) 

325 all_cartesians.append(cartesians) 

326 all_angles.append(angles) 

327 all_dihedrals = np.concatenate(all_dihedrals, axis=0) 

328 all_cartesians = np.concatenate(all_cartesians, axis=0) 

329 all_angles = np.concatenate(all_angles, axis=0) 

330 return all_angles, all_dihedrals, all_cartesians 

331 

332 else: 

333 results = [] 

334 batches = np.array_split(latent, max(1, int(len(latent) / 2048))) 

335 for batch in batches: 

336 results.append(self.sess.run(quantity, feed_dict={self.latent: batch})) 

337 return np.concatenate(results, axis=0) 

338 

339 

340class AngleDihedralCartesianEncoderMapDummy(AngleDihedralCartesianEncoderMap): 

341 def _setup_network(self): 

342 self.inputs = self.data_iterator.get_next() 

343 if self.p.use_backbone_angles: 

344 self.main_inputs = tf.concat([self.inputs[0], self.inputs[1]], axis=1) 

345 else: 

346 self.main_inputs = self.inputs[1] 

347 self.main_inputs = tf.placeholder_with_default( 

348 self.main_inputs, self.main_inputs.shape 

349 ) 

350 self.regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant) 

351 encoded = self._encode(self.main_inputs) 

352 self.latent = tf.placeholder_with_default(encoded, encoded.shape) 

353 variable_summaries("latent", self.latent) 

354 self.generated = self._generate(self.latent) 

355 

356 self.generated_dihedrals = tf.tile( 

357 np.expand_dims(np.mean(self.train_moldata.dihedrals, axis=0), axis=0), 

358 [tf.shape(self.main_inputs)[0], 1], 

359 ) 

360 self.generated_angles = tf.tile( 

361 np.expand_dims(np.mean(self.train_moldata.angles, axis=0), axis=0), 

362 [tf.shape(self.main_inputs)[0], 1], 

363 ) 

364 

365 mean_lengths = np.expand_dims( 

366 np.mean(self.train_moldata.lengths, axis=0), axis=0 

367 ) 

368 self.chain_in_plane = chain_in_plane(mean_lengths, self.generated_angles) 

369 self.cartesian = dihedrals_to_cartesian_tf( 

370 self.generated_dihedrals + pi, self.chain_in_plane 

371 ) 

372 

373 self.amide_H_cartesian = guess_amide_H( 

374 self.cartesian, self.train_moldata.central_atoms.names 

375 ) 

376 self.amide_O_cartesian = guess_amide_O( 

377 self.cartesian, self.train_moldata.central_atoms.names 

378 ) 

379 

380 self.cartesian_with_guessed_atoms = merge_cartesians( 

381 self.cartesian, 

382 self.train_moldata.central_atoms.names, 

383 self.amide_H_cartesian, 

384 self.amide_O_cartesian, 

385 ) 

386 

387 self.input_cartesian_pairwise_dist = pairwise_dist( 

388 self.inputs[2][ 

389 :, 

390 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

391 ], 

392 flat=True, 

393 ) 

394 

395 self.gen_cartesian_pairwise_dist = pairwise_dist( 

396 self.cartesian[ 

397 :, 

398 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step, 

399 ], 

400 flat=True, 

401 ) 

402 

403 self.clashes = tf.count_nonzero( 

404 pairwise_dist(self.cartesian, flat=True) < 1, axis=1, dtype=tf.float32 

405 ) 

406 tf.summary.scalar("clashes", tf.reduce_mean(self.clashes))