/*******************************************************************************
*                                                                              *
*  File:        RSETest.java                            Revision:  1.0         *
*                                                                              *
*  Contents:    a simple tester for RSE en-/decoding based on lookup tables    *                           
*                                                                              *
*  Creation:    14.02.1998                     Last Modification:  14.02.1998  *
*                                                                              *
*  Platform:    Pentium PCI Tower running Windows 95                           *
*                                                                              *
*  Environment: Java 1.1.5                                                     *
*                                                                              *
*  Author:      Andreas Rozek                                                  *
*               Stuttgart University Computer Center                           *
*               Communication Systems & BelWü Development                      *
*               Allmandring 3a                                                 *
*             D-70550 Stuttgart                                                *
*               Germany                                                        *
*                                                                              *
*    Phone:     ++49 (711) 685-4514                                            *
*    Fax:       ++49 (711) 678-8363                                            *
*    EMail:     Andreas.Rozek@RUS.Uni-Stuttgart.De                             *
*                                                                              *
*  Comments:    the program should be invoked as follows:                      *
*                                                                              *
*                 java RSETest <symbolsize> <blocksize> <paritycount>          *
*                                                                              *
*               with the following meanings for the command line arguments:    *
*                                                                              *
*                 <symbolsize>   size of code symbols in bits                  *
*                 <blocksize>    number of symbols per transmission block      *
*                                (must be <= (2**symbolsize)-1)                *
*                                Note: this program is able to use "shortened" *
*                                RSE codes with block sizes < (2^symbolsize)-1 *
*                 <paritycount>  number of parity symbols per block (that many *
*                                symbol losses may be repared by the resulting *
*                                RSE code)                                     *
*                                                                              *
*******************************************************************************/

import java.io.*;

public class RSETest {

/**** variables for encoder and decoder ****/

  static int SymbolSize;                            // number of bits per symbol
  static int BlockSize;                           // number of symbols per block
  static int DataCount, ParityCount;            // number of data/parity symbols

  static int PackingFactor;                 // number of symbols per "int" value

  static int SymbolMask;                    // to mask symbols out of an integer

  static int EncoderTable[], DecoderTable[];     // RSE encoding/decoding tables

/*******************************************************************************
*                                                                              *
* CaseCount                  calculates the number of remaining loss/use cases *
*                                                                              *
*******************************************************************************/

  static int CaseCount (int StartIndex, int SymbolCount, int Size) {
    int Result;                           // temporarily stores the final result
    int Index;                                                  // loop variable

    if (SymbolCount == 1) {
      return Size-StartIndex;                      // "StartIndex" begins with 0
    } else {
      SymbolCount--;
      Result = 0;
        for (Index = StartIndex; Index < Size-SymbolCount; Index++) {
          Result += CaseCount(Index+1, SymbolCount, Size);
        };
      return Result;
    };
  };

/*******************************************************************************
*                                                                              *
* decodeRSEPackets  reconstructs lost packets from data and parity information * 
*                                                                              *
*******************************************************************************/

  static void decodeRSEPackets (int Candidate[][], int LossList[], int ParityList[], int LossCount, int UseCases) {
    int Offset;              // bit shift for accessing a symbol within an "int"
    int DecoderAddress;                             // index into "DecoderTable"
    int Corrector;                                   // result from table lookup

    int i,j,k,l;                                               // loop variables


    int PacketSize = (UseCases+(PackingFactor-1))/PackingFactor;      

  /**** calculate loss case number ****/

    int DataCases   = DataCount;                   // # cases with 1 lost symbol
    int ParityCases = ParityCount;              // # choices for 1 parity symbol
    int LossFactor  = DataCases*ParityCases;
    int LossCase    = 0;
    for (i = 1; i < LossCount; i++) {
      DataCases   =   (DataCases*(DataCount-i))  /(i+1);
      ParityCases = (ParityCases*(ParityCount-i))/(i+1);

      LossCase   += LossFactor;
      LossFactor *= DataCases*ParityCases;
    };

    int Position = 0;                       // start at the beginning of a block
    for (i = 0; i < LossCount-1; i++) {     // all lost symbols but the last one
      for (j = Position; j < LossList[i]; j++) {
        LossCase += CaseCount(j, LossCount-i-1, DataCount);
      };
      Position = LossList[i]+1;                       // prepare for next symbol
    };
    LossCase += LossList[LossCount-1]-Position;      // the last case is trivial

    LossCase *= ParityCases;
    Position = 0;                 // start with the first possible parity symbol
    for (i = 0; i < LossCount-1; i++) {   // all parity symbols but the last one
      for (j = Position; j < ParityList[i]; j++) {
        LossCase += CaseCount(j, LossCount-i-1, ParityCount);
      };
      Position = LossList[i]+1;                       // prepare for next symbol
    };
    LossCase += ParityList[LossCount-1]-Position;    // the last case is trivial

  /**** now reconstruct all symbols of the given packets ****/

    for (i = 0; i < PacketSize; i++) {        // handle all "int"s of any packet
      for (j = 0; j < PackingFactor; j++) {// handle all symbols within an "int"
        Offset = j*SymbolSize;        // used to access a symbol within an "int"

      /**** determine address for table lookup ****/

        DecoderAddress = 0;
        for (k = 0, l = 0; k < DataCount; k++) {
          if (LossList[l] == k) {                // has actual symbol been lost?
            if (l < LossCount-1) l++;
          } else {
            DecoderAddress = (DecoderAddress << SymbolSize) | 
                             ((Candidate[k][i] >> Offset) & SymbolMask);
          };
        };

        for (k = 0; k < LossCount; k++) {
          DecoderAddress = (DecoderAddress << SymbolSize) | 
                           ((Candidate[DataCount+ParityList[k]][i] >> Offset) & SymbolMask);
        };

        DecoderAddress |= (LossCase << DataCount*SymbolSize);

      /**** get (compound) corrector information ****/

        Corrector = DecoderTable[DecoderAddress];

      /**** scatter corrector word into lost packets ****/

        for (k = 0; k < LossCount; k++) {
          Candidate[LossList[k]][i] |= ((Corrector & SymbolMask) << Offset);    
          Corrector >>= SymbolSize;  // shift compound corrector for next packet
        };
      };
    };
  };

/*******************************************************************************
*                                                                              *
* encodeRSEPackets         calculates parity information for the given packets * 
*                                                                              *
*******************************************************************************/

  static void encodeRSEPackets (int Reference[][], int UseCases) {
    int Offset;              // bit shift for accessing a symbol within an "int"
    int EncoderAddress;                             // index into "EncoderTable"
    int Parity;                                      // result from table lookup

    int i,j,k;                                                 // loop variables


    int PacketSize = (UseCases+(PackingFactor-1))/PackingFactor;      

    for (i = 0; i < PacketSize; i++) {        // handle all "int"s of any packet
      for (j = 0; j < PackingFactor; j++) {// handle all symbols within an "int"
        Offset = j*SymbolSize;        // used to access a symbol within an "int"

      /**** determine address for table lookup ****/

        EncoderAddress = 0;
        for (k = 0; k < DataCount; k++) {
          EncoderAddress = (EncoderAddress << SymbolSize) | 
                           ((Reference[k][i] >> Offset) & SymbolMask);
        };

      /**** get (compound) parity information ****/

        Parity = EncoderTable[EncoderAddress];

      /**** scatter parity word into separate packets ****/

        for (k = 0; k < ParityCount; k++) {
          Reference[DataCount+k][i] |= ((Parity & SymbolMask) << Offset);    
          Parity >>= SymbolSize;        // shift compound parity for next packet
        };
      };
    };
  };

/*******************************************************************************
*                                                                              *
* loadRSETable                              loads RSE encoding/decoding tables * 
*                                                                              *
*******************************************************************************/

  static void loadRSETable (int SymbolSize, int BlockSize, int ParityCount) {
    int i,j;                                                   // loop variables

  /**** preserve given arguments (no range checking) ****/

    RSETest.SymbolSize  = SymbolSize;
    RSETest.BlockSize   = BlockSize;
    RSETest.ParityCount = ParityCount;

  /**** calculate missing values ****/

    DataCount = BlockSize-ParityCount;

    SymbolMask  = (1 << SymbolSize)-1;      // to mask symbols out of an integer

    PackingFactor = 32 / SymbolSize;                      // # symbols per "int"

  /**** construct a file name for the desired table and attach that file ****/

    String FileName = "RSE-Table_" + SymbolSize + "_" + BlockSize + "_" + ParityCount + ".bin";

    File InFile = new File(FileName);
    if (!InFile.exists() || !InFile.isFile() || !InFile.canRead()) {
      System.out.println("  unable to load file \"" + FileName + "\"");
      System.exit(10);
    };

    try {
      DataInputStream InStream = new DataInputStream(
        new BufferedInputStream(new FileInputStream(InFile))
      );

    /**** check file "header" ****/

      if ((InStream.readUnsignedByte() != SymbolSize) ||
          (InStream.readUnsignedByte() != BlockSize)  ||
          (InStream.readUnsignedByte() != DataCount)  ||
          (InStream.readUnsignedByte() != ParityCount)) {
        System.out.println("  file \"" + FileName + "\" does not contain a RSE codec table");
        System.out.println("  for SymbolSize = " + SymbolSize +
          ", BlockSize = " + BlockSize + ", ParityCount = " + ParityCount);
        System.exit(12);
      };

    /**** load RSE encoding table ****/

      int EncoderLength = 1 << (DataCount*SymbolSize);
      EncoderTable = new int[EncoderLength];

      int ParityBits = ParityCount*SymbolSize;
      
      if        (ParityBits <=  8) {
        for (i = 0; i < EncoderLength; i++) {
          EncoderTable[i] = InStream.readUnsignedByte();
        };
      } else if (ParityBits <= 16) {
        for (i = 0; i < EncoderLength; i++) {
          EncoderTable[i] = InStream.readUnsignedShort();
        };
      } else {
        for (i = 0; i < EncoderLength; i++) {
          EncoderTable[i] = InStream.readInt();
        };
      };

    /**** load RSE decoding table ****/

      int CaseCount = 0;
      int DataCases = 0, ParityCases = 0;

      for (i = 1; i <= (DataCount < ParityCount ? DataCount : ParityCount); i++) {
        DataCases = DataCount;
        for (j = 1; j < i; j++) DataCases   =   (DataCases * (DataCount-j))   / (j+1);

        ParityCases = ParityCount;
        for (j = 1; j < i; j++) ParityCases = (ParityCases * (ParityCount-j)) / (j+1);

        CaseCount += DataCases*ParityCases;
      };

      int DecoderLength = (1 << DataCount*SymbolSize) * CaseCount;   
      DecoderTable = new int[DecoderLength];
      
      if        (ParityBits <=  8) {
        for (i = 0; i < DecoderLength; i++) {
          DecoderTable[i] = InStream.readUnsignedByte();
        };
      } else if (ParityBits <= 16) {
        for (i = 0; i < DecoderLength; i++) {
          DecoderTable[i] = InStream.readUnsignedShort();
        };
      } else {
        for (i = 0; i < DecoderLength; i++) {
          DecoderTable[i] = InStream.readInt();
        };
      };

      InStream.close();
    } catch (FileNotFoundException Signal) {
      System.out.println("  unable to read from file \"" + FileName + "\"");
      System.out.println("  reason: \"" + Signal.getMessage() + "\"");
      System.exit(11);
    } catch (EOFException Signal) {
      System.out.println("  unexpected EOF in file \"" + FileName + "\"");
      System.out.println("  reason: \"" + Signal.getMessage() + "\"");
      System.exit(12);
    } catch (IOException Signal) {
      System.out.println("  error while reading file \"" + FileName + "\"");
      System.out.println("  reason: \"" + Signal.getMessage() + "\"");
      System.exit(13);
    };
  };

/*******************************************************************************
*                                                                              *
* main                                                            main program * 
*                                                                              *
*******************************************************************************/

  public static void main (String ArgList[]) {
    int SymbolSize;                                 // number of bits per symbol
    int BlockSize;                                // number of symbols per block
    int DataCount, ParityCount;                 // number of data/parity symbols

    int UseCases, LossCases;   // # data combinations/failure cases to be tested
    int DataCases, ParityCases;                                // loop variables

    int PackingFactor;                      // number of symbols per "int" value
    int PacketSize;                         // number of "int" values per packet

    int Reference[][], Candidate[][];              // test and reference packets

    int LossList[], ParityList[];               // list of loss/parity positions
    int LossCount;                                     // number of symbols lost
    int LossBase, LossFactor;     // numbers used to calculate a loss case index
    int LossCase, UseCase;                                               // dto.
    int Position;                   // symbol position when traversing the block

    int i,j,k,l,m;                                             // loop variables


    System.out.println();
    System.out.println("  RSETest - a simple tester for RSE en-/decoding based on lookup tables");
    System.out.println();

  /**** fetch command line parameters ****/

    if (ArgList.length != 3) {
      System.out.println("  usage: java RSETest <symbolsize> <blocksize> <paritycount>");
      System.out.println();
      System.out.println("  <symbolsize>   size of code symbols in bits");
      System.out.println("  <blocksize>    number of symbols per transmission block");
      System.out.println("                 (must be <= (2**symbolsize)-1)");
      System.out.println("                 Note: this program is able to use \"shortened\"");
      System.out.println("                 RSE codes with block sizes < (2^symbolsize)-1");
      System.out.println("  <paritycount>  number of parity symbols per block (that many");
      System.out.println("                 symbol losses may be repared by the resulting");
      System.out.println("                 RSE code)");
      System.out.println();

      System.exit(0);
    };

    SymbolSize = BlockSize = ParityCount = 0;            // satisfy the compiler

    try {
      SymbolSize = Integer.parseInt(ArgList[0]);
    } catch (NumberFormatException Signal) {
      System.out.println("  invalid SymbolSize \"" + ArgList[0] + "\"");
      System.out.println("  please specify a numerical value in the range 2...16");
      System.exit(1);
    };

    try {
      BlockSize = Integer.parseInt(ArgList[1]);
    } catch (NumberFormatException Signal) {
      System.out.println("  invalid BlockSize \"" + ArgList[1] + "\"");
      System.out.println("  please specify a numerical value in the range 2...2**SymbolSize-1");
      System.exit(2);
    };

    try {
      ParityCount = Integer.parseInt(ArgList[2]);
    } catch (NumberFormatException Signal) {
      System.out.println("  invalid ParityCount \"" + ArgList[2] + "\"");
      System.out.println("  please specify a numerical value in the range 2...BlockSize/2");
      System.exit(3);
    };

  /**** now check the given numbers (regardless of any later restrictions) ****/

    if ((SymbolSize < 2) || (SymbolSize > 16)) {
      System.out.println("  illegal SymbolSize \"" + SymbolSize + "\"");
      System.out.println("  please specify a numerical value in the range 2...16");
      System.exit(4);
    };

    if ((BlockSize < 2) || (BlockSize > (1 << SymbolSize)-1)) {
      System.out.println("  illegal BlockSize \"" + BlockSize + "\"");
      System.out.println("  please specify a numerical value in the range 2...2**SymbolSize-1");
      System.exit(5);
    };

    if ((ParityCount < 1) || (ParityCount > BlockSize/2)) {
      System.out.println("  illegal ParityCount \"" + ParityCount + "\"");
      System.out.println("  please specify a numerical value in the range 1...BlockSize-1");
      System.exit(6);
    };

  /**** load RSE encoder/decoder tables ****/

    System.out.println("  loading RSE tables...");
      loadRSETable(SymbolSize, BlockSize, ParityCount);
    System.out.println("  done");
    System.out.println();

  /**** determine number of loss situations to be tested ****/

    DataCount = BlockSize-ParityCount;

    UseCases  = 1 << (SymbolSize*DataCount);       // possible data combinations

    LossCases = 0;
    for (i = 1; i <= (DataCount < ParityCount ? DataCount : ParityCount); i++) {
      DataCases = DataCount;
      for (j = 1; j < i; j++) DataCases   =   (DataCases * (DataCount-j))   / (j+1);

      ParityCases = ParityCount;
      for (j = 1; j < i; j++) ParityCases = (ParityCases * (ParityCount-j)) / (j+1);

      LossCases += DataCases*ParityCases;
    };

  /**** create arrays for test and reference symbols ****/

    PackingFactor = 32 / SymbolSize;                      // # symbols per "int"
    PacketSize    = (UseCases+(PackingFactor-1))/PackingFactor;      // # "int"s

    Reference = new int[BlockSize][];
    for (i = 0; i < BlockSize; i++) Reference[i] = new int[PacketSize];

    Candidate = new int[BlockSize][];
    for (i = 0; i < BlockSize; i++) Candidate[i] = new int[PacketSize];

    LossList   = new int[ParityCount];
    ParityList = new int[ParityCount];

  /**** clear "Reference" ****/

    for (i = 0; i < PacketSize; i++) Reference[0][i] = 0;
    for (i = 1; i < BlockSize;  i++) System.arraycopy(Reference[0],0, Reference[i],0, PacketSize);

  /**** load "Reference" with test data ****/

    for (i = 0; i < UseCases; i++) {
      j = i/PackingFactor;          // which "int" word of a packet is affected?

      for (k = 0; k < DataCount; k++) {  // scatter "i" into consecutive packets
        Reference[DataCount-k-1][j] = (Reference[DataCount-k-1][j] << SymbolSize) |
                                      ((i >> k*SymbolSize) & SymbolMask);
      };
    };

  /**** encode reference data ****/

    encodeRSEPackets(Reference, UseCases);

  /**** load "Candidate" with "Reference" information ****/

    for (i = 0; i < BlockSize; i++) {
      System.arraycopy(Reference[i],0, Candidate[i],0, PacketSize);
    };

  /**** test all possible loss situations ****/

    System.out.println("  testing " + LossCases + " loss case(s)...");
      for (i = 0; i < LossCases; i++) {

      /**** calculate number of lost symbols ****/

        DataCases   = DataCount;                   // # cases with 1 lost symbol
        ParityCases = ParityCount;              // # choices for 1 parity symbol
        LossFactor  = DataCases*ParityCases;
        LossBase    = 0;
        for (LossCount = 1; LossCount <= ParityCount; LossCount++) {
          if (i < LossBase+LossFactor) {       // do we have "LossCount" losses?
            break;
          } else {
            DataCases   =   (DataCases*(DataCount-LossCount))  /(LossCount+1);
            ParityCases = (ParityCases*(ParityCount-LossCount))/(LossCount+1);

            LossBase   += LossFactor;
            LossFactor *= DataCases*ParityCases;
          };
        };

        LossCase = i-LossBase;                           // "normalize" LossCase

      /**** calculate positions of lost symbols ****/

        UseCase  = LossCase % ParityCases;         // isolate parity information
        LossCase = LossCase / ParityCases;           // isolate data information

        Position = 0;                       // start at the beginning of a block
        for (j = 0; j < LossCount-1; j++) { // all lost symbols but the last one
          LossBase = 0;    // step through "LossCase" and test every possibility
          for (k = Position; k < DataCount-(LossCount-1)+j; k++) {
            LossFactor = CaseCount(k+1, LossCount-j-1, DataCount);
            if (LossCase < LossBase+LossFactor) {
              LossList[j] = k;          // remember position of lost data symbol
              Position = k+1;                         // prepare for next symbol
              LossCase -= LossBase;
              break;
            } else {
              LossBase += LossFactor;
            };
          };
        };
        LossList[LossCount-1] = Position+LossCase;   // the last case is trivial

      /**** calculate positions of parity symbols to be used ****/

        LossCase = UseCase;        // poor, but allows to "reuse" LossCase below

        Position = 0;            // start at the beginning of parity information
        for (j = 0; j < LossCount-1; j++) {     // every symbol but the last one
          LossBase = 0;    // step through "LossCase" and test every possibility
          for (k = Position; k < ParityCount-(LossCount-1)+j; k++) {
            LossFactor = CaseCount(k+1, LossCount-j-1, ParityCount);
            if (LossCase < LossBase+LossFactor) {
              ParityList[j] = k;      // remember position of used parity symbol
              Position = k+1;                         // prepare for next symbol
              LossCase -= LossBase;
              break;
            } else {
              LossBase += LossFactor;
            };
          };
        };
        ParityList[LossCount-1] = Position+LossCase;  // the last one is trivial

      /**** destroy "lost" data packets and unused parity packets ****/

        for (j = 0; j < LossCount; j++) {
          for (k = 0; k < PacketSize; k++) Candidate[LossList[j]][k] = 0;
        };

        for (j = 0, k = 0, l = LossCount; j < ParityCount; j++) {
          if (ParityList[k] == j) {       // is actual parity symbol to be used?
            for (m = 0; m < PacketSize; m++) Candidate[DataCount+j][m] = Reference[DataCount+j][m];
            if (k < LossCount-1) k++;
          } else {
            for (m = 0; m < PacketSize; m++) Candidate[DataCount+j][m] = 0;
//          LossList[l] = DataCount+j;                // mark lost parity symbol
//          l++;
          };
        };

      /**** decode damaged "Candidate" ****/

        decodeRSEPackets(Candidate, LossList, ParityList, LossCount, UseCases);

      /**** compare decoded "Candidate" with original "Reference" ****/

        for (j = 0; j < LossCount; j++) {
          for (k = 0; k < PacketSize; k++) {
            if (Candidate[LossList[j]][k] != Reference[LossList[j]][k]) {
              System.out.println("    loss recovery error (LossCase = " + i + ")" +
                " at k = " + k);
              System.out.print  ("      reference data:      ");
                for (l = 0; l < BlockSize-1; l++) {
                  System.out.print(Integer.toHexString(Reference[l][k]) + ",");
                };
                System.out.print(Integer.toHexString(Reference[BlockSize-1][k]) + "\n");
              System.out.print  ("      reconstruction:      ");
                for (l = 0; l < BlockSize-1; l++) {
                  System.out.print(Integer.toHexString(Candidate[l][k]) + ",");
                };
                System.out.print(Integer.toHexString(Candidate[BlockSize-1][k]) + "\n");
              System.out.print  ("      lost data symbols:   ");
                for (l = 0; l < LossCount-1; l++) {
                  System.out.print(LossList[l] + ",");
                };
                System.out.print(LossList[LossCount-1] + "\n");
              System.out.print  ("      used parity symbols: ");
                for (l = 0; l < LossCount-1; l++) {
                  System.out.print(ParityList[l] + ",");
                };
                System.out.print(ParityList[LossCount-1] + "\n");
              System.exit(7);
            };
          };
        };
      };
    System.out.println("  done");
    System.out.println();

    System.out.println("  test successfully passed");

    System.exit(0);
  };
};
