Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionWithMatrix.h"
23 : #include "core/ActionRegister.h"
24 : #include "tools/LeptonCall.h"
25 :
26 : //+PLUMEDOC COLVAR OUTER_PRODUCT
27 : /*
28 : Calculate the outer product matrix of two vectors
29 :
30 : \par Examples
31 :
32 : */
33 : //+ENDPLUMEDOC
34 :
35 : namespace PLMD {
36 : namespace matrixtools {
37 :
38 : class OuterProduct : public ActionWithMatrix {
39 : private:
40 : bool domin, domax, diagzero;
41 : LeptonCall function;
42 : unsigned nderivatives;
43 : bool stored_vector1, stored_vector2;
44 : public:
45 : static void registerKeywords( Keywords& keys );
46 : explicit OuterProduct(const ActionOptions&);
47 : unsigned getNumberOfDerivatives();
48 : void prepare() override ;
49 2298 : unsigned getNumberOfColumns() const override { return getConstPntrToComponent(0)->getShape()[1]; }
50 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
51 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
52 : void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
53 : };
54 :
55 : PLUMED_REGISTER_ACTION(OuterProduct,"OUTER_PRODUCT")
56 :
57 137 : void OuterProduct::registerKeywords( Keywords& keys ) {
58 137 : ActionWithMatrix::registerKeywords(keys); keys.use("ARG");
59 274 : keys.add("compulsory","FUNC","x*y","the function of the input vectors that should be put in the elements of the outer product");
60 274 : keys.addFlag("ELEMENTS_ON_DIAGONAL_ARE_ZERO",false,"set all diagonal elements to zero");
61 137 : keys.setValueDescription("a matrix containing the outer product of the two input vectors that was obtained using the function that was input");
62 137 : }
63 :
64 77 : OuterProduct::OuterProduct(const ActionOptions&ao):
65 : Action(ao),
66 : ActionWithMatrix(ao),
67 77 : domin(false),
68 77 : domax(false)
69 : {
70 77 : if( getNumberOfArguments()!=2 ) error("should be two arguments to this action, a matrix and a vector");
71 77 : if( getPntrToArgument(0)->getRank()!=1 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a vector");
72 77 : if( getPntrToArgument(1)->getRank()!=1 || getPntrToArgument(1)->hasDerivatives() ) error("first argument to this action should be a vector");
73 :
74 154 : std::string func; parse("FUNC",func);
75 77 : if( func=="min") {
76 0 : domin=true;
77 0 : log.printf(" taking minimum of two input vectors \n");
78 77 : } else if( func=="max" ) {
79 2 : domax=true;
80 2 : log.printf(" taking maximum of two input vectors \n");
81 : } else {
82 75 : log.printf(" with function : %s \n", func.c_str() );
83 75 : std::vector<std::string> var(2); var[0]="x"; var[1]="y";
84 75 : function.set( func, var, this );
85 75 : }
86 77 : parseFlag("ELEMENTS_ON_DIAGONAL_ARE_ZERO",diagzero);
87 77 : if( diagzero ) log.printf(" setting diagonal elements equal to zero\n");
88 :
89 77 : std::vector<unsigned> shape(2); shape[0]=getPntrToArgument(0)->getShape()[0]; shape[1]=getPntrToArgument(1)->getShape()[0];
90 77 : addValue( shape ); setNotPeriodic(); nderivatives = buildArgumentStore(0);
91 77 : std::string headstr=getFirstActionInChain()->getLabel();
92 77 : stored_vector1 = getPntrToArgument(0)->ignoreStoredValue( headstr );
93 77 : stored_vector2 = getPntrToArgument(1)->ignoreStoredValue( headstr );
94 77 : if( getPntrToArgument(0)->isDerivativeZeroWhenValueIsZero() || getPntrToArgument(1)->isDerivativeZeroWhenValueIsZero() ) getPntrToComponent(0)->setDerivativeIsZeroWhenValueIsZero();
95 77 : }
96 :
97 96 : unsigned OuterProduct::getNumberOfDerivatives() {
98 96 : return nderivatives;
99 : }
100 :
101 158 : void OuterProduct::prepare() {
102 158 : ActionWithVector::prepare(); Value* myval=getPntrToComponent(0);
103 158 : if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] && myval->getShape()[1]==getPntrToArgument(1)->getShape()[0] ) return;
104 17 : std::vector<unsigned> shape(2); shape[0] = getPntrToArgument(0)->getShape()[0]; shape[1] = getPntrToArgument(1)->getShape()[0];
105 17 : myval->setShape( shape );
106 : }
107 :
108 27151 : void OuterProduct::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
109 27151 : unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(1)->getShape()[0];
110 27151 : if( diagzero ) {
111 990 : if( indices.size()!=size_v ) indices.resize( size_v );
112 : unsigned k=1;
113 99000 : for(unsigned i=0; i<size_v; ++i) {
114 98010 : if( task_index==i ) continue ;
115 97020 : indices[k] = size_v + i; k++;
116 : }
117 : myvals.setSplitIndex( size_v );
118 : } else {
119 26161 : if( indices.size()!=size_v+1 ) indices.resize( size_v+1 );
120 1690193 : for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + i;
121 : myvals.setSplitIndex( size_v + 1 );
122 : }
123 27151 : }
124 :
125 6874326 : void OuterProduct::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
126 6874326 : unsigned ostrn = getConstPntrToComponent(0)->getPositionInStream(), ind2=index2;
127 6874326 : if( index2>=getPntrToArgument(0)->getShape()[0] ) ind2 = index2 - getPntrToArgument(0)->getShape()[0];
128 13385626 : if( diagzero && index1==ind2 ) return;
129 :
130 6874326 : double fval; unsigned jarg = 0, kelem = index1; bool jstore=stored_vector1;
131 6874326 : std::vector<double> args(2);
132 6874326 : args[0] = getArgumentElement( 0, index1, myvals );
133 6874326 : args[1] = getArgumentElement( 1, ind2, myvals );
134 6874326 : if( domin ) {
135 0 : fval=args[0]; if( args[1]<args[0] ) { fval=args[1]; jarg=1; kelem=ind2; jstore=stored_vector2; }
136 6874326 : } else if( domax ) {
137 315192 : fval=args[0]; if( args[1]>args[0] ) { fval=args[1]; jarg=1; kelem=ind2; jstore=stored_vector2; }
138 6559134 : } else { fval=function.evaluate( args ); }
139 :
140 6874326 : myvals.addValue( ostrn, fval );
141 6874326 : if( doNotCalculateDerivatives() ) return ;
142 :
143 366326 : if( domin || domax ) {
144 0 : addDerivativeOnVectorArgument( jstore, 0, jarg, kelem, 1.0, myvals );
145 : } else {
146 366326 : addDerivativeOnVectorArgument( stored_vector1, 0, 0, index1, function.evaluateDeriv( 0, args ), myvals );
147 366326 : addDerivativeOnVectorArgument( stored_vector2, 0, 1, ind2, function.evaluateDeriv( 1, args ), myvals );
148 : }
149 366326 : if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
150 363026 : unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
151 363026 : myvals.getMatrixRowDerivativeIndices( nmat )[nmat_ind] = arg_deriv_starts[1] + ind2; myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+1 );
152 : }
153 :
154 39963 : void OuterProduct::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
155 39963 : if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
156 11402 : unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
157 11402 : myvals.getMatrixRowDerivativeIndices( nmat )[nmat_ind] = ival; myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+1 );
158 : }
159 :
160 : }
161 : }
|