Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2016-2018 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionRegister.h"
23 : #include "core/PlumedMain.h"
24 : #include "core/ActionSet.h"
25 : #include "core/ActionShortcut.h"
26 : #include "core/ActionWithValue.h"
27 :
28 : //+PLUMEDOC FUNCTION MAHALANOBIS_DISTANCE
29 : /*
30 : Calculate the Mahalanobis distance between two points in CV space
31 :
32 : If we have two $n$-dimensional vectors $u$ and $v$ we can calculate the
33 : [Mahalanobis distance](https://en.wikipedia.org/wiki/Mahalanobis_distance) between the two points as
34 :
35 : $$
36 : d = \sqrt{ \sum_{i=1}^n \sum_{j=1}^n m_{ij} (u_i - v_i)(u_j - v_j) }
37 : $$
38 :
39 : which can be expressed in matrix form as:
40 :
41 : $$
42 : d^2 = (u-v)^T M (u-v)
43 : $$
44 :
45 : The inputs below shows an example where this is used to calculate the Mahalanobis distance
46 : between the instaneous values of some torsional angles and some reference values
47 : for these distances. The inverse covriance values are provided in the constant value with label `m`.
48 : In this first example the input values are vectors:
49 :
50 : ```plumed
51 : m: CONSTANT VALUES=2.45960237E-0001,-1.30615381E-0001,-1.30615381E-0001,2.40239117E-0001 NROWS=2 NCOLS=2
52 : c: CONSTANT VALUES=1,2
53 : d: DISTANCE ATOMS1=1,2 ATOMS2=3,4
54 : dd: MAHALANOBIS_DISTANCE ARG1=c ARG2=d METRIC=m
55 : PRINT ARG=dd FILE=colvar
56 : ```
57 :
58 : while this second example does the same thing but uses scalars in input.
59 :
60 : ```plumed
61 : m: CONSTANT VALUES=2.45960237E-0001,-1.30615381E-0001,-1.30615381E-0001,2.40239117E-0001 NROWS=2 NCOLS=2
62 : c1: CONSTANT VALUE=1
63 : d1: DISTANCE ATOMS=1,2
64 : c2: CONSTANT VALUE=2
65 : d2: DISTANCE ATOMS=3,4
66 : dd: MAHALANOBIS_DISTANCE ARG1=c1,c2 ARG2=d1,d2 METRIC=m
67 : PRINT ARG=dd FILE=colvar
68 : ```
69 :
70 : ## Dealing with periodic variables
71 :
72 : When you are calculating a distance from a reference point you need to be careful when the input variables
73 : are periodic. If you are calculating the distance using the [EUCLIDEAN_DISTANCE](EUCLIDEAN_DISTANCE.md) and
74 : [NORMALIZED_EUCLIDEAN_DISTANCE](NORMALIZED_EUCLIDEAN_DISTANCE.md) commands this is not a problem. The problems are
75 : specific to the Mahalanobis distance command and have been resolved in the papers that are cited below by defining
76 : the following alternatative to the Mahalanobis distance:
77 :
78 : $$
79 : d^2 = 2\sum_{i=1}^n m_{ii} \left[ 1 - \cos\left( \frac{2\pi(u_i-v_i)}{P_i} \right) \right] + \sum_{i\ne j} m_{ij} \sin\left( \frac{2\pi(u_i-v_i)}{P_i} \right) \sin\left( \frac{2\pi(u_j-v_j)}{P_j} \right)
80 : $$
81 :
82 : In this expression, $P_i$ indicates the periodicity of the domain for variable $i$. If you would like to compute this
83 : distance with PLUMED you use the `VON_MISSES` shown below:
84 :
85 : ```plumed
86 : m: CONSTANT VALUES=2.45960237E-0001,-1.30615381E-0001,-1.30615381E-0001,2.40239117E-0001 NROWS=2 NCOLS=2
87 : c: CONSTANT VALUES=1,2
88 : d: TORSION ATOMS1=1,2,3,4 ATOMS2=5,6,7,8
89 : dd: MAHALANOBIS_DISTANCE ARG1=c ARG2=d METRIC=m VON_MISSES
90 : PRINT ARG=dd FILE=colvar
91 : ```
92 :
93 : ## Calculating multiple distances
94 :
95 : Suppose that we now have $m$ reference configurations we can define the following $m$ distances
96 : from these reference configurations:
97 :
98 : $$
99 : d_j^2 = (u-v_j)^T M (u-v_j)
100 : $$
101 :
102 : Lets suppose that we put the $m$, $n$-dimensional $(u-v_j)$ vectors in this expression into a
103 : $n\times m$ matrix, $A$, by using the [DISPLACEMENT](DISPLACEMENT.md) command. It is then
104 : straightforward to show that the $d_j^2$ values in the above expression are the diagonal
105 : elements of the matrix product $A^T M A$.
106 :
107 : We can use this idea to calculate multiple MAHALANOBIS_DISTANCE values in the following inputs.
108 : This first example calculates the three distances between the instaneoues values of two torsions
109 : and three reference configurations.
110 :
111 : ```plumed
112 : m: CONSTANT VALUES=2.45960237E-0001,-1.30615381E-0001,-1.30615381E-0001,2.40239117E-0001 NROWS=2 NCOLS=2
113 : ref_psi: CONSTANT VALUES=2.25,1.3,-1.5
114 : ref_phi: CONSTANT VALUES=-1.91,-0.6,2.4
115 :
116 : psi: TORSION ATOMS=1,2,3,4
117 : phi: TORSION ATOMS=13,14,15,16
118 :
119 : dd: MAHALANOBIS_DISTANCE ARG1=psi,phi ARG2=ref_psi,ref_phi METRIC=m
120 : PRINT ARG=dd FILE=colvar
121 : ```
122 :
123 : This section example calculates the three distances between a single reference value for the two
124 : torsions and three instances of this pair of torsions.
125 :
126 : ```plumed
127 : m: CONSTANT VALUES=2.45960237E-0001,-1.30615381E-0001,-1.30615381E-0001,2.40239117E-0001 NROWS=2 NCOLS=2
128 : ref_psi: CONSTANT VALUES=2.25
129 : ref_phi: CONSTANT VALUES=-1.91
130 :
131 : psi: TORSION ATOMS1=1,2,3,4 ATOMS2=5,6,7,8 ATOMS3=9,10,11,12
132 : phi: TORSION ATOMS1=13,14,15,16 ATOMS2=17,18,19,20 ATOMS3=21,22,23,24
133 :
134 : dd: MAHALANOBIS_DISTANCE ARG1=psi,phi ARG2=ref_psi,ref_phi METRIC=m
135 : PRINT ARG=dd FILE=colvar
136 : ```
137 :
138 : This final example then computes three distances between three pairs of torsional angles and threee
139 : reference values for these three values.
140 :
141 : ```plumed
142 : m: CONSTANT VALUES=2.45960237E-0001,-1.30615381E-0001,-1.30615381E-0001,2.40239117E-0001 NROWS=2 NCOLS=2
143 : ref_psi: CONSTANT VALUES=2.25,1.3,-1.5
144 : ref_phi: CONSTANT VALUES=-1.91,-0.6,2.4
145 :
146 : psi: TORSION ATOMS1=1,2,3,4 ATOMS2=5,6,7,8 ATOMS3=9,10,11,12
147 : phi: TORSION ATOMS1=13,14,15,16 ATOMS2=17,18,19,20 ATOMS3=21,22,23,24
148 :
149 : dd: MAHALANOBIS_DISTANCE ARG1=psi,phi ARG2=ref_psi,ref_phi METRIC=m
150 : PRINT ARG=dd FILE=colvar
151 : ```
152 :
153 : Notice, finally, that you can also calculate multiple distances if you use the `VON_MISSES` option:
154 :
155 : ```plumed
156 : m: CONSTANT VALUES=2.45960237E-0001,-1.30615381E-0001,-1.30615381E-0001,2.40239117E-0001 NROWS=2 NCOLS=2
157 : ref_psi: CONSTANT VALUES=2.25,1.3,-1.5
158 : ref_phi: CONSTANT VALUES=-1.91,-0.6,2.4
159 :
160 : psi: TORSION ATOMS1=1,2,3,4 ATOMS2=5,6,7,8 ATOMS3=9,10,11,12
161 : phi: TORSION ATOMS1=13,14,15,16 ATOMS2=17,18,19,20 ATOMS3=21,22,23,24
162 :
163 : dd: MAHALANOBIS_DISTANCE ARG1=psi,phi ARG2=ref_psi,ref_phi METRIC=m VON_MISSES
164 : PRINT ARG=dd FILE=colvar
165 : ```
166 :
167 : */
168 : //+ENDPLUMEDOC
169 :
170 : namespace PLMD {
171 : namespace refdist {
172 :
173 : class MahalanobisDistance : public ActionShortcut {
174 : public:
175 : static void registerKeywords( Keywords& keys );
176 : explicit MahalanobisDistance(const ActionOptions&ao);
177 : };
178 :
179 : PLUMED_REGISTER_ACTION(MahalanobisDistance,"MAHALANOBIS_DISTANCE")
180 :
181 21 : void MahalanobisDistance::registerKeywords( Keywords& keys ) {
182 21 : ActionShortcut::registerKeywords(keys);
183 21 : keys.add("compulsory","ARG1","The point that we are calculating the distance from");
184 21 : keys.add("compulsory","ARG2","The point that we are calculating the distance to");
185 21 : keys.add("compulsory","METRIC","The inverse covariance matrix that should be used when calculating the distance");
186 21 : keys.addFlag("SQUARED",false,"The squared distance should be calculated");
187 21 : keys.addFlag("VON_MISSES",false,"Compute the mahalanobis distance in a way that is more sympathetic to the periodic boundary conditions");
188 42 : keys.setValueDescription("scalar/vector","the Mahalanobis distances between the input vectors");
189 21 : keys.needsAction("DISPLACEMENT");
190 21 : keys.needsAction("CUSTOM");
191 21 : keys.needsAction("OUTER_PRODUCT");
192 21 : keys.needsAction("TRANSPOSE");
193 21 : keys.needsAction("MATRIX_PRODUCT_DIAGONAL");
194 21 : keys.needsAction("CONSTANT");
195 21 : keys.needsAction("MATRIX_VECTOR_PRODUCT");
196 21 : keys.needsAction("MATRIX_PRODUCT");
197 21 : keys.needsAction("COMBINE");
198 21 : keys.addDOI("10.1073/pnas.1011511107");
199 21 : keys.addDOI("10.1021/acs.jctc.7b00993");
200 21 : }
201 :
202 12 : MahalanobisDistance::MahalanobisDistance( const ActionOptions& ao):
203 : Action(ao),
204 12 : ActionShortcut(ao) {
205 : std::string arg1, arg2, metstr;
206 12 : parse("ARG1",arg1);
207 12 : parse("ARG2",arg2);
208 12 : parse("METRIC",metstr);
209 : // Check on input metric
210 12 : ActionWithValue* mav=plumed.getActionSet().selectWithLabel<ActionWithValue*>( metstr );
211 12 : if( !mav ) {
212 0 : error("could not find action named " + metstr + " to use for metric");
213 : }
214 12 : if( mav->copyOutput(0)->getRank()!=2 ) {
215 0 : error("metric has incorrect rank");
216 : }
217 :
218 24 : readInputLine( getShortcutLabel() + "_diff: DISPLACEMENT ARG1=" + arg1 + " ARG2=" + arg2 );
219 24 : readInputLine( getShortcutLabel() + "_diffT: TRANSPOSE ARG=" + getShortcutLabel() + "_diff");
220 : bool von_miss, squared;
221 12 : parseFlag("VON_MISSES",von_miss);
222 12 : parseFlag("SQUARED",squared);
223 12 : if( von_miss ) {
224 7 : unsigned nrows = mav->copyOutput(0)->getShape()[0];
225 7 : if( mav->copyOutput(0)->getShape()[1]!=nrows ) {
226 0 : error("metric is not symmetric");
227 : }
228 : // Create a matrix that can be used to compute the off diagonal elements
229 : std::string valstr, nrstr;
230 7 : Tools::convert( mav->copyOutput(0)->get(0), valstr );
231 7 : Tools::convert( nrows, nrstr );
232 7 : std::string diagmet = getShortcutLabel() + "_diagmet: CONSTANT VALUES=" + valstr;
233 14 : std::string offdiagmet = getShortcutLabel() + "_offdiagmet: CONSTANT NROWS=" + nrstr + " NCOLS=" + nrstr + " VALUES=0";
234 21 : for(unsigned i=0; i<nrows; ++i) {
235 42 : for(unsigned j=0; j<nrows; ++j) {
236 28 : Tools::convert( mav->copyOutput(0)->get(i*nrows+j), valstr );
237 28 : if( i==j && i>0 ) {
238 : offdiagmet += ",0";
239 14 : diagmet += "," + valstr;
240 21 : } else if( i!=j ) {
241 28 : offdiagmet += "," + valstr;
242 : }
243 : }
244 : }
245 7 : readInputLine( diagmet );
246 7 : readInputLine( offdiagmet );
247 : // Compute distances scaled by periods
248 7 : ActionWithValue* av=plumed.getActionSet().selectWithLabel<ActionWithValue*>( getShortcutLabel() + "_diff" );
249 7 : plumed_assert( av );
250 7 : if( !av->copyOutput(0)->isPeriodic() ) {
251 0 : error("VON_MISSES only works with periodic variables");
252 : }
253 : std::string min, max;
254 7 : av->copyOutput(0)->getDomain(min,max);
255 14 : readInputLine( getShortcutLabel() + "_scaled: CUSTOM ARG=" + getShortcutLabel() + "_diffT FUNC=2*pi*x/(" + max +"-" + min + ") PERIODIC=NO");
256 : // We start calculating off-diagonal elements by computing the sines of the scaled differences (this is a column vector)
257 14 : readInputLine( getShortcutLabel() + "_sinediffT: CUSTOM ARG=" + getShortcutLabel() + "_scaled FUNC=sin(x) PERIODIC=NO");
258 : // Transpose sines to get a row vector
259 14 : readInputLine( getShortcutLabel() + "_sinediff: TRANSPOSE ARG=" + getShortcutLabel() + "_sinediffT");
260 : // Compute the off diagonal elements
261 7 : ActionWithValue* avs=plumed.getActionSet().selectWithLabel<ActionWithValue*>( getShortcutLabel() + "_sinediffT" );
262 7 : plumed_assert( avs && avs->getNumberOfComponents()==1 );
263 7 : if( (avs->copyOutput(0))->getRank()==1 ) {
264 0 : readInputLine( getShortcutLabel() + "_matvec: MATRIX_VECTOR_PRODUCT ARG=" + metstr + "," + getShortcutLabel() +"_sinediffT");
265 : } else {
266 14 : readInputLine( getShortcutLabel() + "_matvec: MATRIX_PRODUCT ARG=" + getShortcutLabel() + "_offdiagmet," + getShortcutLabel() +"_sinediffT");
267 : }
268 14 : readInputLine( getShortcutLabel() + "_offdiag: MATRIX_PRODUCT_DIAGONAL ARG=" + getShortcutLabel() + "_sinediff," + getShortcutLabel() +"_matvec");
269 : // Sort out the metric for the diagonal elements
270 7 : std::string metstr2 = getShortcutLabel() + "_diagmet";
271 : // If this is a matrix we need create a matrix to multiply by
272 7 : if( av->copyOutput(0)->getShape()[0]>1 ) {
273 : // Create some ones
274 7 : std::string ones=" VALUES=1";
275 21 : for(unsigned i=1; i<av->copyOutput(0)->getShape()[0]; ++i ) {
276 : ones += ",1";
277 : }
278 14 : readInputLine( getShortcutLabel() + "_ones: CONSTANT " + ones );
279 : // Now do some multiplication to create a matrix that can be multiplied by our "inverse variance" vector
280 14 : readInputLine( getShortcutLabel() + "_" + metstr + ": OUTER_PRODUCT ARG=" + metstr2 + "," + getShortcutLabel() + "_ones");
281 14 : metstr2 = getShortcutLabel() + "_" + metstr;
282 : }
283 : // Compute the diagonal elements
284 14 : readInputLine( getShortcutLabel() + "_prod: CUSTOM ARG=" + getShortcutLabel() + "_scaled," + metstr2 + " FUNC=2*(1-cos(x))*y PERIODIC=NO");
285 : std::string ncstr;
286 7 : Tools::convert( nrows, ncstr );
287 7 : Tools::convert( av->copyOutput(0)->getShape()[0], nrstr );
288 7 : std::string ones=" VALUES=1";
289 42 : for(unsigned i=1; i<av->copyOutput(0)->getNumberOfValues(); ++i) {
290 : ones += ",1";
291 : }
292 14 : readInputLine( getShortcutLabel() + "_matones: CONSTANT NROWS=" + nrstr + " NCOLS=" + ncstr + ones );
293 14 : readInputLine( getShortcutLabel() + "_diag: MATRIX_PRODUCT_DIAGONAL ARG=" + getShortcutLabel() + "_matones," + getShortcutLabel() + "_prod");
294 : // Sum everything
295 7 : if( !squared ) {
296 0 : readInputLine( getShortcutLabel() + "_2: COMBINE ARG=" + getShortcutLabel() + "_offdiag," + getShortcutLabel() + "_diag PERIODIC=NO");
297 : } else {
298 14 : readInputLine( getShortcutLabel() + ": COMBINE ARG=" + getShortcutLabel() + "_offdiag," + getShortcutLabel() + "_diag PERIODIC=NO");
299 : }
300 : } else {
301 5 : ActionWithValue* av=plumed.getActionSet().selectWithLabel<ActionWithValue*>( getShortcutLabel() + "_diffT" );
302 5 : plumed_assert( av && av->getNumberOfComponents()==1 );
303 5 : if( (av->copyOutput(0))->getRank()==1 ) {
304 8 : readInputLine( getShortcutLabel() + "_matvec: MATRIX_VECTOR_PRODUCT ARG=" + metstr + "," + getShortcutLabel() +"_diffT");
305 : } else {
306 2 : readInputLine( getShortcutLabel() + "_matvec: MATRIX_PRODUCT ARG=" + metstr + "," + getShortcutLabel() +"_diffT");
307 : }
308 5 : std::string olab = getShortcutLabel();
309 5 : if( !squared ) {
310 : olab += "_2";
311 : }
312 10 : readInputLine( olab + ": MATRIX_PRODUCT_DIAGONAL ARG=" + getShortcutLabel() + "_diff," + getShortcutLabel() +"_matvec");
313 : }
314 12 : if( !squared ) {
315 10 : readInputLine( getShortcutLabel() + ": CUSTOM ARG=" + getShortcutLabel() + "_2 FUNC=sqrt(x) PERIODIC=NO");
316 : }
317 12 : }
318 :
319 : }
320 : }
|