/* alphaBeta2.t.cc
 */
#include "problems.h"
#include "osl/search/alphaBeta2.h"
#include "osl/search/simpleHashTable.h"
#include "osl/move_generator/legalMoves.h"
#include "osl/container/moveVector.h"
#include "osl/record/csaRecord.h"
#include "osl/record/csaString.h"
#include "osl/oslConfig.h"

#include <cppunit/TestCase.h>
#include <cppunit/extensions/HelperMacros.h>

#include <fstream>
#include <iostream>
#include <string>

class AlphaBeta2Test : public CppUnit::TestFixture 
{
  CPPUNIT_TEST_SUITE(AlphaBeta2Test);
  CPPUNIT_TEST(testProblems1);
  CPPUNIT_TEST(testProblems2);
  CPPUNIT_TEST(testProblems3);
  CPPUNIT_TEST(testRootIgnore);
  CPPUNIT_TEST(testRootAllIgnore);
  CPPUNIT_TEST_SUITE_END();
  void testProblems(const osl::Problem *problems, int num_problems, int depth);
public:
  void testProblems1();
  void testProblems2();
  void testProblems3();
  void testRootIgnore();
  void testRootAllIgnore();
};

CPPUNIT_TEST_SUITE_REGISTRATION(AlphaBeta2Test);

using namespace osl;
using namespace osl::search;

typedef SearchState2::checkmate_t checkmate_t;

void AlphaBeta2Test::
testProblems(const Problem *problems, int numProblems, int depth)
{
  static bool loaded = eval::ProgressEval::setUp()
    && osl::eval::ml::OpenMidEndingEval::setUp()
    && osl::progress::ml::NewProgress::setUp();
  CPPUNIT_ASSERT(loaded);
 
  CountRecorder recorder;

  for (int i=0; i<numProblems; ++i)
  {
    NumEffectState state((CsaString(problems[i].state).getInitialState()));
    checkmate_t checkmate;
    SimpleHashTable table(80000, -1);
    AlphaBeta2OpenMidEndingEval searcher(state, checkmate, &table, recorder);

    const Move best_move = searcher.computeBestMoveIteratively(depth, 200);
    CPPUNIT_ASSERT(problems[i].acceptable(best_move));
  }
}

void AlphaBeta2Test::testProblems1()
{
  testProblems(problems1, numProblems1, 600);
}
void AlphaBeta2Test::testProblems2()
{
  testProblems(problems1, numProblems1, 800);
}
void AlphaBeta2Test::testProblems3()
{
  testProblems(problems3, numProblems3, 1000);
}

void AlphaBeta2Test::testRootIgnore()
{
  NumEffectState state(CsaString(
			 "P1-GI-KE-KI-GI-KE-TO-TO-TO-OU\n"
			 "P2 *  *  *  *  *  *  * -TO-KI\n"
			 "P3-FU * -FU-FU *  *  *  * -KY\n"
			 "P4 * -KY *  *  *  *  * -UM-FU\n"
			 "P5 *  *  *  *  * -HI-FU *  * \n"
			 "P6 *  * +KE *  * +KY *  * +KE\n"
			 "P7+FU+FU+FU+FU+FU+FU+FU+FU+FU\n"
			 "P8 * +KA+KI * +OU *  * +HI * \n"
			 "P9+KY * +GI *  * +KI+GI *  * \n"
			 "+\n").getInitialState());
  
  checkmate_t checkmate;
  SimpleHashTable table(1000000, -1);
  CountRecorder recorder;
  AlphaBeta2ProgressEval searcher(state, checkmate, &table, recorder);

  const Move best_move(Square(1,6),Square(2,4),KNIGHT,PBISHOP,false,BLACK);
  {
    const Move move = searcher.computeBestMoveIteratively(1000, 200);
    table.clear();
    CPPUNIT_ASSERT_EQUAL(best_move, move);
  }
  
  MoveVector ignores;
  searcher.setRootIgnoreMoves(&ignores, true);

  {
    const Move move = searcher.computeBestMoveIteratively(1000, 200);
    table.clear();
    CPPUNIT_ASSERT_EQUAL(best_move, move);
  }

  ignores.push_back(best_move);

  const Move best_move2(Square(4,6),Square(4,5),LANCE,ROOK,false,BLACK);
  {
    const Move move = searcher.computeBestMoveIteratively(1000, 200);
    table.clear();
    CPPUNIT_ASSERT_EQUAL(best_move2, move);
  }

  ignores.push_back(best_move2);

  const Move best_move3(Square(7,6),Square(8,4),KNIGHT,LANCE,false,BLACK);
  {
    const Move move = searcher.computeBestMoveIteratively(1000, 200);
    table.clear();
    CPPUNIT_ASSERT_EQUAL(best_move3, move);
  }
}

void AlphaBeta2Test::testRootAllIgnore()
{
  NumEffectState state(CsaString(
			 "P1-GI-KE-KI-GI-KE-TO-TO-TO-OU\n"
			 "P2 *  *  *  *  *  *  * -TO-HI\n"
			 "P3-FU * -FU-FU *  *  *  * -KY\n"
			 "P4 * -KY *  *  *  *  * -KA-FU\n"
			 "P5 *  *  *  *  * -KI-FU *  * \n"
			 "P6 *  * +KE *  * +KY *  * +KE\n"
			 "P7+FU+FU+FU+FU+FU+FU+FU+FU+FU\n"
			 "P8 * +KA+KI * +OU *  * +HI * \n"
			 "P9+KY * +GI *  * +KI+GI *  * \n"
			 "+\n").getInitialState());
  
  checkmate_t checkmate;
  SimpleHashTable table(1000000, -1);
  CountRecorder recorder;
  AlphaBeta2ProgressEval searcher(state, checkmate, &table, recorder);

  MoveVector moves;
  LegalMoves::generate(state, moves);
  
  searcher.setRootIgnoreMoves(&moves, true);

  const Move move = searcher.computeBestMoveIteratively(1000, 200);
  std::cerr << move << "\n";
  CPPUNIT_ASSERT(move.isInvalid());
}

/* ------------------------------------------------------------------------- */
// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; End:
