keras/layers/core/reshape.js

const lodashDefaults = require('lodash').defaults;

const Layer = require('../layer');
const utils = require('../../../utils/utils');

/**
 * Reshape layer used to reshape a tensor.
 * @extends @memberof mentality.keras.layers.Layer
 * @memberof mentality.keras.layers
 */
class Reshape extends Layer {
  /**
   * Constructor
   * @param  {Object}             args  Properties of conv layer.
   * @param  {Layer | undefined}  input Input layer.
   * @return {Conv}
   */
  constructor(args = {}, input) {
    super(args);

    this.targetShape = args.targetShape;

    this.setInput(input);
  }

  /**
   * Set a layer as an input of this layer. This function vadilate if reshaping input to target shape is achivable.
   */
  setInput(input) {
    if (!input) return;

    this.targetShape.unshift(input.shape[0]);

    const inputShape = input.shape.slice(1);
    const targetShape = this.targetShape.slice(1);
    function prod(cum, value) {
      return cum * value;
    }

    if (inputShape.reduce(prod, 1) !== targetShape.reduce(prod, 1)) {
      throw new Error(`Target shape is not compatible with input shape. Got ${utils.toString(input.shape)} and ${utils.toString(this.targetShape)}.`);
    }

    super.setInput(input);
  }

  /**
   * Compute output shape.
   * @return {Number[]} Output Shape
   */
  computeOutputShape() {
    return this.targetShape;
  }

  /**
   * Build layer.
   * @param  {Writer} writer Writer object used to build.
   * @param  {Object} opts   Options.
   */
  build(writer, opts = {}) {
    const line = `${this.getName()} = tf.reshape(tensor=${this.input.getName()}, shape=${utils.toString(this.targetShape)})`;

    writer.emitLine(line);
    writer.emitNewline();
  }

  /**
   * Export layer as JSON.
   * @param  {Object} opts  Options.
   * @return {Object}       Layer properties as JSON.
   */
  toJson(opts = {}) {
    return lodashDefaults({
      targetShape: this.targetShape,
    }, super.toJson());
  }

  /**
   * Get neurons in this layer.
   * @return {Number}   Number of neurons.
   */
  countNeurons() {
    return 0;
  }

  /**
   * Get connections in this layer.
   * @return {Number}   Number of neurons.
   */
  countWeights() {
    return 0;
  }
}

module.exports = Reshape;