Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2011-2020 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #ifndef __PLUMED_gridtools_FunctionOfGrid_h 23 : #define __PLUMED_gridtools_FunctionOfGrid_h 24 : 25 : #include "ActionWithGrid.h" 26 : #include "function/Custom.h" 27 : #include "tools/Matrix.h" 28 : 29 : namespace PLMD { 30 : namespace gridtools { 31 : 32 : template <class T> 33 : class FunctionOfGrid : public ActionWithGrid { 34 : private: 35 : /// The function that is being computed 36 : T myfunc; 37 : public: 38 : static void registerKeywords(Keywords&); 39 : explicit FunctionOfGrid(const ActionOptions&); 40 : /// This does setup required on first step 41 : void setupOnFirstStep( const bool incalc ) override ; 42 : /// Get the number of derivatives for this action 43 : unsigned getNumberOfDerivatives() override ; 44 : /// Get the label to write in the graph 45 0 : std::string writeInGraph() const override { return myfunc.getGraphInfo( getName() ); } 46 : /// Get the underlying names 47 : std::vector<std::string> getGridCoordinateNames() const override ; 48 : /// Get the underlying grid coordinates object 49 : const GridCoordinatesObject& getGridCoordinatesObject() const override ; 50 : /// Calculate the function 51 : void performTask( const unsigned& current, MultiValue& myvals ) const override ; 52 : /// 53 : void gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals, 54 : const unsigned& bufstart, std::vector<double>& buffer ) const override ; 55 : /// Add the forces 56 : void apply() override; 57 : }; 58 : 59 : template <class T> 60 1047 : void FunctionOfGrid<T>::registerKeywords(Keywords& keys ) { 61 2094 : ActionWithGrid::registerKeywords(keys); keys.use("ARG"); std::string name = keys.getDisplayName(); 62 1047 : std::size_t und=name.find("_GRID"); keys.setDisplayName( name.substr(0,und) ); 63 2094 : keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function. If the output is not periodic you must state this using PERIODIC=NO"); 64 1047 : T tfunc; tfunc.registerKeywords( keys ); if( typeid(tfunc)==typeid(function::Custom()) ) keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log"); 65 2094 : if( keys.getDisplayName()=="INTEGRATE") { 66 294 : keys.setValueDescription("the numerical integral of the input function over its whole domain"); 67 1800 : } else if( keys.outputComponentExists(".#!value") ) { 68 1800 : keys.setValueDescription("the grid obtained by doing an element-wise application of " + keys.getOutputComponentDescription(".#!value") + " to the input grid"); 69 : } 70 1918 : } 71 : 72 : template <class T> 73 523 : FunctionOfGrid<T>::FunctionOfGrid(const ActionOptions&ao): 74 : Action(ao), 75 523 : ActionWithGrid(ao) 76 : { 77 523 : if( getNumberOfArguments()==0 ) error("found no arguments"); 78 : // This will require a fix 79 523 : if( getPntrToArgument(0)->getRank()==0 || !getPntrToArgument(0)->hasDerivatives() ) error("first input to this action must be a grid"); 80 : // Get the shape of the input grid 81 523 : std::vector<unsigned> shape( getPntrToArgument(0)->getShape() ); 82 937 : for(unsigned i=1; i<getNumberOfArguments(); ++i ) { 83 414 : if( getPntrToArgument(i)->getRank()==0 ) continue; 84 292 : std::vector<unsigned> s( getPntrToArgument(i)->getShape() ); 85 292 : if( s.size()!=shape.size() ) error("mismatch between dimensionalities of input grids"); 86 : } 87 : // Read the input and do some checks 88 523 : myfunc.read( this ); 89 : // Check we are not calculating an integral 90 523 : if( myfunc.zeroRank() ) { shape.resize(0); } 91 : // Check that derivatives are available 92 90 : if( !myfunc.derivativesImplemented() ) error("derivatives have not been implemended for " + getName() ); 93 : // Get the names of the components 94 523 : std::vector<std::string> components( keywords.getOutputComponents() ); 95 : // Create the values to hold the output 96 1046 : if( components.size()!=1 || components[0]!=".#!value" ) error("functions of grid should only output one grid"); 97 523 : addValueWithDerivatives( shape ); 98 : // Set the periodicities of the output components 99 523 : myfunc.setPeriodicityForOutputs( this ); 100 : // Check if we can turn off the derivatives when they are zero 101 433 : if( myfunc.getDerivativeZeroIfValueIsZero() ) { 102 492 : for(int i=0; i<getNumberOfComponents(); ++i) getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero(); 103 : } 104 523 : setupOnFirstStep( false ); 105 1046 : } 106 : 107 : template <class T> 108 1045 : void FunctionOfGrid<T>::setupOnFirstStep( const bool incalc ) { 109 : double volume = 1.0; 110 1045 : const GridCoordinatesObject& mygrid = getGridCoordinatesObject(); 111 1045 : unsigned npoints = getPntrToArgument(0)->getNumberOfValues(); 112 2090 : if( mygrid.getGridType()=="flat" ) { 113 999 : std::vector<unsigned> shape( getGridCoordinatesObject().getNbin(true) ); 114 1794 : for(unsigned i=1; i<getNumberOfArguments(); ++i ) { 115 795 : if( getPntrToArgument(i)->getRank()==0 ) continue; 116 573 : std::vector<unsigned> s( getPntrToArgument(i)->getShape() ); 117 1160 : for(unsigned j=0; j<shape.size(); ++j) { 118 587 : if( shape[j]!=s[j] ) error("mismatch between sizes of input grids"); 119 : } 120 : } 121 1998 : for(int i=0; i<getNumberOfComponents(); ++i) { 122 999 : if( getPntrToComponent(i)->getRank()>0 ) getPntrToComponent(i)->setShape(shape); 123 : } 124 999 : std::vector<double> vv( getGridCoordinatesObject().getGridSpacing() ); 125 1081 : volume=vv[0]; for(unsigned i=1; i<vv.size(); ++i) volume *=vv[i]; 126 14 : } else volume=4*pi / static_cast<double>( npoints ); 127 : // This resizes the scalars 128 2090 : for(int i=0; i<getNumberOfComponents(); ++i) { 129 1045 : if( getPntrToComponent(i)->getRank()==0 ) getPntrToComponent(i)->resizeDerivatives( npoints ); 130 : } 131 1045 : if( getName()=="SUM_GRID" ) volume = 1.0; 132 : // This sets the prefactor to the volume which converts integrals to sums 133 1045 : myfunc.setup( this ); myfunc.setPrefactor( this, volume ); 134 1045 : } 135 : 136 : template <class T> 137 15460 : const GridCoordinatesObject& FunctionOfGrid<T>::getGridCoordinatesObject() const { 138 15460 : ActionWithGrid* ag=ActionWithGrid::getInputActionWithGrid( getPntrToArgument(0)->getPntrToAction() ); 139 15460 : plumed_assert( ag ); return ag->getGridCoordinatesObject(); 140 : } 141 : 142 : template <class T> 143 198 : std::vector<std::string> FunctionOfGrid<T>::getGridCoordinateNames() const { 144 198 : ActionWithGrid* ag=ActionWithGrid::getInputActionWithGrid( getPntrToArgument(0)->getPntrToAction() ); 145 198 : plumed_assert( ag ); return ag->getGridCoordinateNames(); 146 : } 147 : 148 : template <class T> 149 1814 : unsigned FunctionOfGrid<T>::getNumberOfDerivatives() { 150 514 : if( myfunc.zeroRank() ) return getPntrToArgument(0)->getNumberOfValues(); 151 1300 : unsigned nder = getGridCoordinatesObject().getDimension(); 152 1300 : return getGridCoordinatesObject().getDimension() + getNumberOfArguments() - myfunc.getArgStart(); 153 : } 154 : 155 : template <class T> 156 974814 : void FunctionOfGrid<T>::performTask( const unsigned& current, MultiValue& myvals ) const { 157 974814 : unsigned argstart=myfunc.getArgStart(); std::vector<double> args( getNumberOfArguments() - argstart ); 158 2629182 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) { 159 1654368 : if( getPntrToArgument(i)->getRank()==0 ) args[i-argstart]=getPntrToArgument(i)->get(); 160 1110136 : else args[i-argstart] = getPntrToArgument(i)->get(current); 161 : } 162 : // Calculate the function and its derivatives 163 974814 : std::vector<double> vals(1); Matrix<double> derivatives( 1, getNumberOfArguments()-argstart ); 164 974814 : myfunc.calc( this, args, vals, derivatives ); unsigned np = myvals.getTaskIndex(); 165 : // And set the values and derivatives 166 974814 : unsigned ostrn = getConstPntrToComponent(0)->getPositionInStream(); 167 974814 : myvals.addValue( ostrn, vals[0] ); 168 23665 : if( !myfunc.zeroRank() ) { 169 : // Add the derivatives for a grid 170 2581852 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) { 171 : // We store all the derivatives of all the input values - i.e. the grid points these are used in apply 172 1630703 : myvals.addDerivative( ostrn, getConstPntrToComponent(0)->getRank()+j-argstart, derivatives(0,j-argstart) ); 173 : // And now we calculate the derivatives of the value that is stored on the grid correctly so that we can interpolate functions 174 1630703 : if( getPntrToArgument(j)->getRank()!=0 ) { 175 3413677 : for(unsigned k=0; k<getPntrToArgument(j)->getRank(); ++k) myvals.addDerivative( ostrn, k, derivatives(0,j-argstart)*getPntrToArgument(j)->getGridDerivative( np, k ) ); 176 : } 177 : } 178 951149 : unsigned nderivatives = getConstPntrToComponent(0)->getNumberOfGridDerivatives(); 179 4575808 : for(unsigned j=0; j<nderivatives; ++j) myvals.updateIndex( ostrn, j ); 180 23665 : } else if( !doNotCalculateDerivatives() ) { 181 : // These are the derivatives of the integral 182 8161 : myvals.addDerivative( ostrn, current, derivatives(0,0) ); myvals.updateIndex( ostrn, current ); 183 : } 184 974814 : } 185 : 186 : template <class T> 187 951149 : void FunctionOfGrid<T>::gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals, 188 : const unsigned& bufstart, std::vector<double>& buffer ) const { 189 951149 : if( getConstPntrToComponent(0)->getRank()>0 && getConstPntrToComponent(0)->hasDerivatives() ) { 190 : plumed_dbg_assert( getNumberOfComponents()==1 && valindex==0 ); 191 951149 : unsigned nder = getConstPntrToComponent(0)->getNumberOfGridDerivatives(); 192 951149 : unsigned ostr = getConstPntrToComponent(0)->getPositionInStream(); 193 951149 : unsigned kp = bufstart + code*(1+nder); buffer[kp] += myvals.get( ostr ); 194 4575808 : for(unsigned i=0; i<nder; ++i) buffer[kp + 1 + i] += myvals.getDerivative( ostr, i ); 195 0 : } else ActionWithVector::gatherStoredValue( valindex, code, myvals, bufstart, buffer ); 196 951149 : } 197 : 198 : template <class T> 199 6067 : void FunctionOfGrid<T>::apply() { 200 6561 : if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) return; 201 : 202 : // This applies forces for the integral 203 494 : if( myfunc.zeroRank() ) { ActionWithVector::apply(); return; } 204 : 205 : // Work out how to deal with arguments 206 : unsigned nscalars=0, argstart=myfunc.getArgStart(); 207 1870 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) { 208 1245 : if( getPntrToArgument(i)->getRank()==0 ) { nscalars++; } 209 : } 210 : 211 625 : std::vector<double> totv(nscalars,0); Value* outval=getPntrToComponent(0); 212 7575 : for(unsigned i=0; i<outval->getNumberOfValues(); ++i) { 213 : nscalars=0; 214 19845 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) { 215 12895 : double fforce = outval->getForce(i); 216 12895 : if( getPntrToArgument(j)->getRank()==0 ) { 217 3205 : totv[nscalars] += fforce*outval->getGridDerivative( i, outval->getRank()+j ); nscalars++; 218 : } else { 219 9690 : double vval = outval->getGridDerivative( i, outval->getRank()+j ); 220 9690 : getPntrToArgument(j)->addForce( i, fforce*vval ); 221 : } 222 : } 223 : } 224 : nscalars=0; 225 1870 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) { 226 1245 : if( getPntrToArgument(i)->getRank()==0 ) { getPntrToArgument(i)->addForce( 0, totv[nscalars] ); nscalars++; } 227 : } 228 : 229 : } 230 : 231 : } 232 : } 233 : #endif