LCOV - code coverage report
Current view: top level - refdist - MatrixProductDiagonal.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 60 60 100.0 %
Date: 2024-10-18 14:00:25 Functions: 5 6 83.3 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionWithVector.h"
      23             : #include "core/ActionRegister.h"
      24             : 
      25             : //+PLUMEDOC FUNCTION MATRIX_PRODUCT_DIAGONAL
      26             : /*
      27             : Calculate the product of two matrices and return a vector that contains the diagonal elements of the ouptut vector
      28             : 
      29             : \par Examples
      30             : 
      31             : */
      32             : //+ENDPLUMEDOC
      33             : 
      34             : namespace PLMD {
      35             : namespace refdist {
      36             : 
      37             : class MatrixProductDiagonal : public ActionWithVector {
      38             : private:
      39             : public:
      40             :   static void registerKeywords( Keywords& keys );
      41             :   explicit MatrixProductDiagonal(const ActionOptions&);
      42             :   unsigned getNumberOfDerivatives() override ;
      43             :   void calculate() override ;
      44             :   void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
      45             : };
      46             : 
      47             : PLUMED_REGISTER_ACTION(MatrixProductDiagonal,"MATRIX_PRODUCT_DIAGONAL")
      48             : 
      49         112 : void MatrixProductDiagonal::registerKeywords( Keywords& keys ) {
      50         112 :   ActionWithVector::registerKeywords(keys); keys.use("ARG");
      51         112 :   keys.setValueDescription("a vector containing the diagonal elements of the matrix that obtaned by multiplying the two input matrices together");
      52         112 : }
      53             : 
      54          55 : MatrixProductDiagonal::MatrixProductDiagonal(const ActionOptions&ao):
      55             :   Action(ao),
      56          55 :   ActionWithVector(ao)
      57             : {
      58          55 :   if( getNumberOfArguments()!=2 ) error("should be two arguments to this action, a matrix and a vector");
      59             : 
      60             :   unsigned ncols;
      61          55 :   if( getPntrToArgument(0)->getRank()==1 ) {
      62           2 :     if( getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a vector or matrix");
      63             :     ncols = 1;
      64          53 :   } else if( getPntrToArgument(0)->getRank()==2 ) {
      65          53 :     if( getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix");
      66          53 :     ncols = getPntrToArgument(0)->getShape()[1];
      67             :   }
      68             : 
      69          55 :   if( getPntrToArgument(1)->getRank()==1 ) {
      70          32 :     if( getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a vector or matrix");
      71          32 :     if( ncols!=getPntrToArgument(1)->getShape()[0] ) error("number of columns in first matrix does not equal number of elements in vector");
      72          32 :     if( getPntrToArgument(0)->getShape()[0]!=1 ) error("matrix output by this action must be square");
      73          64 :     addValueWithDerivatives(); setNotPeriodic();
      74             :   } else {
      75          23 :     if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a vector or a matrix");
      76          23 :     if( ncols!=getPntrToArgument(1)->getShape()[0] ) error("number of columns in first matrix does not equal number of rows in second matrix");
      77          23 :     if( getPntrToArgument(0)->getShape()[0]!=getPntrToArgument(1)->getShape()[1] ) error("matrix output by this action must be square");
      78          23 :     std::vector<unsigned> shape(1); shape[0]=getPntrToArgument(0)->getShape()[0];
      79          23 :     addValue( shape ); setNotPeriodic();
      80             :   }
      81          55 :   getPntrToArgument(0)->buildDataStore(); getPntrToArgument(1)->buildDataStore();
      82          55 : }
      83             : 
      84        2418 : unsigned MatrixProductDiagonal::getNumberOfDerivatives() {
      85        2418 :   if( doNotCalculateDerivatives() ) return 0;
      86         108 :   return getPntrToArgument(0)->getNumberOfValues() + getPntrToArgument(1)->getNumberOfValues();;
      87             : }
      88             : 
      89      119881 : void MatrixProductDiagonal::performTask( const unsigned& task_index, MultiValue& myvals ) const {
      90      119881 :   unsigned ostrn = getConstPntrToComponent(0)->getPositionInStream();
      91             :   Value* arg1 = getPntrToArgument(0); Value* arg2 = getPntrToArgument(1);
      92      119881 :   if( arg1->getRank()==1 ) {
      93          40 :     double val1 = arg1->get( task_index );
      94          40 :     double val2 = arg2->get( task_index );
      95          40 :     myvals.addValue( ostrn, val1*val2 );
      96             : 
      97          40 :     if( doNotCalculateDerivatives() ) return;
      98             : 
      99          40 :     myvals.addDerivative( ostrn, task_index, val2 );
     100          40 :     myvals.updateIndex( ostrn, task_index );
     101          40 :     unsigned nvals = getPntrToArgument(0)->getNumberOfValues();
     102          40 :     myvals.addDerivative( ostrn, nvals + task_index, val1 );
     103          40 :     myvals.updateIndex( ostrn, nvals + task_index );
     104             :   } else {
     105             :     unsigned nmult = arg1->getRowLength(task_index);
     106      119841 :     unsigned nrowsA = getPntrToArgument(0)->getShape()[1];
     107      119841 :     unsigned nrowsB = 1; if( getPntrToArgument(1)->getRank()>1 ) nrowsB = getPntrToArgument(1)->getShape()[1];
     108      119841 :     unsigned nvals1 = getPntrToArgument(0)->getNumberOfValues();
     109             : 
     110             :     double matval = 0;
     111     3025572 :     for(unsigned i=0; i<nmult; ++i) {
     112             :       unsigned kind = arg1->getRowIndex( task_index, i );
     113     2905731 :       double val1 = arg1->get( task_index*nrowsA + kind );
     114     2905731 :       double val2 = arg2->get( kind*nrowsB + task_index );
     115     2905731 :       matval += val1*val2;
     116             : 
     117     2905731 :       if( doNotCalculateDerivatives() ) continue;
     118             : 
     119     2836839 :       myvals.addDerivative( ostrn, task_index*nrowsA + kind, val2 );
     120     2836839 :       myvals.updateIndex( ostrn, task_index*nrowsA + kind );
     121     2836839 :       myvals.addDerivative( ostrn, nvals1 + kind*nrowsB + task_index, val1 );
     122     2836839 :       myvals.updateIndex( ostrn, nvals1 + kind*nrowsB + task_index );
     123             :     }
     124             :     // And add this part of the product
     125      119841 :     myvals.addValue( ostrn, matval );
     126             :   }
     127             : }
     128             : 
     129        2792 : void MatrixProductDiagonal::calculate() {
     130        2792 :   if( getPntrToArgument(1)->getRank()==1 ) {
     131        1179 :     unsigned nder = getNumberOfDerivatives();
     132        1179 :     MultiValue myvals( 1, nder, 0, 0, 0 ); performTask( 0, myvals );
     133             : 
     134        1179 :     Value* myval=getPntrToComponent(0); myval->set( myvals.get(0) );
     135        1329 :     for(unsigned i=0; i<nder; ++i) myval->setDerivative( i, myvals.getDerivative(0,i) );
     136        2792 :   } else runAllTasks();
     137        2792 : }
     138             : 
     139             : }
     140             : }

Generated by: LCOV version 1.16