LCOV - code coverage report
Current view: top level - function - FunctionOfMatrix.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 201 217 92.6 %
Date: 2024-10-18 13:59:31 Functions: 85 108 78.7 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2020 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_function_FunctionOfMatrix_h
      23             : #define __PLUMED_function_FunctionOfMatrix_h
      24             : 
      25             : #include "core/ActionWithMatrix.h"
      26             : #include "FunctionOfVector.h"
      27             : #include "Sum.h"
      28             : #include "tools/Matrix.h"
      29             : 
      30             : namespace PLMD {
      31             : namespace function {
      32             : 
      33             : template <class T>
      34             : class FunctionOfMatrix : public ActionWithMatrix {
      35             : private:
      36             : /// Is this the first step of the calculation
      37             :   bool firststep;
      38             : /// The function that is being computed
      39             :   T myfunc;
      40             : /// The number of derivatives for this action
      41             :   unsigned nderivatives;
      42             : /// A vector that tells us if we have stored the input value
      43             :   std::vector<bool> stored_arguments;
      44             : /// Switch off updating the arguments for this action
      45             :   std::vector<bool> update_arguments;
      46             : /// The list of actiosn in this chain
      47             :   std::vector<std::string> actionsLabelsInChain;
      48             : /// Get the shape of the output matrix
      49             :   std::vector<unsigned> getValueShapeFromArguments();
      50             : public:
      51             :   static void registerKeywords(Keywords&);
      52             :   explicit FunctionOfMatrix(const ActionOptions&);
      53             : /// Get the label to write in the graph
      54           0 :   std::string writeInGraph() const override { return myfunc.getGraphInfo( getName() ); }
      55             : /// Make sure the derivatives are turned on
      56             :   void turnOnDerivatives() override;
      57             : /// Get the number of derivatives for this action
      58             :   unsigned getNumberOfDerivatives() override ;
      59             : /// Resize the matrices
      60             :   void prepare() override ;
      61             : /// This gets the number of columns
      62             :   unsigned getNumberOfColumns() const override ;
      63             : /// This checks for tasks in the parent class
      64             : //  void buildTaskListFromArgumentRequests( const unsigned& ntasks, bool& reduce, std::set<AtomNumber>& otasks ) override ;
      65             : /// This ensures that we create some bookeeping stuff during the first step
      66             :   void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override ;
      67             : /// This sets up for the task
      68             :   void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
      69             : /// Calculate the full matrix
      70             :   void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override ;
      71             : /// This updates the indices for the matrix
      72             : //  void updateCentralMatrixIndex( const unsigned& ind, const std::vector<unsigned>& indices, MultiValue& myvals ) const override ;
      73             :   void runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
      74             : };
      75             : 
      76             : template <class T>
      77        1013 : void FunctionOfMatrix<T>::registerKeywords(Keywords& keys ) {
      78        1013 :   ActionWithMatrix::registerKeywords(keys); std::string name = keys.getDisplayName();
      79        1013 :   std::size_t und=name.find("_MATRIX"); keys.setDisplayName( name.substr(0,und) );
      80        2026 :   keys.addInputKeyword("compulsory","ARG","scalar/matrix","the labels of the scalar and matrices that on which the function is being calculated elementwise");
      81        2026 :   keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
      82        2026 :   keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function.  If the output is not periodic you must state this using PERIODIC=NO");
      83        1013 :   T tfunc; tfunc.registerKeywords( keys );
      84        2026 :   if( keys.getDisplayName()=="SUM" ) {
      85         168 :     keys.setValueDescription("scalar","the sum of all the elements in the input matrix");
      86        1858 :   } else if( keys.getDisplayName()=="HIGHEST" ) {
      87           0 :     keys.setValueDescription("scalar","the largest element of the input matrix");
      88        1858 :   } else if( keys.getDisplayName()=="LOWEST" ) {
      89           0 :     keys.setValueDescription("scalar","the smallest element in the input matrix");
      90        1858 :   } else if( keys.outputComponentExists(".#!value") ) {
      91        1672 :     keys.setValueDescription("matrix","the matrix obtained by doing an element-wise application of " + keys.getOutputComponentDescription(".#!value") + " to the input matrix");
      92             :   }
      93        1899 : }
      94             : 
      95             : template <class T>
      96         495 : FunctionOfMatrix<T>::FunctionOfMatrix(const ActionOptions&ao):
      97             :   Action(ao),
      98             :   ActionWithMatrix(ao),
      99         495 :   firststep(true)
     100             : {
     101         451 :   if( myfunc.getArgStart()>0 ) error("this has not beeen implemented -- if you are interested email gareth.tribello@gmail.com");
     102             :   // Get the shape of the output
     103         495 :   std::vector<unsigned> shape( getValueShapeFromArguments() );
     104             :   // Check if the output matrix is symmetric
     105         495 :   bool symmetric=true; unsigned argstart=myfunc.getArgStart();
     106        1508 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     107        1013 :     if( getPntrToArgument(i)->getRank()==2 ) {
     108         948 :       if( !getPntrToArgument(i)->isSymmetric() ) { symmetric=false;  }
     109             :     }
     110             :   }
     111             :   // Read the input and do some checks
     112         495 :   myfunc.read( this );
     113             :   // Setup to do this in chain if possible
     114         495 :   if( myfunc.doWithTasks() ) done_in_chain=true;
     115             :   // Check we are not calculating a sum
     116          41 :   if( myfunc.zeroRank() ) shape.resize(0);
     117             :   // Get the names of the components
     118         495 :   std::vector<std::string> components( keywords.getOutputComponents() );
     119             :   // Create the values to hold the output
     120          42 :   std::vector<std::string> str_ind( myfunc.getComponentsPerLabel() );
     121        1034 :   for(unsigned i=0; i<components.size(); ++i) {
     122          84 :     if( str_ind.size()>0 ) {
     123         168 :       std::string compstr = components[i]; if( components[i]==".#!value" ) compstr = "";
     124         760 :       for(unsigned j=0; j<str_ind.size(); ++j) {
     125             :         if( myfunc.zeroRank() ) {
     126             :           addComponentWithDerivatives( compstr + str_ind[j], shape );
     127             :         } else {
     128        1352 :           addComponent( compstr + str_ind[j], shape );
     129         676 :           getPntrToComponent(i*str_ind.size()+j)->setSymmetric( symmetric );
     130             :         }
     131             :       }
     132          41 :     } else if( components[i]==".#!value" && myfunc.zeroRank() ) {
     133          41 :       addValueWithDerivatives( shape );
     134         414 :     } else if( components[i]==".#!value" ) {
     135         410 :       addValue( shape ); getPntrToComponent(0)->setSymmetric( symmetric );
     136           4 :     } else if( components[i].find_first_of("_")!=std::string::npos ) {
     137           0 :       if( getNumberOfArguments()-argstart==1 ) { addValue( shape ); getPntrToComponent(0)->setSymmetric( symmetric ); }
     138             :       else {
     139           0 :         for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     140           0 :           addComponent( getPntrToArgument(j)->getName() + components[i], shape );
     141           0 :           getPntrToComponent(i*(getNumberOfArguments()-argstart)+j-argstart)->setSymmetric( symmetric );
     142             :         }
     143             :       }
     144           4 :     } else { addComponent( components[i], shape ); getPntrToComponent(i)->setSymmetric( symmetric ); }
     145             :   }
     146             :   // Check if this can be sped up
     147         370 :   if( myfunc.getDerivativeZeroIfValueIsZero() )  {
     148         174 :     for(int i=0; i<getNumberOfComponents(); ++i) getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero();
     149             :   }
     150             :   // Set the periodicities of the output components
     151         495 :   myfunc.setPeriodicityForOutputs( this );
     152             :   // We can't do this with if we are dividing a stack by some a product v.v^T product as we need to store the vector
     153             :   // In order to do this type of calculation.  There should be a neater fix than this but I can't see it.
     154             :   bool foundneigh=false; const ActionWithMatrix* chainstart = NULL;
     155        1503 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     156        1207 :     if( getPntrToArgument(i)->isConstant() && getNumberOfArguments()>1 ) continue ;
     157         932 :     std::string argname=(getPntrToArgument(i)->getPntrToAction())->getName();
     158         932 :     if( argname=="NEIGHBORS" ) { foundneigh=true; break; }
     159         929 :     ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
     160         929 :     if( !av ) done_in_chain=false;
     161         929 :     if( getPntrToArgument(i)->getRank()==0 ) {
     162           0 :       function::FunctionOfVector<function::Sum>* as = dynamic_cast<function::FunctionOfVector<function::Sum>*>( getPntrToArgument(i)->getPntrToAction() );
     163           0 :       if(as) done_in_chain=false;
     164         929 :     } else if( getPntrToArgument(i)->ignoreStoredValue( getLabel() ) ) {
     165             :       // This option deals with the case when you have two adjacency matrices, A_ij and B_ij, multiplied together.  This cannot be done in the chain as the rows
     166             :       // of the two adjacency matrix are run over separately.  The value A_ij is thus not available when B_ij is calculated.
     167         853 :       ActionWithMatrix* am = dynamic_cast<ActionWithMatrix*>( getPntrToArgument(i)->getPntrToAction() );
     168         853 :       plumed_assert( am ); const ActionWithMatrix* thischain = am->getFirstMatrixInChain();
     169         853 :       if( !thischain->isAdjacencyMatrix() && thischain->getName()!="VSTACK" ) continue;
     170         657 :       if( !chainstart ) chainstart = thischain;
     171         317 :       else if( thischain!=chainstart ) done_in_chain=false;
     172             :     }
     173             :   }
     174             :   // If we are working with neighbors we trick PLUMED into storing ALL the components of the other arguments
     175             :   // in this way we can ensure that the function of the neighbours matrix is in a chain starting from the
     176             :   // Neighbours matrix action.
     177             :   if( foundneigh ) {
     178           9 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     179           6 :       ActionWithValue* av=getPntrToArgument(i)->getPntrToAction();
     180           6 :       if( av->getName()!="NEIGHBORS" ) {
     181           8 :         for(int i=0; i<av->getNumberOfComponents(); ++i) (av->copyOutput(i))->buildDataStore();
     182             :       }
     183             :     }
     184             :   }
     185             :   // Now setup the action in the chain if we can
     186         495 :   nderivatives = buildArgumentStore(myfunc.getArgStart());
     187         990 : }
     188             : 
     189             : template <class T>
     190        1921 : void FunctionOfMatrix<T>::turnOnDerivatives() {
     191        1921 :   if( !myfunc.derivativesImplemented() ) error("derivatives have not been implemended for " + getName() );
     192        1921 :   ActionWithValue::turnOnDerivatives(); myfunc.setup(this);
     193        1921 : }
     194             : 
     195             : template <class T>
     196       30411 : unsigned FunctionOfMatrix<T>::getNumberOfDerivatives() {
     197       30411 :   return nderivatives;
     198             : }
     199             : 
     200             : template <class T>
     201        2229 : void FunctionOfMatrix<T>::prepare() {
     202        2229 :   unsigned argstart = myfunc.getArgStart(); std::vector<unsigned> shape(2);
     203        2229 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     204        2229 :     if( getPntrToArgument(i)->getRank()==2 ) {
     205        2229 :       shape[0] = getPntrToArgument(i)->getShape()[0];
     206        2229 :       shape[1] = getPntrToArgument(i)->getShape()[1];
     207        2229 :       break;
     208             :     }
     209             :   }
     210        6682 :   for(unsigned i=0; i<getNumberOfComponents(); ++i) {
     211        4453 :     Value* myval = getPntrToComponent(i);
     212        4453 :     if( myval->getRank()==2 && (myval->getShape()[0]!=shape[0] || myval->getShape()[1]!=shape[1]) ) {
     213          18 :       myval->setShape(shape); if( myval->valueIsStored() ) myval->reshapeMatrixStore( shape[1] );
     214             :     }
     215             :   }
     216        2229 :   ActionWithVector::prepare();
     217        2229 : }
     218             : 
     219             : template <class T>
     220      281844 : unsigned FunctionOfMatrix<T>::getNumberOfColumns() const {
     221      281844 :   if( getConstPntrToComponent(0)->getRank()==2 ) {
     222             :     unsigned argstart=myfunc.getArgStart();
     223      281844 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     224      281844 :       if( getPntrToArgument(i)->getRank()==2 ) {
     225      281844 :         ActionWithMatrix* am=dynamic_cast<ActionWithMatrix*>( getPntrToArgument(i)->getPntrToAction() );
     226      281844 :         if( am ) return am->getNumberOfColumns();
     227        2238 :         return getPntrToArgument(i)->getShape()[1];
     228             :       }
     229             :     }
     230             :   }
     231           0 :   plumed_error(); return 0;
     232             : }
     233             : 
     234             : template <class T>
     235        4209 : void FunctionOfMatrix<T>::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
     236       11667 :   for(unsigned i=0; i<getNumberOfArguments(); ++i) plumed_assert( getPntrToArgument(i)->getRank()==2 );
     237        4209 :   unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getShape()[1];
     238        4209 :   if( indices.size()!=size_v+1 ) indices.resize( size_v+1 );
     239      642613 :   for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + i;
     240             :   myvals.setSplitIndex( size_v + 1 );
     241        4209 : }
     242             : 
     243             : // template <class T>
     244             : // void FunctionOfMatrix<T>::buildTaskListFromArgumentRequests( const unsigned& ntasks, bool& reduce, std::set<AtomNumber>& otasks ) {
     245             : //   // Check if this is the first element in a chain
     246             : //   if( actionInChain() ) return;
     247             : //   // If it is computed outside a chain get the tassks the daughter chain needs
     248             : //   propegateTaskListsForValue( 0, ntasks, reduce, otasks );
     249             : // }
     250             : 
     251             : template <class T>
     252        2525 : void FunctionOfMatrix<T>::setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) {
     253        2525 :   if( firststep ) {
     254         489 :     stored_arguments.resize( getNumberOfArguments() );
     255         489 :     update_arguments.resize( getNumberOfArguments(), true );
     256         489 :     std::string control = getFirstActionInChain()->getLabel();
     257        1484 :     for(unsigned i=0; i<stored_arguments.size(); ++i) {
     258         995 :       stored_arguments[i] = !getPntrToArgument(i)->ignoreStoredValue( control );
     259         995 :       if( !stored_arguments[i] ) update_arguments[i] = true;
     260         164 :       else update_arguments[i] = !argumentDependsOn( headstr, this, getPntrToArgument(i) );
     261             :     }
     262         489 :     firststep=false;
     263             :   }
     264        2525 :   ActionWithMatrix::setupStreamedComponents( headstr, nquants, nmat, maxcol, nbookeeping );
     265        2525 : }
     266             : 
     267             : template <class T>
     268    27928651 : void FunctionOfMatrix<T>::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
     269    27928651 :   unsigned argstart=myfunc.getArgStart(); std::vector<double> args( getNumberOfArguments() - argstart );
     270    27928651 :   unsigned ind2 = index2;
     271    27928651 :   if( getConstPntrToComponent(0)->getRank()==2 && index2>=getConstPntrToComponent(0)->getShape()[0] ) ind2 = index2 - getConstPntrToComponent(0)->getShape()[0];
     272    24292268 :   else if( index2>=getPntrToArgument(0)->getShape()[0] ) ind2 = index2 - getPntrToArgument(0)->getShape()[0];
     273    27928651 :   if( actionInChain() ) {
     274    85619946 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     275    58329699 :       if( getPntrToArgument(i)->getRank()==0 ) args[i-argstart] = getPntrToArgument(i)->get();
     276    58193979 :       else if( !getPntrToArgument(i)->valueHasBeenSet() ) args[i-argstart] = myvals.get( getPntrToArgument(i)->getPositionInStream() );
     277     1188593 :       else args[i-argstart] = getPntrToArgument(i)->get( getPntrToArgument(i)->getShape()[1]*index1 + ind2 );
     278             :     }
     279             :   } else {
     280     1727072 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     281     1088668 :       if( getPntrToArgument(i)->getRank()==2 ) args[i-argstart]=getPntrToArgument(i)->get( getPntrToArgument(i)->getShape()[1]*index1 + ind2 );
     282           0 :       else args[i-argstart] = getPntrToArgument(i)->get();
     283             :     }
     284             :   }
     285             :   // Calculate the function and its derivatives
     286    27928651 :   std::vector<double> vals( getNumberOfComponents() ); Matrix<double> derivatives( getNumberOfComponents(), getNumberOfArguments()-argstart );
     287    27928651 :   myfunc.calc( this, args, vals, derivatives );
     288             :   // And set the values
     289    99634355 :   for(unsigned i=0; i<vals.size(); ++i) myvals.addValue( getConstPntrToComponent(i)->getPositionInStream(), vals[i] );
     290             :   // Return if we are not computing derivatives
     291    27928651 :   if( doNotCalculateDerivatives() ) return;
     292             : 
     293     5399619 :   if( actionInChain() ) {
     294    33385311 :     for(int i=0; i<getNumberOfComponents(); ++i) {
     295    27990552 :       unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     296   131996523 :       for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     297   104005971 :         if( getPntrToArgument(j)->getRank()==2 ) {
     298   103890411 :           unsigned istrn = getPntrToArgument(j)->getPositionInStream();
     299   103890411 :           if( stored_arguments[j] ) {
     300      395048 :             unsigned task_index = getPntrToArgument(i)->getShape()[1]*index1 + ind2;
     301      395048 :             myvals.clearDerivatives(istrn); myvals.addDerivative( istrn, task_index, 1.0 ); myvals.updateIndex( istrn, task_index );
     302             :           }
     303   470717695 :           for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
     304   366827284 :             unsigned kind=myvals.getActiveIndex(istrn,k);
     305   366827284 :             myvals.addDerivative( ostrn, arg_deriv_starts[j] + kind, derivatives(i,j)*myvals.getDerivative( istrn, kind ) );
     306             :           }
     307             :         }
     308             :       }
     309             :     }
     310             :     // If we are computing a matrix we need to update the indices here so that derivatives are calcualted correctly in functions of these
     311     5394759 :     if( getConstPntrToComponent(0)->getRank()==2 ) {
     312    32784527 :       for(int i=0; i<getNumberOfComponents(); ++i) {
     313    27690160 :         unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     314   131395739 :         for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     315   103705579 :           if( !update_arguments[j] || getPntrToArgument(j)->getRank()==0 ) continue ;
     316             :           // Ensure we only store one lot of derivative indices
     317             :           bool found=false;
     318   105009550 :           for(unsigned k=0; k<j; ++k) {
     319    76601003 :             if( arg_deriv_starts[k]==arg_deriv_starts[j] ) { found=true; break; }
     320             :           }
     321   103589995 :           if( found ) continue;
     322             :           unsigned istrn = getPntrToArgument(j)->getPositionInStream();
     323   138447375 :           for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
     324   110038828 :             unsigned kind=myvals.getActiveIndex(istrn,k);
     325   110038828 :             myvals.updateIndex( ostrn, arg_deriv_starts[j] + kind );
     326             :           }
     327             :         }
     328             :       }
     329             :     }
     330             :   } else {
     331        4860 :     unsigned base=0; ind2 = index2;
     332        4860 :     for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     333        4860 :       if( getPntrToArgument(j)->getRank()!=2 ) continue ;
     334        4860 :       if( index2>=getPntrToArgument(j)->getShape()[0] ) ind2 = index2 - getPntrToArgument(j)->getShape()[0];
     335             :       break;
     336             :     }
     337       13965 :     for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     338        9105 :       if( getPntrToArgument(j)->getRank()==2 ) {
     339       18210 :         for(int i=0; i<getNumberOfComponents(); ++i) {
     340        9105 :           unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     341        9105 :           unsigned myind = base + getPntrToArgument(j)->getShape()[1]*index1 + ind2;
     342        9105 :           myvals.addDerivative( ostrn, myind, derivatives(i,j) );
     343        9105 :           myvals.updateIndex( ostrn, myind );
     344             :         }
     345             :       } else {
     346           0 :         for(int i=0; i<getNumberOfComponents(); ++i) {
     347           0 :           unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     348           0 :           myvals.addDerivative( ostrn, base, derivatives(i,j) );
     349           0 :           myvals.updateIndex( ostrn, base );
     350             :         }
     351             :       }
     352        9105 :       base += getPntrToArgument(j)->getNumberOfValues();
     353             :     }
     354             :   }
     355             : }
     356             : 
     357             : template <class T>
     358      226389 : void FunctionOfMatrix<T>::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
     359      226389 :   if( doNotCalculateDerivatives() ) return;
     360             : 
     361             :   unsigned argstart=myfunc.getArgStart();
     362       71237 :   if( actionInChain() && getConstPntrToComponent(0)->getRank()==2 ) {
     363             :     // This is triggered if we are outputting a matrix
     364      624578 :     for(int vv=0; vv<getNumberOfComponents(); ++vv) {
     365      558183 :       unsigned nmat = getConstPntrToComponent(vv)->getPositionInMatrixStash();
     366             :       std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); unsigned ntot_mat=0;
     367      558183 :       if( mat_indices.size()<nderivatives ) mat_indices.resize( nderivatives );
     368     2701544 :       for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     369     2143361 :         if( !update_arguments[i] || getPntrToArgument(i)->getRank()==0 ) continue ;
     370             :         // Ensure we only store one lot of derivative indices
     371             :         bool found=false;
     372     2168017 :         for(unsigned j=0; j<i; ++j) {
     373     1591516 :           if( arg_deriv_starts[j]==arg_deriv_starts[i] ) { found=true; break; }
     374             :         }
     375     2142277 :         if( found ) continue;
     376             : 
     377      576501 :         if( stored_arguments[i] ) {
     378       15483 :           unsigned tbase = getPntrToArgument(i)->getShape()[1]*ind;
     379      410507 :           for(unsigned k=1; k<indices.size(); ++k) {
     380      395024 :             unsigned ind2 = indices[k] - getConstPntrToComponent(0)->getShape()[0];
     381      395024 :             mat_indices[ntot_mat + k - 1] = arg_deriv_starts[i] + tbase + ind2;
     382             :           }
     383       15483 :           ntot_mat += indices.size()-1;
     384             :         } else {
     385             :           unsigned istrn = getPntrToArgument(i)->getPositionInMatrixStash();
     386             :           std::vector<unsigned>& imat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
     387    31848426 :           for(unsigned k=0; k<myvals.getNumberOfMatrixRowDerivatives( istrn ); ++k) mat_indices[ntot_mat + k] = arg_deriv_starts[i] + imat_indices[k];
     388      561018 :           ntot_mat += myvals.getNumberOfMatrixRowDerivatives( istrn );
     389             :         }
     390             :       }
     391             :       myvals.setNumberOfMatrixRowDerivatives( nmat, ntot_mat );
     392             :     }
     393        4842 :   } else if( actionInChain() ) {
     394             :     // This is triggered if we are calculating a single scalar in the function
     395        8822 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     396             :       bool found=false;
     397        4411 :       for(unsigned j=0; j<i; ++j) {
     398           0 :         if( arg_deriv_starts[j]==arg_deriv_starts[i] ) { found=true; break; }
     399             :       }
     400        4411 :       if( found ) continue;
     401             :       unsigned istrn = getPntrToArgument(i)->getPositionInMatrixStash();
     402             :       std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
     403      926766 :       for(unsigned k=0; k<myvals.getNumberOfMatrixRowDerivatives( istrn ); ++k) {
     404     1844710 :         for(int j=0; j<getNumberOfComponents(); ++j) {
     405      922355 :           unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
     406      922355 :           myvals.updateIndex( ostrn, arg_deriv_starts[i] + mat_indices[k] );
     407             :         }
     408             :       }
     409             :     }
     410         431 :   } else if( getConstPntrToComponent(0)->getRank()==2 ) {
     411         760 :     for(int vv=0; vv<getNumberOfComponents(); ++vv) {
     412         380 :       unsigned nmat = getConstPntrToComponent(vv)->getPositionInMatrixStash();
     413             :       std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); unsigned ntot_mat=0;
     414         380 :       if( mat_indices.size()<nderivatives ) mat_indices.resize( nderivatives ); unsigned matderbase = 0;
     415         986 :       for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     416         606 :         if( getPntrToArgument(i)->getRank()==0 ) continue ;
     417         606 :         unsigned ss = getPntrToArgument(i)->getShape()[1]; unsigned tbase = matderbase + ss*myvals.getTaskIndex();
     418        9558 :         for(unsigned k=0; k<ss; ++k) mat_indices[ntot_mat + k] = tbase + k;
     419         606 :         ntot_mat += ss; matderbase += getPntrToArgument(i)->getNumberOfValues();
     420             :       }
     421             :       myvals.setNumberOfMatrixRowDerivatives( nmat, ntot_mat );
     422             :     }
     423             :   }
     424             : }
     425             : 
     426             : template <class T>
     427         495 : std::vector<unsigned> FunctionOfMatrix<T>::getValueShapeFromArguments() {
     428         495 :   unsigned argstart=myfunc.getArgStart(); std::vector<unsigned> shape(2); shape[0]=shape[1]=0;
     429        1508 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     430        1013 :     plumed_assert( getPntrToArgument(i)->getRank()==2 || getPntrToArgument(i)->getRank()==0 );
     431        1013 :     if( getPntrToArgument(i)->getRank()==2 ) {
     432         948 :       if( shape[0]>0 && (getPntrToArgument(i)->getShape()[0]!=shape[0] || getPntrToArgument(i)->getShape()[1]!=shape[1]) ) error("all matrices input should have the same shape");
     433         948 :       else if( shape[0]==0 ) { shape[0]=getPntrToArgument(i)->getShape()[0]; shape[1]=getPntrToArgument(i)->getShape()[1]; }
     434         948 :       plumed_assert( !getPntrToArgument(i)->hasDerivatives() );
     435             :     }
     436             :   }
     437         495 :   myfunc.setPrefactor( this, 1.0 ); return shape;
     438             : }
     439             : 
     440             : }
     441             : }
     442             : #endif

Generated by: LCOV version 1.16