LCOV - code coverage report
Current view: top level - matrixtools - MatrixTimesVector.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 141 179 78.8 %
Date: 2025-04-08 21:11:17 Functions: 10 12 83.3 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionWithMatrix.h"
      23             : #include "core/ActionRegister.h"
      24             : 
      25             : //+PLUMEDOC MCOLVAR MATRIX_VECTOR_PRODUCT
      26             : /*
      27             : Calculate the product of the matrix and the vector
      28             : 
      29             : Thiis action allows you to [multiply](https://en.wikipedia.org/wiki/Matrix_multiplication) a matrix and a vector together.
      30             : This action is primarily used to calculate coordination numbers and symmetry functions, which is what is done by the example below:
      31             : 
      32             : ```plumed
      33             : c1: CONTACT_MATRIX GROUP=1-7 SWITCH={RATIONAL R_0=2.6 NN=6 MM=12}
      34             : ones: ONES SIZE=7
      35             : cc: MATRIX_VECTOR_PRODUCT ARG=c1,ones
      36             : PRINT ARG=cc FILE=colvar
      37             : ```
      38             : 
      39             : Notice that you can use this action to multiply multiple matrices by a single vector as shown below:
      40             : 
      41             : ```plumed
      42             : c1: CONTACT_MATRIX COMPONENTS GROUP=1-7 SWITCH={RATIONAL R_0=2.6 NN=6 MM=12 D_MAX=10.0}
      43             : ones: ONES SIZE=7
      44             : cc: MATRIX_VECTOR_PRODUCT ARG=c1.x,c1.y,c1.z,ones
      45             : PRINT ARG=cc.x,cc.y,cc.z FILE=colvar
      46             : ```
      47             : 
      48             : Notice that if you use this options all the input matrices must have the same sparsity pattern.  This feature
      49             : was implemented in order to making caluclating Steinhardt parameters such as [Q6](Q6.md) straightforward.
      50             : 
      51             : You can also multiply a single matrix by multiple vectors:
      52             : 
      53             : ```plumed
      54             : c1: CONTACT_MATRIX GROUP=1-7 SWITCH={RATIONAL R_0=2.6 NN=6 MM=12 D_MAX=10.0}
      55             : ones: ONES SIZE=7
      56             : twos: CONSTANT VALUES=1,2,3,4,5,6,7
      57             : cc: MATRIX_VECTOR_PRODUCT ARG=c1,ones,twos
      58             : PRINT ARG=cc.ones,cc.twos FILE=colvar
      59             : ```
      60             : 
      61             : This feature was implemented to make calculating local averages of the Steinhard parameters straightforward.
      62             : 
      63             : */
      64             : //+ENDPLUMEDOC
      65             : 
      66             : namespace PLMD {
      67             : namespace matrixtools {
      68             : 
      69             : class MatrixTimesVector : public ActionWithMatrix {
      70             : private:
      71             :   bool sumrows;
      72             :   unsigned nderivatives;
      73             :   std::vector<bool> stored_arg;
      74             : public:
      75             :   static void registerKeywords( Keywords& keys );
      76             :   explicit MatrixTimesVector(const ActionOptions&);
      77             :   std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
      78           0 :   unsigned getNumberOfColumns() const override {
      79           0 :     plumed_error();
      80             :   }
      81             :   unsigned getNumberOfDerivatives();
      82             :   void prepare() override ;
      83        2151 :   bool isInSubChain( unsigned& nder ) override {
      84        2151 :     nder = arg_deriv_starts[0];
      85        2151 :     return true;
      86             :   }
      87             :   void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
      88             :   void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
      89             :   void runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
      90             :   void updateAdditionalIndices( const unsigned& ostrn, MultiValue& myvals ) const override ;
      91             : };
      92             : 
      93             : PLUMED_REGISTER_ACTION(MatrixTimesVector,"MATRIX_VECTOR_PRODUCT")
      94             : 
      95         629 : void MatrixTimesVector::registerKeywords( Keywords& keys ) {
      96         629 :   ActionWithMatrix::registerKeywords(keys);
      97        1258 :   keys.addInputKeyword("compulsory","ARG","matrix/vector/scalar","the label for the matrix and the vector/scalar that are being multiplied.  Alternatively, you can provide labels for multiple matrices and a single vector or labels for a single matrix and multiple vectors. In these cases multiple matrix vector products will be computed.");
      98        1258 :   keys.setValueDescription("vector","the vector that is obtained by taking the product between the matrix and the vector that were input");
      99         629 :   ActionWithValue::useCustomisableComponents(keys);
     100         629 : }
     101             : 
     102           6 : std::string MatrixTimesVector::getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const {
     103           6 :   if( getPntrToArgument(1)->getRank()==1 ) {
     104           0 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
     105           0 :       if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
     106           0 :         return "the product of the matrix " + getPntrToArgument(0)->getName() + " and the vector " + getPntrToArgument(i)->getName();
     107             :       }
     108             :     }
     109             :   }
     110          21 :   for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
     111          21 :     if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
     112          12 :       return "the product of the matrix " + getPntrToArgument(i)->getName() + " and the vector " + getPntrToArgument(getNumberOfArguments()-1)->getName();
     113             :     }
     114             :   }
     115           0 :   plumed_merror( "could not understand request for component " + cname );
     116             :   return "";
     117             : }
     118             : 
     119         352 : MatrixTimesVector::MatrixTimesVector(const ActionOptions&ao):
     120             :   Action(ao),
     121             :   ActionWithMatrix(ao),
     122         352 :   sumrows(false) {
     123         352 :   if( getNumberOfArguments()<2 ) {
     124           0 :     error("Not enough arguments specified");
     125             :   }
     126             :   unsigned nvectors=0, nmatrices=0;
     127        1875 :   for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     128        1523 :     if( getPntrToArgument(i)->hasDerivatives() ) {
     129           0 :       error("arguments should be vectors or matrices");
     130             :     }
     131        1523 :     if( getPntrToArgument(i)->getRank()<=1 ) {
     132         537 :       nvectors++;
     133             :     }
     134        1523 :     if( getPntrToArgument(i)->getRank()==2 ) {
     135         986 :       nmatrices++;
     136             :     }
     137             :   }
     138             : 
     139         352 :   std::vector<unsigned> shape(1);
     140         352 :   shape[0]=getPntrToArgument(0)->getShape()[0];
     141         352 :   if( nvectors==1 ) {
     142         343 :     unsigned n = getNumberOfArguments()-1;
     143        1320 :     for(unsigned i=0; i<n; ++i) {
     144         977 :       if( getPntrToArgument(i)->getRank()!=2 || getPntrToArgument(i)->hasDerivatives() ) {
     145           0 :         error("all arguments other than last argument should be matrices");
     146             :       }
     147         977 :       if( getPntrToArgument(n)->getRank()==0 ) {
     148           1 :         if( getPntrToArgument(i)->getShape()[1]!=1 ) {
     149           0 :           error("number of columns in input matrix does not equal number of elements in vector");
     150             :         }
     151         976 :       } else if( getPntrToArgument(i)->getShape()[1]!=getPntrToArgument(n)->getShape()[0] ) {
     152             :         std::string str_nmat, str_nvec;
     153           0 :         Tools::convert( getPntrToArgument(i)->getShape()[1], str_nmat);
     154           0 :         Tools::convert( getPntrToArgument(n)->getShape()[0], str_nvec );
     155           0 :         error("number of columns in input matrix is " + str_nmat + " which does not equal number of elements in vector, which is " + str_nvec);
     156             :       }
     157             :     }
     158         343 :     if( getPntrToArgument(n)->getRank()>0 ) {
     159         342 :       if( getPntrToArgument(n)->getRank()!=1 || getPntrToArgument(n)->hasDerivatives() ) {
     160           0 :         error("last argument to this action should be a vector");
     161             :       }
     162             :     }
     163         343 :     getPntrToArgument(n)->buildDataStore();
     164             : 
     165         343 :     ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() );
     166         343 :     if( av ) {
     167         314 :       done_in_chain=canBeAfterInChain( av );
     168             :     }
     169             : 
     170         343 :     if( getNumberOfArguments()==2 ) {
     171         301 :       addValue( shape );
     172         301 :       setNotPeriodic();
     173             :     } else {
     174         718 :       for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
     175         676 :         std::string name = getPntrToArgument(i)->getName();
     176         676 :         if( name.find_first_of(".")!=std::string::npos ) {
     177         676 :           std::size_t dot=name.find_first_of(".");
     178        1352 :           name = name.substr(dot+1);
     179             :         }
     180         676 :         addComponent( name, shape );
     181         676 :         componentIsNotPeriodic( name );
     182             :       }
     183             :     }
     184         343 :     if( (getPntrToArgument(n)->getPntrToAction())->getName()=="CONSTANT" ) {
     185         306 :       sumrows=true;
     186         306 :       if( getPntrToArgument(n)->getRank()==0 ) {
     187           1 :         if( fabs( getPntrToArgument(n)->get() - 1.0 )>epsilon ) {
     188           0 :           sumrows = false;
     189             :         }
     190             :       } else {
     191      180438 :         for(unsigned i=0; i<getPntrToArgument(n)->getShape()[0]; ++i) {
     192      180141 :           if( fabs( getPntrToArgument(n)->get(i) - 1.0 )>epsilon ) {
     193           8 :             sumrows=false;
     194           8 :             break;
     195             :           }
     196             :         }
     197             :       }
     198             :     }
     199           9 :   } else if( nmatrices==1 ) {
     200           9 :     if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) {
     201           0 :       error("first argument to this action should be a matrix");
     202             :     }
     203         203 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
     204         194 :       if( getPntrToArgument(i)->getRank()>1 || getPntrToArgument(i)->hasDerivatives() ) {
     205           0 :         error("all arguments other than first argument should be vectors");
     206             :       }
     207         194 :       if( getPntrToArgument(i)->getRank()==0 ) {
     208           0 :         if( getPntrToArgument(0)->getShape()[1]!=1 ) {
     209           0 :           error("number of columns in input matrix does not equal number of elements in vector");
     210             :         }
     211         194 :       } else if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(i)->getShape()[0] ) {
     212           0 :         error("number of columns in input matrix does not equal number of elements in vector");
     213             :       }
     214         194 :       getPntrToArgument(i)->buildDataStore();
     215             :     }
     216             : 
     217           9 :     ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() );
     218           9 :     if( av ) {
     219           9 :       done_in_chain=canBeAfterInChain( av );
     220             :     }
     221             : 
     222         203 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
     223         194 :       std::string name = getPntrToArgument(i)->getName();
     224         194 :       if( name.find_first_of(".")!=std::string::npos ) {
     225           0 :         std::size_t dot=name.find_first_of(".");
     226           0 :         name = name.substr(dot+1);
     227             :       }
     228         194 :       addComponent( name, shape );
     229         194 :       componentIsNotPeriodic( name );
     230             :     }
     231             :   } else {
     232           0 :     error("You should either have one vector or one matrix in input");
     233             :   }
     234             : 
     235         352 :   nderivatives = buildArgumentStore(0);
     236         352 :   std::string headstr=getFirstActionInChain()->getLabel();
     237         352 :   stored_arg.resize( getNumberOfArguments() );
     238        1875 :   for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     239        1523 :     stored_arg[i] = getPntrToArgument(i)->ignoreStoredValue( headstr );
     240             :   }
     241         352 : }
     242             : 
     243       31643 : unsigned MatrixTimesVector::getNumberOfDerivatives() {
     244       31643 :   return nderivatives;
     245             : }
     246             : 
     247       13575 : void MatrixTimesVector::prepare() {
     248       13575 :   ActionWithVector::prepare();
     249       13575 :   Value* myval = getPntrToComponent(0);
     250       13575 :   if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] ) {
     251       13565 :     return;
     252             :   }
     253          10 :   std::vector<unsigned> shape(1);
     254          10 :   shape[0] = getPntrToArgument(0)->getShape()[0];
     255          10 :   myval->setShape(shape);
     256             : }
     257             : 
     258        6574 : void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
     259        6574 :   unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
     260        6574 :   if( indices.size()!=size_v+1 ) {
     261        3508 :     indices.resize( size_v + 1 );
     262             :   }
     263      842762 :   for(unsigned i=0; i<size_v; ++i) {
     264      836188 :     indices[i+1] = start_n + i;
     265             :   }
     266             :   myvals.setSplitIndex( size_v + 1 );
     267        6574 : }
     268             : 
     269    23970940 : void MatrixTimesVector::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
     270    23970940 :   unsigned ind2 = index2;
     271    23970940 :   if( index2>=getPntrToArgument(0)->getShape()[0] ) {
     272     1600742 :     ind2 = index2 - getPntrToArgument(0)->getShape()[0];
     273             :   }
     274    23970940 :   if( sumrows ) {
     275    22303792 :     unsigned n=getNumberOfArguments()-1;
     276             :     double matval = 0;
     277    87441027 :     for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
     278    65137235 :       unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
     279             :       Value* myarg = getPntrToArgument(i);
     280    65137235 :       if( !myarg->valueHasBeenSet() ) {
     281    65122517 :         myvals.addValue( ostrn, myvals.get( myarg->getPositionInStream() ) );
     282             :       } else {
     283       14718 :         myvals.addValue( ostrn, myarg->get( index1*myarg->getNumberOfColumns() + ind2, false ) );
     284             :       }
     285             :       // Now lets work out the derivatives
     286    65137235 :       if( doNotCalculateDerivatives() ) {
     287    32313889 :         continue;
     288             :       }
     289    32823346 :       addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, 1.0, myvals );
     290             :     }
     291     1667148 :   } else if( getPntrToArgument(1)->getRank()==1 ) {
     292             :     double matval = 0;
     293             :     Value* myarg = getPntrToArgument(0);
     294     1667148 :     unsigned vcol = ind2;
     295     1667148 :     if( !myarg->valueHasBeenSet() ) {
     296      840110 :       matval = myvals.get( myarg->getPositionInStream() );
     297             :     } else {
     298      827038 :       matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
     299      827038 :       vcol = getPntrToArgument(0)->getRowIndex( index1, ind2 );
     300             :     }
     301    18356786 :     for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
     302    16689638 :       unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
     303    16689638 :       double vecval=getArgumentElement( i+1, vcol, myvals );
     304             :       // And add this part of the product
     305    16689638 :       myvals.addValue( ostrn, matval*vecval );
     306             :       // Now lets work out the derivatives
     307    16689638 :       if( doNotCalculateDerivatives() ) {
     308     1000870 :         continue;
     309             :       }
     310    15688768 :       addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals );
     311    15688768 :       addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, vcol, matval, myvals );
     312             :     }
     313             :   } else {
     314           0 :     unsigned n=getNumberOfArguments()-1;
     315           0 :     double matval = 0;
     316           0 :     unsigned vcol = ind2;
     317           0 :     for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
     318           0 :       unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
     319             :       Value* myarg = getPntrToArgument(i);
     320           0 :       if( !myarg->valueHasBeenSet() ) {
     321           0 :         matval = myvals.get( myarg->getPositionInStream() );
     322             :       } else {
     323           0 :         matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
     324           0 :         vcol = getPntrToArgument(i)->getRowIndex( index1, ind2 );
     325             :       }
     326           0 :       double vecval=getArgumentElement( n, vcol, myvals );
     327             :       // And add this part of the product
     328           0 :       myvals.addValue( ostrn, matval*vecval );
     329             :       // Now lets work out the derivatives
     330           0 :       if( doNotCalculateDerivatives() ) {
     331           0 :         continue;
     332             :       }
     333           0 :       addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals );
     334           0 :       addDerivativeOnVectorArgument( stored_arg[n], i, n, vcol, matval, myvals );
     335             :     }
     336             :   }
     337    23970940 : }
     338             : 
     339      472445 : void MatrixTimesVector::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
     340      472445 :   if( doNotCalculateDerivatives() || !actionInChain() ) {
     341             :     return ;
     342             :   }
     343             : 
     344      358714 :   if( getPntrToArgument(1)->getRank()==1 ) {
     345             :     unsigned istrn = getPntrToArgument(0)->getPositionInMatrixStash();
     346             :     std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
     347     1010565 :     for(unsigned j=0; j<getNumberOfComponents(); ++j) {
     348      671975 :       unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
     349    40971258 :       for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(istrn); ++i) {
     350    40299283 :         myvals.updateIndex( ostrn, mat_indices[i] );
     351             :       }
     352             :     }
     353             :   } else {
     354      530036 :     for(unsigned j=0; j<getNumberOfComponents(); ++j) {
     355             :       unsigned istrn = getPntrToArgument(j)->getPositionInMatrixStash();
     356      509912 :       unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
     357             :       std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
     358    17456348 :       for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(istrn); ++i) {
     359    16946436 :         myvals.updateIndex( ostrn, mat_indices[i] );
     360             :       }
     361             :     }
     362             :   }
     363             : }
     364             : 
     365      372677 : void MatrixTimesVector::updateAdditionalIndices( const unsigned& ostrn, MultiValue& myvals ) const {
     366      372677 :   unsigned n = getNumberOfArguments()-1;
     367      372677 :   if( getPntrToArgument(1)->getRank()==1 ) {
     368             :     n = 1;
     369             :   }
     370      372677 :   unsigned nvals = getPntrToArgument(n)->getNumberOfValues();
     371  1387754027 :   for(unsigned i=0; i<nvals; ++i) {
     372  1387381350 :     myvals.updateIndex( ostrn, arg_deriv_starts[n] + i );
     373             :   }
     374      372677 : }
     375             : 
     376             : }
     377             : }

Generated by: LCOV version 1.16