LCOV - code coverage report
Current view: top level - gridtools - FunctionOfGrid.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 101 103 98.1 %
Date: 2024-10-18 14:00:25 Functions: 16 20 80.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2020 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_gridtools_FunctionOfGrid_h
      23             : #define __PLUMED_gridtools_FunctionOfGrid_h
      24             : 
      25             : #include "ActionWithGrid.h"
      26             : #include "function/Custom.h"
      27             : #include "tools/Matrix.h"
      28             : 
      29             : namespace PLMD {
      30             : namespace gridtools {
      31             : 
      32             : template <class T>
      33             : class FunctionOfGrid : public ActionWithGrid {
      34             : private:
      35             : /// The function that is being computed
      36             :   T myfunc;
      37             : public:
      38             :   static void registerKeywords(Keywords&);
      39             :   explicit FunctionOfGrid(const ActionOptions&);
      40             : /// This does setup required on first step
      41             :   void setupOnFirstStep( const bool incalc ) override ;
      42             : /// Get the number of derivatives for this action
      43             :   unsigned getNumberOfDerivatives() override ;
      44             : /// Get the label to write in the graph
      45           0 :   std::string writeInGraph() const override { return myfunc.getGraphInfo( getName() ); }
      46             : /// Get the underlying names
      47             :   std::vector<std::string> getGridCoordinateNames() const override ;
      48             : /// Get the underlying grid coordinates object
      49             :   const GridCoordinatesObject& getGridCoordinatesObject() const override ;
      50             : /// Calculate the function
      51             :   void performTask( const unsigned& current, MultiValue& myvals ) const override ;
      52             : ///
      53             :   void gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals,
      54             :                           const unsigned& bufstart, std::vector<double>& buffer ) const override ;
      55             : /// Add the forces
      56             :   void apply() override;
      57             : };
      58             : 
      59             : template <class T>
      60        1047 : void FunctionOfGrid<T>::registerKeywords(Keywords& keys ) {
      61        2094 :   ActionWithGrid::registerKeywords(keys); keys.use("ARG"); std::string name = keys.getDisplayName();
      62        1047 :   std::size_t und=name.find("_GRID"); keys.setDisplayName( name.substr(0,und) );
      63        2094 :   keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function.  If the output is not periodic you must state this using PERIODIC=NO");
      64        1047 :   T tfunc; tfunc.registerKeywords( keys ); if( typeid(tfunc)==typeid(function::Custom()) ) keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
      65        2094 :   if( keys.getDisplayName()=="INTEGRATE") {
      66         294 :     keys.setValueDescription("the numerical integral of the input function over its whole domain");
      67        1800 :   } else if( keys.outputComponentExists(".#!value") ) {
      68        1800 :     keys.setValueDescription("the grid obtained by doing an element-wise application of " + keys.getOutputComponentDescription(".#!value") + " to the input grid");
      69             :   }
      70        1918 : }
      71             : 
      72             : template <class T>
      73         523 : FunctionOfGrid<T>::FunctionOfGrid(const ActionOptions&ao):
      74             :   Action(ao),
      75         523 :   ActionWithGrid(ao)
      76             : {
      77         523 :   if( getNumberOfArguments()==0 ) error("found no arguments");
      78             :   // This will require a fix
      79         523 :   if( getPntrToArgument(0)->getRank()==0 || !getPntrToArgument(0)->hasDerivatives() ) error("first input to this action must be a grid");
      80             :   // Get the shape of the input grid
      81         523 :   std::vector<unsigned> shape( getPntrToArgument(0)->getShape() );
      82         937 :   for(unsigned i=1; i<getNumberOfArguments(); ++i ) {
      83         414 :     if( getPntrToArgument(i)->getRank()==0 ) continue;
      84         292 :     std::vector<unsigned> s( getPntrToArgument(i)->getShape() );
      85         292 :     if( s.size()!=shape.size() ) error("mismatch between dimensionalities of input grids");
      86             :   }
      87             :   // Read the input and do some checks
      88         523 :   myfunc.read( this );
      89             :   // Check we are not calculating an integral
      90         523 :   if( myfunc.zeroRank() ) { shape.resize(0); }
      91             :   // Check that derivatives are available
      92          90 :   if( !myfunc.derivativesImplemented() ) error("derivatives have not been implemended for " + getName() );
      93             :   // Get the names of the components
      94         523 :   std::vector<std::string> components( keywords.getOutputComponents() );
      95             :   // Create the values to hold the output
      96        1046 :   if( components.size()!=1 || components[0]!=".#!value" ) error("functions of grid should only output one grid");
      97         523 :   addValueWithDerivatives( shape );
      98             :   // Set the periodicities of the output components
      99         523 :   myfunc.setPeriodicityForOutputs( this );
     100             :   // Check if we can turn off the derivatives when they are zero
     101         433 :   if( myfunc.getDerivativeZeroIfValueIsZero() )  {
     102         492 :     for(int i=0; i<getNumberOfComponents(); ++i) getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero();
     103             :   }
     104         523 :   setupOnFirstStep( false );
     105        1046 : }
     106             : 
     107             : template <class T>
     108        1045 : void FunctionOfGrid<T>::setupOnFirstStep( const bool incalc ) {
     109             :   double volume = 1.0;
     110        1045 :   const GridCoordinatesObject& mygrid = getGridCoordinatesObject();
     111        1045 :   unsigned npoints = getPntrToArgument(0)->getNumberOfValues();
     112        2090 :   if( mygrid.getGridType()=="flat" ) {
     113         999 :     std::vector<unsigned> shape( getGridCoordinatesObject().getNbin(true) );
     114        1794 :     for(unsigned i=1; i<getNumberOfArguments(); ++i ) {
     115         795 :       if( getPntrToArgument(i)->getRank()==0 ) continue;
     116         573 :       std::vector<unsigned> s( getPntrToArgument(i)->getShape() );
     117        1160 :       for(unsigned j=0; j<shape.size(); ++j) {
     118         587 :         if( shape[j]!=s[j] ) error("mismatch between sizes of input grids");
     119             :       }
     120             :     }
     121        1998 :     for(int i=0; i<getNumberOfComponents(); ++i) {
     122         999 :       if( getPntrToComponent(i)->getRank()>0 ) getPntrToComponent(i)->setShape(shape);
     123             :     }
     124         999 :     std::vector<double> vv( getGridCoordinatesObject().getGridSpacing() );
     125        1081 :     volume=vv[0]; for(unsigned i=1; i<vv.size(); ++i) volume *=vv[i];
     126          14 :   } else volume=4*pi / static_cast<double>( npoints );
     127             :   // This resizes the scalars
     128        2090 :   for(int i=0; i<getNumberOfComponents(); ++i) {
     129        1045 :     if( getPntrToComponent(i)->getRank()==0 ) getPntrToComponent(i)->resizeDerivatives( npoints );
     130             :   }
     131        1045 :   if( getName()=="SUM_GRID" ) volume = 1.0;
     132             :   // This sets the prefactor to the volume which converts integrals to sums
     133        1045 :   myfunc.setup( this ); myfunc.setPrefactor( this, volume );
     134        1045 : }
     135             : 
     136             : template <class T>
     137       15460 : const GridCoordinatesObject& FunctionOfGrid<T>::getGridCoordinatesObject() const {
     138       15460 :   ActionWithGrid* ag=ActionWithGrid::getInputActionWithGrid( getPntrToArgument(0)->getPntrToAction() );
     139       15460 :   plumed_assert( ag ); return ag->getGridCoordinatesObject();
     140             : }
     141             : 
     142             : template <class T>
     143         198 : std::vector<std::string> FunctionOfGrid<T>::getGridCoordinateNames() const {
     144         198 :   ActionWithGrid* ag=ActionWithGrid::getInputActionWithGrid( getPntrToArgument(0)->getPntrToAction() );
     145         198 :   plumed_assert( ag ); return ag->getGridCoordinateNames();
     146             : }
     147             : 
     148             : template <class T>
     149        1814 : unsigned FunctionOfGrid<T>::getNumberOfDerivatives() {
     150         514 :   if( myfunc.zeroRank() ) return getPntrToArgument(0)->getNumberOfValues();
     151        1300 :   unsigned nder = getGridCoordinatesObject().getDimension();
     152        1300 :   return getGridCoordinatesObject().getDimension() + getNumberOfArguments() - myfunc.getArgStart();
     153             : }
     154             : 
     155             : template <class T>
     156      974814 : void FunctionOfGrid<T>::performTask( const unsigned& current, MultiValue& myvals ) const {
     157      974814 :   unsigned argstart=myfunc.getArgStart(); std::vector<double> args( getNumberOfArguments() - argstart );
     158     2629182 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     159     1654368 :     if( getPntrToArgument(i)->getRank()==0 ) args[i-argstart]=getPntrToArgument(i)->get();
     160     1110136 :     else args[i-argstart] = getPntrToArgument(i)->get(current);
     161             :   }
     162             :   // Calculate the function and its derivatives
     163      974814 :   std::vector<double> vals(1); Matrix<double> derivatives( 1, getNumberOfArguments()-argstart );
     164      974814 :   myfunc.calc( this, args, vals, derivatives ); unsigned np = myvals.getTaskIndex();
     165             :   // And set the values and derivatives
     166      974814 :   unsigned ostrn = getConstPntrToComponent(0)->getPositionInStream();
     167      974814 :   myvals.addValue( ostrn, vals[0] );
     168       23665 :   if( !myfunc.zeroRank() ) {
     169             :     // Add the derivatives for a grid
     170     2581852 :     for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     171             :       // We store all the derivatives of all the input values - i.e. the grid points these are used in apply
     172     1630703 :       myvals.addDerivative( ostrn, getConstPntrToComponent(0)->getRank()+j-argstart, derivatives(0,j-argstart) );
     173             :       // And now we calculate the derivatives of the value that is stored on the grid correctly so that we can interpolate functions
     174     1630703 :       if( getPntrToArgument(j)->getRank()!=0 ) {
     175     3413677 :         for(unsigned k=0; k<getPntrToArgument(j)->getRank(); ++k) myvals.addDerivative( ostrn, k, derivatives(0,j-argstart)*getPntrToArgument(j)->getGridDerivative( np, k ) );
     176             :       }
     177             :     }
     178      951149 :     unsigned nderivatives = getConstPntrToComponent(0)->getNumberOfGridDerivatives();
     179     4575808 :     for(unsigned j=0; j<nderivatives; ++j) myvals.updateIndex( ostrn, j );
     180       23665 :   } else if( !doNotCalculateDerivatives() ) {
     181             :     // These are the derivatives of the integral
     182        8161 :     myvals.addDerivative( ostrn, current, derivatives(0,0) ); myvals.updateIndex( ostrn, current );
     183             :   }
     184      974814 : }
     185             : 
     186             : template <class T>
     187      951149 : void FunctionOfGrid<T>::gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals,
     188             :     const unsigned& bufstart, std::vector<double>& buffer ) const {
     189      951149 :   if( getConstPntrToComponent(0)->getRank()>0 && getConstPntrToComponent(0)->hasDerivatives() ) {
     190             :     plumed_dbg_assert( getNumberOfComponents()==1 && valindex==0 );
     191      951149 :     unsigned nder = getConstPntrToComponent(0)->getNumberOfGridDerivatives();
     192      951149 :     unsigned ostr = getConstPntrToComponent(0)->getPositionInStream();
     193      951149 :     unsigned kp = bufstart + code*(1+nder); buffer[kp] += myvals.get( ostr );
     194     4575808 :     for(unsigned i=0; i<nder; ++i) buffer[kp + 1 + i] += myvals.getDerivative( ostr, i );
     195           0 :   } else ActionWithVector::gatherStoredValue( valindex, code, myvals, bufstart, buffer );
     196      951149 : }
     197             : 
     198             : template <class T>
     199        6067 : void FunctionOfGrid<T>::apply() {
     200        6561 :   if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) return;
     201             : 
     202             :   // This applies forces for the integral
     203         494 :   if( myfunc.zeroRank() ) { ActionWithVector::apply(); return; }
     204             : 
     205             :   // Work out how to deal with arguments
     206             :   unsigned nscalars=0, argstart=myfunc.getArgStart();
     207        1870 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     208        1245 :     if( getPntrToArgument(i)->getRank()==0 ) { nscalars++; }
     209             :   }
     210             : 
     211         625 :   std::vector<double> totv(nscalars,0); Value* outval=getPntrToComponent(0);
     212        7575 :   for(unsigned i=0; i<outval->getNumberOfValues(); ++i) {
     213             :     nscalars=0;
     214       19845 :     for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     215       12895 :       double fforce = outval->getForce(i);
     216       12895 :       if( getPntrToArgument(j)->getRank()==0 ) {
     217        3205 :         totv[nscalars] += fforce*outval->getGridDerivative( i, outval->getRank()+j ); nscalars++;
     218             :       } else {
     219        9690 :         double vval = outval->getGridDerivative( i, outval->getRank()+j  );
     220        9690 :         getPntrToArgument(j)->addForce( i, fforce*vval );
     221             :       }
     222             :     }
     223             :   }
     224             :   nscalars=0;
     225        1870 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     226        1245 :     if( getPntrToArgument(i)->getRank()==0 ) { getPntrToArgument(i)->addForce( 0, totv[nscalars] ); nscalars++; }
     227             :   }
     228             : 
     229             : }
     230             : 
     231             : }
     232             : }
     233             : #endif

Generated by: LCOV version 1.16