Coverage for encodermap/misc/summaries.py: 12%

89 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/misc/summaries.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22 

23""" 

24Functions that write stuff to tensorboard. Mainly used for the iumage callbacks. 

25""" 

26 

27import io 

28 

29import matplotlib 

30import matplotlib.pyplot as plt 

31import numpy as np 

32import tensorflow as tf 

33 

34matplotlib.use("Agg") 

35 

36__all__ = ["add_layer_summaries", "image_summary"] 

37 

38 

39def add_layer_summaries(layer, step=None): 

40 """Adds summaries for a layer to Tensorboard. 

41 

42 Args: 

43 layer (tf.keras.layers.Layer): The layer. 

44 step (Union[tf.Tensor, int, None], optional): The current step. 

45 Can be either a Tensor or None. Defaults to None. 

46 

47 """ 

48 weights = layer.variables[0] 

49 biases = layer.variables[1] 

50 if "encoder" in layer.name.lower(): 

51 namescope = "Encoder" 

52 elif "decoder" in layer.name.lower(): 

53 namescope = "Decoder" 

54 elif "latent" in layer.name.lower(): 

55 namescope = "Latent" 

56 else: 

57 namescope = "InputOutputLayers" 

58 variable_summaries(namescope, layer.name + "/weights", weights, step) 

59 variable_summaries(namescope, layer.name + "/biases", biases, step) 

60 

61 

62def variable_summaries(namescope, name, variables, step=None): 

63 """ 

64 Attach several summaries to a Tensor for TensorBoard visualization. 

65 

66 Args: 

67 namescope (str): The string to prepend to the layer names. 

68 Makes it easier to group the layers. 

69 name (str): The name of the layer. 

70 variables (tf.Tensor): The variables (weighhts, biases) of the layer. 

71 step (Union[tf.Tensor, int, None], optional): The current step. 

72 Can be either a Tensor or None. Defaults to None. 

73 

74 """ 

75 if not isinstance(variables, list): 

76 variables = [variables] 

77 

78 for i, var in enumerate(variables): 

79 try: 

80 add_index = len(variables) > 1 

81 except TypeError: 

82 add_index = True 

83 if add_index: 

84 name = name + str(i) 

85 with tf.name_scope(namescope + "/" + name): 

86 mean = tf.reduce_mean(var) 

87 tf.summary.scalar("mean", mean, step=step) 

88 with tf.name_scope("stddev"): 

89 stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 

90 tf.summary.scalar("stddev", stddev, step=step) 

91 tf.summary.scalar("max", tf.reduce_max(var), step=step) 

92 tf.summary.scalar("min", tf.reduce_min(var), step=step) 

93 tf.summary.histogram("histogram", var, step=step) 

94 

95 

96def image_summary( 

97 lowd, step=None, scatter_kws={"s": 20}, hist_kws={"bins": 50}, additional_fns=None 

98): 

99 """Writes an image to Tensorboard. 

100 

101 Args: 

102 lowd (np.ndarray): The data to plot. Usually that will be the output of the latent space of the 

103 Autoencoder. This array has to be of dimensionality 2 (rows and columns). The first two points of the 

104 rows will be used as xy coordinates in a scatter plot. 

105 step (Union[int, None], optional): The training step under which you can find the image 

106 in tensorboard. Defaults to None. 

107 scatter_kws (dict, optional): A dictionary with keyword arguments to be passed to matpltlib.pyplot.scatter(). 

108 Defaults to {'s': 20}. 

109 hist_kws (dict, optional): A dictionary with keyword arguments to be passed to matpltlib.pyplot.hist2d(). 

110 Defaults to {'bins': 50}. 

111 additional_fns (Union[None, list], optional): A list of functions that take the data of the latent space 

112 and return a tf.Tensor that can be logged to tensorboard with tf.summary.image(). 

113 

114 Raises: 

115 AssertionError: When lowd.ndim is not 2 and when len(lowd) != len(ids) 

116 

117 """ 

118 if np.any(np.isnan(lowd)): 

119 image = _gen_nan_image() 

120 with tf.name_scope("Latent Scatter"): 

121 tf.summary.image(f"Latent at step {step}", image, step=step) 

122 return 

123 scatter_image = _gen_scatter(lowd[:, :2], scatter_kws) 

124 hist_image = _gen_hist(lowd[:, :2], hist_kws) 

125 with tf.name_scope("Latent Scatter"): 

126 tf.summary.image(f"Latent at step {step}", scatter_image, step=step) 

127 with tf.name_scope("Latent Density"): 

128 tf.summary.image(f"Latent at step {step}", hist_image, step=step) 

129 

130 if additional_fns is not None: 

131 with tf.name_scope("User Provided Plotting Functions"): 

132 for i, fn in enumerate(additional_fns): 

133 tf.summary.image( 

134 f"User Plotting {i} at step {step}", fn(lowd), step=step 

135 ) 

136 

137 

138def _gen_hist(data, hist_kws): 

139 """Creates matplotlib histogram and returns tensorflow Tensor that represents an image. 

140 

141 Args: 

142 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2. 

143 1st dimension the datapoints, 2nd dimension x, y. 

144 hist_kws (dict): Additional keywords to be passed to matplotlib.pyplot.hist2d(). 

145 

146 Returns: 

147 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image(). 

148 

149 """ 

150 plt.close("all") 

151 matplotlib.use("Agg") # overwrites current backend of notebook 

152 plt.figure() 

153 plt.hist2d(*data.T, **hist_kws) 

154 buf = io.BytesIO() 

155 plt.savefig(buf, format="png") 

156 buf.seek(0) 

157 image = tf.image.decode_png(buf.getvalue(), 4) 

158 image = tf.expand_dims(image, 0) 

159 return image 

160 

161 

162def _gen_nan_image(): 

163 """Creates matplotlib image, whith debug info. 

164 

165 Returns: 

166 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image(). 

167 

168 """ 

169 plt.close("all") 

170 matplotlib.use("Agg") # overwrites current backend of notebook 

171 fig, ax = plt.subplots() 

172 ax.text( 

173 0.5, 

174 0.5, 

175 "Some data of lowd is nan", 

176 ha="center", 

177 va="center", 

178 transform=ax.transAxes, 

179 ) 

180 buf = io.BytesIO() 

181 plt.savefig(buf, format="png") 

182 buf.seek(0) 

183 image = tf.image.decode_png(buf.getvalue(), 4) 

184 image = tf.expand_dims(image, 0) 

185 return image 

186 

187 

188def _gen_scatter(data, scatter_kws): 

189 """Creates matplotlib scatter plot and returns tensorflow Tensor that represents an image. 

190 

191 Args: 

192 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2. 

193 1st dimension the datapoints, 2nd dimension x, y. 

194 scatter_kws (dict): Additional keywords to be passed to matplotlib.pyplot.scatter(). 

195 

196 Returns: 

197 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image(). 

198 

199 """ 

200 plt.close("all") 

201 matplotlib.use("Agg") # overwrites current backend of notebook 

202 plt.figure() 

203 # plt.plot([1, 2]) 

204 if isinstance(data, np.ndarray): 

205 plt.scatter(*data.T, **scatter_kws) 

206 else: 

207 plt.scatter(*data.numpy().T, **scatter_kws) 

208 buf = io.BytesIO() 

209 plt.savefig(buf, format="png") 

210 buf.seek(0) 

211 image = tf.image.decode_png(buf.getvalue(), 4) 

212 image = tf.expand_dims(image, 0) 

213 return image