Coverage for encodermap/misc/xarray_save_wrong_hdf5.py: 36%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/misc/xarray_save_wrong_hdf5.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Allows the combined storing of CVs and trajectories in single HDF5/NetCDF4 files. 

23 

24These files represent collated and completed trajectory ensembles, which can be 

25lazy-loaded (memory efficient) and used as training input for EncoderMap's NNs. 

26 

27""" 

28 

29################################################################################ 

30# Imports 

31################################################################################ 

32 

33 

34# Future Imports at the top 

35from __future__ import annotations 

36 

37# Standard Library Imports 

38import os 

39import re 

40from io import BytesIO 

41from numbers import Number 

42from pathlib import Path 

43 

44# Third Party Imports 

45import numpy as np 

46from optional_imports import _optional_import 

47 

48 

49################################################################################ 

50# Optional Imports 

51################################################################################ 

52 

53 

54xr = _optional_import("xarray") 

55AbstractDataStore = _optional_import("xarray", "backends.common.AbstractDataStore") 

56ArrayWriter = _optional_import("xarray", "backends.common.ArrayWriter") 

57Dataset = _optional_import("xarray", "core.dataset.Dataset") 

58backends = _optional_import("xarray", "backends") 

59_get_scheduler = _optional_import("xarray", "backends.locks._get_scheduler") 

60conventions = _optional_import("xarray", "conventions") 

61_get_default_engine = _optional_import("xarray", "backends.api._get_default_engine") 

62 

63 

64try: 

65 # Third Party Imports 

66 from dask.delayed import Delayed 

67except ImportError: 

68 Delayed = None 

69 

70 

71################################################################################ 

72# Typing 

73################################################################################ 

74 

75 

76# Standard Library Imports 

77from collections.abc import Callable, Hashable, Iterable, Mapping 

78from typing import TYPE_CHECKING, Optional, Union 

79 

80 

81WritableStoresType = dict[str, Callable] 

82 

83try: 

84 WRITEABLE_STORES: dict[str, Callable] = { 

85 "netcdf4": backends.NetCDF4DataStore.open, 

86 "scipy": backends.ScipyDataStore, 

87 "h5netcdf": backends.H5NetCDFStore.open, 

88 } 

89except (ImportError, ValueError, AttributeError): 

90 WRITEABLE_STORES: dict[str, Callable] = {} 

91 

92if TYPE_CHECKING: 

93 # Third Party Imports 

94 from dask.delayed import Delayed 

95 from xarray import Dataset, backends, conventions 

96 from xarray.backends.api import _get_default_engine 

97 from xarray.backends.common import AbstractDataStore, ArrayWriter 

98 from xarray.backends.locks import _get_scheduler 

99 

100 

101################################################################################ 

102# Globals 

103################################################################################ 

104 

105 

106__all__: list[str] = ["save_netcdf_alongside_mdtraj"] 

107 

108 

109################################################################################ 

110# Public functions 

111################################################################################ 

112 

113 

114def save_netcdf_alongside_mdtraj(fname: str, dataset: Dataset) -> None: 

115 """Saves a netcdf dataset alongside an MDTraj trajectory.""" 

116 _to_netcdf( 

117 dataset, 

118 fname, 

119 mode="a", 

120 format="NETCDF4", 

121 group="CVs", 

122 engine="h5netcdf", 

123 invalid_netcdf=True, 

124 ) 

125 

126 

127################################################################################ 

128# xarray duplication to allow saving dataset alongside mdtraj 

129################################################################################ 

130 

131 

132def dump_to_store( 

133 dataset: Dataset, 

134 store: WritableStoresType = WRITEABLE_STORES, 

135 writer: Optional[Callable] = None, 

136 encoder: Optional[Callable] = None, 

137 encoding: Optional[str] = None, 

138 unlimited_dims: Optional[Iterable[Hashable]] = None, 

139) -> None: # pragma: no cover, no doccheck 

140 """Store dataset contents to a backends.*DataStore object.""" 

141 if writer is None: 

142 writer = ArrayWriter() 

143 

144 if encoding is None: 

145 encoding = {} 

146 

147 variables, attrs = conventions.encode_dataset_coordinates(dataset) 

148 

149 check_encoding = set() 

150 for k, enc in encoding.items(): 

151 # no need to shallow copy the variable again; that already happened 

152 # in encode_dataset_coordinates 

153 variables[k].encoding = enc 

154 check_encoding.add(k) 

155 

156 if encoder: 

157 variables, attrs = encoder(variables, attrs) 

158 

159 store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims) 

160 

161 

162def _normalize_path(path: str) -> str: 

163 """Normalize a path. 

164 

165 See Also: 

166 https://docs.python.org/3.10/library/os.path.html#os.path.abspath 

167 

168 Args: 

169 path (str): The input path. 

170 

171 Returns: 

172 str: The output path. 

173 

174 """ 

175 if is_remote_uri(path): 

176 return path 

177 else: 

178 return os.path.abspath(os.path.expanduser(path)) 

179 

180 

181def is_remote_uri(path: str) -> bool: 

182 """Checks, whether a path is a remote URI. 

183 

184 Args: 

185 path (str): The path to check. 

186 

187 Returns: 

188 bool: Whether the path is an URI. 

189 

190 """ 

191 return bool(re.search(r"^https?\://", path)) 

192 

193 

194def _validate_dataset_names(dataset: Dataset) -> None: 

195 """DataArray.name and Dataset keys must be a string or None""" 

196 

197 def check_name(name): 

198 if isinstance(name, str): 

199 if not name: 

200 raise ValueError( 

201 "Invalid name for DataArray or Dataset key: " 

202 "string must be length 1 or greater for " 

203 "serialization to netCDF files" 

204 ) 

205 elif name is not None: 

206 raise TypeError( 

207 "DataArray.name or Dataset key must be either a " 

208 "string or None for serialization to netCDF files" 

209 ) 

210 

211 for k in dataset.variables: 

212 check_name(k) 

213 

214 

215def _validate_attrs(dataset: Dataset) -> None: 

216 """`attrs` must have a string key and a value which is either: a number, 

217 a string, an ndarray or a list/tuple of numbers/strings. 

218 """ 

219 

220 def check_attr(name, value): 

221 if isinstance(name, str): 

222 if not name: 

223 raise ValueError( 

224 "Invalid name for attr: string must be " 

225 "length 1 or greater for serialization to " 

226 "netCDF files" 

227 ) 

228 else: 

229 raise TypeError( 

230 "Invalid name for attr: {} must be a string for " 

231 "serialization to netCDF files".format(name) 

232 ) 

233 

234 if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)): 

235 raise TypeError( 

236 "Invalid value for attr: {} must be a number, " 

237 "a string, an ndarray or a list/tuple of " 

238 "numbers/strings for serialization to netCDF " 

239 "files".format(value) 

240 ) 

241 

242 # Check attrs on the dataset itself 

243 for k, v in dataset.attrs.items(): 

244 check_attr(k, v) 

245 

246 # Check attrs on each variable within the dataset 

247 for variable in dataset.variables.values(): 

248 for k, v in variable.attrs.items(): 

249 check_attr(k, v) 

250 

251 

252def _to_netcdf( 

253 dataset: Dataset, 

254 path_or_file: Optional[str] = None, 

255 mode: Optional[str] = "w", 

256 format: Optional[str] = None, 

257 group: Optional[str] = None, 

258 engine: Optional[str] = None, 

259 encoding: Optional[Mapping] = None, 

260 unlimited_dims: Optional[Iterable[Hashable]] = None, 

261 compute: bool = True, 

262 multifile: bool = False, 

263 invalid_netcdf: bool = False, 

264) -> Optional[Delayed]: # pragma: no cover, no doccheck 

265 """This function creates an appropriate datastore for writing a dataset to 

266 disk as a HDF5 file. 

267 

268 See `Dataset.to_netcdf` for full API docs. 

269 

270 The `multifile` argument is only for the private use of `save_mfdataset`. 

271 

272 """ 

273 if isinstance(path_or_file, Path): 

274 path_or_file = str(path_or_file) 

275 

276 if encoding is None: 

277 encoding = {} 

278 

279 if path_or_file is None: 

280 if engine is None: 

281 engine = "scipy" 

282 elif engine != "scipy": 

283 raise ValueError( 

284 "invalid engine for creating bytes with " 

285 "to_netcdf: %r. Only the default engine " 

286 "or engine='scipy' is supported" % engine 

287 ) 

288 if not compute: 

289 raise NotImplementedError( 

290 "to_netcdf() with compute=False is not yet implemented when " 

291 "returning bytes" 

292 ) 

293 elif isinstance(path_or_file, str): 

294 if engine is None: 

295 engine = _get_default_engine(path_or_file) 

296 path_or_file = _normalize_path(path_or_file) 

297 else: # file-like object 

298 engine = "scipy" 

299 

300 # validate Dataset keys, DataArray names, and attr keys/values 

301 _validate_dataset_names(dataset) 

302 _validate_attrs(dataset) 

303 

304 try: 

305 store_open = WRITEABLE_STORES[engine] 

306 except KeyError: 

307 raise ValueError("unrecognized engine for to_netcdf: %r" % engine) 

308 

309 if format is not None: 

310 format = format.upper() 

311 

312 # handle scheduler specific logic 

313 scheduler = _get_scheduler() 

314 have_chunks = any(v.chunks for v in dataset.variables.values()) 

315 

316 autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] 

317 if autoclose and engine == "scipy": 

318 raise NotImplementedError( 

319 "Writing netCDF files with the %s backend " 

320 "is not currently supported with dask's %s " 

321 "scheduler" % (engine, scheduler) 

322 ) 

323 

324 target = path_or_file if path_or_file is not None else BytesIO() 

325 kwargs = dict(autoclose=True) if autoclose else {} 

326 # added phony dims support 

327 if engine == "h5netcdf": 

328 kwargs.update(dict(phony_dims="access")) 

329 if invalid_netcdf: 

330 if engine == "h5netcdf": 

331 kwargs["invalid_netcdf"] = invalid_netcdf 

332 else: 

333 raise ValueError( 

334 "unrecognized option 'invalid_netcdf' for engine %s" % engine 

335 ) 

336 store = store_open(target, mode, format, group, **kwargs) 

337 

338 if unlimited_dims is None: 

339 unlimited_dims = dataset.encoding.get("unlimited_dims", None) 

340 if unlimited_dims is not None: 

341 if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): 

342 unlimited_dims = [unlimited_dims] 

343 else: 

344 unlimited_dims = list(unlimited_dims) 

345 

346 writer = ArrayWriter() 

347 

348 # TODO: figure out how to refactor this logic (here and in save_mfdataset) 

349 # to avoid this mess of conditionals 

350 try: 

351 # TODO: allow this work (setting up the file for writing array data) 

352 # to be parallelized with dask 

353 dump_to_store( 

354 dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims 

355 ) 

356 if autoclose: 

357 store.close() 

358 

359 if multifile: 

360 return writer, store 

361 

362 writes = writer.sync(compute=compute) 

363 

364 if path_or_file is None: 

365 store.sync() 

366 return target.getvalue() 

367 finally: 

368 if not multifile and compute: 

369 store.close() 

370 

371 if not compute: 

372 # Third Party Imports 

373 import dask 

374 

375 return dask.delayed(_finalize_store)(writes, store) 

376 return None