Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #ifndef __PLUMED_core_ActionWithMatrix_h
23 : #define __PLUMED_core_ActionWithMatrix_h
24 :
25 : #include "ActionWithVector.h"
26 :
27 : namespace PLMD {
28 :
29 : class ActionWithMatrix : public ActionWithVector {
30 : private:
31 : ActionWithMatrix* next_action_in_chain;
32 : ActionWithMatrix* matrix_to_do_before;
33 : ActionWithMatrix* matrix_to_do_after;
34 : /// This holds the bookeeping arrays for sparse matrices
35 : std::vector<unsigned> matrix_bookeeping;
36 : /// Update all the neighbour lists in the chain
37 : void updateAllNeighbourLists();
38 : /// This is used to clear up the matrix elements
39 : void clearMatrixElements( MultiValue& myvals ) const ;
40 : /// This is used to find the total amount of space we need for storing matrix elements
41 : void getTotalMatrixBookeeping( unsigned& stashsize );
42 : /// This transfers the non-zero elements to the Value
43 : void transferNonZeroMatrixElementsToValues( unsigned& nval, const std::vector<unsigned>& matbook );
44 : /// This does the calculation of a particular matrix element
45 : void runTask( const std::string& controller, const unsigned& current, const unsigned colno, MultiValue& myvals ) const ;
46 : protected:
47 : /// This turns off derivative clearing for contact matrix if we are not storing derivatives
48 : bool clearOnEachCycle;
49 : /// Does the matrix chain continue on from this action
50 : bool matrixChainContinues() const ;
51 : /// This returns the jelem th element of argument ic
52 : double getArgumentElement( const unsigned& ic, const unsigned& jelem, const MultiValue& myvals ) const ;
53 : /// This returns an element of a matrix that is passed an argument
54 : double getElementOfMatrixArgument( const unsigned& imat, const unsigned& irow, const unsigned& jcol, const MultiValue& myvals ) const ;
55 : /// Add derivatives given the derivative wrt to the input vector element as input
56 : void addDerivativeOnVectorArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& jelem, const double& der, MultiValue& myvals ) const ;
57 : /// Add derivatives given the derative wrt to the input matrix element as input
58 : void addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const ;
59 : public:
60 : static void registerKeywords( Keywords& keys );
61 : explicit ActionWithMatrix(const ActionOptions&);
62 : virtual ~ActionWithMatrix();
63 : ///
64 66890656 : virtual bool isAdjacencyMatrix() const {
65 66890656 : return false;
66 : }
67 : ///
68 : void getAllActionLabelsInMatrixChain( std::vector<std::string>& mylabels ) const override ;
69 : /// Get the first matrix in this chain
70 : const ActionWithMatrix* getFirstMatrixInChain() const ;
71 : ///
72 : void finishChainBuild( ActionWithVector* act );
73 : /// This should return the number of columns to help with sparse storage of matrices
74 : virtual unsigned getNumberOfColumns() const = 0;
75 : /// This requires some thought
76 : void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override;
77 : //// This does some setup before we run over the row of the matrix
78 : virtual void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const = 0;
79 : /// Run over one row of the matrix
80 : void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
81 : /// Gather a row of the matrix
82 : void gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals, const unsigned& bufstart, std::vector<double>& buffer ) const override;
83 : /// Gather all the data from the threads
84 : void gatherThreads( const unsigned& nt, const unsigned& bufsize, const std::vector<double>& omp_buffer, std::vector<double>& buffer, MultiValue& myvals ) override ;
85 : /// Gather all the data from the MPI processes
86 : void gatherProcesses( std::vector<double>& buffer ) override;
87 : /// This is the virtual that will do the calculation of the task for a particular matrix element
88 : virtual void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const = 0;
89 : /// This is the jobs that need to be done when we have run all the jobs in a row of the matrix
90 : virtual void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const = 0;
91 : /// This is overwritten in Adjacency matrix where we have a neighbour list
92 19272 : virtual void updateNeighbourList() {}
93 : /// Run the calculation
94 : virtual void calculate() override;
95 : /// Check if there are forces we need to account for on this task
96 : bool checkForTaskForce( const unsigned& itask, const Value* myval ) const override ;
97 : /// This gathers the force on a particular value
98 : void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const override;
99 : };
100 :
101 : inline
102 : bool ActionWithMatrix::matrixChainContinues() const {
103 552731 : return matrix_to_do_after!=NULL;
104 : }
105 :
106 : inline
107 70002481 : double ActionWithMatrix::getArgumentElement( const unsigned& ic, const unsigned& jelem, const MultiValue& myvals ) const {
108 70002481 : if( !getPntrToArgument(ic)->valueHasBeenSet() ) {
109 1272875 : return myvals.get( getPntrToArgument(ic)->getPositionInStream() );
110 : }
111 68729606 : return getPntrToArgument(ic)->get( jelem );
112 : }
113 :
114 : inline
115 32301796 : double ActionWithMatrix::getElementOfMatrixArgument( const unsigned& imat, const unsigned& irow, const unsigned& jcol, const MultiValue& myvals ) const {
116 : plumed_dbg_assert( imat<getNumberOfArguments() && getPntrToArgument(imat)->getRank()==2 && !getPntrToArgument(imat)->hasDerivatives() );
117 32301796 : if( !getPntrToArgument(imat)->valueHasBeenSet() ) {
118 1533000 : return myvals.get( getPntrToArgument(imat)->getPositionInStream() );
119 : }
120 30768796 : return getArgumentElement( imat, irow*getPntrToArgument(imat)->getShape()[1] + jcol, myvals );
121 : }
122 :
123 : inline
124 32289795 : void ActionWithMatrix::addDerivativeOnVectorArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& jelem, const double& der, MultiValue& myvals ) const {
125 : plumed_dbg_massert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()<2, "failing in action " + getName() + " with label " + getLabel() );
126 32289795 : unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
127 32289795 : if( !inchain ) {
128 31433348 : myvals.addDerivative( ostrn, vstart + jelem, der );
129 31433348 : myvals.updateIndex( ostrn, vstart + jelem );
130 : } else {
131 : unsigned istrn = getPntrToArgument(jarg)->getPositionInStream();
132 43315006 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
133 42458559 : unsigned kind=myvals.getActiveIndex(istrn,k);
134 42458559 : myvals.addDerivative( ostrn, arg_deriv_starts[jarg] + kind, der*myvals.getDerivative( istrn, kind ) );
135 42458559 : myvals.updateIndex( ostrn, arg_deriv_starts[jarg] + kind );
136 : }
137 : }
138 32289795 : }
139 :
140 : inline
141 56805054 : void ActionWithMatrix::addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const {
142 : plumed_dbg_assert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()==2 && !getPntrToArgument(jarg)->hasDerivatives() );
143 56805054 : unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
144 56805054 : if( !inchain ) {
145 2243240 : unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns() + jcol;
146 2243240 : myvals.addDerivative( ostrn, dloc, der );
147 2243240 : myvals.updateIndex( ostrn, dloc );
148 : } else {
149 : unsigned istrn = getPntrToArgument(jarg)->getPositionInStream();
150 353992585 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
151 299430771 : unsigned kind=myvals.getActiveIndex(istrn,k);
152 299430771 : myvals.addDerivative( ostrn, kind, der*myvals.getDerivative( istrn, kind ) );
153 : }
154 : }
155 56805054 : }
156 :
157 : }
158 : #endif
|