LCOV - code coverage report
Current view: top level - valtools - Concatenate.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 85 89 95.5 %
Date: 2024-10-18 13:59:31 Functions: 5 6 83.3 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2014-2017 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionWithValue.h"
      23             : #include "core/ActionWithArguments.h"
      24             : #include "core/ActionRegister.h"
      25             : 
      26             : //+PLUMEDOC MCOLVAR CONCATENATE
      27             : /*
      28             : Join vectors or matrices together
      29             : 
      30             : \par Examples
      31             : 
      32             : */
      33             : //+ENDPLUMEDOC
      34             : 
      35             : namespace PLMD {
      36             : namespace valtools {
      37             : 
      38             : class Concatenate :
      39             :   public ActionWithValue,
      40             :   public ActionWithArguments {
      41             : private:
      42             :   bool vectors;
      43             :   std::vector<unsigned> row_starts;
      44             :   std::vector<unsigned> col_starts;
      45             : public:
      46             :   static void registerKeywords( Keywords& keys );
      47             : /// Constructor
      48             :   explicit Concatenate(const ActionOptions&);
      49             : /// Get the number of derivatives
      50         257 :   unsigned getNumberOfDerivatives() override { return 0; }
      51             : /// Do the calculation
      52             :   void calculate() override;
      53             : ///
      54             :   void apply();
      55             : };
      56             : 
      57             : PLUMED_REGISTER_ACTION(Concatenate,"CONCATENATE")
      58             : 
      59         353 : void Concatenate::registerKeywords( Keywords& keys ) {
      60         353 :   Action::registerKeywords( keys ); ActionWithValue::registerKeywords( keys ); ActionWithArguments::registerKeywords( keys );
      61         706 :   keys.addInputKeyword("optional","ARG","scalar/vector","the values that should be concatenated together to form the output vector");
      62        1059 :   keys.addInputKeyword("numbered","MATRIX","scalar/matrix","specify the matrices that you wish to join together into a single matrix"); keys.reset_style("MATRIX","compulsory");
      63         706 :   keys.setValueDescription("vector/matrix","the concatenated vector/matrix that was constructed from the input values");
      64         353 : }
      65             : 
      66         176 : Concatenate::Concatenate(const ActionOptions& ao):
      67             :   Action(ao),
      68             :   ActionWithValue(ao),
      69         176 :   ActionWithArguments(ao)
      70             : {
      71         176 :   if( getNumberOfArguments()>0 ) {
      72         172 :     vectors=true; std::vector<unsigned> shape(1); shape[0]=0;
      73         547 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
      74         375 :       if( getPntrToArgument(i)->getRank()>1 ) error("cannot concatenate matrix with vectors");
      75         375 :       getPntrToArgument(i)->buildDataStore(); shape[0] += getPntrToArgument(i)->getNumberOfValues();
      76             :     }
      77         172 :     log.printf("  creating vector with %d elements \n", shape[0] );
      78         172 :     addValue( shape ); bool period=getPntrToArgument(0)->isPeriodic();
      79         172 :     std::string min, max; if( period ) getPntrToArgument(0)->getDomain( min, max );
      80         375 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
      81         203 :       if( period!=getPntrToArgument(i)->isPeriodic() ) error("periods of input arguments should match");
      82         203 :       if( period ) {
      83           0 :         std::string min0, max0; getPntrToArgument(i)->getDomain( min0, max0 );
      84           0 :         if( min0!=min || max0!=max ) error("domains of input arguments should match");
      85             :       }
      86             :     }
      87         172 :     if( period ) setPeriodic( min, max ); else setNotPeriodic();
      88         172 :     getPntrToComponent(0)->buildDataStore();
      89         172 :     if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
      90             :   } else {
      91           4 :     unsigned nrows=0, ncols=0; std::vector<Value*> arglist; vectors=false;
      92           7 :     for(unsigned i=1;; i++) {
      93             :       unsigned nt_cols=0; unsigned size_b4 = arglist.size();
      94          14 :       for(unsigned j=1;; j++) {
      95          25 :         if( j==10 ) error("cannot combine more than 9 matrices");
      96          50 :         std::vector<Value*> argn; parseArgumentList("MATRIX", i*10+j, argn);
      97          25 :         if( argn.size()==0 ) break;
      98          14 :         if( argn.size()>1 ) error("should only be one argument to each matrix keyword");
      99          14 :         if( argn[0]->getRank()!=0 && argn[0]->getRank()!=2 ) error("input arguments for this action should be matrices");
     100          14 :         argn[0]->buildDataStore(); arglist.push_back( argn[0] ); nt_cols++;
     101          14 :         if( argn[0]->getRank()==0 ) log.printf("  %d %d component of composed matrix is scalar labelled %s\n", i, j, argn[0]->getName().c_str() );
     102          14 :         else log.printf("  %d %d component of composed matrix is %d by %d matrix labelled %s\n", i, j, argn[0]->getShape()[0], argn[0]->getShape()[1], argn[0]->getName().c_str() );
     103          14 :       }
     104          11 :       if( arglist.size()==size_b4 ) break;
     105           7 :       if( i==1 ) ncols=nt_cols;
     106           3 :       else if( nt_cols!=ncols ) error("should be joining same number of matrices in each row");
     107           7 :       nrows++;
     108           7 :     }
     109             : 
     110           4 :     std::vector<unsigned> shape(2); shape[0]=0; unsigned k=0;
     111           4 :     row_starts.resize( arglist.size() ); col_starts.resize( arglist.size() );
     112          11 :     for(unsigned i=0; i<nrows; ++i) {
     113           7 :       unsigned cstart = 0, nr = 1; if( arglist[k]->getRank()==2 ) nr=arglist[k]->getShape()[0];
     114          21 :       for(unsigned j=0; j<ncols; ++j) {
     115          14 :         if( arglist[k]->getRank()==0 ) {
     116           0 :           if( nr!=1 ) error("mismatched matrix sizes");
     117          14 :         } else if( nrows>1 && arglist[k]->getShape()[0]!=nr ) error("mismatched matrix sizes");
     118          14 :         row_starts[k] = shape[0]; col_starts[k] = cstart;
     119          14 :         if( arglist[k]->getRank()==0 ) cstart += 1;
     120          14 :         else cstart += arglist[k]->getShape()[1];
     121          14 :         k++;
     122             :       }
     123           7 :       if( i==0 ) shape[1]=cstart;
     124           3 :       else if( cstart!=shape[1] ) error("mismatched matrix sizes");
     125           7 :       if( arglist[k-1]->getRank()==0 ) shape[0] += 1;
     126           7 :       else shape[0] += arglist[k-1]->getShape()[0];
     127             :     }
     128             :     // Now request the arguments to make sure we store things we need
     129           4 :     requestArguments(arglist); addValue( shape ); setNotPeriodic(); getPntrToComponent(0)->buildDataStore();
     130           4 :     if( getPntrToComponent(0)->getRank()==2 ) getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
     131             :   }
     132         176 : }
     133             : 
     134       12191 : void Concatenate::calculate() {
     135       12191 :   Value* myval = getPntrToComponent(0);
     136       12191 :   if( vectors ) {
     137             :     unsigned k=0;
     138       61297 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     139       49158 :       Value* myarg=getPntrToArgument(i); unsigned nvals=myarg->getNumberOfValues();
     140      404266 :       for(unsigned j=0; j<nvals; ++j) { myval->set( k, myarg->get(j) ); k++; }
     141             :     }
     142             :   } else {
     143             :     // Retrieve the matrix from input
     144          52 :     unsigned ncols = myval->getShape()[1];
     145         258 :     for(unsigned k=0; k<getNumberOfArguments(); ++k) {
     146             :       Value* argn = getPntrToArgument(k);
     147         206 :       if( argn->getRank()==0 ) {
     148           0 :         myval->set( ncols*row_starts[k]+col_starts[k], argn->get() );
     149             :       } else {
     150             :         std::vector<double> vals; std::vector<std::pair<unsigned,unsigned> > pairs;
     151             :         bool symmetric=getPntrToArgument(k)->isSymmetric();
     152         206 :         unsigned nedge=0; getPntrToArgument(k)->retrieveEdgeList( nedge, pairs, vals );
     153        8946 :         for(unsigned l=0; l<nedge; ++l ) {
     154        8740 :           unsigned i=pairs[l].first, j=pairs[l].second;
     155        8740 :           myval->set( ncols*(row_starts[k]+i)+col_starts[k]+j, vals[l] );
     156        8740 :           if( symmetric ) myval->set( ncols*(row_starts[k]+j)+col_starts[k]+i, vals[l] );
     157             :         }
     158             :       }
     159             :     }
     160             :   }
     161       12191 : }
     162             : 
     163       12070 : void Concatenate::apply() {
     164       12070 :   if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) return;
     165             : 
     166        5033 :   Value* val=getPntrToComponent(0);
     167        5033 :   if( vectors ) {
     168             :     unsigned k=0;
     169       19923 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     170       14942 :       Value* myarg=getPntrToArgument(i); unsigned nvals=myarg->getNumberOfValues();
     171      205938 :       for(unsigned j=0; j<nvals; ++j) { myarg->addForce( j, val->getForce(k) ); k++; }
     172             :     }
     173             :   } else {
     174          52 :     unsigned ncols=val->getShape()[1];
     175         258 :     for(unsigned k=0; k<getNumberOfArguments(); ++k) {
     176             :       Value* argn=getPntrToArgument(k);
     177         206 :       if( argn->getRank()==0 ) argn->addForce( 0, val->getForce(ncols*row_starts[k]+col_starts[k]) );
     178             :       else {
     179             :         unsigned val_ncols=val->getNumberOfColumns();
     180             :         unsigned arg_ncols=argn->getNumberOfColumns();
     181        1686 :         for(unsigned i=0; i<argn->getShape()[0]; ++i) {
     182             :           unsigned ncol = argn->getRowLength(i);
     183       13140 :           for(unsigned j=0; j<ncol; ++j) argn->addForce( i*arg_ncols+j, val->getForce( val_ncols*(row_starts[k]+i)+col_starts[k]+argn->getRowIndex(i,j) ), false );
     184             :         }
     185             :       }
     186             :     }
     187             :   }
     188             : }
     189             : 
     190             : }
     191             : }

Generated by: LCOV version 1.16