import { box_muller } from './rand.js';
import * as d3 from 'd3'

const matrixMultiplication = (A, B) => {
  var result = new Array(A.length).fill(0).map(row => new Array(B[0].length).fill(0));
  return result.map((row, i) => {
    return row.map((val, j) => {
      return A[i].reduce((sum, elm, k) => sum + (elm*B[k][j]) ,0)
    })
  })
}

const matrixElementwiseOperation = (A, B, operation) => {
  // If any of the matrices is a scalar, then broadcast it to be a matrix
  if (typeof(A) === "number") {
    A = new Array(B.length).fill(A).map(row => new Array(B[0].length).fill(A));
  }
  else if (typeof(B) === "number") {
    B = new Array(A.length).fill(B).map(row => new Array(A[0].length).fill(B));
  }
  var result = new Array(A.length).fill(0).map(row => new Array(A[0].length).fill(0));
  for (var i = 0; i < A.length; i++) {
    for (var j = 0; j < A[0].length; j++) {
      if (operation === "Addition") {
        result[i][j] = A[i][j] + B[i][j]
      }
      else if (operation === "Subtraction") {
        result[i][j] = A[i][j] - B[i][j]
      }
      else if (operation === "Multiplication") {
        result[i][j] = A[i][j] * B[i][j]
      }
      else if (operation === "Division") {
        result[i][j] = A[i][j] / B[i][j]
      }
    }
  }
  return result
}

const matrixElementwiseFunction = (matrix, func, additionalInputs) => {
  var result = new Array(matrix.length).fill(0).map(row => new Array(matrix[0].length).fill(0));
  for (var i = 0; i < matrix.length; i++) {
    for (var j = 0; j < matrix[0].length; j++) {
      if (additionalInputs !== undefined) {
        result[i][j] = func(matrix[i][j], ...additionalInputs)
      }
      else {
        result[i][j] = func(matrix[i][j])
      }

    }
  }
  return result
}

const layerNormalization = (layer) => {
  return layer.map((instance) => {
    const mean = d3.mean(instance)
    const std = d3.deviation(instance)
    return instance.map((x) => { return (x - mean) / std } )
  })
}

const featureDropout = (features, targetedInputs) => {
  var result = new Array(features.length).fill(0).map(row => new Array(features[0].length).fill(0))
  for (var i = 0; i < features.length; i++) {
    for (var j = 0; j < features[0].length; j++) {
      if (targetedInputs.includes(j)) {
        result[i][j] = features[i][j]
      }
      else {
        break
      }
    }
  }
  return result
}

const transposeMatrix = (matrix) => {
  return matrix[0].map((_, colIndex) => matrix.map(row => row[colIndex]))
}

const generateRandomMatrices = (inputDim, outputDim) => {
  const initScale = 2 / (inputDim + outputDim)
  var matrix = []
  for (var i = 0; i < inputDim; i++) {
    for (var j = 0; j < outputDim; j++) {
      if (j === 0) {
        matrix.push([])
      }
      matrix[matrix.length-1].push(box_muller(0, initScale))
    }
  }
  return matrix
}

const linear = (x, derivativeBool) => {
  if (derivativeBool) {
    const linearMatrix = []
    x.forEach((currentInstance) => {
      linearMatrix.push(currentInstance.map((x_i) => { 
        return x_i
      }))
    })
    return linearMatrix
  }
  return x
}

const sigmoidScalar = (x, derivativeBool) => {
  if (derivativeBool) {
    const sig = sigmoidScalar(x, false)
    return sig * (1 - sig)
  }
  return 1 / (1 + Math.exp(-x))
}

const sigmoid = (x, derivativeBool) => {
  var sigmoidMatrix = []
  x.forEach((currentInstance) => {
    sigmoidMatrix.push(currentInstance.map((x_i) => { 
      return sigmoidScalar(x_i, derivativeBool) 
    }))
  })
  return sigmoidMatrix
}

const tanhScalar = (x, derivativeBool) => {
  if (derivativeBool) {
    return 1 - tanhScalar(x, false)**2
  }
  return (Math.exp(x) - Math.exp(-x)) / (Math.exp(x) + Math.exp(-x))
}

const tanh = (x, derivativeBool) => {
  var tanhMatrix = []
  x.forEach((currentInstance) => {
    tanhMatrix.push(currentInstance.map((x_i) => { 
      return tanhScalar(x_i, derivativeBool) 
    }))
  })
  return tanhMatrix
}

const relu = (x, derivativeBool) => {
  var reluMatrix = []
  x.forEach((currentInstance) => {
    if (derivativeBool) {
      reluMatrix.push(currentInstance.map((x_i) => { return x_i > 0 ? 1 : 0 }))
    }
    else {
      reluMatrix.push(currentInstance.map((x_i) => { return x_i > 0 ? x_i : 0 }))
    }
  })
  return reluMatrix
}

const elu = (x, derivativeBool) => {
  var eluMatrix = []
  const alpha = 1
  x.forEach((currentInstance) => {
    if (derivativeBool) {
      eluMatrix.push(currentInstance.map((x_i) => {
        return x_i > 0 ? 1 : alpha * (Math.exp(x_i) - 1) + alpha
      }))
    }
    else {
      eluMatrix.push(currentInstance.map((x_i) => {
        return x_i > 0 ? x_i : alpha * (Math.exp(x_i) - 1)
      }))
    }
  })
  return eluMatrix
}

const softmax = (x, derivativeBool) => {
  var xMax
  var denom
  var stabalizedCurrentInstance
  var softmaxMatrix = []
  x.forEach((currentInstance) => {
    if (derivativeBool) {
      console.log("PUT SOMETHING HERE")
    }
    else {
      xMax = d3.max(currentInstance)
      stabalizedCurrentInstance = currentInstance.map((x_i) => { return  x_i - xMax })
      denom = stabalizedCurrentInstance.reduce((total, current) => total + Math.exp(current), 0)
      softmaxMatrix.push(stabalizedCurrentInstance.map((x_i) => { return Math.exp(x_i) / denom }))
    }
  })
  return softmaxMatrix
}

const valuesToOneHot = (x) => {
  var result = new Array(x.length).fill(0).map(row => new Array(x[0].length).fill(0));
  for (var i = 0; i < x.length; i++) {
    var maxIdx = 0
    var maxNum = 0
    for (var j = 0; j < x[0].length; j++) {
      if (x[i][j] > maxNum) {
        maxIdx = j
        maxNum = x[i][j]
      }
    }
    result[i][maxIdx] = 1
  }
  return result
}

const idxToOnehot = (indices, numOptions) => {
  var result = new Array(indices.length).fill(0).map(row => new Array(numOptions).fill(0));
  for (var i = 0; i < indices.length; i++) {
    result[i][indices[i]] = 1;
  }
  return result
}

const getCategoricalCrossentropy = (probabilities, labels) => {
  const EPSILON = 0.0001
  , logPredictions = matrixElementwiseFunction(matrixElementwiseOperation(probabilities, EPSILON, "Addition"), Math.log)
  , categoricalCrossentropy = matrixElementwiseOperation(labels, logPredictions, "Multiplication")
  , crossentropyReduced = categoricalCrossentropy.map((row) => {
    return row.reduce((a,b) => a + b, 0)
  })
  return -crossentropyReduced.reduce((a,b) => a + b, 0) / crossentropyReduced.length
}

const flattenParameters = (parameters) => {
  var flattenedParams = [];
  for (var layer = 0; layer < parameters.length; layer++) {
    for (var row = 0; row < parameters[layer].length; row++) {
      flattenedParams = flattenedParams.concat(parameters[layer][row])
    }
  }
  return flattenedParams
}

export {
  matrixMultiplication,
  matrixElementwiseOperation,
  matrixElementwiseFunction,
  transposeMatrix,
  generateRandomMatrices,
  featureDropout,
  layerNormalization,
  linear,
  sigmoid,
  tanh,
  relu,
  elu,
  softmax,
  valuesToOneHot,
  idxToOnehot,
  getCategoricalCrossentropy,
  flattenParameters
}
