/*  Jazz, a program for playing chess
 *  Copyright (C) 2009, 2011  Evert Glebbeek
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
#ifdef SMP
#include <pthread.h>
#endif
#include <stdint.h>
#include "assert.h"
#include "smp.h"
#include "alphabeta.h"
#include "history.h"

#ifdef SMP

#undef SMP_DEBUG

#ifdef SMP_DEBUG

#define TRACE(...) printf(__VA_ARGS__);

#else

#define TRACE(...) (void)0

#endif


#if defined _WIN32 || defined _WIN64
#include <windows.h>
#else
#include <unistd.h>
#endif

typedef struct thread_t {
   struct gamestate_t *game;     /* Work space for each thread */
   split_t *splitpoint;          /* Current split point */
   int thread_id;                /* Thread ID number */
   volatile int active_splits;   /* Number of active splits for this thread */
   volatile uint8_t state;       /* The current state (flag) of the thread */
   volatile bool must_die;       /* Thread will exit if this becomes true (FIXME: should be a state flag?) */
   pthread_t pt;                 /* pthreads thread ID (for pthread_join) */
   char barrier[128];            /* Make sure data for different threads is on different cache lines */
} thread_t;

#define THREAD_WORKING     0x01  /* Thread is currently searching */
#define THREAD_SLEEPING    0x02  /* Thread is sleeping */
#define THREAD_WAITING     0x04  /* Thread is waiting to be assigned work */
#define THREAD_SCHEDULED   0x08  /* Thread has been assigned work but has not started searching yet */

static thread_t thread[MAX_THREADS];
static int threads_created = 0;

static split_t splitpoint[MAX_THREADS][MAX_SPLIT];

/* Condition variable for threads to go to sleep */
pthread_cond_t sleep_condition;
pthread_mutex_t sleep_mutex;
lock_t split_mutex;

/* Return the total number of cores on the current machine.
 * NB: returns the number of "logical" cores on hyper-threading machines.
 */
int get_number_of_cores(void)
{
#ifdef WIN32
    SYSTEM_INFO sysinfo;
    GetSystemInfo(&sysinfo);
    return sysinfo.dwNumberOfProcessors;
#else
    return sysconf(_SC_NPROCESSORS_ONLN);
#endif
}

/* Copy relevant data from the game tree to the split point, where the worker
 * threads can pick it up.
 */
static void copy_game_state_to_splitpoint(const gamestate_t *from, gamestate_t *to)
{
   //memcpy(to->board_list,     from->board_list,    (from->moves_played+1) * sizeof *from->board_list);
   //memcpy(to->fifty_counter,  from->fifty_counter, (from->moves_played+1) * sizeof *from->fifty_counter);
   //memcpy(to->move_list,      from->move_list,     (from->moves_played+1) * sizeof *from->fifty_counter);
   to->fifty_counter = from->fifty_counter;
   to->move_list = from->move_list;
   to->board_list = from->board_list;

   to->clock.root_moves_played = from->clock.root_moves_played;
   to->moves_played = to->last_move = from->moves_played;
   to->side_to_move = from->side_to_move;
   to->board        = to->board_list + to->moves_played;
   to->root_board   = from->root_board;

   assert(get_occupied(to->board));

   assert(from->board->bbc[0] == to->board->bbc[0]);
   assert(from->board->bbc[1] == to->board->bbc[1]);

   to->max_nodes = from->max_nodes;
   to->positions_evaluated = 0;
   to->moves_searched = 0;
   to->positions_in_hashtable = 0;
   to->branches_pruned = 0;
   to->branches_pruned_1st = 0;

   /* The transposition table is shared */
   to->transposition_table = from->transposition_table;
   to->eval_hash = from->eval_hash;

   /* We have to be able to check the clock */
   to->clock = from->clock;

   //memcpy(&to->dynamic_psq, &from->dynamic_psq, sizeof(to->dynamic_psq));
#if 0
   memcpy(to->killer, from->killer, sizeof(to->killer));
   memcpy(to->counter, from->counter, sizeof(to->counter));
   memcpy(to->history_reduce, from->history_reduce, sizeof(to->history_reduce));
#endif
}


/* Copy the relevant game state from the split point to a thread's working space. */
static void copy_game_state_from_splitpoint(const gamestate_t *from, gamestate_t *to)
{
   size_t s = (from->moves_played - from->clock.root_moves_played + 1);
   size_t off = from->clock.root_moves_played;
   assert(from->clock.root_moves_played == to->clock.root_moves_played);

   memcpy(to->board_list    + off, from->board_list    + off, s * sizeof *from->board_list);
   memcpy(to->fifty_counter + off, from->fifty_counter + off, s * sizeof *from->fifty_counter);
   memcpy(to->move_list     + off, from->move_list     + off, s * sizeof *from->fifty_counter);

   to->moves_played = to->last_move = from->moves_played;
   to->side_to_move = from->side_to_move;
   to->board        = to->board_list + to->moves_played;
   to->root_board   = from->root_board;

   assert(get_occupied(to->board));

   assert(from->board->bbc[0] == to->board->bbc[0]);
   assert(from->board->bbc[1] == to->board->bbc[1]);

   to->max_nodes = from->max_nodes;
   to->positions_evaluated = 0;
   to->moves_searched = 0;
   to->positions_in_hashtable = 0;
   to->branches_pruned = 0;
   to->branches_pruned_1st = 0;

   /* The transposition table is shared */
   to->transposition_table = from->transposition_table;
   to->eval_hash = from->eval_hash;

   /* We have to be able to check the clock */
   to->clock = from->clock;

   //memcpy(&to->dynamic_psq, &from->dynamic_psq, sizeof(to->dynamic_psq));
#if 0
   memcpy(to->killer, from->killer, sizeof(to->killer));
   memcpy(to->counter, from->counter, sizeof(to->counter));
   memcpy(to->history_reduce, from->history_reduce, sizeof(to->history_reduce));
#endif
}


/* Main loop for each thread */
static void *thread_func(void *arg)
{
   /* Initialisation */
   thread_t *self = (thread_t *)arg;

   while (!self->must_die) {
      if (self->state == THREAD_SLEEPING) {
         assert(self->thread_id != 0);
         pthread_mutex_lock(&sleep_mutex);
         pthread_cond_wait(&sleep_condition, &sleep_mutex);
         pthread_mutex_unlock(&sleep_mutex);
         self->state = THREAD_WAITING;
      }

      /* Wait for work */
      while (self->state == THREAD_WAITING && !self->must_die);

      /* If we're scheduled to do work, start working */
      while (self->state == THREAD_SCHEDULED);

      if (self->state == THREAD_WORKING) {
         TRACE("Thread %d accepted work\n", self->thread_id);

         copy_game_state_from_splitpoint(self->splitpoint->game, self->game);
         assert(self->game != self->splitpoint->game);
         assert(self->game->board != self->splitpoint->game->board);
         assert(self->game->board->bbc[0] == self->splitpoint->game->board->bbc[0]);
         assert(self->game->board->bbc[1] == self->splitpoint->game->board->bbc[1]);

         /* Call SMP search */
         search_split(self->game, self->splitpoint);

         /* Mark thread as waiting for work. */
         self->state = THREAD_WAITING;
      }
   }
   return NULL;
}

/* Kick all threads into action */
void wake_all_threads(void)
{
   if (threads_created) {
      pthread_mutex_lock(&sleep_mutex);
      pthread_cond_broadcast(&sleep_condition);
      pthread_mutex_unlock(&sleep_mutex);
   }
}

/* Tell all threads to go to sleep */
void sleep_all_threads(void)
{
   int n;
   for (n=1; n<threads_created; n++)
      thread[n].state = THREAD_SLEEPING;
}

/* Initialise data structures for the specified number of threads */
void init_threads(int num_threads)
{
   int n;
   if (num_threads > MAX_THREADS) num_threads = MAX_THREADS;

   if (threads_created)
      kill_threads();

   threads_created = 0;
   if (num_threads < 2) return;

   /* Initialise global condition and mutex variables */
   pthread_cond_init(&sleep_condition, NULL);
   pthread_mutex_init(&sleep_mutex, NULL);

   init_lock(&split_mutex);

   /* Initialise split points */
   for (n=0; n<num_threads; n++) {
      for (int k=0; k<MAX_SPLIT; k++) {
         splitpoint[n][k].game = create_game();

         init_lock(&splitpoint[n][k].lock);
      }
   }

   /* Set options for created threads: run them detached from the main thread */
   pthread_attr_t attributes;
   pthread_attr_init(&attributes); 
   pthread_attr_setdetachstate(&attributes, PTHREAD_CREATE_DETACHED); 

   /* Initialise threads */
   memset(&thread, 0, sizeof thread);
   for (n=0; n<num_threads; n++) {
      gamestate_t *game = create_game();
      thread[n].thread_id = n;
      thread[n].must_die = false;
      thread[n].game = game;
      game->default_hash_size = 1;
      start_new_game(game);
      destroy_hash_table(game->transposition_table);
      destroy_eval_hash_table(game->eval_hash);
      clear_history(game);
      game->transposition_table = NULL;
      game->eval_hash = NULL;
      game->output_iteration = NULL;
      game->uci_output = NULL;
      game->xboard_output = NULL;
      game->thread_id = thread[n].thread_id;

      if (n == 0) continue;

      thread[n].state = THREAD_SLEEPING;
      if (pthread_create(&thread[n].pt, NULL, thread_func, &thread[n]) != 0) {
         printf("Failed to launch thread %d\n", n);
         exit(0);
      }
   }
   thread[0].state = THREAD_WORKING;
   thread[0].must_die = false;
   threads_created = num_threads;
}

/* Terminate all threads */
void kill_threads(void)
{
   int n;

   for (n=1; n<MAX_THREADS; n++) {
      thread[n].must_die = true;
   }

   wake_all_threads();

   for (n=1; n<threads_created; n++) {
      pthread_join(thread[n].pt, NULL);
      thread[n].game->transposition_table = NULL;     /* The transposition table is shared */
      thread[n].game->eval_hash = NULL;     /* The transposition table is shared */
      end_game(thread[n].game);
   }
   threads_created = 0;
}

/* Copy the static piece square table (used for move ordering only) to all threads.
 * This is done once at each root search.
 */
void share_static_psq(combined_psq_t *psq)
{
   for (int n = 0; n<threads_created; n++) {
      if (thread[n].game) memcpy(&thread[n].game->dynamic_psq, psq, sizeof *psq);
   }
}

void copy_game_history(const gamestate_t *from)
{
   int n;
   for (n=0; n<threads_created; n++) {
      gamestate_t *to = thread[n].game;

      assert(to);

      to->clock.root_moves_played = from->clock.root_moves_played;
      memcpy(to->board_list,     from->board_list,    (from->moves_played+1) * sizeof *from->board_list);
      memcpy(to->fifty_counter,  from->fifty_counter, (from->moves_played+1) * sizeof *from->fifty_counter);
      memcpy(to->move_list,      from->move_list,     (from->moves_played+1) * sizeof *from->fifty_counter);
   }
}

/* Set status word: terminate search at the secified split point, in case of a
 * beta-cutoff.
 */
void stop_split_search(split_t *split)
{
   split->stop = true;
}

int get_number_of_threads(void)
{
   return threads_created;
}

/* Returns true if the thread should stop searching (because a cutoff has occurred
 * in one of the other threads working on this split point).
 */
bool thread_should_stop(int id)
{
   split_t *split = thread[id].splitpoint;

   while (split) {
      if (split->stop)
         return true;
      split = split->parent;  
   }
   return false;
}

static bool thread_can_help(int helper, int master)
{
   if (helper == master) return false;                            /* Can't help self         */
   if (thread[helper].state != THREAD_WAITING) return false;      /* Can't help if busy      */

   int active_splits = thread[helper].active_splits;
   if (active_splits == 0) return true;            /* Can always help if idle */

   /* If we're "the other thread" we can always help. */
   if (threads_created == 2) return true;

   /* TODO: if the helper has active splits, but they are on the same branch as the
    * master is searching, then it can still help because accumulated data is
    * relevant.
    */
   if (splitpoint[helper][active_splits - 1].workers & (1<<master))
      return true;

   return false;
}

/* Returns true if there are any threads available to help the master thread */
static bool thread_available(int master)
{
   int n;
   for (n=0; n<threads_created; n++) {
      if (thread_can_help(n, master))
         return true;
   }
   return false;
}

void split(gamestate_t *game, movelist_t *movelist, int static_score, int *alpha, int *beta, int depth, int draft, int *score, move_t *best_move)
{
   int me = game->thread_id;
   int n;

   if (threads_created <= 1) return;

   /* We have to manipulate the split point stack */
   acquire_lock(&split_mutex);

   /* Maximum number of split points reached? */
   if (thread[me].active_splits == MAX_SPLIT) goto done;

   /* Test if there is an idle thread waiting */
   if (!thread_available(me)) goto done;

   /* There are threads available, initialise a split point */
   split_t *sp = &splitpoint[me][thread[me].active_splits];
   thread[me].active_splits++;

   /* Copy data */
   TRACE("Creating split point, ply = %d depth = %d [a,b]=[%d, %d]\n", depth, draft/PLY, *alpha, *beta);
   copy_game_state_to_splitpoint(game, sp->game);
   sp->movelist = movelist;
   sp->parent = thread[me].splitpoint;
   sp->alpha = *alpha;
   sp->beta  = *beta;
   sp->score = *score;
   sp->depth = depth;
   sp->draft = draft;
   sp->best_move = *best_move;
   sp->stop  = false;
   sp->workers = 0;

   thread[me].splitpoint = sp;

   /* Assign all workers */
   int num_workers = 0;
   for (n=0; n<threads_created; n++) {
      sp->workers |= (1<<me);
      if (thread_can_help(n, me) || n == me) {
         TRACE("Scheduling work for thread %d\n", n);
         thread[n].state = THREAD_SCHEDULED;
         thread[n].splitpoint = sp;
         TRACE("> %d %d\n", n, me);
         sp->workers |= (1<<n);
         num_workers++;
      }
   }

   /* We should have more than one worker available, otherwise we would have exited
    * this function at the top.
    */
   assert(num_workers > 1);

   /* We can release the split lock since workers have already been marked as
    * "scheduled"
    */
   release_lock(&split_mutex);

   /* Now kick all workers into action */
   TRACE("Starting threads\n");
   for (n=0; n<threads_created; n++) {
      if (sp->workers & (1<<n)) {
         TRACE("< %d %d\n", n, me);
         thread[n].state = THREAD_WORKING;
      }
   }

   /* Enter thread loop for the master thread */
   TRACE("Entering loop (#%d)\n", me);
   copy_game_state_from_splitpoint(thread[me].splitpoint->game, thread[me].game);
   search_split(thread[me].game, thread[me].splitpoint);
   assert( (sp->workers & (1<<me)) == 0);
   while (sp->workers);
   TRACE("Exit loop (#%d)\n", me);

   /* Update variables from the split point and return */
   acquire_lock(&split_mutex);
   *alpha = sp->alpha;
   *beta  = sp->beta;
   *score = sp->score;
   *best_move = sp->best_move;
   thread[me].active_splits--;
   thread[me].splitpoint = sp->parent;
   TRACE("Joined split point, ply = %d depth = %d [a,b]=[%d, %d]\n", depth, draft/PLY, *alpha, *beta);

   game->positions_evaluated    += sp->game->positions_evaluated;
   game->moves_searched         += sp->game->moves_searched;
   game->positions_in_hashtable += sp->game->positions_in_hashtable;
   game->branches_pruned        += sp->game->branches_pruned;
   game->branches_pruned_1st    += sp->game->branches_pruned_1st;

done:
   release_lock(&split_mutex);
   return;
}
#endif
