Coverage for encodermap/trajinfo/trajinfo_utils.py: 82%

227 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/trajinfo/trajinfo_utils.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Util functions for the `TrajEnsemble` and `SingleTraj` classes. 

23 

24""" 

25 

26 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32from __future__ import annotations 

33 

34import warnings 

35from pathlib import Path 

36 

37import numpy as np 

38 

39from .._optional_imports import _optional_import 

40from ..loading.featurizer import PyEMMAFeaturizer as Featurizer 

41from ..misc.misc import FEATURE_NAMES 

42from ..misc.xarray import construct_xarray_from_numpy 

43 

44############################################################################## 

45# Optional Imports 

46############################################################################## 

47 

48 

49xr = _optional_import("xarray") 

50Feature = _optional_import("pyemma", "coordinates.data.featurization._base.Feature") 

51md = _optional_import("mdtraj") 

52 

53 

54################################################################################ 

55# Typing 

56################################################################################ 

57 

58 

59from typing import TYPE_CHECKING, Literal, Optional, Union 

60 

61if TYPE_CHECKING: 61 ↛ 62line 61 didn't jump to line 62, because the condition on line 61 was never true

62 import mdtraj as md 

63 import xarray as xr 

64 from pyemma.coordinates.data.featurization._base import Feature 

65 

66 from ..loading.featurizer import Featurizer 

67 from .info_single import SingleTraj, TrajEnsemble 

68 

69 SingleTrajFeatureType = Union[ 

70 str, Path, np.ndarray, Feature, xr.Dataset, xr.DataArray, Featurizer 

71 ] 

72 TrajEnsembleFeatureType = Union[ 

73 list[str], list[Path], list[np.ndarray], xr.Dataset, Featurizer, Literal["all"] 

74 ] 

75 Index = Optional[ 

76 Union[tuple[int, list, np.ndarray, slice]], int, list, np.ndarray, slice 

77 ] 

78 

79 

80################################################################################ 

81# Globals 

82################################################################################ 

83 

84 

85CAN_BE_FEATURE_NAME = list(FEATURE_NAMES.keys()) + list(FEATURE_NAMES.values()) 

86__all__ = ["load_CVs_singletraj", "load_CVs_ensembletraj"] 

87 

88 

89################################################################################ 

90# Utils 

91################################################################################ 

92 

93 

94def np_to_xr( 

95 data: np.ndarray, 

96 traj: SingleTraj, 

97 attr_name: Optional[str] = None, 

98 labels: Optional[list[str]] = None, 

99 filename: Optional[Union[str, Path]] = None, 

100) -> xr.DataArray: 

101 if attr_name is None: 

102 if filename is None: 

103 msg = f"Please also provide an `attr_name` under which to save the CV." 

104 raise Exception(msg) 

105 attr_name = Path(filename).stem 

106 

107 if labels is not None: 107 ↛ 108line 107 didn't jump to line 108, because the condition on line 107 was never true

108 if isinstance(labels, str): 

109 labels = [ 

110 f"{attr_name.upper()} {labels.upper()} {i}" 

111 for i in range(data.shape[1]) 

112 ] 

113 elif ( 

114 not all([isinstance(l, str) for l in labels]) 

115 and len(labels) == data.shape[1] 

116 ): 

117 pass 

118 else: 

119 raise Exception( 

120 f"'labels' is either not a list of str or does not have the " 

121 f"same number of datapoints as self.n_frames={data.n_frames}: {labels=}" 

122 ) 

123 data = np.expand_dims(data, axis=0) 

124 if np.any(np.isnan(data)): 

125 # if some nans are found along frame remove them 

126 if data.ndim == 2: 

127 data = data[:, ~np.isnan(data).any(axis=0)] 

128 if data.ndim == 3: 

129 data = data[:, ~np.isnan(data).any(axis=2)[0]] 

130 if data.ndim == 4: 

131 data = data[:, ~np.isnan(data).any(axis=2)[0].any(axis=1)] 

132 da = construct_xarray_from_numpy(traj, data, attr_name, labels, check_n_frames=True) 

133 return da 

134 

135 

136def load_CV_from_string_or_path( 

137 file_or_feature: str, 

138 traj: SingleTraj, 

139 attr_name: Optional[str] = None, 

140 cols: Optional[Union[int, list[int]]] = None, 

141 labels: Optional[list[str]] = None, 

142) -> xr.Dataset: 

143 """Loads CV data from a string. That string can either identify a features, 

144 or point to a file. 

145 

146 Args: 

147 file_or_feature (str): The file or feature to load. If 'all' is 

148 provided, all "standard" features are loaded. But a feature name 

149 like 'sidechain_angle' can alsop be provided. If a file with 

150 the .txt or .npy extension is provided, the data in that file is used. 

151 traj (SingleTraj): The trajectory, that is used to load the features. 

152 attr_name (Union[None, str], optional): The name under which the CV should be found in the class. 

153 Is needed, if a raw numpy array is passed, otherwise the name will be generated from the filename 

154 (if data == str), the DataArray.name (if data == xarray.DataArray), or the feature name. 

155 cols (Union[list, None], optional): A list specifying the columns to use for the highD data. 

156 If your highD data contains (x,y,z,...)-errors or has an enumeration 

157 column at col=0 this can be used to remove this unwanted data. 

158 labels (Union[list, str, None], optional): If you want to label the data you provided pass a list of str. 

159 If set to None, the features in this dimension will be labelled as 

160 [f"{attr_name.upper()} FEATURE {i}" for i in range(self.n_frames)]. If a str is provided, the features 

161 will be labelled as [f"{attr_name.upper()} {label.upper()} {i}" for i in range(self.n_frames)]. If a list of str 

162 is provided it needs to have the same length as the traj has frames. Defaults to None. 

163 

164 Returns: 

165 xr.Dataset: An xarray dataset. 

166 

167 """ 

168 if str(file_or_feature) == "all" or str(file_or_feature) in CAN_BE_FEATURE_NAME: 

169 feat = Featurizer(traj) 

170 if file_or_feature == "all": 

171 feat.add_list_of_feats("all") 

172 else: 

173 feat.add_list_of_feats([file_or_feature]) 

174 out = feat.get_output() 

175 if traj.traj_num is not None: 

176 assert out.coords["traj_num"] == np.array([traj.traj_num]), print( 

177 traj.traj_num, 

178 out.coords["traj_num"].values, 

179 feat.trajs.trajs[0].traj_num, 

180 ) 

181 return out 

182 elif (f := Path(file_or_feature)).exists(): 

183 if f.suffix == ".txt": 

184 data = np.loadtxt(f, usecols=cols) 

185 elif f.suffix == ".npy": 

186 data = np.load(f) 

187 if cols is not None: 187 ↛ 188line 187 didn't jump to line 188, because the condition on line 187 was never true

188 data = data[:, cols] 

189 elif f.suffix in [".nc", ".h5"]: 189 ↛ 210line 189 didn't jump to line 210, because the condition on line 189 was never false

190 data = xr.open_dataset(f) 

191 if len(data.data_vars.keys()) != 1: 

192 if attr_name is not None: 192 ↛ 203line 192 didn't jump to line 203, because the condition on line 192 was never false

193 raise Exception( 

194 f"The dataset in {f} has " 

195 f"{len(data.data_vars.keys())} dataarrays, " 

196 f"but only one `attr_name`: '{attr_name}' " 

197 f"was requested. The names of the dataarrays " 

198 f"are: {data.data_vars.keys()}. I can't over" 

199 f"ride them all with one `attr_name`. Set " 

200 f"`attr_name` to None to load the data with " 

201 f"their respective names" 

202 ) 

203 return data 

204 else: 

205 if attr_name is not None: 205 ↛ 208line 205 didn't jump to line 208, because the condition on line 205 was never false

206 d = list(data.data_vars.values())[0] 

207 d.name = attr_name 

208 return d 

209 else: 

210 raise Exception( 

211 f"Currently only .txt, .npy, .nc, and .h5 files can " 

212 f"be loaded. Your file {f} does not have the " 

213 f"correct extension." 

214 ) 

215 else: 

216 raise Exception( 

217 f"If features are loaded via a string, the string needs " 

218 f"to be 'all', a features name ('central_dihedrals') or " 

219 f'an existing file. Your string "{file_or_feature}"' 

220 f"is none of those" 

221 ) 

222 

223 return np_to_xr(data, traj, attr_name, labels, file_or_feature) 

224 

225 

226def load_CVs_singletraj( 

227 data: SingleTrajFeatureType, 

228 traj: SingleTraj, 

229 attr_name: Optional[str] = None, 

230 cols: Optional[list[int]] = None, 

231 labels: Optional[list[str]] = None, 

232) -> xr.Dataset: 

233 if isinstance(attr_name, str): 

234 if not attr_name.isidentifier(): 234 ↛ 235line 234 didn't jump to line 235, because the condition on line 234 was never true

235 raise Exception( 

236 f"Provided string for `attr_name` can not be a " 

237 f"python identifier. Choose another attribute name." 

238 ) 

239 # load a string 

240 if isinstance(data, (str, Path)): 

241 CVs = load_CV_from_string_or_path(str(data), traj, attr_name, cols, labels) 

242 

243 # load a list of strings from standard features 

244 elif isinstance(data, list) and all([isinstance(_, str) for _ in data]): 

245 feat = Featurizer(traj) 

246 feat.add_list_of_feats(data) 

247 return feat.get_output() 

248 

249 # if the data is a numpy array 

250 elif isinstance(data, (list, np.ndarray)): 

251 CVs = np_to_xr(np.asarray(data), traj, attr_name, labels).to_dataset() 

252 

253 # xarray objects are simply returned 

254 elif isinstance(data, xr.Dataset): 

255 return data 

256 

257 elif isinstance(data, xr.DataArray): 

258 return data.to_dataset() 

259 

260 # if this is a feature 

261 elif issubclass(data.__class__, Feature): 

262 feat = Featurizer(traj) 

263 feat.add_custom_feature(data) 

264 return feat.get_output() 

265 

266 # if an instance of featurizer is provided 

267 elif isinstance(data, Featurizer): 267 ↛ 301line 267 didn't jump to line 301, because the condition on line 267 was never false

268 if isinstance(attr_name, str): 

269 if len(data) != 1: 

270 raise TypeError( 

271 f"Provided Featurizer contains {len(data)} " 

272 f"features and `attr_name` is of type `str`. " 

273 f"Please provide a list of str." 

274 ) 

275 attr_name = [attr_name] 

276 if isinstance(attr_name, list): 

277 if len(attr_name) != len(data): 277 ↛ 278line 277 didn't jump to line 278, because the condition on line 277 was never true

278 raise IndexError( 

279 f"Provided Featurizer contains {len(data)} " 

280 f"features and `attr_name` contains " 

281 f"{len(attr_name)} elements. Please make sure " 

282 f"they contain the same amount of items." 

283 ) 

284 out = data.get_output() 

285 if attr_name is not None: 

286 if isinstance(attr_name, str): 286 ↛ 287line 286 didn't jump to line 287, because the condition on line 286 was never true

287 attr_name = [attr_name] 

288 _renaming = {} 

289 for f, v in zip(data.features, attr_name): 

290 _feature = False 

291 if hasattr(f, "name"): 291 ↛ 295line 291 didn't jump to line 295, because the condition on line 291 was never false

292 if f.name in FEATURE_NAMES: 

293 k = FEATURE_NAMES[f.name] 

294 _feature = True 

295 if not _feature: 

296 k = f.__class__.__name__ 

297 _renaming[k] = v 

298 out = out.rename_vars(_renaming) 

299 return out 

300 else: 

301 raise TypeError( 

302 f"`data` must be str, np.ndarray, list, xr.DataArray, xr.Dataset, " 

303 f"em.Featurizer or em.features.Feature. You supplied " 

304 f"{type(data)}." 

305 ) 

306 

307 return CVs 

308 

309 

310def load_CVs_ensembletraj( 

311 trajs: TrajEnsemble, 

312 data: TrajEnsembleFeatureType, 

313 attr_name: Optional[list[str]] = None, 

314 cols: Optional[list[int]] = None, 

315 labels: Optional[list[str]] = None, 

316 directory: Optional[Union[Path, str]] = None, 

317 ensemble: bool = False, 

318) -> None: 

319 if isinstance(data, (str, Path)) and not ensemble: 

320 path_data = Path(data) 

321 npy_files = [ 

322 (t._traj_file.parent if directory is None else Path(directory)) 

323 / (t.basename + f"_{data}.npy") 

324 for t in trajs 

325 ] 

326 txt_files = [ 

327 (t._traj_file.parent if directory is None else Path(directory)) 

328 / (t.basename + f"_{data}.txt") 

329 for t in trajs 

330 ] 

331 raw_files = [ 

332 (t._traj_file.parent if directory is None else Path(directory)) 

333 / (t.basename + f"_{data}") 

334 for t in trajs 

335 ] 

336 if str(data) == "all": 

337 [t.load_CV("all") for t in trajs] 

338 return 

339 if path_data.is_dir(): 

340 return load_CVs_from_dir(trajs, data, attr_name=attr_name, cols=cols) 

341 elif data in CAN_BE_FEATURE_NAME: 341 ↛ 342line 341 didn't jump to line 342, because the condition on line 341 was never true

342 [t.load_CV(data, attr_name, cols, labels) for t in trajs] 

343 return 

344 elif path_data.is_file() and ( 

345 path_data.suffix == ".h5" or path_data.suffix == ".nc" 

346 ): 

347 ds = xr.open_dataset(path_data) 

348 if diff := set([t.traj_num for t in trajs]) - set(ds["traj_num"].values): 348 ↛ 349line 348 didn't jump to line 349, because the condition on line 348 was never true

349 raise Exception( 

350 f"The dataset you try to load and the TrajEnsemble " 

351 f"have different number of trajectories: {diff}." 

352 ) 

353 for t, (traj_num, sub_ds) in zip(trajs, ds.groupby("traj_num")): 

354 assert t.traj_num == traj_num 

355 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num) 

356 sub_ds = sub_ds.expand_dims("traj_num") 

357 assert sub_ds.coords["traj_num"] == np.array([t.traj_num]) 

358 t.load_CV(sub_ds) 

359 return 

360 elif all([f.is_file() for f in npy_files]): 

361 [ 

362 t.load_CV(f, attr_name=data, cols=cols, labels=labels) 

363 for t, f in zip(trajs, npy_files) 

364 ] 

365 return 

366 elif all([f.is_file() for f in txt_files]): 366 ↛ 372line 366 didn't jump to line 372, because the condition on line 366 was never false

367 [ 

368 t.load_CV(f, attr_name=data, cols=cols, labels=labels) 

369 for t, f in zip(trajs, txt_files) 

370 ] 

371 return 

372 elif all([f.is_file() for f in raw_files]): 

373 [ 

374 t.load_CV(f, attr_name=data, cols=cols, labels=labels) 

375 for t, f in zip(trajs, raw_files) 

376 ] 

377 return 

378 else: 

379 msg = ( 

380 f"If `data` is provided a single string, the string needs to " 

381 f"be either a feature ({CAN_BE_FEATURE_NAME}), a .h5/.nc file " 

382 f"({file}), or a list of npy/txt files ({npy_files}, " 

383 f"{txt_files}). The provided `data` fits none of " 

384 f"these possibilities." 

385 ) 

386 raise ValueError(msg) 

387 

388 elif isinstance(data, list) and not ensemble: 

389 if all([i in CAN_BE_FEATURE_NAME for i in data]): 

390 [t.load_CV(data, attr_name, cols, labels) for t in trajs] 

391 return 

392 elif all([isinstance(i, (list, np.ndarray)) for i in data]): 

393 [t.load_CV(d, attr_name, cols, labels) for t, d in zip(trajs, data)] 

394 return 

395 elif all([Path(f).is_file() for f in data]): 

396 suffix = set([Path(f).suffix for f in data]) 

397 if len(suffix) != 1: 397 ↛ 398line 397 didn't jump to line 398, because the condition on line 397 was never true

398 raise Exception( 

399 "Please provide a list with consistent file " 

400 f"extensions and not a mish-mash, like: {suffix}" 

401 ) 

402 suffix = suffix.pop() 

403 if suffix == ".npy": 

404 [ 

405 t.load_CV(np.load(d), attr_name, cols, labels) 

406 for t, d in zip(trajs, data) 

407 ] 

408 else: 

409 [ 

410 t.load_CV(np.genfromtxt(d), attr_name, cols, labels) 

411 for t, d in zip(trajs, data) 

412 ] 

413 return 

414 else: 

415 msg = ( 

416 f"If `data` is provided as a list, the list needs to contain " 

417 f"strings that can be features ({CAN_BE_FEATURE_NAME}), or " 

418 f"some combination of lists and numpy arrays." 

419 ) 

420 raise ValueError(msg) 

421 

422 elif isinstance(data, np.ndarray): 

423 if len(data) != trajs.n_trajs and len(data) != trajs.n_frames: 

424 raise ValueError( 

425 f"The provided numpy array is misshaped. It needs " 

426 f"to be of shape (n_trajs={trajs.n_trajs}, " 

427 f"n_frames={np.unique([t.n_frames for t in trajs])[0]}, " 

428 f"X, (Y)), but is {data.shape}." 

429 ) 

430 if len(data) == trajs.n_frames: 

431 data = [data[t.id[:, 1]] for t in trajs] 

432 [t.load_CV(d, attr_name, cols, labels) for t, d in zip(trajs, data)] 

433 for t in trajs: 

434 for v in t._CVs.values(): 

435 assert v.shape[0] == 1, print(t.basename, v) 

436 return 

437 

438 elif isinstance(data, Featurizer): 

439 ds = data.get_output() 

440 for t, (traj_num, sub_ds) in zip(trajs, ds.groupby("traj_num")): 

441 assert t.traj_num == traj_num 

442 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num) 

443 sub_ds = sub_ds.expand_dims("traj_num") 

444 t.load_CV(sub_ds) 

445 return 

446 

447 elif isinstance(data, xr.Dataset): 447 ↛ 448line 447 didn't jump to line 448, because the condition on line 447 was never true

448 for t, (traj_num, sub_ds) in zip(trajs, data.groupby("traj_num")): 

449 assert t.traj_num == traj_num 

450 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num) 

451 sub_ds = sub_ds.expand_dims("traj_num") 

452 t.load_CV(sub_ds) 

453 return 

454 

455 if ensemble: 455 ↛ 459line 455 didn't jump to line 459, because the condition on line 455 was never false

456 return load_CVs_ensemble(trajs, data) 

457 

458 else: 

459 raise TypeError( 

460 f"`data` must be str, np.ndarray, list, xr.Dataset" 

461 f"em.Featurizer or. You supplied {type(data)}." 

462 ) 

463 

464 

465def load_CVs_ensemble( 

466 trajs: TrajEnsemble, 

467 data: Union[str, list[str]], 

468) -> None: 

469 if isinstance(data, str): 469 ↛ 472line 469 didn't jump to line 472, because the condition on line 469 was never false

470 if data != "all": 470 ↛ 471line 470 didn't jump to line 471, because the condition on line 470 was never true

471 data = [data] 

472 feat = Featurizer(trajs) 

473 feat.add_list_of_feats(data) 

474 for t, (traj_num, sub_ds) in zip(trajs, feat.get_output().groupby("traj_num")): 

475 assert t.traj_num == traj_num 

476 sub_ds = sub_ds.assign_coords(traj_num=t.traj_num) 

477 sub_ds = sub_ds.expand_dims("traj_num") 

478 if t._CVs: 478 ↛ 479line 478 didn't jump to line 479, because the condition on line 478 was never true

479 warnings.warn( 

480 "Using ensemble=True will drop old CV entries from " 

481 "trajs, because the ferature length increases." 

482 ) 

483 t._CVs = sub_ds 

484 

485 

486def load_CVs_from_dir( 

487 trajs: TrajEnsemble, 

488 data: Path, 

489 attr_name: Optional[str] = None, 

490 cols: Optional[list[int]] = None, 

491) -> None: 

492 files = map(str, data.glob("*")) 

493 files = list( 

494 filter( 

495 lambda x: True if any([traj.basename in x for traj in trajs]) else False, 

496 files, 

497 ) 

498 ) 

499 key = {"npy": 1, "txt": 2} 

500 files = sorted( 

501 files, 

502 key=lambda x: key[x.split(".")[-1]] if x.split(".")[-1] in key else 3, 

503 )[: trajs.n_trajs] 

504 files = sorted( 

505 files, 

506 key=lambda x: [traj.basename in x for traj in trajs].index(True), 

507 ) 

508 for traj, f in zip(trajs, files): 

509 if traj.basename not in f: 509 ↛ 510line 509 didn't jump to line 510, because the condition on line 509 was never true

510 raise Exception(f"File {f} does not contain substring of traj {traj}.") 

511 traj.load_CV(f, attr_name=attr_name, cols=cols) 

512 return