Coverage for encodermap/encodermap_tf1/angle_dihedral_cartesian_encodermap.py: 9%
171 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1from math import pi
3import numpy as np
4import tensorflow.compat.v1 as tf
6from .autoencoder import Autoencoder
7from .backmapping import (
8 chain_in_plane,
9 dihedrals_to_cartesian_tf,
10 guess_amide_H,
11 guess_amide_O,
12 merge_cartesians,
13)
14from .misc import distance_cost, pairwise_dist, periodic_distance, variable_summaries
15from .moldata import MolData
16from .parameters import ADCParameters
19class AngleDihedralCartesianEncoderMap(Autoencoder):
20 """
21 This EncoderMap variant is specially designed for protein conformations.
22 During the training, the cartesian conformations of the backbone chain are reconstructed from backbone angles and
23 dihedrals.
24 This allows for a more sophisticated comparison of input conformations and generated conformations and improves
25 the accuracy of generated conformations especially for large proteins.
26 We achieve this with the cartesian_cost where we compare pairwise distances between atoms in cartesian coordinates
27 in the input and generated conformations.
28 """
30 def __init__(self, *args, **kwargs):
31 """
32 :param parameters: ADCParameters object as defined in :class:`encodermap.encodermap_tf1.parameters.ADCParameters`
34 :param train_data: the training data as a :class:`.MolData` object
36 :param validation_data: not yet supported
38 :param checkpoint_path: If a checkpoint path is given, values like neural network weights stored in this
39 checkpoint will be restored.
41 :param read_only: if True, no output is writen
42 """
43 print("calling_super")
44 super(AngleDihedralCartesianEncoderMap, self).__init__(*args, **kwargs)
45 print("super ended")
46 assert isinstance(self.p, ADCParameters)
47 assert isinstance(self.train_moldata, MolData)
49 def _prepare_data(self):
50 self.train_moldata = self.train_data
51 assert (
52 self.train_moldata.lengths.shape[1]
53 == self.train_moldata.central_cartesians.shape[1] - 1
54 )
55 assert (
56 self.train_moldata.angles.shape[1]
57 == self.train_moldata.central_cartesians.shape[1] - 2
58 )
59 assert (
60 self.train_moldata.dihedrals.shape[1]
61 == self.train_moldata.central_cartesians.shape[1] - 3
62 )
63 self.train_data = (
64 self.train_moldata.angles,
65 self.train_moldata.dihedrals,
66 self.train_moldata.central_cartesians,
67 )
69 if self.validation_data is not None:
70 raise ValueError("validation data not supported yet") # Todo: add support
72 def _setup_network(self):
73 self.inputs = self.data_iterator.get_next()
74 print("self.inputs:", self.inputs)
75 if self.p.use_backbone_angles:
76 self.main_inputs = tf.concat([self.inputs[0], self.inputs[1]], axis=1)
77 else:
78 self.main_inputs = self.inputs[1]
79 self.main_inputs = tf.placeholder_with_default(
80 self.main_inputs, self.main_inputs.shape
81 )
82 self.regularizer = tf.keras.regularizers.l2(self.p.l2_reg_constant)
83 encoded = self._encode(self.main_inputs)
84 self.latent = tf.placeholder_with_default(encoded, encoded.shape)
85 variable_summaries("latent", self.latent)
86 self.generated = self._generate(self.latent)
88 print("self.main_inputs:", self.main_inputs)
89 print("self.latent:", self.latent)
90 print("self.generated:", self.generated)
92 if self.p.use_backbone_angles:
93 self.generated_angles = self.generated[:, : self.inputs[0].shape[1]]
94 self.generated_dihedrals = self.generated[:, self.inputs[0].shape[1] :]
95 else:
96 self.generated_dihedrals = self.generated
97 self.generated_angles = tf.tile(
98 np.expand_dims(np.mean(self.train_moldata.angles, axis=0), axis=0),
99 [tf.shape(self.generated_dihedrals)[0], 1],
100 )
102 print("self.generated_angles:", self.generated_angles)
103 print("self.generated_dihedrals:", self.generated_dihedrals)
105 mean_lengths = np.expand_dims(
106 np.mean(self.train_moldata.lengths, axis=0), axis=0
107 )
108 print("mean_lengths.shape:", mean_lengths.shape)
109 self.chain_in_plane = chain_in_plane(mean_lengths, self.generated_angles)
110 print("self.chain_in_plane:", self.chain_in_plane)
111 self.cartesian = dihedrals_to_cartesian_tf(
112 self.generated_dihedrals + pi, self.chain_in_plane
113 )
114 print("self.cartesian:", self.cartesian)
115 # self.amide_H_cartesian = guess_amide_H(self.cartesian, self.train_moldata.central_atoms.names)
116 # print('self.amide_H_cartesian:', self.amide_H_cartesian)
117 # self.amide_O_cartesian = guess_amide_O(self.cartesian, self.train_moldata.central_atoms.names)
118 # print('self.amide_O_cartesian:', self.amide_O_cartesian)
119 # self.cartesian_with_guessed_atoms = merge_cartesians(self.cartesian, self.train_moldata.central_atoms.names,
120 # self.amide_H_cartesian, self.amide_O_cartesian)
121 # print('self.cartesian_with_guessed_atoms:', self.cartesian_with_guessed_atoms)
122 self.input_cartesian_pairwise_dist = pairwise_dist(
123 self.inputs[2][
124 :,
125 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
126 ],
127 flat=True,
128 )
129 print("self.input_cartesian_pairwise_dist:", self.input_cartesian_pairwise_dist)
130 self.gen_cartesian_pairwise_dist = pairwise_dist(
131 self.cartesian[
132 :,
133 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
134 ],
135 flat=True,
136 )
137 self.cost = 0
138 print("cartesian_cost:", self._cartesian_cost().eval())
139 raise Exception("STOP")
140 print("self.gen_cartesian_pairwise_dist:", self.gen_cartesian_pairwise_dist)
141 self.clashes = tf.count_nonzero(
142 pairwise_dist(self.cartesian, flat=True) < 1, axis=1, dtype=tf.float32
143 )
144 print("self.clashes:", self.clashes)
145 tf.summary.scalar("clashes", tf.reduce_mean(self.clashes))
147 raise Exception("STOP")
149 def _setup_cost(self):
150 self._dihedral_cost()
151 self._angle_cost()
152 self._cartesian_cost()
154 self._distance_cost()
155 self._cartesian_distance_cost()
156 self._center_cost()
157 self._l2_reg_cost()
159 def _dihedral_cost(self):
160 if self.p.dihedral_cost_scale is not None:
161 if self.p.dihedral_cost_variant == "mean_square":
162 dihedral_cost = tf.reduce_mean(
163 tf.square(
164 periodic_distance(
165 self.inputs[1], self.generated_dihedrals, self.p.periodicity
166 )
167 )
168 )
169 elif self.p.dihedral_cost_variant == "mean_abs":
170 dihedral_cost = tf.reduce_mean(
171 tf.abs(
172 periodic_distance(
173 self.inputs[1], self.generated_dihedrals, self.p.periodicity
174 )
175 )
176 )
177 elif self.p.dihedral_cost_variant == "mean_norm":
178 dihedral_cost = tf.reduce_mean(
179 tf.norm(
180 periodic_distance(
181 self.inputs[1], self.generated_dihedrals, self.p.periodicity
182 ),
183 axis=1,
184 )
185 )
186 else:
187 raise ValueError(
188 "dihedral_cost_variant {} not available".format(
189 self.p.auto_cost_variant
190 )
191 )
192 dihedral_cost /= self.p.dihedral_cost_reference
193 tf.summary.scalar("dihedral_cost", dihedral_cost)
194 if self.p.dihedral_cost_scale != 0:
195 self.cost += self.p.dihedral_cost_scale * dihedral_cost
197 def _angle_cost(self):
198 if self.p.angle_cost_scale is not None:
199 if self.p.angle_cost_variant == "mean_square":
200 angle_cost = tf.reduce_mean(
201 tf.square(
202 periodic_distance(
203 self.inputs[0], self.generated_angles, self.p.periodicity
204 )
205 )
206 )
207 elif self.p.angle_cost_variant == "mean_abs":
208 angle_cost = tf.reduce_mean(
209 tf.abs(
210 periodic_distance(
211 self.inputs[0], self.generated_angles, self.p.periodicity
212 )
213 )
214 )
215 elif self.p.angle_cost_variant == "mean_norm":
216 angle_cost = tf.reduce_mean(
217 tf.norm(
218 periodic_distance(
219 self.inputs[0], self.generated_angles, self.p.periodicity
220 ),
221 axis=1,
222 )
223 )
224 else:
225 raise ValueError(
226 "angle_cost_variant {} not available".format(
227 self.p.auto_cost_variant
228 )
229 )
230 angle_cost /= self.p.angle_cost_reference
231 tf.summary.scalar("angle_cost", angle_cost)
232 if self.p.angle_cost_scale != 0:
233 self.cost += self.p.angle_cost_scale * angle_cost
235 def _distance_cost(self):
236 if self.p.distance_cost_scale is not None:
237 dist_cost = distance_cost(
238 self.main_inputs,
239 self.latent,
240 *self.p.dist_sig_parameters,
241 self.p.periodicity,
242 )
243 tf.summary.scalar("sigmoid_loss", dist_cost)
244 if self.p.distance_cost_scale != 0:
245 self.cost += self.p.distance_cost_scale * dist_cost
247 def _cartesian_distance_cost(self):
248 if self.p.cartesian_distance_cost_scale is not None:
249 dist_cost = distance_cost(
250 self.input_cartesian_pairwise_dist,
251 self.latent,
252 *self.p.cartesian_dist_sig_parameters,
253 float("inf"),
254 )
255 tf.summary.scalar("cartesian_distance_cost", dist_cost)
256 if self.p.cartesian_distance_cost_scale != 0:
257 self.cost += self.p.cartesian_distance_cost_scale * dist_cost
259 def _cartesian_cost(self):
260 if self.p.cartesian_cost_scale is not None:
261 if self.p.cartesian_cost_variant == "mean_square":
262 cartesian_cost = tf.reduce_mean(
263 tf.square(
264 self.input_cartesian_pairwise_dist
265 - self.gen_cartesian_pairwise_dist
266 )
267 )
268 elif self.p.cartesian_cost_variant == "mean_abs":
269 cartesian_cost = tf.reduce_mean(
270 tf.abs(
271 self.input_cartesian_pairwise_dist
272 - self.gen_cartesian_pairwise_dist
273 )
274 )
275 elif self.p.cartesian_cost_variant == "mean_norm":
276 cartesian_cost = tf.reduce_mean(
277 tf.norm(
278 self.input_cartesian_pairwise_dist
279 - self.gen_cartesian_pairwise_dist,
280 axis=1,
281 )
282 )
283 else:
284 raise ValueError(
285 "cartesian_cost_variant {} not available".format(
286 self.p.dihedral_to_cartesian_cost_variant
287 )
288 )
289 cartesian_cost /= self.p.cartesian_cost_reference
290 tf.summary.scalar("cartesian_cost", cartesian_cost)
291 if self.p.cartesian_cost_scale != 0:
292 if self.p.cartesian_cost_scale_soft_start[0] is None:
293 self.cost += self.p.cartesian_cost_scale * cartesian_cost
294 else:
295 a = self.p.cartesian_cost_scale_soft_start[0]
296 b = self.p.cartesian_cost_scale_soft_start[1]
297 cost_scale = tf.case(
298 [
299 (
300 tf.less(self.global_step, a),
301 lambda: tf.constant(0, tf.float32),
302 ), # if global step is less than a
303 (
304 tf.greater(self.global_step, b),
305 lambda: tf.constant(
306 self.p.cartesian_cost_scale, tf.float32
307 ),
308 ),
309 ],
310 default=lambda: self.p.cartesian_cost_scale
311 / (b - a)
312 * (tf.cast(self.global_step, tf.float32) - a),
313 )
314 tf.summary.scalar("cartesian_cost_scale", cost_scale)
315 self.cost += cost_scale * cartesian_cost
317 def generate(self, latent, quantity=None):
318 if quantity is None:
319 all_dihedrals = []
320 all_cartesians = []
321 all_angles = []
322 batches = np.array_split(latent, max(1, int(len(latent) / 2048)))
323 for batch in batches:
324 angles, dihedrals, cartesians = self.sess.run(
325 (
326 self.generated_angles,
327 self.generated_dihedrals,
328 self.cartesian_with_guessed_atoms,
329 ),
330 feed_dict={self.latent: batch},
331 )
332 all_dihedrals.append(dihedrals)
333 all_cartesians.append(cartesians)
334 all_angles.append(angles)
335 all_dihedrals = np.concatenate(all_dihedrals, axis=0)
336 all_cartesians = np.concatenate(all_cartesians, axis=0)
337 all_angles = np.concatenate(all_angles, axis=0)
338 return all_angles, all_dihedrals, all_cartesians
340 else:
341 results = []
342 batches = np.array_split(latent, max(1, int(len(latent) / 2048)))
343 for batch in batches:
344 results.append(self.sess.run(quantity, feed_dict={self.latent: batch}))
345 return np.concatenate(results, axis=0)
348class AngleDihedralCartesianEncoderMapDummy(AngleDihedralCartesianEncoderMap):
349 def _setup_network(self):
350 self.inputs = self.data_iterator.get_next()
351 if self.p.use_backbone_angles:
352 self.main_inputs = tf.concat([self.inputs[0], self.inputs[1]], axis=1)
353 else:
354 self.main_inputs = self.inputs[1]
355 self.main_inputs = tf.placeholder_with_default(
356 self.main_inputs, self.main_inputs.shape
357 )
358 self.regularizer = tf.contrib.layers.l2_regularizer(
359 scale=self.p.l2_reg_constant
360 )
361 encoded = self._encode(self.main_inputs)
362 self.latent = tf.placeholder_with_default(encoded, encoded.shape)
363 variable_summaries("latent", self.latent)
364 self.generated = self._generate(self.latent)
366 self.generated_dihedrals = tf.tile(
367 np.expand_dims(np.mean(self.train_moldata.dihedrals, axis=0), axis=0),
368 [tf.shape(self.main_inputs)[0], 1],
369 )
370 self.generated_angles = tf.tile(
371 np.expand_dims(np.mean(self.train_moldata.angles, axis=0), axis=0),
372 [tf.shape(self.main_inputs)[0], 1],
373 )
375 mean_lengths = np.expand_dims(
376 np.mean(self.train_moldata.lengths, axis=0), axis=0
377 )
378 self.chain_in_plane = chain_in_plane(mean_lengths, self.generated_angles)
379 self.cartesian = dihedrals_to_cartesian_tf(
380 self.generated_dihedrals + pi, self.chain_in_plane
381 )
383 self.amide_H_cartesian = guess_amide_H(
384 self.cartesian, self.train_moldata.central_atoms.names
385 )
386 self.amide_O_cartesian = guess_amide_O(
387 self.cartesian, self.train_moldata.central_atoms.names
388 )
390 self.cartesian_with_guessed_atoms = merge_cartesians(
391 self.cartesian,
392 self.train_moldata.central_atoms.names,
393 self.amide_H_cartesian,
394 self.amide_O_cartesian,
395 )
397 self.input_cartesian_pairwise_dist = pairwise_dist(
398 self.inputs[2][
399 :,
400 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
401 ],
402 flat=True,
403 )
405 self.gen_cartesian_pairwise_dist = pairwise_dist(
406 self.cartesian[
407 :,
408 self.p.cartesian_pwd_start : self.p.cartesian_pwd_stop : self.p.cartesian_pwd_step,
409 ],
410 flat=True,
411 )
413 self.clashes = tf.count_nonzero(
414 pairwise_dist(self.cartesian, flat=True) < 1, axis=1, dtype=tf.float32
415 )
416 tf.summary.scalar("clashes", tf.reduce_mean(self.clashes))