programs/tensorkit.js

const FileWriter = require('../writers/filewriter');

/**
 * Graph used to generate TensorKit archtecture file.
 * @memberof mentality.programs
 */
class TensorKitProgram {
  constructor(args) {
    const {
      writerOpts = {},
      writer = new FileWriter(writerOpts),
    } = args;

    this.writer = writer;
  }

  setGraph(graph) {
    this.graph = graph;
  }

  compile(mopts = {}) {
    const opts = mopts;
    opts.writer = this.writer;
    this.writer.open();

    const imports = [
      'from __future__ import absolute_import',
      'from __future__ import division',
      'from __future__ import print_function',
      '',
      'import tensorflow as tf',
      'import keras',
      'from tensorkit.base import ArchitectBase',
      '',
      '',
    ];

    this.writer.emitLines(imports);
    this.writer.emitLine('class Architect(ArchitectBase):');
    this.writer.incIndent();
    this.writer.emitLine('def build_graph(self, hypes, input, phrase):');
    this.writer.incIndent();
    this.graph.compile(opts);

    const graphChildren = this.graph.children;

    function pyReturn(child) {
      return `'${child.getName()}': ${child.getName()}`;
    }

    this.writer.emitLine(`return {${graphChildren.map(pyReturn).join(',')}}`);
  }
}

module.exports = TensorKitProgram;