Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2011-2023 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionWithVector.h" 23 : #include "core/ActionRegister.h" 24 : 25 : //+PLUMEDOC FUNCTION MATRIX_PRODUCT_DIAGONAL 26 : /* 27 : Calculate the product of two matrices and return a vector that contains the diagonal elements of the ouptut vector 28 : 29 : \par Examples 30 : 31 : */ 32 : //+ENDPLUMEDOC 33 : 34 : namespace PLMD { 35 : namespace refdist { 36 : 37 : class MatrixProductDiagonal : public ActionWithVector { 38 : private: 39 : public: 40 : static void registerKeywords( Keywords& keys ); 41 : explicit MatrixProductDiagonal(const ActionOptions&); 42 : unsigned getNumberOfDerivatives() override ; 43 : void calculate() override ; 44 : void performTask( const unsigned& task_index, MultiValue& myvals ) const override ; 45 : }; 46 : 47 : PLUMED_REGISTER_ACTION(MatrixProductDiagonal,"MATRIX_PRODUCT_DIAGONAL") 48 : 49 112 : void MatrixProductDiagonal::registerKeywords( Keywords& keys ) { 50 112 : ActionWithVector::registerKeywords(keys); 51 224 : keys.addInputKeyword("compulsory","ARG","vector/matrix","the two vectors/matrices whose product are to be taken"); 52 224 : keys.setValueDescription("scalar/vector","a vector containing the diagonal elements of the matrix that obtaned by multiplying the two input matrices together"); 53 112 : } 54 : 55 55 : MatrixProductDiagonal::MatrixProductDiagonal(const ActionOptions&ao): 56 : Action(ao), 57 55 : ActionWithVector(ao) { 58 55 : if( getNumberOfArguments()!=2 ) { 59 0 : error("should be two arguments to this action, a matrix and a vector"); 60 : } 61 : 62 : unsigned ncols; 63 55 : if( getPntrToArgument(0)->getRank()==1 ) { 64 2 : if( getPntrToArgument(0)->hasDerivatives() ) { 65 0 : error("first argument to this action should be a vector or matrix"); 66 : } 67 : ncols = 1; 68 53 : } else if( getPntrToArgument(0)->getRank()==2 ) { 69 53 : if( getPntrToArgument(0)->hasDerivatives() ) { 70 0 : error("first argument to this action should be a matrix"); 71 : } 72 53 : ncols = getPntrToArgument(0)->getShape()[1]; 73 : } 74 : 75 55 : if( getPntrToArgument(1)->getRank()==1 ) { 76 32 : if( getPntrToArgument(1)->hasDerivatives() ) { 77 0 : error("second argument to this action should be a vector or matrix"); 78 : } 79 32 : if( ncols!=getPntrToArgument(1)->getShape()[0] ) { 80 0 : error("number of columns in first matrix does not equal number of elements in vector"); 81 : } 82 32 : if( getPntrToArgument(0)->getShape()[0]!=1 ) { 83 0 : error("matrix output by this action must be square"); 84 : } 85 32 : addValueWithDerivatives(); 86 32 : setNotPeriodic(); 87 : } else { 88 23 : if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) { 89 0 : error("second argument to this action should be a vector or a matrix"); 90 : } 91 23 : if( ncols!=getPntrToArgument(1)->getShape()[0] ) { 92 0 : error("number of columns in first matrix does not equal number of rows in second matrix"); 93 : } 94 23 : if( getPntrToArgument(0)->getShape()[0]!=getPntrToArgument(1)->getShape()[1] ) { 95 0 : error("matrix output by this action must be square"); 96 : } 97 23 : std::vector<unsigned> shape(1); 98 23 : shape[0]=getPntrToArgument(0)->getShape()[0]; 99 23 : addValue( shape ); 100 23 : setNotPeriodic(); 101 : } 102 55 : getPntrToArgument(0)->buildDataStore(); 103 55 : getPntrToArgument(1)->buildDataStore(); 104 55 : } 105 : 106 2418 : unsigned MatrixProductDiagonal::getNumberOfDerivatives() { 107 2418 : if( doNotCalculateDerivatives() ) { 108 : return 0; 109 : } 110 108 : return getPntrToArgument(0)->getNumberOfValues() + getPntrToArgument(1)->getNumberOfValues();; 111 : } 112 : 113 119881 : void MatrixProductDiagonal::performTask( const unsigned& task_index, MultiValue& myvals ) const { 114 119881 : unsigned ostrn = getConstPntrToComponent(0)->getPositionInStream(); 115 : Value* arg1 = getPntrToArgument(0); 116 : Value* arg2 = getPntrToArgument(1); 117 119881 : if( arg1->getRank()==1 ) { 118 40 : double val1 = arg1->get( task_index ); 119 40 : double val2 = arg2->get( task_index ); 120 40 : myvals.addValue( ostrn, val1*val2 ); 121 : 122 40 : if( doNotCalculateDerivatives() ) { 123 0 : return; 124 : } 125 : 126 40 : myvals.addDerivative( ostrn, task_index, val2 ); 127 40 : myvals.updateIndex( ostrn, task_index ); 128 40 : unsigned nvals = getPntrToArgument(0)->getNumberOfValues(); 129 40 : myvals.addDerivative( ostrn, nvals + task_index, val1 ); 130 40 : myvals.updateIndex( ostrn, nvals + task_index ); 131 : } else { 132 : unsigned nmult = arg1->getRowLength(task_index); 133 119841 : unsigned nrowsA = getPntrToArgument(0)->getShape()[1]; 134 : unsigned nrowsB = 1; 135 119841 : if( getPntrToArgument(1)->getRank()>1 ) { 136 118662 : nrowsB = getPntrToArgument(1)->getShape()[1]; 137 : } 138 119841 : unsigned nvals1 = getPntrToArgument(0)->getNumberOfValues(); 139 : 140 : double matval = 0; 141 3025572 : for(unsigned i=0; i<nmult; ++i) { 142 : unsigned kind = arg1->getRowIndex( task_index, i ); 143 2905731 : double val1 = arg1->get( task_index*nrowsA + kind ); 144 2905731 : double val2 = arg2->get( kind*nrowsB + task_index ); 145 2905731 : matval += val1*val2; 146 : 147 2905731 : if( doNotCalculateDerivatives() ) { 148 68892 : continue; 149 : } 150 : 151 2836839 : myvals.addDerivative( ostrn, task_index*nrowsA + kind, val2 ); 152 2836839 : myvals.updateIndex( ostrn, task_index*nrowsA + kind ); 153 2836839 : myvals.addDerivative( ostrn, nvals1 + kind*nrowsB + task_index, val1 ); 154 2836839 : myvals.updateIndex( ostrn, nvals1 + kind*nrowsB + task_index ); 155 : } 156 : // And add this part of the product 157 119841 : myvals.addValue( ostrn, matval ); 158 : } 159 : } 160 : 161 2792 : void MatrixProductDiagonal::calculate() { 162 2792 : if( getPntrToArgument(1)->getRank()==1 ) { 163 1179 : unsigned nder = getNumberOfDerivatives(); 164 1179 : MultiValue myvals( 1, nder, 0, 0, 0 ); 165 1179 : performTask( 0, myvals ); 166 : 167 1179 : Value* myval=getPntrToComponent(0); 168 : myval->set( myvals.get(0) ); 169 1329 : for(unsigned i=0; i<nder; ++i) { 170 150 : myval->setDerivative( i, myvals.getDerivative(0,i) ); 171 : } 172 1179 : } else { 173 1613 : runAllTasks(); 174 : } 175 2792 : } 176 : 177 : } 178 : }