Coverage for encodermap/encodermap_tf1/backmapping.py: 17%

162 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# Standard Library Imports 

2from math import cos, pi, sin 

3 

4# Third Party Imports 

5import MDAnalysis as md 

6import numpy as np 

7import tensorflow.compat.v1 as tf 

8from MDAnalysis.analysis.base import AnalysisFromFunction 

9from MDAnalysis.coordinates.memory import MemoryReader 

10 

11# Local Folder Imports 

12from .misc import rotation_matrix 

13 

14 

15def _expand_universe(universe, length): 

16 coordinates = ( 

17 AnalysisFromFunction(lambda ag: ag.positions.copy(), universe.atoms) 

18 .run() 

19 .results 

20 )["timeseries"] 

21 coordinates = np.tile(coordinates, (length, 1, 1)) 

22 universe.load_new(coordinates, format=MemoryReader) 

23 

24 

25def _set_dihedral(dihedral, atoms, angle): 

26 current_angle = dihedral.dihedral.value() 

27 head = atoms[dihedral[2].id :] 

28 vec = dihedral[2].position - dihedral[1].position 

29 head.rotateby(angle - current_angle, vec, dihedral[2].position) 

30 

31 

32def dihedral_backmapping(pdb_path, dihedral_trajectory, rough_n_points=-1): 

33 """ 

34 Takes a pdb file with a peptide and creates a trajectory based on the dihedral angles given. 

35 It simply rotates around the dihedral angle axis. In the result side-chains might overlap but the backbone should 

36 turn out quite well. 

37 

38 :param pdb_path: (str) 

39 :param dihedral_trajectory: 

40 array-like of shape (traj_length, number_of_dihedrals) 

41 :param rough_n_points: (int) a step_size to select a subset of values from dihedral_trajectory is calculated by 

42 max(1, int(len(dihedral_trajectory) / rough_n_points)) with rough_n_points = -1 all values are used. 

43 :return: (MDAnalysis.Universe) 

44 """ 

45 step_size = max(1, int(len(dihedral_trajectory) / rough_n_points)) 

46 dihedral_trajectory = dihedral_trajectory[::step_size] 

47 

48 uni = md.Universe(pdb_path) 

49 protein = uni.select_atoms("protein") 

50 

51 dihedrals = [] 

52 

53 for residue in protein.residues: 

54 phi = residue.phi_selection() 

55 if phi: 

56 dihedrals.append(phi) 

57 

58 for residue in protein.residues: 

59 psi = residue.psi_selection() 

60 if psi: 

61 dihedrals.append(psi) 

62 

63 _expand_universe(uni, len(dihedral_trajectory)) 

64 

65 for dihedral_values, step in zip(dihedral_trajectory, uni.trajectory): 

66 for dihedral, value in zip(dihedrals, dihedral_values): 

67 _set_dihedral(dihedral, protein, value / (2 * pi) * 360) 

68 return uni 

69 

70 

71def straight_tetrahedral_chain(n_atoms=None, bond_lengths=None): 

72 dx = cos(70.63 / 180 * pi) 

73 dy = sin(70.63 / 180 * pi) 

74 

75 if n_atoms and not bond_lengths: 

76 coordinates = np.zeros((n_atoms, 3), dtype=np.float32) 

77 indices = np.repeat(np.arange(int(n_atoms / 2) + 1), 2) 

78 coordinates[:, 0] = indices[1 : n_atoms + 1] + dx * indices[0:n_atoms] 

79 coordinates[:, 1] = dy * indices[0:n_atoms] 

80 

81 elif (bond_lengths and not n_atoms) or n_atoms == len(bond_lengths) + 1: 

82 n_bonds = len(bond_lengths) 

83 n_atoms = n_atoms or n_bonds + 1 

84 

85 dxs = bond_lengths * np.tile([1, dx], int(n_atoms / 2))[:n_bonds] 

86 dys = bond_lengths * np.tile([0, dy], int(n_atoms / 2))[:n_bonds] 

87 

88 coordinates = np.zeros((n_atoms, 3), dtype=np.float32) 

89 coordinates[1:, 0] = np.cumsum(dxs) 

90 coordinates[1:, 1] = np.cumsum(dys) 

91 

92 else: 

93 raise ValueError("input not compatible") 

94 return coordinates 

95 

96 

97def chain_in_plane(lengths, angles): 

98 """Reconstructs cartesions from distances and angles.""" 

99 batch_size = tf.shape(angles)[0] 

100 

101 prev_angle = tf.zeros((batch_size)) 

102 xs = [tf.zeros((batch_size))] 

103 ys = [tf.zeros((batch_size))] 

104 sign = 1 

105 

106 for i in range(angles.shape[1]): 

107 xs.append(xs[-1] + lengths[:, i] * tf.cos(prev_angle)) 

108 ys.append(ys[-1] + lengths[:, i] * tf.sin(prev_angle) * sign) 

109 prev_angle = pi - angles[:, i] - prev_angle 

110 sign *= -1 

111 

112 xs.append(xs[-1] + lengths[:, i + 1] * tf.cos(prev_angle)) 

113 ys.append(ys[-1] + lengths[:, i + 1] * tf.sin(prev_angle) * sign) 

114 

115 xs = tf.stack(xs, axis=1) 

116 ys = tf.stack(ys, axis=1) 

117 cartesians = tf.stack([xs, ys, tf.zeros(tf.shape(xs))], axis=2) 

118 

119 return cartesians 

120 

121 

122def dihedrals_to_cartesian_tf_old( 

123 dihedrals, cartesian=None, central_atom_indices=None, no_omega=False 

124): 

125 if not tf.is_numeric_tensor(dihedrals): 

126 dihedrals = tf.convert_to_tensor(dihedrals) 

127 if len(dihedrals.get_shape()) == 1: 

128 one_d = True 

129 dihedrals = tf.expand_dims(dihedrals, 0) 

130 else: 

131 one_d = False 

132 

133 n = int(dihedrals.shape[-1]) 

134 dihedrals = -dihedrals 

135 

136 if cartesian is None: 

137 cartesian = tf.constant(straight_tetrahedral_chain(n + 3)) 

138 if len(cartesian.get_shape()) == 2: 

139 cartesian = tf.tile( 

140 tf.expand_dims(cartesian, axis=0), [tf.shape(dihedrals)[0], 1, 1] 

141 ) 

142 

143 if central_atom_indices is None: 

144 cai = list(range(cartesian.shape[1])) 

145 else: 

146 cai = central_atom_indices 

147 

148 for i in range(n): 

149 if not no_omega: 

150 j = i 

151 else: 

152 j = i + int((i + 1) / 2) 

153 axis = cartesian[:, cai[j + 2]] - cartesian[:, cai[j + 1]] 

154 axis /= tf.norm(axis, axis=1, keepdims=True) 

155 rotated = cartesian[:, cai[j + 2] : cai[j + 2] + 1] + tf.matmul( 

156 cartesian[:, cai[j + 2] + 1 :] - cartesian[:, cai[j + 2] : cai[j + 2] + 1], 

157 rotation_matrix(axis, dihedrals[:, i]), 

158 ) 

159 cartesian = tf.concat([cartesian[:, : cai[j + 2] + 1], rotated], axis=1) 

160 

161 return cartesian 

162 

163 

164def dihedrals_to_cartesian_tf(dihedrals, cartesian): 

165 if not tf.is_numeric_tensor(dihedrals): 

166 dihedrals = tf.convert_to_tensor(dihedrals) 

167 

168 n = int(dihedrals.shape[-1]) 

169 

170 if len(cartesian.get_shape()) == 2: 

171 cartesian = tf.tile( 

172 tf.expand_dims(cartesian, axis=0), [tf.shape(dihedrals)[0], 1, 1] 

173 ) 

174 

175 split = int(int(cartesian.shape[1]) / 2) 

176 

177 cartesian_right = cartesian[:, split - 1 :] 

178 dihedrals_right = dihedrals[:, split - 1 :] 

179 

180 cartesian_left = cartesian[:, split + 1 :: -1] 

181 dihedrals_left = dihedrals[:, split - 2 :: -1] 

182 

183 new_cartesian_right = dihedral_to_cartesian_tf_one_way( 

184 dihedrals_right, cartesian_right 

185 ) 

186 new_cartesian_left = dihedral_to_cartesian_tf_one_way( 

187 dihedrals_left, cartesian_left 

188 ) 

189 # return new_cartesian_left 

190 

191 new_cartesian = tf.concat( 

192 [new_cartesian_left[:, ::-1], new_cartesian_right[:, 3:]], axis=1 

193 ) 

194 

195 return new_cartesian 

196 

197 

198def dihedral_to_cartesian_tf_one_way(dihedrals, cartesian): 

199 n = int(dihedrals.shape[-1]) 

200 dihedrals = -dihedrals 

201 

202 rotated = cartesian[:, 1:] 

203 collected_cartesians = [cartesian[:, 0:1]] 

204 for i in range(n): 

205 collected_cartesians.append(rotated[:, 0:1]) 

206 axis = rotated[:, 1] - rotated[:, 0] 

207 axis /= tf.norm(axis, axis=1, keepdims=True) 

208 offset = rotated[:, 1:2] 

209 rotated = offset + tf.matmul( 

210 rotated[:, 1:] - offset, rotation_matrix(axis, dihedrals[:, i]) 

211 ) 

212 collected_cartesians.append(rotated) 

213 collected_cartesians = tf.concat(collected_cartesians, axis=1) 

214 return collected_cartesians 

215 

216 

217# def dihedral_to_cartesian_tf_one_way2(dihedrals, cartesian): 

218# n = int(dihedrals.shape[-1]) 

219# dihedrals = -dihedrals 

220# 

221# n_batch = tf.shape(cartesian)[0] 

222# 

223# new_cartesians = tf.Variable(np.zeros((256, int(cartesian.shape[1]), 3), dtype=np.float32), trainable=False) 

224# new_cartesians = new_cartesians[:n_batch].assign(cartesian) 

225# 

226# for i in range(n): 

227# axis = new_cartesians[:n_batch, i + 2] - new_cartesians[:n_batch, i + 1] 

228# axis /= tf.norm(axis, axis=1, keepdims=True) 

229# new_cartesians[:n_batch, i + 3:].assign(new_cartesians[:n_batch, i + 2:i + 3] + 

230# tf.matmul(new_cartesians[:n_batch, i + 3:] - new_cartesians[:n_batch, i + 2:i + 3], 

231# rotation_matrix(axis, dihedrals[:, i]))) 

232# return new_cartesians[:n_batch] 

233 

234# def dihedrals_to_cartesian_tf(dihedrals, cartesian): 

235# 

236# if not tf.is_numeric_tensor(dihedrals): 

237# dihedrals = tf.convert_to_tensor(dihedrals) 

238# 

239# n = int(dihedrals.shape[-1]) 

240# dihedrals = -dihedrals 

241# 

242# if len(cartesian.get_shape()) == 2: 

243# cartesian = tf.tile(tf.expand_dims(cartesian, axis=0), [tf.shape(dihedrals)[0], 1, 1]) 

244# 

245# for i in range(n): 

246# axis = cartesian[:, i + 2] - cartesian[:, i + 1] 

247# axis /= tf.norm(axis, axis=1, keepdims=True) 

248# rotated = cartesian[:, i + 2:i + 2 + 1] + \ 

249# tf.matmul(cartesian[:, i + 3:] - cartesian[:, i + 2:i + 3], 

250# rotation_matrix(axis, dihedrals[:, i])) 

251# cartesian = tf.concat([cartesian[:, :i + 3], rotated], axis=1) 

252# 

253# return cartesian 

254 

255 

256def guess_sp2_atom( 

257 cartesians, atom_names, bond_partner, angle_to_previous, bond_length 

258): 

259 assert cartesians.shape[1] == len( 

260 atom_names 

261 ), f"{cartesians.shape=} {len(atom_names)=}" 

262 added_cartesians = [] 

263 for i in range(1, len(atom_names)): 

264 if atom_names[i] == bond_partner: 

265 prev_vec = cartesians[:, i - 1] - cartesians[:, i] 

266 try: 

267 next_vec = cartesians[:, i + 1] - cartesians[:, i] 

268 except (tf.errors.InvalidArgumentError, ValueError): 

269 next_vec = cartesians[:, i - 2] - cartesians[:, i] 

270 

271 perpendicular_axis = tf.cross(prev_vec, next_vec) 

272 perpendicular_axis /= tf.norm(perpendicular_axis, axis=1, keepdims=True) 

273 bond_vec = tf.matmul( 

274 tf.expand_dims(prev_vec, 1), 

275 rotation_matrix(perpendicular_axis, angle_to_previous), 

276 ) 

277 bond_vec = bond_vec[:, 0, :] 

278 bond_vec *= bond_length / tf.norm(bond_vec, axis=1, keepdims=True) 

279 added_cartesians.append(cartesians[:, i] + bond_vec) 

280 added_cartesians = tf.stack(added_cartesians, axis=1) 

281 return added_cartesians 

282 

283 

284def guess_amide_H(cartesians, atom_names): 

285 return guess_sp2_atom(cartesians, atom_names, "N", 123 / 180 * pi, 1.10) 

286 

287 

288def guess_amide_O(cartesians, atom_names): 

289 return guess_sp2_atom(cartesians, atom_names, "C", 121 / 180 * pi, 1.24) 

290 

291 

292def merge_cartesians( 

293 central_cartesians, central_atom_names, H_cartesians, O_cartesians 

294): 

295 cartesian = [central_cartesians[:, 0]] 

296 h_i = 0 

297 o_i = 0 

298 for i in range(1, len(central_atom_names)): 

299 atom_name = central_atom_names[i] 

300 cartesian.append(central_cartesians[:, i]) 

301 if atom_name == "N": 

302 cartesian.append(H_cartesians[:, h_i]) 

303 h_i += 1 

304 elif atom_name == "C": 

305 cartesian.append(O_cartesians[:, o_i]) 

306 o_i += 1 

307 cartesian = tf.stack(cartesian, axis=1) 

308 assert ( 

309 cartesian.shape[1] 

310 == central_cartesians.shape[1] + H_cartesians.shape[1] + O_cartesians.shape[1] 

311 ) 

312 return cartesian