Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2015-2020 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionShortcut.h"
23 : #include "core/PlumedMain.h"
24 : #include "core/ActionSet.h"
25 : #include "core/ActionRegister.h"
26 : #include "core/ActionWithValue.h"
27 :
28 : //+PLUMEDOC DIMRED SKETCHMAP
29 : /*
30 : Construct a sketch map projection of the input data
31 :
32 : \par Examples
33 :
34 : */
35 : //+ENDPLUMEDOC
36 :
37 : namespace PLMD {
38 : namespace dimred {
39 :
40 : class SketchMap : public ActionShortcut {
41 : public:
42 : static void registerKeywords( Keywords& keys );
43 : explicit SketchMap( const ActionOptions& ao );
44 : };
45 :
46 : PLUMED_REGISTER_ACTION(SketchMap,"SKETCHMAP")
47 :
48 9 : void SketchMap::registerKeywords( Keywords& keys ) {
49 9 : ActionShortcut::registerKeywords( keys );
50 18 : keys.add("compulsory","NLOW_DIM","number of low-dimensional coordinates required");
51 18 : keys.add("optional","WEIGHTS","a vector containing the weights of the points");
52 18 : keys.add("compulsory","ARG","the matrix of high dimensional coordinates that you want to project in the low dimensional space");
53 18 : keys.add("compulsory","HIGH_DIM_FUNCTION","the parameters of the switching function in the high dimensional space");
54 18 : keys.add("compulsory","LOW_DIM_FUNCTION","the parameters of the switching function in the low dimensional space");
55 18 : keys.add("compulsory","CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the projection of the landmarks");
56 18 : keys.add("compulsory","MAXITER","1000","maximum number of optimization cycles for optimisation algorithms");
57 18 : keys.add("compulsory","NCYCLES","0","The number of cycles of pointwise global optimisation that are required");
58 18 : keys.add("compulsory","BUFFER","1.1","grid extent for search is (max projection - minimum projection) multiplied by this value");
59 18 : keys.add("compulsory","CGRID_SIZE","10","number of points to use in each grid direction");
60 18 : keys.add("compulsory","FGRID_SIZE","0","interpolate the grid onto this number of points -- only works in 2D");
61 18 : keys.addFlag("PROJECT_ALL",false,"if the input are landmark coordinates then project the out of sample configurations");
62 18 : keys.add("compulsory","OS_CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the out of sample projections");
63 18 : keys.addFlag("USE_SMACOF",false,"find the projection in the low dimensional space using the SMACOF algorithm");
64 18 : keys.add("compulsory","SMACTOL","1E-4","the tolerance for the smacof algorithm");
65 18 : keys.add("compulsory","SMACREG","0.001","this is used to ensure that we don't divide by zero when updating weights for SMACOF algorithm");
66 18 : keys.setValueDescription("matrix","the sketch-map projection of the input points");
67 18 : keys.addOutputComponent("osample","PROJECT_ALL","matrix","the out-of-sample projections");
68 36 : keys.needsAction("CLASSICAL_MDS"); keys.needsAction("MORE_THAN"); keys.needsAction("SUM"); keys.needsAction("CUSTOM");
69 36 : keys.needsAction("OUTER_PRODUCT"); keys.needsAction("ARRANGE_POINTS"); keys.needsAction("PROJECT_POINTS"); keys.needsAction("VSTACK");
70 9 : }
71 :
72 4 : SketchMap::SketchMap( const ActionOptions& ao):
73 : Action(ao),
74 4 : ActionShortcut(ao)
75 : {
76 : // Get the high dimensioal data
77 8 : std::string argn; parse("ARG",argn); std::string dissimilarities = getShortcutLabel() + "_mds_mat";
78 4 : ActionShortcut* as = plumed.getActionSet().getShortcutActionWithLabel( argn );
79 4 : if( !as ) error("found no action with name " + argn );
80 4 : if( as->getName()!="COLLECT_FRAMES" ) {
81 1 : if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
82 0 : error("found no COLLECT_FRAMES or LANDMARK_SELECT action with label " + argn );
83 : } else {
84 1 : ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_sqrdissims");
85 2 : if( dissims ) dissimilarities = argn + "_sqrdissims";
86 : }
87 : }
88 8 : unsigned ndim; parse("NLOW_DIM",ndim);
89 4 : std::string str_ndim; Tools::convert( ndim, str_ndim );
90 : // Construct a projection using classical MDS
91 8 : readInputLine( getShortcutLabel() + "_mds: CLASSICAL_MDS ARG=" + argn + " NLOW_DIM=" + str_ndim );
92 : // Transform the dissimilarities using the switching function
93 4 : std::string hdfunc; parse("HIGH_DIM_FUNCTION",hdfunc);
94 8 : readInputLine( getShortcutLabel() + "_hdmat: MORE_THAN ARG=" + dissimilarities + " SQUARED SWITCH={" + hdfunc + "}");
95 : // Now for the weights - read the vector of weights first
96 9 : std::string wvec; parse("WEIGHTS",wvec); if( wvec.length()==0 ) wvec = argn + "_weights";
97 : // Now calculate the sum of thse weights
98 8 : readInputLine( wvec + "_sum: SUM ARG=" + wvec + " PERIODIC=NO");
99 : // And normalise the vector of weights using this sum
100 8 : readInputLine( wvec + "_normed: CUSTOM ARG=" + wvec + "_sum," + wvec + " FUNC=y/x PERIODIC=NO");
101 : // And now create the matrix of weights
102 8 : readInputLine( wvec + "_mat: OUTER_PRODUCT ARG=" + wvec + "_normed," + wvec + "_normed");
103 : // Run the arrange points object
104 20 : std::string ldfunc, cgtol, maxiter; parse("LOW_DIM_FUNCTION",ldfunc); parse("CGTOL",cgtol); parse("MAXITER",maxiter); unsigned ncycles; parse("NCYCLES",ncycles);
105 5 : std::string num, argstr, lname=getShortcutLabel() + "_ap"; if( ncycles>0 ) lname = getShortcutLabel() + "_cg";
106 12 : argstr = "ARG=" + getShortcutLabel() + "_mds-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_mds-" + num; }
107 4 : bool usesmacof; parseFlag("USE_SMACOF",usesmacof);
108 4 : if( usesmacof ) {
109 2 : std::string smactol, smacreg; parse("SMACTOL",smactol); parse("SMACREG",smacreg);
110 3 : readInputLine( lname + ": ARRANGE_POINTS " + argstr + " MINTYPE=smacof TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat" +
111 3 : " MAXITER=" + maxiter + " SMACTOL=" + smactol + " SMACREG=" + smacreg + " TARGET2=" + getShortcutLabel() + "_mds_mat WEIGHTS2=" + wvec + "_mat");
112 : } else {
113 6 : readInputLine( lname + ": ARRANGE_POINTS " + argstr + " MINTYPE=conjgrad TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
114 3 : if( ncycles>0 ) {
115 2 : std::string buf; parse("BUFFER",buf);
116 2 : std::vector<std::string> fgrid; parseVector("FGRID_SIZE",fgrid);
117 2 : std::string ncyc; Tools::convert(ncycles,ncyc); std::string pwise_args=" NCYCLES=" + ncyc + " BUFFER=" + buf;
118 1 : if( fgrid.size()>0 ) {
119 1 : if( fgrid.size()!=ndim ) error("number of elements of fgrid is not correct");
120 3 : pwise_args += " FGRID_SIZE=" + fgrid[0]; for(unsigned i=1; i<fgrid.size(); ++i) pwise_args += "," + fgrid[i];
121 : }
122 2 : std::vector<std::string> cgrid(ndim); parseVector("CGRID_SIZE",cgrid);
123 3 : pwise_args += " CGRID_SIZE=" + cgrid[0]; for(unsigned i=1; i<cgrid.size(); ++i) pwise_args += "," + cgrid[i];
124 3 : argstr="ARG=" + getShortcutLabel() + "_cg.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_cg.coord-" + num; }
125 2 : readInputLine( getShortcutLabel() + "_ap: ARRANGE_POINTS " + argstr + pwise_args + " MINTYPE=pointwise TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
126 2 : }
127 : }
128 12 : argstr="ARG=" + getShortcutLabel() + "_ap.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_ap.coord-" + num; }
129 8 : readInputLine( getShortcutLabel() + ": VSTACK " + argstr );
130 8 : bool projall; parseFlag("PROJECT_ALL",projall); if( !projall ) return ;
131 4 : parse("OS_CGTOL",cgtol); argstr = getShortcutLabel() + "_ap.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_ap.coord-" + num; }
132 1 : if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
133 0 : readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS " + argstr + " TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
134 : } else {
135 1 : ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_rectdissims");
136 1 : if( !dissims ) error("cannot PROJECT_ALL as " + as->getName() + " with label " + argn + " was involved without the DISSIMILARITIES keyword");
137 2 : readInputLine( getShortcutLabel() + "_lhdmat: MORE_THAN ARG=" + argn + "_rectdissims SQUARED SWITCH={" + hdfunc + "}");
138 2 : readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS ARG=" + argstr + " TARGET1=" + getShortcutLabel() + "_lhdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
139 : }
140 3 : argstr="ARG=" + getShortcutLabel() + "_osample_pp.coord-1"; for(unsigned i=1; i<ndim; ++i) { Tools::convert( i+1, num ); argstr += "," + getShortcutLabel() + "_osample_pp.coord-" + num; }
141 2 : readInputLine( getShortcutLabel() + "_osample: VSTACK " + argstr );
142 0 : }
143 :
144 : }
145 : }
|