Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2014-2017 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionWithValue.h" 23 : #include "core/ActionWithArguments.h" 24 : #include "core/ActionRegister.h" 25 : #include "core/PlumedMain.h" 26 : #include "core/ActionSet.h" 27 : 28 : //+PLUMEDOC PRINTANALYSIS SELECT_WITH_MASK 29 : /* 30 : Use a mask to select elements of an array 31 : 32 : \par Examples 33 : 34 : */ 35 : //+ENDPLUMEDOC 36 : 37 : namespace PLMD { 38 : namespace valtools { 39 : 40 : class SelectWithMask : 41 : public ActionWithValue, 42 : public ActionWithArguments { 43 : private: 44 : unsigned getOutputVectorLength( const Value* mask ) const ; 45 : public: 46 : static void registerKeywords( Keywords& keys ); 47 : /// Constructor 48 : explicit SelectWithMask(const ActionOptions&); 49 : /// Get the number of derivatives 50 98 : unsigned getNumberOfDerivatives() override { return 0; } 51 : /// 52 : void getMatrixColumnTitles( std::vector<std::string>& argnames ) const override ; 53 : /// 54 : void prepare() override ; 55 : /// Do the calculation 56 : void calculate() override; 57 : /// 58 : void apply() override; 59 : }; 60 : 61 : PLUMED_REGISTER_ACTION(SelectWithMask,"SELECT_WITH_MASK") 62 : 63 178 : void SelectWithMask::registerKeywords( Keywords& keys ) { 64 178 : Action::registerKeywords( keys ); ActionWithValue::registerKeywords( keys ); 65 178 : ActionWithArguments::registerKeywords( keys ); 66 356 : keys.addInputKeyword("compulsory","ARG","scalar/vector/matrix","the label for the value upon which you are going to apply the mask"); 67 356 : keys.addInputKeyword("optional","ROW_MASK","vector","an array with ones in the rows of the matrix that you want to discard"); 68 356 : keys.addInputKeyword("optional","COLUMN_MASK","vector","an array with ones in the columns of the matrix that you want to discard"); 69 356 : keys.addInputKeyword("compulsory","MASK","vector/matrix","an array with ones in the components that you want to discard"); 70 356 : keys.setValueDescription("vector/matrix","a vector/matrix of values that is obtained using a mask to select elements of interest"); 71 178 : } 72 : 73 93 : SelectWithMask::SelectWithMask(const ActionOptions& ao): 74 : Action(ao), 75 : ActionWithValue(ao), 76 93 : ActionWithArguments(ao) 77 : { 78 93 : if( getNumberOfArguments()!=1 ) error("should only be one argument for this action"); 79 93 : getPntrToArgument(0)->buildDataStore(); std::vector<unsigned> shape; 80 93 : if( getPntrToArgument(0)->getRank()==1 ) { 81 136 : std::vector<Value*> mask; parseArgumentList("MASK",mask); 82 68 : if( mask.size()!=1 ) error("should only be one input for mask"); 83 68 : if( mask[0]->getNumberOfValues()!=getPntrToArgument(0)->getNumberOfValues() ) error("mismatch between size of mask and input vector"); 84 68 : log.printf(" creating vector from elements of %s who have a corresponding element in %s that is zero\n", getPntrToArgument(0)->getName().c_str(), mask[0]->getName().c_str() ); 85 68 : std::vector<Value*> args( getArguments() ); args.push_back( mask[0] ); requestArguments( args ); 86 68 : shape.resize(1,0); if( (mask[0]->getPntrToAction())->getName()=="CONSTANT" ) shape[0]=getOutputVectorLength(mask[0]); 87 25 : } else if( getPntrToArgument(0)->getRank()==2 ) { 88 75 : std::vector<Value*> rmask, cmask; parseArgumentList("ROW_MASK",rmask); parseArgumentList("COLUMN_MASK",cmask); 89 25 : if( rmask.size()==0 && cmask.size()==0 ) { 90 0 : error("no mask elements have been specified"); 91 25 : } else if( cmask.size()==0 ) { 92 144 : std::string con="0"; for(unsigned i=1; i<getPntrToArgument(0)->getShape()[1]; ++i) con += ",0"; 93 22 : plumed.readInputLine( getLabel() + "_colmask: CONSTANT VALUES=" + con ); std::vector<std::string> labs(1, getLabel() + "_colmask"); 94 11 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, cmask ); 95 25 : } else if( rmask.size()==0 ) { 96 13 : std::string con="0"; for(unsigned i=1; i<getPntrToArgument(0)->getShape()[0]; ++i) con += ",0"; 97 2 : plumed.readInputLine( getLabel() + "_rowmask: CONSTANT VALUES=" + con ); std::vector<std::string> labs(1, getLabel() + "_rowmask"); 98 1 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, rmask ); 99 1 : } 100 25 : shape.resize(2); 101 25 : rmask[0]->buildDataStore(); shape[0] = getOutputVectorLength( rmask[0] ); 102 25 : cmask[0]->buildDataStore(); shape[1] = getOutputVectorLength( cmask[0] ); 103 25 : std::vector<Value*> args( getArguments() ); args.push_back( rmask[0] ); 104 25 : args.push_back( cmask[0] ); requestArguments( args ); 105 0 : } else error("input should be vector or matrix"); 106 : 107 93 : addValue( shape ); getPntrToComponent(0)->buildDataStore(); 108 93 : if( getPntrToArgument(0)->isPeriodic() ) { 109 7 : std::string min, max; getPntrToArgument(0)->getDomain( min, max ); setPeriodic( min, max ); 110 86 : } else setNotPeriodic(); 111 93 : if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 112 93 : } 113 : 114 10693 : unsigned SelectWithMask::getOutputVectorLength( const Value* mask ) const { 115 : unsigned l=0; 116 154139 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) { 117 143446 : if( fabs(mask->get(i))>0 ) continue; 118 133435 : l++; 119 : } 120 10693 : return l; 121 : } 122 : 123 18 : void SelectWithMask::getMatrixColumnTitles( std::vector<std::string>& argnames ) const { 124 18 : std::vector<std::string> alltitles; (getPntrToArgument(0)->getPntrToAction())->getMatrixColumnTitles( alltitles ); 125 103 : for(unsigned i=0; i<alltitles.size(); ++i) { 126 85 : if( fabs(getPntrToArgument(2)->get(i))>0 ) continue; 127 51 : argnames.push_back( alltitles[i] ); 128 : } 129 18 : } 130 : 131 10551 : void SelectWithMask::prepare() { 132 10551 : Value* arg = getPntrToArgument(0); Value* out = getPntrToComponent(0); 133 10551 : if( arg->getRank()==1 ) { 134 : Value* mask = getPntrToArgument(1); 135 10516 : std::vector<unsigned> shape(1); shape[0]=getOutputVectorLength( mask ); 136 10516 : if( out->getNumberOfValues()!=shape[0] ) { 137 19 : if( shape[0]==1 ) shape.resize(0); 138 19 : out->setShape(shape); 139 : } 140 35 : } else if( arg->getRank()==2 ) { 141 35 : std::vector<unsigned> outshape(2); 142 35 : Value* rmask = getPntrToArgument(1); outshape[0] = getOutputVectorLength( rmask ); 143 35 : Value* cmask = getPntrToArgument(2); outshape[1] = getOutputVectorLength( cmask ); 144 35 : if( out->getShape()[0]!=outshape[0] || out->getShape()[1]!=outshape[1] ) { 145 20 : out->setShape(outshape); out->reshapeMatrixStore( outshape[1] ); 146 : } 147 : } 148 10551 : } 149 : 150 10543 : void SelectWithMask::calculate() { 151 10543 : Value* arg = getPntrToArgument(0); Value* out = getPntrToComponent(0); 152 10543 : if( arg->getRank()==1 ) { 153 : Value* mask = getPntrToArgument(1); unsigned n=0; 154 149144 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) { 155 138632 : if( fabs(mask->get(i))>0 ) continue; 156 131198 : out->set(n, arg->get(i) ); n++; 157 : } 158 31 : } else if ( arg->getRank()==2 ) { 159 31 : std::vector<unsigned> outshape( out->getShape() ); 160 31 : unsigned n = 0; std::vector<unsigned> inshape( arg->getShape() ); 161 : Value* rmask = getPntrToArgument(1); Value* cmask = getPntrToArgument(2); 162 1774 : for(unsigned i=0; i<inshape[0]; ++i) { 163 1743 : if( fabs(rmask->get(i))>0 ) continue; 164 : unsigned m = 0; 165 378651 : for(unsigned j=0; j<inshape[1]; ++j) { 166 377500 : if( fabs(cmask->get(j))>0 ) continue; 167 189405 : out->set( n*outshape[1] + m, arg->get(i*inshape[1] + j) ); 168 189405 : m++; 169 : } 170 1151 : n++; 171 : } 172 : } 173 10543 : } 174 : 175 10505 : void SelectWithMask::apply() { 176 10505 : if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) return ; 177 : 178 10443 : Value* arg = getPntrToArgument(0); Value* out = getPntrToComponent(0); 179 10443 : if( arg->getRank()==1 ) { 180 : unsigned n=0; Value* mask = getPntrToArgument(1); 181 145276 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) { 182 134833 : if( fabs(mask->get(i))>0 ) continue; 183 130680 : arg->addForce(i, out->getForce(n) ); n++; 184 : } 185 0 : } else if( arg->getRank()==2 ) { 186 : unsigned n = 0; 187 0 : std::vector<unsigned> inshape( arg->getShape() ); 188 0 : std::vector<unsigned> outshape( out->getShape() ); 189 : Value* rmask = getPntrToArgument(1); Value* cmask = getPntrToArgument(2); 190 0 : for(unsigned i=0; i<inshape[0]; ++i) { 191 0 : if( fabs(rmask->get(i))>0 ) continue; 192 : unsigned m = 0; 193 0 : for(unsigned j=0; j<inshape[1]; ++j) { 194 0 : if( fabs(cmask->get(j))>0 ) continue; 195 0 : arg->addForce( i*inshape[1] + j, out->getForce(n*outshape[1] + m) ); 196 0 : m++; 197 : } 198 0 : n++; 199 : } 200 : } 201 : } 202 : 203 : 204 : 205 : } 206 : }