Coverage for encodermap/misc/function_def.py: 31%

14 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/misc/function_def.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""Wraps tensorflow's `tf.function` again to accept a debug=True or debug=False argument. 

23 

24With debug=True, the function will not be compiled. With debug=False (which is 

25teh default), it will be compiled. 

26 

27""" 

28# Future Imports at the top 

29from __future__ import annotations 

30 

31# Standard Library Imports 

32from typing import Any 

33 

34# Third Party Imports 

35import tensorflow as tf 

36 

37 

38def function(debug: bool = False) -> Any: 

39 """Encodermap's implementation of `tf.function`. 

40 

41 Args: 

42 debug (bool): If True, the decorated function will not be compiled. 

43 Defaults to False. 

44 

45 """ 

46 

47 def decorator(f: Any) -> Any: 

48 """The decorator, that takes the function.""" 

49 

50 def wrapper(*args: Any, **kwargs: Any) -> Any: 

51 """The wrapper, that calls the function based on the debug argument.""" 

52 if debug: 

53 result = f(*args, **kwargs) 

54 else: 

55 compiled = tf.function(f) 

56 result = compiled(*args, **kwargs) 

57 return result 

58 

59 return wrapper 

60 

61 return decorator