smooth
A C++ library for Lie theory
Loading...
Searching...
No Matches
ceres.hpp
Go to the documentation of this file.
1// Copyright (C) 2022 Petter Nilsson. MIT License.
2
3#pragma once
4
10#include <utility>
11
12#include <ceres/autodiff_manifold.h>
13#include <ceres/internal/autodiff.h>
14
15#define SMOOTH_DIFF_CERES
16
17#include "smooth/detail/traits.hpp"
18#include "smooth/detail/wrt_impl.hpp"
19#include "smooth/lie_group_base.hpp"
20#include "smooth/manifolds.hpp"
21#include "smooth/wrt.hpp"
22
23SMOOTH_BEGIN_NAMESPACE
24
25// mark Jet as a valid scalar
26template<typename T, int I>
27struct detail::scalar_trait<ceres::Jet<T, I>>
28{
29 static constexpr bool value = true;
30};
31
32// \cond
33template<Manifold G>
34struct CeresParamFunctor
35{
36 template<typename Scalar>
37 bool Plus(const Scalar * x, const Scalar * delta, Scalar * x_plus_delta) const
38 {
39 smooth::MapDispatch<const CastT<Scalar, G>> mx(x);
40 Eigen::Map<const Tangent<CastT<Scalar, G>>> mdelta(delta);
41 smooth::MapDispatch<CastT<Scalar, G>> mx_plus_delta(x_plus_delta);
42
43 mx_plus_delta = rplus(mx, mdelta);
44
45 return true;
46 }
47
48 template<typename Scalar>
49 bool Minus(const Scalar * x, const Scalar * y, Scalar * x_minus_y) const
50 {
51 smooth::MapDispatch<const CastT<Scalar, G>> mx(x);
52 smooth::MapDispatch<const CastT<Scalar, G>> my(y);
53 Eigen::Map<Tangent<CastT<Scalar, G>>> m_x_minus_y(x_minus_y);
54
55 m_x_minus_y = rminus<CastT<Scalar, G>, CastT<Scalar, G>>(mx, my);
56
57 return true;
58 }
59};
60// \endcond
61
65template<Manifold G>
66using CeresLocalParameterization = ceres::AutoDiffManifold<CeresParamFunctor<G>, G::RepSize, G::Dof>;
67
75auto dr_ceres(auto && f, auto && x)
76{
77 // There is potential to improve thie speed of this by reducing casting.
78 // The ceres Jet type supports binary operations with e.g. double, but currently
79 // the Lie operations require everything to have a uniform scalar type. Enabling
80 // plus and minus for different scalars would thus save some casts.
81 using Result = decltype(std::apply(f, x));
82 using Scalar = ::smooth::Scalar<Result>;
83
84 static_assert(Manifold<Result>, "f(x) is not a Manifold");
85
86 Result fval = std::apply(f, x);
87
88 static constexpr Eigen::Index Nx = wrt_Dof<decltype(x)>();
89 static constexpr Eigen::Index Ny = Dof<Result>;
90 const Eigen::Index nx = std::apply([](auto &&... args) { return (dof(args) + ...); }, x);
91 const Eigen::Index ny = dof<Result>(fval);
92
93 static_assert(Nx > 0, "Ceres autodiff does not support dynamic sizes");
94
95 Eigen::Matrix<Scalar, Nx, 1> a = Eigen::Matrix<Scalar, Nx, 1>::Zero(nx);
96 Eigen::Matrix<Scalar, Ny, 1> b(ny);
97 Eigen::Matrix<Scalar, Ny, Nx, (Nx == 1) ? Eigen::ColMajor : Eigen::RowMajor> jac(ny, nx);
98 jac.setZero();
99
100 const auto f_deriv = [&]<typename T>(const T * in, T * out) {
101 Eigen::Map<const Eigen::Matrix<T, Nx, 1>> mi(in, nx);
102 Eigen::Map<Eigen::Matrix<T, Ny, 1>> mo(out, ny);
103 mo = rminus<CastT<T, Result>>(std::apply(f, wrt_rplus(wrt_cast<T>(x), mi)), cast<T, Result>(fval));
104 return true;
105 }; // NOLINT
106 const Scalar * a_ptr[1] = {a.data()};
107 Scalar * jac_ptr[1] = {jac.data()};
108
109 ceres::internal::AutoDifferentiate<Dof<Result>, ceres::internal::StaticParameterDims<Nx>>(
110 f_deriv, a_ptr, static_cast<int>(b.size()), b.data(), jac_ptr);
111
112 return std::make_pair(std::move(fval), Eigen::Matrix<Scalar, Ny, Nx>(jac));
113}
114
115SMOOTH_END_NAMESPACE
ceres::AutoDiffManifold< CeresParamFunctor< G >, G::RepSize, G::Dof > CeresLocalParameterization
Parameterization for on-manifold optimization with Ceres.
Definition ceres.hpp:66
auto dr_ceres(auto &&f, auto &&x)
Automatic differentiation in tangent space.
Definition ceres.hpp:75
Class-external Manifold interface defined through the traits::man trait class.
Definition manifold.hpp:31
typename traits::man< M >::Scalar Scalar
Manifold scalar type.
Definition manifold.hpp:88
Eigen::Index dof(const M &m)
Manifold degrees of freedom (tangent space dimension)
Definition manifold.hpp:145
PlainObject< M > rplus(const M &m, const Eigen::MatrixBase< Derived > &a)
Manifold right-plus.
Definition manifold.hpp:163
typename traits::man< M >::template CastT< NewScalar > CastT
Cast'ed type.
Definition manifold.hpp:100
Meta header to include all Manifold concept model specifications.