LCOV - code coverage report
Current view: top level - ves - BF_Wavelets.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 118 118 100.0 %
Date: 2024-10-18 14:00:25 Functions: 5 5 100.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2016-2021 The VES code team
       3             :    (see the PEOPLE-VES file at the root of this folder for a list of names)
       4             : 
       5             :    See http://www.ves-code.org for more information.
       6             : 
       7             :    This file is part of VES code module.
       8             : 
       9             :    The VES code module is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    The VES code module is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with the VES code module.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : 
      23             : 
      24             : #include "BasisFunctions.h"
      25             : #include "GridLinearInterpolation.h"
      26             : #include "tools/Grid.h"
      27             : #include "VesTools.h"
      28             : #include "WaveletGrid.h"
      29             : #include "core/ActionRegister.h"
      30             : #include "tools/Exception.h"
      31             : #include "core/PlumedMain.h"
      32             : 
      33             : 
      34             : namespace PLMD {
      35             : namespace ves {
      36             : 
      37             : 
      38             : //+PLUMEDOC VES_BASISF BF_WAVELETS
      39             : /*
      40             : Daubechies Wavelets basis functions.
      41             : 
      42             : Note: at the moment only bases with a single level of scaling functions are usable, as multiscale optimization is not yet implemented.
      43             : 
      44             : This basis set uses Daubechies Wavelets \cite daubechies_ten_1992 to construct a complete and orthogonal basis. See \cite ValssonPampel_Wavelets_2022 for full details.
      45             : 
      46             : The basis set is based on using a pair of functions, the scaling function (or father wavelet) \f$\phi\f$ and the wavelet function (or mother wavelet) \f$\psi\f$.
      47             : They are defined via the two-scale relations for scale \f$j\f$ and shift \f$k\f$:
      48             : 
      49             : \f{align*}{
      50             :   \phi_k^j \left(x\right) = 2^{-j/2} \phi \left( 2^{-j} x - k\right)\\
      51             :   \psi_k^j \left(x\right) = 2^{-j/2} \psi \left( 2^{-j} x - k\right)
      52             : \f}
      53             : 
      54             : The exact properties are set by choosing filter coefficients, e.g. choosing \f$h_k\f$ for the father wavelet:
      55             : 
      56             : \f[
      57             :   \phi\left(x\right) = \sqrt{2} \sum_k h_k\, \phi \left( 2 x - k\right)
      58             : \f]
      59             : 
      60             : The filter coefficients by Daubechies result in an orthonormal basis of all integer shifted functions:
      61             : \f[
      62             :   \int \phi(x+i) \phi(x+j) \mathop{}\!\mathrm{d}x = \delta_{ij} \quad \text{for} \quad i,j \in \mathbb{Z}
      63             : \f]
      64             : 
      65             : Because no analytic formula for these wavelets exist, they are instead constructed iteratively on a grid.
      66             : The method of construction is close to the "Vector cascade algorithm" described in \cite strang_wavelets_1997 .
      67             : The needed filter coefficients of the scaling function are hardcoded, and were previously generated via a python script.
      68             : Currently the "maximum phase" type (Db) and the "least asymmetric" (Sym) type are implemented.
      69             : We recommend to use Symlets.
      70             : 
      71             : As an example two adjacent basis functions of both Sym8 (ORDER=8, TYPE=SYMLET) and Db8 (ORDER=8, TYPE=DAUBECHIES) is shown in the figure.
      72             : The full basis consists of shifted wavelets in the full specified interval.
      73             : 
      74             : \image html ves_basisf-wavelets.png
      75             : 
      76             : 
      77             : \par Specify the wavelet type
      78             : 
      79             : The TYPE keyword sets the type of Wavelet, at the moment "DAUBECHIES" and "SYMLETS" are available.
      80             : The specified ORDER of the basis corresponds to the number of vanishing moments of the wavelet, i.e. if TYPE was specified as "DAUBECHIES" an order of 8 results in Db8 wavelets.
      81             : 
      82             : 
      83             : \par Specify the number of functions
      84             : 
      85             : The resulting basis set consists of integer shifts of the wavelet with some scaling \f$j\f$,
      86             : \f[
      87             :   V(x) = \sum_i \alpha_i * \phi_i (x) = \sum_i \alpha_i * \phi(\frac{x+i}{j})
      88             : \f]
      89             : with the variational parameters \f$ \alpha \f$.
      90             : Additionally a constant basis function is included.
      91             : 
      92             : There are two different ways to specify the number of used basis functions implemented.
      93             : You can either specify the scale or alternatively a fixed number of basis function.
      94             : 
      95             : Coming from the multiresolution aspect of wavelets, you can set the scale of the father wavelets, i.e. the largest scale used for approximation.
      96             : This can be done with the FUNCTION_LENGTH keyword.
      97             : It should be given in the same units as the used CV and specifies the length (of the domain interval) of the individual father wavelet functions.
      98             : 
      99             : Alternatively a fixed number of basis functions for the bias expansion can be specified with the NUM_BF keyword, which will set the scale automatically to match the desired number of functions.
     100             : Note that this also includes the constant function.
     101             : 
     102             : If you do not specify anything, it is assumed that the range of the bias should match the scale of the wavelet functions.
     103             : More precise, the basis functions are scaled to match the specified size of the CV space (MINIMUM and MAXIMUM keywords).
     104             : This has so far been a good initial choice.
     105             : 
     106             : If the wavelets are scaled to match the CV range exactly there would be \f$4*\text{ORDER} -3\f$ basis functions whose domain is at least partially in this region.
     107             : This number is adjusted if FUNCTION_LENGTH or NUM_BF is specified.
     108             : Additionally, some of the shifted basis functions will not have significant contributions because of their function values being close to zero over the full range of the bias.
     109             : These 'tail wavelets' can be omitted by using the TAILS_THRESHOLD keyword.
     110             : This omits all shifted functions that have only function values smaller than a fraction of their maximum value inside the bias range.
     111             : Using a value of e.g. 0.01 will already reduce the number of basis functions significantly.
     112             : The default setting will not omit any tail wavelets (i.e. TAILS_THRESHOLD=0).
     113             : 
     114             : The number of basis functions is then not easily determinable a priori but will be given in the logfile.
     115             : Additionally the starting point (leftmost defined point) of the individual basis functions is printed.
     116             : 
     117             : 
     118             : With the PERIODIC keyword the basis set can also be used to bias periodic CVs.
     119             : Then the shift between the functions will be chosen such that the function at the left border and right border coincide.
     120             : If the FUNCTION_LENGTH keyword is used together with PERIODIC, a smaller length might be chosen to satisfy this requirement.
     121             : 
     122             : 
     123             : \par Grid
     124             : 
     125             : The values of the wavelet function are generated on a grid.
     126             : Using the cascade algorithm results in doubling the grid values for each iteration.
     127             : This means that the grid size will always be a power of two multiplied by the number of coefficients (\f$ 2*\text{ORDER} -1\f$) for the specified wavelet.
     128             : Using the MIN_GRID_SIZE keyword a lower bound for the number of grid points can be specified.
     129             : By default at least 1,000 grid points are used.
     130             : Function values in between grid points are calculated by linear interpolation.
     131             : 
     132             : \par Optimization notes
     133             : 
     134             : To avoid 'blind' optimization of the basis functions outside the currently sampled area, it is often beneficial to use the OPTIMIZATION_THRESHOLD keyword of the \ref VES_LINEAR_EXPANSION (set it to a small value, e.g. 1e-6)
     135             : 
     136             : \par Examples
     137             : 
     138             : 
     139             : First a very simple example that relies on the default values.
     140             : We want to bias some CV in the range of 0 to 4.
     141             : The wavelets will therefore be scaled to match that range.
     142             : Using Db8 wavelets this results in 30 basis functions (including the constant one), with their starting points given by \f$ -14*\frac{4}{15}, -13*\frac{4}{15}, \cdots , 0 , \cdots, 13*\frac{4}{15}, 14*\frac{4}{15} \f$.
     143             : \plumedfile
     144             : BF_WAVELETS ...
     145             :  ORDER=8
     146             :  TYPE=DAUBECHIES
     147             :  MINIMUM=0.0
     148             :  MAXIMUM=4.0
     149             :  LABEL=bf
     150             : ... BF_WAVELETS
     151             : \endplumedfile
     152             : 
     153             : 
     154             : By omitting wavelets with only insignificant parts, we can reduce the number of basis functions. Using a threshold of 0.01 will in this example remove the 8 leftmost shifts, which we can check in the logfile.
     155             : \plumedfile
     156             : BF_WAVELETS ...
     157             :  ORDER=8
     158             :  TYPE=DAUBECHIES
     159             :  MINIMUM=0.0
     160             :  MAXIMUM=4.0
     161             :  TAILS_THRESHOLD=0.01
     162             :  LABEL=bf
     163             : ... BF_WAVELETS
     164             : \endplumedfile
     165             : 
     166             : 
     167             : The length of the individual basis functions can also be adjusted to fit the specific problem.
     168             : If for example the wavelets are instead scaled to length 3, there will be 35 basis functions, with leftmost points at \f$ -14*\frac{3}{15}, -13*\frac{3}{15}, \cdots, 0, \cdots, 18*\frac{3}{15}, 19*\frac{3}{15} \f$.
     169             : \plumedfile
     170             : BF_WAVELETS ...
     171             :  ORDER=8
     172             :  TYPE=DAUBECHIES
     173             :  MINIMUM=0.0
     174             :  MAXIMUM=4.0
     175             :  FUNCTION_LENGTH=3
     176             :  LABEL=bf
     177             : ... BF_WAVELETS
     178             : \endplumedfile
     179             : 
     180             : 
     181             : Alternatively you can also specify the number of basis functions. Here we specify the usage of 40 Sym10 wavelet functions. We also used a custom minimum size for the grid and want it to be printed to a file with a specific numerical format.
     182             : \plumedfile
     183             : BF_WAVELETS ...
     184             :  ORDER=10
     185             :  TYPE=SYMLETS
     186             :  MINIMUM=0.0
     187             :  MAXIMUM=4.0
     188             :  NUM_BF=40
     189             :  MIN_GRID_SIZE=500
     190             :  DUMP_WAVELET_GRID
     191             :  WAVELET_FILE_FMT=%11.4f
     192             :  LABEL=bf
     193             : ... BF_WAVELETS
     194             : \endplumedfile
     195             : 
     196             : */
     197             : //+ENDPLUMEDOC
     198             : 
     199             : 
     200             : class BF_Wavelets : public BasisFunctions {
     201             : private:
     202             :   void setupLabels() override;
     203             :   /// ptr to Grid that holds the Wavelet values and its derivative
     204             :   std::unique_ptr<Grid> waveletGrid_;
     205             :   /// calculate threshold for omitted tail wavelets
     206             :   std::vector<double> getCutoffPoints(const double& threshold);
     207             :   /// scale factor of the individual BFs to match specified length
     208             :   double scale_;
     209             :   /// shift of the individual BFs
     210             :   std::vector<double> shifts_;
     211             : public:
     212             :   static void registerKeywords( Keywords&);
     213             :   explicit BF_Wavelets(const ActionOptions&);
     214             :   void getAllValues(const double, double&, bool&, std::vector<double>&, std::vector<double>&) const override;
     215             : };
     216             : 
     217             : 
     218             : PLUMED_REGISTER_ACTION(BF_Wavelets,"BF_WAVELETS")
     219             : 
     220             : 
     221          49 : void BF_Wavelets::registerKeywords(Keywords& keys) {
     222          49 :   BasisFunctions::registerKeywords(keys);
     223          98 :   keys.add("compulsory","TYPE","Specify the wavelet type. Currently available are DAUBECHIES Wavelets with minimum phase and the more symmetric SYMLETS");
     224          98 :   keys.add("optional","FUNCTION_LENGTH","The domain size of the individual basis functions. (length) This is used to alter the scaling of the basis functions. By default it is set to the total size of the interval. This also influences the number of actually used basis functions, as all shifted functions that are partially supported in the CV space are used.");
     225          98 :   keys.add("optional","NUM_BF","The number of basis functions that should be used. Includes the constant one and N-1 shifted wavelets within the specified range. Cannot be used together with FUNCTION_LENGTH.");
     226          98 :   keys.add("optional","TAILS_THRESHOLD","The threshold for cutting off tail wavelets as a fraction of the maximum value. All shifted wavelet functions that only have values smaller than the threshold in the bias range will be excluded from the basis set. Defaults to 0 (include all).");
     227          98 :   keys.addFlag("MOTHER_WAVELET", false, "If this flag is set mother wavelets will be used instead of the scaling function (father wavelet). Makes only sense for multiresolution, which is at the moment not usable.");
     228          98 :   keys.add("optional","MIN_GRID_SIZE","The minimal number of grid bins of the Wavelet function. The true number depends also on the used wavelet type and will probably be larger. Defaults to 1000.");
     229          98 :   keys.addFlag("DUMP_WAVELET_GRID", false, "If this flag is set the grid with the wavelet values will be written to a file.  This file is called wavelet_grid.data.");
     230          98 :   keys.add("optional","WAVELET_FILE_FMT","The number format of the wavelet grid values and derivatives written to file. By default it is %15.8f.\n");
     231          98 :   keys.addFlag("PERIODIC", false, "Use periodic version of basis set.");
     232          49 :   keys.remove("NUMERICAL_INTEGRALS");
     233          49 : }
     234             : 
     235             : 
     236          47 : BF_Wavelets::BF_Wavelets(const ActionOptions& ao):
     237             :   PLUMED_VES_BASISFUNCTIONS_INIT(ao),
     238          47 :   waveletGrid_(nullptr),
     239          47 :   scale_(0.0)
     240             : {
     241          47 :   log.printf("  Wavelet basis functions, see and cite ");
     242          94 :   log << plumed.cite("Pampel and Valsson, J. Chem. Theory Comput. 18, 4127-4141 (2022) - DOI:10.1021/acs.jctc.2c00197");
     243             : 
     244             :   // parse properties for waveletGrid and set it up
     245             :   bool use_mother_wavelet;
     246          94 :   parseFlag("MOTHER_WAVELET", use_mother_wavelet);
     247             : 
     248             :   std::string wavelet_type_str;
     249          47 :   parse("TYPE", wavelet_type_str);
     250          94 :   addKeywordToList("TYPE", wavelet_type_str);
     251             : 
     252          47 :   unsigned min_grid_size = 1000;
     253          47 :   parse("MIN_GRID_SIZE", min_grid_size);
     254          83 :   if(min_grid_size != 1000) {addKeywordToList("MIN_GRID_SIZE",min_grid_size);}
     255             : 
     256          94 :   waveletGrid_ = WaveletGrid::setupGrid(getOrder(), min_grid_size, use_mother_wavelet, WaveletGrid::stringToType(wavelet_type_str));
     257          47 :   bool dump_wavelet_grid=false;
     258          47 :   parseFlag("DUMP_WAVELET_GRID", dump_wavelet_grid);
     259          47 :   if (dump_wavelet_grid) {
     260          36 :     OFile wavelet_gridfile;
     261          36 :     std::string fmt = "%13.6f";
     262          72 :     parse("WAVELET_FILE_FMT",fmt);
     263             :     waveletGrid_->setOutputFmt(fmt); // property of grid not OFile determines fmt
     264          36 :     wavelet_gridfile.link(*this);
     265          36 :     wavelet_gridfile.enforceBackup();
     266          72 :     wavelet_gridfile.open(getLabel()+".wavelet_grid.data");
     267          36 :     waveletGrid_->writeToFile(wavelet_gridfile);
     268          36 :   }
     269             : 
     270          47 :   bool periodic = false;
     271          47 :   parseFlag("PERIODIC",periodic);
     272          51 :   if (periodic) {addKeywordToList("PERIODIC",periodic);}
     273             : 
     274             :   // now set up properties of basis set
     275          47 :   unsigned intrinsic_length = 2*getOrder() - 1; // length of unscaled wavelet
     276          47 :   double bias_length = intervalMax() - intervalMin(); // intervalRange() is not yet set
     277             : 
     278             :   // parse threshold for tail wavelets and get respective cutoff points
     279          47 :   double threshold = 0.0;
     280          47 :   std::vector<double> cutoffpoints (2);
     281          47 :   parse("TAILS_THRESHOLD",threshold);
     282          47 :   plumed_massert(threshold < 1, "TAILS_THRESHOLD should be significantly smaller than 1.");
     283          47 :   if(threshold == 0.0) {
     284          45 :     cutoffpoints = {0.0, static_cast<double>(intrinsic_length)};
     285             :   }
     286             :   else {
     287           2 :     plumed_massert(!periodic, "TAILS_THRESHOLD can't be used with the periodic wavelet variant");
     288           2 :     addKeywordToList("TAILS_THRESHOLD",threshold);
     289           4 :     cutoffpoints = getCutoffPoints(threshold);
     290             :   };
     291             : 
     292          47 :   double function_length = bias_length;
     293          47 :   parse("FUNCTION_LENGTH",function_length);
     294          47 :   if(function_length != bias_length) {
     295           4 :     if (periodic) {  // shifted functions need to fit into interval exactly -> reduce size if not
     296           2 :       unsigned num_shifts = ceil(bias_length * intrinsic_length / function_length);
     297           2 :       function_length = bias_length * intrinsic_length / num_shifts;
     298             :     }
     299           8 :     addKeywordToList("FUNCTION_LENGTH",function_length);
     300             :   }
     301             : 
     302             :   // determine number of BFs and needed scaling
     303          47 :   unsigned num_BFs = 0;
     304          47 :   parse("NUM_BF",num_BFs);
     305          47 :   if(num_BFs == 0) { // get from function length
     306          43 :     scale_ = intrinsic_length / function_length;
     307          43 :     if (periodic) {
     308             :       // this is the same value as num_shifts above + constant
     309           2 :       num_BFs = static_cast<unsigned>(bias_length * scale_) + 1;
     310             :     }
     311             :     else {
     312          41 :       num_BFs = 1; // constant one
     313             :       // left shifts (w/o left cutoff) + right shifts - right cutoff - 1
     314          41 :       num_BFs += static_cast<unsigned>(ceil(cutoffpoints[1] + (bias_length)*scale_ - cutoffpoints[0]) - 1);
     315             :     }
     316             :   }
     317             :   else {
     318             :     plumed_massert(num_BFs > 0, "The number of basis functions has to be positive (NUM_BF > 0)");
     319             :     // check does not work if function length was given as intrinsic length, but can't check for keyword use directly
     320           4 :     plumed_massert(function_length==bias_length,"The keywords \"NUM_BF\" and \"FUNCTION_LENGTH\" cannot be used at the same time");
     321           4 :     addKeywordToList("NUM_BF",num_BFs);
     322             : 
     323           4 :     if (periodic) {  // inverted num_BFs calculation from where FUNCTION_LENGTH is specified
     324           2 :       scale_ = (num_BFs  - 1) / bias_length ;
     325             :     }
     326             :     else {
     327           2 :       double cutoff_length = cutoffpoints[1] - cutoffpoints [0];
     328           2 :       double intrinsic_bias_length = num_BFs - cutoff_length + 1; // length of bias in intrinsic scale of wavelets
     329           2 :       scale_ = intrinsic_bias_length / bias_length;
     330             :     }
     331             :   }
     332             : 
     333          47 :   setNumberOfBasisFunctions(num_BFs);
     334             : 
     335             :   // now set up the starting points of the basis functions
     336          47 :   shifts_.push_back(0.0); // constant BF – never used, just for clearer notation
     337        1908 :   for(unsigned int i = 1; i < getNumberOfBasisFunctions(); ++i) {
     338        1861 :     shifts_.push_back(-intervalMin()*scale_ + cutoffpoints[1] - i);
     339             :   }
     340             : 
     341             :   // set some properties
     342          47 :   setIntrinsicInterval(0.0,intrinsic_length);
     343          47 :   periodic ? setPeriodic() : setNonPeriodic();
     344             :   setIntervalBounded();
     345             :   setType(wavelet_type_str);
     346          47 :   setDescription("Wavelets as localized basis functions");
     347          47 :   setupBF();
     348          47 :   checkRead();
     349             : 
     350          47 :   log.printf("  Each basisfunction spans %f in CV space\n", intrinsic_length/scale_);
     351          47 : }
     352             : 
     353             : 
     354       62249 : void BF_Wavelets::getAllValues(const double arg, double& argT, bool& inside_range, std::vector<double>& values, std::vector<double>& derivs) const {
     355       62249 :   argT=checkIfArgumentInsideInterval(arg,inside_range);
     356             :   //
     357       62249 :   values[0]=1.0;
     358       62249 :   derivs[0]=0.0;
     359     2315762 :   for(unsigned int i = 1; i < getNumberOfBasisFunctions(); ++i) {
     360             :     // scale and shift argument to match current wavelet
     361     2253513 :     double x = shifts_[i] + argT*scale_;
     362     2253513 :     if (arePeriodic()) { // periodic interval [0,intervalRange*scale]
     363      171766 :       x = x - floor(x/(intervalRange()*scale_))*intervalRange()*scale_;
     364             :     }
     365             : 
     366     2253513 :     if (x < 0 || x >= intrinsicIntervalMax()) { // Wavelets are 0 outside the defined range
     367      989659 :       values[i] = 0.0; derivs[i] = 0.0;
     368             :     }
     369             :     else {
     370     1263854 :       std::vector<double> temp_deriv (1);
     371     1263854 :       values[i] = GridLinearInterpolation::getGridValueAndDerivativesWithLinearInterpolation(waveletGrid_.get(), {x}, temp_deriv);
     372     1263854 :       derivs[i] = temp_deriv[0] * scale_; // scale derivative
     373             :     }
     374             :   }
     375       67885 :   if(!inside_range) {for(auto& deriv : derivs) {deriv=0.0;}}
     376       62249 : }
     377             : 
     378             : 
     379             : // returns left and right cutoff point of Wavelet
     380             : // threshold is a percent value of maximum
     381           2 : std::vector<double> BF_Wavelets::getCutoffPoints(const double& threshold) {
     382           2 :   double threshold_value = threshold * waveletGrid_->getMaxValue();
     383             :   std::vector<double> cutoffpoints;
     384             : 
     385         475 :   for (size_t i = 0; i < waveletGrid_->getSize(); ++i) {
     386         475 :     if (fabs(waveletGrid_->getValue(i)) >= threshold_value) {
     387           2 :       cutoffpoints.push_back(waveletGrid_->getPoint(i)[0]);
     388           2 :       break;
     389             :     }
     390             :   }
     391             : 
     392        1073 :   for (int i = waveletGrid_->getSize() - 1; i >= 0; --i) {
     393        1073 :     if (fabs(waveletGrid_->getValue(i)) >= threshold_value) {
     394           2 :       cutoffpoints.push_back(waveletGrid_->getPoint(i)[0]);
     395           2 :       break;
     396             :     }
     397             :   }
     398             : 
     399           2 :   return cutoffpoints;
     400             : }
     401             : 
     402             : 
     403             : // labels according to minimum position in CV space
     404          47 : void BF_Wavelets::setupLabels() {
     405          47 :   setLabel(0,"const");
     406        1908 :   for(unsigned int i=1; i < getNumberOfBasisFunctions(); i++) {
     407        1861 :     double pos = -shifts_[i]/scale_;
     408        1861 :     if (arePeriodic()) {
     409          88 :       pos = pos - floor((pos-intervalMin())/intervalRange())*intervalRange();
     410             :     }
     411        1861 :     std::string is; Tools::convert(pos, is);
     412        3722 :     setLabel(i,"i="+is);
     413             :   }
     414          47 : }
     415             : 
     416             : 
     417             : }
     418             : }

Generated by: LCOV version 1.16