Coverage for encodermap/misc/xarray_save_wrong_hdf5.py: 14%

136 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/misc/xarray_save_wrong_hdf5.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Allows the combined storing of CVs and trajectories in single HDF5/NetCDF4 files. 

23 

24These files represent collated and completed trajectory ensembles, which can be 

25lazy-loaded (memory efficient) and used as training input for encodermap's NNs. 

26 

27""" 

28 

29################################################################################ 

30# Imports 

31################################################################################ 

32 

33 

34from __future__ import annotations 

35 

36import os 

37import re 

38from math import pi 

39from numbers import Number 

40from pathlib import Path 

41 

42import numpy as np 

43 

44from .._optional_imports import _optional_import 

45from .errors import BadError 

46 

47################################################################################ 

48# Optional Imports 

49################################################################################ 

50 

51 

52xr = _optional_import("xarray") 

53AbstractDataStore = _optional_import("xarray", "backends.common.AbstractDataStore") 

54ArrayWriter = _optional_import("xarray", "backends.common.ArrayWriter") 

55Dataset = _optional_import("xarray", "core.dataset.Dataset") 

56backends = _optional_import("xarray", "backends") 

57_get_scheduler = _optional_import("xarray", "backends.locks._get_scheduler") 

58conventions = _optional_import("xarray", "conventions") 

59_get_default_engine = _optional_import("xarray", "backends.api._get_default_engine") 

60 

61 

62try: 

63 from dask.delayed import Delayed 

64except ImportError: 

65 Delayed = None 

66 

67 

68################################################################################ 

69# Typing 

70################################################################################ 

71 

72 

73from typing import ( 

74 TYPE_CHECKING, 

75 Callable, 

76 Dict, 

77 Hashable, 

78 Iterable, 

79 Mapping, 

80 Optional, 

81 Tuple, 

82 Union, 

83) 

84 

85try: 

86 WRITEABLE_STORES: Dict[str, Callable] = { 

87 "netcdf4": backends.NetCDF4DataStore.open, 

88 "scipy": backends.ScipyDataStore, 

89 "h5netcdf": backends.H5NetCDFStore.open, 

90 } 

91except (ImportError, ValueError, AttributeError): 

92 WRITEABLE_STORES = {} 

93 

94if TYPE_CHECKING: 

95 from dask.delayed import Delayed 

96 from xarray import Dataset, backends, conventions 

97 from xarray.backends.api import _get_default_engine 

98 from xarray.backends.common import AbstractDataStore, ArrayWriter 

99 from xarray.backends.locks import _get_scheduler 

100 

101 

102################################################################################ 

103# Globals 

104################################################################################ 

105 

106 

107__all__ = ["save_netcdf_alongside_mdtraj"] 

108 

109 

110################################################################################ 

111# Public functions 

112################################################################################ 

113 

114 

115def save_netcdf_alongside_mdtraj(fname: str, dataset: Dataset) -> None: 

116 _to_netcdf( 

117 dataset, 

118 fname, 

119 mode="a", 

120 format="NETCDF4", 

121 group="CVs", 

122 engine="h5netcdf", 

123 invalid_netcdf=True, 

124 ) 

125 

126 

127############################################################################## 

128# xarray duplication to allow saving dataset alongside mdtraj 

129############################################################################## 

130 

131 

132def dump_to_store( 

133 dataset: Dataset, 

134 store: WRITEABLE_STORES, 

135 writer: Optional[ArrayWriter] = None, 

136 encoder=None, 

137 encoding: Optional[str] = None, 

138 unlimited_dims: Optional[Iterable[Hashable]] = None, 

139): 

140 """Store dataset contents to a backends.*DataStore object.""" 

141 if writer is None: 

142 writer = ArrayWriter() 

143 

144 if encoding is None: 

145 encoding = {} 

146 

147 variables, attrs = conventions.encode_dataset_coordinates(dataset) 

148 

149 check_encoding = set() 

150 for k, enc in encoding.items(): 

151 # no need to shallow copy the variable again; that already happened 

152 # in encode_dataset_coordinates 

153 variables[k].encoding = enc 

154 check_encoding.add(k) 

155 

156 if encoder: 

157 variables, attrs = encoder(variables, attrs) 

158 

159 store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims) 

160 

161 

162def _normalize_path(path: str) -> str: 

163 if is_remote_uri(path): 

164 return path 

165 else: 

166 return os.path.abspath(os.path.expanduser(path)) 

167 

168 

169def is_remote_uri(path: str) -> bool: 

170 return bool(re.search(r"^https?\://", path)) 

171 

172 

173def _validate_dataset_names(dataset: Dataset) -> None: 

174 """DataArray.name and Dataset keys must be a string or None""" 

175 

176 def check_name(name): 

177 if isinstance(name, str): 

178 if not name: 

179 raise ValueError( 

180 "Invalid name for DataArray or Dataset key: " 

181 "string must be length 1 or greater for " 

182 "serialization to netCDF files" 

183 ) 

184 elif name is not None: 

185 raise TypeError( 

186 "DataArray.name or Dataset key must be either a " 

187 "string or None for serialization to netCDF files" 

188 ) 

189 

190 for k in dataset.variables: 

191 check_name(k) 

192 

193 

194def _validate_attrs(dataset: Dataset) -> None: 

195 """`attrs` must have a string key and a value which is either: a number, 

196 a string, an ndarray or a list/tuple of numbers/strings. 

197 """ 

198 

199 def check_attr(name, value): 

200 if isinstance(name, str): 

201 if not name: 

202 raise ValueError( 

203 "Invalid name for attr: string must be " 

204 "length 1 or greater for serialization to " 

205 "netCDF files" 

206 ) 

207 else: 

208 raise TypeError( 

209 "Invalid name for attr: {} must be a string for " 

210 "serialization to netCDF files".format(name) 

211 ) 

212 

213 if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)): 

214 raise TypeError( 

215 "Invalid value for attr: {} must be a number, " 

216 "a string, an ndarray or a list/tuple of " 

217 "numbers/strings for serialization to netCDF " 

218 "files".format(value) 

219 ) 

220 

221 # Check attrs on the dataset itself 

222 for k, v in dataset.attrs.items(): 

223 check_attr(k, v) 

224 

225 # Check attrs on each variable within the dataset 

226 for variable in dataset.variables.values(): 

227 for k, v in variable.attrs.items(): 

228 check_attr(k, v) 

229 

230 

231def _to_netcdf( 

232 dataset: Dataset, 

233 path_or_file: Optional[str] = None, 

234 mode: Optional[str] = "w", 

235 format: Optional[str] = None, 

236 group: Optional[str] = None, 

237 engine: Optional[str] = None, 

238 encoding: Optional[Mapping] = None, 

239 unlimited_dims: Optional[Iterable[Hashable]] = None, 

240 compute: bool = True, 

241 multifile: bool = False, 

242 invalid_netcdf: bool = False, 

243) -> Union[None, Delayed]: 

244 """This function creates an appropriate datastore for writing a dataset to 

245 disk as a netCDF file 

246 

247 See `Dataset.to_netcdf` for full API docs. 

248 

249 The ``multifile`` argument is only for the private use of save_mfdataset. 

250 """ 

251 if isinstance(path_or_file, Path): 

252 path_or_file = str(path_or_file) 

253 

254 if encoding is None: 

255 encoding = {} 

256 

257 if path_or_file is None: 

258 if engine is None: 

259 engine = "scipy" 

260 elif engine != "scipy": 

261 raise ValueError( 

262 "invalid engine for creating bytes with " 

263 "to_netcdf: %r. Only the default engine " 

264 "or engine='scipy' is supported" % engine 

265 ) 

266 if not compute: 

267 raise NotImplementedError( 

268 "to_netcdf() with compute=False is not yet implemented when " 

269 "returning bytes" 

270 ) 

271 elif isinstance(path_or_file, str): 

272 if engine is None: 

273 engine = _get_default_engine(path_or_file) 

274 path_or_file = _normalize_path(path_or_file) 

275 else: # file-like object 

276 engine = "scipy" 

277 

278 # validate Dataset keys, DataArray names, and attr keys/values 

279 _validate_dataset_names(dataset) 

280 _validate_attrs(dataset) 

281 

282 try: 

283 store_open = WRITEABLE_STORES[engine] 

284 except KeyError: 

285 raise ValueError("unrecognized engine for to_netcdf: %r" % engine) 

286 

287 if format is not None: 

288 format = format.upper() 

289 

290 # handle scheduler specific logic 

291 scheduler = _get_scheduler() 

292 have_chunks = any(v.chunks for v in dataset.variables.values()) 

293 

294 autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] 

295 if autoclose and engine == "scipy": 

296 raise NotImplementedError( 

297 "Writing netCDF files with the %s backend " 

298 "is not currently supported with dask's %s " 

299 "scheduler" % (engine, scheduler) 

300 ) 

301 

302 target = path_or_file if path_or_file is not None else BytesIO() 

303 kwargs = dict(autoclose=True) if autoclose else {} 

304 # added phony dims support 

305 if engine == "h5netcdf": 

306 kwargs.update(dict(phony_dims="access")) 

307 if invalid_netcdf: 

308 if engine == "h5netcdf": 

309 kwargs["invalid_netcdf"] = invalid_netcdf 

310 else: 

311 raise ValueError( 

312 "unrecognized option 'invalid_netcdf' for engine %s" % engine 

313 ) 

314 store = store_open(target, mode, format, group, **kwargs) 

315 

316 if unlimited_dims is None: 

317 unlimited_dims = dataset.encoding.get("unlimited_dims", None) 

318 if unlimited_dims is not None: 

319 if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): 

320 unlimited_dims = [unlimited_dims] 

321 else: 

322 unlimited_dims = list(unlimited_dims) 

323 

324 writer = ArrayWriter() 

325 

326 # TODO: figure out how to refactor this logic (here and in save_mfdataset) 

327 # to avoid this mess of conditionals 

328 try: 

329 # TODO: allow this work (setting up the file for writing array data) 

330 # to be parallelized with dask 

331 dump_to_store( 

332 dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims 

333 ) 

334 if autoclose: 

335 store.close() 

336 

337 if multifile: 

338 return writer, store 

339 

340 writes = writer.sync(compute=compute) 

341 

342 if path_or_file is None: 

343 store.sync() 

344 return target.getvalue() 

345 finally: 

346 if not multifile and compute: 

347 store.close() 

348 

349 if not compute: 

350 import dask 

351 

352 return dask.delayed(_finalize_store)(writes, store) 

353 return None