LCOV - code coverage report
Current view: top level - valtools - Concatenate.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 84 88 95.5 %
Date: 2024-10-18 14:00:25 Functions: 5 6 83.3 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2014-2017 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionWithValue.h"
      23             : #include "core/ActionWithArguments.h"
      24             : #include "core/ActionRegister.h"
      25             : 
      26             : //+PLUMEDOC MCOLVAR CONCATENATE
      27             : /*
      28             : Join vectors or matrices together
      29             : 
      30             : \par Examples
      31             : 
      32             : */
      33             : //+ENDPLUMEDOC
      34             : 
      35             : namespace PLMD {
      36             : namespace valtools {
      37             : 
      38             : class Concatenate :
      39             :   public ActionWithValue,
      40             :   public ActionWithArguments {
      41             : private:
      42             :   bool vectors;
      43             :   std::vector<unsigned> row_starts;
      44             :   std::vector<unsigned> col_starts;
      45             : public:
      46             :   static void registerKeywords( Keywords& keys );
      47             : /// Constructor
      48             :   explicit Concatenate(const ActionOptions&);
      49             : /// Get the number of derivatives
      50         257 :   unsigned getNumberOfDerivatives() override { return 0; }
      51             : /// Do the calculation
      52             :   void calculate() override;
      53             : ///
      54             :   void apply();
      55             : };
      56             : 
      57             : PLUMED_REGISTER_ACTION(Concatenate,"CONCATENATE")
      58             : 
      59         353 : void Concatenate::registerKeywords( Keywords& keys ) {
      60         353 :   Action::registerKeywords( keys ); ActionWithValue::registerKeywords( keys ); ActionWithArguments::registerKeywords( keys ); keys.use("ARG");
      61        1059 :   keys.add("numbered","MATRIX","specify the matrices that you wish to join together into a single matrix"); keys.reset_style("MATRIX","compulsory");
      62         353 :   keys.setValueDescription("the concatenated vector/matrix that was constructed from the input values");
      63         353 : }
      64             : 
      65         176 : Concatenate::Concatenate(const ActionOptions& ao):
      66             :   Action(ao),
      67             :   ActionWithValue(ao),
      68         176 :   ActionWithArguments(ao)
      69             : {
      70         176 :   if( getNumberOfArguments()>0 ) {
      71         172 :     vectors=true; std::vector<unsigned> shape(1); shape[0]=0;
      72         547 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
      73         375 :       if( getPntrToArgument(i)->getRank()>1 ) error("cannot concatenate matrix with vectors");
      74         375 :       getPntrToArgument(i)->buildDataStore(); shape[0] += getPntrToArgument(i)->getNumberOfValues();
      75             :     }
      76         172 :     log.printf("  creating vector with %d elements \n", shape[0] );
      77         172 :     addValue( shape ); bool period=getPntrToArgument(0)->isPeriodic();
      78         172 :     std::string min, max; if( period ) getPntrToArgument(0)->getDomain( min, max );
      79         375 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
      80         203 :       if( period!=getPntrToArgument(i)->isPeriodic() ) error("periods of input arguments should match");
      81         203 :       if( period ) {
      82           0 :         std::string min0, max0; getPntrToArgument(i)->getDomain( min0, max0 );
      83           0 :         if( min0!=min || max0!=max ) error("domains of input arguments should match");
      84             :       }
      85             :     }
      86         172 :     if( period ) setPeriodic( min, max ); else setNotPeriodic();
      87         172 :     getPntrToComponent(0)->buildDataStore();
      88         172 :     if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
      89             :   } else {
      90           4 :     unsigned nrows=0, ncols=0; std::vector<Value*> arglist; vectors=false;
      91           7 :     for(unsigned i=1;; i++) {
      92             :       unsigned nt_cols=0; unsigned size_b4 = arglist.size();
      93          14 :       for(unsigned j=1;; j++) {
      94          25 :         if( j==10 ) error("cannot combine more than 9 matrices");
      95          50 :         std::vector<Value*> argn; parseArgumentList("MATRIX", i*10+j, argn);
      96          25 :         if( argn.size()==0 ) break;
      97          14 :         if( argn.size()>1 ) error("should only be one argument to each matrix keyword");
      98          14 :         if( argn[0]->getRank()!=0 && argn[0]->getRank()!=2 ) error("input arguments for this action should be matrices");
      99          14 :         argn[0]->buildDataStore(); arglist.push_back( argn[0] ); nt_cols++;
     100          14 :         if( argn[0]->getRank()==0 ) log.printf("  %d %d component of composed matrix is scalar labelled %s\n", i, j, argn[0]->getName().c_str() );
     101          14 :         else log.printf("  %d %d component of composed matrix is %d by %d matrix labelled %s\n", i, j, argn[0]->getShape()[0], argn[0]->getShape()[1], argn[0]->getName().c_str() );
     102          14 :       }
     103          11 :       if( arglist.size()==size_b4 ) break;
     104           7 :       if( i==1 ) ncols=nt_cols;
     105           3 :       else if( nt_cols!=ncols ) error("should be joining same number of matrices in each row");
     106           7 :       nrows++;
     107           7 :     }
     108             : 
     109           4 :     std::vector<unsigned> shape(2); shape[0]=0; unsigned k=0;
     110           4 :     row_starts.resize( arglist.size() ); col_starts.resize( arglist.size() );
     111          11 :     for(unsigned i=0; i<nrows; ++i) {
     112           7 :       unsigned cstart = 0, nr = 1; if( arglist[k]->getRank()==2 ) nr=arglist[k]->getShape()[0];
     113          21 :       for(unsigned j=0; j<ncols; ++j) {
     114          14 :         if( arglist[k]->getRank()==0 ) {
     115           0 :           if( nr!=1 ) error("mismatched matrix sizes");
     116          14 :         } else if( nrows>1 && arglist[k]->getShape()[0]!=nr ) error("mismatched matrix sizes");
     117          14 :         row_starts[k] = shape[0]; col_starts[k] = cstart;
     118          14 :         if( arglist[k]->getRank()==0 ) cstart += 1;
     119          14 :         else cstart += arglist[k]->getShape()[1];
     120          14 :         k++;
     121             :       }
     122           7 :       if( i==0 ) shape[1]=cstart;
     123           3 :       else if( cstart!=shape[1] ) error("mismatched matrix sizes");
     124           7 :       if( arglist[k-1]->getRank()==0 ) shape[0] += 1;
     125           7 :       else shape[0] += arglist[k-1]->getShape()[0];
     126             :     }
     127             :     // Now request the arguments to make sure we store things we need
     128           4 :     requestArguments(arglist); addValue( shape ); setNotPeriodic(); getPntrToComponent(0)->buildDataStore();
     129           4 :     if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
     130             :   }
     131         176 : }
     132             : 
     133       12191 : void Concatenate::calculate() {
     134       12191 :   Value* myval = getPntrToComponent(0);
     135       12191 :   if( vectors ) {
     136             :     unsigned k=0;
     137       61297 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     138       49158 :       Value* myarg=getPntrToArgument(i); unsigned nvals=myarg->getNumberOfValues();
     139      404266 :       for(unsigned j=0; j<nvals; ++j) { myval->set( k, myarg->get(j) ); k++; }
     140             :     }
     141             :   } else {
     142             :     // Retrieve the matrix from input
     143          52 :     unsigned ncols = myval->getShape()[1];
     144         258 :     for(unsigned k=0; k<getNumberOfArguments(); ++k) {
     145             :       Value* argn = getPntrToArgument(k);
     146         206 :       if( argn->getRank()==0 ) {
     147           0 :         myval->set( ncols*row_starts[k]+col_starts[k], argn->get() );
     148             :       } else {
     149             :         std::vector<double> vals; std::vector<std::pair<unsigned,unsigned> > pairs;
     150             :         bool symmetric=getPntrToArgument(k)->isSymmetric();
     151         206 :         unsigned nedge=0; getPntrToArgument(k)->retrieveEdgeList( nedge, pairs, vals );
     152        8946 :         for(unsigned l=0; l<nedge; ++l ) {
     153        8740 :           unsigned i=pairs[l].first, j=pairs[l].second;
     154        8740 :           myval->set( ncols*(row_starts[k]+i)+col_starts[k]+j, vals[l] );
     155        8740 :           if( symmetric ) myval->set( ncols*(row_starts[k]+j)+col_starts[k]+i, vals[l] );
     156             :         }
     157             :       }
     158             :     }
     159             :   }
     160       12191 : }
     161             : 
     162       12070 : void Concatenate::apply() {
     163       12070 :   if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) return;
     164             : 
     165        5033 :   Value* val=getPntrToComponent(0);
     166        5033 :   if( vectors ) {
     167             :     unsigned k=0;
     168       19923 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     169       14942 :       Value* myarg=getPntrToArgument(i); unsigned nvals=myarg->getNumberOfValues();
     170      205938 :       for(unsigned j=0; j<nvals; ++j) { myarg->addForce( j, val->getForce(k) ); k++; }
     171             :     }
     172             :   } else {
     173          52 :     unsigned ncols=val->getShape()[1];
     174         258 :     for(unsigned k=0; k<getNumberOfArguments(); ++k) {
     175             :       Value* argn=getPntrToArgument(k);
     176         206 :       if( argn->getRank()==0 ) argn->addForce( 0, val->getForce(ncols*row_starts[k]+col_starts[k]) );
     177             :       else {
     178             :         unsigned val_ncols=val->getNumberOfColumns();
     179             :         unsigned arg_ncols=argn->getNumberOfColumns();
     180        1686 :         for(unsigned i=0; i<argn->getShape()[0]; ++i) {
     181             :           unsigned ncol = argn->getRowLength(i);
     182       13140 :           for(unsigned j=0; j<ncol; ++j) argn->addForce( i*arg_ncols+j, val->getForce( val_ncols*(row_starts[k]+i)+col_starts[k]+argn->getRowIndex(i,j) ), false );
     183             :         }
     184             :       }
     185             :     }
     186             :   }
     187             : }
     188             : 
     189             : }
     190             : }

Generated by: LCOV version 1.16