Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2014-2017 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionWithMatrix.h" 23 : #include "core/ActionRegister.h" 24 : 25 : //+PLUMEDOC MCOLVAR VSTACK 26 : /* 27 : Create a matrix by stacking vectors together 28 : 29 : \par Examples 30 : 31 : */ 32 : //+ENDPLUMEDOC 33 : 34 : namespace PLMD { 35 : namespace valtools { 36 : 37 : class VStack : public ActionWithMatrix { 38 : private: 39 : std::vector<bool> stored; 40 : public: 41 : static void registerKeywords( Keywords& keys ); 42 : /// Constructor 43 : explicit VStack(const ActionOptions&); 44 : /// Get the number of derivatives 45 294 : unsigned getNumberOfDerivatives() override { return 0; } 46 : /// 47 : void prepare() override ; 48 : /// 49 1270505 : unsigned getNumberOfColumns() const override { return getNumberOfArguments(); } 50 : /// 51 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const override ; 52 : /// 53 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override ; 54 : /// 55 : void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ; 56 : /// 57 : void getMatrixColumnTitles( std::vector<std::string>& argnames ) const override ; 58 : }; 59 : 60 : PLUMED_REGISTER_ACTION(VStack,"VSTACK") 61 : 62 254 : void VStack::registerKeywords( Keywords& keys ) { 63 254 : ActionWithMatrix::registerKeywords( keys ); keys.use("ARG"); 64 254 : keys.setValueDescription("a matrix that contains the input vectors in its columns"); 65 254 : } 66 : 67 139 : VStack::VStack(const ActionOptions& ao): 68 : Action(ao), 69 139 : ActionWithMatrix(ao) 70 : { 71 139 : if( getNumberOfArguments()==0 ) error("no arguments were specificed"); 72 139 : if( getPntrToArgument(0)->getRank()>1 ) error("all arguments should be vectors"); 73 : unsigned nvals=1; bool periodic=false; std::string smin, smax; 74 139 : if( getPntrToArgument(0)->getRank()==1 ) nvals = getPntrToArgument(0)->getShape()[0]; 75 139 : if( getPntrToArgument(0)->isPeriodic() ) { periodic=true; getPntrToArgument(0)->getDomain( smin, smax ); } 76 : 77 1228 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 78 1089 : if( getPntrToArgument(i)->getRank()>1 || (getPntrToArgument(i)->getRank()==1 && getPntrToArgument(i)->hasDerivatives()) ) error("all arguments should be vectors"); 79 1089 : if( getPntrToArgument(i)->getRank()==0 ) { 80 41 : if( nvals!=1 ) error("all input vector should have same number of elements"); 81 1048 : } else if( getPntrToArgument(i)->getShape()[0]!=nvals ) error("all input vector should have same number of elements"); 82 1089 : if( periodic ) { 83 51 : if( !getPntrToArgument(i)->isPeriodic() ) error("one argument is periodic but " + getPntrToArgument(i)->getName() + " is not periodic"); 84 51 : std::string tmin, tmax; getPntrToArgument(i)->getDomain( tmin, tmax ); 85 51 : if( tmin!=smin || tmax!=smax ) error("domain of argument " + getPntrToArgument(i)->getName() + " is different from domain for all other arguments"); 86 1038 : } else if( getPntrToArgument(i)->isPeriodic() ) error("one argument is not periodic but " + getPntrToArgument(i)->getName() + " is periodic"); 87 : } 88 : // And create a value to hold the matrix 89 139 : std::vector<unsigned> shape(2); shape[0]=nvals; shape[1]=getNumberOfArguments(); addValue( shape ); 90 139 : if( periodic ) setPeriodic( smin, smax ); else setNotPeriodic(); 91 : // And store this value 92 139 : getPntrToComponent(0)->buildDataStore(); getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 93 : // Setup everything so we can build the store 94 139 : done_in_chain=true; ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() ); 95 139 : if( av ) { 96 99 : const ActionWithVector* head0 = av->getFirstActionInChain(); 97 996 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 98 928 : ActionWithVector* avv=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() ); 99 928 : if( !avv ) continue; 100 924 : if( head0!=avv->getFirstActionInChain() ) { done_in_chain=false; break; } 101 : } 102 40 : } else done_in_chain=false; 103 139 : unsigned nder = buildArgumentStore(0); 104 : // This checks which values have been stored 105 139 : stored.resize( getNumberOfArguments() ); std::string headstr=getFirstActionInChain()->getLabel(); 106 1228 : for(unsigned i=0; i<stored.size(); ++i) stored[i] = getPntrToArgument(i)->ignoreStoredValue( headstr ); 107 139 : } 108 : 109 23 : void VStack::getMatrixColumnTitles( std::vector<std::string>& argnames ) const { 110 60 : for(unsigned j=0; j<getNumberOfArguments(); ++j) { 111 37 : if( (getPntrToArgument(j)->getPntrToAction())->getName()=="COLLECT" ) { 112 17 : ActionWithArguments* aa = dynamic_cast<ActionWithArguments*>( getPntrToArgument(j)->getPntrToAction() ); 113 17 : plumed_assert( aa && aa->getNumberOfArguments()==1 ); argnames.push_back( (aa->getPntrToArgument(0))->getName() ); 114 20 : } else argnames.push_back( getPntrToArgument(j)->getName() ); 115 : } 116 23 : } 117 : 118 3100 : void VStack::prepare() { 119 3100 : ActionWithVector::prepare(); 120 3100 : if( getPntrToArgument(0)->getRank()==0 || getPntrToArgument(0)->getShape()[0]==getPntrToComponent(0)->getShape()[0] ) return ; 121 18 : std::vector<unsigned> shape(2); shape[0] = getPntrToArgument(0)->getShape()[0]; shape[1] = getNumberOfArguments(); 122 18 : getPntrToComponent(0)->setShape(shape); getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 123 : } 124 : 125 214418 : void VStack::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const { 126 214418 : unsigned nargs = getNumberOfArguments(); unsigned nvals = getConstPntrToComponent(0)->getShape()[0]; 127 214418 : if( indices.size()!=nargs+1 ) indices.resize( nargs+1 ); 128 4287565 : for(unsigned i=0; i<nargs; ++i) indices[i+1] = nvals + i; 129 : myvals.setSplitIndex( nargs + 1 ); 130 214418 : } 131 : 132 4073147 : void VStack::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const { 133 4073147 : unsigned ind2 = index2; if( index2>=getConstPntrToComponent(0)->getShape()[0] ) ind2 = index2 - getConstPntrToComponent(0)->getShape()[0]; 134 4073147 : myvals.addValue( getConstPntrToComponent(0)->getPositionInStream(), getArgumentElement( ind2, index1, myvals ) ); 135 : 136 4073147 : if( doNotCalculateDerivatives() ) return; 137 3692599 : addDerivativeOnVectorArgument( stored[ind2], 0, ind2, index1, 1.0, myvals ); 138 : } 139 : 140 214418 : void VStack::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const { 141 214418 : if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ; 142 : 143 3 : unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat ); 144 : std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); 145 3 : plumed_assert( nmat_ind<matrix_indices.size() ); 146 12 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 147 9 : bool found=false; ActionWithValue* iav = getPntrToArgument(i)->getPntrToAction(); 148 9 : for(unsigned j=0; j<i; ++j) { 149 6 : if( iav==getPntrToArgument(j)->getPntrToAction() ) { found=true; break; } 150 : } 151 9 : if( found ) continue ; 152 : 153 : unsigned istrn = getPntrToArgument(i)->getPositionInStream(); 154 48 : for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) { 155 45 : matrix_indices[nmat_ind] = myvals.getActiveIndex(istrn,k); nmat_ind++; 156 : } 157 : } 158 : myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind ); 159 : } 160 : 161 : } 162 : }