LCOV - code coverage report
Current view: top level - metatensor - vesin.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 0 321 0.0 %
Date: 2024-10-18 13:59:31 Functions: 0 29 0.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2024 The METATENSOR code team
       3             : (see the PEOPLE-METATENSOR file at the root of this folder for a list of names)
       4             : 
       5             : See https://docs.metatensor.org/latest/ for more information about the
       6             : metatensor package that this module allows you to call from PLUMED.
       7             : 
       8             : This file is part of METATENSOR-PLUMED module.
       9             : 
      10             : The METATENSOR-PLUMED module is free software: you can redistribute it and/or modify
      11             : it under the terms of the GNU Lesser General Public License as published by
      12             : the Free Software Foundation, either version 3 of the License, or
      13             : (at your option) any later version.
      14             : 
      15             : The METATENSOR-PLUMED module is distributed in the hope that it will be useful,
      16             : but WITHOUT ANY WARRANTY; without even the implied warranty of
      17             : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      18             : GNU Lesser General Public License for more details.
      19             : 
      20             : You should have received a copy of the GNU Lesser General Public License
      21             : along with the METATENSOR-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
      22             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      23             : /*INDENT-OFF*/
      24             : #include "vesin.h"
      25             : #include <cassert>
      26             : #include <cstdlib>
      27             : #include <cstring>
      28             : 
      29             : #include <algorithm>
      30             : #include <tuple>
      31             : #include <new>
      32             : 
      33             : #ifndef VESIN_CPU_CELL_LIST_HPP
      34             : #define VESIN_CPU_CELL_LIST_HPP
      35             : 
      36             : #include <vector>
      37             : 
      38             : #include "vesin.h"
      39             : 
      40             : #ifndef VESIN_TYPES_HPP
      41             : #define VESIN_TYPES_HPP
      42             : 
      43             : #ifndef VESIN_MATH_HPP
      44             : #define VESIN_MATH_HPP
      45             : 
      46             : #include <array>
      47             : #include <cmath>
      48             : #include <stdexcept>
      49             : 
      50             : namespace PLMD {
      51             : namespace metatensor {
      52             : namespace vesin {
      53             : struct Vector;
      54             : 
      55             : Vector operator*(Vector vector, double scalar);
      56             : 
      57             : struct Vector: public std::array<double, 3> {
      58             :     double dot(Vector other) const {
      59           0 :         return (*this)[0] * other[0] + (*this)[1] * other[1] + (*this)[2] * other[2];
      60             :     }
      61             : 
      62           0 :     double norm() const {
      63           0 :         return std::sqrt(this->dot(*this));
      64             :     }
      65             : 
      66           0 :     Vector normalize() const {
      67           0 :         return *this * (1.0 / this->norm());
      68             :     }
      69             : 
      70             :     Vector cross(Vector other) const {
      71             :         return Vector{
      72           0 :             (*this)[1] * other[2] - (*this)[2] * other[1],
      73           0 :             (*this)[2] * other[0] - (*this)[0] * other[2],
      74           0 :             (*this)[0] * other[1] - (*this)[1] * other[0],
      75           0 :         };
      76             :     }
      77             : };
      78             : 
      79             : inline Vector operator+(Vector u, Vector v) {
      80             :     return Vector{
      81           0 :         u[0] + v[0],
      82           0 :         u[1] + v[1],
      83           0 :         u[2] + v[2],
      84             :     };
      85             : }
      86             : 
      87             : inline Vector operator-(Vector u, Vector v) {
      88             :     return Vector{
      89           0 :         u[0] - v[0],
      90           0 :         u[1] - v[1],
      91           0 :         u[2] - v[2],
      92             :     };
      93             : }
      94             : 
      95             : inline Vector operator*(double scalar, Vector vector) {
      96             :     return Vector{
      97             :         scalar * vector[0],
      98             :         scalar * vector[1],
      99             :         scalar * vector[2],
     100             :     };
     101             : }
     102             : 
     103             : inline Vector operator*(Vector vector, double scalar) {
     104             :     return Vector{
     105           0 :         scalar * vector[0],
     106           0 :         scalar * vector[1],
     107           0 :         scalar * vector[2],
     108           0 :     };
     109             : }
     110             : 
     111             : 
     112             : struct Matrix: public std::array<std::array<double, 3>, 3> {
     113           0 :     double determinant() const {
     114           0 :         return (*this)[0][0] * ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2])
     115           0 :              - (*this)[0][1] * ((*this)[1][0] * (*this)[2][2] - (*this)[1][2] * (*this)[2][0])
     116           0 :              + (*this)[0][2] * ((*this)[1][0] * (*this)[2][1] - (*this)[1][1] * (*this)[2][0]);
     117             :     }
     118             : 
     119           0 :     Matrix inverse() const {
     120           0 :         auto det = this->determinant();
     121             : 
     122           0 :         if (std::abs(det) < 1e-30) {
     123           0 :             throw std::runtime_error("this matrix is not invertible");
     124             :         }
     125             : 
     126             :         auto inverse = Matrix();
     127           0 :         inverse[0][0] = ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2]) / det;
     128           0 :         inverse[0][1] = ((*this)[0][2] * (*this)[2][1] - (*this)[0][1] * (*this)[2][2]) / det;
     129           0 :         inverse[0][2] = ((*this)[0][1] * (*this)[1][2] - (*this)[0][2] * (*this)[1][1]) / det;
     130           0 :         inverse[1][0] = ((*this)[1][2] * (*this)[2][0] - (*this)[1][0] * (*this)[2][2]) / det;
     131           0 :         inverse[1][1] = ((*this)[0][0] * (*this)[2][2] - (*this)[0][2] * (*this)[2][0]) / det;
     132           0 :         inverse[1][2] = ((*this)[1][0] * (*this)[0][2] - (*this)[0][0] * (*this)[1][2]) / det;
     133           0 :         inverse[2][0] = ((*this)[1][0] * (*this)[2][1] - (*this)[2][0] * (*this)[1][1]) / det;
     134           0 :         inverse[2][1] = ((*this)[2][0] * (*this)[0][1] - (*this)[0][0] * (*this)[2][1]) / det;
     135           0 :         inverse[2][2] = ((*this)[0][0] * (*this)[1][1] - (*this)[1][0] * (*this)[0][1]) / det;
     136           0 :         return inverse;
     137             :     }
     138             : };
     139             : 
     140             : 
     141             : inline Vector operator*(Matrix matrix, Vector vector) {
     142             :     return Vector{
     143             :         matrix[0][0] * vector[0] + matrix[0][1] * vector[1] + matrix[0][2] * vector[2],
     144             :         matrix[1][0] * vector[0] + matrix[1][1] * vector[1] + matrix[1][2] * vector[2],
     145             :         matrix[2][0] * vector[0] + matrix[2][1] * vector[1] + matrix[2][2] * vector[2],
     146             :     };
     147             : }
     148             : 
     149           0 : inline Vector operator*(Vector vector, Matrix matrix) {
     150             :     return Vector{
     151           0 :         vector[0] * matrix[0][0] + vector[1] * matrix[1][0] + vector[2] * matrix[2][0],
     152           0 :         vector[0] * matrix[0][1] + vector[1] * matrix[1][1] + vector[2] * matrix[2][1],
     153           0 :         vector[0] * matrix[0][2] + vector[1] * matrix[1][2] + vector[2] * matrix[2][2],
     154           0 :     };
     155             : }
     156             : 
     157             : } // namespace vesin
     158             : } // namespace metatensor
     159             : } // namespace PLMD
     160             : 
     161             : #endif
     162             : 
     163             : namespace PLMD {
     164             : namespace metatensor {
     165             : namespace vesin {
     166             : 
     167             : class BoundingBox {
     168             : public:
     169           0 :     BoundingBox(Matrix matrix, bool periodic): matrix_(matrix), periodic_(periodic) {
     170           0 :         if (periodic) {
     171           0 :             this->inverse_ = matrix_.inverse();
     172             :         } else {
     173           0 :             this->matrix_ = Matrix{{{
     174             :                 {{1, 0, 0}},
     175             :                 {{0, 1, 0}},
     176             :                 {{0, 0, 1}}
     177             :             }}};
     178           0 :             this->inverse_ = matrix_;
     179             :         }
     180           0 :     }
     181             : 
     182             :     const Matrix& matrix() const {
     183             :         return this->matrix_;
     184             :     }
     185             : 
     186             :     bool periodic() const {
     187           0 :         return this->periodic_;
     188             :     }
     189             : 
     190             :     /// Convert a vector from cartesian coordinates to fractional coordinates
     191             :     Vector cartesian_to_fractional(Vector cartesian) const {
     192           0 :         return cartesian * inverse_;
     193             :     }
     194             : 
     195             :     /// Convert a vector from fractional coordinates to cartesian coordinates
     196             :     Vector fractional_to_cartesian(Vector fractional) const {
     197             :         return fractional * matrix_;
     198             :     }
     199             : 
     200             :     /// Get the three distances between faces of the bounding box
     201           0 :     Vector distances_between_faces() const {
     202           0 :         auto a = Vector{matrix_[0]};
     203           0 :         auto b = Vector{matrix_[1]};
     204           0 :         auto c = Vector{matrix_[2]};
     205             : 
     206             :         // Plans normal vectors
     207           0 :         auto na = b.cross(c).normalize();
     208           0 :         auto nb = c.cross(a).normalize();
     209           0 :         auto nc = a.cross(b).normalize();
     210             : 
     211             :         return Vector{
     212             :             std::abs(na.dot(a)),
     213             :             std::abs(nb.dot(b)),
     214             :             std::abs(nc.dot(c)),
     215           0 :         };
     216             :     }
     217             : 
     218             : private:
     219             :     Matrix matrix_;
     220             :     Matrix inverse_;
     221             :     bool periodic_;
     222             : };
     223             : 
     224             : 
     225             : /// A cell shift represents the displacement along cell axis between the actual
     226             : /// position of an atom and a periodic image of this atom.
     227             : ///
     228             : /// The cell shift can be used to reconstruct the vector between two points,
     229             : /// wrapped inside the unit cell.
     230             : struct CellShift: public std::array<int32_t, 3> {
     231             :     /// Compute the shift vector in cartesian coordinates, using the given cell
     232             :     /// matrix (stored in row major order).
     233           0 :     Vector cartesian(Matrix cell) const {
     234             :         auto vector = Vector{
     235           0 :             static_cast<double>((*this)[0]),
     236           0 :             static_cast<double>((*this)[1]),
     237           0 :             static_cast<double>((*this)[2]),
     238           0 :         };
     239           0 :         return vector * cell;
     240             :     }
     241             : };
     242             : 
     243             : inline CellShift operator+(CellShift a, CellShift b) {
     244             :     return CellShift{
     245           0 :         a[0] + b[0],
     246           0 :         a[1] + b[1],
     247           0 :         a[2] + b[2],
     248             :     };
     249             : }
     250             : 
     251             : inline CellShift operator-(CellShift a, CellShift b) {
     252             :     return CellShift{
     253           0 :         a[0] - b[0],
     254           0 :         a[1] - b[1],
     255           0 :         a[2] - b[2],
     256             :     };
     257             : }
     258             : 
     259             : 
     260             : } // namespace vesin
     261             : } // namespace metatensor
     262             : } // namespace PLMD
     263             : 
     264             : #endif
     265             : 
     266             : namespace PLMD {
     267             : namespace metatensor {
     268             : namespace vesin { namespace cpu {
     269             : 
     270             : void free_neighbors(VesinNeighborList& neighbors);
     271             : 
     272             : void neighbors(
     273             :     const Vector* points,
     274             :     size_t n_points,
     275             :     BoundingBox cell,
     276             :     VesinOptions options,
     277             :     VesinNeighborList& neighbors
     278             : );
     279             : 
     280             : 
     281             : /// The cell list is used to sort atoms inside bins/cells.
     282             : ///
     283             : /// The list of potential pairs is then constructed by looking through all
     284             : /// neighboring cells (the number of cells to search depends on the cutoff and
     285             : /// the size of the cells) for each atom to create pair candidates.
     286           0 : class CellList {
     287             : public:
     288             :     /// Create a new `CellList` for the given bounding box and cutoff,
     289             :     /// determining all required parameters.
     290             :     CellList(BoundingBox box, double cutoff);
     291             : 
     292             :     /// Add a single point to the cell list at the given `position`. The point
     293             :     /// is uniquely identified by its `index`.
     294             :     void add_point(size_t index, Vector position);
     295             : 
     296             :     /// Iterate over all possible pairs, calling the given callback every time
     297             :     template <typename Function>
     298             :     void foreach_pair(Function callback);
     299             : 
     300             : private:
     301             :     /// How many cells do we need to look at when searching neighbors to include
     302             :     /// all neighbors below cutoff
     303             :     std::array<int32_t, 3> n_search_;
     304             : 
     305             :     /// the cells themselves are a list of points & corresponding
     306             :     /// shift to place the point inside the cell
     307             :     struct Point {
     308             :         size_t index;
     309             :         CellShift shift;
     310             :     };
     311           0 :     struct Cell: public std::vector<Point> {};
     312             : 
     313             :     // raw data for the cells
     314             :     std::vector<Cell> cells_;
     315             :     // shape of the cell array
     316             :     std::array<size_t, 3> cells_shape_;
     317             : 
     318             :     BoundingBox box_;
     319             : 
     320             :     Cell& get_cell(std::array<int32_t, 3> index);
     321             : };
     322             : 
     323             : /// Wrapper around `VesinNeighborList` that behaves like a std::vector,
     324             : /// automatically growing memory allocations.
     325             : class GrowableNeighborList {
     326             : public:
     327             :     VesinNeighborList& neighbors;
     328             :     size_t capacity;
     329             :     VesinOptions options;
     330             : 
     331             :     size_t length() const {
     332           0 :         return neighbors.length;
     333             :     }
     334             : 
     335             :     void increment_length() {
     336           0 :         neighbors.length += 1;
     337           0 :     }
     338             : 
     339             :     void set_pair(size_t index, size_t first, size_t second);
     340             :     void set_shift(size_t index, PLMD::metatensor::vesin::CellShift shift);
     341             :     void set_distance(size_t index, double distance);
     342             :     void set_vector(size_t index, PLMD::metatensor::vesin::Vector vector);
     343             : 
     344             :     // reset length to 0, and allocate/deallocate members of
     345             :     // `neighbors` according to `options`
     346             :     void reset();
     347             : 
     348             :     // allocate more memory & update capacity
     349             :     void grow();
     350             : };
     351             : 
     352             : } // namespace vesin
     353             : } // namespace metatensor
     354             : } // namespace PLMD
     355             : } // namespace cpu
     356             : 
     357             : #endif
     358             : 
     359             : using namespace PLMD::metatensor::vesin;
     360             : using namespace PLMD::metatensor::vesin::cpu;
     361             : 
     362           0 : void PLMD::metatensor::vesin::cpu::neighbors(
     363             :     const Vector* points,
     364             :     size_t n_points,
     365             :     BoundingBox cell,
     366             :     VesinOptions options,
     367             :     VesinNeighborList& raw_neighbors
     368             : ) {
     369           0 :     auto cell_list = CellList(cell, options.cutoff);
     370             : 
     371           0 :     for (size_t i=0; i<n_points; i++) {
     372           0 :         cell_list.add_point(i, points[i]);
     373             :     }
     374             : 
     375           0 :     auto cell_matrix = cell.matrix();
     376           0 :     auto cutoff2 = options.cutoff * options.cutoff;
     377             : 
     378             :     // the cell list creates too many pairs, we only need to keep the
     379             :     // one where the distance is actually below the cutoff
     380           0 :     auto neighbors = GrowableNeighborList{raw_neighbors, raw_neighbors.length, options};
     381           0 :     neighbors.reset();
     382             : 
     383           0 :     cell_list.foreach_pair([&](size_t first, size_t second, CellShift shift) {
     384           0 :         if (!options.full) {
     385             :             // filter out some pairs for half neighbor lists
     386           0 :             if (first > second) {
     387             :                 return;
     388             :             }
     389             : 
     390           0 :             if (first == second) {
     391             :                 // When creating pairs between a point and one of its periodic
     392             :                 // images, the code generate multiple redundant pairs (e.g. with
     393             :                 // shifts 0 1 1 and 0 -1 -1); and we want to only keep one of
     394             :                 // these.
     395           0 :                 if (shift[0] + shift[1] + shift[2] < 0) {
     396             :                     // drop shifts on the negative half-space
     397             :                     return;
     398             :                 }
     399             : 
     400             :                 if ((shift[0] + shift[1] + shift[2] == 0)
     401           0 :                     && (shift[2] < 0 || (shift[2] == 0 && shift[1] < 0))) {
     402             :                     // drop shifts in the negative half plane or the negative
     403             :                     // shift[1] axis. See below for a graphical representation:
     404             :                     // we are keeping the shifts indicated with `O` and dropping
     405             :                     // the ones indicated with `X`
     406             :                     //
     407             :                     //  O O O │ O O O
     408             :                     //  O O O │ O O O
     409             :                     //  O O O │ O O O
     410             :                     // ─X─X─X─┼─O─O─O─
     411             :                     //  X X X │ X X X
     412             :                     //  X X X │ X X X
     413             :                     //  X X X │ X X X
     414             :                     return;
     415             :                 }
     416             :             }
     417             :         }
     418             : 
     419           0 :         auto vector = points[second] - points[first] + shift.cartesian(cell_matrix);
     420             :         auto distance2 = vector.dot(vector);
     421             : 
     422           0 :         if (distance2 < cutoff2) {
     423           0 :             auto index = neighbors.length();
     424           0 :             neighbors.set_pair(index, first, second);
     425             : 
     426           0 :             if (options.return_shifts) {
     427           0 :                 neighbors.set_shift(index, shift);
     428             :             }
     429             : 
     430           0 :             if (options.return_distances) {
     431           0 :                 neighbors.set_distance(index, std::sqrt(distance2));
     432             :             }
     433             : 
     434           0 :             if (options.return_vectors) {
     435           0 :                 neighbors.set_vector(index, vector);
     436             :             }
     437             : 
     438             :             neighbors.increment_length();
     439             :         }
     440             :     });
     441           0 : }
     442             : 
     443             : /* ========================================================================== */
     444             : 
     445             : /// Maximal number of cells, we need to use this to prevent having too many
     446             : /// cells with a small bounding box and a large cutoff
     447             : #define MAX_NUMBER_OF_CELLS 1e5
     448             : 
     449             : 
     450             : /// Function to compute both quotient and remainder of the division of a by b.
     451             : /// This function follows Python convention, making sure the remainder have the
     452             : /// same sign as `b`.
     453           0 : static std::tuple<int32_t, int32_t> divmod(int32_t a, size_t b) {
     454             :     assert(b < (std::numeric_limits<int32_t>::max()));
     455           0 :     auto b_32 = static_cast<int32_t>(b);
     456           0 :     auto quotient = a / b_32;
     457           0 :     auto remainder = a % b_32;
     458           0 :     if (remainder < 0) {
     459           0 :         remainder += b_32;
     460           0 :         quotient -= 1;
     461             :     }
     462           0 :     return std::make_tuple(quotient, remainder);
     463             : }
     464             : 
     465             : /// Apply the `divmod` function to three components at the time
     466             : static std::tuple<std::array<int32_t, 3>, std::array<int32_t, 3>>
     467           0 : divmod(std::array<int32_t, 3> a, std::array<size_t, 3> b) {
     468           0 :     auto [qx, rx] = divmod(a[0], b[0]);
     469           0 :     auto [qy, ry] = divmod(a[1], b[1]);
     470           0 :     auto [qz, rz] = divmod(a[2], b[2]);
     471             :     return std::make_tuple(
     472           0 :         std::array<int32_t, 3>{qx, qy, qz},
     473           0 :         std::array<int32_t, 3>{rx, ry, rz}
     474           0 :     );
     475             : }
     476             : 
     477           0 : CellList::CellList(BoundingBox box, double cutoff):
     478           0 :     n_search_({0, 0, 0}),
     479           0 :     cells_shape_({0, 0, 0}),
     480           0 :     box_(box)
     481             : {
     482           0 :     auto distances_between_faces = box_.distances_between_faces();
     483             : 
     484             :     auto n_cells = Vector{
     485           0 :         std::clamp(std::trunc(distances_between_faces[0] / cutoff), 1.0, HUGE_VAL),
     486           0 :         std::clamp(std::trunc(distances_between_faces[1] / cutoff), 1.0, HUGE_VAL),
     487           0 :         std::clamp(std::trunc(distances_between_faces[2] / cutoff), 1.0, HUGE_VAL),
     488           0 :     };
     489             : 
     490             :     assert(std::isfinite(n_cells[0]) && std::isfinite(n_cells[1]) && std::isfinite(n_cells[2]));
     491             : 
     492             :     // limit memory consumption by ensuring we have less than `MAX_N_CELLS`
     493             :     // cells to look though
     494           0 :     auto n_cells_total = n_cells[0] * n_cells[1] * n_cells[2];
     495           0 :     if (n_cells_total > MAX_NUMBER_OF_CELLS) {
     496             :         // set the total number of cells close to MAX_N_CELLS, while keeping
     497             :         // roughly the ratio of cells in each direction
     498           0 :         auto ratio_x_y = n_cells[0] / n_cells[1];
     499           0 :         auto ratio_y_z = n_cells[1] / n_cells[2];
     500             : 
     501           0 :         n_cells[2] = std::trunc(std::cbrt(MAX_NUMBER_OF_CELLS / (ratio_x_y * ratio_y_z * ratio_y_z)));
     502           0 :         n_cells[1] = std::trunc(ratio_y_z * n_cells[2]);
     503           0 :         n_cells[0] = std::trunc(ratio_x_y * n_cells[1]);
     504             :     }
     505             : 
     506             :     // number of cells to search in each direction to make sure all possible
     507             :     // pairs below the cutoff are accounted for.
     508           0 :     this->n_search_ = std::array<int32_t, 3>{
     509           0 :         static_cast<int32_t>(std::ceil(cutoff * n_cells[0] / distances_between_faces[0])),
     510           0 :         static_cast<int32_t>(std::ceil(cutoff * n_cells[1] / distances_between_faces[1])),
     511           0 :         static_cast<int32_t>(std::ceil(cutoff * n_cells[2] / distances_between_faces[2])),
     512             :     };
     513             : 
     514           0 :     this->cells_shape_ = std::array<size_t, 3>{
     515           0 :         static_cast<size_t>(n_cells[0]),
     516           0 :         static_cast<size_t>(n_cells[1]),
     517           0 :         static_cast<size_t>(n_cells[2]),
     518             :     };
     519             : 
     520           0 :     for (size_t spatial=0; spatial<3; spatial++) {
     521           0 :         if (n_search_[spatial] < 1) {
     522           0 :             n_search_[spatial] = 1;
     523             :         }
     524             : 
     525             :         // don't look for neighboring cells if we have only one cell and no
     526             :         // periodic boundary condition
     527           0 :         if (n_cells[spatial] == 1 && !box.periodic()) {
     528           0 :             n_search_[spatial] = 0;
     529             :         }
     530             :     }
     531             : 
     532           0 :     this->cells_.resize(cells_shape_[0] * cells_shape_[1] * cells_shape_[2]);
     533           0 : }
     534             : 
     535           0 : void CellList::add_point(size_t index, Vector position) {
     536           0 :     auto fractional = box_.cartesian_to_fractional(position);
     537             : 
     538             :     // find the cell in which this atom should go
     539             :     auto cell_index = std::array<int32_t, 3>{
     540           0 :         static_cast<int32_t>(std::floor(fractional[0] * static_cast<double>(cells_shape_[0]))),
     541           0 :         static_cast<int32_t>(std::floor(fractional[1] * static_cast<double>(cells_shape_[1]))),
     542           0 :         static_cast<int32_t>(std::floor(fractional[2] * static_cast<double>(cells_shape_[2]))),
     543           0 :     };
     544             : 
     545             :     // deal with pbc by wrapping the atom inside if it was outside of the
     546             :     // cell
     547             :     CellShift shift;
     548             :     // auto (shift, cell_index) =
     549           0 :     if (box_.periodic()) {
     550           0 :         auto result = divmod(cell_index, cells_shape_);
     551           0 :         shift = CellShift{std::get<0>(result)};
     552           0 :         cell_index = std::get<1>(result);
     553             :     } else {
     554             :         shift = CellShift({0, 0, 0});
     555           0 :         cell_index = std::array<int32_t, 3>{
     556           0 :             std::clamp(cell_index[0], 0, static_cast<int32_t>(cells_shape_[0] - 1)),
     557           0 :             std::clamp(cell_index[1], 0, static_cast<int32_t>(cells_shape_[1] - 1)),
     558           0 :             std::clamp(cell_index[2], 0, static_cast<int32_t>(cells_shape_[2] - 1)),
     559             :         };
     560             :     }
     561             : 
     562           0 :     this->get_cell(cell_index).emplace_back(Point{index, shift});
     563           0 : }
     564             : 
     565             : 
     566             : template <typename Function>
     567           0 : void CellList::foreach_pair(Function callback) {
     568           0 :     for (int32_t cell_i_x=0; cell_i_x<static_cast<int32_t>(cells_shape_[0]); cell_i_x++) {
     569           0 :     for (int32_t cell_i_y=0; cell_i_y<static_cast<int32_t>(cells_shape_[1]); cell_i_y++) {
     570           0 :     for (int32_t cell_i_z=0; cell_i_z<static_cast<int32_t>(cells_shape_[2]); cell_i_z++) {
     571           0 :         const auto& current_cell = this->get_cell({cell_i_x, cell_i_y, cell_i_z});
     572             :         // look through each neighboring cell
     573           0 :         for (int32_t delta_x=-n_search_[0]; delta_x<=n_search_[0]; delta_x++) {
     574           0 :         for (int32_t delta_y=-n_search_[1]; delta_y<=n_search_[1]; delta_y++) {
     575           0 :         for (int32_t delta_z=-n_search_[2]; delta_z<=n_search_[2]; delta_z++) {
     576           0 :             auto cell_i = std::array<int32_t, 3>{
     577           0 :                 cell_i_x + delta_x,
     578           0 :                 cell_i_y + delta_y,
     579           0 :                 cell_i_z + delta_z,
     580             :             };
     581             : 
     582             :             // shift vector from one cell to the other and index of
     583             :             // the neighboring cell
     584           0 :             auto [cell_shift, neighbor_cell_i] = divmod(cell_i, cells_shape_);
     585             : 
     586           0 :             for (const auto& atom_i: current_cell) {
     587           0 :                 for (const auto& atom_j: this->get_cell(neighbor_cell_i)) {
     588           0 :                     auto shift = CellShift{cell_shift} + atom_i.shift - atom_j.shift;
     589           0 :                     auto shift_is_zero = shift[0] == 0 && shift[1] == 0 && shift[2] == 0;
     590             : 
     591           0 :                     if (!box_.periodic() && !shift_is_zero) {
     592             :                         // do not create pairs crossing the periodic
     593             :                         // boundaries in a non-periodic box
     594           0 :                         continue;
     595             :                     }
     596             : 
     597           0 :                     if (atom_i.index == atom_j.index && shift_is_zero) {
     598             :                         // only create pairs with the same atom twice if the
     599             :                         // pair spans more than one bounding box
     600           0 :                         continue;
     601             :                     }
     602             : 
     603           0 :                     callback(atom_i.index, atom_j.index, shift);
     604             :                 }
     605             :             } // loop over atoms in current neighbor cells
     606             :         }}}
     607             :     }}} // loop over neighboring cells
     608           0 : }
     609             : 
     610           0 : CellList::Cell& CellList::get_cell(std::array<int32_t, 3> index) {
     611           0 :     size_t linear_index = (cells_shape_[0] * cells_shape_[1] * index[2])
     612           0 :                         + (cells_shape_[0] * index[1])
     613           0 :                         + index[0];
     614           0 :     return cells_[linear_index];
     615             : }
     616             : 
     617             : /* ========================================================================== */
     618             : 
     619             : 
     620           0 : void GrowableNeighborList::set_pair(size_t index, size_t first, size_t second) {
     621           0 :     if (index >= this->capacity) {
     622           0 :         this->grow();
     623             :     }
     624             : 
     625           0 :     this->neighbors.pairs[index][0] = first;
     626           0 :     this->neighbors.pairs[index][1] = second;
     627           0 : }
     628             : 
     629           0 : void GrowableNeighborList::set_shift(size_t index, PLMD::metatensor::vesin::CellShift shift) {
     630           0 :     if (index >= this->capacity) {
     631           0 :         this->grow();
     632             :     }
     633             : 
     634           0 :     this->neighbors.shifts[index][0] = shift[0];
     635           0 :     this->neighbors.shifts[index][1] = shift[1];
     636           0 :     this->neighbors.shifts[index][2] = shift[2];
     637           0 : }
     638             : 
     639           0 : void GrowableNeighborList::set_distance(size_t index, double distance) {
     640           0 :     if (index >= this->capacity) {
     641           0 :         this->grow();
     642             :     }
     643             : 
     644           0 :     this->neighbors.distances[index] = distance;
     645           0 : }
     646             : 
     647           0 : void GrowableNeighborList::set_vector(size_t index, PLMD::metatensor::vesin::Vector vector) {
     648           0 :     if (index >= this->capacity) {
     649           0 :         this->grow();
     650             :     }
     651             : 
     652           0 :     this->neighbors.vectors[index][0] = vector[0];
     653           0 :     this->neighbors.vectors[index][1] = vector[1];
     654           0 :     this->neighbors.vectors[index][2] = vector[2];
     655           0 : }
     656             : 
     657             : template <typename scalar_t, size_t N>
     658           0 : static scalar_t (*alloc(scalar_t (*ptr)[N], size_t size, size_t new_size))[N] {
     659           0 :     auto* new_ptr = reinterpret_cast<scalar_t (*)[N]>(std::realloc(ptr, new_size * sizeof(scalar_t[N])));
     660             : 
     661           0 :     if (new_ptr == nullptr) {
     662           0 :         throw std::bad_alloc();
     663             :     }
     664             : 
     665             :     // initialize with a bit pattern that maps to NaN for double
     666           0 :     std::memset(new_ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t[N]));
     667             : 
     668           0 :     return new_ptr;
     669             : }
     670             : 
     671             : template <typename scalar_t>
     672           0 : static scalar_t* alloc(scalar_t* ptr, size_t size, size_t new_size) {
     673           0 :     auto* new_ptr = reinterpret_cast<scalar_t*>(std::realloc(ptr, new_size * sizeof(scalar_t)));
     674             : 
     675           0 :     if (new_ptr == nullptr) {
     676           0 :         throw std::bad_alloc();
     677             :     }
     678             : 
     679             :     // initialize with a bit pattern that maps to NaN for double
     680           0 :     std::memset(new_ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t));
     681             : 
     682           0 :     return new_ptr;
     683             : }
     684             : 
     685           0 : void GrowableNeighborList::grow() {
     686           0 :     auto new_size = neighbors.length * 2;
     687             :     if (new_size == 0) {
     688             :         new_size = 1;
     689             :     }
     690             : 
     691           0 :     auto* new_pairs = alloc<size_t, 2>(neighbors.pairs, neighbors.length, new_size);
     692             : 
     693             :     int32_t (*new_shifts)[3] = nullptr;
     694           0 :     if (options.return_shifts) {
     695           0 :         new_shifts = alloc<int32_t, 3>(neighbors.shifts, neighbors.length, new_size);
     696             :     }
     697             : 
     698             :     double *new_distances = nullptr;
     699           0 :     if (options.return_distances) {
     700           0 :         new_distances = alloc<double>(neighbors.distances, neighbors.length, new_size);
     701             :     }
     702             : 
     703             :     double (*new_vectors)[3] = nullptr;
     704           0 :     if (options.return_vectors) {
     705           0 :         new_vectors = alloc<double, 3>(neighbors.vectors, neighbors.length, new_size);
     706             :     }
     707             : 
     708           0 :     this->neighbors.pairs = new_pairs;
     709           0 :     this->neighbors.shifts = new_shifts;
     710           0 :     this->neighbors.distances = new_distances;
     711           0 :     this->neighbors.vectors = new_vectors;
     712             : 
     713           0 :     this->capacity = new_size;
     714           0 : }
     715             : 
     716           0 : void GrowableNeighborList::reset() {
     717             :     // set all allocated data to zero
     718           0 :     auto size = this->neighbors.length;
     719           0 :     std::memset(this->neighbors.pairs, 0, size * sizeof(size_t[2]));
     720             : 
     721           0 :     if (this->neighbors.shifts != nullptr) {
     722           0 :         std::memset(this->neighbors.shifts, 0, size * sizeof(int32_t[3]));
     723             :     }
     724             : 
     725           0 :     if (this->neighbors.distances != nullptr) {
     726           0 :         std::memset(this->neighbors.distances, 0, size * sizeof(double));
     727             :     }
     728             : 
     729           0 :     if (this->neighbors.vectors != nullptr) {
     730           0 :         std::memset(this->neighbors.vectors, 0, size * sizeof(double[3]));
     731             :     }
     732             : 
     733             :     // reset length (but keep the capacity where it's at)
     734           0 :     this->neighbors.length = 0;
     735             : 
     736             :     // allocate/deallocate pointers as required
     737           0 :     auto* shifts = this->neighbors.shifts;
     738           0 :     if (this->options.return_shifts && shifts == nullptr) {
     739           0 :         shifts = alloc<int32_t, 3>(shifts, 0, capacity);
     740           0 :     } else if (!this->options.return_shifts && shifts != nullptr) {
     741           0 :         std::free(shifts);
     742             :         shifts = nullptr;
     743             :     }
     744             : 
     745           0 :     auto* distances = this->neighbors.distances;
     746           0 :     if (this->options.return_distances && distances == nullptr) {
     747           0 :         distances = alloc<double>(distances, 0, capacity);
     748           0 :     } else if (!this->options.return_distances && distances != nullptr) {
     749           0 :         std::free(distances);
     750             :         distances = nullptr;
     751             :     }
     752             : 
     753           0 :     auto* vectors = this->neighbors.vectors;
     754           0 :     if (this->options.return_vectors && vectors == nullptr) {
     755           0 :         vectors = alloc<double, 3>(vectors, 0, capacity);
     756           0 :     } else if (!this->options.return_vectors && vectors != nullptr) {
     757           0 :         std::free(vectors);
     758             :         vectors = nullptr;
     759             :     }
     760             : 
     761           0 :     this->neighbors.shifts = shifts;
     762           0 :     this->neighbors.distances = distances;
     763           0 :     this->neighbors.vectors = vectors;
     764           0 : }
     765             : 
     766             : 
     767           0 : void PLMD::metatensor::vesin::cpu::free_neighbors(VesinNeighborList& neighbors) {
     768             :     assert(neighbors.device == VesinCPU);
     769             : 
     770           0 :     std::free(neighbors.pairs);
     771           0 :     std::free(neighbors.shifts);
     772           0 :     std::free(neighbors.vectors);
     773           0 :     std::free(neighbors.distances);
     774           0 : }
     775             : #include <cstring>
     776             : #include <string>
     777             : 
     778             : 
     779             : 
     780             : thread_local std::string LAST_ERROR;
     781             : 
     782           0 : extern "C" int vesin_neighbors(
     783             :     const double (*points)[3],
     784             :     size_t n_points,
     785             :     const double box[3][3],
     786             :     bool periodic,
     787             :     VesinDevice device,
     788             :     VesinOptions options,
     789             :     VesinNeighborList* neighbors,
     790             :     const char** error_message
     791             : ) {
     792           0 :     if (error_message == nullptr) {
     793             :         return EXIT_FAILURE;
     794             :     }
     795             : 
     796           0 :     if (points == nullptr) {
     797           0 :         *error_message = "`points` can not be a NULL pointer";
     798           0 :         return EXIT_FAILURE;
     799             :     }
     800             : 
     801           0 :     if (box == nullptr) {
     802           0 :         *error_message = "`cell` can not be a NULL pointer";
     803           0 :         return EXIT_FAILURE;
     804             :     }
     805             : 
     806           0 :     if (neighbors == nullptr) {
     807           0 :         *error_message = "`neighbors` can not be a NULL pointer";
     808           0 :         return EXIT_FAILURE;
     809             :     }
     810             : 
     811           0 :     if (neighbors->device != VesinUnknownDevice && neighbors->device != device) {
     812           0 :         *error_message = "`neighbors` device and data `device` do not match, free the neighbors first";
     813           0 :         return EXIT_FAILURE;
     814             :     }
     815             : 
     816           0 :     if (device == VesinUnknownDevice) {
     817           0 :         *error_message = "got an unknown device to use when running simulation";
     818           0 :         return EXIT_FAILURE;
     819             :     }
     820             : 
     821           0 :     if (neighbors->device == VesinUnknownDevice) {
     822             :         // initialize the device
     823           0 :         neighbors->device = device;
     824           0 :     } else if (neighbors->device != device) {
     825           0 :         *error_message = "`neighbors.device` and `device` do not match, free the neighbors first";
     826           0 :         return EXIT_FAILURE;
     827             :     }
     828             : 
     829             :     try {
     830           0 :         if (device == VesinCPU) {
     831             :             auto matrix = PLMD::metatensor::vesin::Matrix{{{
     832           0 :                 {{box[0][0], box[0][1], box[0][2]}},
     833           0 :                 {{box[1][0], box[1][1], box[1][2]}},
     834           0 :                 {{box[2][0], box[2][1], box[2][2]}},
     835           0 :             }}};
     836             : 
     837           0 :             PLMD::metatensor::vesin::cpu::neighbors(
     838             :                 reinterpret_cast<const PLMD::metatensor::vesin::Vector*>(points),
     839             :                 n_points,
     840             :                 PLMD::metatensor::vesin::BoundingBox(matrix, periodic),
     841             :                 options,
     842             :                 *neighbors
     843             :             );
     844             :         } else {
     845           0 :             throw std::runtime_error("unknown device " + std::to_string(device));
     846             :         }
     847           0 :     } catch (const std::bad_alloc&) {
     848           0 :         LAST_ERROR = "failed to allocate memory";
     849           0 :         *error_message = LAST_ERROR.c_str();
     850             :         return EXIT_FAILURE;
     851           0 :     } catch (const std::exception& e) {
     852           0 :         LAST_ERROR = e.what();
     853           0 :         *error_message = LAST_ERROR.c_str();
     854             :         return EXIT_FAILURE;
     855           0 :     } catch (...) {
     856           0 :         *error_message = "fatal error: unknown type thrown as exception";
     857             :         return EXIT_FAILURE;
     858           0 :     }
     859             : 
     860           0 :     return EXIT_SUCCESS;
     861             : }
     862             : 
     863             : 
     864           0 : extern "C" void vesin_free(VesinNeighborList* neighbors) {
     865           0 :     if (neighbors == nullptr) {
     866             :         return;
     867             :     }
     868             : 
     869           0 :     if (neighbors->device == VesinUnknownDevice) {
     870             :         // nothing to do
     871           0 :     } else if (neighbors->device == VesinCPU) {
     872           0 :         PLMD::metatensor::vesin::cpu::free_neighbors(*neighbors);
     873             :     }
     874             : 
     875             :     std::memset(neighbors, 0, sizeof(VesinNeighborList));
     876             : }

Generated by: LCOV version 1.16