Skip to content

Generalize view and size functions #1660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stan/math/prim/err.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <stan/math/prim/err/check_symmetric.hpp>
#include <stan/math/prim/err/check_unit_vector.hpp>
#include <stan/math/prim/err/check_vector.hpp>
#include <stan/math/prim/err/check_vector_index.hpp>
#include <stan/math/prim/err/constraint_tolerance.hpp>
#include <stan/math/prim/err/domain_error.hpp>
#include <stan/math/prim/err/domain_error_vec.hpp>
Expand Down
8 changes: 3 additions & 5 deletions stan/math/prim/err/check_column_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@ namespace math {
* <code>stan::error_index::value</code>. This function will
* throw an <code>std::out_of_range</code> exception if
* the index is out of bounds.
* @tparam T_y Type of scalar
* @tparam R number of rows or Eigen::Dynamic
* @tparam C number of columns or Eigen::Dynamic
* @tparam T_y Type of matrix
* @param function Function name (for error messages)
* @param name Variable name (for error messages)
* @param y matrix to test
* @param i column index to check
* @throw <code>std::out_of_range</code> if index is an invalid column
*/
template <typename T_y, int R, int C>
template <typename T_y, typename = require_eigen_t<T_y>>
inline void check_column_index(const char* function, const char* name,
const Eigen::Matrix<T_y, R, C>& y, size_t i) {
const T_y& y, size_t i) {
if (i >= stan::error_index::value
&& i < static_cast<size_t>(y.cols()) + stan::error_index::value) {
return;
Expand Down
8 changes: 3 additions & 5 deletions stan/math/prim/err/check_row_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@ namespace math {
* Check if the specified index is a valid row of the matrix
* This check is 1-indexed by default. This behavior can be changed
* by setting <code>stan::error_index::value</code>.
* @tparam T Scalar type
* @tparam R number of rows or Eigen::Dynamic
* @tparam C number of columns or Eigen::Dynamic
* @tparam T Matrix type
* @param function Function name (for error messages)
* @param name Variable name (for error messages)
* @param y matrix to test
* @param i row index to check
* @throw <code>std::out_of_range</code> if the index is out of range.
*/
template <typename T_y, int R, int C>
template <typename T_y, typename = require_eigen_t<T_y>>
inline void check_row_index(const char* function, const char* name,
const Eigen::Matrix<T_y, R, C>& y, size_t i) {
const T_y& y, size_t i) {
if (i >= stan::error_index::value
&& i < static_cast<size_t>(y.rows()) + stan::error_index::value) {
return;
Expand Down
40 changes: 40 additions & 0 deletions stan/math/prim/err/check_vector_index.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef STAN_MATH_PRIM_ERR_CHECK_VECTOR_INDEX_HPP
#define STAN_MATH_PRIM_ERR_CHECK_VECTOR_INDEX_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/out_of_range.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <sstream>
#include <string>

namespace stan {
namespace math {

/**
* Check if the specified index is a valid element of the row or column vector
* This check is 1-indexed by default. This behavior can be changed
* by setting <code>stan::error_index::value</code>.
* @tparam T Vector type
* @param function Function name (for error messages)
* @param name Variable name (for error messages)
* @param y vector to test
* @param i row index to check
* @throw <code>std::out_of_range</code> if the index is out of range.
*/
template <typename T, typename = require_eigen_vector_t<T>>
inline void check_vector_index(const char* function, const char* name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is missing tests.

const T& y, size_t i) {
if (i >= stan::error_index::value
&& i < static_cast<size_t>(y.size()) + stan::error_index::value) {
return;
}

std::stringstream msg;
msg << " for size of " << name;
std::string msg_str(msg.str());
out_of_range(function, y.rows(), i, msg_str.c_str());
}

} // namespace math
} // namespace stan
#endif
9 changes: 4 additions & 5 deletions stan/math/prim/fun/col.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@ namespace math {
* This is equivalent to calling <code>m.col(i - 1)</code> and
* assigning the resulting template expression to a column vector.
*
* @tparam T type of elements in the matrix
* @tparam T type of the matrix
* @param m Matrix.
* @param j Column index (count from 1).
* @return Specified column of the matrix.
* @throw std::out_of_range if j is out of range.
*/
template <typename T>
inline Eigen::Matrix<T, Eigen::Dynamic, 1> col(
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m, size_t j) {
template <typename T, typename = require_eigen_t<T>>
inline auto col(const T& m, size_t j) {
check_column_index("col", "j", m, j);
return m.col(j - 1);
return m.col(j - 1).eval();
}

} // namespace math
Expand Down
9 changes: 4 additions & 5 deletions stan/math/prim/fun/cols.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_PRIM_FUN_COLS_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>

namespace stan {
namespace math {
Expand All @@ -10,14 +11,12 @@ namespace math {
* Return the number of columns in the specified
* matrix, vector, or row vector.
*
* @tparam T type of elements in the matrix
* @tparam R number of rows, can be Eigen::Dynamic
* @tparam C number of columns, can be Eigen::Dynamic
* @tparam T type of the matrix
* @param[in] m Input matrix, vector, or row vector.
* @return Number of columns.
*/
template <typename T, int R, int C>
inline int cols(const Eigen::Matrix<T, R, C>& m) {
template <typename T, typename = require_eigen_t<T>>
inline int cols(const T& m) {
return m.cols();
}

Expand Down
10 changes: 5 additions & 5 deletions stan/math/prim/fun/diagonal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_PRIM_FUN_DIAGONAL_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>

namespace stan {
namespace math {
Expand All @@ -10,14 +11,13 @@ namespace math {
* Return a column vector of the diagonal elements of the
* specified matrix. The matrix is not required to be square.
*
* @tparam T type of elements in the matrix
* @tparam T type of the matrix
* @param m Specified matrix.
* @return Diagonal of the matrix.
*/
template <typename T>
inline Eigen::Matrix<T, Eigen::Dynamic, 1> diagonal(
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m) {
return m.diagonal();
template <typename T, typename = require_eigen_t<T>>
inline auto diagonal(const T& m) {
return m.diagonal().eval();
}

} // namespace math
Expand Down
7 changes: 4 additions & 3 deletions stan/math/prim/fun/dims.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
#define STAN_MATH_PRIM_FUN_DIMS_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <vector>

namespace stan {
namespace math {

template <typename T>
template <typename T, typename = require_stan_scalar_t<T>>
inline void dims(const T& x, std::vector<int>& result) {
/* no op */
}
template <typename T, int R, int C>
inline void dims(const Eigen::Matrix<T, R, C>& x, std::vector<int>& result) {
template <typename T, typename = require_eigen_t<T>, typename = void>
inline void dims(const T& x, std::vector<int>& result) {
result.push_back(x.rows());
result.push_back(x.cols());
}
Expand Down
39 changes: 8 additions & 31 deletions stan/math/prim/fun/head.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,21 @@ namespace stan {
namespace math {

/**
* Return the specified number of elements as a vector
* from the front of the specified vector.
* Return the specified number of elements as a vector or row vector (same as
* input) from the front of the specified vector or row vector.
*
* @tparam T type of elements in the vector
* @tparam T type of the vector
* @param v Vector input.
* @param n Size of return.
* @return The first n elements of v.
* @throw std::out_of_range if n is out of range.
*/
template <typename T>
inline Eigen::Matrix<T, Eigen::Dynamic, 1> head(
const Eigen::Matrix<T, Eigen::Dynamic, 1>& v, size_t n) {
template <typename T, typename = require_eigen_vector_t<T>>
inline auto head(const T& v, size_t n) {
if (n != 0) {
check_row_index("head", "n", v, n);
check_vector_index("head", "n", v, n);
}
return v.head(n);
}

/**
* Return the specified number of elements as a row vector
* from the front of the specified row vector.
*
* @tparam T type of elements in the vector
* @param rv Row vector.
* @param n Size of return row vector.
* @return The first n elements of rv.
* @throw std::out_of_range if n is out of range.
*/
template <typename T>
inline Eigen::Matrix<T, 1, Eigen::Dynamic> head(
const Eigen::Matrix<T, 1, Eigen::Dynamic>& rv, size_t n) {
if (n != 0) {
check_column_index("head", "n", rv, n);
}
return rv.head(n);
return v.head(n).eval();
}

/**
Expand All @@ -62,10 +42,7 @@ std::vector<T> head(const std::vector<T>& sv, size_t n) {
check_std_vector_index("head", "n", sv, n);
}

std::vector<T> s;
for (size_t i = 0; i < n; ++i) {
s.push_back(sv[i]);
}
std::vector<T> s(sv.begin(), sv.begin() + n);
return s;
}

Expand Down
11 changes: 5 additions & 6 deletions stan/math/prim/fun/num_elements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_PRIM_FUN_NUM_ELEMENTS_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <vector>

namespace stan {
Expand All @@ -14,23 +15,21 @@ namespace math {
* @param x Argument of primitive type.
* @return 1
*/
template <typename T>
template <typename T, typename = require_stan_scalar_t<T>>
inline int num_elements(const T& x) {
return 1;
}

/**
* Returns the size of the specified matrix.
*
* @tparam T type of elements in the matrix
* @tparam R number of rows, can be Eigen::Dynamic
* @tparam C number of columns, can be Eigen::Dynamic
* @tparam T type of the matrix
*
* @param m argument matrix
* @return size of matrix
*/
template <typename T, int R, int C>
inline int num_elements(const Eigen::Matrix<T, R, C>& m) {
template <typename T, typename = require_eigen_t<T>, typename = void>
inline int num_elements(const T& m) {
return m.size();
}

Expand Down
9 changes: 4 additions & 5 deletions stan/math/prim/fun/row.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,17 @@ namespace math {
* This is equivalent to calling <code>m.row(i - 1)</code> and
* assigning the resulting template expression to a row vector.
*
* @tparam T type of elements in the matrix
* @tparam T type of the matrix
* @param m Matrix.
* @param i Row index (count from 1).
* @return Specified row of the matrix.
* @throw std::out_of_range if i is out of range.
*/
template <typename T>
inline Eigen::Matrix<T, 1, Eigen::Dynamic> row(
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m, size_t i) {
template <typename T, typename = require_eigen_t<T>>
inline auto row(const T& m, size_t i) {
check_row_index("row", "i", m, i);

return m.row(i - 1);
return m.row(i - 1).eval();
}

} // namespace math
Expand Down
10 changes: 4 additions & 6 deletions stan/math/prim/fun/rows.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_PRIM_FUN_ROWS_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>

namespace stan {
namespace math {
Expand All @@ -10,15 +11,12 @@ namespace math {
* Return the number of rows in the specified
* matrix, vector, or row vector.
*
* @tparam T type of elements in the matrix
* @tparam R number of rows, can be Eigen::Dynamic
* @tparam C number of columns, can be Eigen::Dynamic
*
* @tparam T type of the matrix
* @param[in] m Input matrix, vector, or row vector.
* @return Number of rows.
*/
template <typename T, int R, int C>
inline int rows(const Eigen::Matrix<T, R, C>& m) {
template <typename T, typename = require_eigen_t<T>>
inline int rows(const T& m) {
return m.rows();
}

Expand Down
10 changes: 4 additions & 6 deletions stan/math/prim/fun/sub_col.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,21 @@ namespace math {
/**
* Return a nrows x 1 subcolumn starting at (i-1, j-1).
*
* @tparam T type of elements in the matrix
* @tparam T type of the matrix
* @param m Matrix.
* @param i Starting row + 1.
* @param j Starting column + 1.
* @param nrows Number of rows in block.
* @throw std::out_of_range if either index is out of range.
*/
template <typename T>
inline Eigen::Matrix<T, Eigen::Dynamic, 1> sub_col(
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m, size_t i,
size_t j, size_t nrows) {
template <typename T, typename = require_eigen_t<T>>
inline auto sub_col(const T& m, size_t i, size_t j, size_t nrows) {
check_row_index("sub_col", "i", m, i);
if (nrows > 0) {
check_row_index("sub_col", "i+nrows-1", m, i + nrows - 1);
}
check_column_index("sub_col", "j", m, j);
return m.block(i - 1, j - 1, nrows, 1);
return m.col(j - 1).segment(i - 1, nrows).eval();
}

} // namespace math
Expand Down
10 changes: 4 additions & 6 deletions stan/math/prim/fun/sub_row.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,21 @@ namespace math {
/**
* Return a 1 x nrows subrow starting at (i-1, j-1).
*
* @tparam T type of elements in the matrix
* @tparam T type of the matrix
* @param m Matrix Input matrix.
* @param i Starting row + 1.
* @param j Starting column + 1.
* @param ncols Number of columns in block.
* @throw std::out_of_range if either index is out of range.
*/
template <typename T>
inline Eigen::Matrix<T, 1, Eigen::Dynamic> sub_row(
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m, size_t i,
size_t j, size_t ncols) {
template <typename T, typename = require_eigen_t<T>>
inline auto sub_row(const T& m, size_t i, size_t j, size_t ncols) {
check_row_index("sub_row", "i", m, i);
check_column_index("sub_row", "j", m, j);
if (ncols > 0) {
check_column_index("sub_col", "j+ncols-1", m, j + ncols - 1);
}
return m.block(i - 1, j - 1, 1, ncols);
return m.row(i - 1).segment(j - 1, ncols).eval();
}

} // namespace math
Expand Down
Loading