Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionWithMatrix.h"
23 : #include "core/ActionRegister.h"
24 :
25 : //+PLUMEDOC MCOLVAR MATRIX_VECTOR_PRODUCT
26 : /*
27 : Calculate the product of the matrix and the vector
28 :
29 : \par Examples
30 :
31 : */
32 : //+ENDPLUMEDOC
33 :
34 : namespace PLMD {
35 : namespace matrixtools {
36 :
37 : class MatrixTimesVector : public ActionWithMatrix {
38 : private:
39 : bool sumrows;
40 : unsigned nderivatives;
41 : std::vector<bool> stored_arg;
42 : public:
43 : static void registerKeywords( Keywords& keys );
44 : explicit MatrixTimesVector(const ActionOptions&);
45 : std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
46 0 : unsigned getNumberOfColumns() const override { plumed_error(); }
47 : unsigned getNumberOfDerivatives();
48 : void prepare() override ;
49 2151 : bool isInSubChain( unsigned& nder ) override { nder = arg_deriv_starts[0]; return true; }
50 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
51 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
52 : void runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
53 : void updateAdditionalIndices( const unsigned& ostrn, MultiValue& myvals ) const override ;
54 : };
55 :
56 : PLUMED_REGISTER_ACTION(MatrixTimesVector,"MATRIX_VECTOR_PRODUCT")
57 :
58 629 : void MatrixTimesVector::registerKeywords( Keywords& keys ) {
59 629 : ActionWithMatrix::registerKeywords(keys); keys.use("ARG");
60 629 : keys.setValueDescription("the vector that is obtained by taking the product between the matrix and the vector that were input");
61 629 : ActionWithValue::useCustomisableComponents(keys);
62 629 : }
63 :
64 6 : std::string MatrixTimesVector::getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const {
65 6 : if( getPntrToArgument(1)->getRank()==1 ) {
66 0 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
67 0 : if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
68 0 : return "the product of the matrix " + getPntrToArgument(0)->getName() + " and the vector " + getPntrToArgument(i)->getName();
69 : }
70 : }
71 : }
72 21 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
73 21 : if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
74 12 : return "the product of the matrix " + getPntrToArgument(i)->getName() + " and the vector " + getPntrToArgument(getNumberOfArguments()-1)->getName();
75 : }
76 : }
77 0 : plumed_merror( "could not understand request for component " + cname ); return "";
78 : }
79 :
80 352 : MatrixTimesVector::MatrixTimesVector(const ActionOptions&ao):
81 : Action(ao),
82 : ActionWithMatrix(ao),
83 352 : sumrows(false)
84 : {
85 352 : if( getNumberOfArguments()<2 ) error("Not enough arguments specified");
86 : unsigned nvectors=0, nmatrices=0;
87 1875 : for(unsigned i=0; i<getNumberOfArguments(); ++i) {
88 1523 : if( getPntrToArgument(i)->hasDerivatives() ) error("arguments should be vectors or matrices");
89 1523 : if( getPntrToArgument(i)->getRank()<=1 ) nvectors++;
90 1523 : if( getPntrToArgument(i)->getRank()==2 ) nmatrices++;
91 : }
92 :
93 352 : std::vector<unsigned> shape(1); shape[0]=getPntrToArgument(0)->getShape()[0];
94 352 : if( nvectors==1 ) {
95 343 : unsigned n = getNumberOfArguments()-1;
96 1320 : for(unsigned i=0; i<n; ++i) {
97 977 : if( getPntrToArgument(i)->getRank()!=2 || getPntrToArgument(i)->hasDerivatives() ) error("all arguments other than last argument should be matrices");
98 977 : if( getPntrToArgument(n)->getRank()==0 ) {
99 1 : if( getPntrToArgument(i)->getShape()[1]!=1 ) error("number of columns in input matrix does not equal number of elements in vector");
100 976 : } else if( getPntrToArgument(i)->getShape()[1]!=getPntrToArgument(n)->getShape()[0] ) error("number of columns in input matrix does not equal number of elements in vector");
101 : }
102 343 : if( getPntrToArgument(n)->getRank()>0 ) {
103 342 : if( getPntrToArgument(n)->getRank()!=1 || getPntrToArgument(n)->hasDerivatives() ) error("last argument to this action should be a vector");
104 : }
105 343 : getPntrToArgument(n)->buildDataStore();
106 :
107 343 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() );
108 343 : if( av ) done_in_chain=canBeAfterInChain( av );
109 :
110 343 : if( getNumberOfArguments()==2 ) {
111 301 : addValue( shape ); setNotPeriodic();
112 : } else {
113 718 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
114 676 : std::string name = getPntrToArgument(i)->getName();
115 1352 : if( name.find_first_of(".")!=std::string::npos ) { std::size_t dot=name.find_first_of("."); name = name.substr(dot+1); }
116 676 : addComponent( name, shape ); componentIsNotPeriodic( name );
117 : }
118 : }
119 343 : if( (getPntrToArgument(n)->getPntrToAction())->getName()=="CONSTANT" ) {
120 306 : sumrows=true;
121 306 : if( getPntrToArgument(n)->getRank()==0 ) {
122 1 : if( fabs( getPntrToArgument(n)->get() - 1.0 )>epsilon ) sumrows = false;
123 : } else {
124 180438 : for(unsigned i=0; i<getPntrToArgument(n)->getShape()[0]; ++i) {
125 180141 : if( fabs( getPntrToArgument(n)->get(i) - 1.0 )>epsilon ) { sumrows=false; break; }
126 : }
127 : }
128 : }
129 9 : } else if( nmatrices==1 ) {
130 9 : if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix");
131 203 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
132 194 : if( getPntrToArgument(i)->getRank()>1 || getPntrToArgument(i)->hasDerivatives() ) error("all arguments other than first argument should be vectors");
133 194 : if( getPntrToArgument(i)->getRank()==0 ) {
134 0 : if( getPntrToArgument(0)->getShape()[1]!=1 ) error("number of columns in input matrix does not equal number of elements in vector");
135 194 : } else if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(i)->getShape()[0] ) error("number of columns in input matrix does not equal number of elements in vector");
136 194 : getPntrToArgument(i)->buildDataStore();
137 : }
138 :
139 9 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() );
140 9 : if( av ) done_in_chain=canBeAfterInChain( av );
141 :
142 203 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
143 194 : std::string name = getPntrToArgument(i)->getName();
144 194 : if( name.find_first_of(".")!=std::string::npos ) { std::size_t dot=name.find_first_of("."); name = name.substr(dot+1); }
145 194 : addComponent( name, shape ); componentIsNotPeriodic( name );
146 : }
147 0 : } else error("You should either have one vector or one matrix in input");
148 :
149 352 : nderivatives = buildArgumentStore(0);
150 352 : std::string headstr=getFirstActionInChain()->getLabel(); stored_arg.resize( getNumberOfArguments() );
151 1875 : for(unsigned i=0; i<getNumberOfArguments(); ++i) stored_arg[i] = getPntrToArgument(i)->ignoreStoredValue( headstr );
152 352 : }
153 :
154 31643 : unsigned MatrixTimesVector::getNumberOfDerivatives() {
155 31643 : return nderivatives;
156 : }
157 :
158 13575 : void MatrixTimesVector::prepare() {
159 13575 : ActionWithVector::prepare(); Value* myval = getPntrToComponent(0);
160 13575 : if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] ) return;
161 10 : std::vector<unsigned> shape(1); shape[0] = getPntrToArgument(0)->getShape()[0]; myval->setShape(shape);
162 : }
163 :
164 6574 : void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
165 6574 : unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
166 6574 : if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 );
167 842762 : for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + i;
168 : myvals.setSplitIndex( size_v + 1 );
169 6574 : }
170 :
171 23970940 : void MatrixTimesVector::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
172 23970940 : unsigned ind2 = index2; if( index2>=getPntrToArgument(0)->getShape()[0] ) ind2 = index2 - getPntrToArgument(0)->getShape()[0];
173 23970940 : if( sumrows ) {
174 22303792 : unsigned n=getNumberOfArguments()-1; double matval = 0;
175 87441027 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
176 65137235 : unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
177 : Value* myarg = getPntrToArgument(i);
178 65137235 : if( !myarg->valueHasBeenSet() ) myvals.addValue( ostrn, myvals.get( myarg->getPositionInStream() ) );
179 14718 : else myvals.addValue( ostrn, myarg->get( index1*myarg->getNumberOfColumns() + ind2, false ) );
180 : // Now lets work out the derivatives
181 65137235 : if( doNotCalculateDerivatives() ) continue;
182 32823346 : addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, 1.0, myvals );
183 : }
184 1667148 : } else if( getPntrToArgument(1)->getRank()==1 ) {
185 1667148 : double matval = 0; Value* myarg = getPntrToArgument(0); unsigned vcol = ind2;
186 1667148 : if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() );
187 : else {
188 827038 : matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
189 827038 : vcol = getPntrToArgument(0)->getRowIndex( index1, ind2 );
190 : }
191 18356786 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
192 16689638 : unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
193 16689638 : double vecval=getArgumentElement( i+1, vcol, myvals );
194 : // And add this part of the product
195 16689638 : myvals.addValue( ostrn, matval*vecval );
196 : // Now lets work out the derivatives
197 16689638 : if( doNotCalculateDerivatives() ) continue;
198 15688768 : addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, vcol, matval, myvals );
199 : }
200 : } else {
201 0 : unsigned n=getNumberOfArguments()-1; double matval = 0; unsigned vcol = ind2;
202 0 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
203 0 : unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
204 : Value* myarg = getPntrToArgument(i);
205 0 : if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() );
206 : else {
207 0 : matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
208 0 : vcol = getPntrToArgument(i)->getRowIndex( index1, ind2 );
209 : }
210 0 : double vecval=getArgumentElement( n, vcol, myvals );
211 : // And add this part of the product
212 0 : myvals.addValue( ostrn, matval*vecval );
213 : // Now lets work out the derivatives
214 0 : if( doNotCalculateDerivatives() ) continue;
215 0 : addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[n], i, n, vcol, matval, myvals );
216 : }
217 : }
218 23970940 : }
219 :
220 472445 : void MatrixTimesVector::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
221 472445 : if( doNotCalculateDerivatives() || !actionInChain() ) return ;
222 :
223 358714 : if( getPntrToArgument(1)->getRank()==1 ) {
224 : unsigned istrn = getPntrToArgument(0)->getPositionInMatrixStash();
225 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
226 1010565 : for(unsigned j=0; j<getNumberOfComponents(); ++j) {
227 671975 : unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
228 40971258 : for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(istrn); ++i) myvals.updateIndex( ostrn, mat_indices[i] );
229 : }
230 : } else {
231 530036 : for(unsigned j=0; j<getNumberOfComponents(); ++j) {
232 : unsigned istrn = getPntrToArgument(j)->getPositionInMatrixStash();
233 509912 : unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
234 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
235 17456348 : for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(istrn); ++i) myvals.updateIndex( ostrn, mat_indices[i] );
236 : }
237 : }
238 : }
239 :
240 372677 : void MatrixTimesVector::updateAdditionalIndices( const unsigned& ostrn, MultiValue& myvals ) const {
241 372677 : unsigned n = getNumberOfArguments()-1; if( getPntrToArgument(1)->getRank()==1 ) n = 1;
242 372677 : unsigned nvals = getPntrToArgument(n)->getNumberOfValues();
243 1387754027 : for(unsigned i=0; i<nvals; ++i) myvals.updateIndex( ostrn, arg_deriv_starts[n] + i );
244 372677 : }
245 :
246 : }
247 : }
|