#!/usr/bin/env python3

##
# SPDX-License-Identifier: LGPL-2.1-only
#
# Copyright (C) 2023 Mediatek
#
# @file generateTest.py
# @brief generate test data and golden test result
# @author kidd-kw.chen <kidd-kw.chen@mediatek.com>

import numpy as np

def save_test_data_for_two_input_one_output(filename, shape, type):
    # Because MDLA can only handle float16 and to avoid loss of 
    # precision during inference, the values after the decimal point are discarded.
    data = np.round(np.random.uniform(-100, 100, shape).astype(type))

    with open(filename, 'wb') as file:
        file.write(data.tobytes())

    golden = data*2
    with open(filename + '.golden', 'wb') as file:
        file.write(golden.tobytes())


def save_test_data_for_two_input_two_output(filename, shape, type):
    # Because MDLA can only handle float16 and to avoid loss of 
    # precision during inference, the values after the decimal point are discarded.
    data = np.round(np.random.uniform(-100, 100, shape).astype(type))

    with open(filename, 'wb') as file:
        file.write(data.tobytes())

    golden1 = data + 10.0
    golden2 = data - 20.0
    with open(filename + '.golden', 'ab') as file: 
        file.write(golden2.tobytes())
        file.write(golden1.tobytes())


save_test_data_for_two_input_one_output('test_2_input_1_output.dat', [4,4,4,1], np.float32)
save_test_data_for_two_input_two_output('test_2_input_2_output.dat', [4,4,4,1], np.float32)
