Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2014-2017 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionWithValue.h"
23 : #include "core/ActionWithArguments.h"
24 : #include "core/ActionRegister.h"
25 : #include "core/PlumedMain.h"
26 : #include "core/ActionSet.h"
27 :
28 : //+PLUMEDOC PRINTANALYSIS SELECT_WITH_MASK
29 : /*
30 : Use a mask to select elements of an array
31 :
32 : Output a scalar, vector or matrix that contains a subset of the elements in the input vector or matrix.
33 : The following example shows how we can output a scalar, `v`, that contains the distance between and 3 and 4
34 : by using the mask vector `m` to select this element from the three element vector `d`:
35 :
36 : ```plumed
37 : d: DISTANCE ATOMS1=1,2 ATOMS2=3,4 ATOMS3=5,6
38 : m: CONSTANT VALUES=1,0,1
39 : v: SELECT_WITH_MASK ARG=d MASK=m
40 : ```
41 :
42 : The value, `m`, that is passed to the keyword MASK here is a vector with the same length as `d`.
43 : Elements of `d` that whose corresponding elements in `m` are zero are copied to the output value `v`.
44 : When elements of `m` are non-zero the corresponding elements in `d` are not transferred to the output
45 : value - they are masked.
46 :
47 : If you use this action with matrices you must use the keywords `ROW_MASK` and `COLUMN_MASK`. As shown in the example
48 : inputs below, these keywords take vectors as input. In this first example, the output matrix is $3 \times 5$ as rows
49 : of the matrix whose corresponding elements in `m` are non-zero are not transferred:
50 :
51 : ```plumed
52 : d: DISTANCE_MATRIX GROUP=1-5
53 : m: CONSTANT VALUES=0,1,1,0,0
54 : v: SELECT_WITH_MASK ARG=d ROW_MASK=m
55 : ```
56 :
57 : For this second example the output matrix is $5 \times 3$ as columns of the matrix whose corresponding elements in `m` are non-zero
58 : are not transferred:
59 :
60 : ```plumed
61 : d: DISTANCE_MATRIX GROUP=1-5
62 : m: CONSTANT VALUES=0,1,1,0,0
63 : v: SELECT_WITH_MASK ARG=d COLUMN_MASK=m
64 : ```
65 :
66 : For this final example the output matrix is $3 \times 3$ as we do not transfer the rows and the columns in `d` whose corresponding
67 : elements in `m` are non-zero.
68 :
69 : ```plumed
70 : d: DISTANCE_MATRIX GROUP=1-5
71 : m: CONSTANT VALUES=0,1,1,0,0
72 : v: SELECT_WITH_MASK ARG=d ROW_MASK=m COLUMN_MASK=m
73 : ```
74 :
75 : */
76 : //+ENDPLUMEDOC
77 :
78 : namespace PLMD {
79 : namespace valtools {
80 :
81 : class SelectWithMask :
82 : public ActionWithValue,
83 : public ActionWithArguments {
84 : private:
85 : unsigned getOutputVectorLength( const Value* mask ) const ;
86 : public:
87 : static void registerKeywords( Keywords& keys );
88 : /// Constructor
89 : explicit SelectWithMask(const ActionOptions&);
90 : /// Get the number of derivatives
91 98 : unsigned getNumberOfDerivatives() override {
92 98 : return 0;
93 : }
94 : ///
95 : void getMatrixColumnTitles( std::vector<std::string>& argnames ) const override ;
96 : ///
97 : void prepare() override ;
98 : /// Do the calculation
99 : void calculate() override;
100 : ///
101 : void apply() override;
102 : };
103 :
104 : PLUMED_REGISTER_ACTION(SelectWithMask,"SELECT_WITH_MASK")
105 :
106 178 : void SelectWithMask::registerKeywords( Keywords& keys ) {
107 178 : Action::registerKeywords( keys );
108 178 : ActionWithValue::registerKeywords( keys );
109 178 : ActionWithArguments::registerKeywords( keys );
110 356 : keys.addInputKeyword("compulsory","ARG","scalar/vector/matrix","the label for the value upon which you are going to apply the mask");
111 356 : keys.addInputKeyword("optional","ROW_MASK","vector","an array with ones in the rows of the matrix that you want to discard");
112 356 : keys.addInputKeyword("optional","COLUMN_MASK","vector","an array with ones in the columns of the matrix that you want to discard");
113 356 : keys.addInputKeyword("compulsory","MASK","vector/matrix","an array with ones in the components that you want to discard");
114 356 : keys.setValueDescription("vector/matrix","a vector/matrix of values that is obtained using a mask to select elements of interest");
115 178 : }
116 :
117 93 : SelectWithMask::SelectWithMask(const ActionOptions& ao):
118 : Action(ao),
119 : ActionWithValue(ao),
120 93 : ActionWithArguments(ao) {
121 93 : if( getNumberOfArguments()!=1 ) {
122 0 : error("should only be one argument for this action");
123 : }
124 93 : getPntrToArgument(0)->buildDataStore();
125 : std::vector<unsigned> shape;
126 93 : if( getPntrToArgument(0)->getRank()==1 ) {
127 : std::vector<Value*> mask;
128 136 : parseArgumentList("MASK",mask);
129 68 : if( mask.size()!=1 ) {
130 0 : error("should only be one input for mask");
131 : }
132 68 : if( mask[0]->getNumberOfValues()!=getPntrToArgument(0)->getNumberOfValues() ) {
133 0 : error("mismatch between size of mask and input vector");
134 : }
135 68 : log.printf(" creating vector from elements of %s who have a corresponding element in %s that is zero\n", getPntrToArgument(0)->getName().c_str(), mask[0]->getName().c_str() );
136 68 : std::vector<Value*> args( getArguments() );
137 68 : args.push_back( mask[0] );
138 68 : requestArguments( args );
139 68 : shape.resize(1,0);
140 68 : shape[0]=getOutputVectorLength(mask[0]);
141 25 : } else if( getPntrToArgument(0)->getRank()==2 ) {
142 : std::vector<Value*> rmask, cmask;
143 25 : parseArgumentList("ROW_MASK",rmask);
144 50 : parseArgumentList("COLUMN_MASK",cmask);
145 25 : if( rmask.size()==0 && cmask.size()==0 ) {
146 0 : error("no mask elements have been specified");
147 25 : } else if( cmask.size()==0 ) {
148 11 : std::string con="0";
149 144 : for(unsigned i=1; i<getPntrToArgument(0)->getShape()[1]; ++i) {
150 : con += ",0";
151 : }
152 11 : plumed.readInputWords( Tools::getWords(getLabel() + "_colmask: CONSTANT VALUES=" + con), false );
153 11 : std::vector<std::string> labs(1, getLabel() + "_colmask");
154 11 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, cmask );
155 25 : } else if( rmask.size()==0 ) {
156 1 : std::string con="0";
157 13 : for(unsigned i=1; i<getPntrToArgument(0)->getShape()[0]; ++i) {
158 : con += ",0";
159 : }
160 1 : plumed.readInputWords( Tools::getWords(getLabel() + "_rowmask: CONSTANT VALUES=" + con), false );
161 1 : std::vector<std::string> labs(1, getLabel() + "_rowmask");
162 1 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, rmask );
163 1 : }
164 25 : shape.resize(2);
165 25 : rmask[0]->buildDataStore();
166 25 : shape[0] = getOutputVectorLength( rmask[0] );
167 25 : cmask[0]->buildDataStore();
168 25 : shape[1] = getOutputVectorLength( cmask[0] );
169 25 : std::vector<Value*> args( getArguments() );
170 25 : args.push_back( rmask[0] );
171 25 : args.push_back( cmask[0] );
172 25 : requestArguments( args );
173 : } else {
174 0 : error("input should be vector or matrix");
175 : }
176 :
177 93 : addValue( shape );
178 93 : getPntrToComponent(0)->buildDataStore();
179 93 : if( getPntrToArgument(0)->isPeriodic() ) {
180 : std::string min, max;
181 7 : getPntrToArgument(0)->getDomain( min, max );
182 7 : setPeriodic( min, max );
183 : } else {
184 86 : setNotPeriodic();
185 : }
186 93 : if( getPntrToComponent(0)->getRank()==2 ) {
187 25 : getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
188 : }
189 93 : }
190 :
191 10704 : unsigned SelectWithMask::getOutputVectorLength( const Value* mask ) const {
192 : unsigned l=0;
193 154174 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) {
194 143470 : if( fabs(mask->get(i))>0 ) {
195 10015 : continue;
196 : }
197 133455 : l++;
198 : }
199 10704 : return l;
200 : }
201 :
202 18 : void SelectWithMask::getMatrixColumnTitles( std::vector<std::string>& argnames ) const {
203 : std::vector<std::string> alltitles;
204 18 : (getPntrToArgument(0)->getPntrToAction())->getMatrixColumnTitles( alltitles );
205 103 : for(unsigned i=0; i<alltitles.size(); ++i) {
206 85 : if( fabs(getPntrToArgument(2)->get(i))>0 ) {
207 34 : continue;
208 : }
209 51 : argnames.push_back( alltitles[i] );
210 : }
211 18 : }
212 :
213 10551 : void SelectWithMask::prepare() {
214 : Value* arg = getPntrToArgument(0);
215 10551 : Value* out = getPntrToComponent(0);
216 10551 : if( arg->getRank()==1 ) {
217 : Value* mask = getPntrToArgument(1);
218 10516 : std::vector<unsigned> shape(1);
219 10516 : shape[0]=getOutputVectorLength( mask );
220 10516 : if( out->getNumberOfValues()!=shape[0] ) {
221 19 : if( shape[0]==1 ) {
222 0 : shape.resize(0);
223 : }
224 19 : out->setShape(shape);
225 : }
226 35 : } else if( arg->getRank()==2 ) {
227 35 : std::vector<unsigned> outshape(2);
228 : Value* rmask = getPntrToArgument(1);
229 35 : outshape[0] = getOutputVectorLength( rmask );
230 : Value* cmask = getPntrToArgument(2);
231 35 : outshape[1] = getOutputVectorLength( cmask );
232 35 : if( out->getShape()[0]!=outshape[0] || out->getShape()[1]!=outshape[1] ) {
233 19 : out->setShape(outshape);
234 19 : out->reshapeMatrixStore( outshape[1] );
235 : }
236 : }
237 10551 : }
238 :
239 10543 : void SelectWithMask::calculate() {
240 : Value* arg = getPntrToArgument(0);
241 10543 : Value* out = getPntrToComponent(0);
242 10543 : if( arg->getRank()==1 ) {
243 : Value* mask = getPntrToArgument(1);
244 : unsigned n=0;
245 149144 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) {
246 138632 : if( fabs(mask->get(i))>0 ) {
247 7434 : continue;
248 : }
249 131198 : out->set(n, arg->get(i) );
250 131198 : n++;
251 : }
252 31 : } else if ( arg->getRank()==2 ) {
253 31 : std::vector<unsigned> outshape( out->getShape() );
254 : unsigned n = 0;
255 31 : std::vector<unsigned> inshape( arg->getShape() );
256 : Value* rmask = getPntrToArgument(1);
257 : Value* cmask = getPntrToArgument(2);
258 1774 : for(unsigned i=0; i<inshape[0]; ++i) {
259 1743 : if( fabs(rmask->get(i))>0 ) {
260 592 : continue;
261 : }
262 : unsigned m = 0;
263 378651 : for(unsigned j=0; j<inshape[1]; ++j) {
264 377500 : if( fabs(cmask->get(j))>0 ) {
265 188095 : continue;
266 : }
267 189405 : out->set( n*outshape[1] + m, arg->get(i*inshape[1] + j) );
268 189405 : m++;
269 : }
270 1151 : n++;
271 : }
272 : }
273 10543 : }
274 :
275 10505 : void SelectWithMask::apply() {
276 10505 : if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) {
277 62 : return ;
278 : }
279 :
280 : Value* arg = getPntrToArgument(0);
281 10443 : Value* out = getPntrToComponent(0);
282 10443 : if( arg->getRank()==1 ) {
283 : unsigned n=0;
284 : Value* mask = getPntrToArgument(1);
285 145276 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) {
286 134833 : if( fabs(mask->get(i))>0 ) {
287 4153 : continue;
288 : }
289 130680 : arg->addForce(i, out->getForce(n) );
290 130680 : n++;
291 : }
292 0 : } else if( arg->getRank()==2 ) {
293 : unsigned n = 0;
294 0 : std::vector<unsigned> inshape( arg->getShape() );
295 0 : std::vector<unsigned> outshape( out->getShape() );
296 : Value* rmask = getPntrToArgument(1);
297 : Value* cmask = getPntrToArgument(2);
298 0 : for(unsigned i=0; i<inshape[0]; ++i) {
299 0 : if( fabs(rmask->get(i))>0 ) {
300 0 : continue;
301 : }
302 : unsigned m = 0;
303 0 : for(unsigned j=0; j<inshape[1]; ++j) {
304 0 : if( fabs(cmask->get(j))>0 ) {
305 0 : continue;
306 : }
307 0 : arg->addForce( i*inshape[1] + j, out->getForce(n*outshape[1] + m) );
308 0 : m++;
309 : }
310 0 : n++;
311 : }
312 : }
313 : }
314 :
315 :
316 :
317 : }
318 : }
|