Coverage for encodermap/misc/distances.py: 98%

42 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-07 11:05 +0000

1# -*- coding: utf-8 -*- 

2# encodermap/misc/distances.py 

3################################################################################ 

4# Encodermap: A python library for dimensionality reduction. 

5# 

6# Copyright 2019-2022 University of Konstanz and the Authors 

7# 

8# Authors: 

9# Kevin Sawade, Tobias Lemke 

10# 

11# Encodermap is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU Lesser General Public License as 

13# published by the Free Software Foundation, either version 2.1 

14# of the License, or (at your option) any later version. 

15# This package is distributed in the hope that it will be useful to other 

16# researches. IT DOES NOT COME WITH ANY WARRANTY WHATSOEVER; without even the 

17# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

18# See the GNU Lesser General Public License for more details. 

19# 

20# See <http://www.gnu.org/licenses/>. 

21################################################################################ 

22################################################################################ 

23# Imports 

24################################################################################ 

25 

26from math import pi 

27 

28import numpy as np 

29import tensorflow as tf 

30 

31################################################################################ 

32# Globals 

33################################################################################ 

34 

35__all__ = [ 

36 "sigmoid", 

37 "periodic_distance", 

38 "periodic_distance_np", 

39 "pairwise_dist", 

40 "pairwise_dist_periodic", 

41] 

42 

43################################################################################ 

44# Functions 

45################################################################################ 

46 

47 

48def sigmoid(sig, a, b): 

49 """Returns a sigmoid function with specified parameters. 

50 

51 Args: 

52 sig (float): Sigma. 

53 a (float): a. 

54 b (float): b. 

55 

56 Returns: 

57 function: A function that can be used to calculate the sigmoid with the 

58 specified parameters. 

59 """ 

60 

61 def func(r): 

62 return 1 - (1 + (2 ** (a / b) - 1) * (r / sig) ** a) ** (-b / a) 

63 

64 return func 

65 

66 

67def periodic_distance_np(a, b, periodicity=2 * pi): 

68 """Calculates distance between two points and respects periodicity. 

69 

70 If the provided dataset is periodic (i.e. angles and torsion angles), the returned 

71 distance is corrected. 

72 

73 Args: 

74 a (np.ndarray): Coordinate of point a. 

75 b (np.ndarray): Coordinate of point b. 

76 periodicity (float, optional): The periodicity (i.e. the box length/ maximum angle) 

77 of your data. Defaults to 2*pi. Provide float('inf') for no periodicity. 

78 

79 """ 

80 d = np.abs(b - a) 

81 return np.minimum(d, periodicity - d) 

82 

83 

84def periodic_distance(a, b, periodicity=2 * pi): 

85 """Calculates distance between two points and respects periodicity. 

86 

87 If the provided dataset is periodic (i.e. angles and torsion angles), the returned 

88 distance is corrected. 

89 

90 Args: 

91 a (tf.Tensor): Coordinate of point a. 

92 b (tf.Tensor): Coordinate of point b. 

93 periodicity (float, optional): The periodicity (i.e. the box length/ maximum angle) 

94 of your data. Defaults to 2*pi. Provide float('inf') for no periodicity. 

95 

96 Example: 

97 >>> import encodermap as em 

98 >>> x = tf.convert_to_tensor(np.array([[1.5], [1.5]])) 

99 >>> y = tf.convert_to_tensor(np.array([[-3.1], [-3.1]])) 

100 >>> r = em.misc.periodic_distance(x, y) 

101 >>> print(r.numpy()) 

102 [[1.68318531] 

103 [1.68318531]] 

104 

105 """ 

106 d = tf.abs(b - a) 

107 return tf.minimum(d, periodicity - d) 

108 

109 

110def pairwise_dist_periodic(positions, periodicity): 

111 """Pairwise distances using periodicity. 

112 

113 Args: 

114 positions (Union[np.ndarray, tf.Tensor]): The positions of the points. 

115 Currently only 2D arrays with positions.shape[0] == n_points 

116 and positions.shape[1] == 1 (rotational values) is supported. 

117 periodicity (float): The periodicity of the data. Most often 

118 you will use either 2*pi or 360. 

119 

120 """ 

121 assert len(positions.shape) == 2 

122 if not tf.debugging.is_numeric_tensor(positions): 

123 positions = tf.convert_to_tensor(positions) 

124 vecs = periodic_distance( 

125 tf.expand_dims(positions, axis=1), 

126 tf.expand_dims(positions, axis=0), 

127 periodicity, 

128 ) 

129 mask = tf.cast(tf.equal(vecs, 0.0), "float32") 

130 vecs = vecs + mask * 1e-16 # gradient infinite for 0 

131 dists = tf.norm(vecs, axis=2) 

132 return dists 

133 

134 

135def pairwise_dist(positions, squared=False, flat=False): 

136 """Tensorflow implementation of `scipy.spatial.distances.cdist`. 

137 

138 Returns a tensor with shape (positions.shape[1], positions.shape[1]). 

139 This tensor is the distance matrix of the provided positions. The 

140 matrix is hollow, i.e. the diagonal elements are zero. 

141 

142 Args: 

143 positions (Union[np.ndarray, tf.Tensor]): Collection of 

144 n-dimensional points. positions.shape[0] are points. 

145 positions.shape[1] are dimensions. 

146 squared (bool, optional): Whether to return the pairwise squared 

147 euclidean distance matrix or normal euclidean distance matrix. 

148 Defaults to False. 

149 flat (bool, otpional): Whether to return only the lower triangle of 

150 the hollow matrix. Setting this to true mimics the behavior 

151 of `scipy.spatial.distance.pdist`. Defaults to False. 

152 

153 """ 

154 # thanks to https://omoindrot.github.io/triplet-loss 

155 

156 if not tf.debugging.is_numeric_tensor(positions): 

157 positions = tf.convert_to_tensor(positions) 

158 if len(positions.get_shape()) == 2: 

159 positions = tf.expand_dims(positions, 0) 

160 

161 # Get the dot product between all embeddings 

162 # shape (batch_size, batch_size) 

163 dot_product = tf.matmul(positions, tf.transpose(positions, [0, 2, 1])) 

164 

165 # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. 

166 # This also provides more numerical stability (the diagonal of the result will be exactly 0). 

167 # shape (batch_size,) 

168 square_norm = tf.linalg.diag_part(dot_product) 

169 

170 # Compute the pairwise distance matrix as we have: 

171 # ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2 

172 # shape (batch_size, batch_size) 

173 distances = ( 

174 tf.expand_dims(square_norm, 1) 

175 - 2.0 * dot_product 

176 + tf.expand_dims(square_norm, 2) 

177 ) 

178 

179 # Because of computation errors, some distances might be negative so we put everything >= 0.0 

180 distances = tf.maximum(distances, 0.0) 

181 

182 if flat: 

183 n = int(positions.shape[1]) 

184 mask = np.ones((n, n), dtype=bool) 

185 mask[np.tril_indices(n)] = False 

186 distances = tf.boolean_mask(distances, mask, axis=1) 

187 

188 if not squared: 188 ↛ 199line 188 didn't jump to line 199, because the condition on line 188 was never false

189 # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) 

190 # we need to add a small epsilon where distances == 0.0 

191 mask = tf.cast(tf.equal(distances, 0.0), np.float32) 

192 distances = distances + mask * 1e-16 

193 

194 distances = tf.sqrt(distances) 

195 

196 # Correct the epsilon added: set the distances on the mask to be exactly 0.0 

197 distances = distances * (1.0 - mask) 

198 

199 return distances