Coverage for encodermap/misc/distances.py: 18%

53 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-12-31 16:54 +0100

1# -*- coding: utf-8 -*- 

2# encodermap/misc/distances.py 

3################################################################################ 

4# EncoderMap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2024 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22"""EncoderMap's implements different distance computations. 

23 

24* Normal: Euclidean distance between two points. 

25* Periodic: Euclidean distance between two points lying in a periodic space. 

26* Pairwise: Euclidean distance between sets of points. Either with or without periodicity. 

27 

28""" 

29################################################################################ 

30# Imports 

31################################################################################ 

32 

33# Standard Library Imports 

34from collections.abc import Callable 

35from math import pi 

36from numbers import Number 

37from typing import Union, overload 

38 

39# Third Party Imports 

40import numpy as np 

41import tensorflow as tf 

42 

43 

44################################################################################ 

45# Typing 

46################################################################################ 

47 

48 

49################################################################################ 

50# Globals 

51################################################################################ 

52 

53__all__: list[str] = [ 

54 "sigmoid", 

55 "periodic_distance", 

56 "periodic_distance_np", 

57 "pairwise_dist", 

58 "pairwise_dist_periodic", 

59] 

60 

61################################################################################ 

62# Functions 

63################################################################################ 

64 

65 

66def sigmoid(sig: float, a: float, b: float) -> Callable: 

67 """Returns a sigmoid function with specified parameters. 

68 

69 Args: 

70 sig (float): Sigma. 

71 a (float): a. 

72 b (float): b. 

73 

74 Returns: 

75 function: A function that can be used to calculate the sigmoid with the 

76 specified parameters. 

77 """ 

78 

79 @overload 

80 def func(r: Number) -> Number: ... 

81 

82 @overload 

83 def func(r: np.ndarray) -> np.ndarray: ... 

84 

85 def func(r: Union[Number, np.ndarray]) -> Union[Number, np.ndarray]: 

86 return 1 - (1 + (2 ** (a / b) - 1) * (r / sig) ** a) ** (-b / a) 

87 

88 return func 

89 

90 

91def periodic_distance_np( 

92 a: np.ndarray, b: np.ndarray, periodicity: float = 2 * pi 

93) -> np.ndarray: 

94 """Calculates distance between two points and respects periodicity. 

95 

96 If the provided dataset is periodic (i.e. angles and torsion angles), the returned 

97 distance is corrected. 

98 

99 Args: 

100 a (np.ndarray): Coordinate of point a. 

101 b (np.ndarray): Coordinate of point b. 

102 periodicity (float): The periodicity (i.e. the box length/ maximum angle) 

103 of your data. Defaults to 2*pi. Provide float('inf') for no periodicity. 

104 

105 Returns: 

106 np.ndarray: The distances accounting for periodicity. 

107 

108 """ 

109 d = np.abs(b - a) 

110 return np.minimum(d, periodicity - d) 

111 

112 

113def periodic_distance( 

114 a: tf.Tensor, b: tf.Tensor, periodicity: float = 2 * pi 

115) -> tf.Tensor: 

116 """Calculates distance between two points and respects periodicity. 

117 

118 If the provided dataset is periodic (i.e. angles and torsion angles), the returned 

119 distance is corrected. 

120 

121 Args: 

122 a (tf.Tensor): Coordinate of point a. 

123 b (tf.Tensor): Coordinate of point b. 

124 periodicity (float): The periodicity (i.e. the box length/ maximum angle) 

125 of your data. Defaults to 2*pi. Provide float('inf') for no periodicity. 

126 

127 Returns: 

128 tf.Tensor: The distances accounting for periodicity. 

129 

130 Example: 

131 >>> import encodermap as em 

132 >>> x = tf.convert_to_tensor(np.array([[1.5], [1.5]])) 

133 >>> y = tf.convert_to_tensor(np.array([[-3.1], [-3.1]])) 

134 >>> r = em.misc.periodic_distance(x, y) 

135 >>> print(r.numpy()) 

136 [[1.68318531] 

137 [1.68318531]] 

138 

139 """ 

140 d = tf.abs(b - a) 

141 return tf.minimum(d, periodicity - d) 

142 

143 

144def pairwise_dist_periodic( 

145 positions: tf.Tensor, 

146 periodicity: float, 

147) -> tf.Tensor: 

148 """Pairwise distances using periodicity. 

149 

150 Args: 

151 positions (tf.Tensor): The positions of the points. 

152 Currently only 2D arrays with positions.shape[0] == n_points 

153 and positions.shape[1] == 1 (rotational values) is supported. 

154 periodicity (float): The periodicity of the data. Most often 

155 you will use either 2*pi or 360. 

156 

157 Returns: 

158 tf.Tensor: The dists. 

159 

160 """ 

161 assert len(positions.shape) == 2 

162 if not tf.debugging.is_numeric_tensor(positions): 

163 positions = tf.convert_to_tensor(positions) 

164 vecs = periodic_distance( 

165 tf.expand_dims(positions, axis=1), 

166 tf.expand_dims(positions, axis=0), 

167 periodicity, 

168 ) 

169 mask = tf.cast(tf.equal(vecs, 0.0), "float32") 

170 vecs = vecs + mask * 1e-12 # gradient infinite for 0 

171 # dists = tf.norm(vecs, axis=2) # gradient still becomes infinite 

172 # might be a problem with tf.norm() 

173 # see here: 

174 # https://datascience.stackexchange.com/q/80898 

175 dists = tf.sqrt(tf.reduce_sum(tf.square(vecs), axis=2)) + 1.0e-12 

176 return dists 

177 

178 

179def pairwise_dist( 

180 positions: tf.Tensor, 

181 squared: bool = False, 

182 flat: bool = False, 

183) -> tf.Tensor: 

184 """Tensorflow implementation of `scipy.spatial.distances.cdist`. 

185 

186 Returns a tensor with shape (positions.shape[1], positions.shape[1]). 

187 This tensor is the distance matrix of the provided positions. The 

188 matrix is hollow, i.e., the diagonal elements are zero. 

189 

190 Thanks to https://omoindrot.github.io/triplet-loss 

191 for this implementation. Find an archived link here: 

192 https://archive.is/lNT2L 

193 

194 Args: 

195 positions (Union[np.ndarray, tf.Tensor]): Collection of 

196 n-dimensional points. `positions[0]` are points. 

197 `positions[1]` are dimensions. 

198 squared (bool): Whether to return the pairwise squared 

199 Euclidean distance matrix or normal Euclidean distance matrix. 

200 Defaults to False. 

201 flat (bool): Whether to return only the lower triangle of 

202 the hollow matrix. Setting this to true mimics the behavior 

203 of `scipy.spatial.distance.pdist`. Defaults to False. 

204 

205 Returns: 

206 tf.Tensor: The distances. 

207 

208 """ 

209 if not tf.debugging.is_numeric_tensor(positions): 

210 positions = tf.convert_to_tensor(positions) 

211 if len(positions.get_shape()) == 2: 

212 positions = tf.expand_dims(positions, 0) 

213 

214 # Get the dot product between all embeddings 

215 # shape (batch_size, batch_size) 

216 dot_product = tf.matmul(positions, tf.transpose(positions, [0, 2, 1])) 

217 

218 # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. 

219 # This also provides more numerical stability (the diagonal of the result will be exactly 0). 

220 # shape (batch_size,) 

221 square_norm = tf.linalg.diag_part(dot_product) 

222 

223 # Compute the pairwise distance matrix as we have: 

224 # ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2 

225 # shape (batch_size, batch_size) 

226 distances = ( 

227 tf.expand_dims(square_norm, 1) 

228 - 2.0 * dot_product 

229 + tf.expand_dims(square_norm, 2) 

230 ) 

231 

232 # Because of computation errors, some distances might be negative so we put everything >= 0.0 

233 distances = tf.maximum(distances, 0.0) 

234 

235 if flat: 

236 try: 

237 n = int(positions.shape[1]) 

238 except TypeError as e: 

239 n = 3 

240 mask = np.ones((n, n), dtype=bool) 

241 mask[np.tril_indices(n)] = False 

242 distances = tf.boolean_mask(distances, mask, axis=1) 

243 

244 if not squared: 

245 # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) 

246 # we need to add a small epsilon where distances == 0.0 

247 mask = tf.cast(tf.equal(distances, 0.0), np.float32) 

248 distances = distances + mask * 1e-16 

249 

250 distances = tf.sqrt(distances) 

251 

252 # Correct the epsilon added: set the distances on the mask to be exactly 0.0 

253 distances = distances * (1.0 - mask) 

254 

255 return distances