from __future__ import absolute_import, print_function

import tvm
import tvm.testing
from tvm import te
import numpy as np

tgt_host = "llvm"
tgt = "cuda"

n = te.var("n") # create a new variable with specified name
A = te.placeholder((n, ), name="A")
B = te.placeholder((n, ), name="B")
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") # construct a new tensor by computing over the shape domain
print("type(C): \t", type(C))

# By default, the computation process is:
# for (int i = 0; i < n; i++) {
# 	C[i] = A[i] + B[i]
# }

s = te.create_schedule(C.op) # create a schedule for list of ops

# To split the loop over the first axis of C:
# for (int bx = 0; bx < ceil(n / 64); ++bx) {
#   for (int tx = 0; tx < 64; ++tx) {
#     int i = bx * 64 + tx;
#     if (i < n) {
#       C[i] = A[i] + B[i];
#     }
#   }
# }

bx, tx = s[C].split(C.op.axis[0], factor=64)

# Bind the iteration axis bx and tx to threads in the GPU compute grid.
# GPU specific operation
if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"):
	s[C].bind(bx, te.thread_axis("blockIdx.x"))
	s[C].bind(tx, te.thread_axis("threadIdx.x"))


# Use tvm.build to create a function
# fadd is a host wrapper function which contains a reference to the generated device function.
fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")


# Run the function
ctx = tvm.context(tgt, 0)
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
fadd(a, b, c)
# verify correctness
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
print("verify correctness")
print("c.asnumpy(): \n", c.asnumpy())

# Inspect the generated code
if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"):
	dev_module = fadd.imported_modules[0]
	print("-----GPU code-----")
	print(dev_module.get_source())
else:
	print(fadd.get_source())


# save compiled module
from tvm.contrib import cc
from tvm.contrib import utils
temp = utils.tempdir() # Create temp dir which deletes the contents when exit.
fadd.save(temp.relpath("myadd.o"))
if tgt == "cuda":
	fadd.imported_modules[0].save(temp.relpath("myadd.ptx"))
if tgt == "rocm":
	fadd.imported_modules[0].save(temp.relpath("myadd.hsaco"))
if tgt.startswith("opencl"):
	fadd.imported_modules[0].save(temp.relpath("myadd.cl"))
# CPU (host) module is saved as a shared library (.so)
cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")]) # create a shared library
print(temp.listdir())

# load compiled module
fadd1 = tvm.runtime.load_module(temp.relpath("myadd.so"))
if tgt == "cuda":
	fadd1_dev = tvm.runtime.load_module(temp.relpath("myadd.ptx"))
	fadd1.import_module(fadd1_dev)

if tgt == "rocm":
	fadd1_dev = tvm.runtime.load_module(temp.relpath("myadd.hsaco"))
	fadd1.import_module(fadd1_dev)

if tgt.startswith("opencl"):
	fadd1_dev = tvm.runtime.load_module(temp.relpath("myadd.cl"))
	fadd1.import_module(fadd1_dev)

fadd1(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
print("load compiled module")
print("c.asnumpy(): \n", c.asnumpy())

# Pack the packages into one library
fadd.export_library(temp.relpath("myadd_pack.so"))
fadd2 = tvm.runtime.load_module(temp.relpath("myadd_pack.so"))
fadd2(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
print("load the library")
print("c.asnumpy(): \n", c.asnumpy())