Coverage for encodermap/misc/xarray_save_wrong_hdf5.py: 14%
136 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/misc/xarray_save_wrong_hdf5.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Allows the combined storing of CVs and trajectories in single HDF5/NetCDF4 files.
24These files represent collated and completed trajectory ensembles, which can be
25lazy-loaded (memory efficient) and used as training input for encodermap's NNs.
27"""
29################################################################################
30# Imports
31################################################################################
34from __future__ import annotations
36import os
37import re
38from math import pi
39from numbers import Number
40from pathlib import Path
42import numpy as np
44from .._optional_imports import _optional_import
45from .errors import BadError
47################################################################################
48# Optional Imports
49################################################################################
52xr = _optional_import("xarray")
53AbstractDataStore = _optional_import("xarray", "backends.common.AbstractDataStore")
54ArrayWriter = _optional_import("xarray", "backends.common.ArrayWriter")
55Dataset = _optional_import("xarray", "core.dataset.Dataset")
56backends = _optional_import("xarray", "backends")
57_get_scheduler = _optional_import("xarray", "backends.locks._get_scheduler")
58conventions = _optional_import("xarray", "conventions")
59_get_default_engine = _optional_import("xarray", "backends.api._get_default_engine")
62try:
63 from dask.delayed import Delayed
64except ImportError:
65 Delayed = None
68################################################################################
69# Typing
70################################################################################
73from typing import (
74 TYPE_CHECKING,
75 Callable,
76 Dict,
77 Hashable,
78 Iterable,
79 Mapping,
80 Optional,
81 Tuple,
82 Union,
83)
85try:
86 WRITEABLE_STORES: Dict[str, Callable] = {
87 "netcdf4": backends.NetCDF4DataStore.open,
88 "scipy": backends.ScipyDataStore,
89 "h5netcdf": backends.H5NetCDFStore.open,
90 }
91except (ImportError, ValueError, AttributeError):
92 WRITEABLE_STORES = {}
94if TYPE_CHECKING:
95 from dask.delayed import Delayed
96 from xarray import Dataset, backends, conventions
97 from xarray.backends.api import _get_default_engine
98 from xarray.backends.common import AbstractDataStore, ArrayWriter
99 from xarray.backends.locks import _get_scheduler
102################################################################################
103# Globals
104################################################################################
107__all__ = ["save_netcdf_alongside_mdtraj"]
110################################################################################
111# Public functions
112################################################################################
115def save_netcdf_alongside_mdtraj(fname: str, dataset: Dataset) -> None:
116 _to_netcdf(
117 dataset,
118 fname,
119 mode="a",
120 format="NETCDF4",
121 group="CVs",
122 engine="h5netcdf",
123 invalid_netcdf=True,
124 )
127##############################################################################
128# xarray duplication to allow saving dataset alongside mdtraj
129##############################################################################
132def dump_to_store(
133 dataset: Dataset,
134 store: WRITEABLE_STORES,
135 writer: Optional[ArrayWriter] = None,
136 encoder=None,
137 encoding: Optional[str] = None,
138 unlimited_dims: Optional[Iterable[Hashable]] = None,
139):
140 """Store dataset contents to a backends.*DataStore object."""
141 if writer is None:
142 writer = ArrayWriter()
144 if encoding is None:
145 encoding = {}
147 variables, attrs = conventions.encode_dataset_coordinates(dataset)
149 check_encoding = set()
150 for k, enc in encoding.items():
151 # no need to shallow copy the variable again; that already happened
152 # in encode_dataset_coordinates
153 variables[k].encoding = enc
154 check_encoding.add(k)
156 if encoder:
157 variables, attrs = encoder(variables, attrs)
159 store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
162def _normalize_path(path: str) -> str:
163 if is_remote_uri(path):
164 return path
165 else:
166 return os.path.abspath(os.path.expanduser(path))
169def is_remote_uri(path: str) -> bool:
170 return bool(re.search(r"^https?\://", path))
173def _validate_dataset_names(dataset: Dataset) -> None:
174 """DataArray.name and Dataset keys must be a string or None"""
176 def check_name(name):
177 if isinstance(name, str):
178 if not name:
179 raise ValueError(
180 "Invalid name for DataArray or Dataset key: "
181 "string must be length 1 or greater for "
182 "serialization to netCDF files"
183 )
184 elif name is not None:
185 raise TypeError(
186 "DataArray.name or Dataset key must be either a "
187 "string or None for serialization to netCDF files"
188 )
190 for k in dataset.variables:
191 check_name(k)
194def _validate_attrs(dataset: Dataset) -> None:
195 """`attrs` must have a string key and a value which is either: a number,
196 a string, an ndarray or a list/tuple of numbers/strings.
197 """
199 def check_attr(name, value):
200 if isinstance(name, str):
201 if not name:
202 raise ValueError(
203 "Invalid name for attr: string must be "
204 "length 1 or greater for serialization to "
205 "netCDF files"
206 )
207 else:
208 raise TypeError(
209 "Invalid name for attr: {} must be a string for "
210 "serialization to netCDF files".format(name)
211 )
213 if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)):
214 raise TypeError(
215 "Invalid value for attr: {} must be a number, "
216 "a string, an ndarray or a list/tuple of "
217 "numbers/strings for serialization to netCDF "
218 "files".format(value)
219 )
221 # Check attrs on the dataset itself
222 for k, v in dataset.attrs.items():
223 check_attr(k, v)
225 # Check attrs on each variable within the dataset
226 for variable in dataset.variables.values():
227 for k, v in variable.attrs.items():
228 check_attr(k, v)
231def _to_netcdf(
232 dataset: Dataset,
233 path_or_file: Optional[str] = None,
234 mode: Optional[str] = "w",
235 format: Optional[str] = None,
236 group: Optional[str] = None,
237 engine: Optional[str] = None,
238 encoding: Optional[Mapping] = None,
239 unlimited_dims: Optional[Iterable[Hashable]] = None,
240 compute: bool = True,
241 multifile: bool = False,
242 invalid_netcdf: bool = False,
243) -> Union[None, Delayed]:
244 """This function creates an appropriate datastore for writing a dataset to
245 disk as a netCDF file
247 See `Dataset.to_netcdf` for full API docs.
249 The ``multifile`` argument is only for the private use of save_mfdataset.
250 """
251 if isinstance(path_or_file, Path):
252 path_or_file = str(path_or_file)
254 if encoding is None:
255 encoding = {}
257 if path_or_file is None:
258 if engine is None:
259 engine = "scipy"
260 elif engine != "scipy":
261 raise ValueError(
262 "invalid engine for creating bytes with "
263 "to_netcdf: %r. Only the default engine "
264 "or engine='scipy' is supported" % engine
265 )
266 if not compute:
267 raise NotImplementedError(
268 "to_netcdf() with compute=False is not yet implemented when "
269 "returning bytes"
270 )
271 elif isinstance(path_or_file, str):
272 if engine is None:
273 engine = _get_default_engine(path_or_file)
274 path_or_file = _normalize_path(path_or_file)
275 else: # file-like object
276 engine = "scipy"
278 # validate Dataset keys, DataArray names, and attr keys/values
279 _validate_dataset_names(dataset)
280 _validate_attrs(dataset)
282 try:
283 store_open = WRITEABLE_STORES[engine]
284 except KeyError:
285 raise ValueError("unrecognized engine for to_netcdf: %r" % engine)
287 if format is not None:
288 format = format.upper()
290 # handle scheduler specific logic
291 scheduler = _get_scheduler()
292 have_chunks = any(v.chunks for v in dataset.variables.values())
294 autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"]
295 if autoclose and engine == "scipy":
296 raise NotImplementedError(
297 "Writing netCDF files with the %s backend "
298 "is not currently supported with dask's %s "
299 "scheduler" % (engine, scheduler)
300 )
302 target = path_or_file if path_or_file is not None else BytesIO()
303 kwargs = dict(autoclose=True) if autoclose else {}
304 # added phony dims support
305 if engine == "h5netcdf":
306 kwargs.update(dict(phony_dims="access"))
307 if invalid_netcdf:
308 if engine == "h5netcdf":
309 kwargs["invalid_netcdf"] = invalid_netcdf
310 else:
311 raise ValueError(
312 "unrecognized option 'invalid_netcdf' for engine %s" % engine
313 )
314 store = store_open(target, mode, format, group, **kwargs)
316 if unlimited_dims is None:
317 unlimited_dims = dataset.encoding.get("unlimited_dims", None)
318 if unlimited_dims is not None:
319 if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable):
320 unlimited_dims = [unlimited_dims]
321 else:
322 unlimited_dims = list(unlimited_dims)
324 writer = ArrayWriter()
326 # TODO: figure out how to refactor this logic (here and in save_mfdataset)
327 # to avoid this mess of conditionals
328 try:
329 # TODO: allow this work (setting up the file for writing array data)
330 # to be parallelized with dask
331 dump_to_store(
332 dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
333 )
334 if autoclose:
335 store.close()
337 if multifile:
338 return writer, store
340 writes = writer.sync(compute=compute)
342 if path_or_file is None:
343 store.sync()
344 return target.getvalue()
345 finally:
346 if not multifile and compute:
347 store.close()
349 if not compute:
350 import dask
352 return dask.delayed(_finalize_store)(writes, store)
353 return None