diff --git a/MLPP/Data/Data.cpp b/MLPP/Data/Data.cpp index 85dfd82..477ca31 100644 --- a/MLPP/Data/Data.cpp +++ b/MLPP/Data/Data.cpp @@ -463,6 +463,22 @@ namespace MLPP{ return {wordEmbeddings, wordList}; } + std::vector> Data::LSA(std::vector sentences, int dim){ + LinAlg alg; + std::vector> docWordData = BOW(sentences, "Binary"); + + auto [U, S, Vt] = alg.SVD(docWordData); + std::vector> S_trunc = alg.zeromat(dim, dim); + std::vector> Vt_trunc; + for(int i = 0; i < dim; i++){ + S_trunc[i][i] = S[i][i]; + Vt_trunc.push_back(Vt[i]); + } + + std::vector> embeddings = alg.matmult(S_trunc, Vt); + return embeddings; + } + std::vector Data::createWordList(std::vector sentences){ std::string combinedText = ""; for(int i = 0; i < sentences.size(); i++){ diff --git a/MLPP/Data/Data.hpp b/MLPP/Data/Data.hpp index aae8de0..9de85d9 100644 --- a/MLPP/Data/Data.hpp +++ b/MLPP/Data/Data.hpp @@ -47,6 +47,7 @@ class Data{ std::vector> BOW(std::vector sentences, std::string = "Default"); std::vector> TFIDF(std::vector sentences); std::tuple>, std::vector> word2Vec(std::vector sentences, std::string type, int windowSize, int dimension, double learning_rate, int max_epoch); + std::vector> LSA(std::vector sentences, int dim); std::vector createWordList(std::vector sentences); diff --git a/a.out b/a.out index 9eea7b9..82823d6 100755 Binary files a/a.out and b/a.out differ diff --git a/main.cpp b/main.cpp index 093abd4..ab06dba 100644 --- a/main.cpp +++ b/main.cpp @@ -487,6 +487,13 @@ int main() { // alg.printMatrix(wordEmbeddings); // std::cout << std::endl; + std::vector textArchive = {"pizza", "pizza hamburger cookie", "hamburger", "ramen", "sushi", "ramen sushi"}; + + alg.printMatrix(data.LSA(textArchive, 2)); + //alg.printMatrix(data.BOW(textArchive, "Default")); + std::cout << std::endl; + + // std::vector> inputSet = {{1,2},{2,3},{3,4},{4,5},{5,6}}; // std::cout << "Feature Scaling Example:" << std::endl; // alg.printMatrix(data.featureScaling(inputSet)); @@ -629,9 +636,9 @@ int main() { // std::cout << std::endl; // } // Harris detector works. Life is good! - std::vector a = {3,4,4}; - std::vector b= {4,4,4}; - alg.printVector(alg.cross(a,b)); + // std::vector a = {3,4,4}; + // std::vector b = {4,4,4}; + // alg.printVector(alg.cross(a,b));