Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2014-2017 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionWithValue.h" 23 : #include "core/ActionWithArguments.h" 24 : #include "core/ActionRegister.h" 25 : #include "core/PlumedMain.h" 26 : #include "core/ActionSet.h" 27 : 28 : //+PLUMEDOC PRINTANALYSIS SELECT_WITH_MASK 29 : /* 30 : Use a mask to select elements of an array 31 : 32 : \par Examples 33 : 34 : */ 35 : //+ENDPLUMEDOC 36 : 37 : namespace PLMD { 38 : namespace valtools { 39 : 40 : class SelectWithMask : 41 : public ActionWithValue, 42 : public ActionWithArguments { 43 : private: 44 : unsigned getOutputVectorLength( const Value* mask ) const ; 45 : public: 46 : static void registerKeywords( Keywords& keys ); 47 : /// Constructor 48 : explicit SelectWithMask(const ActionOptions&); 49 : /// Get the number of derivatives 50 98 : unsigned getNumberOfDerivatives() override { return 0; } 51 : /// 52 : void getMatrixColumnTitles( std::vector<std::string>& argnames ) const override ; 53 : /// 54 : void prepare() override ; 55 : /// Do the calculation 56 : void calculate() override; 57 : /// 58 : void apply() override; 59 : }; 60 : 61 : PLUMED_REGISTER_ACTION(SelectWithMask,"SELECT_WITH_MASK") 62 : 63 178 : void SelectWithMask::registerKeywords( Keywords& keys ) { 64 178 : Action::registerKeywords( keys ); ActionWithValue::registerKeywords( keys ); 65 178 : ActionWithArguments::registerKeywords( keys ); keys.use("ARG"); 66 356 : keys.add("optional","ROW_MASK","an array with ones in the rows of the matrix that you want to discard"); 67 356 : keys.add("optional","COLUMN_MASK","an array with ones in the columns of the matrix that you want to discard"); 68 356 : keys.add("compulsory","MASK","an array with ones in the components that you want to discard"); 69 178 : keys.setValueDescription("a vector/matrix of values that is obtained using a mask to select elements of interest"); 70 178 : } 71 : 72 93 : SelectWithMask::SelectWithMask(const ActionOptions& ao): 73 : Action(ao), 74 : ActionWithValue(ao), 75 93 : ActionWithArguments(ao) 76 : { 77 93 : if( getNumberOfArguments()!=1 ) error("should only be one argument for this action"); 78 93 : getPntrToArgument(0)->buildDataStore(); std::vector<unsigned> shape; 79 93 : if( getPntrToArgument(0)->getRank()==1 ) { 80 136 : std::vector<Value*> mask; parseArgumentList("MASK",mask); 81 68 : if( mask.size()!=1 ) error("should only be one input for mask"); 82 68 : if( mask[0]->getNumberOfValues()!=getPntrToArgument(0)->getNumberOfValues() ) error("mismatch between size of mask and input vector"); 83 68 : log.printf(" creating vector from elements of %s who have a corresponding element in %s that is zero\n", getPntrToArgument(0)->getName().c_str(), mask[0]->getName().c_str() ); 84 68 : std::vector<Value*> args( getArguments() ); args.push_back( mask[0] ); requestArguments( args ); 85 68 : shape.resize(1,0); if( (mask[0]->getPntrToAction())->getName()=="CONSTANT" ) shape[0]=getOutputVectorLength(mask[0]); 86 25 : } else if( getPntrToArgument(0)->getRank()==2 ) { 87 75 : std::vector<Value*> rmask, cmask; parseArgumentList("ROW_MASK",rmask); parseArgumentList("COLUMN_MASK",cmask); 88 25 : if( rmask.size()==0 && cmask.size()==0 ) { 89 0 : error("no mask elements have been specified"); 90 25 : } else if( cmask.size()==0 ) { 91 144 : std::string con="0"; for(unsigned i=1; i<getPntrToArgument(0)->getShape()[1]; ++i) con += ",0"; 92 22 : plumed.readInputLine( getLabel() + "_colmask: CONSTANT VALUES=" + con ); std::vector<std::string> labs(1, getLabel() + "_colmask"); 93 11 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, cmask ); 94 25 : } else if( rmask.size()==0 ) { 95 13 : std::string con="0"; for(unsigned i=1; i<getPntrToArgument(0)->getShape()[0]; ++i) con += ",0"; 96 2 : plumed.readInputLine( getLabel() + "_rowmask: CONSTANT VALUES=" + con ); std::vector<std::string> labs(1, getLabel() + "_rowmask"); 97 1 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, rmask ); 98 1 : } 99 25 : shape.resize(2); 100 25 : rmask[0]->buildDataStore(); shape[0] = getOutputVectorLength( rmask[0] ); 101 25 : cmask[0]->buildDataStore(); shape[1] = getOutputVectorLength( cmask[0] ); 102 25 : std::vector<Value*> args( getArguments() ); args.push_back( rmask[0] ); 103 25 : args.push_back( cmask[0] ); requestArguments( args ); 104 0 : } else error("input should be vector or matrix"); 105 : 106 93 : addValue( shape ); getPntrToComponent(0)->buildDataStore(); 107 93 : if( getPntrToArgument(0)->isPeriodic() ) { 108 7 : std::string min, max; getPntrToArgument(0)->getDomain( min, max ); setPeriodic( min, max ); 109 86 : } else setNotPeriodic(); 110 93 : if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 111 93 : } 112 : 113 10693 : unsigned SelectWithMask::getOutputVectorLength( const Value* mask ) const { 114 : unsigned l=0; 115 154139 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) { 116 143446 : if( fabs(mask->get(i))>0 ) continue; 117 133435 : l++; 118 : } 119 10693 : return l; 120 : } 121 : 122 18 : void SelectWithMask::getMatrixColumnTitles( std::vector<std::string>& argnames ) const { 123 18 : std::vector<std::string> alltitles; (getPntrToArgument(0)->getPntrToAction())->getMatrixColumnTitles( alltitles ); 124 103 : for(unsigned i=0; i<alltitles.size(); ++i) { 125 85 : if( fabs(getPntrToArgument(2)->get(i))>0 ) continue; 126 51 : argnames.push_back( alltitles[i] ); 127 : } 128 18 : } 129 : 130 10551 : void SelectWithMask::prepare() { 131 10551 : Value* arg = getPntrToArgument(0); Value* out = getPntrToComponent(0); 132 10551 : if( arg->getRank()==1 ) { 133 : Value* mask = getPntrToArgument(1); 134 10516 : std::vector<unsigned> shape(1); shape[0]=getOutputVectorLength( mask ); 135 10516 : if( out->getNumberOfValues()!=shape[0] ) { 136 19 : if( shape[0]==1 ) shape.resize(0); 137 19 : out->setShape(shape); 138 : } 139 35 : } else if( arg->getRank()==2 ) { 140 35 : std::vector<unsigned> outshape(2); 141 35 : Value* rmask = getPntrToArgument(1); outshape[0] = getOutputVectorLength( rmask ); 142 35 : Value* cmask = getPntrToArgument(2); outshape[1] = getOutputVectorLength( cmask ); 143 35 : if( out->getShape()[0]!=outshape[0] || out->getShape()[1]!=outshape[1] ) { 144 20 : out->setShape(outshape); out->reshapeMatrixStore( outshape[1] ); 145 : } 146 : } 147 10551 : } 148 : 149 10543 : void SelectWithMask::calculate() { 150 10543 : Value* arg = getPntrToArgument(0); Value* out = getPntrToComponent(0); 151 10543 : if( arg->getRank()==1 ) { 152 : Value* mask = getPntrToArgument(1); unsigned n=0; 153 149144 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) { 154 138632 : if( fabs(mask->get(i))>0 ) continue; 155 131198 : out->set(n, arg->get(i) ); n++; 156 : } 157 31 : } else if ( arg->getRank()==2 ) { 158 31 : std::vector<unsigned> outshape( out->getShape() ); 159 31 : unsigned n = 0; std::vector<unsigned> inshape( arg->getShape() ); 160 : Value* rmask = getPntrToArgument(1); Value* cmask = getPntrToArgument(2); 161 1774 : for(unsigned i=0; i<inshape[0]; ++i) { 162 1743 : if( fabs(rmask->get(i))>0 ) continue; 163 : unsigned m = 0; 164 378651 : for(unsigned j=0; j<inshape[1]; ++j) { 165 377500 : if( fabs(cmask->get(j))>0 ) continue; 166 189405 : out->set( n*outshape[1] + m, arg->get(i*inshape[1] + j) ); 167 189405 : m++; 168 : } 169 1151 : n++; 170 : } 171 : } 172 10543 : } 173 : 174 10505 : void SelectWithMask::apply() { 175 10505 : if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) return ; 176 : 177 10443 : Value* arg = getPntrToArgument(0); Value* out = getPntrToComponent(0); 178 10443 : if( arg->getRank()==1 ) { 179 : unsigned n=0; Value* mask = getPntrToArgument(1); 180 145276 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) { 181 134833 : if( fabs(mask->get(i))>0 ) continue; 182 130680 : arg->addForce(i, out->getForce(n) ); n++; 183 : } 184 0 : } else if( arg->getRank()==2 ) { 185 : unsigned n = 0; 186 0 : std::vector<unsigned> inshape( arg->getShape() ); 187 0 : std::vector<unsigned> outshape( out->getShape() ); 188 : Value* rmask = getPntrToArgument(1); Value* cmask = getPntrToArgument(2); 189 0 : for(unsigned i=0; i<inshape[0]; ++i) { 190 0 : if( fabs(rmask->get(i))>0 ) continue; 191 : unsigned m = 0; 192 0 : for(unsigned j=0; j<inshape[1]; ++j) { 193 0 : if( fabs(cmask->get(j))>0 ) continue; 194 0 : arg->addForce( i*inshape[1] + j, out->getForce(n*outshape[1] + m) ); 195 0 : m++; 196 : } 197 0 : n++; 198 : } 199 : } 200 : } 201 : 202 : 203 : 204 : } 205 : }