Coverage for encodermap/misc/summaries.py: 12%
140 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
« prev ^ index » next coverage.py v7.4.1, created at 2024-12-31 16:54 +0100
1# -*- coding: utf-8 -*-
2# encodermap/misc/summaries.py
3################################################################################
4# EncoderMap: A python library for dimensionality reduction.
5#
6# Copyright 2019-2024 University of Konstanz and the Authors
7#
8# Authors:
9# Kevin Sawade, Tobias Lemke
10#
11# Encodermap is free software: you can redistribute it and/or modify
12# it under the terms of the GNU Lesser General Public License as
13# published by the Free Software Foundation, either version 2.1
14# of the License, or (at your option) any later version.
15# This package is distributed in the hope that it will be useful to other
16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the
17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
18# See the GNU Lesser General Public License for more details.
19#
20# See <http://www.gnu.org/licenses/>.
21################################################################################
23"""
24Functions that write stuff to tensorboard. Mainly used for the image callbacks.
25"""
27################################################################################
28# Imports
29################################################################################
32# Future Imports at the top
33from __future__ import annotations
35# Standard Library Imports
36import io
37from collections.abc import Callable, Sequence
38from typing import Any, Literal, Optional, Union
40# Third Party Imports
41import matplotlib
42import matplotlib.pyplot as plt
43import numpy as np
44import tensorflow as tf
45from optional_imports import _optional_import
47# Encodermap imports
48from encodermap.plot.plotting import _plot_free_energy
51################################################################################
52# Optional Imports
53################################################################################
56px = _optional_import("plotly", "express")
57go = _optional_import("plotly", "graph_objects")
60################################################################################
61# Globals
62################################################################################
65__all__: list[str] = ["add_layer_summaries", "image_summary"]
68################################################################################
69# Summary Functions
70################################################################################
73def add_layer_summaries(
74 layer: tf.keras.layers.Layer,
75 step: Optional[int] = None,
76) -> None:
77 """Adds summaries for a layer to Tensorboard.
79 Args:
80 layer (tf.keras.layers.Layer): The layer.
81 step (Union[tf.Tensor, int, None], optional): The current step.
82 Can be either a Tensor or None. Defaults to None.
84 """
85 weights = layer.variables[0]
86 biases = layer.variables[1]
87 if "encoder" in layer.name.lower():
88 namescope = "Encoder"
89 elif "decoder" in layer.name.lower():
90 namescope = "Decoder"
91 elif "latent" in layer.name.lower():
92 namescope = "Latent"
93 else:
94 namescope = "InputOutputLayers"
95 variable_summaries(namescope, layer.name + "/weights", weights, step)
96 variable_summaries(namescope, layer.name + "/biases", biases, step)
99def variable_summaries(
100 namescope: str,
101 name: str,
102 variables: tf.Tensor,
103 step: Optional[int] = None,
104) -> None:
105 """
106 Attach several summaries to a Tensor for TensorBoard visualization.
108 Args:
109 namescope (str): The string to prepend to the layer names.
110 Makes it easier to group the layers.
111 name (str): The name of the layer.
112 variables (tf.Tensor): The variables (weighhts, biases) of the layer.
113 step (Union[tf.Tensor, int, None], optional): The current step.
114 Can be either a Tensor or None. Defaults to None.
116 """
117 if not isinstance(variables, list):
118 variables = [variables]
120 for i, var in enumerate(variables):
121 try:
122 add_index = len(variables) > 1
123 except TypeError:
124 add_index = True
125 if add_index:
126 name = name + str(i)
127 with tf.name_scope(namescope + "/" + name):
128 mean = tf.reduce_mean(var)
129 tf.summary.scalar("mean", mean, step=step)
130 with tf.name_scope("stddev"):
131 stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
132 tf.summary.scalar("stddev", stddev, step=step)
133 tf.summary.scalar("max", tf.reduce_max(var), step=step)
134 tf.summary.scalar("min", tf.reduce_min(var), step=step)
135 tf.summary.histogram("histogram", var, step=step)
136 tf.compat.v1.summary.tensor_summary("values", var)
139# def intermediate_summary(
140# highd_data: Union[np.ndarray, Sequence[np.ndarray]],
141# step: int,
142# model: tf.keras.Model,
143# parameters: "AnyParameters",
144# backend: Literal["matplotlib", "plotly"] = "matplotlib",
145# ) -> dict[str, io.BytesIO]:
146# """Produces images of intermediate layers during training.
147#
148# Args:
149# highd_data (Union[np.ndarray, Sequence[np.ndarray]]): The high dimensional
150# data to be used with the autoencoder model. If the autoencoder is
151# of type `Autoencoder`, `EncoderMap`, or `DihedralEncoderMap`, the
152# `highd_data` should be a single np.ndarray. For the `AngleDihedralCartesianEncderMap`,
153# the `highd_data` should be a Sequence of np.ndarray in the sequence
154# angles, dihedrals, cartesians, distance, (and maybe side_dihedrals,
155# depending on whether side_dihedrals are used for the training).
156# step (int): The current training ste.
157# model (tf.keras.Model): The autoencoder model.
158# parameters (AnyParameters): Either an instance of `encodermap.parameters.Parameters`, or
159# `encodermap.parameters.ADCParameters`, depending on which autoencoder model is used.
160# The parameters are used to fix periodic input data.
161# backend (Literal["matplotlib", "plotly"]: Which backend to use for
162# plotting. Defaults to 'matplotlib'.
163#
164# """
165# if model.encoder_model.layers[0].name == "Encoder_0":
166# return _encodermap_intermediate_summary(
167# highd_data,
168# step,
169# model,
170# parameters,
171# backend,
172# )
173# else:
174# return _angledihedralcartesianencodermap_intermediate_summary(
175# highd_data,
176# step,
177# model,
178# parameters,
179# backend,
180# )
181#
182#
183# def _layer_output_and_input_summary(
184# data: np.ndarray,
185# name: str,
186# step: int,
187## type: Literal["input", "layer"] = "layer",
188# backend: Literal["matplotlib", "plotly"] = "matplotlib",
189# ) -> io.BytesIO:
190# buf = io.BytesIO()
191# raise Exception(f"Here")
192# if isinstance(data, tf.SparseTensor):
193# data = tf.sparse.to_dense(data, default_value=np.nan)
194# if isinstance(data, tf.Tensor):
195# data = data.numpy()
196# mean = np.mean(data, 0)
197# if backend == "plotly":
198# fig = px.bar(
199# x=mean,
200# y=np.arange(data.shape[1]),
201# orientation="h",
202# color=mean,
203# color_continuous_scale="Viridis",
204# labels={
205# "x": f"mean value of {data.shape[0]} samples",
206# "y": "feature" if type == "input" else "neuron",
207# },
208# width=500,
209# height=500,
210# )
211# fig.update_layout(
212# margin={"l": 0, "r": 0, "t": 0, "b": 0},
213# coloraxis_showscale=False,
214# )
215# fig.write_image(buf)
216# elif backend == "matplotlib":
217# raise NotImplementedError(f"{mean.shape=}")
218# else:
219# raise Exception(
220# f"Argument `backend` must be either 'plotly' or 'matplotlib'."
221# )
222# image = tf.image.decode_png(buf.getvalue(), 4)
223# image = tf.expand_dims(image, 0)
224# with tf.name_scope("Layer Outputs"):
225# tf.summary.image(name, image, step=step)
226# buf.seek(0)
227# return buf
228#
229#
230# def _encodermap_intermediate_summary(
231# highd_data: Union[np.ndarray, Sequence[np.ndarray]],
232# step: int,
233# model: tf.keras.Model,
234# parameters: "AnyParameters",
235# backend: Literal["matplotlib", "plotly"] = "matplotlib",
236# ) -> dict[str, io.BytesIO]:
237# # Local Folder Imports
238#
239# out = {}
240# layers = []
241# for layer in model.encoder_model.layers:
242# layers.append(layer)
243# for layer in model.decoder_model.layers:
244# layers.append(layer)
245#
246# if model.sparse:
247# input = model.get_dense_model(highd_data)
248# else:
249# input = highd_data
250#
251# out["input"] = _layer_output_and_input_summary(
252# input, name="Input", step=step, type="input", backend=backend
253# )
254#
255# if parameters.periodicity != float("inf"):
256# if parameters.periodicity != 2 * np.pi:
257# input = input / parameters.periodicity * 2 * np.pi
258# input = tf.concat([tf.sin(input), tf.cos(input)], 1)
259#
260# for layer in layers[:-1]:
261# input = layer(input)
262# out[layer.name] = _layer_output_and_input_summary(input, layer.name, step=step, backend=backend)
263#
264# # last layer needs to be treated individually because of maybe being periodic
265# input = layers[-1](input)
266# if parameters.periodicity != float("inf"):
267# input = tf.atan2(*tf.split(input, 2, 1))
268# if parameters.periodicity != 2 * np.pi:
269# input = input / (2 * np.pi) * p.periodicity
270# out[layer.name] = _layer_output_and_input_summary(
271# input, layers[-1].name, step=step, backend=backend
272# )
273# return out
274#
275#
276# def _angledihedralcartesianencodermap_intermediate_summary(
277# highd_data: Union[np.ndarray, Sequence[np.ndarray]],
278# step: int,
279# model: tf.keras.Model,
280# parameters: "AnyParameters",
281# backend: Literal["matplotlib", "plotly"] = "matplotlib",
282# ) -> dict[str, io.BytesIO]:
283# out = {}
284#
285# if isinstance(highd_data, (list, tuple)):
286# if len(highd_data) == 1:
287# dihedrals = highd_data[0]
288# angles = None
289# side_dihedrals = None
290# elif len(highd_data) == 2:
291# angles, dihedrals = highd_data
292# side_dihedrals = None
293# elif len(highd_data) == 3:
294# angles, dihedrals, side_dihedrals = highd_data
295# else:
296# raise Exception(
297# f"Can't construct intermediate layer outputs for ADCEMap when "
298# f"provided high-dimensional contains more than 3 elements."
299# )
300# else:
301# dihedrals = highd_data
302# angles = None
303# side_dihedrals = None
304#
305# layers = {}
306# for layer in model.encoder_model.layers:
307# layers[layer.name] = layer
308# for layer in model.decoder_model.layers:
309# layers[layer.name] = layer
310#
311# # angles
312# if angles is not None:
313# layer = layers["input_angles_to_unit_circle"]
314# raise Exception(f"Here {layer=} {model.__class__.__name__=} {angles.shape=} {model.get_dense_model_central_angles.input_shape=}")
315# if model.__class__.__name__ == "ADCSparseFunctionalModel":
316# if model.get_dense_model_central_angles is not None:
317# angles = model.get_dense_model_central_angles(
318# angles
319# )
320# raise Exception(f"Here")
321# angles = layer(angles)
322# out[layer.name] = _layer_output_and_input_summary(angles, layer.name, step=step, backend=backend)
323#
324# raise Exception("Here")
325#
326# # sidechain dihedrals
327# if side_dihedrals is not None:
328# layer = layers["input_side_dihedrals_to_unit_circle"]
329# out["input_side_dihedrals"] = _layer_output_and_input_summary(
330# side_dihedrals, name="Input Side Dihedrals", step=step, type="input"
331# )
332# if model.__class__.__name__ == "ADCSparseFunctionalModel":
333# if model.get_dense_model_side_dihedrals is not None:
334# side_dihedrals = model.get_dense_model_side_dihedrals(
335# side_dihedrals
336# )
337# side_dihedrals = layer(side_dihedrals)
338# out[layer.name] = _layer_output_and_input_summary(
339# side_dihedrals, layer.name, step=step, backend=backend
340# )
341#
342# layer = layers["input_central_dihedrals_to_unit_circle"]
343# out["input_dihedrals"] = _layer_output_and_input_summary(
344# dihedrals, name="Input Dihedrals", step=step, type="input"
345# )
346# if model.__class__.__name__ == "ADCSparseFunctionalModel":
347# if model.get_dense_model_central_dihedrals is not None:
348# dihedrals = model.get_dense_model_side_dihedrals(
349# dihedrals
350# )
351# dihedrals = layer(dihedrals)
352# out[layer.name] = _layer_output_and_input_summary(dihedrals, layer.name, step=step, backend=backend)
353#
354# # concatenate
355# if angles is not None and side_dihedrals is not None:
356# layer = layers["concatenate_angular_inputs"]
357# input = layer((angles, dihedrals, side_dihedrals))
358# splits = [angles.shape[1], dihedrals.shape[1], side_dihedrals.shape[1]]
359# elif angles is not None and side_dihedrals is None:
360# layer = layers["concatenate_angular_inputs"]
361# input = layer((angles, dihedrals))
362# splits = [angles.shape[1], dihedrals.shape[1]]
363# else:
364# input = dihedrals
365# splits = [dihedrals.shape[1]]
366# out[layer.name] = _layer_output_and_input_summary(input, layer.name, step=step, backend=backend)
367#
368# # Encoder
369# i = 0
370# while True:
371# try:
372# layer = layers[f"Encoder_{i}"]
373# except KeyError:
374# break
375# input = layer(input)
376# out[layer.name] = _layer_output_and_input_summary(input, layer.name, step=step, backend=backend)
377# i += 1
378#
379# # Decoder
380# i = 0
381# while True:
382# try:
383# layer = layers[f"Decoder_{i}"]
384# except KeyError:
385# break
386# input = layer(input)
387# out[layer.name] = _layer_output_and_input_summary(input, layer.name, step=step, backend=backend)
388# i += 1
389#
390# # split
391# if angles is not None and side_dihedrals is not None:
392# input = tf.split(input, splits, 1)
393# angles, dihedrals, side_dihedrals = input
394# elif angles is not None and side_dihedrals is None:
395# input = tf.split(input, splits, 1)
396# angles, dihedrals = input
397# else:
398# dihedrals = input
399#
400# # rejig
401# if angles is not None:
402# layer = layers["angles_from_unit_circle"]
403# angles = layer(angles)
404# out["output_angles"] = _layer_output_and_input_summary(
405# angles, name="Output Angles", step=step, backend=backend
406# )
407#
408# # sidechain dihedrals
409# if side_dihedrals is not None:
410# layer = layers["side_dihedrals_from_unit_circle"]
411# side_dihedrals = layer(side_dihedrals)
412# out["output_side_dihedrals"] = _layer_output_and_input_summary(
413# side_dihedrals, name="Output Side Dihedrals", step=step, backend=backend
414# )
415#
416# layer = layers["dihedrals_from_unit_circle"]
417# dihedrals = layer(dihedrals)
418# out["output_dihedrals"] = _layer_output_and_input_summary(
419# dihedrals, name="Output Dihedrals", step=step, backend=backend
420# )
421# return out
424def image_summary(
425 lowd: np.ndarray,
426 step: Optional[int] = None,
427 scatter_kws: Optional[dict[str, Any]] = None,
428 hist_kws: Optional[dict[str, Any]] = None,
429 additional_fns: Optional[Sequence[Callable]] = None,
430 backend: Literal["matplotlib", "plotly"] = "matplotlib",
431) -> None:
432 """Writes an image to Tensorboard.
434 Args:
435 lowd (np.ndarray): The data to plot. Usually that
436 will be the output of the latent space of the Autoencoder.
437 This array has to be of dimensionality 2 (rows and columns).
438 The first two points of the rows will be used as xy coordinates
439 in a scatter plot.
440 step (Optional[int]): The training step under which you can find the
441 image in tensorboard. Defaults to None.
442 scatter_kws (Optional[dict[str, Any]]): A dict with items that
443 `plotly.express.scatter()` will accept. If None is provided,
444 a dict with size 20 will be passed to
445 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`,
446 which sets an appropriate size of scatter points for the size of
447 datasets encodermap is usually used for.
448 hist_kws (Optional[dict[str, Any]]): A dict with items that
449 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a
450 dict with bins 50 will be passed to
451 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`.
452 You can choose a colormap here by providing `{'bins': 50, 'cmap':
453 'plasma'}` for this argument.
454 additional_fns (Optional[Sequence[Callable]]): A sequence of functions that
455 take the data of the latent space and return a tf.Tensor that can
456 be logged to tensorboard with tf.summary.image().
457 backend (Literal["matplotlib", "plotly"]: Which backend to use for
458 plotting. Defaults to 'matplotlib'.
460 Raises:
461 AssertionError: When lowd.ndim is not 2 and when len(lowd) != len(ids)
463 """
464 if backend == "plotly":
465 if scatter_kws is None:
466 scatter_kws = {"size_max": 1, "opacity": 0.2}
467 if hist_kws is None:
468 hist_kws = {"bins": 50}
469 elif backend == "matplotlib":
470 if scatter_kws is None:
471 scatter_kws = {"s": 20}
472 if hist_kws is None:
473 hist_kws = {"bins": 50}
474 else:
475 raise Exception(f"Argument `backend` must be either 'plotly' or 'matplotlib'.")
476 if np.any(np.isnan(lowd)):
477 if backend == "plotly":
478 image = _gen_nan_image_plotly()
479 else:
480 image = _gen_nan_image_matplotlib()
481 with tf.name_scope("Latent Scatter"):
482 tf.summary.image(f"Latent at step {step}", image, step=step)
483 return
484 if backend == "plotly":
485 scatter_image = _gen_scatter_plotly(lowd[:, :2], scatter_kws)
486 hist_image = _gen_hist_plotly(lowd[:, :2], hist_kws)
487 else:
488 scatter_image = _gen_scatter_matplotlib(lowd[:, :2], scatter_kws)
489 hist_image = _gen_hist_matplotlib(lowd[:, :2], hist_kws)
490 with tf.name_scope("Latent Output"):
491 tf.summary.image(f"Latent Scatter", scatter_image, step=step)
492 tf.summary.image(f"Latent Density", hist_image, step=step)
494 if additional_fns is not None:
495 with tf.name_scope("User Provided Plotting Functions"):
496 for i, fn in enumerate(additional_fns):
497 tf.summary.image(f"User Plotting {i}", fn(lowd), step=step)
500def _gen_hist_matplotlib(
501 data: np.ndarray,
502 hist_kws: dict[str, Any],
503) -> tf.Tensor:
504 """Creates matplotlib histogram and returns tensorflow Tensor that represents an image.
506 Args:
507 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2.
508 1st dimension the datapoints, 2nd dimension x, y.
509 hist_kws (dict): Additional keywords to be passed to matplotlib.pyplot.hist2d().
511 Returns:
512 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image().
514 """
515 plt.close("all")
516 matplotlib.use("Agg") # overwrites current backend of notebook
517 plt.figure()
518 plt.hist2d(*data.T, **hist_kws)
519 buf = io.BytesIO()
520 plt.savefig(buf, format="png")
521 buf.seek(0)
522 image = tf.image.decode_png(buf.getvalue(), 4)
523 image = tf.expand_dims(image, 0)
524 return image
527def _gen_hist_plotly(
528 data: np.ndarray,
529 hist_kws: dict[str, Any],
530) -> tf.Tensor:
531 """Creates matplotlib histogram and returns tensorflow Tensor that
532 represents an image.
534 Args:
535 data (Union[np.ndarray, tf.Tensor]): The xy data to be used.
536 `data.ndim` should be 2. 1st dimension the datapoints, 2nd dimension x, y.
537 hist_kws (Optional[dict[str, Any]]): A dict with items that
538 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a
539 dict with bins 50 will be passed to
540 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`.
541 You can choose a colormap here by providing `{'bins': 50, 'cmap':
542 'plasma'}` for this argument.
544 Returns:
545 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image().
547 """
548 trace = _plot_free_energy(
549 x=data[:, 0],
550 y=data[:, 1],
551 **hist_kws,
552 )
553 fig = go.Figure(
554 data=[
555 trace,
556 ],
557 layout={
558 "width": 500,
559 "height": 500,
560 "margin": {"l": 0, "r": 0, "t": 0, "b": 0},
561 },
562 )
563 buf = io.BytesIO()
564 fig.write_image(buf)
565 buf.seek(0)
566 image = tf.image.decode_png(buf.getvalue(), 4)
567 image = tf.expand_dims(image, 0)
568 return image
571def _gen_nan_image_plotly() -> tf.Tensor:
572 """Creates matplotlib image, with debug info.
574 Returns:
575 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image().
577 """
578 fig = go.Figure(
579 layout={
580 "height": 500,
581 "width": 500,
582 "margin": {"l": 0, "r": 0, "t": 0, "b": 0},
583 }
584 )
585 fig.add_annotation(
586 x=2.5,
587 y=1.5,
588 text="Some data of lowd is nan",
589 showarrow=False,
590 font={"size": 36},
591 )
592 buf = io.BytesIO()
593 fig.write_image(buf)
594 buf.seek(0)
595 image = tf.image.decode_png(buf.getvalue(), 4)
596 image = tf.expand_dims(image, 0)
597 return image
600def _gen_nan_image_matplotlib() -> tf.Tensor:
601 """Creates matplotlib image, with debug info.
603 Returns:
604 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image().
606 """
607 plt.close("all")
608 matplotlib.use("Agg") # overwrites current backend of notebook
609 fig, ax = plt.subplots()
610 ax.text(
611 0.5,
612 0.5,
613 "Some data of lowd is nan",
614 ha="center",
615 va="center",
616 transform=ax.transAxes,
617 )
618 buf = io.BytesIO()
619 plt.savefig(buf, format="png")
620 buf.seek(0)
621 image = tf.image.decode_png(buf.getvalue(), 4)
622 image = tf.expand_dims(image, 0)
623 return image
626def _gen_scatter_plotly(
627 data: np.ndarray,
628 scatter_kws: dict[str, Any],
629) -> tf.Tensor:
630 """Creates matplotlib scatter plot and returns tensorflow Tensor that represents an image.
632 Args:
633 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2.
634 1st dimension the datapoints, 2nd dimension x, y.
635 scatter_kws (Optional[dict[str, Any]]): A dict with items that
636 `plotly.express.scatter()` will accept. If None is provided,
637 a dict with size 20 will be passed to
638 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`,
639 which sets an appropriate size of scatter points for the size of
640 datasets encodermap is usually used for.
642 Returns:
643 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image().
645 """
646 if not isinstance(data, np.ndarray):
647 data = data.numpy()
648 fig = px.scatter(
649 x=data[:, 0],
650 y=data[:, 1],
651 height=500,
652 width=500,
653 labels={"x": "", "y": ""},
654 **scatter_kws,
655 )
656 fig.update_layout(
657 margin={"l": 0, "r": 0, "t": 0, "b": 0},
658 coloraxis_showscale=False,
659 showlegend=False,
660 )
661 buf = io.BytesIO()
662 fig.write_image(buf)
663 buf.seek(0)
664 image = tf.image.decode_png(buf.getvalue(), 4)
665 image = tf.expand_dims(image, 0)
666 return image
669def _gen_scatter_matplotlib(
670 data: np.ndarray,
671 scatter_kws: dict[str, Any],
672) -> tf.Tensor:
673 """Creates matplotlib scatter plot and returns tensorflow Tensor that represents an image.
675 Args:
676 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2.
677 1st dimension the datapoints, 2nd dimension x, y.
678 scatter_kws (dict): Additional keywords to be passed to matplotlib.pyplot.scatter().
680 Returns:
681 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image().
683 """
684 plt.close("all")
685 matplotlib.use("Agg") # overwrites current backend of notebook
686 plt.figure()
687 if isinstance(data, np.ndarray):
688 plt.scatter(*data.T, **scatter_kws)
689 else:
690 plt.scatter(*data.numpy().T, **scatter_kws)
691 buf = io.BytesIO()
692 plt.savefig(buf, format="png")
693 buf.seek(0)
694 image = tf.image.decode_png(buf.getvalue(), 4)
695 image = tf.expand_dims(image, 0)
696 return image