Coverage for encodermap/trajinfo/load_traj.py: 55%

91 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/trajinfo/load_traj.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Util functions for the `TrajEnsemble` and `SingleTraj` classes. 

23 

24""" 

25 

26 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32from __future__ import annotations 

33 

34import errno 

35import os 

36import sys 

37import warnings 

38from pathlib import Path 

39 

40import numpy as np 

41 

42from ..misc.misc import _validate_uri 

43 

44warnings.filterwarnings( 

45 "ignore", 

46 message=(".*top= kwargs ignored since this " "file parser does not support it.*"), 

47) 

48 

49 

50############################################################################## 

51# Optional Imports 

52############################################################################## 

53 

54 

55from .._optional_imports import _optional_import 

56 

57md = _optional_import("mdtraj") 

58h5 = _optional_import("h5py") 

59 

60 

61################################################################################ 

62# Typing 

63################################################################################ 

64 

65 

66from typing import TYPE_CHECKING, Optional, Union 

67 

68if TYPE_CHECKING: 

69 from typing_extensions import TypeVarTuple, Unpack 

70 

71 Ts = TypeVarTuple("Ts") 

72 import h5py as h5 

73 import mdtraj as md 

74 

75 Index = Optional[ 

76 Union[tuple[int, list, np.ndarray, slice]], int, list, np.ndarray, slice 

77 ] 

78 

79 

80################################################################################ 

81# Globals 

82################################################################################ 

83 

84 

85__all__ = [] 

86this = sys.modules[__name__] 

87this.PRINTED_HDF_ANNOTATION = False 

88 

89 

90################################################################################ 

91# Utils 

92################################################################################ 

93 

94 

95def _load_traj_and_top( 

96 traj_file: Path, 

97 top_file: Path, 

98 index: Optional[Union[int, list[int], np.ndarray, slice]] = None, 

99) -> md.Trajectory: 

100 """Loads a traj and top file and raises FileNotFoundError, if they do not exist. 

101 

102 Args: 

103 traj_file (Path): The pathlib.Path to the traj_file. 

104 top_file (Path): The pathlib.Path to the top_file. 

105 index (Optional[Union[int, list[int], np.ndarray, slice]]): The index 

106 to load the traj at. If ints are provided, the load_frame 

107 method is used. 

108 

109 Returns: 

110 md.Trajectory: The trajectory. 

111 

112 Raises: 

113 FileNotFoundError: If any of the files are not real. 

114 

115 """ 

116 if not traj_file.is_file(): 

117 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), traj_file) 

118 if not top_file.is_file(): 

119 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), top_file) 

120 

121 if index is not None: 121 ↛ 122line 121 didn't jump to line 122, because the condition on line 121 was never true

122 if isinstance(index, (int, np.integer)): 

123 return md.load_frame(str(traj_file), top=str(top_file), index=index) 

124 elif isinstance(index, (list, np.ndarray, slice)): 

125 return md.load(str(traj_file), top=str(top_file))[index] 

126 else: 

127 raise TypeError( 

128 f"Argument `index` must be int, list, np.ndarray or " 

129 f"slice. You supplied: {index.__class__.__name__}" 

130 ) 

131 else: 

132 return md.load(str(traj_file), top=str(top_file)) 

133 

134 

135def _load_traj( 

136 *index: Unpack(Ts), 

137 traj_file: Union[str, Path], 

138 top_file: Union[str, Path], 

139) -> tuple[md.Trajectory, np.ndarray]: 

140 """Loads a trajectory from disc and applies the indices from *index. 

141 

142 Args: 

143 *index (Unpack[Ts]): Variable length indices of which all need to be 

144 one of these datatypes: None, int, np.int, list[int], slice, np.ndarray. 

145 These indices are applied to the traj in order. So for a traj with 

146 100 frames, the indices (slice(None, None, 5), [0, 2, 4, 6]) would 

147 yield the frames 0, 10, 20, 30, 40. A None will not slice the traj at all. 

148 traj_file (Union[str, Path]): The pathlib.Path to the traj_file. A string 

149 can also be supplied. This also allows to pass a URL, like e.g: 

150 https://files.rcsb.org/view/1GHC.pdb. 

151 top_file (Union[str, Path]): The pathlib.Path to the top_file. Can also 

152 be str. 

153 

154 Returns: 

155 tuple[md.Trajectory, np.ndarray]: The trajectory and a numpy array, which 

156 is the result of np.arange() of the unadulterated trajectory. Can 

157 be useful for continued slicing and indexing to keep track of 

158 everyhting. 

159 

160 """ 

161 # check, whether traj_file is string and can be uri. 

162 if isinstance(traj_file, str): 162 ↛ 169line 162 didn't jump to line 169, because the condition on line 162 was never false

163 if _validate_uri(traj_file): 

164 is_uri = True 

165 else: 

166 is_uri = False 

167 traj_file = Path(traj_file) 

168 else: 

169 is_uri = False 

170 

171 top_file = Path(top_file) 

172 

173 for i, ind in enumerate(index): 

174 if i == 0: 

175 if ind is None: 

176 if is_uri: 

177 traj = md.load_pdb(str(traj_file)) 

178 else: 

179 traj = _load_traj_and_top(traj_file, top_file) 

180 _original_frame_indices = np.arange(traj.n_frames) 

181 elif isinstance(ind, (int, np.integer)): 181 ↛ 182line 181 didn't jump to line 182, because the condition on line 181 was never true

182 print("here") 

183 raise Exception 

184 if traj_file.suffix == ".h5": 

185 if not this.PRINTED_HDF_ANNOTATION: 

186 print( 

187 "╰(◕ᗜ◕)╯ Thank you for using the HDF5 format to " 

188 "accelerate loading of single frames." 

189 ) 

190 this.PRINTED_HDF_ANNOTATION = True 

191 if is_uri: 

192 traj = md.load_pdb(str(traj_file)) 

193 _original_frame_indices = np.arange(traj.n_frames)[ind] 

194 traj = traj[ind] 

195 else: 

196 traj = _load_traj_and_top(traj_file, top_file, index=ind) 

197 _original_frame_indices = np.array([ind]) 

198 elif isinstance(ind, slice): 

199 if Path(traj_file).suffix == ".h5": 199 ↛ 200line 199 didn't jump to line 200, because the condition on line 199 was never true

200 with h5.File(traj_file, "r") as file: 

201 n_frames = file["coordinates"].shape[0] 

202 if not this.PRINTED_HDF_ANNOTATION: 

203 print( 

204 "╰(◕ᗜ◕)╯ Thank you for using the HDF5 format to " 

205 "accelerate loading of single frames." 

206 ) 

207 this.PRINTED_HDF_ANNOTATION = True 

208 keys = np.arange(n_frames) 

209 for j, ind in enumerate(keys): 

210 if j == 0: 

211 traj = _load_traj_and_top(traj_file, top_file, index=ind) 

212 else: 

213 traj = traj.join( 

214 _load_traj_and_top(traj_file, top_file, index=ind) 

215 ) 

216 _original_frame_indices = keys 

217 else: 

218 if is_uri: 218 ↛ 221line 218 didn't jump to line 221, because the condition on line 218 was never false

219 traj = md.load_pdb(str(traj_file)) 

220 else: 

221 traj = _load_traj_and_top(traj_file, top_file) 

222 _original_frame_indices = np.arange(traj.n_frames)[ind] 

223 traj = traj[ind] 

224 elif isinstance(ind, (list, np.ndarray)): 224 ↛ 232line 224 didn't jump to line 232

225 if is_uri: 225 ↛ 226line 225 didn't jump to line 226, because the condition on line 225 was never true

226 traj = md.load_pdb(str(traj_file)) 

227 else: 

228 traj = _load_traj_and_top(traj_file, top_file) 

229 _original_frame_indices = np.arange(traj.n_frames)[ind] 

230 traj = traj[ind] 

231 else: 

232 msg = ( 

233 f"For indexing/slicing only int, slice, list, np.ndarray " 

234 f"can be used. You supplied: {ind.__class__.__name__}" 

235 ) 

236 raise TypeError(msg) 

237 else: 

238 if ind is not None: 238 ↛ 173line 238 didn't jump to line 173, because the condition on line 238 was never false

239 traj = traj[ind] 

240 return traj, _original_frame_indices