import {
  matrixMultiplication,
  matrixElementwiseOperation,
  linear,
  sigmoid,
  tanh,
  relu,
  elu,
  softmax,
} from './matrix-math';
import { policyMapping } from './policy-mapping';

const activationMapping = {
  "relu": relu,
  "elu": elu,
  "sigmoid": sigmoid,
  "tanh": tanh,
  "linear": linear,
  "softmax": softmax
}

export class neuralNetwork {
  constructor(nFeatures, neurons, nActions, flattenedWeights, flattenedBiases, metadata) {
    this.nFeatures = nFeatures
    this.neurons = neurons
    this.nActions = nActions
    this.useActivation = true
    this.usePrior = true

    this.metadata = metadata

    this.activationFunction = activationMapping[this.metadata.activationFunction]    
    if (this.activationFunction === undefined) {
      this.activationFunction = relu
    }    
    this.outputActivation = activationMapping[this.metadata.outputActivation]    
    if (this.outputActivation === undefined) {
      this.outputActivation = softmax
    }

    this.flattenedWeights = flattenedWeights
    this.flattenedBiases = flattenedBiases
    this.reshapeParameters(this.flattenedWeights, this.flattenedBiases)
  }

  reshapeParameters(flattenedWeights, flattenedBiases) {
    var reshapedWeightMatrices = []
    var reshapedBiasMatrices = []
    var flattenedWeightIdx = 0
    var flattenedBiasIdx = 0
    var inDim
    var outDim

    for (var layer = 0; layer <= this.neurons.length; layer++) {
      if (layer === 0) {
        inDim = this.nFeatures;
        outDim = this.neurons[layer];
      }
      else if (layer === this.neurons.length) {
        inDim = this.neurons[layer-1];
        outDim = this.nActions;
      }
      else {
        inDim = this.neurons[layer-1];
        outDim = this.neurons[layer];
      }

      reshapedWeightMatrices[layer] = [];
      reshapedBiasMatrices[layer] = [flattenedBiases.slice(flattenedBiasIdx, flattenedBiasIdx + outDim)];
      for (var row = 0; row < inDim; row++) {
        reshapedWeightMatrices[layer].push(flattenedWeights.slice(flattenedWeightIdx, flattenedWeightIdx + outDim))        
        flattenedWeightIdx += outDim
      }      
      flattenedBiasIdx += outDim
    }
    this.weights = reshapedWeightMatrices
    this.biases = reshapedBiasMatrices
  }

  forwardProp(inputs) {
    var cache = {}
    var currentLayer = inputs
    for (var layer = 0; layer <= this.neurons.length; layer++) {
      currentLayer = matrixElementwiseOperation(
        matrixMultiplication(currentLayer, this.weights[layer]),
        Array(this.weights[layer].length).fill(this.biases[layer][0]),
        "Addition"
      )
      cache["z" + layer] = currentLayer
      if (this.useActivation && layer < this.neurons.length) {
        currentLayer = this.activationFunction(currentLayer, false)
        cache["a" + layer] = currentLayer
      }
      else if (this.useActivation && layer === this.neurons.length) {
        currentLayer = this.outputActivation(currentLayer, false)
        cache["y"] = currentLayer
      }
    }
    return [currentLayer, cache]
  }

  selectAction(inputs) {
    const actionProbabilities = this.forwardProp(inputs)[0]
    return policyMapping[this.metadata.policyMethod](actionProbabilities)
  }
}
