import * as tf from '@tensorflow/tfjs/dist/index';
import {loadGraphModel} from '@tensorflow/tfjs-converter/dist/index';
import {imageDataToTensor} from "./DataUtils";


/**
 * Loads the model, creates a session, loads current video frame and runs inference
 */
export default class TFJSRunner {
    constructor(modelUrl, dbHandler) {
        this.dbHandler = dbHandler;
        this.model = {};
        this.frameSize = 172;
        this.modelUrl = modelUrl; // 'model_directory/model.json';
    }

    async loadModel() {
        try {
            this.model = await loadGraphModel(this.modelUrl);
            // this.model.summary();
            // this.model.outputLayers.print();
            console.log('TFJSRunner initialized: ');
            console.log(this.model);
            return true;
        } catch (e) {
            console.log('model could not be loaded!');
            console.log(e);
            return false;
        }
    }

    // video input shape: shape [1 (batch), 1 (frames), 172, 172, 3]
    // split into frames, save states and pass them separate
    async runNextClip() {
        console.log('starting new prediction');
        const start = performance.now();
        // frames = np.split(video, video.shape[1], axis=1)
        this.dbHandler.getVideoFrameDataArray().then(clip => {
            if (clip !== undefined) {
                // init states
                let states = this.prepare_states();
                let predictions = [];
                // Input shape: [1, 1, 172, 172, 3]
                clip.forEach(frame => {
                    let t_frame = imageDataToTensor(frame.imgArray, [1, 1, this.frameSize, this.frameSize, 3]);
                    // prepare input dict with states and new frames
                    let inputs = states;
                    inputs.image = t_frame;
                    // console.log('model inputs: ');
                    // console.log(inputs);
                    let outputs = this.model.predict(inputs);
                    console.log('model output tensor: ');
                    console.log(outputs);
                    // get predictions
                    let logits = outputs.pop('logits');
                    predictions.push(logits);
                    // save new states
                    states = outputs
                });
                console.log('run took: ' + (performance.now() - start) + 'ms');
                return predictions;
            } else {
                console.log('no clip found for prediction!');
                return undefined;
            }

        });
    }

    prepare_states() {
        // Define the state inputs, which is a dict that maps state names to tensors.
        // init_states = {
        //      state_name(x['name']): tf.zeros(x['shape'], dtype=x['dtype'])
        //      for x in interpreter.get_input_details()
        // }
        // del init_states['image']
        let states = {};
        for (const [, input] of Object.entries(this.model.signature.inputs)) {
            let dtype = "float";
            if (input['dtype'] === "DT_FLOAT"){
                dtype = "float32"
            }else if (input['dtype'] === "DT_INT32"){
                dtype =  "int32";
            }
            let shape = input['tensorShape']['dim'].map(function (obj){
                let size = parseInt(obj['size']);
                return size < 1 ? 1 : size;
            });
            states[input['name']] = tf.zeros(shape, dtype);
        }
        delete states.input;
        return states;
        // state_shapes = {
        //     name: ([s if s > 0 else None for s in state.shape], state.dtype)
        //     for name, state in init_states_fn(tf.constant([0, 0, 0, 0, 3])).items()
        // }
        // states_input = {
        //     name: tf.keras.Input(shape[1:], dtype=dtype, name=name)
        //     for name, (shape, dtype) in state_shapes.items()
        // }
    }
}
