# ***************************************************
# * CUHK, CSE, CMSC 5743 TensorRT Example Code
# *
# * Qi SUN, qsun@cse.cuhk.edu.hk
# ***************************************************

# ###################################################
# 1. get model onnx file
# 2. build cuda engine
# 3. prepare inputs
# 4. create cuda context
# 5. do inference
# 6. process outputs
# ###################################################

from lab_utils import *

def main():
    args = get_args()
    # 1. get model onnx file
    onnx_model      = alexnet_onnx(args.onnx_file_path, args.verbose)
    # 2. build cuda engine
    model_engine    = alexnet_engine(args.engine_file_path, onnx_model)
    # 3. prepare inputs
    image_binary    = get_image(args.input_image_path)
    # 4. create cuda context
    trt_output = None
    trt_time = 0
    with model_engine.create_execution_context() as context:
        inputs, outputs, bindings = allocate_buffers(model_engine)
        # A handle for a queue of operations that will be carried out in order.
        stream = pycuda.driver.Stream()
        inputs[0]["host"] = image_binary
        # 5. do inference
        outputs, trt_time = do_inference(context, bindings, inputs, outputs, stream)
        # 6. process outputs
        trt_output = torch.nn.functional.softmax(torch.Tensor(outputs[0]), dim=0)
        print("trt_label:   ", trt_output.argmax(dim=0).numpy())
        print("trt_time:     %.6f seconds." % trt_time)
        torch.cuda.empty_cache()

    if args.baseline:
        pth_output, pth_time = alexnet_baseline(args.input_image_path)
        pth_output = torch.nn.functional.softmax(pth_output[0], dim=0).cpu()
        pth_label = pth_output.argmax(dim=0).item()
        print("pth_label:   ", pth_label)
        print("pth_time:     %.6f seconds." % pth_time)
        mse = mean_squared_error(trt_output, pth_output)
        print("MSE:         ", mse)
        print("ratio:        %.2f" % (pth_time / trt_time))

if __name__ == '__main__':
    main()
