Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionWithMatrix.h"
23 : #include "core/ActionRegister.h"
24 :
25 : //+PLUMEDOC MCOLVAR MATRIX_VECTOR_PRODUCT
26 : /*
27 : Calculate the product of the matrix and the vector
28 :
29 : Thiis action allows you to [multiply](https://en.wikipedia.org/wiki/Matrix_multiplication) a matrix and a vector together.
30 : This action is primarily used to calculate coordination numbers and symmetry functions, which is what is done by the example below:
31 :
32 : ```plumed
33 : c1: CONTACT_MATRIX GROUP=1-7 SWITCH={RATIONAL R_0=2.6 NN=6 MM=12}
34 : ones: ONES SIZE=7
35 : cc: MATRIX_VECTOR_PRODUCT ARG=c1,ones
36 : PRINT ARG=cc FILE=colvar
37 : ```
38 :
39 : Notice that you can use this action to multiply multiple matrices by a single vector as shown below:
40 :
41 : ```plumed
42 : c1: CONTACT_MATRIX COMPONENTS GROUP=1-7 SWITCH={RATIONAL R_0=2.6 NN=6 MM=12 D_MAX=10.0}
43 : ones: ONES SIZE=7
44 : cc: MATRIX_VECTOR_PRODUCT ARG=c1.x,c1.y,c1.z,ones
45 : PRINT ARG=cc.x,cc.y,cc.z FILE=colvar
46 : ```
47 :
48 : Notice that if you use this options all the input matrices must have the same sparsity pattern. This feature
49 : was implemented in order to making caluclating Steinhardt parameters such as [Q6](Q6.md) straightforward.
50 :
51 : You can also multiply a single matrix by multiple vectors:
52 :
53 : ```plumed
54 : c1: CONTACT_MATRIX GROUP=1-7 SWITCH={RATIONAL R_0=2.6 NN=6 MM=12 D_MAX=10.0}
55 : ones: ONES SIZE=7
56 : twos: CONSTANT VALUES=1,2,3,4,5,6,7
57 : cc: MATRIX_VECTOR_PRODUCT ARG=c1,ones,twos
58 : PRINT ARG=cc.ones,cc.twos FILE=colvar
59 : ```
60 :
61 : This feature was implemented to make calculating local averages of the Steinhard parameters straightforward.
62 :
63 : */
64 : //+ENDPLUMEDOC
65 :
66 : namespace PLMD {
67 : namespace matrixtools {
68 :
69 : class MatrixTimesVector : public ActionWithMatrix {
70 : private:
71 : bool sumrows;
72 : unsigned nderivatives;
73 : std::vector<bool> stored_arg;
74 : public:
75 : static void registerKeywords( Keywords& keys );
76 : explicit MatrixTimesVector(const ActionOptions&);
77 : std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
78 0 : unsigned getNumberOfColumns() const override {
79 0 : plumed_error();
80 : }
81 : unsigned getNumberOfDerivatives();
82 : void prepare() override ;
83 2151 : bool isInSubChain( unsigned& nder ) override {
84 2151 : nder = arg_deriv_starts[0];
85 2151 : return true;
86 : }
87 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
88 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
89 : void runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
90 : void updateAdditionalIndices( const unsigned& ostrn, MultiValue& myvals ) const override ;
91 : };
92 :
93 : PLUMED_REGISTER_ACTION(MatrixTimesVector,"MATRIX_VECTOR_PRODUCT")
94 :
95 629 : void MatrixTimesVector::registerKeywords( Keywords& keys ) {
96 629 : ActionWithMatrix::registerKeywords(keys);
97 1258 : keys.addInputKeyword("compulsory","ARG","matrix/vector/scalar","the label for the matrix and the vector/scalar that are being multiplied. Alternatively, you can provide labels for multiple matrices and a single vector or labels for a single matrix and multiple vectors. In these cases multiple matrix vector products will be computed.");
98 1258 : keys.setValueDescription("vector","the vector that is obtained by taking the product between the matrix and the vector that were input");
99 629 : ActionWithValue::useCustomisableComponents(keys);
100 629 : }
101 :
102 6 : std::string MatrixTimesVector::getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const {
103 6 : if( getPntrToArgument(1)->getRank()==1 ) {
104 0 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
105 0 : if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
106 0 : return "the product of the matrix " + getPntrToArgument(0)->getName() + " and the vector " + getPntrToArgument(i)->getName();
107 : }
108 : }
109 : }
110 21 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
111 21 : if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
112 12 : return "the product of the matrix " + getPntrToArgument(i)->getName() + " and the vector " + getPntrToArgument(getNumberOfArguments()-1)->getName();
113 : }
114 : }
115 0 : plumed_merror( "could not understand request for component " + cname );
116 : return "";
117 : }
118 :
119 352 : MatrixTimesVector::MatrixTimesVector(const ActionOptions&ao):
120 : Action(ao),
121 : ActionWithMatrix(ao),
122 352 : sumrows(false) {
123 352 : if( getNumberOfArguments()<2 ) {
124 0 : error("Not enough arguments specified");
125 : }
126 : unsigned nvectors=0, nmatrices=0;
127 1875 : for(unsigned i=0; i<getNumberOfArguments(); ++i) {
128 1523 : if( getPntrToArgument(i)->hasDerivatives() ) {
129 0 : error("arguments should be vectors or matrices");
130 : }
131 1523 : if( getPntrToArgument(i)->getRank()<=1 ) {
132 537 : nvectors++;
133 : }
134 1523 : if( getPntrToArgument(i)->getRank()==2 ) {
135 986 : nmatrices++;
136 : }
137 : }
138 :
139 352 : std::vector<unsigned> shape(1);
140 352 : shape[0]=getPntrToArgument(0)->getShape()[0];
141 352 : if( nvectors==1 ) {
142 343 : unsigned n = getNumberOfArguments()-1;
143 1320 : for(unsigned i=0; i<n; ++i) {
144 977 : if( getPntrToArgument(i)->getRank()!=2 || getPntrToArgument(i)->hasDerivatives() ) {
145 0 : error("all arguments other than last argument should be matrices");
146 : }
147 977 : if( getPntrToArgument(n)->getRank()==0 ) {
148 1 : if( getPntrToArgument(i)->getShape()[1]!=1 ) {
149 0 : error("number of columns in input matrix does not equal number of elements in vector");
150 : }
151 976 : } else if( getPntrToArgument(i)->getShape()[1]!=getPntrToArgument(n)->getShape()[0] ) {
152 : std::string str_nmat, str_nvec;
153 0 : Tools::convert( getPntrToArgument(i)->getShape()[1], str_nmat);
154 0 : Tools::convert( getPntrToArgument(n)->getShape()[0], str_nvec );
155 0 : error("number of columns in input matrix is " + str_nmat + " which does not equal number of elements in vector, which is " + str_nvec);
156 : }
157 : }
158 343 : if( getPntrToArgument(n)->getRank()>0 ) {
159 342 : if( getPntrToArgument(n)->getRank()!=1 || getPntrToArgument(n)->hasDerivatives() ) {
160 0 : error("last argument to this action should be a vector");
161 : }
162 : }
163 343 : getPntrToArgument(n)->buildDataStore();
164 :
165 343 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() );
166 343 : if( av ) {
167 314 : done_in_chain=canBeAfterInChain( av );
168 : }
169 :
170 343 : if( getNumberOfArguments()==2 ) {
171 301 : addValue( shape );
172 301 : setNotPeriodic();
173 : } else {
174 718 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
175 676 : std::string name = getPntrToArgument(i)->getName();
176 676 : if( name.find_first_of(".")!=std::string::npos ) {
177 676 : std::size_t dot=name.find_first_of(".");
178 1352 : name = name.substr(dot+1);
179 : }
180 676 : addComponent( name, shape );
181 676 : componentIsNotPeriodic( name );
182 : }
183 : }
184 343 : if( (getPntrToArgument(n)->getPntrToAction())->getName()=="CONSTANT" ) {
185 306 : sumrows=true;
186 306 : if( getPntrToArgument(n)->getRank()==0 ) {
187 1 : if( fabs( getPntrToArgument(n)->get() - 1.0 )>epsilon ) {
188 0 : sumrows = false;
189 : }
190 : } else {
191 180438 : for(unsigned i=0; i<getPntrToArgument(n)->getShape()[0]; ++i) {
192 180141 : if( fabs( getPntrToArgument(n)->get(i) - 1.0 )>epsilon ) {
193 8 : sumrows=false;
194 8 : break;
195 : }
196 : }
197 : }
198 : }
199 9 : } else if( nmatrices==1 ) {
200 9 : if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) {
201 0 : error("first argument to this action should be a matrix");
202 : }
203 203 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
204 194 : if( getPntrToArgument(i)->getRank()>1 || getPntrToArgument(i)->hasDerivatives() ) {
205 0 : error("all arguments other than first argument should be vectors");
206 : }
207 194 : if( getPntrToArgument(i)->getRank()==0 ) {
208 0 : if( getPntrToArgument(0)->getShape()[1]!=1 ) {
209 0 : error("number of columns in input matrix does not equal number of elements in vector");
210 : }
211 194 : } else if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(i)->getShape()[0] ) {
212 0 : error("number of columns in input matrix does not equal number of elements in vector");
213 : }
214 194 : getPntrToArgument(i)->buildDataStore();
215 : }
216 :
217 9 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() );
218 9 : if( av ) {
219 9 : done_in_chain=canBeAfterInChain( av );
220 : }
221 :
222 203 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
223 194 : std::string name = getPntrToArgument(i)->getName();
224 194 : if( name.find_first_of(".")!=std::string::npos ) {
225 0 : std::size_t dot=name.find_first_of(".");
226 0 : name = name.substr(dot+1);
227 : }
228 194 : addComponent( name, shape );
229 194 : componentIsNotPeriodic( name );
230 : }
231 : } else {
232 0 : error("You should either have one vector or one matrix in input");
233 : }
234 :
235 352 : nderivatives = buildArgumentStore(0);
236 352 : std::string headstr=getFirstActionInChain()->getLabel();
237 352 : stored_arg.resize( getNumberOfArguments() );
238 1875 : for(unsigned i=0; i<getNumberOfArguments(); ++i) {
239 1523 : stored_arg[i] = getPntrToArgument(i)->ignoreStoredValue( headstr );
240 : }
241 352 : }
242 :
243 31643 : unsigned MatrixTimesVector::getNumberOfDerivatives() {
244 31643 : return nderivatives;
245 : }
246 :
247 13575 : void MatrixTimesVector::prepare() {
248 13575 : ActionWithVector::prepare();
249 13575 : Value* myval = getPntrToComponent(0);
250 13575 : if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] ) {
251 13565 : return;
252 : }
253 10 : std::vector<unsigned> shape(1);
254 10 : shape[0] = getPntrToArgument(0)->getShape()[0];
255 10 : myval->setShape(shape);
256 : }
257 :
258 6574 : void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
259 6574 : unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
260 6574 : if( indices.size()!=size_v+1 ) {
261 3508 : indices.resize( size_v + 1 );
262 : }
263 842762 : for(unsigned i=0; i<size_v; ++i) {
264 836188 : indices[i+1] = start_n + i;
265 : }
266 : myvals.setSplitIndex( size_v + 1 );
267 6574 : }
268 :
269 23970940 : void MatrixTimesVector::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
270 23970940 : unsigned ind2 = index2;
271 23970940 : if( index2>=getPntrToArgument(0)->getShape()[0] ) {
272 1600742 : ind2 = index2 - getPntrToArgument(0)->getShape()[0];
273 : }
274 23970940 : if( sumrows ) {
275 22303792 : unsigned n=getNumberOfArguments()-1;
276 : double matval = 0;
277 87441027 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
278 65137235 : unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
279 : Value* myarg = getPntrToArgument(i);
280 65137235 : if( !myarg->valueHasBeenSet() ) {
281 65122517 : myvals.addValue( ostrn, myvals.get( myarg->getPositionInStream() ) );
282 : } else {
283 14718 : myvals.addValue( ostrn, myarg->get( index1*myarg->getNumberOfColumns() + ind2, false ) );
284 : }
285 : // Now lets work out the derivatives
286 65137235 : if( doNotCalculateDerivatives() ) {
287 32313889 : continue;
288 : }
289 32823346 : addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, 1.0, myvals );
290 : }
291 1667148 : } else if( getPntrToArgument(1)->getRank()==1 ) {
292 : double matval = 0;
293 : Value* myarg = getPntrToArgument(0);
294 1667148 : unsigned vcol = ind2;
295 1667148 : if( !myarg->valueHasBeenSet() ) {
296 840110 : matval = myvals.get( myarg->getPositionInStream() );
297 : } else {
298 827038 : matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
299 827038 : vcol = getPntrToArgument(0)->getRowIndex( index1, ind2 );
300 : }
301 18356786 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
302 16689638 : unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
303 16689638 : double vecval=getArgumentElement( i+1, vcol, myvals );
304 : // And add this part of the product
305 16689638 : myvals.addValue( ostrn, matval*vecval );
306 : // Now lets work out the derivatives
307 16689638 : if( doNotCalculateDerivatives() ) {
308 1000870 : continue;
309 : }
310 15688768 : addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals );
311 15688768 : addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, vcol, matval, myvals );
312 : }
313 : } else {
314 0 : unsigned n=getNumberOfArguments()-1;
315 0 : double matval = 0;
316 0 : unsigned vcol = ind2;
317 0 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
318 0 : unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
319 : Value* myarg = getPntrToArgument(i);
320 0 : if( !myarg->valueHasBeenSet() ) {
321 0 : matval = myvals.get( myarg->getPositionInStream() );
322 : } else {
323 0 : matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
324 0 : vcol = getPntrToArgument(i)->getRowIndex( index1, ind2 );
325 : }
326 0 : double vecval=getArgumentElement( n, vcol, myvals );
327 : // And add this part of the product
328 0 : myvals.addValue( ostrn, matval*vecval );
329 : // Now lets work out the derivatives
330 0 : if( doNotCalculateDerivatives() ) {
331 0 : continue;
332 : }
333 0 : addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals );
334 0 : addDerivativeOnVectorArgument( stored_arg[n], i, n, vcol, matval, myvals );
335 : }
336 : }
337 23970940 : }
338 :
339 472445 : void MatrixTimesVector::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
340 472445 : if( doNotCalculateDerivatives() || !actionInChain() ) {
341 : return ;
342 : }
343 :
344 358714 : if( getPntrToArgument(1)->getRank()==1 ) {
345 : unsigned istrn = getPntrToArgument(0)->getPositionInMatrixStash();
346 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
347 1010565 : for(unsigned j=0; j<getNumberOfComponents(); ++j) {
348 671975 : unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
349 40971258 : for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(istrn); ++i) {
350 40299283 : myvals.updateIndex( ostrn, mat_indices[i] );
351 : }
352 : }
353 : } else {
354 530036 : for(unsigned j=0; j<getNumberOfComponents(); ++j) {
355 : unsigned istrn = getPntrToArgument(j)->getPositionInMatrixStash();
356 509912 : unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
357 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
358 17456348 : for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(istrn); ++i) {
359 16946436 : myvals.updateIndex( ostrn, mat_indices[i] );
360 : }
361 : }
362 : }
363 : }
364 :
365 372677 : void MatrixTimesVector::updateAdditionalIndices( const unsigned& ostrn, MultiValue& myvals ) const {
366 372677 : unsigned n = getNumberOfArguments()-1;
367 372677 : if( getPntrToArgument(1)->getRank()==1 ) {
368 : n = 1;
369 : }
370 372677 : unsigned nvals = getPntrToArgument(n)->getNumberOfValues();
371 1387754027 : for(unsigned i=0; i<nvals; ++i) {
372 1387381350 : myvals.updateIndex( ostrn, arg_deriv_starts[n] + i );
373 : }
374 372677 : }
375 :
376 : }
377 : }
|