Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2014-2017 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionWithValue.h" 23 : #include "core/ActionWithArguments.h" 24 : #include "core/ActionRegister.h" 25 : 26 : //+PLUMEDOC MCOLVAR CONCATENATE 27 : /* 28 : Join vectors or matrices together 29 : 30 : \par Examples 31 : 32 : */ 33 : //+ENDPLUMEDOC 34 : 35 : namespace PLMD { 36 : namespace valtools { 37 : 38 : class Concatenate : 39 : public ActionWithValue, 40 : public ActionWithArguments { 41 : private: 42 : bool vectors; 43 : std::vector<unsigned> row_starts; 44 : std::vector<unsigned> col_starts; 45 : public: 46 : static void registerKeywords( Keywords& keys ); 47 : /// Constructor 48 : explicit Concatenate(const ActionOptions&); 49 : /// Get the number of derivatives 50 257 : unsigned getNumberOfDerivatives() override { return 0; } 51 : /// Do the calculation 52 : void calculate() override; 53 : /// 54 : void apply(); 55 : }; 56 : 57 : PLUMED_REGISTER_ACTION(Concatenate,"CONCATENATE") 58 : 59 353 : void Concatenate::registerKeywords( Keywords& keys ) { 60 353 : Action::registerKeywords( keys ); ActionWithValue::registerKeywords( keys ); ActionWithArguments::registerKeywords( keys ); 61 706 : keys.addInputKeyword("optional","ARG","scalar/vector","the values that should be concatenated together to form the output vector"); 62 1059 : keys.addInputKeyword("numbered","MATRIX","scalar/matrix","specify the matrices that you wish to join together into a single matrix"); keys.reset_style("MATRIX","compulsory"); 63 706 : keys.setValueDescription("vector/matrix","the concatenated vector/matrix that was constructed from the input values"); 64 353 : } 65 : 66 176 : Concatenate::Concatenate(const ActionOptions& ao): 67 : Action(ao), 68 : ActionWithValue(ao), 69 176 : ActionWithArguments(ao) 70 : { 71 176 : if( getNumberOfArguments()>0 ) { 72 172 : vectors=true; std::vector<unsigned> shape(1); shape[0]=0; 73 547 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 74 375 : if( getPntrToArgument(i)->getRank()>1 ) error("cannot concatenate matrix with vectors"); 75 375 : getPntrToArgument(i)->buildDataStore(); shape[0] += getPntrToArgument(i)->getNumberOfValues(); 76 : } 77 172 : log.printf(" creating vector with %d elements \n", shape[0] ); 78 172 : addValue( shape ); bool period=getPntrToArgument(0)->isPeriodic(); 79 172 : std::string min, max; if( period ) getPntrToArgument(0)->getDomain( min, max ); 80 375 : for(unsigned i=1; i<getNumberOfArguments(); ++i) { 81 203 : if( period!=getPntrToArgument(i)->isPeriodic() ) error("periods of input arguments should match"); 82 203 : if( period ) { 83 0 : std::string min0, max0; getPntrToArgument(i)->getDomain( min0, max0 ); 84 0 : if( min0!=min || max0!=max ) error("domains of input arguments should match"); 85 : } 86 : } 87 172 : if( period ) setPeriodic( min, max ); else setNotPeriodic(); 88 172 : getPntrToComponent(0)->buildDataStore(); 89 172 : if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 90 : } else { 91 4 : unsigned nrows=0, ncols=0; std::vector<Value*> arglist; vectors=false; 92 7 : for(unsigned i=1;; i++) { 93 : unsigned nt_cols=0; unsigned size_b4 = arglist.size(); 94 14 : for(unsigned j=1;; j++) { 95 25 : if( j==10 ) error("cannot combine more than 9 matrices"); 96 50 : std::vector<Value*> argn; parseArgumentList("MATRIX", i*10+j, argn); 97 25 : if( argn.size()==0 ) break; 98 14 : if( argn.size()>1 ) error("should only be one argument to each matrix keyword"); 99 14 : if( argn[0]->getRank()!=0 && argn[0]->getRank()!=2 ) error("input arguments for this action should be matrices"); 100 14 : argn[0]->buildDataStore(); arglist.push_back( argn[0] ); nt_cols++; 101 14 : if( argn[0]->getRank()==0 ) log.printf(" %d %d component of composed matrix is scalar labelled %s\n", i, j, argn[0]->getName().c_str() ); 102 14 : else log.printf(" %d %d component of composed matrix is %d by %d matrix labelled %s\n", i, j, argn[0]->getShape()[0], argn[0]->getShape()[1], argn[0]->getName().c_str() ); 103 14 : } 104 11 : if( arglist.size()==size_b4 ) break; 105 7 : if( i==1 ) ncols=nt_cols; 106 3 : else if( nt_cols!=ncols ) error("should be joining same number of matrices in each row"); 107 7 : nrows++; 108 7 : } 109 : 110 4 : std::vector<unsigned> shape(2); shape[0]=0; unsigned k=0; 111 4 : row_starts.resize( arglist.size() ); col_starts.resize( arglist.size() ); 112 11 : for(unsigned i=0; i<nrows; ++i) { 113 7 : unsigned cstart = 0, nr = 1; if( arglist[k]->getRank()==2 ) nr=arglist[k]->getShape()[0]; 114 21 : for(unsigned j=0; j<ncols; ++j) { 115 14 : if( arglist[k]->getRank()==0 ) { 116 0 : if( nr!=1 ) error("mismatched matrix sizes"); 117 14 : } else if( nrows>1 && arglist[k]->getShape()[0]!=nr ) error("mismatched matrix sizes"); 118 14 : row_starts[k] = shape[0]; col_starts[k] = cstart; 119 14 : if( arglist[k]->getRank()==0 ) cstart += 1; 120 14 : else cstart += arglist[k]->getShape()[1]; 121 14 : k++; 122 : } 123 7 : if( i==0 ) shape[1]=cstart; 124 3 : else if( cstart!=shape[1] ) error("mismatched matrix sizes"); 125 7 : if( arglist[k-1]->getRank()==0 ) shape[0] += 1; 126 7 : else shape[0] += arglist[k-1]->getShape()[0]; 127 : } 128 : // Now request the arguments to make sure we store things we need 129 4 : requestArguments(arglist); addValue( shape ); setNotPeriodic(); getPntrToComponent(0)->buildDataStore(); 130 4 : if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 131 : } 132 176 : } 133 : 134 12191 : void Concatenate::calculate() { 135 12191 : Value* myval = getPntrToComponent(0); 136 12191 : if( vectors ) { 137 : unsigned k=0; 138 61297 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 139 49158 : Value* myarg=getPntrToArgument(i); unsigned nvals=myarg->getNumberOfValues(); 140 404266 : for(unsigned j=0; j<nvals; ++j) { myval->set( k, myarg->get(j) ); k++; } 141 : } 142 : } else { 143 : // Retrieve the matrix from input 144 52 : unsigned ncols = myval->getShape()[1]; 145 258 : for(unsigned k=0; k<getNumberOfArguments(); ++k) { 146 : Value* argn = getPntrToArgument(k); 147 206 : if( argn->getRank()==0 ) { 148 0 : myval->set( ncols*row_starts[k]+col_starts[k], argn->get() ); 149 : } else { 150 : std::vector<double> vals; std::vector<std::pair<unsigned,unsigned> > pairs; 151 : bool symmetric=getPntrToArgument(k)->isSymmetric(); 152 206 : unsigned nedge=0; getPntrToArgument(k)->retrieveEdgeList( nedge, pairs, vals ); 153 8946 : for(unsigned l=0; l<nedge; ++l ) { 154 8740 : unsigned i=pairs[l].first, j=pairs[l].second; 155 8740 : myval->set( ncols*(row_starts[k]+i)+col_starts[k]+j, vals[l] ); 156 8740 : if( symmetric ) myval->set( ncols*(row_starts[k]+j)+col_starts[k]+i, vals[l] ); 157 : } 158 : } 159 : } 160 : } 161 12191 : } 162 : 163 12070 : void Concatenate::apply() { 164 12070 : if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) return; 165 : 166 5033 : Value* val=getPntrToComponent(0); 167 5033 : if( vectors ) { 168 : unsigned k=0; 169 19923 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 170 14942 : Value* myarg=getPntrToArgument(i); unsigned nvals=myarg->getNumberOfValues(); 171 205938 : for(unsigned j=0; j<nvals; ++j) { myarg->addForce( j, val->getForce(k) ); k++; } 172 : } 173 : } else { 174 52 : unsigned ncols=val->getShape()[1]; 175 258 : for(unsigned k=0; k<getNumberOfArguments(); ++k) { 176 : Value* argn=getPntrToArgument(k); 177 206 : if( argn->getRank()==0 ) argn->addForce( 0, val->getForce(ncols*row_starts[k]+col_starts[k]) ); 178 : else { 179 : unsigned val_ncols=val->getNumberOfColumns(); 180 : unsigned arg_ncols=argn->getNumberOfColumns(); 181 1686 : for(unsigned i=0; i<argn->getShape()[0]; ++i) { 182 : unsigned ncol = argn->getRowLength(i); 183 13140 : for(unsigned j=0; j<ncol; ++j) argn->addForce( i*arg_ncols+j, val->getForce( val_ncols*(row_starts[k]+i)+col_starts[k]+argn->getRowIndex(i,j) ), false ); 184 : } 185 : } 186 : } 187 : } 188 : } 189 : 190 : } 191 : }