Coverage for encodermap/misc/summaries.py: 12%
89 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-07 11:05 +0000
1# -*- coding: utf-8 -*-
2# encodermap/misc/summaries.py
3################################################################################
4# Encodermap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2022 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
23"""
24Functions that write stuff to tensorboard. Mainly used for the iumage callbacks.
25"""
27import io
29import matplotlib
30import matplotlib.pyplot as plt
31import numpy as np
32import tensorflow as tf
34matplotlib.use("Agg")
36__all__ = ["add_layer_summaries", "image_summary"]
39def add_layer_summaries(layer, step=None):
40 """Adds summaries for a layer to Tensorboard.
42 Args:
43 layer (tf.keras.layers.Layer): The layer.
44 step (Union[tf.Tensor, int, None], optional): The current step.
45 Can be either a Tensor or None. Defaults to None.
47 """
48 weights = layer.variables[0]
49 biases = layer.variables[1]
50 if "encoder" in layer.name.lower():
51 namescope = "Encoder"
52 elif "decoder" in layer.name.lower():
53 namescope = "Decoder"
54 elif "latent" in layer.name.lower():
55 namescope = "Latent"
56 else:
57 namescope = "InputOutputLayers"
58 variable_summaries(namescope, layer.name + "/weights", weights, step)
59 variable_summaries(namescope, layer.name + "/biases", biases, step)
62def variable_summaries(namescope, name, variables, step=None):
63 """
64 Attach several summaries to a Tensor for TensorBoard visualization.
66 Args:
67 namescope (str): The string to prepend to the layer names.
68 Makes it easier to group the layers.
69 name (str): The name of the layer.
70 variables (tf.Tensor): The variables (weighhts, biases) of the layer.
71 step (Union[tf.Tensor, int, None], optional): The current step.
72 Can be either a Tensor or None. Defaults to None.
74 """
75 if not isinstance(variables, list):
76 variables = [variables]
78 for i, var in enumerate(variables):
79 try:
80 add_index = len(variables) > 1
81 except TypeError:
82 add_index = True
83 if add_index:
84 name = name + str(i)
85 with tf.name_scope(namescope + "/" + name):
86 mean = tf.reduce_mean(var)
87 tf.summary.scalar("mean", mean, step=step)
88 with tf.name_scope("stddev"):
89 stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
90 tf.summary.scalar("stddev", stddev, step=step)
91 tf.summary.scalar("max", tf.reduce_max(var), step=step)
92 tf.summary.scalar("min", tf.reduce_min(var), step=step)
93 tf.summary.histogram("histogram", var, step=step)
96def image_summary(
97 lowd, step=None, scatter_kws={"s": 20}, hist_kws={"bins": 50}, additional_fns=None
98):
99 """Writes an image to Tensorboard.
101 Args:
102 lowd (np.ndarray): The data to plot. Usually that will be the output of the latent space of the
103 Autoencoder. This array has to be of dimensionality 2 (rows and columns). The first two points of the
104 rows will be used as xy coordinates in a scatter plot.
105 step (Union[int, None], optional): The training step under which you can find the image
106 in tensorboard. Defaults to None.
107 scatter_kws (dict, optional): A dictionary with keyword arguments to be passed to matpltlib.pyplot.scatter().
108 Defaults to {'s': 20}.
109 hist_kws (dict, optional): A dictionary with keyword arguments to be passed to matpltlib.pyplot.hist2d().
110 Defaults to {'bins': 50}.
111 additional_fns (Union[None, list], optional): A list of functions that take the data of the latent space
112 and return a tf.Tensor that can be logged to tensorboard with tf.summary.image().
114 Raises:
115 AssertionError: When lowd.ndim is not 2 and when len(lowd) != len(ids)
117 """
118 if np.any(np.isnan(lowd)):
119 image = _gen_nan_image()
120 with tf.name_scope("Latent Scatter"):
121 tf.summary.image(f"Latent at step {step}", image, step=step)
122 return
123 scatter_image = _gen_scatter(lowd[:, :2], scatter_kws)
124 hist_image = _gen_hist(lowd[:, :2], hist_kws)
125 with tf.name_scope("Latent Scatter"):
126 tf.summary.image(f"Latent at step {step}", scatter_image, step=step)
127 with tf.name_scope("Latent Density"):
128 tf.summary.image(f"Latent at step {step}", hist_image, step=step)
130 if additional_fns is not None:
131 with tf.name_scope("User Provided Plotting Functions"):
132 for i, fn in enumerate(additional_fns):
133 tf.summary.image(
134 f"User Plotting {i} at step {step}", fn(lowd), step=step
135 )
138def _gen_hist(data, hist_kws):
139 """Creates matplotlib histogram and returns tensorflow Tensor that represents an image.
141 Args:
142 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2.
143 1st dimension the datapoints, 2nd dimension x, y.
144 hist_kws (dict): Additional keywords to be passed to matplotlib.pyplot.hist2d().
146 Returns:
147 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image().
149 """
150 plt.close("all")
151 matplotlib.use("Agg") # overwrites current backend of notebook
152 plt.figure()
153 plt.hist2d(*data.T, **hist_kws)
154 buf = io.BytesIO()
155 plt.savefig(buf, format="png")
156 buf.seek(0)
157 image = tf.image.decode_png(buf.getvalue(), 4)
158 image = tf.expand_dims(image, 0)
159 return image
162def _gen_nan_image():
163 """Creates matplotlib image, whith debug info.
165 Returns:
166 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image().
168 """
169 plt.close("all")
170 matplotlib.use("Agg") # overwrites current backend of notebook
171 fig, ax = plt.subplots()
172 ax.text(
173 0.5,
174 0.5,
175 "Some data of lowd is nan",
176 ha="center",
177 va="center",
178 transform=ax.transAxes,
179 )
180 buf = io.BytesIO()
181 plt.savefig(buf, format="png")
182 buf.seek(0)
183 image = tf.image.decode_png(buf.getvalue(), 4)
184 image = tf.expand_dims(image, 0)
185 return image
188def _gen_scatter(data, scatter_kws):
189 """Creates matplotlib scatter plot and returns tensorflow Tensor that represents an image.
191 Args:
192 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2.
193 1st dimension the datapoints, 2nd dimension x, y.
194 scatter_kws (dict): Additional keywords to be passed to matplotlib.pyplot.scatter().
196 Returns:
197 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image().
199 """
200 plt.close("all")
201 matplotlib.use("Agg") # overwrites current backend of notebook
202 plt.figure()
203 # plt.plot([1, 2])
204 if isinstance(data, np.ndarray):
205 plt.scatter(*data.T, **scatter_kws)
206 else:
207 plt.scatter(*data.numpy().T, **scatter_kws)
208 buf = io.BytesIO()
209 plt.savefig(buf, format="png")
210 buf.seek(0)
211 image = tf.image.decode_png(buf.getvalue(), 4)
212 image = tf.expand_dims(image, 0)
213 return image