Coverage for encodermap/trajinfo/load_traj.py: 13%

118 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/trajinfo/load_traj.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Util functions for the `TrajEnsemble` and `SingleTraj` classes. 

23 

24""" 

25 

26 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32# Future Imports at the top 

33from __future__ import annotations 

34 

35# Standard Library Imports 

36import errno 

37import os 

38import sys 

39import tempfile 

40import urllib 

41import warnings 

42from pathlib import Path 

43 

44# Third Party Imports 

45import numpy as np 

46import requests 

47import tables 

48 

49# Encodermap imports 

50from encodermap._typing import CanBeIndex 

51from encodermap.misc.misc import _validate_uri 

52 

53 

54warnings.filterwarnings( 

55 "ignore", 

56 message=(".*top= kwargs ignored since this " "file parser does not support it.*"), 

57) 

58 

59 

60############################################################################## 

61# Optional Imports 

62############################################################################## 

63 

64 

65# Third Party Imports 

66from optional_imports import _optional_import 

67 

68 

69md = _optional_import("mdtraj") 

70h5 = _optional_import("h5py") 

71 

72 

73################################################################################ 

74# Typing 

75################################################################################ 

76 

77 

78# Standard Library Imports 

79from typing import TYPE_CHECKING, Optional, Union 

80 

81 

82if TYPE_CHECKING: 

83 # Third Party Imports 

84 import h5py as h5 

85 import mdtraj as md 

86 

87 

88################################################################################ 

89# Globals 

90################################################################################ 

91 

92 

93__all__: list[str] = [] 

94this = sys.modules[__name__] 

95this.PRINTED_HDF_ANNOTATION = False 

96 

97 

98################################################################################ 

99# Utils 

100################################################################################ 

101 

102 

103def _load_pdb_from_uri( 

104 uri: str, 

105) -> md.Topology: 

106 """Loads urls and if MDTraj misbehaves saves them in a temporary file.""" 

107 assert _validate_uri(uri) 

108 try: 

109 return md.load_pdb(uri).top 

110 except urllib.error.URLError as e: 

111 with tempfile.NamedTemporaryFile(suffix=".pdb") as f: 

112 text = requests.get(uri).text 

113 f.write(text) 

114 top = md.load_pdb(f.name).top 

115 return top 

116 

117 

118def _load_traj_and_top( 

119 traj_file: Path, 

120 top_file: Path, 

121 traj_num: Union[int, None], 

122 index: Optional[Union[int, list[int], np.ndarray, slice]] = None, 

123 atom_index: Optional[np.ndarray] = None, 

124) -> md.Trajectory: 

125 """Loads a traj and top file and raises FileNotFoundError, if they do not exist. 

126 

127 Args: 

128 traj_file (Path): The pathlib.Path to the traj_file. 

129 top_file (Path): The pathlib.Path to the top_file. 

130 index (Optional[Union[int, list[int], np.ndarray, slice]]): The index 

131 to load the traj at. If ints are provided, the load_frame 

132 method is used. 

133 

134 Returns: 

135 md.Trajectory: The trajectory. 

136 

137 Raises: 

138 FileNotFoundError: If any of the files are not real. 

139 

140 """ 

141 # Local Folder Imports 

142 from .info_all import HDF5GroupWrite 

143 

144 if not traj_file.is_file(): 

145 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), traj_file) 

146 if not top_file.is_file(): 

147 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), top_file) 

148 

149 if index is not None: 

150 if isinstance(index, (int, np.integer)): 

151 t = md.load_frame(str(traj_file), top=str(top_file), index=index) 

152 elif isinstance(index, (list, np.ndarray, slice)): 

153 t = md.load(str(traj_file), top=str(top_file))[index] 

154 else: 

155 raise TypeError( 

156 f"Argument `index` must be int, list, np.ndarray or " 

157 f"slice. You supplied: {index.__class__.__name__}" 

158 ) 

159 else: 

160 try: 

161 with warnings.catch_warnings(): 

162 warnings.filterwarnings("ignore", r".*kwargs\signored.*", UserWarning) 

163 t = md.load(str(traj_file), top=str(top_file)) 

164 except tables.NoSuchNodeError as e: 

165 if traj_num is None: 

166 raise e 

167 with HDF5GroupWrite(traj_file) as h5file: 

168 t = h5file.read_traj(traj_num) 

169 except RuntimeError as e: 

170 raise Exception(f"The file {traj_file} is broken.") 

171 except ValueError as e: 

172 if "must contain" in str(e): 

173 raise Exception( 

174 f"The files {str(traj_file)} and {str(top_file)} contain " 

175 f"different number of atoms." 

176 ) from e 

177 raise e 

178 

179 if atom_index is not None: 

180 t = t.atom_slice(atom_index) 

181 return t 

182 

183 

184def _load_traj( 

185 *index: CanBeIndex, 

186 traj_file: Union[str, Path], 

187 top_file: Union[str, Path], 

188 traj_num: Union[int, None], 

189 atom_indices: Optional[np.ndarray] = None, 

190) -> tuple[md.Trajectory, np.ndarray]: 

191 """Loads a trajectory from disc and applies the indices from *index. 

192 

193 Args: 

194 *index (Unpack[Ts]): Variable length indices of which all need to be 

195 one of these datatypes: None, int, list[int], slice, np.ndarray. 

196 These indices are applied to the traj in order. So for a traj with 

197 100 frames, the indices (`slice (None, None, 5)`, [0, 2, 4, 6]) would 

198 yield the frames 0, 10, 20, 30, 40. A None will not slice the traj at all. 

199 traj_file (Union[str, Path]): The pathlib.Path to the traj_file. A string 

200 can also be supplied. This also allows passing a URL, like e.g: 

201 https://files.rcsb.org/view/1GHC.pdb. 

202 top_file (Union[str, Path]): The pathlib.Path to the top_file. Can also 

203 be str. 

204 

205 Returns: 

206 tuple[md.Trajectory, np.ndarray]: The trajectory and a numpy array, which 

207 is the result of np.arange() of the unadulterated trajectory. Can 

208 be useful for continued slicing and indexing to keep track of 

209 everything. 

210 

211 """ 

212 # check, whether traj_file is string and can be uri. 

213 if isinstance(traj_file, str): 

214 if _validate_uri(traj_file): 

215 is_uri = True 

216 else: 

217 is_uri = False 

218 traj_file = Path(traj_file) 

219 else: 

220 is_uri = False 

221 

222 top_file = Path(top_file) 

223 

224 for i, ind in enumerate(index): 

225 if i == 0: 

226 if ind is None: 

227 if is_uri: 

228 traj = md.load_pdb(str(traj_file)) 

229 else: 

230 traj = _load_traj_and_top(traj_file, top_file, traj_num=traj_num) 

231 _original_frame_indices = np.arange(traj.n_frames) 

232 elif isinstance(ind, (int, np.integer)): 

233 if traj_file.suffix == ".h5": 

234 if not this.PRINTED_HDF_ANNOTATION: 

235 print( 

236 "╰(◕ᗜ◕)╯ Thank you for using the HDF5 format to " 

237 "accelerate loading of single frames." 

238 ) 

239 this.PRINTED_HDF_ANNOTATION = True 

240 if is_uri: 

241 traj = md.load_pdb(str(traj_file)) 

242 _original_frame_indices = np.arange(traj.n_frames)[ind] 

243 traj = traj[ind] 

244 else: 

245 traj = _load_traj_and_top( 

246 traj_file, top_file, index=ind, traj_num=traj_num 

247 ) 

248 _original_frame_indices = np.array([ind]) 

249 elif isinstance(ind, slice): 

250 if Path(traj_file).suffix == ".h5": 

251 with h5.File(traj_file, "r") as file: 

252 n_frames = file["coordinates"].shape[0] 

253 if not this.PRINTED_HDF_ANNOTATION: 

254 print( 

255 "╰(◕ᗜ◕)╯ Thank you for using the HDF5 format to " 

256 "accelerate loading of single frames." 

257 ) 

258 this.PRINTED_HDF_ANNOTATION = True 

259 keys = np.arange(n_frames) 

260 for j, ind in enumerate(keys): 

261 if j == 0: 

262 traj = _load_traj_and_top( 

263 traj_file, top_file, index=ind, traj_num=traj_num 

264 ) 

265 else: 

266 traj = traj.join( 

267 _load_traj_and_top( 

268 traj_file, top_file, index=ind, traj_num=traj_num 

269 ) 

270 ) 

271 _original_frame_indices = keys 

272 else: 

273 if is_uri: 

274 traj = md.load_pdb(str(traj_file)) 

275 else: 

276 traj = _load_traj_and_top( 

277 traj_file, top_file, traj_num=traj_num 

278 ) 

279 _original_frame_indices = np.arange(traj.n_frames)[ind] 

280 traj = traj[ind] 

281 elif isinstance(ind, (list, np.ndarray)): 

282 if is_uri: 

283 traj = md.load_pdb(str(traj_file)) 

284 else: 

285 traj = _load_traj_and_top(traj_file, top_file, traj_num=traj_num) 

286 _original_frame_indices = np.arange(traj.n_frames)[ind] 

287 traj = traj[ind] 

288 else: 

289 msg = ( 

290 f"For indexing/slicing only int, slice, list, np.ndarray " 

291 f"can be used. You supplied: {ind.__class__.__name__}" 

292 ) 

293 raise TypeError(msg) 

294 else: 

295 if ind is not None: 

296 traj = traj[ind] 

297 

298 if atom_indices is not None: 

299 traj = traj.atom_slice(atom_indices) 

300 

301 return traj, _original_frame_indices