diff --git a/mlpp/lin_alg/mlpp_matrix.cpp b/mlpp/lin_alg/mlpp_matrix.cpp index d234f9a..2aa0d2f 100644 --- a/mlpp/lin_alg/mlpp_matrix.cpp +++ b/mlpp/lin_alg/mlpp_matrix.cpp @@ -1999,6 +1999,21 @@ Ref MLPPMatrix::identity_mat(int d) const { return identity_mat; } +Ref MLPPMatrix::create_identity_mat(int d) { + Ref identity_mat; + identity_mat.instance(); + identity_mat->resize(Size2i(d, d)); + identity_mat->fill(0); + + real_t *im_ptr = identity_mat->ptrw(); + + for (int i = 0; i < d; i++) { + im_ptr[identity_mat->calculate_index(i, i)] = 1; + } + + return identity_mat; +} + Ref MLPPMatrix::cov() const { MLPPStat stat; diff --git a/mlpp/lin_alg/mlpp_matrix.h b/mlpp/lin_alg/mlpp_matrix.h index a78af96..378cd68 100644 --- a/mlpp/lin_alg/mlpp_matrix.h +++ b/mlpp/lin_alg/mlpp_matrix.h @@ -248,6 +248,8 @@ public: Ref identityn() const; Ref identity_mat(int d) const; + static Ref create_identity_mat(int d); + Ref cov() const; void covo(Ref out) const;