LCOV - code coverage report
Current view: top level - function - FunctionOfScalar.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 62 66 93.9 %
Date: 2025-04-08 21:11:17 Functions: 55 104 52.9 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_function_FunctionOfScalar_h
      23             : #define __PLUMED_function_FunctionOfScalar_h
      24             : 
      25             : #include "Function.h"
      26             : #include "tools/Matrix.h"
      27             : 
      28             : namespace PLMD {
      29             : namespace function {
      30             : 
      31             : /**
      32             : \ingroup INHERIT
      33             : This is the abstract base class to use for implementing new CV function, within it there is
      34             : \ref AddingAFunction "information" as to how to go about implementing a new function.
      35             : */
      36             : 
      37             : template <class T>
      38             : class FunctionOfScalar : public Function {
      39             : private:
      40             : /// The function that is being computed
      41             :   T myfunc;
      42             : /// Are we on the first step
      43             :   bool firststep;
      44             : public:
      45             :   explicit FunctionOfScalar(const ActionOptions&);
      46        2858 :   virtual ~FunctionOfScalar() {}
      47             : /// Get the label to write in the graph
      48           3 :   std::string writeInGraph() const override {
      49           3 :     return myfunc.getGraphInfo( getName() );
      50             :   }
      51             :   std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
      52             :   void calculate() override;
      53             :   static void registerKeywords(Keywords&);
      54             :   void turnOnDerivatives() override;
      55             : };
      56             : 
      57             : template <class T>
      58        2896 : void FunctionOfScalar<T>::registerKeywords(Keywords& keys) {
      59        2896 :   Function::registerKeywords(keys);
      60        2896 :   std::string name = keys.getDisplayName();
      61        2896 :   std::size_t und=name.find("_SCALAR");
      62        5792 :   keys.setDisplayName( name.substr(0,und) );
      63        2896 :   keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
      64        2358 :   T tfunc;
      65        2896 :   tfunc.registerKeywords( keys );
      66        5792 :   if( keys.getDisplayName()=="SUM" ) {
      67           4 :     keys.setValueDescription("scalar","the sum of all the input arguments");
      68        5788 :   } else if( keys.getDisplayName()=="MEAN" ) {
      69           4 :     keys.setValueDescription("scalar","the mean of all the input arguments");
      70        5784 :   } else if( keys.getDisplayName()=="EVALUATE_FUNCTION_FROM_GRID" ) {
      71           8 :     keys.addInputKeyword("compulsory","ARG","scalar/grid","");
      72             :   }
      73        5733 : }
      74             : 
      75             : template <class T>
      76        1434 : FunctionOfScalar<T>::FunctionOfScalar(const ActionOptions&ao):
      77             :   Action(ao),
      78             :   Function(ao),
      79        1434 :   firststep(true) {
      80        1434 :   myfunc.read( this );
      81             :   // Get the names of the components
      82        1429 :   std::vector<std::string> components( keywords.getOutputComponents() );
      83             :   // Create the values to hold the output
      84        1386 :   std::vector<std::string> str_ind( myfunc.getComponentsPerLabel() );
      85        2858 :   for(unsigned i=0; i<components.size(); ++i) {
      86          13 :     if( str_ind.size()>0 ) {
      87          13 :       std::string compstr = components[i];
      88          13 :       if( compstr==".#!value" ) {
      89             :         compstr = "";
      90             :       }
      91          40 :       for(unsigned j=0; j<str_ind.size(); ++j) {
      92          54 :         addComponentWithDerivatives( compstr + str_ind[j] );
      93             :       }
      94        1416 :     } else if( components[i]==".#!value" ) {
      95        1414 :       addValueWithDerivatives();
      96           2 :     } else if( components[i].find_first_of("_")!=std::string::npos ) {
      97           2 :       if( getNumberOfArguments()==1 ) {
      98           1 :         addValueWithDerivatives();
      99             :       } else {
     100           3 :         for(unsigned j=0; j<getNumberOfArguments(); ++j) {
     101           4 :           addComponentWithDerivatives( getPntrToArgument(j)->getName() + components[i] );
     102             :         }
     103             :       }
     104             :     } else {
     105           0 :       addComponentWithDerivatives( components[i] );
     106             :     }
     107             :   }
     108             :   // Set the periodicities of the output components
     109        1429 :   myfunc.setPeriodicityForOutputs( this );
     110           0 :   myfunc.setPrefactor( this, 1.0 );
     111        1442 : }
     112             : 
     113             : template <class T>
     114           3 : std::string FunctionOfScalar<T>::getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const {
     115           3 :   if( getName().find("SORT")==std::string::npos ) {
     116           0 :     return ActionWithValue::getOutputComponentDescription( cname, keys );
     117             :   }
     118           6 :   return "the " + cname + "th largest of the input scalars";
     119             : }
     120             : 
     121             : template <class T>
     122        2190 : void FunctionOfScalar<T>::turnOnDerivatives() {
     123             :   if( !myfunc.derivativesImplemented() ) {
     124           0 :     error("derivatives have not been implemended for " + getName() );
     125             :   }
     126        2190 :   ActionWithValue::turnOnDerivatives();
     127        2190 : }
     128             : 
     129             : template <class T>
     130       92863 : void FunctionOfScalar<T>::calculate() {
     131       92863 :   if( firststep ) {
     132        1380 :     myfunc.setup( this );
     133        1380 :     firststep=false;
     134             :   }
     135        1675 :   unsigned argstart = myfunc.getArgStart();
     136       92863 :   std::vector<double> args( getNumberOfArguments() - argstart );
     137      210481 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     138      117618 :     args[i-argstart]=getPntrToArgument(i)->get();
     139             :   }
     140       92863 :   std::vector<double> vals( getNumberOfComponents() );
     141       92863 :   Matrix<double> derivatives( getNumberOfComponents(), args.size() );
     142       92863 :   myfunc.calc( this, args, vals, derivatives );
     143      185763 :   for(unsigned i=0; i<vals.size(); ++i) {
     144       92900 :     copyOutput(i)->set(vals[i]);
     145             :   }
     146       92863 :   if( doNotCalculateDerivatives() ) {
     147             :     return;
     148             :   }
     149             : 
     150      160612 :   for(unsigned i=0; i<vals.size(); ++i) {
     151       80321 :     Value* val = getPntrToComponent(i);
     152      177002 :     for(unsigned j=0; j<args.size(); ++j) {
     153       96681 :       setDerivative( val, j, derivatives(i,j) );
     154             :     }
     155             :   }
     156             : }
     157             : 
     158             : }
     159             : }
     160             : #endif

Generated by: LCOV version 1.16