# ***************************************************
# * CUHK, CSE, CMSC 5743 TensorRT Example Code
# *
# * Qi SUN, qsun@cse.cuhk.edu.hk
# ***************************************************

import os
import torch
import time
import argparse
import tensorrt as trt
import torchvision
import torchvision.transforms as transforms
import pycuda
# automatically performs all the steps necessary to get CUDA environment ready for computing
import pycuda.autoinit
import numpy as np
from PIL import Image
from sklearn.metrics import mean_squared_error

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def get_args():
	parser = argparse.ArgumentParser()
	parser.add_argument("--onnx_file_path", 	type=str)
	parser.add_argument("--engine_file_path",	type=str)
	parser.add_argument("--verbose",			action='store_true')
	parser.add_argument("--baseline",			action='store_true')
	parser.add_argument("--input_image_path",	type=str)
	args = parser.parse_args()
	return args

def alexnet_onnx(onnx_file_path, verbose):
	if not os.path.exists(onnx_file_path):
		print("Generating ONNX file for AlexNet: ", onnx_file_path)
		dummy_input = torch.randn(1, 3, 224, 224, device='cuda')
		model = torchvision.models.alexnet(pretrained=True).cuda()
		input_names = ["actual_input_1"] + ["learned_%d" % i for i in range(16)]
		output_names = ["output1"]
		torch.onnx.export(model, dummy_input, onnx_file_path, verbose=verbose, input_names=input_names, output_names=output_names)

	print("Loading ONNX file from: ", onnx_file_path)
	onnx_model = open(onnx_file_path, 'rb')
	return onnx_model

def alexnet_engine(engine_file_path, onnx_model):
	model_engine = None
	if os.path.exists(engine_file_path):
		print("Reading engine from: ", engine_file_path)
		# deserialize the engine file
		with open(engine_file_path, "rb") as model, trt.Runtime(TRT_LOGGER) as runtime:
			model_engine = runtime.deserialize_cuda_engine(model.read())
	else:
		with trt.Builder(TRT_LOGGER) as builder:
			# Specify that the network should be created with an explicit batch dimension
			EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
			network = builder.create_network(EXPLICIT_BATCH)
			parser = trt.OnnxParser(network, TRT_LOGGER)
			builder.max_workspace_size = 1 << 28
			builder.max_batch_size = 1
			parser.parse(onnx_model.read())
			model_engine = builder.build_cuda_engine(network)
			with open(engine_file_path, "wb") as f:
				f.write(model_engine.serialize())
	return model_engine

def get_image(input_image_path):
	print("Get image: ", input_image_path)
	image = Image.open(input_image_path)
	print("Input image format {}, size {}, mode {}.".format(image.format, image.size, image.mode))
	preprocess = transforms.Compose([
		transforms.Resize(256),
		transforms.CenterCrop(224),
		transforms.ToTensor(),
		transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
	])
	image = preprocess(image)
	print("Image size after preprocessing: ", image.shape)
	image_binary = np.array(image, dtype=np.float32, order='C')
	return image_binary

def allocate_buffers(model_engine):
	bindings 	= []
	inputs 		= []
	outputs 	= []
	# binding: describe the input and output ports of the engine
	for binding in model_engine:
		data_size 		= trt.volume(model_engine.get_binding_shape(binding)) * model_engine.max_batch_size
		data_type 		= trt.nptype(model_engine.get_binding_dtype(binding))
		host_memory 	= pycuda.driver.pagelocked_empty(data_size, data_type)
		device_memory 	= pycuda.driver.mem_alloc(host_memory.nbytes)
		# stored the memory index in CUDA context
		bindings.append(int(device_memory))
		if model_engine.binding_is_input(binding):
			inputs.append({"host": host_memory, "device": device_memory})
		else:
			outputs.append({"host": host_memory, "device": device_memory})
	return inputs, outputs, bindings

def do_inference(context, bindings, inputs, outputs, stream):
	start = time.clock()
	# send inputs to device (GPU)
	for input in inputs:
		pycuda.driver.memcpy_htod_async(input["device"], input["host"], stream)
	# do inference
	context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
	# send outputs to host (CPU)
	for output in outputs:
		pycuda.driver.memcpy_dtoh_async(output["host"], output["device"], stream)
	# waot for all activity on this stream to cease, then return.
	stream.synchronize()
	end = time.clock()
	return [output["host"] for output in outputs], end - start

def post_process(outputs):
	output = torch.Tensor(outputs[0])
	return torch.nn.functional.softmax(output, dim=0).argmax(dim=0)

def alexnet_baseline(input_image_path):
	model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)
	model.eval()
	input_image = Image.open(input_image_path)
	print("Input image format {}, size {}, mode {}.".format(input_image.format, input_image.size, input_image.mode))
	preprocess = transforms.Compose([
		transforms.Resize(256),
		transforms.CenterCrop(224),
		transforms.ToTensor(),
		transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
	])
	input_tensor = preprocess(input_image)
	input_batch = input_tensor.unsqueeze(0)

	if torch.cuda.is_available():
		input_batch = input_batch.to('cuda')
		model.to('cuda')
	output = None
	start = time.clock()
	with torch.no_grad():
		output = model(input_batch)
	end = time.clock()

	return output, end-start
