Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2020 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #ifndef __PLUMED_function_FunctionOfMatrix_h
23 : #define __PLUMED_function_FunctionOfMatrix_h
24 :
25 : #include "core/ActionWithMatrix.h"
26 : #include "FunctionOfVector.h"
27 : #include "Sum.h"
28 : #include "tools/Matrix.h"
29 :
30 : namespace PLMD {
31 : namespace function {
32 :
33 : template <class T>
34 : class FunctionOfMatrix : public ActionWithMatrix {
35 : private:
36 : /// Is this the first step of the calculation
37 : bool firststep;
38 : /// The function that is being computed
39 : T myfunc;
40 : /// The number of derivatives for this action
41 : unsigned nderivatives;
42 : /// A vector that tells us if we have stored the input value
43 : std::vector<bool> stored_arguments;
44 : /// Switch off updating the arguments for this action
45 : std::vector<bool> update_arguments;
46 : /// The list of actiosn in this chain
47 : std::vector<std::string> actionsLabelsInChain;
48 : /// Get the shape of the output matrix
49 : std::vector<unsigned> getValueShapeFromArguments();
50 : public:
51 : static void registerKeywords(Keywords&);
52 : explicit FunctionOfMatrix(const ActionOptions&);
53 : /// Get the label to write in the graph
54 0 : std::string writeInGraph() const override {
55 0 : return myfunc.getGraphInfo( getName() );
56 : }
57 : /// Make sure the derivatives are turned on
58 : void turnOnDerivatives() override;
59 : /// Get the number of derivatives for this action
60 : unsigned getNumberOfDerivatives() override ;
61 : /// Resize the matrices
62 : void prepare() override ;
63 : /// This gets the number of columns
64 : unsigned getNumberOfColumns() const override ;
65 : /// This checks for tasks in the parent class
66 : // void buildTaskListFromArgumentRequests( const unsigned& ntasks, bool& reduce, std::set<AtomNumber>& otasks ) override ;
67 : /// This ensures that we create some bookeeping stuff during the first step
68 : void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override ;
69 : /// This sets up for the task
70 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
71 : /// Calculate the full matrix
72 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override ;
73 : /// This updates the indices for the matrix
74 : // void updateCentralMatrixIndex( const unsigned& ind, const std::vector<unsigned>& indices, MultiValue& myvals ) const override ;
75 : void runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
76 : };
77 :
78 : template <class T>
79 1013 : void FunctionOfMatrix<T>::registerKeywords(Keywords& keys ) {
80 1013 : ActionWithMatrix::registerKeywords(keys);
81 1013 : std::string name = keys.getDisplayName();
82 1013 : std::size_t und=name.find("_MATRIX");
83 1013 : keys.setDisplayName( name.substr(0,und) );
84 2026 : keys.addInputKeyword("compulsory","ARG","scalar/matrix","the labels of the scalar and matrices that on which the function is being calculated elementwise");
85 1013 : keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
86 1013 : keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function. If the output is not periodic you must state this using PERIODIC=NO");
87 745 : T tfunc;
88 1013 : tfunc.registerKeywords( keys );
89 2026 : if( keys.getDisplayName()=="SUM" ) {
90 168 : keys.setValueDescription("scalar","the sum of all the elements in the input matrix");
91 1858 : } else if( keys.getDisplayName()=="HIGHEST" ) {
92 0 : keys.setValueDescription("scalar","the largest element of the input matrix");
93 1858 : } else if( keys.getDisplayName()=="LOWEST" ) {
94 0 : keys.setValueDescription("scalar","the smallest element in the input matrix");
95 1858 : } else if( keys.outputComponentExists(".#!value") ) {
96 1672 : keys.setValueDescription("matrix","the matrix obtained by doing an element-wise application of " + keys.getOutputComponentDescription(".#!value") + " to the input matrix");
97 : }
98 1899 : }
99 :
100 : template <class T>
101 495 : FunctionOfMatrix<T>::FunctionOfMatrix(const ActionOptions&ao):
102 : Action(ao),
103 : ActionWithMatrix(ao),
104 495 : firststep(true) {
105 451 : if( myfunc.getArgStart()>0 ) {
106 : error("this has not beeen implemented -- if you are interested email gareth.tribello@gmail.com");
107 : }
108 : // Get the shape of the output
109 495 : std::vector<unsigned> shape( getValueShapeFromArguments() );
110 : // Check if the output matrix is symmetric
111 495 : bool symmetric=true;
112 : unsigned argstart=myfunc.getArgStart();
113 1508 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
114 1013 : if( getPntrToArgument(i)->getRank()==2 ) {
115 948 : if( !getPntrToArgument(i)->isSymmetric() ) {
116 831 : symmetric=false;
117 : }
118 : }
119 : }
120 : // Read the input and do some checks
121 495 : myfunc.read( this );
122 : // Setup to do this in chain if possible
123 : if( myfunc.doWithTasks() ) {
124 495 : done_in_chain=true;
125 : }
126 : // Check we are not calculating a sum
127 41 : if( myfunc.zeroRank() ) {
128 41 : shape.resize(0);
129 : }
130 : // Get the names of the components
131 495 : std::vector<std::string> components( keywords.getOutputComponents() );
132 : // Create the values to hold the output
133 42 : std::vector<std::string> str_ind( myfunc.getComponentsPerLabel() );
134 1034 : for(unsigned i=0; i<components.size(); ++i) {
135 84 : if( str_ind.size()>0 ) {
136 84 : std::string compstr = components[i];
137 84 : if( components[i]==".#!value" ) {
138 : compstr = "";
139 : }
140 760 : for(unsigned j=0; j<str_ind.size(); ++j) {
141 : if( myfunc.zeroRank() ) {
142 : addComponentWithDerivatives( compstr + str_ind[j], shape );
143 : } else {
144 1352 : addComponent( compstr + str_ind[j], shape );
145 676 : getPntrToComponent(i*str_ind.size()+j)->setSymmetric( symmetric );
146 : }
147 : }
148 41 : } else if( components[i]==".#!value" && myfunc.zeroRank() ) {
149 41 : addValueWithDerivatives( shape );
150 414 : } else if( components[i]==".#!value" ) {
151 410 : addValue( shape );
152 410 : getPntrToComponent(0)->setSymmetric( symmetric );
153 4 : } else if( components[i].find_first_of("_")!=std::string::npos ) {
154 0 : if( getNumberOfArguments()-argstart==1 ) {
155 0 : addValue( shape );
156 0 : getPntrToComponent(0)->setSymmetric( symmetric );
157 : } else {
158 0 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
159 0 : addComponent( getPntrToArgument(j)->getName() + components[i], shape );
160 0 : getPntrToComponent(i*(getNumberOfArguments()-argstart)+j-argstart)->setSymmetric( symmetric );
161 : }
162 : }
163 : } else {
164 4 : addComponent( components[i], shape );
165 4 : getPntrToComponent(i)->setSymmetric( symmetric );
166 : }
167 : }
168 : // Check if this can be sped up
169 370 : if( myfunc.getDerivativeZeroIfValueIsZero() ) {
170 174 : for(int i=0; i<getNumberOfComponents(); ++i) {
171 87 : getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero();
172 : }
173 : }
174 : // Set the periodicities of the output components
175 495 : myfunc.setPeriodicityForOutputs( this );
176 : // We can't do this with if we are dividing a stack by some a product v.v^T product as we need to store the vector
177 : // In order to do this type of calculation. There should be a neater fix than this but I can't see it.
178 : bool foundneigh=false;
179 : const ActionWithMatrix* chainstart = NULL;
180 1503 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
181 1011 : if( getPntrToArgument(i)->isConstant() && getNumberOfArguments()>1 ) {
182 275 : continue ;
183 : }
184 932 : std::string argname=(getPntrToArgument(i)->getPntrToAction())->getName();
185 932 : if( argname=="NEIGHBORS" ) {
186 : foundneigh=true;
187 : break;
188 : }
189 929 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
190 929 : if( !av ) {
191 29 : done_in_chain=false;
192 : }
193 929 : if( getPntrToArgument(i)->getRank()==0 ) {
194 0 : function::FunctionOfVector<function::Sum>* as = dynamic_cast<function::FunctionOfVector<function::Sum>*>( getPntrToArgument(i)->getPntrToAction() );
195 0 : if(as) {
196 0 : done_in_chain=false;
197 : }
198 929 : } else if( getPntrToArgument(i)->ignoreStoredValue( getLabel() ) ) {
199 : // This option deals with the case when you have two adjacency matrices, A_ij and B_ij, multiplied together. This cannot be done in the chain as the rows
200 : // of the two adjacency matrix are run over separately. The value A_ij is thus not available when B_ij is calculated.
201 853 : ActionWithMatrix* am = dynamic_cast<ActionWithMatrix*>( getPntrToArgument(i)->getPntrToAction() );
202 853 : plumed_assert( am );
203 853 : const ActionWithMatrix* thischain = am->getFirstMatrixInChain();
204 853 : if( !thischain->isAdjacencyMatrix() && thischain->getName()!="VSTACK" ) {
205 : continue;
206 : }
207 657 : if( !chainstart ) {
208 : chainstart = thischain;
209 317 : } else if( thischain!=chainstart ) {
210 1 : done_in_chain=false;
211 : }
212 : }
213 : }
214 : // If we are working with neighbors we trick PLUMED into storing ALL the components of the other arguments
215 : // in this way we can ensure that the function of the neighbours matrix is in a chain starting from the
216 : // Neighbours matrix action.
217 : if( foundneigh ) {
218 9 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
219 6 : ActionWithValue* av=getPntrToArgument(i)->getPntrToAction();
220 6 : if( av->getName()!="NEIGHBORS" ) {
221 8 : for(int i=0; i<av->getNumberOfComponents(); ++i) {
222 5 : (av->copyOutput(i))->buildDataStore();
223 : }
224 : }
225 : }
226 : }
227 : bool allconstant=true;
228 502 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
229 501 : if( !getPntrToArgument(i)->isConstant() ) {
230 : allconstant=false;
231 : break;
232 : }
233 : }
234 495 : if( allconstant ) {
235 1 : done_in_chain=false;
236 : }
237 : // Now setup the action in the chain if we can
238 495 : nderivatives = buildArgumentStore(myfunc.getArgStart());
239 990 : }
240 :
241 : template <class T>
242 1921 : void FunctionOfMatrix<T>::turnOnDerivatives() {
243 1921 : if( !myfunc.derivativesImplemented() ) {
244 : error("derivatives have not been implemended for " + getName() );
245 : }
246 1921 : ActionWithValue::turnOnDerivatives();
247 1921 : myfunc.setup(this);
248 1921 : }
249 :
250 : template <class T>
251 30411 : unsigned FunctionOfMatrix<T>::getNumberOfDerivatives() {
252 30411 : return nderivatives;
253 : }
254 :
255 : template <class T>
256 2229 : void FunctionOfMatrix<T>::prepare() {
257 : unsigned argstart = myfunc.getArgStart();
258 2229 : std::vector<unsigned> shape(2);
259 2229 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
260 2229 : if( getPntrToArgument(i)->getRank()==2 ) {
261 2229 : shape[0] = getPntrToArgument(i)->getShape()[0];
262 2229 : shape[1] = getPntrToArgument(i)->getShape()[1];
263 2229 : break;
264 : }
265 : }
266 6682 : for(unsigned i=0; i<getNumberOfComponents(); ++i) {
267 4453 : Value* myval = getPntrToComponent(i);
268 4453 : if( myval->getRank()==2 && (myval->getShape()[0]!=shape[0] || myval->getShape()[1]!=shape[1]) ) {
269 18 : myval->setShape(shape);
270 18 : if( myval->valueIsStored() ) {
271 18 : myval->reshapeMatrixStore( shape[1] );
272 : }
273 : }
274 : }
275 2229 : ActionWithVector::prepare();
276 2229 : }
277 :
278 : template <class T>
279 281844 : unsigned FunctionOfMatrix<T>::getNumberOfColumns() const {
280 281844 : if( getConstPntrToComponent(0)->getRank()==2 ) {
281 : unsigned argstart=myfunc.getArgStart();
282 281844 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
283 281844 : if( getPntrToArgument(i)->getRank()==2 ) {
284 281844 : ActionWithMatrix* am=dynamic_cast<ActionWithMatrix*>( getPntrToArgument(i)->getPntrToAction() );
285 281844 : if( am ) {
286 279606 : return am->getNumberOfColumns();
287 : }
288 2238 : return getPntrToArgument(i)->getShape()[1];
289 : }
290 : }
291 : }
292 0 : plumed_error();
293 : return 0;
294 : }
295 :
296 : template <class T>
297 4209 : void FunctionOfMatrix<T>::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
298 11667 : for(unsigned i=0; i<getNumberOfArguments(); ++i) {
299 7458 : plumed_assert( getPntrToArgument(i)->getRank()==2 );
300 : }
301 4209 : unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getShape()[1];
302 4209 : if( indices.size()!=size_v+1 ) {
303 421 : indices.resize( size_v+1 );
304 : }
305 642613 : for(unsigned i=0; i<size_v; ++i) {
306 638404 : indices[i+1] = start_n + i;
307 : }
308 : myvals.setSplitIndex( size_v + 1 );
309 4209 : }
310 :
311 : // template <class T>
312 : // void FunctionOfMatrix<T>::buildTaskListFromArgumentRequests( const unsigned& ntasks, bool& reduce, std::set<AtomNumber>& otasks ) {
313 : // // Check if this is the first element in a chain
314 : // if( actionInChain() ) return;
315 : // // If it is computed outside a chain get the tassks the daughter chain needs
316 : // propegateTaskListsForValue( 0, ntasks, reduce, otasks );
317 : // }
318 :
319 : template <class T>
320 2525 : void FunctionOfMatrix<T>::setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) {
321 2525 : if( firststep ) {
322 489 : stored_arguments.resize( getNumberOfArguments() );
323 489 : update_arguments.resize( getNumberOfArguments(), true );
324 489 : std::string control = getFirstActionInChain()->getLabel();
325 1484 : for(unsigned i=0; i<stored_arguments.size(); ++i) {
326 995 : stored_arguments[i] = !getPntrToArgument(i)->ignoreStoredValue( control );
327 995 : if( !stored_arguments[i] ) {
328 : update_arguments[i] = true;
329 : } else {
330 164 : update_arguments[i] = !argumentDependsOn( headstr, this, getPntrToArgument(i) );
331 : }
332 : }
333 489 : firststep=false;
334 : }
335 2525 : ActionWithMatrix::setupStreamedComponents( headstr, nquants, nmat, maxcol, nbookeeping );
336 2525 : }
337 :
338 : template <class T>
339 27928651 : void FunctionOfMatrix<T>::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
340 : unsigned argstart=myfunc.getArgStart();
341 27928651 : std::vector<double> args( getNumberOfArguments() - argstart );
342 27928651 : unsigned ind2 = index2;
343 27928651 : if( getConstPntrToComponent(0)->getRank()==2 && index2>=getConstPntrToComponent(0)->getShape()[0] ) {
344 3636383 : ind2 = index2 - getConstPntrToComponent(0)->getShape()[0];
345 24292268 : } else if( index2>=getPntrToArgument(0)->getShape()[0] ) {
346 447979 : ind2 = index2 - getPntrToArgument(0)->getShape()[0];
347 : }
348 27928651 : if( actionInChain() ) {
349 85619946 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
350 58329699 : if( getPntrToArgument(i)->getRank()==0 ) {
351 135720 : args[i-argstart] = getPntrToArgument(i)->get();
352 58193979 : } else if( !getPntrToArgument(i)->valueHasBeenSet() ) {
353 57005386 : args[i-argstart] = myvals.get( getPntrToArgument(i)->getPositionInStream() );
354 : } else {
355 1188593 : args[i-argstart] = getPntrToArgument(i)->get( getPntrToArgument(i)->getShape()[1]*index1 + ind2 );
356 : }
357 : }
358 : } else {
359 1727072 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
360 1088668 : if( getPntrToArgument(i)->getRank()==2 ) {
361 1088668 : args[i-argstart]=getPntrToArgument(i)->get( getPntrToArgument(i)->getShape()[1]*index1 + ind2 );
362 : } else {
363 0 : args[i-argstart] = getPntrToArgument(i)->get();
364 : }
365 : }
366 : }
367 : // Calculate the function and its derivatives
368 27928651 : std::vector<double> vals( getNumberOfComponents() );
369 27928651 : Matrix<double> derivatives( getNumberOfComponents(), getNumberOfArguments()-argstart );
370 27928651 : myfunc.calc( this, args, vals, derivatives );
371 : // And set the values
372 99634355 : for(unsigned i=0; i<vals.size(); ++i) {
373 71705704 : myvals.addValue( getConstPntrToComponent(i)->getPositionInStream(), vals[i] );
374 : }
375 : // Return if we are not computing derivatives
376 27928651 : if( doNotCalculateDerivatives() ) {
377 : return;
378 : }
379 :
380 5399619 : if( actionInChain() ) {
381 33385311 : for(int i=0; i<getNumberOfComponents(); ++i) {
382 27990552 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
383 131996523 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
384 104005971 : if( getPntrToArgument(j)->getRank()==2 ) {
385 103890411 : unsigned istrn = getPntrToArgument(j)->getPositionInStream();
386 103890411 : if( stored_arguments[j] ) {
387 395048 : unsigned task_index = getPntrToArgument(i)->getShape()[1]*index1 + ind2;
388 395048 : myvals.clearDerivatives(istrn);
389 395048 : myvals.addDerivative( istrn, task_index, 1.0 );
390 395048 : myvals.updateIndex( istrn, task_index );
391 : }
392 470717695 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
393 366827284 : unsigned kind=myvals.getActiveIndex(istrn,k);
394 366827284 : myvals.addDerivative( ostrn, arg_deriv_starts[j] + kind, derivatives(i,j)*myvals.getDerivative( istrn, kind ) );
395 : }
396 : }
397 : }
398 : }
399 : // If we are computing a matrix we need to update the indices here so that derivatives are calcualted correctly in functions of these
400 5394759 : if( getConstPntrToComponent(0)->getRank()==2 ) {
401 32784527 : for(int i=0; i<getNumberOfComponents(); ++i) {
402 27690160 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
403 131395739 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
404 103705579 : if( !update_arguments[j] || getPntrToArgument(j)->getRank()==0 ) {
405 115584 : continue ;
406 : }
407 : // Ensure we only store one lot of derivative indices
408 : bool found=false;
409 105009550 : for(unsigned k=0; k<j; ++k) {
410 76601003 : if( arg_deriv_starts[k]==arg_deriv_starts[j] ) {
411 : found=true;
412 : break;
413 : }
414 : }
415 103589995 : if( found ) {
416 75181448 : continue;
417 : }
418 : unsigned istrn = getPntrToArgument(j)->getPositionInStream();
419 138447375 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
420 110038828 : unsigned kind=myvals.getActiveIndex(istrn,k);
421 110038828 : myvals.updateIndex( ostrn, arg_deriv_starts[j] + kind );
422 : }
423 : }
424 : }
425 : }
426 : } else {
427 : unsigned base=0;
428 4860 : ind2 = index2;
429 4860 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
430 4860 : if( getPntrToArgument(j)->getRank()!=2 ) {
431 : continue ;
432 : }
433 4860 : if( index2>=getPntrToArgument(j)->getShape()[0] ) {
434 4860 : ind2 = index2 - getPntrToArgument(j)->getShape()[0];
435 : }
436 : break;
437 : }
438 13965 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
439 9105 : if( getPntrToArgument(j)->getRank()==2 ) {
440 18210 : for(int i=0; i<getNumberOfComponents(); ++i) {
441 9105 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
442 9105 : unsigned myind = base + getPntrToArgument(j)->getShape()[1]*index1 + ind2;
443 9105 : myvals.addDerivative( ostrn, myind, derivatives(i,j) );
444 9105 : myvals.updateIndex( ostrn, myind );
445 : }
446 : } else {
447 0 : for(int i=0; i<getNumberOfComponents(); ++i) {
448 0 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
449 0 : myvals.addDerivative( ostrn, base, derivatives(i,j) );
450 0 : myvals.updateIndex( ostrn, base );
451 : }
452 : }
453 9105 : base += getPntrToArgument(j)->getNumberOfValues();
454 : }
455 : }
456 : }
457 :
458 : template <class T>
459 226389 : void FunctionOfMatrix<T>::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
460 226389 : if( doNotCalculateDerivatives() ) {
461 : return;
462 : }
463 :
464 : unsigned argstart=myfunc.getArgStart();
465 71237 : if( actionInChain() && getConstPntrToComponent(0)->getRank()==2 ) {
466 : // This is triggered if we are outputting a matrix
467 624578 : for(int vv=0; vv<getNumberOfComponents(); ++vv) {
468 558183 : unsigned nmat = getConstPntrToComponent(vv)->getPositionInMatrixStash();
469 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
470 : unsigned ntot_mat=0;
471 558183 : if( mat_indices.size()<nderivatives ) {
472 0 : mat_indices.resize( nderivatives );
473 : }
474 2701544 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
475 2143361 : if( !update_arguments[i] || getPntrToArgument(i)->getRank()==0 ) {
476 1084 : continue ;
477 : }
478 : // Ensure we only store one lot of derivative indices
479 : bool found=false;
480 2168017 : for(unsigned j=0; j<i; ++j) {
481 1591516 : if( arg_deriv_starts[j]==arg_deriv_starts[i] ) {
482 : found=true;
483 : break;
484 : }
485 : }
486 2142277 : if( found ) {
487 1565776 : continue;
488 : }
489 :
490 576501 : if( stored_arguments[i] ) {
491 15483 : unsigned tbase = getPntrToArgument(i)->getShape()[1]*ind;
492 410507 : for(unsigned k=1; k<indices.size(); ++k) {
493 395024 : unsigned ind2 = indices[k] - getConstPntrToComponent(0)->getShape()[0];
494 395024 : mat_indices[ntot_mat + k - 1] = arg_deriv_starts[i] + tbase + ind2;
495 : }
496 15483 : ntot_mat += indices.size()-1;
497 : } else {
498 : unsigned istrn = getPntrToArgument(i)->getPositionInMatrixStash();
499 : std::vector<unsigned>& imat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
500 31848426 : for(unsigned k=0; k<myvals.getNumberOfMatrixRowDerivatives( istrn ); ++k) {
501 31287408 : mat_indices[ntot_mat + k] = arg_deriv_starts[i] + imat_indices[k];
502 : }
503 561018 : ntot_mat += myvals.getNumberOfMatrixRowDerivatives( istrn );
504 : }
505 : }
506 : myvals.setNumberOfMatrixRowDerivatives( nmat, ntot_mat );
507 : }
508 4842 : } else if( actionInChain() ) {
509 : // This is triggered if we are calculating a single scalar in the function
510 8822 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
511 : bool found=false;
512 4411 : for(unsigned j=0; j<i; ++j) {
513 0 : if( arg_deriv_starts[j]==arg_deriv_starts[i] ) {
514 : found=true;
515 : break;
516 : }
517 : }
518 4411 : if( found ) {
519 : continue;
520 : }
521 : unsigned istrn = getPntrToArgument(i)->getPositionInMatrixStash();
522 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
523 926766 : for(unsigned k=0; k<myvals.getNumberOfMatrixRowDerivatives( istrn ); ++k) {
524 1844710 : for(int j=0; j<getNumberOfComponents(); ++j) {
525 922355 : unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
526 922355 : myvals.updateIndex( ostrn, arg_deriv_starts[i] + mat_indices[k] );
527 : }
528 : }
529 : }
530 431 : } else if( getConstPntrToComponent(0)->getRank()==2 ) {
531 760 : for(int vv=0; vv<getNumberOfComponents(); ++vv) {
532 380 : unsigned nmat = getConstPntrToComponent(vv)->getPositionInMatrixStash();
533 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
534 : unsigned ntot_mat=0;
535 380 : if( mat_indices.size()<nderivatives ) {
536 0 : mat_indices.resize( nderivatives );
537 : }
538 : unsigned matderbase = 0;
539 986 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
540 606 : if( getPntrToArgument(i)->getRank()==0 ) {
541 0 : continue ;
542 : }
543 606 : unsigned ss = getPntrToArgument(i)->getShape()[1];
544 606 : unsigned tbase = matderbase + ss*myvals.getTaskIndex();
545 9558 : for(unsigned k=0; k<ss; ++k) {
546 8952 : mat_indices[ntot_mat + k] = tbase + k;
547 : }
548 606 : ntot_mat += ss;
549 606 : matderbase += getPntrToArgument(i)->getNumberOfValues();
550 : }
551 : myvals.setNumberOfMatrixRowDerivatives( nmat, ntot_mat );
552 : }
553 : }
554 : }
555 :
556 : template <class T>
557 495 : std::vector<unsigned> FunctionOfMatrix<T>::getValueShapeFromArguments() {
558 : unsigned argstart=myfunc.getArgStart();
559 495 : std::vector<unsigned> shape(2);
560 495 : shape[0]=shape[1]=0;
561 1508 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
562 1013 : plumed_assert( getPntrToArgument(i)->getRank()==2 || getPntrToArgument(i)->getRank()==0 );
563 1013 : if( getPntrToArgument(i)->getRank()==2 ) {
564 948 : if( shape[0]>0 && (getPntrToArgument(i)->getShape()[0]!=shape[0] || getPntrToArgument(i)->getShape()[1]!=shape[1]) ) {
565 0 : error("all matrices input should have the same shape");
566 948 : } else if( shape[0]==0 ) {
567 509 : shape[0]=getPntrToArgument(i)->getShape()[0];
568 509 : shape[1]=getPntrToArgument(i)->getShape()[1];
569 : }
570 948 : plumed_assert( !getPntrToArgument(i)->hasDerivatives() );
571 : }
572 : }
573 41 : myfunc.setPrefactor( this, 1.0 );
574 495 : return shape;
575 : }
576 :
577 : }
578 : }
579 : #endif
|