Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2015-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "Function.h"
23 : #include "ActionRegister.h"
24 :
25 : namespace PLMD {
26 : namespace function {
27 :
28 : //+PLUMEDOC FUNCTION STATS
29 : /*
30 : Calculates statistical properties of a set of collective variables with respect to a set of reference values.
31 :
32 : In particular it calculates and stores as components the sum of the squared deviations, the correlation, the
33 : slope and the intercept of a linear fit.
34 :
35 : The reference values can be either provided as values using PARAMETERS or using value without derivatives
36 : from other actions using PARARG (for example using experimental values from collective variables such as
37 : \ref CS2BACKBONE, \ref RDC, \ref NOE, \ref PRE).
38 :
39 : \par Examples
40 :
41 : The following input tells plumed to print the distance between three couple of atoms
42 : and compare them with three reference distances.
43 :
44 : \plumedfile
45 : d1: DISTANCE ATOMS=10,50
46 : d2: DISTANCE ATOMS=1,100
47 : d3: DISTANCE ATOMS=45,75
48 : st: STATS ARG=d1,d2,d3 PARAMETERS=1.5,4.0,2.0
49 : PRINT ARG=d1,d2,d3,st.*
50 : \endplumedfile
51 :
52 : */
53 : //+ENDPLUMEDOC
54 :
55 :
56 : class Stats :
57 : public Function
58 : {
59 : std::vector<double> parameters;
60 : bool sqdonly;
61 : bool components;
62 : bool upperd;
63 : public:
64 : explicit Stats(const ActionOptions&);
65 : void calculate() override;
66 : static void registerKeywords(Keywords& keys);
67 : };
68 :
69 :
70 10481 : PLUMED_REGISTER_ACTION(Stats,"STATS")
71 :
72 32 : void Stats::registerKeywords(Keywords& keys) {
73 32 : Function::registerKeywords(keys);
74 32 : keys.use("ARG");
75 64 : keys.add("optional","PARARG","the input for this action is the scalar output from one or more other actions without derivatives.");
76 64 : keys.add("optional","PARAMETERS","the parameters of the arguments in your function");
77 64 : keys.addFlag("SQDEVSUM",false,"calculates only SQDEVSUM");
78 64 : keys.addFlag("SQDEV",false,"calculates and store the SQDEV as components");
79 64 : keys.addFlag("UPPERDISTS",false,"calculates and store the SQDEV as components");
80 64 : keys.addOutputComponent("sqdevsum","default","the sum of the squared deviations between arguments and parameters");
81 64 : keys.addOutputComponent("corr","default","the correlation between arguments and parameters");
82 64 : keys.addOutputComponent("slope","default","the slope of a linear fit between arguments and parameters");
83 64 : keys.addOutputComponent("intercept","default","the intercept of a linear fit between arguments and parameters");
84 64 : keys.addOutputComponent("sqd","SQDEV","the squared deviations between arguments and parameters");
85 32 : }
86 :
87 31 : Stats::Stats(const ActionOptions&ao):
88 : Action(ao),
89 : Function(ao),
90 31 : sqdonly(false),
91 31 : components(false),
92 31 : upperd(false)
93 : {
94 62 : parseVector("PARAMETERS",parameters);
95 31 : if(parameters.size()!=static_cast<unsigned>(getNumberOfArguments())&&!parameters.empty())
96 0 : error("Size of PARAMETERS array should be either 0 or the same as of the number of arguments in ARG1");
97 :
98 : std::vector<Value*> arg2;
99 62 : parseArgumentList("PARARG",arg2);
100 :
101 31 : if(!arg2.empty()) {
102 14 : if(parameters.size()>0) error("It is not possible to use PARARG and PARAMETERS together");
103 14 : if(arg2.size()!=getNumberOfArguments()) error("Size of PARARG array should be the same as number for arguments in ARG");
104 5912 : for(unsigned i=0; i<arg2.size(); i++) {
105 5898 : parameters.push_back(arg2[i]->get());
106 5898 : if(arg2[i]->hasDerivatives()==true) error("PARARG can only accept arguments without derivatives");
107 : }
108 : }
109 :
110 31 : if(parameters.size()!=getNumberOfArguments())
111 0 : error("PARARG or PARAMETERS arrays should include the same number of elements as the arguments in ARG");
112 :
113 31 : if(getNumberOfArguments()<2) error("STATS need at least two arguments to be used");
114 :
115 31 : parseFlag("SQDEVSUM",sqdonly);
116 31 : parseFlag("SQDEV",components);
117 31 : parseFlag("UPPERDISTS",upperd);
118 :
119 31 : if(sqdonly&&components) error("You cannot used SQDEVSUM and SQDEV at the sametime");
120 :
121 31 : if(components) sqdonly = true;
122 :
123 31 : if(!arg2.empty()) log.printf(" using %zu parameters from inactive actions:", arg2.size());
124 17 : else log.printf(" using %zu parameters:", arg2.size());
125 6000 : for(unsigned i=0; i<parameters.size(); i++) log.printf(" %f",parameters[i]);
126 31 : log.printf("\n");
127 :
128 31 : if(sqdonly) {
129 17 : if(components) {
130 60 : for(unsigned i=0; i<parameters.size(); i++) {
131 48 : std::string num; Tools::convert(i,num);
132 48 : addComponentWithDerivatives("sqd-"+num);
133 96 : componentIsNotPeriodic("sqd-"+num);
134 : }
135 : } else {
136 5 : addComponentWithDerivatives("sqdevsum");
137 10 : componentIsNotPeriodic("sqdevsum");
138 : }
139 : } else {
140 14 : addComponentWithDerivatives("sqdevsum");
141 14 : componentIsNotPeriodic("sqdevsum");
142 14 : addComponentWithDerivatives("corr");
143 14 : componentIsNotPeriodic("corr");
144 14 : addComponentWithDerivatives("slope");
145 14 : componentIsNotPeriodic("slope");
146 14 : addComponentWithDerivatives("intercept");
147 28 : componentIsNotPeriodic("intercept");
148 : }
149 :
150 31 : checkRead();
151 31 : }
152 :
153 122 : void Stats::calculate()
154 : {
155 122 : if(sqdonly) {
156 :
157 : double nsqd = 0.;
158 : Value* val;
159 106 : if(!components) val=getPntrToComponent("sqdevsum");
160 174 : for(unsigned i=0; i<parameters.size(); ++i) {
161 121 : double dev = getArgument(i)-parameters[i];
162 121 : if(upperd&&dev<0) dev=0.;
163 121 : if(components) {
164 0 : val=getPntrToComponent(i);
165 0 : val->set(dev*dev);
166 : } else {
167 121 : nsqd += dev*dev;
168 : }
169 121 : setDerivative(val,i,2.*dev);
170 : }
171 53 : if(!components) val->set(nsqd);
172 :
173 : } else {
174 :
175 : double scx=0., scx2=0., scy=0., scy2=0., scxy=0.;
176 :
177 6230 : for(unsigned i=0; i<parameters.size(); ++i) {
178 : const double tmpx=getArgument(i);
179 6161 : const double tmpy=parameters[i];
180 6161 : scx += tmpx;
181 6161 : scx2 += tmpx*tmpx;
182 6161 : scy += tmpy;
183 6161 : scy2 += tmpy*tmpy;
184 6161 : scxy += tmpx*tmpy;
185 : }
186 :
187 69 : const double ns = parameters.size();
188 :
189 69 : const double num = ns*scxy - scx*scy;
190 69 : const double idev2x = 1./(ns*scx2-scx*scx);
191 69 : const double idevx = std::sqrt(idev2x);
192 69 : const double idevy = 1./std::sqrt(ns*scy2-scy*scy);
193 :
194 : /* sd */
195 69 : const double nsqd = scx2 + scy2 - 2.*scxy;
196 : /* correlation */
197 69 : const double correlation = num * idevx * idevy;
198 : /* slope and intercept */
199 69 : const double slope = num * idev2x;
200 69 : const double inter = (scy - slope * scx)/ns;
201 :
202 69 : Value* valuea=getPntrToComponent("sqdevsum");
203 69 : Value* valueb=getPntrToComponent("corr");
204 69 : Value* valuec=getPntrToComponent("slope");
205 138 : Value* valued=getPntrToComponent("intercept");
206 :
207 : valuea->set(nsqd);
208 : valueb->set(correlation);
209 : valuec->set(slope);
210 : valued->set(inter);
211 :
212 : /* derivatives */
213 6230 : for(unsigned i=0; i<parameters.size(); ++i) {
214 6161 : const double common_d1 = (ns*parameters[i]-scy)*idevx;
215 6161 : const double common_d2 = num*(ns*getArgument(i)-scx)*idev2x*idevx;
216 6161 : const double common_d3 = common_d1 - common_d2;
217 :
218 : /* sqdevsum */
219 6161 : const double sq_der = 2.*(getArgument(i)-parameters[i]);
220 : /* correlation */
221 6161 : const double co_der = common_d3*idevy;
222 : /* slope */
223 6161 : const double sl_der = (common_d1-2.*common_d2)*idevx;
224 : /* intercept */
225 6161 : const double int_der = -(slope+ scx*sl_der)/ns;
226 :
227 : setDerivative(valuea,i,sq_der);
228 : setDerivative(valueb,i,co_der);
229 : setDerivative(valuec,i,sl_der);
230 : setDerivative(valued,i,int_der);
231 : }
232 :
233 : }
234 122 : }
235 :
236 : }
237 : }
238 :
239 :
|