LCOV - code coverage report
Current view: top level - function - FunctionOfVector.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 160 168 95.2 %
Date: 2024-10-18 14:00:25 Functions: 147 171 86.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2020 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_function_FunctionOfVector_h
      23             : #define __PLUMED_function_FunctionOfVector_h
      24             : 
      25             : #include "core/ActionWithVector.h"
      26             : //#include "core/CollectFrames.h"
      27             : #include "core/ActionSetup.h"
      28             : #include "tools/Matrix.h"
      29             : #include "Sum.h"
      30             : 
      31             : namespace PLMD {
      32             : namespace function {
      33             : 
      34             : template <class T>
      35             : class FunctionOfVector : public ActionWithVector {
      36             : private:
      37             : /// Do the calculation at the end of the run
      38             :   bool doAtEnd;
      39             : /// Is this the first time we are doing the calc
      40             :   bool firststep;
      41             : /// The function that is being computed
      42             :   T myfunc;
      43             : /// The number of derivatives for this action
      44             :   unsigned nderivatives;
      45             : /// A vector that tells us if we have stored the input value
      46             :   std::vector<bool> stored_arguments;
      47             : public:
      48             :   static void registerKeywords(Keywords&);
      49             : /// This method is used to run the calculation with functions such as highest/lowest and sort.
      50             : /// It is static so we can reuse the functionality in FunctionOfMatrix
      51             :   static void runSingleTaskCalculation( const Value* arg, ActionWithValue* action, T& f );
      52             :   explicit FunctionOfVector(const ActionOptions&);
      53        4660 :   ~FunctionOfVector() {}
      54             :   std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
      55             : /// Get the size of the task list at the end of the run
      56             :   unsigned getNumberOfFinalTasks();
      57             : /// Check if derivatives are available
      58             :   void turnOnDerivatives() override;
      59             : /// Get the number of derivatives for this action
      60             :   unsigned getNumberOfDerivatives() override ;
      61             : /// Resize vectors that are the wrong size
      62             :   void prepare() override ;
      63             : /// Check if all he actions are required
      64             :   void areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions );
      65             : /// Get the label to write in the graph
      66          20 :   std::string writeInGraph() const override { return myfunc.getGraphInfo( getName() ); }
      67             : /// This builds the task list for the action
      68             :   void calculate() override;
      69             : /// This ensures that we create some bookeeping stuff during the first step
      70             :   void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override ;
      71             : /// Calculate the function
      72             :   void performTask( const unsigned& current, MultiValue& myvals ) const override ;
      73             : };
      74             : 
      75             : template <class T>
      76        4574 : void FunctionOfVector<T>::registerKeywords(Keywords& keys ) {
      77        4574 :   Action::registerKeywords(keys); ActionWithValue::registerKeywords(keys); ActionWithArguments::registerKeywords(keys); keys.use("ARG");
      78        4574 :   std::string name = keys.getDisplayName(); std::size_t und=name.find("_VECTOR"); keys.setDisplayName( name.substr(0,und) );
      79        9148 :   keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function.  If the output is not periodic you must state this using PERIODIC=NO");
      80        9148 :   keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
      81        4574 :   T tfunc; tfunc.registerKeywords( keys );
      82        9148 :   if( keys.getDisplayName()=="SUM" ) {
      83        2408 :     keys.setValueDescription("the sum of all the elements in the input vector");
      84        6740 :   } else if( keys.getDisplayName()=="MEAN" ) {
      85         714 :     keys.setValueDescription("the mean of all the elements in the input vector");
      86        6026 :   } else if( keys.getDisplayName()=="HIGHEST" ) {
      87          84 :     keys.setValueDescription("the largest element of the input vector");
      88        5942 :   } else if( keys.getDisplayName()=="LOWEST" ) {
      89         118 :     keys.setValueDescription("the smallest element in the input vector");
      90        5824 :   } else if( keys.getDisplayName()=="SORT" ) {
      91          24 :     keys.setValueDescription("a vector that has been sorted into ascending order");
      92        5800 :   } else if( keys.outputComponentExists(".#!value") ) {
      93        5780 :     keys.setValueDescription("the vector obtained by doing an element-wise application of " + keys.getOutputComponentDescription(".#!value") + " to the input vectors");
      94             :   }
      95        7080 : }
      96             : 
      97             : template <class T>
      98        2264 : FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao):
      99             :   Action(ao),
     100             :   ActionWithVector(ao),
     101        2264 :   doAtEnd(true),
     102        2264 :   firststep(true),
     103        2264 :   nderivatives(0)
     104             : {
     105             :   // Get the shape of the output
     106        2264 :   std::vector<unsigned> shape(1); shape[0]=getNumberOfFinalTasks();
     107             :   // Read the input and do some checks
     108        2264 :   myfunc.read( this );
     109             :   // Create the task list
     110        2121 :   if( myfunc.doWithTasks() ) {
     111        2224 :     doAtEnd=false; if( shape[0]>0 ) done_in_chain=true;
     112          40 :   } else { plumed_assert( getNumberOfArguments()==1 ); done_in_chain=false; getPntrToArgument(0)->buildDataStore(); }
     113             :   // Get the names of the components
     114        2264 :   std::vector<std::string> components( keywords.getOutputComponents() );
     115             :   // Create the values to hold the output
     116          56 :   std::vector<std::string> str_ind( myfunc.getComponentsPerLabel() );
     117        4528 :   for(unsigned i=0; i<components.size(); ++i) {
     118           8 :     if( str_ind.size()>0 ) {
     119          16 :       std::string strcompn = components[i]; if( components[i]==".#!value" ) strcompn = "";
     120          34 :       for(unsigned j=0; j<str_ind.size(); ++j) {
     121          52 :         if( myfunc.zeroRank() ) addComponentWithDerivatives( strcompn + str_ind[j] );
     122           0 :         else addComponent( strcompn + str_ind[j], shape );
     123             :       }
     124        2256 :     } else if( components[i].find_first_of("_")!=std::string::npos ) {
     125           0 :       if( getNumberOfArguments()==1 && myfunc.zeroRank() ) addValueWithDerivatives();
     126           0 :       else if( getNumberOfArguments()==1 ) addValue( shape );
     127             :       else {
     128             :         unsigned argstart=myfunc.getArgStart();
     129           0 :         for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     130           0 :           if( myfunc.zeroRank() ) addComponentWithDerivatives( getPntrToArgument(i)->getName() + components[i] );
     131           0 :           else addComponent( getPntrToArgument(i)->getName() + components[i], shape );
     132             :         }
     133             :       }
     134        1635 :     } else if( components[i]==".#!value" && myfunc.zeroRank() ) addValueWithDerivatives();
     135        1446 :     else if( components[i]==".#!value" ) addValue(shape);
     136           0 :     else if( myfunc.zeroRank() ) addComponentWithDerivatives( components[i] );
     137           0 :     else addComponent( components[i], shape );
     138             :   }
     139             :   // Check if we can turn off the derivatives when they are zero
     140        1016 :   if( myfunc.getDerivativeZeroIfValueIsZero() )  {
     141         612 :     for(int i=0; i<getNumberOfComponents(); ++i) getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero();
     142             :   }
     143             :   // Check if this is a timeseries
     144             :   unsigned argstart=myfunc.getArgStart();
     145             :   // for(unsigned i=argstart; i<getNumberOfArguments();++i) {
     146             :   //   if( getPntrToArgument(i)->isTimeSeries() ) {
     147             :   //       for(unsigned i=0; i<getNumberOfComponents(); ++i) getPntrToOutput(i)->makeHistoryDependent();
     148             :   //       break;
     149             :   //   }
     150             :   // }
     151             :   // Set the periodicities of the output components
     152        2264 :   myfunc.setPeriodicityForOutputs( this );
     153             :   // Check if we can put the function in a chain
     154        6143 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     155             :     // CollectFrames* ab=dynamic_cast<CollectFrames*>( getPntrToArgument(i)->getPntrToAction() );
     156             :     // if( ab && ab->hasClear() ) { doNotChain=true; getPntrToArgument(i)->buildDataStore( getLabel() ); }
     157             :     // No chains if we are using a sum or a mean
     158        3879 :     if( getPntrToArgument(i)->getRank()==0 ) {
     159         246 :       FunctionOfVector<Sum>* as = dynamic_cast<FunctionOfVector<Sum>*>( getPntrToArgument(i)->getPntrToAction() );
     160         246 :       if(as) done_in_chain=false;
     161             :     } else {
     162        3633 :       ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
     163        3633 :       if( !av ) done_in_chain=false;
     164             :     }
     165             :   }
     166             :   // Don't need to do the calculation in a chain if the input is constant
     167             :   bool allconstant=true;
     168        2595 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     169        2275 :     if( !getPntrToArgument(i)->isConstant() ) { allconstant=false; break; }
     170             :   }
     171        2264 :   if( allconstant ) done_in_chain=false;
     172        2264 :   nderivatives = buildArgumentStore(myfunc.getArgStart());
     173        4528 : }
     174             : 
     175             : template <class T>
     176           5 : std::string FunctionOfVector<T>::getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const {
     177           5 :   if( getName().find("SORT")==std::string::npos ) return ActionWithValue::getOutputComponentDescription( cname, keys );
     178           8 :   if( getNumberOfArguments()==1 ) return "the " + cname + "th largest element of the vector " + getPntrToArgument(0)->getName();
     179           4 :   return "the " + cname + "th largest element in the input vectors";
     180             : }
     181             : 
     182             : template <class T>
     183        6140 : void FunctionOfVector<T>::turnOnDerivatives() {
     184        6140 :   if( !getPntrToComponent(0)->isConstant() && !myfunc.derivativesImplemented() ) error("derivatives have not been implemended for " + getName() );
     185        6140 :   ActionWithValue::turnOnDerivatives(); myfunc.setup(this );
     186        6140 : }
     187             : 
     188             : template <class T>
     189       53011 : unsigned FunctionOfVector<T>::getNumberOfDerivatives() {
     190       53011 :   return nderivatives;
     191             : }
     192             : 
     193             : template <class T>
     194      185527 : void FunctionOfVector<T>::prepare() {
     195      185527 :   unsigned argstart = myfunc.getArgStart(); std::vector<unsigned> shape(1);
     196      232505 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     197      232505 :     if( getPntrToArgument(i)->getRank()==1 ) {
     198      185527 :       shape[0] = getPntrToArgument(i)->getShape()[0]; break;
     199             :     }
     200             :   }
     201      371470 :   for(unsigned i=0; i<getNumberOfComponents(); ++i) {
     202      185943 :     Value* myval = getPntrToComponent(i);
     203      185943 :     if( myval->getRank()==1 && myval->getShape()[0]!=shape[0] ) { myval->setShape(shape); }
     204             :   }
     205      185527 :   ActionWithVector::prepare();
     206      185527 : }
     207             : 
     208             : template <class T>
     209      312538 : void FunctionOfVector<T>::setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) {
     210      312538 :   if( firststep ) {
     211        2168 :     stored_arguments.resize( getNumberOfArguments() );
     212        2168 :     std::string control = getFirstActionInChain()->getLabel();
     213        5936 :     for(unsigned i=0; i<stored_arguments.size(); ++i) {
     214        3768 :       if( getPntrToArgument(i)->isConstant() ) stored_arguments[i]=false;
     215        3249 :       else stored_arguments[i] = !getPntrToArgument(i)->ignoreStoredValue( control );
     216             :     }
     217        2168 :     firststep=false;
     218             :   }
     219      312538 :   ActionWithVector::setupStreamedComponents( headstr, nquants, nmat, maxcol, nbookeeping );
     220      312538 : }
     221             : 
     222             : template <class T>
     223     6634172 : void FunctionOfVector<T>::performTask( const unsigned& current, MultiValue& myvals ) const {
     224     6634172 :   unsigned argstart=myfunc.getArgStart(); std::vector<double> args( getNumberOfArguments()-argstart);
     225     6634172 :   if( actionInChain() ) {
     226     6832197 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     227     3923742 :       if(  getPntrToArgument(i)->getRank()==0 ) args[i-argstart] = getPntrToArgument(i)->get();
     228     3905954 :       else if( !getPntrToArgument(i)->valueHasBeenSet() ) args[i-argstart] = myvals.get( getPntrToArgument(i)->getPositionInStream() );
     229      154930 :       else args[i-argstart] = getPntrToArgument(i)->get( myvals.getTaskIndex() );
     230             :     }
     231             :   } else {
     232    11032808 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     233     7307091 :       if( getPntrToArgument(i)->getRank()==1 ) args[i-argstart]=getPntrToArgument(i)->get(current);
     234     3242545 :       else args[i-argstart] = getPntrToArgument(i)->get();
     235             :     }
     236             :   }
     237             :   // Calculate the function and its derivatives
     238     6634172 :   std::vector<double> vals( getNumberOfComponents() ); Matrix<double> derivatives( getNumberOfComponents(), args.size() );
     239     6634172 :   myfunc.calc( this, args, vals, derivatives );
     240             :   // And set the values
     241    13268344 :   for(unsigned i=0; i<vals.size(); ++i) myvals.addValue( getConstPntrToComponent(i)->getPositionInStream(), vals[i] );
     242             :   // Return if we are not computing derivatives
     243     6634172 :   if( doNotCalculateDerivatives() ) return;
     244             :   // And now compute the derivatives
     245             :   // Second condition here is only not true if actionInChain()==True if
     246             :   // input arguments the only non-constant objects in input are scalars.
     247             :   // In that case we can use the non chain version to calculate the derivatives
     248             :   // with respect to the scalar.
     249     5721550 :   if( actionInChain() ) {
     250     5183456 :     for(unsigned j=0; j<args.size(); ++j) {
     251        8375 :       unsigned istrn = getPntrToArgument(argstart+j)->getPositionInStream();
     252     2886517 :       if( stored_arguments[argstart+j] ) {
     253          70 :         unsigned task_index = myvals.getTaskIndex(); if( getPntrToArgument(argstart+j)->getRank()==0 ) task_index=0;
     254          70 :         myvals.addDerivative( istrn, task_index, 1.0 ); myvals.updateIndex( istrn, task_index );
     255             :       }
     256     2886517 :       unsigned arg_deriv_s = arg_deriv_starts[argstart+j];
     257   114509648 :       for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
     258   111623131 :         unsigned kind=myvals.getActiveIndex(istrn,k);
     259   223246262 :         for(int i=0; i<getNumberOfComponents(); ++i) {
     260   111623131 :           unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     261   111623131 :           myvals.addDerivative( ostrn, arg_deriv_s + kind, derivatives(i,j)*myvals.getDerivative( istrn, kind ) );
     262             :         }
     263             :       }
     264             :       // Ensure we only store one lot of derivative indices
     265     2886517 :       bool found=false; ActionWithValue* aav=getPntrToArgument(argstart+j)->getPntrToAction();
     266     2920439 :       for(unsigned k=0; k<j; ++k) {
     267      589642 :         if( arg_deriv_starts[argstart+k]==arg_deriv_s ) {
     268      555720 :           if( getPntrToArgument(argstart+k)->getPntrToAction()!=aav ) {
     269      386484 :             ActionWithVector* av = dynamic_cast<ActionWithVector*>( getPntrToArgument(argstart+j)->getPntrToAction() );
     270      386484 :             if( av ) {
     271      772968 :               for(int i=0; i<getNumberOfComponents(); ++i) av->updateAdditionalIndices( getConstPntrToComponent(i)->getPositionInStream(), myvals );
     272             :             }
     273             :           }
     274             :           found=true; break;
     275             :         }
     276             :       }
     277      555720 :       if( found ) continue;
     278    85327769 :       for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
     279    82996972 :         unsigned kind=myvals.getActiveIndex(istrn,k);
     280   165993944 :         for(int i=0; i<getNumberOfComponents(); ++i) {
     281    82996972 :           unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     282    82996972 :           myvals.updateIndex( ostrn, arg_deriv_s + kind );
     283             :         }
     284             :       }
     285             :     }
     286             :   } else {
     287             :     unsigned base=0;
     288    10141540 :     for(unsigned j=0; j<args.size(); ++j) {
     289     6716929 :       if( getPntrToArgument(argstart+j)->getRank()==1 ) {
     290     7241248 :         for(int i=0; i<getNumberOfComponents(); ++i) {
     291     3620624 :           unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     292     3620624 :           myvals.addDerivative( ostrn, base+current, derivatives(i,j) );
     293     3620624 :           myvals.updateIndex( ostrn, base+current );
     294             :         }
     295             :       } else {
     296     6192610 :         for(int i=0; i<getNumberOfComponents(); ++i) {
     297     3096305 :           unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     298     3096305 :           myvals.addDerivative( ostrn, base, derivatives(i,j) );
     299     3096305 :           myvals.updateIndex( ostrn, base );
     300             :         }
     301             :       }
     302     6716929 :       base += getPntrToArgument(argstart+j)->getNumberOfValues();
     303             :     }
     304             :   }
     305             : }
     306             : 
     307             : template <class T>
     308        2264 : unsigned FunctionOfVector<T>::getNumberOfFinalTasks() {
     309             :   unsigned nelements=0, argstart=myfunc.getArgStart();
     310        6143 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     311        3879 :     plumed_assert( getPntrToArgument(i)->getRank()<2 );
     312        3879 :     if( getPntrToArgument(i)->getRank()==1 ) {
     313        3633 :       if( nelements>0 ) {
     314             :         // if( getPntrToArgument(i)->isTimeSeries() && getPntrToArgument(i)->getShape()[0]<nelements ) nelements=getPntrToArgument(i)->isTimeSeries();
     315             :         // else
     316        1369 :         if(getPntrToArgument(i)->getShape()[0]!=nelements ) error("all vectors input should have the same length");
     317        2264 :       } else if( nelements==0 ) nelements=getPntrToArgument(i)->getShape()[0];
     318        3633 :       plumed_assert( !getPntrToArgument(i)->hasDerivatives() );
     319             :     }
     320             :   }
     321             :   // The prefactor for average and sum is set here so the number of input scalars is guaranteed to be correct
     322         777 :   myfunc.setPrefactor( this, 1.0 );
     323        2264 :   return nelements;
     324             : }
     325             : 
     326             : template <class T>
     327       12124 : void FunctionOfVector<T>::areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions ) {
     328       12124 :   if( task_reducing_actions.size()==0 ) return;
     329        2221 :   if( !myfunc.allComponentsRequired( getArguments(), task_reducing_actions ) ) task_reducing_actions.push_back(this);
     330             : }
     331             : 
     332             : template <class T>
     333        5541 : void FunctionOfVector<T>::runSingleTaskCalculation( const Value* arg, ActionWithValue* action, T& f ) {
     334             :   // This is used if we are doing sorting actions on a single vector
     335        5541 :   unsigned nv = arg->getNumberOfValues(); std::vector<double> args( nv );
     336     8198467 :   for(unsigned i=0; i<nv; ++i) args[i] = arg->get(i);
     337        5541 :   std::vector<double> vals( action->getNumberOfComponents() ); Matrix<double> derivatives( action->getNumberOfComponents(), nv );
     338        5541 :   ActionWithArguments* aa=dynamic_cast<ActionWithArguments*>(action); plumed_assert( aa ); f.calc( aa, args, vals, derivatives );
     339       11498 :   for(unsigned i=0; i<vals.size(); ++i) action->copyOutput(i)->set( vals[i] );
     340             :   // Return if we are not computing derivatives
     341        5541 :   if( action->doNotCalculateDerivatives() ) return;
     342             :   // Now set the derivatives
     343      198059 :   for(unsigned j=0; j<nv; ++j) {
     344      388720 :     for(unsigned i=0; i<vals.size(); ++i) action->copyOutput(i)->setDerivative( j, derivatives(i,j) );
     345             :   }
     346             : }
     347             : 
     348             : template <class T>
     349      184898 : void FunctionOfVector<T>::calculate() {
     350             :   // Everything is done elsewhere
     351      184898 :   if( actionInChain() ) return;
     352             :   // This is done if we are calculating a function of multiple cvs
     353       80095 :   if( !doAtEnd ) runAllTasks();
     354             :   // This is used if we are doing sorting actions on a single vector
     355        5541 :   else if( !myfunc.doWithTasks() ) runSingleTaskCalculation( getPntrToArgument(0), this, myfunc );
     356             : }
     357             : 
     358             : }
     359             : }
     360             : #endif

Generated by: LCOV version 1.16