Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2015-2020 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionShortcut.h" 23 : #include "core/PlumedMain.h" 24 : #include "core/ActionSet.h" 25 : #include "core/ActionRegister.h" 26 : #include "core/ActionWithValue.h" 27 : 28 : //+PLUMEDOC DIMRED SKETCHMAP 29 : /* 30 : Construct a sketch map projection of the input data 31 : 32 : \par Examples 33 : 34 : */ 35 : //+ENDPLUMEDOC 36 : 37 : namespace PLMD { 38 : namespace dimred { 39 : 40 : class SketchMap : public ActionShortcut { 41 : public: 42 : static void registerKeywords( Keywords& keys ); 43 : explicit SketchMap( const ActionOptions& ao ); 44 : }; 45 : 46 : PLUMED_REGISTER_ACTION(SketchMap,"SKETCHMAP") 47 : 48 9 : void SketchMap::registerKeywords( Keywords& keys ) { 49 9 : ActionShortcut::registerKeywords( keys ); 50 18 : keys.add("compulsory","NLOW_DIM","number of low-dimensional coordinates required"); 51 18 : keys.add("optional","WEIGHTS","a vector containing the weights of the points"); 52 18 : keys.add("compulsory","ARG","the matrix of high dimensional coordinates that you want to project in the low dimensional space"); 53 18 : keys.add("compulsory","HIGH_DIM_FUNCTION","the parameters of the switching function in the high dimensional space"); 54 18 : keys.add("compulsory","LOW_DIM_FUNCTION","the parameters of the switching function in the low dimensional space"); 55 18 : keys.add("compulsory","CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the projection of the landmarks"); 56 18 : keys.add("compulsory","MAXITER","1000","maximum number of optimization cycles for optimisation algorithms"); 57 18 : keys.add("compulsory","NCYCLES","0","The number of cycles of pointwise global optimisation that are required"); 58 18 : keys.add("compulsory","BUFFER","1.1","grid extent for search is (max projection - minimum projection) multiplied by this value"); 59 18 : keys.add("compulsory","CGRID_SIZE","10","number of points to use in each grid direction"); 60 18 : keys.add("compulsory","FGRID_SIZE","0","interpolate the grid onto this number of points -- only works in 2D"); 61 18 : keys.addFlag("PROJECT_ALL",false,"if the input are landmark coordinates then project the out of sample configurations"); 62 18 : keys.add("compulsory","OS_CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the out of sample projections"); 63 18 : keys.addFlag("USE_SMACOF",false,"find the projection in the low dimensional space using the SMACOF algorithm"); 64 18 : keys.add("compulsory","SMACTOL","1E-4","the tolerance for the smacof algorithm"); 65 18 : keys.add("compulsory","SMACREG","0.001","this is used to ensure that we don't divide by zero when updating weights for SMACOF algorithm"); 66 9 : keys.setValueDescription("the sketch-map projection of the input points"); 67 18 : keys.addOutputComponent("osample","PROJECT_ALL","the out-of-sample projections"); 68 36 : keys.needsAction("CLASSICAL_MDS"); keys.needsAction("MORE_THAN"); keys.needsAction("SUM"); keys.needsAction("CUSTOM"); 69 36 : keys.needsAction("OUTER_PRODUCT"); keys.needsAction("ARRANGE_POINTS"); keys.needsAction("PROJECT_POINTS"); keys.needsAction("VSTACK"); 70 9 : } 71 : 72 4 : SketchMap::SketchMap( const ActionOptions& ao): 73 : Action(ao), 74 4 : ActionShortcut(ao) 75 : { 76 : // Get the high dimensioal data 77 8 : std::string argn; parse("ARG",argn); std::string dissimilarities = getShortcutLabel() + "_mds_mat"; 78 4 : ActionShortcut* as = plumed.getActionSet().getShortcutActionWithLabel( argn ); 79 4 : if( !as ) error("found no action with name " + argn ); 80 4 : if( as->getName()!="COLLECT_FRAMES" ) { 81 1 : if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) { 82 0 : error("found no COLLECT_FRAMES or LANDMARK_SELECT action with label " + argn ); 83 : } else { 84 1 : ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_sqrdissims"); 85 2 : if( dissims ) dissimilarities = argn + "_sqrdissims"; 86 : } 87 : } 88 8 : unsigned ndim; parse("NLOW_DIM",ndim); 89 4 : std::string str_ndim; Tools::convert( ndim, str_ndim ); 90 : // Construct a projection using classical MDS 91 8 : readInputLine( getShortcutLabel() + "_mds: CLASSICAL_MDS ARG=" + argn + " NLOW_DIM=" + str_ndim ); 92 : // Transform the dissimilarities using the switching function 93 4 : std::string hdfunc; parse("HIGH_DIM_FUNCTION",hdfunc); 94 8 : readInputLine( getShortcutLabel() + "_hdmat: MORE_THAN ARG=" + dissimilarities + " SQUARED SWITCH={" + hdfunc + "}"); 95 : // Now for the weights - read the vector of weights first 96 9 : std::string wvec; parse("WEIGHTS",wvec); if( wvec.length()==0 ) wvec = argn + "_weights"; 97 : // Now calculate the sum of thse weights 98 8 : readInputLine( wvec + "_sum: SUM ARG=" + wvec + " PERIODIC=NO"); 99 : // And normalise the vector of weights using this sum 100 8 : readInputLine( wvec + "_normed: CUSTOM ARG=" + wvec + "_sum," + wvec + " FUNC=y/x PERIODIC=NO"); 101 : // And now create the matrix of weights 102 8 : readInputLine( wvec + "_mat: OUTER_PRODUCT ARG=" + wvec + "_normed," + wvec + "_normed"); 103 : // Run the arrange points object 104 20 : std::string ldfunc, cgtol, maxiter; parse("LOW_DIM_FUNCTION",ldfunc); parse("CGTOL",cgtol); parse("MAXITER",maxiter); unsigned ncycles; parse("NCYCLES",ncycles); 105 5 : std::string num, argstr, lname=getShortcutLabel() + "_ap"; if( ncycles>0 ) lname = getShortcutLabel() + "_cg"; 106 12 : argstr = "ARG=" + getShortcutLabel() + "_mds-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_mds-" + num; } 107 4 : bool usesmacof; parseFlag("USE_SMACOF",usesmacof); 108 4 : if( usesmacof ) { 109 2 : std::string smactol, smacreg; parse("SMACTOL",smactol); parse("SMACREG",smacreg); 110 3 : readInputLine( lname + ": ARRANGE_POINTS " + argstr + " MINTYPE=smacof TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat" + 111 3 : " MAXITER=" + maxiter + " SMACTOL=" + smactol + " SMACREG=" + smacreg + " TARGET2=" + getShortcutLabel() + "_mds_mat WEIGHTS2=" + wvec + "_mat"); 112 : } else { 113 6 : readInputLine( lname + ": ARRANGE_POINTS " + argstr + " MINTYPE=conjgrad TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol); 114 3 : if( ncycles>0 ) { 115 2 : std::string buf; parse("BUFFER",buf); 116 2 : std::vector<std::string> fgrid; parseVector("FGRID_SIZE",fgrid); 117 2 : std::string ncyc; Tools::convert(ncycles,ncyc); std::string pwise_args=" NCYCLES=" + ncyc + " BUFFER=" + buf; 118 1 : if( fgrid.size()>0 ) { 119 1 : if( fgrid.size()!=ndim ) error("number of elements of fgrid is not correct"); 120 3 : pwise_args += " FGRID_SIZE=" + fgrid[0]; for(unsigned i=1; i<fgrid.size(); ++i) pwise_args += "," + fgrid[i]; 121 : } 122 2 : std::vector<std::string> cgrid(ndim); parseVector("CGRID_SIZE",cgrid); 123 3 : pwise_args += " CGRID_SIZE=" + cgrid[0]; for(unsigned i=1; i<cgrid.size(); ++i) pwise_args += "," + cgrid[i]; 124 3 : argstr="ARG=" + getShortcutLabel() + "_cg.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_cg.coord-" + num; } 125 2 : readInputLine( getShortcutLabel() + "_ap: ARRANGE_POINTS " + argstr + pwise_args + " MINTYPE=pointwise TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol); 126 2 : } 127 : } 128 12 : argstr="ARG=" + getShortcutLabel() + "_ap.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_ap.coord-" + num; } 129 8 : readInputLine( getShortcutLabel() + ": VSTACK " + argstr ); 130 8 : bool projall; parseFlag("PROJECT_ALL",projall); if( !projall ) return ; 131 4 : parse("OS_CGTOL",cgtol); argstr = getShortcutLabel() + "_ap.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_ap.coord-" + num; } 132 1 : if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) { 133 0 : readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS " + argstr + " TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol ); 134 : } else { 135 1 : ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_rectdissims"); 136 1 : if( !dissims ) error("cannot PROJECT_ALL as " + as->getName() + " with label " + argn + " was involved without the DISSIMILARITIES keyword"); 137 2 : readInputLine( getShortcutLabel() + "_lhdmat: MORE_THAN ARG=" + argn + "_rectdissims SQUARED SWITCH={" + hdfunc + "}"); 138 2 : readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS ARG=" + argstr + " TARGET1=" + getShortcutLabel() + "_lhdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol ); 139 : } 140 3 : argstr="ARG=" + getShortcutLabel() + "_osample_pp.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_osample_pp.coord-" + num; } 141 2 : readInputLine( getShortcutLabel() + "_osample: VSTACK " + argstr ); 142 0 : } 143 : 144 : } 145 : }