LCOV - code coverage report
Current view: top level - gridtools - EvaluateGridFunction.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 90 102 88.2 %
Date: 2024-10-18 14:00:25 Functions: 6 6 100.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2015-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "EvaluateGridFunction.h"
      23             : #include "ActionWithGrid.h"
      24             : #include "core/PlumedMain.h"
      25             : #include "core/ActionSet.h"
      26             : 
      27             : namespace PLMD {
      28             : namespace gridtools {
      29             : 
      30         184 : void EvaluateGridFunction::registerKeywords( Keywords& keys ) {
      31         368 :   keys.add("compulsory","INTERPOLATION_TYPE","spline","the method to use for interpolation.  Can be spline, linear, ceiling or floor.");
      32         368 :   keys.addFlag("ZERO_OUTSIDE_GRID_RANGE",false,"if we are asked to evaluate the function for a number that is outside the range of the grid set it to zero");
      33         184 :   keys.setValueDescription("interpolation of the input grid to get the value of the function at the input arguments");
      34         184 : }
      35             : 
      36         242 : std::vector<bool> EvaluateGridFunction::getPbc() const {
      37         242 :   std::vector<bool> ipbc( gridobject.getDimension() );
      38         484 :   for(unsigned i=0; i<ipbc.size(); ++i) ipbc[i] = gridobject.isPeriodic(i);
      39         242 :   return ipbc;
      40             : }
      41             : 
      42          88 : void EvaluateGridFunction::read( ActionWithArguments* action ) {
      43          88 :   if( action->getPntrToArgument(0)->getRank()==0 || !action->getPntrToArgument(0)->hasDerivatives() ) action->error("should have one grid as input to this action");
      44             :   // Get the input grid
      45          88 :   ActionWithGrid* ag = ActionWithGrid::getInputActionWithGrid( (action->getPntrToArgument(0))->getPntrToAction() );
      46         176 :   if( ag->getGridCoordinatesObject().getGridType()!="flat" ) action->error("cannot interpolate on fibonacci sphere");
      47          88 :   std::vector<bool> ipbc( ag->getGridCoordinatesObject().getDimension() );
      48         184 :   for(unsigned i=0; i<ipbc.size(); ++i) ipbc[i] = ag->getGridCoordinatesObject().isPeriodic(i);
      49         176 :   gridobject.setup( "flat", ipbc, 0, 0.0 );
      50             :   // Now use this information to create a gridobject
      51             :   std::vector<std::string> argn;
      52          88 :   parseFlag(action,"ZERO_OUTSIDE_GRID_RANGE",set_zero_outside_range);
      53          88 :   if( set_zero_outside_range ) action->log.printf("  function is zero outside grid range \n");
      54             :   // Get the type of interpolation that we are doing
      55         176 :   std::string itype; parse(action,"INTERPOLATION_TYPE",itype);
      56          88 :   if( itype=="spline" ) {
      57           8 :     interpolation_type=spline;
      58           8 :     spline_interpolator=Tools::make_unique<Interpolator>( action->getPntrToArgument(0), gridobject );
      59          80 :   } else if( itype=="linear" ) {
      60          65 :     interpolation_type=linear;
      61          15 :   } else if( itype=="floor" ) {
      62           0 :     interpolation_type=floor;
      63          15 :   } else if( itype=="ceiling" ) {
      64          15 :     interpolation_type=ceiling;
      65           0 :   } else action->error("type " + itype + " of interpolation is not defined");
      66          88 :   action->log.printf("  generating off grid points using %s interpolation \n", itype.c_str() );
      67         176 : }
      68             : 
      69         170 : void EvaluateGridFunction::setup( ActionWithValue* action ) {
      70         170 :   FunctionTemplateBase::setup( action );
      71         170 :   ActionWithArguments* aarg = dynamic_cast<ActionWithArguments*>( action );
      72         170 :   ActionWithGrid* ag = ActionWithGrid::getInputActionWithGrid( (aarg->getPntrToArgument(0))->getPntrToAction() );
      73         170 :   const GridCoordinatesObject & ingrid = ag->getGridCoordinatesObject(); std::vector<double> sp( ingrid.getGridSpacing() );
      74         170 :   gridobject.setBounds( ingrid.getMin(), ingrid.getMax(), ingrid.getNbin(false), sp );
      75         170 : }
      76             : 
      77       39145 : void EvaluateGridFunction::calc( const ActionWithArguments* action, const std::vector<double>& args, std::vector<double>& vals, Matrix<double>& derivatives ) const {
      78       39145 :   if( set_zero_outside_range && !gridobject.inbounds( args ) ) { vals[0]=0.0; return; }
      79             :   unsigned dimension = gridobject.getDimension(); plumed_dbg_assert( args.size()==dimension && vals.size()==1 );
      80       39145 :   if( interpolation_type==spline ) {
      81       23729 :     std::vector<double> der( dimension );
      82       23729 :     vals[0] =  spline_interpolator->splineInterpolation( args, der );
      83       74416 :     for(unsigned j=0; j<dimension; ++j) derivatives(0,j) = der[j];
      84       15416 :   } else if( interpolation_type==linear ) {
      85        6600 :     Value* values=action->getPntrToArgument(0); std::vector<double> xfloor(dimension);
      86        6600 :     std::vector<unsigned> indices(dimension), nindices(dimension), ind(dimension);
      87        6600 :     gridobject.getIndices( args, indices ); unsigned nn=gridobject.getIndex(args);
      88        6600 :     gridobject.getGridPointCoordinates( nn, nindices, xfloor );
      89        6600 :     double y1 = values->get(nn); vals[0] = y1;
      90       13200 :     for(unsigned i=0; i<args.size(); ++i) {
      91        6600 :       int x0=1; if(nindices[i]==indices[i]) x0=0;
      92        6600 :       double ddx=gridobject.getGridSpacing()[i];
      93        6600 :       double X = fabs((args[i]-xfloor[i])/ddx-(double)x0);
      94       13200 :       for(unsigned j=0; j<args.size(); ++j) ind[j] = indices[j];
      95        6600 :       if( gridobject.isPeriodic(i) && (ind[i]+1)==gridobject.getNbin(false)[i] ) ind[i]=0;
      96        6600 :       else ind[i] = ind[i] + 1;
      97        6600 :       vals[0] += ( values->get( gridobject.getIndex(ind) ) - y1 )*X;
      98        6600 :       derivatives(0,i) = ( values->get( gridobject.getIndex(ind) ) - y1 ) / ddx;
      99             :     }
     100        8816 :   } else if( interpolation_type==floor ) {
     101           0 :     Value* values=action->getPntrToArgument(0); std::vector<unsigned> indices(dimension);
     102           0 :     gridobject.getIndices( args, indices ); unsigned nn = gridobject.getIndex(indices);
     103             :     plumed_dbg_assert( nn<values->getNumberOfValues() );
     104           0 :     vals[0] = values->get( nn );
     105           0 :     for(unsigned j=0; j<dimension; ++j) derivatives(0,j) = values->getGridDerivative( nn, j );
     106        8816 :   } else if( interpolation_type==ceiling ) {
     107        8816 :     Value* values=action->getPntrToArgument(0); std::vector<unsigned> indices(dimension);
     108        8816 :     gridobject.getIndices( args, indices );
     109       17632 :     for(unsigned i=0; i<indices.size(); ++i) {
     110       17632 :       if( gridobject.isPeriodic(i) && (indices[i]+1)==gridobject.getNbin(false)[i] ) indices[i]=0;
     111        6612 :       else indices[i] = indices[i] + 1;
     112             :     }
     113        8816 :     unsigned nn = gridobject.getIndex(indices); vals[0] = values->get( nn );
     114       17632 :     for(unsigned j=0; j<dimension; ++j) derivatives(0,j) = values->getGridDerivative( nn, j );
     115           0 :   } else plumed_error();
     116             : }
     117             : 
     118        1956 : void EvaluateGridFunction::applyForce( const ActionWithArguments* action, const std::vector<double>& args, const double& force, std::vector<double>& forcesToApply ) const {
     119             :   unsigned dimension = gridobject.getDimension();
     120        1956 :   if( interpolation_type==spline ) {
     121           0 :     action->error("can't apply forces on values interpolated using splines");
     122        1956 :   } else if( interpolation_type==linear ) {
     123         100 :     Value* values=action->getPntrToArgument(0); std::vector<double> xfloor(dimension);
     124         100 :     std::vector<unsigned> indices(dimension), nindices(dimension), ind(dimension);
     125         100 :     gridobject.getIndices( args, indices ); unsigned nn=gridobject.getIndex(args);
     126         100 :     gridobject.getGridPointCoordinates( nn, nindices, xfloor );
     127         200 :     for(unsigned i=0; i<args.size(); ++i) {
     128         100 :       int x0=1; if(nindices[i]==indices[i]) x0=0;
     129         100 :       double ddx=gridobject.getGridSpacing()[i];
     130         100 :       double X = fabs((args[i]-xfloor[i])/ddx-(double)x0);
     131         200 :       for(unsigned j=0; j<args.size(); ++j) ind[j] = indices[j];
     132         100 :       if( gridobject.isPeriodic(i) && (ind[i]+1)==gridobject.getNbin(false)[i] ) ind[i]=0;
     133         100 :       else ind[i] = ind[i] + 1;
     134         100 :       forcesToApply[nn] += force*(1-X); forcesToApply[gridobject.getIndex(ind)] += X*force;
     135             :     }
     136        1856 :   } else if( interpolation_type==floor ) {
     137           0 :     Value* values=action->getPntrToArgument(0); std::vector<unsigned> indices(dimension);
     138           0 :     gridobject.getIndices( args, indices ); unsigned nn = gridobject.getIndex(indices);
     139           0 :     forcesToApply[nn] += force;
     140        1856 :   } else if( interpolation_type==ceiling ) {
     141        1856 :     Value* values=action->getPntrToArgument(0); std::vector<unsigned> indices(dimension);
     142        1856 :     gridobject.getIndices( args, indices );
     143        3712 :     for(unsigned i=0; i<indices.size(); ++i) {
     144        3712 :       if( gridobject.isPeriodic(i) && (indices[i]+1)==gridobject.getNbin(false)[i] ) indices[i]=0;
     145        1392 :       else indices[i] = indices[i] + 1;
     146             :     }
     147        1856 :     unsigned nn = gridobject.getIndex(indices); forcesToApply[nn] += force;
     148           0 :   } else plumed_error();
     149        1956 : }
     150             : 
     151             : }
     152             : }

Generated by: LCOV version 1.16