LCOV - code coverage report
Current view: top level - maze - Optimizer.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 161 179 89.9 %
Date: 2024-10-18 14:00:25 Functions: 9 10 90.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2019 Jakub Rydzewski (jr@fizyka.umk.pl). All rights reserved.
       3             : 
       4             : See http://www.maze-code.github.io for more information.
       5             : 
       6             : This file is part of maze.
       7             : 
       8             : maze is free software: you can redistribute it and/or modify it under the
       9             : terms of the GNU Lesser General Public License as published by the Free
      10             : Software Foundation, either version 3 of the License, or (at your option)
      11             : any later version.
      12             : 
      13             : maze is distributed in the hope that it will be useful, but WITHOUT ANY
      14             : WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
      15             : FOR A PARTICULAR PURPOSE.
      16             : 
      17             : See the GNU Lesser General Public License for more details.
      18             : 
      19             : You should have received a copy of the GNU Lesser General Public License
      20             : along with maze. If not, see <https://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : 
      23             : /**
      24             :  * @file Optimizer.cpp
      25             :  *
      26             :  * @author J. Rydzewski (jr@fizyka.umk.pl)
      27             :  */
      28             : 
      29             : #include "Optimizer.h"
      30             : #include "core/PlumedMain.h"
      31             : #include "tools/Tools.h"
      32             : 
      33             : namespace PLMD {
      34             : namespace maze {
      35             : 
      36          17 : void Optimizer::registerKeywords(Keywords& keys) {
      37          17 :   Colvar::registerKeywords(keys);
      38             : 
      39          34 :   keys.addFlag(
      40             :     "SERIAL",
      41             :     false,
      42             :     "Perform the simulation in serial -- used only for debugging purposes, "
      43             :     "should not be used otherwise."
      44             :   );
      45             : 
      46          34 :   keys.addFlag(
      47             :     "PAIR",
      48             :     false,
      49             :     "Pair only the 1st element of the 1st group with the 1st element in the "
      50             :     "second, etc."
      51             :   );
      52             : 
      53          34 :   keys.addFlag(
      54             :     "NLIST",
      55             :     false,
      56             :     "Use a neighbor list of ligand-protein atom pairs to speed up the "
      57             :     "calculating of the distances."
      58             :   );
      59             : 
      60          34 :   keys.add(
      61             :     "optional",
      62             :     "NL_CUTOFF",
      63             :     "Neighbor list cut-off for the distances of ligand-protein atom pairs."
      64             :   );
      65             : 
      66          34 :   keys.add(
      67             :     "optional",
      68             :     "NL_STRIDE",
      69             :     "Update stride for the ligand-protein atom pairs in the neighbor list."
      70             :   );
      71             : 
      72          34 :   keys.add(
      73             :     "compulsory",
      74             :     "N_ITER",
      75             :     "Number of optimization steps. Required only for optimizers, do not pass "
      76             :     "this keyword to the fake optimizers (results in crash) , e.g., random "
      77             :     "walk, steered MD, or random acceleration MD."
      78             :   );
      79             : 
      80          34 :   keys.add(
      81             :     "optional",
      82             :     "LOSS",
      83             :     "Loss function describing ligand-protein interactions required by every "
      84             :     "optimizer."
      85             :   );
      86             : 
      87          34 :   keys.add(
      88             :     "atoms",
      89             :     "LIGAND",
      90             :     "Indices of ligand atoms."
      91             :   );
      92             : 
      93          34 :   keys.add(
      94             :     "atoms",
      95             :     "PROTEIN",
      96             :     "Indices of protein atoms."
      97             :   );
      98             : 
      99          34 :   keys.add(
     100             :     "compulsory",
     101             :     "OPTIMIZER_STRIDE",
     102             :     "Optimizer stride. Sets up a callback function that launches the "
     103             :     "optimization process every OPTIMIZER_STRIDE."
     104             :   );
     105             : 
     106          34 :   keys.addOutputComponent(
     107             :     "x",
     108             :     "default",
     109             :     "Optimal biasing direction; x component."
     110             :   );
     111             : 
     112          34 :   keys.addOutputComponent(
     113             :     "y",
     114             :     "default",
     115             :     "Optimal biasing direction; y component."
     116             :   );
     117             : 
     118          34 :   keys.addOutputComponent(
     119             :     "z",
     120             :     "default",
     121             :     "Optimal biasing direction; z component."
     122             :   );
     123             : 
     124          34 :   keys.addOutputComponent(
     125             :     "loss",
     126             :     "default",
     127             :     "Loss function value defined by the provided pairing function."
     128             :   );
     129             : 
     130          34 :   keys.addOutputComponent(
     131             :     "sr",
     132             :     "default",
     133             :     "Sampling radius. Reduces sampling to the local proximity of the ligand "
     134             :     "position."
     135             :   );
     136          17 : }
     137             : 
     138           7 : Optimizer::Optimizer(const ActionOptions& ao)
     139             :   : PLUMED_COLVAR_INIT(ao),
     140           7 :     first_step_(true),
     141           7 :     opt_value_(0.0),
     142           7 :     pbc_(true),
     143           7 :     sampling_r_(0.0),
     144           7 :     serial_(false),
     145           7 :     validate_list_(true),
     146           7 :     first_time_(true)
     147             : {
     148           7 :   parseFlag("SERIAL", serial_);
     149             : 
     150          14 :   if (keywords.exists("LOSS")) {
     151           7 :     std::vector<std::string> loss_labels(0);
     152          14 :     parseVector("LOSS", loss_labels);
     153             : 
     154           7 :     plumed_massert(
     155             :       loss_labels.size() > 0,
     156             :       "maze> Something went wrong with the LOSS keyword.\n"
     157             :     );
     158             : 
     159           7 :     std::string error_msg = "";
     160           7 :     vec_loss_ = tls::get_pointers_labels<Loss*>(
     161             :                   loss_labels,
     162           7 :                   plumed.getActionSet(),
     163             :                   error_msg
     164             :                 );
     165             : 
     166           7 :     if (error_msg.size() > 0) {
     167           0 :       plumed_merror(
     168             :         "maze> Error in the LOSS keyword " + getName() + ": " + error_msg
     169             :       );
     170             :     }
     171             : 
     172           7 :     loss_ = vec_loss_[0];
     173           7 :     log.printf("maze> Loss function linked to the optimizer.\n");
     174           7 :   }
     175             : 
     176          14 :   if (keywords.exists("N_ITER")) {
     177           3 :     parse("N_ITER", n_iter_);
     178             : 
     179           3 :     plumed_massert(
     180             :       n_iter_ > 0,
     181             :       "maze> N_ITER should be explicitly specified and positive.\n"
     182             :     );
     183             : 
     184           3 :     log.printf(
     185             :       "maze> Optimizer will run %u iterations once launched.\n",
     186             :       n_iter_
     187             :     );
     188             :   }
     189             : 
     190             :   std::vector<AtomNumber> ga_list, gb_list;
     191           7 :   parseAtomList("LIGAND", ga_list);
     192           7 :   parseAtomList("PROTEIN", gb_list);
     193             : 
     194           7 :   bool nopbc = !pbc_;
     195           7 :   parseFlag("NOPBC", nopbc);
     196             : 
     197           7 :   bool do_pair = false;
     198           7 :   parseFlag("PAIR", do_pair);
     199             : 
     200           7 :   nl_stride_ = 0;
     201           7 :   bool do_neigh = false;
     202           7 :   parseFlag("NLIST", do_neigh);
     203             : 
     204           7 :   if (do_neigh) {
     205          14 :     if (keywords.exists("NL_CUTOFF")) {
     206           7 :       parse("NL_CUTOFF", nl_cutoff_);
     207             : 
     208           7 :       plumed_massert(
     209             :         nl_cutoff_ > 0,
     210             :         "maze> NL_CUTOFF should be explicitly specified and positive.\n"
     211             :       );
     212             :     }
     213             : 
     214          14 :     if (keywords.exists("NL_STRIDE")) {
     215           7 :       parse("NL_STRIDE", nl_stride_);
     216             : 
     217           7 :       plumed_massert(
     218             :         nl_stride_ > 0,
     219             :         "maze> NL_STRIDE should be explicitly specified and positive.\n"
     220             :       );
     221             :     }
     222             :   }
     223             : 
     224           7 :   if (gb_list.size() > 0) {
     225           7 :     if (do_neigh) {
     226           7 :       neighbor_list_ = Tools::make_unique<NeighborList>(
     227             :                          ga_list,
     228             :                          gb_list,
     229             :                          serial_,
     230             :                          do_pair,
     231           7 :                          pbc_,
     232             :                          getPbc(),
     233             :                          comm,
     234           7 :                          nl_cutoff_,
     235           7 :                          nl_stride_
     236             :                        );
     237             :     }
     238             :     else {
     239           0 :       neighbor_list_=Tools::make_unique<NeighborList>(
     240             :                        ga_list,
     241             :                        gb_list,
     242             :                        serial_,
     243             :                        do_pair,
     244           0 :                        pbc_,
     245             :                        getPbc(),
     246             :                        comm
     247             :                      );
     248             :     }
     249             :   }
     250             :   else {
     251           0 :     if (do_neigh) {
     252           0 :       neighbor_list_ = Tools::make_unique<NeighborList>(
     253             :                          ga_list,
     254             :                          serial_,
     255           0 :                          pbc_,
     256             :                          getPbc(),
     257             :                          comm,
     258           0 :                          nl_cutoff_,
     259           0 :                          nl_stride_
     260             :                        );
     261             :     }
     262             :     else {
     263           0 :       neighbor_list_=Tools::make_unique<NeighborList>(
     264             :                        ga_list,
     265             :                        serial_,
     266           0 :                        pbc_,
     267             :                        getPbc(),
     268             :                        comm
     269             :                      );
     270             :     }
     271             :   }
     272             : 
     273           7 :   requestAtoms(neighbor_list_->getFullAtomList());
     274             : 
     275           7 :   log.printf(
     276             :     "maze> Loss will be calculated between two groups of %u and %u atoms.\n",
     277             :     static_cast<unsigned>(ga_list.size()),
     278             :     static_cast<unsigned>(gb_list.size())
     279             :   );
     280             : 
     281           7 :   log.printf(
     282             :     "maze> First group (LIGAND): from %d to %d.\n",
     283             :     ga_list[0].serial(),
     284             :     ga_list[ga_list.size()-1].serial()
     285             :   );
     286             : 
     287           7 :   if (gb_list.size() > 0) {
     288           7 :     log.printf(
     289             :       "maze> Second group (PROTEIN): from %d to %d.\n",
     290             :       gb_list[0].serial(),
     291             :       gb_list[gb_list.size()-1].serial()
     292             :     );
     293             :   }
     294             : 
     295           7 :   if (pbc_) {
     296           7 :     log.printf("maze> Using periodic boundary conditions.\n");
     297             :   }
     298             :   else {
     299           0 :     log.printf("maze> Without periodic boundary conditions.\n");
     300             :   }
     301             : 
     302           7 :   if (do_pair) {
     303           0 :     log.printf("maze> With PAIR option.\n");
     304             :   }
     305             : 
     306           7 :   if (do_neigh) {
     307           7 :     log.printf(
     308             :       "maze> Using neighbor lists updated every %d steps and cutoff %f.\n",
     309             :       nl_stride_,
     310             :       nl_cutoff_
     311             :     );
     312             :   }
     313             : 
     314             :   // OpenMP
     315           7 :   stride_ = comm.Get_size();
     316           7 :   rank_ = comm.Get_rank();
     317             : 
     318           7 :   n_threads_ = OpenMP::getNumThreads();
     319           7 :   unsigned int nn = neighbor_list_->size();
     320             : 
     321           7 :   if (n_threads_ * stride_ * 10 > nn) {
     322           0 :     n_threads_ = nn / stride_ / 10;
     323             :   }
     324             : 
     325           7 :   if (n_threads_ == 0) {
     326           0 :     n_threads_ = 1;
     327             :   }
     328             : 
     329          14 :   if (keywords.exists("OPTIMIZER_STRIDE")) {
     330           7 :     parse("OPTIMIZER_STRIDE", optimizer_stride_);
     331             : 
     332           7 :     plumed_massert(
     333             :       optimizer_stride_,
     334             :       "maze> OPTIMIZER_STRIDE should be explicitly specified and positive.\n"
     335             :     );
     336             : 
     337           7 :     log.printf(
     338             :       "maze> Launching optimization every %u steps.\n",
     339             :       optimizer_stride_
     340             :     );
     341             :   }
     342             : 
     343           7 :   rnd::randomize();
     344             : 
     345           7 :   opt_.zero();
     346             : 
     347          14 :   addComponentWithDerivatives("x");
     348          14 :   componentIsNotPeriodic("x");
     349             : 
     350          14 :   addComponentWithDerivatives("y");
     351          14 :   componentIsNotPeriodic("y");
     352             : 
     353          14 :   addComponentWithDerivatives("z");
     354          14 :   componentIsNotPeriodic("z");
     355             : 
     356          14 :   addComponent("loss");
     357          14 :   componentIsNotPeriodic("loss");
     358             : 
     359          14 :   addComponent("sr");
     360           7 :   componentIsNotPeriodic("sr");
     361             : 
     362           7 :   value_x_ = getPntrToComponent("x");
     363           7 :   value_y_ = getPntrToComponent("y");
     364           7 :   value_z_ = getPntrToComponent("z");
     365           7 :   value_action_ = getPntrToComponent("loss");
     366           7 :   value_sampling_radius_ = getPntrToComponent("sr");
     367           7 : }
     368             : 
     369    16239210 : double Optimizer::pairing(double distance) const {
     370    16239210 :   return loss_->pairing(distance);
     371             : }
     372             : 
     373           6 : Vector Optimizer::center_of_mass() const {
     374           6 :   const unsigned nl_size = neighbor_list_->size();
     375             : 
     376           6 :   Vector center_of_mass;
     377           6 :   center_of_mass.zero();
     378             :   double mass = 0;
     379             : 
     380      189654 :   for (unsigned int i = 0; i < nl_size; ++i) {
     381      189648 :     unsigned int i0 = neighbor_list_->getClosePair(i).first;
     382      189648 :     center_of_mass += getPosition(i0) * getMass(i0);
     383      189648 :     mass += getMass(i0);
     384             :   }
     385             : 
     386           6 :   return center_of_mass / mass;
     387             : }
     388             : 
     389         210 : void Optimizer::prepare() {
     390         210 :   if (neighbor_list_->getStride() > 0) {
     391         210 :     if (first_time_ || (getStep() % neighbor_list_->getStride() == 0)) {
     392           7 :       requestAtoms(neighbor_list_->getFullAtomList());
     393             : 
     394           7 :       validate_list_ = true;
     395           7 :       first_time_ = false;
     396             :     }
     397             :     else {
     398         203 :       requestAtoms(neighbor_list_->getReducedAtomList());
     399             : 
     400         203 :       validate_list_ = false;
     401             : 
     402         203 :       if (getExchangeStep()) {
     403           0 :         plumed_merror(
     404             :           "maze> Neighbor lists should be updated on exchange steps -- choose "
     405             :           "an NL_STRIDE which divides the exchange stride.\n");
     406             :       }
     407             :     }
     408             : 
     409         210 :     if (getExchangeStep()) {
     410           0 :       first_time_ = true;
     411             :     }
     412             :   }
     413         210 : }
     414             : 
     415         226 : double Optimizer::score() {
     416         226 :   const unsigned nl_size = neighbor_list_->size();
     417         226 :   Vector distance;
     418             :   double function = 0;
     419             : 
     420         226 :   #pragma omp parallel num_threads(n_threads_)
     421             :   {
     422             :     #pragma omp for reduction(+:function)
     423             :     for(unsigned int i = 0; i < nl_size; i++) {
     424             :       unsigned i0 = neighbor_list_->getClosePair(i).first;
     425             :       unsigned i1 = neighbor_list_->getClosePair(i).second;
     426             : 
     427             :       if (getAbsoluteIndex(i0) == getAbsoluteIndex(i1)) {
     428             :         continue;
     429             :       }
     430             : 
     431             :       if (pbc_) {
     432             :         distance = pbcDistance(getPosition(i0), getPosition(i1));
     433             :       }
     434             :       else {
     435             :         distance = delta(getPosition(i0), getPosition(i1));
     436             :       }
     437             : 
     438             :       function += pairing(distance.modulo());
     439             :     }
     440             :   }
     441             : 
     442         226 :   return function;
     443             : }
     444             : 
     445         210 : void Optimizer::update_nl() {
     446         210 :   if (neighbor_list_->getStride() > 0 && validate_list_) {
     447           7 :     neighbor_list_->update(getPositions());
     448             :   }
     449         210 : }
     450             : 
     451         367 : double Optimizer::sampling_radius()
     452             : {
     453         367 :   const unsigned nl_size=neighbor_list_->size();
     454         367 :   Vector d;
     455             :   double min=std::numeric_limits<int>::max();
     456             : 
     457     9812323 :   for (unsigned int i = 0; i < nl_size; ++i) {
     458     9811956 :     unsigned i0 = neighbor_list_->getClosePair(i).first;
     459     9811956 :     unsigned i1 = neighbor_list_->getClosePair(i).second;
     460             : 
     461     9811956 :     if (getAbsoluteIndex(i0) == getAbsoluteIndex(i1)) {
     462           0 :       continue;
     463             :     }
     464             : 
     465     9811956 :     if (pbc_) {
     466     9811956 :       d = pbcDistance(getPosition(i0), getPosition(i1));
     467             :     }
     468             :     else {
     469           0 :       d = delta(getPosition(i0), getPosition(i1));
     470             :     }
     471             : 
     472     9811956 :     double dist = d.modulo();
     473             : 
     474     9811956 :     if(dist < min) {
     475             :       min = dist;
     476             :     }
     477             :   }
     478             : 
     479         367 :   return min;
     480             : }
     481             : 
     482         210 : void Optimizer::calculate() {
     483         210 :   update_nl();
     484             : 
     485         210 :   if (getStep() % optimizer_stride_ == 0 && !first_step_) {
     486          19 :     optimize();
     487             : 
     488          19 :     value_x_->set(opt_[0]);
     489          19 :     value_y_->set(opt_[1]);
     490          19 :     value_z_->set(opt_[2]);
     491             : 
     492          19 :     value_action_->set(score());
     493          19 :     value_sampling_radius_->set(sampling_radius());
     494             :   }
     495             :   else {
     496         191 :     first_step_=false;
     497             : 
     498         191 :     value_x_->set(opt_[0]);
     499         191 :     value_y_->set(opt_[1]);
     500         191 :     value_z_->set(opt_[2]);
     501             : 
     502         191 :     value_action_->set(score());
     503         191 :     value_sampling_radius_->set(sampling_radius());
     504             :   }
     505         210 : }
     506             : 
     507             : } // namespace maze
     508             : } // namespace PLMD

Generated by: LCOV version 1.16