Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2020 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #ifndef __PLUMED_function_FunctionOfVector_h
23 : #define __PLUMED_function_FunctionOfVector_h
24 :
25 : #include "core/ActionWithVector.h"
26 : //#include "core/CollectFrames.h"
27 : #include "core/ActionSetup.h"
28 : #include "tools/Matrix.h"
29 : #include "Sum.h"
30 :
31 : namespace PLMD {
32 : namespace function {
33 :
34 : template <class T>
35 : class FunctionOfVector : public ActionWithVector {
36 : private:
37 : /// Do the calculation at the end of the run
38 : bool doAtEnd;
39 : /// Is this the first time we are doing the calc
40 : bool firststep;
41 : /// The function that is being computed
42 : T myfunc;
43 : /// The number of derivatives for this action
44 : unsigned nderivatives;
45 : /// A vector that tells us if we have stored the input value
46 : std::vector<bool> stored_arguments;
47 : public:
48 : static void registerKeywords(Keywords&);
49 : /// This method is used to run the calculation with functions such as highest/lowest and sort.
50 : /// It is static so we can reuse the functionality in FunctionOfMatrix
51 : static void runSingleTaskCalculation( const Value* arg, ActionWithValue* action, T& f );
52 : explicit FunctionOfVector(const ActionOptions&);
53 4660 : ~FunctionOfVector() {}
54 : std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
55 : /// Get the size of the task list at the end of the run
56 : unsigned getNumberOfFinalTasks();
57 : /// Check if derivatives are available
58 : void turnOnDerivatives() override;
59 : /// Get the number of derivatives for this action
60 : unsigned getNumberOfDerivatives() override ;
61 : /// Resize vectors that are the wrong size
62 : void prepare() override ;
63 : /// Check if all he actions are required
64 : void areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions );
65 : /// Get the label to write in the graph
66 20 : std::string writeInGraph() const override { return myfunc.getGraphInfo( getName() ); }
67 : /// This builds the task list for the action
68 : void calculate() override;
69 : /// This ensures that we create some bookeeping stuff during the first step
70 : void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override ;
71 : /// Calculate the function
72 : void performTask( const unsigned& current, MultiValue& myvals ) const override ;
73 : };
74 :
75 : template <class T>
76 4574 : void FunctionOfVector<T>::registerKeywords(Keywords& keys ) {
77 4574 : Action::registerKeywords(keys); ActionWithValue::registerKeywords(keys); ActionWithArguments::registerKeywords(keys); keys.use("ARG");
78 4574 : std::string name = keys.getDisplayName(); std::size_t und=name.find("_VECTOR"); keys.setDisplayName( name.substr(0,und) );
79 9148 : keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function. If the output is not periodic you must state this using PERIODIC=NO");
80 9148 : keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
81 4574 : T tfunc; tfunc.registerKeywords( keys );
82 9148 : if( keys.getDisplayName()=="SUM" ) {
83 2408 : keys.setValueDescription("the sum of all the elements in the input vector");
84 6740 : } else if( keys.getDisplayName()=="MEAN" ) {
85 714 : keys.setValueDescription("the mean of all the elements in the input vector");
86 6026 : } else if( keys.getDisplayName()=="HIGHEST" ) {
87 84 : keys.setValueDescription("the largest element of the input vector");
88 5942 : } else if( keys.getDisplayName()=="LOWEST" ) {
89 118 : keys.setValueDescription("the smallest element in the input vector");
90 5824 : } else if( keys.getDisplayName()=="SORT" ) {
91 24 : keys.setValueDescription("a vector that has been sorted into ascending order");
92 5800 : } else if( keys.outputComponentExists(".#!value") ) {
93 5780 : keys.setValueDescription("the vector obtained by doing an element-wise application of " + keys.getOutputComponentDescription(".#!value") + " to the input vectors");
94 : }
95 7080 : }
96 :
97 : template <class T>
98 2264 : FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao):
99 : Action(ao),
100 : ActionWithVector(ao),
101 2264 : doAtEnd(true),
102 2264 : firststep(true),
103 2264 : nderivatives(0)
104 : {
105 : // Get the shape of the output
106 2264 : std::vector<unsigned> shape(1); shape[0]=getNumberOfFinalTasks();
107 : // Read the input and do some checks
108 2264 : myfunc.read( this );
109 : // Create the task list
110 2121 : if( myfunc.doWithTasks() ) {
111 2224 : doAtEnd=false; if( shape[0]>0 ) done_in_chain=true;
112 40 : } else { plumed_assert( getNumberOfArguments()==1 ); done_in_chain=false; getPntrToArgument(0)->buildDataStore(); }
113 : // Get the names of the components
114 2264 : std::vector<std::string> components( keywords.getOutputComponents() );
115 : // Create the values to hold the output
116 56 : std::vector<std::string> str_ind( myfunc.getComponentsPerLabel() );
117 4528 : for(unsigned i=0; i<components.size(); ++i) {
118 8 : if( str_ind.size()>0 ) {
119 16 : std::string strcompn = components[i]; if( components[i]==".#!value" ) strcompn = "";
120 34 : for(unsigned j=0; j<str_ind.size(); ++j) {
121 52 : if( myfunc.zeroRank() ) addComponentWithDerivatives( strcompn + str_ind[j] );
122 0 : else addComponent( strcompn + str_ind[j], shape );
123 : }
124 2256 : } else if( components[i].find_first_of("_")!=std::string::npos ) {
125 0 : if( getNumberOfArguments()==1 && myfunc.zeroRank() ) addValueWithDerivatives();
126 0 : else if( getNumberOfArguments()==1 ) addValue( shape );
127 : else {
128 : unsigned argstart=myfunc.getArgStart();
129 0 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
130 0 : if( myfunc.zeroRank() ) addComponentWithDerivatives( getPntrToArgument(i)->getName() + components[i] );
131 0 : else addComponent( getPntrToArgument(i)->getName() + components[i], shape );
132 : }
133 : }
134 1635 : } else if( components[i]==".#!value" && myfunc.zeroRank() ) addValueWithDerivatives();
135 1446 : else if( components[i]==".#!value" ) addValue(shape);
136 0 : else if( myfunc.zeroRank() ) addComponentWithDerivatives( components[i] );
137 0 : else addComponent( components[i], shape );
138 : }
139 : // Check if we can turn off the derivatives when they are zero
140 1016 : if( myfunc.getDerivativeZeroIfValueIsZero() ) {
141 612 : for(int i=0; i<getNumberOfComponents(); ++i) getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero();
142 : }
143 : // Check if this is a timeseries
144 : unsigned argstart=myfunc.getArgStart();
145 : // for(unsigned i=argstart; i<getNumberOfArguments();++i) {
146 : // if( getPntrToArgument(i)->isTimeSeries() ) {
147 : // for(unsigned i=0; i<getNumberOfComponents(); ++i) getPntrToOutput(i)->makeHistoryDependent();
148 : // break;
149 : // }
150 : // }
151 : // Set the periodicities of the output components
152 2264 : myfunc.setPeriodicityForOutputs( this );
153 : // Check if we can put the function in a chain
154 6143 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
155 : // CollectFrames* ab=dynamic_cast<CollectFrames*>( getPntrToArgument(i)->getPntrToAction() );
156 : // if( ab && ab->hasClear() ) { doNotChain=true; getPntrToArgument(i)->buildDataStore( getLabel() ); }
157 : // No chains if we are using a sum or a mean
158 3879 : if( getPntrToArgument(i)->getRank()==0 ) {
159 246 : FunctionOfVector<Sum>* as = dynamic_cast<FunctionOfVector<Sum>*>( getPntrToArgument(i)->getPntrToAction() );
160 246 : if(as) done_in_chain=false;
161 : } else {
162 3633 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
163 3633 : if( !av ) done_in_chain=false;
164 : }
165 : }
166 : // Don't need to do the calculation in a chain if the input is constant
167 : bool allconstant=true;
168 2595 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
169 2275 : if( !getPntrToArgument(i)->isConstant() ) { allconstant=false; break; }
170 : }
171 2264 : if( allconstant ) done_in_chain=false;
172 2264 : nderivatives = buildArgumentStore(myfunc.getArgStart());
173 4528 : }
174 :
175 : template <class T>
176 5 : std::string FunctionOfVector<T>::getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const {
177 5 : if( getName().find("SORT")==std::string::npos ) return ActionWithValue::getOutputComponentDescription( cname, keys );
178 8 : if( getNumberOfArguments()==1 ) return "the " + cname + "th largest element of the vector " + getPntrToArgument(0)->getName();
179 4 : return "the " + cname + "th largest element in the input vectors";
180 : }
181 :
182 : template <class T>
183 6140 : void FunctionOfVector<T>::turnOnDerivatives() {
184 6140 : if( !getPntrToComponent(0)->isConstant() && !myfunc.derivativesImplemented() ) error("derivatives have not been implemended for " + getName() );
185 6140 : ActionWithValue::turnOnDerivatives(); myfunc.setup(this );
186 6140 : }
187 :
188 : template <class T>
189 53011 : unsigned FunctionOfVector<T>::getNumberOfDerivatives() {
190 53011 : return nderivatives;
191 : }
192 :
193 : template <class T>
194 185527 : void FunctionOfVector<T>::prepare() {
195 185527 : unsigned argstart = myfunc.getArgStart(); std::vector<unsigned> shape(1);
196 232505 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
197 232505 : if( getPntrToArgument(i)->getRank()==1 ) {
198 185527 : shape[0] = getPntrToArgument(i)->getShape()[0]; break;
199 : }
200 : }
201 371470 : for(unsigned i=0; i<getNumberOfComponents(); ++i) {
202 185943 : Value* myval = getPntrToComponent(i);
203 185943 : if( myval->getRank()==1 && myval->getShape()[0]!=shape[0] ) { myval->setShape(shape); }
204 : }
205 185527 : ActionWithVector::prepare();
206 185527 : }
207 :
208 : template <class T>
209 312538 : void FunctionOfVector<T>::setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) {
210 312538 : if( firststep ) {
211 2168 : stored_arguments.resize( getNumberOfArguments() );
212 2168 : std::string control = getFirstActionInChain()->getLabel();
213 5936 : for(unsigned i=0; i<stored_arguments.size(); ++i) {
214 3768 : if( getPntrToArgument(i)->isConstant() ) stored_arguments[i]=false;
215 3249 : else stored_arguments[i] = !getPntrToArgument(i)->ignoreStoredValue( control );
216 : }
217 2168 : firststep=false;
218 : }
219 312538 : ActionWithVector::setupStreamedComponents( headstr, nquants, nmat, maxcol, nbookeeping );
220 312538 : }
221 :
222 : template <class T>
223 6634172 : void FunctionOfVector<T>::performTask( const unsigned& current, MultiValue& myvals ) const {
224 6634172 : unsigned argstart=myfunc.getArgStart(); std::vector<double> args( getNumberOfArguments()-argstart);
225 6634172 : if( actionInChain() ) {
226 6832197 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
227 3923742 : if( getPntrToArgument(i)->getRank()==0 ) args[i-argstart] = getPntrToArgument(i)->get();
228 3905954 : else if( !getPntrToArgument(i)->valueHasBeenSet() ) args[i-argstart] = myvals.get( getPntrToArgument(i)->getPositionInStream() );
229 154930 : else args[i-argstart] = getPntrToArgument(i)->get( myvals.getTaskIndex() );
230 : }
231 : } else {
232 11032808 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
233 7307091 : if( getPntrToArgument(i)->getRank()==1 ) args[i-argstart]=getPntrToArgument(i)->get(current);
234 3242545 : else args[i-argstart] = getPntrToArgument(i)->get();
235 : }
236 : }
237 : // Calculate the function and its derivatives
238 6634172 : std::vector<double> vals( getNumberOfComponents() ); Matrix<double> derivatives( getNumberOfComponents(), args.size() );
239 6634172 : myfunc.calc( this, args, vals, derivatives );
240 : // And set the values
241 13268344 : for(unsigned i=0; i<vals.size(); ++i) myvals.addValue( getConstPntrToComponent(i)->getPositionInStream(), vals[i] );
242 : // Return if we are not computing derivatives
243 6634172 : if( doNotCalculateDerivatives() ) return;
244 : // And now compute the derivatives
245 : // Second condition here is only not true if actionInChain()==True if
246 : // input arguments the only non-constant objects in input are scalars.
247 : // In that case we can use the non chain version to calculate the derivatives
248 : // with respect to the scalar.
249 5721550 : if( actionInChain() ) {
250 5183456 : for(unsigned j=0; j<args.size(); ++j) {
251 8375 : unsigned istrn = getPntrToArgument(argstart+j)->getPositionInStream();
252 2886517 : if( stored_arguments[argstart+j] ) {
253 70 : unsigned task_index = myvals.getTaskIndex(); if( getPntrToArgument(argstart+j)->getRank()==0 ) task_index=0;
254 70 : myvals.addDerivative( istrn, task_index, 1.0 ); myvals.updateIndex( istrn, task_index );
255 : }
256 2886517 : unsigned arg_deriv_s = arg_deriv_starts[argstart+j];
257 114509648 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
258 111623131 : unsigned kind=myvals.getActiveIndex(istrn,k);
259 223246262 : for(int i=0; i<getNumberOfComponents(); ++i) {
260 111623131 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
261 111623131 : myvals.addDerivative( ostrn, arg_deriv_s + kind, derivatives(i,j)*myvals.getDerivative( istrn, kind ) );
262 : }
263 : }
264 : // Ensure we only store one lot of derivative indices
265 2886517 : bool found=false; ActionWithValue* aav=getPntrToArgument(argstart+j)->getPntrToAction();
266 2920439 : for(unsigned k=0; k<j; ++k) {
267 589642 : if( arg_deriv_starts[argstart+k]==arg_deriv_s ) {
268 555720 : if( getPntrToArgument(argstart+k)->getPntrToAction()!=aav ) {
269 386484 : ActionWithVector* av = dynamic_cast<ActionWithVector*>( getPntrToArgument(argstart+j)->getPntrToAction() );
270 386484 : if( av ) {
271 772968 : for(int i=0; i<getNumberOfComponents(); ++i) av->updateAdditionalIndices( getConstPntrToComponent(i)->getPositionInStream(), myvals );
272 : }
273 : }
274 : found=true; break;
275 : }
276 : }
277 555720 : if( found ) continue;
278 85327769 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
279 82996972 : unsigned kind=myvals.getActiveIndex(istrn,k);
280 165993944 : for(int i=0; i<getNumberOfComponents(); ++i) {
281 82996972 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
282 82996972 : myvals.updateIndex( ostrn, arg_deriv_s + kind );
283 : }
284 : }
285 : }
286 : } else {
287 : unsigned base=0;
288 10141540 : for(unsigned j=0; j<args.size(); ++j) {
289 6716929 : if( getPntrToArgument(argstart+j)->getRank()==1 ) {
290 7241248 : for(int i=0; i<getNumberOfComponents(); ++i) {
291 3620624 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
292 3620624 : myvals.addDerivative( ostrn, base+current, derivatives(i,j) );
293 3620624 : myvals.updateIndex( ostrn, base+current );
294 : }
295 : } else {
296 6192610 : for(int i=0; i<getNumberOfComponents(); ++i) {
297 3096305 : unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
298 3096305 : myvals.addDerivative( ostrn, base, derivatives(i,j) );
299 3096305 : myvals.updateIndex( ostrn, base );
300 : }
301 : }
302 6716929 : base += getPntrToArgument(argstart+j)->getNumberOfValues();
303 : }
304 : }
305 : }
306 :
307 : template <class T>
308 2264 : unsigned FunctionOfVector<T>::getNumberOfFinalTasks() {
309 : unsigned nelements=0, argstart=myfunc.getArgStart();
310 6143 : for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
311 3879 : plumed_assert( getPntrToArgument(i)->getRank()<2 );
312 3879 : if( getPntrToArgument(i)->getRank()==1 ) {
313 3633 : if( nelements>0 ) {
314 : // if( getPntrToArgument(i)->isTimeSeries() && getPntrToArgument(i)->getShape()[0]<nelements ) nelements=getPntrToArgument(i)->isTimeSeries();
315 : // else
316 1369 : if(getPntrToArgument(i)->getShape()[0]!=nelements ) error("all vectors input should have the same length");
317 2264 : } else if( nelements==0 ) nelements=getPntrToArgument(i)->getShape()[0];
318 3633 : plumed_assert( !getPntrToArgument(i)->hasDerivatives() );
319 : }
320 : }
321 : // The prefactor for average and sum is set here so the number of input scalars is guaranteed to be correct
322 777 : myfunc.setPrefactor( this, 1.0 );
323 2264 : return nelements;
324 : }
325 :
326 : template <class T>
327 12124 : void FunctionOfVector<T>::areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions ) {
328 12124 : if( task_reducing_actions.size()==0 ) return;
329 2221 : if( !myfunc.allComponentsRequired( getArguments(), task_reducing_actions ) ) task_reducing_actions.push_back(this);
330 : }
331 :
332 : template <class T>
333 5541 : void FunctionOfVector<T>::runSingleTaskCalculation( const Value* arg, ActionWithValue* action, T& f ) {
334 : // This is used if we are doing sorting actions on a single vector
335 5541 : unsigned nv = arg->getNumberOfValues(); std::vector<double> args( nv );
336 8198467 : for(unsigned i=0; i<nv; ++i) args[i] = arg->get(i);
337 5541 : std::vector<double> vals( action->getNumberOfComponents() ); Matrix<double> derivatives( action->getNumberOfComponents(), nv );
338 5541 : ActionWithArguments* aa=dynamic_cast<ActionWithArguments*>(action); plumed_assert( aa ); f.calc( aa, args, vals, derivatives );
339 11498 : for(unsigned i=0; i<vals.size(); ++i) action->copyOutput(i)->set( vals[i] );
340 : // Return if we are not computing derivatives
341 5541 : if( action->doNotCalculateDerivatives() ) return;
342 : // Now set the derivatives
343 198059 : for(unsigned j=0; j<nv; ++j) {
344 388720 : for(unsigned i=0; i<vals.size(); ++i) action->copyOutput(i)->setDerivative( j, derivatives(i,j) );
345 : }
346 : }
347 :
348 : template <class T>
349 184898 : void FunctionOfVector<T>::calculate() {
350 : // Everything is done elsewhere
351 184898 : if( actionInChain() ) return;
352 : // This is done if we are calculating a function of multiple cvs
353 80095 : if( !doAtEnd ) runAllTasks();
354 : // This is used if we are doing sorting actions on a single vector
355 5541 : else if( !myfunc.doWithTasks() ) runSingleTaskCalculation( getPntrToArgument(0), this, myfunc );
356 : }
357 :
358 : }
359 : }
360 : #endif
|