Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2012-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "ActionWithVessel.h"
23 : #include "tools/Communicator.h"
24 : #include "Vessel.h"
25 : #include "ShortcutVessel.h"
26 : #include "StoreDataVessel.h"
27 : #include "VesselRegister.h"
28 : #include "BridgeVessel.h"
29 : #include "FunctionVessel.h"
30 : #include "StoreDataVessel.h"
31 : #include "tools/OpenMP.h"
32 : #include "tools/Stopwatch.h"
33 :
34 : namespace PLMD {
35 : namespace vesselbase {
36 :
37 682 : void ActionWithVessel::registerKeywords(Keywords& keys) {
38 1364 : keys.add("hidden","TOL","this keyword can be used to speed up your calculation. When accumulating sums in which the individual "
39 : "terms are numbers in between zero and one it is assumed that terms less than a certain tolerance "
40 : "make only a small contribution to the sum. They can thus be safely ignored as can the the derivatives "
41 : "wrt these small quantities.");
42 1364 : keys.add("hidden","MAXDERIVATIVES","The maximum number of derivatives that can be used when storing data. This controls when "
43 : "we have to start using lowmem");
44 1364 : keys.addFlag("SERIAL",false,"do the calculation in serial. Do not use MPI");
45 1364 : keys.addFlag("LOWMEM",false,"lower the memory requirements");
46 1364 : keys.addFlag("TIMINGS",false,"output information on the timings of the various parts of the calculation");
47 1364 : keys.reserveFlag("HIGHMEM",false,"use a more memory intensive version of this collective variable");
48 682 : keys.add( vesselRegister().getKeywords() );
49 682 : }
50 :
51 569 : ActionWithVessel::ActionWithVessel(const ActionOptions&ao):
52 : Action(ao),
53 569 : serial(false),
54 569 : lowmem(false),
55 569 : noderiv(true),
56 569 : actionIsBridged(false),
57 569 : nactive_tasks(0),
58 569 : dertime_can_be_off(false),
59 569 : dertime(true),
60 569 : contributorsAreUnlocked(false),
61 569 : weightHasDerivatives(false),
62 569 : mydata(NULL)
63 : {
64 569 : maxderivatives=309; parse("MAXDERIVATIVES",maxderivatives);
65 1674 : if( keywords.exists("SERIAL") ) parseFlag("SERIAL",serial);
66 33 : else serial=true;
67 569 : if(serial)log.printf(" doing calculation in serial\n");
68 1138 : if( keywords.exists("LOWMEM") ) {
69 1074 : plumed_assert( !keywords.exists("HIGHMEM") );
70 537 : parseFlag("LOWMEM",lowmem);
71 537 : if(lowmem) {
72 29 : log.printf(" lowering memory requirements\n");
73 29 : dertime_can_be_off=true;
74 : }
75 : }
76 1138 : if( keywords.exists("HIGHMEM") ) {
77 46 : plumed_assert( !keywords.exists("LOWMEM") );
78 23 : bool highmem; parseFlag("HIGHMEM",highmem);
79 23 : lowmem=!highmem;
80 23 : if(!lowmem) log.printf(" increasing the memory requirements\n");
81 : }
82 569 : tolerance=nl_tolerance=epsilon;
83 1596 : if( keywords.exists("TOL") ) parse("TOL",tolerance);
84 569 : if( tolerance>epsilon) {
85 6 : log.printf(" Ignoring contributions less than %f \n",tolerance);
86 : }
87 569 : parseFlag("TIMINGS",timers);
88 569 : stopwatch.start(); stopwatch.pause();
89 569 : }
90 :
91 569 : ActionWithVessel::~ActionWithVessel() {
92 569 : stopwatch.start(); stopwatch.stop();
93 569 : if(timers) {
94 0 : log.printf("timings for action %s with label %s \n", getName().c_str(), getLabel().c_str() );
95 0 : log<<stopwatch;
96 : }
97 569 : }
98 :
99 357 : void ActionWithVessel::addVessel( const std::string& name, const std::string& input, const int numlab ) {
100 357 : VesselOptions da(name,"",numlab,input,this);
101 714 : auto vv=vesselRegister().create(name,da);
102 357 : FunctionVessel* fv=dynamic_cast<FunctionVessel*>(vv.get());
103 357 : if( fv ) {
104 327 : std::string mylabel=Vessel::transformName( name );
105 327 : plumed_massert( keywords.outputComponentExists(mylabel,false), "a description of the value calculated by vessel " + name + " has not been added to the manual");
106 : }
107 357 : addVessel(std::move(vv));
108 357 : }
109 :
110 553 : void ActionWithVessel::addVessel( std::unique_ptr<Vessel> vv_ptr ) {
111 :
112 : // In the original code, the dynamically casted pointer was deleted here.
113 : // Now that vv_ptr is a unique_ptr, the object will be deleted automatically when
114 : // exiting this routine.
115 553 : if(dynamic_cast<ShortcutVessel*>(vv_ptr.get())) return;
116 :
117 542 : vv_ptr->checkRead();
118 :
119 542 : StoreDataVessel* mm=dynamic_cast<StoreDataVessel*>( vv_ptr.get() );
120 542 : if( mydata && mm ) error("cannot have more than one StoreDataVessel in one action");
121 542 : else if( mm ) mydata=mm;
122 411 : else dertime_can_be_off=false;
123 :
124 : // Ownership is transferred to functions
125 542 : functions.emplace_back(std::move(vv_ptr));
126 : }
127 :
128 46 : BridgeVessel* ActionWithVessel::addBridgingVessel( ActionWithVessel* tome ) {
129 92 : VesselOptions da("","",0,"",this);
130 46 : auto bv=Tools::make_unique<BridgeVessel>(da);
131 46 : bv->setOutputAction( tome );
132 46 : tome->actionIsBridged=true; dertime_can_be_off=false;
133 : // store this pointer in order to return it later.
134 : // notice that I cannot access this with functions.tail().get()
135 : // since functions contains pointers to a different class (Vessel)
136 : auto toBeReturned=bv.get();
137 46 : functions.emplace_back( std::move(bv) );
138 46 : resizeFunctions();
139 46 : return toBeReturned;
140 46 : }
141 :
142 192 : StoreDataVessel* ActionWithVessel::buildDataStashes( ActionWithVessel* actionThatUses ) {
143 192 : if(mydata) {
144 87 : if( actionThatUses ) mydata->addActionThatUses( actionThatUses );
145 87 : return mydata;
146 : }
147 :
148 210 : VesselOptions da("","",0,"",this);
149 105 : auto mm=Tools::make_unique<StoreDataVessel>(da);
150 105 : if( actionThatUses ) mm->addActionThatUses( actionThatUses );
151 105 : addVessel(std::move(mm));
152 :
153 : // Make sure resizing of vessels is done
154 105 : resizeFunctions();
155 :
156 105 : return mydata;
157 105 : }
158 :
159 12661273 : void ActionWithVessel::addTaskToList( const unsigned& taskCode ) {
160 12661273 : fullTaskList.push_back( taskCode ); taskFlags.push_back(0);
161 12661273 : plumed_assert( fullTaskList.size()==taskFlags.size() );
162 12661273 : }
163 :
164 416 : void ActionWithVessel::readVesselKeywords() {
165 : // Set maxderivatives if it is too big
166 416 : if( maxderivatives>getNumberOfDerivatives() ) maxderivatives=getNumberOfDerivatives();
167 :
168 : // Loop over all keywords find the vessels and create appropriate functions
169 9678 : for(unsigned i=0; i<keywords.size(); ++i) {
170 9262 : std::string thiskey,input; thiskey=keywords.getKeyword(i);
171 : // Check if this is a key for a vessel
172 18524 : if( vesselRegister().check(thiskey) ) {
173 5504 : plumed_assert( keywords.style(thiskey,"vessel") );
174 2752 : bool dothis=false; parseFlag(thiskey,dothis);
175 2752 : if(dothis) addVessel( thiskey, input );
176 :
177 2752 : parse(thiskey,input);
178 2752 : if(input.size()!=0) {
179 124 : addVessel( thiskey, input );
180 : } else {
181 2628 : for(unsigned i=1;; ++i) {
182 2653 : if( !parseNumbered(thiskey,i,input) ) break;
183 25 : std::string ss; Tools::convert(i,ss);
184 25 : addVessel( thiskey, input, i );
185 : input.clear();
186 25 : }
187 : }
188 : }
189 : }
190 :
191 : // Make sure all vessels have had been resized at start
192 416 : if( functions.size()>0 ) resizeFunctions();
193 416 : }
194 :
195 1374 : void ActionWithVessel::resizeFunctions() {
196 3545 : for(unsigned i=0; i<functions.size(); ++i) functions[i]->resize();
197 1374 : }
198 :
199 812 : void ActionWithVessel::needsDerivatives() {
200 : // Turn on the derivatives and resize
201 812 : noderiv=false; resizeFunctions();
202 : // Setting contributors unlocked here ensures that link cells are ignored
203 812 : contributorsAreUnlocked=true; contributorsAreUnlocked=false;
204 : // And turn on the derivatives in all actions on which we are dependent
205 1110 : for(unsigned i=0; i<getDependencies().size(); ++i) {
206 298 : ActionWithVessel* vv=dynamic_cast<ActionWithVessel*>( getDependencies()[i] );
207 298 : if(vv) vv->needsDerivatives();
208 : }
209 812 : }
210 :
211 3315 : void ActionWithVessel::lockContributors() {
212 3315 : nactive_tasks = 0;
213 18492694 : for(unsigned i=0; i<fullTaskList.size(); ++i) {
214 18489379 : if( taskFlags[i]>0 ) nactive_tasks++;
215 : }
216 :
217 : unsigned n=0;
218 3315 : partialTaskList.resize( nactive_tasks );
219 3315 : indexOfTaskInFullList.resize( nactive_tasks );
220 18492694 : for(unsigned i=0; i<fullTaskList.size(); ++i) {
221 : // Deactivate sets inactive tasks to number not equal to zero
222 18489379 : if( taskFlags[i]>0 ) {
223 5482378 : partialTaskList[n] = fullTaskList[i];
224 5482378 : indexOfTaskInFullList[n]=i;
225 5482378 : n++;
226 : }
227 : }
228 : plumed_dbg_assert( n==nactive_tasks );
229 8193 : for(unsigned i=0; i<functions.size(); ++i) {
230 4878 : BridgeVessel* bb = dynamic_cast<BridgeVessel*>( functions[i].get() );
231 4878 : if( bb ) bb->copyTaskFlags();
232 : }
233 : // Resize mydata to accommodate all active tasks
234 3315 : if( mydata ) mydata->resize();
235 3315 : contributorsAreUnlocked=false;
236 3315 : }
237 :
238 3315 : void ActionWithVessel::deactivateAllTasks() {
239 3315 : contributorsAreUnlocked=true; nactive_tasks = 0;
240 3315 : taskFlags.assign(taskFlags.size(),0);
241 3315 : }
242 :
243 213721 : bool ActionWithVessel::taskIsCurrentlyActive( const unsigned& index ) const {
244 213721 : plumed_dbg_assert( index<taskFlags.size() ); return (taskFlags[index]>0);
245 : }
246 :
247 22990 : void ActionWithVessel::doJobsRequiredBeforeTaskList() {
248 : // Do any preparatory stuff for functions
249 58257 : for(unsigned j=0; j<functions.size(); ++j) functions[j]->prepare();
250 22990 : }
251 :
252 23503 : unsigned ActionWithVessel::getSizeOfBuffer( unsigned& bufsize ) {
253 59446 : for(unsigned i=0; i<functions.size(); ++i) functions[i]->setBufferStart( bufsize );
254 23503 : if( buffer.size()!=bufsize ) buffer.resize( bufsize );
255 23503 : if( mydata ) {
256 : unsigned dsize=mydata->getSizeOfDerivativeList();
257 2554 : if( der_list.size()!=dsize ) der_list.resize( dsize );
258 : }
259 23503 : return bufsize;
260 : }
261 :
262 19965 : void ActionWithVessel::runAllTasks() {
263 19965 : plumed_massert( !contributorsAreUnlocked && functions.size()>0, "you must have a call to readVesselKeywords somewhere" );
264 19965 : unsigned stride=comm.Get_size();
265 19965 : unsigned rank=comm.Get_rank();
266 19965 : if(serial) { stride=1; rank=0; }
267 :
268 : // Make sure jobs are done
269 19965 : if(timers) stopwatch.start("1 Prepare Tasks");
270 19965 : doJobsRequiredBeforeTaskList();
271 19965 : if(timers) stopwatch.stop("1 Prepare Tasks");
272 :
273 : // Get number of threads for OpenMP
274 19965 : unsigned nt=OpenMP::getNumThreads();
275 19965 : if( nt*stride*2>nactive_tasks || !threadSafe()) nt=1;
276 :
277 : // Get size for buffer
278 19965 : unsigned bsize=0, bufsize=getSizeOfBuffer( bsize );
279 : // Clear buffer
280 19965 : buffer.assign( buffer.size(), 0.0 );
281 : // Switch off calculation of derivatives in main loop
282 19965 : if( dertime_can_be_off ) dertime=false;
283 :
284 19965 : if(timers) stopwatch.start("2 Loop over tasks");
285 19965 : #pragma omp parallel num_threads(nt)
286 : {
287 : std::vector<double> omp_buffer;
288 : if( nt>1 ) omp_buffer.resize( bufsize, 0.0 );
289 : MultiValue myvals( getNumberOfQuantities(), getNumberOfDerivatives() );
290 : MultiValue bvals( getNumberOfQuantities(), getNumberOfDerivatives() );
291 : myvals.clearAll(); bvals.clearAll();
292 :
293 : #pragma omp for nowait schedule(dynamic)
294 : for(unsigned i=rank; i<nactive_tasks; i+=stride) {
295 : // Calculate the stuff in the loop for this action
296 : performTask( indexOfTaskInFullList[i], partialTaskList[i], myvals );
297 :
298 : // Check for conditions that allow us to just to skip the calculation
299 : // the condition is that the weight of the contribution is low
300 : // N.B. Here weights are assumed to be between zero and one
301 : if( myvals.get(0)<tolerance ) {
302 : // Clear the derivatives
303 : myvals.clearAll();
304 : continue;
305 : }
306 :
307 : // Now calculate all the functions
308 : // If the contribution of this quantity is very small at neighbour list time ignore it
309 : // until next neighbour list time
310 : if( nt>1 ) {
311 : calculateAllVessels( indexOfTaskInFullList[i], myvals, bvals, omp_buffer, der_list );
312 : } else {
313 : calculateAllVessels( indexOfTaskInFullList[i], myvals, bvals, buffer, der_list );
314 : }
315 :
316 : // Clear the value
317 : myvals.clearAll();
318 : }
319 : #pragma omp critical
320 : if(nt>1) for(unsigned i=0; i<bufsize; ++i) buffer[i]+=omp_buffer[i];
321 : }
322 19965 : if(timers) stopwatch.stop("2 Loop over tasks");
323 : // Turn back on derivative calculation
324 19965 : dertime=true;
325 :
326 19965 : if(timers) stopwatch.start("3 MPI gather");
327 : // MPI Gather everything
328 19965 : if( !serial && buffer.size()>0 ) comm.Sum( buffer );
329 : // MPI Gather index stores
330 19965 : if( mydata && !lowmem && !noderiv ) {
331 690 : comm.Sum( der_list ); mydata->setActiveValsAndDerivatives( der_list );
332 : }
333 : // Update the elements that are makign contributions to the sum here
334 : // this causes problems if we do it in prepare
335 19965 : if(timers) stopwatch.stop("3 MPI gather");
336 :
337 19965 : if(timers) stopwatch.start("4 Finishing computations");
338 19965 : finishComputations( buffer );
339 19965 : if(timers) stopwatch.stop("4 Finishing computations");
340 19965 : }
341 :
342 0 : void ActionWithVessel::transformBridgedDerivatives( const unsigned& current, MultiValue& invals, MultiValue& outvals ) const {
343 0 : plumed_error();
344 : }
345 :
346 467098 : void ActionWithVessel::calculateAllVessels( const unsigned& taskCode, MultiValue& myvals, MultiValue& bvals, std::vector<double>& buffer, std::vector<unsigned>& der_list ) {
347 1078808 : for(unsigned j=0; j<functions.size(); ++j) {
348 : // Calculate returns a bool that tells us if this particular
349 : // quantity is contributing more than the tolerance
350 611710 : functions[j]->calculate( taskCode, functions[j]->transformDerivatives(taskCode, myvals, bvals), buffer, der_list );
351 611710 : if( !actionIsBridged ) bvals.clearAll();
352 : }
353 467098 : return;
354 : }
355 :
356 22990 : void ActionWithVessel::finishComputations( const std::vector<double>& buffer ) {
357 : // Set the final value of the function
358 58257 : for(unsigned j=0; j<functions.size(); ++j) functions[j]->finish( buffer );
359 22990 : }
360 :
361 7221 : bool ActionWithVessel::getForcesFromVessels( std::vector<double>& forcesToApply ) {
362 : #ifndef NDEBUG
363 : if( forcesToApply.size()>0 ) plumed_dbg_assert( forcesToApply.size()==getNumberOfDerivatives() );
364 : #endif
365 7221 : if(tmpforces.size()!=forcesToApply.size() ) tmpforces.resize( forcesToApply.size() );
366 :
367 7221 : forcesToApply.assign( forcesToApply.size(),0.0 );
368 : bool wasforced=false;
369 21107 : for(unsigned i=0; i<getNumberOfVessels(); ++i) {
370 13886 : if( (functions[i]->applyForce( tmpforces )) ) {
371 : wasforced=true;
372 524255 : for(unsigned j=0; j<forcesToApply.size(); ++j) forcesToApply[j]+=tmpforces[j];
373 : }
374 : }
375 7221 : return wasforced;
376 : }
377 :
378 0 : void ActionWithVessel::retrieveDomain( std::string& min, std::string& max ) {
379 0 : plumed_merror("If your function is periodic you need to add a retrieveDomain function so that ActionWithVessel can retrieve the domain");
380 : }
381 :
382 0 : Vessel* ActionWithVessel::getVesselWithName( const std::string& mynam ) {
383 : int target=-1;
384 0 : for(unsigned i=0; i<functions.size(); ++i) {
385 0 : if( functions[i]->getName().find(mynam)!=std::string::npos ) {
386 0 : if( target<0 ) target=i;
387 0 : else error("found more than one " + mynam + " object in action");
388 : }
389 : }
390 0 : plumed_assert(target>=0);
391 0 : return functions[target].get();
392 : }
393 :
394 : }
395 : }
|