Coverage for encodermap/loading/delayed.py: 10%

243 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/loading/delayed.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Functions to use with the DaskFeaturizer class. 

23 

24""" 

25 

26 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32# Future Imports at the top 

33from __future__ import annotations 

34 

35# Standard Library Imports 

36from pathlib import Path 

37 

38# Third Party Imports 

39import numpy as np 

40from MDAnalysis.coordinates.XTC import XTCReader 

41from optional_imports import _optional_import 

42 

43 

44################################################################################ 

45# Optional Imports 

46################################################################################ 

47 

48 

49dask = _optional_import("dask") 

50da = _optional_import("dask", "array") 

51dd = _optional_import("dask", "dataframe") 

52box_vectors_to_lengths_and_angles = _optional_import( 

53 "mdtraj", "utils.unitcell.box_vectors_to_lengths_and_angles" 

54) 

55_dist_mic = _optional_import("mdtraj", "geometry._geometry._dist_mic") 

56_dist = _optional_import("mdtraj", "geometry._geometry._dist") 

57_dihedral_mic = _optional_import("mdtraj", "geometry._geometry._dihedral_mic") 

58_dihedral = _optional_import("mdtraj", "geometry._geometry._dihedral") 

59_angle_mic = _optional_import("mdtraj", "geometry._geometry._angle_mic") 

60_angle = _optional_import("mdtraj", "geometry._geometry._angle") 

61jit = _optional_import("numba", "jit") 

62prange = _optional_import("numba", "prange") 

63xr = _optional_import("xarray") 

64md = _optional_import("mdtraj") 

65h5py = _optional_import("h5py") 

66 

67 

68################################################################################ 

69# Typing 

70################################################################################ 

71 

72 

73# Standard Library Imports 

74from typing import TYPE_CHECKING, Literal, Optional, Union, overload 

75 

76 

77if TYPE_CHECKING: 

78 # Third Party Imports 

79 from dask.delayed import Delayed 

80 

81 # Encodermap imports 

82 from encodermap.loading.featurizer import DaskFeaturizer 

83 from encodermap.trajinfo.info_single import SingleTraj 

84 

85 

86################################################################################ 

87# Utils 

88################################################################################ 

89 

90 

91# @jit(parallel=True, nopython=True) 

92def calc_bravais_box(box_info: np.ndarray) -> np.ndarray: 

93 """Calculates the Bravais vectors from lengths and angles (in degrees). 

94 

95 Note: 

96 This code is adapted from gyroid, which is licensed under the BSD 

97 http://pythonhosted.org/gyroid/_modules/gyroid/unitcell.html 

98 

99 Args: 

100 box_info (np.ndarray): The box info, where the columns are ordered as 

101 follows: a, b, c, alpha, beta. gamma in degree. 

102 

103 Returns: 

104 np.ndarray: The bravais vectors as a shape (n_frames, 3, 3) array. 

105 

106 """ 

107 a_length, b_length, c_length = box_info[:, :3].T 

108 alpha, beta, gamma = box_info[:, 3:].T 

109 

110 alpha = alpha * np.pi / 180 

111 beta = beta * np.pi / 180 

112 gamma = gamma * np.pi / 180 

113 

114 a = np.zeros((3, len(a_length)), dtype="float32") 

115 a[0] = a_length 

116 b = np.zeros((3, len(b_length)), dtype="float32") 

117 b[0] = b_length * np.cos(gamma) 

118 b[1] = b_length * np.sin(gamma) 

119 cx = c_length * np.cos(beta) 

120 cy = c_length * (np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma) 

121 cz = np.sqrt(c_length * c_length - cx * cx - cy * cy) 

122 c = np.empty((3, len(c_length)), dtype="float32") 

123 c[0] = cx 

124 c[1] = cy 

125 c[2] = cz 

126 

127 if not a.shape == b.shape == c.shape: 

128 raise TypeError("Shape is messed up.") 

129 

130 # Make sure that all vector components that are _almost_ 0 are set exactly 

131 # to 0 

132 tol = 1e-6 

133 for i in range(a.shape[0]): 

134 for j in range(a.shape[1]): 

135 if a[i, j] > -tol and a[i, j] < tol: 

136 a[i, j] = 0.0 

137 for i in range(b.shape[0]): 

138 for j in range(b.shape[1]): 

139 if b[i, j] > -tol and b[i, j] < tol: 

140 b[i, j] = 0.0 

141 for i in range(c.shape[0]): 

142 for j in range(c.shape[1]): 

143 if c[i, j] > -tol and c[i, j] < tol: 

144 c[i, j] = 0.0 

145 

146 unitcell_vectors = np.ascontiguousarray( 

147 np.swapaxes(np.dstack((a.T, b.T, c.T)), 1, 2) 

148 ) 

149 

150 return unitcell_vectors 

151 

152 

153@dask.delayed(nout=4) 

154def _load_xyz(traj, frame_indices): 

155 """Distances in nm. Angles in degree.""" 

156 positions = np.empty( 

157 shape=(len(frame_indices), traj.n_atoms, 3), dtype="float32", order="C" 

158 ) 

159 time = np.empty(shape=(len(frame_indices)), dtype="float32", order="C") 

160 unitcell_info = np.empty(shape=(len(frame_indices), 6), dtype="float32", order="C") 

161 for i, ts in enumerate(traj[frame_indices]): 

162 positions[i] = ts.positions 

163 time[i] = ts.time 

164 unitcell_info[i] = ts._unitcell 

165 positions /= 10 # for some heretical reason, MDAnalysis uses angstrom 

166 unitcell_info[:, :3] /= 10 

167 unitcell_vectors = calc_bravais_box(unitcell_info) 

168 

169 return positions, time, unitcell_vectors, unitcell_info 

170 

171 

172def load_xyz( 

173 traj_file: str, 

174 frame_indices: np.ndarray, 

175 traj_num: Optional[int] = None, 

176) -> tuple[da.array, da.array, da.array, da.array]: 

177 if Path(traj_file).suffix == ".h5": 

178 return load_xyz_from_h5(traj_file, frame_indices, traj_num) 

179 if Path(traj_file).suffix != ".xtc": 

180 raise Exception( 

181 f"Currently only .xtc and .h5 trajectory files are supported. " 

182 f"But adding more formats is easy. Raise an issue, if you want " 

183 f"to have them added." 

184 ) 

185 traj = XTCReader(traj_file) 

186 n_atoms = traj.n_atoms 

187 n_frames = len(frame_indices) 

188 p, t, uv, ui = _load_xyz(traj, frame_indices) 

189 p = da.from_delayed( 

190 p, 

191 shape=(n_frames, n_atoms, 3), 

192 dtype="float32", 

193 ) 

194 t = da.from_delayed( 

195 t, 

196 shape=(n_frames,), 

197 dtype="float32", 

198 ) 

199 uv = da.from_delayed( 

200 uv, 

201 shape=(n_frames, 3, 3), 

202 dtype="float32", 

203 ) 

204 ui = da.from_delayed( 

205 ui, 

206 shape=(n_frames, 6), 

207 dtype="float32", 

208 ) 

209 return p, t, uv, ui 

210 

211 

212@dask.delayed(nout=4) 

213def _load_xyz_from_h5( 

214 traj_file: str, 

215 frame_indices: np.ndarray, 

216 traj_num: int, 

217) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 

218 """Distances in nm. Angles in degree. 

219 

220 Args: 

221 traj_file (str): The file to load. 

222 frame_indices (np.ndarray): An int array giving the positions to load. 

223 traj_num (int): Which traj num the output should be put to. 

224 

225 Returns: 

226 tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray: A four-tuple of np 

227 arrays. The order of these arrays is: 

228 * positions: Shape (len(frame_indices), 3): The xyz coordinates in nm. 

229 * time: shape (len(frame_indices), ): The time in ps. 

230 * unitcell_vectors: Shape (len(frame_indices), 3, 3): The unitcell vectors. 

231 * unitcell_info: Shape (len(frame_indices), 6), where [:, :3] are 

232 the unitcell lengths in nm and [:, 3:] are the unitcell angles 

233 in degree. 

234 

235 """ 

236 keys = ["coordinates", "time", "cell_lengths", "cell_angles"] 

237 if traj_num is not None: 

238 keys_with_num = [f"{k}_{traj_num}" for k in keys] 

239 else: 

240 keys_with_num = keys 

241 data = {} 

242 with h5py.File(traj_file, "r") as f: 

243 for k, out in zip(keys_with_num, keys): 

244 if k not in f and out not in f: 

245 data[out] = None 

246 elif k in f: 

247 data[out] = f[k][frame_indices] 

248 elif out in f: 

249 data[out] = f[out][frame_indices] 

250 unitcell_info = np.empty(shape=(len(frame_indices), 6), dtype="float32", order="C") 

251 unitcell_info[:, :3] = data["cell_lengths"] 

252 unitcell_info[:, 3:] = data["cell_angles"] 

253 unitcell_vectors = calc_bravais_box(unitcell_info) 

254 return data["coordinates"], data["time"], unitcell_vectors, unitcell_info 

255 

256 

257def load_xyz_from_h5( 

258 traj_file: str, 

259 frame_indices: np.ndarray, 

260 traj_num: Optional[int] = None, 

261) -> tuple[da.array, da.array, da.array, da.array]: 

262 """Loads xyz coordinates and unitcell info from a block in a .h5 file. 

263 

264 Standard MDTraj h5 keys are: 

265 ['cell_angles', 'cell_lengths', 'coordinates', 'time', 'topology'] 

266 

267 Args: 

268 traj_file (str): The file to load. 

269 frame_indices (np.ndarray): An int array giving the positions to load. 

270 traj_num (int): Which traj num the output should be put to. 

271 

272 Returns: 

273 tuple[da.array, da.array, da.array, da.array]: A four-tuple of dask 

274 arrays that contain dask delayeds. The order of these arrays is: 

275 * positions: Shape (len(frame_indices), 3): The xyz coordinates in nm. 

276 * time: shape (len(frame_indices), ): The time in ps. 

277 * unitcell_vectors: Shape (len(frame_indices), 3, 3): The unitcell vectors. 

278 * unitcell_info: Shape (len(frame_indices), 6), where [:, :3] are 

279 the unitcell lengths in nm and [:, 3:] are the unitcell angles 

280 in degree. 

281 

282 """ 

283 # Encodermap imports 

284 from encodermap.trajinfo.info_all import HDF5GroupWrite 

285 

286 n_frames = len(frame_indices) 

287 with HDF5GroupWrite(traj_file) as f: 

288 if "topology" not in f.keys() and traj_num is not None: 

289 top = f.read_topology(f"topology_{traj_num}") 

290 else: 

291 top = f.read_topology("topology") 

292 n_atoms = top.n_atoms 

293 p, t, uv, ui = _load_xyz_from_h5(traj_file, frame_indices, traj_num) 

294 p = da.from_delayed(p, shape=(n_frames, n_atoms, 3), dtype="float32") 

295 t = da.from_delayed(t, shape=(n_frames,), dtype="float32") 

296 uv = da.from_delayed(uv, shape=(n_frames, 3, 3), dtype="float32") 

297 ui = da.from_delayed(ui, shape=(n_frames, 6), dtype="float32") 

298 return p, t, uv, ui 

299 

300 

301################################################################################ 

302# Dask graph creation 

303################################################################################ 

304 

305 

306@overload 

307def build_dask_xarray( 307 ↛ exitline 307 didn't jump to the function exit

308 featurizer: DaskFeaturizer, 

309 traj: Optional[SingleTraj], 

310 streamable: bool, 

311 return_delayeds: Literal[True], 

312) -> tuple[xr.Dataset, dict[str, xr.Variable]]: ... 

313 

314 

315@overload 

316def build_dask_xarray( 316 ↛ exitline 316 didn't jump to the function exit

317 featurizer: DaskFeaturizer, 

318 traj: Optional[SingleTraj], 

319 streamable: bool, 

320 return_delayeds: Literal[False], 

321) -> tuple[xr.Dataset, None]: ... 

322 

323 

324def build_dask_xarray( 

325 featurizer: DaskFeaturizer, 

326 traj: Optional[SingleTraj] = None, 

327 streamable: bool = False, 

328 return_delayeds: bool = False, 

329) -> tuple[xr.Dataset, Union[None, dict[str, xr.Variable]]]: 

330 """Builds a large dask xarray, which will be distributively evaluated. 

331 

332 This class takes a `DaskFeaturizer` class, which contains a list of features. 

333 Every feature in this list contains enough information for the delayed functions 

334 to calculate the requested quantities when provided the xyz coordinates of the 

335 atoms, the unitcell vectors, and the unitcell infos as a Bravais matrix. 

336 

337 Args: 

338 featurizer (DaskFeaturizer): An instance of the DaskFeaturizer. 

339 return_coordinates (bool): Whether to add this information: 

340 all_xyz, all_time, all_cell_lengths, all_cell_angles 

341 to the returned values. Defaults to False. 

342 streamable (bool): Whether to divide the calculations into one-frame 

343 blocks, which can then only be calculated when requested. 

344 

345 Returns: 

346 Union[xr.Dataset, tuple[xr.Dataset, list[dask.delayed]]: 

347 When `return_coordinates` is False, only a xr.Dataset is returned. 

348 Otherwise, a tuple with a xr.Dataset and a sequence of dask.Delayed 

349 objects is returned. 

350 

351 

352 """ 

353 # Imports 

354 # Encodermap imports 

355 from encodermap.loading.features import CustomFeature 

356 from encodermap.misc.xarray import ( 

357 FEATURE_NAMES, 

358 make_dataarray, 

359 make_frame_CV_dataarray, 

360 make_position_dataarray, 

361 ) 

362 from encodermap.trajinfo.trajinfo_utils import trajs_combine_attrs 

363 

364 # definitions 

365 coordinates = { 

366 "coordinates": ["md_frame", "md_atom", "md_cart"], 

367 "time": ["md_frame"], 

368 "cell_lengths": ["md_frame", "md_length"], 

369 "cell_angles": ["md_frame", "md_angle"], 

370 } 

371 

372 # pre-define blocks from the trajectories 

373 n_blocks = 10 

374 

375 # append delayeds here: 

376 if return_delayeds: 

377 delayeds = {} 

378 else: 

379 delayeds = None 

380 

381 assert len(featurizer.feat.active_features) > 0 

382 

383 # collect the Datasets in this list 

384 DSs = [] 

385 

386 # if the dask featurizer contains an `EnsembleFeaturizer`, we can use `itertrajs()` 

387 if traj is None: 

388 if hasattr(featurizer.feat, "trajs"): 

389 iterable = featurizer.feat.trajs.itertrajs() 

390 else: 

391 iterable = enumerate([featurizer.feat.traj]) 

392 else: 

393 iterable = enumerate([traj]) 

394 

395 # iter over trajs or just the one 

396 for i, traj in iterable: 

397 n_frames = len(traj.id) 

398 if not streamable: 

399 n_frames_per_block = n_frames // n_blocks 

400 blocks = [ 

401 np.arange(i * n_frames_per_block, (i + 1) * n_frames_per_block) 

402 for i in range(n_blocks - 1) 

403 ] 

404 blocks.append(np.arange((n_blocks - 1) * n_frames_per_block, n_frames)) 

405 # remove empty blocks 

406 blocks = list(filter(lambda x: x.size > 0, blocks)) 

407 else: 

408 n_frames_per_block = 1 

409 blocks = [[i] for i in range(n_frames)] 

410 

411 # collect multiple DataArrays here 

412 DAs = {} 

413 indexes = {} 

414 

415 if delayeds is not None: 

416 xyz_traj = [] 

417 time_traj = [] 

418 lengths_traj = [] 

419 angles_traj = [] 

420 

421 # distribute the loading to multiple workers 

422 for j, block in enumerate(blocks): 

423 # get the actual frame indices if the traj was sliced 

424 if traj.id.ndim == 2: 

425 frame_indices = traj.id[block, 1] 

426 else: 

427 frame_indices = traj.id[block] 

428 

429 assert len(frame_indices) > 0, f"{frame_indices=}" 

430 xyz, time, unitcell_vector, unitcell_info = load_xyz( 

431 traj.traj_file, frame_indices, traj.traj_num 

432 ) 

433 

434 if delayeds is not None: 

435 unitcell_lengths = unitcell_info[:, :3] 

436 unitcell_angles = unitcell_info[:, 3:] 

437 xyz_traj.append(xyz) 

438 time_traj.append(time) 

439 lengths_traj.append(unitcell_lengths) 

440 angles_traj.append(unitcell_angles) 

441 

442 # iterate over the features and let them use the traj information 

443 if hasattr(featurizer.feat, "trajs"): 

444 features = featurizer.feat.active_features[traj.top] 

445 else: 

446 features = featurizer.feat.active_features 

447 

448 for k, feat in enumerate(features): 

449 # the name of the feature will be used for traceability 

450 if not isinstance(feat, CustomFeature) or not issubclass( 

451 feat.__class__, CustomFeature 

452 ): 

453 assert hasattr(feat, "dask_indices") and hasattr( 

454 feat, "dask_transform" 

455 ), ( 

456 f"For `feature.transform()` to be acceptable as delayed, " 

457 f"the feature needs to implement the `dask_indices` property " 

458 f"and `dask_transform` staticmethod. The feature {feat} has " 

459 f"this these methods and attributes " 

460 f"{[a for a in feat.__dir__() if not a.startswith('_')]}" 

461 ) 

462 assert feat.delayed, ( 

463 f"The feature {feat} was not altered to return a delayed " 

464 f"transform. Please read up in `encodermap.DaskFeaturizer` how " 

465 f"to make features work with dask delayed." 

466 ) 

467 

468 # decide on the name 

469 try: 

470 name = FEATURE_NAMES[feat.name] 

471 except (KeyError, AttributeError): 

472 if hasattr(feat, "name"): 

473 if isinstance(feat.name, str): 

474 name = feat.name 

475 if "mdtraj.trajectory" in feat.name.lower(): 

476 feat.name = feat.__class__.__name__ 

477 name = feat.__class__.__name__ 

478 else: 

479 name = feat.__class__.__name__ 

480 feat.name = name 

481 else: 

482 name = feat.__class__.__name__ 

483 if name == "CustomFeature": 

484 name = feat.describe()[0].split()[0] 

485 feat.name = name 

486 

487 # the feature length is given by the describe() of the feature 

488 if callable(feat.describe()): 

489 feat_length = len([i for i in feat.describe()(traj.top)]) 

490 else: 

491 feat_length = len(feat.describe()) 

492 

493 # dynamically populate kwargs with feature settings 

494 kwargs = {"indexes": getattr(feat, feat.dask_indices)} 

495 if feat._use_periodic: 

496 kwargs["periodic"] = feat.periodic 

497 if feat._use_angle: 

498 kwargs["deg"] = feat.deg 

499 kwargs["cossin"] = feat.cossin 

500 # if feat._use_omega: 

501 # kwargs["omega"] = feat.omega 

502 if hasattr(feat, "_nonstandard_transform_args"): 

503 for k in feat._nonstandard_transform_args: 

504 if not hasattr(feat, k): 

505 kwargs[k] = None 

506 else: 

507 kwargs[k] = getattr(feat, k) 

508 a = da.from_delayed( 

509 feat.dask_transform( 

510 **kwargs, 

511 xyz=xyz, 

512 unitcell_vectors=unitcell_vector, 

513 unitcell_info=unitcell_info, 

514 ), 

515 shape=(len(frame_indices), feat_length), 

516 dtype="float32", 

517 ) 

518 

519 if hasattr(feat, "deg"): 

520 deg = feat.deg 

521 else: 

522 deg = None 

523 

524 if ( 

525 feat.name 

526 in ["AllCartesians", "CentralCartesians", "SideChainCartesians"] 

527 or feat.atom_feature 

528 ): 

529 a = da.reshape(a, (len(frame_indices), -1, 3)) 

530 a = da.expand_dims(a, axis=0) 

531 if hasattr(featurizer, "indices_by_top"): 

532 feat.indexes = featurizer.indices_by_top[traj.top][feat.name] 

533 dataarray, ind_dataarray = make_position_dataarray( 

534 feat.describe(), 

535 traj[block], 

536 name, 

537 a, 

538 deg=deg, 

539 feat=feat, 

540 ) 

541 else: 

542 a = da.expand_dims(a, axis=0) 

543 if feat._dim == 1: 

544 dataarray, ind_dataarray = make_frame_CV_dataarray( 

545 feat.describe(), 

546 traj[block], 

547 name, 

548 a, 

549 deg=deg, 

550 feat=feat, 

551 ) 

552 else: 

553 if hasattr(featurizer, "indices_by_top"): 

554 feat.indexes = featurizer.indices_by_top[traj.top][ 

555 feat.name 

556 ] 

557 dataarray, ind_dataarray = make_dataarray( 

558 feat.describe(), 

559 traj[block], 

560 name, 

561 a, 

562 deg=deg, 

563 feat=feat, 

564 ) 

565 

566 assert dataarray.size > 0, ( 

567 f"Dataarray created for feature {feat} provided with " 

568 f"traj {traj} at frame indices {block} did not contain " 

569 f"any data." 

570 ) 

571 

572 # append the DataArray to the DAs dictionary 

573 DAs.setdefault(name, []).append(dataarray) 

574 if ind_dataarray is not None: 

575 indexes.setdefault(name + "_feature_indices", []).append( 

576 ind_dataarray 

577 ) 

578 else: 

579 indexes[name + "_feature_indices"] = [None] 

580 

581 # after every traj, we combine the datasets 

582 for key, value in DAs.items(): 

583 DAs[key] = xr.concat( 

584 DAs[key], 

585 "frame_num", 

586 combine_attrs=trajs_combine_attrs, 

587 ) 

588 # we only need any component from the indexes but make sure that 

589 # they are homogeneous. Every block of a traj should return 

590 # the same index array, as they don't depend on frame data 

591 if indexes[key + "_feature_indices"][0] is None: 

592 assert all( 

593 [i is None for i in indexes[key + "_feature_indices"][1:]] 

594 ), ( 

595 f"Got an inhomogeneous result for indexes for feature {feat=} " 

596 f"at {frame_indices=} {indexes=}" 

597 ) 

598 indexes[key + "_feature_indices"] = indexes.pop(key + "_feature_indices")[0] 

599 

600 # combine data per traj 

601 DAs_and_indexes = DAs | indexes 

602 DAs_and_indexes = {k: v for k, v in DAs_and_indexes.items() if v is not None} 

603 try: 

604 ds = xr.Dataset( 

605 DAs_and_indexes, 

606 attrs=trajs_combine_attrs( 

607 [v.attrs if v.size > 0 else {} for v in DAs_and_indexes.values()] 

608 ), 

609 ) 

610 except xr.core.merge.MergeError as e: 

611 raise Exception(f"{indexes=}") from e 

612 DSs.append(ds) 

613 

614 # and add to the delayeds if needed 

615 if delayeds is not None: 

616 for (coord, dims), data, unit in zip( 

617 coordinates.items(), 

618 [xyz_traj, time_traj, lengths_traj, angles_traj], 

619 ["nanometers", "picoseconds", "nanometers", "degrees"], 

620 ): 

621 name = f"{coord}_{traj.traj_num}" 

622 delayeds[name] = xr.Variable( 

623 dims=[f"{d}_{traj.traj_num}" for d in dims], 

624 data=da.concatenate(data), 

625 attrs={"units": unit.encode("utf-8")}, 

626 ) 

627 

628 # make a large dataset out of this 

629 ds = xr.concat( 

630 DSs, 

631 data_vars="all", 

632 # compat="broadcast_equals", 

633 # coords="all", 

634 # join="outer", 

635 dim="traj_num", 

636 fill_value=np.nan, 

637 combine_attrs=trajs_combine_attrs, 

638 ) 

639 assert ds, ( 

640 f"Concatenation of chunked datasets yielded empty dataset.\n" 

641 f"{DSs=}\n\n{DAs_and_indexes=}" 

642 ) 

643 

644 return ds, delayeds