// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.


#include <dlib/matrix.h>
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <vector>
#include "../stl_checked.h"
#include "../array.h"
#include "../rand.h"
#include "checkerboard.h"
#include <dlib/statistics.h>

#include "tester.h"
#include <dlib/svm.h>


namespace  
{

    using namespace test;
    using namespace dlib;
    using namespace std;

    logger dlog("test.svm_c_linear");

    typedef matrix<double, 0, 1> sample_type;
    typedef std::vector<std::pair<unsigned int, double> > sparse_sample_type;

// ----------------------------------------------------------------------------------------

    void get_simple_points (
        std::vector<sample_type>& samples,
        std::vector<double>& labels
    )
    {
        samples.clear();
        labels.clear();
        sample_type samp(2);

        samp = 0,0;
        samples.push_back(samp);
        labels.push_back(-1);

        samp = 0,1;
        samples.push_back(samp);
        labels.push_back(-1);

        samp = 3,0;
        samples.push_back(samp);
        labels.push_back(+1);

        samp = 3,1;
        samples.push_back(samp);
        labels.push_back(+1);
    }

// ----------------------------------------------------------------------------------------

    void get_simple_points_sparse (
        std::vector<sparse_sample_type>& samples,
        std::vector<double>& labels
    )
    {
        samples.clear();
        labels.clear();
        sparse_sample_type samp;

        samp.push_back(make_pair(0, 0.0));
        samp.push_back(make_pair(1, 0.0));
        samples.push_back(samp);
        labels.push_back(-1);

        samp.clear();
        samp.push_back(make_pair(0, 0.0));
        samp.push_back(make_pair(1, 1.0));
        samples.push_back(samp);
        labels.push_back(-1);

        samp.clear();
        samp.push_back(make_pair(0, 3.0));
        samp.push_back(make_pair(1, 0.0));
        samples.push_back(samp);
        labels.push_back(+1);

        samp.clear();
        samp.push_back(make_pair(0, 3.0));
        samp.push_back(make_pair(1, 1.0));
        samples.push_back(samp);
        labels.push_back(+1);
    }

// ----------------------------------------------------------------------------------------

    void test_sparse (
    )
    {
        print_spinner();
        dlog << LINFO << "test with sparse vectors";
        std::vector<sparse_sample_type> samples;
        std::vector<double> labels;

        sample_type samp;

        get_simple_points_sparse(samples,labels);

        svm_c_linear_trainer<sparse_linear_kernel<sparse_sample_type> > trainer;
        trainer.set_c(1e4);
        //trainer.be_verbose();
        trainer.set_epsilon(1e-8);


        double obj;
        decision_function<sparse_linear_kernel<sparse_sample_type> > df = trainer.train(samples, labels, obj);
        dlog << LDEBUG << "obj: "<< obj;
        DLIB_TEST_MSG(abs(obj - 0.72222222222) < 1e-8, obj);

        DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6);
        DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6);
        DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6);
        DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6);
    }

// ----------------------------------------------------------------------------------------

    void test_dense (
    )
    {
        print_spinner();
        dlog << LINFO << "test with dense vectors";
        std::vector<sample_type> samples;
        std::vector<double> labels;

        sample_type samp;

        get_simple_points(samples,labels);

        svm_c_linear_trainer<linear_kernel<sample_type> > trainer;
        trainer.set_c(1e4);
        //trainer.be_verbose();
        trainer.set_epsilon(1e-8);


        double obj;
        decision_function<linear_kernel<sample_type> > df = trainer.train(samples, labels, obj);
        dlog << LDEBUG << "obj: "<< obj;
        DLIB_TEST_MSG(abs(obj - 0.72222222222) < 1e-8, obj);
        // There shouldn't be any margin violations since this dataset is so trivial.  So that means the objective
        // should be exactly the squared norm of the decision plane (times 0.5).
        DLIB_TEST_MSG(abs(length_squared(df.basis_vectors(0))*0.5 + df.b*df.b*0.5 - 0.72222222222) < 1e-8, 
                      length_squared(df.basis_vectors(0))*0.5 + df.b*df.b*0.5);

        DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6);
        DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6);
        DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6);
        DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6);
    }

// ----------------------------------------------------------------------------------------

    class svm_c_linear_tester : public tester
    {
    public:
        svm_c_linear_tester (
        ) :
            tester ("test_svm_c_linear",
                    "Runs tests on the svm_c_linear_trainer.")
        {}

        void perform_test (
        )
        {
            test_dense();
            test_sparse();
        }
    } a;

}