LCOV - code coverage report
Current view: top level - isdb - EMMIVox.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 35 647 5.4 %
Date: 2024-10-18 13:59:31 Functions: 1 26 3.8 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : 
      23             : #ifdef __PLUMED_HAS_LIBTORCH
      24             : #include "colvar/Colvar.h"
      25             : #include "core/ActionRegister.h"
      26             : #include "core/PlumedMain.h"
      27             : #include "tools/Communicator.h"
      28             : #include "tools/Matrix.h"
      29             : #include "core/GenericMolInfo.h"
      30             : #include "core/ActionSet.h"
      31             : #include "tools/File.h"
      32             : #include "tools/OpenMP.h"
      33             : #include <string>
      34             : #include <cmath>
      35             : #include <map>
      36             : #include <numeric>
      37             : #include <ctime>
      38             : #include "tools/Random.h"
      39             : 
      40             : #include <torch/torch.h>
      41             : #include <torch/script.h>
      42             : 
      43             : #ifndef M_PI
      44             : #define M_PI           3.14159265358979323846
      45             : #endif
      46             : 
      47             : namespace PLMD {
      48             : namespace isdb {
      49             : 
      50             : //+PLUMEDOC ISDB_COLVAR EMMIVOX
      51             : /*
      52             : Bayesian single-structure and ensemble refinement with cryo-EM maps.
      53             : 
      54             : This action implements the Bayesian approach for single-structure and ensemble refinement from cryo-EM maps introduced <a href="https://www.biorxiv.org/content/10.1101/2023.10.18.562710v1">here</a>.
      55             : EMMIVox does not require fitting the cryo-EM map with a Gaussian Mixture Model, as done in \ref EMMI, but uses directly the voxels in the deposited map.
      56             : 
      57             : When run in single-replica mode, this action allows atomistic, flexible refinement (and B-factors inference) of an individual structure into a density map.
      58             : A coarse-grained forward model can also be used in combination with the MARTINI force field.
      59             : Combined with a multi-replica framework (such as the -multi option in GROMACS), the user can model an ensemble of structures using
      60             : the Metainference approach \cite Bonomi:2016ip . The approach can be used to model continous dynamics of flexible regions as well as semi-ordered waters, lipids, and ions.
      61             : 
      62             : \warning
      63             :     To use EMMIVOX, PLUMED must be linked against the LibTorch library as described \ref ISDB "here"
      64             : 
      65             : \par Examples
      66             : 
      67             : Complete tutorials for single-structure and ensemble refinement can be found <a href="https://github.com/COSBlab/EMMIVox">here</a>.
      68             : 
      69             : */
      70             : //+ENDPLUMEDOC
      71             : 
      72             : class EMMIVOX : public Colvar {
      73             : 
      74             : private:
      75             : 
      76             : // temperature in kbt
      77             :   double kbt_;
      78             : // model - atom types
      79             :   std::vector<unsigned> Model_type_;
      80             : // model - list of atom sigmas - one per atom type
      81             :   std::vector<Vector5d> Model_s_;
      82             : // model - list of atom weights - one per atom type
      83             :   std::vector<Vector5d> Model_w_;
      84             : // model - map between residue/chain IDs and list of atoms
      85             :   std::map< std::pair<unsigned,std::string>, std::vector<unsigned> > Model_resmap_;
      86             : // model - list of residue/chain IDs per atom
      87             :   std::vector< std::pair<unsigned,std::string> > Model_res_;
      88             : // model - list of neighboring voxels per atom
      89             :   std::vector< std::vector<unsigned> > Model_nb_;
      90             : // model - map between residue/chain ID and bfactor
      91             :   std::map< std::pair<unsigned, std::string>, double> Model_b_;
      92             : // model - global list of residue/chain IDs
      93             :   std::vector< std::pair<unsigned,std::string> > Model_rlist_;
      94             : // model density
      95             :   std::vector<double> ovmd_;
      96             : 
      97             : // data map - voxel position
      98             :   std::vector<Vector> Map_m_;
      99             : // data map - density
     100             :   std::vector<double> ovdd_;
     101             : // data map - error
     102             :   std::vector<double> exp_err_;
     103             : 
     104             : // derivatives
     105             :   std::vector<Vector> ovmd_der_;
     106             :   std::vector<Vector> atom_der_;
     107             :   std::vector<double> score_der_;
     108             : // constants
     109             :   double inv_sqrt2_, sqrt2_pi_, inv_pi2_;
     110             :   std::vector<Vector5d> pref_;
     111             :   std::vector<Vector5d> invs2_;
     112             :   std::vector<Vector5d> cfact_;
     113             :   std::vector<double> cut_;
     114             : // metainference
     115             :   unsigned nrep_;
     116             :   unsigned replica_;
     117             :   std::vector<double> ismin_;
     118             : // neighbor list
     119             :   double nl_dist_cutoff_;
     120             :   double nl_gauss_cutoff_;
     121             :   unsigned nl_stride_;
     122             :   bool first_time_;
     123             :   std::vector< std::pair<unsigned,unsigned> > nl_;
     124             :   std::vector< std::pair<unsigned,unsigned> > ns_;
     125             :   std::vector<Vector> refpos_;
     126             : // averaging
     127             :   bool no_aver_;
     128             : // correlation;
     129             :   bool do_corr_;
     130             : // Monte Carlo stuff
     131             :   Random   random_;
     132             :   // Scale and Offset
     133             :   double scale_;
     134             :   double offset_;
     135             : // Bfact Monte Carlo
     136             :   double   dbfact_;
     137             :   double   bfactmin_;
     138             :   double   bfactmax_;
     139             :   double   bfactsig_;
     140             :   bool     bfactnoc_;
     141             :   bool     bfactread_;
     142             :   int      MCBstride_;
     143             :   double   MCBaccept_;
     144             :   double   MCBtrials_;
     145             : // residue neighbor list
     146             :   std::vector< std::vector<unsigned> > nl_res_;
     147             :   bool bfactemin_;
     148             : // Martini scattering factors
     149             :   bool martini_;
     150             :   // status stuff
     151             :   unsigned int statusstride_;
     152             :   std::string       statusfilename_;
     153             :   OFile        statusfile_;
     154             :   bool         first_status_;
     155             :   // total energy and virial
     156             :   double ene_;
     157             :   Tensor virial_;
     158             :   double eps_;
     159             :   // model density file
     160             :   unsigned int mapstride_;
     161             :   std::string       mapfilename_;
     162             :   // Libtorch stuff
     163             :   bool gpu_;
     164             :   torch::Tensor ovmd_gpu_;
     165             :   torch::Tensor ovmd_der_gpu_;
     166             :   torch::Tensor ismin_gpu_;
     167             :   torch::Tensor ovdd_gpu_;
     168             :   torch::Tensor Map_m_gpu_;
     169             :   torch::Tensor pref_gpu_;
     170             :   torch::Tensor invs2_gpu_;
     171             :   torch::Tensor nl_id_gpu_;
     172             :   torch::Tensor nl_im_gpu_;
     173             :   torch::Tensor pref_nl_gpu_;
     174             :   torch::Tensor invs2_nl_gpu_;
     175             :   torch::Tensor Map_m_nl_gpu_;
     176             :   torch::DeviceType device_t_;
     177             : //
     178             : // write file with model density
     179             :   void write_model_density(long int step);
     180             : // get median of vector
     181             :   double get_median(std::vector<double> v);
     182             : // read and write status
     183             :   void read_status();
     184             :   void print_status(long int step);
     185             : // accept or reject
     186             :   bool doAccept(double oldE, double newE, double kbt);
     187             : // vector of close residues
     188             :   void get_close_residues();
     189             : // do MonteCarlo for Bfactor
     190             :   void doMonteCarloBfact();
     191             : // calculate model parameters
     192             :   std::vector<double> get_Model_param(std::vector<AtomNumber> &atoms);
     193             : // read data file
     194             :   void get_exp_data(const std::string &datafile);
     195             : // auxiliary methods
     196             :   void prepare_gpu();
     197             :   void initialize_Bfactor(double reso);
     198             :   void get_auxiliary_vectors();
     199             :   void push_auxiliary_gpu();
     200             : // calculate overlap between two Gaussians
     201             :   double get_overlap(const Vector &d_m, const Vector &m_m,
     202             :                      const Vector5d &cfact, const Vector5d &m_s, double bfact);
     203             : // update the neighbor list
     204             :   void update_neighbor_list();
     205             : // update data on device
     206             :   void update_gpu();
     207             : // update the neighbor sphere
     208             :   void update_neighbor_sphere();
     209             :   bool do_neighbor_sphere();
     210             : // calculate forward model and score on device
     211             :   void calculate_fmod();
     212             :   void calculate_score();
     213             : // calculate correlation
     214             :   void calculate_corr();
     215             : 
     216             : public:
     217             :   static void registerKeywords( Keywords& keys );
     218             :   explicit EMMIVOX(const ActionOptions&);
     219             : // active methods:
     220             :   void prepare() override;
     221             :   void calculate() override;
     222             : };
     223             : 
     224             : PLUMED_REGISTER_ACTION(EMMIVOX,"EMMIVOX")
     225             : 
     226           2 : void EMMIVOX::registerKeywords( Keywords& keys ) {
     227           2 :   Colvar::registerKeywords( keys );
     228           4 :   keys.add("atoms","ATOMS","atoms used in the calculation of the density map, typically all heavy atoms");
     229           4 :   keys.add("compulsory","DATA_FILE","file with cryo-EM map");
     230           4 :   keys.add("compulsory","RESOLUTION", "cryo-EM map resolution");
     231           4 :   keys.add("compulsory","NORM_DENSITY","integral of experimental density");
     232           4 :   keys.add("compulsory","WRITE_STRIDE","stride for writing status file");
     233           4 :   keys.add("optional","NL_DIST_CUTOFF","neighbor list distance cutoff");
     234           4 :   keys.add("optional","NL_GAUSS_CUTOFF","neighbor list Gaussian sigma cutoff");
     235           4 :   keys.add("optional","NL_STRIDE","neighbor list update frequency");
     236           4 :   keys.add("optional","SIGMA_MIN","minimum density error");
     237           4 :   keys.add("optional","DBFACT","Bfactor MC step");
     238           4 :   keys.add("optional","BFACT_MAX","Bfactor maximum value");
     239           4 :   keys.add("optional","MCBFACT_STRIDE", "Bfactor MC stride");
     240           4 :   keys.add("optional","BFACT_SIGMA","Bfactor sigma prior");
     241           4 :   keys.add("optional","STATUS_FILE","write a file with all the data useful for restart");
     242           4 :   keys.add("optional","SCALE","scale factor");
     243           4 :   keys.add("optional","OFFSET","offset");
     244           4 :   keys.add("optional","TEMP","temperature");
     245           4 :   keys.add("optional","WRITE_MAP","file with model density");
     246           4 :   keys.add("optional","WRITE_MAP_STRIDE","stride for writing model density to file");
     247           4 :   keys.addFlag("NO_AVER",false,"no ensemble averaging in multi-replica mode");
     248           4 :   keys.addFlag("CORRELATION",false,"calculate correlation coefficient");
     249           4 :   keys.addFlag("GPU",false,"calculate EMMIVOX on GPU with Libtorch");
     250           4 :   keys.addFlag("BFACT_NOCHAIN",false,"Do not use chain ID for Bfactor MC");
     251           4 :   keys.addFlag("BFACT_READ",false,"Read Bfactor on RESTART (automatic with DBFACT>0)");
     252           4 :   keys.addFlag("BFACT_MINIMIZE",false,"Accept only moves that decrease energy");
     253           4 :   keys.addFlag("MARTINI",false,"Use Martini scattering factors");
     254           4 :   keys.addOutputComponent("scoreb","default","scalar","Bayesian score");
     255           4 :   keys.addOutputComponent("scale", "default","scalar","scale factor");
     256           4 :   keys.addOutputComponent("offset","default","scalar","offset");
     257           4 :   keys.addOutputComponent("accB",  "default","scalar", "Bfactor MC acceptance");
     258           4 :   keys.addOutputComponent("kbt",   "default","scalar", "temperature in energy unit");
     259           4 :   keys.addOutputComponent("corr",  "CORRELATION","scalar", "correlation coefficient");
     260           2 : }
     261             : 
     262           0 : EMMIVOX::EMMIVOX(const ActionOptions&ao):
     263             :   PLUMED_COLVAR_INIT(ao),
     264           0 :   nl_dist_cutoff_(1.0), nl_gauss_cutoff_(3.0), nl_stride_(50),
     265           0 :   first_time_(true), no_aver_(false), do_corr_(false),
     266           0 :   scale_(1.), offset_(0.),
     267           0 :   dbfact_(0.0), bfactmin_(0.05), bfactmax_(5.0),
     268           0 :   bfactsig_(0.1), bfactnoc_(false), bfactread_(false),
     269           0 :   MCBstride_(1), MCBaccept_(0.), MCBtrials_(0.), bfactemin_(false),
     270           0 :   martini_(false), statusstride_(0), first_status_(true),
     271           0 :   eps_(0.0001), mapstride_(0), gpu_(false)
     272             : {
     273             :   // set constants
     274           0 :   inv_sqrt2_ = 1.0/sqrt(2.0);
     275           0 :   sqrt2_pi_  = sqrt(2.0 / M_PI);
     276           0 :   inv_pi2_   = 0.5 / M_PI / M_PI;
     277             : 
     278             :   // list of atoms
     279             :   std::vector<AtomNumber> atoms;
     280           0 :   parseAtomList("ATOMS", atoms);
     281             : 
     282             :   // file with experimental cryo-EM map
     283             :   std::string datafile;
     284           0 :   parse("DATA_FILE", datafile);
     285             : 
     286             :   // neighbor list cutoffs
     287           0 :   parse("NL_DIST_CUTOFF",nl_dist_cutoff_);
     288           0 :   parse("NL_GAUSS_CUTOFF",nl_gauss_cutoff_);
     289             :   // checks
     290           0 :   if(nl_dist_cutoff_<=0. && nl_gauss_cutoff_<=0.) error("You must specify either NL_DIST_CUTOFF or NL_GAUSS_CUTOFF or both");
     291           0 :   if(nl_gauss_cutoff_<=0.) nl_gauss_cutoff_ = 1.0e+10;
     292           0 :   if(nl_dist_cutoff_<=0.) nl_dist_cutoff_ = 1.0e+10;
     293             :   // neighbor list update stride
     294           0 :   parse("NL_STRIDE",nl_stride_);
     295           0 :   if(nl_stride_<=0) error("NL_STRIDE must be explicitly specified and positive");
     296             : 
     297             :   // minimum value for error
     298           0 :   double sigma_min = 0.2;
     299           0 :   parse("SIGMA_MIN", sigma_min);
     300           0 :   if(sigma_min<0.) error("SIGMA_MIN must be greater or equal to zero");
     301             : 
     302             :   // status file parameters
     303           0 :   parse("WRITE_STRIDE", statusstride_);
     304           0 :   if(statusstride_<=0) error("you must specify a positive WRITE_STRIDE");
     305           0 :   parse("STATUS_FILE",  statusfilename_);
     306           0 :   if(statusfilename_=="") statusfilename_ = "EMMIStatus"+getLabel();
     307             : 
     308             :   // integral of the experimetal density
     309             :   double norm_d;
     310           0 :   parse("NORM_DENSITY", norm_d);
     311             : 
     312             :   // temperature
     313           0 :   kbt_ = getkBT();
     314             : 
     315             :   // scale and offset
     316           0 :   parse("SCALE", scale_);
     317           0 :   parse("OFFSET",offset_);
     318             : 
     319             :   // B-factors MC
     320           0 :   parse("DBFACT",dbfact_);
     321             :   // read Bfactors
     322           0 :   parseFlag("BFACT_READ",bfactread_);
     323             :   // do not use chains
     324           0 :   parseFlag("BFACT_NOCHAIN",bfactnoc_);
     325             :   // other parameters
     326           0 :   if(dbfact_>0.) {
     327           0 :     parse("MCBFACT_STRIDE",MCBstride_);
     328           0 :     parse("BFACT_MAX",bfactmax_);
     329           0 :     parse("BFACT_SIGMA",bfactsig_);
     330           0 :     parseFlag("BFACT_MINIMIZE",bfactemin_);
     331             :     // checks
     332           0 :     if(MCBstride_<=0) error("you must specify a positive MCBFACT_STRIDE");
     333           0 :     if(bfactmax_<=bfactmin_) error("you must specify a positive BFACT_MAX");
     334           0 :     if(MCBstride_%nl_stride_!=0) error("MCBFACT_STRIDE must be multiple of NL_STRIDE");
     335           0 :     if(bfactsig_<=0.) error("you must specify a positive BFACT_SIGMA");
     336             :   }
     337             : 
     338             :   // read map resolution
     339             :   double reso;
     340           0 :   parse("RESOLUTION", reso);
     341           0 :   if(reso<=0.) error("RESOLUTION must be strictly positive");
     342             : 
     343             :   // averaging or not
     344           0 :   parseFlag("NO_AVER",no_aver_);
     345             : 
     346             :   // calculate correlation coefficient
     347           0 :   parseFlag("CORRELATION",do_corr_);
     348             : 
     349             :   // write density file
     350           0 :   parse("WRITE_MAP_STRIDE", mapstride_);
     351           0 :   parse("WRITE_MAP", mapfilename_);
     352           0 :   if(mapstride_>0 && mapfilename_=="") error("With WRITE_MAP_STRIDE you must specify WRITE_MAP");
     353             : 
     354             :   // use GPU?
     355           0 :   parseFlag("GPU",gpu_);
     356             :   // set device
     357           0 :   if (gpu_ && torch::cuda::is_available()) {
     358           0 :     device_t_ = torch::kCUDA;
     359             :   } else {
     360           0 :     device_t_ = torch::kCPU;
     361           0 :     gpu_ = false;
     362             :   }
     363             : 
     364             : // Martini model
     365           0 :   parseFlag("MARTINI",martini_);
     366             : 
     367             :   // check read
     368           0 :   checkRead();
     369             : 
     370             :   // set parallel stuff
     371           0 :   unsigned mpisize=comm.Get_size();
     372           0 :   if(mpisize>1) error("EMMIVOX supports only OpenMP parallelization");
     373             : 
     374             :   // get number of replicas
     375           0 :   if(no_aver_) {
     376           0 :     nrep_ = 1;
     377             :   } else {
     378           0 :     nrep_ = multi_sim_comm.Get_size();
     379             :   }
     380           0 :   replica_ = multi_sim_comm.Get_rank();
     381             : 
     382           0 :   if(nrep_>1 && dbfact_>0) error("Bfactor sampling not supported with ensemble averaging");
     383             : 
     384           0 :   log.printf("  number of atoms involved : %u\n", atoms.size());
     385           0 :   log.printf("  experimental density map : %s\n", datafile.c_str());
     386           0 :   if(no_aver_) log.printf("  without ensemble averaging\n");
     387           0 :   if(gpu_) {log.printf("  running on GPU \n");}
     388           0 :   else {log.printf("  running on CPU \n");}
     389           0 :   if(nl_dist_cutoff_ <1.0e+10) log.printf("  neighbor list distance cutoff : %lf\n", nl_dist_cutoff_);
     390           0 :   if(nl_gauss_cutoff_<1.0e+10) log.printf("  neighbor list Gaussian sigma cutoff : %lf\n", nl_gauss_cutoff_);
     391           0 :   log.printf("  neighbor list update stride : %u\n",  nl_stride_);
     392           0 :   log.printf("  minimum density error : %f\n", sigma_min);
     393           0 :   log.printf("  scale factor : %lf\n", scale_);
     394           0 :   log.printf("  offset : %lf\n", offset_);
     395           0 :   log.printf("  reading/writing to status file : %s\n", statusfilename_.c_str());
     396           0 :   log.printf("  with stride : %u\n", statusstride_);
     397           0 :   if(dbfact_>0) {
     398           0 :     log.printf("  maximum Bfactor MC move : %f\n", dbfact_);
     399           0 :     log.printf("  stride MC move : %u\n", MCBstride_);
     400           0 :     log.printf("  using prior with sigma : %f\n", bfactsig_);
     401             :   }
     402           0 :   if(bfactread_) log.printf("  reading Bfactors from file : %s\n", statusfilename_.c_str());
     403           0 :   log.printf("  temperature of the system in energy unit : %f\n", kbt_);
     404           0 :   if(nrep_>1) {
     405           0 :     log.printf("  number of replicas for averaging: %u\n", nrep_);
     406           0 :     log.printf("  replica ID : %u\n", replica_);
     407             :   }
     408           0 :   if(mapstride_>0) {
     409           0 :     log.printf("  writing model density to file : %s\n", mapfilename_.c_str());
     410           0 :     log.printf("  with stride : %u\n", mapstride_);
     411             :   }
     412           0 :   if(martini_) log.printf("  using Martini scattering factors\n");
     413             : 
     414             :   // calculate model constant parameters
     415           0 :   std::vector<double> Model_w = get_Model_param(atoms);
     416             : 
     417             :   // read experimental map and errors
     418           0 :   get_exp_data(datafile);
     419           0 :   log.printf("  number of voxels : %u\n", static_cast<unsigned>(ovdd_.size()));
     420             : 
     421             :   // normalize atom weight map
     422             :   double norm_m = accumulate(Model_w.begin(),  Model_w.end(),  0.0);
     423             :   // renormalization and constant factor on atom types
     424           0 :   for(unsigned i=0; i<Model_w_.size(); ++i) {
     425           0 :     Vector5d cf;
     426           0 :     for(unsigned j=0; j<5; ++j) {
     427           0 :       Model_w_[i][j] *= norm_d / norm_m;
     428           0 :       cf[j] = Model_w_[i][j]/pow( 2.0*pi, 1.5 );
     429             :     }
     430           0 :     cfact_.push_back(cf);
     431             :   }
     432             : 
     433             :   // median density
     434           0 :   double ovdd_m = get_median(ovdd_);
     435             :   // median experimental error
     436           0 :   double err_m  = get_median(exp_err_);
     437             :   // minimum error
     438           0 :   double minerr = sigma_min*ovdd_m;
     439             :   // print out statistics
     440           0 :   log.printf("     median density : %lf\n", ovdd_m);
     441           0 :   log.printf("     minimum error  : %lf\n", minerr);
     442           0 :   log.printf("     median error   : %lf\n", err_m);
     443             :   // populate ismin: cycle on all voxels
     444           0 :   for(unsigned id=0; id<ovdd_.size(); ++id) {
     445             :     // define smin
     446           0 :     double smin = std::max(minerr, exp_err_[id]);
     447             :     // and to ismin_
     448           0 :     ismin_.push_back(1.0/smin);
     449             :   }
     450             : 
     451             :   // prepare gpu stuff: map centers, data, and error
     452           0 :   prepare_gpu();
     453             : 
     454             :   // initialize Bfactors
     455           0 :   initialize_Bfactor(reso);
     456             : 
     457             :   // read status file if restarting
     458           0 :   if(getRestart() || bfactread_) read_status();
     459             : 
     460             :   // prepare auxiliary vectors
     461           0 :   get_auxiliary_vectors();
     462             : 
     463             :   // prepare other vectors: data and derivatives
     464           0 :   ovmd_.resize(ovdd_.size());
     465           0 :   atom_der_.resize(Model_type_.size());
     466           0 :   score_der_.resize(ovdd_.size());
     467             : 
     468             :   // add components
     469           0 :   addComponentWithDerivatives("scoreb"); componentIsNotPeriodic("scoreb");
     470           0 :   addComponent("scale");                 componentIsNotPeriodic("scale");
     471           0 :   addComponent("offset");                componentIsNotPeriodic("offset");
     472           0 :   addComponent("kbt");                   componentIsNotPeriodic("kbt");
     473           0 :   if(dbfact_>0)   {addComponent("accB"); componentIsNotPeriodic("accB");}
     474           0 :   if(do_corr_)    {addComponent("corr"); componentIsNotPeriodic("corr");}
     475             : 
     476             :   // initialize random seed
     477           0 :   unsigned iseed = time(NULL)+replica_;
     478           0 :   random_.setSeed(-iseed);
     479             : 
     480             :   // request atoms
     481           0 :   requestAtoms(atoms);
     482             : 
     483             :   // print bibliography
     484           0 :   log<<"  Bibliography "<<plumed.cite("Bonomi, Camilloni, Bioinformatics, 33, 3999 (2017)");
     485           0 :   log<<plumed.cite("Hoff, Thomasen, Lindorff-Larsen, Bonomi, bioRxiv (2023) doi: 10.1101/2023.10.18.562710");
     486           0 :   if(!no_aver_ && nrep_>1)log<<plumed.cite("Bonomi, Camilloni, Cavalli, Vendruscolo, Sci. Adv. 2, e150117 (2016)");
     487           0 :   log<<"\n";
     488           0 : }
     489             : 
     490           0 : void EMMIVOX::prepare_gpu()
     491             : {
     492             :   // number of data points
     493           0 :   int nd = ovdd_.size();
     494             :   // 1) put ismin_ on device_t_
     495           0 :   ismin_gpu_ = torch::from_blob(ismin_.data(), {nd}, torch::kFloat64).to(torch::kFloat32).to(device_t_);
     496             :   // 2) put ovdd_ on device_t_
     497           0 :   ovdd_gpu_  = torch::from_blob(ovdd_.data(),  {nd}, torch::kFloat64).to(torch::kFloat32).to(device_t_);
     498             :   // 3) put Map_m_ on device_t_
     499           0 :   std::vector<double> Map_m_gpu(3*nd);
     500           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
     501             :   for(int i=0; i<nd; ++i) {
     502             :     Map_m_gpu[i]      = Map_m_[i][0];
     503             :     Map_m_gpu[i+nd]   = Map_m_[i][1];
     504             :     Map_m_gpu[i+2*nd] = Map_m_[i][2];
     505             :   }
     506             :   // libtorch tensor
     507           0 :   Map_m_gpu_ = torch::from_blob(Map_m_gpu.data(), {3,nd}, torch::kFloat64).clone().to(torch::kFloat32).to(device_t_);
     508           0 : }
     509             : 
     510           0 : void EMMIVOX::write_model_density(long int step)
     511             : {
     512           0 :   OFile ovfile;
     513           0 :   ovfile.link(*this);
     514           0 :   std::string num; Tools::convert(step,num);
     515           0 :   std::string name = mapfilename_+"-"+num;
     516           0 :   ovfile.open(name);
     517             :   ovfile.setHeavyFlush();
     518           0 :   ovfile.fmtField("%10.7e ");
     519             : // write density
     520           0 :   for(unsigned i=0; i<ovmd_.size(); ++i) {
     521           0 :     ovfile.printField("Model", ovmd_[i]);
     522           0 :     ovfile.printField("ModelScaled", scale_ * ovmd_[i] + offset_);
     523           0 :     ovfile.printField("Data", ovdd_[i]);
     524           0 :     ovfile.printField();
     525             :   }
     526           0 :   ovfile.close();
     527           0 : }
     528             : 
     529           0 : double EMMIVOX::get_median(std::vector<double> v)
     530             : {
     531             : // dimension of vector
     532           0 :   unsigned size = v.size();
     533             : // in case of only one entry
     534           0 :   if (size==1) {
     535           0 :     return v[0];
     536             :   } else {
     537             :     // reorder vector
     538           0 :     sort(v.begin(), v.end());
     539             :     // odd or even?
     540           0 :     if (size%2==0) {
     541           0 :       return (v[size/2-1]+v[size/2])/2.0;
     542             :     } else {
     543           0 :       return v[size/2];
     544             :     }
     545             :   }
     546             : }
     547             : 
     548           0 : void EMMIVOX::read_status()
     549             : {
     550             :   double MDtime;
     551             : // open file
     552           0 :   IFile *ifile = new IFile();
     553           0 :   ifile->link(*this);
     554           0 :   if(ifile->FileExist(statusfilename_)) {
     555           0 :     ifile->open(statusfilename_);
     556           0 :     while(ifile->scanField("MD_time", MDtime)) {
     557             :       // read scale and offset
     558           0 :       ifile->scanField("scale", scale_);
     559           0 :       ifile->scanField("offset", offset_);
     560             :       // read bfactors if doing fitting of reading it at restart
     561           0 :       if(dbfact_>0 || bfactread_) {
     562             :         // cycle on residues
     563           0 :         for(unsigned ir=0; ir<Model_rlist_.size(); ++ir) {
     564             :           // key: pair of residue/chain IDs
     565             :           std::pair<unsigned,std::string> key = Model_rlist_[ir];
     566             :           // convert ires to std::string
     567           0 :           std::string num; Tools::convert(key.first,num);
     568             :           // read entry
     569           0 :           std::string ch = key.second;
     570           0 :           if(ch==" ") ch="";
     571           0 :           ifile->scanField("bf-"+num+":"+ch, Model_b_[key]);
     572             :         }
     573             :       }
     574             :       // new line
     575           0 :       ifile->scanField();
     576             :     }
     577           0 :     ifile->close();
     578             :   } else {
     579           0 :     error("Cannot find status file "+statusfilename_+"\n");
     580             :   }
     581           0 :   delete ifile;
     582           0 : }
     583             : 
     584           0 : void EMMIVOX::print_status(long int step)
     585             : {
     586             : // if first time open the file
     587           0 :   if(first_status_) {
     588           0 :     first_status_ = false;
     589           0 :     statusfile_.link(*this);
     590           0 :     statusfile_.open(statusfilename_);
     591             :     statusfile_.setHeavyFlush();
     592           0 :     statusfile_.fmtField("%10.7e ");
     593             :   }
     594             : // write fields
     595           0 :   double MDtime = static_cast<double>(step)*getTimeStep();
     596           0 :   statusfile_.printField("MD_time", MDtime);
     597             :   // write scale and offset
     598           0 :   statusfile_.printField("scale", scale_);
     599           0 :   statusfile_.printField("offset", offset_);
     600             :   // write bfactors only if doing fitting or reading bfactors
     601           0 :   if(dbfact_>0 || bfactread_) {
     602             :     // cycle on residues
     603           0 :     for(unsigned ir=0; ir<Model_rlist_.size(); ++ir) {
     604             :       // key: pair of residue/chain IDs
     605             :       std::pair<unsigned,std::string> key = Model_rlist_[ir];
     606             :       // bfactor from map
     607           0 :       double bf = Model_b_[key];
     608             :       // convert ires to std::string
     609           0 :       std::string num; Tools::convert(key.first,num);
     610             :       // print entry
     611           0 :       statusfile_.printField("bf-"+num+":"+key.second, bf);
     612             :     }
     613             :   }
     614           0 :   statusfile_.printField();
     615           0 : }
     616             : 
     617           0 : bool EMMIVOX::doAccept(double oldE, double newE, double kbt)
     618             : {
     619             :   bool accept = false;
     620             :   // calculate delta energy
     621           0 :   double delta = ( newE - oldE ) / kbt;
     622             :   // if delta is negative always accept move
     623           0 :   if( delta < 0.0 ) {
     624             :     accept = true;
     625             :   } else {
     626             :     // otherwise extract random number
     627           0 :     double s = random_.RandU01();
     628           0 :     if( s < exp(-delta) ) { accept = true; }
     629             :   }
     630           0 :   return accept;
     631             : }
     632             : 
     633           0 : std::vector<double> EMMIVOX::get_Model_param(std::vector<AtomNumber> &atoms)
     634             : {
     635             :   // check if MOLINFO line is present
     636           0 :   auto* moldat=plumed.getActionSet().selectLatest<GenericMolInfo*>(this);
     637           0 :   if(!moldat) error("MOLINFO DATA not found\n");
     638           0 :   log<<"  MOLINFO DATA found with label " <<moldat->getLabel()<<", using proper atom names\n";
     639             : 
     640             :   // list of weights - one per atom
     641             :   std::vector<double> Model_w;
     642             :   // 5-Gaussians parameters
     643             :   // map of atom types to A and B coefficients of scattering factor
     644             :   // f(s) = A * exp(-B*s**2)
     645             :   // B is in Angstrom squared
     646             :   // Elastic atomic scattering factors of electrons for neutral atoms
     647             :   // and s up to 6.0 A^-1: as implemented in PLUMED
     648             :   // map between an atom type and an index
     649             :   std::map<std::string, unsigned> type_map;
     650             :   // atomistic types
     651           0 :   type_map["C"]=0;  type_map["O"]=1;  type_map["N"]=2;
     652           0 :   type_map["S"]=3;  type_map["P"]=4;  type_map["F"]=5;
     653           0 :   type_map["NA"]=6; type_map["MG"]=7; type_map["CL"]=8;
     654           0 :   type_map["CA"]=9; type_map["K"]=10; type_map["ZN"]=11;
     655             :   // Martini types
     656           0 :   type_map["ALA_BB"]=12;    type_map["ALA_SC1"]=13;   type_map["CYS_BB"]=14;
     657           0 :   type_map["CYS_SC1"]=15;   type_map["ASP_BB"]=16;    type_map["ASP_SC1"]=17;
     658           0 :   type_map["GLU_BB"]=18;    type_map["GLU_SC1"]=19;   type_map["PHE_BB"]=20;
     659           0 :   type_map["PHE_SC1"]=21;   type_map["PHE_SC2"]=22;   type_map["PHE_SC3"]=23;
     660           0 :   type_map["GLY_BB"]=24;    type_map["HIS_BB"]=25;    type_map["HIS_SC1"]=26;
     661           0 :   type_map["HIS_SC2"]=27;   type_map["HIS_SC3"]=28;   type_map["ILE_BB"]=29;
     662           0 :   type_map["ILE_SC1"]=30;   type_map["LYS_BB"]=31;    type_map["LYS_SC1"]=32;
     663           0 :   type_map["LYS_SC2"]=33;   type_map["LEU_BB"]=34;    type_map["LEU_SC1"]=35;
     664           0 :   type_map["MET_BB"]=36;    type_map["MET_SC1"]=37;   type_map["ASN_BB"]=38;
     665           0 :   type_map["ASN_SC1"]=39;   type_map["PRO_BB"]=40;    type_map["PRO_SC1"]=41;
     666           0 :   type_map["GLN_BB"]=42;    type_map["GLN_SC1"]=43;   type_map["ARG_BB"]=44;
     667           0 :   type_map["ARG_SC1"]=45;   type_map["ARG_SC2"]=46;   type_map["SER_BB"]=47;
     668           0 :   type_map["SER_SC1"]=48;   type_map["THR_BB"]=49;    type_map["THR_SC1"]=50;
     669           0 :   type_map["VAL_BB"]=51;    type_map["VAL_SC1"]=52;   type_map["TRP_BB"]=53;
     670           0 :   type_map["TRP_SC1"]=54;   type_map["TRP_SC2"]=55;   type_map["TRP_SC3"]=56;
     671           0 :   type_map["TRP_SC4"]=57;   type_map["TRP_SC5"]=58;   type_map["TYR_BB"]=59;
     672           0 :   type_map["TYR_SC1"]=60;   type_map["TYR_SC2"]=61;   type_map["TYR_SC3"]=62;
     673           0 :   type_map["TYR_SC4"]=63;
     674             :   // fill in sigma vector for atoms
     675           0 :   Model_s_.push_back(0.01*Vector5d(0.114,1.0825,5.4281,17.8811,51.1341));   // C
     676           0 :   Model_s_.push_back(0.01*Vector5d(0.0652,0.6184,2.9449,9.6298,28.2194));   // O
     677           0 :   Model_s_.push_back(0.01*Vector5d(0.0541,0.5165,2.8207,10.6297,34.3764));  // N
     678           0 :   Model_s_.push_back(0.01*Vector5d(0.0838,0.7788,4.3462,15.5846,44.63655)); // S
     679           0 :   Model_s_.push_back(0.01*Vector5d(0.0977,0.9084,4.9654,18.5471,54.3648));  // P
     680           0 :   Model_s_.push_back(0.01*Vector5d(0.0613,0.5753,2.6858,8.8214,25.6668));   // F
     681           0 :   Model_s_.push_back(0.01*Vector5d(0.1684,1.7150,8.8386,50.8265,147.2073)); // NA
     682           0 :   Model_s_.push_back(0.01*Vector5d(0.1356,1.3579,6.9255,32.3165,92.1138));  // MG
     683           0 :   Model_s_.push_back(0.01*Vector5d(0.0694,0.6443,3.5351,12.5058,35.8633));  // CL
     684           0 :   Model_s_.push_back(0.01*Vector5d(0.1742,1.8329,8.8407,47.4583,134.9613)); // CA
     685           0 :   Model_s_.push_back(0.01*Vector5d(0.1660,1.6906,8.7447,46.7825,165.6923)); // K
     686           0 :   Model_s_.push_back(0.01*Vector5d(0.0876,0.8650,3.8612,18.8726,64.7016));  // ZN
     687             :   // fill in sigma vector for Martini beads
     688           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // ALA_BB
     689           0 :   Model_s_.push_back(0.01*Vector5d(0.500000,0.500000,0.500000,0.500000,0.500000)); // ALA_SC1
     690           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // CYS_BB
     691           0 :   Model_s_.push_back(0.01*Vector5d(8.500000,8.500000,8.500000,8.500000,8.500000)); // CYS_SC1
     692           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // ASP_BB
     693           0 :   Model_s_.push_back(0.01*Vector5d(17.000000,17.000000,17.000000,17.000000,17.000000)); // ASP_SC1
     694           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // GLU_BB
     695           0 :   Model_s_.push_back(0.01*Vector5d(24.000000,24.000000,24.000000,24.000000,24.000000)); // GLU_SC1
     696           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // PHE_BB
     697           0 :   Model_s_.push_back(0.01*Vector5d(17.500000,17.500000,17.500000,17.500000,17.500000)); // PHE_SC1
     698           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // PHE_SC2
     699           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // PHE_SC3
     700           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // GLY_BB
     701           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // HIS_BB
     702           0 :   Model_s_.push_back(0.01*Vector5d(11.500000,11.500000,11.500000,11.500000,11.500000)); // HIS_SC1
     703           0 :   Model_s_.push_back(0.01*Vector5d(9.000000,9.000000,9.000000,9.000000,9.000000)); // HIS_SC2
     704           0 :   Model_s_.push_back(0.01*Vector5d(8.500000,8.500000,8.500000,8.500000,8.500000)); // HIS_SC3
     705           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // ILE_BB
     706           0 :   Model_s_.push_back(0.01*Vector5d(25.500000,25.500000,25.500000,25.500000,25.500000)); // ILE_SC1
     707           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // LYS_BB
     708           0 :   Model_s_.push_back(0.01*Vector5d(18.000000,18.000000,18.000000,18.000000,18.000000)); // LYS_SC1
     709           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // LYS_SC2
     710           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // LEU_BB
     711           0 :   Model_s_.push_back(0.01*Vector5d(21.500000,21.500000,21.500000,21.500000,21.500000)); // LEU_SC1
     712           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // MET_BB
     713           0 :   Model_s_.push_back(0.01*Vector5d(22.500000,22.500000,22.500000,22.500000,22.500000)); // MET_SC1
     714           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // ASN_BB
     715           0 :   Model_s_.push_back(0.01*Vector5d(18.500000,18.500000,18.500000,18.500000,18.500000)); // ASN_SC1
     716           0 :   Model_s_.push_back(0.01*Vector5d(23.500000,23.500000,23.500000,23.500000,23.500000)); // PRO_BB
     717           0 :   Model_s_.push_back(0.01*Vector5d(17.500000,17.500000,17.500000,17.500000,17.500000)); // PRO_SC1
     718           0 :   Model_s_.push_back(0.01*Vector5d(22.000000,22.000000,22.000000,22.000000,22.000000)); // GLN_BB
     719           0 :   Model_s_.push_back(0.01*Vector5d(24.500000,24.500000,24.500000,24.500000,24.500000)); // GLN_SC1
     720           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // ARG_BB
     721           0 :   Model_s_.push_back(0.01*Vector5d(18.000000,18.000000,18.000000,18.000000,18.000000)); // ARG_SC1
     722           0 :   Model_s_.push_back(0.01*Vector5d(18.000000,18.000000,18.000000,18.000000,18.000000)); // ARG_SC2
     723           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // SER_BB
     724           0 :   Model_s_.push_back(0.01*Vector5d(9.000000,9.000000,9.000000,9.000000,9.000000)); // SER_SC1
     725           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // THR_BB
     726           0 :   Model_s_.push_back(0.01*Vector5d(17.000000,17.000000,17.000000,17.000000,17.000000)); // THR_SC1
     727           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // VAL_BB
     728           0 :   Model_s_.push_back(0.01*Vector5d(18.000000,18.000000,18.000000,18.000000,18.000000)); // VAL_SC1
     729           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // TRP_BB
     730           0 :   Model_s_.push_back(0.01*Vector5d(11.500000,11.500000,11.500000,11.500000,11.500000)); // TRP_SC1
     731           0 :   Model_s_.push_back(0.01*Vector5d(9.000000,9.000000,9.000000,9.000000,9.000000)); // TRP_SC2
     732           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // TRP_SC3
     733           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // TRP_SC4
     734           0 :   Model_s_.push_back(0.01*Vector5d(9.500000,9.500000,9.500000,9.500000,9.500000)); // TRP_SC5
     735           0 :   Model_s_.push_back(0.01*Vector5d(23.000000,23.000000,23.000000,23.000000,23.000000)); // TYR_BB
     736           0 :   Model_s_.push_back(0.01*Vector5d(12.000000,12.000000,12.000000,12.000000,12.000000)); // TYR_SC1
     737           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // TYR_SC2
     738           0 :   Model_s_.push_back(0.01*Vector5d(11.000000,11.000000,11.000000,11.000000,11.000000)); // TYR_SC3
     739           0 :   Model_s_.push_back(0.01*Vector5d(8.500000,8.500000,8.500000,8.500000,8.500000)); // TYR_SC4
     740             :   // fill in weight vector for atoms
     741           0 :   Model_w_.push_back(Vector5d(0.0489,0.2091,0.7537,1.1420,0.3555)); // C
     742           0 :   Model_w_.push_back(Vector5d(0.0365,0.1729,0.5805,0.8814,0.3121)); // O
     743           0 :   Model_w_.push_back(Vector5d(0.0267,0.1328,0.5301,1.1020,0.4215)); // N
     744           0 :   Model_w_.push_back(Vector5d(0.0915,0.4312,1.0847,2.4671,1.0852)); // S
     745           0 :   Model_w_.push_back(Vector5d(0.1005,0.4615,1.0663,2.5854,1.2725)); // P
     746           0 :   Model_w_.push_back(Vector5d(0.0382,0.1822,0.5972,0.7707,0.2130)); // F
     747           0 :   Model_w_.push_back(Vector5d(0.1260,0.6442,0.8893,1.8197,1.2988)); // NA
     748           0 :   Model_w_.push_back(Vector5d(0.1130,0.5575,0.9046,2.1580,1.4735)); // MG
     749           0 :   Model_w_.push_back(Vector5d(0.0799,0.3891,1.0037,2.3332,1.0507)); // CL
     750           0 :   Model_w_.push_back(Vector5d(0.2355,0.9916,2.3959,3.7252,2.5647)); // CA
     751           0 :   Model_w_.push_back(Vector5d(0.2149,0.8703,2.4999,2.3591,3.0318)); // K
     752           0 :   Model_w_.push_back(Vector5d(0.1780,0.8096,1.6744,1.9499,1.4495)); // ZN
     753             :   // fill in weight vector for Martini beads
     754           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // ALA_BB
     755           0 :   Model_w_.push_back(Vector5d(0.100000,0.100000,0.100000,0.100000,0.100000)); // ALA_SC1
     756           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // CYS_BB
     757           0 :   Model_w_.push_back(Vector5d(1.100000,1.100000,1.100000,1.100000,1.100000)); // CYS_SC1
     758           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // ASP_BB
     759           0 :   Model_w_.push_back(Vector5d(1.700000,1.700000,1.700000,1.700000,1.700000)); // ASP_SC1
     760           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // GLU_BB
     761           0 :   Model_w_.push_back(Vector5d(2.300000,2.300000,2.300000,2.300000,2.300000)); // GLU_SC1
     762           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // PHE_BB
     763           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // PHE_SC1
     764           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // PHE_SC2
     765           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // PHE_SC3
     766           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // GLY_BB
     767           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // HIS_BB
     768           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // HIS_SC1
     769           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // HIS_SC2
     770           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // HIS_SC3
     771           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // ILE_BB
     772           0 :   Model_w_.push_back(Vector5d(2.000000,2.000000,2.000000,2.000000,2.000000)); // ILE_SC1
     773           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // LYS_BB
     774           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // LYS_SC1
     775           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // LYS_SC2
     776           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // LEU_BB
     777           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // LEU_SC1
     778           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // MET_BB
     779           0 :   Model_w_.push_back(Vector5d(2.300000,2.300000,2.300000,2.300000,2.300000)); // MET_SC1
     780           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // ASN_BB
     781           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // ASN_SC1
     782           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // PRO_BB
     783           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // PRO_SC1
     784           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // GLN_BB
     785           0 :   Model_w_.push_back(Vector5d(2.300000,2.300000,2.300000,2.300000,2.300000)); // GLN_SC1
     786           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // ARG_BB
     787           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // ARG_SC1
     788           0 :   Model_w_.push_back(Vector5d(1.800000,1.800000,1.800000,1.800000,1.800000)); // ARG_SC2
     789           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // SER_BB
     790           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // SER_SC1
     791           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // THR_BB
     792           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // THR_SC1
     793           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // VAL_BB
     794           0 :   Model_w_.push_back(Vector5d(1.400000,1.400000,1.400000,1.400000,1.400000)); // VAL_SC1
     795           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // TRP_BB
     796           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TRP_SC1
     797           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // TRP_SC2
     798           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TRP_SC3
     799           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TRP_SC4
     800           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // TRP_SC5
     801           0 :   Model_w_.push_back(Vector5d(1.900000,1.900000,1.900000,1.900000,1.900000)); // TYR_BB
     802           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TYR_SC1
     803           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TYR_SC2
     804           0 :   Model_w_.push_back(Vector5d(0.900000,0.900000,0.900000,0.900000,0.900000)); // TYR_SC3
     805           0 :   Model_w_.push_back(Vector5d(0.800000,0.800000,0.800000,0.800000,0.800000)); // TYR_SC4
     806             :   // cycle on atoms
     807           0 :   for(unsigned i=0; i<atoms.size(); ++i) {
     808             :     // get atom name
     809           0 :     std::string name = moldat->getAtomName(atoms[i]);
     810             :     // get residue name
     811           0 :     std::string resname = moldat->getResidueName(atoms[i]);
     812             :     // type of atoms/bead
     813             :     std::string type_s;
     814             :     // Martini model
     815           0 :     if(martini_) {
     816           0 :       type_s = resname+"_"+name;
     817             :       // Atomistic model
     818             :     } else {
     819             :       char type;
     820             :       // get atom type
     821           0 :       char first = name.at(0);
     822             :       // GOLDEN RULE: type is first letter, if not a number
     823           0 :       if (!isdigit(first)) {
     824             :         type = first;
     825             :         // otherwise is the second
     826             :       } else {
     827           0 :         type = name.at(1);
     828             :       }
     829             :       // convert to std::string
     830           0 :       type_s = std::string(1,type);
     831             :       // special cases
     832           0 :       if(name=="SOD" || name=="NA" || name =="Na") type_s = "NA";
     833           0 :       if(name=="MG"  || name=="Mg")                type_s = "MG";
     834           0 :       if(name=="CLA" || name=="CL" || name =="Cl") type_s = "CL";
     835           0 :       if((resname=="CAL" || resname=="CA") && (name=="CAL" || name=="CA" || name =="C0")) type_s = "CA";
     836           0 :       if(name=="POT" || name=="K")                 type_s = "K";
     837           0 :       if(name=="ZN"  || name=="Zn")                type_s = "ZN";
     838             :     }
     839             :     // check if key in map
     840           0 :     if(type_map.find(type_s) != type_map.end()) {
     841             :       // save atom type
     842           0 :       Model_type_.push_back(type_map[type_s]);
     843             :       // this will be normalized in the final density
     844           0 :       Vector5d w = Model_w_[type_map[type_s]];
     845           0 :       Model_w.push_back(w[0]+w[1]+w[2]+w[3]+w[4]);
     846             :       // get residue id
     847           0 :       unsigned ires = moldat->getResidueNumber(atoms[i]);
     848             :       // and chain
     849           0 :       std::string c ("*");
     850           0 :       if(!bfactnoc_) c = moldat->getChainID(atoms[i]);
     851             :       // define pair residue/chain IDs
     852             :       std::pair<unsigned,std::string> key = std::make_pair(ires,c);
     853             :       // add to map between residue/chain and list of atoms
     854           0 :       Model_resmap_[key].push_back(i);
     855             :       // and global list of residue/chain per atom
     856           0 :       Model_res_.push_back(key);
     857             :       // initialize Bfactor map
     858           0 :       Model_b_[key] = 0.0;
     859             :     } else {
     860           0 :       error("Wrong atom type "+type_s+" from atom name "+name+"\n");
     861             :     }
     862             :   }
     863             :   // create ordered vector of residue-chain IDs
     864           0 :   for(unsigned i=0; i<Model_res_.size(); ++i) {
     865             :     std::pair<unsigned,std::string> key = Model_res_[i];
     866             :     // search in Model_rlist_
     867           0 :     if(find(Model_rlist_.begin(), Model_rlist_.end(), key) == Model_rlist_.end())
     868           0 :       Model_rlist_.push_back(key);
     869             :   }
     870             :   // return weights
     871           0 :   return Model_w;
     872             : }
     873             : 
     874             : // read experimental data file in PLUMED format:
     875           0 : void EMMIVOX::get_exp_data(const std::string &datafile)
     876             : {
     877           0 :   Vector pos;
     878             :   double dens, err;
     879             :   int idcomp;
     880             : 
     881             : // open file
     882           0 :   IFile *ifile = new IFile();
     883           0 :   if(ifile->FileExist(datafile)) {
     884           0 :     ifile->open(datafile);
     885           0 :     while(ifile->scanField("Id",idcomp)) {
     886           0 :       ifile->scanField("Pos_0",pos[0]);
     887           0 :       ifile->scanField("Pos_1",pos[1]);
     888           0 :       ifile->scanField("Pos_2",pos[2]);
     889           0 :       ifile->scanField("Density",dens);
     890           0 :       ifile->scanField("Error",err);
     891             :       // voxel center
     892           0 :       Map_m_.push_back(pos);
     893             :       // experimental density
     894           0 :       ovdd_.push_back(dens);
     895             :       // error
     896           0 :       exp_err_.push_back(err);
     897             :       // new line
     898           0 :       ifile->scanField();
     899             :     }
     900           0 :     ifile->close();
     901             :   } else {
     902           0 :     error("Cannot find DATA_FILE "+datafile+"\n");
     903             :   }
     904           0 :   delete ifile;
     905           0 : }
     906             : 
     907           0 : void EMMIVOX::initialize_Bfactor(double reso)
     908             : {
     909           0 :   double bfactini = 0.0;
     910             :   // if doing Bfactor Monte Carlo
     911           0 :   if(dbfact_>0) {
     912             :     // initialize B factor based on empirical relation between resolution and average bfactor
     913             :     // calculated on ~8000 cryo-EM data with resolution < 5 Ang
     914             :     // Bfact = A*reso**2+B; with A=6.95408 B=-2.45697/100.0 nm^2
     915           0 :     bfactini = 6.95408*reso*reso - 0.01*2.45697;
     916             :     // check for min and max
     917           0 :     bfactini = std::min(bfactmax_, std::max(bfactmin_, bfactini));
     918             :   }
     919             :   // set initial Bfactor
     920           0 :   for(std::map< std::pair<unsigned,std::string>, double>::iterator it=Model_b_.begin(); it!=Model_b_.end(); ++it) {
     921           0 :     it->second = bfactini;
     922             :   }
     923           0 :   log.printf("  experimental map resolution : %3.2f\n", reso);
     924             :   // if doing Bfactor Monte Carlo
     925           0 :   if(dbfact_>0) {
     926           0 :     log.printf("  minimum Bfactor value : %3.2f\n", bfactmin_);
     927           0 :     log.printf("  maximum Bfactor value : %3.2f\n", bfactmax_);
     928           0 :     log.printf("  initial Bfactor value : %3.2f\n", bfactini);
     929             :   }
     930           0 : }
     931             : 
     932             : // prepare auxiliary vectors
     933           0 : void EMMIVOX::get_auxiliary_vectors()
     934             : {
     935             : // number of atoms
     936           0 :   unsigned natoms = Model_res_.size();
     937             : // clear lists
     938           0 :   pref_.clear(); invs2_.clear(); cut_.clear();
     939             : // resize
     940           0 :   pref_.resize(natoms); invs2_.resize(natoms); cut_.resize(natoms);
     941             : // cycle on all atoms
     942           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
     943             :   for(unsigned im=0; im<natoms; ++im) {
     944             :     // get atom type
     945             :     unsigned atype = Model_type_[im];
     946             :     // get residue/chain IDs
     947             :     std::pair<unsigned,std::string> key = Model_res_[im];
     948             :     // get bfactor
     949             :     double bfact = Model_b_[key];
     950             :     // sigma for 5 gaussians
     951             :     Vector5d m_s = Model_s_[atype];
     952             :     // calculate constant quantities
     953             :     Vector5d pref, invs2;
     954             :     // calculate cutoff
     955             :     double n = 0.0;
     956             :     double d = 0.0;
     957             :     for(unsigned j=0; j<5; ++j) {
     958             :       // total value of b
     959             :       double m_b = m_s[j] + bfact/4.0;
     960             :       // calculate invs2
     961             :       invs2[j] = 1.0/(inv_pi2_*m_b);
     962             :       // prefactor
     963             :       pref[j]  = cfact_[atype][j] * pow(invs2[j],1.5);
     964             :       // cutoff
     965             :       n += pref[j] / invs2[j];
     966             :       d += pref[j];
     967             :     }
     968             :     // put into global lists
     969             :     pref_[im]  = pref;
     970             :     invs2_[im] = invs2;
     971             :     cut_[im] = std::min(nl_dist_cutoff_, sqrt(n/d)*nl_gauss_cutoff_);
     972             :   }
     973             :   // push to GPU
     974           0 :   push_auxiliary_gpu();
     975           0 : }
     976             : 
     977           0 : void EMMIVOX::push_auxiliary_gpu()
     978             : {
     979             :   // 1) create vector of pref_ and invs2_
     980           0 :   int natoms = Model_type_.size();
     981           0 :   std::vector<double> pref(5*natoms), invs2(5*natoms);
     982           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
     983             :   for(int i=0; i<natoms; ++i) {
     984             :     for(int j=0; j<5; ++j) {
     985             :       pref[i+j*natoms]  = pref_[i][j];
     986             :       invs2[i+j*natoms] = invs2_[i][j];
     987             :     }
     988             :   }
     989             :   // 2) initialize gpu tensors
     990           0 :   pref_gpu_  = torch::from_blob(pref.data(),  {5,natoms}, torch::kFloat64).clone().to(torch::kFloat32).to(device_t_);
     991           0 :   invs2_gpu_ = torch::from_blob(invs2.data(), {5,natoms}, torch::kFloat64).clone().to(torch::kFloat32).to(device_t_);
     992           0 : }
     993             : 
     994           0 : void EMMIVOX::get_close_residues()
     995             : {
     996             :   // clear neighbor list
     997           0 :   nl_res_.clear(); nl_res_.resize(Model_rlist_.size());
     998             : 
     999             :   // loop in parallel
    1000           0 :   #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1001             :   {
    1002             :     // private variable
    1003             :     std::vector< std::vector<unsigned> > nl_res_l(Model_rlist_.size());
    1004             :     // cycle on residues/chains #1
    1005             :     #pragma omp for
    1006             :     for(unsigned i=0; i<Model_rlist_.size()-1; ++i) {
    1007             : 
    1008             :       // key1: pair of residue/chain IDs
    1009             :       std::pair<unsigned,std::string> key1 = Model_rlist_[i];
    1010             : 
    1011             :       // cycle over residues/chains #2
    1012             :       for(unsigned j=i+1; j<Model_rlist_.size(); ++j) {
    1013             : 
    1014             :         // key2: pair of residue/chain IDs
    1015             :         std::pair<unsigned,std::string> key2 = Model_rlist_[j];
    1016             : 
    1017             :         // set flag neighbor
    1018             :         bool neigh = false;
    1019             : 
    1020             :         // cycle over all the atoms belonging to key1
    1021             :         for(unsigned im1=0; im1<Model_resmap_[key1].size(); ++im1) {
    1022             :           // get atom position #1
    1023             :           Vector pos1 = getPosition(Model_resmap_[key1][im1]);
    1024             :           // cycle over all the atoms belonging to key2
    1025             :           for(unsigned im2=0; im2<Model_resmap_[key2].size(); ++im2) {
    1026             :             // get atom position #2
    1027             :             Vector pos2 = getPosition(Model_resmap_[key2][im2]);
    1028             :             // if closer than 0.5 nm, then residues key1 and key2 are neighbors
    1029             :             if(delta(pos1,pos2).modulo()<0.5) {
    1030             :               // set neighbors
    1031             :               neigh = true;
    1032             :               // and exit
    1033             :               break;
    1034             :             }
    1035             :           }
    1036             :           // check if neighbor already found
    1037             :           if(neigh) break;
    1038             :         }
    1039             : 
    1040             :         // if neighbors, add to local list
    1041             :         if(neigh) {
    1042             :           nl_res_l[i].push_back(j);
    1043             :           nl_res_l[j].push_back(i);
    1044             :         }
    1045             :       }
    1046             :     }
    1047             :     // add to global list
    1048             :     #pragma omp critical
    1049             :     {
    1050             :       for(unsigned i=0; i<nl_res_.size(); ++i)
    1051             :         nl_res_[i].insert(nl_res_[i].end(), nl_res_l[i].begin(), nl_res_l[i].end());
    1052             :     }
    1053             :   }
    1054           0 : }
    1055             : 
    1056           0 : void EMMIVOX::doMonteCarloBfact()
    1057             : {
    1058             : // update residue neighbor list
    1059           0 :   get_close_residues();
    1060             : 
    1061             : // cycle over residues/chains
    1062           0 :   for(unsigned ir=0; ir<Model_rlist_.size(); ++ir) {
    1063             : 
    1064             :     // key: pair of residue/chain IDs
    1065             :     std::pair<unsigned,std::string> key = Model_rlist_[ir];
    1066             :     // old bfactor
    1067           0 :     double bfactold = Model_b_[key];
    1068             : 
    1069             :     // propose move in bfactor
    1070           0 :     double bfactnew = bfactold + dbfact_ * ( 2.0 * random_.RandU01() - 1.0 );
    1071             :     // check boundaries
    1072           0 :     if(bfactnew > bfactmax_) {bfactnew = 2.0*bfactmax_ - bfactnew;}
    1073           0 :     if(bfactnew < bfactmin_) {bfactnew = 2.0*bfactmin_ - bfactnew;}
    1074             : 
    1075             :     // useful quantities
    1076             :     std::map<unsigned, double> deltaov;
    1077             : 
    1078           0 :     #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1079             :     {
    1080             :       // private variables
    1081             :       std::map<unsigned, double> deltaov_l;
    1082             :       #pragma omp for
    1083             :       // cycle over all the atoms belonging to key (residue/chain)
    1084             :       for(unsigned ia=0; ia<Model_resmap_[key].size(); ++ia) {
    1085             : 
    1086             :         // get atom id
    1087             :         unsigned im = Model_resmap_[key][ia];
    1088             :         // get atom type
    1089             :         unsigned atype = Model_type_[im];
    1090             :         // sigma for 5 Gaussians
    1091             :         Vector5d m_s = Model_s_[atype];
    1092             :         // prefactors
    1093             :         Vector5d cfact = cfact_[atype];
    1094             :         // and position
    1095             :         Vector pos = getPosition(im);
    1096             : 
    1097             :         // cycle on all the neighboring voxels affected by a change in Bfactor
    1098             :         for(unsigned i=0; i<Model_nb_[im].size(); ++i) {
    1099             :           // voxel id
    1100             :           unsigned id = Model_nb_[im][i];
    1101             :           // get contribution to density in id before change
    1102             :           double dold = get_overlap(Map_m_[id], pos, cfact, m_s, bfactold);
    1103             :           // get contribution after change
    1104             :           double dnew = get_overlap(Map_m_[id], pos, cfact, m_s, bfactnew);
    1105             :           // update delta density
    1106             :           deltaov_l[id] += dnew-dold;
    1107             :         }
    1108             :       }
    1109             :       // add to global list
    1110             :       #pragma omp critical
    1111             :       {
    1112             :         for(std::map<unsigned,double>::iterator itov=deltaov_l.begin(); itov!=deltaov_l.end(); ++itov)
    1113             :           deltaov[itov->first] += itov->second;
    1114             :       }
    1115             :     }
    1116             : 
    1117             :     // now calculate new and old score
    1118             :     double old_ene = 0.0;
    1119             :     double new_ene = 0.0;
    1120             : 
    1121             :     // cycle on all affected voxels
    1122           0 :     for(std::map<unsigned,double>::iterator itov=deltaov.begin(); itov!=deltaov.end(); ++itov) {
    1123             :       // id of the component
    1124           0 :       unsigned id = itov->first;
    1125             :       // new value
    1126           0 :       double ovmdnew = ovmd_[id]+itov->second;
    1127             :       // deviations
    1128           0 :       double devold = scale_ * ovmd_[id] + offset_ - ovdd_[id];
    1129           0 :       double devnew = scale_ * ovmdnew   + offset_ - ovdd_[id];
    1130             :       // inverse of sigma_min
    1131           0 :       double ismin = ismin_[id];
    1132             :       // scores
    1133           0 :       if(devold==0.0) {
    1134           0 :         old_ene += -kbt_ * std::log( 0.5 * sqrt2_pi_ * ismin );
    1135             :       } else {
    1136           0 :         old_ene += -kbt_ * std::log( 0.5 / devold * erf ( devold * inv_sqrt2_ * ismin ));
    1137             :       }
    1138           0 :       if(devnew==0.0) {
    1139           0 :         new_ene += -kbt_ * std::log( 0.5 * sqrt2_pi_ * ismin );
    1140             :       } else {
    1141           0 :         new_ene += -kbt_ * std::log( 0.5 / devnew * erf ( devnew * inv_sqrt2_ * ismin ));
    1142             :       }
    1143             :     }
    1144             : 
    1145             :     // list of neighboring residues
    1146           0 :     std::vector<unsigned> close = nl_res_[ir];
    1147             :     // add restraint to keep Bfactors of close residues close
    1148           0 :     for(unsigned i=0; i<close.size(); ++i) {
    1149             :       // residue/chain IDs of neighbor
    1150           0 :       std::pair<unsigned,std::string> keyn = Model_rlist_[close[i]];
    1151             :       // deviations
    1152           0 :       double devold = bfactold - Model_b_[keyn];
    1153           0 :       double devnew = bfactnew - Model_b_[keyn];
    1154             :       // inverse of sigma_min
    1155           0 :       double ismin = 1.0 / bfactsig_;
    1156             :       // scores
    1157           0 :       if(devold==0.0) {
    1158           0 :         old_ene += -kbt_ * std::log( 0.5 * sqrt2_pi_ * ismin );
    1159             :       } else {
    1160           0 :         old_ene += -kbt_ * std::log( 0.5 / devold * erf ( devold * inv_sqrt2_ * ismin ));
    1161             :       }
    1162           0 :       if(devnew==0.0) {
    1163           0 :         new_ene += -kbt_ * std::log( 0.5 * sqrt2_pi_ * ismin );
    1164             :       } else {
    1165           0 :         new_ene += -kbt_ * std::log( 0.5 / devnew * erf ( devnew * inv_sqrt2_ * ismin ));
    1166             :       }
    1167             :     }
    1168             : 
    1169             :     // increment number of trials
    1170           0 :     MCBtrials_ += 1.0;
    1171             : 
    1172             :     // accept or reject
    1173             :     bool accept = false;
    1174           0 :     if(bfactemin_) {
    1175           0 :       if(new_ene < old_ene) accept = true;
    1176             :     } else {
    1177           0 :       accept = doAccept(old_ene, new_ene, kbt_);
    1178             :     }
    1179             : 
    1180             :     // in case of acceptance
    1181           0 :     if(accept) {
    1182             :       // update acceptance rate
    1183           0 :       MCBaccept_ += 1.0;
    1184             :       // update bfactor
    1185           0 :       Model_b_[key] = bfactnew;
    1186             :       // change all the ovmd_ affected
    1187           0 :       for(std::map<unsigned,double>::iterator itov=deltaov.begin(); itov!=deltaov.end(); ++itov)
    1188           0 :         ovmd_[itov->first] += itov->second;
    1189             :     }
    1190             : 
    1191             :   } // end cycle on bfactors
    1192             : 
    1193             : // update auxiliary lists (to update pref_gpu_, invs2_gpu_, and cut_ on CPU/GPU)
    1194           0 :   get_auxiliary_vectors();
    1195             : // update neighbor list (new cut_ + update pref_nl_gpu_ and invs2_nl_gpu_ on GPU)
    1196           0 :   update_neighbor_list();
    1197             : // recalculate fmod (to update derivatives)
    1198           0 :   calculate_fmod();
    1199           0 : }
    1200             : 
    1201             : // get overlap
    1202           0 : double EMMIVOX::get_overlap(const Vector &d_m, const Vector &m_m,
    1203             :                             const Vector5d &cfact, const Vector5d &m_s, double bfact)
    1204             : {
    1205             :   // calculate vector difference
    1206           0 :   Vector md = delta(m_m, d_m);
    1207             :   // norm squared
    1208           0 :   double md2 = md[0]*md[0]+md[1]*md[1]+md[2]*md[2];
    1209             :   // cycle on 5 Gaussians
    1210             :   double ov_tot = 0.0;
    1211           0 :   for(unsigned j=0; j<5; ++j) {
    1212             :     // total value of b
    1213           0 :     double m_b = m_s[j]+bfact/4.0;
    1214             :     // calculate invs2
    1215           0 :     double invs2 = 1.0/(inv_pi2_*m_b);
    1216             :     // final calculation
    1217           0 :     ov_tot += cfact[j] * pow(invs2, 1.5) * std::exp(-0.5 * md2 * invs2);
    1218             :   }
    1219           0 :   return ov_tot;
    1220             : }
    1221             : 
    1222           0 : void EMMIVOX::update_neighbor_sphere()
    1223             : {
    1224             :   // number of atoms
    1225           0 :   unsigned natoms = Model_type_.size();
    1226             :   // clear neighbor sphere
    1227             :   ns_.clear();
    1228             :   // store reference positions
    1229           0 :   refpos_ = getPositions();
    1230             : 
    1231             :   // cycle on voxels - in parallel
    1232           0 :   #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1233             :   {
    1234             :     // private variables
    1235             :     std::vector< std::pair<unsigned,unsigned> > ns_l;
    1236             :     #pragma omp for
    1237             :     for(unsigned id=0; id<ovdd_.size(); ++id) {
    1238             :       // grid point
    1239             :       Vector d_m = Map_m_[id];
    1240             :       // cycle on atoms
    1241             :       for(unsigned im=0; im<natoms; ++im) {
    1242             :         // calculate distance
    1243             :         double dist = delta(getPosition(im), d_m).modulo();
    1244             :         // add to local list
    1245             :         if(dist<=2.0*cut_[im]) ns_l.push_back(std::make_pair(id,im));
    1246             :       }
    1247             :     }
    1248             :     // add to global list
    1249             :     #pragma omp critical
    1250             :     ns_.insert(ns_.end(), ns_l.begin(), ns_l.end());
    1251             :   }
    1252           0 : }
    1253             : 
    1254           0 : bool EMMIVOX::do_neighbor_sphere()
    1255             : {
    1256           0 :   std::vector<double> dist(getPositions().size());
    1257             :   bool update = false;
    1258             : 
    1259             : // calculate displacement
    1260           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1261             :   for(unsigned im=0; im<dist.size(); ++im) {
    1262             :     dist[im] = delta(getPosition(im),refpos_[im]).modulo()/cut_[im];
    1263             :   }
    1264             : 
    1265             : // check if update or not
    1266           0 :   double maxdist = *max_element(dist.begin(), dist.end());
    1267           0 :   if(maxdist>=1.0) update=true;
    1268             : 
    1269             : // return if update or not
    1270           0 :   return update;
    1271             : }
    1272             : 
    1273           0 : void EMMIVOX::update_neighbor_list()
    1274             : {
    1275             :   // number of atoms
    1276           0 :   unsigned natoms = Model_type_.size();
    1277             :   // clear neighbor list
    1278             :   nl_.clear();
    1279             : 
    1280             :   // cycle on neighbour sphere - in parallel
    1281           0 :   #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1282             :   {
    1283             :     // private variables
    1284             :     std::vector< std::pair<unsigned,unsigned> > nl_l;
    1285             :     #pragma omp for
    1286             :     for(unsigned long long i=0; i<ns_.size(); ++i) {
    1287             :       // calculate distance
    1288             :       double dist = delta(Map_m_[ns_[i].first], getPosition(ns_[i].second)).modulo();
    1289             :       // add to local neighbour list
    1290             :       if(dist<=cut_[ns_[i].second]) nl_l.push_back(ns_[i]);
    1291             :     }
    1292             :     // add to global list
    1293             :     #pragma omp critical
    1294             :     nl_.insert(nl_.end(), nl_l.begin(), nl_l.end());
    1295             :   }
    1296             : 
    1297             :   // new dimension of neighbor list
    1298             :   unsigned long long nl_size = nl_.size();
    1299             :   // now resize derivatives
    1300           0 :   ovmd_der_.resize(nl_size);
    1301             : 
    1302             :   // in case of B-factors sampling - at the right step
    1303           0 :   if(dbfact_>0 && getStep()%MCBstride_==0) {
    1304             :     // clear vectors
    1305           0 :     Model_nb_.clear(); Model_nb_.resize(natoms);
    1306             :     // cycle over the neighbor list to creat a list of voxels per atom
    1307           0 :     #pragma omp parallel num_threads(OpenMP::getNumThreads())
    1308             :     {
    1309             :       // private variables
    1310             :       std::vector< std::vector<unsigned> > Model_nb_l(natoms);
    1311             :       #pragma omp for
    1312             :       for(unsigned long long i=0; i<nl_size; ++i) {
    1313             :         Model_nb_l[nl_[i].second].push_back(nl_[i].first);
    1314             :       }
    1315             :       // add to global list
    1316             :       #pragma omp critical
    1317             :       {
    1318             :         for(unsigned i=0; i<natoms; ++i)
    1319             :           Model_nb_[i].insert(Model_nb_[i].end(), Model_nb_l[i].begin(), Model_nb_l[i].end());
    1320             :       }
    1321             :     }
    1322             :   }
    1323             : 
    1324             :   // transfer data to gpu
    1325           0 :   update_gpu();
    1326           0 : }
    1327             : 
    1328           0 : void EMMIVOX::update_gpu()
    1329             : {
    1330             :   // dimension of neighbor list
    1331             :   long long nl_size = nl_.size();
    1332             :   // create useful vectors
    1333           0 :   std::vector<int> nl_id(nl_size), nl_im(nl_size);
    1334           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1335             :   for(unsigned long long i=0; i<nl_size; ++i) {
    1336             :     nl_id[i] = static_cast<int>(nl_[i].first);
    1337             :     nl_im[i] = static_cast<int>(nl_[i].second);
    1338             :   }
    1339             :   // create tensors on device
    1340           0 :   nl_id_gpu_ = torch::from_blob(nl_id.data(), {nl_size}, torch::kInt32).clone().to(device_t_);
    1341           0 :   nl_im_gpu_ = torch::from_blob(nl_im.data(), {nl_size}, torch::kInt32).clone().to(device_t_);
    1342             :   // now we need to create pref_nl_gpu_ [5,nl_size]
    1343           0 :   pref_nl_gpu_  = torch::index_select(pref_gpu_,1,nl_im_gpu_);
    1344             :   // and invs2_nl_gpu_ [5,nl_size]
    1345           0 :   invs2_nl_gpu_ = torch::index_select(invs2_gpu_,1,nl_im_gpu_);
    1346             :   // and Map_m_nl_gpu_ [3,nl_size]
    1347           0 :   Map_m_nl_gpu_ = torch::index_select(Map_m_gpu_,1,nl_id_gpu_);
    1348           0 : }
    1349             : 
    1350           0 : void EMMIVOX::prepare()
    1351             : {
    1352           0 :   if(getExchangeStep()) first_time_=true;
    1353           0 : }
    1354             : 
    1355             : // calculate forward model on gpu
    1356           0 : void EMMIVOX::calculate_fmod()
    1357             : {
    1358             :   // number of atoms
    1359           0 :   int natoms = Model_type_.size();
    1360             :   // number of data points
    1361           0 :   int nd = ovdd_.size();
    1362             : 
    1363             :   // fill positions in in parallel
    1364           0 :   std::vector<double> posg(3*natoms);
    1365           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1366             :   for (int i=0; i<natoms; ++i) {
    1367             :     // fill vectors
    1368             :     posg[i]          = getPosition(i)[0];
    1369             :     posg[i+natoms]   = getPosition(i)[1];
    1370             :     posg[i+2*natoms] = getPosition(i)[2];
    1371             :   }
    1372             :   // transfer positions to pos_gpu [3,natoms]
    1373           0 :   torch::Tensor pos_gpu = torch::from_blob(posg.data(), {3,natoms}, torch::kFloat64).to(torch::kFloat32).to(device_t_);
    1374             :   // create pos_nl_gpu_ [3,nl_size]
    1375           0 :   torch::Tensor pos_nl_gpu = torch::index_select(pos_gpu,1,nl_im_gpu_);
    1376             :   // calculate vector difference [3,nl_size]
    1377           0 :   torch::Tensor md = Map_m_nl_gpu_ - pos_nl_gpu;
    1378             :   // calculate norm squared by column [1,nl_size]
    1379           0 :   torch::Tensor md2 = torch::sum(md*md,0);
    1380             :   // calculate density [5,nl_size]
    1381           0 :   torch::Tensor ov = pref_nl_gpu_ * torch::exp(-0.5 * md2 * invs2_nl_gpu_);
    1382             :   // and derivatives [5,nl_size]
    1383           0 :   ovmd_der_gpu_ = invs2_nl_gpu_ * ov;
    1384             :   // sum density over 5 columns [1,nl_size]
    1385           0 :   ov = torch::sum(ov,0);
    1386             :   // sum contributions from the same atom
    1387           0 :   auto options = torch::TensorOptions().device(device_t_).dtype(torch::kFloat32);
    1388           0 :   ovmd_gpu_ = torch::zeros({nd}, options);
    1389           0 :   ovmd_gpu_.index_add_(0, nl_id_gpu_, ov);
    1390             :   // sum derivatives over 5 rows [1,nl_size] and multiply by md [3,nl_size]
    1391           0 :   ovmd_der_gpu_ = md * torch::sum(ovmd_der_gpu_,0);
    1392             : 
    1393             :   // in case of metainference: average them across replicas
    1394           0 :   if(!no_aver_ && nrep_>1) {
    1395             :     // communicate ovmd_gpu_ to CPU [1, nd]
    1396           0 :     torch::Tensor ovmd_cpu = ovmd_gpu_.detach().to(torch::kCPU).to(torch::kFloat64);
    1397             :     // and put them in ovmd_
    1398           0 :     ovmd_ = std::vector<double>(ovmd_cpu.data_ptr<double>(), ovmd_cpu.data_ptr<double>() + ovmd_cpu.numel());
    1399             :     // sum across replicas
    1400           0 :     multi_sim_comm.Sum(&ovmd_[0], nd);
    1401             :     // and divide by number of replicas
    1402           0 :     double escale = 1.0 / static_cast<double>(nrep_);
    1403           0 :     for(int i=0; i<nd; ++i) ovmd_[i] *= escale;
    1404             :     // put back on device
    1405           0 :     ovmd_gpu_ = torch::from_blob(ovmd_.data(), {nd}, torch::kFloat64).to(torch::kFloat32).to(device_t_);
    1406             :   }
    1407             : 
    1408             :   // communicate back model density
    1409             :   // this is needed only in certain situations
    1410           0 :   long int step = getStep();
    1411             :   bool do_comm = false;
    1412           0 :   if(mapstride_>0 && step%mapstride_==0) do_comm = true;
    1413           0 :   if(dbfact_>0    && step%MCBstride_==0) do_comm = true;
    1414           0 :   if(do_corr_) do_comm = true;
    1415             :   // in case of metainference: already communicated
    1416           0 :   if(!no_aver_ && nrep_>1) do_comm = false;
    1417           0 :   if(do_comm) {
    1418             :     // communicate ovmd_gpu_ to CPU [1, nd]
    1419           0 :     torch::Tensor ovmd_cpu = ovmd_gpu_.detach().to(torch::kCPU).to(torch::kFloat64);
    1420             :     // and put them in ovmd_
    1421           0 :     ovmd_ = std::vector<double>(ovmd_cpu.data_ptr<double>(), ovmd_cpu.data_ptr<double>() + ovmd_cpu.numel());
    1422             :   }
    1423           0 : }
    1424             : 
    1425             : // calculate score
    1426           0 : void EMMIVOX::calculate_score()
    1427             : {
    1428             :   // number of atoms
    1429           0 :   int natoms = Model_type_.size();
    1430             : 
    1431             :   // calculate deviation model/data [1, nd]
    1432           0 :   torch::Tensor dev = scale_ * ovmd_gpu_ + offset_ - ovdd_gpu_;
    1433             :   // error function [1, nd]
    1434           0 :   torch::Tensor errf = torch::erf( dev * inv_sqrt2_ * ismin_gpu_ );
    1435             :   // take care of dev = zero
    1436           0 :   torch::Tensor zeros_d = torch::ne(dev, 0.0);
    1437             :   // redefine dev
    1438           0 :   dev = dev * zeros_d + eps_ * torch::logical_not(zeros_d);
    1439             :   // take care of errf = zero
    1440           0 :   torch::Tensor zeros_e = torch::ne(errf, 0.0);
    1441             :   // redefine errf
    1442           0 :   errf = errf * zeros_e + eps_ * torch::logical_not(zeros_e);
    1443             :   // logical AND: both dev and errf different from zero
    1444             :   torch::Tensor zeros = torch::logical_and(zeros_d, zeros_e);
    1445             :   // energy - with limit dev going to zero
    1446           0 :   torch::Tensor ene = 0.5 * ( errf / dev * zeros + torch::logical_not(zeros) * sqrt2_pi_ *  ismin_gpu_);
    1447             :   // logarithm and sum
    1448           0 :   ene = -kbt_ * torch::sum(torch::log(ene));
    1449             :   // and derivatives [1, nd]
    1450           0 :   torch::Tensor d_der = -kbt_ * zeros * ( sqrt2_pi_ * torch::exp( -0.5 * dev * dev * ismin_gpu_ * ismin_gpu_ ) * ismin_gpu_ / errf - 1.0 / dev );
    1451             :   // tensor for derivatives wrt atoms [1, nl_size]
    1452           0 :   torch::Tensor der_gpu = torch::index_select(d_der,0,nl_id_gpu_);
    1453             :   // multiply by ovmd_der_gpu_ and scale [3, nl_size]
    1454           0 :   der_gpu = ovmd_der_gpu_ * scale_ * der_gpu;
    1455             :   // sum contributions for each atom
    1456           0 :   auto options = torch::TensorOptions().device(device_t_).dtype(torch::kFloat32);
    1457           0 :   torch::Tensor atoms_der_gpu = torch::zeros({3,natoms}, options);
    1458           0 :   atoms_der_gpu.index_add_(1, nl_im_gpu_, der_gpu);
    1459             : 
    1460             :   // FINAL STUFF
    1461             :   //
    1462             :   // 1) communicate total energy to CPU
    1463           0 :   torch::Tensor ene_cpu = ene.detach().to(torch::kCPU).to(torch::kFloat64);
    1464           0 :   ene_ = *ene_cpu.data_ptr<double>();
    1465             :   // with marginal, simply multiply by number of replicas!
    1466           0 :   if(!no_aver_ && nrep_>1) ene_ *= static_cast<double>(nrep_);
    1467             :   //
    1468             :   // 2) communicate derivatives to CPU
    1469           0 :   torch::Tensor atom_der_cpu = atoms_der_gpu.detach().to(torch::kCPU).to(torch::kFloat64);
    1470             :   // convert to std::vector<double>
    1471           0 :   std::vector<double> atom_der = std::vector<double>(atom_der_cpu.data_ptr<double>(), atom_der_cpu.data_ptr<double>() + atom_der_cpu.numel());
    1472             :   // and put in atom_der_
    1473           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads())
    1474             :   for(int i=0; i<natoms; ++i) {
    1475             :     atom_der_[i] = Vector(atom_der[i],atom_der[i+natoms],atom_der[i+2*natoms]);
    1476             :   }
    1477             :   //
    1478             :   // 3) calculate virial on CPU
    1479           0 :   Tensor virial;
    1480             :   // declare omp reduction for Tensors
    1481             :   #pragma omp declare reduction( sumTensor : Tensor : omp_out += omp_in )
    1482             : 
    1483           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads()) reduction (sumTensor : virial)
    1484             :   for(int i=0; i<natoms; ++i) {
    1485             :     virial += Tensor(getPosition(i), -atom_der_[i]);
    1486             :   }
    1487             :   // store virial
    1488           0 :   virial_ = virial;
    1489           0 : }
    1490             : 
    1491           0 : void EMMIVOX::calculate()
    1492             : {
    1493             :   // get time step
    1494           0 :   long int step = getStep();
    1495             : 
    1496             :   // set temperature value
    1497           0 :   getPntrToComponent("kbt")->set(kbt_);
    1498             : 
    1499             :   // neighbor list update
    1500           0 :   if(first_time_ || getExchangeStep() || step%nl_stride_==0) {
    1501             :     // check if time to update neighbor sphere
    1502             :     bool update = false;
    1503           0 :     if(first_time_ || getExchangeStep()) update = true;
    1504           0 :     else update = do_neighbor_sphere();
    1505             :     // update neighbor sphere
    1506           0 :     if(update) update_neighbor_sphere();
    1507             :     // update neighbor list
    1508           0 :     update_neighbor_list();
    1509             :     // set flag
    1510           0 :     first_time_=false;
    1511             :   }
    1512             : 
    1513             :   // calculate forward model
    1514           0 :   calculate_fmod();
    1515             : 
    1516             :   // Monte Carlo on bfactors
    1517           0 :   if(dbfact_>0) {
    1518             :     double acc = 0.0;
    1519             :     // do Monte Carlo
    1520           0 :     if(step%MCBstride_==0 && !getExchangeStep() && step>0) doMonteCarloBfact();
    1521             :     // calculate acceptance ratio
    1522           0 :     if(MCBtrials_>0) acc = MCBaccept_ / MCBtrials_;
    1523             :     // set value
    1524           0 :     getPntrToComponent("accB")->set(acc);
    1525             :   }
    1526             : 
    1527             :   // calculate score
    1528           0 :   calculate_score();
    1529             : 
    1530             :   // set score, virial, and derivatives
    1531           0 :   Value* score = getPntrToComponent("scoreb");
    1532           0 :   score->set(ene_);
    1533           0 :   setBoxDerivatives(score, virial_);
    1534           0 :   #pragma omp parallel for
    1535             :   for(unsigned i=0; i<atom_der_.size(); ++i) setAtomsDerivatives(score, i, atom_der_[i]);
    1536             :   // set scale and offset value
    1537           0 :   getPntrToComponent("scale")->set(scale_);
    1538           0 :   getPntrToComponent("offset")->set(offset_);
    1539             :   // calculate correlation coefficient
    1540           0 :   if(do_corr_) calculate_corr();
    1541             :   // PRINT other quantities to files
    1542             :   // - status file
    1543           0 :   if(step%statusstride_==0) print_status(step);
    1544             :   // - density file
    1545           0 :   if(mapstride_>0 && step%mapstride_==0) write_model_density(step);
    1546           0 : }
    1547             : 
    1548           0 : void EMMIVOX::calculate_corr()
    1549             : {
    1550             : // number of data points
    1551           0 :   double nd = static_cast<double>(ovdd_.size());
    1552             : // average ovmd_ and ovdd_
    1553           0 :   double ave_md = std::accumulate(ovmd_.begin(), ovmd_.end(), 0.) / nd;
    1554           0 :   double ave_dd = std::accumulate(ovdd_.begin(), ovdd_.end(), 0.) / nd;
    1555             : // calculate correlation
    1556             :   double num = 0.;
    1557             :   double den1 = 0.;
    1558             :   double den2 = 0.;
    1559           0 :   #pragma omp parallel for num_threads(OpenMP::getNumThreads()) reduction( + : num, den1, den2)
    1560             :   for(unsigned i=0; i<ovdd_.size(); ++i) {
    1561             :     double md = ovmd_[i]-ave_md;
    1562             :     double dd = ovdd_[i]-ave_dd;
    1563             :     num  += md*dd;
    1564             :     den1 += md*md;
    1565             :     den2 += dd*dd;
    1566             :   }
    1567             : // correlation coefficient
    1568           0 :   double cc = num / sqrt(den1*den2);
    1569             : // set plumed
    1570           0 :   getPntrToComponent("corr")->set(cc);
    1571           0 : }
    1572             : 
    1573             : 
    1574             : }
    1575             : }
    1576             : 
    1577             : #endif

Generated by: LCOV version 1.16