Basix
ndarray.h
1 // Copyright (C) 2021 Igor Baratta
2 //
3 // This file is part of DOLFINX (https://www.fenicsproject.org)
4 //
5 // SPDX-License-Identifier: LGPL-3.0-or-later
6 
7 #pragma once
8 
9 #include "span.hpp"
10 #include <array>
11 #include <cassert>
12 #include <numeric>
13 #include <ostream>
14 #include <vector>
15 
16 namespace basix
17 {
18 
19 template <typename T, typename = std::array<std::size_t, T::rank>>
20 struct has_shape : std::false_type
21 {
22 };
23 
24 template <typename T>
25 struct has_shape<T, decltype(T::shape)> : std::true_type
26 {
27 };
28 
29 template <typename T, std::size_t N>
30 class ndspan;
31 
34 template <typename T, std::size_t N, class Allocator = std::allocator<T>>
35 class ndarray
36 {
37 public:
39  using value_type = T;
40  using allocator_type = Allocator;
41  using size_type = typename std::vector<T, Allocator>::size_type;
42  using reference = typename std::vector<T, Allocator>::reference;
43  using const_reference = typename std::vector<T, Allocator>::const_reference;
44  using pointer = typename std::vector<T, Allocator>::pointer;
45  using iterator = typename std::vector<T, Allocator>::iterator;
46  using const_iterator = typename std::vector<T, Allocator>::const_iterator;
48 
53  ndarray(std::array<size_type, N> shape, value_type value = T(),
54  const Allocator& alloc = Allocator())
55  : shape(shape)
56  {
57  size_type size = std::accumulate(shape.begin(), shape.end(), 1,
58  std::multiplies<size_type>());
59  _storage = std::vector<T, Allocator>(size, value, alloc);
60  }
61 
67  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
68  ndarray(size_type rows, size_type cols, value_type value = T(),
69  const Allocator& alloc = Allocator())
70  : shape({rows, cols})
71  {
72  _storage = std::vector<T, Allocator>(shape[0] * shape[1], value, alloc);
73  }
74 
76  template <typename Vector,
77  typename = std::enable_if_t<std::is_class<Vector>::value>>
78  ndarray(std::array<size_type, N> shape, Vector&& x)
79  : shape(shape), _storage(std::forward<Vector>(x))
80  {
81  // Do nothing
82  }
83 
87  template <typename = std::enable_if_t<N == 2>>
88  constexpr ndarray(std::initializer_list<std::initializer_list<T>> list)
89  : shape({list.size(), (*list.begin()).size()})
90  {
91  _storage.reserve(shape[0] * shape[1]);
92  for (std::initializer_list<T> l : list)
93  for (const T val : l)
94  _storage.push_back(val);
95  }
96 
99  template <typename Span, typename = std::enable_if_t<has_shape<Span>::value>>
100  constexpr ndarray(Span& s)
101  : shape(s.shape), _storage(s.data(), s.data() + s.size())
102  {
103  // Do nothing
104  }
105 
107  ndarray(const ndarray& x) = default;
108 
110  ndarray(ndarray&& x) = default;
111 
113  ~ndarray() = default;
114 
116  ndarray& operator=(const ndarray& x) = default;
117 
119  ndarray& operator=(ndarray&& x) = default;
120 
126  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
127  constexpr reference operator()(size_type i, size_type j)
128  {
129  return _storage[i * shape[1] + j];
130  }
131 
138  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
139  constexpr const_reference operator()(size_type i, size_type j) const
140  {
141  return _storage[i * shape[1] + j];
142  }
143 
145  template <std::size_t _N = N, typename = std::enable_if_t<_N == 3>>
146  constexpr reference operator()(size_type i, size_type j, size_type k)
147  {
148  return _storage[shape[2] * (i * shape[1] + j) + k];
149  }
150 
152  template <std::size_t _N = N, typename = std::enable_if_t<_N == 3>>
153  constexpr const_reference operator()(size_type i, size_type j,
154  size_type k) const
155  {
156  return _storage[shape[2] * (i * shape[1] + j) + k];
157  }
158 
162  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
163  constexpr tcb::span<value_type> row(size_type i)
164  {
165  return tcb::span<value_type>(std::next(_storage.data(), i * shape[1]),
166  shape[1]);
167  }
168 
172  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
173  constexpr tcb::span<const value_type> row(size_type i) const
174  {
175  return tcb::span<const value_type>(std::next(_storage.data(), i * shape[1]),
176  shape[1]);
177  }
178 
180  template <std::size_t _N = N, typename = std::enable_if_t<_N == 3>>
181  constexpr ndspan<value_type, 2> row(size_type i)
182  {
183  return ndspan<value_type, 2>(
184  std::next(_storage.data(), i * shape[2] * shape[1]),
185  {shape[1], shape[2]});
186  }
187 
189  template <std::size_t _N = N, typename = std::enable_if_t<_N == 3>>
190  constexpr ndspan<const value_type, 2> row(size_type i) const
191  {
193  std::next(_storage.data(), i * shape[2] * shape[1]),
194  {shape[1], shape[2]});
195  }
196 
199  constexpr value_type* data() noexcept { return _storage.data(); }
200 
204  constexpr const value_type* data() const noexcept { return _storage.data(); };
205 
210  constexpr size_type size() const noexcept { return _storage.size(); }
211 
213  template <int _N = N, typename = std::enable_if_t<_N == 2>>
214  constexpr std::array<size_type, 2> strides() const noexcept
215  {
216  return {shape[1] * sizeof(T), sizeof(T)};
217  }
218 
221  constexpr bool empty() const noexcept { return _storage.empty(); }
222 
224  std::array<size_type, N> shape;
225 
227  static constexpr size_type rank = size_type(N);
228 
230  template <typename Array>
231  friend std::ostream& operator<<(std::ostream& out, const Array& array);
232 
233 private:
234  std::vector<T, Allocator> _storage;
235 };
236 
238 template <typename T, std::size_t N = 2>
239 class ndspan
240 {
241 public:
242  // /// \cond DO_NOT_DOCUMENT
243  using value_type = T;
244  using size_type = std::size_t;
245  using reference = T&;
246  using const_reference = const T&;
247  using pointer = T*;
248  using const_pointer = const T*;
249  // /// \endcond
250 
254  constexpr ndspan(T* data, std::array<size_type, N> shape)
255  : _storage(data), shape(shape)
256  {
257  // Do nothing
258  }
259 
261  template <typename Array,
262  typename = std::enable_if_t<has_shape<Array>::value>>
263  constexpr ndspan(Array& x) : shape(x.shape), _storage(x.data())
264  {
265  // Do nothing
266  }
267 
273  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
274  constexpr reference operator()(size_type i, size_type j)
275  {
276  return _storage[i * shape[1] + j];
277  }
278 
285  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
286  constexpr reference operator()(size_type i, size_type j) const
287  {
288  return _storage[i * shape[1] + j];
289  }
290 
292  template <std::size_t _N = N, typename = std::enable_if_t<_N == 3>>
293  constexpr reference operator()(size_type i, size_type j, size_type k)
294  {
295  return _storage[shape[2] * (i * shape[1] + j) + k];
296  }
297 
299  template <std::size_t _N = N, typename = std::enable_if_t<_N == 3>>
300  constexpr const_reference operator()(size_type i, size_type j,
301  size_type k) const
302  {
303  return _storage[shape[2] * (i * shape[1] + j) + k];
304  }
305 
309  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
310  constexpr tcb::span<value_type> row(size_type i)
311  {
312  return tcb::span<value_type>(_storage + i * shape[1], shape[1]);
313  }
314 
318  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
319  constexpr tcb::span<const value_type> row(size_type i) const
320  {
321  return tcb::span<const value_type>(_storage + i * shape[1], shape[1]);
322  }
323 
325  template <std::size_t _N = N, typename = std::enable_if_t<_N == 3>>
326  constexpr ndspan<value_type, 2> row(size_type i)
327  {
328  return ndspan<value_type, 2>(_storage + i * shape[2] * shape[1],
329  {shape[1], shape[2]});
330  }
331 
333  template <std::size_t _N = N, typename = std::enable_if_t<_N == 3>>
334  constexpr ndspan<const value_type, 2> row(size_type i) const
335  {
336  return ndspan<const value_type, 2>(_storage + i * shape[2] * shape[1],
337  {shape[1], shape[2]});
338  }
339 
342  // constexpr value_type* data() noexcept { return _storage; }
343 
347  constexpr value_type* data() const noexcept { return _storage; };
348 
353  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
354  constexpr size_type size() const noexcept
355  {
356  return std::accumulate(shape.begin(), shape.end(), 1,
357  std::multiplies<size_type>());
358  }
359 
361  template <std::size_t _N = N, typename = std::enable_if_t<_N == 2>>
362  constexpr std::array<size_type, 2> strides() const noexcept
363  {
364  return {shape[1] * sizeof(T), sizeof(T)};
365  }
366 
368  std::array<size_type, N> shape;
369 
371  static constexpr size_type rank = size_type(N);
372 
374  template <typename Span>
375  friend std::ostream& operator<<(std::ostream& out, const Span& array);
376 
377 private:
378  T* _storage;
379 };
380 
382 template <typename Array>
383 std::ostream& print_array(std::ostream& out, const Array& array)
384 {
385  if constexpr (array.rank == 2)
386  for (std::size_t i = 0; i < array.shape[0]; i++)
387  {
388  out << "{";
389  for (std::size_t j = 0; j < array.shape[1]; j++)
390  out << array(i, j) << ", ";
391  out << "}" << std::endl;
392  }
393 
394  if constexpr (array.rank == 3)
395  for (std::size_t i = 0; i < array.shape[0]; i++)
396  {
397  for (std::size_t j = 0; j < array.shape[1]; j++)
398  {
399  out << "{";
400  for (std::size_t k = 0; k < array.shape[2]; k++)
401  out << array(i, j, k) << ", ";
402  out << "}" << std::endl;
403  }
404  out << std::endl;
405  }
406 
407  return out;
408 }
409 
411 template <typename T, std::size_t N>
412 std::ostream& operator<<(std::ostream& out, const ndarray<T, N>& array)
413 {
414  return print_array(out, array);
415 }
416 
418 template <typename T, std::size_t N>
419 std::ostream& operator<<(std::ostream& out, const ndspan<T, N>& span)
420 {
421  return print_array(out, span);
422 }
423 
424 } // namespace basix
Definition: ndarray.h:36
constexpr tcb::span< value_type > row(size_type i)
Definition: ndarray.h:163
constexpr ndspan< value_type, 2 > row(size_type i)
Access a row in the array.
Definition: ndarray.h:181
ndarray & operator=(ndarray &&x)=default
Move assignment.
friend std::ostream & operator<<(std::ostream &out, const Array &array)
Pretty printing, useful for debuging.
ndarray(ndarray &&x)=default
Move constructor.
constexpr reference operator()(size_type i, size_type j)
Definition: ndarray.h:127
ndarray(std::array< size_type, N > shape, value_type value=T(), const Allocator &alloc=Allocator())
Definition: ndarray.h:53
constexpr ndspan< const value_type, 2 > row(size_type i) const
Access a row in the array (const version)
Definition: ndarray.h:190
constexpr const value_type * data() const noexcept
Definition: ndarray.h:204
constexpr ndarray(Span &s)
Definition: ndarray.h:100
constexpr size_type size() const noexcept
Definition: ndarray.h:210
constexpr std::array< size_type, 2 > strides() const noexcept
Returns the strides of the array.
Definition: ndarray.h:214
ndarray(std::array< size_type, N > shape, Vector &&x)
Constructs an n-dimensional array from a vector.
Definition: ndarray.h:78
constexpr const_reference operator()(size_type i, size_type j, size_type k) const
Return a reference to the element at specified location (i, j, k)
Definition: ndarray.h:153
ndarray(const ndarray &x)=default
Copy constructor.
static constexpr size_type rank
The rank of the array.
Definition: ndarray.h:227
constexpr ndarray(std::initializer_list< std::initializer_list< T >> list)
Definition: ndarray.h:88
constexpr value_type * data() noexcept
Definition: ndarray.h:199
constexpr tcb::span< const value_type > row(size_type i) const
Definition: ndarray.h:173
constexpr reference operator()(size_type i, size_type j, size_type k)
Return a reference to the element at specified location (i, j, k)
Definition: ndarray.h:146
constexpr const_reference operator()(size_type i, size_type j) const
Definition: ndarray.h:139
constexpr bool empty() const noexcept
Definition: ndarray.h:221
ndarray(size_type rows, size_type cols, value_type value=T(), const Allocator &alloc=Allocator())
Definition: ndarray.h:68
~ndarray()=default
Destructor.
ndarray & operator=(const ndarray &x)=default
Copy assignment.
std::array< size_type, N > shape
The shape of the array.
Definition: ndarray.h:224
This class provides a view into an n-dimensional row-wise array of data.
Definition: ndarray.h:240
constexpr ndspan< value_type, 2 > row(size_type i)
Access a row in the array.
Definition: ndarray.h:326
constexpr tcb::span< value_type > row(size_type i)
Definition: ndarray.h:310
constexpr value_type * data() const noexcept
Definition: ndarray.h:347
std::array< size_type, N > shape
The shape of the span.
Definition: ndarray.h:368
constexpr ndspan(Array &x)
Construct an n-dimensional span from an n-dimensional array.
Definition: ndarray.h:263
static constexpr size_type rank
The rank of the span.
Definition: ndarray.h:371
constexpr ndspan(T *data, std::array< size_type, N > shape)
Definition: ndarray.h:254
constexpr reference operator()(size_type i, size_type j) const
Definition: ndarray.h:286
constexpr const_reference operator()(size_type i, size_type j, size_type k) const
Return a reference to the element at specified location (i, j, k)
Definition: ndarray.h:300
constexpr tcb::span< const value_type > row(size_type i) const
Definition: ndarray.h:319
constexpr std::array< size_type, 2 > strides() const noexcept
Returns the strides of the span.
Definition: ndarray.h:362
constexpr ndspan< const value_type, 2 > row(size_type i) const
Access a row in the array (const version)
Definition: ndarray.h:334
friend std::ostream & operator<<(std::ostream &out, const Span &array)
Pretty printing, useful for debuging.
constexpr reference operator()(size_type i, size_type j)
Definition: ndarray.h:274
constexpr reference operator()(size_type i, size_type j, size_type k)
Return a reference to the element at specified location (i, j, k)
Definition: ndarray.h:293
constexpr size_type size() const noexcept
Definition: ndarray.h:354
Placeholder.
Definition: basix.h:10
std::ostream & operator<<(std::ostream &out, const ndarray< T, N > &array)
Pretty printing, useful for debuging.
Definition: ndarray.h:412
std::ostream & print_array(std::ostream &out, const Array &array)
Convenience function for outputting arrays.
Definition: ndarray.h:383
Definition: ndarray.h:21