Coverage for encodermap/misc/function_def.py: 31%
14 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/misc/function_def.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
22"""Wraps tensorflow's `tf.function` again to accept a debug=True or debug=False argument.
24With debug=True, the function will not be compiled. With debug=False (which is
25teh default), it will be compiled.
27"""
28# Future Imports at the top
29from __future__ import annotations
31# Standard Library Imports
32from typing import Any
34# Third Party Imports
35import tensorflow as tf
38def function(debug: bool = False) -> Any:
39 """Encodermap's implementation of `tf.function`.
41 Args:
42 debug (bool): If True, the decorated function will not be compiled.
43 Defaults to False.
45 """
47 def decorator(f: Any) -> Any:
48 """The decorator, that takes the function."""
50 def wrapper(*args: Any, **kwargs: Any) -> Any:
51 """The wrapper, that calls the function based on the debug argument."""
52 if debug:
53 result = f(*args, **kwargs)
54 else:
55 compiled = tf.function(f)
56 result = compiled(*args, **kwargs)
57 return result
59 return wrapper
61 return decorator