Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2012-2023 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "LeptonCall.h" 23 : #include "OpenMP.h" 24 : 25 : namespace PLMD { 26 : 27 3091 : void LeptonCall::set(const std::string & func, const std::vector<std::string>& var, Action* action, const bool& a ) { 28 3091 : unsigned nth=OpenMP::getNumThreads(); expression.resize(nth); expression_deriv.resize(var.size()); 29 : // Resize the expression for the derivatives 30 8085 : for(unsigned i=0; i<expression_deriv.size(); ++i) expression_deriv[i].resize(OpenMP::getNumThreads()); 31 3091 : allow_extra_args=a; nargs=var.size(); 32 : 33 3091 : lepton_ref.resize(nth*nargs,nullptr); 34 3091 : lepton::ParsedExpression pe=lepton::Parser::parse(func).optimize(lepton::Constants()); unsigned nt=0; 35 3091 : if( action ) action->log<<" function as parsed by lepton: "<<pe<<"\n"; 36 9273 : for(auto & e : expression) { 37 6182 : e=pe.createCompiledExpression(); 38 16170 : for(unsigned j=0; j<var.size(); ++j) { 39 : try { 40 9988 : lepton_ref[nt*var.size()+j]=&const_cast<lepton::CompiledExpression*>(&expression[nt])->getVariableReference(var[j]); 41 88 : } catch(const PLMD::lepton::Exception& exc) { 42 : // this is necessary since in some cases lepton things a variable is not present even though it is present 43 : // e.g. func=0*x 44 88 : } 45 : } 46 6182 : nt++; 47 : } 48 8041 : for(auto & p : expression[0].getVariables()) { 49 4950 : if(std::find(var.begin(),var.end(),p)==var.end()) { 50 0 : if( action ) action->error("variable " + p + " is not defined"); 51 0 : else plumed_merror("variable " + p + " is not defined in lepton function"); 52 : } 53 : } 54 3091 : if( action ) action->log<<" derivatives as computed by lepton:\n"; 55 3091 : lepton_ref_deriv.resize(nth*nargs*nargs,nullptr); 56 8085 : for(unsigned i=0; i<var.size(); i++) { 57 9988 : lepton::ParsedExpression pe=lepton::Parser::parse(func).differentiate(var[i]).optimize(lepton::Constants()); nt=0; if( action ) action->log<<" "<<pe<<"\n"; 58 14982 : for(auto & e : expression_deriv[i]) { 59 9988 : e=pe.createCompiledExpression(); 60 28932 : for(unsigned j=0; j<var.size(); ++j) { 61 : try { 62 18944 : lepton_ref_deriv[i*OpenMP::getNumThreads()*var.size() + nt*var.size()+j]=&const_cast<lepton::CompiledExpression*>(&expression_deriv[i][nt])->getVariableReference(var[j]); 63 6110 : } catch(const PLMD::lepton::Exception& exc) { 64 : // this is necessary since in some cases lepton things a variable is not present even though it is present 65 : // e.g. func=0*x 66 6110 : } 67 : } 68 9988 : nt++; 69 : } 70 : } 71 3091 : } 72 : 73 22966081 : double LeptonCall::evaluate( const std::vector<double>& args ) const { 74 : plumed_dbg_assert( allow_extra_args || args.size()==nargs ); 75 22966081 : const unsigned t=OpenMP::getThreadNum(), tbas=t*nargs; 76 67582801 : for(unsigned i=0; i<nargs; ++i) { 77 44616720 : if( lepton_ref[tbas+i] ) *lepton_ref[tbas+i] = args[i]; 78 : } 79 22966081 : return expression[t].evaluate(); 80 : } 81 : 82 32189400 : double LeptonCall::evaluateDeriv( const unsigned& ider, const std::vector<double>& args ) const { 83 : plumed_dbg_assert( allow_extra_args || args.size()==nargs ); plumed_dbg_assert( ider<nargs ); 84 32189400 : const unsigned t=OpenMP::getThreadNum(), dbas = ider*OpenMP::getNumThreads()*nargs + t*nargs; 85 106819836 : for(unsigned j=0; j<nargs; j++) { 86 74630436 : if(lepton_ref_deriv[dbas+j] ) *lepton_ref_deriv[dbas+j] = args[j]; 87 : } 88 32189400 : return expression_deriv[ider][t].evaluate(); 89 : } 90 : 91 : }