Coverage for encodermap/kondata.py: 7%

223 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-15 22:46 +0200

1# -*- coding: utf-8 -*- 

2# encodermap/kondata.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Functions for interfacing with the University of Konstanz's repository service KonDATA. 

23 

24""" 

25 

26 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32# Future Imports at the top 

33from __future__ import annotations 

34 

35# Standard Library Imports 

36import getpass 

37import importlib.util 

38import shutil 

39import tarfile 

40import time 

41import warnings 

42from collections.abc import Generator 

43from pathlib import Path 

44from typing import Any, Optional, Union 

45 

46# Third Party Imports 

47import requests 

48from optional_imports import _optional_import 

49from requests.auth import HTTPBasicAuth 

50from rich.status import Status 

51 

52 

53################################################################################ 

54# Optional Imports 

55################################################################################ 

56 

57 

58BeautifulSoup = _optional_import("bs4", "BeautifulSoup") 

59 

60 

61################################################################################ 

62# Globals 

63################################################################################ 

64 

65 

66__all__: list[str] = ["get_from_kondata"] 

67DATASET_URL_MAPPING = { 

68 "test": "https://dx.doi.org/10.48606/108", 

69 "H1Ub": "https://dx.doi.org/10.48606/99", 

70 "linear_dimers": "https://dx.doi.org/10.48606/rtda48r76y1e71cc", 

71 "Ub_K11_mutants": "https://dx.doi.org/10.48606/rx7md621k9zawfz7", 

72} 

73 

74 

75################################################################################ 

76# Functions 

77################################################################################ 

78 

79 

80def untar( 

81 tar_file: Path, 

82 doi_tar: str, 

83 output: Path, 

84 force_overwrite: bool = False, 

85 silence_overwrite: bool = False, 

86) -> list[str]: 

87 """Untars files in a tar archive downloaded from KonDATA. 

88 

89 As the files are found under {doi_tar}/data/dataset/, only this directory is considered. 

90 The other directories in the tar file contain RADAR metadata, that is not important. 

91 

92 Args: 

93 tar_file (Path): The pathlib.Path to the tar file. 

94 doi_tar (str): The doi_tar of the tar file. This arg is needed to find the 

95 datafiles in {doi_tar}/data/dataset. The doi_tar can be obtained from 

96 the doi of the dataset by removing 'https://dx.doi.org/' and replacing 

97 '/' in the doi with '-'. 

98 output (Path): The directory, where to put the files of the archive. 

99 force_overwrite (bool): Whether to overwrite files that are already there. 

100 Will print 'file already exists' if file exists and `force_overwrite` 

101 is set to False. Defaults to False. 

102 silence_overwrite (bool): Will not print information about already 

103 existing files. Useful for when this function is used in a script. 

104 

105 Returns: 

106 list[str]: A list of the new files. 

107 

108 """ 

109 untarred_files: list[str] = [] 

110 if not silence_overwrite: 

111 print(f"{tar_file} already exists. Set `force_overwrite` to True to overwrite.") 

112 with tarfile.open(tar_file) as tar: 

113 for member in tar.getmembers(): 

114 if ( 

115 member.path.startswith(f"{doi_tar}/data/dataset") 

116 and member.path != f"{doi_tar}/data/dataset" 

117 ): 

118 filename = member.path.lstrip(f"{doi_tar}/data/dataset/") 

119 extract_to = output / filename 

120 if extract_to.exists() and not force_overwrite: 

121 if not silence_overwrite: 

122 print( 

123 f"{extract_to} already exists. Set " 

124 f"`force_overwrite` to True to overwrite." 

125 ) 

126 else: 

127 if member.isdir(): 

128 extract_to.mkdir(parents=True, exist_ok=True) 

129 elif member.isfile(): 

130 untarred_files.append(str(extract_to)) 

131 tar.makefile(member, extract_to) 

132 else: 

133 print(f"Unknown TarInfo type: {member}") 

134 return untarred_files 

135 

136 

137def get_from_kondata( 

138 dataset_name: str, 

139 output: Optional[Union[str, Path]] = None, 

140 force_overwrite: bool = False, 

141 mk_parentdir: bool = False, 

142 silence_overwrite_message: bool = False, 

143 tqdm_class: Optional[Any] = None, 

144 download_extra_data: bool = False, 

145 download_checkpoints: bool = False, 

146 download_h5: bool = True, 

147) -> str: 

148 """Get dataset from the University of Konstanz's data repository KONData. 

149 

150 Args: 

151 dataset_name (str): The name of the dataset. Refer to `DATASET_URL_MAPPING` 

152 to get a list of the available datasets. 

153 output (Union[str, Path]): The output directory. 

154 force_overwrite (bool): Whether to overwrite existing files. Defaults to False. 

155 mk_parentdir (bool): Whether to create the `output` directory if it does 

156 not already exist. Defaults to False. 

157 silence_overwrite_message (bool): Whether to silence the 'file already exists' 

158 warning. Can be useful in scripts. Defaults to False. 

159 tqdm_class (Optional[Any]): A class that is similar to tqdm.tqdm. This 

160 is mainly useful if this function is used inside a `rich.status.Status` 

161 context manager, as the normal tqdm does not work inside this context. 

162 If None is provided, the default tqdm will be used. 

163 download_extra_data (bool): Whether to download extra data. It Is only used 

164 if the dataset is not available on KonDATA. Defaults to False. 

165 download_checkpoints (bool): Whether to download pretrained checkpoints. 

166 It is only used if the dataset is not available on KonDATA. 

167 Defaults to False. 

168 download_h5 (bool): Whether to also download an h5 file of the 

169 ensemble. Defaults to True. 

170 

171 Returns: 

172 str: The output directory. 

173 

174 """ 

175 # Local Folder Imports 

176 from .misc.misc import _is_notebook 

177 

178 if dataset_name not in DATASET_URL_MAPPING: 

179 return get_from_url( 

180 f"https://encodermap.site/encodermap_data/{dataset_name}", 

181 output=output, 

182 force_overwrite=force_overwrite, 

183 mk_parentdir=mk_parentdir, 

184 silence_overwrite_message=silence_overwrite_message, 

185 tqdm_class=tqdm_class, 

186 download_extra_data=download_extra_data, 

187 download_checkpoints=download_checkpoints, 

188 download_h5=download_h5, 

189 ) 

190 if output is None: 

191 # Standard Library Imports 

192 import pkgutil 

193 

194 package = pkgutil.get_loader("encodermap") 

195 if package is None: 

196 output = Path("~").resolve() / f".encodermap_data/{dataset_name}" 

197 else: 

198 emfile = package.get_filename() 

199 output = Path(emfile).parent.parent / "tests" 

200 if not output.is_dir(): 

201 output = Path("~").resolve() / f".encodermap_data/{dataset_name}" 

202 else: 

203 output /= f"data/{dataset_name}" 

204 output.parent.mkdir(exist_ok=True) 

205 if not output.parent.is_dir(): # pragma: nocover 

206 question = input( 

207 f"I will create the directory {output.parent} and download " 

208 f"the dataset {dataset_name} to it." 

209 ) 

210 if question.lower() not in ["y", "ye", "yes"]: 

211 raise Exception(f"User has answered to not overwrite {output.parent}.") 

212 output.mkdir(parents=True, exist_ok=True) 

213 

214 if dataset_name == "trained_networks": 

215 warnings.warn("Implement") 

216 return 

217 

218 # in all other cases make sure its path 

219 output = Path(output) 

220 

221 if tqdm_class is None: 

222 if _is_notebook(): 

223 # Third Party Imports 

224 from tqdm.notebook import tqdm as tqdm_class # type: ignore[no-redef] 

225 else: 

226 # Third Party Imports 

227 from tqdm import tqdm as tqdm_class # type: ignore[no-redef] 

228 assert tqdm_class is not None 

229 

230 if not (output := Path(output)).is_dir(): 

231 if not mk_parentdir: 

232 raise Exception( 

233 f"Output directory {output} does not exists. Create it " 

234 f"or set `mk_parentdir` to True." 

235 ) 

236 else: 

237 output.mkdir() 

238 

239 # clear partial downloads 

240 partial_downloads = output.glob("*.crdownload") 

241 for partial_download in partial_downloads: 

242 partial_download.unlink() 

243 

244 # define stuff for the tarfile and maybe unpack the files if not already done so 

245 doi_tar = ( 

246 DATASET_URL_MAPPING[dataset_name] 

247 .lstrip("https://dx.doi.org/") 

248 .replace("/", "-") 

249 ) 

250 tar_file = output / f"{doi_tar}.tar" 

251 if tar_file.is_file() and not force_overwrite: 

252 untarred_files = untar( 

253 tar_file, doi_tar, output, force_overwrite, silence_overwrite_message 

254 ) 

255 if not force_overwrite: 

256 return str(output) 

257 

258 # instantiate the status 

259 s = Status(f"Downloading {dataset_name}", spinner="material") 

260 s.start() 

261 

262 # use selenium to download the archive 

263 url = DATASET_URL_MAPPING[dataset_name] 

264 try: 

265 # Third Party Imports 

266 from selenium import webdriver 

267 from selenium.webdriver.common.by import By 

268 except ImportError as e: 

269 raise Exception( 

270 f"Programmatically downloading from KonDATA requires selenium. Beware " 

271 f"it uses Google code to interact with web pages. Either " 

272 f"install it with `pip install selenium`, or manually download the " 

273 f"files from {url} and un-tar them to {output}" 

274 ) from e 

275 prefs = {"download.default_directory": str(output)} 

276 options = webdriver.ChromeOptions() 

277 options.add_experimental_option("prefs", prefs) 

278 options.add_argument("--headless=new") 

279 driver = webdriver.Chrome(options=options) 

280 driver.get(url) 

281 download_element = driver.find_element( 

282 By.XPATH, '//a[starts-with(@class, "dataset-download")]' 

283 ) 

284 download_element.click() 

285 accept_element = driver.find_element(By.XPATH, '//button[text()="Accept"]') 

286 accept_element.click() 

287 

288 # play an animation while downloading 

289 i = 0 

290 while True: 

291 time.sleep(0.2) 

292 i += 1 

293 if tar_file.is_file(): 

294 s.stop() 

295 break 

296 copied_tar = output / f"{doi_tar} (1).tar" 

297 if copied_tar.is_file(): 

298 shutil.move(copied_tar, tar_file) 

299 untarred_files = untar( 

300 tar_file, doi_tar, output, force_overwrite, silence_overwrite=True 

301 ) 

302 

303 if dataset_name == "H1Ub" and download_h5: 

304 trajs_file = output / "trajs.h5" 

305 if trajs_file.is_file() and not force_overwrite: 

306 print(f"trajs.h5 file for H1Ub already present.") 

307 else: 

308 print(f"The H1Ub dataset does not contain a trajs.h5 file.") 

309 

310 return str(output) 

311 

312 

313def get_assign_from_file(file: Path, assign: str) -> str: 

314 """Reads a file and extracts lines with assignments. Can be 

315 used for reading simple secret files which look like:: 

316 PASSWORD=my_sekret_password 

317 USERNAME=oll_korrect_username 

318 

319 Args: 

320 file (Path): The file. 

321 assign (str): The string to look for (e.g. PASSWORD). 

322 

323 Returns: 

324 str: The assignment after the equal (=) sign. 

325 

326 """ 

327 content = file.read_text().splitlines() 

328 content_line = content[[c.startswith(assign) for c in content].index(True)] 

329 return content_line.split("=")[-1] 

330 

331 

332def is_directory(url: str) -> bool: 

333 """Returns, whether a string ends with '/' which makes that an url of a dir. 

334 

335 Args: 

336 url (str): The url. 

337 

338 Returns: 

339 bool: Whether that string ends with '/'. 

340 

341 """ 

342 if url.endswith("/"): 

343 return True 

344 return False 

345 

346 

347def find_links(url: str, auth: Optional[HTTPBasicAuth] = None) -> Generator: 

348 """Recourses through an html content file with beautifulsoup and extracts links. 

349 

350 Can be used to mimic `wget -R` with python. 

351 

352 Args: 

353 url (str): The url to recourse. 

354 auth (Optional[HTTPBasicAuth]): The authentication to use. Can be None 

355 for unprotected urls. Can be an instance of `requests.auth.HTTPBasicAuth`. 

356 

357 Yields: 

358 tuple: A tuple of the following: 

359 str: The complete link to the file. 

360 str: The truncated link (without the `url` substring), which can 

361 be used to set the filename on disk, the link will be downloaded 

362 to. 

363 

364 """ 

365 organization_names = ["?C=N;O=D", "?C=M;O=A", "?C=S;O=A", "?C=D;O=A"] 

366 content = requests.get(url, auth=auth).content 

367 soup = BeautifulSoup(content, "html.parser") 

368 maybe_directories = soup.find_all("a", href=True) 

369 for link in maybe_directories: 

370 if is_directory(link["href"]) and "Parent Directory" not in link.text: 

371 if not url.endswith("/"): 

372 new_url = url + "/" + link["href"] 

373 else: 

374 new_url = url + link["href"] 

375 yield from find_links(new_url, auth) 

376 else: 

377 if link["href"] not in organization_names: 

378 filename = url + "/" + link["href"] 

379 if not filename.endswith("/"): 

380 yield filename 

381 

382 

383def get_from_url( 

384 url: str, 

385 output: Optional[Union[str, Path]] = None, 

386 username: Optional[str] = None, 

387 password: Optional[str] = None, 

388 force_overwrite: bool = False, 

389 mk_parentdir: bool = False, 

390 silence_overwrite_message: bool = False, 

391 tqdm_class: Optional[Any] = None, 

392 download_extra_data: bool = False, 

393 download_checkpoints: bool = False, 

394 download_h5: bool = True, 

395 combine_progbars: bool = False, 

396) -> str: 

397 """Recourses through `url` and downloads all strings into `output`. 

398 

399 Args: 

400 url (str): The url to visit. 

401 output (Optional[Union[str, Path]]): Where to put the files. 

402 username (Optional[str]): The username for protected sites. If the site 

403 is protected and this arg is None, the `input()` builtin will be used 

404 to get the username. 

405 password (Optional[str]): The password for protected sites. If the site 

406 is protected and this arg is None, the `getpass()` method will be 

407 used to get the password. 

408 force_overwrite (bool): Whether to overwrite existing files. 

409 mk_parentdir (bool): Whether to create the `output` directory in case it 

410 is missing. 

411 tqdm_class (Optional[Any]): A class implementing a tqdm feature. 

412 download_extra_data (bool): Whether to download (potentially) large 

413 extra data. Mainly useful for unittests. Defaults to False. 

414 download_checkpoints (bool): Whether to download checkpoints. 

415 Good for skipping long training. 

416 combine_progbars (bool): Whether to make the download print one long 

417 progression bar. 

418 

419 Returns: 

420 str: The output directory. 

421 

422 """ 

423 # Local Folder Imports 

424 from .misc.misc import _is_notebook 

425 

426 if "encodermap.site" in url: 

427 dataset_name = url.replace("https://encodermap.site/encodermap_data/", "") 

428 

429 if not url.startswith("https"): 

430 raise Exception(f"Not downloading from {url=}. Missing https.") 

431 

432 if output is None: 

433 # Standard Library Imports 

434 import pkgutil 

435 

436 package = importlib.util.find_spec("encodermap") 

437 if package is None: 

438 output = Path("~").resolve() / f".encodermap_data/{dataset_name}" 

439 else: 

440 emfile = package.origin 

441 output = Path(emfile).parent.parent / "tests" 

442 if not output.is_dir(): 

443 output = Path("~").resolve() / f".encodermap_data/{dataset_name}" 

444 else: 

445 output /= f"data/{dataset_name}" 

446 output.parent.mkdir(exist_ok=True) 

447 if not output.parent.is_dir(): # pragma: nocover 

448 question = input( 

449 f"I will create the directory {output.parent} and download " 

450 f"the dataset {dataset_name} to it." 

451 ) 

452 if question.lower() not in ["y", "ye", "yes"]: 

453 raise Exception(f"User has answered to not overwrite {output.parent}.") 

454 output.mkdir(parents=True, exist_ok=True) 

455 

456 if tqdm_class is None: 

457 if _is_notebook(): 

458 # Third Party Imports 

459 from tqdm.notebook import tqdm as tqdm_class # type: ignore[no-redef] 

460 else: 

461 # Third Party Imports 

462 from tqdm import tqdm as tqdm_class # type: ignore[no-redef] 

463 assert tqdm_class is not None 

464 

465 downloaded_files: list[str] = [] 

466 if not (output := Path(output)).is_dir(): 

467 if not mk_parentdir: 

468 raise Exception( 

469 f"Output directory {output} does not exists. Create it " 

470 f"or set `mk_parentdir` to True." 

471 ) 

472 else: 

473 output.mkdir() 

474 # check for the act.vault file 

475 try: 

476 status_code = requests.get(url, timeout=3).status_code 

477 except requests.exceptions.Timeout: 

478 warnings.warn(f"EncoderMap's repository at {url} timed out.") 

479 return str(output) 

480 except requests.exceptions.ConnectionError: 

481 warnings.warn(f"EncoderMap's repository at {url} is unreachable.") 

482 return str(output) 

483 if status_code == 401: 

484 vault_file = Path(__file__).resolve().parent.parent / "act.vault" 

485 if vault_file.is_file() and username is None: 

486 username = get_assign_from_file(vault_file, "ENCODERMAP_DATA_USER") 

487 elif username is None: 

488 username = input("Please enter the username: ") 

489 if vault_file.is_file() and password is None: 

490 password = get_assign_from_file(vault_file, "ENCODERMAP_DATA_PASSWORD") 

491 elif password is None: 

492 password = getpass.getpass("Please enter the password: ") 

493 auth = HTTPBasicAuth(username, password) 

494 elif status_code == 200: 

495 auth = None 

496 else: 

497 raise Exception(f"Url {url} returned error: {status_code}") 

498 

499 # try whether the password works 

500 status_code_with_auth = requests.get(url, auth=auth).status_code 

501 if status_code_with_auth == 401: 

502 raise Exception(f"Wrong username/password.") 

503 

504 files = list(find_links(url, auth)) 

505 in_files = [] 

506 out_files = [] 

507 for f1 in files: 

508 in_files.append(f1) 

509 o = f1.replace(url, "").replace("//", "/").lstrip("/") 

510 f2 = output / o 

511 out_files.append(f2) 

512 

513 if combine_progbars: 

514 raise NotImplementedError 

515 

516 for in_file, out_file in zip(in_files, out_files): 

517 if out_file.is_file() and not force_overwrite: 

518 if not silence_overwrite_message: 

519 print( 

520 f"{out_file} already exists. Set `force_overwrite` to True to overwrite." 

521 ) 

522 continue 

523 if "extra_data" in str(in_file) and not download_extra_data: 

524 continue 

525 if "checkpoints" in str(in_file) and not download_checkpoints: 

526 continue 

527 if str(in_file).split(".")[-1] == "h5" and not download_h5: 

528 print(f"Skipping {in_file}") 

529 continue 

530 out_file.parent.mkdir(parents=True, exist_ok=True) 

531 response = requests.get(in_file, auth=auth, stream=True) 

532 total_length = int(response.headers.get("content-length", 0)) 

533 with ( 

534 open(out_file, "wb") as file, 

535 tqdm_class( 

536 desc=str(out_file), 

537 total=total_length, 

538 unit="iB", 

539 unit_scale=True, 

540 unit_divisor=1024, 

541 ) as bar, 

542 ): 

543 for data in response.iter_content(chunk_size=1024): 

544 size = file.write(data) 

545 bar.update(size) 

546 downloaded_files.append(str(out_file)) 

547 return str(output)