Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2022-2023 of Luigi Bonati and Enrico Trizio. 3 : 4 : The pytorch module is free software: you can redistribute it and/or modify 5 : it under the terms of the GNU Lesser General Public License as published by 6 : the Free Software Foundation, either version 3 of the License, or 7 : (at your option) any later version. 8 : 9 : The pytorch module is distributed in the hope that it will be useful, 10 : but WITHOUT ANY WARRANTY; without even the implied warranty of 11 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 : GNU Lesser General Public License for more details. 13 : 14 : You should have received a copy of the GNU Lesser General Public License 15 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 16 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 17 : 18 : #ifdef __PLUMED_HAS_LIBTORCH 19 : // convert LibTorch version to string 20 : //#define STRINGIFY(x) #x 21 : //#define TOSTR(x) STRINGIFY(x) 22 : //#define LIBTORCH_VERSION TO_STR(TORCH_VERSION_MAJOR) "." TO_STR(TORCH_VERSION_MINOR) "." TO_STR(TORCH_VERSION_PATCH) 23 : 24 : #include "core/PlumedMain.h" 25 : #include "function/Function.h" 26 : #include "core/ActionRegister.h" 27 : 28 : #include <torch/torch.h> 29 : #include <torch/script.h> 30 : 31 : #include <fstream> 32 : #include <cmath> 33 : 34 : // Note: Freezing a ScriptModule (torch::jit::freeze) works only in >=1.11 35 : // For 1.8 <= versions <=1.10 we need a hack 36 : // (see https://discuss.pytorch.org/t/how-to-check-libtorch-version/77709/4 and also 37 : // https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp#L479) 38 : // adapted from NequIP https://github.com/mir-group/nequip 39 : #if ( TORCH_VERSION_MAJOR == 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10 ) 40 : #define DO_TORCH_FREEZE_HACK 41 : // For the hack, need more headers: 42 : #include <torch/csrc/jit/passes/freeze_module.h> 43 : #include <torch/csrc/jit/passes/frozen_graph_optimizations.h> 44 : #endif 45 : 46 : using namespace std; 47 : 48 : namespace PLMD { 49 : namespace function { 50 : namespace pytorch { 51 : 52 : //+PLUMEDOC PYTORCH_FUNCTION PYTORCH_MODEL 53 : /* 54 : Load a PyTorch model compiled with TorchScript. 55 : 56 : This can be a function defined in Python or a more complex model, such as a neural network optimized on a set of data. In both cases the derivatives of the outputs with respect to the inputs are computed using the automatic differentiation (autograd) feature of Pytorch. 57 : 58 : By default it is assumed that the model is saved as: `model.ptc`, unless otherwise indicated by the `FILE` keyword. The function automatically checks for the number of output dimensions and creates a component for each of them. The outputs are called node-i with i between 0 and N-1 for N outputs. 59 : 60 : Note that this function requires \ref installation-libtorch LibTorch C++ library. Check the instructions in the \ref PYTORCH page to enable the module. 61 : 62 : \par Examples 63 : Load a model called `torch_model.ptc` that takes as input two dihedral angles and returns two outputs. 64 : 65 : \plumedfile 66 : #SETTINGS AUXFILE=regtest/pytorch/rt-pytorch_model_2d/torch_model.ptc 67 : phi: TORSION ATOMS=5,7,9,15 68 : psi: TORSION ATOMS=7,9,15,17 69 : model: PYTORCH_MODEL FILE=torch_model.ptc ARG=phi,psi 70 : PRINT FILE=COLVAR ARG=model.node-0,model.node-1 71 : \endplumedfile 72 : 73 : */ 74 : //+ENDPLUMEDOC 75 : 76 : 77 : class PytorchModel : 78 : public Function 79 : { 80 : unsigned _n_in; 81 : unsigned _n_out; 82 : torch::jit::script::Module _model; 83 : torch::Device device = torch::kCPU; 84 : 85 : public: 86 : explicit PytorchModel(const ActionOptions&); 87 : void calculate(); 88 : static void registerKeywords(Keywords& keys); 89 : 90 : std::vector<float> tensor_to_vector(const torch::Tensor& x); 91 : }; 92 : 93 : PLUMED_REGISTER_ACTION(PytorchModel,"PYTORCH_MODEL") 94 : 95 6 : void PytorchModel::registerKeywords(Keywords& keys) { 96 6 : Function::registerKeywords(keys); 97 12 : keys.add("optional","FILE","Filename of the PyTorch compiled model"); 98 12 : keys.addOutputComponent("node", "default", "Model outputs"); 99 6 : } 100 : 101 : // Auxiliary function to transform torch tensors in std vectors 102 103 : std::vector<float> PytorchModel::tensor_to_vector(const torch::Tensor& x) { 103 206 : return std::vector<float>(x.data_ptr<float>(), x.data_ptr<float>() + x.numel()); 104 : } 105 : 106 4 : PytorchModel::PytorchModel(const ActionOptions&ao): 107 : Action(ao), 108 4 : Function(ao) 109 : { 110 : // print libtorch version 111 4 : std::stringstream ss; 112 4 : ss << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." << TORCH_VERSION_PATCH; 113 : std::string version; 114 4 : ss >> version; // extract into the string. 115 8 : log.printf((" LibTorch version: "+version+"\n").data()); 116 : 117 : //number of inputs of the model 118 4 : _n_in=getNumberOfArguments(); 119 : 120 : //parse model name 121 4 : std::string fname="model.ptc"; 122 8 : parse("FILE",fname); 123 : 124 : //deserialize the model from file 125 : try { 126 4 : _model = torch::jit::load(fname, device); 127 : } 128 : 129 : //if an error is thrown check if the file exists or not 130 0 : catch (const c10::Error& e) { 131 0 : std::ifstream infile(fname); 132 : bool exist = infile.good(); 133 0 : infile.close(); 134 0 : if (exist) { 135 0 : plumed_merror("Cannot load FILE: '"+fname+"'. Please check that it is a Pytorch compiled model (exported with 'torch.jit.trace' or 'torch.jit.script')."); 136 : } 137 : else { 138 0 : plumed_merror("The FILE: '"+fname+"' does not exist."); 139 : } 140 0 : } 141 4 : checkRead(); 142 : 143 : // Optimize model 144 : _model.eval(); 145 : #ifdef DO_TORCH_FREEZE_HACK 146 : // Do the hack 147 : // Copied from the implementation of torch::jit::freeze, 148 : // except without the broken check 149 : // See https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp 150 : bool optimize_numerics = true; // the default 151 : // the {} is preserved_attrs 152 : auto out_mod = torch::jit::freeze_module( 153 : _model, {} 154 4 : ); 155 : // See 1.11 bugfix in https://github.com/pytorch/pytorch/pull/71436 156 8 : auto graph = out_mod.get_method("forward").graph(); 157 4 : OptimizeFrozenGraph(graph, optimize_numerics); 158 4 : _model = out_mod; 159 : #else 160 : // Do it normally 161 : _model = torch::jit::freeze(_model); 162 : #endif 163 : 164 : // Optimize model for inference 165 : #if (TORCH_VERSION_MAJOR == 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 10) 166 4 : _model = torch::jit::optimize_for_inference(_model); 167 : #endif 168 : 169 : //check the dimension of the output 170 4 : log.printf(" Checking output dimension:\n"); 171 4 : std::vector<float> input_test (_n_in); 172 4 : torch::Tensor single_input = torch::tensor(input_test).view({1,_n_in}); 173 8 : single_input = single_input.to(device); 174 : std::vector<torch::jit::IValue> inputs; 175 4 : inputs.push_back( single_input ); 176 8 : torch::Tensor output = _model.forward( inputs ).toTensor(); 177 4 : vector<float> cvs = this->tensor_to_vector (output); 178 4 : _n_out=cvs.size(); 179 : 180 : //create components 181 9 : for(unsigned j=0; j<_n_out; j++) { 182 5 : string name_comp = "node-"+std::to_string(j); 183 5 : addComponentWithDerivatives( name_comp ); 184 5 : componentIsNotPeriodic( name_comp ); 185 : } 186 : 187 : //print log 188 4 : log.printf(" Number of input: %d \n",_n_in); 189 4 : log.printf(" Number of outputs: %d \n",_n_out); 190 4 : log.printf(" Bibliography: "); 191 8 : log<<plumed.cite("Bonati, Trizio, Rizzi and Parrinello, J. Chem. Phys. 159, 014801 (2023)"); 192 8 : log<<plumed.cite("Bonati, Rizzi and Parrinello, J. Phys. Chem. Lett. 11, 2998-3004 (2020)"); 193 4 : log.printf("\n"); 194 : 195 12 : } 196 : 197 : 198 44 : void PytorchModel::calculate() { 199 : 200 : // retrieve arguments 201 44 : vector<float> current_S(_n_in); 202 99 : for(unsigned i=0; i<_n_in; i++) 203 55 : current_S[i]=getArgument(i); 204 : //convert to tensor 205 44 : torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in}).to(device); 206 : input_S.set_requires_grad(true); 207 : //convert to Ivalue 208 : std::vector<torch::jit::IValue> inputs; 209 44 : inputs.push_back( input_S ); 210 : //calculate output 211 88 : torch::Tensor output = _model.forward( inputs ).toTensor(); 212 : 213 : 214 99 : for(unsigned j=0; j<_n_out; j++) { 215 55 : auto grad_output = torch::ones({1}).expand({1, 1}).to(device); 216 440 : auto gradient = torch::autograd::grad({output.slice(/*dim=*/1, /*start=*/j, /*end=*/j+1)}, 217 : {input_S}, 218 : /*grad_outputs=*/ {grad_output}, 219 : /*retain_graph=*/true, 220 : /*create_graph=*/false)[0]; // the [0] is to get a tensor and not a vector<at::tensor> 221 : 222 55 : vector<float> der = this->tensor_to_vector ( gradient ); 223 55 : string name_comp = "node-"+std::to_string(j); 224 : //set derivatives of component j 225 132 : for(unsigned i=0; i<_n_in; i++) 226 77 : setDerivative( getPntrToComponent(name_comp),i, der[i] ); 227 : } 228 : 229 : //set CV values 230 44 : vector<float> cvs = this->tensor_to_vector (output); 231 99 : for(unsigned j=0; j<_n_out; j++) { 232 55 : string name_comp = "node-"+std::to_string(j); 233 55 : getPntrToComponent(name_comp)->set(cvs[j]); 234 : } 235 : 236 88 : } 237 : 238 : 239 : } //PLMD 240 : } //function 241 : } //pytorch 242 : 243 : #endif //PLUMED_HAS_LIBTORCH