LCOV - code coverage report
Current view: top level - refdist - MatrixProductDiagonal.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 70 80 87.5 %
Date: 2025-04-08 21:11:17 Functions: 5 6 83.3 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionWithVector.h"
      23             : #include "core/ActionRegister.h"
      24             : 
      25             : //+PLUMEDOC FUNCTION MATRIX_PRODUCT_DIAGONAL
      26             : /*
      27             : Calculate the product of two matrices and return a vector that contains the diagonal elements of the ouptut vector
      28             : 
      29             : \par Examples
      30             : 
      31             : */
      32             : //+ENDPLUMEDOC
      33             : 
      34             : namespace PLMD {
      35             : namespace refdist {
      36             : 
      37             : class MatrixProductDiagonal : public ActionWithVector {
      38             : private:
      39             : public:
      40             :   static void registerKeywords( Keywords& keys );
      41             :   explicit MatrixProductDiagonal(const ActionOptions&);
      42             :   unsigned getNumberOfDerivatives() override ;
      43             :   void calculate() override ;
      44             :   void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
      45             : };
      46             : 
      47             : PLUMED_REGISTER_ACTION(MatrixProductDiagonal,"MATRIX_PRODUCT_DIAGONAL")
      48             : 
      49         112 : void MatrixProductDiagonal::registerKeywords( Keywords& keys ) {
      50         112 :   ActionWithVector::registerKeywords(keys);
      51         224 :   keys.addInputKeyword("compulsory","ARG","vector/matrix","the two vectors/matrices whose product are to be taken");
      52         224 :   keys.setValueDescription("scalar/vector","a vector containing the diagonal elements of the matrix that obtaned by multiplying the two input matrices together");
      53         112 : }
      54             : 
      55          55 : MatrixProductDiagonal::MatrixProductDiagonal(const ActionOptions&ao):
      56             :   Action(ao),
      57          55 :   ActionWithVector(ao) {
      58          55 :   if( getNumberOfArguments()!=2 ) {
      59           0 :     error("should be two arguments to this action, a matrix and a vector");
      60             :   }
      61             : 
      62             :   unsigned ncols;
      63          55 :   if( getPntrToArgument(0)->getRank()==1 ) {
      64           2 :     if( getPntrToArgument(0)->hasDerivatives() ) {
      65           0 :       error("first argument to this action should be a vector or matrix");
      66             :     }
      67             :     ncols = 1;
      68          53 :   } else if( getPntrToArgument(0)->getRank()==2 ) {
      69          53 :     if( getPntrToArgument(0)->hasDerivatives() ) {
      70           0 :       error("first argument to this action should be a matrix");
      71             :     }
      72          53 :     ncols = getPntrToArgument(0)->getShape()[1];
      73             :   }
      74             : 
      75          55 :   if( getPntrToArgument(1)->getRank()==1 ) {
      76          32 :     if( getPntrToArgument(1)->hasDerivatives() ) {
      77           0 :       error("second argument to this action should be a vector or matrix");
      78             :     }
      79          32 :     if( ncols!=getPntrToArgument(1)->getShape()[0] ) {
      80           0 :       error("number of columns in first matrix does not equal number of elements in vector");
      81             :     }
      82          32 :     if( getPntrToArgument(0)->getShape()[0]!=1 ) {
      83           0 :       error("matrix output by this action must be square");
      84             :     }
      85          32 :     addValueWithDerivatives();
      86          32 :     setNotPeriodic();
      87             :   } else {
      88          23 :     if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) {
      89           0 :       error("second argument to this action should be a vector or a matrix");
      90             :     }
      91          23 :     if( ncols!=getPntrToArgument(1)->getShape()[0] ) {
      92           0 :       error("number of columns in first matrix does not equal number of rows in second matrix");
      93             :     }
      94          23 :     if( getPntrToArgument(0)->getShape()[0]!=getPntrToArgument(1)->getShape()[1] ) {
      95           0 :       error("matrix output by this action must be square");
      96             :     }
      97          23 :     std::vector<unsigned> shape(1);
      98          23 :     shape[0]=getPntrToArgument(0)->getShape()[0];
      99          23 :     addValue( shape );
     100          23 :     setNotPeriodic();
     101             :   }
     102          55 :   getPntrToArgument(0)->buildDataStore();
     103          55 :   getPntrToArgument(1)->buildDataStore();
     104          55 : }
     105             : 
     106        2418 : unsigned MatrixProductDiagonal::getNumberOfDerivatives() {
     107        2418 :   if( doNotCalculateDerivatives() ) {
     108             :     return 0;
     109             :   }
     110         108 :   return getPntrToArgument(0)->getNumberOfValues() + getPntrToArgument(1)->getNumberOfValues();;
     111             : }
     112             : 
     113      119881 : void MatrixProductDiagonal::performTask( const unsigned& task_index, MultiValue& myvals ) const {
     114      119881 :   unsigned ostrn = getConstPntrToComponent(0)->getPositionInStream();
     115             :   Value* arg1 = getPntrToArgument(0);
     116             :   Value* arg2 = getPntrToArgument(1);
     117      119881 :   if( arg1->getRank()==1 ) {
     118          40 :     double val1 = arg1->get( task_index );
     119          40 :     double val2 = arg2->get( task_index );
     120          40 :     myvals.addValue( ostrn, val1*val2 );
     121             : 
     122          40 :     if( doNotCalculateDerivatives() ) {
     123           0 :       return;
     124             :     }
     125             : 
     126          40 :     myvals.addDerivative( ostrn, task_index, val2 );
     127          40 :     myvals.updateIndex( ostrn, task_index );
     128          40 :     unsigned nvals = getPntrToArgument(0)->getNumberOfValues();
     129          40 :     myvals.addDerivative( ostrn, nvals + task_index, val1 );
     130          40 :     myvals.updateIndex( ostrn, nvals + task_index );
     131             :   } else {
     132             :     unsigned nmult = arg1->getRowLength(task_index);
     133      119841 :     unsigned nrowsA = getPntrToArgument(0)->getShape()[1];
     134             :     unsigned nrowsB = 1;
     135      119841 :     if( getPntrToArgument(1)->getRank()>1 ) {
     136      118662 :       nrowsB = getPntrToArgument(1)->getShape()[1];
     137             :     }
     138      119841 :     unsigned nvals1 = getPntrToArgument(0)->getNumberOfValues();
     139             : 
     140             :     double matval = 0;
     141     3025572 :     for(unsigned i=0; i<nmult; ++i) {
     142             :       unsigned kind = arg1->getRowIndex( task_index, i );
     143     2905731 :       double val1 = arg1->get( task_index*nrowsA + kind );
     144     2905731 :       double val2 = arg2->get( kind*nrowsB + task_index );
     145     2905731 :       matval += val1*val2;
     146             : 
     147     2905731 :       if( doNotCalculateDerivatives() ) {
     148       68892 :         continue;
     149             :       }
     150             : 
     151     2836839 :       myvals.addDerivative( ostrn, task_index*nrowsA + kind, val2 );
     152     2836839 :       myvals.updateIndex( ostrn, task_index*nrowsA + kind );
     153     2836839 :       myvals.addDerivative( ostrn, nvals1 + kind*nrowsB + task_index, val1 );
     154     2836839 :       myvals.updateIndex( ostrn, nvals1 + kind*nrowsB + task_index );
     155             :     }
     156             :     // And add this part of the product
     157      119841 :     myvals.addValue( ostrn, matval );
     158             :   }
     159             : }
     160             : 
     161        2792 : void MatrixProductDiagonal::calculate() {
     162        2792 :   if( getPntrToArgument(1)->getRank()==1 ) {
     163        1179 :     unsigned nder = getNumberOfDerivatives();
     164        1179 :     MultiValue myvals( 1, nder, 0, 0, 0 );
     165        1179 :     performTask( 0, myvals );
     166             : 
     167        1179 :     Value* myval=getPntrToComponent(0);
     168             :     myval->set( myvals.get(0) );
     169        1329 :     for(unsigned i=0; i<nder; ++i) {
     170         150 :       myval->setDerivative( i, myvals.getDerivative(0,i) );
     171             :     }
     172        1179 :   } else {
     173        1613 :     runAllTasks();
     174             :   }
     175        2792 : }
     176             : 
     177             : }
     178             : }

Generated by: LCOV version 1.16