Coverage for encodermap/misc/summaries.py: 12%

140 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/misc/summaries.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22 

23""" 

24Functions that write stuff to tensorboard. Mainly used for the image callbacks. 

25""" 

26 

27################################################################################ 

28# Imports 

29################################################################################ 

30 

31 

32# Future Imports at the top 

33from __future__ import annotations 

34 

35# Standard Library Imports 

36import io 

37from collections.abc import Callable, Sequence 

38from typing import Any, Literal, Optional, Union 

39 

40# Third Party Imports 

41import matplotlib 

42import matplotlib.pyplot as plt 

43import numpy as np 

44import tensorflow as tf 

45from optional_imports import _optional_import 

46 

47# Encodermap imports 

48from encodermap.plot.plotting import _plot_free_energy 

49 

50 

51################################################################################ 

52# Optional Imports 

53################################################################################ 

54 

55 

56px = _optional_import("plotly", "express") 

57go = _optional_import("plotly", "graph_objects") 

58 

59 

60################################################################################ 

61# Globals 

62################################################################################ 

63 

64 

65__all__: list[str] = ["add_layer_summaries", "image_summary"] 

66 

67 

68################################################################################ 

69# Summary Functions 

70################################################################################ 

71 

72 

73def add_layer_summaries( 

74 layer: tf.keras.layers.Layer, 

75 step: Optional[int] = None, 

76) -> None: 

77 """Adds summaries for a layer to Tensorboard. 

78 

79 Args: 

80 layer (tf.keras.layers.Layer): The layer. 

81 step (Union[tf.Tensor, int, None], optional): The current step. 

82 Can be either a Tensor or None. Defaults to None. 

83 

84 """ 

85 weights = layer.variables[0] 

86 biases = layer.variables[1] 

87 if "encoder" in layer.name.lower(): 

88 namescope = "Encoder" 

89 elif "decoder" in layer.name.lower(): 

90 namescope = "Decoder" 

91 elif "latent" in layer.name.lower(): 

92 namescope = "Latent" 

93 else: 

94 namescope = "InputOutputLayers" 

95 variable_summaries(namescope, layer.name + "/weights", weights, step) 

96 variable_summaries(namescope, layer.name + "/biases", biases, step) 

97 

98 

99def variable_summaries( 

100 namescope: str, 

101 name: str, 

102 variables: tf.Tensor, 

103 step: Optional[int] = None, 

104) -> None: 

105 """ 

106 Attach several summaries to a Tensor for TensorBoard visualization. 

107 

108 Args: 

109 namescope (str): The string to prepend to the layer names. 

110 Makes it easier to group the layers. 

111 name (str): The name of the layer. 

112 variables (tf.Tensor): The variables (weighhts, biases) of the layer. 

113 step (Union[tf.Tensor, int, None], optional): The current step. 

114 Can be either a Tensor or None. Defaults to None. 

115 

116 """ 

117 if not isinstance(variables, list): 

118 variables = [variables] 

119 

120 for i, var in enumerate(variables): 

121 try: 

122 add_index = len(variables) > 1 

123 except TypeError: 

124 add_index = True 

125 if add_index: 

126 name = name + str(i) 

127 with tf.name_scope(namescope + "/" + name): 

128 mean = tf.reduce_mean(var) 

129 tf.summary.scalar("mean", mean, step=step) 

130 with tf.name_scope("stddev"): 

131 stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 

132 tf.summary.scalar("stddev", stddev, step=step) 

133 tf.summary.scalar("max", tf.reduce_max(var), step=step) 

134 tf.summary.scalar("min", tf.reduce_min(var), step=step) 

135 tf.summary.histogram("histogram", var, step=step) 

136 tf.compat.v1.summary.tensor_summary("values", var) 

137 

138 

139# def intermediate_summary( 

140# highd_data: Union[np.ndarray, Sequence[np.ndarray]], 

141# step: int, 

142# model: tf.keras.Model, 

143# parameters: "AnyParameters", 

144# backend: Literal["matplotlib", "plotly"] = "matplotlib", 

145# ) -> dict[str, io.BytesIO]: 

146# """Produces images of intermediate layers during training. 

147# 

148# Args: 

149# highd_data (Union[np.ndarray, Sequence[np.ndarray]]): The high dimensional 

150# data to be used with the autoencoder model. If the autoencoder is 

151# of type `Autoencoder`, `EncoderMap`, or `DihedralEncoderMap`, the 

152# `highd_data` should be a single np.ndarray. For the `AngleDihedralCartesianEncderMap`, 

153# the `highd_data` should be a Sequence of np.ndarray in the sequence 

154# angles, dihedrals, cartesians, distance, (and maybe side_dihedrals, 

155# depending on whether side_dihedrals are used for the training). 

156# step (int): The current training ste. 

157# model (tf.keras.Model): The autoencoder model. 

158# parameters (AnyParameters): Either an instance of `encodermap.parameters.Parameters`, or 

159# `encodermap.parameters.ADCParameters`, depending on which autoencoder model is used. 

160# The parameters are used to fix periodic input data. 

161# backend (Literal["matplotlib", "plotly"]: Which backend to use for 

162# plotting. Defaults to 'matplotlib'. 

163# 

164# """ 

165# if model.encoder_model.layers[0].name == "Encoder_0": 

166# return _encodermap_intermediate_summary( 

167# highd_data, 

168# step, 

169# model, 

170# parameters, 

171# backend, 

172# ) 

173# else: 

174# return _angledihedralcartesianencodermap_intermediate_summary( 

175# highd_data, 

176# step, 

177# model, 

178# parameters, 

179# backend, 

180# ) 

181# 

182# 

183# def _layer_output_and_input_summary( 

184# data: np.ndarray, 

185# name: str, 

186# step: int, 

187## type: Literal["input", "layer"] = "layer", 

188# backend: Literal["matplotlib", "plotly"] = "matplotlib", 

189# ) -> io.BytesIO: 

190# buf = io.BytesIO() 

191# raise Exception(f"Here") 

192# if isinstance(data, tf.SparseTensor): 

193# data = tf.sparse.to_dense(data, default_value=np.nan) 

194# if isinstance(data, tf.Tensor): 

195# data = data.numpy() 

196# mean = np.mean(data, 0) 

197# if backend == "plotly": 

198# fig = px.bar( 

199# x=mean, 

200# y=np.arange(data.shape[1]), 

201# orientation="h", 

202# color=mean, 

203# color_continuous_scale="Viridis", 

204# labels={ 

205# "x": f"mean value of {data.shape[0]} samples", 

206# "y": "feature" if type == "input" else "neuron", 

207# }, 

208# width=500, 

209# height=500, 

210# ) 

211# fig.update_layout( 

212# margin={"l": 0, "r": 0, "t": 0, "b": 0}, 

213# coloraxis_showscale=False, 

214# ) 

215# fig.write_image(buf) 

216# elif backend == "matplotlib": 

217# raise NotImplementedError(f"{mean.shape=}") 

218# else: 

219# raise Exception( 

220# f"Argument `backend` must be either 'plotly' or 'matplotlib'." 

221# ) 

222# image = tf.image.decode_png(buf.getvalue(), 4) 

223# image = tf.expand_dims(image, 0) 

224# with tf.name_scope("Layer Outputs"): 

225# tf.summary.image(name, image, step=step) 

226# buf.seek(0) 

227# return buf 

228# 

229# 

230# def _encodermap_intermediate_summary( 

231# highd_data: Union[np.ndarray, Sequence[np.ndarray]], 

232# step: int, 

233# model: tf.keras.Model, 

234# parameters: "AnyParameters", 

235# backend: Literal["matplotlib", "plotly"] = "matplotlib", 

236# ) -> dict[str, io.BytesIO]: 

237# # Local Folder Imports 

238# 

239# out = {} 

240# layers = [] 

241# for layer in model.encoder_model.layers: 

242# layers.append(layer) 

243# for layer in model.decoder_model.layers: 

244# layers.append(layer) 

245# 

246# if model.sparse: 

247# input = model.get_dense_model(highd_data) 

248# else: 

249# input = highd_data 

250# 

251# out["input"] = _layer_output_and_input_summary( 

252# input, name="Input", step=step, type="input", backend=backend 

253# ) 

254# 

255# if parameters.periodicity != float("inf"): 

256# if parameters.periodicity != 2 * np.pi: 

257# input = input / parameters.periodicity * 2 * np.pi 

258# input = tf.concat([tf.sin(input), tf.cos(input)], 1) 

259# 

260# for layer in layers[:-1]: 

261# input = layer(input) 

262# out[layer.name] = _layer_output_and_input_summary(input, layer.name, step=step, backend=backend) 

263# 

264# # last layer needs to be treated individually because of maybe being periodic 

265# input = layers[-1](input) 

266# if parameters.periodicity != float("inf"): 

267# input = tf.atan2(*tf.split(input, 2, 1)) 

268# if parameters.periodicity != 2 * np.pi: 

269# input = input / (2 * np.pi) * p.periodicity 

270# out[layer.name] = _layer_output_and_input_summary( 

271# input, layers[-1].name, step=step, backend=backend 

272# ) 

273# return out 

274# 

275# 

276# def _angledihedralcartesianencodermap_intermediate_summary( 

277# highd_data: Union[np.ndarray, Sequence[np.ndarray]], 

278# step: int, 

279# model: tf.keras.Model, 

280# parameters: "AnyParameters", 

281# backend: Literal["matplotlib", "plotly"] = "matplotlib", 

282# ) -> dict[str, io.BytesIO]: 

283# out = {} 

284# 

285# if isinstance(highd_data, (list, tuple)): 

286# if len(highd_data) == 1: 

287# dihedrals = highd_data[0] 

288# angles = None 

289# side_dihedrals = None 

290# elif len(highd_data) == 2: 

291# angles, dihedrals = highd_data 

292# side_dihedrals = None 

293# elif len(highd_data) == 3: 

294# angles, dihedrals, side_dihedrals = highd_data 

295# else: 

296# raise Exception( 

297# f"Can't construct intermediate layer outputs for ADCEMap when " 

298# f"provided high-dimensional contains more than 3 elements." 

299# ) 

300# else: 

301# dihedrals = highd_data 

302# angles = None 

303# side_dihedrals = None 

304# 

305# layers = {} 

306# for layer in model.encoder_model.layers: 

307# layers[layer.name] = layer 

308# for layer in model.decoder_model.layers: 

309# layers[layer.name] = layer 

310# 

311# # angles 

312# if angles is not None: 

313# layer = layers["input_angles_to_unit_circle"] 

314# raise Exception(f"Here {layer=} {model.__class__.__name__=} {angles.shape=} {model.get_dense_model_central_angles.input_shape=}") 

315# if model.__class__.__name__ == "ADCSparseFunctionalModel": 

316# if model.get_dense_model_central_angles is not None: 

317# angles = model.get_dense_model_central_angles( 

318# angles 

319# ) 

320# raise Exception(f"Here") 

321# angles = layer(angles) 

322# out[layer.name] = _layer_output_and_input_summary(angles, layer.name, step=step, backend=backend) 

323# 

324# raise Exception("Here") 

325# 

326# # sidechain dihedrals 

327# if side_dihedrals is not None: 

328# layer = layers["input_side_dihedrals_to_unit_circle"] 

329# out["input_side_dihedrals"] = _layer_output_and_input_summary( 

330# side_dihedrals, name="Input Side Dihedrals", step=step, type="input" 

331# ) 

332# if model.__class__.__name__ == "ADCSparseFunctionalModel": 

333# if model.get_dense_model_side_dihedrals is not None: 

334# side_dihedrals = model.get_dense_model_side_dihedrals( 

335# side_dihedrals 

336# ) 

337# side_dihedrals = layer(side_dihedrals) 

338# out[layer.name] = _layer_output_and_input_summary( 

339# side_dihedrals, layer.name, step=step, backend=backend 

340# ) 

341# 

342# layer = layers["input_central_dihedrals_to_unit_circle"] 

343# out["input_dihedrals"] = _layer_output_and_input_summary( 

344# dihedrals, name="Input Dihedrals", step=step, type="input" 

345# ) 

346# if model.__class__.__name__ == "ADCSparseFunctionalModel": 

347# if model.get_dense_model_central_dihedrals is not None: 

348# dihedrals = model.get_dense_model_side_dihedrals( 

349# dihedrals 

350# ) 

351# dihedrals = layer(dihedrals) 

352# out[layer.name] = _layer_output_and_input_summary(dihedrals, layer.name, step=step, backend=backend) 

353# 

354# # concatenate 

355# if angles is not None and side_dihedrals is not None: 

356# layer = layers["concatenate_angular_inputs"] 

357# input = layer((angles, dihedrals, side_dihedrals)) 

358# splits = [angles.shape[1], dihedrals.shape[1], side_dihedrals.shape[1]] 

359# elif angles is not None and side_dihedrals is None: 

360# layer = layers["concatenate_angular_inputs"] 

361# input = layer((angles, dihedrals)) 

362# splits = [angles.shape[1], dihedrals.shape[1]] 

363# else: 

364# input = dihedrals 

365# splits = [dihedrals.shape[1]] 

366# out[layer.name] = _layer_output_and_input_summary(input, layer.name, step=step, backend=backend) 

367# 

368# # Encoder 

369# i = 0 

370# while True: 

371# try: 

372# layer = layers[f"Encoder_{i}"] 

373# except KeyError: 

374# break 

375# input = layer(input) 

376# out[layer.name] = _layer_output_and_input_summary(input, layer.name, step=step, backend=backend) 

377# i += 1 

378# 

379# # Decoder 

380# i = 0 

381# while True: 

382# try: 

383# layer = layers[f"Decoder_{i}"] 

384# except KeyError: 

385# break 

386# input = layer(input) 

387# out[layer.name] = _layer_output_and_input_summary(input, layer.name, step=step, backend=backend) 

388# i += 1 

389# 

390# # split 

391# if angles is not None and side_dihedrals is not None: 

392# input = tf.split(input, splits, 1) 

393# angles, dihedrals, side_dihedrals = input 

394# elif angles is not None and side_dihedrals is None: 

395# input = tf.split(input, splits, 1) 

396# angles, dihedrals = input 

397# else: 

398# dihedrals = input 

399# 

400# # rejig 

401# if angles is not None: 

402# layer = layers["angles_from_unit_circle"] 

403# angles = layer(angles) 

404# out["output_angles"] = _layer_output_and_input_summary( 

405# angles, name="Output Angles", step=step, backend=backend 

406# ) 

407# 

408# # sidechain dihedrals 

409# if side_dihedrals is not None: 

410# layer = layers["side_dihedrals_from_unit_circle"] 

411# side_dihedrals = layer(side_dihedrals) 

412# out["output_side_dihedrals"] = _layer_output_and_input_summary( 

413# side_dihedrals, name="Output Side Dihedrals", step=step, backend=backend 

414# ) 

415# 

416# layer = layers["dihedrals_from_unit_circle"] 

417# dihedrals = layer(dihedrals) 

418# out["output_dihedrals"] = _layer_output_and_input_summary( 

419# dihedrals, name="Output Dihedrals", step=step, backend=backend 

420# ) 

421# return out 

422 

423 

424def image_summary( 

425 lowd: np.ndarray, 

426 step: Optional[int] = None, 

427 scatter_kws: Optional[dict[str, Any]] = None, 

428 hist_kws: Optional[dict[str, Any]] = None, 

429 additional_fns: Optional[Sequence[Callable]] = None, 

430 backend: Literal["matplotlib", "plotly"] = "matplotlib", 

431) -> None: 

432 """Writes an image to Tensorboard. 

433 

434 Args: 

435 lowd (np.ndarray): The data to plot. Usually that 

436 will be the output of the latent space of the Autoencoder. 

437 This array has to be of dimensionality 2 (rows and columns). 

438 The first two points of the rows will be used as xy coordinates 

439 in a scatter plot. 

440 step (Optional[int]): The training step under which you can find the 

441 image in tensorboard. Defaults to None. 

442 scatter_kws (Optional[dict[str, Any]]): A dict with items that 

443 `plotly.express.scatter()` will accept. If None is provided, 

444 a dict with size 20 will be passed to 

445 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`, 

446 which sets an appropriate size of scatter points for the size of 

447 datasets encodermap is usually used for. 

448 hist_kws (Optional[dict[str, Any]]): A dict with items that 

449 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a 

450 dict with bins 50 will be passed to 

451 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`. 

452 You can choose a colormap here by providing `{'bins': 50, 'cmap': 

453 'plasma'}` for this argument. 

454 additional_fns (Optional[Sequence[Callable]]): A sequence of functions that 

455 take the data of the latent space and return a tf.Tensor that can 

456 be logged to tensorboard with tf.summary.image(). 

457 backend (Literal["matplotlib", "plotly"]: Which backend to use for 

458 plotting. Defaults to 'matplotlib'. 

459 

460 Raises: 

461 AssertionError: When lowd.ndim is not 2 and when len(lowd) != len(ids) 

462 

463 """ 

464 if backend == "plotly": 

465 if scatter_kws is None: 

466 scatter_kws = {"size_max": 1, "opacity": 0.2} 

467 if hist_kws is None: 

468 hist_kws = {"bins": 50} 

469 elif backend == "matplotlib": 

470 if scatter_kws is None: 

471 scatter_kws = {"s": 20} 

472 if hist_kws is None: 

473 hist_kws = {"bins": 50} 

474 else: 

475 raise Exception(f"Argument `backend` must be either 'plotly' or 'matplotlib'.") 

476 if np.any(np.isnan(lowd)): 

477 if backend == "plotly": 

478 image = _gen_nan_image_plotly() 

479 else: 

480 image = _gen_nan_image_matplotlib() 

481 with tf.name_scope("Latent Scatter"): 

482 tf.summary.image(f"Latent at step {step}", image, step=step) 

483 return 

484 if backend == "plotly": 

485 scatter_image = _gen_scatter_plotly(lowd[:, :2], scatter_kws) 

486 hist_image = _gen_hist_plotly(lowd[:, :2], hist_kws) 

487 else: 

488 scatter_image = _gen_scatter_matplotlib(lowd[:, :2], scatter_kws) 

489 hist_image = _gen_hist_matplotlib(lowd[:, :2], hist_kws) 

490 with tf.name_scope("Latent Output"): 

491 tf.summary.image(f"Latent Scatter", scatter_image, step=step) 

492 tf.summary.image(f"Latent Density", hist_image, step=step) 

493 

494 if additional_fns is not None: 

495 with tf.name_scope("User Provided Plotting Functions"): 

496 for i, fn in enumerate(additional_fns): 

497 tf.summary.image(f"User Plotting {i}", fn(lowd), step=step) 

498 

499 

500def _gen_hist_matplotlib( 

501 data: np.ndarray, 

502 hist_kws: dict[str, Any], 

503) -> tf.Tensor: 

504 """Creates matplotlib histogram and returns tensorflow Tensor that represents an image. 

505 

506 Args: 

507 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2. 

508 1st dimension the datapoints, 2nd dimension x, y. 

509 hist_kws (dict): Additional keywords to be passed to matplotlib.pyplot.hist2d(). 

510 

511 Returns: 

512 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image(). 

513 

514 """ 

515 plt.close("all") 

516 matplotlib.use("Agg") # overwrites current backend of notebook 

517 plt.figure() 

518 plt.hist2d(*data.T, **hist_kws) 

519 buf = io.BytesIO() 

520 plt.savefig(buf, format="png") 

521 buf.seek(0) 

522 image = tf.image.decode_png(buf.getvalue(), 4) 

523 image = tf.expand_dims(image, 0) 

524 return image 

525 

526 

527def _gen_hist_plotly( 

528 data: np.ndarray, 

529 hist_kws: dict[str, Any], 

530) -> tf.Tensor: 

531 """Creates matplotlib histogram and returns tensorflow Tensor that 

532 represents an image. 

533 

534 Args: 

535 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. 

536 `data.ndim` should be 2. 1st dimension the datapoints, 2nd dimension x, y. 

537 hist_kws (Optional[dict[str, Any]]): A dict with items that 

538 `encodermap.plot.plotting._plot_free_energy()` will accept. If None is provided a 

539 dict with bins 50 will be passed to 

540 `encodermap.plot.plotting._plot_free_energy(**{'bins': 50})`. 

541 You can choose a colormap here by providing `{'bins': 50, 'cmap': 

542 'plasma'}` for this argument. 

543 

544 Returns: 

545 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image(). 

546 

547 """ 

548 trace = _plot_free_energy( 

549 x=data[:, 0], 

550 y=data[:, 1], 

551 **hist_kws, 

552 ) 

553 fig = go.Figure( 

554 data=[ 

555 trace, 

556 ], 

557 layout={ 

558 "width": 500, 

559 "height": 500, 

560 "margin": {"l": 0, "r": 0, "t": 0, "b": 0}, 

561 }, 

562 ) 

563 buf = io.BytesIO() 

564 fig.write_image(buf) 

565 buf.seek(0) 

566 image = tf.image.decode_png(buf.getvalue(), 4) 

567 image = tf.expand_dims(image, 0) 

568 return image 

569 

570 

571def _gen_nan_image_plotly() -> tf.Tensor: 

572 """Creates matplotlib image, with debug info. 

573 

574 Returns: 

575 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image(). 

576 

577 """ 

578 fig = go.Figure( 

579 layout={ 

580 "height": 500, 

581 "width": 500, 

582 "margin": {"l": 0, "r": 0, "t": 0, "b": 0}, 

583 } 

584 ) 

585 fig.add_annotation( 

586 x=2.5, 

587 y=1.5, 

588 text="Some data of lowd is nan", 

589 showarrow=False, 

590 font={"size": 36}, 

591 ) 

592 buf = io.BytesIO() 

593 fig.write_image(buf) 

594 buf.seek(0) 

595 image = tf.image.decode_png(buf.getvalue(), 4) 

596 image = tf.expand_dims(image, 0) 

597 return image 

598 

599 

600def _gen_nan_image_matplotlib() -> tf.Tensor: 

601 """Creates matplotlib image, with debug info. 

602 

603 Returns: 

604 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image(). 

605 

606 """ 

607 plt.close("all") 

608 matplotlib.use("Agg") # overwrites current backend of notebook 

609 fig, ax = plt.subplots() 

610 ax.text( 

611 0.5, 

612 0.5, 

613 "Some data of lowd is nan", 

614 ha="center", 

615 va="center", 

616 transform=ax.transAxes, 

617 ) 

618 buf = io.BytesIO() 

619 plt.savefig(buf, format="png") 

620 buf.seek(0) 

621 image = tf.image.decode_png(buf.getvalue(), 4) 

622 image = tf.expand_dims(image, 0) 

623 return image 

624 

625 

626def _gen_scatter_plotly( 

627 data: np.ndarray, 

628 scatter_kws: dict[str, Any], 

629) -> tf.Tensor: 

630 """Creates matplotlib scatter plot and returns tensorflow Tensor that represents an image. 

631 

632 Args: 

633 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2. 

634 1st dimension the datapoints, 2nd dimension x, y. 

635 scatter_kws (Optional[dict[str, Any]]): A dict with items that 

636 `plotly.express.scatter()` will accept. If None is provided, 

637 a dict with size 20 will be passed to 

638 `px.scatter(**{'size_max': 10, 'opacity': 0.2})`, 

639 which sets an appropriate size of scatter points for the size of 

640 datasets encodermap is usually used for. 

641 

642 Returns: 

643 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image(). 

644 

645 """ 

646 if not isinstance(data, np.ndarray): 

647 data = data.numpy() 

648 fig = px.scatter( 

649 x=data[:, 0], 

650 y=data[:, 1], 

651 height=500, 

652 width=500, 

653 labels={"x": "", "y": ""}, 

654 **scatter_kws, 

655 ) 

656 fig.update_layout( 

657 margin={"l": 0, "r": 0, "t": 0, "b": 0}, 

658 coloraxis_showscale=False, 

659 showlegend=False, 

660 ) 

661 buf = io.BytesIO() 

662 fig.write_image(buf) 

663 buf.seek(0) 

664 image = tf.image.decode_png(buf.getvalue(), 4) 

665 image = tf.expand_dims(image, 0) 

666 return image 

667 

668 

669def _gen_scatter_matplotlib( 

670 data: np.ndarray, 

671 scatter_kws: dict[str, Any], 

672) -> tf.Tensor: 

673 """Creates matplotlib scatter plot and returns tensorflow Tensor that represents an image. 

674 

675 Args: 

676 data (Union[np.ndarray, tf.Tensor]): The xy data to be used. data.ndim should be 2. 

677 1st dimension the datapoints, 2nd dimension x, y. 

678 scatter_kws (dict): Additional keywords to be passed to matplotlib.pyplot.scatter(). 

679 

680 Returns: 

681 tf.Tensor: A tensorflow tensor that can be written to Tensorboard with tf.summary.image(). 

682 

683 """ 

684 plt.close("all") 

685 matplotlib.use("Agg") # overwrites current backend of notebook 

686 plt.figure() 

687 if isinstance(data, np.ndarray): 

688 plt.scatter(*data.T, **scatter_kws) 

689 else: 

690 plt.scatter(*data.numpy().T, **scatter_kws) 

691 buf = io.BytesIO() 

692 plt.savefig(buf, format="png") 

693 buf.seek(0) 

694 image = tf.image.decode_png(buf.getvalue(), 4) 

695 image = tf.expand_dims(image, 0) 

696 return image