#!/usr/bin/env python3

##
# SPDX-License-Identifier: LGPL-2.1-only
#
# Copyright (C) 2023 Mediatek
#
# @file tflite2dla.py
# @brief compile tflite to dla file and query input and output tensor info
# @author Kidd-kw.chen <Kidd-kw.chen@mediatek.com>

import os
import sys
import re
import logging
import subprocess

# Query supported backend of this platform and use the first backend to compile tflite model to dla file later.
# We expect the first backend to be MDLA.
def query_supported_backend():
  batcmd = ('ncc-tflite --arch=?')
  res = subprocess.run(batcmd, shell=True, check=False, executable='/bin/bash', stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
  result = res.stdout

  reg = re.compile("\n.*?- (.*)")
  match = re.findall(reg,result)
  if match:
    output_count = len(match)
  else:
    logging.error("FAIL to find backends")
    exit("")

  return match[0];

# Query input/output tensor information by ncc-tflite.
# Because we can't get input/output tensor information from dla file, so we have to 
# query these information when we compile tflite model to dla.
def parse_tensor_info(content, prefix):
  compile_options = ' '

  # Find Type
  s = prefix + 'type=';
  matches = re.findall(r' Type:\s*(\S+)', content) 
  if matches:
    count = len(matches)
    for i in range(count):
      t = matches[i]
      if (t == 'kTfLiteFloat32'):
        s += 'float32,'
        compile_options += ' --relax-fp32 '
      elif (t == 'kTfLiteInt32'):
        s += 'int32,'
      elif (t == 'kTfLiteUInt8'):
        s += 'uint8,'
      elif (t == 'kTfLiteInt64'):
        s += 'float64,'
      else:
        logging.error("Unknown Type")
        return ""
  else: 
    logging.error("FAIL to find Type")
    return ""

  s = s.rstrip(',')
  logging.debug(s)

  # Find Shape
  s = s + ' ' + prefix + '=';
  matches = re.findall(r' Shape:\s*\{\s*([\d\s,]+)\s*\}', content)
  if matches:
    count = len(matches)
    for i in range(count):
      numbers = matches[i].split(',')
      for j in range(len(numbers)-1, -1, -1):
        s += numbers[j] + ':'

      s = s.rstrip(':')
      s += ','
  else: 
    logging.error("FAIL to find Shape")
    return ""

  s = s.rstrip(',')
  logging.debug(s)
  return s, compile_options





logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)

# Query supported backend on platform
arch = query_supported_backend()

logging.debug(sys.argv[1]) # tflite model path
logging.debug(sys.argv[2]) # dla file path



# Get input/output tensor information
batcmd = ('ncc-tflite --arch=%s %s --show-io-info' % (arch, sys.argv[1]))
res = subprocess.run(batcmd, shell=True, check=False, executable='/bin/bash', stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
result = res.stdout

parts = result.partition("# of output tensors")
input_part = parts[0]
output_part = parts[1] + parts[2]

# Parse and convert input/output tensor information to nnstreamer filter properties
input_tensor_info_str, compile_options = parse_tensor_info(input_part, 'input')
output_tensor_info_str, compile_options = parse_tensor_info(output_part, 'output')

tensor_info_str =  input_tensor_info_str + ' ' + output_tensor_info_str


# Compile tflite to dla file
batcmd = ('ncc-tflite --arch=%s %s -o %s %s' % (arch, sys.argv[1], sys.argv[2], compile_options))
res = subprocess.run(batcmd, shell=True, check=False, executable='/bin/bash', stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
result = res.stdout

exit(tensor_info_str)



