Coverage for encodermap/encodermap_tf1/backmapping.py: 74%

162 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1from math import cos, pi, sin 

2 

3import MDAnalysis as md 

4import numpy as np 

5import tensorflow.compat.v1 as tf 

6from MDAnalysis.analysis.base import AnalysisFromFunction 

7from MDAnalysis.coordinates.memory import MemoryReader 

8 

9from .misc import rotation_matrix 

10 

11 

12def _expand_universe(universe, length): 

13 coordinates = ( 

14 AnalysisFromFunction(lambda ag: ag.positions.copy(), universe.atoms) 

15 .run() 

16 .results 

17 ) 

18 coordinates = np.tile(coordinates, (length, 1, 1)) 

19 universe.load_new(coordinates, format=MemoryReader) 

20 

21 

22def _set_dihedral(dihedral, atoms, angle): 

23 current_angle = dihedral.dihedral.value() 

24 head = atoms[dihedral[2].id :] 

25 vec = dihedral[2].position - dihedral[1].position 

26 head.rotateby(angle - current_angle, vec, dihedral[2].position) 

27 

28 

29def dihedral_backmapping(pdb_path, dihedral_trajectory, rough_n_points=-1): 

30 """ 

31 Takes a pdb file with a peptide and creates a trajectory based on the dihedral angles given. 

32 It simply rotates around the dihedral angle axis. In the result side-chains might overlap but the backbone should 

33 turn out quite well. 

34 

35 :param pdb_path: (str) 

36 :param dihedral_trajectory: 

37 array-like of shape (traj_length, number_of_dihedrals) 

38 :param rough_n_points: (int) a step_size to select a subset of values from dihedral_trajectory is calculated by 

39 max(1, int(len(dihedral_trajectory) / rough_n_points)) with rough_n_points = -1 all values are used. 

40 :return: (MDAnalysis.Universe) 

41 """ 

42 step_size = max(1, int(len(dihedral_trajectory) / rough_n_points)) 

43 dihedral_trajectory = dihedral_trajectory[::step_size] 

44 

45 uni = md.Universe(pdb_path) 

46 protein = uni.select_atoms("protein") 

47 

48 dihedrals = [] 

49 

50 for residue in protein.residues: 

51 phi = residue.phi_selection() 

52 if phi: 

53 dihedrals.append(phi) 

54 

55 for residue in protein.residues: 

56 psi = residue.psi_selection() 

57 if psi: 

58 dihedrals.append(psi) 

59 

60 _expand_universe(uni, len(dihedral_trajectory)) 

61 

62 for dihedral_values, step in zip(dihedral_trajectory, uni.trajectory): 

63 for dihedral, value in zip(dihedrals, dihedral_values): 

64 _set_dihedral(dihedral, protein, value / (2 * pi) * 360) 

65 return uni 

66 

67 

68def straight_tetrahedral_chain(n_atoms=None, bond_lengths=None): 

69 dx = cos(70.63 / 180 * pi) 

70 dy = sin(70.63 / 180 * pi) 

71 

72 if n_atoms and not bond_lengths: 

73 coordinates = np.zeros((n_atoms, 3), dtype=np.float32) 

74 indices = np.repeat(np.arange(int(n_atoms / 2) + 1), 2) 

75 coordinates[:, 0] = indices[1 : n_atoms + 1] + dx * indices[0:n_atoms] 

76 coordinates[:, 1] = dy * indices[0:n_atoms] 

77 

78 elif (bond_lengths and not n_atoms) or n_atoms == len(bond_lengths) + 1: 78 ↛ 90line 78 didn't jump to line 90, because the condition on line 78 was never false

79 n_bonds = len(bond_lengths) 

80 n_atoms = n_atoms or n_bonds + 1 

81 

82 dxs = bond_lengths * np.tile([1, dx], int(n_atoms / 2))[:n_bonds] 

83 dys = bond_lengths * np.tile([0, dy], int(n_atoms / 2))[:n_bonds] 

84 

85 coordinates = np.zeros((n_atoms, 3), dtype=np.float32) 

86 coordinates[1:, 0] = np.cumsum(dxs) 

87 coordinates[1:, 1] = np.cumsum(dys) 

88 

89 else: 

90 raise ValueError("input not compatible") 

91 return coordinates 

92 

93 

94def chain_in_plane(lengths, angles): 

95 """Reconstructs cartesions from distances and angles.""" 

96 batch_size = tf.shape(angles)[0] 

97 

98 prev_angle = tf.zeros((batch_size)) 

99 xs = [tf.zeros((batch_size))] 

100 ys = [tf.zeros((batch_size))] 

101 sign = 1 

102 

103 for i in range(angles.shape[1]): 

104 xs.append(xs[-1] + lengths[:, i] * tf.cos(prev_angle)) 

105 ys.append(ys[-1] + lengths[:, i] * tf.sin(prev_angle) * sign) 

106 prev_angle = pi - angles[:, i] - prev_angle 

107 sign *= -1 

108 

109 xs.append(xs[-1] + lengths[:, i + 1] * tf.cos(prev_angle)) 

110 ys.append(ys[-1] + lengths[:, i + 1] * tf.sin(prev_angle) * sign) 

111 

112 xs = tf.stack(xs, axis=1) 

113 ys = tf.stack(ys, axis=1) 

114 cartesians = tf.stack([xs, ys, tf.zeros(tf.shape(xs))], axis=2) 

115 

116 return cartesians 

117 

118 

119def dihedrals_to_cartesian_tf_old( 

120 dihedrals, cartesian=None, central_atom_indices=None, no_omega=False 

121): 

122 

123 if not tf.is_numeric_tensor(dihedrals): 123 ↛ 124line 123 didn't jump to line 124, because the condition on line 123 was never true

124 dihedrals = tf.convert_to_tensor(dihedrals) 

125 if len(dihedrals.get_shape()) == 1: 125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true

126 one_d = True 

127 dihedrals = tf.expand_dims(dihedrals, 0) 

128 else: 

129 one_d = False 

130 

131 n = int(dihedrals.shape[-1]) 

132 dihedrals = -dihedrals 

133 

134 if cartesian is None: 134 ↛ 135line 134 didn't jump to line 135, because the condition on line 134 was never true

135 cartesian = tf.constant(straight_tetrahedral_chain(n + 3)) 

136 if len(cartesian.get_shape()) == 2: 136 ↛ 137line 136 didn't jump to line 137, because the condition on line 136 was never true

137 cartesian = tf.tile( 

138 tf.expand_dims(cartesian, axis=0), [tf.shape(dihedrals)[0], 1, 1] 

139 ) 

140 

141 if central_atom_indices is None: 141 ↛ 144line 141 didn't jump to line 144, because the condition on line 141 was never false

142 cai = list(range(cartesian.shape[1])) 

143 else: 

144 cai = central_atom_indices 

145 

146 for i in range(n): 

147 if not no_omega: 147 ↛ 150line 147 didn't jump to line 150, because the condition on line 147 was never false

148 j = i 

149 else: 

150 j = i + int((i + 1) / 2) 

151 axis = cartesian[:, cai[j + 2]] - cartesian[:, cai[j + 1]] 

152 axis /= tf.norm(axis, axis=1, keepdims=True) 

153 rotated = cartesian[:, cai[j + 2] : cai[j + 2] + 1] + tf.matmul( 

154 cartesian[:, cai[j + 2] + 1 :] - cartesian[:, cai[j + 2] : cai[j + 2] + 1], 

155 rotation_matrix(axis, dihedrals[:, i]), 

156 ) 

157 cartesian = tf.concat([cartesian[:, : cai[j + 2] + 1], rotated], axis=1) 

158 

159 return cartesian 

160 

161 

162def dihedrals_to_cartesian_tf(dihedrals, cartesian): 

163 

164 if not tf.is_numeric_tensor(dihedrals): 164 ↛ 165line 164 didn't jump to line 165, because the condition on line 164 was never true

165 dihedrals = tf.convert_to_tensor(dihedrals) 

166 

167 n = int(dihedrals.shape[-1]) 

168 

169 if len(cartesian.get_shape()) == 2: 

170 cartesian = tf.tile( 

171 tf.expand_dims(cartesian, axis=0), [tf.shape(dihedrals)[0], 1, 1] 

172 ) 

173 

174 split = int(int(cartesian.shape[1]) / 2) 

175 

176 cartesian_right = cartesian[:, split - 1 :] 

177 dihedrals_right = dihedrals[:, split - 1 :] 

178 

179 cartesian_left = cartesian[:, split + 1 :: -1] 

180 dihedrals_left = dihedrals[:, split - 2 :: -1] 

181 

182 new_cartesian_right = dihedral_to_cartesian_tf_one_way( 

183 dihedrals_right, cartesian_right 

184 ) 

185 new_cartesian_left = dihedral_to_cartesian_tf_one_way( 

186 dihedrals_left, cartesian_left 

187 ) 

188 

189 new_cartesian = tf.concat( 

190 [new_cartesian_left[:, ::-1], new_cartesian_right[:, 3:]], axis=1 

191 ) 

192 

193 return new_cartesian 

194 

195 

196def dihedral_to_cartesian_tf_one_way(dihedrals, cartesian): 

197 n = int(dihedrals.shape[-1]) 

198 dihedrals = -dihedrals 

199 

200 rotated = cartesian[:, 1:] 

201 collected_cartesians = [cartesian[:, 0:1]] 

202 for i in range(n): 

203 collected_cartesians.append(rotated[:, 0:1]) 

204 axis = rotated[:, 1] - rotated[:, 0] 

205 axis /= tf.norm(axis, axis=1, keepdims=True) 

206 offset = rotated[:, 1:2] 

207 rotated = offset + tf.matmul( 

208 rotated[:, 1:] - offset, rotation_matrix(axis, dihedrals[:, i]) 

209 ) 

210 collected_cartesians.append(rotated) 

211 collected_cartesians = tf.concat(collected_cartesians, axis=1) 

212 return collected_cartesians 

213 

214 

215# def dihedral_to_cartesian_tf_one_way2(dihedrals, cartesian): 

216# n = int(dihedrals.shape[-1]) 

217# dihedrals = -dihedrals 

218# 

219# n_batch = tf.shape(cartesian)[0] 

220# 

221# new_cartesians = tf.Variable(np.zeros((256, int(cartesian.shape[1]), 3), dtype=np.float32), trainable=False) 

222# new_cartesians = new_cartesians[:n_batch].assign(cartesian) 

223# 

224# for i in range(n): 

225# axis = new_cartesians[:n_batch, i + 2] - new_cartesians[:n_batch, i + 1] 

226# axis /= tf.norm(axis, axis=1, keepdims=True) 

227# new_cartesians[:n_batch, i + 3:].assign(new_cartesians[:n_batch, i + 2:i + 3] + 

228# tf.matmul(new_cartesians[:n_batch, i + 3:] - new_cartesians[:n_batch, i + 2:i + 3], 

229# rotation_matrix(axis, dihedrals[:, i]))) 

230# return new_cartesians[:n_batch] 

231 

232# def dihedrals_to_cartesian_tf(dihedrals, cartesian): 

233# 

234# if not tf.is_numeric_tensor(dihedrals): 

235# dihedrals = tf.convert_to_tensor(dihedrals) 

236# 

237# n = int(dihedrals.shape[-1]) 

238# dihedrals = -dihedrals 

239# 

240# if len(cartesian.get_shape()) == 2: 

241# cartesian = tf.tile(tf.expand_dims(cartesian, axis=0), [tf.shape(dihedrals)[0], 1, 1]) 

242# 

243# for i in range(n): 

244# axis = cartesian[:, i + 2] - cartesian[:, i + 1] 

245# axis /= tf.norm(axis, axis=1, keepdims=True) 

246# rotated = cartesian[:, i + 2:i + 2 + 1] + \ 

247# tf.matmul(cartesian[:, i + 3:] - cartesian[:, i + 2:i + 3], 

248# rotation_matrix(axis, dihedrals[:, i])) 

249# cartesian = tf.concat([cartesian[:, :i + 3], rotated], axis=1) 

250# 

251# return cartesian 

252 

253 

254def guess_sp2_atom( 

255 cartesians, atom_names, bond_partner, angle_to_previous, bond_length 

256): 

257 assert cartesians.shape[1] == len(atom_names) 

258 added_cartesians = [] 

259 for i in range(1, len(atom_names)): 

260 if atom_names[i] == bond_partner: 

261 prev_vec = cartesians[:, i - 1] - cartesians[:, i] 

262 try: 

263 next_vec = cartesians[:, i + 1] - cartesians[:, i] 

264 except tf.errors.InvalidArgumentError: 

265 next_vec = cartesians[:, i - 2] - cartesians[:, i] 

266 

267 perpendicular_axis = tf.cross(prev_vec, next_vec) 

268 perpendicular_axis /= tf.norm(perpendicular_axis, axis=1, keepdims=True) 

269 bond_vec = tf.matmul( 

270 tf.expand_dims(prev_vec, 1), 

271 rotation_matrix(perpendicular_axis, angle_to_previous), 

272 ) 

273 bond_vec = bond_vec[:, 0, :] 

274 bond_vec *= bond_length / tf.norm(bond_vec, axis=1, keepdims=True) 

275 added_cartesians.append(cartesians[:, i] + bond_vec) 

276 added_cartesians = tf.stack(added_cartesians, axis=1) 

277 return added_cartesians 

278 

279 

280def guess_amide_H(cartesians, atom_names): 

281 return guess_sp2_atom(cartesians, atom_names, "N", 123 / 180 * pi, 1.10) 

282 

283 

284def guess_amide_O(cartesians, atom_names): 

285 return guess_sp2_atom(cartesians, atom_names, "C", 121 / 180 * pi, 1.24) 

286 

287 

288def merge_cartesians( 

289 central_cartesians, central_atom_names, H_cartesians, O_cartesians 

290): 

291 cartesian = [central_cartesians[:, 0]] 

292 h_i = 0 

293 o_i = 0 

294 for i in range(1, len(central_atom_names)): 

295 atom_name = central_atom_names[i] 

296 cartesian.append(central_cartesians[:, i]) 

297 if atom_name == "N": 

298 cartesian.append(H_cartesians[:, h_i]) 

299 h_i += 1 

300 elif atom_name == "C": 

301 cartesian.append(O_cartesians[:, o_i]) 

302 o_i += 1 

303 cartesian = tf.stack(cartesian, axis=1) 

304 assert ( 

305 cartesian.shape[1] 

306 == central_cartesians.shape[1] + H_cartesians.shape[1] + O_cartesians.shape[1] 

307 ) 

308 return cartesian