import {InferenceSession} from "onnxruntime-web";


/**
 * Loads the model, creates a session, loads current video frame and runs inference
 */
export default class ONNXRunner {
    constructor(dataLoader) {
        this.dataLoader = dataLoader;
        this.session = {};
    }

    async initSession() {
        try {
            this.session = await InferenceSession.create(
                this.dataLoader.modelUrl,
                {
                    executionProviders: ["webgl"],
                }
            );
            console.log('ONNXRunner initialized: ' + this.session);
            return true;
        } catch (e) {
            console.log('model could not be loaded!');
            console.log(e);
            return false;
        }
    }

   /* async loadModel() {
        try {
            await this.session.loadModel(this.dataLoader.modelUrl);
            console.log('ONNXRunner initialized: ' + this.session);
            return true;
        } catch (e) {
            console.log('model could not be loaded!');
            return false;
        }
    }*/

    async runTestImage() {
        console.log('starting new prediction');
        const start = performance.now();
        // load new data
        const inputTensor = await this.dataLoader.loadImage();
        console.log('input tensor');
        console.log(inputTensor);
        if (inputTensor !== undefined) {
            // execute the model
            const feeds = {};
            feeds[this.session.inputNames[0]] = inputTensor;
            const outputMap = await this.session.run(feeds).catch((error) => {
                console.log(error);
            });
            // console.log(outputMap);
            // const results = outputMap.values().next().value.data;
            const results = outputMap[this.session.outputNames[0]].data;
            console.log('model output tensor: ');
            console.log(results);
            // consume the output shape (1x125x13x13)
            console.log('run took: ' + (performance.now() - start) + 'ms');
            return results;
        }
        return undefined;
    }


    async runNextClip() {
        console.log('starting new prediction');
        const start = performance.now();
        // load new data
        const videoInputTensor = await this.dataLoader.loadVideo();
        // const audioInputTensor = await this.dataLoader.loadAudio();
        console.log('video input tensor: ');
        console.log(videoInputTensor);
        // console.log(audioInputTensor);
        if (videoInputTensor !== undefined ){ //&& audioInputTensor !== undefined) {
            // execute the model
            const feeds = {};
            feeds[this.session.inputNames[0]] = videoInputTensor;
            // feeds[this.session.inputNames[1]] = audioInputTensor;
            const outputMap = await this.session.run(feeds).catch((error) => {
                console.log(error);
            });
            // console.log(outputMap);
            const results = outputMap[this.session.outputNames[0]].data;
            console.log('model output tensor: ');
            console.log(results);
            console.log('run took: ' + (performance.now() - start) + 'ms');
            return results;
        }
        return undefined;
    }


}





