/*******************************************************************************
 *
 *  instanton search for iterative decoding
 *  Copyright (C) 2010-2011 Misha Stepanov
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License
 *    http://www.gnu.org/licenses/gpl.txt
 *  for more details.
 *
 ******************************************************************************/

/*******************************************************************************
 *
 *  Finding iterative decoding instantons using the procedure described in
 *    M. Stepanov, Instantons causing iterative decoding to cycle,
 *    submitted to IEEE Transactions on Information Theory
 *    [arxiv: cs.IT/1108.5547]
 *
 *  usage: ./instanton <n_iter;max> <noise> <time> <id>
 *    <n_iter;max> --- maximal number of iterations
 *    <noise> --- minimal possible value of the perturbation amplitude
 *    <time> --- run for <time> CPU time, then exit
 *    <id> --- id of the run, used to generate filenames
 *
 *  The working directory should contain the files
 *    matrix_H --- parity check matrix of the code, see ldpcc.c for the format
 *    data/height/, data/noise/, and data/time/ directories
 *
 *  The instantons are stored in data/noise/<id> file, while data/height/<id>
 *  and data/time/<id> could be used to track progress.
 *
 ******************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <math.h>

#define LDPCC_RANDOM yes

#include "../ldpcc/3.2.3/ldpcc.c"

typedef struct {
  real *xi; /* noise configuration */
  real height; /* cost of noise, smaller the better */
  int n_iter; /* number of iterations, larger the better */
  int *mask; /* bits [not] to be perturbed */
  int n_iter_min; /* best point from n_iter_min to n_iter */
  real noise; /* how large do we want the perturbation to be */
  int io_flag;
 } search_point;

int N_ITER, IT;
search_point **sp, *tmpsp = 0;
FILE *data;

channel C;
code H;
iterative_decoder D;

#define CREATE_SEARCH_POINT(SP) \
  { ALLOCATE(SP, sizeof(search_point), search_point, search_point *, "SP"); \
    ALLOCATE(SP->xi, H.bits, real, real *, "SP_xi"); \
    ALLOCATE(SP->mask, H.bits, int, int *, "SP_mask"); \
    for (i = 0; i < H.bits; i++) SP->mask[i] = 1; \
    SP->noise = 0.1; }

/* zeroing and masking the bits on which the nouse is almost zero */
#define ZEROING_AND_MASKING \
  { for (tmpsp->height = 0., i = 0; i < H.bits; i++) \
      tmpsp->height += (tmpsp->xi[i] * tmpsp->xi[i]); \
    for (i = 0; i < H.bits; i++) \
      if (tmpsp->xi[i] * tmpsp->xi[i] < 4.e-16 * tmpsp->height) \
        { tmpsp->xi[i] = 0.; tmpsp->mask[i] = 0; } }

void process_tmp_search_point()
 {
  int i, it, flag;

/* Gaussian (or AWGN) channel */
  for (tmpsp->height = 0., i = 0; i < H.bits; i++)
    tmpsp->height += (tmpsp->xi[i] * tmpsp->xi[i]);

  send_msg_w2b_calc_h(&H, &C, tmpsp->xi, 0);
  tmpsp->n_iter = iterative_decoding(&H, &D, 0) - 1;

  for (tmpsp->n_iter_min = -1, it = 0; it <= tmpsp->n_iter; it++)
   {
    flag = 0;
    if (sp[it] == 0) flag = 1;
     else if (tmpsp->height < sp[it]->height) flag = 1;
    if (flag == 1) { tmpsp->n_iter_min = it; break; }
   }

  if (tmpsp->n_iter_min >= 0)
   {
    for (it = tmpsp->n_iter_min; it <= tmpsp->n_iter; it++)
     {
      if (sp[it] != 0) if (it == sp[it]->n_iter)
        { free(sp[it]->xi); free(sp[it]->mask); free(sp[it]); }
      sp[it] = tmpsp;
     }
    if (tmpsp->n_iter < N_ITER)
      sp[tmpsp->n_iter + 1]->n_iter_min = tmpsp->n_iter + 1;
   }
   else { free(tmpsp->xi); free(tmpsp->mask); free(tmpsp); }
  tmpsp = 0;
 }

int main(int argc, char **argv)
 {
  int i, j, flag;
  real rtmp, CPU_time, NOISE_min, TIME_max;
  gsl_rng *RNG;
  char noisefile[256], heightfile[256], timefile[256], buffer[256];
  clock_t start, current;

  gsl_rng_env_setup();
  RNG = gsl_rng_alloc(gsl_rng_ranlxd2);
  gsl_rng_set(RNG, (unsigned long int)time(NULL));

  C.type = Gaussian_channel; C.SNR = 1.;
  C.RNG_grown = 0; grow_RNG(&C);

  H.HiHa_grown = H.ID_grown = 0;
  read_H_matrix("matrix_H", &H);
  grow_iterative_decoder(&H);

  N_ITER = atoi(argv[1]);
  if (N_ITER < 0) N_ITER = 0;
  if (N_ITER > 512) N_ITER = 512;
  ALLOCATE(sp, N_ITER + 1, search_point *, search_point **, "sp");
/* min-sum decoder with checking the output for being a codeword */
  D.n_iter = N_ITER + 1; D.WCC = 1; D.relaxed = 0;

  NOISE_min = atof(argv[2]);
  if (NOISE_min > 0.1) NOISE_min = 0.1;
  if (NOISE_min < 1.e-14) NOISE_min = 1.e-14;

  TIME_max = atof(argv[3]);
  if (TIME_max <= 0.) TIME_max = 1.e+10;

  sprintf(heightfile, "data/height/%s", argv[4]);
  sprintf(noisefile, "data/noise/%s", argv[4]);
  sprintf(timefile, "data/time/%s", argv[4]);

/* initialize a point with xi[] = 1 */
  for (IT = 0; IT <= N_ITER; IT++) sp[IT] = 0;
  CREATE_SEARCH_POINT(tmpsp);
  for (i = 0; i < H.bits; i++) tmpsp->xi[i] = 1.;
  process_tmp_search_point();

  if ((data = fopen(noisefile, "r")) != NULL)
   {
    for (flag = 3; flag == 3;)
     {
      flag = fscanf(data, "%d %le %d", &i, &rtmp, &j);
      if (flag == 3)
       {
        CREATE_SEARCH_POINT(tmpsp);
        tmpsp->xi[0] = rtmp; tmpsp->mask[0] = j;
        for (i = 1; i < H.bits; i++)
          fscanf(data, "%d %le %d", &j, &(tmpsp->xi[i]), &(tmpsp->mask[i]));
        fscanf(data, "%s %le", buffer, &(tmpsp->noise));
        if (tmpsp->noise < NOISE_min) tmpsp->noise = NOISE_min;
        ZEROING_AND_MASKING;
        process_tmp_search_point();
       }
     }
    fclose(data);
   }

  data = fopen(timefile, "a");
  fprintf(data, "\n%e  %22.16e %22.16e  %8.2e %8.2e\n", 0.,
    sp[N_ITER]->height, sp[25]->height, sp[N_ITER]->noise, sp[25]->noise);
  fclose(data);

/*===main_cycle=======================================*/ for (CPU_time = 0.;;) {

  start = clock();
  for (IT = 0; IT <= N_ITER; IT++)
   {
    CREATE_SEARCH_POINT(tmpsp);
/* adding noise to a point */
    for (tmpsp->height = 0., j = 0, i = 0; i < H.bits; i++)
     {
      tmpsp->xi[i] = sp[IT]->xi[i];
      tmpsp->height += (tmpsp->xi[i] * tmpsp->xi[i]);
      if ((tmpsp->mask[i] = sp[IT]->mask[i]) == 1) j++;
     }
    if ((tmpsp->noise = 2. * sp[IT]->noise) > 0.1) tmpsp->noise = 0.1;
    sp[IT]->noise *= 0.999;
    if (sp[IT]->noise < NOISE_min) sp[IT]->noise = NOISE_min;

    rtmp = 1. - j * tmpsp->noise * tmpsp->noise / tmpsp->height;
    if (rtmp <= 0.) rtmp = 0.; else rtmp = sqrt(rtmp);
    for (i = 0; i < H.bits; i++) if (tmpsp->mask[i] == 1)
     {
      tmpsp->xi[i] *= rtmp;
      tmpsp->xi[i] += tmpsp->noise * gsl_ran_gaussian(RNG, 1.);
     }
    ZEROING_AND_MASKING;

    process_tmp_search_point();
   }
  current = clock();
  CPU_time += (((real)(current - start)) / CLOCKS_PER_SEC);
  data = fopen(timefile, "a");
  fprintf(data, "%e  %22.16e %22.16e  %8.2e %8.2e\n", CPU_time,
    sp[N_ITER]->height, sp[25]->height, sp[N_ITER]->noise, sp[25]->noise);
  fclose(data);

/* mark all search points as "not written yet" */
  for (IT = 0; IT <= N_ITER; IT++) sp[IT]->io_flag = 0;

  data = fopen(noisefile, "w");
  for (IT = 0; IT <= N_ITER; IT++) if (sp[IT]->io_flag == 0)
   {
    for (i = 0; i < H.bits; i++)
      if (sp[IT]->mask[i])
        fprintf(data, "%d %22.16e 1\n", i, sp[IT]->xi[i]);
       else fprintf(data, "%d 0. 0\n", i);
    fprintf(data, "#noise %8.2e\n", sp[IT]->noise);
/* mark the search point as "written", so it is not written twice */
    sp[IT]->io_flag = 1;
   }
  fclose(data);

  data = fopen(heightfile, "w");
  for (IT = 0; IT <= N_ITER; IT++)
    fprintf(data, "%3d %22.16e   %8.2e\n", IT, sp[IT]->height, sp[IT]->noise);
  fclose(data);

  if (CPU_time >= TIME_max) break;

/*===end_of_main_cycle======================================================*/ }

  for (IT = 0; IT <= N_ITER; IT++)
   {
    flag = 0;
    if (IT == N_ITER) flag = 1; else if (sp[IT] != sp[IT + 1]) flag = 1;
    if (flag == 1) { free(sp[IT]->xi); free(sp[IT]->mask); free(sp[IT]); }
   }

  kill_RNG(&C);
  kill_iterative_decoder(&H);
  kill_H_matrix(&H);

  return 0;
 }

