LCOV - code coverage report
Current view: top level - tools - SwitchingFunction.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 392 416 94.2 %
Date: 2025-03-25 09:33:27 Functions: 82 106 77.4 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2012-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "SwitchingFunction.h"
      23             : #include "Tools.h"
      24             : #include "Keywords.h"
      25             : #include "OpenMP.h"
      26             : #include <vector>
      27             : #include <limits>
      28             : #include <algorithm>
      29             : #include <optional>
      30             : 
      31             : /*
      32             : IMPORTANT NOTE FOR DEVELOPERS:
      33             : 
      34             : If you add a new type of switching function in this file please add documentation for your new switching function type in function/LessThan.cpp
      35             : */
      36             : 
      37             : namespace PLMD {
      38             : 
      39             : namespace switchContainers {
      40             : 
      41        1654 : baseSwitch::baseSwitch(double D0,double DMAX, double R0, std::string_view name)
      42        1654 :   : d0(D0),
      43        1654 :     dmax(DMAX),
      44        1654 :     dmax_2([](const double d) {
      45        1654 :   if(d<std::sqrt(std::numeric_limits<double>::max())) {
      46         244 :     return  d*d;
      47             :   } else {
      48             :     return std::numeric_limits<double>::max();
      49             :   }
      50             : }(dmax)),
      51        1654 : invr0(1.0/R0),
      52        1654 : invr0_2(invr0*invr0),
      53        1898 : mytype(name) {}
      54             : 
      55        1654 : baseSwitch::~baseSwitch()=default;
      56             : 
      57   162833080 : double baseSwitch::calculate(const double distance, double& dfunc) const {
      58             :   double res = 0.0;//RVO!
      59   162833080 :   dfunc = 0.0;
      60   162833080 :   if(distance <= dmax) {
      61             :     res = 1.0;
      62   156015549 :     const double rdist = (distance-d0)*invr0;
      63   156015549 :     if(rdist > 0.0) {
      64    59652852 :       res = function(rdist,dfunc);
      65             :       //the following comments came from the original
      66             :       // this is for the chain rule (derivative of rdist):
      67    59652852 :       dfunc *= invr0;
      68             :       // for any future switching functions, be aware that multiplying invr0 is only
      69             :       // correct for functions of rdist = (r-d0)/r0.
      70             : 
      71             :       // this is because calculate() sets dfunc to the derivative divided times the
      72             :       // distance.
      73             :       // (I think this is misleading and I would like to modify it - GB)
      74    59652852 :       dfunc /= distance;
      75             :     }
      76   156015549 :     res=res*stretch+shift;
      77   156015549 :     dfunc*=stretch;
      78             :   }
      79   162833080 :   return res;
      80             : }
      81             : 
      82    31818564 : double baseSwitch::calculateSqr(double distance2,double&dfunc) const {
      83    31818564 :   double res= calculate(std::sqrt(distance2),dfunc);//RVO!
      84    31818564 :   return res;
      85             : }
      86           8 : double baseSwitch::get_d0() const {
      87           8 :   return d0;
      88             : }
      89        1534 : double baseSwitch::get_r0() const {
      90        1534 :   return 1.0/invr0;
      91             : }
      92   536580542 : double baseSwitch::get_dmax() const {
      93   536580542 :   return dmax;
      94             : }
      95    49030642 : double baseSwitch::get_dmax2() const {
      96    49030642 :   return dmax_2;
      97             : }
      98        1502 : std::string baseSwitch::description() const {
      99        1502 :   std::ostringstream ostr;
     100        1502 :   ostr<<get_r0()
     101             :       <<".  Using "
     102             :       << mytype
     103        3004 :       <<" switching function with parameters d0="<< d0
     104        3004 :       << specificDescription();
     105        1502 :   return ostr.str();
     106        1502 : }
     107         150 : std::string baseSwitch::specificDescription() const {
     108         150 :   return "";
     109             : }
     110         216 : void baseSwitch::setupStretch() {
     111         216 :   if(dmax!=std::numeric_limits<double>::max()) {
     112         216 :     stretch=1.0;
     113         216 :     shift=0.0;
     114             :     double dummy;
     115         216 :     double s0=calculate(0.0,dummy);
     116         216 :     double sd=calculate(dmax,dummy);
     117         216 :     stretch=1.0/(s0-sd);
     118         216 :     shift=-sd*stretch;
     119             :   }
     120         216 : }
     121           0 : void baseSwitch::removeStretch() {
     122           0 :   stretch=1.0;
     123           0 :   shift=0.0;
     124           0 : }
     125             : template<int N, std::enable_if_t< (N >0), bool> = true, std::enable_if_t< (N %2 == 0), bool> = true>
     126             :     class fixedRational :public baseSwitch {
     127         263 :   std::string specificDescription() const override {
     128         263 :     std::ostringstream ostr;
     129         263 :     ostr << " nn=" << N << " mm=" <<N*2;
     130         263 :     return ostr.str();
     131         263 :   }
     132             : public:
     133         282 :   fixedRational(double D0,double DMAX, double R0)
     134         282 :     :baseSwitch(D0,DMAX,R0,"rational") {}
     135             : 
     136             :   template <int POW>
     137        1382 :   static inline double doRational(const double rdist, double&dfunc, double result=0.0) {
     138             :     const double rNdist=Tools::fastpow<POW-1>(rdist);
     139    27485030 :     result=1.0/(1.0+rNdist*rdist);
     140    27485030 :     dfunc = -POW*rNdist*result*result;
     141        1382 :     return result;
     142             :   }
     143             : 
     144    16154932 :   inline double function(double rdist,double&dfunc) const override {
     145             :     //preRes and preDfunc are passed already set
     146        1382 :     dfunc=0.0;
     147        1382 :     double result = doRational<N>(rdist,dfunc);
     148    16154932 :     return result;
     149             :   }
     150             : 
     151    11475850 :   double calculateSqr(double distance2,double&dfunc) const override {
     152             :     double result=0.0;
     153    11475850 :     dfunc=0.0;
     154    11475850 :     if(distance2 <= dmax_2) {
     155    11330098 :       const double rdist = distance2*invr0_2;
     156             :       result = doRational<N/2>(rdist,dfunc);
     157    11330098 :       dfunc*=2*invr0_2;
     158             :       // stretch:
     159    11330098 :       result=result*stretch+shift;
     160    11330098 :       dfunc*=stretch;
     161             :     }
     162    11475850 :     return result;
     163             : 
     164             :   }
     165             : };
     166             : 
     167             : //these enums are useful for clarifying the settings in the factory
     168             : //and the code is autodocumented ;)
     169             : enum class rationalPow:bool {standard, fast};
     170             : enum class rationalForm:bool {standard, simplified};
     171             : 
     172             : template<rationalPow isFast, rationalForm nis2m>
     173             : class rational : public baseSwitch {
     174             : protected:
     175             :   const int nn=6;
     176             :   const int mm=12;
     177             :   const double preRes;
     178             :   const double preDfunc;
     179             :   const double preSecDev;
     180             :   const int nnf;
     181             :   const int mmf;
     182             :   const double preDfuncF;
     183             :   const double preSecDevF;
     184             :   //I am using PLMD::epsilon to be certain to call the one defined in Tools.h
     185             :   static constexpr double moreThanOne=1.0+5.0e10*PLMD::epsilon;
     186             :   static constexpr double lessThanOne=1.0-5.0e10*PLMD::epsilon;
     187             : 
     188         177 :   std::string specificDescription() const override {
     189         177 :     std::ostringstream ostr;
     190         177 :     ostr << " nn=" << nn << " mm=" <<mm;
     191         177 :     return ostr.str();
     192         177 :   }
     193             : public:
     194         196 :   rational(double D0,double DMAX, double R0, int N, int M)
     195             :     :baseSwitch(D0,DMAX,R0,"rational"),
     196         196 :      nn(N),
     197         196 :      mm([](int m,int n) {
     198         196 :     if (m==0) {
     199          89 :       return n*2;
     200             :     } else {
     201             :       return m;
     202             :     }
     203             :   }(M,N)),
     204         196 :   preRes(static_cast<double>(nn)/mm),
     205         196 :   preDfunc(0.5*nn*(nn-mm)/static_cast<double>(mm)),
     206             :   //wolfram <3:lim_(x->1) d^2/(dx^2) (1 - x^N)/(1 - x^M) = (N (M^2 - 3 M (-1 + N) + N (-3 + 2 N)))/(6 M)
     207         196 :   preSecDev ((nn * (mm * mm - 3.0* mm * (-1 + nn ) + nn *(-3 + 2* nn )))/(6.0* mm )),
     208         196 :   nnf(nn/2),
     209         196 :   mmf(mm/2),
     210         196 :   preDfuncF(0.5*nnf*(nnf-mmf)/static_cast<double>(mmf)),
     211         196 :   preSecDevF((nnf* (mmf*mmf - 3.0* mmf* (-1 + nnf) + nnf*(-3 + 2* nnf)))/(6.0* mmf)) {}
     212             : 
     213    18240673 :   static inline double doRational(const double rdist, double&dfunc,double secDev, const int N,
     214             :                                   const int M,double result=0.0) {
     215             :     //the result and dfunc are assigned in the drivers for doRational
     216             :     //if(rdist>(1.0-100.0*epsilon) && rdist<(1.0+100.0*epsilon)) {
     217             :     //result=preRes;
     218             :     //dfunc=preDfunc;
     219             :     //} else {
     220             :     if constexpr (nis2m==rationalForm::simplified) {
     221     2113979 :       const double rNdist=Tools::fastpow(rdist,N-1);
     222     2113979 :       result=1.0/(1.0+rNdist*rdist);
     223     2113979 :       dfunc = -N*rNdist*result*result;
     224             :     } else {
     225    16126694 :       if(!((rdist > lessThanOne) && (rdist < moreThanOne))) {
     226    16126682 :         const double rNdist=Tools::fastpow(rdist,N-1);
     227    16126682 :         const double rMdist=Tools::fastpow(rdist,M-1);
     228    16126682 :         const double num = 1.0-rNdist*rdist;
     229    16126682 :         const double iden = 1.0/(1.0-rMdist*rdist);
     230    16126682 :         result = num*iden;
     231    16126682 :         dfunc = ((M*result*rMdist)-(N*rNdist))*iden;
     232    16126682 :       } else {
     233             :         //here I imply that the correct initialized are being passed to doRational
     234          12 :         const double x =(rdist-1.0);
     235          12 :         result = result+ x * ( dfunc + 0.5 * x * secDev);
     236          12 :         dfunc  = dfunc + x * secDev;
     237             :       }
     238             :     }
     239    18240673 :     return result;
     240             :   }
     241    18240621 :   inline double function(double rdist,double&dfunc) const override {
     242             :     //preRes and preDfunc are passed already set
     243    18240621 :     dfunc=preDfunc;
     244    18240621 :     double result = doRational(rdist,dfunc,preSecDev,nn,mm,preRes);
     245    18240621 :     return result;
     246             :   }
     247             : 
     248     3408359 :   double calculateSqr(double distance2,double&dfunc) const override {
     249             :     if constexpr (isFast==rationalPow::fast) {
     250             :       double result=0.0;
     251          60 :       dfunc=0.0;
     252          60 :       if(distance2 <= dmax_2) {
     253          52 :         const double rdist = distance2*invr0_2;
     254          52 :         dfunc=preDfuncF;
     255          52 :         result = doRational(rdist,dfunc,preSecDevF,nnf,mmf,preRes);
     256          52 :         dfunc*=2*invr0_2;
     257             : // stretch:
     258          52 :         result=result*stretch+shift;
     259          52 :         dfunc*=stretch;
     260             :       }
     261          60 :       return result;
     262             :     } else {
     263     3408299 :       double res= calculate(std::sqrt(distance2),dfunc);//RVO!
     264     3408299 :       return res;
     265             :     }
     266             :   }
     267             : };
     268             : 
     269             : 
     270             : template<int EXP,std::enable_if_t< (EXP %2 == 0), bool> = true>
     271        1079 : std::optional<std::unique_ptr<baseSwitch>> fixedRationalFactory(double D0,double DMAX, double R0, int N) {
     272             :   if constexpr (EXP == 0) {
     273           0 :     return  std::nullopt;
     274             :   } else {
     275        1079 :     if (N==EXP) {
     276         282 :       return PLMD::Tools::make_unique<switchContainers::fixedRational<EXP>>(D0,DMAX,R0);
     277             :     } else {
     278         797 :       return fixedRationalFactory<EXP-2>(D0,DMAX,R0,N);
     279             :     }
     280             :   }
     281             : }
     282             : 
     283             : std::unique_ptr<baseSwitch>
     284         478 : rationalFactory(double D0,double DMAX, double R0, int N, int M) {
     285         478 :   bool fast = N%2==0 && M%2==0 && D0==0.0;
     286             :   //if (M==0) M will automatically became 2*NN
     287             :   constexpr int highestPrecompiledPower=12;
     288             :   //precompiled rational
     289         478 :   if(((2*N)==M || M == 0) && fast && N<=highestPrecompiledPower) {
     290         282 :     auto tmp = fixedRationalFactory<highestPrecompiledPower>(D0,DMAX,R0,N);
     291         282 :     if(tmp) {
     292             :       return std::move(*tmp);
     293             :     }
     294             :     //else continue with the at runtime implementation
     295             :   }
     296             :   //template<bool isFast, bool n2m>
     297             :   //class rational : public baseSwitch
     298         196 :   if(2*N==M || M == 0) {
     299         132 :     if(fast) {
     300             :       //fast rational
     301             :       return PLMD::Tools::make_unique<switchContainers::rational<
     302           0 :              rationalPow::fast,rationalForm::simplified>>(D0,DMAX,R0,N,M);
     303             :     }
     304             :     return PLMD::Tools::make_unique<switchContainers::rational<
     305         132 :            rationalPow::standard,rationalForm::simplified>>(D0,DMAX,R0,N,M);
     306             :   }
     307          64 :   if(fast) {
     308             :     //fast rational
     309             :     return PLMD::Tools::make_unique<switchContainers::rational<
     310          61 :            rationalPow::fast,rationalForm::standard>>(D0,DMAX,R0,N,M);
     311             :   }
     312             :   return PLMD::Tools::make_unique<switchContainers::rational<
     313           3 :          rationalPow::standard,rationalForm::standard>>(D0,DMAX,R0,N,M);
     314             : }
     315             : //function =
     316             : 
     317             : class exponentialSwitch: public baseSwitch {
     318             : public:
     319          75 :   exponentialSwitch(double D0, double DMAX, double R0)
     320          75 :     :baseSwitch(D0,DMAX,R0,"exponential") {}
     321             : protected:
     322     2404247 :   inline double function(const double rdist,double&dfunc) const override {
     323     2404247 :     double result = std::exp(-rdist);
     324     2404247 :     dfunc=-result;
     325     2404247 :     return result;
     326             :   }
     327             : };
     328             : 
     329             : class gaussianSwitch: public baseSwitch {
     330             : public:
     331          66 :   gaussianSwitch(double D0, double DMAX, double R0)
     332          66 :     :baseSwitch(D0,DMAX,R0,"gaussian") {}
     333             : protected:
     334      279640 :   inline double function(const double rdist,double&dfunc) const override {
     335      279640 :     double result = std::exp(-0.5*rdist*rdist);
     336      279640 :     dfunc=-rdist*result;
     337      279640 :     return result;
     338             :   }
     339             : };
     340             : 
     341             : class fastGaussianSwitch: public baseSwitch {
     342             : public:
     343         114 :   fastGaussianSwitch(double /*D0*/, double DMAX, double /*R0*/)
     344         114 :     :baseSwitch(0.0,DMAX,1.0,"fastgaussian") {}
     345             : protected:
     346           1 :   inline double function(const double rdist,double&dfunc) const override {
     347           1 :     double result = std::exp(-0.5*rdist*rdist);
     348           1 :     dfunc=-rdist*result;
     349           1 :     return result;
     350             :   }
     351    38317812 :   inline double calculateSqr(double distance2,double&dfunc) const override {
     352             :     double result = 0.0;
     353    38317812 :     if(distance2>dmax_2) {
     354           8 :       dfunc=0.0;
     355             :     } else  {
     356    38317804 :       result = exp(-0.5*distance2);
     357    38317804 :       dfunc = -result;
     358             :       // stretch:
     359    38317804 :       result=result*stretch+shift;
     360    38317804 :       dfunc*=stretch;
     361             :     }
     362    38317812 :     return result;
     363             :   }
     364             : };
     365             : 
     366             : class smapSwitch: public baseSwitch {
     367             :   const int a=0;
     368             :   const int b=0;
     369             :   const double c=0.0;
     370             :   const double d=0.0;
     371             : protected:
     372          15 :   std::string specificDescription() const override {
     373          15 :     std::ostringstream ostr;
     374          15 :     ostr<<" a="<<a<<" b="<<b;
     375          15 :     return ostr.str();
     376          15 :   }
     377             : public:
     378          15 :   smapSwitch(double D0, double DMAX, double R0, int A, int B)
     379          15 :     :baseSwitch(D0,DMAX,R0,"smap"),
     380          15 :      a(A),
     381          15 :      b(B),
     382          15 :      c(std::pow(2., static_cast<double>(a)/static_cast<double>(b) ) - 1.0),
     383          15 :      d(-static_cast<double>(b) / static_cast<double>(a)) {}
     384             : protected:
     385    21911326 :   inline double function(const double rdist,double&dfunc) const override {
     386             : 
     387    21911326 :     const double sx=c*Tools::fastpow( rdist, a );
     388    21911326 :     double result=std::pow( 1.0 + sx, d );
     389    21911326 :     dfunc=-b*sx/rdist*result/(1.0+sx);
     390    21911326 :     return result;
     391             :   }
     392             : };
     393             : 
     394             : class cubicSwitch: public baseSwitch {
     395             : protected:
     396          15 :   std::string specificDescription() const override {
     397          15 :     std::ostringstream ostr;
     398          15 :     ostr<<" dmax="<<dmax;
     399          15 :     return ostr.str();
     400          15 :   }
     401             : public:
     402          15 :   cubicSwitch(double D0, double DMAX)
     403          15 :     :baseSwitch(D0,DMAX,DMAX-D0,"cubic") {
     404             :     //this operation should be already done!!
     405             :     // R0 = dmax - d0;
     406             :     // invr0 = 1/R0;
     407             :     // invr0_2 = invr0*invr0;
     408          15 :   }
     409          15 :   ~cubicSwitch()=default;
     410             : protected:
     411      127256 :   inline double function(const double rdist,double&dfunc) const override {
     412      127256 :     const double tmp1 = rdist - 1.0;
     413      127256 :     const double tmp2 = 1.0+2.0*rdist;
     414             :     //double result = tmp1*tmp1*tmp2;
     415      127256 :     dfunc = 2*tmp1*tmp2 + 2*tmp1*tmp1;
     416      127256 :     return tmp1*tmp1*tmp2;
     417             :   }
     418             : };
     419             : 
     420             : class tanhSwitch: public baseSwitch {
     421             : public:
     422           4 :   tanhSwitch(double D0, double DMAX, double R0)
     423           4 :     :baseSwitch(D0,DMAX,R0,"tanh") {}
     424             : protected:
     425       12718 :   inline double function(const double rdist,double&dfunc) const override {
     426       12718 :     const double tmp1 = std::tanh(rdist);
     427             :     //was dfunc=-(1-tmp1*tmp1);
     428       12718 :     dfunc = tmp1 * tmp1 - 1.0;
     429             :     //return result;
     430       12718 :     return 1.0 - tmp1;
     431             :   }
     432             : };
     433             : 
     434             : class cosinusSwitch: public baseSwitch {
     435             : public:
     436           3 :   cosinusSwitch(double D0, double DMAX, double R0)
     437           3 :     :baseSwitch(D0,DMAX,R0,"cosinus") {}
     438             : protected:
     439      522111 :   inline double function(const double rdist,double&dfunc) const override {
     440             :     double result = 0.0;
     441      522111 :     dfunc=0.0;
     442      522111 :     if(rdist<=1.0) {
     443             : // rdist = (r-r1)/(r2-r1) ; 0.0<=rdist<=1.0 if r1 <= r <=r2; (r2-r1)/(r2-r1)=1
     444      227012 :       double rdistPI = rdist * PLMD::pi;
     445      227012 :       result = 0.5 * (std::cos ( rdistPI ) + 1.0);
     446      227012 :       dfunc = -0.5 * PLMD::pi * std::sin ( rdistPI ) * invr0;
     447             :     }
     448      522111 :     return result;
     449             :   }
     450             : };
     451             : 
     452             : class nativeqSwitch: public baseSwitch {
     453             :   double beta = 50.0;  // nm-1
     454             :   double lambda = 1.8; // unitless
     455             :   double ref=0.0;
     456             : protected:
     457         864 :   std::string specificDescription() const override {
     458         864 :     std::ostringstream ostr;
     459         864 :     ostr<<" beta="<<beta<<" lambda="<<lambda<<" ref="<<ref;
     460         864 :     return ostr.str();
     461         864 :   }
     462           0 :   inline double function(const double rdist,double&dfunc) const override {
     463           0 :     return 0.0;
     464             :   }
     465             : public:
     466             :   nativeqSwitch(double D0, double DMAX, double R0, double BETA, double LAMBDA,double REF)
     467         864 :     :  baseSwitch(D0,DMAX,R0,"nativeq"),beta(BETA),lambda(LAMBDA),ref(REF) {}
     468      292924 :   double calculate(const double distance, double& dfunc) const override {
     469             :     double res = 0.0;//RVO!
     470      292924 :     dfunc = 0.0;
     471      292924 :     if(distance<=dmax) {
     472             :       res = 1.0;
     473      292916 :       if(distance > d0) {
     474      292909 :         const double rdist = beta*(distance - lambda * ref);
     475      292909 :         double exprdist=std::exp(rdist);
     476      292909 :         res=1.0/(1.0+exprdist);
     477             :         /*2.9
     478             :         //need to see if this (5op+assign)
     479             :         //double exprmdist=1.0 + exprdist;
     480             :         //dfunc = - (beta *exprdist)/(exprmdist*exprmdist);
     481             :         //or this (5op but 2 divisions) is faster
     482             :         dfunc = - beta /(exprdist+ 2.0 +1.0/exprdist);
     483             :         //this cames from - beta * exprdist/(exprdist*exprdist+ 2.0 *exprdist +1.0)
     484             :         //dfunc *= invr0;
     485             :         dfunc /= distance;
     486             :         */
     487             :         //2.10
     488      292909 :         dfunc = - beta /(exprdist+ 2.0 +1.0/exprdist) /distance;
     489             : 
     490      292909 :         dfunc*=stretch;
     491             :       }
     492      292916 :       res=res*stretch+shift;
     493             :     }
     494      292924 :     return res;
     495             :   }
     496             : };
     497             : 
     498             : class leptonSwitch: public baseSwitch {
     499             : /// Lepton expression.
     500          62 :   class funcAndDeriv {
     501             :     lepton::CompiledExpression expression;
     502             :     lepton::CompiledExpression deriv;
     503             :     double* varRef=nullptr;
     504             :     double* varDevRef=nullptr;
     505             :   public:
     506          20 :     funcAndDeriv(const std::string &func) {
     507          20 :       lepton::ParsedExpression pe=lepton::Parser::parse(func).optimize(lepton::Constants());
     508          20 :       expression=pe.createCompiledExpression();
     509          22 :       std::string arg="x";
     510             : 
     511             :       {
     512          20 :         auto vars=expression.getVariables();
     513          20 :         bool found_x=std::find(vars.begin(),vars.end(),"x")!=vars.end();
     514          20 :         bool found_x2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
     515             : 
     516          20 :         if(found_x2) {
     517             :           arg="x2";
     518             :         }
     519          20 :         if (vars.size()==0) {
     520             : // this is necessary since in some cases lepton thinks a variable is not present even though it is present
     521             : // e.g. func=0*x
     522           0 :           varRef=nullptr;
     523          20 :         } else if(vars.size()==1 && (found_x || found_x2)) {
     524          18 :           varRef=&expression.getVariableReference(arg);
     525             :         } else {
     526           4 :           plumed_error()
     527             :               <<"Please declare a function with only ONE argument that can only be x or x2. Your function is: "
     528           4 :               << func;
     529             :         }
     530             :       }
     531             : 
     532          38 :       lepton::ParsedExpression ped=lepton::Parser::parse(func).differentiate(arg).optimize(lepton::Constants());
     533          18 :       deriv=ped.createCompiledExpression();
     534             :       {
     535          18 :         auto vars=expression.getVariables();
     536          18 :         if (vars.size()==0) {
     537           0 :           varDevRef=nullptr;
     538             :         } else {
     539          18 :           varDevRef=&deriv.getVariableReference(arg);
     540             :         }
     541             :       }
     542             : 
     543          22 :     }
     544          44 :     funcAndDeriv (const funcAndDeriv& other):
     545          44 :       expression(other.expression),
     546          44 :       deriv(other.deriv) {
     547          44 :       std::string arg="x";
     548             : 
     549             :       {
     550          44 :         auto vars=expression.getVariables();
     551          44 :         bool found_x=std::find(vars.begin(),vars.end(),"x")!=vars.end();
     552          44 :         bool found_x2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
     553             : 
     554          44 :         if(found_x2) {
     555             :           arg="x2";
     556             :         }
     557          44 :         if (vars.size()==0) {
     558           0 :           varRef=nullptr;
     559          44 :         } else if(vars.size()==1 && (found_x || found_x2)) {
     560          44 :           varRef=&expression.getVariableReference(arg);
     561             :         }// UB: I assume that the function is already correct
     562             :       }
     563             : 
     564             :       {
     565          44 :         auto vars=expression.getVariables();
     566          44 :         if (vars.size()==0) {
     567           0 :           varDevRef=nullptr;
     568             :         } else {
     569          44 :           varDevRef=&deriv.getVariableReference(arg);
     570             :         }
     571             :       }
     572          44 :     }
     573             : 
     574             :     funcAndDeriv& operator= (const funcAndDeriv& other) {
     575             :       if(this != &other) {
     576             :         expression = other.expression;
     577             :         deriv = other.deriv;
     578             :         std::string arg="x";
     579             : 
     580             :         {
     581             :           auto vars=expression.getVariables();
     582             :           bool found_x=std::find(vars.begin(),vars.end(),"x")!=vars.end();
     583             :           bool found_x2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
     584             : 
     585             :           if(found_x2) {
     586             :             arg="x2";
     587             :           }
     588             :           if (vars.size()==0) {
     589             :             varRef=nullptr;
     590             :           } else if(vars.size()==1 && (found_x || found_x2)) {
     591             :             varRef=&expression.getVariableReference(arg);
     592             :           }// UB: I assume that the function is already correct
     593             :         }
     594             : 
     595             :         {
     596             :           auto vars=expression.getVariables();
     597             :           if (vars.size()==0) {
     598             :             varDevRef=nullptr;
     599             :           } else {
     600             :             varDevRef=&deriv.getVariableReference(arg);
     601             :           }
     602             :         }
     603             :       }
     604             :       return *this;
     605             :     }
     606             : 
     607     6515285 :     std::pair<double,double> operator()(double const x) const {
     608             :       //FAQ: why this works? this thing is const and you are modifying things!
     609             :       //Actually I am modifying something that is pointed at, not my pointers,
     610             :       //so I am not mutating the state of this!
     611     6515285 :       if(varRef) {
     612     6515285 :         *varRef=x;
     613             :       }
     614     6515285 :       if(varDevRef) {
     615     6515285 :         *varDevRef=x;
     616             :       }
     617             :       return std::make_pair(
     618     6515285 :                expression.evaluate(),
     619     6515285 :                deriv.evaluate());
     620             :     }
     621             : 
     622             :     auto& getVariables() const {
     623          18 :       return expression.getVariables();
     624             :     }
     625             :     auto& getVariables_derivative() const {
     626             :       return deriv.getVariables();
     627             :     }
     628             :   };
     629             :   /// Function for lepton
     630             :   std::string lepton_func;
     631             :   /// \warning Since lepton::CompiledExpression is mutable, a vector is necessary for multithreading!
     632             :   std::vector <funcAndDeriv> expressions{};
     633             :   /// Set to true if lepton only uses x2
     634             :   bool leptonx2=false;
     635             : protected:
     636          18 :   std::string specificDescription() const override {
     637          18 :     std::ostringstream ostr;
     638          18 :     ostr<<" func=" << lepton_func;
     639          18 :     return ostr.str();
     640          18 :   }
     641           0 :   inline double function(const double,double&) const override {
     642           0 :     return 0.0;
     643             :   }
     644             : public:
     645          20 :   leptonSwitch(double D0, double DMAX, double R0, const std::string & func)
     646          20 :     :baseSwitch(D0,DMAX,R0,"lepton"),
     647          20 :      lepton_func(func),
     648          38 :      expressions  (OpenMP::getNumThreads(), lepton_func) {
     649             :     //this is a bit odd, but it works
     650             :     auto vars=expressions[0].getVariables();
     651          18 :     leptonx2=std::find(vars.begin(),vars.end(),"x2")!=vars.end();
     652          20 :   }
     653             : 
     654     5877796 :   double calculate(const double distance,double&dfunc) const override {
     655     5877796 :     double res = 0.0;//RVO!
     656     5877796 :     dfunc = 0.0;
     657     5877796 :     if(leptonx2) {
     658           2 :       res= calculateSqr(distance*distance,dfunc);
     659             :     } else {
     660     5877794 :       if(distance<=dmax) {
     661     5573105 :         res = 1.0;
     662     5573105 :         const double rdist = (distance-d0)*invr0;
     663     5573105 :         if(rdist > 0.0) {
     664     5267183 :           const unsigned t=OpenMP::getThreadNum();
     665     5267183 :           plumed_assert(t<expressions.size());
     666     5267183 :           std::tie(res,dfunc) = expressions[t](rdist);
     667     5267183 :           dfunc *= invr0;
     668     5267183 :           dfunc /= distance;
     669             :         }
     670     5573105 :         res=res*stretch+shift;
     671     5573105 :         dfunc*=stretch;
     672             :       }
     673             :     }
     674     5877796 :     return res;
     675             :   }
     676             : 
     677     7125890 :   double calculateSqr(const double distance2,double&dfunc) const override {
     678     7125890 :     double result =0.0;
     679     7125890 :     dfunc=0.0;
     680     7125890 :     if(leptonx2) {
     681     1248110 :       if(distance2<=dmax_2) {
     682     1248102 :         const unsigned t=OpenMP::getThreadNum();
     683     1248102 :         const double rdist_2 = distance2*invr0_2;
     684     1248102 :         plumed_assert(t<expressions.size());
     685     1248102 :         std::tie(result,dfunc) = expressions[t](rdist_2);
     686             :         // chain rule:
     687     1248102 :         dfunc*=2*invr0_2;
     688             :         // stretch:
     689     1248102 :         result=result*stretch+shift;
     690     1248102 :         dfunc*=stretch;
     691             :       }
     692             :     } else {
     693     5877780 :       result = calculate(std::sqrt(distance2),dfunc);
     694             :     }
     695     7125890 :     return result;
     696             :   }
     697             : };
     698             : } // namespace switchContainers
     699             : 
     700           0 : void SwitchingFunction::registerKeywords( Keywords& keys ) {
     701           0 :   keys.add("compulsory","R_0","the value of R_0 in the switching function");
     702           0 :   keys.add("compulsory","D_0","0.0","the value of D_0 in the switching function");
     703           0 :   keys.add("optional","D_MAX","the value at which the switching function can be assumed equal to zero");
     704           0 :   keys.add("compulsory","NN","6","the value of n in the switching function (only needed for TYPE=RATIONAL)");
     705           0 :   keys.add("compulsory","MM","0","the value of m in the switching function (only needed for TYPE=RATIONAL); 0 implies 2*NN");
     706           0 :   keys.add("compulsory","A","the value of a in the switching function (only needed for TYPE=SMAP)");
     707           0 :   keys.add("compulsory","B","the value of b in the switching function (only needed for TYPE=SMAP)");
     708           0 : }
     709             : 
     710        1581 : void SwitchingFunction::set(const std::string & definition,std::string& errormsg) {
     711        1581 :   std::vector<std::string> data=Tools::getWords(definition);
     712             : #define CHECKandPARSE(datastring,keyword,variable,errormsg) \
     713             :   if(Tools::findKeyword(datastring,keyword) && !Tools::parse(datastring,keyword,variable))\
     714             :     errormsg="could not parse " keyword; //adiacent strings are automagically concatenated
     715             : #define REQUIREDPARSE(datastring,keyword,variable,errormsg) \
     716             :   if(!Tools::parse(datastring,keyword,variable))\
     717             :     errormsg=keyword " is required for " + name ; //adiacent strings are automagically concatenated
     718             : 
     719        1581 :   if( data.size()<1 ) {
     720             :     errormsg="missing all input for switching function";
     721             :     return;
     722             :   }
     723        1581 :   std::string name=data[0];
     724             :   data.erase(data.begin());
     725        1581 :   double r0=0.0;
     726        1581 :   double d0=0.0;
     727        1581 :   double dmax=std::numeric_limits<double>::max();
     728        1581 :   init=true;
     729        2009 :   CHECKandPARSE(data,"D_0",d0,errormsg);
     730        1921 :   CHECKandPARSE(data,"D_MAX",dmax,errormsg);
     731             : 
     732        1581 :   bool dostretch=false;
     733        1581 :   Tools::parseFlag(data,"STRETCH",dostretch); // this is ignored now
     734        1581 :   dostretch=true;
     735        1581 :   bool dontstretch=false;
     736        1581 :   Tools::parseFlag(data,"NOSTRETCH",dontstretch); // this is ignored now
     737        1581 :   if(dontstretch) {
     738         175 :     dostretch=false;
     739             :   }
     740        1581 :   if(name=="CUBIC") {
     741             :     //cubic is the only switch type that only uses d0 and dmax
     742          15 :     function = PLMD::Tools::make_unique<switchContainers::cubicSwitch>(d0,dmax);
     743             :   } else {
     744        3132 :     REQUIREDPARSE(data,"R_0",r0,errormsg);
     745        1566 :     if(name=="RATIONAL") {
     746         404 :       int nn=6;
     747         404 :       int mm=0;
     748         660 :       CHECKandPARSE(data,"NN",nn,errormsg);
     749         654 :       CHECKandPARSE(data,"MM",mm,errormsg);
     750         808 :       function = switchContainers::rationalFactory(d0,dmax,r0,nn,mm);
     751        1162 :     } else if(name=="SMAP") {
     752          15 :       int a=0;
     753          15 :       int b=0;
     754             :       //in the original a and b are "default=0",
     755             :       //but you divide by a and b during the initialization!
     756             :       //better an error message than an UB, so no default
     757          30 :       REQUIREDPARSE(data,"A",a,errormsg);
     758          30 :       REQUIREDPARSE(data,"B",b,errormsg);
     759          15 :       function = PLMD::Tools::make_unique<switchContainers::smapSwitch>(d0,dmax,r0,a,b);
     760        1147 :     } else if(name=="Q") {
     761         864 :       double beta = 50.0;  // nm-1
     762         864 :       double lambda = 1.8; // unitless
     763             :       double ref;
     764        2592 :       CHECKandPARSE(data,"BETA",beta,errormsg);
     765        2592 :       CHECKandPARSE(data,"LAMBDA",lambda,errormsg);
     766        1728 :       REQUIREDPARSE(data,"REF",ref,errormsg);
     767             :       //the original error message was not standard
     768             :       // if(!Tools::parse(data,"REF",ref))
     769             :       //   errormsg="REF (reference distaance) is required for native Q";
     770         864 :       function = PLMD::Tools::make_unique<switchContainers::nativeqSwitch>(d0,dmax,r0,beta,lambda,ref);
     771         283 :     } else if(name=="EXP") {
     772          75 :       function = PLMD::Tools::make_unique<switchContainers::exponentialSwitch>(d0,dmax,r0);
     773         208 :     } else if(name=="GAUSSIAN") {
     774         180 :       if ( r0==1.0 && d0==0.0 ) {
     775         114 :         function = PLMD::Tools::make_unique<switchContainers::fastGaussianSwitch>(d0,dmax,r0);
     776             :       } else {
     777          66 :         function = PLMD::Tools::make_unique<switchContainers::gaussianSwitch>(d0,dmax,r0);
     778             :       }
     779          28 :     } else if(name=="TANH") {
     780           4 :       function = PLMD::Tools::make_unique<switchContainers::tanhSwitch>(d0,dmax,r0);
     781          24 :     } else if(name=="COSINUS") {
     782           3 :       function = PLMD::Tools::make_unique<switchContainers::cosinusSwitch>(d0,dmax,r0);
     783          39 :     } else if((name=="MATHEVAL" || name=="CUSTOM")) {
     784             :       std::string func;
     785          40 :       Tools::parse(data,"FUNC",func);
     786          18 :       function = PLMD::Tools::make_unique<switchContainers::leptonSwitch>(d0,dmax,r0,func);
     787             :     } else {
     788           2 :       errormsg="cannot understand switching function type '"+name+"'";
     789             :     }
     790             :   }
     791             : #undef CHECKandPARSE
     792             : #undef REQUIREDPARSE
     793             : 
     794        1579 :   if( !data.empty() ) {
     795             :     errormsg="found the following rogue keywords in switching function input : ";
     796           0 :     for(unsigned i=0; i<data.size(); ++i) {
     797           2 :       errormsg = errormsg + data[i] + " ";
     798             :     }
     799             :   }
     800             : 
     801        1579 :   if(dostretch && dmax!=std::numeric_limits<double>::max()) {
     802         142 :     function->setupStretch();
     803             :   }
     804        1581 : }
     805             : 
     806        1502 : std::string SwitchingFunction::description() const {
     807             :   // if this error is necessary, something went wrong in the constructor
     808             :   //  plumed_merror("Unknown switching function type");
     809        1502 :   return function->description();
     810             : }
     811             : 
     812    92146473 : double SwitchingFunction::calculateSqr(double distance2,double&dfunc)const {
     813    92146473 :   return function -> calculateSqr(distance2, dfunc);
     814             : }
     815             : 
     816   127898725 : double SwitchingFunction::calculate(double distance,double&dfunc)const {
     817   127898725 :   plumed_massert(init,"you are trying to use an unset SwitchingFunction");
     818   127898725 :   double result=function->calculate(distance,dfunc);
     819   127898725 :   return result;
     820             : }
     821             : 
     822          74 : void SwitchingFunction::set(const int nn,int mm, const double r0, const double d0) {
     823          74 :   init=true;
     824          74 :   if(mm == 0) {
     825          70 :     mm = 2*nn;
     826             :   }
     827          74 :   double dmax=d0+r0*std::pow(0.00001,1./(nn-mm));
     828         148 :   function = switchContainers::rationalFactory(d0,dmax,r0,nn,mm);
     829          74 :   function->setupStretch();
     830          74 : }
     831             : 
     832          32 : double SwitchingFunction::get_r0() const {
     833          32 :   return function->get_r0();
     834             : }
     835             : 
     836           8 : double SwitchingFunction::get_d0() const {
     837           8 :   return function->get_d0();
     838             : }
     839             : 
     840   536580542 : double SwitchingFunction::get_dmax() const {
     841   536580542 :   return function->get_dmax();
     842             : }
     843             : 
     844    49030642 : double SwitchingFunction::get_dmax2() const {
     845    49030642 :   return function->get_dmax2();
     846             : }
     847             : 
     848             : }// Namespace PLMD

Generated by: LCOV version 1.16