00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 #ifndef CLevenbergMarquardt_H
00029 #define CLevenbergMarquardt_H
00030
00031 #include <mrpt/utils/CDebugOutputCapable.h>
00032 #include <mrpt/math/CMatrixD.h>
00033 #include <mrpt/math/utils.h>
00034
00035
00036
00037
00038 namespace mrpt
00039 {
00040 namespace math
00041 {
00042
00043
00044
00045
00046
00047
00048
00049
00050 template <typename NUMTYPE = double, class USERPARAM = std::vector<NUMTYPE> >
00051 class CLevenbergMarquardtTempl : public mrpt::utils::CDebugOutputCapable
00052 {
00053 public:
00054
00055
00056
00057
00058
00059
00060 typedef void (*TFunctor)(
00061 const std::vector<NUMTYPE> &x,
00062 const USERPARAM &y,
00063 std::vector<NUMTYPE> &out);
00064
00065 struct TResultInfo
00066 {
00067 NUMTYPE final_sqr_err;
00068 size_t iterations_executed;
00069 CMatrixTemplateNumeric<NUMTYPE> path;
00070
00071
00072
00073
00074
00075
00076
00077
00078 };
00079
00080
00081
00082
00083
00084
00085
00086 static void execute(
00087 std::vector<NUMTYPE> &out_optimal_x,
00088 const std::vector<NUMTYPE> &x0,
00089 TFunctor functor,
00090 const std::vector<NUMTYPE> &increments,
00091 const USERPARAM &userParam,
00092 TResultInfo &out_info,
00093 bool verbose = false,
00094 const size_t &maxIter = 200,
00095 const NUMTYPE tau = 1e-3,
00096 const NUMTYPE e1 = 1e-8,
00097 const NUMTYPE e2 = 1e-8
00098 )
00099 {
00100 using namespace mrpt;
00101 using namespace mrpt::utils;
00102 using namespace mrpt::math;
00103 using namespace std;
00104
00105 MRPT_TRY_START;
00106
00107 std::vector<NUMTYPE> &x=out_optimal_x;
00108
00109
00110 ASSERT_( increments.size() == x0.size() );
00111
00112 x=x0;
00113 vector<NUMTYPE> f_x;
00114 CMatrixTemplateNumeric<NUMTYPE> AUX;
00115 CMatrixTemplateNumeric<NUMTYPE> J;
00116 CMatrixTemplateNumeric<NUMTYPE> H;
00117 vector<NUMTYPE> g;
00118
00119
00120 mrpt::math::estimateJacobian( x, functor, increments, userParam, J);
00121 H.multiply_AtA(J);
00122
00123 const size_t H_len = H.getColCount();
00124
00125
00126 functor(x, userParam ,f_x);
00127 J.multiply_Atb(f_x, g);
00128
00129
00130 bool found = math::norm_inf(g)<=e1;
00131 if (verbose && found) cout << "[LM] End condition: math::norm_inf(g)<=e1 :" << math::norm_inf(g) << endl;
00132
00133 NUMTYPE lambda = tau * H.maximumDiagonal();
00134 size_t iter = 0;
00135 NUMTYPE v = 2;
00136
00137 vector<NUMTYPE> h_lm;
00138 vector<NUMTYPE> xnew, f_xnew ;
00139 NUMTYPE F_x = pow( math::norm( f_x ), 2);
00140
00141 const size_t N = x.size();
00142
00143 out_info.path.setSize(maxIter,N+1);
00144 out_info.path.insertRow(iter,x);
00145
00146 while (!found && ++iter<maxIter)
00147 {
00148
00149 for (size_t k=0;k<H_len;k++)
00150 H(k,k)+= lambda;
00151
00152
00153 H.inv_fast(AUX);
00154 AUX.multiply_Ab(g,h_lm);
00155 h_lm *= NUMTYPE(-1.0);
00156
00157 double h_lm_n2 = math::norm(h_lm);
00158 double x_n2 = math::norm(x);
00159
00160 if (verbose) cout << "[LM] Iter: " << iter << " x:" << x << endl;
00161
00162 if (h_lm_n2<e2*(x_n2+e2))
00163 {
00164
00165 found = true;
00166 if (verbose)
00167 {
00168 cout.precision(10);
00169 cout << "[LM] End condition: " << scientific << h_lm_n2 << " < " << e2*(x_n2+e2) << endl;
00170 }
00171 }
00172 else
00173 {
00174
00175 xnew = x;
00176 xnew += h_lm;
00177 functor(xnew, userParam ,f_xnew );
00178 const double F_xnew = pow( math::norm(f_xnew), 2);
00179
00180
00181 vector<NUMTYPE> tmp(h_lm);
00182 tmp *= lambda;
00183 tmp -= g;
00184 tmp *= h_lm;
00185 double denom = math::sum(tmp);
00186 double l = (F_x - F_xnew) / denom;
00187
00188
00189
00190 if (l>0)
00191 {
00192
00193 x = xnew;
00194 f_x = f_xnew;
00195 F_x = F_xnew;
00196
00197 math::estimateJacobian( x, functor, increments, userParam, J);
00198 H.multiply_AtA(J);
00199 J.multiply_Atb(f_x, g);
00200
00201 found = math::norm_inf(g)<=e1;
00202 if (verbose && found) cout << "[LM] End condition: math::norm_inf(g)<=e1 :" << math::norm_inf(g) << endl;
00203
00204 lambda *= max(0.33, 1-pow(2*l-1,3) );
00205 v = 2;
00206 }
00207 else
00208 {
00209
00210 lambda *= v;
00211 v*= 2;
00212 }
00213
00214
00215 out_info.path.insertRow(iter,x);
00216 out_info.path(iter,x.size()) = F_x;
00217 }
00218 }
00219
00220
00221 out_info.final_sqr_err = F_x;
00222 out_info.iterations_executed = iter;
00223 out_info.path.setSize(iter,N+1);
00224
00225 MRPT_TRY_END;
00226 }
00227
00228 };
00229
00230
00231 typedef CLevenbergMarquardtTempl<double> CLevenbergMarquardt;
00232
00233 }
00234 }
00235 #endif