Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2020 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #ifndef __PLUMED_function_FunctionOfMatrix_h
23 : #define __PLUMED_function_FunctionOfMatrix_h
24 :
25 : #include "core/ActionWithMatrix.h"
26 : #include "FunctionOfVector.h"
27 : #include "Sum.h"
28 : #include "tools/Matrix.h"
29 :
30 : namespace PLMD {
31 : namespace function {
32 :
33 : template <class T>
34 : class FunctionOfMatrix : public ActionWithMatrix {
35 : private:
36 : /// Is this the first step of the calculation
37 : bool firststep;
38 : /// The function that is being computed
39 : T myfunc;
40 : /// The number of derivatives for this action
41 : unsigned nderivatives;
42 : /// A vector that tells us if we have stored the input value
43 : std::vector<bool> stored_arguments;
44 : /// Switch off updating the arguments for this action
45 : std::vector<bool> update_arguments;
46 : /// The list of actiosn in this chain
47 : std::vector<std::string> actionsLabelsInChain;
48 : /// Get the shape of the output matrix
49 : std::vector<unsigned> getValueShapeFromArguments();
50 : public:
51 : static void registerKeywords(Keywords&);
52 : explicit FunctionOfMatrix(const ActionOptions&);
53 : /// Get the label to write in the graph
54 0 : std::string writeInGraph() const override { return myfunc.getGraphInfo( getName() ); }
55 : /// Make sure the derivatives are turned on
56 : void turnOnDerivatives() override;
57 : /// Get the number of derivatives for this action
58 : unsigned getNumberOfDerivatives() override ;
59 : /// Resize the matrices
60 : void prepare() override ;
61 : /// This gets the number of columns
62 : unsigned getNumberOfColumns() const override ;
63 : /// This checks for tasks in the parent class
64 : // void buildTaskListFromArgumentRequests( const unsigned& ntasks, bool& reduce, std::set<AtomNumber>& otasks ) override ;
65 : /// This ensures that we create some bookeeping stuff during the first step
66 : void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override ;
67 : /// This sets up for the task
68 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
69 : /// Calculate the full matrix
70 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override ;
71 : /// This updates the indices for the matrix
72 : // void updateCentralMatrixIndex( const unsigned& ind, const std::vector<unsigned>& indices, MultiValue& myvals ) const override ;
73 : void runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
74 : };
75 :
76 : template <class T>
77 1013 : void FunctionOfMatrix<T>::registerKeywords(Keywords& keys ) {
78 1013 : ActionWithMatrix::registerKeywords(keys); std::string name = keys.getDisplayName();
79 1013 : std::size_t und=name.find("_MATRIX"); keys.setDisplayName( name.substr(0,und) );
80 2026 : keys.addInputKeyword("compulsory","ARG","scalar/matrix","the labels of the scalar and matrices that on which the function is being calculated elementwise");
81 2026 : keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
82 2026 : keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function. If the output is not periodic you must state this using PERIODIC=NO");
83 1013 : T tfunc; tfunc.registerKeywords( keys );
84 2026 : if( keys.getDisplayName()=="SUM" ) {
85 168 : keys.setValueDescription("scalar","the sum of all the elements in the input matrix");
86 1858 : } else if( keys.getDisplayName()=="HIGHEST" ) {
87 0 : keys.setValueDescription("scalar","the largest element of the input matrix");
88 1858 : } else if( keys.getDisplayName()=="LOWEST" ) {
89 0 : keys.setValueDescription("scalar","the smallest element in the input matrix");
90 1858 : } else if( keys.outputComponentExists(".#!value") ) {
91 1672 : keys.setValueDescription("matrix","the matrix obtained by doing an element-wise application of " + keys.getOutputComponentDescription(".#!value") + " to the input matrix");
92 : }
93 1899 : }
94 :
95 : template <class T>
96 495 : FunctionOfMatrix<T>::FunctionOfMatrix(const ActionOptions&ao):
97 : Action(ao),
98 : ActionWithMatrix(ao),
99 495 : firststep(true)
100 : {
101 451 : if( myfunc.getArgStart()>0 ) error("this has not beeen implemented -- if you are interested email gareth.tribello@gmail.com");
102 : // Get the shape of the output
103 495 : std::vector<unsigned> shape( getValueShapeFromArguments() );
104 : // Check if the output matrix is symmetric
105 495 : bool symmetric=true; unsigned argstart=myfunc.getArgStart();
106 1508 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
107 1013 : if( getPntrToArgument(i)->getRank()==2 ) {
108 948 : if( !getPntrToArgument(i)->isSymmetric() ) { symmetric=false; }
109 : }
110 : }
111 : // Read the input and do some checks
112 495 : myfunc.read( this );
113 : // Setup to do this in chain if possible
114 495 : if( myfunc.doWithTasks() ) done_in_chain=true;
115 : // Check we are not calculating a sum
116 41 : if( myfunc.zeroRank() ) shape.resize(0);
117 : // Get the names of the components
118 495 : std::vector<std::string> components( keywords.getOutputComponents() );
119 : // Create the values to hold the output
120 42 : std::vector<std::string> str_ind( myfunc.getComponentsPerLabel() );
121 1034 : for(unsigned i=0; i<components.size(); ++i) {
122 84 : if( str_ind.size()>0 ) {
123 168 : std::string compstr = components[i]; if( components[i]==".#!value" ) compstr = "";
124 760 : for(unsigned j=0; j<str_ind.size(); ++j) {
125 : if( myfunc.zeroRank() ) {
126 : addComponentWithDerivatives( compstr + str_ind[j], shape );
127 : } else {
128 1352 : addComponent( compstr + str_ind[j], shape );
129 676 : getPntrToComponent(i*str_ind.size()+j)->setSymmetric( symmetric );
130 : }
131 : }
132 41 : } else if( components[i]==".#!value" && myfunc.zeroRank() ) {
133 41 : addValueWithDerivatives( shape );
134 414 : } else if( components[i]==".#!value" ) {
135 410 : addValue( shape ); getPntrToComponent(0)->setSymmetric( symmetric );
136 4 : } else if( components[i].find_first_of("_")!=std::string::npos ) {
137 0 : if( getNumberOfArguments()-argstart==1 ) { addValue( shape ); getPntrToComponent(0)->setSymmetric( symmetric ); }
138 : else {
139 0 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
140 0 : addComponent( getPntrToArgument(j)->getName() + components[i], shape );
141 0 : getPntrToComponent(i*(getNumberOfArguments()-argstart)+j-argstart)->setSymmetric( symmetric );
142 : }
143 : }
144 4 : } else { addComponent( components[i], shape ); getPntrToComponent(i)->setSymmetric( symmetric ); }
145 : }
146 : // Check if this can be sped up
147 370 : if( myfunc.getDerivativeZeroIfValueIsZero() ) {
148 174 : for(int i=0; i<getNumberOfComponents(); ++i) getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero();
149 : }
150 : // Set the periodicities of the output components
151 495 : myfunc.setPeriodicityForOutputs( this );
152 : // We can't do this with if we are dividing a stack by some a product v.v^T product as we need to store the vector
153 : // In order to do this type of calculation. There should be a neater fix than this but I can't see it.
154 : bool foundneigh=false; const ActionWithMatrix* chainstart = NULL;
155 1503 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
156 1207 : if( getPntrToArgument(i)->isConstant() && getNumberOfArguments()>1 ) continue ;
157 932 : std::string argname=(getPntrToArgument(i)->getPntrToAction())->getName();
158 932 : if( argname=="NEIGHBORS" ) { foundneigh=true; break; }
159 929 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
160 929 : if( !av ) done_in_chain=false;
161 929 : if( getPntrToArgument(i)->getRank()==0 ) {
162 0 : function::FunctionOfVector<function::Sum>* as = dynamic_cast<function::FunctionOfVector<function::Sum>*>( getPntrToArgument(i)->getPntrToAction() );
163 0 : if(as) done_in_chain=false;
164 929 : } else if( getPntrToArgument(i)->ignoreStoredValue( getLabel() ) ) {
165 : // This option deals with the case when you have two adjacency matrices, A_ij and B_ij, multiplied together. This cannot be done in the chain as the rows
166 : // of the two adjacency matrix are run over separately. The value A_ij is thus not available when B_ij is calculated.
167 853 : ActionWithMatrix* am = dynamic_cast<ActionWithMatrix*>( getPntrToArgument(i)->getPntrToAction() );
168 853 : plumed_assert( am ); const ActionWithMatrix* thischain = am->getFirstMatrixInChain();
169 853 : if( !thischain->isAdjacencyMatrix() && thischain->getName()!="VSTACK" ) continue;
170 657 : if( !chainstart ) chainstart = thischain;
171 317 : else if( thischain!=chainstart ) done_in_chain=false;
172 : }
173 : }
174 : // If we are working with neighbors we trick PLUMED into storing ALL the components of the other arguments
175 : // in this way we can ensure that the function of the neighbours matrix is in a chain starting from the
176 : // Neighbours matrix action.
177 : if( foundneigh ) {
178 9 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
179 6 : ActionWithValue* av=getPntrToArgument(i)->getPntrToAction();
180 6 : if( av->getName()!="NEIGHBORS" ) {
181 8 : for(int i=0; i<av->getNumberOfComponents(); ++i) (av->copyOutput(i))->buildDataStore();
182 : }
183 : }
184 : }
185 : // Now setup the action in the chain if we can
186 495 : nderivatives = buildArgumentStore(myfunc.getArgStart());
187 990 : }
188 :
189 : template <class T>
190 1921 : void FunctionOfMatrix<T>::turnOnDerivatives() {
191 1921 : if( !myfunc.derivativesImplemented() ) error("derivatives have not been implemended for " + getName() );
192 1921 : ActionWithValue::turnOnDerivatives(); myfunc.setup(this);
193 1921 : }
194 :
195 : template <class T>
196 30411 : unsigned FunctionOfMatrix<T>::getNumberOfDerivatives() {
197 30411 : return nderivatives;
198 : }
199 :
200 : template <class T>
201 2229 : void FunctionOfMatrix<T>::prepare() {
202 2229 : unsigned argstart = myfunc.getArgStart(); std::vector<unsigned> shape(2);
203 2229 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
204 2229 : if( getPntrToArgument(i)->getRank()==2 ) {
205 2229 : shape[0] = getPntrToArgument(i)->getShape()[0];
206 2229 : shape[1] = getPntrToArgument(i)->getShape()[1];
207 2229 : break;
208 : }
209 : }
210 6682 : for(unsigned i=0; i<getNumberOfComponents(); ++i) {
211 4453 : Value* myval = getPntrToComponent(i);
212 4453 : if( myval->getRank()==2 && (myval->getShape()[0]!=shape[0] || myval->getShape()[1]!=shape[1]) ) {
213 18 : myval->setShape(shape); if( myval->valueIsStored() ) myval->reshapeMatrixStore( shape[1] );
214 : }
215 : }
216 2229 : ActionWithVector::prepare();
217 2229 : }
218 :
219 : template <class T>
220 281844 : unsigned FunctionOfMatrix<T>::getNumberOfColumns() const {
221 281844 : if( getConstPntrToComponent(0)->getRank()==2 ) {
222 : unsigned argstart=myfunc.getArgStart();
223 281844 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
224 281844 : if( getPntrToArgument(i)->getRank()==2 ) {
225 281844 : ActionWithMatrix* am=dynamic_cast<ActionWithMatrix*>( getPntrToArgument(i)->getPntrToAction() );
226 281844 : if( am ) return am->getNumberOfColumns();
227 2238 : return getPntrToArgument(i)->getShape()[1];
228 : }
229 : }
230 : }
231 0 : plumed_error(); return 0;
232 : }
233 :
234 : template <class T>
235 4209 : void FunctionOfMatrix<T>::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
236 11667 : for(unsigned i=0; i<getNumberOfArguments(); ++i) plumed_assert( getPntrToArgument(i)->getRank()==2 );
237 4209 : unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getShape()[1];
238 4209 : if( indices.size()!=size_v+1 ) indices.resize( size_v+1 );
239 642613 : for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + i;
240 : myvals.setSplitIndex( size_v + 1 );
241 4209 : }
242 :
243 : // template <class T>
244 : // void FunctionOfMatrix<T>::buildTaskListFromArgumentRequests( const unsigned& ntasks, bool& reduce, std::set<AtomNumber>& otasks ) {
245 : // // Check if this is the first element in a chain
246 : // if( actionInChain() ) return;
247 : // // If it is computed outside a chain get the tassks the daughter chain needs
248 : // propegateTaskListsForValue( 0, ntasks, reduce, otasks );
249 : // }
250 :
251 : template <class T>
252 2525 : void FunctionOfMatrix<T>::setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) {
253 2525 : if( firststep ) {
254 489 : stored_arguments.resize( getNumberOfArguments() );
255 489 : update_arguments.resize( getNumberOfArguments(), true );
256 489 : std::string control = getFirstActionInChain()->getLabel();
257 1484 : for(unsigned i=0; i<stored_arguments.size(); ++i) {
258 995 : stored_arguments[i] = !getPntrToArgument(i)->ignoreStoredValue( control );
259 995 : if( !stored_arguments[i] ) update_arguments[i] = true;
260 164 : else update_arguments[i] = !argumentDependsOn( headstr, this, getPntrToArgument(i) );
261 : }
262 489 : firststep=false;
263 : }
264 2525 : ActionWithMatrix::setupStreamedComponents( headstr, nquants, nmat, maxcol, nbookeeping );
265 2525 : }
266 :
267 : template <class T>
268 27928651 : void FunctionOfMatrix<T>::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
269 27928651 : unsigned argstart=myfunc.getArgStart(); std::vector<double> args( getNumberOfArguments() - argstart );
270 27928651 : unsigned ind2 = index2;
271 27928651 : if( getConstPntrToComponent(0)->getRank()==2 && index2>=getConstPntrToComponent(0)->getShape()[0] ) ind2 = index2 - getConstPntrToComponent(0)->getShape()[0];
272 24292268 : else if( index2>=getPntrToArgument(0)->getShape()[0] ) ind2 = index2 - getPntrToArgument(0)->getShape()[0];
273 27928651 : if( actionInChain() ) {
274 85619946 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
275 58329699 : if( getPntrToArgument(i)->getRank()==0 ) args[i-argstart] = getPntrToArgument(i)->get();
276 58193979 : else if( !getPntrToArgument(i)->valueHasBeenSet() ) args[i-argstart] = myvals.get( getPntrToArgument(i)->getPositionInStream() );
277 1188593 : else args[i-argstart] = getPntrToArgument(i)->get( getPntrToArgument(i)->getShape()[1]*index1 + ind2 );
278 : }
279 : } else {
280 1727072 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
281 1088668 : if( getPntrToArgument(i)->getRank()==2 ) args[i-argstart]=getPntrToArgument(i)->get( getPntrToArgument(i)->getShape()[1]*index1 + ind2 );
282 0 : else args[i-argstart] = getPntrToArgument(i)->get();
283 : }
284 : }
285 : // Calculate the function and its derivatives
286 27928651 : std::vector<double> vals( getNumberOfComponents() ); Matrix<double> derivatives( getNumberOfComponents(), getNumberOfArguments()-argstart );
287 27928651 : myfunc.calc( this, args, vals, derivatives );
288 : // And set the values
289 99634355 : for(unsigned i=0; i<vals.size(); ++i) myvals.addValue( getConstPntrToComponent(i)->getPositionInStream(), vals[i] );
290 : // Return if we are not computing derivatives
291 27928651 : if( doNotCalculateDerivatives() ) return;
292 :
293 5399619 : if( actionInChain() ) {
294 33385311 : for(int i=0; i<getNumberOfComponents(); ++i) {
295 27990552 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
296 131996523 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
297 104005971 : if( getPntrToArgument(j)->getRank()==2 ) {
298 103890411 : unsigned istrn = getPntrToArgument(j)->getPositionInStream();
299 103890411 : if( stored_arguments[j] ) {
300 395048 : unsigned task_index = getPntrToArgument(i)->getShape()[1]*index1 + ind2;
301 395048 : myvals.clearDerivatives(istrn); myvals.addDerivative( istrn, task_index, 1.0 ); myvals.updateIndex( istrn, task_index );
302 : }
303 470717695 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
304 366827284 : unsigned kind=myvals.getActiveIndex(istrn,k);
305 366827284 : myvals.addDerivative( ostrn, arg_deriv_starts[j] + kind, derivatives(i,j)*myvals.getDerivative( istrn, kind ) );
306 : }
307 : }
308 : }
309 : }
310 : // If we are computing a matrix we need to update the indices here so that derivatives are calcualted correctly in functions of these
311 5394759 : if( getConstPntrToComponent(0)->getRank()==2 ) {
312 32784527 : for(int i=0; i<getNumberOfComponents(); ++i) {
313 27690160 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
314 131395739 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
315 103705579 : if( !update_arguments[j] || getPntrToArgument(j)->getRank()==0 ) continue ;
316 : // Ensure we only store one lot of derivative indices
317 : bool found=false;
318 105009550 : for(unsigned k=0; k<j; ++k) {
319 76601003 : if( arg_deriv_starts[k]==arg_deriv_starts[j] ) { found=true; break; }
320 : }
321 103589995 : if( found ) continue;
322 : unsigned istrn = getPntrToArgument(j)->getPositionInStream();
323 138447375 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
324 110038828 : unsigned kind=myvals.getActiveIndex(istrn,k);
325 110038828 : myvals.updateIndex( ostrn, arg_deriv_starts[j] + kind );
326 : }
327 : }
328 : }
329 : }
330 : } else {
331 4860 : unsigned base=0; ind2 = index2;
332 4860 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
333 4860 : if( getPntrToArgument(j)->getRank()!=2 ) continue ;
334 4860 : if( index2>=getPntrToArgument(j)->getShape()[0] ) ind2 = index2 - getPntrToArgument(j)->getShape()[0];
335 : break;
336 : }
337 13965 : for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
338 9105 : if( getPntrToArgument(j)->getRank()==2 ) {
339 18210 : for(int i=0; i<getNumberOfComponents(); ++i) {
340 9105 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
341 9105 : unsigned myind = base + getPntrToArgument(j)->getShape()[1]*index1 + ind2;
342 9105 : myvals.addDerivative( ostrn, myind, derivatives(i,j) );
343 9105 : myvals.updateIndex( ostrn, myind );
344 : }
345 : } else {
346 0 : for(int i=0; i<getNumberOfComponents(); ++i) {
347 0 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
348 0 : myvals.addDerivative( ostrn, base, derivatives(i,j) );
349 0 : myvals.updateIndex( ostrn, base );
350 : }
351 : }
352 9105 : base += getPntrToArgument(j)->getNumberOfValues();
353 : }
354 : }
355 : }
356 :
357 : template <class T>
358 226389 : void FunctionOfMatrix<T>::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
359 226389 : if( doNotCalculateDerivatives() ) return;
360 :
361 : unsigned argstart=myfunc.getArgStart();
362 71237 : if( actionInChain() && getConstPntrToComponent(0)->getRank()==2 ) {
363 : // This is triggered if we are outputting a matrix
364 624578 : for(int vv=0; vv<getNumberOfComponents(); ++vv) {
365 558183 : unsigned nmat = getConstPntrToComponent(vv)->getPositionInMatrixStash();
366 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); unsigned ntot_mat=0;
367 558183 : if( mat_indices.size()<nderivatives ) mat_indices.resize( nderivatives );
368 2701544 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
369 2143361 : if( !update_arguments[i] || getPntrToArgument(i)->getRank()==0 ) continue ;
370 : // Ensure we only store one lot of derivative indices
371 : bool found=false;
372 2168017 : for(unsigned j=0; j<i; ++j) {
373 1591516 : if( arg_deriv_starts[j]==arg_deriv_starts[i] ) { found=true; break; }
374 : }
375 2142277 : if( found ) continue;
376 :
377 576501 : if( stored_arguments[i] ) {
378 15483 : unsigned tbase = getPntrToArgument(i)->getShape()[1]*ind;
379 410507 : for(unsigned k=1; k<indices.size(); ++k) {
380 395024 : unsigned ind2 = indices[k] - getConstPntrToComponent(0)->getShape()[0];
381 395024 : mat_indices[ntot_mat + k - 1] = arg_deriv_starts[i] + tbase + ind2;
382 : }
383 15483 : ntot_mat += indices.size()-1;
384 : } else {
385 : unsigned istrn = getPntrToArgument(i)->getPositionInMatrixStash();
386 : std::vector<unsigned>& imat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
387 31848426 : for(unsigned k=0; k<myvals.getNumberOfMatrixRowDerivatives( istrn ); ++k) mat_indices[ntot_mat + k] = arg_deriv_starts[i] + imat_indices[k];
388 561018 : ntot_mat += myvals.getNumberOfMatrixRowDerivatives( istrn );
389 : }
390 : }
391 : myvals.setNumberOfMatrixRowDerivatives( nmat, ntot_mat );
392 : }
393 4842 : } else if( actionInChain() ) {
394 : // This is triggered if we are calculating a single scalar in the function
395 8822 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
396 : bool found=false;
397 4411 : for(unsigned j=0; j<i; ++j) {
398 0 : if( arg_deriv_starts[j]==arg_deriv_starts[i] ) { found=true; break; }
399 : }
400 4411 : if( found ) continue;
401 : unsigned istrn = getPntrToArgument(i)->getPositionInMatrixStash();
402 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
403 926766 : for(unsigned k=0; k<myvals.getNumberOfMatrixRowDerivatives( istrn ); ++k) {
404 1844710 : for(int j=0; j<getNumberOfComponents(); ++j) {
405 922355 : unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
406 922355 : myvals.updateIndex( ostrn, arg_deriv_starts[i] + mat_indices[k] );
407 : }
408 : }
409 : }
410 431 : } else if( getConstPntrToComponent(0)->getRank()==2 ) {
411 760 : for(int vv=0; vv<getNumberOfComponents(); ++vv) {
412 380 : unsigned nmat = getConstPntrToComponent(vv)->getPositionInMatrixStash();
413 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); unsigned ntot_mat=0;
414 380 : if( mat_indices.size()<nderivatives ) mat_indices.resize( nderivatives ); unsigned matderbase = 0;
415 986 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
416 606 : if( getPntrToArgument(i)->getRank()==0 ) continue ;
417 606 : unsigned ss = getPntrToArgument(i)->getShape()[1]; unsigned tbase = matderbase + ss*myvals.getTaskIndex();
418 9558 : for(unsigned k=0; k<ss; ++k) mat_indices[ntot_mat + k] = tbase + k;
419 606 : ntot_mat += ss; matderbase += getPntrToArgument(i)->getNumberOfValues();
420 : }
421 : myvals.setNumberOfMatrixRowDerivatives( nmat, ntot_mat );
422 : }
423 : }
424 : }
425 :
426 : template <class T>
427 495 : std::vector<unsigned> FunctionOfMatrix<T>::getValueShapeFromArguments() {
428 495 : unsigned argstart=myfunc.getArgStart(); std::vector<unsigned> shape(2); shape[0]=shape[1]=0;
429 1508 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
430 1013 : plumed_assert( getPntrToArgument(i)->getRank()==2 || getPntrToArgument(i)->getRank()==0 );
431 1013 : if( getPntrToArgument(i)->getRank()==2 ) {
432 948 : if( shape[0]>0 && (getPntrToArgument(i)->getShape()[0]!=shape[0] || getPntrToArgument(i)->getShape()[1]!=shape[1]) ) error("all matrices input should have the same shape");
433 948 : else if( shape[0]==0 ) { shape[0]=getPntrToArgument(i)->getShape()[0]; shape[1]=getPntrToArgument(i)->getShape()[1]; }
434 948 : plumed_assert( !getPntrToArgument(i)->hasDerivatives() );
435 : }
436 : }
437 495 : myfunc.setPrefactor( this, 1.0 ); return shape;
438 : }
439 :
440 : }
441 : }
442 : #endif
|