/* -*- Mode: C; c-basic-offset:4 ; -*- */
/*
 * Copyright (c) 2016      The University of Tennessee and The University
 *                         of Tennessee Research Foundation.  All rights
 *                         reserved.
 * $COPYRIGHT$
 *
 * Additional copyrights may follow
 *
 * $HEADER$
 */

#include <stdio.h>
#include <stdlib.h>

#include "ompi_config.h"
#include "ompi/datatype/ompi_datatype.h"
#include "opal/runtime/opal.h"
#include "opal/datatype/opal_convertor.h"
#include "opal/datatype/opal_datatype_internal.h"
#include <arpa/inet.h>

static int verbose = 0;
 
typedef int (*checker_t)(void*, void*, ompi_datatype_t *, int, void*);

int check_contiguous( void* send_buffer, void* packed,
                      ompi_datatype_t * datatype, int count, void* arg );

int check_vector( void* send_buffer, void* packed,
                  ompi_datatype_t * datatype, int count, void* arg );

static int pack_unpack_datatype( void* send_data, ompi_datatype_t *datatype, int count,
                                 void* recv_data, checker_t validator, void *validator_arg );

static void
dump_hex(const char* msg, const void* vbuf, int nbytes,
         int start_from, int stop_at, int vals_per_line)
{
    const char* buf = (const char*)vbuf;

    if( -1 == stop_at ) stop_at = nbytes;

    for (int i = (start_from / vals_per_line) * vals_per_line; i < nbytes; ++i) {
        if( i >= stop_at ) return;
        if (0 == (i % vals_per_line)) {
            if( NULL == msg) printf("\n");
            else printf("\n%s", msg);
        } else {
            if (i % 4 == 0) {
                printf("  ");
            }
        }
        printf(" ");
        if( i < start_from )
            printf("  ");
        else
            printf("%02x", *((unsigned char *)(buf + i)));
    }
}

int check_contiguous( void* send_buffer, void* packed,
                      ompi_datatype_t* datatype, int count, void* arg )
{
    int i;

    if( (datatype == &ompi_mpi_int.dt) || (datatype == &ompi_mpi_int32_t.dt) ) {
        uint32_t val;
        for( i = 0 ; i < count; i++ ) {
            val = htonl(((uint32_t*)send_buffer)[i]);
            if( val != ((uint32_t*)packed)[i] ) {
                printf("Error at position %d expected %x found %x (type %s)\n",
                       i, ((uint32_t*)packed)[i], ((uint32_t*)send_buffer)[i], datatype->name);
                return -1;
            }
        }
    } else if( (datatype == &ompi_mpi_short.dt) || (datatype == &ompi_mpi_int16_t.dt) ) {
        uint16_t val;
        for( i = 0 ; i < count; i++ ) {
            val = htons(((uint16_t*)send_buffer)[i]);
            if( val != ((uint16_t*)packed)[i] ) {
                printf("Error at position %d expected %x found %x (type %s)\n",
                       i, ((uint16_t*)packed)[i], ((uint16_t*)send_buffer)[i], datatype->name);
                return -1;
            }
        }
    } else {
        printf("Unknown type\n");
        return -1;
    }
    return 0;
}

int check_vector( void* send_buffer, void* packed,
                  ompi_datatype_t* datatype, int count, void* arg )
{
    int i;
    ompi_datatype_t *origtype = (ompi_datatype_t *)arg;

    if( (origtype == &ompi_mpi_int.dt) || (origtype == &ompi_mpi_int32_t.dt) ) {
        uint32_t val;
        for( i = 0 ; i < count; i++ ) {
            val = htonl(((uint32_t*)send_buffer)[2*i]);
            if( val != ((uint32_t*)packed)[i] ) {
                printf("Error at position %d expected %x found %x (type %s)\n",
                       i, ((uint32_t*)packed)[i], ((uint32_t*)send_buffer)[2*i], datatype->name);
                return -1;
            }
        }
    } else if( (origtype == &ompi_mpi_short.dt) || (origtype == &ompi_mpi_int16_t.dt) ) {
        uint16_t val;
        for( i = 0 ; i < count; i++ ) {
            val = htons(((uint16_t*)send_buffer)[2*i]);
            if( val != ((uint16_t*)packed)[i] ) {
                printf("Error at position %d expected %x found %x (type %s)\n",
                       i, ((uint16_t*)packed)[i], ((uint16_t*)send_buffer)[2*i], datatype->name);
                return -1;
            }
        }
    } else {
        printf("Unknown %s type\n", datatype->name);
        return -1;
    }
    return 0;
}

static int pack_unpack_datatype( void* send_data, ompi_datatype_t *datatype, int count,
                                 void* recv_data,
                                 checker_t validator, void* validator_arg)
{
    MPI_Aint position = 0, buffer_size;
    void* buffer;
    int error;

    error = ompi_datatype_pack_external_size("external32",
                                             count, datatype, &buffer_size);
    if( MPI_SUCCESS != error ) goto return_error_code;

    buffer = (void*)malloc(buffer_size);
    if( NULL == buffer ) { error = MPI_ERR_UNKNOWN; goto return_error_code; }

    error = ompi_datatype_pack_external("external32", (void*)send_data, count, datatype,
                                        buffer, buffer_size, &position);
    if( MPI_SUCCESS != error ) goto return_error_code;
    if( 0 != validator(send_data, buffer, datatype, count, validator_arg) ) {
        printf("Error during pack external. Bailing out\n");
        return -1;
    }

    printf("packed %ld bytes into a %ld bytes buffer ", position, buffer_size);
    dump_hex(NULL, buffer, position, 0, -1, 24); printf("\n");

    position = 0;
    error = ompi_datatype_unpack_external("external32", buffer, buffer_size, &position,
                                          recv_data, count, datatype);
    if( MPI_SUCCESS != error ) goto return_error_code;
    free(buffer);

 return_error_code:
    return (error == MPI_SUCCESS ? 0 : -1);
}

int main(int argc, char *argv[])
{
    opal_init_util(&argc, &argv);
    ompi_datatype_init();

    /* Simple contiguous data: MPI_INT32_T */
    {
        int32_t send_data[2] = {1234, 5678};
        int32_t recv_data[2] = {-1, -1};

        if( verbose ) {
            printf("send data %08x %08x \n", send_data[0], send_data[1]);
            printf("data ");
            dump_hex(NULL, &send_data, sizeof(int32_t) * 2, 0, -1, 24); printf("\n");
        }
        (void)pack_unpack_datatype( send_data, &ompi_mpi_int32_t.dt, 2,
                                    recv_data, check_contiguous, (void*)&ompi_mpi_int32_t.dt );
        if( verbose ) {
            printf("recv ");
            dump_hex(NULL, &recv_data, sizeof(int32_t) * 2, 0, -1, 24); printf("\n");
            printf("recv data %08x %08x \n", recv_data[0], recv_data[1]);
        }
        if( (send_data[0] != recv_data[0]) || (send_data[1] != recv_data[1]) ) {
            printf("Error during external32 pack/unack for contiguous types (MPI_INT32_T)\n");
            exit(-1);
        }
    }
    /* Simple contiguous data: MPI_INT16_T */
    {
        int16_t send_data[2] = {1234, 5678};
        int16_t recv_data[2] = {-1, -1};

        if( verbose ) {
            printf("send data %08x %08x \n", send_data[0], send_data[1]);
            printf("data ");
            dump_hex(NULL, &send_data, sizeof(int16_t) * 2, 0, -1, 24); printf("\n");
        }
        (void)pack_unpack_datatype( send_data, &ompi_mpi_int16_t.dt, 2,
                                    recv_data, check_contiguous, (void*)&ompi_mpi_int16_t.dt );
        if( verbose ) {
            printf("recv ");
            dump_hex(NULL, &recv_data, sizeof(int16_t) * 2, 0, -1, 24); printf("\n");
            printf("recv data %08x %08x \n", recv_data[0], recv_data[1]);
        }
        if( (send_data[0] != recv_data[0]) || (send_data[1] != recv_data[1]) ) {
            printf("Error during external32 pack/unack for contiguous types\n");
            exit(-1);
        }
    }

    /* Vector datatype */
    printf("\n\nVector datatype\n\n");
    {
        int count=2, blocklength=1, stride=2;
        int send_data[3] = {1234, 0, 5678};
        int recv_data[3] = {-1, -1, -1};
        ompi_datatype_t *ddt;

        ompi_datatype_create_vector ( count, blocklength, stride, &ompi_mpi_int.dt, &ddt );
        {
            const int* a_i[3] = {&count, &blocklength, &stride};
            ompi_datatype_t *type = &ompi_mpi_int.dt;

            ompi_datatype_set_args( ddt, 3, a_i, 0, NULL, 1, &type, MPI_COMBINER_VECTOR );
        }
        ompi_datatype_commit(&ddt);

        if( verbose ) {
            printf("send data %08x %x08x %08x \n", send_data[0], send_data[1], send_data[2]);
            printf("data "); dump_hex(NULL, &send_data, sizeof(int32_t) * 3, 0, -1, 24); printf("\n");
        }
        (void)pack_unpack_datatype( send_data, ddt, 1, recv_data, check_vector, (void*)&ompi_mpi_int32_t.dt );
        if( verbose ) {
            printf("recv "); dump_hex(NULL, &recv_data, sizeof(int32_t) * 3, 0, -1, 24); printf("\n");
            printf("recv data %08x %08x %08x \n", recv_data[0], recv_data[1], recv_data[2]);
        }
        ompi_datatype_destroy(&ddt);
        if( (send_data[0] != recv_data[0]) || (send_data[2] != recv_data[2]) ) {
            printf("Error during external32 pack/unack for vector types (MPI_INT32_T)\n");
            printf("[0]: %d ? %d  |  [2]: %d ? %d  ([1]: %d ? %d)\n", send_data[0], recv_data[0],
                   send_data[2], recv_data[2], send_data[1], recv_data[1]);
            exit(-1);
        }
    }

    ompi_datatype_finalize();

    return 0;
}