LCOV - code coverage report
Current view: top level - core - ActionWithMatrix.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 27 27 100.0 %
Date: 2024-10-18 14:00:25 Functions: 6 6 100.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_core_ActionWithMatrix_h
      23             : #define __PLUMED_core_ActionWithMatrix_h
      24             : 
      25             : #include "ActionWithVector.h"
      26             : 
      27             : namespace PLMD {
      28             : 
      29             : class ActionWithMatrix : public ActionWithVector {
      30             : private:
      31             :   ActionWithMatrix* next_action_in_chain;
      32             :   ActionWithMatrix* matrix_to_do_before;
      33             :   ActionWithMatrix* matrix_to_do_after;
      34             : /// This holds the bookeeping arrays for sparse matrices
      35             :   std::vector<unsigned> matrix_bookeeping;
      36             : /// Update all the neighbour lists in the chain
      37             :   void updateAllNeighbourLists();
      38             : /// This is used to clear up the matrix elements
      39             :   void clearMatrixElements( MultiValue& myvals ) const ;
      40             : /// This is used to find the total amount of space we need for storing matrix elements
      41             :   void getTotalMatrixBookeeping( unsigned& stashsize );
      42             : /// This transfers the non-zero elements to the Value
      43             :   void transferNonZeroMatrixElementsToValues( unsigned& nval, const std::vector<unsigned>& matbook );
      44             : /// This does the calculation of a particular matrix element
      45             :   void runTask( const std::string& controller, const unsigned& current, const unsigned colno, MultiValue& myvals ) const ;
      46             : protected:
      47             : /// This turns off derivative clearing for contact matrix if we are not storing derivatives
      48             :   bool clearOnEachCycle;
      49             : /// Does the matrix chain continue on from this action
      50             :   bool matrixChainContinues() const ;
      51             : /// This returns the jelem th element of argument ic
      52             :   double getArgumentElement( const unsigned& ic, const unsigned& jelem, const MultiValue& myvals ) const ;
      53             : /// This returns an element of a matrix that is passed an argument
      54             :   double getElementOfMatrixArgument( const unsigned& imat, const unsigned& irow, const unsigned& jcol, const MultiValue& myvals ) const ;
      55             : /// Add derivatives given the derivative wrt to the input vector element as input
      56             :   void addDerivativeOnVectorArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& jelem, const double& der, MultiValue& myvals ) const ;
      57             : /// Add derivatives given the derative wrt to the input matrix element as input
      58             :   void addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const ;
      59             : public:
      60             :   static void registerKeywords( Keywords& keys );
      61             :   explicit ActionWithMatrix(const ActionOptions&);
      62             :   virtual ~ActionWithMatrix();
      63             : ///
      64    66890656 :   virtual bool isAdjacencyMatrix() const { return false; }
      65             : ///
      66             :   void getAllActionLabelsInMatrixChain( std::vector<std::string>& mylabels ) const override ;
      67             : /// Get the first matrix in this chain
      68             :   const ActionWithMatrix* getFirstMatrixInChain() const ;
      69             : ///
      70             :   void finishChainBuild( ActionWithVector* act );
      71             : /// This should return the number of columns to help with sparse storage of matrices
      72             :   virtual unsigned getNumberOfColumns() const = 0;
      73             : /// This requires some thought
      74             :   void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override;
      75             : //// This does some setup before we run over the row of the matrix
      76             :   virtual void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const = 0;
      77             : /// Run over one row of the matrix
      78             :   void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
      79             : /// Gather a row of the matrix
      80             :   void gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals, const unsigned& bufstart, std::vector<double>& buffer ) const override;
      81             : /// Gather all the data from the threads
      82             :   void gatherThreads( const unsigned& nt, const unsigned& bufsize, const std::vector<double>& omp_buffer, std::vector<double>& buffer, MultiValue& myvals ) override ;
      83             : /// Gather all the data from the MPI processes
      84             :   void gatherProcesses( std::vector<double>& buffer ) override;
      85             : /// This is the virtual that will do the calculation of the task for a particular matrix element
      86             :   virtual void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const = 0;
      87             : /// This is the jobs that need to be done when we have run all the jobs in a row of the matrix
      88             :   virtual void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const = 0;
      89             : /// This is overwritten in Adjacency matrix where we have a neighbour list
      90       19272 :   virtual void updateNeighbourList() {}
      91             : /// Run the calculation
      92             :   virtual void calculate() override;
      93             : /// Check if there are forces we need to account for on this task
      94             :   bool checkForTaskForce( const unsigned& itask, const Value* myval ) const override ;
      95             : /// This gathers the force on a particular value
      96             :   void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const override;
      97             : };
      98             : 
      99             : inline
     100             : bool ActionWithMatrix::matrixChainContinues() const {
     101      552731 :   return matrix_to_do_after!=NULL;
     102             : }
     103             : 
     104             : inline
     105    70002481 : double ActionWithMatrix::getArgumentElement( const unsigned& ic, const unsigned& jelem, const MultiValue& myvals ) const {
     106    70002481 :   if( !getPntrToArgument(ic)->valueHasBeenSet() ) return myvals.get( getPntrToArgument(ic)->getPositionInStream() );
     107    68729606 :   return getPntrToArgument(ic)->get( jelem );
     108             : }
     109             : 
     110             : inline
     111    32301796 : double ActionWithMatrix::getElementOfMatrixArgument( const unsigned& imat, const unsigned& irow, const unsigned& jcol, const MultiValue& myvals ) const {
     112             :   plumed_dbg_assert( imat<getNumberOfArguments() && getPntrToArgument(imat)->getRank()==2 && !getPntrToArgument(imat)->hasDerivatives() );
     113    32301796 :   if( !getPntrToArgument(imat)->valueHasBeenSet() ) return myvals.get( getPntrToArgument(imat)->getPositionInStream() );
     114    30768796 :   return getArgumentElement( imat, irow*getPntrToArgument(imat)->getShape()[1] + jcol, myvals );
     115             : }
     116             : 
     117             : inline
     118    32289795 : void ActionWithMatrix::addDerivativeOnVectorArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& jelem, const double& der, MultiValue& myvals ) const {
     119             :   plumed_dbg_massert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()<2, "failing in action " + getName() + " with label " + getLabel() );
     120    32289795 :   unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
     121    32289795 :   if( !inchain ) {
     122    31433348 :     myvals.addDerivative( ostrn, vstart + jelem, der ); myvals.updateIndex( ostrn, vstart + jelem );
     123             :   } else {
     124             :     unsigned istrn = getPntrToArgument(jarg)->getPositionInStream();
     125    43315006 :     for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
     126    42458559 :       unsigned kind=myvals.getActiveIndex(istrn,k);
     127    42458559 :       myvals.addDerivative( ostrn, arg_deriv_starts[jarg] + kind, der*myvals.getDerivative( istrn, kind ) );
     128    42458559 :       myvals.updateIndex( ostrn, arg_deriv_starts[jarg] + kind );
     129             :     }
     130             :   }
     131    32289795 : }
     132             : 
     133             : inline
     134    56805054 : void ActionWithMatrix::addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const {
     135             :   plumed_dbg_assert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()==2 && !getPntrToArgument(jarg)->hasDerivatives() );
     136    56805054 :   unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
     137    56805054 :   if( !inchain ) {
     138     2243240 :     unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns() + jcol;
     139     2243240 :     myvals.addDerivative( ostrn, dloc, der ); myvals.updateIndex( ostrn, dloc );
     140             :   } else {
     141             :     unsigned istrn = getPntrToArgument(jarg)->getPositionInStream();
     142   353992585 :     for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
     143   299430771 :       unsigned kind=myvals.getActiveIndex(istrn,k);
     144   299430771 :       myvals.addDerivative( ostrn, kind, der*myvals.getDerivative( istrn, kind ) );
     145             :     }
     146             :   }
     147    56805054 : }
     148             : 
     149             : }
     150             : #endif

Generated by: LCOV version 1.16