Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2022-2023 of Luigi Bonati and Enrico Trizio. 3 : 4 : The pytorch module is free software: you can redistribute it and/or modify 5 : it under the terms of the GNU Lesser General Public License as published by 6 : the Free Software Foundation, either version 3 of the License, or 7 : (at your option) any later version. 8 : 9 : The pytorch module is distributed in the hope that it will be useful, 10 : but WITHOUT ANY WARRANTY; without even the implied warranty of 11 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 : GNU Lesser General Public License for more details. 13 : 14 : You should have received a copy of the GNU Lesser General Public License 15 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 16 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 17 : 18 : #ifdef __PLUMED_HAS_LIBTORCH 19 : // convert LibTorch version to string 20 : //#define STRINGIFY(x) #x 21 : //#define TOSTR(x) STRINGIFY(x) 22 : //#define LIBTORCH_VERSION TO_STR(TORCH_VERSION_MAJOR) "." TO_STR(TORCH_VERSION_MINOR) "." TO_STR(TORCH_VERSION_PATCH) 23 : 24 : #include "core/PlumedMain.h" 25 : #include "function/Function.h" 26 : #include "core/ActionRegister.h" 27 : 28 : #include <torch/torch.h> 29 : #include <torch/script.h> 30 : 31 : #include <fstream> 32 : #include <cmath> 33 : 34 : // Note: Freezing a ScriptModule (torch::jit::freeze) works only in >=1.11 35 : // For 1.8 <= versions <=1.10 we need a hack 36 : // (see https://discuss.pytorch.org/t/how-to-check-libtorch-version/77709/4 and also 37 : // https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp#L479) 38 : // adapted from NequIP https://github.com/mir-group/nequip 39 : #if ( TORCH_VERSION_MAJOR == 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10 ) 40 : #define DO_TORCH_FREEZE_HACK 41 : // For the hack, need more headers: 42 : #include <torch/csrc/jit/passes/freeze_module.h> 43 : #include <torch/csrc/jit/passes/frozen_graph_optimizations.h> 44 : #endif 45 : 46 : using namespace std; 47 : 48 : namespace PLMD { 49 : namespace function { 50 : namespace pytorch { 51 : 52 : //+PLUMEDOC PYTORCH_FUNCTION PYTORCH_MODEL 53 : /* 54 : Load a PyTorch model compiled with TorchScript. 55 : 56 : This can be a function defined in Python or a more complex model, such as a neural network optimized on a set of data. In both cases the derivatives of the outputs with respect to the inputs are computed using the automatic differentiation (autograd) feature of Pytorch. 57 : 58 : By default it is assumed that the model is saved as: `model.ptc`, unless otherwise indicated by the `FILE` keyword. The function automatically checks for the number of output dimensions and creates a component for each of them. The outputs are called node-i with i between 0 and N-1 for N outputs. 59 : 60 : Note that this function requires \ref installation-libtorch LibTorch C++ library. Check the instructions in the \ref PYTORCH page to enable the module. 61 : 62 : \par Examples 63 : Load a model called `torch_model.ptc` that takes as input two dihedral angles and returns two outputs. 64 : 65 : \plumedfile 66 : #SETTINGS AUXFILE=regtest/pytorch/rt-pytorch_model_2d/torch_model.ptc 67 : phi: TORSION ATOMS=5,7,9,15 68 : psi: TORSION ATOMS=7,9,15,17 69 : model: PYTORCH_MODEL FILE=torch_model.ptc ARG=phi,psi 70 : PRINT FILE=COLVAR ARG=model.node-0,model.node-1 71 : \endplumedfile 72 : 73 : */ 74 : //+ENDPLUMEDOC 75 : 76 : 77 : class PytorchModel : 78 : public Function { 79 : unsigned _n_in; 80 : unsigned _n_out; 81 : torch::jit::script::Module _model; 82 : torch::Device device = torch::kCPU; 83 : 84 : public: 85 : explicit PytorchModel(const ActionOptions&); 86 : void calculate(); 87 : static void registerKeywords(Keywords& keys); 88 : 89 : std::vector<float> tensor_to_vector(const torch::Tensor& x); 90 : }; 91 : 92 : PLUMED_REGISTER_ACTION(PytorchModel,"PYTORCH_MODEL") 93 : 94 6 : void PytorchModel::registerKeywords(Keywords& keys) { 95 6 : Function::registerKeywords(keys); 96 6 : keys.add("optional","FILE","Filename of the PyTorch compiled model"); 97 12 : keys.addOutputComponent("node", "default", "Model outputs"); 98 6 : } 99 : 100 : // Auxiliary function to transform torch tensors in std vectors 101 103 : std::vector<float> PytorchModel::tensor_to_vector(const torch::Tensor& x) { 102 206 : return std::vector<float>(x.data_ptr<float>(), x.data_ptr<float>() + x.numel()); 103 : } 104 : 105 4 : PytorchModel::PytorchModel(const ActionOptions&ao): 106 : Action(ao), 107 4 : Function(ao) { 108 : // print libtorch version 109 4 : std::stringstream ss; 110 4 : ss << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." << TORCH_VERSION_PATCH; 111 : std::string version; 112 4 : ss >> version; // extract into the string. 113 8 : log.printf((" LibTorch version: "+version+"\n").data()); 114 : 115 : //number of inputs of the model 116 4 : _n_in=getNumberOfArguments(); 117 : 118 : //parse model name 119 4 : std::string fname="model.ptc"; 120 8 : parse("FILE",fname); 121 : 122 : //deserialize the model from file 123 : try { 124 4 : _model = torch::jit::load(fname, device); 125 : } 126 : 127 : //if an error is thrown check if the file exists or not 128 0 : catch (const c10::Error& e) { 129 0 : std::ifstream infile(fname); 130 : bool exist = infile.good(); 131 0 : infile.close(); 132 0 : if (exist) { 133 0 : plumed_merror("Cannot load FILE: '"+fname+"'. Please check that it is a Pytorch compiled model (exported with 'torch.jit.trace' or 'torch.jit.script')."); 134 : } else { 135 0 : plumed_merror("The FILE: '"+fname+"' does not exist."); 136 : } 137 0 : } 138 4 : checkRead(); 139 : 140 : // Optimize model 141 : _model.eval(); 142 : #ifdef DO_TORCH_FREEZE_HACK 143 : // Do the hack 144 : // Copied from the implementation of torch::jit::freeze, 145 : // except without the broken check 146 : // See https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp 147 : bool optimize_numerics = true; // the default 148 : // the {} is preserved_attrs 149 : auto out_mod = torch::jit::freeze_module( 150 : _model, {} 151 4 : ); 152 : // See 1.11 bugfix in https://github.com/pytorch/pytorch/pull/71436 153 8 : auto graph = out_mod.get_method("forward").graph(); 154 4 : OptimizeFrozenGraph(graph, optimize_numerics); 155 4 : _model = out_mod; 156 : #else 157 : // Do it normally 158 : _model = torch::jit::freeze(_model); 159 : #endif 160 : 161 : // Optimize model for inference 162 : #if (TORCH_VERSION_MAJOR == 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 10) 163 4 : _model = torch::jit::optimize_for_inference(_model); 164 : #endif 165 : 166 : //check the dimension of the output 167 4 : log.printf(" Checking output dimension:\n"); 168 4 : std::vector<float> input_test (_n_in); 169 4 : torch::Tensor single_input = torch::tensor(input_test).view({1,_n_in}); 170 8 : single_input = single_input.to(device); 171 : std::vector<torch::jit::IValue> inputs; 172 4 : inputs.push_back( single_input ); 173 8 : torch::Tensor output = _model.forward( inputs ).toTensor(); 174 4 : vector<float> cvs = this->tensor_to_vector (output); 175 4 : _n_out=cvs.size(); 176 : 177 : //create components 178 9 : for(unsigned j=0; j<_n_out; j++) { 179 5 : string name_comp = "node-"+std::to_string(j); 180 5 : addComponentWithDerivatives( name_comp ); 181 5 : componentIsNotPeriodic( name_comp ); 182 : } 183 : 184 : //print log 185 4 : log.printf(" Number of input: %d \n",_n_in); 186 4 : log.printf(" Number of outputs: %d \n",_n_out); 187 4 : log.printf(" Bibliography: "); 188 8 : log<<plumed.cite("Bonati, Trizio, Rizzi and Parrinello, J. Chem. Phys. 159, 014801 (2023)"); 189 8 : log<<plumed.cite("Bonati, Rizzi and Parrinello, J. Phys. Chem. Lett. 11, 2998-3004 (2020)"); 190 4 : log.printf("\n"); 191 : 192 12 : } 193 : 194 : 195 44 : void PytorchModel::calculate() { 196 : 197 : // retrieve arguments 198 44 : vector<float> current_S(_n_in); 199 99 : for(unsigned i=0; i<_n_in; i++) { 200 55 : current_S[i]=getArgument(i); 201 : } 202 : //convert to tensor 203 44 : torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in}).to(device); 204 : input_S.set_requires_grad(true); 205 : //convert to Ivalue 206 : std::vector<torch::jit::IValue> inputs; 207 44 : inputs.push_back( input_S ); 208 : //calculate output 209 88 : torch::Tensor output = _model.forward( inputs ).toTensor(); 210 : 211 : 212 99 : for(unsigned j=0; j<_n_out; j++) { 213 55 : auto grad_output = torch::ones({1}).expand({1, 1}).to(device); 214 440 : auto gradient = torch::autograd::grad({output.slice(/*dim=*/1, /*start=*/j, /*end=*/j+1)}, 215 : {input_S}, 216 : /*grad_outputs=*/ {grad_output}, 217 : /*retain_graph=*/true, 218 : /*create_graph=*/false)[0]; // the [0] is to get a tensor and not a vector<at::tensor> 219 : 220 55 : vector<float> der = this->tensor_to_vector ( gradient ); 221 55 : string name_comp = "node-"+std::to_string(j); 222 : //set derivatives of component j 223 132 : for(unsigned i=0; i<_n_in; i++) { 224 77 : setDerivative( getPntrToComponent(name_comp),i, der[i] ); 225 : } 226 : } 227 : 228 : //set CV values 229 44 : vector<float> cvs = this->tensor_to_vector (output); 230 99 : for(unsigned j=0; j<_n_out; j++) { 231 55 : string name_comp = "node-"+std::to_string(j); 232 55 : getPntrToComponent(name_comp)->set(cvs[j]); 233 : } 234 : 235 88 : } 236 : 237 : 238 : } //PLMD 239 : } //function 240 : } //pytorch 241 : 242 : #endif //PLUMED_HAS_LIBTORCH