Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2011-2017 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionShortcut.h" 23 : #include "core/PlumedMain.h" 24 : #include "core/ActionSet.h" 25 : #include "core/ActionRegister.h" 26 : #include "core/ActionWithValue.h" 27 : #include "tools/IFile.h" 28 : 29 : #include <cmath> 30 : 31 : namespace PLMD { 32 : namespace refdist { 33 : 34 : //+PLUMEDOC FUNCTION KERNEL 35 : /* 36 : Use a switching function to determine how many of the input variables are less than a certain cutoff. 37 : 38 : \par Examples 39 : 40 : */ 41 : //+ENDPLUMEDOC 42 : 43 : 44 : class Kernel : public ActionShortcut { 45 : public: 46 : static std::string fixArgumentDot( const std::string& argin ); 47 : explicit Kernel(const ActionOptions&); 48 : static void registerKeywords(Keywords& keys); 49 : }; 50 : 51 : 52 : PLUMED_REGISTER_ACTION(Kernel,"KERNEL") 53 : 54 20 : void Kernel::registerKeywords(Keywords& keys) { 55 20 : ActionShortcut::registerKeywords( keys ); 56 40 : keys.add("numbered","ARG","the arguments that should be used as input to this method"); 57 40 : keys.add("compulsory","TYPE","gaussian","the type of kernel to use"); 58 40 : keys.add("compulsory","CENTER","the position of the center of the kernel"); 59 40 : keys.add("optional","SIGMA","square root of variance of the cluster"); 60 40 : keys.add("compulsory","COVAR","the covariance of the kernel"); 61 40 : keys.add("compulsory","WEIGHT","1.0","the weight to multiply this kernel function by"); 62 40 : keys.add("optional","REFERENCE","the file from which to read the kernel parameters"); 63 40 : keys.add("compulsory","NUMBER","1","if there are multiple sets of kernel parameters in the input file which set of kernel parameters would you like to read in here"); 64 40 : keys.addFlag("NORMALIZED",false,"would you like the kernel function to be normalized"); 65 40 : keys.setValueDescription("scalar/vector","the value of the kernel evaluated at the argument values"); 66 60 : keys.needsAction("CONSTANT"); keys.needsAction("CUSTOM"); keys.needsAction("NORMALIZED_EUCLIDEAN_DISTANCE"); 67 60 : keys.needsAction("PRODUCT"); keys.needsAction("INVERT_MATRIX"); keys.needsAction("MAHALANOBIS_DISTANCE"); 68 60 : keys.needsAction("DIAGONALIZE"); keys.needsAction("CONCATENATE"); keys.needsAction("DETERMINANT"); 69 20 : keys.needsAction("BESSEL"); 70 20 : } 71 : 72 32 : std::string Kernel::fixArgumentDot( const std::string& argin ) { 73 32 : std::string argout = argin; std::size_t dot=argin.find("."); 74 32 : if( dot!=std::string::npos ) argout = argin.substr(0,dot) + "_" + argin.substr(dot+1); 75 32 : return argout; 76 : } 77 : 78 9 : Kernel::Kernel(const ActionOptions&ao): 79 : Action(ao), 80 9 : ActionShortcut(ao) 81 : { 82 : // Read in the arguments 83 18 : std::vector<std::string> argnames; parseVector("ARG",argnames); 84 9 : if( argnames.size()==0 ) error("no arguments were specified"); 85 : // Now sort out the parameters 86 18 : double weight; std::string fname; parse("REFERENCE",fname); bool usemahalanobis=false; 87 9 : if( fname.length()>0 ) { 88 9 : IFile ifile; ifile.open(fname); ifile.allowIgnoredFields(); 89 9 : unsigned number; parse("NUMBER",number); bool readline=false; 90 : // Create actions to hold the position of the center 91 31 : for(unsigned line=0; line<number; ++line) { 92 90 : for(unsigned i=0; i<argnames.size(); ++i) { 93 59 : std::string val; ifile.scanField(argnames[i], val); 94 75 : if( line==number-1 ) readInputLine( getShortcutLabel() + "_" + fixArgumentDot(argnames[i]) + "_ref: CONSTANT VALUES=" + val ); 95 : } 96 62 : if( ifile.FieldExist("sigma_" + argnames[0]) ) { 97 : std::string varstr; 98 0 : for(unsigned i=0; i<argnames.size(); ++i) { 99 0 : std::string val; ifile.scanField("sigma_" + argnames[i], val); 100 0 : if( i==0 ) varstr = val; else varstr += "," + val; 101 : } 102 0 : if( line==number-1 ) readInputLine( getShortcutLabel() + "_var: CONSTANT VALUES=" + varstr ); 103 : } else { 104 31 : std::string varstr, nvals; Tools::convert( argnames.size(), nvals ); usemahalanobis=(argnames.size()>1); 105 90 : for(unsigned i=0; i<argnames.size(); ++i) { 106 174 : for(unsigned j=0; j<argnames.size(); j++) { 107 230 : std::string val; ifile.scanField("sigma_" +argnames[i] + "_" + argnames[j], val ); 108 199 : if(i==0 && j==0 ) varstr = val; else varstr += "," + val; 109 : } 110 : } 111 31 : if( line==number-1 ) { 112 11 : if( !usemahalanobis ) readInputLine( getShortcutLabel() + "_var: CONSTANT VALUES=" + varstr ); 113 14 : else readInputLine( getShortcutLabel() + "_cov: CONSTANT NCOLS=" + nvals + " NROWS=" + nvals + " VALUES=" + varstr ); 114 : } 115 : } 116 31 : if( line==number-1 ) { readline=true; break; } 117 22 : ifile.scanField(); 118 : } 119 9 : if( !readline ) error("could not read reference configuration"); 120 9 : ifile.scanField(); ifile.close(); 121 9 : } else { 122 : // Create actions to hold the position of the center 123 0 : std::vector<std::string> center(argnames.size()); parseVector("CENTER",center); 124 0 : for(unsigned i=0; i<argnames.size(); ++i) readInputLine( getShortcutLabel() + "_" + fixArgumentDot(argnames[i]) + "_ref: CONSTANT VALUES=" + center[i] ); 125 0 : std::vector<std::string> sig; parseVector("SIGMA",sig); 126 0 : if( sig.size()==0 ) { 127 : // Create actions to hold the covariance 128 0 : std::string cov; parse("COVAR",cov); usemahalanobis=(argnames.size()>1); 129 0 : if( !usemahalanobis ) { 130 0 : readInputLine( getShortcutLabel() + "_var: CONSTANT VALUES=" + cov ); 131 : } else { 132 0 : std::string nvals; Tools::convert( argnames.size(), nvals ); 133 0 : readInputLine( getShortcutLabel() + "_cov: CONSTANT NCOLS=" + nvals + " NROWS=" + nvals + " VALUES=" + cov ); 134 : } 135 0 : } else if( sig.size()==argnames.size() ) { 136 : // And actions to hold the standard deviation 137 0 : std::string valstr = sig[0]; for(unsigned i=1; i<sig.size(); ++i) valstr += "," + sig[i]; 138 0 : readInputLine( getShortcutLabel() + "_sigma: CONSTANT VALUES=" + valstr ); 139 0 : readInputLine( getShortcutLabel() + "_var: CUSTOM ARG=" + getShortcutLabel() + "_sigma FUNC=x*x PERIODIC=NO"); 140 0 : } else error("sigma has wrong length"); 141 0 : } 142 : 143 : // Create the reference point and arguments 144 : std::string refpoint, argstr; 145 25 : for(unsigned i=0; i<argnames.size(); ++i) { 146 34 : if( i==0 ) { argstr = argnames[0]; refpoint = getShortcutLabel() + "_" + fixArgumentDot(argnames[i]) + "_ref"; } 147 21 : else { argstr += "," + argnames[1]; refpoint += "," + getShortcutLabel() + "_" + fixArgumentDot(argnames[i]) + "_ref"; } 148 : } 149 : 150 : // Get the information on the kernel type 151 18 : std::string func_str, ktype; parse("TYPE",ktype); 152 16 : if( ktype=="gaussian" || ktype=="von-misses" ) func_str = "exp(-x/2)"; 153 0 : else if( ktype=="triangular" ) func_str = "step(1.-sqrt(x))*(1.-sqrt(x))"; 154 : else func_str = ktype; 155 9 : std::string vm_str=""; if( ktype=="von-misses" ) vm_str=" VON_MISSES"; 156 : 157 9 : unsigned nvals = argnames.size(); bool norm; parseFlag("NORMALIZED",norm); 158 9 : if( !usemahalanobis ) { 159 : // Invert the variance 160 4 : readInputLine( getShortcutLabel() + "_icov: CUSTOM ARG=" + getShortcutLabel() + "_var FUNC=1/x PERIODIC=NO"); 161 : // Compute the distance between the center of the basin and the current configuration 162 4 : readInputLine( getShortcutLabel() + "_dist_2: NORMALIZED_EUCLIDEAN_DISTANCE SQUARED" + vm_str +" ARG1=" + argstr + " ARG2=" + refpoint + " METRIC=" + getShortcutLabel() + "_icov"); 163 : // And compute a determinent for the input covariance matrix if it is required 164 2 : if( norm ) { 165 2 : if( ktype=="von-misses" ) readInputLine( getShortcutLabel() + "_vec: CUSTOM ARG=" + getShortcutLabel() + "_icov FUNC=x PERIODIC=NO" ); 166 4 : else readInputLine( getShortcutLabel() + "_det: PRODUCT ARG=" + getShortcutLabel() + "_var"); 167 : } 168 : } else { 169 : // Invert the input covariance matrix 170 14 : readInputLine( getShortcutLabel() + "_icov: INVERT_MATRIX ARG=" + getShortcutLabel() + "_cov" ); 171 : // Compute the distance between the center of the basin and the current configuration 172 14 : readInputLine( getShortcutLabel() + "_dist_2: MAHALANOBIS_DISTANCE SQUARED ARG1=" + argstr + " ARG2=" + refpoint + " METRIC=" + getShortcutLabel() + "_icov " + vm_str ); 173 : // And compute a determinent for the input covariance matrix if it is required 174 7 : if( norm ) { 175 7 : if( ktype=="von-misses" ) { 176 14 : readInputLine( getShortcutLabel() + "_det: DIAGONALIZE ARG=" + getShortcutLabel() + "_cov VECTORS=all" ); 177 7 : std::string num, argnames= getShortcutLabel() + "_det.vals-1"; 178 14 : for(unsigned i=1; i<nvals; ++i) { Tools::convert( i+1, num ); argnames += "," + getShortcutLabel() + "_det.vals-" + num; } 179 14 : readInputLine( getShortcutLabel() + "_comp: CONCATENATE ARG=" + argnames ); 180 14 : readInputLine( getShortcutLabel() + "_vec: CUSTOM ARG=" + getShortcutLabel() + "_comp FUNC=1/x PERIODIC=NO"); 181 : } else { 182 0 : readInputLine( getShortcutLabel() + "_det: DETERMINANT ARG=" + getShortcutLabel() + "_cov"); 183 : } 184 : } 185 : } 186 : 187 : // Compute the Gaussian 188 9 : std::string wstr; parse("WEIGHT",wstr); 189 9 : if( norm ) { 190 9 : if( ktype=="gaussian" ) { 191 2 : std::string pstr; Tools::convert( sqrt(pow(2*pi,nvals)), pstr ); 192 4 : readInputLine( getShortcutLabel() + "_vol: CUSTOM ARG=" + getShortcutLabel() + "_det FUNC=(sqrt(x)*" + pstr + ") PERIODIC=NO"); 193 7 : } else if( ktype=="von-misses" ) { 194 : std::string wstr, min, max; 195 14 : ActionWithValue* av=plumed.getActionSet().selectWithLabel<ActionWithValue*>( getShortcutLabel() + "_dist_2_diff" ); plumed_assert( av ); 196 7 : if( !av->copyOutput(0)->isPeriodic() ) error("VON_MISSES only works with periodic variables"); 197 7 : av->copyOutput(0)->getDomain(min,max); 198 14 : readInputLine( getShortcutLabel() + "_bes: BESSEL ORDER=0 ARG=" + getShortcutLabel() + "_vec"); 199 14 : readInputLine( getShortcutLabel() + "_cc: CUSTOM ARG=" + getShortcutLabel() + "_bes FUNC=("+max+"-"+min+")*x PERIODIC=NO"); 200 14 : readInputLine( getShortcutLabel() + "_vol: PRODUCT ARG=" + getShortcutLabel() + "_cc"); 201 0 : } else error("only gaussian and von-misses kernels are normalizable"); 202 : // And the (suitably normalized) kernel 203 18 : readInputLine( getShortcutLabel() + ": CUSTOM ARG=" + getShortcutLabel() + "_dist_2," + getShortcutLabel() + "_vol FUNC=" + wstr + "*exp(-x/2)/y PERIODIC=NO"); 204 : } else { 205 0 : readInputLine( getShortcutLabel() + ": CUSTOM ARG1=" + getShortcutLabel() + "_dist_2 FUNC=" + wstr + "*" + func_str + " PERIODIC=NO"); 206 : } 207 9 : checkRead(); 208 : 209 9 : } 210 : 211 : } 212 : } 213 : 214 :