/*
 * Copyright (c) 2016 Macrofocus GmbH. All Rights Reserved.
 */
package com.macrofocus.high_d.scatterplotmatrix

import com.macrofocus.common.command.Future
import com.macrofocus.common.concurrent.Callable
import com.macrofocus.common.concurrent.ExecutorService
import com.macrofocus.common.concurrent.Runtime
import com.macrofocus.common.crossplatform.CPHelper
import com.macrofocus.common.properties.Property
import com.macrofocus.common.timer.CPTimer
import com.macrofocus.common.timer.CPTimerListener
import com.macrofocus.high_d.axis.AxisListener
import com.macrofocus.high_d.axis.AxisModel
import com.macrofocus.high_d.parallelcoordinatesmatrix.ScatterPlotMatrixComponent
import com.macrofocus.order.OrderEvent
import com.macrofocus.order.OrderListener
import org.mkui.canvas.CPCanvas
import org.mkui.canvas.PaletteProvider
import org.mkui.geom.Point2D
import org.mkui.geom.Rectangle
import org.mkui.geom.Rectangle2D
import org.mkui.graphics.AbstractIDrawing
import org.mkui.graphics.IDrawing
import org.mkui.graphics.IGraphics
import org.mkui.graphics.colortheme.ColorTheme
import org.mkui.graphics.pressure.LogPressure
import org.mkui.palette.FixedPalette
import org.mkui.rubberband.RubberbandDrawing
import org.mkui.visual.VisualLayer
import org.mkui.visual.VisualListener
import kotlin.math.min

/**
 * Created by luc on 25/04/16.
 */
abstract class AbstractScatterPlotMatrixComponent<Row, Column>(
    private val view: ScatterPlotMatrixView<Row, Column>,
) : ScatterPlotMatrixComponent< Row, Column> {
    protected val canvas: CPCanvas = CPCanvas()
    override var model: ScatterPlotMatrixModel<Row, Column>? = null
    set(value) {
        if (field != null) {
            field!!.removeScatterPlotMatrixListener(listener)
            field!!.getAxisGroupModel().axisOrder!!.removeOrderListener(orderListener)
            field!!.getAxisGroupModel().removeAxisListener(axisListener)
        }
        field = value
        if (field != null) {
            createOverplots()
            field!!.addScatterPlotMatrixListener(listener)
            field!!.getAxisGroupModel().axisOrder!!.addOrderListener(orderListener)
            field!!.getAxisGroupModel().addAxisListener(axisListener)
        }
        timer.restart()
    }
    protected val timer: CPTimer = CPHelper.instance.createTimer("ScatterPlotMatrixResizer", 40, true, object : CPTimerListener {
        override fun timerTriggered() {
            if (getWidth() > 0 && getHeight() > 0) {
                refresh()
            }
        }
    })
    private val listener: ScatterPlotMatrixListener =
        object : ScatterPlotMatrixListener {
            override fun scatterPlotMatrixChanged() {
                timer.restart()
            }
        }
    private val axisListener: AxisListener = object : AxisListener {
        override fun axisChanged() {
            scheduleUpdate()
        }
    }
    private val orderListener: OrderListener<AxisModel<Row, Column>> = object : OrderListener<AxisModel<Row, Column>> {
        override fun orderChanged(event: OrderEvent<AxisModel<Row, Column>>?) {
            scheduleUpdate()
        }

        override fun orderVisibility(event: OrderEvent<AxisModel<Row, Column>>) {
            scheduleUpdate()
        }

        override fun orderAdded(event: OrderEvent<AxisModel<Row, Column>>) {
            scheduleUpdate()
        }

        override fun orderRemoved(event: OrderEvent<AxisModel<Row, Column>>) {
            scheduleUpdate()
        }
    }

    protected open fun refresh() {
        canvas.refresh()
    }

    protected abstract fun getWidth(): Int
    protected abstract fun getHeight(): Int
    protected abstract fun repaint()

    override fun createOverplots() {
        if (model != null) {
            canvas.removeAllLayers()
            val filteredDrawing: IDrawing = object : AbstractVisualLayerIDrawing(model!!.getVisual().filtered) {
                override val isActive: Boolean
                    get() = view.getShowFiltered().value && super.isActive

                override fun draw(g: IGraphics, mp: Point2D, row: Row) {
                    g.drawPoint(mp)
                }
            }
            val visibleDrawing: IDrawing = object : AbstractVisualLayerIDrawing(model!!.getVisual().visible) {
                override fun draw(g: IGraphics, mp: Point2D, row: Row) {
                    val radius = 2
                    g.fillCircle(Rectangle2D.Double(mp.x - radius, mp.y - radius, radius * 2.0, radius * 2.0))
                }
            }
            val colorMappedDrawing: IDrawing = object : AbstractVisualLayerIDrawing(model!!.getVisual().colorMapped) {
                override fun draw(g: IGraphics, mp: Point2D, row: Row) {
                    g.setColor(model!!.getColorMapping().getColor(row)!!)
                    val radius = 2
                    g.fillCircle(Rectangle2D.Double(mp.x - radius, mp.y - radius, radius * 2.0, radius * 2.0))
                }
            }
            val coloredDrawing: IDrawing = object : AbstractVisualLayerIDrawing(model!!.getVisual().colored) {
                override fun draw(g: IGraphics, mp: Point2D, row: Row) {
                    g.setColor(model!!.getColoring().getColor(row)!!)
                    val radius = 2
                    g.fillCircle(Rectangle2D.Double(mp.x - radius, mp.y - radius, radius * 2.0, radius * 2.0))
                }
            }
            val singleSelectedDrawing: IDrawing = object : AbstractVisualLayerIDrawing(model!!.getVisual().singleSelected) {
                override fun draw(g: IGraphics, mp: Point2D, row: Row) {
                    g.setColor(view.getColorTheme().value.selection)
                    val radius = 2
                    g.fillCircle(Rectangle2D.Double(mp.x - radius, mp.y - radius, radius * 2.0, radius * 2.0))
                }
            }
            val multiSelectedDrawing: IDrawing = object : AbstractVisualLayerIDrawing(model!!.getVisual().multipleSelected) {
                override fun draw(g: IGraphics, mp: Point2D, row: Row) {
                    g.setColor(view.getColorTheme().value.selection)
                    val radius = 2
                    g.fillCircle(Rectangle2D.Double(mp.x - radius, mp.y - radius, radius * 2.0, radius * 2.0))
                }
            }
            val probedDrawing: IDrawing = object : AbstractVisualLayerIDrawing(model!!.getVisual().probed) {
                override fun draw(g: IGraphics, mp: Point2D, row: Row) {
                    g.setColor(view.getColorTheme().value.probing)
                    var radius = 3
                    g.fillCircle(Rectangle2D.Double(mp.x - radius, mp.y - radius, radius * 2.0, radius * 2.0))
                    if (model!!.getSelection().isSelected(row)) {
                        g.setColor(view.getColorTheme().value.selection)
                        radius = 2
                        g.fillCircle(Rectangle2D.Double(mp.x - radius, mp.y - radius, radius * 2.0, radius * 2.0))
                    }
                }
            }
            canvas.addLayer(object : AbstractIDrawing() {
                override fun draw(g: IGraphics, point: Point2D?, width: Double, height: Double, clipBounds: Rectangle) {
                    g.setColor(view.getColorTheme().value.background)
                    g.fillRectangle(0, 0, getWidth(), getHeight())
                    for (i in 0 until model!!.getAxisCount()) {
                        val xAxisModel: AxisModel<Row,Column> = model!!.getAxisModel(i)
                        val location: Double = model!!.getLocation(xAxisModel)
                        g.setColor(view.getColorTheme().value.foreground)
                        g.setLineWidth(3.0)
                        val x = (location * getWidth()).toInt()
                        val y = (location * getHeight()).toInt()
                        g.drawLine(x, 0, x, getHeight())
                        g.drawLine(0, y, getWidth(), y)
                    }
                }
            })
            canvas.addDensityLayer(view.getRendering().value, filteredDrawing, LogPressure(), object : PaletteProvider {
                override val palette: FixedPalette
                    get() = view.getColorTheme().value.ghostedPalette
            })
            canvas.addDensityLayer(view.getRendering().value, visibleDrawing, LogPressure(), object : PaletteProvider {
                override val palette: FixedPalette
                    get() = view.getColorTheme().value.visiblePalette
            })
            canvas.addAveragingLayer(view.getRendering().value, colorMappedDrawing)
            canvas.addBufferedLayer(coloredDrawing)
            canvas.addDensityLayer(view.getRendering().value, multiSelectedDrawing, LogPressure(), object : PaletteProvider {
                override val palette: FixedPalette
                    get() = view.getColorTheme().value.selectedPalette
                   }
            )
            canvas.addLayer(singleSelectedDrawing)
            canvas.addLayer(probedDrawing)
            canvas.addLayer(object : RubberbandDrawing(view.getRubberBand()) {
                override val colorTheme: Property<ColorTheme>
                    get() = view.getColorTheme()
            })
        }
    }

    override fun getClosestRow(x: Int, y: Int): Any? {
        for (i in 0 until model!!.getAxisCount()) {
            val xAxisModel: AxisModel<Row, Column> = model!!.getAxisModel(i)
            val x1Location: Double = model!!.getLocation(xAxisModel)
            var x2Location: Double
            x2Location = if (i + 1 < model!!.getAxisCount()) {
                model!!.getLocation(model!!.getAxisModel(i + 1))
            } else {
                1.0
            }
            if (x1Location != null && x2Location != null) {
                for (j in 0 until model!!.getAxisCount()) {
                    if (i != j) {
                        val yAxisModel: AxisModel<Row, Column> = model!!.getAxisModel(j)
                        val y1Location: Double = model!!.getLocation(yAxisModel)
                        var y2Location: Double
                        y2Location = if (j + 1 < model!!.getAxisCount()) {
                            model!!.getLocation(model!!.getAxisModel(j + 1))
                        } else {
                            1.0
                        }
                        if (y1Location != null && y2Location != null) {
                            val x1 = (x1Location * getWidth()).toInt()
                            val y1 = (y1Location * getHeight()).toInt()
                            val x2 = (x2Location * getWidth()).toInt()
                            val y2 = (y2Location * getHeight()).toInt()
                            val width = x2 - x1
                            val height = y2 - y1
                            if (x >= x1 && x <= x2 && y >= y1 && y <= y2) {
                                var minDistance = Double.MAX_VALUE
                                var closest: Row? = null
                                val visible: VisualLayer<Row> = model!!.getVisual().active
                                for (row in visible) {
                                    val mp: Point2D? = getPoint(AxisModel.DATA, row, xAxisModel, yAxisModel, x1, y1, width, height)
                                    if (mp != null) {
                                        val distance: Double = Point2D.distance(x.toDouble(), y.toDouble(), mp.x, mp.y)
                                        if (distance < minDistance) {
                                            closest = row
                                            minDistance = distance
                                        }
                                    }
                                }
                                return closest
                            }
                        }
                    }
                }
            }
        }
        return null
    }

    override fun getRows(rect: Rectangle2D): List<Row> {
        val list: MutableList<Row> = ArrayList<Row>()
        for (i in 0 until model!!.getAxisCount()) {
            val xAxisModel: AxisModel<Row, Column> = model!!.getAxisModel(i)
            val x1Location: Double = model!!.getLocation(xAxisModel)
            var x2Location: Double
            x2Location = if (i + 1 < model!!.getAxisCount()) {
                model!!.getLocation(model!!.getAxisModel(i + 1))
            } else {
                1.0
            }
            if (x1Location != null && x2Location != null) {
                for (j in 0 until model!!.getAxisCount()) {
                    if (i != j) {
                        val yAxisModel: AxisModel<Row, Column> = model!!.getAxisModel(j)
                        val y1Location: Double = model!!.getLocation(yAxisModel)
                        var y2Location: Double
                        y2Location = if (j + 1 < model!!.getAxisCount()) {
                            model!!.getLocation(model!!.getAxisModel(j + 1))
                        } else {
                            1.0
                        }
                        if (y1Location != null && y2Location != null) {
                            val x1 = (x1Location * getWidth()).toInt()
                            val y1 = (y1Location * getHeight()).toInt()
                            val x2 = (x2Location * getWidth()).toInt()
                            val y2 = (y2Location * getHeight()).toInt()
                            val width = x2 - x1
                            val height = y2 - y1
                            if (rect.x >= x1 && rect.x <= x2 && rect.y >= y1 && rect.y <= y2) {
                                val visible: VisualLayer<Row> = model!!.getVisual().active
                                for (row in visible) {
                                    val mp: Point2D? = getPoint(AxisModel.DATA, row, xAxisModel, yAxisModel, x1, y1, width, height)
                                    if (mp != null) {
                                        if (rect.contains(mp)) {
                                            list.add(row)
                                        }
                                    }
                                }
                                return list
                            }
                        }
                    }
                }
            }
        }
        return list
    }

    fun getPoint(layer: Int, row: Row, xAxisModel: AxisModel<Row,*>, yAxisModel: AxisModel<Row,*>, x: Int, y: Int, width: Int, height: Int): Point2D? {
        var x = x
        var y = y
        var width = width
        var height = height
        val inset = 5
        x += inset
        width -= inset + inset
        y += inset
        height -= inset + inset
        val v1: Number? = xAxisModel.getValue(layer, row)
        val v2: Number? = yAxisModel.getValue(layer, row)
        return if (v1 != null && v2 != null) {
            Point2D.Double(
                x + ((width - 1) * ((v1.toDouble() - xAxisModel.minimum) / (xAxisModel.maximum - xAxisModel.minimum))),
                y + height - ((height - 1) * ((v2.toDouble() - yAxisModel.minimum) / (yAxisModel.maximum - yAxisModel.minimum)))
            )
        } else {
            null
        }
    }

    abstract inner class AbstractVisualLayerIDrawing protected constructor(visualLayer: VisualLayer<Row>) : AbstractIDrawing() {
        private val visualLayer: VisualLayer<Row>
        override val isActive: Boolean
            get() = visualLayer.isActive

        override fun draw(g: IGraphics, point: Point2D?, width: Double, height: Double, clipBounds: Rectangle) {
            if (visualLayer.objectCount > 0) {
                if (!MULTITHREADED || !g.isThreadSafe()) {
                    for (row in visualLayer) {
                        draw(g, row)
                    }
                } else {
                    val nTasks: Int =
                        min(Runtime.getRuntime().availableProcessors(), visualLayer.objectCount) // Determine number of tasks i.e. threads
                    val nRowsPerTask: Int = visualLayer.objectCount / nTasks
                    val todo: MutableList<Callable<Any?>> = ArrayList<Callable<Any?>>(nTasks)
                    for (nTask in 0 until nTasks)  // Iterate over all tasks
                    {
                        val fromRow = nTask * nRowsPerTask // the tasks first row
                        val toRow = if (nTask < nTasks - 1) fromRow + nRowsPerTask else visualLayer.objectCount // the tasks last row
                        todo.add(object : Callable<Any?> {
                            @Throws(Exception::class)
                            override fun call(): Any? {
                                for (row in visualLayer.iterable(
                                    // Create and define the task
                                    fromRow, toRow - 1
                                )) {
                                    draw(g, row)
                                }
                                return null
                            }
                        })
                    }
                    try {
                        val answers: List<Future<Any?>> = executor!!.invokeAll(todo)
                        for (answer in answers) {
                            try {
                                answer.get()
                            } catch (e: Exception) {
                                e.printStackTrace()
                            }
                        }
                    } catch (e: Exception) {
                        e.printStackTrace() // something wrong, maybe
                    }
                }
            }
        }

        protected fun draw(g: IGraphics, row: Row) {
            for (i in 0 until model!!.getAxisCount()) {
                val xAxisModel: AxisModel<Row, Column> = model!!.getAxisModel(i)
                val x1Location: Double = model!!.getLocation(xAxisModel)
                var x2Location: Double
                x2Location = if (i + 1 < model!!.getAxisCount()) {
                    model!!.getLocation(model!!.getAxisModel(i + 1))
                } else {
                    1.0
                }
                if (x1Location != null && x2Location != null) {
                    for (j in 0 until model!!.getAxisCount()) {
                        if (i != j) {
                            val yAxisModel: AxisModel<Row, Column> = model!!.getAxisModel(j)
                            val y1Location: Double = model!!.getLocation(yAxisModel)
                            var y2Location: Double
                            y2Location = if (j + 1 < model!!.getAxisCount()) {
                                model!!.getLocation(model!!.getAxisModel(j + 1))
                            } else {
                                1.0
                            }
                            if (y1Location != null && y2Location != null) {
                                val x = (x1Location * getWidth()).toInt()
                                val y = (y1Location * getHeight()).toInt()
                                val width = (x2Location * getWidth()).toInt() - x
                                val height = (y2Location * getHeight()).toInt() - y
                                val mp: Point2D? = getPoint(getLayer(), row, xAxisModel, yAxisModel, x, y, width, height)
                                if (mp != null) {
                                    draw(g, mp, row)
                                }
                            }
                        }
                    }
                }
            }
        }

        fun getLayer(): Int {
            return AxisModel.DATA
        }

        abstract fun draw(g: IGraphics, p: Point2D, row: Row)

        init {
            this.visualLayer = visualLayer
            visualLayer.addVisualListener(object : VisualListener {
                override fun visualChanged() {
                    notifyIDrawingChanged()
                }
            })
        }
    }

    companion object {
        private const val MULTITHREADED = true
        private val executor: ExecutorService? = CPHelper.instance.visualizationExecutorService()
    }
}