LCOV - code coverage report
Current view: top level - gridtools - KDE.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 379 461 82.2 %
Date: 2025-03-25 09:33:27 Functions: 18 19 94.7 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2012-2017 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "ActionWithGrid.h"
      23             : #include "core/PlumedMain.h"
      24             : #include "core/ActionSet.h"
      25             : #include "core/ActionRegister.h"
      26             : #include "core/PbcAction.h"
      27             : #include "tools/HistogramBead.h"
      28             : #include "tools/SwitchingFunction.h"
      29             : #include "tools/Matrix.h"
      30             : 
      31             : //+PLUMEDOC ANALYSIS KDE
      32             : /*
      33             : Create a histogram from the input scalar/vector/matrix using KDE
      34             : 
      35             : \par Examples
      36             : 
      37             : 
      38             : */
      39             : //+ENDPLUMEDOC
      40             : 
      41             : //+PLUMEDOC ANALYSIS SPHERICAL_KDE
      42             : /*
      43             : Create a histogram from the input scalar/vector/matrix using SPHERICAL_KDE
      44             : 
      45             : \par Examples
      46             : 
      47             : 
      48             : */
      49             : //+ENDPLUMEDOC
      50             : 
      51             : namespace PLMD {
      52             : namespace gridtools {
      53             : 
      54             : class KDE : public ActionWithGrid {
      55             : private:
      56             :   double hh;
      57             :   bool hasheight;
      58             :   bool ignore_out_of_bounds, fixed_width;
      59             :   double dp2cutoff;
      60             :   std::string kerneltype;
      61             :   GridCoordinatesObject gridobject;
      62             :   std::vector<std::string> gmin, gmax;
      63             :   std::vector<double> center;
      64             :   std::vector<double> gspacing;
      65             :   unsigned num_neigh, bwargno;
      66             :   std::vector<Value> grid_diff_value;
      67             :   std::vector<unsigned> nbin, nneigh, neighbors;
      68             :   unsigned numberOfKernels, nbins;
      69             :   SwitchingFunction switchingFunction;
      70             :   double von_misses_concentration, von_misses_norm;
      71             :   void setupNeighborsVector();
      72             :   void retrieveArgumentsAndHeight( const MultiValue& myvals, std::vector<double>& args, double& height ) const ;
      73             :   double evaluateKernel( const std::vector<double>& gpoint, const std::vector<double>& args, const double& height, std::vector<double>& der ) const ;
      74             :   void setupHistogramBeads( std::vector<HistogramBead>& bead ) const ;
      75             :   double evaluateBeadValue( std::vector<HistogramBead>& bead, const std::vector<double>& gpoint, const std::vector<double>& args, const double& height, std::vector<double>& der ) const ;
      76             : public:
      77             :   static void registerKeywords( Keywords& keys );
      78             :   explicit KDE(const ActionOptions&ao);
      79             :   std::vector<std::string> getGridCoordinateNames() const override ;
      80             :   const GridCoordinatesObject& getGridCoordinatesObject() const override ;
      81             :   unsigned getNumberOfDerivatives() override;
      82             :   void setupOnFirstStep( const bool incalc ) override ;
      83             :   void getNumberOfTasks( unsigned& ntasks ) override ;
      84             :   void areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions ) override ;
      85             :   int checkTaskStatus( const unsigned& taskno, int& flag ) const override ;
      86             :   void performTask( const unsigned& current, MultiValue& myvals ) const override ;
      87             :   void gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals,
      88             :                           const unsigned& bufstart, std::vector<double>& buffer ) const override ;
      89             :   void updateForceTasksFromValue( const Value* myval, std::vector<unsigned>& force_tasks ) const override ;
      90             :   void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const override ;
      91             : };
      92             : 
      93             : PLUMED_REGISTER_ACTION(KDE,"KDE")
      94             : PLUMED_REGISTER_ACTION(KDE,"SPHERICAL_KDE")
      95             : 
      96         274 : void KDE::registerKeywords( Keywords& keys ) {
      97         274 :   ActionWithGrid::registerKeywords( keys );
      98         548 :   keys.addInputKeyword("compulsory","ARG","scalar/vector/matrix","the label for the value that should be used to construct the histogram");
      99         274 :   keys.add("optional","HEIGHTS","this keyword takes the label of an action that calculates a vector of values.  The elements of this vector "
     100             :            "are used as weights for the Gaussians.");
     101         274 :   keys.add("optional","VOLUMES","this keyword take the label of an action that calculates a vector of values.  The elements of this vector "
     102             :            "divided by the volume of the Gaussian are used as weights for the Gaussians");
     103             :   // Keywords for KDE
     104         274 :   keys.add("compulsory","GRID_MIN","auto","the lower bounds for the grid");
     105         274 :   keys.add("compulsory","GRID_MAX","auto","the upper bounds for the grid");
     106         274 :   keys.add("optional","BANDWIDTH","the bandwidths for kernel density esimtation");
     107         274 :   keys.add("compulsory","METRIC","the inverse covariance to use for the kernels that are added to the grid");
     108         274 :   keys.add("compulsory","CUTOFF","6.25","the cutoff at which to stop evaluating the kernel functions is set equal to sqrt(2*x)*bandwidth in each direction where x is this number");
     109         274 :   keys.add("compulsory","KERNEL","GAUSSIAN","the kernel function you are using.  More details on  the kernels available "
     110             :            "in plumed plumed can be found in \\ref kernelfunctions.");
     111         274 :   keys.add("optional","GRID_BIN","the number of bins for the grid");
     112         274 :   keys.addFlag("IGNORE_IF_OUT_OF_RANGE",false,"if a kernel is outside of the range of the grid it is safe to ignore");
     113         274 :   keys.add("optional","GRID_SPACING","the approximate grid spacing (to be used as an alternative or together with GRID_BIN)");
     114             :   // Keywords for spherical KDE
     115         274 :   keys.add("compulsory","CONCENTRATION","the concentration parameter for Von Mises-Fisher distributions (only required for SPHERICAL_KDE)");
     116         548 :   keys.setValueDescription("grid","a function on a grid that was obtained by doing a Kernel Density Estimation using the input arguments");
     117         274 : }
     118             : 
     119         149 : KDE::KDE(const ActionOptions&ao):
     120             :   Action(ao),
     121             :   ActionWithGrid(ao),
     122         149 :   hasheight(false),
     123         149 :   fixed_width(false) {
     124         149 :   std::vector<unsigned> shape( getNumberOfArguments() );
     125         149 :   center.resize( getNumberOfArguments() );
     126         149 :   numberOfKernels=getPntrToArgument(0)->getNumberOfValues();
     127         199 :   for(unsigned i=1; i<shape.size(); ++i) {
     128          50 :     if( numberOfKernels!=getPntrToArgument(i)->getNumberOfValues() ) {
     129           0 :       error("mismatch between numbers of values in input arguments");
     130             :     }
     131             :   }
     132             : 
     133             :   bool weights_are_volumes=true;
     134             :   std::vector<std::string> weight_str;
     135         298 :   parseVector("VOLUMES",weight_str);
     136         149 :   if( weight_str.size()==0 ) {
     137         146 :     parseVector("HEIGHTS",weight_str);
     138          73 :     if( weight_str.size()>0 ) {
     139             :       weights_are_volumes=false;
     140             :     }
     141             :   }
     142         149 :   hasheight=(weight_str.size()==1);
     143         149 :   if( weight_str.size()>1 ) {
     144           0 :     error("only one scalar/vector/matrix should be input to HEIGHTS");
     145             :   }
     146             : 
     147         149 :   if( getName()=="KDE" ) {
     148         278 :     parse("KERNEL",kerneltype);
     149         139 :     if( kerneltype!="DISCRETE" ) {
     150             :       std::string bandwidth;
     151             :       std::vector<std::string> bwidths;
     152         254 :       parseVector("BANDWIDTH",bwidths);
     153         127 :       if( bwidths.size()>0 ) {
     154         127 :         std::string band="VALUES=" + bwidths[0];
     155         284 :         for(unsigned i=0; i<bwidths.size(); ++i) {
     156         157 :           if( i>0 ) {
     157          60 :             band += "," + bwidths[i];
     158             :           }
     159             :         }
     160         127 :         plumed.readInputWords( Tools::getWords(getLabel() + "_sigma: CONSTANT " + band), false );
     161         127 :         plumed.readInputWords( Tools::getWords(getLabel() + "_cov: CUSTOM ARG=" + getLabel() + "_sigma FUNC=x*x PERIODIC=NO"), false );
     162         127 :         plumed.readInputWords( Tools::getWords(getLabel() + "_icov: CUSTOM ARG=" + getLabel() + "_cov FUNC=1/x PERIODIC=NO"), false );
     163         254 :         bandwidth = getLabel() + "_icov";
     164             : 
     165         254 :         if( (kerneltype=="gaussian" || kerneltype=="GAUSSIAN") && weights_are_volumes ) {
     166             :           std::string pstr;
     167         105 :           Tools::convert( sqrt(pow(2*pi,bwidths.size())), pstr );
     168         105 :           plumed.readInputWords( Tools::getWords(getLabel() + "_bwprod: PRODUCT ARG=" + getLabel() + "_cov"), false );
     169         105 :           plumed.readInputWords( Tools::getWords(getLabel() + "_vol: CUSTOM ARG=" + getLabel() + "_bwprod FUNC=(sqrt(x)*" + pstr + ") PERIODIC=NO"), false );
     170         105 :           if( hasheight ) {
     171          76 :             plumed.readInputWords( Tools::getWords(getLabel() + "_height: CUSTOM ARG=" + weight_str[0] + "," + getLabel() + "_vol FUNC=x/y PERIODIC=NO"), false);
     172             :           } else {
     173          29 :             plumed.readInputWords( Tools::getWords(getLabel() + "_height: CUSTOM ARG=" + getLabel() + "_vol FUNC=1/x PERIODIC=NO"), false);
     174             :           }
     175         105 :           hasheight=true;
     176         105 :           weight_str.resize(1);
     177         210 :           weight_str[0] = getLabel() + "_height";
     178             :         }
     179             :       } else {
     180           0 :         parse("METRIC",bandwidth);
     181             :       }
     182         127 :       weight_str.push_back( bandwidth );
     183         127 :     }
     184             :   }
     185         149 :   if( weight_str.size()>0 ) {
     186             :     std::vector<Value*> weight_args;
     187         145 :     ActionWithArguments::interpretArgumentList( weight_str, plumed.getActionSet(), this, weight_args );
     188         145 :     std::vector<Value*> args( getArguments() );
     189         145 :     args.push_back( weight_args[0] );
     190         145 :     if( hasheight && weight_args[0]->getNumberOfValues()>1 && numberOfKernels!=weight_args[0]->getNumberOfValues() ) {
     191           0 :       error("mismatch between numbers of values in input arguments and HEIGHTS");
     192             :     }
     193             : 
     194         145 :     if( weight_str.size()==2 ) {
     195         224 :       log.printf("  quantities used for weights are : %s \n", weight_str[0].c_str() );
     196         112 :       args.push_back( weight_args[1] );
     197         112 :       if( weight_args[1]->getRank()==1 && weight_args[1]->getNumberOfValues()!=shape.size() ) {
     198           0 :         error("size of bandwidth vector is incorrect");
     199             :       }
     200         112 :       if( weight_args[1]->getRank()>2 ) {
     201           0 :         error("bandwidths cannot have rank greater than 2");
     202             :       }
     203         112 :       bwargno=args.size()-1;
     204         112 :       log.printf("  bandwidths are taken from : %s \n", weight_str[1].c_str() );
     205          33 :     } else if( !hasheight ) {
     206          15 :       if( weight_args[0]->getRank()==1 && weight_args[0]->getNumberOfValues()!=shape.size() ) {
     207           0 :         error("size of bandwidth vector is incorrect");
     208             :       }
     209          15 :       if( weight_args[0]->getRank()>2 ) {
     210           0 :         error("bandwidths cannot have rank greater than 2");
     211             :       }
     212          15 :       bwargno=args.size()-1;
     213          15 :       log.printf("  bandwidths are taken from : %s \n", weight_str[0].c_str() );
     214          18 :     } else if ( weight_str.size()==1 ) {
     215          18 :       log.printf("  quantities used for weights are : %s \n", weight_str[0].c_str() );
     216             :     } else {
     217           0 :       error("only one scalar/vector/matrix should be input to HEIGHTS");
     218             :     }
     219         145 :     requestArguments( args );
     220             :   }
     221             : 
     222         149 :   if( getName()=="KDE" ) {
     223             :     bool hasauto=false;
     224         139 :     gmin.resize( shape.size() );
     225         139 :     gmax.resize( shape.size() );
     226         139 :     parseVector("GRID_MIN",gmin);
     227         139 :     parseVector("GRID_MAX",gmax);
     228         308 :     for(unsigned i=0; i<gmin.size(); ++i) {
     229         169 :       if( gmin[i]=="auto" ) {
     230          52 :         log.printf("  for %dth coordinate min and max are set from cell directions \n", (i+1) );
     231             :         hasauto=true;  // We need to do a preparation step to set the grid from the box size
     232          52 :         if( gmax[i]!="auto" ) {
     233           0 :           error("if gmin is set from box vectors gmax must also be set in the same way");
     234             :         }
     235          52 :         if( getPntrToArgument(i)->isPeriodic() ) {
     236           0 :           if( gmin[i]=="auto" ) {
     237           0 :             getPntrToArgument(i)->getDomain( gmin[i], gmax[i] );
     238             :           } else {
     239             :             std::string str_min, str_max;
     240           0 :             getPntrToArgument(i)->getDomain( str_min, str_max );
     241           0 :             if( str_min!=gmin[i] || str_max!=gmax[i] ) {
     242           0 :               error("all periodic arguments should have the same domain");
     243             :             }
     244             :           }
     245          52 :         } else if( getPntrToArgument(i)->getName().find(".")!=std::string::npos ) {
     246          52 :           std::size_t dot = getPntrToArgument(i)->getName().find_first_of(".");
     247          52 :           std::string name = getPntrToArgument(i)->getName().substr(dot+1);
     248          90 :           if( name!="x" && name!="y" && name!="z" ) {
     249           0 :             error("cannot set GRID_MIN and GRID_MAX automatically if input argument is not component of distance");
     250             :           }
     251             :         } else {
     252           0 :           error("cannot set GRID_MIN and GRID_MAX automatically if input argument is not component of distance");
     253             :         }
     254             :       } else {
     255         117 :         log.printf("  for %dth coordinate min is set to %s and max is set to %s \n", (i+1), gmin[i].c_str(), gmax[i].c_str() );
     256             :       }
     257             :     }
     258         139 :     if( hasauto && gmin.size()>3 ) {
     259           0 :       error("can only set GRID_MIN and GRID_MAX automatically if components of distance are used in input");
     260             :     }
     261             : 
     262         139 :     parseVector("GRID_BIN",nbin);
     263         139 :     parseVector("GRID_SPACING",gspacing);
     264         139 :     parse("CUTOFF",dp2cutoff);
     265         260 :     if( kerneltype.find("bin")==std::string::npos && kerneltype!="DISCRETE" ) {
     266             :       std::string errors;
     267         218 :       switchingFunction.set( kerneltype + " R_0=1.0 NOSTRETCH", errors );
     268         109 :       if( errors.length()!=0 ) {
     269           0 :         error("problem reading switching function description " + errors);
     270             :       }
     271             :     }
     272             : 
     273         139 :     if( nbin.size()!=shape.size() && gspacing.size()!=shape.size() ) {
     274           0 :       error("GRID_BIN or GRID_SPACING must be set");
     275             :     }
     276             :     // Create a value
     277         139 :     std::vector<bool> ipbc( shape.size() );
     278         308 :     for(unsigned i=0; i<shape.size(); ++i) {
     279         324 :       if( getPntrToArgument( i )->isPeriodic() || gmin[i]=="auto" ) {
     280             :         ipbc[i]=true;
     281             :       } else {
     282             :         ipbc[i]=false;
     283             :       }
     284             :     }
     285         278 :     gridobject.setup( "flat", ipbc, 0, 0.0 );
     286             :   } else {
     287          10 :     if( shape.size()!=3 ) {
     288           0 :       error("should have three coordinates in input to this action");
     289             :     }
     290             : 
     291          10 :     parse("GRID_BIN",nbins);
     292          10 :     log.printf("  setting number of bins to %d \n", nbins );
     293          10 :     parse("CONCENTRATION",von_misses_concentration);
     294          10 :     fixed_width=true;
     295          10 :     von_misses_norm = von_misses_concentration / ( 4*pi*sinh( von_misses_concentration ) );
     296          10 :     log.printf("  setting concentration parameter to %f \n", von_misses_concentration );
     297             : 
     298             :     // Create a value
     299          10 :     std::vector<bool> ipbc( shape.size(), false );
     300          10 :     double fib_cutoff = std::log( epsilon / von_misses_norm ) / von_misses_concentration;
     301          10 :     gridobject.setup( "fibonacci", ipbc, nbins, fib_cutoff );
     302          10 :     checkRead();
     303             : 
     304             :     // Setup the grid
     305          10 :     shape[0]=nbins;
     306          10 :     shape[1]=shape[2]=1;
     307             :   }
     308         149 :   parseFlag("IGNORE_IF_OUT_OF_RANGE",ignore_out_of_bounds);
     309         149 :   if( ignore_out_of_bounds ) {
     310          65 :     log.printf("  ignoring kernels that are outside of grid \n");
     311             :   }
     312         149 :   addValueWithDerivatives( shape );
     313         149 :   setNotPeriodic();
     314         149 :   getPntrToComponent(0)->setDerivativeIsZeroWhenValueIsZero();
     315             :   // Make sure we store all the arguments
     316         605 :   for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     317         456 :     getPntrToArgument(i)->buildDataStore();
     318             :   }
     319             :   // Check for task reduction
     320         149 :   updateTaskListReductionStatus();
     321         149 :   setupOnFirstStep( false );
     322         298 : }
     323             : 
     324         295 : void KDE::setupOnFirstStep( const bool incalc ) {
     325         295 :   if( getName()=="SPHERICAL_KDE" ) {
     326          20 :     return ;
     327             :   }
     328             : 
     329         606 :   for(unsigned i=0; i<getNumberOfDerivatives(); ++i) {
     330         331 :     if( gmin[i]=="auto" && incalc ) {
     331             :       double lcoord, ucoord;
     332          92 :       PbcAction* bv = plumed.getActionSet().selectWithLabel<PbcAction*>("Box");
     333          46 :       Tensor box( bv->getPbc().getBox() );
     334          46 :       std::size_t dot = getPntrToArgument(i)->getName().find_first_of(".");
     335          46 :       std::string name = getPntrToArgument(i)->getName().substr(dot+1);
     336          46 :       if( name=="x" ) {
     337          24 :         lcoord=-0.5*box(0,0);
     338          24 :         ucoord=0.5*box(0,0);
     339          22 :       } else if( name=="y" ) {
     340          12 :         lcoord=-0.5*box(1,1);
     341          12 :         ucoord=0.5*box(1,1);
     342          10 :       } else if( name=="z" ) {
     343          10 :         lcoord=-0.5*box(2,2);
     344          10 :         ucoord=0.5*box(2,2);
     345             :       } else {
     346           0 :         plumed_error();
     347             :       }
     348             :       // And convert to strings for bin and bmax
     349          46 :       Tools::convert( lcoord, gmin[i] );
     350          46 :       Tools::convert( ucoord, gmax[i] );
     351             :     }
     352         331 :     if( incalc ) {
     353         324 :       grid_diff_value.push_back( Value() );
     354         162 :       if( gridobject.isPeriodic(i) ) {
     355          59 :         grid_diff_value[i].setDomain( gmin[i], gmax[i] );
     356             :       } else {
     357         103 :         grid_diff_value[i].setNotPeriodic();
     358             :       }
     359             :     }
     360             :   }
     361             :   // And setup the grid object
     362         275 :   gridobject.setBounds( gmin, gmax, nbin, gspacing );
     363         275 :   std::vector<unsigned> shape( gridobject.getNbin(true) );
     364         275 :   getPntrToComponent(0)->setShape( shape );
     365             :   bool hasauto=false;
     366         552 :   for(unsigned i=0; i<gmin.size(); ++i) {
     367         592 :     if(gmin[i]=="auto" || gmax[i]=="auto" ) {
     368             :       hasauto=true;
     369             :       break;
     370             :     }
     371             :   }
     372             :   // And setup the neighbors
     373         275 :   if( !hasauto && kerneltype!="DISCRETE" && getPntrToArgument(bwargno)->isConstant() ) {
     374         214 :     fixed_width=true;
     375         214 :     setupNeighborsVector();
     376             :   }
     377             : }
     378             : 
     379         237 : void KDE::setupNeighborsVector() {
     380         237 :   if( kerneltype!="DISCRETE" ) {
     381         214 :     std::vector<double> support(gmin.size(),0);
     382         214 :     nneigh.resize( gmin.size() );
     383         214 :     if( kerneltype.find("bin")!=std::string::npos ) {
     384          18 :       std::size_t dd = kerneltype.find("-bin");
     385          18 :       HistogramBead bead;
     386          18 :       bead.setKernelType( kerneltype.substr(0,dd) );
     387          18 :       Value* bw_arg=getPntrToArgument(bwargno);
     388          18 :       if( bw_arg->getRank()<2 ) {
     389          36 :         for(unsigned i=0; i<support.size(); ++i) {
     390          18 :           bead.set( 0, gridobject.getGridSpacing()[i], 1./sqrt(bw_arg->get(i)) );
     391          18 :           support[i] = bead.getCutoff();
     392          18 :           nneigh[i] = static_cast<unsigned>( ceil( support[i]/gridobject.getGridSpacing()[i] ));
     393             :         }
     394             :       } else {
     395           0 :         plumed_error();
     396             :       }
     397             :     } else {
     398         196 :       Value* bw_arg=getPntrToArgument(bwargno);
     399         196 :       if( bw_arg->getRank()<2 ) {
     400         432 :         for(unsigned i=0; i<support.size(); ++i) {
     401         236 :           support[i] = sqrt(2.0*dp2cutoff)*(1.0/sqrt(bw_arg->get(i)));
     402         236 :           nneigh[i] = static_cast<unsigned>( ceil( support[i] / gridobject.getGridSpacing()[i] ) );
     403             :         }
     404           0 :       } else if( bw_arg->getRank()==2 ) {
     405           0 :         Matrix<double> metric(support.size(),support.size());
     406             :         unsigned k=0;
     407           0 :         for(unsigned i=0; i<support.size(); ++i) {
     408           0 :           for(unsigned j=0; j<support.size(); ++j) {
     409           0 :             metric(i,j)=bw_arg->get(k);
     410           0 :             k++;
     411             :           }
     412             :         }
     413           0 :         Matrix<double> myautovec(support.size(),support.size());
     414           0 :         std::vector<double> myautoval(support.size());
     415           0 :         diagMat(metric,myautoval,myautovec);
     416           0 :         double maxautoval=1/myautoval[0];
     417             :         unsigned ind_maxautoval=0;
     418           0 :         for(unsigned i=1; i<support.size(); i++) {
     419           0 :           double neweig=1/myautoval[i];
     420           0 :           if(neweig>maxautoval) {
     421             :             maxautoval=neweig;
     422             :             ind_maxautoval=i;
     423             :           }
     424             :         }
     425           0 :         for(unsigned i=0; i<support.size(); i++) {
     426           0 :           support[i] = sqrt(2.0*dp2cutoff)*fabs(sqrt(maxautoval)*myautovec(i,ind_maxautoval));
     427           0 :           nneigh[i] = static_cast<unsigned>( ceil( support[i] / gridobject.getGridSpacing()[i] ) );
     428             :         }
     429             :       } else {
     430           0 :         plumed_error();
     431             :       }
     432             :     }
     433         468 :     for(unsigned i=0; i<gridobject.getDimension(); ++i) {
     434             :       double fmax, fmin;
     435         254 :       Tools::convert( gridobject.getMin()[i], fmin );
     436         254 :       Tools::convert( gridobject.getMax()[i], fmax );
     437         254 :       if( gridobject.isPeriodic(i) && 2*support[i]>(fmax-fmin) ) {
     438           0 :         error("bandwidth is too large for periodic grid");
     439             :       }
     440             :     }
     441             :   }
     442         237 : }
     443             : 
     444        1065 : unsigned KDE::getNumberOfDerivatives() {
     445        1065 :   return gridobject.getDimension();
     446             : }
     447             : 
     448         125 : std::vector<std::string> KDE::getGridCoordinateNames() const {
     449         125 :   std::vector<std::string> names( gridobject.getDimension() );
     450         301 :   for(unsigned i=0; i<names.size(); ++i) {
     451             :     names[i] = getPntrToArgument(i)->getName();
     452             :   }
     453         125 :   return names;
     454           0 : }
     455             : 
     456        6222 : const GridCoordinatesObject& KDE::getGridCoordinatesObject() const {
     457        6222 :   return gridobject;
     458             : }
     459             : 
     460         149 : void KDE::areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions ) {
     461         149 :   if( numberOfKernels==1 || (hasheight && getPntrToArgument(gridobject.getDimension())->getRank()>0) ) {
     462         127 :     task_reducing_actions.push_back(this);
     463             :   }
     464         149 : }
     465             : 
     466        2651 : void KDE::getNumberOfTasks( unsigned& ntasks ) {
     467        2651 :   if( !fixed_width ) {
     468          23 :     setupNeighborsVector();
     469             :   }
     470        2651 :   ntasks = numberOfKernels = getPntrToArgument(0)->getNumberOfValues();
     471        2651 :   if( numberOfKernels>1 ) {
     472             :     return;
     473             :   }
     474             : 
     475         132 :   hh = 1.0;
     476         132 :   if( hasheight ) {
     477         120 :     hh = getPntrToArgument(gridobject.getDimension())->get();
     478             :   }
     479         309 :   for(unsigned i=0; i<center.size(); ++i) {
     480         177 :     center[i]=getPntrToArgument(i)->get();
     481             :   }
     482         132 :   if( !ignore_out_of_bounds && !gridobject.inbounds( center ) ) {
     483             :     //if( fabs(height)>epsilon ) warning("bounds are possibly set too small as hills with substantial heights are being ignored");
     484             :     return;
     485             :   }
     486         132 :   if( kerneltype=="DISCRETE" ) {
     487          12 :     num_neigh=1;
     488          12 :     neighbors.resize(1);
     489          24 :     for(unsigned i=0; i<center.size(); ++i) {
     490          12 :       center[i] += 0.5*gridobject.getGridSpacing()[i];
     491             :     }
     492          12 :     neighbors[0]=gridobject.getIndex( center );
     493             :   } else {
     494         120 :     gridobject.getNeighbors( center, nneigh, num_neigh, neighbors );
     495             :   }
     496         132 :   ntasks = getPntrToComponent(0)->getNumberOfValues();
     497         132 :   return;
     498             : }
     499             : 
     500     1500435 : int KDE::checkTaskStatus( const unsigned& taskno, int& flag ) const {
     501     1500435 :   if( numberOfKernels>1 ) {
     502     1400374 :     if( hasheight && getPntrToArgument(gridobject.getDimension())->getRank()>0
     503     2823898 :         && fabs(getPntrToArgument(gridobject.getDimension())->get(taskno))<epsilon ) {
     504     1238166 :       return 0;
     505             :     }
     506      185408 :     return 1;
     507             :   }
     508    19394629 :   for(unsigned i=0; i<num_neigh; ++i) {
     509    19339045 :     if( taskno==neighbors[i] ) {
     510             :       return 1;
     511             :     }
     512             :   }
     513             :   return 0;
     514             : }
     515             : 
     516      309567 : void KDE::performTask( const unsigned& current, MultiValue& myvals ) const {
     517      309567 :   if( numberOfKernels==1 ) {
     518             :     double newval;
     519       21277 :     std::vector<double> args( gridobject.getDimension() ), der( gridobject.getDimension() );
     520       21277 :     unsigned valout = getConstPntrToComponent(0)->getPositionInStream();
     521       21277 :     gridobject.getGridPointCoordinates( current, args );
     522       21277 :     if( getName()=="KDE" ) {
     523       21277 :       if( kerneltype=="DISCRETE" ) {
     524             :         newval = 1.0;
     525       21265 :       } else if( kerneltype.find("bin")!=std::string::npos ) {
     526           0 :         double val=hh;
     527           0 :         std::size_t dd = kerneltype.find("-bin");
     528           0 :         HistogramBead bead;
     529           0 :         bead.setKernelType( kerneltype.substr(0,dd) );
     530           0 :         Value* bw_arg=getPntrToArgument(bwargno);
     531           0 :         for(unsigned j=0; j<args.size(); ++j) {
     532           0 :           if( gridobject.isPeriodic(j) ) {
     533             :             double lcoord,  ucoord;
     534           0 :             Tools::convert( gmin[j], lcoord );
     535           0 :             Tools::convert( gmax[j], ucoord );
     536           0 :             bead.isPeriodic( lcoord, ucoord );
     537             :           } else {
     538             :             bead.isNotPeriodic();
     539             :           }
     540           0 :           if( bw_arg->getRank()<2 ) {
     541           0 :             bead.set( args[j], args[j]+gridobject.getGridSpacing()[j], 1/sqrt(bw_arg->get(j)) );
     542           0 :           } else if( bw_arg->getRank()==2 ) {
     543           0 :             plumed_error();
     544             :           }
     545           0 :           double contr = bead.calculateWithCutoff( args[j], der[j] );
     546           0 :           val = val*contr;
     547           0 :           der[j] = der[j] / contr;
     548             :         }
     549           0 :         for(unsigned j=0; j<args.size(); ++j) {
     550           0 :           der[j] *= val;
     551             :         }
     552             :         newval=val;
     553             :       } else {
     554       21265 :         newval = evaluateKernel( args, center, hh, der );
     555             :       }
     556             :     } else {
     557             :       double dot=0;
     558           0 :       for(unsigned i=0; i<der.size(); ++i) {
     559           0 :         dot += args[i]*center[i];
     560             :       }
     561           0 :       newval = hh*von_misses_norm*exp( von_misses_concentration*dot );
     562           0 :       for(unsigned i=0; i<der.size(); ++i) {
     563           0 :         der[i] = von_misses_concentration*newval*args[i];
     564             :       }
     565             :     }
     566       21277 :     myvals.setValue( valout, newval );
     567       78047 :     for(unsigned i=0; i<der.size(); ++i) {
     568       56770 :       myvals.addDerivative( valout, i, der[i] );
     569       56770 :       myvals.updateIndex( valout, i );
     570             :     }
     571             :   }
     572      309567 : }
     573             : 
     574      288290 : void KDE::retrieveArgumentsAndHeight( const MultiValue& myvals, std::vector<double>& args, double& height ) const {
     575      288290 :   height=1.0;
     576      712540 :   for(unsigned i=0; i<args.size(); ++i) {
     577      424250 :     args[i]=getPntrToArgument(i)->get( myvals.getTaskIndex() );
     578             :   }
     579      288290 :   if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) {
     580       15932 :     height = getPntrToArgument( args.size() )->get();
     581      272358 :   } else if( hasheight ) {
     582      162158 :     height = getPntrToArgument( args.size() )->get( myvals.getTaskIndex() );
     583             :   }
     584      288290 : }
     585             : 
     586    38119671 : double KDE::evaluateKernel( const std::vector<double>& gpoint, const std::vector<double>& args, const double& height, std::vector<double>& der ) const {
     587    38119671 :   double r2=0, hval = height;
     588    38119671 :   Value* bw_arg=getPntrToArgument(bwargno);
     589    38119671 :   if( bw_arg->getRank()<2 ) {
     590   151428182 :     for(unsigned j=0; j<der.size(); ++j) {
     591   113308511 :       double tmp = -grid_diff_value[j].difference( gpoint[j], args[j] );
     592   113308511 :       der[j] = tmp*bw_arg->get(j);
     593   113308511 :       r2 += tmp*der[j];
     594             :     }
     595           0 :   } else if( bw_arg->getRank()==2 ) {
     596           0 :     for(unsigned j=0; j<der.size(); ++j) {
     597           0 :       der[j]=0;
     598             :       double dp_j, dp_k;
     599           0 :       dp_j = -grid_diff_value[j].difference( gpoint[j], args[j] );
     600           0 :       for(unsigned k=0; k<der.size(); ++k ) {
     601           0 :         if(j==k) {
     602             :           dp_k = dp_j;
     603             :         } else {
     604           0 :           dp_k = -grid_diff_value[k].difference( gpoint[k], args[k] );
     605             :         }
     606           0 :         der[j] += bw_arg->get(j*der.size()+k)*dp_k;
     607           0 :         r2 += dp_j*dp_k*bw_arg->get(j*der.size()+k);
     608             :       }
     609             :     }
     610             :   } else {
     611           0 :     plumed_error();
     612             :   }
     613    38119671 :   double dval, val=hval*switchingFunction.calculateSqr( r2, dval );
     614    38119671 :   dval *= hval;
     615   151428182 :   for(unsigned j=0; j<der.size(); ++j) {
     616   113308511 :     der[j] *= dval;
     617             :   }
     618    38119671 :   return val;
     619             : }
     620             : 
     621      133400 : void KDE::setupHistogramBeads( std::vector<HistogramBead>& bead ) const {
     622      133400 :   std::size_t dd = kerneltype.find("-bin");
     623      133400 :   std::string ktype=kerneltype.substr(0,dd);
     624      266800 :   for(unsigned j=0; j<bead.size(); ++j) {
     625      133400 :     bead[j].setKernelType( ktype );
     626      133400 :     if( gridobject.isPeriodic(j) ) {
     627             :       double lcoord,  ucoord;
     628      133400 :       Tools::convert( gmin[j], lcoord );
     629      133400 :       Tools::convert( gmax[j], ucoord );
     630      133400 :       bead[j].isPeriodic( lcoord, ucoord );
     631             :     } else {
     632             :       bead[j].isNotPeriodic();
     633             :     }
     634             :   }
     635      133400 : }
     636             : 
     637      533600 : double KDE::evaluateBeadValue( std::vector<HistogramBead>& bead, const std::vector<double>& gpoint, const std::vector<double>& args,
     638             :                                const double& height, std::vector<double>& der ) const {
     639      533600 :   double val=height;
     640      533600 :   std::vector<double> contr( args.size() );
     641      533600 :   Value* bw_arg=getPntrToArgument(bwargno);
     642      533600 :   if( bw_arg->getRank()<2 ) {
     643     1067200 :     for(unsigned j=0; j<args.size(); ++j) {
     644      533600 :       bead[j].set( gpoint[j], gpoint[j]+gridobject.getGridSpacing()[j], 1/sqrt(bw_arg->get(j)) );
     645      533600 :       contr[j] = bead[j].calculateWithCutoff( args[j], der[j] );
     646      533600 :       val = val*contr[j];
     647             :     }
     648             :   } else {
     649           0 :     plumed_error();
     650             :   }
     651     1067200 :   for(unsigned j=0; j<args.size(); ++j) {
     652      533600 :     if( fabs(contr[j])>epsilon ) {
     653      445921 :       der[j] *= val / contr[j];
     654             :     }
     655             :   }
     656      533600 :   return val;
     657             : }
     658             : 
     659      251539 : void KDE::gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals,
     660             :                              const unsigned& bufstart, std::vector<double>& buffer ) const {
     661             :   plumed_dbg_assert( valindex==0 );
     662      251539 :   if( numberOfKernels==1 ) {
     663       21277 :     unsigned istart = bufstart + (1+gridobject.getDimension())*code;
     664       21277 :     unsigned valout = getConstPntrToComponent(0)->getPositionInStream();
     665       21277 :     buffer[istart] += myvals.get( valout );
     666       78047 :     for(unsigned i=0; i<gridobject.getDimension(); ++i) {
     667       56770 :       buffer[istart+1+i] += myvals.getDerivative( valout, i );
     668             :     }
     669       46591 :     return;
     670             :   }
     671      230262 :   std::vector<double> args( gridobject.getDimension() );
     672             :   double height;
     673      230262 :   retrieveArgumentsAndHeight( myvals, args, height );
     674      230262 :   if( !ignore_out_of_bounds && !gridobject.inbounds( args ) ) {
     675             :     // if( fabs(height)>epsilon ) warning("bounds are possibly set too small as hills with substantial heights are being ignored");
     676             :     return ;
     677             :   }
     678             :   // Add the kernel to the grid
     679             :   unsigned num_neigh;
     680             :   std::vector<unsigned> neighbors;
     681      204948 :   if( kerneltype!="DISCRETE" ) {
     682      173454 :     gridobject.getNeighbors( args, nneigh, num_neigh, neighbors );
     683             :   }
     684      204948 :   std::vector<double> der( args.size() ), gpoint( args.size() );
     685      204948 :   if( fabs(height)>epsilon ) {
     686      204948 :     if( getName()=="KDE" ) {
     687      197340 :       if( kerneltype=="DISCRETE" ) {
     688       31494 :         std::vector<double> newargs( args.size() );
     689       62988 :         for(unsigned i=0; i<args.size(); ++i) {
     690       31494 :           newargs[i] = args[i] + 0.5*gridobject.getGridSpacing()[i];
     691             :         }
     692       31494 :         plumed_assert( bufstart + gridobject.getIndex( newargs )*(1+args.size())<buffer.size() );
     693       31494 :         buffer[ bufstart + gridobject.getIndex( newargs )*(1+args.size()) ] += height;
     694      165846 :       } else if( kerneltype.find("bin")!=std::string::npos ) {
     695      104400 :         std::vector<HistogramBead> bead( args.size() );
     696      104400 :         setupHistogramBeads( bead );
     697      522000 :         for(unsigned i=0; i<num_neigh; ++i) {
     698      417600 :           gridobject.getGridPointCoordinates( neighbors[i], gpoint );
     699      417600 :           double val = evaluateBeadValue( bead, gpoint, args, height, der );
     700      417600 :           buffer[ bufstart + neighbors[i]*(1+der.size()) ] += val;
     701      835200 :           for(unsigned j=0; j<der.size(); ++j) {
     702      417600 :             buffer[ bufstart + neighbors[i]*(1+der.size()) + 1 + j ] += val*der[j];
     703             :           }
     704             :         }
     705             :       } else {
     706    37979124 :         for(unsigned i=0; i<num_neigh; ++i) {
     707    37917678 :           gridobject.getGridPointCoordinates( neighbors[i], gpoint );
     708    37917678 :           buffer[ bufstart + neighbors[i]*(1+der.size()) ] += evaluateKernel( gpoint, args, height, der );
     709   150985777 :           for(unsigned j=0; j<der.size(); ++j) {
     710   113068099 :             buffer[ bufstart + neighbors[i]*(1+der.size()) + 1 + j ] += der[j];
     711             :           }
     712             :         }
     713             :       }
     714             :     } else {
     715      828405 :       for(unsigned i=0; i<num_neigh; ++i) {
     716      820797 :         gridobject.getGridPointCoordinates( neighbors[i], gpoint );
     717             :         double dot=0;
     718     3283188 :         for(unsigned j=0; j<gpoint.size(); ++j) {
     719     2462391 :           dot += args[j]*gpoint[j];
     720             :         }
     721      820797 :         double newval = height*von_misses_norm*exp( von_misses_concentration*dot );
     722      820797 :         buffer[ bufstart + neighbors[i]*(1+gpoint.size()) ] += newval;
     723     3283188 :         for(unsigned j=0; j<gpoint.size(); ++j) {
     724     2462391 :           buffer[ bufstart + neighbors[i]*(1+gpoint.size()) + 1 + j ] += von_misses_concentration*newval*gpoint[j];
     725             :         }
     726             :       }
     727             :     }
     728             :   }
     729             : }
     730             : 
     731         610 : void KDE::updateForceTasksFromValue( const Value* myval, std::vector<unsigned>& force_tasks ) const {
     732         610 :   if( !myval->forcesWereAdded() ) {
     733           0 :     return ;
     734             :   }
     735         610 :   if( numberOfKernels==1 ) {
     736           0 :     plumed_error();
     737             :   }
     738             : 
     739         610 :   int flag=1;
     740      290135 :   for(unsigned i=0; i<numberOfKernels; ++i) {
     741      289525 :     if( checkTaskStatus( i, flag ) ) {
     742       58028 :       force_tasks.push_back(i);
     743             :     }
     744             :   }
     745             : }
     746             : 
     747       58028 : void KDE::gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const {
     748       58028 :   if( numberOfKernels==1 ) {
     749           0 :     plumed_error();
     750             :     return;
     751             :   }
     752             :   double height;
     753       58028 :   std::vector<double> args( gridobject.getDimension() );
     754       58028 :   retrieveArgumentsAndHeight( myvals, args, height );
     755             :   unsigned num_neigh;
     756             :   std::vector<unsigned> neighbors;
     757       58028 :   gridobject.getNeighbors( args, nneigh, num_neigh, neighbors );
     758       58028 :   std::vector<double> der( args.size() ), gpoint( args.size() );
     759             :   unsigned hforce_start = 0;
     760      129700 :   for(unsigned j=0; j<der.size(); ++j) {
     761       71672 :     hforce_start += getPntrToArgument(j)->getNumberOfStoredValues();
     762             :   }
     763       58028 :   if( fabs(height)>epsilon ) {
     764       58028 :     if( getName()=="KDE" ) {
     765       51226 :       if( kerneltype.find("bin")!=std::string::npos ) {
     766       29000 :         std::vector<HistogramBead> bead( args.size() );
     767       29000 :         setupHistogramBeads( bead );
     768      145000 :         for(unsigned i=0; i<num_neigh; ++i) {
     769      116000 :           gridobject.getGridPointCoordinates( neighbors[i], gpoint );
     770      116000 :           double val = evaluateBeadValue( bead, gpoint, args, height, der );
     771      116000 :           double fforce = getConstPntrToComponent(0)->getForce( neighbors[i] );
     772      116000 :           if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) {
     773           0 :             forces[ hforce_start ] += val*fforce / height;
     774      116000 :           } else if( hasheight ) {
     775       23200 :             forces[ hforce_start + getPntrToArgument(args.size())->getIndexInStore(itask) ] += val*fforce / height;
     776             :           }
     777             :           unsigned n=0;
     778      232000 :           for(unsigned j=0; j<der.size(); ++j) {
     779      116000 :             forces[n + getPntrToArgument(j)->getIndexInStore(itask)] += der[j]*fforce;
     780      116000 :             n += getPntrToArgument(j)->getNumberOfStoredValues();
     781             :           }
     782             :         }
     783             :       } else {
     784      202954 :         for(unsigned i=0; i<num_neigh; ++i) {
     785      180728 :           gridobject.getGridPointCoordinates( neighbors[i], gpoint );
     786      180728 :           double val = evaluateKernel( gpoint, args, height, der ), fforce = getConstPntrToComponent(0)->getForce( neighbors[i] );
     787      180728 :           if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) {
     788         818 :             forces[ hforce_start ] += val*fforce / height;
     789      179910 :           } else if( hasheight ) {
     790      179910 :             forces[ hforce_start + getPntrToArgument(args.size())->getIndexInStore(itask) ] += val*fforce / height;
     791             :           }
     792             :           unsigned n=0;
     793      364382 :           for(unsigned j=0; j<der.size(); ++j) {
     794      183654 :             forces[n + getPntrToArgument(j)->getIndexInStore(itask)] += -der[j]*fforce;
     795      183654 :             n += getPntrToArgument(j)->getNumberOfStoredValues();
     796             :           }
     797             :         }
     798             :       }
     799             :     } else {
     800      687002 :       for(unsigned i=0; i<num_neigh; ++i) {
     801      680200 :         gridobject.getGridPointCoordinates( neighbors[i], gpoint );
     802             :         double dot=0;
     803     2720800 :         for(unsigned j=0; j<gpoint.size(); ++j) {
     804     2040600 :           dot += args[j]*gpoint[j];
     805             :         }
     806      680200 :         double fforce = myval->getForce( neighbors[i] );
     807      680200 :         double newval = height*von_misses_norm*exp( von_misses_concentration*dot );
     808      680200 :         if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) {
     809           0 :           forces[ hforce_start ] += newval*fforce / height;
     810      680200 :         } else if( hasheight ) {
     811      680200 :           forces[ hforce_start + getPntrToArgument(args.size())->getIndexInStore(itask) ] += newval*fforce / height;
     812             :         }
     813             :         unsigned n=0;
     814     2720800 :         for(unsigned j=0; j<gpoint.size(); ++j) {
     815     2040600 :           forces[n + getPntrToArgument(j)->getIndexInStore(itask)] += von_misses_concentration*newval*gpoint[j]*fforce;
     816     2040600 :           n += getPntrToArgument(j)->getNumberOfStoredValues();
     817             :         }
     818             :       }
     819             :     }
     820             :   }
     821             : }
     822             : 
     823             : }
     824             : }

Generated by: LCOV version 1.16