Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2015-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "Function.h"
23 : #include "core/ActionRegister.h"
24 :
25 : namespace PLMD {
26 : namespace function {
27 :
28 : //+PLUMEDOC FUNCTION STATS
29 : /*
30 : Calculates statistical properties of a set of collective variables with respect to a set of reference values.
31 :
32 : In particular it calculates and stores as components the sum of the squared deviations, the correlation, the
33 : slope and the intercept of a linear fit.
34 :
35 : The reference values can be either provided as values using PARAMETERS or using value without derivatives
36 : from other actions using PARARG (for example using experimental values from collective variables such as
37 : \ref CS2BACKBONE, \ref RDC, \ref NOE, \ref PRE).
38 :
39 : \par Examples
40 :
41 : The following input tells plumed to print the distance between three couple of atoms
42 : and compare them with three reference distances.
43 :
44 : \plumedfile
45 : d1: DISTANCE ATOMS=10,50
46 : d2: DISTANCE ATOMS=1,100
47 : d3: DISTANCE ATOMS=45,75
48 : st: STATS ARG=d1,d2,d3 PARAMETERS=1.5,4.0,2.0
49 : PRINT ARG=d1,d2,d3,st.*
50 : \endplumedfile
51 :
52 : */
53 : //+ENDPLUMEDOC
54 :
55 :
56 : class Stats :
57 : public Function
58 : {
59 : std::vector<double> parameters;
60 : bool sqdonly;
61 : bool components;
62 : bool upperd;
63 : public:
64 : explicit Stats(const ActionOptions&);
65 : void calculate() override;
66 : static void registerKeywords(Keywords& keys);
67 : };
68 :
69 :
70 : PLUMED_REGISTER_ACTION(Stats,"STATS")
71 :
72 33 : void Stats::registerKeywords(Keywords& keys) {
73 33 : Function::registerKeywords(keys);
74 66 : keys.addInputKeyword("optional","PARARG","scalar","the input for this action is the scalar output from one or more other actions without derivatives.");
75 66 : keys.add("optional","PARAMETERS","the parameters of the arguments in your function");
76 66 : keys.addFlag("SQDEVSUM",false,"calculates only SQDEVSUM");
77 66 : keys.addFlag("SQDEV",false,"calculates and store the SQDEV as components");
78 66 : keys.addFlag("UPPERDISTS",false,"calculates and store the SQDEV as components");
79 66 : keys.addOutputComponent("sqdevsum","default","scalar","the sum of the squared deviations between arguments and parameters");
80 66 : keys.addOutputComponent("corr","default","scalar","the correlation between arguments and parameters");
81 66 : keys.addOutputComponent("slope","default","scalar","the slope of a linear fit between arguments and parameters");
82 66 : keys.addOutputComponent("intercept","default","scalar","the intercept of a linear fit between arguments and parameters");
83 66 : keys.addOutputComponent("sqd","SQDEV","scalar","the squared deviations between arguments and parameters");
84 33 : }
85 :
86 31 : Stats::Stats(const ActionOptions&ao):
87 : Action(ao),
88 : Function(ao),
89 31 : sqdonly(false),
90 31 : components(false),
91 31 : upperd(false)
92 : {
93 62 : parseVector("PARAMETERS",parameters);
94 31 : if(parameters.size()!=static_cast<unsigned>(getNumberOfArguments())&&!parameters.empty())
95 0 : error("Size of PARAMETERS array should be either 0 or the same as of the number of arguments in ARG1");
96 :
97 : std::vector<Value*> arg2;
98 62 : parseArgumentList("PARARG",arg2);
99 :
100 31 : if(!arg2.empty()) {
101 14 : if(parameters.size()>0) error("It is not possible to use PARARG and PARAMETERS together");
102 14 : if(arg2.size()!=getNumberOfArguments()) error("Size of PARARG array should be the same as number for arguments in ARG");
103 5912 : for(unsigned i=0; i<arg2.size(); i++) {
104 5898 : parameters.push_back(arg2[i]->get());
105 5898 : if(arg2[i]->hasDerivatives()==true) error("PARARG can only accept arguments without derivatives");
106 : }
107 : }
108 :
109 : if(parameters.size()!=getNumberOfArguments())
110 0 : error("PARARG or PARAMETERS arrays should include the same number of elements as the arguments in ARG");
111 :
112 31 : if(getNumberOfArguments()<2) error("STATS need at least two arguments to be used");
113 :
114 31 : parseFlag("SQDEVSUM",sqdonly);
115 31 : parseFlag("SQDEV",components);
116 31 : parseFlag("UPPERDISTS",upperd);
117 :
118 31 : if(sqdonly&&components) error("You cannot used SQDEVSUM and SQDEV at the sametime");
119 :
120 31 : if(components) sqdonly = true;
121 :
122 31 : if(!arg2.empty()) log.printf(" using %zu parameters from inactive actions:", arg2.size());
123 17 : else log.printf(" using %zu parameters:", arg2.size());
124 6000 : for(unsigned i=0; i<parameters.size(); i++) log.printf(" %f",parameters[i]);
125 31 : log.printf("\n");
126 :
127 31 : if(sqdonly) {
128 17 : if(components) {
129 60 : for(unsigned i=0; i<parameters.size(); i++) {
130 48 : std::string num; Tools::convert(i,num);
131 48 : addComponentWithDerivatives("sqd-"+num);
132 96 : componentIsNotPeriodic("sqd-"+num);
133 : }
134 : } else {
135 5 : addComponentWithDerivatives("sqdevsum");
136 10 : componentIsNotPeriodic("sqdevsum");
137 : }
138 : } else {
139 14 : addComponentWithDerivatives("sqdevsum");
140 14 : componentIsNotPeriodic("sqdevsum");
141 14 : addComponentWithDerivatives("corr");
142 14 : componentIsNotPeriodic("corr");
143 14 : addComponentWithDerivatives("slope");
144 14 : componentIsNotPeriodic("slope");
145 14 : addComponentWithDerivatives("intercept");
146 28 : componentIsNotPeriodic("intercept");
147 : }
148 :
149 31 : checkRead();
150 31 : }
151 :
152 122 : void Stats::calculate()
153 : {
154 122 : if(sqdonly) {
155 :
156 : double nsqd = 0.;
157 : Value* val;
158 106 : if(!components) val=getPntrToComponent("sqdevsum");
159 174 : for(unsigned i=0; i<parameters.size(); ++i) {
160 121 : double dev = getArgument(i)-parameters[i];
161 121 : if(upperd&&dev<0) dev=0.;
162 121 : if(components) {
163 0 : val=getPntrToComponent(i);
164 0 : val->set(dev*dev);
165 : } else {
166 121 : nsqd += dev*dev;
167 : }
168 121 : setDerivative(val,i,2.*dev);
169 : }
170 53 : if(!components) val->set(nsqd);
171 :
172 : } else {
173 :
174 : double scx=0., scx2=0., scy=0., scy2=0., scxy=0.;
175 :
176 6230 : for(unsigned i=0; i<parameters.size(); ++i) {
177 6161 : const double tmpx=getArgument(i);
178 6161 : const double tmpy=parameters[i];
179 6161 : scx += tmpx;
180 6161 : scx2 += tmpx*tmpx;
181 6161 : scy += tmpy;
182 6161 : scy2 += tmpy*tmpy;
183 6161 : scxy += tmpx*tmpy;
184 : }
185 :
186 69 : const double ns = parameters.size();
187 :
188 69 : const double num = ns*scxy - scx*scy;
189 69 : const double idev2x = 1./(ns*scx2-scx*scx);
190 69 : const double idevx = std::sqrt(idev2x);
191 69 : const double idevy = 1./std::sqrt(ns*scy2-scy*scy);
192 :
193 : /* sd */
194 69 : const double nsqd = scx2 + scy2 - 2.*scxy;
195 : /* correlation */
196 69 : const double correlation = num * idevx * idevy;
197 : /* slope and intercept */
198 69 : const double slope = num * idev2x;
199 69 : const double inter = (scy - slope * scx)/ns;
200 :
201 69 : Value* valuea=getPntrToComponent("sqdevsum");
202 69 : Value* valueb=getPntrToComponent("corr");
203 69 : Value* valuec=getPntrToComponent("slope");
204 138 : Value* valued=getPntrToComponent("intercept");
205 :
206 : valuea->set(nsqd);
207 : valueb->set(correlation);
208 : valuec->set(slope);
209 : valued->set(inter);
210 :
211 : /* derivatives */
212 6230 : for(unsigned i=0; i<parameters.size(); ++i) {
213 6161 : const double common_d1 = (ns*parameters[i]-scy)*idevx;
214 6161 : const double common_d2 = num*(ns*getArgument(i)-scx)*idev2x*idevx;
215 6161 : const double common_d3 = common_d1 - common_d2;
216 :
217 : /* sqdevsum */
218 6161 : const double sq_der = 2.*(getArgument(i)-parameters[i]);
219 : /* correlation */
220 6161 : const double co_der = common_d3*idevy;
221 : /* slope */
222 6161 : const double sl_der = (common_d1-2.*common_d2)*idevx;
223 : /* intercept */
224 6161 : const double int_der = -(slope+ scx*sl_der)/ns;
225 :
226 : setDerivative(valuea,i,sq_der);
227 6161 : setDerivative(valueb,i,co_der);
228 6161 : setDerivative(valuec,i,sl_der);
229 6161 : setDerivative(valued,i,int_der);
230 : }
231 :
232 : }
233 122 : }
234 :
235 : }
236 : }
237 :
238 :
|