LCOV - code coverage report
Current view: top level - isdb - EMMIVox.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 35 739 4.7 %
Date: 2025-03-25 09:33:27 Functions: 1 26 3.8 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : 
      23             : #ifdef __PLUMED_HAS_LIBTORCH
      24             : #include "colvar/Colvar.h"
      25             : #include "core/ActionRegister.h"
      26             : #include "core/PlumedMain.h"
      27             : #include "tools/Communicator.h"
      28             : #include "tools/Matrix.h"
      29             : #include "core/GenericMolInfo.h"
      30             : #include "core/ActionSet.h"
      31             : #include "tools/File.h"
      32             : #include "tools/OpenMP.h"
      33             : #include <string>
      34             : #include <cmath>
      35             : #include <map>
      36             : #include <numeric>
      37             : #include <ctime>
      38             : #include "tools/Random.h"
      39             : 
      40             : #include <torch/torch.h>
      41             : #include <torch/script.h>
      42             : 
      43             : #ifndef M_PI
      44             : #define M_PI           3.14159265358979323846
      45             : #endif
      46             : 
      47             : namespace PLMD {
      48             : namespace isdb {
      49             : 
      50             : //+PLUMEDOC ISDB_COLVAR EMMIVOX
      51             : /*
      52             : Bayesian single-structure and ensemble refinement with cryo-EM maps.
      53             : 
      54             : This action implements the Bayesian approach for single-structure and ensemble refinement from cryo-EM maps introduced <a href="https://www.biorxiv.org/content/10.1101/2023.10.18.562710v1">here</a>.
      55             : EMMIVox does not require fitting the cryo-EM map with a Gaussian Mixture Model, as done in \ref EMMI, but uses directly the voxels in the deposited map.
      56             : 
      57             : When run in single-replica mode, this action allows atomistic, flexible refinement (and B-factors inference) of an individual structure into a density map.
      58             : A coarse-grained forward model can also be used in combination with the MARTINI force field.
      59             : Combined with a multi-replica framework (such as the -multi option in GROMACS), the user can model an ensemble of structures using
      60             : the Metainference approach \cite Bonomi:2016ip . The approach can be used to model continous dynamics of flexible regions as well as semi-ordered waters, lipids, and ions.
      61             : 
      62             : \warning
      63             :     To use EMMIVOX, PLUMED must be linked against the LibTorch library as described \ref ISDB "here"
      64             : 
      65             : \par Examples
      66             : 
      67             : Complete tutorials for single-structure and ensemble refinement can be found <a href="https://github.com/COSBlab/EMMIVox">here</a>.
      68             : 
      69             : */
      70             : //+ENDPLUMEDOC
      71             : 
      72             : class EMMIVOX : public Colvar {
      73             : 
      74             : private:
      75             : 
      76             : // temperature in kbt
      77             :   double kbt_;
      78             : // model - atom types
      79             :   std::vector<unsigned> Model_type_;
      80             : // model - list of atom sigmas - one per atom type
      81             :   std::vector<Vector5d> Model_s_;
      82             : // model - list of atom weights - one per atom type
      83             :   std::vector<Vector5d> Model_w_;
      84             : // model - map between residue/chain IDs and list of atoms
      85             :   std::map< std::pair<unsigned,std::string>, std::vector<unsigned> > Model_resmap_;
      86             : // model - list of residue/chain IDs per atom
      87             :   std::vector< std::pair<unsigned,std::string> > Model_res_;
      88             : // model - list of neighboring voxels per atom
      89             :   std::vector< std::vector<unsigned> > Model_nb_;
      90             : // model - map between residue/chain ID and bfactor
      91             :   std::map< std::pair<unsigned, std::string>, double> Model_b_;
      92             : // model - global list of residue/chain IDs
      93             :   std::vector< std::pair<unsigned,std::string> > Model_rlist_;
      94             : // model density
      95             :   std::vector<double> ovmd_;
      96             : 
      97             : // data map - voxel position
      98             :   std::vector<Vector> Map_m_;
      99             : // data map - density
     100             :   std::vector<double> ovdd_;
     101             : // data map - error
     102             :   std::vector<double> exp_err_;
     103             : 
     104             : // derivatives
     105             :   std::vector<Vector> ovmd_der_;
     106             :   std::vector<Vector> atom_der_;
     107             :   std::vector<double> score_der_;
     108             : // constants
     109             :   double inv_sqrt2_, sqrt2_pi_, inv_pi2_;
     110             :   std::vector<Vector5d> pref_;
     111             :   std::vector<Vector5d> invs2_;
     112             :   std::vector<Vector5d> cfact_;
     113             :   std::vector<double> cut_;
     114             : // metainference
     115             :   unsigned nrep_;
     116             :   unsigned replica_;
     117             :   std::vector<double> ismin_;
     118             : // neighbor list
     119             :   double nl_dist_cutoff_;
     120             :   double nl_gauss_cutoff_;
     121             :   unsigned nl_stride_;
     122             :   bool first_time_;
     123             :   std::vector< std::pair<unsigned,unsigned> > nl_;
     124             :   std::vector< std::pair<unsigned,unsigned> > ns_;
     125             :   std::vector<Vector> refpos_;
     126             : // averaging
     127             :   bool no_aver_;
     128             : // correlation;
     129             :   bool do_corr_;
     130             : // Monte Carlo stuff
     131             :   Random   random_;
     132             :   // Scale and Offset
     133             :   double scale_;
     134             :   double offset_;
     135             : // Bfact Monte Carlo
     136             :   double   dbfact_;
     137             :   double   bfactmin_;
     138             :   double   bfactmax_;
     139             :   double   bfactsig_;
     140             :   bool     bfactnoc_;
     141             :   bool     bfactread_;
     142             :   int      MCBstride_;
     143             :   double   MCBaccept_;
     144             :   double   MCBtrials_;
     145             : // residue neighbor list
     146             :   std::vector< std::vector<unsigned> > nl_res_;
     147             :   bool bfactemin_;
     148             : // Martini scattering factors
     149             :   bool martini_;
     150             :   // status stuff
     151             :   unsigned int statusstride_;
     152             :   std::string       statusfilename_;
     153             :   OFile        statusfile_;
     154             :   bool         first_status_;
     155             :   // total energy and virial
     156             :   double ene_;
     157             :   Tensor virial_;
     158             :   double eps_;
     159             :   // model density file
     160             :   unsigned int mapstride_;
     161             :   std::string       mapfilename_;
     162             :   // Libtorch stuff
     163             :   bool gpu_;
     164             :   torch::Tensor ovmd_gpu_;
     165             :   torch::Tensor ovmd_der_gpu_;
     166             :   torch::Tensor ismin_gpu_;
     167             :   torch::Tensor ovdd_gpu_;
     168             :   torch::Tensor Map_m_gpu_;
     169             :   torch::Tensor pref_gpu_;
     170             :   torch::Tensor invs2_gpu_;
     171             :   torch::Tensor nl_id_gpu_;
     172             :   torch::Tensor nl_im_gpu_;
     173             :   torch::Tensor pref_nl_gpu_;
     174             :   torch::Tensor invs2_nl_gpu_;
     175             :   torch::Tensor Map_m_nl_gpu_;
     176             :   torch::DeviceType device_t_;
     177             : //
     178             : // write file with model density
     179             :   void write_model_density(long int step);
     180             : // get median of vector
     181             :   double get_median(std::vector<double> v);
     182             : // read and write status
     183             :   void read_status();
     184             :   void print_status(long int step);
     185             : // accept or reject
     186             :   bool doAccept(double oldE, double newE, double kbt);
     187             : // vector of close residues
     188             :   void get_close_residues();
     189             : // do MonteCarlo for Bfactor
     190             :   void doMonteCarloBfact();
     191             : // calculate model parameters
     192             :   std::vector<double> get_Model_param(std::vector<AtomNumber> &atoms);
     193             : // read data file
     194             :   void get_exp_data(const std::string &datafile);
     195             : // auxiliary methods
     196             :   void prepare_gpu();
     197             :   void initialize_Bfactor(double reso);
     198             :   void get_auxiliary_vectors();
     199             :   void push_auxiliary_gpu();
     200             : // calculate overlap between two Gaussians
     201             :   double get_overlap(const Vector &d_m, const Vector &m_m,
     202             :                      const Vector5d &cfact, const Vector5d &m_s, double bfact);
     203             : // update the neighbor list
     204             :   void update_neighbor_list();
     205             : // update data on device
     206             :   void update_gpu();
     207             : // update the neighbor sphere
     208             :   void update_neighbor_sphere();
     209             :   bool do_neighbor_sphere();
     210             : // calculate forward model and score on device
     211             :   void calculate_fmod();
     212             :   void calculate_score();
     213             : // calculate correlation
     214             :   void calculate_corr();
     215             : 
     216             : public:
     217             :   static void registerKeywords( Keywords& keys );
     218             :   explicit EMMIVOX(const ActionOptions&);
     219             : // active methods:
     220             :   void prepare() override;
     221             :   void calculate() override;
     222             : };
     223             : 
     224             : PLUMED_REGISTER_ACTION(EMMIVOX,"EMMIVOX")
     225             : 
     226           2 : void EMMIVOX::registerKeywords( Keywords& keys ) {
     227           2 :   Colvar::registerKeywords( keys );
     228           2 :   keys.add("atoms","ATOMS","atoms used in the calculation of the density map, typically all heavy atoms");
     229           2 :   keys.add("compulsory","DATA_FILE","file with cryo-EM map");
     230           2 :   keys.add("compulsory","RESOLUTION", "cryo-EM map resolution");
     231           2 :   keys.add("compulsory","NORM_DENSITY","integral of experimental density");
     232           2 :   keys.add("compulsory","WRITE_STRIDE","stride for writing status file");
     233           2 :   keys.add("optional","NL_DIST_CUTOFF","neighbor list distance cutoff");
     234           2 :   keys.add("optional","NL_GAUSS_CUTOFF","neighbor list Gaussian sigma cutoff");
     235           2 :   keys.add("optional","NL_STRIDE","neighbor list update frequency");
     236           2 :   keys.add("optional","SIGMA_MIN","minimum density error");
     237           2 :   keys.add("optional","DBFACT","Bfactor MC step");
     238           2 :   keys.add("optional","BFACT_MAX","Bfactor maximum value");
     239           2 :   keys.add("optional","MCBFACT_STRIDE", "Bfactor MC stride");
     240           2 :   keys.add("optional","BFACT_SIGMA","Bfactor sigma prior");
     241           2 :   keys.add("optional","STATUS_FILE","write a file with all the data useful for restart");
     242           2 :   keys.add("optional","SCALE","scale factor");
     243           2 :   keys.add("optional","OFFSET","offset");
     244           2 :   keys.add("optional","TEMP","temperature");
     245           2 :   keys.add("optional","WRITE_MAP","file with model density");
     246           2 :   keys.add("optional","WRITE_MAP_STRIDE","stride for writing model density to file");
     247           2 :   keys.addFlag("NO_AVER",false,"no ensemble averaging in multi-replica mode");
     248           2 :   keys.addFlag("CORRELATION",false,"calculate correlation coefficient");
     249           2 :   keys.addFlag("GPU",false,"calculate EMMIVOX on GPU with Libtorch");
     250           2 :   keys.addFlag("BFACT_NOCHAIN",false,"Do not use chain ID for Bfactor MC");
     251           2 :   keys.addFlag("BFACT_READ",false,"Read Bfactor on RESTART (automatic with DBFACT>0)");
     252           2 :   keys.addFlag("BFACT_MINIMIZE",false,"Accept only moves that decrease energy");
     253           2 :   keys.addFlag("MARTINI",false,"Use Martini scattering factors");
     254           4 :   keys.addOutputComponent("scoreb","default","scalar","Bayesian score");
     255           4 :   keys.addOutputComponent("scale", "default","scalar","scale factor");
     256           4 :   keys.addOutputComponent("offset","default","scalar","offset");
     257           4 :   keys.addOutputComponent("accB",  "default","scalar", "Bfactor MC acceptance");
     258           4 :   keys.addOutputComponent("kbt",   "default","scalar", "temperature in energy unit");
     259           4 :   keys.addOutputComponent("corr",  "CORRELATION","scalar", "correlation coefficient");
     260           2 : }
     261             : 
     262           0 : EMMIVOX::EMMIVOX(const ActionOptions&ao):
     263             :   PLUMED_COLVAR_INIT(ao),
     264           0 :   nl_dist_cutoff_(1.0), nl_gauss_cutoff_(3.0), nl_stride_(50),
     265           0 :   first_time_(true), no_aver_(false), do_corr_(false),
     266           0 :   scale_(1.), offset_(0.),
     267           0 :   dbfact_(0.0), bfactmin_(0.05), bfactmax_(5.0),
     268           0 :   bfactsig_(0.1), bfactnoc_(false), bfactread_(false),
     269           0 :   MCBstride_(1), MCBaccept_(0.), MCBtrials_(0.), bfactemin_(false),
     270           0 :   martini_(false), statusstride_(0), first_status_(true),
     271           0 :   eps_(0.0001), mapstride_(0), gpu_(false) {
     272             :   // set constants
     273           0 :   inv_sqrt2_ = 1.0/sqrt(2.0);
     274           0 :   sqrt2_pi_  = sqrt(2.0 / M_PI);
     275           0 :   inv_pi2_   = 0.5 / M_PI / M_PI;
     276             : 
     277             :   // list of atoms
     278             :   std::vector<AtomNumber> atoms;
     279           0 :   parseAtomList("ATOMS", atoms);
     280             : 
     281             :   // file with experimental cryo-EM map
     282             :   std::string datafile;
     283           0 :   parse("DATA_FILE", datafile);
     284             : 
     285             :   // neighbor list cutoffs
     286           0 :   parse("NL_DIST_CUTOFF",nl_dist_cutoff_);
     287           0 :   parse("NL_GAUSS_CUTOFF",nl_gauss_cutoff_);
     288             :   // checks
     289           0 :   if(nl_dist_cutoff_<=0. && nl_gauss_cutoff_<=0.) {
     290           0 :     error("You must specify either NL_DIST_CUTOFF or NL_GAUSS_CUTOFF or both");
     291             :   }
     292           0 :   if(nl_gauss_cutoff_<=0.) {
     293           0 :     nl_gauss_cutoff_ = 1.0e+10;
     294             :   }
     295           0 :   if(nl_dist_cutoff_<=0.) {
     296           0 :     nl_dist_cutoff_ = 1.0e+10;
     297             :   }
     298             :   // neighbor list update stride
     299           0 :   parse("NL_STRIDE",nl_stride_);
     300           0 :   if(nl_stride_<=0) {
     301           0 :     error("NL_STRIDE must be explicitly specified and positive");
     302             :   }
     303             : 
     304             :   // minimum value for error
     305           0 :   double sigma_min = 0.2;
     306           0 :   parse("SIGMA_MIN", sigma_min);
     307           0 :   if(sigma_min<0.) {
     308           0 :     error("SIGMA_MIN must be greater or equal to zero");
     309             :   }
     310             : 
     311             :   // status file parameters
     312           0 :   parse("WRITE_STRIDE", statusstride_);
     313           0 :   if(statusstride_<=0) {
     314           0 :     error("you must specify a positive WRITE_STRIDE");
     315             :   }
     316           0 :   parse("STATUS_FILE",  statusfilename_);
     317           0 :   if(statusfilename_=="") {
     318           0 :     statusfilename_ = "EMMIStatus"+getLabel();
     319             :   }
     320             : 
     321             :   // integral of the experimetal density
     322             :   double norm_d;
     323           0 :   parse("NORM_DENSITY", norm_d);
     324             : 
     325             :   // temperature
     326           0 :   kbt_ = getkBT();
     327             : 
     328             :   // scale and offset
     329           0 :   parse("SCALE", scale_);
     330           0 :   parse("OFFSET",offset_);
     331             : 
     332             :   // B-factors MC
     333           0 :   parse("DBFACT",dbfact_);
     334             :   // read Bfactors
     335           0 :   parseFlag("BFACT_READ",bfactread_);
     336             :   // do not use chains
     337           0 :   parseFlag("BFACT_NOCHAIN",bfactnoc_);
     338             :   // other parameters
     339           0 :   if(dbfact_>0.) {
     340           0 :     parse("MCBFACT_STRIDE",MCBstride_);
     341           0 :     parse("BFACT_MAX",bfactmax_);
     342           0 :     parse("BFACT_SIGMA",bfactsig_);
     343           0 :     parseFlag("BFACT_MINIMIZE",bfactemin_);
     344             :     // checks
     345           0 :     if(MCBstride_<=0) {
     346           0 :       error("you must specify a positive MCBFACT_STRIDE");
     347             :     }
     348           0 :     if(bfactmax_<=bfactmin_) {
     349           0 :       error("you must specify a positive BFACT_MAX");
     350             :     }
     351           0 :     if(MCBstride_%nl_stride_!=0) {
     352           0 :       error("MCBFACT_STRIDE must be multiple of NL_STRIDE");
     353             :     }
     354           0 :     if(bfactsig_<=0.) {
     355           0 :       error("you must specify a positive BFACT_SIGMA");
     356             :     }
     357             :   }
     358             : 
     359             :   // read map resolution
     360             :   double reso;
     361           0 :   parse("RESOLUTION", reso);
     362           0 :   if(reso<=0.) {
     363           0 :     error("RESOLUTION must be strictly positive");
     364             :   }
     365             : 
     366             :   // averaging or not
     367           0 :   parseFlag("NO_AVER",no_aver_);
     368             : 
     369             :   // calculate correlation coefficient
     370           0 :   parseFlag("CORRELATION",do_corr_);
     371             : 
     372             :   // write density file
     373           0 :   parse("WRITE_MAP_STRIDE", mapstride_);
     374           0 :   parse("WRITE_MAP", mapfilename_);
     375           0 :   if(mapstride_>0 && mapfilename_=="") {
     376           0 :     error("With WRITE_MAP_STRIDE you must specify WRITE_MAP");
     377             :   }
     378             : 
     379             :   // use GPU?
     380           0 :   parseFlag("GPU",gpu_);
     381             :   // set device
     382           0 :   if (gpu_ && torch::cuda::is_available()) {
     383           0 :     device_t_ = torch::kCUDA;
     384             :   } else {
     385           0 :     device_t_ = torch::kCPU;
     386           0 :     gpu_ = false;
     387             :   }
     388             : 
     389             : // Martini model
     390           0 :   parseFlag("MARTINI",martini_);
     391             : 
     392             :   // check read
     393           0 :   checkRead();
     394             : 
     395             :   // set parallel stuff
     396           0 :   unsigned mpisize=comm.Get_size();
     397           0 :   if(mpisize>1) {
     398           0 :     error("EMMIVOX supports only OpenMP parallelization");
     399             :   }
     400             : 
     401             :   // get number of replicas
     402           0 :   if(no_aver_) {
     403           0 :     nrep_ = 1;
     404             :   } else {
     405           0 :     nrep_ = multi_sim_comm.Get_size();
     406             :   }
     407           0 :   replica_ = multi_sim_comm.Get_rank();
     408             : 
     409           0 :   if(nrep_>1 && dbfact_>0) {
     410           0 :     error("Bfactor sampling not supported with ensemble averaging");
     411             :   }
     412             : 
     413           0 :   log.printf("  number of atoms involved : %u\n", atoms.size());
     414           0 :   log.printf("  experimental density map : %s\n", datafile.c_str());
     415           0 :   if(no_aver_) {
     416           0 :     log.printf("  without ensemble averaging\n");
     417             :   }
     418           0 :   if(gpu_) {
     419           0 :     log.printf("  running on GPU \n");
     420             :   } else {
     421           0 :     log.printf("  running on CPU \n");
     422             :   }
     423           0 :   if(nl_dist_cutoff_ <1.0e+10) {
     424           0 :     log.printf("  neighbor list distance cutoff : %lf\n", nl_dist_cutoff_);
     425             :   }
     426           0 :   if(nl_gauss_cutoff_<1.0e+10) {
     427           0 :     log.printf("  neighbor list Gaussian sigma cutoff : %lf\n", nl_gauss_cutoff_);
     428             :   }
     429           0 :   log.printf("  neighbor list update stride : %u\n",  nl_stride_);
     430           0 :   log.printf("  minimum density error : %f\n", sigma_min);
     431           0 :   log.printf("  scale factor : %lf\n", scale_);
     432           0 :   log.printf("  offset : %lf\n", offset_);
     433           0 :   log.printf("  reading/writing to status file : %s\n", statusfilename_.c_str());
     434           0 :   log.printf("  with stride : %u\n", statusstride_);
     435           0 :   if(dbfact_>0) {
     436           0 :     log.printf("  maximum Bfactor MC move : %f\n", dbfact_);
     437           0 :     log.printf("  stride MC move : %u\n", MCBstride_);
     438           0 :     log.printf("  using prior with sigma : %f\n", bfactsig_);
     439             :   }
     440           0 :   if(bfactread_) {
     441           0 :     log.printf("  reading Bfactors from file : %s\n", statusfilename_.c_str());
     442             :   }
     443           0 :   log.printf("  temperature of the system in energy unit : %f\n", kbt_);
     444           0 :   if(nrep_>1) {
     445           0 :     log.printf("  number of replicas for averaging: %u\n", nrep_);
     446           0 :     log.printf("  replica ID : %u\n", replica_);
     447             :   }
     448           0 :   if(mapstride_>0) {
     449           0 :     log.printf("  writing model density to file : %s\n", mapfilename_.c_str());
     450           0 :     log.printf("  with stride : %u\n", mapstride_);
     451             :   }
     452           0 :   if(martini_) {
     453           0 :     log.printf("  using Martini scattering factors\n");
     454             :   }
     455             : 
     456             :   // calculate model constant parameters
     457           0 :   std::vector<double> Model_w = get_Model_param(atoms);
     458             : 
     459             :   // read experimental map and errors
     460           0 :   get_exp_data(datafile);
     461           0 :   log.printf("  number of voxels : %u\n", static_cast<unsigned>(ovdd_.size()));
     462             : 
     463             :   // normalize atom weight map
     464             :   double norm_m = accumulate(Model_w.begin(),  Model_w.end(),  0.0);
     465             :   // renormalization and constant factor on atom types
     466           0 :   for(unsigned i=0; i<Model_w_.size(); ++i) {
     467           0 :     Vector5d cf;
     468           0 :     for(unsigned j=0; j<5; ++j) {
     469           0 :       Model_w_[i][j] *= norm_d / norm_m;
     470           0 :       cf[j] = Model_w_[i][j]/pow( 2.0*pi, 1.5 );
     471             :     }
     472           0 :     cfact_.push_back(cf);
     473             :   }
     474             : 
     475             :   // median density
     476           0 :   double ovdd_m = get_median(ovdd_);
     477             :   // median experimental error
     478           0 :   double err_m  = get_median(exp_err_);
     479             :   // minimum error
     480           0 :   double minerr = sigma_min*ovdd_m;
     481             :   // print out statistics
     482           0 :   log.printf("     median density : %lf\n", ovdd_m);
     483           0 :   log.printf("     minimum error  : %lf\n", minerr);
     484           0 :   log.printf("     median error   : %lf\n", err_m);
     485             :   // populate ismin: cycle on all voxels
     486           0 :   for(unsigned id=0; id<ovdd_.size(); ++id) {
     487             :     // define smin
     488           0 :     double smin = std::max(minerr, exp_err_[id]);
     489             :     // and to ismin_
     490           0 :     ismin_.push_back(1.0/smin);
     491             :   }
     492             : 
     493             :   // prepare gpu stuff: map centers, data, and error
     494           0 :   prepare_gpu();
     495             : 
     496             :   // initialize Bfactors
     497           0 :   initialize_Bfactor(reso);
     498             : 
     499             :   // read status file if restarting
     500           0 :   if(getRestart() || bfactread_) {
     501           0 :     read_status();
     502             :   }
     503             : 
     504             :   // prepare auxiliary vectors
     505           0 :   get_auxiliary_vectors();
     506             : 
     507             :   // prepare other vectors: data and derivatives
     508           0 :   ovmd_.resize(ovdd_.size());
     509           0 :   atom_der_.resize(Model_type_.size());
     510           0 :   score_der_.resize(ovdd_.size());
     511             : 
     512             :   // add components
     513           0 :   addComponentWithDerivatives("scoreb");
     514           0 :   componentIsNotPeriodic("scoreb");
     515           0 :   addComponent("scale");
     516           0 :   componentIsNotPeriodic("scale");
     517           0 :   addComponent("offset");
     518           0 :   componentIsNotPeriodic("offset");
     519           0 :   addComponent("kbt");
     520           0 :   componentIsNotPeriodic("kbt");
     521           0 :   if(dbfact_>0)   {
     522           0 :     addComponent("accB");
     523           0 :     componentIsNotPeriodic("accB");
     524             :   }
     525           0 :   if(do_corr_)    {
     526           0 :     addComponent("corr");
     527           0 :     componentIsNotPeriodic("corr");
     528             :   }
     529             : 
     530             :   // initialize random seed
     531           0 :   unsigned iseed = time(NULL)+replica_;
     532           0 :   random_.setSeed(-iseed);
     533             : 
     534             :   // request atoms
     535           0 :   requestAtoms(atoms);
     536             : 
     537             :   // print bibliography
     538           0 :   log<<"  Bibliography "<<plumed.cite("Bonomi, Camilloni, Bioinformatics, 33, 3999 (2017)");
     539           0 :   log<<plumed.cite("Hoff, Thomasen, Lindorff-Larsen, Bonomi, bioRxiv (2023) doi: 10.1101/2023.10.18.562710");
     540           0 :   if(!no_aver_ && nrep_>1) {
     541           0 :     log<<plumed.cite("Bonomi, Camilloni, Cavalli, Vendruscolo, Sci. Adv. 2, e150117 (2016)");
     542             :   }
     543           0 :   log<<"\n";
     544           0 : }
     545             : 
     546           0 : void EMMIVOX::prepare_gpu() {
     547             :   // number of data points
     548           0 :   int nd = ovdd_.size();
     549             :   // 1) put ismin_ on device_t_
     550           0 :   ismin_gpu_ = torch::from_blob(ismin_.data(), {nd}, torch::kFloat64).to(torch::kFloat32).to(device_t_);
     551             :   // 2) put ovdd_ on device_t_
     552           0 :   ovdd_gpu_  = torch::from_blob(ovdd_.data(),  {nd}, torch::kFloat64).to(torch::kFloat32).to(device_t_);
     553             :   // 3) put Map_m_ on device_t_
     554           0 :   std::vector<double> Map_m_gpu(3*nd);
     555           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
     556             :   for(int i=0; i<nd; ++i) {
     557             :     Map_m_gpu[i]      = Map_m_[i][0];
     558             :     Map_m_gpu[i+nd]   = Map_m_[i][1];
     559             :     Map_m_gpu[i+2*nd] = Map_m_[i][2];
     560             :   }
     561             :   // libtorch tensor
     562           0 :   Map_m_gpu_ = torch::from_blob(Map_m_gpu.data(), {3,nd}, torch::kFloat64).clone().to(torch::kFloat32).to(device_t_);
     563           0 : }
     564             : 
     565           0 : void EMMIVOX::write_model_density(long int step) {
     566           0 :   OFile ovfile;
     567           0 :   ovfile.link(*this);
     568             :   std::string num;
     569           0 :   Tools::convert(step,num);
     570           0 :   std::string name = mapfilename_+"-"+num;
     571           0 :   ovfile.open(name);
     572             :   ovfile.setHeavyFlush();
     573           0 :   ovfile.fmtField("%10.7e ");
     574             : // write density
     575           0 :   for(unsigned i=0; i<ovmd_.size(); ++i) {
     576           0 :     ovfile.printField("Model", ovmd_[i]);
     577           0 :     ovfile.printField("ModelScaled", scale_ * ovmd_[i] + offset_);
     578           0 :     ovfile.printField("Data", ovdd_[i]);
     579           0 :     ovfile.printField();
     580             :   }
     581           0 :   ovfile.close();
     582           0 : }
     583             : 
     584           0 : double EMMIVOX::get_median(std::vector<double> v) {
     585             : // dimension of vector
     586           0 :   unsigned size = v.size();
     587             : // in case of only one entry
     588           0 :   if (size==1) {
     589           0 :     return v[0];
     590             :   } else {
     591             :     // reorder vector
     592           0 :     sort(v.begin(), v.end());
     593             :     // odd or even?
     594           0 :     if (size%2==0) {
     595           0 :       return (v[size/2-1]+v[size/2])/2.0;
     596             :     } else {
     597           0 :       return v[size/2];
     598             :     }
     599             :   }
     600             : }
     601             : 
     602           0 : void EMMIVOX::read_status() {
     603             :   double MDtime;
     604             : // open file
     605           0 :   IFile *ifile = new IFile();
     606           0 :   ifile->link(*this);
     607           0 :   if(ifile->FileExist(statusfilename_)) {
     608           0 :     ifile->open(statusfilename_);
     609           0 :     while(ifile->scanField("MD_time", MDtime)) {
     610             :       // read scale and offset
     611           0 :       ifile->scanField("scale", scale_);
     612           0 :       ifile->scanField("offset", offset_);
     613             :       // read bfactors if doing fitting of reading it at restart
     614           0 :       if(dbfact_>0 || bfactread_) {
     615             :         // cycle on residues
     616           0 :         for(unsigned ir=0; ir<Model_rlist_.size(); ++ir) {
     617             :           // key: pair of residue/chain IDs
     618             :           std::pair<unsigned,std::string> key = Model_rlist_[ir];
     619             :           // convert ires to std::string
     620             :           std::string num;
     621           0 :           Tools::convert(key.first,num);
     622             :           // read entry
     623           0 :           std::string ch = key.second;
     624           0 :           if(ch==" ") {
     625             :             ch="";
     626             :           }
     627           0 :           ifile->scanField("bf-"+num+":"+ch, Model_b_[key]);
     628             :         }
     629             :       }
     630             :       // new line
     631           0 :       ifile->scanField();
     632             :     }
     633           0 :     ifile->close();
     634             :   } else {
     635           0 :     error("Cannot find status file "+statusfilename_+"\n");
     636             :   }
     637           0 :   delete ifile;
     638           0 : }
     639             : 
     640           0 : void EMMIVOX::print_status(long int step) {
     641             : // if first time open the file
     642           0 :   if(first_status_) {
     643           0 :     first_status_ = false;
     644           0 :     statusfile_.link(*this);
     645           0 :     statusfile_.open(statusfilename_);
     646             :     statusfile_.setHeavyFlush();
     647           0 :     statusfile_.fmtField("%10.7e ");
     648             :   }
     649             : // write fields
     650           0 :   double MDtime = static_cast<double>(step)*getTimeStep();
     651           0 :   statusfile_.printField("MD_time", MDtime);
     652             :   // write scale and offset
     653           0 :   statusfile_.printField("scale", scale_);
     654           0 :   statusfile_.printField("offset", offset_);
     655             :   // write bfactors only if doing fitting or reading bfactors
     656           0 :   if(dbfact_>0 || bfactread_) {
     657             :     // cycle on residues
     658           0 :     for(unsigned ir=0; ir<Model_rlist_.size(); ++ir) {
     659             :       // key: pair of residue/chain IDs
     660             :       std::pair<unsigned,std::string> key = Model_rlist_[ir];
     661             :       // bfactor from map
     662           0 :       double bf = Model_b_[key];
     663             :       // convert ires to std::string
     664             :       std::string num;
     665           0 :       Tools::convert(key.first,num);
     666             :       // print entry
     667           0 :       statusfile_.printField("bf-"+num+":"+key.second, bf);
     668             :     }
     669             :   }
     670           0 :   statusfile_.printField();
     671           0 : }
     672             : 
     673           0 : bool EMMIVOX::doAccept(double oldE, double newE, double kbt) {
     674             :   bool accept = false;
     675             :   // calculate delta energy
     676           0 :   double delta = ( newE - oldE ) / kbt;
     677             :   // if delta is negative always accept move
     678           0 :   if( delta < 0.0 ) {
     679             :     accept = true;
     680             :   } else {
     681             :     // otherwise extract random number
     682           0 :     double s = random_.RandU01();
     683           0 :     if( s < exp(-delta) ) {
     684             :       accept = true;
     685             :     }
     686             :   }
     687           0 :   return accept;
     688             : }
     689             : 
     690           0 : std::vector<double> EMMIVOX::get_Model_param(std::vector<AtomNumber> &atoms) {
     691             :   // check if MOLINFO line is present
     692           0 :   auto* moldat=plumed.getActionSet().selectLatest<GenericMolInfo*>(this);
     693           0 :   if(!moldat) {
     694           0 :     error("MOLINFO DATA not found\n");
     695             :   }
     696           0 :   log<<"  MOLINFO DATA found with label " <<moldat->getLabel()<<", using proper atom names\n";
     697             : 
     698             :   // list of weights - one per atom
     699             :   std::vector<double> Model_w;
     700             :   // 5-Gaussians parameters
     701             :   // map of atom types to A and B coefficients of scattering factor
     702             :   // f(s) = A * exp(-B*s**2)
     703             :   // B is in Angstrom squared
     704             :   // Elastic atomic scattering factors of electrons for neutral atoms
     705             :   // and s up to 6.0 A^-1: as implemented in PLUMED
     706             :   // map between an atom type and an index
     707             :   std::map<std::string, unsigned> type_map;
     708             :   // atomistic types
     709           0 :   type_map["C"]=0;
     710           0 :   type_map["O"]=1;
     711           0 :   type_map["N"]=2;
     712           0 :   type_map["S"]=3;
     713           0 :   type_map["P"]=4;
     714           0 :   type_map["F"]=5;
     715           0 :   type_map["NA"]=6;
     716           0 :   type_map["MG"]=7;
     717           0 :   type_map["CL"]=8;
     718           0 :   type_map["CA"]=9;
     719           0 :   type_map["K"]=10;
     720           0 :   type_map["ZN"]=11;
     721             :   // Martini types
     722           0 :   type_map["ALA_BB"]=12;
     723           0 :   type_map["ALA_SC1"]=13;
     724           0 :   type_map["CYS_BB"]=14;
     725           0 :   type_map["CYS_SC1"]=15;
     726           0 :   type_map["ASP_BB"]=16;
     727           0 :   type_map["ASP_SC1"]=17;
     728           0 :   type_map["GLU_BB"]=18;
     729           0 :   type_map["GLU_SC1"]=19;
     730           0 :   type_map["PHE_BB"]=20;
     731           0 :   type_map["PHE_SC1"]=21;
     732           0 :   type_map["PHE_SC2"]=22;
     733           0 :   type_map["PHE_SC3"]=23;
     734           0 :   type_map["GLY_BB"]=24;
     735           0 :   type_map["HIS_BB"]=25;
     736           0 :   type_map["HIS_SC1"]=26;
     737           0 :   type_map["HIS_SC2"]=27;
     738           0 :   type_map["HIS_SC3"]=28;
     739           0 :   type_map["ILE_BB"]=29;
     740           0 :   type_map["ILE_SC1"]=30;
     741           0 :   type_map["LYS_BB"]=31;
     742           0 :   type_map["LYS_SC1"]=32;
     743           0 :   type_map["LYS_SC2"]=33;
     744           0 :   type_map["LEU_BB"]=34;
     745           0 :   type_map["LEU_SC1"]=35;
     746           0 :   type_map["MET_BB"]=36;
     747           0 :   type_map["MET_SC1"]=37;
     748           0 :   type_map["ASN_BB"]=38;
     749           0 :   type_map["ASN_SC1"]=39;
     750           0 :   type_map["PRO_BB"]=40;
     751           0 :   type_map["PRO_SC1"]=41;
     752           0 :   type_map["GLN_BB"]=42;
     753           0 :   type_map["GLN_SC1"]=43;
     754           0 :   type_map["ARG_BB"]=44;
     755           0 :   type_map["ARG_SC1"]=45;
     756           0 :   type_map["ARG_SC2"]=46;
     757           0 :   type_map["SER_BB"]=47;
     758           0 :   type_map["SER_SC1"]=48;
     759           0 :   type_map["THR_BB"]=49;
     760           0 :   type_map["THR_SC1"]=50;
     761           0 :   type_map["VAL_BB"]=51;
     762           0 :   type_map["VAL_SC1"]=52;
     763           0 :   type_map["TRP_BB"]=53;
     764           0 :   type_map["TRP_SC1"]=54;
     765           0 :   type_map["TRP_SC2"]=55;
     766           0 :   type_map["TRP_SC3"]=56;
     767           0 :   type_map["TRP_SC4"]=57;
     768           0 :   type_map["TRP_SC5"]=58;
     769           0 :   type_map["TYR_BB"]=59;
     770           0 :   type_map["TYR_SC1"]=60;
     771           0 :   type_map["TYR_SC2"]=61;
     772           0 :   type_map["TYR_SC3"]=62;
     773           0 :   type_map["TYR_SC4"]=63;
     774             :   // fill in sigma vector for atoms
     775           0 :   Model_s_.push_back(0.01*Vector5d(0.114,1.0825,5.4281,17.8811,51.1341));   // C
     776           0 :   Model_s_.push_back(0.01*Vector5d(0.0652,0.6184,2.9449,9.6298,28.2194));   // O
     777           0 :   Model_s_.push_back(0.01*Vector5d(0.0541,0.5165,2.8207,10.6297,34.3764));  // N
     778           0 :   Model_s_.push_back(0.01*Vector5d(0.0838,0.7788,4.3462,15.5846,44.63655)); // S
     779           0 :   Model_s_.push_back(0.01*Vector5d(0.0977,0.9084,4.9654,18.5471,54.3648));  // P
     780           0 :   Model_s_.push_back(0.01*Vector5d(0.0613,0.5753,2.6858,8.8214,25.6668));   // F
     781           0 :   Model_s_.push_back(0.01*Vector5d(0.1684,1.7150,8.8386,50.8265,147.2073)); // NA
     782           0 :   Model_s_.push_back(0.01*Vector5d(0.1356,1.3579,6.9255,32.3165,92.1138));  // MG
     783           0 :   Model_s_.push_back(0.01*Vector5d(0.0694,0.6443,3.5351,12.5058,35.8633));  // CL
     784           0 :   Model_s_.push_back(0.01*Vector5d(0.1742,1.8329,8.8407,47.4583,134.9613)); // CA
     785           0 :   Model_s_.push_back(0.01*Vector5d(0.1660,1.6906,8.7447,46.7825,165.6923)); // K
     786           0 :   Model_s_.push_back(0.01*Vector5d(0.0876,0.8650,3.8612,18.8726,64.7016));  // ZN
     787             :   // fill in sigma vector for Martini beads
     788           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // ALA_BB
     789           0 :   Model_s_.push_back(0.01*Vector5d(0.500000,0.500000,0.500000,0.500000,0.500000)); // ALA_SC1
     790           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // CYS_BB
     791           0 :   Model_s_.push_back(0.01*Vector5d(8.500000,8.500000,8.500000,8.500000,8.500000)); // CYS_SC1
     792           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // ASP_BB
     793           0 :   Model_s_.push_back(0.01*Vector5d(17.000000,17.000000,17.000000,17.000000,17.000000)); // ASP_SC1
     794           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // GLU_BB
     795           0 :   Model_s_.push_back(0.01*Vector5d(24.000000,24.000000,24.000000,24.000000,24.000000)); // GLU_SC1
     796           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // PHE_BB
     797           0 :   Model_s_.push_back(0.01*Vector5d(17.500000,17.500000,17.500000,17.500000,17.500000)); // PHE_SC1
     798           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // PHE_SC2
     799           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // PHE_SC3
     800           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // GLY_BB
     801           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // HIS_BB
     802           0 :   Model_s_.push_back(0.01*Vector5d(11.500000,11.500000,11.500000,11.500000,11.500000)); // HIS_SC1
     803           0 :   Model_s_.push_back(0.01*Vector5d(9.000000,9.000000,9.000000,9.000000,9.000000)); // HIS_SC2
     804           0 :   Model_s_.push_back(0.01*Vector5d(8.500000,8.500000,8.500000,8.500000,8.500000)); // HIS_SC3
     805           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // ILE_BB
     806           0 :   Model_s_.push_back(0.01*Vector5d(25.500000,25.500000,25.500000,25.500000,25.500000)); // ILE_SC1
     807           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // LYS_BB
     808           0 :   Model_s_.push_back(0.01*Vector5d(18.000000,18.000000,18.000000,18.000000,18.000000)); // LYS_SC1
     809           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // LYS_SC2
     810           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // LEU_BB
     811           0 :   Model_s_.push_back(0.01*Vector5d(21.500000,21.500000,21.500000,21.500000,21.500000)); // LEU_SC1
     812           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // MET_BB
     813           0 :   Model_s_.push_back(0.01*Vector5d(22.500000,22.500000,22.500000,22.500000,22.500000)); // MET_SC1
     814           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // ASN_BB
     815           0 :   Model_s_.push_back(0.01*Vector5d(18.500000,18.500000,18.500000,18.500000,18.500000)); // ASN_SC1
     816           0 :   Model_s_.push_back(0.01*Vector5d(23.500000,23.500000,23.500000,23.500000,23.500000)); // PRO_BB
     817           0 :   Model_s_.push_back(0.01*Vector5d(17.500000,17.500000,17.500000,17.500000,17.500000)); // PRO_SC1
     818           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // GLN_BB
     819           0 :   Model_s_.push_back(0.01*Vector5d(24.500000,24.500000,24.500000,24.500000,24.500000)); // GLN_SC1
     820           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // ARG_BB
     821           0 :   Model_s_.push_back(0.01*Vector5d(18.000000,18.000000,18.000000,18.000000,18.000000)); // ARG_SC1
     822           0 :   Model_s_.push_back(0.01*Vector5d(18.000000,18.000000,18.000000,18.000000,18.000000)); // ARG_SC2
     823           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // SER_BB
     824           0 :   Model_s_.push_back(0.01*Vector5d(9.000000,9.000000,9.000000,9.000000,9.000000)); // SER_SC1
     825           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // THR_BB
     826           0 :   Model_s_.push_back(0.01*Vector5d(17.000000,17.000000,17.000000,17.000000,17.000000)); // THR_SC1
     827           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // VAL_BB
     828           0 :   Model_s_.push_back(0.01*Vector5d(18.000000,18.000000,18.000000,18.000000,18.000000)); // VAL_SC1
     829           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // TRP_BB
     830           0 :   Model_s_.push_back(0.01*Vector5d(11.500000,11.500000,11.500000,11.500000,11.500000)); // TRP_SC1
     831           0 :   Model_s_.push_back(0.01*Vector5d(9.000000,9.000000,9.000000,9.000000,9.000000)); // TRP_SC2
     832           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // TRP_SC3
     833           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // TRP_SC4
     834           0 :   Model_s_.push_back(0.01*Vector5d(9.500000,9.500000,9.500000,9.500000,9.500000)); // TRP_SC5
     835           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // TYR_BB
     836           0 :   Model_s_.push_back(0.01*Vector5d(12.000000,12.000000,12.000000,12.000000,12.000000)); // TYR_SC1
     837           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // TYR_SC2
     838           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // TYR_SC3
     839           0 :   Model_s_.push_back(0.01*Vector5d(8.500000,8.500000,8.500000,8.500000,8.500000)); // TYR_SC4
     840             :   // fill in weight vector for atoms
     841           0 :   Model_w_.push_back(Vector5d(0.0489,0.2091,0.7537,1.1420,0.3555)); // C
     842           0 :   Model_w_.push_back(Vector5d(0.0365,0.1729,0.5805,0.8814,0.3121)); // O
     843           0 :   Model_w_.push_back(Vector5d(0.0267,0.1328,0.5301,1.1020,0.4215)); // N
     844           0 :   Model_w_.push_back(Vector5d(0.0915,0.4312,1.0847,2.4671,1.0852)); // S
     845           0 :   Model_w_.push_back(Vector5d(0.1005,0.4615,1.0663,2.5854,1.2725)); // P
     846           0 :   Model_w_.push_back(Vector5d(0.0382,0.1822,0.5972,0.7707,0.2130)); // F
     847           0 :   Model_w_.push_back(Vector5d(0.1260,0.6442,0.8893,1.8197,1.2988)); // NA
     848           0 :   Model_w_.push_back(Vector5d(0.1130,0.5575,0.9046,2.1580,1.4735)); // MG
     849           0 :   Model_w_.push_back(Vector5d(0.0799,0.3891,1.0037,2.3332,1.0507)); // CL
     850           0 :   Model_w_.push_back(Vector5d(0.2355,0.9916,2.3959,3.7252,2.5647)); // CA
     851           0 :   Model_w_.push_back(Vector5d(0.2149,0.8703,2.4999,2.3591,3.0318)); // K
     852           0 :   Model_w_.push_back(Vector5d(0.1780,0.8096,1.6744,1.9499,1.4495)); // ZN
     853             :   // fill in weight vector for Martini beads
     854           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // ALA_BB
     855           0 :   Model_w_.push_back(Vector5d(0.100000,0.100000,0.100000,0.100000,0.100000)); // ALA_SC1
     856           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // CYS_BB
     857           0 :   Model_w_.push_back(Vector5d(1.100000,1.100000,1.100000,1.100000,1.100000)); // CYS_SC1
     858           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // ASP_BB
     859           0 :   Model_w_.push_back(Vector5d(1.700000,1.700000,1.700000,1.700000,1.700000)); // ASP_SC1
     860           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // GLU_BB
     861           0 :   Model_w_.push_back(Vector5d(2.300000,2.300000,2.300000,2.300000,2.300000)); // GLU_SC1
     862           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // PHE_BB
     863           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // PHE_SC1
     864           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // PHE_SC2
     865           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // PHE_SC3
     866           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // GLY_BB
     867           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // HIS_BB
     868           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // HIS_SC1
     869           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // HIS_SC2
     870           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // HIS_SC3
     871           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // ILE_BB
     872           0 :   Model_w_.push_back(Vector5d(2.000000,2.000000,2.000000,2.000000,2.000000)); // ILE_SC1
     873           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // LYS_BB
     874           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // LYS_SC1
     875           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // LYS_SC2
     876           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // LEU_BB
     877           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // LEU_SC1
     878           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // MET_BB
     879           0 :   Model_w_.push_back(Vector5d(2.300000,2.300000,2.300000,2.300000,2.300000)); // MET_SC1
     880           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // ASN_BB
     881           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // ASN_SC1
     882           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // PRO_BB
     883           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // PRO_SC1
     884           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // GLN_BB
     885           0 :   Model_w_.push_back(Vector5d(2.300000,2.300000,2.300000,2.300000,2.300000)); // GLN_SC1
     886           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // ARG_BB
     887           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // ARG_SC1
     888           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // ARG_SC2
     889           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // SER_BB
     890           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // SER_SC1
     891           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // THR_BB
     892           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // THR_SC1
     893           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // VAL_BB
     894           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // VAL_SC1
     895           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // TRP_BB
     896           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TRP_SC1
     897           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // TRP_SC2
     898           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TRP_SC3
     899           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TRP_SC4
     900           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // TRP_SC5
     901           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // TYR_BB
     902           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TYR_SC1
     903           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TYR_SC2
     904           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TYR_SC3
     905           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // TYR_SC4
     906             :   // cycle on atoms
     907           0 :   for(unsigned i=0; i<atoms.size(); ++i) {
     908             :     // get atom name
     909           0 :     std::string name = moldat->getAtomName(atoms[i]);
     910             :     // get residue name
     911           0 :     std::string resname = moldat->getResidueName(atoms[i]);
     912             :     // type of atoms/bead
     913             :     std::string type_s;
     914             :     // Martini model
     915           0 :     if(martini_) {
     916           0 :       type_s = resname+"_"+name;
     917             :       // Atomistic model
     918             :     } else {
     919             :       char type;
     920             :       // get atom type
     921           0 :       char first = name.at(0);
     922             :       // GOLDEN RULE: type is first letter, if not a number
     923           0 :       if (!isdigit(first)) {
     924             :         type = first;
     925             :         // otherwise is the second
     926             :       } else {
     927           0 :         type = name.at(1);
     928             :       }
     929             :       // convert to std::string
     930           0 :       type_s = std::string(1,type);
     931             :       // special cases
     932           0 :       if(name=="SOD" || name=="NA" || name =="Na") {
     933             :         type_s = "NA";
     934             :       }
     935           0 :       if(name=="MG"  || name=="Mg") {
     936             :         type_s = "MG";
     937             :       }
     938           0 :       if(name=="CLA" || name=="CL" || name =="Cl") {
     939             :         type_s = "CL";
     940             :       }
     941           0 :       if((resname=="CAL" || resname=="CA") && (name=="CAL" || name=="CA" || name =="C0")) {
     942             :         type_s = "CA";
     943             :       }
     944           0 :       if(name=="POT" || name=="K") {
     945             :         type_s = "K";
     946             :       }
     947           0 :       if(name=="ZN"  || name=="Zn") {
     948             :         type_s = "ZN";
     949             :       }
     950             :     }
     951             :     // check if key in map
     952           0 :     if(type_map.find(type_s) != type_map.end()) {
     953             :       // save atom type
     954           0 :       Model_type_.push_back(type_map[type_s]);
     955             :       // this will be normalized in the final density
     956           0 :       Vector5d w = Model_w_[type_map[type_s]];
     957           0 :       Model_w.push_back(w[0]+w[1]+w[2]+w[3]+w[4]);
     958             :       // get residue id
     959           0 :       unsigned ires = moldat->getResidueNumber(atoms[i]);
     960             :       // and chain
     961           0 :       std::string c ("*");
     962           0 :       if(!bfactnoc_) {
     963           0 :         c = moldat->getChainID(atoms[i]);
     964             :       }
     965             :       // define pair residue/chain IDs
     966             :       std::pair<unsigned,std::string> key = std::make_pair(ires,c);
     967             :       // add to map between residue/chain and list of atoms
     968           0 :       Model_resmap_[key].push_back(i);
     969             :       // and global list of residue/chain per atom
     970           0 :       Model_res_.push_back(key);
     971             :       // initialize Bfactor map
     972           0 :       Model_b_[key] = 0.0;
     973             :     } else {
     974           0 :       error("Wrong atom type "+type_s+" from atom name "+name+"\n");
     975             :     }
     976             :   }
     977             :   // create ordered vector of residue-chain IDs
     978           0 :   for(unsigned i=0; i<Model_res_.size(); ++i) {
     979             :     std::pair<unsigned,std::string> key = Model_res_[i];
     980             :     // search in Model_rlist_
     981           0 :     if(find(Model_rlist_.begin(), Model_rlist_.end(), key) == Model_rlist_.end()) {
     982           0 :       Model_rlist_.push_back(key);
     983             :     }
     984             :   }
     985             :   // return weights
     986           0 :   return Model_w;
     987             : }
     988             : 
     989             : // read experimental data file in PLUMED format:
     990           0 : void EMMIVOX::get_exp_data(const std::string &datafile) {
     991           0 :   Vector pos;
     992             :   double dens, err;
     993             :   int idcomp;
     994             : 
     995             : // open file
     996           0 :   IFile *ifile = new IFile();
     997           0 :   if(ifile->FileExist(datafile)) {
     998           0 :     ifile->open(datafile);
     999           0 :     while(ifile->scanField("Id",idcomp)) {
    1000           0 :       ifile->scanField("Pos_0",pos[0]);
    1001           0 :       ifile->scanField("Pos_1",pos[1]);
    1002           0 :       ifile->scanField("Pos_2",pos[2]);
    1003           0 :       ifile->scanField("Density",dens);
    1004           0 :       ifile->scanField("Error",err);
    1005             :       // voxel center
    1006           0 :       Map_m_.push_back(pos);
    1007             :       // experimental density
    1008           0 :       ovdd_.push_back(dens);
    1009             :       // error
    1010           0 :       exp_err_.push_back(err);
    1011             :       // new line
    1012           0 :       ifile->scanField();
    1013             :     }
    1014           0 :     ifile->close();
    1015             :   } else {
    1016           0 :     error("Cannot find DATA_FILE "+datafile+"\n");
    1017             :   }
    1018           0 :   delete ifile;
    1019           0 : }
    1020             : 
    1021           0 : void EMMIVOX::initialize_Bfactor(double reso) {
    1022           0 :   double bfactini = 0.0;
    1023             :   // if doing Bfactor Monte Carlo
    1024           0 :   if(dbfact_>0) {
    1025             :     // initialize B factor based on empirical relation between resolution and average bfactor
    1026             :     // calculated on ~8000 cryo-EM data with resolution < 5 Ang
    1027             :     // Bfact = A*reso**2+B; with A=6.95408 B=-2.45697/100.0 nm^2
    1028           0 :     bfactini = 6.95408*reso*reso - 0.01*2.45697;
    1029             :     // check for min and max
    1030           0 :     bfactini = std::min(bfactmax_, std::max(bfactmin_, bfactini));
    1031             :   }
    1032             :   // set initial Bfactor
    1033           0 :   for(std::map< std::pair<unsigned,std::string>, double>::iterator it=Model_b_.begin(); it!=Model_b_.end(); ++it) {
    1034           0 :     it->second = bfactini;
    1035             :   }
    1036           0 :   log.printf("  experimental map resolution : %3.2f\n", reso);
    1037             :   // if doing Bfactor Monte Carlo
    1038           0 :   if(dbfact_>0) {
    1039           0 :     log.printf("  minimum Bfactor value : %3.2f\n", bfactmin_);
    1040           0 :     log.printf("  maximum Bfactor value : %3.2f\n", bfactmax_);
    1041           0 :     log.printf("  initial Bfactor value : %3.2f\n", bfactini);
    1042             :   }
    1043           0 : }
    1044             : 
    1045             : // prepare auxiliary vectors
    1046           0 : void EMMIVOX::get_auxiliary_vectors() {
    1047             : // number of atoms
    1048           0 :   unsigned natoms = Model_res_.size();
    1049             : // clear lists
    1050           0 :   pref_.clear();
    1051           0 :   invs2_.clear();
    1052           0 :   cut_.clear();
    1053             : // resize
    1054           0 :   pref_.resize(natoms);
    1055           0 :   invs2_.resize(natoms);
    1056           0 :   cut_.resize(natoms);
    1057             : // cycle on all atoms
    1058           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1059             :   for(unsigned im=0; im<natoms; ++im) {
    1060             :     // get atom type
    1061             :     unsigned atype = Model_type_[im];
    1062             :     // get residue/chain IDs
    1063             :     std::pair<unsigned,std::string> key = Model_res_[im];
    1064             :     // get bfactor
    1065             :     double bfact = Model_b_[key];
    1066             :     // sigma for 5 gaussians
    1067             :     Vector5d m_s = Model_s_[atype];
    1068             :     // calculate constant quantities
    1069             :     Vector5d pref, invs2;
    1070             :     // calculate cutoff
    1071             :     double n = 0.0;
    1072             :     double d = 0.0;
    1073             :     for(unsigned j=0; j<5; ++j) {
    1074             :       // total value of b
    1075             :       double m_b = m_s[j] + bfact/4.0;
    1076             :       // calculate invs2
    1077             :       invs2[j] = 1.0/(inv_pi2_*m_b);
    1078             :       // prefactor
    1079             :       pref[j]  = cfact_[atype][j] * pow(invs2[j],1.5);
    1080             :       // cutoff
    1081             :       n += pref[j] / invs2[j];
    1082             :       d += pref[j];
    1083             :     }
    1084             :     // put into global lists
    1085             :     pref_[im]  = pref;
    1086             :     invs2_[im] = invs2;
    1087             :     cut_[im] = std::min(nl_dist_cutoff_, sqrt(n/d)*nl_gauss_cutoff_);
    1088             :   }
    1089             :   // push to GPU
    1090           0 :   push_auxiliary_gpu();
    1091           0 : }
    1092             : 
    1093           0 : void EMMIVOX::push_auxiliary_gpu() {
    1094             :   // 1) create vector of pref_ and invs2_
    1095           0 :   int natoms = Model_type_.size();
    1096           0 :   std::vector<double> pref(5*natoms), invs2(5*natoms);
    1097           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1098             :   for(int i=0; i<natoms; ++i) {
    1099             :     for(int j=0; j<5; ++j) {
    1100             :       pref[i+j*natoms]  = pref_[i][j];
    1101             :       invs2[i+j*natoms] = invs2_[i][j];
    1102             :     }
    1103             :   }
    1104             :   // 2) initialize gpu tensors
    1105           0 :   pref_gpu_  = torch::from_blob(pref.data(),  {5,natoms}, torch::kFloat64).clone().to(torch::kFloat32).to(device_t_);
    1106           0 :   invs2_gpu_ = torch::from_blob(invs2.data(), {5,natoms}, torch::kFloat64).clone().to(torch::kFloat32).to(device_t_);
    1107           0 : }
    1108             : 
    1109           0 : void EMMIVOX::get_close_residues() {
    1110             :   // clear neighbor list
    1111           0 :   nl_res_.clear();
    1112           0 :   nl_res_.resize(Model_rlist_.size());
    1113             : 
    1114             :   // loop in parallel
    1115           0 :   #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1116             :   {
    1117             :     // private variable
    1118             :     std::vector< std::vector<unsigned> > nl_res_l(Model_rlist_.size());
    1119             :     // cycle on residues/chains #1
    1120             :     #pragma omp for
    1121             :     for(unsigned i=0; i<Model_rlist_.size()-1; ++i) {
    1122             : 
    1123             :       // key1: pair of residue/chain IDs
    1124             :       std::pair<unsigned,std::string> key1 = Model_rlist_[i];
    1125             : 
    1126             :       // cycle over residues/chains #2
    1127             :       for(unsigned j=i+1; j<Model_rlist_.size(); ++j) {
    1128             : 
    1129             :         // key2: pair of residue/chain IDs
    1130             :         std::pair<unsigned,std::string> key2 = Model_rlist_[j];
    1131             : 
    1132             :         // set flag neighbor
    1133             :         bool neigh = false;
    1134             : 
    1135             :         // cycle over all the atoms belonging to key1
    1136             :         for(unsigned im1=0; im1<Model_resmap_[key1].size(); ++im1) {
    1137             :           // get atom position #1
    1138             :           Vector pos1 = getPosition(Model_resmap_[key1][im1]);
    1139             :           // cycle over all the atoms belonging to key2
    1140             :           for(unsigned im2=0; im2<Model_resmap_[key2].size(); ++im2) {
    1141             :             // get atom position #2
    1142             :             Vector pos2 = getPosition(Model_resmap_[key2][im2]);
    1143             :             // if closer than 0.5 nm, then residues key1 and key2 are neighbors
    1144             :             if(delta(pos1,pos2).modulo()<0.5) {
    1145             :               // set neighbors
    1146             :               neigh = true;
    1147             :               // and exit
    1148             :               break;
    1149             :             }
    1150             :           }
    1151             :           // check if neighbor already found
    1152             :           if(neigh) {
    1153             :             break;
    1154             :           }
    1155             :         }
    1156             : 
    1157             :         // if neighbors, add to local list
    1158             :         if(neigh) {
    1159             :           nl_res_l[i].push_back(j);
    1160             :           nl_res_l[j].push_back(i);
    1161             :         }
    1162             :       }
    1163             :     }
    1164             :     // add to global list
    1165             :     #pragma omp critical
    1166             :     {
    1167             :       for(unsigned i=0; i<nl_res_.size(); ++i) {
    1168             :         nl_res_[i].insert(nl_res_[i].end(), nl_res_l[i].begin(), nl_res_l[i].end());
    1169             :       }
    1170             :     }
    1171             :   }
    1172           0 : }
    1173             : 
    1174           0 : void EMMIVOX::doMonteCarloBfact() {
    1175             : // update residue neighbor list
    1176           0 :   get_close_residues();
    1177             : 
    1178             : // cycle over residues/chains
    1179           0 :   for(unsigned ir=0; ir<Model_rlist_.size(); ++ir) {
    1180             : 
    1181             :     // key: pair of residue/chain IDs
    1182             :     std::pair<unsigned,std::string> key = Model_rlist_[ir];
    1183             :     // old bfactor
    1184           0 :     double bfactold = Model_b_[key];
    1185             : 
    1186             :     // propose move in bfactor
    1187           0 :     double bfactnew = bfactold + dbfact_ * ( 2.0 * random_.RandU01() - 1.0 );
    1188             :     // check boundaries
    1189           0 :     if(bfactnew > bfactmax_) {
    1190           0 :       bfactnew = 2.0*bfactmax_ - bfactnew;
    1191             :     }
    1192           0 :     if(bfactnew < bfactmin_) {
    1193           0 :       bfactnew = 2.0*bfactmin_ - bfactnew;
    1194             :     }
    1195             : 
    1196             :     // useful quantities
    1197             :     std::map<unsigned, double> deltaov;
    1198             : 
    1199           0 :     #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1200             :     {
    1201             :       // private variables
    1202             :       std::map<unsigned, double> deltaov_l;
    1203             :       #pragma omp for
    1204             :       // cycle over all the atoms belonging to key (residue/chain)
    1205             :       for(unsigned ia=0; ia<Model_resmap_[key].size(); ++ia) {
    1206             : 
    1207             :         // get atom id
    1208             :         unsigned im = Model_resmap_[key][ia];
    1209             :         // get atom type
    1210             :         unsigned atype = Model_type_[im];
    1211             :         // sigma for 5 Gaussians
    1212             :         Vector5d m_s = Model_s_[atype];
    1213             :         // prefactors
    1214             :         Vector5d cfact = cfact_[atype];
    1215             :         // and position
    1216             :         Vector pos = getPosition(im);
    1217             : 
    1218             :         // cycle on all the neighboring voxels affected by a change in Bfactor
    1219             :         for(unsigned i=0; i<Model_nb_[im].size(); ++i) {
    1220             :           // voxel id
    1221             :           unsigned id = Model_nb_[im][i];
    1222             :           // get contribution to density in id before change
    1223             :           double dold = get_overlap(Map_m_[id], pos, cfact, m_s, bfactold);
    1224             :           // get contribution after change
    1225             :           double dnew = get_overlap(Map_m_[id], pos, cfact, m_s, bfactnew);
    1226             :           // update delta density
    1227             :           deltaov_l[id] += dnew-dold;
    1228             :         }
    1229             :       }
    1230             :       // add to global list
    1231             :       #pragma omp critical
    1232             :       {
    1233             :         for(std::map<unsigned,double>::iterator itov=deltaov_l.begin(); itov!=deltaov_l.end(); ++itov) {
    1234             :           deltaov[itov->first] += itov->second;
    1235             :         }
    1236             :       }
    1237             :     }
    1238             : 
    1239             :     // now calculate new and old score
    1240             :     double old_ene = 0.0;
    1241             :     double new_ene = 0.0;
    1242             : 
    1243             :     // cycle on all affected voxels
    1244           0 :     for(std::map<unsigned,double>::iterator itov=deltaov.begin(); itov!=deltaov.end(); ++itov) {
    1245             :       // id of the component
    1246           0 :       unsigned id = itov->first;
    1247             :       // new value
    1248           0 :       double ovmdnew = ovmd_[id]+itov->second;
    1249             :       // deviations
    1250           0 :       double devold = scale_ * ovmd_[id] + offset_ - ovdd_[id];
    1251           0 :       double devnew = scale_ * ovmdnew   + offset_ - ovdd_[id];
    1252             :       // inverse of sigma_min
    1253           0 :       double ismin = ismin_[id];
    1254             :       // scores
    1255           0 :       if(devold==0.0) {
    1256           0 :         old_ene += -kbt_ * std::log( 0.5 * sqrt2_pi_ * ismin );
    1257             :       } else {
    1258           0 :         old_ene += -kbt_ * std::log( 0.5 / devold * erf ( devold * inv_sqrt2_ * ismin ));
    1259             :       }
    1260           0 :       if(devnew==0.0) {
    1261           0 :         new_ene += -kbt_ * std::log( 0.5 * sqrt2_pi_ * ismin );
    1262             :       } else {
    1263           0 :         new_ene += -kbt_ * std::log( 0.5 / devnew * erf ( devnew * inv_sqrt2_ * ismin ));
    1264             :       }
    1265             :     }
    1266             : 
    1267             :     // list of neighboring residues
    1268           0 :     std::vector<unsigned> close = nl_res_[ir];
    1269             :     // add restraint to keep Bfactors of close residues close
    1270           0 :     for(unsigned i=0; i<close.size(); ++i) {
    1271             :       // residue/chain IDs of neighbor
    1272           0 :       std::pair<unsigned,std::string> keyn = Model_rlist_[close[i]];
    1273             :       // deviations
    1274           0 :       double devold = bfactold - Model_b_[keyn];
    1275           0 :       double devnew = bfactnew - Model_b_[keyn];
    1276             :       // inverse of sigma_min
    1277           0 :       double ismin = 1.0 / bfactsig_;
    1278             :       // scores
    1279           0 :       if(devold==0.0) {
    1280           0 :         old_ene += -kbt_ * std::log( 0.5 * sqrt2_pi_ * ismin );
    1281             :       } else {
    1282           0 :         old_ene += -kbt_ * std::log( 0.5 / devold * erf ( devold * inv_sqrt2_ * ismin ));
    1283             :       }
    1284           0 :       if(devnew==0.0) {
    1285           0 :         new_ene += -kbt_ * std::log( 0.5 * sqrt2_pi_ * ismin );
    1286             :       } else {
    1287           0 :         new_ene += -kbt_ * std::log( 0.5 / devnew * erf ( devnew * inv_sqrt2_ * ismin ));
    1288             :       }
    1289             :     }
    1290             : 
    1291             :     // increment number of trials
    1292           0 :     MCBtrials_ += 1.0;
    1293             : 
    1294             :     // accept or reject
    1295             :     bool accept = false;
    1296           0 :     if(bfactemin_) {
    1297           0 :       if(new_ene < old_ene) {
    1298             :         accept = true;
    1299             :       }
    1300             :     } else {
    1301           0 :       accept = doAccept(old_ene, new_ene, kbt_);
    1302             :     }
    1303             : 
    1304             :     // in case of acceptance
    1305           0 :     if(accept) {
    1306             :       // update acceptance rate
    1307           0 :       MCBaccept_ += 1.0;
    1308             :       // update bfactor
    1309           0 :       Model_b_[key] = bfactnew;
    1310             :       // change all the ovmd_ affected
    1311           0 :       for(std::map<unsigned,double>::iterator itov=deltaov.begin(); itov!=deltaov.end(); ++itov) {
    1312           0 :         ovmd_[itov->first] += itov->second;
    1313             :       }
    1314             :     }
    1315             : 
    1316             :   } // end cycle on bfactors
    1317             : 
    1318             : // update auxiliary lists (to update pref_gpu_, invs2_gpu_, and cut_ on CPU/GPU)
    1319           0 :   get_auxiliary_vectors();
    1320             : // update neighbor list (new cut_ + update pref_nl_gpu_ and invs2_nl_gpu_ on GPU)
    1321           0 :   update_neighbor_list();
    1322             : // recalculate fmod (to update derivatives)
    1323           0 :   calculate_fmod();
    1324           0 : }
    1325             : 
    1326             : // get overlap
    1327           0 : double EMMIVOX::get_overlap(const Vector &d_m, const Vector &m_m,
    1328             :                             const Vector5d &cfact, const Vector5d &m_s, double bfact) {
    1329             :   // calculate vector difference
    1330           0 :   Vector md = delta(m_m, d_m);
    1331             :   // norm squared
    1332           0 :   double md2 = md[0]*md[0]+md[1]*md[1]+md[2]*md[2];
    1333             :   // cycle on 5 Gaussians
    1334             :   double ov_tot = 0.0;
    1335           0 :   for(unsigned j=0; j<5; ++j) {
    1336             :     // total value of b
    1337           0 :     double m_b = m_s[j]+bfact/4.0;
    1338             :     // calculate invs2
    1339           0 :     double invs2 = 1.0/(inv_pi2_*m_b);
    1340             :     // final calculation
    1341           0 :     ov_tot += cfact[j] * pow(invs2, 1.5) * std::exp(-0.5 * md2 * invs2);
    1342             :   }
    1343           0 :   return ov_tot;
    1344             : }
    1345             : 
    1346           0 : void EMMIVOX::update_neighbor_sphere() {
    1347             :   // number of atoms
    1348           0 :   unsigned natoms = Model_type_.size();
    1349             :   // clear neighbor sphere
    1350             :   ns_.clear();
    1351             :   // store reference positions
    1352           0 :   refpos_ = getPositions();
    1353             : 
    1354             :   // cycle on voxels - in parallel
    1355           0 :   #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1356             :   {
    1357             :     // private variables
    1358             :     std::vector< std::pair<unsigned,unsigned> > ns_l;
    1359             :     #pragma omp for
    1360             :     for(unsigned id=0; id<ovdd_.size(); ++id) {
    1361             :       // grid point
    1362             :       Vector d_m = Map_m_[id];
    1363             :       // cycle on atoms
    1364             :       for(unsigned im=0; im<natoms; ++im) {
    1365             :         // calculate distance
    1366             :         double dist = delta(getPosition(im), d_m).modulo();
    1367             :         // add to local list
    1368             :         if(dist<=2.0*cut_[im]) {
    1369             :           ns_l.push_back(std::make_pair(id,im));
    1370             :         }
    1371             :       }
    1372             :     }
    1373             :     // add to global list
    1374             :     #pragma omp critical
    1375             :     ns_.insert(ns_.end(), ns_l.begin(), ns_l.end());
    1376             :   }
    1377           0 : }
    1378             : 
    1379           0 : bool EMMIVOX::do_neighbor_sphere() {
    1380           0 :   std::vector<double> dist(getPositions().size());
    1381             :   bool update = false;
    1382             : 
    1383             : // calculate displacement
    1384           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1385             :   for(unsigned im=0; im<dist.size(); ++im) {
    1386             :     dist[im] = delta(getPosition(im),refpos_[im]).modulo()/cut_[im];
    1387             :   }
    1388             : 
    1389             : // check if update or not
    1390           0 :   double maxdist = *max_element(dist.begin(), dist.end());
    1391           0 :   if(maxdist>=1.0) {
    1392             :     update=true;
    1393             :   }
    1394             : 
    1395             : // return if update or not
    1396           0 :   return update;
    1397             : }
    1398             : 
    1399           0 : void EMMIVOX::update_neighbor_list() {
    1400             :   // number of atoms
    1401           0 :   unsigned natoms = Model_type_.size();
    1402             :   // clear neighbor list
    1403             :   nl_.clear();
    1404             : 
    1405             :   // cycle on neighbour sphere - in parallel
    1406           0 :   #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1407             :   {
    1408             :     // private variables
    1409             :     std::vector< std::pair<unsigned,unsigned> > nl_l;
    1410             :     #pragma omp for
    1411             :     for(unsigned long long i=0; i<ns_.size(); ++i) {
    1412             :       // calculate distance
    1413             :       double dist = delta(Map_m_[ns_[i].first], getPosition(ns_[i].second)).modulo();
    1414             :       // add to local neighbour list
    1415             :       if(dist<=cut_[ns_[i].second]) {
    1416             :         nl_l.push_back(ns_[i]);
    1417             :       }
    1418             :     }
    1419             :     // add to global list
    1420             :     #pragma omp critical
    1421             :     nl_.insert(nl_.end(), nl_l.begin(), nl_l.end());
    1422             :   }
    1423             : 
    1424             :   // new dimension of neighbor list
    1425             :   unsigned long long nl_size = nl_.size();
    1426             :   // now resize derivatives
    1427           0 :   ovmd_der_.resize(nl_size);
    1428             : 
    1429             :   // in case of B-factors sampling - at the right step
    1430           0 :   if(dbfact_>0 && getStep()%MCBstride_==0) {
    1431             :     // clear vectors
    1432           0 :     Model_nb_.clear();
    1433           0 :     Model_nb_.resize(natoms);
    1434             :     // cycle over the neighbor list to creat a list of voxels per atom
    1435           0 :     #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1436             :     {
    1437             :       // private variables
    1438             :       std::vector< std::vector<unsigned> > Model_nb_l(natoms);
    1439             :       #pragma omp for
    1440             :       for(unsigned long long i=0; i<nl_size; ++i) {
    1441             :         Model_nb_l[nl_[i].second].push_back(nl_[i].first);
    1442             :       }
    1443             :       // add to global list
    1444             :       #pragma omp critical
    1445             :       {
    1446             :         for(unsigned i=0; i<natoms; ++i) {
    1447             :           Model_nb_[i].insert(Model_nb_[i].end(), Model_nb_l[i].begin(), Model_nb_l[i].end());
    1448             :         }
    1449             :       }
    1450             :     }
    1451             :   }
    1452             : 
    1453             :   // transfer data to gpu
    1454           0 :   update_gpu();
    1455           0 : }
    1456             : 
    1457           0 : void EMMIVOX::update_gpu() {
    1458             :   // dimension of neighbor list
    1459             :   long long nl_size = nl_.size();
    1460             :   // create useful vectors
    1461           0 :   std::vector<int> nl_id(nl_size), nl_im(nl_size);
    1462           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1463             :   for(unsigned long long i=0; i<nl_size; ++i) {
    1464             :     nl_id[i] = static_cast<int>(nl_[i].first);
    1465             :     nl_im[i] = static_cast<int>(nl_[i].second);
    1466             :   }
    1467             :   // create tensors on device
    1468           0 :   nl_id_gpu_ = torch::from_blob(nl_id.data(), {nl_size}, torch::kInt32).clone().to(device_t_);
    1469           0 :   nl_im_gpu_ = torch::from_blob(nl_im.data(), {nl_size}, torch::kInt32).clone().to(device_t_);
    1470             :   // now we need to create pref_nl_gpu_ [5,nl_size]
    1471           0 :   pref_nl_gpu_  = torch::index_select(pref_gpu_,1,nl_im_gpu_);
    1472             :   // and invs2_nl_gpu_ [5,nl_size]
    1473           0 :   invs2_nl_gpu_ = torch::index_select(invs2_gpu_,1,nl_im_gpu_);
    1474             :   // and Map_m_nl_gpu_ [3,nl_size]
    1475           0 :   Map_m_nl_gpu_ = torch::index_select(Map_m_gpu_,1,nl_id_gpu_);
    1476           0 : }
    1477             : 
    1478           0 : void EMMIVOX::prepare() {
    1479           0 :   if(getExchangeStep()) {
    1480           0 :     first_time_=true;
    1481             :   }
    1482           0 : }
    1483             : 
    1484             : // calculate forward model on gpu
    1485           0 : void EMMIVOX::calculate_fmod() {
    1486             :   // number of atoms
    1487           0 :   int natoms = Model_type_.size();
    1488             :   // number of data points
    1489           0 :   int nd = ovdd_.size();
    1490             : 
    1491             :   // fill positions in in parallel
    1492           0 :   std::vector<double> posg(3*natoms);
    1493           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1494             :   for (int i=0; i<natoms; ++i) {
    1495             :     // fill vectors
    1496             :     posg[i]          = getPosition(i)[0];
    1497             :     posg[i+natoms]   = getPosition(i)[1];
    1498             :     posg[i+2*natoms] = getPosition(i)[2];
    1499             :   }
    1500             :   // transfer positions to pos_gpu [3,natoms]
    1501           0 :   torch::Tensor pos_gpu = torch::from_blob(posg.data(), {3,natoms}, torch::kFloat64).to(torch::kFloat32).to(device_t_);
    1502             :   // create pos_nl_gpu_ [3,nl_size]
    1503           0 :   torch::Tensor pos_nl_gpu = torch::index_select(pos_gpu,1,nl_im_gpu_);
    1504             :   // calculate vector difference [3,nl_size]
    1505           0 :   torch::Tensor md = Map_m_nl_gpu_ - pos_nl_gpu;
    1506             :   // calculate norm squared by column [1,nl_size]
    1507           0 :   torch::Tensor md2 = torch::sum(md*md,0);
    1508             :   // calculate density [5,nl_size]
    1509           0 :   torch::Tensor ov = pref_nl_gpu_ * torch::exp(-0.5 * md2 * invs2_nl_gpu_);
    1510             :   // and derivatives [5,nl_size]
    1511           0 :   ovmd_der_gpu_ = invs2_nl_gpu_ * ov;
    1512             :   // sum density over 5 columns [1,nl_size]
    1513           0 :   ov = torch::sum(ov,0);
    1514             :   // sum contributions from the same atom
    1515           0 :   auto options = torch::TensorOptions().device(device_t_).dtype(torch::kFloat32);
    1516           0 :   ovmd_gpu_ = torch::zeros({nd}, options);
    1517           0 :   ovmd_gpu_.index_add_(0, nl_id_gpu_, ov);
    1518             :   // sum derivatives over 5 rows [1,nl_size] and multiply by md [3,nl_size]
    1519           0 :   ovmd_der_gpu_ = md * torch::sum(ovmd_der_gpu_,0);
    1520             : 
    1521             :   // in case of metainference: average them across replicas
    1522           0 :   if(!no_aver_ && nrep_>1) {
    1523             :     // communicate ovmd_gpu_ to CPU [1, nd]
    1524           0 :     torch::Tensor ovmd_cpu = ovmd_gpu_.detach().to(torch::kCPU).to(torch::kFloat64);
    1525             :     // and put them in ovmd_
    1526           0 :     ovmd_ = std::vector<double>(ovmd_cpu.data_ptr<double>(), ovmd_cpu.data_ptr<double>() + ovmd_cpu.numel());
    1527             :     // sum across replicas
    1528           0 :     multi_sim_comm.Sum(&ovmd_[0], nd);
    1529             :     // and divide by number of replicas
    1530           0 :     double escale = 1.0 / static_cast<double>(nrep_);
    1531           0 :     for(int i=0; i<nd; ++i) {
    1532           0 :       ovmd_[i] *= escale;
    1533             :     }
    1534             :     // put back on device
    1535           0 :     ovmd_gpu_ = torch::from_blob(ovmd_.data(), {nd}, torch::kFloat64).to(torch::kFloat32).to(device_t_);
    1536             :   }
    1537             : 
    1538             :   // communicate back model density
    1539             :   // this is needed only in certain situations
    1540           0 :   long int step = getStep();
    1541             :   bool do_comm = false;
    1542           0 :   if(mapstride_>0 && step%mapstride_==0) {
    1543             :     do_comm = true;
    1544             :   }
    1545           0 :   if(dbfact_>0    && step%MCBstride_==0) {
    1546             :     do_comm = true;
    1547             :   }
    1548           0 :   if(do_corr_) {
    1549             :     do_comm = true;
    1550             :   }
    1551             :   // in case of metainference: already communicated
    1552           0 :   if(!no_aver_ && nrep_>1) {
    1553             :     do_comm = false;
    1554             :   }
    1555           0 :   if(do_comm) {
    1556             :     // communicate ovmd_gpu_ to CPU [1, nd]
    1557           0 :     torch::Tensor ovmd_cpu = ovmd_gpu_.detach().to(torch::kCPU).to(torch::kFloat64);
    1558             :     // and put them in ovmd_
    1559           0 :     ovmd_ = std::vector<double>(ovmd_cpu.data_ptr<double>(), ovmd_cpu.data_ptr<double>() + ovmd_cpu.numel());
    1560             :   }
    1561           0 : }
    1562             : 
    1563             : // calculate score
    1564           0 : void EMMIVOX::calculate_score() {
    1565             :   // number of atoms
    1566           0 :   int natoms = Model_type_.size();
    1567             : 
    1568             :   // calculate deviation model/data [1, nd]
    1569           0 :   torch::Tensor dev = scale_ * ovmd_gpu_ + offset_ - ovdd_gpu_;
    1570             :   // error function [1, nd]
    1571           0 :   torch::Tensor errf = torch::erf( dev * inv_sqrt2_ * ismin_gpu_ );
    1572             :   // take care of dev = zero
    1573           0 :   torch::Tensor zeros_d = torch::ne(dev, 0.0);
    1574             :   // redefine dev
    1575           0 :   dev = dev * zeros_d + eps_ * torch::logical_not(zeros_d);
    1576             :   // take care of errf = zero
    1577           0 :   torch::Tensor zeros_e = torch::ne(errf, 0.0);
    1578             :   // redefine errf
    1579           0 :   errf = errf * zeros_e + eps_ * torch::logical_not(zeros_e);
    1580             :   // logical AND: both dev and errf different from zero
    1581             :   torch::Tensor zeros = torch::logical_and(zeros_d, zeros_e);
    1582             :   // energy - with limit dev going to zero
    1583           0 :   torch::Tensor ene = 0.5 * ( errf / dev * zeros + torch::logical_not(zeros) * sqrt2_pi_ *  ismin_gpu_);
    1584             :   // logarithm and sum
    1585           0 :   ene = -kbt_ * torch::sum(torch::log(ene));
    1586             :   // and derivatives [1, nd]
    1587           0 :   torch::Tensor d_der = -kbt_ * zeros * ( sqrt2_pi_ * torch::exp( -0.5 * dev * dev * ismin_gpu_ * ismin_gpu_ ) * ismin_gpu_ / errf - 1.0 / dev );
    1588             :   // tensor for derivatives wrt atoms [1, nl_size]
    1589           0 :   torch::Tensor der_gpu = torch::index_select(d_der,0,nl_id_gpu_);
    1590             :   // multiply by ovmd_der_gpu_ and scale [3, nl_size]
    1591           0 :   der_gpu = ovmd_der_gpu_ * scale_ * der_gpu;
    1592             :   // sum contributions for each atom
    1593           0 :   auto options = torch::TensorOptions().device(device_t_).dtype(torch::kFloat32);
    1594           0 :   torch::Tensor atoms_der_gpu = torch::zeros({3,natoms}, options);
    1595           0 :   atoms_der_gpu.index_add_(1, nl_im_gpu_, der_gpu);
    1596             : 
    1597             :   // FINAL STUFF
    1598             :   //
    1599             :   // 1) communicate total energy to CPU
    1600           0 :   torch::Tensor ene_cpu = ene.detach().to(torch::kCPU).to(torch::kFloat64);
    1601           0 :   ene_ = *ene_cpu.data_ptr<double>();
    1602             :   // with marginal, simply multiply by number of replicas!
    1603           0 :   if(!no_aver_ && nrep_>1) {
    1604           0 :     ene_ *= static_cast<double>(nrep_);
    1605             :   }
    1606             :   //
    1607             :   // 2) communicate derivatives to CPU
    1608           0 :   torch::Tensor atom_der_cpu = atoms_der_gpu.detach().to(torch::kCPU).to(torch::kFloat64);
    1609             :   // convert to std::vector<double>
    1610           0 :   std::vector<double> atom_der = std::vector<double>(atom_der_cpu.data_ptr<double>(), atom_der_cpu.data_ptr<double>() + atom_der_cpu.numel());
    1611             :   // and put in atom_der_
    1612           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1613             :   for(int i=0; i<natoms; ++i) {
    1614             :     atom_der_[i] = Vector(atom_der[i],atom_der[i+natoms],atom_der[i+2*natoms]);
    1615             :   }
    1616             :   //
    1617             :   // 3) calculate virial on CPU
    1618           0 :   Tensor virial;
    1619             :   // declare omp reduction for Tensors
    1620             :   #pragma omp declare reduction( sumTensor : Tensor : omp_out += omp_in )
    1621             : 
    1622           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads()) reduction (sumTensor : virial)
    1623             :   for(int i=0; i<natoms; ++i) {
    1624             :     virial += Tensor(getPosition(i), -atom_der_[i]);
    1625             :   }
    1626             :   // store virial
    1627           0 :   virial_ = virial;
    1628           0 : }
    1629             : 
    1630           0 : void EMMIVOX::calculate() {
    1631             :   // get time step
    1632           0 :   long int step = getStep();
    1633             : 
    1634             :   // set temperature value
    1635           0 :   getPntrToComponent("kbt")->set(kbt_);
    1636             : 
    1637             :   // neighbor list update
    1638           0 :   if(first_time_ || getExchangeStep() || step%nl_stride_==0) {
    1639             :     // check if time to update neighbor sphere
    1640             :     bool update = false;
    1641           0 :     if(first_time_ || getExchangeStep()) {
    1642             :       update = true;
    1643             :     } else {
    1644           0 :       update = do_neighbor_sphere();
    1645             :     }
    1646             :     // update neighbor sphere
    1647           0 :     if(update) {
    1648           0 :       update_neighbor_sphere();
    1649             :     }
    1650             :     // update neighbor list
    1651           0 :     update_neighbor_list();
    1652             :     // set flag
    1653           0 :     first_time_=false;
    1654             :   }
    1655             : 
    1656             :   // calculate forward model
    1657           0 :   calculate_fmod();
    1658             : 
    1659             :   // Monte Carlo on bfactors
    1660           0 :   if(dbfact_>0) {
    1661             :     double acc = 0.0;
    1662             :     // do Monte Carlo
    1663           0 :     if(step%MCBstride_==0 && !getExchangeStep() && step>0) {
    1664           0 :       doMonteCarloBfact();
    1665             :     }
    1666             :     // calculate acceptance ratio
    1667           0 :     if(MCBtrials_>0) {
    1668           0 :       acc = MCBaccept_ / MCBtrials_;
    1669             :     }
    1670             :     // set value
    1671           0 :     getPntrToComponent("accB")->set(acc);
    1672             :   }
    1673             : 
    1674             :   // calculate score
    1675           0 :   calculate_score();
    1676             : 
    1677             :   // set score, virial, and derivatives
    1678           0 :   Value* score = getPntrToComponent("scoreb");
    1679           0 :   score->set(ene_);
    1680           0 :   setBoxDerivatives(score, virial_);
    1681           0 :   #pragma omp parallel for
    1682             :   for(unsigned i=0; i<atom_der_.size(); ++i) {
    1683             :     setAtomsDerivatives(score, i, atom_der_[i]);
    1684             :   }
    1685             :   // set scale and offset value
    1686           0 :   getPntrToComponent("scale")->set(scale_);
    1687           0 :   getPntrToComponent("offset")->set(offset_);
    1688             :   // calculate correlation coefficient
    1689           0 :   if(do_corr_) {
    1690           0 :     calculate_corr();
    1691             :   }
    1692             :   // PRINT other quantities to files
    1693             :   // - status file
    1694           0 :   if(step%statusstride_==0) {
    1695           0 :     print_status(step);
    1696             :   }
    1697             :   // - density file
    1698           0 :   if(mapstride_>0 && step%mapstride_==0) {
    1699           0 :     write_model_density(step);
    1700             :   }
    1701           0 : }
    1702             : 
    1703           0 : void EMMIVOX::calculate_corr() {
    1704             : // number of data points
    1705           0 :   double nd = static_cast<double>(ovdd_.size());
    1706             : // average ovmd_ and ovdd_
    1707           0 :   double ave_md = std::accumulate(ovmd_.begin(), ovmd_.end(), 0.) / nd;
    1708           0 :   double ave_dd = std::accumulate(ovdd_.begin(), ovdd_.end(), 0.) / nd;
    1709             : // calculate correlation
    1710             :   double num = 0.;
    1711             :   double den1 = 0.;
    1712             :   double den2 = 0.;
    1713           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads()) reduction( + : num, den1, den2)
    1714             :   for(unsigned i=0; i<ovdd_.size(); ++i) {
    1715             :     double md = ovmd_[i]-ave_md;
    1716             :     double dd = ovdd_[i]-ave_dd;
    1717             :     num  += md*dd;
    1718             :     den1 += md*md;
    1719             :     den2 += dd*dd;
    1720             :   }
    1721             : // correlation coefficient
    1722           0 :   double cc = num / sqrt(den1*den2);
    1723             : // set plumed
    1724           0 :   getPntrToComponent("corr")->set(cc);
    1725           0 : }
    1726             : 
    1727             : 
    1728             : }
    1729             : }
    1730             : 
    1731             : #endif

Generated by: LCOV version 1.16