1 Star 0 Fork 6

C-Band / MLTool

forked from Yang9527 / MLTool 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
lu.h 5.24 KB
一键复制 编辑 原始数据 按行查看 历史
Yang9527 提交于 2013-07-07 22:43 . 13/07/07
#ifndef MLTOOL_LU_H
#define MLTOOL_LU_H
#include<algorithm>
#include<cassert>
#include"vector.h"
#include"tags.h"
#include"triangular_solver.h"
namespace math{
class permutation_matrix
{
public:
explicit permutation_matrix(size_t n);
inline size_t& operator[](size_t index);
inline const size_t& operator[](size_t index) const;
inline const size_t size() const { return v.size();}
private:
vector<size_t> v;
};
permutation_matrix::permutation_matrix(size_t n):v(n)
{
for(size_t i = 0; i < n; ++i)
v[i] = i;
}
inline size_t& permutation_matrix::operator[](size_t index)
{
return v[index];
}
inline const size_t& permutation_matrix::operator[](size_t index) const
{
return v[index];
}
template<class Mat>
void swap_rows(const permutation_matrix & P, Mat& m, matrix_tag)
{
#ifdef CHECK_DIMENSION_MATCH
assert(P.size() == m.size1());
#endif
for(size_t i = 0; i < m.size1(); ++i)
{
if(P[i] != i)
{
for(size_t j = 0; j < m.size2(); ++j)
std::swap(m(i,j),m(P[i],j));
}
}
}
template<class Vec>
void swap_rows(const permutation_matrix& P, Vec &m, vector_tag)
{
#ifdef CHECK_DIMENSION_MATCH
assert(P.size() == m.size());
#endif
for(int i = 0; i < m.size(); ++i)
{
if(P[i] != i)
std::swap(m(P[i]),m(i));
}
}
template<class T>
void swap_rows(const permutation_matrix& P, T &m)
{
swap_rows(P, m, typename T::category());
}
template<class Mat>
size_t lu_factorize(Mat & m)
{
typedef typename Mat::value_type T;
size_t size1 = m.size1();
size_t size2 = m.size2();
size_t size = std::min(size1,size2);
size_t singular_ = 0;
for(int i = 0; i < size; ++i)
{
if(m(i,i) != T())
{
T m_inv = 1.0/m(i,i);
for(int k = i + 1; k < size1; ++k)
{
m(k,i) *= m_inv;
for(int j = i + 1; j < size2; ++j)
m(k,j) -= m(i,j)*m(k,i);
}
}
else if(singular_ == 0)
{
singular_ = i + 1;
return singular_;
}
}
return singular_;
}
template<class Mat>
size_t lu_factorize(Mat& m, permutation_matrix& P)
{
typedef typename Mat::value_type T;
size_t size1 = m.size1();
size_t size2 = m.size2();
size_t size = std::min(size1, size2);
size_t singular_ = 0;
for(int i = 0; i < size; ++i)
{
size_t index_norm_inf = i;
T max_elem = abs(m(i,i));
for(int j = i + 1; j < size1; ++j)
{
if(abs(m(j,i))> max_elem)
{
max_elem = abs(m(j,i));
index_norm_inf = j;
}
}
if(i != index_norm_inf)
{
P[i] = index_norm_inf;
for(int k = 0; k < size2; ++k)
std::swap(m(i,k), m(index_norm_inf,k));
}
if(m(i,i) != T())
{
T m_inv = 1.0 / m(i,i);
for(int k = i+1; k < size1; ++k)
{
m(k,i) *= m_inv;
for(int j = i+1; j < size2; ++j)
m(k,j) -= m(i,j)*m(k,i);
}
}
else if(singular_ == 0)
{
singular_ = i + 1;
return singular_;
}
}
return singular_;
}
/*
template<class Mat>
size_t lu_factorize(Mat &m, permutation_matrix &P, permutation_matrix &Q)
{
typedef typename Mat::value_type T;
size_t size1 = m.size1();
size_t size2 = m.size2();
size_t size = std::min(size1,size2);
size_t singular_ = 0;
for(int i = 0; i < size; ++i)
{
//找主元
size_t index_1 = i;
size_t index_2 = i;
T max_elem = std::abs(m(i,i));
for(int k = i + 1; k < size1; ++k)
for(int j = i + 1; j < size2; ++j)
if(std::abs(m(k,j) > max_elem))
{
index_1 = k;
index_2 = j;
max_elem = std::abs(m(k,j));
}
if(index_1 != i)
{
P[i] = index_1;
for(int j = 0; j < size2; ++j)
std::swap(m(i,j),m(index_1,j));
}
if(index_2 != i)
{
Q[i] = index_2;
for(int k = 0; k < size1; ++k)
std::swap(m(k,i),m(k,index_2));
}
if(m(i,i) != T())
{
T m_inv = 1.0 / m(i,i);
for(int k = i + 1; k < size1; ++k)
{
m(k,i) *= m_inv;
for(int j = i + 1; j < size2; ++j)
m(k,j) -= m(i,j) * m(k,i);
}
}
else if(singular_ == 0)
{
singular_ = i + 1;
return singular_;
}
}
return singular_;
}
*/
template<class Mat,class V>
int lu_inplace_solve(Mat &A, V &B)
{
permutation_matrix P(A.size1());
int singular = lu_factorize(A,P);
if(0 == singular)
{
swap_rows(P,B);
inplace_solve(A,B,unit_lower_tag());
inplace_solve(A,B,upper_tag());
}
return singular;
}
}
#endif // LU_H
C++
1
https://gitee.com/C-BAND/mltool.git
git@gitee.com:C-BAND/mltool.git
C-BAND
mltool
MLTool
master

搜索帮助

53164aa7 5694891 3bd8fe86 5694891