Coverage for encodermap/misc/xarray_save_wrong_hdf5.py: 36%
58 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/misc/xarray_save_wrong_hdf5.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Allows the combined storing of CVs and trajectories in single HDF5/NetCDF4 files.
24These files represent collated and completed trajectory ensembles, which can be
25lazy-loaded (memory efficient) and used as training input for EncoderMap's NNs.
27"""
29################################################################################
30# Imports
31################################################################################
34# Future Imports at the top
35from __future__ import annotations
37# Standard Library Imports
38import os
39import re
40from io import BytesIO
41from numbers import Number
42from pathlib import Path
44# Third Party Imports
45import numpy as np
46from optional_imports import _optional_import
49################################################################################
50# Optional Imports
51################################################################################
54xr = _optional_import("xarray")
55AbstractDataStore = _optional_import("xarray", "backends.common.AbstractDataStore")
56ArrayWriter = _optional_import("xarray", "backends.common.ArrayWriter")
57Dataset = _optional_import("xarray", "core.dataset.Dataset")
58backends = _optional_import("xarray", "backends")
59_get_scheduler = _optional_import("xarray", "backends.locks._get_scheduler")
60conventions = _optional_import("xarray", "conventions")
61_get_default_engine = _optional_import("xarray", "backends.api._get_default_engine")
64try:
65 # Third Party Imports
66 from dask.delayed import Delayed
67except ImportError:
68 Delayed = None
71################################################################################
72# Typing
73################################################################################
76# Standard Library Imports
77from collections.abc import Callable, Hashable, Iterable, Mapping
78from typing import TYPE_CHECKING, Optional, Union
81WritableStoresType = dict[str, Callable]
83try:
84 WRITEABLE_STORES: dict[str, Callable] = {
85 "netcdf4": backends.NetCDF4DataStore.open,
86 "scipy": backends.ScipyDataStore,
87 "h5netcdf": backends.H5NetCDFStore.open,
88 }
89except (ImportError, ValueError, AttributeError):
90 WRITEABLE_STORES: dict[str, Callable] = {}
92if TYPE_CHECKING:
93 # Third Party Imports
94 from dask.delayed import Delayed
95 from xarray import Dataset, backends, conventions
96 from xarray.backends.api import _get_default_engine
97 from xarray.backends.common import AbstractDataStore, ArrayWriter
98 from xarray.backends.locks import _get_scheduler
101################################################################################
102# Globals
103################################################################################
106__all__: list[str] = ["save_netcdf_alongside_mdtraj"]
109################################################################################
110# Public functions
111################################################################################
114def save_netcdf_alongside_mdtraj(fname: str, dataset: Dataset) -> None:
115 """Saves a netcdf dataset alongside an MDTraj trajectory."""
116 _to_netcdf(
117 dataset,
118 fname,
119 mode="a",
120 format="NETCDF4",
121 group="CVs",
122 engine="h5netcdf",
123 invalid_netcdf=True,
124 )
127################################################################################
128# xarray duplication to allow saving dataset alongside mdtraj
129################################################################################
132def dump_to_store(
133 dataset: Dataset,
134 store: WritableStoresType = WRITEABLE_STORES,
135 writer: Optional[Callable] = None,
136 encoder: Optional[Callable] = None,
137 encoding: Optional[str] = None,
138 unlimited_dims: Optional[Iterable[Hashable]] = None,
139) -> None: # pragma: no cover, no doccheck
140 """Store dataset contents to a backends.*DataStore object."""
141 if writer is None:
142 writer = ArrayWriter()
144 if encoding is None:
145 encoding = {}
147 variables, attrs = conventions.encode_dataset_coordinates(dataset)
149 check_encoding = set()
150 for k, enc in encoding.items():
151 # no need to shallow copy the variable again; that already happened
152 # in encode_dataset_coordinates
153 variables[k].encoding = enc
154 check_encoding.add(k)
156 if encoder:
157 variables, attrs = encoder(variables, attrs)
159 store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
162def _normalize_path(path: str) -> str:
163 """Normalize a path.
165 See Also:
166 https://docs.python.org/3.10/library/os.path.html#os.path.abspath
168 Args:
169 path (str): The input path.
171 Returns:
172 str: The output path.
174 """
175 if is_remote_uri(path):
176 return path
177 else:
178 return os.path.abspath(os.path.expanduser(path))
181def is_remote_uri(path: str) -> bool:
182 """Checks, whether a path is a remote URI.
184 Args:
185 path (str): The path to check.
187 Returns:
188 bool: Whether the path is an URI.
190 """
191 return bool(re.search(r"^https?\://", path))
194def _validate_dataset_names(dataset: Dataset) -> None:
195 """DataArray.name and Dataset keys must be a string or None"""
197 def check_name(name):
198 if isinstance(name, str):
199 if not name:
200 raise ValueError(
201 "Invalid name for DataArray or Dataset key: "
202 "string must be length 1 or greater for "
203 "serialization to netCDF files"
204 )
205 elif name is not None:
206 raise TypeError(
207 "DataArray.name or Dataset key must be either a "
208 "string or None for serialization to netCDF files"
209 )
211 for k in dataset.variables:
212 check_name(k)
215def _validate_attrs(dataset: Dataset) -> None:
216 """`attrs` must have a string key and a value which is either: a number,
217 a string, an ndarray or a list/tuple of numbers/strings.
218 """
220 def check_attr(name, value):
221 if isinstance(name, str):
222 if not name:
223 raise ValueError(
224 "Invalid name for attr: string must be "
225 "length 1 or greater for serialization to "
226 "netCDF files"
227 )
228 else:
229 raise TypeError(
230 "Invalid name for attr: {} must be a string for "
231 "serialization to netCDF files".format(name)
232 )
234 if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)):
235 raise TypeError(
236 "Invalid value for attr: {} must be a number, "
237 "a string, an ndarray or a list/tuple of "
238 "numbers/strings for serialization to netCDF "
239 "files".format(value)
240 )
242 # Check attrs on the dataset itself
243 for k, v in dataset.attrs.items():
244 check_attr(k, v)
246 # Check attrs on each variable within the dataset
247 for variable in dataset.variables.values():
248 for k, v in variable.attrs.items():
249 check_attr(k, v)
252def _to_netcdf(
253 dataset: Dataset,
254 path_or_file: Optional[str] = None,
255 mode: Optional[str] = "w",
256 format: Optional[str] = None,
257 group: Optional[str] = None,
258 engine: Optional[str] = None,
259 encoding: Optional[Mapping] = None,
260 unlimited_dims: Optional[Iterable[Hashable]] = None,
261 compute: bool = True,
262 multifile: bool = False,
263 invalid_netcdf: bool = False,
264) -> Optional[Delayed]: # pragma: no cover, no doccheck
265 """This function creates an appropriate datastore for writing a dataset to
266 disk as a HDF5 file.
268 See `Dataset.to_netcdf` for full API docs.
270 The `multifile` argument is only for the private use of `save_mfdataset`.
272 """
273 if isinstance(path_or_file, Path):
274 path_or_file = str(path_or_file)
276 if encoding is None:
277 encoding = {}
279 if path_or_file is None:
280 if engine is None:
281 engine = "scipy"
282 elif engine != "scipy":
283 raise ValueError(
284 "invalid engine for creating bytes with "
285 "to_netcdf: %r. Only the default engine "
286 "or engine='scipy' is supported" % engine
287 )
288 if not compute:
289 raise NotImplementedError(
290 "to_netcdf() with compute=False is not yet implemented when "
291 "returning bytes"
292 )
293 elif isinstance(path_or_file, str):
294 if engine is None:
295 engine = _get_default_engine(path_or_file)
296 path_or_file = _normalize_path(path_or_file)
297 else: # file-like object
298 engine = "scipy"
300 # validate Dataset keys, DataArray names, and attr keys/values
301 _validate_dataset_names(dataset)
302 _validate_attrs(dataset)
304 try:
305 store_open = WRITEABLE_STORES[engine]
306 except KeyError:
307 raise ValueError("unrecognized engine for to_netcdf: %r" % engine)
309 if format is not None:
310 format = format.upper()
312 # handle scheduler specific logic
313 scheduler = _get_scheduler()
314 have_chunks = any(v.chunks for v in dataset.variables.values())
316 autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"]
317 if autoclose and engine == "scipy":
318 raise NotImplementedError(
319 "Writing netCDF files with the %s backend "
320 "is not currently supported with dask's %s "
321 "scheduler" % (engine, scheduler)
322 )
324 target = path_or_file if path_or_file is not None else BytesIO()
325 kwargs = dict(autoclose=True) if autoclose else {}
326 # added phony dims support
327 if engine == "h5netcdf":
328 kwargs.update(dict(phony_dims="access"))
329 if invalid_netcdf:
330 if engine == "h5netcdf":
331 kwargs["invalid_netcdf"] = invalid_netcdf
332 else:
333 raise ValueError(
334 "unrecognized option 'invalid_netcdf' for engine %s" % engine
335 )
336 store = store_open(target, mode, format, group, **kwargs)
338 if unlimited_dims is None:
339 unlimited_dims = dataset.encoding.get("unlimited_dims", None)
340 if unlimited_dims is not None:
341 if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable):
342 unlimited_dims = [unlimited_dims]
343 else:
344 unlimited_dims = list(unlimited_dims)
346 writer = ArrayWriter()
348 # TODO: figure out how to refactor this logic (here and in save_mfdataset)
349 # to avoid this mess of conditionals
350 try:
351 # TODO: allow this work (setting up the file for writing array data)
352 # to be parallelized with dask
353 dump_to_store(
354 dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
355 )
356 if autoclose:
357 store.close()
359 if multifile:
360 return writer, store
362 writes = writer.sync(compute=compute)
364 if path_or_file is None:
365 store.sync()
366 return target.getvalue()
367 finally:
368 if not multifile and compute:
369 store.close()
371 if not compute:
372 # Third Party Imports
373 import dask
375 return dask.delayed(_finalize_store)(writes, store)
376 return None