LCOV - code coverage report
Current view: top level - valtools - Concatenate.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 113 134 84.3 %
Date: 2025-04-08 21:11:17 Functions: 5 6 83.3 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2014-2017 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionWithValue.h"
      23             : #include "core/ActionWithArguments.h"
      24             : #include "core/ActionRegister.h"
      25             : 
      26             : //+PLUMEDOC MCOLVAR CONCATENATE
      27             : /*
      28             : Join vectors or matrices together
      29             : 
      30             : This action can be used to join sets of scalars, vector or matrices together into a single output value.  The following
      31             : example shows how this works with vectors and scalars:
      32             : 
      33             : ```plumed
      34             : d1: DISTANCE ATOMS=1,2
      35             : d2: DISTANCE ATOMS1=3,4 ATOMS2=5,6
      36             : v: CONCATENATE ARG=d1,d2
      37             : ```
      38             : 
      39             : The output vector here, `v`, has three components.  The first of these is the distance between atoms 1 and 2 and the second
      40             : and third are the distances between atoms 3 and 4 and 5 and 6 respectively.
      41             : 
      42             : Concatenating two or more vectors is similarly straightfoward. Concatenating matrices is a little more difficult as the
      43             : user must explain the axes that each matrix should be joined to. The following input should give some idea as to how
      44             : the joining of matrices is done in practise.  In this example, we define two groups of atoms and calculate the contact
      45             : matrix between all the atoms in the two groups by first calculating three contact matrices that describe the inter and intra
      46             : group contacts.
      47             : 
      48             : ```plumed
      49             : # Define two groups of atoms
      50             : g: GROUP ATOMS=1-5
      51             : h: GROUP ATOMS=6-20
      52             : 
      53             : # Calculate the CONTACT_MATRIX for the atoms in group g
      54             : cg: CONTACT_MATRIX GROUP=g SWITCH={RATIONAL R_0=0.1}
      55             : 
      56             : # Calculate the CONTACT_MATRIX for the atoms in group h
      57             : ch: CONTACT_MATRIX GROUP=h SWITCH={RATIONAL R_0=0.2}
      58             : 
      59             : # Calculate the CONTACT_MATRIX between the atoms in group g and group h
      60             : cgh: CONTACT_MATRIX GROUPA=g GROUPB=h SWITCH={RATIONAL R_0=0.15}
      61             : 
      62             : # Now calculate the contact matrix between the atoms in group h and group h
      63             : # Notice this is just the transpose of cgh
      64             : cghT: TRANSPOSE ARG=cgh
      65             : 
      66             : # And concatenate the matrices together to construct the adjacency matrix between the
      67             : # adjacency matrices
      68             : m: CONCATENATE ...
      69             :  MATRIX11=cg MATRIX12=cgh
      70             :  MATRIX21=cghT MATRIX22=ch
      71             : ...
      72             : ```
      73             : 
      74             : */
      75             : //+ENDPLUMEDOC
      76             : 
      77             : namespace PLMD {
      78             : namespace valtools {
      79             : 
      80             : class Concatenate :
      81             :   public ActionWithValue,
      82             :   public ActionWithArguments {
      83             : private:
      84             :   bool vectors;
      85             :   std::vector<unsigned> row_starts;
      86             :   std::vector<unsigned> col_starts;
      87             : public:
      88             :   static void registerKeywords( Keywords& keys );
      89             : /// Constructor
      90             :   explicit Concatenate(const ActionOptions&);
      91             : /// Get the number of derivatives
      92         257 :   unsigned getNumberOfDerivatives() override {
      93         257 :     return 0;
      94             :   }
      95             : /// Do the calculation
      96             :   void calculate() override;
      97             : ///
      98             :   void apply();
      99             : };
     100             : 
     101             : PLUMED_REGISTER_ACTION(Concatenate,"CONCATENATE")
     102             : 
     103         353 : void Concatenate::registerKeywords( Keywords& keys ) {
     104         353 :   Action::registerKeywords( keys );
     105         353 :   ActionWithValue::registerKeywords( keys );
     106         353 :   ActionWithArguments::registerKeywords( keys );
     107         706 :   keys.addInputKeyword("optional","ARG","scalar/vector","the values that should be concatenated together to form the output vector");
     108         706 :   keys.addInputKeyword("numbered","MATRIX","scalar/matrix","specify the matrices that you wish to join together into a single matrix");
     109         706 :   keys.reset_style("MATRIX","compulsory");
     110         706 :   keys.setValueDescription("vector/matrix","the concatenated vector/matrix that was constructed from the input values");
     111         353 : }
     112             : 
     113         176 : Concatenate::Concatenate(const ActionOptions& ao):
     114             :   Action(ao),
     115             :   ActionWithValue(ao),
     116         176 :   ActionWithArguments(ao) {
     117         176 :   if( getNumberOfArguments()>0 ) {
     118         172 :     vectors=true;
     119         172 :     std::vector<unsigned> shape(1);
     120         172 :     shape[0]=0;
     121         547 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     122         375 :       if( getPntrToArgument(i)->getRank()>1 ) {
     123           0 :         error("cannot concatenate matrix with vectors");
     124             :       }
     125         375 :       getPntrToArgument(i)->buildDataStore();
     126         375 :       shape[0] += getPntrToArgument(i)->getNumberOfValues();
     127             :     }
     128         172 :     log.printf("  creating vector with %d elements \n", shape[0] );
     129         172 :     addValue( shape );
     130         172 :     bool period=getPntrToArgument(0)->isPeriodic();
     131             :     std::string min, max;
     132         172 :     if( period ) {
     133           0 :       getPntrToArgument(0)->getDomain( min, max );
     134             :     }
     135         375 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
     136         203 :       if( period!=getPntrToArgument(i)->isPeriodic() ) {
     137           0 :         error("periods of input arguments should match");
     138             :       }
     139         203 :       if( period ) {
     140             :         std::string min0, max0;
     141           0 :         getPntrToArgument(i)->getDomain( min0, max0 );
     142           0 :         if( min0!=min || max0!=max ) {
     143           0 :           error("domains of input arguments should match");
     144             :         }
     145             :       }
     146             :     }
     147         172 :     if( period ) {
     148           0 :       setPeriodic( min, max );
     149             :     } else {
     150         172 :       setNotPeriodic();
     151             :     }
     152         172 :     getPntrToComponent(0)->buildDataStore();
     153         172 :     if( getPntrToComponent(0)->getRank()==2 ) {
     154           0 :       getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
     155             :     }
     156             :   } else {
     157             :     unsigned nrows=0, ncols=0;
     158             :     std::vector<Value*> arglist;
     159           4 :     vectors=false;
     160           7 :     for(unsigned i=1;; i++) {
     161             :       unsigned nt_cols=0;
     162             :       unsigned size_b4 = arglist.size();
     163          14 :       for(unsigned j=1;; j++) {
     164          25 :         if( j==10 ) {
     165           0 :           error("cannot combine more than 9 matrices");
     166             :         }
     167             :         std::vector<Value*> argn;
     168          50 :         parseArgumentList("MATRIX", i*10+j, argn);
     169          25 :         if( argn.size()==0 ) {
     170             :           break;
     171             :         }
     172          14 :         if( argn.size()>1 ) {
     173           0 :           error("should only be one argument to each matrix keyword");
     174             :         }
     175          14 :         if( argn[0]->getRank()!=0 && argn[0]->getRank()!=2 ) {
     176           0 :           error("input arguments for this action should be matrices");
     177             :         }
     178          14 :         argn[0]->buildDataStore();
     179          14 :         arglist.push_back( argn[0] );
     180          14 :         nt_cols++;
     181          14 :         if( argn[0]->getRank()==0 ) {
     182           0 :           log.printf("  %d %d component of composed matrix is scalar labelled %s\n", i, j, argn[0]->getName().c_str() );
     183             :         } else {
     184          14 :           log.printf("  %d %d component of composed matrix is %d by %d matrix labelled %s\n", i, j, argn[0]->getShape()[0], argn[0]->getShape()[1], argn[0]->getName().c_str() );
     185             :         }
     186          14 :       }
     187          11 :       if( arglist.size()==size_b4 ) {
     188             :         break;
     189             :       }
     190           7 :       if( i==1 ) {
     191             :         ncols=nt_cols;
     192           3 :       } else if( nt_cols!=ncols ) {
     193           0 :         error("should be joining same number of matrices in each row");
     194             :       }
     195           7 :       nrows++;
     196           7 :     }
     197             : 
     198           4 :     std::vector<unsigned> shape(2);
     199           4 :     shape[0]=0;
     200             :     unsigned k=0;
     201           4 :     row_starts.resize( arglist.size() );
     202           4 :     col_starts.resize( arglist.size() );
     203          11 :     for(unsigned i=0; i<nrows; ++i) {
     204             :       unsigned cstart = 0, nr = 1;
     205           7 :       if( arglist[k]->getRank()==2 ) {
     206           7 :         nr=arglist[k]->getShape()[0];
     207             :       }
     208          21 :       for(unsigned j=0; j<ncols; ++j) {
     209          14 :         if( arglist[k]->getRank()==0 ) {
     210           0 :           if( nr!=1 ) {
     211           0 :             error("mismatched matrix sizes");
     212             :           }
     213          14 :         } else if( nrows>1 && arglist[k]->getShape()[0]!=nr ) {
     214           0 :           error("mismatched matrix sizes");
     215             :         }
     216          14 :         row_starts[k] = shape[0];
     217          14 :         col_starts[k] = cstart;
     218          14 :         if( arglist[k]->getRank()==0 ) {
     219           0 :           cstart += 1;
     220             :         } else {
     221          14 :           cstart += arglist[k]->getShape()[1];
     222             :         }
     223          14 :         k++;
     224             :       }
     225           7 :       if( i==0 ) {
     226           4 :         shape[1]=cstart;
     227           3 :       } else if( cstart!=shape[1] ) {
     228           0 :         error("mismatched matrix sizes");
     229             :       }
     230           7 :       if( arglist[k-1]->getRank()==0 ) {
     231           0 :         shape[0] += 1;
     232             :       } else {
     233           7 :         shape[0] += arglist[k-1]->getShape()[0];
     234             :       }
     235             :     }
     236             :     // Now request the arguments to make sure we store things we need
     237           4 :     requestArguments(arglist);
     238           4 :     addValue( shape );
     239           4 :     setNotPeriodic();
     240           4 :     getPntrToComponent(0)->buildDataStore();
     241           4 :     if( getPntrToComponent(0)->getRank()==2 ) {
     242           4 :       getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
     243             :     }
     244             :   }
     245         176 : }
     246             : 
     247       12191 : void Concatenate::calculate() {
     248       12191 :   Value* myval = getPntrToComponent(0);
     249       12191 :   if( vectors ) {
     250             :     unsigned k=0;
     251       61297 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     252             :       Value* myarg=getPntrToArgument(i);
     253       49158 :       unsigned nvals=myarg->getNumberOfValues();
     254      404266 :       for(unsigned j=0; j<nvals; ++j) {
     255      355108 :         myval->set( k, myarg->get(j) );
     256      355108 :         k++;
     257             :       }
     258             :     }
     259             :   } else {
     260             :     // Retrieve the matrix from input
     261          52 :     unsigned ncols = myval->getShape()[1];
     262         258 :     for(unsigned k=0; k<getNumberOfArguments(); ++k) {
     263             :       Value* argn = getPntrToArgument(k);
     264         206 :       if( argn->getRank()==0 ) {
     265           0 :         myval->set( ncols*row_starts[k]+col_starts[k], argn->get() );
     266             :       } else {
     267             :         std::vector<double> vals;
     268             :         std::vector<std::pair<unsigned,unsigned> > pairs;
     269             :         bool symmetric=getPntrToArgument(k)->isSymmetric();
     270         206 :         unsigned nedge=0;
     271         206 :         getPntrToArgument(k)->retrieveEdgeList( nedge, pairs, vals );
     272        8946 :         for(unsigned l=0; l<nedge; ++l ) {
     273        8740 :           unsigned i=pairs[l].first, j=pairs[l].second;
     274        8740 :           myval->set( ncols*(row_starts[k]+i)+col_starts[k]+j, vals[l] );
     275        8740 :           if( symmetric ) {
     276        2142 :             myval->set( ncols*(row_starts[k]+j)+col_starts[k]+i, vals[l] );
     277             :           }
     278             :         }
     279             :       }
     280             :     }
     281             :   }
     282       12191 : }
     283             : 
     284       12070 : void Concatenate::apply() {
     285       12070 :   if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) {
     286        7037 :     return;
     287             :   }
     288             : 
     289        5033 :   Value* val=getPntrToComponent(0);
     290        5033 :   if( vectors ) {
     291             :     unsigned k=0;
     292       19923 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     293             :       Value* myarg=getPntrToArgument(i);
     294       14942 :       unsigned nvals=myarg->getNumberOfValues();
     295      205938 :       for(unsigned j=0; j<nvals; ++j) {
     296      190996 :         myarg->addForce( j, val->getForce(k) );
     297      190996 :         k++;
     298             :       }
     299             :     }
     300             :   } else {
     301          52 :     unsigned ncols=val->getShape()[1];
     302         258 :     for(unsigned k=0; k<getNumberOfArguments(); ++k) {
     303             :       Value* argn=getPntrToArgument(k);
     304         206 :       if( argn->getRank()==0 ) {
     305           0 :         argn->addForce( 0, val->getForce(ncols*row_starts[k]+col_starts[k]) );
     306             :       } else {
     307             :         unsigned val_ncols=val->getNumberOfColumns();
     308             :         unsigned arg_ncols=argn->getNumberOfColumns();
     309        1686 :         for(unsigned i=0; i<argn->getShape()[0]; ++i) {
     310             :           unsigned ncol = argn->getRowLength(i);
     311       13140 :           for(unsigned j=0; j<ncol; ++j) {
     312       11660 :             argn->addForce( i*arg_ncols+j, val->getForce( val_ncols*(row_starts[k]+i)+col_starts[k]+argn->getRowIndex(i,j) ), false );
     313             :           }
     314             :         }
     315             :       }
     316             :     }
     317             :   }
     318             : }
     319             : 
     320             : }
     321             : }

Generated by: LCOV version 1.16