LCOV - code coverage report
Current view: top level - dimred - SketchMap.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 112 118 94.9 %
Date: 2025-03-25 09:33:27 Functions: 2 3 66.7 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2015-2020 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionShortcut.h"
      23             : #include "core/PlumedMain.h"
      24             : #include "core/ActionSet.h"
      25             : #include "core/ActionRegister.h"
      26             : #include "core/ActionWithValue.h"
      27             : 
      28             : //+PLUMEDOC DIMRED SKETCHMAP
      29             : /*
      30             : Construct a sketch map projection of the input data
      31             : 
      32             : \par Examples
      33             : 
      34             : */
      35             : //+ENDPLUMEDOC
      36             : 
      37             : namespace PLMD {
      38             : namespace dimred {
      39             : 
      40             : class SketchMap : public ActionShortcut {
      41             : public:
      42             :   static void registerKeywords( Keywords& keys );
      43             :   explicit SketchMap( const ActionOptions& ao );
      44             : };
      45             : 
      46             : PLUMED_REGISTER_ACTION(SketchMap,"SKETCHMAP")
      47             : 
      48           9 : void SketchMap::registerKeywords( Keywords& keys ) {
      49           9 :   ActionShortcut::registerKeywords( keys );
      50           9 :   keys.add("compulsory","NLOW_DIM","number of low-dimensional coordinates required");
      51           9 :   keys.add("optional","WEIGHTS","a vector containing the weights of the points");
      52           9 :   keys.add("compulsory","ARG","the matrix of high dimensional coordinates that you want to project in the low dimensional space");
      53           9 :   keys.add("compulsory","HIGH_DIM_FUNCTION","the parameters of the switching function in the high dimensional space");
      54           9 :   keys.add("compulsory","LOW_DIM_FUNCTION","the parameters of the switching function in the low dimensional space");
      55           9 :   keys.add("compulsory","CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the projection of the landmarks");
      56           9 :   keys.add("compulsory","MAXITER","1000","maximum number of optimization cycles for optimisation algorithms");
      57           9 :   keys.add("compulsory","NCYCLES","0","The number of cycles of pointwise global optimisation that are required");
      58           9 :   keys.add("compulsory","BUFFER","1.1","grid extent for search is (max projection - minimum projection) multiplied by this value");
      59           9 :   keys.add("compulsory","CGRID_SIZE","10","number of points to use in each grid direction");
      60           9 :   keys.add("compulsory","FGRID_SIZE","0","interpolate the grid onto this number of points -- only works in 2D");
      61           9 :   keys.addFlag("PROJECT_ALL",false,"if the input are landmark coordinates then project the out of sample configurations");
      62           9 :   keys.add("compulsory","OS_CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the out of sample projections");
      63           9 :   keys.addFlag("USE_SMACOF",false,"find the projection in the low dimensional space using the SMACOF algorithm");
      64           9 :   keys.add("compulsory","SMACTOL","1E-4","the tolerance for the smacof algorithm");
      65           9 :   keys.add("compulsory","SMACREG","0.001","this is used to ensure that we don't divide by zero when updating weights for SMACOF algorithm");
      66          18 :   keys.setValueDescription("matrix","the sketch-map projection of the input points");
      67          18 :   keys.addOutputComponent("osample","PROJECT_ALL","matrix","the out-of-sample projections");
      68           9 :   keys.needsAction("CLASSICAL_MDS");
      69           9 :   keys.needsAction("MORE_THAN");
      70           9 :   keys.needsAction("SUM");
      71           9 :   keys.needsAction("CUSTOM");
      72           9 :   keys.needsAction("OUTER_PRODUCT");
      73           9 :   keys.needsAction("ARRANGE_POINTS");
      74           9 :   keys.needsAction("PROJECT_POINTS");
      75           9 :   keys.needsAction("VSTACK");
      76           9 : }
      77             : 
      78           4 : SketchMap::SketchMap( const ActionOptions& ao):
      79             :   Action(ao),
      80           4 :   ActionShortcut(ao) {
      81             :   // Get the high dimensioal data
      82             :   std::string argn;
      83           4 :   parse("ARG",argn);
      84           4 :   std::string dissimilarities = getShortcutLabel() + "_mds_mat";
      85           4 :   ActionShortcut* as = plumed.getActionSet().getShortcutActionWithLabel( argn );
      86           4 :   if( !as ) {
      87           0 :     error("found no action with name " + argn );
      88             :   }
      89           4 :   if( as->getName()!="COLLECT_FRAMES" ) {
      90           1 :     if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
      91           0 :       error("found no COLLECT_FRAMES or LANDMARK_SELECT action with label " + argn );
      92             :     } else {
      93           1 :       ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_sqrdissims");
      94           1 :       if( dissims ) {
      95           2 :         dissimilarities = argn + "_sqrdissims";
      96             :       }
      97             :     }
      98             :   }
      99             :   unsigned ndim;
     100           8 :   parse("NLOW_DIM",ndim);
     101             :   std::string str_ndim;
     102           4 :   Tools::convert( ndim, str_ndim );
     103             :   // Construct a projection using classical MDS
     104           8 :   readInputLine( getShortcutLabel() + "_mds: CLASSICAL_MDS ARG=" + argn + " NLOW_DIM=" + str_ndim );
     105             :   // Transform the dissimilarities using the switching function
     106             :   std::string hdfunc;
     107           4 :   parse("HIGH_DIM_FUNCTION",hdfunc);
     108           8 :   readInputLine( getShortcutLabel() + "_hdmat: MORE_THAN ARG=" + dissimilarities + " SQUARED SWITCH={" + hdfunc + "}");
     109             :   // Now for the weights - read the vector of weights first
     110             :   std::string wvec;
     111           8 :   parse("WEIGHTS",wvec);
     112           4 :   if( wvec.length()==0 ) {
     113           2 :     wvec = argn + "_weights";
     114             :   }
     115             :   // Now calculate the sum of thse weights
     116           8 :   readInputLine( wvec + "_sum: SUM ARG=" + wvec + " PERIODIC=NO");
     117             :   // And normalise the vector of weights using this sum
     118           8 :   readInputLine( wvec + "_normed: CUSTOM ARG=" + wvec + "_sum," + wvec + " FUNC=y/x PERIODIC=NO");
     119             :   // And now create the matrix of weights
     120           8 :   readInputLine( wvec + "_mat: OUTER_PRODUCT ARG=" + wvec + "_normed," + wvec + "_normed");
     121             :   // Run the arrange points object
     122             :   std::string ldfunc, cgtol, maxiter;
     123           4 :   parse("LOW_DIM_FUNCTION",ldfunc);
     124           4 :   parse("CGTOL",cgtol);
     125           4 :   parse("MAXITER",maxiter);
     126             :   unsigned ncycles;
     127           8 :   parse("NCYCLES",ncycles);
     128           4 :   std::string num, argstr, lname=getShortcutLabel() + "_ap";
     129           4 :   if( ncycles>0 ) {
     130           2 :     lname = getShortcutLabel() + "_cg";
     131             :   }
     132           8 :   argstr = "ARG=" + getShortcutLabel() + "_mds-1";
     133           8 :   for(unsigned i=1; i<ndim; ++i) {
     134           4 :     Tools::convert( i+1, num );
     135           8 :     argstr += "," + getShortcutLabel() + "_mds-" + num;
     136             :   }
     137             :   bool usesmacof;
     138           4 :   parseFlag("USE_SMACOF",usesmacof);
     139           4 :   if( usesmacof ) {
     140             :     std::string smactol, smacreg;
     141           1 :     parse("SMACTOL",smactol);
     142           1 :     parse("SMACREG",smacreg);
     143           3 :     readInputLine( lname + ": ARRANGE_POINTS " + argstr  + " MINTYPE=smacof TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat" +
     144           3 :                    " MAXITER=" + maxiter + " SMACTOL=" + smactol + " SMACREG=" + smacreg + " TARGET2=" + getShortcutLabel() + "_mds_mat WEIGHTS2=" + wvec + "_mat");
     145             :   } else {
     146           6 :     readInputLine( lname + ": ARRANGE_POINTS " + argstr  + " MINTYPE=conjgrad TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
     147           3 :     if( ncycles>0 ) {
     148             :       std::string buf;
     149           2 :       parse("BUFFER",buf);
     150             :       std::vector<std::string> fgrid;
     151           2 :       parseVector("FGRID_SIZE",fgrid);
     152             :       std::string ncyc;
     153           1 :       Tools::convert(ncycles,ncyc);
     154           2 :       std::string pwise_args=" NCYCLES=" + ncyc + " BUFFER=" + buf;
     155           1 :       if( fgrid.size()>0 ) {
     156           1 :         if( fgrid.size()!=ndim ) {
     157           0 :           error("number of elements of fgrid is not correct");
     158             :         }
     159           1 :         pwise_args += " FGRID_SIZE=" + fgrid[0];
     160           2 :         for(unsigned i=1; i<fgrid.size(); ++i) {
     161           2 :           pwise_args += "," + fgrid[i];
     162             :         }
     163             :       }
     164           1 :       std::vector<std::string> cgrid(ndim);
     165           2 :       parseVector("CGRID_SIZE",cgrid);
     166           1 :       pwise_args += " CGRID_SIZE=" + cgrid[0];
     167           2 :       for(unsigned i=1; i<cgrid.size(); ++i) {
     168           2 :         pwise_args += "," + cgrid[i];
     169             :       }
     170           2 :       argstr="ARG=" + getShortcutLabel() + "_cg.coord-1";
     171           2 :       for(unsigned i=1; i<ndim; ++i) {
     172           1 :         Tools::convert( i+1, num );
     173           2 :         argstr += "," + getShortcutLabel() + "_cg.coord-" + num;
     174             :       }
     175           2 :       readInputLine( getShortcutLabel() + "_ap: ARRANGE_POINTS " + argstr  + pwise_args + " MINTYPE=pointwise TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
     176           2 :     }
     177             :   }
     178           8 :   argstr="ARG=" + getShortcutLabel() + "_ap.coord-1";
     179           8 :   for(unsigned i=1; i<ndim; ++i) {
     180           4 :     Tools::convert( i+1, num );
     181           8 :     argstr += "," + getShortcutLabel() + "_ap.coord-" + num;
     182             :   }
     183           8 :   readInputLine( getShortcutLabel() + ": VSTACK " + argstr );
     184             :   bool projall;
     185           4 :   parseFlag("PROJECT_ALL",projall);
     186           4 :   if( !projall ) {
     187             :     return ;
     188             :   }
     189           1 :   parse("OS_CGTOL",cgtol);
     190           1 :   argstr = getShortcutLabel() + "_ap.coord-1";
     191           2 :   for(unsigned i=1; i<ndim; ++i) {
     192           1 :     Tools::convert( i+1, num );
     193           2 :     argstr += "," + getShortcutLabel() + "_ap.coord-" + num;
     194             :   }
     195           1 :   if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
     196           0 :     readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS " + argstr + " TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
     197             :   } else {
     198           1 :     ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_rectdissims");
     199           1 :     if( !dissims ) {
     200           0 :       error("cannot PROJECT_ALL as " + as->getName() + " with label " + argn + " was involved without the DISSIMILARITIES keyword");
     201             :     }
     202           2 :     readInputLine( getShortcutLabel() + "_lhdmat: MORE_THAN ARG=" + argn + "_rectdissims SQUARED SWITCH={" + hdfunc + "}");
     203           2 :     readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS ARG=" + argstr + " TARGET1=" + getShortcutLabel() + "_lhdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
     204             :   }
     205           2 :   argstr="ARG=" + getShortcutLabel() + "_osample_pp.coord-1";
     206           2 :   for(unsigned i=1; i<ndim; ++i) {
     207           1 :     Tools::convert( i+1, num );
     208           2 :     argstr += "," + getShortcutLabel() + "_osample_pp.coord-" + num;
     209             :   }
     210           2 :   readInputLine( getShortcutLabel() + "_osample: VSTACK " + argstr );
     211           0 : }
     212             : 
     213             : }
     214             : }

Generated by: LCOV version 1.16