/*
 * Copyright (c) 2014 Macrofocus GmbH. All Rights Reserved.
 */
package com.macrofocus.high_d.parallelcoordinates

import com.macrofocus.common.command.Future
import com.macrofocus.common.concurrent.Callable
import com.macrofocus.common.concurrent.ExecutorService
import com.macrofocus.common.concurrent.Runtime
import com.macrofocus.common.crossplatform.CPHelper
import com.macrofocus.common.filter.Filter
import com.macrofocus.common.properties.Property
import com.macrofocus.common.selection.Selection
import com.macrofocus.common.selection.SelectionEvent
import com.macrofocus.common.selection.SelectionListener
import com.macrofocus.common.timer.CPTimer
import com.macrofocus.common.timer.CPTimerListener
import com.macrofocus.high_d.axis.AxisModel
import com.macrofocus.high_d.parallelcoordinates.geometry.Geometry
import com.macrofocus.high_d.parallelcoordinates.layout.ParallelCoordinatesLayout
import org.mkui.canvas.CPCanvas
import org.mkui.canvas.PaletteProvider
import org.mkui.color.MkColor
import org.mkui.color.colorOf
import org.mkui.component.CPComponentWrapper
import org.mkui.geom.*
import org.mkui.graphics.AbstractIDrawing
import org.mkui.graphics.IGraphics
import org.mkui.graphics.colortheme.ColorTheme
import org.mkui.graphics.pressure.LogPressure
import org.mkui.graphics.pressure.Pressure
import org.mkui.palette.FixedPalette
import org.mkui.rubberband.RubberbandDrawing
import org.mkui.visual.VisualLayer
import org.mkui.visual.VisualListener
import kotlin.math.min
import kotlin.math.sqrt

abstract class AbstractParallelCoordinatesComponent<Row, Column> protected constructor(
    protected val view: ParallelCoordinatesView<Row, Column>,
    protected val canvas: CPCanvas = CPCanvas(),
    protected val level: Int
) : CPComponentWrapper(canvas), ParallelCoordinatesComponent<Row, Column> {
    override var model: ParallelCoordinatesModel<Row, Column>? = null
        set(value) {
            field?.removeParallelCoordinateListener(listener)
            field = value
            if (value != null) {
                createOverplots()
                value.addParallelCoordinatesListener(listener)
            }
            cache.clear()
            if (value != null) {
                timer.restart()
            }
        }
    private val cache: Cache = Cache()
    private val listener: ParallelCoordinatesListener =
        object : ParallelCoordinatesListener {
            override fun pararallelCoordinatesChanged() {
                repaint()
            }
        }
    protected val timer: CPTimer
    protected abstract fun getWidth(): Int
    protected abstract fun getHeight(): Int
    protected abstract fun refresh()
    protected abstract fun repaint()

    override fun createOverplots() {
        if (model != null) {
            canvas.removeAllLayers()
            canvas.addLayer(object : AbstractIDrawing() {
                override fun draw(g: IGraphics, point: Point2D?, width: Double, height: Double, clipBounds: Rectangle) {
                    val background: MkColor? = view.getColorTheme().value.background
                    if (background != null) {
                        g.setColor(background)
                        g.fillRectangle2D(Rectangle2D.Double(0.0, 0.0, width, height))
                    }
                }
            })
            canvas.addLayer(object : AbstractIDrawing() {
                override fun draw(g: IGraphics, point: Point2D?, width: Double, height: Double, clipBounds: Rectangle) {
                    if (view.getGeometry().value !== Geometry.Steps && view.getAxisLine().value) {
                        g.setColor(view.getColorTheme().value.foreground)
                        g.setLineWidth(axisStroke)
                        val layout: ParallelCoordinatesLayout<Row, Column> = view.getParallelCoordinatesLayout()
                        for (axisGroup in layout.getAxisGroups(level)) {
                            for (axisModel in axisGroup.axisOrder!!) {
                                val x: Int = view.getAxisX(axisGroup, axisModel) + layout.getAxisPreferredWidth() / 2
                                val trackBounds: Rectangle2D = layout.getTrackBounds(axisGroup, axisModel)
                                g.drawLine(x, trackBounds.y.toInt(), x, (trackBounds.y + trackBounds.height).toInt())
                            }
                        }
                    }
                }
            })
            canvas.addDensityLayer(view.getRendering().value,
                object : AbstractVisualLayerIDrawing(model!!.getVisual().filtered, cache, AxisModel.DATA, level) {
                    override val isActive: Boolean
                        get() = view.getShowFiltered().value && super.isActive

                    override fun draw(g: IGraphics, shape: Shape, row: Row) {
                        g.drawShape(shape)
                    }
                },
                LogPressure(),
                object : PaletteProvider {
                    override val palette: FixedPalette
                        get() = view.getColorTheme().value.ghostedPalette
                }
            )
            canvas.addDensityLayer(view.getRendering().value,
                object : AbstractVisualLayerIDrawing(model!!.getVisual().visible, cache, AxisModel.DATA, level) {
                    override fun draw(g: IGraphics, shape: Shape, row: Row) {
                        g.drawShape(shape)
                    }
                },
                LogPressure(),
                object : PaletteProvider {
                    override val palette: FixedPalette
                        get() = view.getColorTheme().value.visiblePalette
                }
            )
            canvas.addAveragingLayer(view.getRendering().value,
                object : AbstractVisualLayerIDrawing(model!!.getVisual().colorMapped, cache, AxisModel.DATA, level) {
                    override fun draw(g: IGraphics, shape: Shape, row: Row) {
                        g.setColor(model!!.getColorMapping().getColor(row)!!)
                        g.drawShape(shape)
                    }
                })
            canvas.addBufferedLayer(object : AbstractVisualLayerIDrawing(model!!.getVisual().colored, cache, AxisModel.DATA, level) {
                override fun draw(g: IGraphics, shape: Shape, row: Row) {
                    g.setLineWidth(coloredStroke)
                    g.setColor(model!!.getColoring().getColor(row)!!)
                    g.drawShape(shape)
                }
            })
            canvas.addDensityLayer(view.getRendering().value,
                object : AbstractVisualLayerIDrawing(model!!.getVisual().multipleSelected, cache, AxisModel.DATA, level) {
                    override fun draw(g: IGraphics, shape: Shape, row: Row) {
                        g.setColor(view.getColorTheme().value.selection)
                        g.drawShape(shape)
                    }
                },
                LogPressure(),
                object : PaletteProvider {
                    override val palette: FixedPalette
                        get() = view.getColorTheme().value.selectedPalette

                }
            )
            canvas.addLayer(object : AbstractVisualLayerIDrawing(model!!.getVisual().singleSelected, cache, AxisModel.DATA, level) {
                override fun draw(g: IGraphics, shape: Shape, row: Row) {
                    g.setLineWidth(selectionStroke)
                    g.setColor(view.getColorTheme().value.selection)
                    g.drawShape(shape)
                }
            })
            canvas.addLayer(object : AbstractVisualLayerIDrawing(model!!.getVisual().probed, cache, AxisModel.DATA, level) {
                override fun draw(g: IGraphics, shape: Shape, row: Row) {
                    g.setLineWidth(probingStroke)
                    g.setColor(view.getColorTheme().value.probing)
                    g.drawShape(shape)
                    if (model!!.getVisual().selection.isSelected(row)) {
                        g.setLineWidth(selectionStroke)
                        g.setColor(view.getColorTheme().value.selection)
                        g.drawShape(shape)
                    }
                }
            })
            canvas.addLayer(AnnotationIDrawing(model!!.getVisual().annotationProbing, colorOf(255, 150, 0, 200)))
            canvas.addLayer(AnnotationIDrawing(model!!.getVisual().annotationSelection, colorOf(255, 100, 0, 200)))
            canvas.addLayer(object : AbstractIDrawing() {
                override val isActive: Boolean
                    get() = view.getShowProbedValues().value

                override fun draw(g: IGraphics, point: Point2D?, width: Double, height: Double, clipBounds: Rectangle) {
                    val row: Row? = model!!.getProbing().selected
                    if (row != null) {
                        val layout: ParallelCoordinatesLayout<Row,Column> = view.getParallelCoordinatesLayout()
                        for (axisGroup in layout.getAxisGroups(level)) {
                            if (!axisGroup.isCollapsed) {
                                for (axisModel in axisGroup.axisOrder!!) {
                                    //                                final int x = view.getAxisX(axisGroup, axisModel) + (layout.getAxisPreferredWidth() / 2);
                                    val trackBounds: Rectangle2D = layout.getTrackBounds(axisGroup, axisModel)
                                    val x: Int = trackBounds.x.toInt() + layout.getAxisPreferredWidth() / 2
                                    val value: Number? = axisModel.getValue(AxisModel.DATA, row)
                                    val currentAvailable = value != null
                                    if (currentAvailable) {
                                        val height = trackBounds.height.toInt()
                                        val levelBounds: Rectangle2D = view.getParallelCoordinatesLayout().getLevelBounds(level)
                                        val fullHeight = levelBounds.height.toInt()
                                        var y: Double = Geometry.getY(
                                            fullHeight,
                                            height,
                                            layout.getAxisAfterTrackGap(),
                                            value!!.toDouble(),
                                            axisModel
                                        )
                                        val formattedValue: String = axisModel.getFormattedValue(row)
                                        val bounds: Rectangle2D = g.getStringBounds(formattedValue)
                                        if (view.getColorTheme().value.isDark) {
                                            g.setColor(colorOf(0, 0, 0, 191))
                                        } else {
                                            g.setColor(colorOf(0, 0, 0, 191))
                                        }
                                        if (y > fullHeight / 2) {
                                            y -= 16 + bounds.height / 2
                                        } else {
                                            y += 16 + bounds.height / 2
                                        }
                                        g.fillRectangle2D(Rectangle2D.Double(x - 2.0, y - bounds.height / 2 - 2, bounds.width + 4, bounds.height + 4))
                                        g.setTextBaseline(IGraphics.TextBaseline.Middle)
                                        g.setColor(view.getColorTheme().value.probing)
                                        g.drawString(formattedValue, x.toFloat(), y.toFloat())
                                    }
                                }
                            }
                        }
                    }
                }
            })
            canvas.addLayer(object : RubberbandDrawing(view.getRubberBand()) {
                override val colorTheme: Property<ColorTheme>
                    get() = view.getColorTheme()
            })
        }
    }

    protected fun getShape(layer: Int, row: Row, level: Int): Shape? {
        return cache.getShape(layer, row, level)
    }

    override fun clearCache() {
        cache.clear()
    }

    override fun getClosestRow(x: Int, y: Int): Row? {
        var bestdistance = Double.MAX_VALUE
        var bestrow: Row? = null
        if (model != null) {
            for (row in 0 until model!!.objectCount) {
                val rowObject: Row = model!!.getObject(row)
                val filter: Filter<Row> = model!!.getFilter()
                if (!filter.isFiltered(rowObject)) {
                    val line: Shape? = getShape(AxisModel.DATA, rowObject, level)
                    if (line != null) {
                        val pi: PathIterator = line.getPathIterator()
                        val data = FloatArray(6)
                        var previousX = 0
                        var previousY = 0
                        var lastdistance = Double.MAX_VALUE
                        val willNotImprove = false
                        while (!willNotImprove && !pi.isDone) {
                            val segType: Int = pi.currentSegment(data)
                            when (segType) {
                                PathIterator.SEG_MOVETO -> {
                                    previousX = data[0].toInt()
                                    previousY = data[1].toInt()
                                    val distance = distance(previousX.toDouble(), previousY.toDouble(), x.toDouble(), y.toDouble())
                                    if (distance < bestdistance) {
                                        bestdistance = distance
                                        bestrow = rowObject
                                    }
                                    // ToDo: Optimization to be verified
                                    if (distance > lastdistance) {
//                                        willNotImprove = true;
                                    }
                                    lastdistance = distance
                                }
                                PathIterator.SEG_LINETO -> {
                                    val currentX = data[0].toInt()
                                    val currentY = data[1].toInt()
                                    val distance: Double = Line2D.ptSegDist(previousX.toDouble(), previousY.toDouble(), currentX.toDouble(), currentY.toDouble(), x.toDouble(), y.toDouble())
                                    if (distance < bestdistance) {
                                        bestdistance = distance
                                        bestrow = rowObject
                                    }

                                    // ToDo: Optimization to be verified
                                    if (distance > lastdistance) {
//                                        willNotImprove = true;
                                    }
                                    lastdistance = distance
                                    previousX = currentX
                                    previousY = currentY
                                }
                                else -> throw Error("Segmented type unsupported")
                            }
                            pi.next()
                        }
                    }
                }
            }
        }
        return bestrow
    }

    override fun getRows(r: Rectangle2D): List<Row>? {
        if (model != null) {
            val list: MutableList<Row> = ArrayList<Row>()
            for (row in 0 until model!!.objectCount) {
                val rowObject: Row = model!!.getObject(row)
                val filter: Filter<Row> = model!!.getFilter()
                if (!filter.isFiltered(rowObject)) {
                    val line: Shape? = getShape(AxisModel.DATA, rowObject, level)
                    if (line != null) {
                        val pi: PathIterator = line.getPathIterator()
                        val data = FloatArray(6)
                        var previousX = 0
                        var previousY = 0
                        loop@ while (!pi.isDone) {
                            val segType: Int = pi.currentSegment(data)
                            when (segType) {
                                PathIterator.SEG_MOVETO -> {
                                    previousX = data[0].toInt()
                                    previousY = data[1].toInt()
                                }
                                PathIterator.SEG_LINETO -> {
                                    val currentX = data[0].toInt()
                                    val currentY = data[1].toInt()
                                    if (r.intersectsLine(previousX.toDouble(), previousY.toDouble(), currentX.toDouble(), currentY.toDouble())) {
                                        list.add(rowObject)
                                        break@loop
                                    }
                                    previousX = currentX
                                    previousY = currentY
                                }
                                else -> Error("Segmented type unsupported")
                            }
                            pi.next()
                        }
                    }
                }
            }
            return list
        }
        return null
    }

    abstract inner class AbstractVisualLayerIDrawing protected constructor(visualLayer: VisualLayer<Row>, cache: Cache, layer: Int, level: Int) :
        AbstractIDrawing() {
        private val visualLayer: VisualLayer<Row>
        private val cache: Cache
        private val layer: Int
        private val level: Int
        override val isActive: Boolean
        get() = visualLayer.isActive

        override fun draw(g: IGraphics, point: Point2D?, width: Double, height: Double, clipBounds: Rectangle) {
            if (visualLayer.objectCount > 0) {
                init(g)
                if (!MULTITHREADED || !g.isThreadSafe()) {
                    for (row in visualLayer) {
                        val geometry: Shape? = cache.getShape(layer, row, level)
                        if (geometry != null) {
                            draw(g, geometry, row)
                        }
                    }
                } else {
                    val nTasks: Int =
                        min(Runtime.getRuntime().availableProcessors(), visualLayer.objectCount) // Determine number of tasks i.e. threads
                    val nRowsPerTask: Int = visualLayer.objectCount / nTasks
                    val todo: MutableList<Callable<Any?>> = ArrayList<Callable<Any?>>(nTasks)
                    for (nTask in 0 until nTasks)  // Iterate over all tasks
                    {
                        val fromRow = nTask * nRowsPerTask // the tasks first row
                        val toRow = if (nTask < nTasks - 1) fromRow + nRowsPerTask else visualLayer.objectCount // the tasks last row
                        todo.add(object : Callable<Any?> {
                            @Throws(Exception::class)
                            override fun call(): Any? {
                                for (row in visualLayer.iterable(
                                    // Create and define the task
                                    fromRow, toRow - 1
                                )) {
                                    val geometry: Shape? = cache.getShape(layer, row, level)
                                    if (geometry != null) {
                                        draw(g, geometry, row)
                                    }
                                }
                                return null
                            }
                        })
                    }
                    try {
                        val answers: List<Future<Any?>> = executor!!.invokeAll(todo)
                        for (answer in answers) {
                            try {
                                answer.get()
                            } catch (e: Exception) {
                                e.printStackTrace()
                            }
                        }
                    } catch (e: Exception) {
                        e.printStackTrace() // something wrong, maybe
                    }
                }
            }
        }

        fun init(g: IGraphics) {}
        abstract fun draw(g: IGraphics, shape: Shape, row: Row)

        init {
            this.visualLayer = visualLayer
            this.cache = cache
            this.layer = layer
            this.level = level
            visualLayer.addVisualListener(object : VisualListener {
                override fun visualChanged() {
                    notifyIDrawingChanged()
                }
            })
        }
    }

    inner class Cache {
        private val shapes: MutableMap<Row, Shape?>
//        private val lock: ReadWriteLock = ReentrantReadWriteLock()
        fun getShape(layer: Int, row: Row, level: Int): Shape? {
//            lock.readLock().lock()
//            try {
                if (shapes.containsKey(row)) {
                    return shapes[row]
                }
//            } finally {
//                lock.readLock().unlock()
//            }
            val shape: Shape?
//            return try {
                val levelBounds: Rectangle2D = view.getParallelCoordinatesLayout().getLevelBounds(level)
                val fullHeight = levelBounds.height.toInt()
                shape = createShape(fullHeight, layer, row, level)
//                lock.writeLock().lock()
//                try {
                    shapes[row] = shape
//                } finally {
//                    lock.writeLock().unlock()
//                }
                return shape
//            } catch (e: RuntimeException) {
//                e.printStackTrace()
//                null
//            }
        }

        fun createShape(fullHeight: Int, layer: Int, row: Row, level: Int): Shape? {
            return view.getGeometry().value.createGeometry(view, model, fullHeight, layer, row, level)
        }

        fun clear() {
//            lock.writeLock().lock()
//            try {
                shapes.clear()
//            } finally {
//                lock.writeLock().unlock()
//            }
        }

        init {
            shapes = HashMap<Row, Shape?>()
        }
    }

    private inner class AnnotationIDrawing(private val selection: Selection<Any>, color: MkColor) : AbstractIDrawing() {
        private val color: MkColor
        private val selectionListener: SelectionListener<Any> = object : SelectionListener<Any> {
            override fun selectionChanged(event: SelectionEvent<Any>) {
                notifyIDrawingChanged()
            }
        }

        override val isActive: Boolean
        get() = selection.isActive

        override fun draw(g: IGraphics, point: Point2D?, width: Double, height: Double, clipBounds: Rectangle) {
            for (row in selection) {
                g.setColor(color)
                val levelBounds: Rectangle2D = view.getParallelCoordinatesLayout().getLevelBounds(level)
                val fullHeight = levelBounds.height.toInt()
                val shape: Shape? = cache.createShape(fullHeight, AxisModel.ANNOTATION, row as Row, level)
                g.drawShape(shape!!)
            }
        }

        init {
            this.color = color
            selection.addWeakSelectionListener(selectionListener)
        }
    }

    companion object {
        //    private static final BasicStroke normalStroke = new BasicStroke(1f);
        private const val coloredStroke = 1.5
        private const val probingStroke = 4.0
        private const val selectionStroke = 2.0
        private const val axisStroke = 3.0
        private const val rubberbandStroke = 1.0
        private const val MULTITHREADED = true
        private val executor: ExecutorService? = CPHelper.instance.visualizationExecutorService()
        private val logPressure: Pressure = object : Pressure {
            override fun transform(v: Double): Double {
                return if (v > 0.0) {
                    v
                } else {
                    0.0
                }
            }
        }

        /**
         * Computes the euclidean distance between (x1,y1) and (x2,y2)
         *
         * @param x1
         * @param y1
         * @param x2
         * @param y2
         *
         * @return euclidean distance between (x1,y1) and (x2,y2)
         */
        fun distance(
            x1: Double, y1: Double,
            x2: Double, y2: Double
        ): Double {
            var x1 = x1
            var y1 = y1
            x1 -= x2
            y1 -= y2
            return sqrt(x1 * x1 + y1 * y1)
        }
    }

    init {
        timer = CPHelper.instance.createTimer("ParallelCoordinates\$Resizer", 40, true, object : CPTimerListener {
            override fun timerTriggered() {
                if (getWidth() > 0 && getHeight() > 0) {
                    refresh()
                }
            }
        })
    }
}