Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2012-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "SwitchingFunction.h"
23 : #include "Tools.h"
24 : #include "Keywords.h"
25 : #include "OpenMP.h"
26 : #include <vector>
27 : #include <limits>
28 : #include <algorithm>
29 : #include <optional>
30 :
31 : /*
32 : IMPORTANT NOTE FOR DEVELOPERS:
33 :
34 : If you add a new type of switching function in this file please add documentation for your new switching function type in function/LessThan.cpp
35 : */
36 :
37 : namespace PLMD {
38 :
39 : namespace switchContainers {
40 :
41 1654 : baseSwitch::baseSwitch(double D0,double DMAX, double R0, std::string_view name)
42 1654 : : d0(D0),
43 1654 : dmax(DMAX),
44 1654 : dmax_2([](const double d) {
45 1654 : if(d<std::sqrt(std::numeric_limits<double>::max())) {
46 244 : return d*d;
47 : } else {
48 : return std::numeric_limits<double>::max();
49 : }
50 : }(dmax)),
51 1654 : invr0(1.0/R0),
52 1654 : invr0_2(invr0*invr0),
53 1898 : mytype(name) {}
54 :
55 1654 : baseSwitch::~baseSwitch()=default;
56 :
57 162833080 : double baseSwitch::calculate(const double distance, double& dfunc) const {
58 : double res = 0.0;//RVO!
59 162833080 : dfunc = 0.0;
60 162833080 : if(distance <= dmax) {
61 : res = 1.0;
62 156015549 : const double rdist = (distance-d0)*invr0;
63 156015549 : if(rdist > 0.0) {
64 59652852 : res = function(rdist,dfunc);
65 : //the following comments came from the original
66 : // this is for the chain rule (derivative of rdist):
67 59652852 : dfunc *= invr0;
68 : // for any future switching functions, be aware that multiplying invr0 is only
69 : // correct for functions of rdist = (r-d0)/r0.
70 :
71 : // this is because calculate() sets dfunc to the derivative divided times the
72 : // distance.
73 : // (I think this is misleading and I would like to modify it - GB)
74 59652852 : dfunc /= distance;
75 : }
76 156015549 : res=res*stretch+shift;
77 156015549 : dfunc*=stretch;
78 : }
79 162833080 : return res;
80 : }
81 :
82 31818564 : double baseSwitch::calculateSqr(double distance2,double&dfunc) const {
83 31818564 : double res= calculate(std::sqrt(distance2),dfunc);//RVO!
84 31818564 : return res;
85 : }
86 8 : double baseSwitch::get_d0() const {
87 8 : return d0;
88 : }
89 1534 : double baseSwitch::get_r0() const {
90 1534 : return 1.0/invr0;
91 : }
92 536580542 : double baseSwitch::get_dmax() const {
93 536580542 : return dmax;
94 : }
95 49030642 : double baseSwitch::get_dmax2() const {
96 49030642 : return dmax_2;
97 : }
98 1502 : std::string baseSwitch::description() const {
99 1502 : std::ostringstream ostr;
100 1502 : ostr<<get_r0()
101 : <<". Using "
102 : << mytype
103 3004 : <<" switching function with parameters d0="<< d0
104 3004 : << specificDescription();
105 1502 : return ostr.str();
106 1502 : }
107 150 : std::string baseSwitch::specificDescription() const {
108 150 : return "";
109 : }
110 216 : void baseSwitch::setupStretch() {
111 216 : if(dmax!=std::numeric_limits<double>::max()) {
112 216 : stretch=1.0;
113 216 : shift=0.0;
114 : double dummy;
115 216 : double s0=calculate(0.0,dummy);
116 216 : double sd=calculate(dmax,dummy);
117 216 : stretch=1.0/(s0-sd);
118 216 : shift=-sd*stretch;
119 : }
120 216 : }
121 0 : void baseSwitch::removeStretch() {
122 0 : stretch=1.0;
123 0 : shift=0.0;
124 0 : }
125 : template<int N, std::enable_if_t< (N >0), bool> = true, std::enable_if_t< (N %2 == 0), bool> = true>
126 : class fixedRational :public baseSwitch {
127 263 : std::string specificDescription() const override {
128 263 : std::ostringstream ostr;
129 263 : ostr << " nn=" << N << " mm=" <<N*2;
130 263 : return ostr.str();
131 263 : }
132 : public:
133 282 : fixedRational(double D0,double DMAX, double R0)
134 282 : :baseSwitch(D0,DMAX,R0,"rational") {}
135 :
136 : template <int POW>
137 1382 : static inline double doRational(const double rdist, double&dfunc, double result=0.0) {
138 : const double rNdist=Tools::fastpow<POW-1>(rdist);
139 27485030 : result=1.0/(1.0+rNdist*rdist);
140 27485030 : dfunc = -POW*rNdist*result*result;
141 1382 : return result;
142 : }
143 :
144 16154932 : inline double function(double rdist,double&dfunc) const override {
145 : //preRes and preDfunc are passed already set
146 1382 : dfunc=0.0;
147 1382 : double result = doRational<N>(rdist,dfunc);
148 16154932 : return result;
149 : }
150 :
151 11475850 : double calculateSqr(double distance2,double&dfunc) const override {
152 : double result=0.0;
153 11475850 : dfunc=0.0;
154 11475850 : if(distance2 <= dmax_2) {
155 11330098 : const double rdist = distance2*invr0_2;
156 : result = doRational<N/2>(rdist,dfunc);
157 11330098 : dfunc*=2*invr0_2;
158 : // stretch:
159 11330098 : result=result*stretch+shift;
160 11330098 : dfunc*=stretch;
161 : }
162 11475850 : return result;
163 :
164 : }
165 : };
166 :
167 : //these enums are useful for clarifying the settings in the factory
168 : //and the code is autodocumented ;)
169 : enum class rationalPow:bool {standard, fast};
170 : enum class rationalForm:bool {standard, simplified};
171 :
172 : template<rationalPow isFast, rationalForm nis2m>
173 : class rational : public baseSwitch {
174 : protected:
175 : const int nn=6;
176 : const int mm=12;
177 : const double preRes;
178 : const double preDfunc;
179 : const double preSecDev;
180 : const int nnf;
181 : const int mmf;
182 : const double preDfuncF;
183 : const double preSecDevF;
184 : //I am using PLMD::epsilon to be certain to call the one defined in Tools.h
185 : static constexpr double moreThanOne=1.0+5.0e10*PLMD::epsilon;
186 : static constexpr double lessThanOne=1.0-5.0e10*PLMD::epsilon;
187 :
188 177 : std::string specificDescription() const override {
189 177 : std::ostringstream ostr;
190 177 : ostr << " nn=" << nn << " mm=" <<mm;
191 177 : return ostr.str();
192 177 : }
193 : public:
194 196 : rational(double D0,double DMAX, double R0, int N, int M)
195 : :baseSwitch(D0,DMAX,R0,"rational"),
196 196 : nn(N),
197 196 : mm([](int m,int n) {
198 196 : if (m==0) {
199 89 : return n*2;
200 : } else {
201 : return m;
202 : }
203 : }(M,N)),
204 196 : preRes(static_cast<double>(nn)/mm),
205 196 : preDfunc(0.5*nn*(nn-mm)/static_cast<double>(mm)),
206 : //wolfram <3:lim_(x->1) d^2/(dx^2) (1 - x^N)/(1 - x^M) = (N (M^2 - 3 M (-1 + N) + N (-3 + 2 N)))/(6 M)
207 196 : preSecDev ((nn * (mm * mm - 3.0* mm * (-1 + nn ) + nn *(-3 + 2* nn )))/(6.0* mm )),
208 196 : nnf(nn/2),
209 196 : mmf(mm/2),
210 196 : preDfuncF(0.5*nnf*(nnf-mmf)/static_cast<double>(mmf)),
211 196 : preSecDevF((nnf* (mmf*mmf - 3.0* mmf* (-1 + nnf) + nnf*(-3 + 2* nnf)))/(6.0* mmf)) {}
212 :
213 18240673 : static inline double doRational(const double rdist, double&dfunc,double secDev, const int N,
214 : const int M,double result=0.0) {
215 : //the result and dfunc are assigned in the drivers for doRational
216 : //if(rdist>(1.0-100.0*epsilon) && rdist<(1.0+100.0*epsilon)) {
217 : //result=preRes;
218 : //dfunc=preDfunc;
219 : //} else {
220 : if constexpr (nis2m==rationalForm::simplified) {
221 2113979 : const double rNdist=Tools::fastpow(rdist,N-1);
222 2113979 : result=1.0/(1.0+rNdist*rdist);
223 2113979 : dfunc = -N*rNdist*result*result;
224 : } else {
225 16126694 : if(!((rdist > lessThanOne) && (rdist < moreThanOne))) {
226 16126682 : const double rNdist=Tools::fastpow(rdist,N-1);
227 16126682 : const double rMdist=Tools::fastpow(rdist,M-1);
228 16126682 : const double num = 1.0-rNdist*rdist;
229 16126682 : const double iden = 1.0/(1.0-rMdist*rdist);
230 16126682 : result = num*iden;
231 16126682 : dfunc = ((M*result*rMdist)-(N*rNdist))*iden;
232 16126682 : } else {
233 : //here I imply that the correct initialized are being passed to doRational
234 12 : const double x =(rdist-1.0);
235 12 : result = result+ x * ( dfunc + 0.5 * x * secDev);
236 12 : dfunc = dfunc + x * secDev;
237 : }
238 : }
239 18240673 : return result;
240 : }
241 18240621 : inline double function(double rdist,double&dfunc) const override {
242 : //preRes and preDfunc are passed already set
243 18240621 : dfunc=preDfunc;
244 18240621 : double result = doRational(rdist,dfunc,preSecDev,nn,mm,preRes);
245 18240621 : return result;
246 : }
247 :
248 3408359 : double calculateSqr(double distance2,double&dfunc) const override {
249 : if constexpr (isFast==rationalPow::fast) {
250 : double result=0.0;
251 60 : dfunc=0.0;
252 60 : if(distance2 <= dmax_2) {
253 52 : const double rdist = distance2*invr0_2;
254 52 : dfunc=preDfuncF;
255 52 : result = doRational(rdist,dfunc,preSecDevF,nnf,mmf,preRes);
256 52 : dfunc*=2*invr0_2;
257 : // stretch:
258 52 : result=result*stretch+shift;
259 52 : dfunc*=stretch;
260 : }
261 60 : return result;
262 : } else {
263 3408299 : double res= calculate(std::sqrt(distance2),dfunc);//RVO!
264 3408299 : return res;
265 : }
266 : }
267 : };
268 :
269 :
270 : template<int EXP,std::enable_if_t< (EXP %2 == 0), bool> = true>
271 1079 : std::optional<std::unique_ptr<baseSwitch>> fixedRationalFactory(double D0,double DMAX, double R0, int N) {
272 : if constexpr (EXP == 0) {
273 0 : return std::nullopt;
274 : } else {
275 1079 : if (N==EXP) {
276 282 : return PLMD::Tools::make_unique<switchContainers::fixedRational<EXP>>(D0,DMAX,R0);
277 : } else {
278 797 : return fixedRationalFactory<EXP-2>(D0,DMAX,R0,N);
279 : }
280 : }
281 : }
282 :
283 : std::unique_ptr<baseSwitch>
284 478 : rationalFactory(double D0,double DMAX, double R0, int N, int M) {
285 478 : bool fast = N%2==0 && M%2==0 && D0==0.0;
286 : //if (M==0) M will automatically became 2*NN
287 : constexpr int highestPrecompiledPower=12;
288 : //precompiled rational
289 478 : if(((2*N)==M || M == 0) && fast && N<=highestPrecompiledPower) {
290 282 : auto tmp = fixedRationalFactory<highestPrecompiledPower>(D0,DMAX,R0,N);
291 282 : if(tmp) {
292 : return std::move(*tmp);
293 : }
294 : //else continue with the at runtime implementation
295 : }
296 : //template<bool isFast, bool n2m>
297 : //class rational : public baseSwitch
298 196 : if(2*N==M || M == 0) {
299 132 : if(fast) {
300 : //fast rational
301 : return PLMD::Tools::make_unique<switchContainers::rational<
302 0 : rationalPow::fast,rationalForm::simplified>>(D0,DMAX,R0,N,M);
303 : }
304 : return PLMD::Tools::make_unique<switchContainers::rational<
305 132 : rationalPow::standard,rationalForm::simplified>>(D0,DMAX,R0,N,M);
306 : }
307 64 : if(fast) {
308 : //fast rational
309 : return PLMD::Tools::make_unique<switchContainers::rational<
310 61 : rationalPow::fast,rationalForm::standard>>(D0,DMAX,R0,N,M);
311 : }
312 : return PLMD::Tools::make_unique<switchContainers::rational<
313 3 : rationalPow::standard,rationalForm::standard>>(D0,DMAX,R0,N,M);
314 : }
315 : //function =
316 :
317 : class exponentialSwitch: public baseSwitch {
318 : public:
319 75 : exponentialSwitch(double D0, double DMAX, double R0)
320 75 : :baseSwitch(D0,DMAX,R0,"exponential") {}
321 : protected:
322 2404247 : inline double function(const double rdist,double&dfunc) const override {
323 2404247 : double result = std::exp(-rdist);
324 2404247 : dfunc=-result;
325 2404247 : return result;
326 : }
327 : };
328 :
329 : class gaussianSwitch: public baseSwitch {
330 : public:
331 66 : gaussianSwitch(double D0, double DMAX, double R0)
332 66 : :baseSwitch(D0,DMAX,R0,"gaussian") {}
333 : protected:
334 279640 : inline double function(const double rdist,double&dfunc) const override {
335 279640 : double result = std::exp(-0.5*rdist*rdist);
336 279640 : dfunc=-rdist*result;
337 279640 : return result;
338 : }
339 : };
340 :
341 : class fastGaussianSwitch: public baseSwitch {
342 : public:
343 114 : fastGaussianSwitch(double /*D0*/, double DMAX, double /*R0*/)
344 114 : :baseSwitch(0.0,DMAX,1.0,"fastgaussian") {}
345 : protected:
346 1 : inline double function(const double rdist,double&dfunc) const override {
347 1 : double result = std::exp(-0.5*rdist*rdist);
348 1 : dfunc=-rdist*result;
349 1 : return result;
350 : }
351 38317812 : inline double calculateSqr(double distance2,double&dfunc) const override {
352 : double result = 0.0;
353 38317812 : if(distance2>dmax_2) {
354 8 : dfunc=0.0;
355 : } else {
356 38317804 : result = exp(-0.5*distance2);
357 38317804 : dfunc = -result;
358 : // stretch:
359 38317804 : result=result*stretch+shift;
360 38317804 : dfunc*=stretch;
361 : }
362 38317812 : return result;
363 : }
364 : };
365 :
366 : class smapSwitch: public baseSwitch {
367 : const int a=0;
368 : const int b=0;
369 : const double c=0.0;
370 : const double d=0.0;
371 : protected:
372 15 : std::string specificDescription() const override {
373 15 : std::ostringstream ostr;
374 15 : ostr<<" a="<<a<<" b="<<b;
375 15 : return ostr.str();
376 15 : }
377 : public:
378 15 : smapSwitch(double D0, double DMAX, double R0, int A, int B)
379 15 : :baseSwitch(D0,DMAX,R0,"smap"),
380 15 : a(A),
381 15 : b(B),
382 15 : c(std::pow(2., static_cast<double>(a)/static_cast<double>(b) ) - 1.0),
383 15 : d(-static_cast<double>(b) / static_cast<double>(a)) {}
384 : protected:
385 21911326 : inline double function(const double rdist,double&dfunc) const override {
386 :
387 21911326 : const double sx=c*Tools::fastpow( rdist, a );
388 21911326 : double result=std::pow( 1.0 + sx, d );
389 21911326 : dfunc=-b*sx/rdist*result/(1.0+sx);
390 21911326 : return result;
391 : }
392 : };
393 :
394 : class cubicSwitch: public baseSwitch {
395 : protected:
396 15 : std::string specificDescription() const override {
397 15 : std::ostringstream ostr;
398 15 : ostr<<" dmax="<<dmax;
399 15 : return ostr.str();
400 15 : }
401 : public:
402 15 : cubicSwitch(double D0, double DMAX)
403 15 : :baseSwitch(D0,DMAX,DMAX-D0,"cubic") {
404 : //this operation should be already done!!
405 : // R0 = dmax - d0;
406 : // invr0 = 1/R0;
407 : // invr0_2 = invr0*invr0;
408 15 : }
409 15 : ~cubicSwitch()=default;
410 : protected:
411 127256 : inline double function(const double rdist,double&dfunc) const override {
412 127256 : const double tmp1 = rdist - 1.0;
413 127256 : const double tmp2 = 1.0+2.0*rdist;
414 : //double result = tmp1*tmp1*tmp2;
415 127256 : dfunc = 2*tmp1*tmp2 + 2*tmp1*tmp1;
416 127256 : return tmp1*tmp1*tmp2;
417 : }
418 : };
419 :
420 : class tanhSwitch: public baseSwitch {
421 : public:
422 4 : tanhSwitch(double D0, double DMAX, double R0)
423 4 : :baseSwitch(D0,DMAX,R0,"tanh") {}
424 : protected:
425 12718 : inline double function(const double rdist,double&dfunc) const override {
426 12718 : const double tmp1 = std::tanh(rdist);
427 : //was dfunc=-(1-tmp1*tmp1);
428 12718 : dfunc = tmp1 * tmp1 - 1.0;
429 : //return result;
430 12718 : return 1.0 - tmp1;
431 : }
432 : };
433 :
434 : class cosinusSwitch: public baseSwitch {
435 : public:
436 3 : cosinusSwitch(double D0, double DMAX, double R0)
437 3 : :baseSwitch(D0,DMAX,R0,"cosinus") {}
438 : protected:
439 522111 : inline double function(const double rdist,double&dfunc) const override {
440 : double result = 0.0;
441 522111 : dfunc=0.0;
442 522111 : if(rdist<=1.0) {
443 : // rdist = (r-r1)/(r2-r1) ; 0.0<=rdist<=1.0 if r1 <= r <=r2; (r2-r1)/(r2-r1)=1
444 227012 : double rdistPI = rdist * PLMD::pi;
445 227012 : result = 0.5 * (std::cos ( rdistPI ) + 1.0);
446 227012 : dfunc = -0.5 * PLMD::pi * std::sin ( rdistPI ) * invr0;
447 : }
448 522111 : return result;
449 : }
450 : };
451 :
452 : class nativeqSwitch: public baseSwitch {
453 : double beta = 50.0; // nm-1
454 : double lambda = 1.8; // unitless
455 : double ref=0.0;
456 : protected:
457 864 : std::string specificDescription() const override {
458 864 : std::ostringstream ostr;
459 864 : ostr<<" beta="<<beta<<" lambda="<<lambda<<" ref="<<ref;
460 864 : return ostr.str();
461 864 : }
462 0 : inline double function(const double rdist,double&dfunc) const override {
463 0 : return 0.0;
464 : }
465 : public:
466 : nativeqSwitch(double D0, double DMAX, double R0, double BETA, double LAMBDA,double REF)
467 864 : : baseSwitch(D0,DMAX,R0,"nativeq"),beta(BETA),lambda(LAMBDA),ref(REF) {}
468 292924 : double calculate(const double distance, double& dfunc) const override {
469 : double res = 0.0;//RVO!
470 292924 : dfunc = 0.0;
471 292924 : if(distance<=dmax) {
472 : res = 1.0;
473 292916 : if(distance > d0) {
474 292909 : const double rdist = beta*(distance - lambda * ref);
475 292909 : double exprdist=std::exp(rdist);
476 292909 : res=1.0/(1.0+exprdist);
477 : /*2.9
478 : //need to see if this (5op+assign)
479 : //double exprmdist=1.0 + exprdist;
480 : //dfunc = - (beta *exprdist)/(exprmdist*exprmdist);
481 : //or this (5op but 2 divisions) is faster
482 : dfunc = - beta /(exprdist+ 2.0 +1.0/exprdist);
483 : //this cames from - beta * exprdist/(exprdist*exprdist+ 2.0 *exprdist +1.0)
484 : //dfunc *= invr0;
485 : dfunc /= distance;
486 : */
487 : //2.10
488 292909 : dfunc = - beta /(exprdist+ 2.0 +1.0/exprdist) /distance;
489 :
490 292909 : dfunc*=stretch;
491 : }
492 292916 : res=res*stretch+shift;
493 : }
494 292924 : return res;
495 : }
496 : };
497 :
498 : class leptonSwitch: public baseSwitch {
499 : /// Lepton expression.
500 62 : class funcAndDeriv {
501 : lepton::CompiledExpression expression;
502 : lepton::CompiledExpression deriv;
503 : double* varRef=nullptr;
504 : double* varDevRef=nullptr;
505 : public:
506 20 : funcAndDeriv(const std::string &func) {
507 20 : lepton::ParsedExpression pe=lepton::Parser::parse(func).optimize(lepton::Constants());
508 20 : expression=pe.createCompiledExpression();
509 22 : std::string arg="x";
510 :
511 : {
512 20 : auto vars=expression.getVariables();
513 20 : bool found_x=std::find(vars.begin(),vars.end(),"x")!=vars.end();
514 20 : bool found_x2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
515 :
516 20 : if(found_x2) {
517 : arg="x2";
518 : }
519 20 : if (vars.size()==0) {
520 : // this is necessary since in some cases lepton thinks a variable is not present even though it is present
521 : // e.g. func=0*x
522 0 : varRef=nullptr;
523 20 : } else if(vars.size()==1 && (found_x || found_x2)) {
524 18 : varRef=&expression.getVariableReference(arg);
525 : } else {
526 4 : plumed_error()
527 : <<"Please declare a function with only ONE argument that can only be x or x2. Your function is: "
528 4 : << func;
529 : }
530 : }
531 :
532 38 : lepton::ParsedExpression ped=lepton::Parser::parse(func).differentiate(arg).optimize(lepton::Constants());
533 18 : deriv=ped.createCompiledExpression();
534 : {
535 18 : auto vars=expression.getVariables();
536 18 : if (vars.size()==0) {
537 0 : varDevRef=nullptr;
538 : } else {
539 18 : varDevRef=&deriv.getVariableReference(arg);
540 : }
541 : }
542 :
543 22 : }
544 44 : funcAndDeriv (const funcAndDeriv& other):
545 44 : expression(other.expression),
546 44 : deriv(other.deriv) {
547 44 : std::string arg="x";
548 :
549 : {
550 44 : auto vars=expression.getVariables();
551 44 : bool found_x=std::find(vars.begin(),vars.end(),"x")!=vars.end();
552 44 : bool found_x2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
553 :
554 44 : if(found_x2) {
555 : arg="x2";
556 : }
557 44 : if (vars.size()==0) {
558 0 : varRef=nullptr;
559 44 : } else if(vars.size()==1 && (found_x || found_x2)) {
560 44 : varRef=&expression.getVariableReference(arg);
561 : }// UB: I assume that the function is already correct
562 : }
563 :
564 : {
565 44 : auto vars=expression.getVariables();
566 44 : if (vars.size()==0) {
567 0 : varDevRef=nullptr;
568 : } else {
569 44 : varDevRef=&deriv.getVariableReference(arg);
570 : }
571 : }
572 44 : }
573 :
574 : funcAndDeriv& operator= (const funcAndDeriv& other) {
575 : if(this != &other) {
576 : expression = other.expression;
577 : deriv = other.deriv;
578 : std::string arg="x";
579 :
580 : {
581 : auto vars=expression.getVariables();
582 : bool found_x=std::find(vars.begin(),vars.end(),"x")!=vars.end();
583 : bool found_x2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
584 :
585 : if(found_x2) {
586 : arg="x2";
587 : }
588 : if (vars.size()==0) {
589 : varRef=nullptr;
590 : } else if(vars.size()==1 && (found_x || found_x2)) {
591 : varRef=&expression.getVariableReference(arg);
592 : }// UB: I assume that the function is already correct
593 : }
594 :
595 : {
596 : auto vars=expression.getVariables();
597 : if (vars.size()==0) {
598 : varDevRef=nullptr;
599 : } else {
600 : varDevRef=&deriv.getVariableReference(arg);
601 : }
602 : }
603 : }
604 : return *this;
605 : }
606 :
607 6515285 : std::pair<double,double> operator()(double const x) const {
608 : //FAQ: why this works? this thing is const and you are modifying things!
609 : //Actually I am modifying something that is pointed at, not my pointers,
610 : //so I am not mutating the state of this!
611 6515285 : if(varRef) {
612 6515285 : *varRef=x;
613 : }
614 6515285 : if(varDevRef) {
615 6515285 : *varDevRef=x;
616 : }
617 : return std::make_pair(
618 6515285 : expression.evaluate(),
619 6515285 : deriv.evaluate());
620 : }
621 :
622 : auto& getVariables() const {
623 18 : return expression.getVariables();
624 : }
625 : auto& getVariables_derivative() const {
626 : return deriv.getVariables();
627 : }
628 : };
629 : /// Function for lepton
630 : std::string lepton_func;
631 : /// \warning Since lepton::CompiledExpression is mutable, a vector is necessary for multithreading!
632 : std::vector <funcAndDeriv> expressions{};
633 : /// Set to true if lepton only uses x2
634 : bool leptonx2=false;
635 : protected:
636 18 : std::string specificDescription() const override {
637 18 : std::ostringstream ostr;
638 18 : ostr<<" func=" << lepton_func;
639 18 : return ostr.str();
640 18 : }
641 0 : inline double function(const double,double&) const override {
642 0 : return 0.0;
643 : }
644 : public:
645 20 : leptonSwitch(double D0, double DMAX, double R0, const std::string & func)
646 20 : :baseSwitch(D0,DMAX,R0,"lepton"),
647 20 : lepton_func(func),
648 38 : expressions (OpenMP::getNumThreads(), lepton_func) {
649 : //this is a bit odd, but it works
650 : auto vars=expressions[0].getVariables();
651 18 : leptonx2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
652 20 : }
653 :
654 5877796 : double calculate(const double distance,double&dfunc) const override {
655 5877796 : double res = 0.0;//RVO!
656 5877796 : dfunc = 0.0;
657 5877796 : if(leptonx2) {
658 2 : res= calculateSqr(distance*distance,dfunc);
659 : } else {
660 5877794 : if(distance<=dmax) {
661 5573105 : res = 1.0;
662 5573105 : const double rdist = (distance-d0)*invr0;
663 5573105 : if(rdist > 0.0) {
664 5267183 : const unsigned t=OpenMP::getThreadNum();
665 5267183 : plumed_assert(t<expressions.size());
666 5267183 : std::tie(res,dfunc) = expressions[t](rdist);
667 5267183 : dfunc *= invr0;
668 5267183 : dfunc /= distance;
669 : }
670 5573105 : res=res*stretch+shift;
671 5573105 : dfunc*=stretch;
672 : }
673 : }
674 5877796 : return res;
675 : }
676 :
677 7125890 : double calculateSqr(const double distance2,double&dfunc) const override {
678 7125890 : double result =0.0;
679 7125890 : dfunc=0.0;
680 7125890 : if(leptonx2) {
681 1248110 : if(distance2<=dmax_2) {
682 1248102 : const unsigned t=OpenMP::getThreadNum();
683 1248102 : const double rdist_2 = distance2*invr0_2;
684 1248102 : plumed_assert(t<expressions.size());
685 1248102 : std::tie(result,dfunc) = expressions[t](rdist_2);
686 : // chain rule:
687 1248102 : dfunc*=2*invr0_2;
688 : // stretch:
689 1248102 : result=result*stretch+shift;
690 1248102 : dfunc*=stretch;
691 : }
692 : } else {
693 5877780 : result = calculate(std::sqrt(distance2),dfunc);
694 : }
695 7125890 : return result;
696 : }
697 : };
698 : } // namespace switchContainers
699 :
700 0 : void SwitchingFunction::registerKeywords( Keywords& keys ) {
701 0 : keys.add("compulsory","R_0","the value of R_0 in the switching function");
702 0 : keys.add("compulsory","D_0","0.0","the value of D_0 in the switching function");
703 0 : keys.add("optional","D_MAX","the value at which the switching function can be assumed equal to zero");
704 0 : keys.add("compulsory","NN","6","the value of n in the switching function (only needed for TYPE=RATIONAL)");
705 0 : keys.add("compulsory","MM","0","the value of m in the switching function (only needed for TYPE=RATIONAL); 0 implies 2*NN");
706 0 : keys.add("compulsory","A","the value of a in the switching function (only needed for TYPE=SMAP)");
707 0 : keys.add("compulsory","B","the value of b in the switching function (only needed for TYPE=SMAP)");
708 0 : }
709 :
710 1581 : void SwitchingFunction::set(const std::string & definition,std::string& errormsg) {
711 1581 : std::vector<std::string> data=Tools::getWords(definition);
712 : #define CHECKandPARSE(datastring,keyword,variable,errormsg) \
713 : if(Tools::findKeyword(datastring,keyword) && !Tools::parse(datastring,keyword,variable))\
714 : errormsg="could not parse " keyword; //adiacent strings are automagically concatenated
715 : #define REQUIREDPARSE(datastring,keyword,variable,errormsg) \
716 : if(!Tools::parse(datastring,keyword,variable))\
717 : errormsg=keyword " is required for " + name ; //adiacent strings are automagically concatenated
718 :
719 1581 : if( data.size()<1 ) {
720 : errormsg="missing all input for switching function";
721 : return;
722 : }
723 1581 : std::string name=data[0];
724 : data.erase(data.begin());
725 1581 : double r0=0.0;
726 1581 : double d0=0.0;
727 1581 : double dmax=std::numeric_limits<double>::max();
728 1581 : init=true;
729 2009 : CHECKandPARSE(data,"D_0",d0,errormsg);
730 1921 : CHECKandPARSE(data,"D_MAX",dmax,errormsg);
731 :
732 1581 : bool dostretch=false;
733 1581 : Tools::parseFlag(data,"STRETCH",dostretch); // this is ignored now
734 1581 : dostretch=true;
735 1581 : bool dontstretch=false;
736 1581 : Tools::parseFlag(data,"NOSTRETCH",dontstretch); // this is ignored now
737 1581 : if(dontstretch) {
738 175 : dostretch=false;
739 : }
740 1581 : if(name=="CUBIC") {
741 : //cubic is the only switch type that only uses d0 and dmax
742 15 : function = PLMD::Tools::make_unique<switchContainers::cubicSwitch>(d0,dmax);
743 : } else {
744 3132 : REQUIREDPARSE(data,"R_0",r0,errormsg);
745 1566 : if(name=="RATIONAL") {
746 404 : int nn=6;
747 404 : int mm=0;
748 660 : CHECKandPARSE(data,"NN",nn,errormsg);
749 654 : CHECKandPARSE(data,"MM",mm,errormsg);
750 808 : function = switchContainers::rationalFactory(d0,dmax,r0,nn,mm);
751 1162 : } else if(name=="SMAP") {
752 15 : int a=0;
753 15 : int b=0;
754 : //in the original a and b are "default=0",
755 : //but you divide by a and b during the initialization!
756 : //better an error message than an UB, so no default
757 30 : REQUIREDPARSE(data,"A",a,errormsg);
758 30 : REQUIREDPARSE(data,"B",b,errormsg);
759 15 : function = PLMD::Tools::make_unique<switchContainers::smapSwitch>(d0,dmax,r0,a,b);
760 1147 : } else if(name=="Q") {
761 864 : double beta = 50.0; // nm-1
762 864 : double lambda = 1.8; // unitless
763 : double ref;
764 2592 : CHECKandPARSE(data,"BETA",beta,errormsg);
765 2592 : CHECKandPARSE(data,"LAMBDA",lambda,errormsg);
766 1728 : REQUIREDPARSE(data,"REF",ref,errormsg);
767 : //the original error message was not standard
768 : // if(!Tools::parse(data,"REF",ref))
769 : // errormsg="REF (reference distaance) is required for native Q";
770 864 : function = PLMD::Tools::make_unique<switchContainers::nativeqSwitch>(d0,dmax,r0,beta,lambda,ref);
771 283 : } else if(name=="EXP") {
772 75 : function = PLMD::Tools::make_unique<switchContainers::exponentialSwitch>(d0,dmax,r0);
773 208 : } else if(name=="GAUSSIAN") {
774 180 : if ( r0==1.0 && d0==0.0 ) {
775 114 : function = PLMD::Tools::make_unique<switchContainers::fastGaussianSwitch>(d0,dmax,r0);
776 : } else {
777 66 : function = PLMD::Tools::make_unique<switchContainers::gaussianSwitch>(d0,dmax,r0);
778 : }
779 28 : } else if(name=="TANH") {
780 4 : function = PLMD::Tools::make_unique<switchContainers::tanhSwitch>(d0,dmax,r0);
781 24 : } else if(name=="COSINUS") {
782 3 : function = PLMD::Tools::make_unique<switchContainers::cosinusSwitch>(d0,dmax,r0);
783 39 : } else if((name=="MATHEVAL" || name=="CUSTOM")) {
784 : std::string func;
785 40 : Tools::parse(data,"FUNC",func);
786 18 : function = PLMD::Tools::make_unique<switchContainers::leptonSwitch>(d0,dmax,r0,func);
787 : } else {
788 2 : errormsg="cannot understand switching function type '"+name+"'";
789 : }
790 : }
791 : #undef CHECKandPARSE
792 : #undef REQUIREDPARSE
793 :
794 1579 : if( !data.empty() ) {
795 : errormsg="found the following rogue keywords in switching function input : ";
796 0 : for(unsigned i=0; i<data.size(); ++i) {
797 2 : errormsg = errormsg + data[i] + " ";
798 : }
799 : }
800 :
801 1579 : if(dostretch && dmax!=std::numeric_limits<double>::max()) {
802 142 : function->setupStretch();
803 : }
804 1581 : }
805 :
806 1502 : std::string SwitchingFunction::description() const {
807 : // if this error is necessary, something went wrong in the constructor
808 : // plumed_merror("Unknown switching function type");
809 1502 : return function->description();
810 : }
811 :
812 92146473 : double SwitchingFunction::calculateSqr(double distance2,double&dfunc)const {
813 92146473 : return function -> calculateSqr(distance2, dfunc);
814 : }
815 :
816 127898725 : double SwitchingFunction::calculate(double distance,double&dfunc)const {
817 127898725 : plumed_massert(init,"you are trying to use an unset SwitchingFunction");
818 127898725 : double result=function->calculate(distance,dfunc);
819 127898725 : return result;
820 : }
821 :
822 74 : void SwitchingFunction::set(const int nn,int mm, const double r0, const double d0) {
823 74 : init=true;
824 74 : if(mm == 0) {
825 70 : mm = 2*nn;
826 : }
827 74 : double dmax=d0+r0*std::pow(0.00001,1./(nn-mm));
828 148 : function = switchContainers::rationalFactory(d0,dmax,r0,nn,mm);
829 74 : function->setupStretch();
830 74 : }
831 :
832 32 : double SwitchingFunction::get_r0() const {
833 32 : return function->get_r0();
834 : }
835 :
836 8 : double SwitchingFunction::get_d0() const {
837 8 : return function->get_d0();
838 : }
839 :
840 536580542 : double SwitchingFunction::get_dmax() const {
841 536580542 : return function->get_dmax();
842 : }
843 :
844 49030642 : double SwitchingFunction::get_dmax2() const {
845 49030642 : return function->get_dmax2();
846 : }
847 :
848 : }// Namespace PLMD
|