Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2015-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "Function.h"
23 : #include "core/ActionRegister.h"
24 :
25 : namespace PLMD {
26 : namespace function {
27 :
28 : //+PLUMEDOC FUNCTION STATS
29 : /*
30 : Calculates statistical properties of a set of collective variables with respect to a set of reference values.
31 :
32 : In particular it calculates and stores as components the sum of the squared deviations, the correlation, the
33 : slope and the intercept of a linear fit.
34 :
35 : The reference values can be either provided as values using PARAMETERS or using value without derivatives
36 : from other actions using PARARG (for example using experimental values from collective variables such as
37 : [CS2BACKBONE](CS2BACKBONE.md), [RDC](RDC.md), [NOE](NOE.md), [PRE](PRE.md)).
38 :
39 : \par Examples
40 :
41 : The following input tells plumed to print the distance between three couple of atoms
42 : and compare them with three reference distances.
43 :
44 : ```plumed
45 : d1: DISTANCE ATOMS=10,50
46 : d2: DISTANCE ATOMS=1,100
47 : d3: DISTANCE ATOMS=45,75
48 : st: STATS ARG=d1,d2,d3 PARAMETERS=1.5,4.0,2.0
49 : PRINT ARG=d1,d2,d3,st.*
50 : ```
51 :
52 : */
53 : //+ENDPLUMEDOC
54 :
55 :
56 : class Stats :
57 : public Function {
58 : std::vector<double> parameters;
59 : bool sqdonly;
60 : bool components;
61 : bool upperd;
62 : public:
63 : explicit Stats(const ActionOptions&);
64 : void calculate() override;
65 : static void registerKeywords(Keywords& keys);
66 : };
67 :
68 :
69 : PLUMED_REGISTER_ACTION(Stats,"STATS")
70 :
71 33 : void Stats::registerKeywords(Keywords& keys) {
72 33 : Function::registerKeywords(keys);
73 66 : keys.addInputKeyword("optional","PARARG","scalar","the input for this action is the scalar output from one or more other actions without derivatives.");
74 33 : keys.add("optional","PARAMETERS","the parameters of the arguments in your function");
75 33 : keys.addFlag("SQDEVSUM",false,"calculates only SQDEVSUM");
76 33 : keys.addFlag("SQDEV",false,"calculates and store the SQDEV as components");
77 33 : keys.addFlag("UPPERDISTS",false,"calculates and store the SQDEV as components");
78 66 : keys.addOutputComponent("sqdevsum","default","scalar","the sum of the squared deviations between arguments and parameters");
79 66 : keys.addOutputComponent("corr","default","scalar","the correlation between arguments and parameters");
80 66 : keys.addOutputComponent("slope","default","scalar","the slope of a linear fit between arguments and parameters");
81 66 : keys.addOutputComponent("intercept","default","scalar","the intercept of a linear fit between arguments and parameters");
82 66 : keys.addOutputComponent("sqd","SQDEV","scalar","the squared deviations between arguments and parameters");
83 33 : }
84 :
85 31 : Stats::Stats(const ActionOptions&ao):
86 : Action(ao),
87 : Function(ao),
88 31 : sqdonly(false),
89 31 : components(false),
90 31 : upperd(false) {
91 62 : parseVector("PARAMETERS",parameters);
92 31 : if(parameters.size()!=static_cast<unsigned>(getNumberOfArguments())&&!parameters.empty()) {
93 0 : error("Size of PARAMETERS array should be either 0 or the same as of the number of arguments in ARG1");
94 : }
95 :
96 : std::vector<Value*> arg2;
97 62 : parseArgumentList("PARARG",arg2);
98 :
99 31 : if(!arg2.empty()) {
100 14 : if(parameters.size()>0) {
101 0 : error("It is not possible to use PARARG and PARAMETERS together");
102 : }
103 14 : if(arg2.size()!=getNumberOfArguments()) {
104 0 : error("Size of PARARG array should be the same as number for arguments in ARG");
105 : }
106 5912 : for(unsigned i=0; i<arg2.size(); i++) {
107 5898 : parameters.push_back(arg2[i]->get());
108 5898 : if(arg2[i]->hasDerivatives()==true) {
109 0 : error("PARARG can only accept arguments without derivatives");
110 : }
111 : }
112 : }
113 :
114 31 : if(parameters.size()!=getNumberOfArguments()) {
115 0 : error("PARARG or PARAMETERS arrays should include the same number of elements as the arguments in ARG");
116 : }
117 :
118 31 : if(getNumberOfArguments()<2) {
119 0 : error("STATS need at least two arguments to be used");
120 : }
121 :
122 31 : parseFlag("SQDEVSUM",sqdonly);
123 31 : parseFlag("SQDEV",components);
124 31 : parseFlag("UPPERDISTS",upperd);
125 :
126 31 : if(sqdonly&&components) {
127 0 : error("You cannot used SQDEVSUM and SQDEV at the sametime");
128 : }
129 :
130 31 : if(components) {
131 12 : sqdonly = true;
132 : }
133 :
134 31 : if(!arg2.empty()) {
135 14 : log.printf(" using %zu parameters from inactive actions:", arg2.size());
136 : } else {
137 17 : log.printf(" using %zu parameters:", arg2.size());
138 : }
139 6000 : for(unsigned i=0; i<parameters.size(); i++) {
140 5969 : log.printf(" %f",parameters[i]);
141 : }
142 31 : log.printf("\n");
143 :
144 31 : if(sqdonly) {
145 17 : if(components) {
146 60 : for(unsigned i=0; i<parameters.size(); i++) {
147 : std::string num;
148 48 : Tools::convert(i,num);
149 48 : addComponentWithDerivatives("sqd-"+num);
150 96 : componentIsNotPeriodic("sqd-"+num);
151 : }
152 : } else {
153 5 : addComponentWithDerivatives("sqdevsum");
154 10 : componentIsNotPeriodic("sqdevsum");
155 : }
156 : } else {
157 14 : addComponentWithDerivatives("sqdevsum");
158 14 : componentIsNotPeriodic("sqdevsum");
159 14 : addComponentWithDerivatives("corr");
160 14 : componentIsNotPeriodic("corr");
161 14 : addComponentWithDerivatives("slope");
162 14 : componentIsNotPeriodic("slope");
163 14 : addComponentWithDerivatives("intercept");
164 28 : componentIsNotPeriodic("intercept");
165 : }
166 :
167 31 : checkRead();
168 31 : }
169 :
170 122 : void Stats::calculate() {
171 122 : if(sqdonly) {
172 :
173 : double nsqd = 0.;
174 : Value* val;
175 53 : if(!components) {
176 106 : val=getPntrToComponent("sqdevsum");
177 : }
178 174 : for(unsigned i=0; i<parameters.size(); ++i) {
179 121 : double dev = getArgument(i)-parameters[i];
180 121 : if(upperd&&dev<0) {
181 : dev=0.;
182 : }
183 121 : if(components) {
184 0 : val=getPntrToComponent(i);
185 0 : val->set(dev*dev);
186 : } else {
187 121 : nsqd += dev*dev;
188 : }
189 121 : setDerivative(val,i,2.*dev);
190 : }
191 53 : if(!components) {
192 : val->set(nsqd);
193 : }
194 :
195 : } else {
196 :
197 : double scx=0., scx2=0., scy=0., scy2=0., scxy=0.;
198 :
199 6230 : for(unsigned i=0; i<parameters.size(); ++i) {
200 6161 : const double tmpx=getArgument(i);
201 6161 : const double tmpy=parameters[i];
202 6161 : scx += tmpx;
203 6161 : scx2 += tmpx*tmpx;
204 6161 : scy += tmpy;
205 6161 : scy2 += tmpy*tmpy;
206 6161 : scxy += tmpx*tmpy;
207 : }
208 :
209 69 : const double ns = parameters.size();
210 :
211 69 : const double num = ns*scxy - scx*scy;
212 69 : const double idev2x = 1./(ns*scx2-scx*scx);
213 69 : const double idevx = std::sqrt(idev2x);
214 69 : const double idevy = 1./std::sqrt(ns*scy2-scy*scy);
215 :
216 : /* sd */
217 69 : const double nsqd = scx2 + scy2 - 2.*scxy;
218 : /* correlation */
219 69 : const double correlation = num * idevx * idevy;
220 : /* slope and intercept */
221 69 : const double slope = num * idev2x;
222 69 : const double inter = (scy - slope * scx)/ns;
223 :
224 69 : Value* valuea=getPntrToComponent("sqdevsum");
225 69 : Value* valueb=getPntrToComponent("corr");
226 69 : Value* valuec=getPntrToComponent("slope");
227 138 : Value* valued=getPntrToComponent("intercept");
228 :
229 : valuea->set(nsqd);
230 : valueb->set(correlation);
231 : valuec->set(slope);
232 : valued->set(inter);
233 :
234 : /* derivatives */
235 6230 : for(unsigned i=0; i<parameters.size(); ++i) {
236 6161 : const double common_d1 = (ns*parameters[i]-scy)*idevx;
237 6161 : const double common_d2 = num*(ns*getArgument(i)-scx)*idev2x*idevx;
238 6161 : const double common_d3 = common_d1 - common_d2;
239 :
240 : /* sqdevsum */
241 6161 : const double sq_der = 2.*(getArgument(i)-parameters[i]);
242 : /* correlation */
243 6161 : const double co_der = common_d3*idevy;
244 : /* slope */
245 6161 : const double sl_der = (common_d1-2.*common_d2)*idevx;
246 : /* intercept */
247 6161 : const double int_der = -(slope+ scx*sl_der)/ns;
248 :
249 : setDerivative(valuea,i,sq_der);
250 6161 : setDerivative(valueb,i,co_der);
251 6161 : setDerivative(valuec,i,sl_der);
252 6161 : setDerivative(valued,i,int_der);
253 : }
254 :
255 : }
256 122 : }
257 :
258 : }
259 : }
260 :
261 :
|