/*
 * Copyright (c) 2013 Macrofocus GmbH. All Rights Reserved.
 */
package com.macrofocus.high_d.parallelcoordinates.layout

import com.macrofocus.hierarchy.Hierarchy
import com.macrofocus.high_d.axis.Alignment
import com.macrofocus.high_d.axis.AxisModel
import com.macrofocus.high_d.axis.group.AxisGroupModel
import com.macrofocus.high_d.parallelcoordinates.ParallelCoordinatesModel
import com.macrofocus.high_d.parallelcoordinates.ParallelCoordinatesView
import org.mkui.geom.Rectangle2D
import kotlin.math.max

class DefaultParallelCoordinatesLayoutEngine<Row, Column>(view: ParallelCoordinatesView<Row, Column>) : ParallelCoordinatesLayoutEngine<Row,Column> {
    private val view: ParallelCoordinatesView<Row, Column>
    override fun doLayout(
        axisPreferredWidth: Int,
        axisBeforeTrackGap: Int,
        axisAfterTrackGap: Int
    ): ParallelCoordinatesLayout<Row, Column> {
        val verticalLayout: Map<AxisGroupModel<Row, Column>, Rectangle2D> = doVerticalLayout()
        val horizontalLayout: Map<AxisGroupModel<Row, Column>, Rectangle2D> = doHorizontalLayout()
        val axisGroupsPerLevels: List<MutableList<AxisGroupModel<Row, Column>>> = doAxisGroupsPerLevel()
        val levelsLayout: Map<Int, Rectangle2D> = doLevelsLayout(axisGroupsPerLevels, verticalLayout)
        return object : ParallelCoordinatesLayout<Row, Column> {
            override fun getAxisPreferredWidth(): Int {
                return axisPreferredWidth
            }

            override fun getAxisBeforeTrackGap(): Int {
                return axisBeforeTrackGap
            }

            override fun getAxisAfterTrackGap(): Int {
                return axisAfterTrackGap
            }

            override fun getBounds(axisGroup: AxisGroupModel<Row, Column>): Rectangle2D {
                return verticalLayout[axisGroup]!!.createIntersection(horizontalLayout[axisGroup]!!)
            }

            override fun getBounds(axisGroup: AxisGroupModel<Row, Column>, axisModel: AxisModel<Row, Column>): Rectangle2D {
                val groupBounds: Rectangle2D = getBounds(axisGroup)
                val position = getPosition(axisGroup, axisModel)
                val x: Double = groupBounds.x + groupBounds.width * position
                var width = 0.0
                when (view.getAlignment().value) {
                    Alignment.Left -> width = groupBounds.width / axisGroup.visibleAxisCount
                    Alignment.Center -> width = groupBounds.width / axisGroup.visibleAxisCount
                    Alignment.Fill -> width = groupBounds.width / (axisGroup.visibleAxisCount - 1)
                    Alignment.Right -> width = groupBounds.width / axisGroup.visibleAxisCount
                }
                //                System.err.println(axisModel + ": " + position + ": " + x + ", " + width);
                //                System.err.println(axisModel + ": " + position + ", " + bounds);
                return Rectangle2D.Double(x, groupBounds.y, width, groupBounds.height)
            }

            override fun getHeaderBounds(axisGroup: AxisGroupModel<Row, Column>): Rectangle2D {
                val bounds: Rectangle2D = getBounds(axisGroup)
                return Rectangle2D.Double(
                    bounds.x,
                    bounds.y,
                    bounds.width,
                    if (view.getAxisGroupView(axisGroup) == null || view.model!!.getAxisHierarchy().axisGroupHierarchy.isRoot(axisGroup)
                    ) 0.0 else view.getHeaderAxisGroupMaximumHeight().toDouble()
                )
            }

            override fun getHeaderBounds(axisGroup: AxisGroupModel<Row, Column>, axisModel: AxisModel<Row, Column>): Rectangle2D? {
                val axisGroupBounds: Rectangle2D = getHeaderBounds(axisGroup)
                val axisBounds: Rectangle2D = getBounds(axisGroup, axisModel)
                when (view.getAlignment().value) {
                    Alignment.Left -> return Rectangle2D.Double(axisBounds.x, axisGroupBounds.maxY, axisBounds.width, view.getHeaderAxisMaximumHeight().toDouble())
                    Alignment.Center -> return Rectangle2D.Double(
                        axisBounds.x - axisBounds.width / 2.0,
                        axisGroupBounds.maxY,
                        axisBounds.width,
                        view.getHeaderAxisMaximumHeight().toDouble()
                    )
                    Alignment.Fill -> {
                        val shift: Double
                        val width: Double
                        val p = getPosition(axisGroup, axisModel)
                        if (p == 0.0) {
                            shift = 0.0
                            width = axisBounds.width / 2.0
                        } else if (p == 1.0) {
                            shift = axisBounds.width / 2.0
                            width = axisBounds.width / 2.0
                        } else {
                            shift = axisBounds.width / 2.0
                            width = axisBounds.width
                        }
                        return Rectangle2D.Double(axisBounds.x - shift, axisGroupBounds.maxY, width, view.getHeaderAxisMaximumHeight().toDouble())
                    }
                    Alignment.Right -> return Rectangle2D.Double(
                        axisBounds.x - axisBounds.width,
                        axisGroupBounds.maxY,
                        axisBounds.width,
                        view.getHeaderAxisMaximumHeight().toDouble()
                    )
                }
                return null
            }

            override fun getSliderBounds(axisGroup: AxisGroupModel<Row, Column>, axisModel: AxisModel<Row, Column>): Rectangle2D? {
                val headerBounds: Rectangle2D = getHeaderBounds(axisGroup, axisModel)!!
                val axisBounds: Rectangle2D = getBounds(axisGroup, axisModel)
                when (view.getAlignment().value) {
                    Alignment.Left -> return Rectangle2D.Double(
                        axisBounds.x,
                        headerBounds.maxY,
                        axisPreferredWidth.toDouble(),
                        axisBounds.height - (headerBounds.maxY - axisBounds.y)
                    )
                    Alignment.Center -> return Rectangle2D.Double(
                        axisBounds.x - axisPreferredWidth / 2.0,
                        headerBounds.maxY,
                        axisPreferredWidth.toDouble(),
                        axisBounds.height - (headerBounds.maxY - axisBounds.y)
                    )
                    Alignment.Fill -> {
                        val shift: Int
                        shift = if (axisBounds.x == 0.0) {
                            0
                        } else {
                            axisPreferredWidth
                        }
                        return Rectangle2D.Double(
                            axisBounds.x - shift,
                            headerBounds.maxY,
                            axisPreferredWidth.toDouble(),
                            axisBounds.height - (headerBounds.maxY - axisBounds.y)
                        )
                    }
                    Alignment.Right -> return Rectangle2D.Double(
                        axisBounds.x - axisPreferredWidth,
                        headerBounds.maxY,
                        axisPreferredWidth.toDouble(),
                        axisBounds.height - (headerBounds.maxY - axisBounds.y)
                    )
                }
                return null
            }

            override fun getTrackBounds(axisGroup: AxisGroupModel<Row, Column>, axisModel: AxisModel<Row, Column>): Rectangle2D {
                val sliderBounds: Rectangle2D = getSliderBounds(axisGroup, axisModel)!!
                return Rectangle2D.Double(
                    sliderBounds.x,
                    sliderBounds.y + axisBeforeTrackGap,
                    sliderBounds.width,
                    sliderBounds.height - axisBeforeTrackGap - axisAfterTrackGap
                )
            }

            override fun getLevelBounds(level: Int): Rectangle2D {
                return levelsLayout[level]!!
            }

            override fun getLevelCount(): Int {
                return axisGroupsPerLevels.size
            }

            override fun getAxisGroups(level: Int): Iterable<AxisGroupModel<Row, Column>> {
                return axisGroupsPerLevels[level]
            }
        }
    }

    private fun doHorizontalLayout(): Map<AxisGroupModel<Row, Column>, Rectangle2D> {
        val horizontalLayout: MutableMap<AxisGroupModel<Row, Column>, Rectangle2D> = HashMap<AxisGroupModel<Row, Column>, Rectangle2D>()
        val model: ParallelCoordinatesModel<Row, Column>? = view.model
        val hierarchy: Hierarchy<AxisGroupModel<Row, Column>> = model!!.getAxisHierarchy().axisGroupHierarchy

        // Sum up the number of axis of each leave
        var sumAxisCount = 0
        var sumRelativeCount = 0
        var sumFixedCount = 0
        var sumFixed = 0
        for (axisGroup in hierarchy.leavesIterator()) {
            sumAxisCount += axisGroup.visibleAxisCount
            if (axisGroup.isCollapsed) {
                sumFixed += axisGroup.visibleAxisCount * 4
                sumFixedCount += axisGroup.visibleAxisCount
            } else {
                sumRelativeCount += axisGroup.visibleAxisCount
            }
        }

        // Sum up the number of fixed groups

        // Compute the initial position along the x axis
        var x = 0.0
        //        switch (model.getSettings().getAlignment()) {
//            case Left:
//                x = 0;
//                break;
//            case Center:
//                x = view.width / (sumAxisCount * 2);
//                break;
//            case Fill:
//                x = 0;
//                break;
//        }

        // Assign the horizontal positions to each leaves
        for (axisGroup in hierarchy.leavesIterator()) {
            val width: Double
            width = if (axisGroup.isCollapsed) {
                axisGroup.visibleAxisCount * 4.0
            } else {
                (view.getWidth() - sumFixed) * axisGroup.visibleAxisCount / (sumAxisCount - sumFixedCount).toDouble()
            }
            val value: Rectangle2D.Double = Rectangle2D.Double(x, 0.0, width, view.getHeight().toDouble())
            horizontalLayout[axisGroup] = value
            x += width
        }

        // Propagate the horizontal positions higher up in the hierarchy
        for (axisGroup in hierarchy.depthFirstIterator()) {
            val parent: AxisGroupModel<Row, Column>? = hierarchy.getParent(axisGroup)
            if (parent != null) {
                val groupBounds: Rectangle2D? = horizontalLayout[axisGroup]
                if (!horizontalLayout.containsKey(parent)) {
                    horizontalLayout[parent] = groupBounds!!
                } else {
                    horizontalLayout[parent] = horizontalLayout[parent]!!.createUnion(groupBounds!!)
                }
            }
        }
        return horizontalLayout
    }

    fun getPosition(axisGroup: AxisGroupModel<Row, Column>, axisModel: AxisModel<Row, Column>): Double {
        return view.model!!.getAxisLocations(axisGroup)!!.getLocation(view.getAlignment().value, axisModel)
    }

    private fun doVerticalLayout(): Map<AxisGroupModel<Row, Column>, Rectangle2D> {
        val verticalLayout: MutableMap<AxisGroupModel<Row, Column>, Rectangle2D> = HashMap<AxisGroupModel<Row, Column>, Rectangle2D>()
        val model: ParallelCoordinatesModel<Row, Column> = view.model!!
        val hierarchy: Hierarchy<AxisGroupModel<Row, Column>> = model.getAxisHierarchy().axisGroupHierarchy
        val depth: Int = hierarchy.depth
        val layoutSizes: MutableList<VerticalLayoutSize> = ArrayList<VerticalLayoutSize>(depth + 1)
        for (i in 0 until depth + 1) {
            layoutSizes.add(VerticalLayoutSize())
        }

        // Find out for each level the amount of fixed and relative space needed
        for (axisGroup in hierarchy.breadthFirstIterator()) {
            val level = getLevel(hierarchy, axisGroup)
            val layoutSize = layoutSizes[level]
            layoutSize.relative = if (axisGroup.visibleAxisCount === 0) 0.0 else 1.0
            layoutSize.fixed = if (hierarchy.isRoot(axisGroup)) 0 else max(layoutSize.fixed, 18)
        }

        // Sum up the fixed and relative space
        var sumFixed = 0
        var sumRelative = 0.0
        for (layoutSize in layoutSizes) {
            sumFixed += layoutSize.fixed
            sumRelative += layoutSize.relative
        }

        // Compute the the vertical positions
        val remainingHeigh: Double = (view.getHeight() - sumFixed).toDouble()
        var y = 0.0
        for (layoutSize in layoutSizes) {
            layoutSize.y = y
            layoutSize.height = layoutSize.fixed + remainingHeigh * layoutSize.relative / sumRelative
            y += layoutSize.height.toInt()
        }

        // Assign the vertical positions
        for (axisGroup in hierarchy.breadthFirstIterator()) {
            val level = getLevel(hierarchy, axisGroup)
            val layoutSize = layoutSizes[level]
            verticalLayout[axisGroup] = Rectangle2D.Double(0.0, layoutSize.y, view.getWidth().toDouble(), layoutSize.height)
        }
        return verticalLayout
    }

    fun doAxisGroupsPerLevel(): List<MutableList<AxisGroupModel<Row, Column>>> {
        val axisGroupsPerLevels: MutableList<MutableList<AxisGroupModel<Row, Column>>>
        val model: ParallelCoordinatesModel<Row, Column> = view.model!!
        val hierarchy: Hierarchy<AxisGroupModel<Row, Column>> = model.getAxisHierarchy().axisGroupHierarchy
        val depth: Int = hierarchy.depth
        axisGroupsPerLevels = ArrayList<MutableList<AxisGroupModel<Row, Column>>>(depth + 1)
        for (i in 0 until depth + 1) {
            axisGroupsPerLevels.add(ArrayList<AxisGroupModel<Row, Column>>())
        }
        for (axisGroup in hierarchy.breadthFirstIterator()) {
            val level = getLevel(hierarchy, axisGroup)
            val axisGroups: MutableList<AxisGroupModel<Row, Column>> = axisGroupsPerLevels[level]
            axisGroups.add(axisGroup)
        }
        return axisGroupsPerLevels
    }

    private fun doLevelsLayout(
        axisGroupsPerLevels: List<MutableList<AxisGroupModel<Row, Column>>>,
        verticalLayout: Map<AxisGroupModel<Row, Column>, Rectangle2D>
    ): Map<Int, Rectangle2D> {
        val levelsLayout: MutableMap<Int, Rectangle2D> = HashMap<Int, Rectangle2D>()
        for (l in axisGroupsPerLevels.indices) {
            val axisGroups: List<AxisGroupModel<*,*>> = axisGroupsPerLevels[l]
            for (axisGroup in axisGroups) {
                if (levelsLayout.containsKey(l)) {
                    val bounds: Rectangle2D = levelsLayout[l]!!.createUnion(verticalLayout[axisGroup]!!)
                    levelsLayout[l] = bounds
                } else {
                    levelsLayout[l] = verticalLayout[axisGroup]!!
                }
            }
        }
        return levelsLayout
    }

    private fun getLevel(hierarchy: Hierarchy<AxisGroupModel<Row, Column>>, axisGroup: AxisGroupModel<Row, Column>): Int {
        return hierarchy.getLevel(axisGroup)
    }

    private inner class VerticalLayoutSize {
        var fixed = 0
        var relative = 0.0
        var y = 0.0
        var height = 0.0
    }

    init {
        this.view = view
    }
}