LCOV - code coverage report
Current view: top level - dimred - SketchMap.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 73 76 96.1 %
Date: 2024-10-18 13:59:31 Functions: 2 3 66.7 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2015-2020 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionShortcut.h"
      23             : #include "core/PlumedMain.h"
      24             : #include "core/ActionSet.h"
      25             : #include "core/ActionRegister.h"
      26             : #include "core/ActionWithValue.h"
      27             : 
      28             : //+PLUMEDOC DIMRED SKETCHMAP
      29             : /*
      30             : Construct a sketch map projection of the input data
      31             : 
      32             : \par Examples
      33             : 
      34             : */
      35             : //+ENDPLUMEDOC
      36             : 
      37             : namespace PLMD {
      38             : namespace dimred {
      39             : 
      40             : class SketchMap : public ActionShortcut {
      41             : public:
      42             :   static void registerKeywords( Keywords& keys );
      43             :   explicit SketchMap( const ActionOptions& ao );
      44             : };
      45             : 
      46             : PLUMED_REGISTER_ACTION(SketchMap,"SKETCHMAP")
      47             : 
      48           9 : void SketchMap::registerKeywords( Keywords& keys ) {
      49           9 :   ActionShortcut::registerKeywords( keys );
      50          18 :   keys.add("compulsory","NLOW_DIM","number of low-dimensional coordinates required");
      51          18 :   keys.add("optional","WEIGHTS","a vector containing the weights of the points");
      52          18 :   keys.add("compulsory","ARG","the matrix of high dimensional coordinates that you want to project in the low dimensional space");
      53          18 :   keys.add("compulsory","HIGH_DIM_FUNCTION","the parameters of the switching function in the high dimensional space");
      54          18 :   keys.add("compulsory","LOW_DIM_FUNCTION","the parameters of the switching function in the low dimensional space");
      55          18 :   keys.add("compulsory","CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the projection of the landmarks");
      56          18 :   keys.add("compulsory","MAXITER","1000","maximum number of optimization cycles for optimisation algorithms");
      57          18 :   keys.add("compulsory","NCYCLES","0","The number of cycles of pointwise global optimisation that are required");
      58          18 :   keys.add("compulsory","BUFFER","1.1","grid extent for search is (max projection - minimum projection) multiplied by this value");
      59          18 :   keys.add("compulsory","CGRID_SIZE","10","number of points to use in each grid direction");
      60          18 :   keys.add("compulsory","FGRID_SIZE","0","interpolate the grid onto this number of points -- only works in 2D");
      61          18 :   keys.addFlag("PROJECT_ALL",false,"if the input are landmark coordinates then project the out of sample configurations");
      62          18 :   keys.add("compulsory","OS_CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the out of sample projections");
      63          18 :   keys.addFlag("USE_SMACOF",false,"find the projection in the low dimensional space using the SMACOF algorithm");
      64          18 :   keys.add("compulsory","SMACTOL","1E-4","the tolerance for the smacof algorithm");
      65          18 :   keys.add("compulsory","SMACREG","0.001","this is used to ensure that we don't divide by zero when updating weights for SMACOF algorithm");
      66          18 :   keys.setValueDescription("matrix","the sketch-map projection of the input points");
      67          18 :   keys.addOutputComponent("osample","PROJECT_ALL","matrix","the out-of-sample projections");
      68          36 :   keys.needsAction("CLASSICAL_MDS"); keys.needsAction("MORE_THAN"); keys.needsAction("SUM"); keys.needsAction("CUSTOM");
      69          36 :   keys.needsAction("OUTER_PRODUCT"); keys.needsAction("ARRANGE_POINTS"); keys.needsAction("PROJECT_POINTS"); keys.needsAction("VSTACK");
      70           9 : }
      71             : 
      72           4 : SketchMap::SketchMap( const ActionOptions& ao):
      73             :   Action(ao),
      74           4 :   ActionShortcut(ao)
      75             : {
      76             :   // Get the high dimensioal data
      77           8 :   std::string argn; parse("ARG",argn); std::string dissimilarities = getShortcutLabel() + "_mds_mat";
      78           4 :   ActionShortcut* as = plumed.getActionSet().getShortcutActionWithLabel( argn );
      79           4 :   if( !as ) error("found no action with name " + argn );
      80           4 :   if( as->getName()!="COLLECT_FRAMES" ) {
      81           1 :     if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
      82           0 :       error("found no COLLECT_FRAMES or LANDMARK_SELECT action with label " + argn );
      83             :     } else {
      84           1 :       ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_sqrdissims");
      85           2 :       if( dissims ) dissimilarities = argn + "_sqrdissims";
      86             :     }
      87             :   }
      88           8 :   unsigned ndim; parse("NLOW_DIM",ndim);
      89           4 :   std::string str_ndim; Tools::convert( ndim, str_ndim );
      90             :   // Construct a projection using classical MDS
      91           8 :   readInputLine( getShortcutLabel() + "_mds: CLASSICAL_MDS ARG=" + argn + " NLOW_DIM=" + str_ndim );
      92             :   // Transform the dissimilarities using the switching function
      93           4 :   std::string hdfunc; parse("HIGH_DIM_FUNCTION",hdfunc);
      94           8 :   readInputLine( getShortcutLabel() + "_hdmat: MORE_THAN ARG=" + dissimilarities + " SQUARED SWITCH={" + hdfunc + "}");
      95             :   // Now for the weights - read the vector of weights first
      96           9 :   std::string wvec; parse("WEIGHTS",wvec); if( wvec.length()==0 ) wvec = argn + "_weights";
      97             :   // Now calculate the sum of thse weights
      98           8 :   readInputLine( wvec + "_sum: SUM ARG=" + wvec + " PERIODIC=NO");
      99             :   // And normalise the vector of weights using this sum
     100           8 :   readInputLine( wvec + "_normed: CUSTOM ARG=" + wvec + "_sum," + wvec + " FUNC=y/x PERIODIC=NO");
     101             :   // And now create the matrix of weights
     102           8 :   readInputLine( wvec + "_mat: OUTER_PRODUCT ARG=" + wvec + "_normed," + wvec + "_normed");
     103             :   // Run the arrange points object
     104          20 :   std::string ldfunc, cgtol, maxiter; parse("LOW_DIM_FUNCTION",ldfunc); parse("CGTOL",cgtol); parse("MAXITER",maxiter); unsigned ncycles; parse("NCYCLES",ncycles);
     105           5 :   std::string num, argstr, lname=getShortcutLabel() + "_ap"; if( ncycles>0 ) lname = getShortcutLabel() + "_cg";
     106          12 :   argstr = "ARG=" + getShortcutLabel() + "_mds-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_mds-" + num; }
     107           4 :   bool usesmacof; parseFlag("USE_SMACOF",usesmacof);
     108           4 :   if( usesmacof ) {
     109           2 :     std::string smactol, smacreg; parse("SMACTOL",smactol); parse("SMACREG",smacreg);
     110           3 :     readInputLine( lname + ": ARRANGE_POINTS " + argstr  + " MINTYPE=smacof TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat" +
     111           3 :                    " MAXITER=" + maxiter + " SMACTOL=" + smactol + " SMACREG=" + smacreg + " TARGET2=" + getShortcutLabel() + "_mds_mat WEIGHTS2=" + wvec + "_mat");
     112             :   } else {
     113           6 :     readInputLine( lname + ": ARRANGE_POINTS " + argstr  + " MINTYPE=conjgrad TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
     114           3 :     if( ncycles>0 ) {
     115           2 :       std::string buf; parse("BUFFER",buf);
     116           2 :       std::vector<std::string> fgrid; parseVector("FGRID_SIZE",fgrid);
     117           2 :       std::string ncyc; Tools::convert(ncycles,ncyc); std::string pwise_args=" NCYCLES=" + ncyc + " BUFFER=" + buf;
     118           1 :       if( fgrid.size()>0 ) {
     119           1 :         if( fgrid.size()!=ndim ) error("number of elements of fgrid is not correct");
     120           3 :         pwise_args += " FGRID_SIZE=" + fgrid[0];  for(unsigned i=1; i<fgrid.size(); ++i) pwise_args += "," + fgrid[i];
     121             :       }
     122           2 :       std::vector<std::string> cgrid(ndim); parseVector("CGRID_SIZE",cgrid);
     123           3 :       pwise_args += " CGRID_SIZE=" + cgrid[0]; for(unsigned i=1; i<cgrid.size(); ++i) pwise_args += "," + cgrid[i];
     124           3 :       argstr="ARG=" + getShortcutLabel() + "_cg.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_cg.coord-" + num; }
     125           2 :       readInputLine( getShortcutLabel() + "_ap: ARRANGE_POINTS " + argstr  + pwise_args + " MINTYPE=pointwise TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
     126           2 :     }
     127             :   }
     128          12 :   argstr="ARG=" + getShortcutLabel() + "_ap.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_ap.coord-" + num; }
     129           8 :   readInputLine( getShortcutLabel() + ": VSTACK " + argstr );
     130           8 :   bool projall; parseFlag("PROJECT_ALL",projall); if( !projall ) return ;
     131           4 :   parse("OS_CGTOL",cgtol); argstr = getShortcutLabel() + "_ap.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_ap.coord-" + num; }
     132           1 :   if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
     133           0 :     readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS " + argstr + " TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
     134             :   } else {
     135           1 :     ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_rectdissims");
     136           1 :     if( !dissims ) error("cannot PROJECT_ALL as " + as->getName() + " with label " + argn + " was involved without the DISSIMILARITIES keyword");
     137           2 :     readInputLine( getShortcutLabel() + "_lhdmat: MORE_THAN ARG=" + argn + "_rectdissims SQUARED SWITCH={" + hdfunc + "}");
     138           2 :     readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS ARG=" + argstr + " TARGET1=" + getShortcutLabel() + "_lhdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
     139             :   }
     140           3 :   argstr="ARG=" + getShortcutLabel() + "_osample_pp.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_osample_pp.coord-" + num; }
     141           2 :   readInputLine( getShortcutLabel() + "_osample: VSTACK " + argstr );
     142           0 : }
     143             : 
     144             : }
     145             : }

Generated by: LCOV version 1.16