Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #ifndef __PLUMED_core_ActionWithMatrix_h
23 : #define __PLUMED_core_ActionWithMatrix_h
24 :
25 : #include "ActionWithVector.h"
26 :
27 : namespace PLMD {
28 :
29 : class ActionWithMatrix : public ActionWithVector {
30 : private:
31 : ActionWithMatrix* next_action_in_chain;
32 : ActionWithMatrix* matrix_to_do_before;
33 : ActionWithMatrix* matrix_to_do_after;
34 : /// This holds the bookeeping arrays for sparse matrices
35 : std::vector<unsigned> matrix_bookeeping;
36 : /// Update all the neighbour lists in the chain
37 : void updateAllNeighbourLists();
38 : /// This is used to clear up the matrix elements
39 : void clearMatrixElements( MultiValue& myvals ) const ;
40 : /// This is used to find the total amount of space we need for storing matrix elements
41 : void getTotalMatrixBookeeping( unsigned& stashsize );
42 : /// This transfers the non-zero elements to the Value
43 : void transferNonZeroMatrixElementsToValues( unsigned& nval, const std::vector<unsigned>& matbook );
44 : /// This does the calculation of a particular matrix element
45 : void runTask( const std::string& controller, const unsigned& current, const unsigned colno, MultiValue& myvals ) const ;
46 : protected:
47 : /// This turns off derivative clearing for contact matrix if we are not storing derivatives
48 : bool clearOnEachCycle;
49 : /// Does the matrix chain continue on from this action
50 : bool matrixChainContinues() const ;
51 : /// This returns the jelem th element of argument ic
52 : double getArgumentElement( const unsigned& ic, const unsigned& jelem, const MultiValue& myvals ) const ;
53 : /// This returns an element of a matrix that is passed an argument
54 : double getElementOfMatrixArgument( const unsigned& imat, const unsigned& irow, const unsigned& jcol, const MultiValue& myvals ) const ;
55 : /// Add derivatives given the derivative wrt to the input vector element as input
56 : void addDerivativeOnVectorArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& jelem, const double& der, MultiValue& myvals ) const ;
57 : /// Add derivatives given the derative wrt to the input matrix element as input
58 : void addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const ;
59 : public:
60 : static void registerKeywords( Keywords& keys );
61 : explicit ActionWithMatrix(const ActionOptions&);
62 : virtual ~ActionWithMatrix();
63 : ///
64 66890656 : virtual bool isAdjacencyMatrix() const { return false; }
65 : ///
66 : void getAllActionLabelsInMatrixChain( std::vector<std::string>& mylabels ) const override ;
67 : /// Get the first matrix in this chain
68 : const ActionWithMatrix* getFirstMatrixInChain() const ;
69 : ///
70 : void finishChainBuild( ActionWithVector* act );
71 : /// This should return the number of columns to help with sparse storage of matrices
72 : virtual unsigned getNumberOfColumns() const = 0;
73 : /// This requires some thought
74 : void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override;
75 : //// This does some setup before we run over the row of the matrix
76 : virtual void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const = 0;
77 : /// Run over one row of the matrix
78 : void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
79 : /// Gather a row of the matrix
80 : void gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals, const unsigned& bufstart, std::vector<double>& buffer ) const override;
81 : /// Gather all the data from the threads
82 : void gatherThreads( const unsigned& nt, const unsigned& bufsize, const std::vector<double>& omp_buffer, std::vector<double>& buffer, MultiValue& myvals ) override ;
83 : /// Gather all the data from the MPI processes
84 : void gatherProcesses( std::vector<double>& buffer ) override;
85 : /// This is the virtual that will do the calculation of the task for a particular matrix element
86 : virtual void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const = 0;
87 : /// This is the jobs that need to be done when we have run all the jobs in a row of the matrix
88 : virtual void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const = 0;
89 : /// This is overwritten in Adjacency matrix where we have a neighbour list
90 19272 : virtual void updateNeighbourList() {}
91 : /// Run the calculation
92 : virtual void calculate() override;
93 : /// Check if there are forces we need to account for on this task
94 : bool checkForTaskForce( const unsigned& itask, const Value* myval ) const override ;
95 : /// This gathers the force on a particular value
96 : void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const override;
97 : };
98 :
99 : inline
100 : bool ActionWithMatrix::matrixChainContinues() const {
101 552731 : return matrix_to_do_after!=NULL;
102 : }
103 :
104 : inline
105 70002481 : double ActionWithMatrix::getArgumentElement( const unsigned& ic, const unsigned& jelem, const MultiValue& myvals ) const {
106 70002481 : if( !getPntrToArgument(ic)->valueHasBeenSet() ) return myvals.get( getPntrToArgument(ic)->getPositionInStream() );
107 68729606 : return getPntrToArgument(ic)->get( jelem );
108 : }
109 :
110 : inline
111 32301796 : double ActionWithMatrix::getElementOfMatrixArgument( const unsigned& imat, const unsigned& irow, const unsigned& jcol, const MultiValue& myvals ) const {
112 : plumed_dbg_assert( imat<getNumberOfArguments() && getPntrToArgument(imat)->getRank()==2 && !getPntrToArgument(imat)->hasDerivatives() );
113 32301796 : if( !getPntrToArgument(imat)->valueHasBeenSet() ) return myvals.get( getPntrToArgument(imat)->getPositionInStream() );
114 30768796 : return getArgumentElement( imat, irow*getPntrToArgument(imat)->getShape()[1] + jcol, myvals );
115 : }
116 :
117 : inline
118 32289795 : void ActionWithMatrix::addDerivativeOnVectorArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& jelem, const double& der, MultiValue& myvals ) const {
119 : plumed_dbg_massert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()<2, "failing in action " + getName() + " with label " + getLabel() );
120 32289795 : unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
121 32289795 : if( !inchain ) {
122 31433348 : myvals.addDerivative( ostrn, vstart + jelem, der ); myvals.updateIndex( ostrn, vstart + jelem );
123 : } else {
124 : unsigned istrn = getPntrToArgument(jarg)->getPositionInStream();
125 43315006 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
126 42458559 : unsigned kind=myvals.getActiveIndex(istrn,k);
127 42458559 : myvals.addDerivative( ostrn, arg_deriv_starts[jarg] + kind, der*myvals.getDerivative( istrn, kind ) );
128 42458559 : myvals.updateIndex( ostrn, arg_deriv_starts[jarg] + kind );
129 : }
130 : }
131 32289795 : }
132 :
133 : inline
134 56805054 : void ActionWithMatrix::addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const {
135 : plumed_dbg_assert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()==2 && !getPntrToArgument(jarg)->hasDerivatives() );
136 56805054 : unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
137 56805054 : if( !inchain ) {
138 2243240 : unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns() + jcol;
139 2243240 : myvals.addDerivative( ostrn, dloc, der ); myvals.updateIndex( ostrn, dloc );
140 : } else {
141 : unsigned istrn = getPntrToArgument(jarg)->getPositionInStream();
142 353992585 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
143 299430771 : unsigned kind=myvals.getActiveIndex(istrn,k);
144 299430771 : myvals.addDerivative( ostrn, kind, der*myvals.getDerivative( istrn, kind ) );
145 : }
146 : }
147 56805054 : }
148 :
149 : }
150 : #endif
|