LCOV - code coverage report
Current view: top level - tools - SwitchingFunction.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 385 406 94.8 %
Date: 2024-10-18 13:59:31 Functions: 82 106 77.4 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2012-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "SwitchingFunction.h"
      23             : #include "Tools.h"
      24             : #include "Keywords.h"
      25             : #include "OpenMP.h"
      26             : #include <vector>
      27             : #include <limits>
      28             : #include <algorithm>
      29             : #include <optional>
      30             : 
      31             : namespace PLMD {
      32             : 
      33             : //+PLUMEDOC INTERNAL switchingfunction
      34             : /*
      35             : Functions that measure whether values are less than a certain quantity.
      36             : 
      37             : Switching functions \f$s(r)\f$ take a minimum of one input parameter \f$r_0\f$.
      38             : For \f$r \le d_0 \quad s(r)=1.0\f$ while for \f$r > d_0\f$ the function decays smoothly to 0.
      39             : The various switching functions available in PLUMED differ in terms of how this decay is performed.
      40             : 
      41             : Where there is an accepted convention in the literature (e.g. \ref COORDINATION) on the form of the
      42             : switching function we use the convention as the default.  However, the flexibility to use different
      43             : switching functions is always present generally through a single keyword. This keyword generally
      44             : takes an input with the following form:
      45             : 
      46             : \verbatim
      47             : KEYWORD={TYPE <list of parameters>}
      48             : \endverbatim
      49             : 
      50             : The following table contains a list of the various switching functions that are available in PLUMED 2
      51             : together with an example input.
      52             : 
      53             : <table align=center frame=void width=95%% cellpadding=5%%>
      54             : <tr>
      55             : <td> TYPE </td> <td> FUNCTION </td> <td> EXAMPLE INPUT </td> <td> DEFAULT PARAMETERS </td>
      56             : </tr> <tr> <td>RATIONAL </td> <td>
      57             : \f$
      58             : s(r)=\frac{ 1 - \left(\frac{ r - d_0 }{ r_0 }\right)^{n} }{ 1 - \left(\frac{ r - d_0 }{ r_0 }\right)^{m} }
      59             : \f$
      60             : </td> <td>
      61             : {RATIONAL R_0=\f$r_0\f$ D_0=\f$d_0\f$ NN=\f$n\f$ MM=\f$m\f$}
      62             : </td> <td> \f$d_0=0.0\f$, \f$n=6\f$, \f$m=2n\f$ </td>
      63             : </tr> <tr>
      64             : <td> EXP </td> <td>
      65             : \f$
      66             : s(r)=\exp\left(-\frac{ r - d_0 }{ r_0 }\right)
      67             : \f$
      68             : </td> <td>
      69             : {EXP  R_0=\f$r_0\f$ D_0=\f$d_0\f$}
      70             : </td> <td> \f$d_0=0.0\f$ </td>
      71             : </tr> <tr>
      72             : <td> GAUSSIAN </td> <td>
      73             : \f$
      74             : s(r)=\exp\left(-\frac{ (r - d_0)^2 }{ 2r_0^2 }\right)
      75             : \f$
      76             : </td> <td>
      77             : {GAUSSIAN R_0=\f$r_0\f$ D_0=\f$d_0\f$}
      78             : </td> <td> \f$d_0=0.0\f$ </td>
      79             : </tr> <tr>
      80             : <td> SMAP </td> <td>
      81             : \f$
      82             : s(r) = \left[ 1 + ( 2^{a/b} -1 )\left( \frac{r-d_0}{r_0} \right)^a \right]^{-b/a}
      83             : \f$
      84             : </td> <td>
      85             : {SMAP R_0=\f$r_0\f$ D_0=\f$d_0\f$ A=\f$a\f$ B=\f$b\f$}
      86             : </td> <td> \f$d_0=0.0\f$ </td>
      87             : </tr> <tr>
      88             : <td> Q </td> <td>
      89             : \f$
      90             : s(r) = \frac{1}{1 + \exp(\beta(r_{ij} - \lambda r_{ij}^0))}
      91             : \f$
      92             : </td> <td>
      93             : {Q REF=\f$r_{ij}^0\f$ BETA=\f$\beta\f$ LAMBDA=\f$\lambda\f$ }
      94             : </td> <td> \f$\lambda=1.8\f$,  \f$\beta=50 nm^-1\f$ (all-atom)<br/>\f$\lambda=1.5\f$,  \f$\beta=50 nm^-1\f$ (coarse-grained)  </td>
      95             : </tr> <tr>
      96             : <td> CUBIC </td> <td>
      97             : \f$
      98             : s(r) = (y-1)^2(1+2y) \qquad \textrm{where} \quad y = \frac{r - r_1}{r_0-r_1}
      99             : \f$
     100             : </td> <td>
     101             : {CUBIC D_0=\f$r_1\f$ D_MAX=\f$r_0\f$}
     102             : </td> <td> </td>
     103             : </tr> <tr>
     104             : <td> TANH </td> <td>
     105             : \f$
     106             : s(r) = 1 - \tanh\left( \frac{ r - d_0 }{ r_0 } \right)
     107             : \f$
     108             : </td> <td>
     109             : {TANH R_0=\f$r_0\f$ D_0=\f$d_0\f$}
     110             : </td> <td> </td>
     111             : </tr> <tr>
     112             : <td> COSINUS </td> <td>
     113             : \f$s(r) =\left\{\begin{array}{ll}
     114             :    1                                                           & \mathrm{if } r \leq d_0 \\
     115             :    0.5 \left( \cos ( \frac{ r - d_0 }{ r_0 } \pi ) + 1 \right) & \mathrm{if } d_0 < r\leq d_0 + r_0 \\
     116             :    0                                                           & \mathrm{if } r > d_0 + r_0
     117             :   \end{array}\right.
     118             : \f$
     119             : </td> <td>
     120             : {COSINUS R_0=\f$r_0\f$ D_0=\f$d_0\f$}
     121             : </td> <td> </td>
     122             : </tr> <tr>
     123             : <td> CUSTOM </td> <td>
     124             : \f$
     125             : s(r) = FUNC
     126             : \f$
     127             : </td> <td>
     128             : {CUSTOM FUNC=1/(1+x^6) R_0=\f$r_0\f$ D_0=\f$d_0\f$}
     129             : </td> <td> </td>
     130             : </tr>
     131             : </table>
     132             : 
     133             : Notice that most commonly used rational functions are better optimized and might run faster.
     134             : 
     135             : Notice that for backward compatibility we allow using `MATHEVAL` instead of `CUSTOM`.
     136             : Also notice that if the a `CUSTOM` switching function only depends on even powers of `x` it can be
     137             : made faster by using `x2` as a variable. For instance
     138             : \verbatim
     139             : {CUSTOM FUNC=1/(1+x2^3) R_0=0.3}
     140             : \endverbatim
     141             : is equivalent to
     142             : \verbatim
     143             : {CUSTOM FUNC=1/(1+x^6) R_0=0.3}
     144             : \endverbatim
     145             : but runs faster. The reason is that there is an expensive square root calculation that can be optimized out.
     146             : 
     147             : 
     148             : \attention
     149             : With the default implementation CUSTOM is slower than other functions
     150             : (e.g., it is slower than an equivalent RATIONAL function by approximately a factor 2).
     151             : Checkout page \ref Lepton to see how to improve its performance.
     152             : 
     153             : For all the switching functions in the above table one can also specify a further (optional) parameter using the parameter
     154             : keyword D_MAX to assert that for \f$r>d_{\textrm{max}}\f$ the switching function can be assumed equal to zero.
     155             : In this case the function is brought smoothly to zero by stretching and shifting it.
     156             : \verbatim
     157             : KEYWORD={RATIONAL R_0=1 D_MAX=3}
     158             : \endverbatim
     159             : the resulting switching function will be
     160             : \f$
     161             : s(r) = \frac{s'(r)-s'(d_{max})}{s'(0)-s'(d_{max})}
     162             : \f$
     163             : where
     164             : \f$
     165             : s'(r)=\frac{1-r^6}{1-r^{12}}
     166             : \f$
     167             : Since PLUMED 2.2 this is the default. The old behavior (no stretching) can be obtained with the
     168             : NOSTRETCH flag. The NOSTRETCH keyword is only provided for backward compatibility and might be
     169             : removed in the future. Similarly, the STRETCH keyword is still allowed but has no effect.
     170             : 
     171             : Notice that switching functions defined with the simplified syntax are never stretched
     172             : for backward compatibility. This might change in the future.
     173             : 
     174             : */
     175             : //+ENDPLUMEDOC
     176             : 
     177             : namespace switchContainers {
     178             : 
     179        1356 : baseSwitch::baseSwitch(double D0,double DMAX, double R0, std::string_view name)
     180        1356 :   : d0(D0),
     181        1356 :     dmax(DMAX),
     182        1356 :     dmax_2([](const double d) {
     183        1356 :   if(d<std::sqrt(std::numeric_limits<double>::max())) {
     184         244 :     return  d*d;
     185             :   } else {
     186             :     return std::numeric_limits<double>::max();
     187             :   }
     188             : }(dmax)),
     189        1356 : invr0(1.0/R0),
     190        1356 : invr0_2(invr0*invr0),
     191        1600 : mytype(name) {}
     192             : 
     193        1356 : baseSwitch::~baseSwitch()=default;
     194             : 
     195   162817906 : double baseSwitch::calculate(const double distance, double& dfunc) const {
     196             :   double res = 0.0;//RVO!
     197   162817906 :   dfunc = 0.0;
     198   162817906 :   if(distance <= dmax) {
     199             :     res = 1.0;
     200   156000375 :     const double rdist = (distance-d0)*invr0;
     201   156000375 :     if(rdist > 0.0) {
     202    59637678 :       res = function(rdist,dfunc);
     203             :       //the following comments came from the original
     204             :       // this is for the chain rule (derivative of rdist):
     205    59637678 :       dfunc *= invr0;
     206             :       // for any future switching functions, be aware that multiplying invr0 is only
     207             :       // correct for functions of rdist = (r-d0)/r0.
     208             : 
     209             :       // this is because calculate() sets dfunc to the derivative divided times the
     210             :       // distance.
     211             :       // (I think this is misleading and I would like to modify it - GB)
     212    59637678 :       dfunc /= distance;
     213             :     }
     214   156000375 :     res=res*stretch+shift;
     215   156000375 :     dfunc*=stretch;
     216             :   }
     217   162817906 :   return res;
     218             : }
     219             : 
     220    31818564 : double baseSwitch::calculateSqr(double distance2,double&dfunc) const {
     221    31818564 :   double res= calculate(std::sqrt(distance2),dfunc);//RVO!
     222    31818564 :   return res;
     223             : }
     224           8 : double baseSwitch::get_d0() const {return d0;}
     225        1236 : double baseSwitch::get_r0() const {return 1.0/invr0;}
     226   536580542 : double baseSwitch::get_dmax() const {return dmax;}
     227    49030642 : double baseSwitch::get_dmax2() const {return dmax_2;}
     228        1204 : std::string baseSwitch::description() const {
     229        1204 :   std::ostringstream ostr;
     230        1204 :   ostr<<get_r0()
     231             :       <<".  Using "
     232             :       << mytype
     233        2408 :       <<" switching function with parameters d0="<< d0
     234        2408 :       << specificDescription();
     235        1204 :   return ostr.str();
     236        1204 : }
     237         150 : std::string baseSwitch::specificDescription() const {return "";}
     238         216 : void baseSwitch::setupStretch() {
     239         216 :   if(dmax!=std::numeric_limits<double>::max()) {
     240         216 :     stretch=1.0;
     241         216 :     shift=0.0;
     242             :     double dummy;
     243         216 :     double s0=calculate(0.0,dummy);
     244         216 :     double sd=calculate(dmax,dummy);
     245         216 :     stretch=1.0/(s0-sd);
     246         216 :     shift=-sd*stretch;
     247             :   }
     248         216 : }
     249           0 : void baseSwitch::removeStretch() {
     250           0 :   stretch=1.0;
     251           0 :   shift=0.0;
     252           0 : }
     253             : template<int N, std::enable_if_t< (N >0), bool> = true, std::enable_if_t< (N %2 == 0), bool> = true>
     254             :     class fixedRational :public baseSwitch {
     255         263 :   std::string specificDescription() const override {
     256         263 :     std::ostringstream ostr;
     257         263 :     ostr << " nn=" << N << " mm=" <<N*2;
     258         263 :     return ostr.str();
     259         263 :   }
     260             : public:
     261         282 :   fixedRational(double D0,double DMAX, double R0)
     262         282 :     :baseSwitch(D0,DMAX,R0,"rational") {}
     263             : 
     264             :   template <int POW>
     265        1382 :   static inline double doRational(const double rdist, double&dfunc, double result=0.0) {
     266             :     const double rNdist=Tools::fastpow<POW-1>(rdist);
     267    27485030 :     result=1.0/(1.0+rNdist*rdist);
     268    27485030 :     dfunc = -POW*rNdist*result*result;
     269        1382 :     return result;
     270             :   }
     271             : 
     272    16154932 :   inline double function(double rdist,double&dfunc) const override {
     273             :     //preRes and preDfunc are passed already set
     274        1382 :     dfunc=0.0;
     275        1382 :     double result = doRational<N>(rdist,dfunc);
     276    16154932 :     return result;
     277             :   }
     278             : 
     279    11475850 :   double calculateSqr(double distance2,double&dfunc) const override {
     280             :     double result=0.0;
     281    11475850 :     dfunc=0.0;
     282    11475850 :     if(distance2 <= dmax_2) {
     283    11330098 :       const double rdist = distance2*invr0_2;
     284             :       result = doRational<N/2>(rdist,dfunc);
     285    11330098 :       dfunc*=2*invr0_2;
     286             :       // stretch:
     287    11330098 :       result=result*stretch+shift;
     288    11330098 :       dfunc*=stretch;
     289             :     }
     290    11475850 :     return result;
     291             : 
     292             :   }
     293             : };
     294             : 
     295             : //these enums are useful for clarifying the settings in the factory
     296             : //and the code is autodocumented ;)
     297             : enum class rationalPow:bool {standard, fast};
     298             : enum class rationalForm:bool {standard, simplified};
     299             : 
     300             : template<rationalPow isFast, rationalForm nis2m>
     301             : class rational : public baseSwitch {
     302             : protected:
     303             :   const int nn=6;
     304             :   const int mm=12;
     305             :   const double preRes;
     306             :   const double preDfunc;
     307             :   const double preSecDev;
     308             :   const int nnf;
     309             :   const int mmf;
     310             :   const double preDfuncF;
     311             :   const double preSecDevF;
     312             :   //I am using PLMD::epsilon to be certain to call the one defined in Tools.h
     313             :   static constexpr double moreThanOne=1.0+5.0e10*PLMD::epsilon;
     314             :   static constexpr double lessThanOne=1.0-5.0e10*PLMD::epsilon;
     315             : 
     316         171 :   std::string specificDescription() const override {
     317         171 :     std::ostringstream ostr;
     318         171 :     ostr << " nn=" << nn << " mm=" <<mm;
     319         171 :     return ostr.str();
     320         171 :   }
     321             : public:
     322         190 :   rational(double D0,double DMAX, double R0, int N, int M)
     323             :     :baseSwitch(D0,DMAX,R0,"rational"),
     324         190 :      nn(N),
     325          89 :      mm([](int m,int n) {if (m==0) {return n*2;} else {return m;}}(M,N)),
     326         190 :   preRes(static_cast<double>(nn)/mm),
     327         190 :   preDfunc(0.5*nn*(nn-mm)/static_cast<double>(mm)),
     328             :   //wolfram <3:lim_(x->1) d^2/(dx^2) (1 - x^N)/(1 - x^M) = (N (M^2 - 3 M (-1 + N) + N (-3 + 2 N)))/(6 M)
     329         190 :   preSecDev ((nn * (mm * mm - 3.0* mm * (-1 + nn ) + nn *(-3 + 2* nn )))/(6.0* mm )),
     330         190 :   nnf(nn/2),
     331         190 :   mmf(mm/2),
     332         190 :   preDfuncF(0.5*nnf*(nnf-mmf)/static_cast<double>(mmf)),
     333         190 :   preSecDevF((nnf* (mmf*mmf - 3.0* mmf* (-1 + nnf) + nnf*(-3 + 2* nnf)))/(6.0* mmf)) {}
     334             : 
     335    18225499 :   static inline double doRational(const double rdist, double&dfunc,double secDev, const int N,
     336             :                                   const int M,double result=0.0) {
     337             :     //the result and dfunc are assigned in the drivers for doRational
     338             :     //if(rdist>(1.0-100.0*epsilon) && rdist<(1.0+100.0*epsilon)) {
     339             :     //result=preRes;
     340             :     //dfunc=preDfunc;
     341             :     //} else {
     342             :     if constexpr (nis2m==rationalForm::simplified) {
     343     2113979 :       const double rNdist=Tools::fastpow(rdist,N-1);
     344     2113979 :       result=1.0/(1.0+rNdist*rdist);
     345     2113979 :       dfunc = -N*rNdist*result*result;
     346             :     } else {
     347    16111520 :       if(!((rdist > lessThanOne) && (rdist < moreThanOne))) {
     348    16111508 :         const double rNdist=Tools::fastpow(rdist,N-1);
     349    16111508 :         const double rMdist=Tools::fastpow(rdist,M-1);
     350    16111508 :         const double num = 1.0-rNdist*rdist;
     351    16111508 :         const double iden = 1.0/(1.0-rMdist*rdist);
     352    16111508 :         result = num*iden;
     353    16111508 :         dfunc = ((M*result*rMdist)-(N*rNdist))*iden;
     354    16111508 :       } else {
     355             :         //here I imply that the correct initialized are being passed to doRational
     356          12 :         const double x =(rdist-1.0);
     357          12 :         result = result+ x * ( dfunc + 0.5 * x * secDev);
     358          12 :         dfunc  = dfunc + x * secDev;
     359             :       }
     360             :     }
     361    18225499 :     return result;
     362             :   }
     363    18225447 :   inline double function(double rdist,double&dfunc) const override {
     364             :     //preRes and preDfunc are passed already set
     365    18225447 :     dfunc=preDfunc;
     366    18225447 :     double result = doRational(rdist,dfunc,preSecDev,nn,mm,preRes);
     367    18225447 :     return result;
     368             :   }
     369             : 
     370     3408359 :   double calculateSqr(double distance2,double&dfunc) const override {
     371             :     if constexpr (isFast==rationalPow::fast) {
     372             :       double result=0.0;
     373          60 :       dfunc=0.0;
     374          60 :       if(distance2 <= dmax_2) {
     375          52 :         const double rdist = distance2*invr0_2;
     376          52 :         dfunc=preDfuncF;
     377          52 :         result = doRational(rdist,dfunc,preSecDevF,nnf,mmf,preRes);
     378          52 :         dfunc*=2*invr0_2;
     379             : // stretch:
     380          52 :         result=result*stretch+shift;
     381          52 :         dfunc*=stretch;
     382             :       }
     383          60 :       return result;
     384             :     } else {
     385     3408299 :       double res= calculate(std::sqrt(distance2),dfunc);//RVO!
     386     3408299 :       return res;
     387             :     }
     388             :   }
     389             : };
     390             : 
     391             : 
     392             : template<int EXP,std::enable_if_t< (EXP %2 == 0), bool> = true>
     393        1079 : std::optional<std::unique_ptr<baseSwitch>> fixedRationalFactory(double D0,double DMAX, double R0, int N) {
     394             :   if constexpr (EXP == 0) {
     395           0 :     return  std::nullopt;
     396             :   } else {
     397        1079 :     if (N==EXP) {
     398         282 :       return PLMD::Tools::make_unique<switchContainers::fixedRational<EXP>>(D0,DMAX,R0);
     399             :     } else {
     400         797 :       return fixedRationalFactory<EXP-2>(D0,DMAX,R0,N);
     401             :     }
     402             :   }
     403             : }
     404             : 
     405             : std::unique_ptr<baseSwitch>
     406         472 : rationalFactory(double D0,double DMAX, double R0, int N, int M) {
     407         472 :   bool fast = N%2==0 && M%2==0 && D0==0.0;
     408             :   //if (M==0) M will automatically became 2*NN
     409             :   constexpr int highestPrecompiledPower=12;
     410             :   //precompiled rational
     411         472 :   if(((2*N)==M || M == 0) && fast && N<=highestPrecompiledPower) {
     412         282 :     auto tmp = fixedRationalFactory<highestPrecompiledPower>(D0,DMAX,R0,N);
     413         282 :     if(tmp) {
     414             :       return std::move(*tmp);
     415             :     }
     416             :     //else continue with the at runtime implementation
     417             :   }
     418             :   //template<bool isFast, bool n2m>
     419             :   //class rational : public baseSwitch
     420         190 :   if(2*N==M || M == 0) {
     421         132 :     if(fast) {
     422             :       //fast rational
     423             :       return PLMD::Tools::make_unique<switchContainers::rational<
     424           0 :              rationalPow::fast,rationalForm::simplified>>(D0,DMAX,R0,N,M);
     425             :     }
     426             :     return PLMD::Tools::make_unique<switchContainers::rational<
     427         132 :            rationalPow::standard,rationalForm::simplified>>(D0,DMAX,R0,N,M);
     428             :   }
     429          58 :   if(fast) {
     430             :     //fast rational
     431             :     return PLMD::Tools::make_unique<switchContainers::rational<
     432          55 :            rationalPow::fast,rationalForm::standard>>(D0,DMAX,R0,N,M);
     433             :   }
     434             :   return PLMD::Tools::make_unique<switchContainers::rational<
     435           3 :          rationalPow::standard,rationalForm::standard>>(D0,DMAX,R0,N,M);
     436             : }
     437             : //function =
     438             : 
     439             : class exponentialSwitch: public baseSwitch {
     440             : public:
     441          75 :   exponentialSwitch(double D0, double DMAX, double R0)
     442          75 :     :baseSwitch(D0,DMAX,R0,"exponential") {}
     443             : protected:
     444     2404247 :   inline double function(const double rdist,double&dfunc) const override {
     445     2404247 :     double result = std::exp(-rdist);
     446     2404247 :     dfunc=-result;
     447     2404247 :     return result;
     448             :   }
     449             : };
     450             : 
     451             : class gaussianSwitch: public baseSwitch {
     452             : public:
     453          66 :   gaussianSwitch(double D0, double DMAX, double R0)
     454          66 :     :baseSwitch(D0,DMAX,R0,"gaussian") {}
     455             : protected:
     456      279640 :   inline double function(const double rdist,double&dfunc) const override {
     457      279640 :     double result = std::exp(-0.5*rdist*rdist);
     458      279640 :     dfunc=-rdist*result;
     459      279640 :     return result;
     460             :   }
     461             : };
     462             : 
     463             : class fastGaussianSwitch: public baseSwitch {
     464             : public:
     465         114 :   fastGaussianSwitch(double /*D0*/, double DMAX, double /*R0*/)
     466         114 :     :baseSwitch(0.0,DMAX,1.0,"fastgaussian") {}
     467             : protected:
     468           1 :   inline double function(const double rdist,double&dfunc) const override {
     469           1 :     double result = std::exp(-0.5*rdist*rdist);
     470           1 :     dfunc=-rdist*result;
     471           1 :     return result;
     472             :   }
     473    38317812 :   inline double calculateSqr(double distance2,double&dfunc) const override {
     474             :     double result = 0.0;
     475    38317812 :     if(distance2>dmax_2) {
     476           8 :       dfunc=0.0;
     477             :     } else  {
     478    38317804 :       result = exp(-0.5*distance2);
     479    38317804 :       dfunc = -result;
     480             :       // stretch:
     481    38317804 :       result=result*stretch+shift;
     482    38317804 :       dfunc*=stretch;
     483             :     }
     484    38317812 :     return result;
     485             :   }
     486             : };
     487             : 
     488             : class smapSwitch: public baseSwitch {
     489             :   const int a=0;
     490             :   const int b=0;
     491             :   const double c=0.0;
     492             :   const double d=0.0;
     493             : protected:
     494          15 :   std::string specificDescription() const override {
     495          15 :     std::ostringstream ostr;
     496          15 :     ostr<<" a="<<a<<" b="<<b;
     497          15 :     return ostr.str();
     498          15 :   }
     499             : public:
     500          15 :   smapSwitch(double D0, double DMAX, double R0, int A, int B)
     501          15 :     :baseSwitch(D0,DMAX,R0,"smap"),
     502          15 :      a(A),
     503          15 :      b(B),
     504          15 :      c(std::pow(2., static_cast<double>(a)/static_cast<double>(b) ) - 1.0),
     505          15 :      d(-static_cast<double>(b) / static_cast<double>(a)) {}
     506             : protected:
     507    21911326 :   inline double function(const double rdist,double&dfunc) const override {
     508             : 
     509    21911326 :     const double sx=c*Tools::fastpow( rdist, a );
     510    21911326 :     double result=std::pow( 1.0 + sx, d );
     511    21911326 :     dfunc=-b*sx/rdist*result/(1.0+sx);
     512    21911326 :     return result;
     513             :   }
     514             : };
     515             : 
     516             : class cubicSwitch: public baseSwitch {
     517             : protected:
     518          15 :   std::string specificDescription() const override {
     519          15 :     std::ostringstream ostr;
     520          15 :     ostr<<" dmax="<<dmax;
     521          15 :     return ostr.str();
     522          15 :   }
     523             : public:
     524          15 :   cubicSwitch(double D0, double DMAX)
     525          15 :     :baseSwitch(D0,DMAX,DMAX-D0,"cubic") {
     526             :     //this operation should be already done!!
     527             :     // R0 = dmax - d0;
     528             :     // invr0 = 1/R0;
     529             :     // invr0_2 = invr0*invr0;
     530          15 :   }
     531          15 :   ~cubicSwitch()=default;
     532             : protected:
     533      127256 :   inline double function(const double rdist,double&dfunc) const override {
     534      127256 :     const double tmp1 = rdist - 1.0;
     535      127256 :     const double tmp2 = 1.0+2.0*rdist;
     536             :     //double result = tmp1*tmp1*tmp2;
     537      127256 :     dfunc = 2*tmp1*tmp2 + 2*tmp1*tmp1;
     538      127256 :     return tmp1*tmp1*tmp2;
     539             :   }
     540             : };
     541             : 
     542             : class tanhSwitch: public baseSwitch {
     543             : public:
     544           4 :   tanhSwitch(double D0, double DMAX, double R0)
     545           4 :     :baseSwitch(D0,DMAX,R0,"tanh") {}
     546             : protected:
     547       12718 :   inline double function(const double rdist,double&dfunc) const override {
     548       12718 :     const double tmp1 = std::tanh(rdist);
     549             :     //was dfunc=-(1-tmp1*tmp1);
     550       12718 :     dfunc = tmp1 * tmp1 - 1.0;
     551             :     //return result;
     552       12718 :     return 1.0 - tmp1;
     553             :   }
     554             : };
     555             : 
     556             : class cosinusSwitch: public baseSwitch {
     557             : public:
     558           3 :   cosinusSwitch(double D0, double DMAX, double R0)
     559           3 :     :baseSwitch(D0,DMAX,R0,"cosinus") {}
     560             : protected:
     561      522111 :   inline double function(const double rdist,double&dfunc) const override {
     562             :     double result = 0.0;
     563      522111 :     dfunc=0.0;
     564      522111 :     if(rdist<=1.0) {
     565             : // rdist = (r-r1)/(r2-r1) ; 0.0<=rdist<=1.0 if r1 <= r <=r2; (r2-r1)/(r2-r1)=1
     566      227012 :       double rdistPI = rdist * PLMD::pi;
     567      227012 :       result = 0.5 * (std::cos ( rdistPI ) + 1.0);
     568      227012 :       dfunc = -0.5 * PLMD::pi * std::sin ( rdistPI ) * invr0;
     569             :     }
     570      522111 :     return result;
     571             :   }
     572             : };
     573             : 
     574             : class nativeqSwitch: public baseSwitch {
     575             :   double beta = 50.0;  // nm-1
     576             :   double lambda = 1.8; // unitless
     577             :   double ref=0.0;
     578             : protected:
     579         572 :   std::string specificDescription() const override {
     580         572 :     std::ostringstream ostr;
     581         572 :     ostr<<" beta="<<beta<<" lambda="<<lambda<<" ref="<<ref;
     582         572 :     return ostr.str();
     583         572 :   }
     584           0 :   inline double function(const double rdist,double&dfunc) const override {return 0.0;  }
     585             : public:
     586             :   nativeqSwitch(double D0, double DMAX, double R0, double BETA, double LAMBDA,double REF)
     587         572 :     :  baseSwitch(D0,DMAX,R0,"nativeq"),beta(BETA),lambda(LAMBDA),ref(REF) {}
     588      146632 :   double calculate(const double distance, double& dfunc) const override {
     589             :     double res = 0.0;//RVO!
     590      146632 :     dfunc = 0.0;
     591      146632 :     if(distance<=dmax) {
     592             :       res = 1.0;
     593      146624 :       if(distance > d0) {
     594      146617 :         const double rdist = beta*(distance - lambda * ref);
     595      146617 :         double exprdist=std::exp(rdist);
     596      146617 :         res=1.0/(1.0+exprdist);
     597             :         /*2.9
     598             :         //need to see if this (5op+assign)
     599             :         //double exprmdist=1.0 + exprdist;
     600             :         //dfunc = - (beta *exprdist)/(exprmdist*exprmdist);
     601             :         //or this (5op but 2 divisions) is faster
     602             :         dfunc = - beta /(exprdist+ 2.0 +1.0/exprdist);
     603             :         //this cames from - beta * exprdist/(exprdist*exprdist+ 2.0 *exprdist +1.0)
     604             :         //dfunc *= invr0;
     605             :         dfunc /= distance;
     606             :         */
     607             :         //2.10
     608      146617 :         dfunc = - beta /(exprdist+ 2.0 +1.0/exprdist) /distance;
     609             : 
     610      146617 :         dfunc*=stretch;
     611             :       }
     612      146624 :       res=res*stretch+shift;
     613             :     }
     614      146632 :     return res;
     615             :   }
     616             : };
     617             : 
     618             : class leptonSwitch: public baseSwitch {
     619             : /// Lepton expression.
     620          62 :   class funcAndDeriv {
     621             :     lepton::CompiledExpression expression;
     622             :     lepton::CompiledExpression deriv;
     623             :     double* varRef=nullptr;
     624             :     double* varDevRef=nullptr;
     625             :   public:
     626          20 :     funcAndDeriv(const std::string &func) {
     627          20 :       lepton::ParsedExpression pe=lepton::Parser::parse(func).optimize(lepton::Constants());
     628          20 :       expression=pe.createCompiledExpression();
     629          22 :       std::string arg="x";
     630             : 
     631             :       {
     632          20 :         auto vars=expression.getVariables();
     633          20 :         bool found_x=std::find(vars.begin(),vars.end(),"x")!=vars.end();
     634          20 :         bool found_x2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
     635             : 
     636          20 :         if(found_x2) {
     637             :           arg="x2";
     638             :         }
     639          20 :         if (vars.size()==0) {
     640             : // this is necessary since in some cases lepton thinks a variable is not present even though it is present
     641             : // e.g. func=0*x
     642           0 :           varRef=nullptr;
     643          20 :         } else if(vars.size()==1 && (found_x || found_x2)) {
     644          18 :           varRef=&expression.getVariableReference(arg);
     645             :         } else {
     646           4 :           plumed_error()
     647             :               <<"Please declare a function with only ONE argument that can only be x or x2. Your function is: "
     648           4 :               << func;
     649             :         }
     650             :       }
     651             : 
     652          38 :       lepton::ParsedExpression ped=lepton::Parser::parse(func).differentiate(arg).optimize(lepton::Constants());
     653          18 :       deriv=ped.createCompiledExpression();
     654             :       {
     655          18 :         auto vars=expression.getVariables();
     656          18 :         if (vars.size()==0) {
     657           0 :           varDevRef=nullptr;
     658             :         } else {
     659          18 :           varDevRef=&deriv.getVariableReference(arg);
     660             :         }
     661             :       }
     662             : 
     663          22 :     }
     664          44 :     funcAndDeriv (const funcAndDeriv& other):
     665          44 :       expression(other.expression),
     666          44 :       deriv(other.deriv) {
     667          44 :       std::string arg="x";
     668             : 
     669             :       {
     670          44 :         auto vars=expression.getVariables();
     671          44 :         bool found_x=std::find(vars.begin(),vars.end(),"x")!=vars.end();
     672          44 :         bool found_x2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
     673             : 
     674          44 :         if(found_x2) {
     675             :           arg="x2";
     676             :         }
     677          44 :         if (vars.size()==0) {
     678           0 :           varRef=nullptr;
     679          44 :         } else if(vars.size()==1 && (found_x || found_x2)) {
     680          44 :           varRef=&expression.getVariableReference(arg);
     681             :         }// UB: I assume that the function is already correct
     682             :       }
     683             : 
     684             :       {
     685          44 :         auto vars=expression.getVariables();
     686          44 :         if (vars.size()==0) {
     687           0 :           varDevRef=nullptr;
     688             :         } else {
     689          44 :           varDevRef=&deriv.getVariableReference(arg);
     690             :         }
     691             :       }
     692          44 :     }
     693             : 
     694             :     funcAndDeriv& operator= (const funcAndDeriv& other) {
     695             :       if(this != &other) {
     696             :         expression = other.expression;
     697             :         deriv = other.deriv;
     698             :         std::string arg="x";
     699             : 
     700             :         {
     701             :           auto vars=expression.getVariables();
     702             :           bool found_x=std::find(vars.begin(),vars.end(),"x")!=vars.end();
     703             :           bool found_x2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
     704             : 
     705             :           if(found_x2) {
     706             :             arg="x2";
     707             :           }
     708             :           if (vars.size()==0) {
     709             :             varRef=nullptr;
     710             :           } else if(vars.size()==1 && (found_x || found_x2)) {
     711             :             varRef=&expression.getVariableReference(arg);
     712             :           }// UB: I assume that the function is already correct
     713             :         }
     714             : 
     715             :         {
     716             :           auto vars=expression.getVariables();
     717             :           if (vars.size()==0) {
     718             :             varDevRef=nullptr;
     719             :           } else {
     720             :             varDevRef=&deriv.getVariableReference(arg);
     721             :           }
     722             :         }
     723             :       }
     724             :       return *this;
     725             :     }
     726             : 
     727     6515285 :     std::pair<double,double> operator()(double const x) const {
     728             :       //FAQ: why this works? this thing is const and you are modifying things!
     729             :       //Actually I am modifying something that is pointed at, not my pointers,
     730             :       //so I am not mutating the state of this!
     731     6515285 :       if(varRef) {
     732     6515285 :         *varRef=x;
     733             :       }
     734     6515285 :       if(varDevRef) {
     735     6515285 :         *varDevRef=x;
     736             :       }
     737             :       return std::make_pair(
     738     6515285 :                expression.evaluate(),
     739     6515285 :                deriv.evaluate());
     740             :     }
     741             : 
     742             :     auto& getVariables() const {
     743          18 :       return expression.getVariables();
     744             :     }
     745             :     auto& getVariables_derivative() const {
     746             :       return deriv.getVariables();
     747             :     }
     748             :   };
     749             :   /// Function for lepton
     750             :   std::string lepton_func;
     751             :   /// \warning Since lepton::CompiledExpression is mutable, a vector is necessary for multithreading!
     752             :   std::vector <funcAndDeriv> expressions{};
     753             :   /// Set to true if lepton only uses x2
     754             :   bool leptonx2=false;
     755             : protected:
     756          18 :   std::string specificDescription() const override {
     757          18 :     std::ostringstream ostr;
     758          18 :     ostr<<" func=" << lepton_func;
     759          18 :     return ostr.str();
     760          18 :   }
     761           0 :   inline double function(const double,double&) const override {return 0.0;}
     762             : public:
     763          20 :   leptonSwitch(double D0, double DMAX, double R0, const std::string & func)
     764          20 :     :baseSwitch(D0,DMAX,R0,"lepton"),
     765          20 :      lepton_func(func),
     766          38 :      expressions  (OpenMP::getNumThreads(), lepton_func) {
     767             :     //this is a bit odd, but it works
     768             :     auto vars=expressions[0].getVariables();
     769          18 :     leptonx2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
     770          20 :   }
     771             : 
     772     5877796 :   double calculate(const double distance,double&dfunc) const override {
     773     5877796 :     double res = 0.0;//RVO!
     774     5877796 :     dfunc = 0.0;
     775     5877796 :     if(leptonx2) {
     776           2 :       res= calculateSqr(distance*distance,dfunc);
     777             :     } else {
     778     5877794 :       if(distance<=dmax) {
     779     5573105 :         res = 1.0;
     780     5573105 :         const double rdist = (distance-d0)*invr0;
     781     5573105 :         if(rdist > 0.0) {
     782     5267183 :           const unsigned t=OpenMP::getThreadNum();
     783     5267183 :           plumed_assert(t<expressions.size());
     784     5267183 :           std::tie(res,dfunc) = expressions[t](rdist);
     785     5267183 :           dfunc *= invr0;
     786     5267183 :           dfunc /= distance;
     787             :         }
     788     5573105 :         res=res*stretch+shift;
     789     5573105 :         dfunc*=stretch;
     790             :       }
     791             :     }
     792     5877796 :     return res;
     793             :   }
     794             : 
     795     7125890 :   double calculateSqr(const double distance2,double&dfunc) const override {
     796     7125890 :     double result =0.0;
     797     7125890 :     dfunc=0.0;
     798     7125890 :     if(leptonx2) {
     799     1248110 :       if(distance2<=dmax_2) {
     800     1248102 :         const unsigned t=OpenMP::getThreadNum();
     801     1248102 :         const double rdist_2 = distance2*invr0_2;
     802     1248102 :         plumed_assert(t<expressions.size());
     803     1248102 :         std::tie(result,dfunc) = expressions[t](rdist_2);
     804             :         // chain rule:
     805     1248102 :         dfunc*=2*invr0_2;
     806             :         // stretch:
     807     1248102 :         result=result*stretch+shift;
     808     1248102 :         dfunc*=stretch;
     809             :       }
     810             :     } else {
     811     5877780 :       result = calculate(std::sqrt(distance2),dfunc);
     812             :     }
     813     7125890 :     return result;
     814             :   }
     815             : };
     816             : } // namespace switchContainers
     817             : 
     818           0 : void SwitchingFunction::registerKeywords( Keywords& keys ) {
     819           0 :   keys.add("compulsory","R_0","the value of R_0 in the switching function");
     820           0 :   keys.add("compulsory","D_0","0.0","the value of D_0 in the switching function");
     821           0 :   keys.add("optional","D_MAX","the value at which the switching function can be assumed equal to zero");
     822           0 :   keys.add("compulsory","NN","6","the value of n in the switching function (only needed for TYPE=RATIONAL)");
     823           0 :   keys.add("compulsory","MM","0","the value of m in the switching function (only needed for TYPE=RATIONAL); 0 implies 2*NN");
     824           0 :   keys.add("compulsory","A","the value of a in the switching function (only needed for TYPE=SMAP)");
     825           0 :   keys.add("compulsory","B","the value of b in the switching function (only needed for TYPE=SMAP)");
     826           0 : }
     827             : 
     828        1283 : void SwitchingFunction::set(const std::string & definition,std::string& errormsg) {
     829        1283 :   std::vector<std::string> data=Tools::getWords(definition);
     830             : #define CHECKandPARSE(datastring,keyword,variable,errormsg) \
     831             :   if(Tools::findKeyword(datastring,keyword) && !Tools::parse(datastring,keyword,variable))\
     832             :     errormsg="could not parse " keyword; //adiacent strings are automagically concatenated
     833             : #define REQUIREDPARSE(datastring,keyword,variable,errormsg) \
     834             :   if(!Tools::parse(datastring,keyword,variable))\
     835             :     errormsg=keyword " is required for " + name ; //adiacent strings are automagically concatenated
     836             : 
     837        1283 :   if( data.size()<1 ) {
     838             :     errormsg="missing all input for switching function";
     839             :     return;
     840             :   }
     841        1283 :   std::string name=data[0];
     842             :   data.erase(data.begin());
     843        1283 :   double r0=0.0;
     844        1283 :   double d0=0.0;
     845        1283 :   double dmax=std::numeric_limits<double>::max();
     846        1283 :   init=true;
     847        1711 :   CHECKandPARSE(data,"D_0",d0,errormsg);
     848        1623 :   CHECKandPARSE(data,"D_MAX",dmax,errormsg);
     849             : 
     850        1283 :   bool dostretch=false;
     851        1283 :   Tools::parseFlag(data,"STRETCH",dostretch); // this is ignored now
     852        1283 :   dostretch=true;
     853        1283 :   bool dontstretch=false;
     854        1283 :   Tools::parseFlag(data,"NOSTRETCH",dontstretch); // this is ignored now
     855        1283 :   if(dontstretch)
     856         169 :     dostretch=false;
     857        1283 :   if(name=="CUBIC") {
     858             :     //cubic is the only switch type that only uses d0 and dmax
     859          15 :     function = PLMD::Tools::make_unique<switchContainers::cubicSwitch>(d0,dmax);
     860             :   } else {
     861        2536 :     REQUIREDPARSE(data,"R_0",r0,errormsg);
     862        1268 :     if(name=="RATIONAL") {
     863         398 :       int nn=6;
     864         398 :       int mm=0;
     865         642 :       CHECKandPARSE(data,"NN",nn,errormsg);
     866         636 :       CHECKandPARSE(data,"MM",mm,errormsg);
     867         796 :       function = switchContainers::rationalFactory(d0,dmax,r0,nn,mm);
     868         870 :     } else if(name=="SMAP") {
     869          15 :       int a=0;
     870          15 :       int b=0;
     871             :       //in the original a and b are "default=0",
     872             :       //but you divide by a and b during the initialization!
     873             :       //better an error message than an UB, so no default
     874          30 :       REQUIREDPARSE(data,"A",a,errormsg);
     875          30 :       REQUIREDPARSE(data,"B",b,errormsg);
     876          15 :       function = PLMD::Tools::make_unique<switchContainers::smapSwitch>(d0,dmax,r0,a,b);
     877         855 :     } else if(name=="Q") {
     878         572 :       double beta = 50.0;  // nm-1
     879         572 :       double lambda = 1.8; // unitless
     880             :       double ref;
     881        1716 :       CHECKandPARSE(data,"BETA",beta,errormsg);
     882        1716 :       CHECKandPARSE(data,"LAMBDA",lambda,errormsg);
     883        1144 :       REQUIREDPARSE(data,"REF",ref,errormsg);
     884             :       //the original error message was not standard
     885             :       // if(!Tools::parse(data,"REF",ref))
     886             :       //   errormsg="REF (reference distaance) is required for native Q";
     887         572 :       function = PLMD::Tools::make_unique<switchContainers::nativeqSwitch>(d0,dmax,r0,beta,lambda,ref);
     888         283 :     } else if(name=="EXP") {
     889          75 :       function = PLMD::Tools::make_unique<switchContainers::exponentialSwitch>(d0,dmax,r0);
     890         208 :     } else if(name=="GAUSSIAN") {
     891         180 :       if ( r0==1.0 && d0==0.0 ) {
     892         114 :         function = PLMD::Tools::make_unique<switchContainers::fastGaussianSwitch>(d0,dmax,r0);
     893             :       } else {
     894          66 :         function = PLMD::Tools::make_unique<switchContainers::gaussianSwitch>(d0,dmax,r0);
     895             :       }
     896          28 :     } else if(name=="TANH") {
     897           4 :       function = PLMD::Tools::make_unique<switchContainers::tanhSwitch>(d0,dmax,r0);
     898          24 :     } else if(name=="COSINUS") {
     899           3 :       function = PLMD::Tools::make_unique<switchContainers::cosinusSwitch>(d0,dmax,r0);
     900          39 :     } else if((name=="MATHEVAL" || name=="CUSTOM")) {
     901             :       std::string func;
     902          40 :       Tools::parse(data,"FUNC",func);
     903          18 :       function = PLMD::Tools::make_unique<switchContainers::leptonSwitch>(d0,dmax,r0,func);
     904             :     } else {
     905           2 :       errormsg="cannot understand switching function type '"+name+"'";
     906             :     }
     907             :   }
     908             : #undef CHECKandPARSE
     909             : #undef REQUIREDPARSE
     910             : 
     911        1281 :   if( !data.empty() ) {
     912             :     errormsg="found the following rogue keywords in switching function input : ";
     913           2 :     for(unsigned i=0; i<data.size(); ++i) errormsg = errormsg + data[i] + " ";
     914             :   }
     915             : 
     916        1281 :   if(dostretch && dmax!=std::numeric_limits<double>::max()) {
     917         142 :     function->setupStretch();
     918             :   }
     919        1283 : }
     920             : 
     921        1204 : std::string SwitchingFunction::description() const {
     922             :   // if this error is necessary, something went wrong in the constructor
     923             :   //  plumed_merror("Unknown switching function type");
     924        1204 :   return function->description();
     925             : }
     926             : 
     927    92146473 : double SwitchingFunction::calculateSqr(double distance2,double&dfunc)const {
     928    92146473 :   return function -> calculateSqr(distance2, dfunc);
     929             : }
     930             : 
     931   127737259 : double SwitchingFunction::calculate(double distance,double&dfunc)const {
     932   127737259 :   plumed_massert(init,"you are trying to use an unset SwitchingFunction");
     933   127737259 :   double result=function->calculate(distance,dfunc);
     934   127737259 :   return result;
     935             : }
     936             : 
     937          74 : void SwitchingFunction::set(const int nn,int mm, const double r0, const double d0) {
     938          74 :   init=true;
     939          74 :   if(mm == 0) {
     940          70 :     mm = 2*nn;
     941             :   }
     942          74 :   double dmax=d0+r0*std::pow(0.00001,1./(nn-mm));
     943         148 :   function = switchContainers::rationalFactory(d0,dmax,r0,nn,mm);
     944          74 :   function->setupStretch();
     945          74 : }
     946             : 
     947          32 : double SwitchingFunction::get_r0() const {
     948          32 :   return function->get_r0();
     949             : }
     950             : 
     951           8 : double SwitchingFunction::get_d0() const {
     952           8 :   return function->get_d0();
     953             : }
     954             : 
     955   536580542 : double SwitchingFunction::get_dmax() const {
     956   536580542 :   return function->get_dmax();
     957             : }
     958             : 
     959    49030642 : double SwitchingFunction::get_dmax2() const {
     960    49030642 :   return function->get_dmax2();
     961             : }
     962             : 
     963             : }// Namespace PLMD

Generated by: LCOV version 1.16