Coverage for encodermap/trajinfo/load_traj.py: 13%
118 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/trajinfo/load_traj.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Util functions for the `TrajEnsemble` and `SingleTraj` classes.
24"""
27################################################################################
28# Imports
29################################################################################
32# Future Imports at the top
33from __future__ import annotations
35# Standard Library Imports
36import errno
37import os
38import sys
39import tempfile
40import urllib
41import warnings
42from pathlib import Path
44# Third Party Imports
45import numpy as np
46import requests
47import tables
49# Encodermap imports
50from encodermap._typing import CanBeIndex
51from encodermap.misc.misc import _validate_uri
54warnings.filterwarnings(
55 "ignore",
56 message=(".*top= kwargs ignored since this " "file parser does not support it.*"),
57)
60##############################################################################
61# Optional Imports
62##############################################################################
65# Third Party Imports
66from optional_imports import _optional_import
69md = _optional_import("mdtraj")
70h5 = _optional_import("h5py")
73################################################################################
74# Typing
75################################################################################
78# Standard Library Imports
79from typing import TYPE_CHECKING, Optional, Union
82if TYPE_CHECKING:
83 # Third Party Imports
84 import h5py as h5
85 import mdtraj as md
88################################################################################
89# Globals
90################################################################################
93__all__: list[str] = []
94this = sys.modules[__name__]
95this.PRINTED_HDF_ANNOTATION = False
98################################################################################
99# Utils
100################################################################################
103def _load_pdb_from_uri(
104 uri: str,
105) -> md.Topology:
106 """Loads urls and if MDTraj misbehaves saves them in a temporary file."""
107 assert _validate_uri(uri)
108 try:
109 return md.load_pdb(uri).top
110 except urllib.error.URLError as e:
111 with tempfile.NamedTemporaryFile(suffix=".pdb") as f:
112 text = requests.get(uri).text
113 f.write(text)
114 top = md.load_pdb(f.name).top
115 return top
118def _load_traj_and_top(
119 traj_file: Path,
120 top_file: Path,
121 traj_num: Union[int, None],
122 index: Optional[Union[int, list[int], np.ndarray, slice]] = None,
123 atom_index: Optional[np.ndarray] = None,
124) -> md.Trajectory:
125 """Loads a traj and top file and raises FileNotFoundError, if they do not exist.
127 Args:
128 traj_file (Path): The pathlib.Path to the traj_file.
129 top_file (Path): The pathlib.Path to the top_file.
130 index (Optional[Union[int, list[int], np.ndarray, slice]]): The index
131 to load the traj at. If ints are provided, the load_frame
132 method is used.
134 Returns:
135 md.Trajectory: The trajectory.
137 Raises:
138 FileNotFoundError: If any of the files are not real.
140 """
141 # Local Folder Imports
142 from .info_all import HDF5GroupWrite
144 if not traj_file.is_file():
145 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), traj_file)
146 if not top_file.is_file():
147 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), top_file)
149 if index is not None:
150 if isinstance(index, (int, np.integer)):
151 t = md.load_frame(str(traj_file), top=str(top_file), index=index)
152 elif isinstance(index, (list, np.ndarray, slice)):
153 t = md.load(str(traj_file), top=str(top_file))[index]
154 else:
155 raise TypeError(
156 f"Argument `index` must be int, list, np.ndarray or "
157 f"slice. You supplied: {index.__class__.__name__}"
158 )
159 else:
160 try:
161 with warnings.catch_warnings():
162 warnings.filterwarnings("ignore", r".*kwargs\signored.*", UserWarning)
163 t = md.load(str(traj_file), top=str(top_file))
164 except tables.NoSuchNodeError as e:
165 if traj_num is None:
166 raise e
167 with HDF5GroupWrite(traj_file) as h5file:
168 t = h5file.read_traj(traj_num)
169 except RuntimeError as e:
170 raise Exception(f"The file {traj_file} is broken.")
171 except ValueError as e:
172 if "must contain" in str(e):
173 raise Exception(
174 f"The files {str(traj_file)} and {str(top_file)} contain "
175 f"different number of atoms."
176 ) from e
177 raise e
179 if atom_index is not None:
180 t = t.atom_slice(atom_index)
181 return t
184def _load_traj(
185 *index: CanBeIndex,
186 traj_file: Union[str, Path],
187 top_file: Union[str, Path],
188 traj_num: Union[int, None],
189 atom_indices: Optional[np.ndarray] = None,
190) -> tuple[md.Trajectory, np.ndarray]:
191 """Loads a trajectory from disc and applies the indices from *index.
193 Args:
194 *index (Unpack[Ts]): Variable length indices of which all need to be
195 one of these datatypes: None, int, list[int], slice, np.ndarray.
196 These indices are applied to the traj in order. So for a traj with
197 100 frames, the indices (`slice (None, None, 5)`, [0, 2, 4, 6]) would
198 yield the frames 0, 10, 20, 30, 40. A None will not slice the traj at all.
199 traj_file (Union[str, Path]): The pathlib.Path to the traj_file. A string
200 can also be supplied. This also allows passing a URL, like e.g:
201 https://files.rcsb.org/view/1GHC.pdb.
202 top_file (Union[str, Path]): The pathlib.Path to the top_file. Can also
203 be str.
205 Returns:
206 tuple[md.Trajectory, np.ndarray]: The trajectory and a numpy array, which
207 is the result of np.arange() of the unadulterated trajectory. Can
208 be useful for continued slicing and indexing to keep track of
209 everything.
211 """
212 # check, whether traj_file is string and can be uri.
213 if isinstance(traj_file, str):
214 if _validate_uri(traj_file):
215 is_uri = True
216 else:
217 is_uri = False
218 traj_file = Path(traj_file)
219 else:
220 is_uri = False
222 top_file = Path(top_file)
224 for i, ind in enumerate(index):
225 if i == 0:
226 if ind is None:
227 if is_uri:
228 traj = md.load_pdb(str(traj_file))
229 else:
230 traj = _load_traj_and_top(traj_file, top_file, traj_num=traj_num)
231 _original_frame_indices = np.arange(traj.n_frames)
232 elif isinstance(ind, (int, np.integer)):
233 if traj_file.suffix == ".h5":
234 if not this.PRINTED_HDF_ANNOTATION:
235 print(
236 "╰(◕ᗜ◕)╯ Thank you for using the HDF5 format to "
237 "accelerate loading of single frames."
238 )
239 this.PRINTED_HDF_ANNOTATION = True
240 if is_uri:
241 traj = md.load_pdb(str(traj_file))
242 _original_frame_indices = np.arange(traj.n_frames)[ind]
243 traj = traj[ind]
244 else:
245 traj = _load_traj_and_top(
246 traj_file, top_file, index=ind, traj_num=traj_num
247 )
248 _original_frame_indices = np.array([ind])
249 elif isinstance(ind, slice):
250 if Path(traj_file).suffix == ".h5":
251 with h5.File(traj_file, "r") as file:
252 n_frames = file["coordinates"].shape[0]
253 if not this.PRINTED_HDF_ANNOTATION:
254 print(
255 "╰(◕ᗜ◕)╯ Thank you for using the HDF5 format to "
256 "accelerate loading of single frames."
257 )
258 this.PRINTED_HDF_ANNOTATION = True
259 keys = np.arange(n_frames)
260 for j, ind in enumerate(keys):
261 if j == 0:
262 traj = _load_traj_and_top(
263 traj_file, top_file, index=ind, traj_num=traj_num
264 )
265 else:
266 traj = traj.join(
267 _load_traj_and_top(
268 traj_file, top_file, index=ind, traj_num=traj_num
269 )
270 )
271 _original_frame_indices = keys
272 else:
273 if is_uri:
274 traj = md.load_pdb(str(traj_file))
275 else:
276 traj = _load_traj_and_top(
277 traj_file, top_file, traj_num=traj_num
278 )
279 _original_frame_indices = np.arange(traj.n_frames)[ind]
280 traj = traj[ind]
281 elif isinstance(ind, (list, np.ndarray)):
282 if is_uri:
283 traj = md.load_pdb(str(traj_file))
284 else:
285 traj = _load_traj_and_top(traj_file, top_file, traj_num=traj_num)
286 _original_frame_indices = np.arange(traj.n_frames)[ind]
287 traj = traj[ind]
288 else:
289 msg = (
290 f"For indexing/slicing only int, slice, list, np.ndarray "
291 f"can be used. You supplied: {ind.__class__.__name__}"
292 )
293 raise TypeError(msg)
294 else:
295 if ind is not None:
296 traj = traj[ind]
298 if atom_indices is not None:
299 traj = traj.atom_slice(atom_indices)
301 return traj, _original_frame_indices