Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2020 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #ifndef __PLUMED_function_FunctionOfVector_h
23 : #define __PLUMED_function_FunctionOfVector_h
24 :
25 : #include "core/ActionWithVector.h"
26 : //#include "core/CollectFrames.h"
27 : #include "core/ActionSetup.h"
28 : #include "tools/Matrix.h"
29 : #include "Sum.h"
30 :
31 : namespace PLMD {
32 : namespace function {
33 :
34 : template <class T>
35 : class FunctionOfVector : public ActionWithVector {
36 : private:
37 : /// Do the calculation at the end of the run
38 : bool doAtEnd;
39 : /// Is this the first time we are doing the calc
40 : bool firststep;
41 : /// The function that is being computed
42 : T myfunc;
43 : /// The number of derivatives for this action
44 : unsigned nderivatives;
45 : /// A vector that tells us if we have stored the input value
46 : std::vector<bool> stored_arguments;
47 : public:
48 : static void registerKeywords(Keywords&);
49 : /// This method is used to run the calculation with functions such as highest/lowest and sort.
50 : /// It is static so we can reuse the functionality in FunctionOfMatrix
51 : static void runSingleTaskCalculation( const Value* arg, ActionWithValue* action, T& f );
52 : explicit FunctionOfVector(const ActionOptions&);
53 4696 : ~FunctionOfVector() {}
54 : std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
55 : /// Get the size of the task list at the end of the run
56 : unsigned getNumberOfFinalTasks();
57 : /// Check if derivatives are available
58 : void turnOnDerivatives() override;
59 : /// Get the number of derivatives for this action
60 : unsigned getNumberOfDerivatives() override ;
61 : /// Resize vectors that are the wrong size
62 : void prepare() override ;
63 : /// Check if all he actions are required
64 : void areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions );
65 : /// Get the label to write in the graph
66 20 : std::string writeInGraph() const override {
67 20 : return myfunc.getGraphInfo( getName() );
68 : }
69 : /// This builds the task list for the action
70 : void calculate() override;
71 : /// This ensures that we create some bookeeping stuff during the first step
72 : void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override ;
73 : /// Calculate the function
74 : void performTask( const unsigned& current, MultiValue& myvals ) const override ;
75 : };
76 :
77 : template <class T>
78 4604 : void FunctionOfVector<T>::registerKeywords(Keywords& keys ) {
79 4604 : Action::registerKeywords(keys);
80 4604 : ActionWithValue::registerKeywords(keys);
81 4604 : ActionWithArguments::registerKeywords(keys);
82 4604 : std::string name = keys.getDisplayName();
83 4604 : std::size_t und=name.find("_VECTOR");
84 4604 : keys.setDisplayName( name.substr(0,und) );
85 9208 : keys.addInputKeyword("compulsory","ARG","scalar/vector","the labels of the scalar and vector that on which the function is being calculated elementwise");
86 4604 : keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function. If the output is not periodic you must state this using PERIODIC=NO");
87 4604 : keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
88 2043 : T tfunc;
89 4604 : tfunc.registerKeywords( keys );
90 9208 : if( keys.getDisplayName()=="SUM" ) {
91 2432 : keys.setValueDescription("scalar","the sum of all the elements in the input vector");
92 6776 : } else if( keys.getDisplayName()=="MEAN" ) {
93 714 : keys.setValueDescription("scalar","the mean of all the elements in the input vector");
94 6062 : } else if( keys.getDisplayName()=="HIGHEST" ) {
95 84 : keys.setValueDescription("scalar/vector","the largest element of the input vector if one vector specified. If multiple vectors of the same size specified the largest elements of these vector computed elementwise.");
96 5978 : } else if( keys.getDisplayName()=="LOWEST" ) {
97 130 : keys.setValueDescription("scalar/vector","the smallest element in the input vector if one vector specified. If multiple vectors of the same size specified the largest elements of these vector computed elementwise.");
98 5848 : } else if( keys.getDisplayName()=="SORT" ) {
99 24 : keys.setValueDescription("vector","a vector that has been sorted into ascending order");
100 5824 : } else if( keys.outputComponentExists(".#!value") ) {
101 5804 : keys.setValueDescription("vector","the vector obtained by doing an element-wise application of " + keys.getOutputComponentDescription(".#!value") + " to the input vectors");
102 : }
103 7110 : }
104 :
105 : template <class T>
106 2279 : FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao):
107 : Action(ao),
108 : ActionWithVector(ao),
109 2279 : doAtEnd(true),
110 2279 : firststep(true),
111 2279 : nderivatives(0) {
112 : // Get the shape of the output
113 2279 : std::vector<unsigned> shape(1);
114 2279 : shape[0]=getNumberOfFinalTasks();
115 : // Read the input and do some checks
116 2279 : myfunc.read( this );
117 : // Create the task list
118 2136 : if( myfunc.doWithTasks() ) {
119 2239 : doAtEnd=false;
120 2239 : if( shape[0]>0 ) {
121 2176 : done_in_chain=true;
122 : }
123 : } else {
124 40 : plumed_assert( getNumberOfArguments()==1 );
125 40 : done_in_chain=false;
126 40 : getPntrToArgument(0)->buildDataStore();
127 : }
128 : // Get the names of the components
129 2279 : std::vector<std::string> components( keywords.getOutputComponents() );
130 : // Create the values to hold the output
131 59 : std::vector<std::string> str_ind( myfunc.getComponentsPerLabel() );
132 4558 : for(unsigned i=0; i<components.size(); ++i) {
133 8 : if( str_ind.size()>0 ) {
134 8 : std::string strcompn = components[i];
135 8 : if( components[i]==".#!value" ) {
136 : strcompn = "";
137 : }
138 34 : for(unsigned j=0; j<str_ind.size(); ++j) {
139 7 : if( myfunc.zeroRank() ) {
140 52 : addComponentWithDerivatives( strcompn + str_ind[j] );
141 : } else {
142 0 : addComponent( strcompn + str_ind[j], shape );
143 : }
144 : }
145 2271 : } else if( components[i].find_first_of("_")!=std::string::npos ) {
146 0 : if( getNumberOfArguments()==1 && myfunc.zeroRank() ) {
147 0 : addValueWithDerivatives();
148 0 : } else if( getNumberOfArguments()==1 ) {
149 0 : addValue( shape );
150 : } else {
151 : unsigned argstart=myfunc.getArgStart();
152 0 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
153 0 : if( myfunc.zeroRank() ) {
154 0 : addComponentWithDerivatives( getPntrToArgument(i)->getName() + components[i] );
155 : } else {
156 0 : addComponent( getPntrToArgument(i)->getName() + components[i], shape );
157 : }
158 : }
159 : }
160 834 : } else if( components[i]==".#!value" && myfunc.zeroRank() ) {
161 1632 : addValueWithDerivatives();
162 1455 : } else if( components[i]==".#!value" ) {
163 1455 : addValue(shape);
164 0 : } else if( myfunc.zeroRank() ) {
165 0 : addComponentWithDerivatives( components[i] );
166 : } else {
167 0 : addComponent( components[i], shape );
168 : }
169 : }
170 : // Check if we can turn off the derivatives when they are zero
171 1016 : if( myfunc.getDerivativeZeroIfValueIsZero() ) {
172 624 : for(int i=0; i<getNumberOfComponents(); ++i) {
173 312 : getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero();
174 : }
175 : }
176 : // Check if this is a timeseries
177 : unsigned argstart=myfunc.getArgStart();
178 : // for(unsigned i=argstart; i<getNumberOfArguments();++i) {
179 : // if( getPntrToArgument(i)->isTimeSeries() ) {
180 : // for(unsigned i=0; i<getNumberOfComponents(); ++i) getPntrToOutput(i)->makeHistoryDependent();
181 : // break;
182 : // }
183 : // }
184 : // Set the periodicities of the output components
185 2279 : myfunc.setPeriodicityForOutputs( this );
186 : // Check if we can put the function in a chain
187 6176 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
188 : // CollectFrames* ab=dynamic_cast<CollectFrames*>( getPntrToArgument(i)->getPntrToAction() );
189 : // if( ab && ab->hasClear() ) { doNotChain=true; getPntrToArgument(i)->buildDataStore( getLabel() ); }
190 : // No chains if we are using a sum or a mean
191 3897 : if( getPntrToArgument(i)->getRank()==0 ) {
192 246 : FunctionOfVector<Sum>* as = dynamic_cast<FunctionOfVector<Sum>*>( getPntrToArgument(i)->getPntrToAction() );
193 246 : if(as) {
194 48 : done_in_chain=false;
195 : }
196 : } else {
197 3651 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
198 3651 : if( !av ) {
199 444 : done_in_chain=false;
200 : }
201 : }
202 : }
203 : // Don't need to do the calculation in a chain if the input is constant
204 : bool allconstant=true;
205 2610 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
206 2290 : if( !getPntrToArgument(i)->isConstant() ) {
207 : allconstant=false;
208 : break;
209 : }
210 : }
211 2279 : if( allconstant ) {
212 320 : done_in_chain=false;
213 : }
214 2279 : nderivatives = buildArgumentStore(myfunc.getArgStart());
215 4558 : }
216 :
217 : template <class T>
218 5 : std::string FunctionOfVector<T>::getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const {
219 5 : if( getName().find("SORT")==std::string::npos ) {
220 0 : return ActionWithValue::getOutputComponentDescription( cname, keys );
221 : }
222 5 : if( getNumberOfArguments()==1 ) {
223 6 : return "the " + cname + "th largest element of the vector " + getPntrToArgument(0)->getName();
224 : }
225 4 : return "the " + cname + "th largest element in the input vectors";
226 : }
227 :
228 : template <class T>
229 6180 : void FunctionOfVector<T>::turnOnDerivatives() {
230 6180 : if( !getPntrToComponent(0)->isConstant() && !myfunc.derivativesImplemented() ) {
231 0 : error("derivatives have not been implemended for " + getName() );
232 : }
233 6180 : ActionWithValue::turnOnDerivatives();
234 6180 : myfunc.setup(this );
235 6180 : }
236 :
237 : template <class T>
238 53087 : unsigned FunctionOfVector<T>::getNumberOfDerivatives() {
239 53087 : return nderivatives;
240 : }
241 :
242 : template <class T>
243 185617 : void FunctionOfVector<T>::prepare() {
244 : unsigned argstart = myfunc.getArgStart();
245 185617 : std::vector<unsigned> shape(1);
246 232595 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
247 232595 : if( getPntrToArgument(i)->getRank()==1 ) {
248 185617 : shape[0] = getPntrToArgument(i)->getShape()[0];
249 185617 : break;
250 : }
251 : }
252 371650 : for(unsigned i=0; i<getNumberOfComponents(); ++i) {
253 186033 : Value* myval = getPntrToComponent(i);
254 186033 : if( myval->getRank()==1 && myval->getShape()[0]!=shape[0] ) {
255 47 : myval->setShape(shape);
256 : }
257 : }
258 185617 : ActionWithVector::prepare();
259 185617 : }
260 :
261 : template <class T>
262 312628 : void FunctionOfVector<T>::setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) {
263 312628 : if( firststep ) {
264 2183 : stored_arguments.resize( getNumberOfArguments() );
265 2183 : std::string control = getFirstActionInChain()->getLabel();
266 5969 : for(unsigned i=0; i<stored_arguments.size(); ++i) {
267 3786 : if( getPntrToArgument(i)->isConstant() ) {
268 : stored_arguments[i]=false;
269 : } else {
270 3267 : stored_arguments[i] = !getPntrToArgument(i)->ignoreStoredValue( control );
271 : }
272 : }
273 2183 : firststep=false;
274 : }
275 312628 : ActionWithVector::setupStreamedComponents( headstr, nquants, nmat, maxcol, nbookeeping );
276 312628 : }
277 :
278 : template <class T>
279 6673610 : void FunctionOfVector<T>::performTask( const unsigned& current, MultiValue& myvals ) const {
280 8375 : unsigned argstart=myfunc.getArgStart();
281 6673610 : std::vector<double> args( getNumberOfArguments()-argstart);
282 6673610 : if( actionInChain() ) {
283 6920163 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
284 3972270 : if( getPntrToArgument(i)->getRank()==0 ) {
285 17788 : args[i-argstart] = getPntrToArgument(i)->get();
286 : } else if( !getPntrToArgument(i)->valueHasBeenSet() ) {
287 3799552 : args[i-argstart] = myvals.get( getPntrToArgument(i)->getPositionInStream() );
288 : } else {
289 154930 : args[i-argstart] = getPntrToArgument(i)->get( myvals.getTaskIndex() );
290 : }
291 : }
292 : } else {
293 11032808 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
294 7307091 : if( getPntrToArgument(i)->getRank()==1 ) {
295 4064546 : args[i-argstart]=getPntrToArgument(i)->get(current);
296 : } else {
297 3242545 : args[i-argstart] = getPntrToArgument(i)->get();
298 : }
299 : }
300 : }
301 : // Calculate the function and its derivatives
302 6673610 : std::vector<double> vals( getNumberOfComponents() );
303 6673610 : Matrix<double> derivatives( getNumberOfComponents(), args.size() );
304 6673610 : myfunc.calc( this, args, vals, derivatives );
305 : // And set the values
306 13347220 : for(unsigned i=0; i<vals.size(); ++i) {
307 6673610 : myvals.addValue( getConstPntrToComponent(i)->getPositionInStream(), vals[i] );
308 : }
309 : // Return if we are not computing derivatives
310 6673610 : if( doNotCalculateDerivatives() ) {
311 : return;
312 : }
313 : // And now compute the derivatives
314 : // Second condition here is only not true if actionInChain()==True if
315 : // input arguments the only non-constant objects in input are scalars.
316 : // In that case we can use the non chain version to calculate the derivatives
317 : // with respect to the scalar.
318 5760988 : if( actionInChain() ) {
319 5271422 : for(unsigned j=0; j<args.size(); ++j) {
320 8375 : unsigned istrn = getPntrToArgument(argstart+j)->getPositionInStream();
321 2935045 : if( stored_arguments[argstart+j] ) {
322 70 : unsigned task_index = myvals.getTaskIndex();
323 70 : if( getPntrToArgument(argstart+j)->getRank()==0 ) {
324 : task_index=0;
325 : }
326 70 : myvals.addDerivative( istrn, task_index, 1.0 );
327 70 : myvals.updateIndex( istrn, task_index );
328 : }
329 2935045 : unsigned arg_deriv_s = arg_deriv_starts[argstart+j];
330 119362448 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
331 116427403 : unsigned kind=myvals.getActiveIndex(istrn,k);
332 232854806 : for(int i=0; i<getNumberOfComponents(); ++i) {
333 116427403 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
334 116427403 : myvals.addDerivative( ostrn, arg_deriv_s + kind, derivatives(i,j)*myvals.getDerivative( istrn, kind ) );
335 : }
336 : }
337 : // Ensure we only store one lot of derivative indices
338 : bool found=false;
339 2935045 : ActionWithValue* aav=getPntrToArgument(argstart+j)->getPntrToAction();
340 2968967 : for(unsigned k=0; k<j; ++k) {
341 598732 : if( arg_deriv_starts[argstart+k]==arg_deriv_s ) {
342 564810 : if( getPntrToArgument(argstart+k)->getPntrToAction()!=aav ) {
343 386484 : ActionWithVector* av = dynamic_cast<ActionWithVector*>( getPntrToArgument(argstart+j)->getPntrToAction() );
344 386484 : if( av ) {
345 772968 : for(int i=0; i<getNumberOfComponents(); ++i) {
346 386484 : av->updateAdditionalIndices( getConstPntrToComponent(i)->getPositionInStream(), myvals );
347 : }
348 : }
349 : }
350 : found=true;
351 : break;
352 : }
353 : }
354 564810 : if( found ) {
355 564810 : continue;
356 : }
357 89271569 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
358 86901334 : unsigned kind=myvals.getActiveIndex(istrn,k);
359 173802668 : for(int i=0; i<getNumberOfComponents(); ++i) {
360 86901334 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
361 86901334 : myvals.updateIndex( ostrn, arg_deriv_s + kind );
362 : }
363 : }
364 : }
365 : } else {
366 : unsigned base=0;
367 10141540 : for(unsigned j=0; j<args.size(); ++j) {
368 6716929 : if( getPntrToArgument(argstart+j)->getRank()==1 ) {
369 7241248 : for(int i=0; i<getNumberOfComponents(); ++i) {
370 3620624 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
371 3620624 : myvals.addDerivative( ostrn, base+current, derivatives(i,j) );
372 3620624 : myvals.updateIndex( ostrn, base+current );
373 : }
374 : } else {
375 6192610 : for(int i=0; i<getNumberOfComponents(); ++i) {
376 3096305 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
377 3096305 : myvals.addDerivative( ostrn, base, derivatives(i,j) );
378 3096305 : myvals.updateIndex( ostrn, base );
379 : }
380 : }
381 6716929 : base += getPntrToArgument(argstart+j)->getNumberOfValues();
382 : }
383 : }
384 : }
385 :
386 : template <class T>
387 2279 : unsigned FunctionOfVector<T>::getNumberOfFinalTasks() {
388 : unsigned nelements=0, argstart=myfunc.getArgStart();
389 6176 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
390 3897 : plumed_assert( getPntrToArgument(i)->getRank()<2 );
391 3897 : if( getPntrToArgument(i)->getRank()==1 ) {
392 3651 : if( nelements>0 ) {
393 : // if( getPntrToArgument(i)->isTimeSeries() && getPntrToArgument(i)->getShape()[0]<nelements ) nelements=getPntrToArgument(i)->isTimeSeries();
394 : // else
395 1372 : if(getPntrToArgument(i)->getShape()[0]!=nelements ) {
396 0 : error("all vectors input should have the same length");
397 : }
398 : } else if( nelements==0 ) {
399 2279 : nelements=getPntrToArgument(i)->getShape()[0];
400 : }
401 3651 : plumed_assert( !getPntrToArgument(i)->hasDerivatives() );
402 : }
403 : }
404 : // The prefactor for average and sum is set here so the number of input scalars is guaranteed to be correct
405 783 : myfunc.setPrefactor( this, 1.0 );
406 2279 : return nelements;
407 : }
408 :
409 : template <class T>
410 12151 : void FunctionOfVector<T>::areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions ) {
411 12151 : if( task_reducing_actions.size()==0 ) {
412 : return;
413 : }
414 2248 : if( !myfunc.allComponentsRequired( getArguments(), task_reducing_actions ) ) {
415 1765 : task_reducing_actions.push_back(this);
416 : }
417 : }
418 :
419 : template <class T>
420 5541 : void FunctionOfVector<T>::runSingleTaskCalculation( const Value* arg, ActionWithValue* action, T& f ) {
421 : // This is used if we are doing sorting actions on a single vector
422 5541 : unsigned nv = arg->getNumberOfValues();
423 5541 : std::vector<double> args( nv );
424 8198467 : for(unsigned i=0; i<nv; ++i) {
425 8192926 : args[i] = arg->get(i);
426 : }
427 5541 : std::vector<double> vals( action->getNumberOfComponents() );
428 5541 : Matrix<double> derivatives( action->getNumberOfComponents(), nv );
429 5541 : ActionWithArguments* aa=dynamic_cast<ActionWithArguments*>(action);
430 5541 : plumed_assert( aa );
431 5541 : f.calc( aa, args, vals, derivatives );
432 11498 : for(unsigned i=0; i<vals.size(); ++i) {
433 5957 : action->copyOutput(i)->set( vals[i] );
434 : }
435 : // Return if we are not computing derivatives
436 5541 : if( action->doNotCalculateDerivatives() ) {
437 : return;
438 : }
439 : // Now set the derivatives
440 198059 : for(unsigned j=0; j<nv; ++j) {
441 388720 : for(unsigned i=0; i<vals.size(); ++i) {
442 195074 : action->copyOutput(i)->setDerivative( j, derivatives(i,j) );
443 : }
444 : }
445 : }
446 :
447 : template <class T>
448 184988 : void FunctionOfVector<T>::calculate() {
449 : // Everything is done elsewhere
450 184988 : if( actionInChain() ) {
451 : return;
452 : }
453 : // This is done if we are calculating a function of multiple cvs
454 80095 : if( !doAtEnd ) {
455 74554 : runAllTasks();
456 : }
457 : // This is used if we are doing sorting actions on a single vector
458 5541 : else if( !myfunc.doWithTasks() ) {
459 5541 : runSingleTaskCalculation( getPntrToArgument(0), this, myfunc );
460 : }
461 : }
462 :
463 : }
464 : }
465 : #endif
|