Coverage for encodermap/kondata.py: 7%
223 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 22:46 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2025-05-15 22:46 +0200
1# -*- coding: utf-8 -*-
2# encodermap/kondata.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Functions for interfacing with the University of Konstanz's repository service KonDATA.
24"""
27################################################################################
28# Imports
29################################################################################
32# Future Imports at the top
33from __future__ import annotations
35# Standard Library Imports
36import getpass
37import importlib.util
38import shutil
39import tarfile
40import time
41import warnings
42from collections.abc import Generator
43from pathlib import Path
44from typing import Any, Optional, Union
46# Third Party Imports
47import requests
48from optional_imports import _optional_import
49from requests.auth import HTTPBasicAuth
50from rich.status import Status
53################################################################################
54# Optional Imports
55################################################################################
58BeautifulSoup = _optional_import("bs4", "BeautifulSoup")
61################################################################################
62# Globals
63################################################################################
66__all__: list[str] = ["get_from_kondata"]
67DATASET_URL_MAPPING = {
68 "test": "https://dx.doi.org/10.48606/108",
69 "H1Ub": "https://dx.doi.org/10.48606/99",
70 "linear_dimers": "https://dx.doi.org/10.48606/rtda48r76y1e71cc",
71 "Ub_K11_mutants": "https://dx.doi.org/10.48606/rx7md621k9zawfz7",
72}
75################################################################################
76# Functions
77################################################################################
80def untar(
81 tar_file: Path,
82 doi_tar: str,
83 output: Path,
84 force_overwrite: bool = False,
85 silence_overwrite: bool = False,
86) -> list[str]:
87 """Untars files in a tar archive downloaded from KonDATA.
89 As the files are found under {doi_tar}/data/dataset/, only this directory is considered.
90 The other directories in the tar file contain RADAR metadata, that is not important.
92 Args:
93 tar_file (Path): The pathlib.Path to the tar file.
94 doi_tar (str): The doi_tar of the tar file. This arg is needed to find the
95 datafiles in {doi_tar}/data/dataset. The doi_tar can be obtained from
96 the doi of the dataset by removing 'https://dx.doi.org/' and replacing
97 '/' in the doi with '-'.
98 output (Path): The directory, where to put the files of the archive.
99 force_overwrite (bool): Whether to overwrite files that are already there.
100 Will print 'file already exists' if file exists and `force_overwrite`
101 is set to False. Defaults to False.
102 silence_overwrite (bool): Will not print information about already
103 existing files. Useful for when this function is used in a script.
105 Returns:
106 list[str]: A list of the new files.
108 """
109 untarred_files: list[str] = []
110 if not silence_overwrite:
111 print(f"{tar_file} already exists. Set `force_overwrite` to True to overwrite.")
112 with tarfile.open(tar_file) as tar:
113 for member in tar.getmembers():
114 if (
115 member.path.startswith(f"{doi_tar}/data/dataset")
116 and member.path != f"{doi_tar}/data/dataset"
117 ):
118 filename = member.path.lstrip(f"{doi_tar}/data/dataset/")
119 extract_to = output / filename
120 if extract_to.exists() and not force_overwrite:
121 if not silence_overwrite:
122 print(
123 f"{extract_to} already exists. Set "
124 f"`force_overwrite` to True to overwrite."
125 )
126 else:
127 if member.isdir():
128 extract_to.mkdir(parents=True, exist_ok=True)
129 elif member.isfile():
130 untarred_files.append(str(extract_to))
131 tar.makefile(member, extract_to)
132 else:
133 print(f"Unknown TarInfo type: {member}")
134 return untarred_files
137def get_from_kondata(
138 dataset_name: str,
139 output: Optional[Union[str, Path]] = None,
140 force_overwrite: bool = False,
141 mk_parentdir: bool = False,
142 silence_overwrite_message: bool = False,
143 tqdm_class: Optional[Any] = None,
144 download_extra_data: bool = False,
145 download_checkpoints: bool = False,
146 download_h5: bool = True,
147) -> str:
148 """Get dataset from the University of Konstanz's data repository KONData.
150 Args:
151 dataset_name (str): The name of the dataset. Refer to `DATASET_URL_MAPPING`
152 to get a list of the available datasets.
153 output (Union[str, Path]): The output directory.
154 force_overwrite (bool): Whether to overwrite existing files. Defaults to False.
155 mk_parentdir (bool): Whether to create the `output` directory if it does
156 not already exist. Defaults to False.
157 silence_overwrite_message (bool): Whether to silence the 'file already exists'
158 warning. Can be useful in scripts. Defaults to False.
159 tqdm_class (Optional[Any]): A class that is similar to tqdm.tqdm. This
160 is mainly useful if this function is used inside a `rich.status.Status`
161 context manager, as the normal tqdm does not work inside this context.
162 If None is provided, the default tqdm will be used.
163 download_extra_data (bool): Whether to download extra data. It Is only used
164 if the dataset is not available on KonDATA. Defaults to False.
165 download_checkpoints (bool): Whether to download pretrained checkpoints.
166 It is only used if the dataset is not available on KonDATA.
167 Defaults to False.
168 download_h5 (bool): Whether to also download an h5 file of the
169 ensemble. Defaults to True.
171 Returns:
172 str: The output directory.
174 """
175 # Local Folder Imports
176 from .misc.misc import _is_notebook
178 if dataset_name not in DATASET_URL_MAPPING:
179 return get_from_url(
180 f"https://encodermap.site/encodermap_data/{dataset_name}",
181 output=output,
182 force_overwrite=force_overwrite,
183 mk_parentdir=mk_parentdir,
184 silence_overwrite_message=silence_overwrite_message,
185 tqdm_class=tqdm_class,
186 download_extra_data=download_extra_data,
187 download_checkpoints=download_checkpoints,
188 download_h5=download_h5,
189 )
190 if output is None:
191 # Standard Library Imports
192 import pkgutil
194 package = pkgutil.get_loader("encodermap")
195 if package is None:
196 output = Path("~").resolve() / f".encodermap_data/{dataset_name}"
197 else:
198 emfile = package.get_filename()
199 output = Path(emfile).parent.parent / "tests"
200 if not output.is_dir():
201 output = Path("~").resolve() / f".encodermap_data/{dataset_name}"
202 else:
203 output /= f"data/{dataset_name}"
204 output.parent.mkdir(exist_ok=True)
205 if not output.parent.is_dir(): # pragma: nocover
206 question = input(
207 f"I will create the directory {output.parent} and download "
208 f"the dataset {dataset_name} to it."
209 )
210 if question.lower() not in ["y", "ye", "yes"]:
211 raise Exception(f"User has answered to not overwrite {output.parent}.")
212 output.mkdir(parents=True, exist_ok=True)
214 if dataset_name == "trained_networks":
215 warnings.warn("Implement")
216 return
218 # in all other cases make sure its path
219 output = Path(output)
221 if tqdm_class is None:
222 if _is_notebook():
223 # Third Party Imports
224 from tqdm.notebook import tqdm as tqdm_class # type: ignore[no-redef]
225 else:
226 # Third Party Imports
227 from tqdm import tqdm as tqdm_class # type: ignore[no-redef]
228 assert tqdm_class is not None
230 if not (output := Path(output)).is_dir():
231 if not mk_parentdir:
232 raise Exception(
233 f"Output directory {output} does not exists. Create it "
234 f"or set `mk_parentdir` to True."
235 )
236 else:
237 output.mkdir()
239 # clear partial downloads
240 partial_downloads = output.glob("*.crdownload")
241 for partial_download in partial_downloads:
242 partial_download.unlink()
244 # define stuff for the tarfile and maybe unpack the files if not already done so
245 doi_tar = (
246 DATASET_URL_MAPPING[dataset_name]
247 .lstrip("https://dx.doi.org/")
248 .replace("/", "-")
249 )
250 tar_file = output / f"{doi_tar}.tar"
251 if tar_file.is_file() and not force_overwrite:
252 untarred_files = untar(
253 tar_file, doi_tar, output, force_overwrite, silence_overwrite_message
254 )
255 if not force_overwrite:
256 return str(output)
258 # instantiate the status
259 s = Status(f"Downloading {dataset_name}", spinner="material")
260 s.start()
262 # use selenium to download the archive
263 url = DATASET_URL_MAPPING[dataset_name]
264 try:
265 # Third Party Imports
266 from selenium import webdriver
267 from selenium.webdriver.common.by import By
268 except ImportError as e:
269 raise Exception(
270 f"Programmatically downloading from KonDATA requires selenium. Beware "
271 f"it uses Google code to interact with web pages. Either "
272 f"install it with `pip install selenium`, or manually download the "
273 f"files from {url} and un-tar them to {output}"
274 ) from e
275 prefs = {"download.default_directory": str(output)}
276 options = webdriver.ChromeOptions()
277 options.add_experimental_option("prefs", prefs)
278 options.add_argument("--headless=new")
279 driver = webdriver.Chrome(options=options)
280 driver.get(url)
281 download_element = driver.find_element(
282 By.XPATH, '//a[starts-with(@class, "dataset-download")]'
283 )
284 download_element.click()
285 accept_element = driver.find_element(By.XPATH, '//button[text()="Accept"]')
286 accept_element.click()
288 # play an animation while downloading
289 i = 0
290 while True:
291 time.sleep(0.2)
292 i += 1
293 if tar_file.is_file():
294 s.stop()
295 break
296 copied_tar = output / f"{doi_tar} (1).tar"
297 if copied_tar.is_file():
298 shutil.move(copied_tar, tar_file)
299 untarred_files = untar(
300 tar_file, doi_tar, output, force_overwrite, silence_overwrite=True
301 )
303 if dataset_name == "H1Ub" and download_h5:
304 trajs_file = output / "trajs.h5"
305 if trajs_file.is_file() and not force_overwrite:
306 print(f"trajs.h5 file for H1Ub already present.")
307 else:
308 print(f"The H1Ub dataset does not contain a trajs.h5 file.")
310 return str(output)
313def get_assign_from_file(file: Path, assign: str) -> str:
314 """Reads a file and extracts lines with assignments. Can be
315 used for reading simple secret files which look like::
316 PASSWORD=my_sekret_password
317 USERNAME=oll_korrect_username
319 Args:
320 file (Path): The file.
321 assign (str): The string to look for (e.g. PASSWORD).
323 Returns:
324 str: The assignment after the equal (=) sign.
326 """
327 content = file.read_text().splitlines()
328 content_line = content[[c.startswith(assign) for c in content].index(True)]
329 return content_line.split("=")[-1]
332def is_directory(url: str) -> bool:
333 """Returns, whether a string ends with '/' which makes that an url of a dir.
335 Args:
336 url (str): The url.
338 Returns:
339 bool: Whether that string ends with '/'.
341 """
342 if url.endswith("/"):
343 return True
344 return False
347def find_links(url: str, auth: Optional[HTTPBasicAuth] = None) -> Generator:
348 """Recourses through an html content file with beautifulsoup and extracts links.
350 Can be used to mimic `wget -R` with python.
352 Args:
353 url (str): The url to recourse.
354 auth (Optional[HTTPBasicAuth]): The authentication to use. Can be None
355 for unprotected urls. Can be an instance of `requests.auth.HTTPBasicAuth`.
357 Yields:
358 tuple: A tuple of the following:
359 str: The complete link to the file.
360 str: The truncated link (without the `url` substring), which can
361 be used to set the filename on disk, the link will be downloaded
362 to.
364 """
365 organization_names = ["?C=N;O=D", "?C=M;O=A", "?C=S;O=A", "?C=D;O=A"]
366 content = requests.get(url, auth=auth).content
367 soup = BeautifulSoup(content, "html.parser")
368 maybe_directories = soup.find_all("a", href=True)
369 for link in maybe_directories:
370 if is_directory(link["href"]) and "Parent Directory" not in link.text:
371 if not url.endswith("/"):
372 new_url = url + "/" + link["href"]
373 else:
374 new_url = url + link["href"]
375 yield from find_links(new_url, auth)
376 else:
377 if link["href"] not in organization_names:
378 filename = url + "/" + link["href"]
379 if not filename.endswith("/"):
380 yield filename
383def get_from_url(
384 url: str,
385 output: Optional[Union[str, Path]] = None,
386 username: Optional[str] = None,
387 password: Optional[str] = None,
388 force_overwrite: bool = False,
389 mk_parentdir: bool = False,
390 silence_overwrite_message: bool = False,
391 tqdm_class: Optional[Any] = None,
392 download_extra_data: bool = False,
393 download_checkpoints: bool = False,
394 download_h5: bool = True,
395 combine_progbars: bool = False,
396) -> str:
397 """Recourses through `url` and downloads all strings into `output`.
399 Args:
400 url (str): The url to visit.
401 output (Optional[Union[str, Path]]): Where to put the files.
402 username (Optional[str]): The username for protected sites. If the site
403 is protected and this arg is None, the `input()` builtin will be used
404 to get the username.
405 password (Optional[str]): The password for protected sites. If the site
406 is protected and this arg is None, the `getpass()` method will be
407 used to get the password.
408 force_overwrite (bool): Whether to overwrite existing files.
409 mk_parentdir (bool): Whether to create the `output` directory in case it
410 is missing.
411 tqdm_class (Optional[Any]): A class implementing a tqdm feature.
412 download_extra_data (bool): Whether to download (potentially) large
413 extra data. Mainly useful for unittests. Defaults to False.
414 download_checkpoints (bool): Whether to download checkpoints.
415 Good for skipping long training.
416 combine_progbars (bool): Whether to make the download print one long
417 progression bar.
419 Returns:
420 str: The output directory.
422 """
423 # Local Folder Imports
424 from .misc.misc import _is_notebook
426 if "encodermap.site" in url:
427 dataset_name = url.replace("https://encodermap.site/encodermap_data/", "")
429 if not url.startswith("https"):
430 raise Exception(f"Not downloading from {url=}. Missing https.")
432 if output is None:
433 # Standard Library Imports
434 import pkgutil
436 package = importlib.util.find_spec("encodermap")
437 if package is None:
438 output = Path("~").resolve() / f".encodermap_data/{dataset_name}"
439 else:
440 emfile = package.origin
441 output = Path(emfile).parent.parent / "tests"
442 if not output.is_dir():
443 output = Path("~").resolve() / f".encodermap_data/{dataset_name}"
444 else:
445 output /= f"data/{dataset_name}"
446 output.parent.mkdir(exist_ok=True)
447 if not output.parent.is_dir(): # pragma: nocover
448 question = input(
449 f"I will create the directory {output.parent} and download "
450 f"the dataset {dataset_name} to it."
451 )
452 if question.lower() not in ["y", "ye", "yes"]:
453 raise Exception(f"User has answered to not overwrite {output.parent}.")
454 output.mkdir(parents=True, exist_ok=True)
456 if tqdm_class is None:
457 if _is_notebook():
458 # Third Party Imports
459 from tqdm.notebook import tqdm as tqdm_class # type: ignore[no-redef]
460 else:
461 # Third Party Imports
462 from tqdm import tqdm as tqdm_class # type: ignore[no-redef]
463 assert tqdm_class is not None
465 downloaded_files: list[str] = []
466 if not (output := Path(output)).is_dir():
467 if not mk_parentdir:
468 raise Exception(
469 f"Output directory {output} does not exists. Create it "
470 f"or set `mk_parentdir` to True."
471 )
472 else:
473 output.mkdir()
474 # check for the act.vault file
475 try:
476 status_code = requests.get(url, timeout=3).status_code
477 except requests.exceptions.Timeout:
478 warnings.warn(f"EncoderMap's repository at {url} timed out.")
479 return str(output)
480 except requests.exceptions.ConnectionError:
481 warnings.warn(f"EncoderMap's repository at {url} is unreachable.")
482 return str(output)
483 if status_code == 401:
484 vault_file = Path(__file__).resolve().parent.parent / "act.vault"
485 if vault_file.is_file() and username is None:
486 username = get_assign_from_file(vault_file, "ENCODERMAP_DATA_USER")
487 elif username is None:
488 username = input("Please enter the username: ")
489 if vault_file.is_file() and password is None:
490 password = get_assign_from_file(vault_file, "ENCODERMAP_DATA_PASSWORD")
491 elif password is None:
492 password = getpass.getpass("Please enter the password: ")
493 auth = HTTPBasicAuth(username, password)
494 elif status_code == 200:
495 auth = None
496 else:
497 raise Exception(f"Url {url} returned error: {status_code}")
499 # try whether the password works
500 status_code_with_auth = requests.get(url, auth=auth).status_code
501 if status_code_with_auth == 401:
502 raise Exception(f"Wrong username/password.")
504 files = list(find_links(url, auth))
505 in_files = []
506 out_files = []
507 for f1 in files:
508 in_files.append(f1)
509 o = f1.replace(url, "").replace("//", "/").lstrip("/")
510 f2 = output / o
511 out_files.append(f2)
513 if combine_progbars:
514 raise NotImplementedError
516 for in_file, out_file in zip(in_files, out_files):
517 if out_file.is_file() and not force_overwrite:
518 if not silence_overwrite_message:
519 print(
520 f"{out_file} already exists. Set `force_overwrite` to True to overwrite."
521 )
522 continue
523 if "extra_data" in str(in_file) and not download_extra_data:
524 continue
525 if "checkpoints" in str(in_file) and not download_checkpoints:
526 continue
527 if str(in_file).split(".")[-1] == "h5" and not download_h5:
528 print(f"Skipping {in_file}")
529 continue
530 out_file.parent.mkdir(parents=True, exist_ok=True)
531 response = requests.get(in_file, auth=auth, stream=True)
532 total_length = int(response.headers.get("content-length", 0))
533 with (
534 open(out_file, "wb") as file,
535 tqdm_class(
536 desc=str(out_file),
537 total=total_length,
538 unit="iB",
539 unit_scale=True,
540 unit_divisor=1024,
541 ) as bar,
542 ):
543 for data in response.iter_content(chunk_size=1024):
544 size = file.write(data)
545 bar.update(size)
546 downloaded_files.append(str(out_file))
547 return str(output)