Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2015-2020 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionShortcut.h"
23 : #include "core/PlumedMain.h"
24 : #include "core/ActionSet.h"
25 : #include "core/ActionRegister.h"
26 : #include "core/ActionWithValue.h"
27 :
28 : //+PLUMEDOC DIMRED SKETCHMAP
29 : /*
30 : Construct a sketch map projection of the input data
31 :
32 : \par Examples
33 :
34 : */
35 : //+ENDPLUMEDOC
36 :
37 : namespace PLMD {
38 : namespace dimred {
39 :
40 : class SketchMap : public ActionShortcut {
41 : public:
42 : static void registerKeywords( Keywords& keys );
43 : explicit SketchMap( const ActionOptions& ao );
44 : };
45 :
46 : PLUMED_REGISTER_ACTION(SketchMap,"SKETCHMAP")
47 :
48 9 : void SketchMap::registerKeywords( Keywords& keys ) {
49 9 : ActionShortcut::registerKeywords( keys );
50 9 : keys.add("compulsory","NLOW_DIM","number of low-dimensional coordinates required");
51 9 : keys.add("optional","WEIGHTS","a vector containing the weights of the points");
52 9 : keys.add("compulsory","ARG","the matrix of high dimensional coordinates that you want to project in the low dimensional space");
53 9 : keys.add("compulsory","HIGH_DIM_FUNCTION","the parameters of the switching function in the high dimensional space");
54 9 : keys.add("compulsory","LOW_DIM_FUNCTION","the parameters of the switching function in the low dimensional space");
55 9 : keys.add("compulsory","CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the projection of the landmarks");
56 9 : keys.add("compulsory","MAXITER","1000","maximum number of optimization cycles for optimisation algorithms");
57 9 : keys.add("compulsory","NCYCLES","0","The number of cycles of pointwise global optimisation that are required");
58 9 : keys.add("compulsory","BUFFER","1.1","grid extent for search is (max projection - minimum projection) multiplied by this value");
59 9 : keys.add("compulsory","CGRID_SIZE","10","number of points to use in each grid direction");
60 9 : keys.add("compulsory","FGRID_SIZE","0","interpolate the grid onto this number of points -- only works in 2D");
61 9 : keys.addFlag("PROJECT_ALL",false,"if the input are landmark coordinates then project the out of sample configurations");
62 9 : keys.add("compulsory","OS_CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the out of sample projections");
63 9 : keys.addFlag("USE_SMACOF",false,"find the projection in the low dimensional space using the SMACOF algorithm");
64 9 : keys.add("compulsory","SMACTOL","1E-4","the tolerance for the smacof algorithm");
65 9 : keys.add("compulsory","SMACREG","0.001","this is used to ensure that we don't divide by zero when updating weights for SMACOF algorithm");
66 18 : keys.setValueDescription("matrix","the sketch-map projection of the input points");
67 18 : keys.addOutputComponent("osample","PROJECT_ALL","matrix","the out-of-sample projections");
68 9 : keys.needsAction("CLASSICAL_MDS");
69 9 : keys.needsAction("MORE_THAN");
70 9 : keys.needsAction("SUM");
71 9 : keys.needsAction("CUSTOM");
72 9 : keys.needsAction("OUTER_PRODUCT");
73 9 : keys.needsAction("ARRANGE_POINTS");
74 9 : keys.needsAction("PROJECT_POINTS");
75 9 : keys.needsAction("VSTACK");
76 9 : }
77 :
78 4 : SketchMap::SketchMap( const ActionOptions& ao):
79 : Action(ao),
80 4 : ActionShortcut(ao) {
81 : // Get the high dimensioal data
82 : std::string argn;
83 4 : parse("ARG",argn);
84 4 : std::string dissimilarities = getShortcutLabel() + "_mds_mat";
85 4 : ActionShortcut* as = plumed.getActionSet().getShortcutActionWithLabel( argn );
86 4 : if( !as ) {
87 0 : error("found no action with name " + argn );
88 : }
89 4 : if( as->getName()!="COLLECT_FRAMES" ) {
90 1 : if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
91 0 : error("found no COLLECT_FRAMES or LANDMARK_SELECT action with label " + argn );
92 : } else {
93 1 : ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_sqrdissims");
94 1 : if( dissims ) {
95 2 : dissimilarities = argn + "_sqrdissims";
96 : }
97 : }
98 : }
99 : unsigned ndim;
100 8 : parse("NLOW_DIM",ndim);
101 : std::string str_ndim;
102 4 : Tools::convert( ndim, str_ndim );
103 : // Construct a projection using classical MDS
104 8 : readInputLine( getShortcutLabel() + "_mds: CLASSICAL_MDS ARG=" + argn + " NLOW_DIM=" + str_ndim );
105 : // Transform the dissimilarities using the switching function
106 : std::string hdfunc;
107 4 : parse("HIGH_DIM_FUNCTION",hdfunc);
108 8 : readInputLine( getShortcutLabel() + "_hdmat: MORE_THAN ARG=" + dissimilarities + " SQUARED SWITCH={" + hdfunc + "}");
109 : // Now for the weights - read the vector of weights first
110 : std::string wvec;
111 8 : parse("WEIGHTS",wvec);
112 4 : if( wvec.length()==0 ) {
113 2 : wvec = argn + "_weights";
114 : }
115 : // Now calculate the sum of thse weights
116 8 : readInputLine( wvec + "_sum: SUM ARG=" + wvec + " PERIODIC=NO");
117 : // And normalise the vector of weights using this sum
118 8 : readInputLine( wvec + "_normed: CUSTOM ARG=" + wvec + "_sum," + wvec + " FUNC=y/x PERIODIC=NO");
119 : // And now create the matrix of weights
120 8 : readInputLine( wvec + "_mat: OUTER_PRODUCT ARG=" + wvec + "_normed," + wvec + "_normed");
121 : // Run the arrange points object
122 : std::string ldfunc, cgtol, maxiter;
123 4 : parse("LOW_DIM_FUNCTION",ldfunc);
124 4 : parse("CGTOL",cgtol);
125 4 : parse("MAXITER",maxiter);
126 : unsigned ncycles;
127 8 : parse("NCYCLES",ncycles);
128 4 : std::string num, argstr, lname=getShortcutLabel() + "_ap";
129 4 : if( ncycles>0 ) {
130 2 : lname = getShortcutLabel() + "_cg";
131 : }
132 8 : argstr = "ARG=" + getShortcutLabel() + "_mds-1";
133 8 : for(unsigned i=1; i<ndim; ++i) {
134 4 : Tools::convert( i+1, num );
135 8 : argstr += "," + getShortcutLabel() + "_mds-" + num;
136 : }
137 : bool usesmacof;
138 4 : parseFlag("USE_SMACOF",usesmacof);
139 4 : if( usesmacof ) {
140 : std::string smactol, smacreg;
141 1 : parse("SMACTOL",smactol);
142 1 : parse("SMACREG",smacreg);
143 3 : readInputLine( lname + ": ARRANGE_POINTS " + argstr + " MINTYPE=smacof TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat" +
144 3 : " MAXITER=" + maxiter + " SMACTOL=" + smactol + " SMACREG=" + smacreg + " TARGET2=" + getShortcutLabel() + "_mds_mat WEIGHTS2=" + wvec + "_mat");
145 : } else {
146 6 : readInputLine( lname + ": ARRANGE_POINTS " + argstr + " MINTYPE=conjgrad TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
147 3 : if( ncycles>0 ) {
148 : std::string buf;
149 2 : parse("BUFFER",buf);
150 : std::vector<std::string> fgrid;
151 2 : parseVector("FGRID_SIZE",fgrid);
152 : std::string ncyc;
153 1 : Tools::convert(ncycles,ncyc);
154 2 : std::string pwise_args=" NCYCLES=" + ncyc + " BUFFER=" + buf;
155 1 : if( fgrid.size()>0 ) {
156 1 : if( fgrid.size()!=ndim ) {
157 0 : error("number of elements of fgrid is not correct");
158 : }
159 1 : pwise_args += " FGRID_SIZE=" + fgrid[0];
160 2 : for(unsigned i=1; i<fgrid.size(); ++i) {
161 2 : pwise_args += "," + fgrid[i];
162 : }
163 : }
164 1 : std::vector<std::string> cgrid(ndim);
165 2 : parseVector("CGRID_SIZE",cgrid);
166 1 : pwise_args += " CGRID_SIZE=" + cgrid[0];
167 2 : for(unsigned i=1; i<cgrid.size(); ++i) {
168 2 : pwise_args += "," + cgrid[i];
169 : }
170 2 : argstr="ARG=" + getShortcutLabel() + "_cg.coord-1";
171 2 : for(unsigned i=1; i<ndim; ++i) {
172 1 : Tools::convert( i+1, num );
173 2 : argstr += "," + getShortcutLabel() + "_cg.coord-" + num;
174 : }
175 2 : readInputLine( getShortcutLabel() + "_ap: ARRANGE_POINTS " + argstr + pwise_args + " MINTYPE=pointwise TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
176 2 : }
177 : }
178 8 : argstr="ARG=" + getShortcutLabel() + "_ap.coord-1";
179 8 : for(unsigned i=1; i<ndim; ++i) {
180 4 : Tools::convert( i+1, num );
181 8 : argstr += "," + getShortcutLabel() + "_ap.coord-" + num;
182 : }
183 8 : readInputLine( getShortcutLabel() + ": VSTACK " + argstr );
184 : bool projall;
185 4 : parseFlag("PROJECT_ALL",projall);
186 4 : if( !projall ) {
187 : return ;
188 : }
189 1 : parse("OS_CGTOL",cgtol);
190 1 : argstr = getShortcutLabel() + "_ap.coord-1";
191 2 : for(unsigned i=1; i<ndim; ++i) {
192 1 : Tools::convert( i+1, num );
193 2 : argstr += "," + getShortcutLabel() + "_ap.coord-" + num;
194 : }
195 1 : if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
196 0 : readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS " + argstr + " TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
197 : } else {
198 1 : ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_rectdissims");
199 1 : if( !dissims ) {
200 0 : error("cannot PROJECT_ALL as " + as->getName() + " with label " + argn + " was involved without the DISSIMILARITIES keyword");
201 : }
202 2 : readInputLine( getShortcutLabel() + "_lhdmat: MORE_THAN ARG=" + argn + "_rectdissims SQUARED SWITCH={" + hdfunc + "}");
203 2 : readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS ARG=" + argstr + " TARGET1=" + getShortcutLabel() + "_lhdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
204 : }
205 2 : argstr="ARG=" + getShortcutLabel() + "_osample_pp.coord-1";
206 2 : for(unsigned i=1; i<ndim; ++i) {
207 1 : Tools::convert( i+1, num );
208 2 : argstr += "," + getShortcutLabel() + "_osample_pp.coord-" + num;
209 : }
210 2 : readInputLine( getShortcutLabel() + "_osample: VSTACK " + argstr );
211 0 : }
212 :
213 : }
214 : }
|