Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2014-2017 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionWithValue.h" 23 : #include "core/ActionWithArguments.h" 24 : #include "core/ActionRegister.h" 25 : 26 : //+PLUMEDOC MCOLVAR CONCATENATE 27 : /* 28 : Join vectors or matrices together 29 : 30 : \par Examples 31 : 32 : */ 33 : //+ENDPLUMEDOC 34 : 35 : namespace PLMD { 36 : namespace valtools { 37 : 38 : class Concatenate : 39 : public ActionWithValue, 40 : public ActionWithArguments { 41 : private: 42 : bool vectors; 43 : std::vector<unsigned> row_starts; 44 : std::vector<unsigned> col_starts; 45 : public: 46 : static void registerKeywords( Keywords& keys ); 47 : /// Constructor 48 : explicit Concatenate(const ActionOptions&); 49 : /// Get the number of derivatives 50 257 : unsigned getNumberOfDerivatives() override { return 0; } 51 : /// Do the calculation 52 : void calculate() override; 53 : /// 54 : void apply(); 55 : }; 56 : 57 : PLUMED_REGISTER_ACTION(Concatenate,"CONCATENATE") 58 : 59 353 : void Concatenate::registerKeywords( Keywords& keys ) { 60 353 : Action::registerKeywords( keys ); ActionWithValue::registerKeywords( keys ); ActionWithArguments::registerKeywords( keys ); keys.use("ARG"); 61 1059 : keys.add("numbered","MATRIX","specify the matrices that you wish to join together into a single matrix"); keys.reset_style("MATRIX","compulsory"); 62 353 : keys.setValueDescription("the concatenated vector/matrix that was constructed from the input values"); 63 353 : } 64 : 65 176 : Concatenate::Concatenate(const ActionOptions& ao): 66 : Action(ao), 67 : ActionWithValue(ao), 68 176 : ActionWithArguments(ao) 69 : { 70 176 : if( getNumberOfArguments()>0 ) { 71 172 : vectors=true; std::vector<unsigned> shape(1); shape[0]=0; 72 547 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 73 375 : if( getPntrToArgument(i)->getRank()>1 ) error("cannot concatenate matrix with vectors"); 74 375 : getPntrToArgument(i)->buildDataStore(); shape[0] += getPntrToArgument(i)->getNumberOfValues(); 75 : } 76 172 : log.printf(" creating vector with %d elements \n", shape[0] ); 77 172 : addValue( shape ); bool period=getPntrToArgument(0)->isPeriodic(); 78 172 : std::string min, max; if( period ) getPntrToArgument(0)->getDomain( min, max ); 79 375 : for(unsigned i=1; i<getNumberOfArguments(); ++i) { 80 203 : if( period!=getPntrToArgument(i)->isPeriodic() ) error("periods of input arguments should match"); 81 203 : if( period ) { 82 0 : std::string min0, max0; getPntrToArgument(i)->getDomain( min0, max0 ); 83 0 : if( min0!=min || max0!=max ) error("domains of input arguments should match"); 84 : } 85 : } 86 172 : if( period ) setPeriodic( min, max ); else setNotPeriodic(); 87 172 : getPntrToComponent(0)->buildDataStore(); 88 172 : if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 89 : } else { 90 4 : unsigned nrows=0, ncols=0; std::vector<Value*> arglist; vectors=false; 91 7 : for(unsigned i=1;; i++) { 92 : unsigned nt_cols=0; unsigned size_b4 = arglist.size(); 93 14 : for(unsigned j=1;; j++) { 94 25 : if( j==10 ) error("cannot combine more than 9 matrices"); 95 50 : std::vector<Value*> argn; parseArgumentList("MATRIX", i*10+j, argn); 96 25 : if( argn.size()==0 ) break; 97 14 : if( argn.size()>1 ) error("should only be one argument to each matrix keyword"); 98 14 : if( argn[0]->getRank()!=0 && argn[0]->getRank()!=2 ) error("input arguments for this action should be matrices"); 99 14 : argn[0]->buildDataStore(); arglist.push_back( argn[0] ); nt_cols++; 100 14 : if( argn[0]->getRank()==0 ) log.printf(" %d %d component of composed matrix is scalar labelled %s\n", i, j, argn[0]->getName().c_str() ); 101 14 : else log.printf(" %d %d component of composed matrix is %d by %d matrix labelled %s\n", i, j, argn[0]->getShape()[0], argn[0]->getShape()[1], argn[0]->getName().c_str() ); 102 14 : } 103 11 : if( arglist.size()==size_b4 ) break; 104 7 : if( i==1 ) ncols=nt_cols; 105 3 : else if( nt_cols!=ncols ) error("should be joining same number of matrices in each row"); 106 7 : nrows++; 107 7 : } 108 : 109 4 : std::vector<unsigned> shape(2); shape[0]=0; unsigned k=0; 110 4 : row_starts.resize( arglist.size() ); col_starts.resize( arglist.size() ); 111 11 : for(unsigned i=0; i<nrows; ++i) { 112 7 : unsigned cstart = 0, nr = 1; if( arglist[k]->getRank()==2 ) nr=arglist[k]->getShape()[0]; 113 21 : for(unsigned j=0; j<ncols; ++j) { 114 14 : if( arglist[k]->getRank()==0 ) { 115 0 : if( nr!=1 ) error("mismatched matrix sizes"); 116 14 : } else if( nrows>1 && arglist[k]->getShape()[0]!=nr ) error("mismatched matrix sizes"); 117 14 : row_starts[k] = shape[0]; col_starts[k] = cstart; 118 14 : if( arglist[k]->getRank()==0 ) cstart += 1; 119 14 : else cstart += arglist[k]->getShape()[1]; 120 14 : k++; 121 : } 122 7 : if( i==0 ) shape[1]=cstart; 123 3 : else if( cstart!=shape[1] ) error("mismatched matrix sizes"); 124 7 : if( arglist[k-1]->getRank()==0 ) shape[0] += 1; 125 7 : else shape[0] += arglist[k-1]->getShape()[0]; 126 : } 127 : // Now request the arguments to make sure we store things we need 128 4 : requestArguments(arglist); addValue( shape ); setNotPeriodic(); getPntrToComponent(0)->buildDataStore(); 129 4 : if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 130 : } 131 176 : } 132 : 133 12191 : void Concatenate::calculate() { 134 12191 : Value* myval = getPntrToComponent(0); 135 12191 : if( vectors ) { 136 : unsigned k=0; 137 61297 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 138 49158 : Value* myarg=getPntrToArgument(i); unsigned nvals=myarg->getNumberOfValues(); 139 404266 : for(unsigned j=0; j<nvals; ++j) { myval->set( k, myarg->get(j) ); k++; } 140 : } 141 : } else { 142 : // Retrieve the matrix from input 143 52 : unsigned ncols = myval->getShape()[1]; 144 258 : for(unsigned k=0; k<getNumberOfArguments(); ++k) { 145 : Value* argn = getPntrToArgument(k); 146 206 : if( argn->getRank()==0 ) { 147 0 : myval->set( ncols*row_starts[k]+col_starts[k], argn->get() ); 148 : } else { 149 : std::vector<double> vals; std::vector<std::pair<unsigned,unsigned> > pairs; 150 : bool symmetric=getPntrToArgument(k)->isSymmetric(); 151 206 : unsigned nedge=0; getPntrToArgument(k)->retrieveEdgeList( nedge, pairs, vals ); 152 8946 : for(unsigned l=0; l<nedge; ++l ) { 153 8740 : unsigned i=pairs[l].first, j=pairs[l].second; 154 8740 : myval->set( ncols*(row_starts[k]+i)+col_starts[k]+j, vals[l] ); 155 8740 : if( symmetric ) myval->set( ncols*(row_starts[k]+j)+col_starts[k]+i, vals[l] ); 156 : } 157 : } 158 : } 159 : } 160 12191 : } 161 : 162 12070 : void Concatenate::apply() { 163 12070 : if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) return; 164 : 165 5033 : Value* val=getPntrToComponent(0); 166 5033 : if( vectors ) { 167 : unsigned k=0; 168 19923 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 169 14942 : Value* myarg=getPntrToArgument(i); unsigned nvals=myarg->getNumberOfValues(); 170 205938 : for(unsigned j=0; j<nvals; ++j) { myarg->addForce( j, val->getForce(k) ); k++; } 171 : } 172 : } else { 173 52 : unsigned ncols=val->getShape()[1]; 174 258 : for(unsigned k=0; k<getNumberOfArguments(); ++k) { 175 : Value* argn=getPntrToArgument(k); 176 206 : if( argn->getRank()==0 ) argn->addForce( 0, val->getForce(ncols*row_starts[k]+col_starts[k]) ); 177 : else { 178 : unsigned val_ncols=val->getNumberOfColumns(); 179 : unsigned arg_ncols=argn->getNumberOfColumns(); 180 1686 : for(unsigned i=0; i<argn->getShape()[0]; ++i) { 181 : unsigned ncol = argn->getRowLength(i); 182 13140 : for(unsigned j=0; j<ncol; ++j) argn->addForce( i*arg_ncols+j, val->getForce( val_ncols*(row_starts[k]+i)+col_starts[k]+argn->getRowIndex(i,j) ), false ); 183 : } 184 : } 185 : } 186 : } 187 : } 188 : 189 : } 190 : }