LCOV - code coverage report
Current view: top level - landmarks - LandmarkSelection.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 68 68 100.0 %
Date: 2024-10-18 13:59:31 Functions: 2 3 66.7 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2015-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionShortcut.h"
      23             : #include "core/ActionRegister.h"
      24             : #include "core/ActionWithValue.h"
      25             : #include "core/ActionPilot.h"
      26             : #include "core/PlumedMain.h"
      27             : #include "core/ActionSet.h"
      28             : 
      29             : //+PLUMEDOC LANDMARKS LANDMARK_SELECT_STRIDE
      30             : /*
      31             : Select every ith frame from the stored data
      32             : 
      33             : \par Examples
      34             : 
      35             : */
      36             : //+ENDPLUMEDOC
      37             : 
      38             : //+PLUMEDOC LANDMARKS LANDMARK_SELECT_RANDOM
      39             : /*
      40             : Select a random set of landmarks from a large set of configurations.
      41             : 
      42             : \par Examples
      43             : 
      44             : */
      45             : //+ENDPLUMEDOC
      46             : 
      47             : //+PLUMEDOC LANDMARKS LANDMARK_SELECT_FPS
      48             : /*
      49             : Select a of landmarks from a large set of configurations using farthest point sampling.
      50             : 
      51             : \par Examples
      52             : 
      53             : */
      54             : //+ENDPLUMEDOC
      55             : 
      56             : namespace PLMD {
      57             : namespace landmarks {
      58             : 
      59             : class LandmarkSelection : public ActionShortcut {
      60             : public:
      61             :   static void registerKeywords( Keywords& keys );
      62             :   explicit LandmarkSelection( const ActionOptions& ao );
      63             : };
      64             : 
      65             : PLUMED_REGISTER_ACTION(LandmarkSelection,"LANDMARK_SELECT_STRIDE")
      66             : PLUMED_REGISTER_ACTION(LandmarkSelection,"LANDMARK_SELECT_RANDOM")
      67             : PLUMED_REGISTER_ACTION(LandmarkSelection,"LANDMARK_SELECT_FPS")
      68             : 
      69          17 : void LandmarkSelection::registerKeywords( Keywords& keys ) {
      70          17 :   ActionShortcut::registerKeywords( keys );
      71          34 :   keys.add("optional","ARG","the COLLECT_FRAMES action that you used to get the data");
      72          34 :   keys.add("optional","DISSIMILARITIES","the matrix of dissimilarities if this is not provided the squared dissimilarities are calculated");
      73          34 :   keys.add("compulsory","NLANDMARKS","the numbe rof landmarks you would like to create");
      74          34 :   keys.add("optional","SEED","a random number seed");
      75          34 :   keys.addFlag("NOVORONOI",false,"do not do a Voronoi analysis of the data to determine weights of final points");
      76          34 :   keys.addFlag("NODISSIMILARITIES",false,"do not calculate the dissimilarities");
      77          34 :   keys.addOutputComponent("data","ARG","matrix","the data that is being collected by this action");
      78          34 :   keys.addOutputComponent("logweights","ARG","vector","the logarithms of the weights of the data points");
      79          34 :   keys.addOutputComponent("rectdissims","DISSIMILARITIES","matrix","a rectangular matrix containing the distances between the landmark points and the rest of the points");
      80          34 :   keys.addOutputComponent("sqrdissims","DISSIMILARITIES","matrix","a square matrix containing the distances between each pair of landmark points");
      81          51 :   keys.needsAction("LOGSUMEXP"); keys.needsAction("TRANSPOSE"); keys.needsAction("DISSIMILARITIES");
      82          51 :   keys.needsAction("ONES"); keys.needsAction("CREATE_MASK"); keys.needsAction("FARTHEST_POINT_SAMPLING");
      83          51 :   keys.needsAction("SELECT_WITH_MASK"); keys.needsAction("COMBINE"); keys.needsAction("VORONOI");
      84          34 :   keys.needsAction("MATRIX_PRODUCT"); keys.needsAction("CUSTOM");
      85          17 : }
      86             : 
      87           8 : LandmarkSelection::LandmarkSelection( const ActionOptions& ao ):
      88             :   Action(ao),
      89           8 :   ActionShortcut(ao)
      90             : {
      91          16 :   std::string nlandmarks; parse("NLANDMARKS",nlandmarks); bool novoronoi; parseFlag("NOVORONOI",novoronoi);
      92             : 
      93          16 :   bool nodissims; parseFlag("NODISSIMILARITIES",nodissims);
      94          24 :   std::string argn, dissims; parse("ARG",argn); parse("DISSIMILARITIES",dissims);
      95           8 :   if( argn.length()>0 ) {
      96           7 :     ActionShortcut* as = plumed.getActionSet().getShortcutActionWithLabel( argn );
      97           7 :     if( !as || as->getName()!="COLLECT_FRAMES" ) error("found no COLLECT_FRAMES action with label " + argn );
      98             :     // Get the weights
      99          14 :     readInputLine( getShortcutLabel() + "_allweights: LOGSUMEXP ARG=" + argn + "_logweights");
     100             :   }
     101           8 :   if( dissims.length()>0 ) {
     102           4 :     ActionWithValue* ds = plumed.getActionSet().selectWithLabel<ActionWithValue*>( dissims );
     103           4 :     if( (ds->copyOutput(0))->getRank()!=2 ) error("input for dissimilarities shoudl be a matrix");
     104             :     // Calculate the dissimilarities if the user didn't specify them
     105           4 :   } else if( !nodissims ) {
     106           2 :     readInputLine( getShortcutLabel() + "_" + argn + "_dataT: TRANSPOSE ARG=" + argn + "_data"); dissims = getShortcutLabel() + "_dissims";
     107           2 :     readInputLine( getShortcutLabel() + "_dissims: DISSIMILARITIES SQUARED ARG=" + argn + "_data," + getShortcutLabel() + "_" + argn + "_dataT");
     108             :   }
     109             :   // This deals with a corner case whereby users have a matrix of dissimilarities but no corresponding coordinates for these frames
     110           8 :   if( argn.length()==0 && dissims.size()>0 ) {
     111           1 :     ActionWithValue* ds = plumed.getActionSet().selectWithLabel<ActionWithValue*>( dissims );
     112           1 :     if( ds->getName()!="CONSTANT" || (ds->copyOutput(0))->getRank()!=2 ) error("set ARG as well as DISSIMILARITIES");
     113           1 :     std::string size; Tools::convert(  (ds->copyOutput(0))->getShape()[0], size );
     114           2 :     readInputLine( getShortcutLabel() + "_allweights: ONES SIZE=" + size );
     115             :   }
     116             : 
     117           8 :   if( getName()=="LANDMARK_SELECT_STRIDE" ) {
     118          12 :     readInputLine( getShortcutLabel() + "_mask: CREATE_MASK ARG=" + getShortcutLabel() + "_allweights TYPE=stride NZEROS=" + nlandmarks );
     119           2 :   } else if( getName()=="LANDMARK_SELECT_RANDOM" ) {
     120           1 :     if( argn.length()==0 ) error("must set COLLECT_FRAMES object for landmark selection using ARG keyword");
     121           3 :     std::string seed; parse("SEED",seed); if( seed.length()>0 ) seed = " SEED=" + seed;
     122           2 :     readInputLine( getShortcutLabel() + "_mask: CREATE_MASK ARG=" + getShortcutLabel() + "_allweights TYPE=random NZEROS=" + nlandmarks + seed );
     123           1 :   } else if( getName()=="LANDMARK_SELECT_FPS" ) {
     124           1 :     if( dissims.length()==0 ) error("dissimiarities must be defined to use FPS sampling");
     125           2 :     std::string seed; parse("SEED",seed); if( seed.length()>0 ) seed = " SEED=" + seed;
     126           2 :     readInputLine( getShortcutLabel() + "_mask: FARTHEST_POINT_SAMPLING ARG=" + dissims + " NZEROS=" + nlandmarks + seed );
     127             :   }
     128             : 
     129          15 :   if( argn.length()>0 ) readInputLine( getShortcutLabel() + "_data: SELECT_WITH_MASK ARG=" + argn + "_data ROW_MASK=" + getShortcutLabel() + "_mask");
     130             : 
     131           8 :   unsigned nland; Tools::convert( nlandmarks, nland );
     132           8 :   if( dissims.length()>0 ) {
     133           5 :     ActionWithValue* ds = plumed.getActionSet().selectWithLabel<ActionWithValue*>( dissims );
     134           5 :     if( (ds->copyOutput(0))->getShape()[0]==nland ) {
     135           1 :       if( !novoronoi ) { warning("cannot use voronoi procedure to give weights as not all distances between points are known"); novoronoi=true; }
     136           2 :       readInputLine( getShortcutLabel() + "_sqrdissims: COMBINE ARG=" + dissims + " PERIODIC=NO");
     137             :     } else {
     138           8 :       readInputLine( getShortcutLabel() + "_rmask: CREATE_MASK ARG=" + getShortcutLabel() + "_allweights TYPE=nomask");
     139           8 :       readInputLine( getShortcutLabel() + "_rectdissims: SELECT_WITH_MASK ARG=" + dissims + " COLUMN_MASK=" + getShortcutLabel() + "_mask ROW_MASK=" + getShortcutLabel() + "_rmask");
     140           8 :       readInputLine( getShortcutLabel() + "_sqrdissims: SELECT_WITH_MASK ARG=" + dissims + " ROW_MASK=" + getShortcutLabel() + "_mask COLUMN_MASK=" + getShortcutLabel() + "_mask");
     141             :     }
     142             :   }
     143             : 
     144           8 :   if( !novoronoi && argn.length()>0 && dissims.length()>0 ) {
     145           6 :     readInputLine( getShortcutLabel() + "_voronoi: VORONOI ARG=" + getShortcutLabel() + "_rectdissims");
     146           6 :     readInputLine( getShortcutLabel() + "_allweightsT: TRANSPOSE ARG=" + getShortcutLabel() + "_allweights");
     147           6 :     readInputLine( getShortcutLabel() + "_weightsT: MATRIX_PRODUCT ARG=" + getShortcutLabel() + "_allweightsT," + getShortcutLabel() + "_voronoi");
     148           6 :     readInputLine( getShortcutLabel() + "_weights: TRANSPOSE ARG=" + getShortcutLabel() + "_weightsT");
     149           6 :     readInputLine( getShortcutLabel() + "_logweights: CUSTOM ARG=" + getShortcutLabel() + "_weights FUNC=log(x) PERIODIC=NO");
     150           5 :   } else if( argn.length()>0 ) {
     151           4 :     if( !novoronoi ) warning("cannot use voronoi procedure to give weights to landmark points as DISSIMILARITIES was not set");
     152           8 :     readInputLine( getShortcutLabel() + "_logweights: SELECT_WITH_MASK ARG=" + argn + "_logweights MASK=" + getShortcutLabel() + "_mask");
     153             :   }
     154             :   // Create the vector of ones that is needed by Classical MDS
     155          15 :   if( argn.length()>0 ) readInputLine( getShortcutLabel() + "_ones: SELECT_WITH_MASK ARG=" + argn + "_ones MASK=" + getShortcutLabel() + "_mask");
     156           8 : }
     157             : 
     158             : }
     159             : }

Generated by: LCOV version 1.16