Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2024 The METATENSOR code team
3 : (see the PEOPLE-METATENSOR file at the root of this folder for a list of names)
4 :
5 : See https://docs.metatensor.org/latest/ for more information about the
6 : metatensor package that this module allows you to call from PLUMED.
7 :
8 : This file is part of METATENSOR-PLUMED module.
9 :
10 : The METATENSOR-PLUMED module is free software: you can redistribute it and/or modify
11 : it under the terms of the GNU Lesser General Public License as published by
12 : the Free Software Foundation, either version 3 of the License, or
13 : (at your option) any later version.
14 :
15 : The METATENSOR-PLUMED module is distributed in the hope that it will be useful,
16 : but WITHOUT ANY WARRANTY; without even the implied warranty of
17 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 : GNU Lesser General Public License for more details.
19 :
20 : You should have received a copy of the GNU Lesser General Public License
21 : along with the METATENSOR-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
22 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
23 : /*INDENT-OFF*/
24 : #include "vesin.h"
25 : #include <cassert>
26 : #include <cstdlib>
27 : #include <cstring>
28 :
29 : #include <algorithm>
30 : #include <tuple>
31 : #include <new>
32 :
33 : #ifndef VESIN_CPU_CELL_LIST_HPP
34 : #define VESIN_CPU_CELL_LIST_HPP
35 :
36 : #include <vector>
37 :
38 : #include "vesin.h"
39 :
40 : #ifndef VESIN_TYPES_HPP
41 : #define VESIN_TYPES_HPP
42 :
43 : #ifndef VESIN_MATH_HPP
44 : #define VESIN_MATH_HPP
45 :
46 : #include <array>
47 : #include <cmath>
48 : #include <stdexcept>
49 :
50 : namespace PLMD {
51 : namespace metatensor {
52 : namespace vesin {
53 : struct Vector;
54 :
55 : Vector operator*(Vector vector, double scalar);
56 :
57 : struct Vector: public std::array<double, 3> {
58 : double dot(Vector other) const {
59 0 : return (*this)[0] * other[0] + (*this)[1] * other[1] + (*this)[2] * other[2];
60 : }
61 :
62 0 : double norm() const {
63 0 : return std::sqrt(this->dot(*this));
64 : }
65 :
66 0 : Vector normalize() const {
67 0 : return *this * (1.0 / this->norm());
68 : }
69 :
70 : Vector cross(Vector other) const {
71 : return Vector{
72 0 : (*this)[1] * other[2] - (*this)[2] * other[1],
73 0 : (*this)[2] * other[0] - (*this)[0] * other[2],
74 0 : (*this)[0] * other[1] - (*this)[1] * other[0],
75 0 : };
76 : }
77 : };
78 :
79 : inline Vector operator+(Vector u, Vector v) {
80 : return Vector{
81 0 : u[0] + v[0],
82 0 : u[1] + v[1],
83 0 : u[2] + v[2],
84 : };
85 : }
86 :
87 : inline Vector operator-(Vector u, Vector v) {
88 : return Vector{
89 0 : u[0] - v[0],
90 0 : u[1] - v[1],
91 0 : u[2] - v[2],
92 : };
93 : }
94 :
95 : inline Vector operator*(double scalar, Vector vector) {
96 : return Vector{
97 : scalar * vector[0],
98 : scalar * vector[1],
99 : scalar * vector[2],
100 : };
101 : }
102 :
103 : inline Vector operator*(Vector vector, double scalar) {
104 : return Vector{
105 0 : scalar * vector[0],
106 0 : scalar * vector[1],
107 0 : scalar * vector[2],
108 0 : };
109 : }
110 :
111 :
112 : struct Matrix: public std::array<std::array<double, 3>, 3> {
113 0 : double determinant() const {
114 0 : return (*this)[0][0] * ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2])
115 0 : - (*this)[0][1] * ((*this)[1][0] * (*this)[2][2] - (*this)[1][2] * (*this)[2][0])
116 0 : + (*this)[0][2] * ((*this)[1][0] * (*this)[2][1] - (*this)[1][1] * (*this)[2][0]);
117 : }
118 :
119 0 : Matrix inverse() const {
120 0 : auto det = this->determinant();
121 :
122 0 : if (std::abs(det) < 1e-30) {
123 0 : throw std::runtime_error("this matrix is not invertible");
124 : }
125 :
126 : auto inverse = Matrix();
127 0 : inverse[0][0] = ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2]) / det;
128 0 : inverse[0][1] = ((*this)[0][2] * (*this)[2][1] - (*this)[0][1] * (*this)[2][2]) / det;
129 0 : inverse[0][2] = ((*this)[0][1] * (*this)[1][2] - (*this)[0][2] * (*this)[1][1]) / det;
130 0 : inverse[1][0] = ((*this)[1][2] * (*this)[2][0] - (*this)[1][0] * (*this)[2][2]) / det;
131 0 : inverse[1][1] = ((*this)[0][0] * (*this)[2][2] - (*this)[0][2] * (*this)[2][0]) / det;
132 0 : inverse[1][2] = ((*this)[1][0] * (*this)[0][2] - (*this)[0][0] * (*this)[1][2]) / det;
133 0 : inverse[2][0] = ((*this)[1][0] * (*this)[2][1] - (*this)[2][0] * (*this)[1][1]) / det;
134 0 : inverse[2][1] = ((*this)[2][0] * (*this)[0][1] - (*this)[0][0] * (*this)[2][1]) / det;
135 0 : inverse[2][2] = ((*this)[0][0] * (*this)[1][1] - (*this)[1][0] * (*this)[0][1]) / det;
136 0 : return inverse;
137 : }
138 : };
139 :
140 :
141 : inline Vector operator*(Matrix matrix, Vector vector) {
142 : return Vector{
143 : matrix[0][0] * vector[0] + matrix[0][1] * vector[1] + matrix[0][2] * vector[2],
144 : matrix[1][0] * vector[0] + matrix[1][1] * vector[1] + matrix[1][2] * vector[2],
145 : matrix[2][0] * vector[0] + matrix[2][1] * vector[1] + matrix[2][2] * vector[2],
146 : };
147 : }
148 :
149 0 : inline Vector operator*(Vector vector, Matrix matrix) {
150 : return Vector{
151 0 : vector[0] * matrix[0][0] + vector[1] * matrix[1][0] + vector[2] * matrix[2][0],
152 0 : vector[0] * matrix[0][1] + vector[1] * matrix[1][1] + vector[2] * matrix[2][1],
153 0 : vector[0] * matrix[0][2] + vector[1] * matrix[1][2] + vector[2] * matrix[2][2],
154 0 : };
155 : }
156 :
157 : } // namespace vesin
158 : } // namespace metatensor
159 : } // namespace PLMD
160 :
161 : #endif
162 :
163 : namespace PLMD {
164 : namespace metatensor {
165 : namespace vesin {
166 :
167 : class BoundingBox {
168 : public:
169 0 : BoundingBox(Matrix matrix, bool periodic): matrix_(matrix), periodic_(periodic) {
170 0 : if (periodic) {
171 0 : this->inverse_ = matrix_.inverse();
172 : } else {
173 0 : this->matrix_ = Matrix{{{
174 : {{1, 0, 0}},
175 : {{0, 1, 0}},
176 : {{0, 0, 1}}
177 : }}};
178 0 : this->inverse_ = matrix_;
179 : }
180 0 : }
181 :
182 : const Matrix& matrix() const {
183 : return this->matrix_;
184 : }
185 :
186 : bool periodic() const {
187 0 : return this->periodic_;
188 : }
189 :
190 : /// Convert a vector from cartesian coordinates to fractional coordinates
191 : Vector cartesian_to_fractional(Vector cartesian) const {
192 0 : return cartesian * inverse_;
193 : }
194 :
195 : /// Convert a vector from fractional coordinates to cartesian coordinates
196 : Vector fractional_to_cartesian(Vector fractional) const {
197 : return fractional * matrix_;
198 : }
199 :
200 : /// Get the three distances between faces of the bounding box
201 0 : Vector distances_between_faces() const {
202 0 : auto a = Vector{matrix_[0]};
203 0 : auto b = Vector{matrix_[1]};
204 0 : auto c = Vector{matrix_[2]};
205 :
206 : // Plans normal vectors
207 0 : auto na = b.cross(c).normalize();
208 0 : auto nb = c.cross(a).normalize();
209 0 : auto nc = a.cross(b).normalize();
210 :
211 : return Vector{
212 : std::abs(na.dot(a)),
213 : std::abs(nb.dot(b)),
214 : std::abs(nc.dot(c)),
215 0 : };
216 : }
217 :
218 : private:
219 : Matrix matrix_;
220 : Matrix inverse_;
221 : bool periodic_;
222 : };
223 :
224 :
225 : /// A cell shift represents the displacement along cell axis between the actual
226 : /// position of an atom and a periodic image of this atom.
227 : ///
228 : /// The cell shift can be used to reconstruct the vector between two points,
229 : /// wrapped inside the unit cell.
230 : struct CellShift: public std::array<int32_t, 3> {
231 : /// Compute the shift vector in cartesian coordinates, using the given cell
232 : /// matrix (stored in row major order).
233 0 : Vector cartesian(Matrix cell) const {
234 : auto vector = Vector{
235 0 : static_cast<double>((*this)[0]),
236 0 : static_cast<double>((*this)[1]),
237 0 : static_cast<double>((*this)[2]),
238 0 : };
239 0 : return vector * cell;
240 : }
241 : };
242 :
243 : inline CellShift operator+(CellShift a, CellShift b) {
244 : return CellShift{
245 0 : a[0] + b[0],
246 0 : a[1] + b[1],
247 0 : a[2] + b[2],
248 : };
249 : }
250 :
251 : inline CellShift operator-(CellShift a, CellShift b) {
252 : return CellShift{
253 0 : a[0] - b[0],
254 0 : a[1] - b[1],
255 0 : a[2] - b[2],
256 : };
257 : }
258 :
259 :
260 : } // namespace vesin
261 : } // namespace metatensor
262 : } // namespace PLMD
263 :
264 : #endif
265 :
266 : namespace PLMD {
267 : namespace metatensor {
268 : namespace vesin { namespace cpu {
269 :
270 : void free_neighbors(VesinNeighborList& neighbors);
271 :
272 : void neighbors(
273 : const Vector* points,
274 : size_t n_points,
275 : BoundingBox cell,
276 : VesinOptions options,
277 : VesinNeighborList& neighbors
278 : );
279 :
280 :
281 : /// The cell list is used to sort atoms inside bins/cells.
282 : ///
283 : /// The list of potential pairs is then constructed by looking through all
284 : /// neighboring cells (the number of cells to search depends on the cutoff and
285 : /// the size of the cells) for each atom to create pair candidates.
286 0 : class CellList {
287 : public:
288 : /// Create a new `CellList` for the given bounding box and cutoff,
289 : /// determining all required parameters.
290 : CellList(BoundingBox box, double cutoff);
291 :
292 : /// Add a single point to the cell list at the given `position`. The point
293 : /// is uniquely identified by its `index`.
294 : void add_point(size_t index, Vector position);
295 :
296 : /// Iterate over all possible pairs, calling the given callback every time
297 : template <typename Function>
298 : void foreach_pair(Function callback);
299 :
300 : private:
301 : /// How many cells do we need to look at when searching neighbors to include
302 : /// all neighbors below cutoff
303 : std::array<int32_t, 3> n_search_;
304 :
305 : /// the cells themselves are a list of points & corresponding
306 : /// shift to place the point inside the cell
307 : struct Point {
308 : size_t index;
309 : CellShift shift;
310 : };
311 0 : struct Cell: public std::vector<Point> {};
312 :
313 : // raw data for the cells
314 : std::vector<Cell> cells_;
315 : // shape of the cell array
316 : std::array<size_t, 3> cells_shape_;
317 :
318 : BoundingBox box_;
319 :
320 : Cell& get_cell(std::array<int32_t, 3> index);
321 : };
322 :
323 : /// Wrapper around `VesinNeighborList` that behaves like a std::vector,
324 : /// automatically growing memory allocations.
325 : class GrowableNeighborList {
326 : public:
327 : VesinNeighborList& neighbors;
328 : size_t capacity;
329 : VesinOptions options;
330 :
331 : size_t length() const {
332 0 : return neighbors.length;
333 : }
334 :
335 : void increment_length() {
336 0 : neighbors.length += 1;
337 0 : }
338 :
339 : void set_pair(size_t index, size_t first, size_t second);
340 : void set_shift(size_t index, PLMD::metatensor::vesin::CellShift shift);
341 : void set_distance(size_t index, double distance);
342 : void set_vector(size_t index, PLMD::metatensor::vesin::Vector vector);
343 :
344 : // reset length to 0, and allocate/deallocate members of
345 : // `neighbors` according to `options`
346 : void reset();
347 :
348 : // allocate more memory & update capacity
349 : void grow();
350 : };
351 :
352 : } // namespace vesin
353 : } // namespace metatensor
354 : } // namespace PLMD
355 : } // namespace cpu
356 :
357 : #endif
358 :
359 : using namespace PLMD::metatensor::vesin;
360 : using namespace PLMD::metatensor::vesin::cpu;
361 :
362 0 : void PLMD::metatensor::vesin::cpu::neighbors(
363 : const Vector* points,
364 : size_t n_points,
365 : BoundingBox cell,
366 : VesinOptions options,
367 : VesinNeighborList& raw_neighbors
368 : ) {
369 0 : auto cell_list = CellList(cell, options.cutoff);
370 :
371 0 : for (size_t i=0; i<n_points; i++) {
372 0 : cell_list.add_point(i, points[i]);
373 : }
374 :
375 0 : auto cell_matrix = cell.matrix();
376 0 : auto cutoff2 = options.cutoff * options.cutoff;
377 :
378 : // the cell list creates too many pairs, we only need to keep the
379 : // one where the distance is actually below the cutoff
380 0 : auto neighbors = GrowableNeighborList{raw_neighbors, raw_neighbors.length, options};
381 0 : neighbors.reset();
382 :
383 0 : cell_list.foreach_pair([&](size_t first, size_t second, CellShift shift) {
384 0 : if (!options.full) {
385 : // filter out some pairs for half neighbor lists
386 0 : if (first > second) {
387 : return;
388 : }
389 :
390 0 : if (first == second) {
391 : // When creating pairs between a point and one of its periodic
392 : // images, the code generate multiple redundant pairs (e.g. with
393 : // shifts 0 1 1 and 0 -1 -1); and we want to only keep one of
394 : // these.
395 0 : if (shift[0] + shift[1] + shift[2] < 0) {
396 : // drop shifts on the negative half-space
397 : return;
398 : }
399 :
400 : if ((shift[0] + shift[1] + shift[2] == 0)
401 0 : && (shift[2] < 0 || (shift[2] == 0 && shift[1] < 0))) {
402 : // drop shifts in the negative half plane or the negative
403 : // shift[1] axis. See below for a graphical representation:
404 : // we are keeping the shifts indicated with `O` and dropping
405 : // the ones indicated with `X`
406 : //
407 : // O O O │ O O O
408 : // O O O │ O O O
409 : // O O O │ O O O
410 : // ─X─X─X─┼─O─O─O─
411 : // X X X │ X X X
412 : // X X X │ X X X
413 : // X X X │ X X X
414 : return;
415 : }
416 : }
417 : }
418 :
419 0 : auto vector = points[second] - points[first] + shift.cartesian(cell_matrix);
420 : auto distance2 = vector.dot(vector);
421 :
422 0 : if (distance2 < cutoff2) {
423 0 : auto index = neighbors.length();
424 0 : neighbors.set_pair(index, first, second);
425 :
426 0 : if (options.return_shifts) {
427 0 : neighbors.set_shift(index, shift);
428 : }
429 :
430 0 : if (options.return_distances) {
431 0 : neighbors.set_distance(index, std::sqrt(distance2));
432 : }
433 :
434 0 : if (options.return_vectors) {
435 0 : neighbors.set_vector(index, vector);
436 : }
437 :
438 : neighbors.increment_length();
439 : }
440 : });
441 0 : }
442 :
443 : /* ========================================================================== */
444 :
445 : /// Maximal number of cells, we need to use this to prevent having too many
446 : /// cells with a small bounding box and a large cutoff
447 : #define MAX_NUMBER_OF_CELLS 1e5
448 :
449 :
450 : /// Function to compute both quotient and remainder of the division of a by b.
451 : /// This function follows Python convention, making sure the remainder have the
452 : /// same sign as `b`.
453 0 : static std::tuple<int32_t, int32_t> divmod(int32_t a, size_t b) {
454 : assert(b < (std::numeric_limits<int32_t>::max()));
455 0 : auto b_32 = static_cast<int32_t>(b);
456 0 : auto quotient = a / b_32;
457 0 : auto remainder = a % b_32;
458 0 : if (remainder < 0) {
459 0 : remainder += b_32;
460 0 : quotient -= 1;
461 : }
462 0 : return std::make_tuple(quotient, remainder);
463 : }
464 :
465 : /// Apply the `divmod` function to three components at the time
466 : static std::tuple<std::array<int32_t, 3>, std::array<int32_t, 3>>
467 0 : divmod(std::array<int32_t, 3> a, std::array<size_t, 3> b) {
468 0 : auto [qx, rx] = divmod(a[0], b[0]);
469 0 : auto [qy, ry] = divmod(a[1], b[1]);
470 0 : auto [qz, rz] = divmod(a[2], b[2]);
471 : return std::make_tuple(
472 0 : std::array<int32_t, 3>{qx, qy, qz},
473 0 : std::array<int32_t, 3>{rx, ry, rz}
474 0 : );
475 : }
476 :
477 0 : CellList::CellList(BoundingBox box, double cutoff):
478 0 : n_search_({0, 0, 0}),
479 0 : cells_shape_({0, 0, 0}),
480 0 : box_(box)
481 : {
482 0 : auto distances_between_faces = box_.distances_between_faces();
483 :
484 : auto n_cells = Vector{
485 0 : std::clamp(std::trunc(distances_between_faces[0] / cutoff), 1.0, HUGE_VAL),
486 0 : std::clamp(std::trunc(distances_between_faces[1] / cutoff), 1.0, HUGE_VAL),
487 0 : std::clamp(std::trunc(distances_between_faces[2] / cutoff), 1.0, HUGE_VAL),
488 0 : };
489 :
490 : assert(std::isfinite(n_cells[0]) && std::isfinite(n_cells[1]) && std::isfinite(n_cells[2]));
491 :
492 : // limit memory consumption by ensuring we have less than `MAX_N_CELLS`
493 : // cells to look though
494 0 : auto n_cells_total = n_cells[0] * n_cells[1] * n_cells[2];
495 0 : if (n_cells_total > MAX_NUMBER_OF_CELLS) {
496 : // set the total number of cells close to MAX_N_CELLS, while keeping
497 : // roughly the ratio of cells in each direction
498 0 : auto ratio_x_y = n_cells[0] / n_cells[1];
499 0 : auto ratio_y_z = n_cells[1] / n_cells[2];
500 :
501 0 : n_cells[2] = std::trunc(std::cbrt(MAX_NUMBER_OF_CELLS / (ratio_x_y * ratio_y_z * ratio_y_z)));
502 0 : n_cells[1] = std::trunc(ratio_y_z * n_cells[2]);
503 0 : n_cells[0] = std::trunc(ratio_x_y * n_cells[1]);
504 : }
505 :
506 : // number of cells to search in each direction to make sure all possible
507 : // pairs below the cutoff are accounted for.
508 0 : this->n_search_ = std::array<int32_t, 3>{
509 0 : static_cast<int32_t>(std::ceil(cutoff * n_cells[0] / distances_between_faces[0])),
510 0 : static_cast<int32_t>(std::ceil(cutoff * n_cells[1] / distances_between_faces[1])),
511 0 : static_cast<int32_t>(std::ceil(cutoff * n_cells[2] / distances_between_faces[2])),
512 : };
513 :
514 0 : this->cells_shape_ = std::array<size_t, 3>{
515 0 : static_cast<size_t>(n_cells[0]),
516 0 : static_cast<size_t>(n_cells[1]),
517 0 : static_cast<size_t>(n_cells[2]),
518 : };
519 :
520 0 : for (size_t spatial=0; spatial<3; spatial++) {
521 0 : if (n_search_[spatial] < 1) {
522 0 : n_search_[spatial] = 1;
523 : }
524 :
525 : // don't look for neighboring cells if we have only one cell and no
526 : // periodic boundary condition
527 0 : if (n_cells[spatial] == 1 && !box.periodic()) {
528 0 : n_search_[spatial] = 0;
529 : }
530 : }
531 :
532 0 : this->cells_.resize(cells_shape_[0] * cells_shape_[1] * cells_shape_[2]);
533 0 : }
534 :
535 0 : void CellList::add_point(size_t index, Vector position) {
536 0 : auto fractional = box_.cartesian_to_fractional(position);
537 :
538 : // find the cell in which this atom should go
539 : auto cell_index = std::array<int32_t, 3>{
540 0 : static_cast<int32_t>(std::floor(fractional[0] * static_cast<double>(cells_shape_[0]))),
541 0 : static_cast<int32_t>(std::floor(fractional[1] * static_cast<double>(cells_shape_[1]))),
542 0 : static_cast<int32_t>(std::floor(fractional[2] * static_cast<double>(cells_shape_[2]))),
543 0 : };
544 :
545 : // deal with pbc by wrapping the atom inside if it was outside of the
546 : // cell
547 : CellShift shift;
548 : // auto (shift, cell_index) =
549 0 : if (box_.periodic()) {
550 0 : auto result = divmod(cell_index, cells_shape_);
551 0 : shift = CellShift{std::get<0>(result)};
552 0 : cell_index = std::get<1>(result);
553 : } else {
554 : shift = CellShift({0, 0, 0});
555 0 : cell_index = std::array<int32_t, 3>{
556 0 : std::clamp(cell_index[0], 0, static_cast<int32_t>(cells_shape_[0] - 1)),
557 0 : std::clamp(cell_index[1], 0, static_cast<int32_t>(cells_shape_[1] - 1)),
558 0 : std::clamp(cell_index[2], 0, static_cast<int32_t>(cells_shape_[2] - 1)),
559 : };
560 : }
561 :
562 0 : this->get_cell(cell_index).emplace_back(Point{index, shift});
563 0 : }
564 :
565 :
566 : template <typename Function>
567 0 : void CellList::foreach_pair(Function callback) {
568 0 : for (int32_t cell_i_x=0; cell_i_x<static_cast<int32_t>(cells_shape_[0]); cell_i_x++) {
569 0 : for (int32_t cell_i_y=0; cell_i_y<static_cast<int32_t>(cells_shape_[1]); cell_i_y++) {
570 0 : for (int32_t cell_i_z=0; cell_i_z<static_cast<int32_t>(cells_shape_[2]); cell_i_z++) {
571 0 : const auto& current_cell = this->get_cell({cell_i_x, cell_i_y, cell_i_z});
572 : // look through each neighboring cell
573 0 : for (int32_t delta_x=-n_search_[0]; delta_x<=n_search_[0]; delta_x++) {
574 0 : for (int32_t delta_y=-n_search_[1]; delta_y<=n_search_[1]; delta_y++) {
575 0 : for (int32_t delta_z=-n_search_[2]; delta_z<=n_search_[2]; delta_z++) {
576 0 : auto cell_i = std::array<int32_t, 3>{
577 0 : cell_i_x + delta_x,
578 0 : cell_i_y + delta_y,
579 0 : cell_i_z + delta_z,
580 : };
581 :
582 : // shift vector from one cell to the other and index of
583 : // the neighboring cell
584 0 : auto [cell_shift, neighbor_cell_i] = divmod(cell_i, cells_shape_);
585 :
586 0 : for (const auto& atom_i: current_cell) {
587 0 : for (const auto& atom_j: this->get_cell(neighbor_cell_i)) {
588 0 : auto shift = CellShift{cell_shift} + atom_i.shift - atom_j.shift;
589 0 : auto shift_is_zero = shift[0] == 0 && shift[1] == 0 && shift[2] == 0;
590 :
591 0 : if (!box_.periodic() && !shift_is_zero) {
592 : // do not create pairs crossing the periodic
593 : // boundaries in a non-periodic box
594 0 : continue;
595 : }
596 :
597 0 : if (atom_i.index == atom_j.index && shift_is_zero) {
598 : // only create pairs with the same atom twice if the
599 : // pair spans more than one bounding box
600 0 : continue;
601 : }
602 :
603 0 : callback(atom_i.index, atom_j.index, shift);
604 : }
605 : } // loop over atoms in current neighbor cells
606 : }}}
607 : }}} // loop over neighboring cells
608 0 : }
609 :
610 0 : CellList::Cell& CellList::get_cell(std::array<int32_t, 3> index) {
611 0 : size_t linear_index = (cells_shape_[0] * cells_shape_[1] * index[2])
612 0 : + (cells_shape_[0] * index[1])
613 0 : + index[0];
614 0 : return cells_[linear_index];
615 : }
616 :
617 : /* ========================================================================== */
618 :
619 :
620 0 : void GrowableNeighborList::set_pair(size_t index, size_t first, size_t second) {
621 0 : if (index >= this->capacity) {
622 0 : this->grow();
623 : }
624 :
625 0 : this->neighbors.pairs[index][0] = first;
626 0 : this->neighbors.pairs[index][1] = second;
627 0 : }
628 :
629 0 : void GrowableNeighborList::set_shift(size_t index, PLMD::metatensor::vesin::CellShift shift) {
630 0 : if (index >= this->capacity) {
631 0 : this->grow();
632 : }
633 :
634 0 : this->neighbors.shifts[index][0] = shift[0];
635 0 : this->neighbors.shifts[index][1] = shift[1];
636 0 : this->neighbors.shifts[index][2] = shift[2];
637 0 : }
638 :
639 0 : void GrowableNeighborList::set_distance(size_t index, double distance) {
640 0 : if (index >= this->capacity) {
641 0 : this->grow();
642 : }
643 :
644 0 : this->neighbors.distances[index] = distance;
645 0 : }
646 :
647 0 : void GrowableNeighborList::set_vector(size_t index, PLMD::metatensor::vesin::Vector vector) {
648 0 : if (index >= this->capacity) {
649 0 : this->grow();
650 : }
651 :
652 0 : this->neighbors.vectors[index][0] = vector[0];
653 0 : this->neighbors.vectors[index][1] = vector[1];
654 0 : this->neighbors.vectors[index][2] = vector[2];
655 0 : }
656 :
657 : template <typename scalar_t, size_t N>
658 0 : static scalar_t (*alloc(scalar_t (*ptr)[N], size_t size, size_t new_size))[N] {
659 0 : auto* new_ptr = reinterpret_cast<scalar_t (*)[N]>(std::realloc(ptr, new_size * sizeof(scalar_t[N])));
660 :
661 0 : if (new_ptr == nullptr) {
662 0 : throw std::bad_alloc();
663 : }
664 :
665 : // initialize with a bit pattern that maps to NaN for double
666 0 : std::memset(new_ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t[N]));
667 :
668 0 : return new_ptr;
669 : }
670 :
671 : template <typename scalar_t>
672 0 : static scalar_t* alloc(scalar_t* ptr, size_t size, size_t new_size) {
673 0 : auto* new_ptr = reinterpret_cast<scalar_t*>(std::realloc(ptr, new_size * sizeof(scalar_t)));
674 :
675 0 : if (new_ptr == nullptr) {
676 0 : throw std::bad_alloc();
677 : }
678 :
679 : // initialize with a bit pattern that maps to NaN for double
680 0 : std::memset(new_ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t));
681 :
682 0 : return new_ptr;
683 : }
684 :
685 0 : void GrowableNeighborList::grow() {
686 0 : auto new_size = neighbors.length * 2;
687 : if (new_size == 0) {
688 : new_size = 1;
689 : }
690 :
691 0 : auto* new_pairs = alloc<size_t, 2>(neighbors.pairs, neighbors.length, new_size);
692 :
693 : int32_t (*new_shifts)[3] = nullptr;
694 0 : if (options.return_shifts) {
695 0 : new_shifts = alloc<int32_t, 3>(neighbors.shifts, neighbors.length, new_size);
696 : }
697 :
698 : double *new_distances = nullptr;
699 0 : if (options.return_distances) {
700 0 : new_distances = alloc<double>(neighbors.distances, neighbors.length, new_size);
701 : }
702 :
703 : double (*new_vectors)[3] = nullptr;
704 0 : if (options.return_vectors) {
705 0 : new_vectors = alloc<double, 3>(neighbors.vectors, neighbors.length, new_size);
706 : }
707 :
708 0 : this->neighbors.pairs = new_pairs;
709 0 : this->neighbors.shifts = new_shifts;
710 0 : this->neighbors.distances = new_distances;
711 0 : this->neighbors.vectors = new_vectors;
712 :
713 0 : this->capacity = new_size;
714 0 : }
715 :
716 0 : void GrowableNeighborList::reset() {
717 : // set all allocated data to zero
718 0 : auto size = this->neighbors.length;
719 0 : std::memset(this->neighbors.pairs, 0, size * sizeof(size_t[2]));
720 :
721 0 : if (this->neighbors.shifts != nullptr) {
722 0 : std::memset(this->neighbors.shifts, 0, size * sizeof(int32_t[3]));
723 : }
724 :
725 0 : if (this->neighbors.distances != nullptr) {
726 0 : std::memset(this->neighbors.distances, 0, size * sizeof(double));
727 : }
728 :
729 0 : if (this->neighbors.vectors != nullptr) {
730 0 : std::memset(this->neighbors.vectors, 0, size * sizeof(double[3]));
731 : }
732 :
733 : // reset length (but keep the capacity where it's at)
734 0 : this->neighbors.length = 0;
735 :
736 : // allocate/deallocate pointers as required
737 0 : auto* shifts = this->neighbors.shifts;
738 0 : if (this->options.return_shifts && shifts == nullptr) {
739 0 : shifts = alloc<int32_t, 3>(shifts, 0, capacity);
740 0 : } else if (!this->options.return_shifts && shifts != nullptr) {
741 0 : std::free(shifts);
742 : shifts = nullptr;
743 : }
744 :
745 0 : auto* distances = this->neighbors.distances;
746 0 : if (this->options.return_distances && distances == nullptr) {
747 0 : distances = alloc<double>(distances, 0, capacity);
748 0 : } else if (!this->options.return_distances && distances != nullptr) {
749 0 : std::free(distances);
750 : distances = nullptr;
751 : }
752 :
753 0 : auto* vectors = this->neighbors.vectors;
754 0 : if (this->options.return_vectors && vectors == nullptr) {
755 0 : vectors = alloc<double, 3>(vectors, 0, capacity);
756 0 : } else if (!this->options.return_vectors && vectors != nullptr) {
757 0 : std::free(vectors);
758 : vectors = nullptr;
759 : }
760 :
761 0 : this->neighbors.shifts = shifts;
762 0 : this->neighbors.distances = distances;
763 0 : this->neighbors.vectors = vectors;
764 0 : }
765 :
766 :
767 0 : void PLMD::metatensor::vesin::cpu::free_neighbors(VesinNeighborList& neighbors) {
768 : assert(neighbors.device == VesinCPU);
769 :
770 0 : std::free(neighbors.pairs);
771 0 : std::free(neighbors.shifts);
772 0 : std::free(neighbors.vectors);
773 0 : std::free(neighbors.distances);
774 0 : }
775 : #include <cstring>
776 : #include <string>
777 :
778 :
779 :
780 : thread_local std::string LAST_ERROR;
781 :
782 0 : extern "C" int vesin_neighbors(
783 : const double (*points)[3],
784 : size_t n_points,
785 : const double box[3][3],
786 : bool periodic,
787 : VesinDevice device,
788 : VesinOptions options,
789 : VesinNeighborList* neighbors,
790 : const char** error_message
791 : ) {
792 0 : if (error_message == nullptr) {
793 : return EXIT_FAILURE;
794 : }
795 :
796 0 : if (points == nullptr) {
797 0 : *error_message = "`points` can not be a NULL pointer";
798 0 : return EXIT_FAILURE;
799 : }
800 :
801 0 : if (box == nullptr) {
802 0 : *error_message = "`cell` can not be a NULL pointer";
803 0 : return EXIT_FAILURE;
804 : }
805 :
806 0 : if (neighbors == nullptr) {
807 0 : *error_message = "`neighbors` can not be a NULL pointer";
808 0 : return EXIT_FAILURE;
809 : }
810 :
811 0 : if (neighbors->device != VesinUnknownDevice && neighbors->device != device) {
812 0 : *error_message = "`neighbors` device and data `device` do not match, free the neighbors first";
813 0 : return EXIT_FAILURE;
814 : }
815 :
816 0 : if (device == VesinUnknownDevice) {
817 0 : *error_message = "got an unknown device to use when running simulation";
818 0 : return EXIT_FAILURE;
819 : }
820 :
821 0 : if (neighbors->device == VesinUnknownDevice) {
822 : // initialize the device
823 0 : neighbors->device = device;
824 0 : } else if (neighbors->device != device) {
825 0 : *error_message = "`neighbors.device` and `device` do not match, free the neighbors first";
826 0 : return EXIT_FAILURE;
827 : }
828 :
829 : try {
830 0 : if (device == VesinCPU) {
831 : auto matrix = PLMD::metatensor::vesin::Matrix{{{
832 0 : {{box[0][0], box[0][1], box[0][2]}},
833 0 : {{box[1][0], box[1][1], box[1][2]}},
834 0 : {{box[2][0], box[2][1], box[2][2]}},
835 0 : }}};
836 :
837 0 : PLMD::metatensor::vesin::cpu::neighbors(
838 : reinterpret_cast<const PLMD::metatensor::vesin::Vector*>(points),
839 : n_points,
840 : PLMD::metatensor::vesin::BoundingBox(matrix, periodic),
841 : options,
842 : *neighbors
843 : );
844 : } else {
845 0 : throw std::runtime_error("unknown device " + std::to_string(device));
846 : }
847 0 : } catch (const std::bad_alloc&) {
848 0 : LAST_ERROR = "failed to allocate memory";
849 0 : *error_message = LAST_ERROR.c_str();
850 : return EXIT_FAILURE;
851 0 : } catch (const std::exception& e) {
852 0 : LAST_ERROR = e.what();
853 0 : *error_message = LAST_ERROR.c_str();
854 : return EXIT_FAILURE;
855 0 : } catch (...) {
856 0 : *error_message = "fatal error: unknown type thrown as exception";
857 : return EXIT_FAILURE;
858 0 : }
859 :
860 0 : return EXIT_SUCCESS;
861 : }
862 :
863 :
864 0 : extern "C" void vesin_free(VesinNeighborList* neighbors) {
865 0 : if (neighbors == nullptr) {
866 : return;
867 : }
868 :
869 0 : if (neighbors->device == VesinUnknownDevice) {
870 : // nothing to do
871 0 : } else if (neighbors->device == VesinCPU) {
872 0 : PLMD::metatensor::vesin::cpu::free_neighbors(*neighbors);
873 : }
874 :
875 : std::memset(neighbors, 0, sizeof(VesinNeighborList));
876 : }
|