LCOV - code coverage report
Current view: top level - valtools - SelectWithMask.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 121 139 87.1 %
Date: 2025-04-08 21:11:17 Functions: 8 9 88.9 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2014-2017 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionWithValue.h"
      23             : #include "core/ActionWithArguments.h"
      24             : #include "core/ActionRegister.h"
      25             : #include "core/PlumedMain.h"
      26             : #include "core/ActionSet.h"
      27             : 
      28             : //+PLUMEDOC PRINTANALYSIS SELECT_WITH_MASK
      29             : /*
      30             : Use a mask to select elements of an array
      31             : 
      32             : Output a scalar, vector or matrix that contains a subset of the elements in the input vector or matrix.
      33             : The following example shows how we can output a scalar, `v`, that contains the distance between and 3 and 4
      34             : by using the mask vector `m` to select this element from the three element vector `d`:
      35             : 
      36             : ```plumed
      37             : d: DISTANCE ATOMS1=1,2 ATOMS2=3,4 ATOMS3=5,6
      38             : m: CONSTANT VALUES=1,0,1
      39             : v: SELECT_WITH_MASK ARG=d MASK=m
      40             : ```
      41             : 
      42             : The value, `m`, that is passed to the keyword MASK here is a vector with the same length as `d`.
      43             : Elements of `d` that whose corresponding elements in `m` are zero are copied to the output value `v`.
      44             : When elements of `m` are non-zero the corresponding elements in `d` are not transferred to the output
      45             : value - they are masked.
      46             : 
      47             : If you use this action with matrices you must use the keywords `ROW_MASK` and `COLUMN_MASK`. As shown in the example
      48             : inputs below, these keywords take vectors as input.  In this first example, the output matrix is $3 \times 5$ as rows
      49             : of the matrix whose corresponding elements in `m` are non-zero are not transferred:
      50             : 
      51             : ```plumed
      52             : d: DISTANCE_MATRIX GROUP=1-5
      53             : m: CONSTANT VALUES=0,1,1,0,0
      54             : v: SELECT_WITH_MASK ARG=d ROW_MASK=m
      55             : ```
      56             : 
      57             : For this second example the output matrix is $5 \times 3$ as columns of the matrix whose corresponding elements in `m` are non-zero
      58             : are not transferred:
      59             : 
      60             : ```plumed
      61             : d: DISTANCE_MATRIX GROUP=1-5
      62             : m: CONSTANT VALUES=0,1,1,0,0
      63             : v: SELECT_WITH_MASK ARG=d COLUMN_MASK=m
      64             : ```
      65             : 
      66             : For this final example the output matrix is $3 \times 3$ as we do not transfer the rows and the columns in `d` whose corresponding
      67             : elements in `m` are non-zero.
      68             : 
      69             : ```plumed
      70             : d: DISTANCE_MATRIX GROUP=1-5
      71             : m: CONSTANT VALUES=0,1,1,0,0
      72             : v: SELECT_WITH_MASK ARG=d ROW_MASK=m COLUMN_MASK=m
      73             : ```
      74             : 
      75             : */
      76             : //+ENDPLUMEDOC
      77             : 
      78             : namespace PLMD {
      79             : namespace valtools {
      80             : 
      81             : class SelectWithMask :
      82             :   public ActionWithValue,
      83             :   public ActionWithArguments {
      84             : private:
      85             :   unsigned getOutputVectorLength( const Value* mask ) const ;
      86             : public:
      87             :   static void registerKeywords( Keywords& keys );
      88             : /// Constructor
      89             :   explicit SelectWithMask(const ActionOptions&);
      90             : /// Get the number of derivatives
      91          98 :   unsigned getNumberOfDerivatives() override {
      92          98 :     return 0;
      93             :   }
      94             : ///
      95             :   void getMatrixColumnTitles( std::vector<std::string>& argnames ) const override ;
      96             : ///
      97             :   void prepare() override ;
      98             : /// Do the calculation
      99             :   void calculate() override;
     100             : ///
     101             :   void apply() override;
     102             : };
     103             : 
     104             : PLUMED_REGISTER_ACTION(SelectWithMask,"SELECT_WITH_MASK")
     105             : 
     106         178 : void SelectWithMask::registerKeywords( Keywords& keys ) {
     107         178 :   Action::registerKeywords( keys );
     108         178 :   ActionWithValue::registerKeywords( keys );
     109         178 :   ActionWithArguments::registerKeywords( keys );
     110         356 :   keys.addInputKeyword("compulsory","ARG","scalar/vector/matrix","the label for the value upon which you are going to apply the mask");
     111         356 :   keys.addInputKeyword("optional","ROW_MASK","vector","an array with ones in the rows of the matrix that you want to discard");
     112         356 :   keys.addInputKeyword("optional","COLUMN_MASK","vector","an array with ones in the columns of the matrix that you want to discard");
     113         356 :   keys.addInputKeyword("compulsory","MASK","vector/matrix","an array with ones in the components that you want to discard");
     114         356 :   keys.setValueDescription("vector/matrix","a vector/matrix of values that is obtained using a mask to select elements of interest");
     115         178 : }
     116             : 
     117          93 : SelectWithMask::SelectWithMask(const ActionOptions& ao):
     118             :   Action(ao),
     119             :   ActionWithValue(ao),
     120          93 :   ActionWithArguments(ao) {
     121          93 :   if( getNumberOfArguments()!=1 ) {
     122           0 :     error("should only be one argument for this action");
     123             :   }
     124          93 :   getPntrToArgument(0)->buildDataStore();
     125             :   std::vector<unsigned> shape;
     126          93 :   if( getPntrToArgument(0)->getRank()==1 ) {
     127             :     std::vector<Value*> mask;
     128         136 :     parseArgumentList("MASK",mask);
     129          68 :     if( mask.size()!=1 ) {
     130           0 :       error("should only be one input for mask");
     131             :     }
     132          68 :     if( mask[0]->getNumberOfValues()!=getPntrToArgument(0)->getNumberOfValues() ) {
     133           0 :       error("mismatch between size of mask and input vector");
     134             :     }
     135          68 :     log.printf("  creating vector from elements of %s who have a corresponding element in %s that is zero\n", getPntrToArgument(0)->getName().c_str(), mask[0]->getName().c_str() );
     136          68 :     std::vector<Value*> args( getArguments() );
     137          68 :     args.push_back( mask[0] );
     138          68 :     requestArguments( args );
     139          68 :     shape.resize(1,0);
     140          68 :     shape[0]=getOutputVectorLength(mask[0]);
     141          25 :   } else if( getPntrToArgument(0)->getRank()==2 ) {
     142             :     std::vector<Value*> rmask, cmask;
     143          25 :     parseArgumentList("ROW_MASK",rmask);
     144          50 :     parseArgumentList("COLUMN_MASK",cmask);
     145          25 :     if( rmask.size()==0 && cmask.size()==0 ) {
     146           0 :       error("no mask elements have been specified");
     147          25 :     } else if( cmask.size()==0 ) {
     148          11 :       std::string con="0";
     149         144 :       for(unsigned i=1; i<getPntrToArgument(0)->getShape()[1]; ++i) {
     150             :         con += ",0";
     151             :       }
     152          11 :       plumed.readInputWords( Tools::getWords(getLabel() + "_colmask: CONSTANT VALUES=" + con), false );
     153          11 :       std::vector<std::string> labs(1, getLabel() + "_colmask");
     154          11 :       ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, cmask );
     155          25 :     } else if( rmask.size()==0 ) {
     156           1 :       std::string con="0";
     157          13 :       for(unsigned i=1; i<getPntrToArgument(0)->getShape()[0]; ++i) {
     158             :         con += ",0";
     159             :       }
     160           1 :       plumed.readInputWords( Tools::getWords(getLabel() + "_rowmask: CONSTANT VALUES=" + con), false );
     161           1 :       std::vector<std::string> labs(1, getLabel() + "_rowmask");
     162           1 :       ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, rmask );
     163           1 :     }
     164          25 :     shape.resize(2);
     165          25 :     rmask[0]->buildDataStore();
     166          25 :     shape[0] = getOutputVectorLength( rmask[0] );
     167          25 :     cmask[0]->buildDataStore();
     168          25 :     shape[1] = getOutputVectorLength( cmask[0] );
     169          25 :     std::vector<Value*> args( getArguments() );
     170          25 :     args.push_back( rmask[0] );
     171          25 :     args.push_back( cmask[0] );
     172          25 :     requestArguments( args );
     173             :   } else {
     174           0 :     error("input should be vector or matrix");
     175             :   }
     176             : 
     177          93 :   addValue( shape );
     178          93 :   getPntrToComponent(0)->buildDataStore();
     179          93 :   if( getPntrToArgument(0)->isPeriodic() ) {
     180             :     std::string min, max;
     181           7 :     getPntrToArgument(0)->getDomain( min, max );
     182           7 :     setPeriodic( min, max );
     183             :   } else {
     184          86 :     setNotPeriodic();
     185             :   }
     186          93 :   if( getPntrToComponent(0)->getRank()==2 ) {
     187          25 :     getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
     188             :   }
     189          93 : }
     190             : 
     191       10704 : unsigned SelectWithMask::getOutputVectorLength( const Value* mask ) const  {
     192             :   unsigned l=0;
     193      154174 :   for(unsigned i=0; i<mask->getNumberOfValues(); ++i) {
     194      143470 :     if( fabs(mask->get(i))>0 ) {
     195       10015 :       continue;
     196             :     }
     197      133455 :     l++;
     198             :   }
     199       10704 :   return l;
     200             : }
     201             : 
     202          18 : void SelectWithMask::getMatrixColumnTitles( std::vector<std::string>& argnames ) const {
     203             :   std::vector<std::string> alltitles;
     204          18 :   (getPntrToArgument(0)->getPntrToAction())->getMatrixColumnTitles( alltitles );
     205         103 :   for(unsigned i=0; i<alltitles.size(); ++i) {
     206          85 :     if( fabs(getPntrToArgument(2)->get(i))>0 ) {
     207          34 :       continue;
     208             :     }
     209          51 :     argnames.push_back( alltitles[i] );
     210             :   }
     211          18 : }
     212             : 
     213       10551 : void SelectWithMask::prepare() {
     214             :   Value* arg = getPntrToArgument(0);
     215       10551 :   Value* out = getPntrToComponent(0);
     216       10551 :   if( arg->getRank()==1 ) {
     217             :     Value* mask = getPntrToArgument(1);
     218       10516 :     std::vector<unsigned> shape(1);
     219       10516 :     shape[0]=getOutputVectorLength( mask );
     220       10516 :     if( out->getNumberOfValues()!=shape[0] ) {
     221          19 :       if( shape[0]==1 ) {
     222           0 :         shape.resize(0);
     223             :       }
     224          19 :       out->setShape(shape);
     225             :     }
     226          35 :   } else if( arg->getRank()==2 ) {
     227          35 :     std::vector<unsigned> outshape(2);
     228             :     Value* rmask = getPntrToArgument(1);
     229          35 :     outshape[0] = getOutputVectorLength( rmask );
     230             :     Value* cmask = getPntrToArgument(2);
     231          35 :     outshape[1] = getOutputVectorLength( cmask );
     232          35 :     if( out->getShape()[0]!=outshape[0] || out->getShape()[1]!=outshape[1] ) {
     233          19 :       out->setShape(outshape);
     234          19 :       out->reshapeMatrixStore( outshape[1] );
     235             :     }
     236             :   }
     237       10551 : }
     238             : 
     239       10543 : void SelectWithMask::calculate() {
     240             :   Value* arg = getPntrToArgument(0);
     241       10543 :   Value* out = getPntrToComponent(0);
     242       10543 :   if( arg->getRank()==1 ) {
     243             :     Value* mask = getPntrToArgument(1);
     244             :     unsigned n=0;
     245      149144 :     for(unsigned i=0; i<mask->getNumberOfValues(); ++i) {
     246      138632 :       if( fabs(mask->get(i))>0 ) {
     247        7434 :         continue;
     248             :       }
     249      131198 :       out->set(n, arg->get(i) );
     250      131198 :       n++;
     251             :     }
     252          31 :   } else if ( arg->getRank()==2 ) {
     253          31 :     std::vector<unsigned> outshape( out->getShape() );
     254             :     unsigned n = 0;
     255          31 :     std::vector<unsigned> inshape( arg->getShape() );
     256             :     Value* rmask = getPntrToArgument(1);
     257             :     Value* cmask = getPntrToArgument(2);
     258        1774 :     for(unsigned i=0; i<inshape[0]; ++i) {
     259        1743 :       if( fabs(rmask->get(i))>0 ) {
     260         592 :         continue;
     261             :       }
     262             :       unsigned m = 0;
     263      378651 :       for(unsigned j=0; j<inshape[1]; ++j) {
     264      377500 :         if( fabs(cmask->get(j))>0 ) {
     265      188095 :           continue;
     266             :         }
     267      189405 :         out->set( n*outshape[1] + m, arg->get(i*inshape[1] + j) );
     268      189405 :         m++;
     269             :       }
     270        1151 :       n++;
     271             :     }
     272             :   }
     273       10543 : }
     274             : 
     275       10505 : void SelectWithMask::apply() {
     276       10505 :   if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) {
     277          62 :     return ;
     278             :   }
     279             : 
     280             :   Value* arg = getPntrToArgument(0);
     281       10443 :   Value* out = getPntrToComponent(0);
     282       10443 :   if( arg->getRank()==1 ) {
     283             :     unsigned n=0;
     284             :     Value* mask = getPntrToArgument(1);
     285      145276 :     for(unsigned i=0; i<mask->getNumberOfValues(); ++i) {
     286      134833 :       if( fabs(mask->get(i))>0 ) {
     287        4153 :         continue;
     288             :       }
     289      130680 :       arg->addForce(i, out->getForce(n) );
     290      130680 :       n++;
     291             :     }
     292           0 :   } else if( arg->getRank()==2 ) {
     293             :     unsigned n = 0;
     294           0 :     std::vector<unsigned> inshape( arg->getShape() );
     295           0 :     std::vector<unsigned> outshape( out->getShape() );
     296             :     Value* rmask = getPntrToArgument(1);
     297             :     Value* cmask = getPntrToArgument(2);
     298           0 :     for(unsigned i=0; i<inshape[0]; ++i) {
     299           0 :       if( fabs(rmask->get(i))>0 ) {
     300           0 :         continue;
     301             :       }
     302             :       unsigned m = 0;
     303           0 :       for(unsigned j=0; j<inshape[1]; ++j) {
     304           0 :         if( fabs(cmask->get(j))>0 ) {
     305           0 :           continue;
     306             :         }
     307           0 :         arg->addForce( i*inshape[1] + j, out->getForce(n*outshape[1] + m) );
     308           0 :         m++;
     309             :       }
     310           0 :       n++;
     311             :     }
     312             :   }
     313             : }
     314             : 
     315             : 
     316             : 
     317             : }
     318             : }

Generated by: LCOV version 1.16