package org.molap.exporter.parquet

import org.apache.parquet.column.ColumnReadStore
import org.apache.parquet.column.impl.ColumnReadStoreImpl
import org.apache.parquet.column.page.PageReadStore
import org.apache.parquet.example.data.simple.convert.GroupRecordConverter
import org.apache.parquet.hadoop.ParquetFileReader
import org.apache.parquet.schema.OriginalType
import org.apache.parquet.schema.PrimitiveType
import org.molap.dataframe.AbstractDataFrame
import org.molap.dataframe.DataFrame
import org.molap.dataframe.DefaultDataFrame
import org.molap.dataframe.IndexedDataFrame
import org.molap.index.UniqueIndex
import org.molap.series.Series
import org.molap.series.SeriesFactory
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.UnsupportedEncodingException
import java.math.BigDecimal
import java.math.BigInteger
import java.nio.ByteOrder
import java.sql.SQLException
import java.time.*
import kotlin.reflect.KClass

class ParquetDataFrame constructor(bis: InputStream) : AbstractDataFrame<Int, String, Any?>() {
    val df = loadDataFrameFromInputStream(bis)

    override fun getRowClass(row: Int): KClass<*>? {
        return df!!.getRowClass(row)
    }

    override fun getColumnClass(column: String): KClass<*> {
        return df!!.getColumnClass(column)
    }

    override fun getValueAt(row: Int, column: String): Any? {
        return df!!.getValueAt(row, column)
    }

    override val rowIndex: UniqueIndex<Int>
        get() = df!!.rowIndex
    override val columnIndex: UniqueIndex<String>
        get() = df!!.columnIndex

    @Throws(IOException::class, SQLException::class)
    fun loadDataFrameFromInputStream(bis: InputStream?): DataFrame<Int, String, Any?>? {
        val stream = ByteArrayOutputStream()
        stream.write(bis!!.readAllBytes())
        ParquetFileReader.open(ParquetStream("parquet", stream)).use { reader ->
            val footer = reader.footer
            val schema = footer.fileMetaData.schema

            //            SchemaConverter converter = new SchemaConverter();
//            SchemaMapping mapping = converter.fromParquet(schema);
//            System.err.println(mapping.);

//            mapping.getArrowSchema().

//            StructType struct = toSmileSchema(schema);
//            System.err.println("The meta data of parquet file " + file.toString() + ": " + ParquetMetadata.toPrettyJSON(footer));
            val nrows = reader.recordCount
            //            List<Tuple> rows = new ArrayList<>(nrows);
//
            var store: PageReadStore
            while ((reader.readNextRowGroup().also { store = it }) != null) {
                val rowCount = store.rowCount

                val colReader: ColumnReadStore = ColumnReadStoreImpl(
                    store,
                    GroupRecordConverter(schema).rootConverter,
                    schema,
                    footer.fileMetaData.createdBy
                )

                val columns = schema.columns
                val vectors: Array<Series<Int,Any?>?> = arrayOfNulls<Series<Int, Any?>>(columns.size)

                for (c in columns.indices) {
                    val column = columns[c]

                    val primitiveType = column.primitiveType
                    val originalType = primitiveType.originalType

                    val path = column.path
                    val name = path[path.size - 1]

                    val columnReader = colReader.getColumnReader(column)
                    val count = columnReader.totalValueCount.toInt()
                    val maxDefinitionLevel = column.maxDefinitionLevel

                    val rep = columnReader.currentRepetitionLevel

                    when (primitiveType.primitiveTypeName) {
                        PrimitiveType.PrimitiveTypeName.BOOLEAN -> if (rep >= 0 && rep <= 1) {
                            val a = arrayOfNulls<Boolean>(count)
                            var i = 0
                            while (i < count) {
                                val definitionLevel = columnReader.currentDefinitionLevel
                                if (definitionLevel == maxDefinitionLevel) {
                                    a[i] = columnReader.boolean
                                }
                                columnReader.consume()
                                i++
                            }

                            vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                        } else if (rep > 1) {
                            val a = BooleanArray(rep)
                            //                                for (int j = 0; j < rep; j++)
//                                    a[j] = g.getBoolean(i, j);
//                                o[i] = a;
                        }

                        PrimitiveType.PrimitiveTypeName.INT32 -> {
                            if (originalType == null) {
                                if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<Int>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            a[i] = columnReader.integer
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                } else if (rep > 1) {
//                                    int[] a = new int[rep];
//                                    for (int j = 0; j < rep; j++)
//                                        a[j] = g.getInteger(i, j);
//                                    o[i] = a;
                                }
                            }

                            when (originalType) {
                                OriginalType.INT_8 -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<Byte>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            a[i] = columnReader.integer.toByte()
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                } else if (rep > 1) {
//                                        byte[] a = new byte[rep];
//                                        for (int j = 0; j < rep; j++)
//                                            a[j] = (byte) g.getInteger(i, j);
//                                        o[i] = a;
                                }

                                OriginalType.UINT_8, OriginalType.INT_16 -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<Short>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            a[i] = columnReader.integer.toShort()
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                } else if (rep > 1) {
//                                        short[] a = new short[rep];
//                                        for (int j = 0; j < rep; j++)
//                                            a[j] = (short) g.getInteger(i, j);
//                                        o[i] = a;
                                }

                                OriginalType.UINT_16, OriginalType.INT_32 -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<Int>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            a[i] = columnReader.integer
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                } else if (rep > 1) {
//                                        int[] a = new int[rep];
//                                        for (int j = 0; j < rep; j++)
//                                            a[j] = g.getInteger(i, j);
//                                        o[i] = a;
                                }

                                OriginalType.DECIMAL -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<BigDecimal>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            val unscaledValue = columnReader.integer
                                            val decimalMetadata = primitiveType.decimalMetadata
                                            val scale = decimalMetadata.scale
                                            a[i] = BigDecimal.valueOf(unscaledValue.toLong(), scale)
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                }

                                OriginalType.DATE -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<LocalDate>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            val days = columnReader.integer
                                            a[i] = LocalDate.ofEpochDay(days.toLong())
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                }

                                OriginalType.TIME_MILLIS -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<LocalTime>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            val millis = columnReader.integer
                                            a[i] = LocalTime.ofNanoOfDay((millis * 1000000).toLong())
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                }

                                OriginalType.MAP -> TODO()
                                OriginalType.LIST -> TODO()
                                OriginalType.UTF8 -> TODO()
                                OriginalType.MAP_KEY_VALUE -> TODO()
                                OriginalType.ENUM -> TODO()
                                OriginalType.TIME_MICROS -> TODO()
                                OriginalType.TIMESTAMP_MILLIS -> TODO()
                                OriginalType.TIMESTAMP_MICROS -> TODO()
                                OriginalType.UINT_32 -> TODO()
                                OriginalType.UINT_64 -> TODO()
                                OriginalType.INT_64 -> TODO()
                                OriginalType.JSON -> TODO()
                                OriginalType.BSON -> TODO()
                                OriginalType.INTERVAL -> TODO()
                            }
                        }

                        PrimitiveType.PrimitiveTypeName.INT64 -> {
                            if (originalType == null) {
                                if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<Long>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            a[i] = columnReader.long
                                        }
                                        columnReader.consume()
                                        vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                        i++
                                    }
                                } else if (rep > 1) {
//                                    long[] a = new long[rep];
//                                    for (int j = 0; j < rep; j++)
//                                        a[j] = g.getLong(i, j);
//                                    o[i] = a;
                                }
                            }

                            when (originalType) {
                                OriginalType.INT_64 -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<Long>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            a[i] = columnReader.long
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                } else if (rep > 1) {
//                                        long[] a = new long[rep];
//                                        for (int j = 0; j < rep; j++)
//                                            a[j] = g.getLong(i, j);
//                                        o[i] = a;
                                }

                                OriginalType.DECIMAL -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<BigDecimal>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            val unscaledValue = columnReader.long
                                            val decimalMetadata = primitiveType.decimalMetadata
                                            val scale = decimalMetadata.scale
                                            a[i] = BigDecimal.valueOf(unscaledValue, scale)
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                }

                                OriginalType.TIME_MICROS -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<LocalTime>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            val micros = columnReader.long
                                            a[i] = LocalTime.ofNanoOfDay(micros * 1000)
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                }

                                OriginalType.TIMESTAMP_MILLIS -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<LocalDateTime>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            val millis = columnReader.long
                                            a[i] = LocalDateTime.ofInstant(Instant.ofEpochMilli(millis), ZoneOffset.UTC)
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                }

                                OriginalType.TIMESTAMP_MICROS -> if (rep >= 0 && rep <= 1) {
                                    val a = arrayOfNulls<LocalDateTime>(count)
                                    var i = 0
                                    while (i < count) {
                                        val definitionLevel = columnReader.currentDefinitionLevel
                                        if (definitionLevel == maxDefinitionLevel) {
                                            val micros = columnReader.long
                                            val second = micros / 1000000
                                            val nano = (micros % 1000000).toInt() * 1000
                                            a[i] = LocalDateTime.ofEpochSecond(second, nano, ZoneOffset.UTC)
                                        }
                                        columnReader.consume()
                                        i++
                                    }
                                    vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                                }

                                OriginalType.MAP -> TODO()
                                OriginalType.LIST -> TODO()
                                OriginalType.UTF8 -> TODO()
                                OriginalType.MAP_KEY_VALUE -> TODO()
                                OriginalType.ENUM -> TODO()
                                OriginalType.DATE -> TODO()
                                OriginalType.TIME_MILLIS -> TODO()
                                OriginalType.UINT_8 -> TODO()
                                OriginalType.UINT_16 -> TODO()
                                OriginalType.UINT_32 -> TODO()
                                OriginalType.UINT_64 -> TODO()
                                OriginalType.INT_8 -> TODO()
                                OriginalType.INT_16 -> TODO()
                                OriginalType.INT_32 -> TODO()
                                OriginalType.JSON -> TODO()
                                OriginalType.BSON -> TODO()
                                OriginalType.INTERVAL -> TODO()
                            }
                        }

                        PrimitiveType.PrimitiveTypeName.INT96 -> if (rep >= 0 && rep <= 1) {
                            val a = arrayOfNulls<LocalDateTime>(count)
                            var i = 0
                            while (i < count) {
                                val definitionLevel = columnReader.currentDefinitionLevel
                                if (definitionLevel == maxDefinitionLevel) {
                                    val buf = columnReader.binary.toByteBuffer().order(ByteOrder.LITTLE_ENDIAN)
                                    val nanoOfDay = buf.getLong()
                                    val julianDay = buf.getInt()
                                    // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian
                                    // it's 2440587.5, rounding up to compatible with Hive
                                    val date = LocalDate.ofEpochDay((julianDay - 2440588).toLong())
                                    val time = LocalTime.ofNanoOfDay(nanoOfDay)
                                    a[i] = LocalDateTime.of(date, time)
                                }
                                columnReader.consume()
                                i++
                            }
                            vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                        }

                        PrimitiveType.PrimitiveTypeName.FLOAT -> if (rep >= 0 && rep <= 1) {
                            val a = arrayOfNulls<Float>(count)
                            var i = 0
                            while (i < count) {
                                val definitionLevel = columnReader.currentDefinitionLevel
                                if (definitionLevel == maxDefinitionLevel) {
                                    a[i] = columnReader.float
                                }
                                columnReader.consume()
                                i++
                            }

                            vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                        } else if (rep > 1) {
//                                float[] a = new float[rep];
//                                for (int j = 0; j < rep; j++)
//                                    a[j] = g.getFloat(i, j);
//                                o[i] = a;
                        }

                        PrimitiveType.PrimitiveTypeName.DOUBLE -> {
                            if (rep >= 0 && rep <= 1) {
                                val a = arrayOfNulls<Double>(count)
                                var i = 0
                                while (i < count) {
                                    val definitionLevel = columnReader.currentDefinitionLevel
                                    if (definitionLevel == maxDefinitionLevel) {
                                        a[i] = columnReader.double
                                    }
                                    columnReader.consume()
                                    i++
                                }

                                vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                            }
                        }

                        PrimitiveType.PrimitiveTypeName.BINARY, PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY -> when (originalType) {
                            OriginalType.UTF8 -> if (rep >= 0 && rep <= 1) {
                                val a = arrayOfNulls<String>(count)
                                var i = 0
                                while (i < count) {
                                    val definitionLevel = columnReader.currentDefinitionLevel
                                    if (definitionLevel == maxDefinitionLevel) {
                                        a[i] = columnReader.binary.toStringUsingUTF8()
                                    }
                                    columnReader.consume()
                                    i++
                                }
                                vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                            } else if (rep > 1) {
                                throw UnsupportedEncodingException("Unsuported repition level $rep")
                            }

                            OriginalType.DECIMAL -> if (rep >= 0 && rep <= 1) {
                                val a = arrayOfNulls<BigDecimal>(count)
                                var i = 0
                                while (i < count) {
                                    val definitionLevel = columnReader.currentDefinitionLevel
                                    if (definitionLevel == maxDefinitionLevel) {
                                        val value = columnReader.binary.bytes
                                        val decimalMetadata = primitiveType.decimalMetadata
                                        val scale = decimalMetadata.scale
                                        a[i] = BigDecimal(BigInteger(value), scale)
                                    }
                                    columnReader.consume()
                                    i++
                                }
                                vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                            }

                            else -> if (rep >= 0 && rep <= 1) {
                                val a = arrayOfNulls<ByteArray>(count)
                                var i = 0
                                while (i < count) {
                                    val definitionLevel = columnReader.currentDefinitionLevel
                                    if (definitionLevel == maxDefinitionLevel) {
                                        val value = columnReader.binary.bytes
                                        val decimalMetadata = primitiveType.decimalMetadata
                                        val scale = decimalMetadata.scale
                                        a[i] = columnReader.binary.bytes
                                    }
                                    columnReader.consume()
                                    i++
                                }
                                vectors[c] = SeriesFactory.create(name, a) as Series<Int,Any?>
                            }
                        }

                        else -> {
                            System.err.println("Unsupported type $primitiveType, $originalType")
                            var r = 0
                            while (r < rowCount) {
                                columnReader.consume()
                                r++
                            }
                        }
                    }
                }


                val x: Array<Series<Int, Any?>> = vectors.requireNoNulls()
                val dataFrame = IndexedDataFrame<Int,String,Any?>(x)
                return DefaultDataFrame(dataFrame as DataFrame<Any?,Any?,Any?>)
            }
        }
        return null
    }

}