#!/usr/bin/python

# -*- coding: utf-8 -*-

# For better print formatting
from __future__ import print_function

# Imports
from pycompss.api.parallel import parallel
from pycompss.api.constraint import constraint
from pycompss.api.task import task
from pycompss.api.api import compss_barrier, compss_wait_on

import numpy as np


# TODO: Extend for non-square matrices. There is no size verifications.


############################################
# MATRIX GENERATION
############################################

def generate_matrix(m_size, b_size, block_type='random'):
    mat = []
    for i in range(m_size):
        mat.append([])
        for _ in range(m_size):
            mat[i].append(create_block_task(b_size, block_type=block_type))
    return mat


def generate_identity(m_size, b_size):
    mat = []
    for i in range(m_size):
        mat.append([])
        for _ in range(0, i):
            mat[i].append(create_block_task(b_size, block_type='zeros'))
        mat[i].append(create_block_task(b_size, block_type='identity'))
        for _ in range(i + 1, m_size):
            mat[i].append(create_block_task(b_size, block_type='zeros'))
    return mat


@constraint(ComputingUnits="${ComputingUnits}")
@task(returns=list)
def create_block_task(b_size, block_type='random'):
    return create_block(b_size, block_type=block_type)


def create_block(b_size, block_type='random'):
    if block_type == 'zeros':
        block = np.matrix(np.zeros((b_size, b_size)), dtype=np.float64, copy=False)
    elif block_type == 'identity':
        block = np.matrix(np.identity(b_size), dtype=np.float64, copy=False)
    else:
        block = np.matrix(np.random.random((b_size, b_size)), dtype=np.float64, copy=False)
    return block


############################################
# MAIN FUNCTION
############################################

# [COMPSs Autoparallel] Begin Autogenerated code
import math

from pycompss.api.api import compss_barrier, compss_wait_on, compss_open
from pycompss.api.task import task
from pycompss.api.parameter import *


@task(var3=IN, returns=2)
def S1(var3):
    return qr(var3, transpose=True)


@task(var2=IN, var3=IN, returns=1)
def S2(var2, var3):
    return dot(var2, var3, transpose_b=True)


@task(var2=IN, var3=IN, returns=1)
def S3(var2, var3):
    return dot(var2, var3)


@task(var7=IN, var8=IN, b_size=IN, returns=6)
def S4(var7, var8, b_size):
    return little_qr(var7, var8, b_size, transpose=True)


@task(b_size=IN, returns=1)
def S5(b_size):
    return create_block(b_size, block_type='zeros')


@task(var2=IN, var3=IN, var4=IN, returns=1)
def S6(var2, var3, var4):
    return multiply_single_block(var2, var3, var4, transpose_b=False)


@task(var2=IN, var3=IN, var4=IN, returns=1)
def S7(var2, var3, var4):
    return multiply_single_block(var2, var3, var4, transpose_b=False)


@task(b_size=IN, returns=1)
def S8(b_size):
    return create_block(b_size, block_type='zeros')


@task(var2=IN, var3=IN, var4=IN, returns=1)
def S9(var2, var3, var4):
    return multiply_single_block(var2, var3, var4, transpose_b=False)


@task(var2=IN, var3=IN, var4=IN, returns=1)
def S10(var2, var3, var4):
    return multiply_single_block(var2, var3, var4, transpose_b=False)


@task(var2=IN, returns=1)
def S11(var2):
    return copy_reference(var2)


@task(var2=IN, returns=1)
def S12(var2):
    return copy_reference(var2)


@task(b_size=IN, returns=1)
def S13(b_size):
    return create_block(b_size, block_type='zeros')


@task(var2=IN, var3=IN, var4=IN, returns=1)
def S14(var2, var3, var4):
    return multiply_single_block(var2, var3, var4, transpose_b=True)


@task(var2=IN, var3=IN, var4=IN, returns=1)
def S15(var2, var3, var4):
    return multiply_single_block(var2, var3, var4, transpose_b=True)


@task(b_size=IN, returns=1)
def S16(b_size):
    return create_block(b_size, block_type='zeros')


@task(var2=IN, var3=IN, var4=IN, returns=1)
def S17(var2, var3, var4):
    return multiply_single_block(var2, var3, var4, transpose_b=True)


@task(var2=IN, var3=IN, var4=IN, returns=1)
def S18(var2, var3, var4):
    return multiply_single_block(var2, var3, var4, transpose_b=True)


@task(var2=IN, returns=1)
def S19(var2):
    return copy_reference(var2)


@task(var2=IN, returns=1)
def S20(var2):
    return copy_reference(var2)


def qr_blocked(a, m_size, b_size):
    if __debug__:
        a = compss_wait_on(a)
        print('Matrix A:')
        print(a)
    q = generate_identity(m_size, b_size)
    r = copy_blocked(a)
    q_act = [None]
    q_sub = [[np.matrix(np.array([0])), np.matrix(np.array([0]))], [np.matrix(np.array([0])), np.matrix(np.array([0]))]]
    aux = [None, None]
    if m_size >= 1:
        lbp = 0
        ubp = m_size - 2
        for t1 in range(0, m_size - 2 + 1):
            q_act[0], r[t1][t1] = S1(r[t1][t1])
            lbp = 0
            ubp = t1
            for t3 in range(lbp, ubp + 1):
                q[t3][t1] = S2(q[t3][t1], q_act[0])
            lbp = t1 + 1
            ubp = m_size - 1
            for t3 in range(lbp, ubp + 1):
                r[t1][t3] = S3(q_act[0], r[t1][t3])
                q[t3][t1] = S2(q[t3][t1], q_act[0])
            lbp = t1 + 1
            ubp = m_size - 1
            for t3 in range(t1 + 1, m_size - 1 + 1):
                q_sub[0][0], q_sub[0][1], q_sub[1][0], q_sub[1][1], r[t1][t1], r[t3][t1] = S4(r[t1][t1], r[t3][t1],
                    b_size)
                lbp = t1 + 1
                ubp = m_size - 1
                for t6 in range(t1 + 1, m_size - 1 + 1):
                    aux[1] = S8(b_size)
                    aux[1] = S9(q_sub[1][0], r[t1][t6], aux[1])
                    aux[1] = S10(q_sub[1][1], r[t3][t6], aux[1])
                    aux[0] = S5(b_size)
                    aux[0] = S6(q_sub[0][0], r[t1][t6], aux[0])
                    aux[0] = S7(q_sub[0][1], r[t3][t6], aux[0])
                    r[t3][t6] = S12(aux[1])
                    r[t1][t6] = S11(aux[0])
                lbp = 0
                ubp = m_size - 1
                for t6 in range(0, m_size - 1 + 1):
                    aux[1] = S16(b_size)
                    aux[1] = S17(q[t6][t1], q_sub[1][0], aux[1])
                    aux[1] = S18(q[t6][t3], q_sub[1][1], aux[1])
                    aux[0] = S13(b_size)
                    aux[0] = S14(q[t6][t1], q_sub[0][0], aux[0])
                    aux[0] = S15(q[t6][t3], q_sub[0][1], aux[0])
                    q[t6][t3] = S20(aux[1])
                    q[t6][t1] = S19(aux[0])
        q_act[0], r[m_size - 1][m_size - 1] = S1(r[m_size - 1][m_size - 1])
        lbp = 0
        ubp = m_size - 1
        for t3 in range(lbp, ubp + 1):
            q[t3][m_size - 1] = S2(q[t3][m_size - 1], q_act[0])
    compss_barrier()
    if __debug__:
        input_a = join_matrix(compss_wait_on(a))
        q_res = join_matrix(compss_wait_on(q))
        r_res = join_matrix(compss_wait_on(r))
        print('Matrix A:')
        print(input_a)
        print('Matrix Q:')
        print(q_res)
        print('Matrix R:')
        print(r_res)
    if __debug__:
        check_result(q_res, r_res, input_a)

# [COMPSs Autoparallel] End Autogenerated code


############################################
# MATHEMATICAL FUNCTIONS
############################################

def qr(a, mode='reduced', transpose=False):
    # Numpy call
    from numpy.linalg import qr as qr_numpy
    q, r = qr_numpy(a, mode=mode)

    # Transpose if requested
    if transpose:
        q = np.transpose(q)

    return q, r


def dot(a, b, transpose_result=False, transpose_b=False):
    if transpose_b:
        b = np.transpose(b)

    if transpose_result:
        return np.transpose(np.dot(a, b))
    else:
        return np.dot(a, b)


def little_qr(a, b, b_size, transpose=False):
    # Numpy call
    from numpy.linalg import qr as qr_numpy
    current_a = np.bmat([[a], [b]])
    sub_q, sub_r = qr_numpy(current_a, mode='complete')

    new_a = sub_r[0:b_size]
    new_b = sub_r[b_size:2 * b_size]
    sub_q = split_matrix(sub_q, 2)

    # Transpose if requested (care indexes)
    if transpose:
        return np.transpose(sub_q[0][0]), np.transpose(sub_q[1][0]), np.transpose(sub_q[0][1]), np.transpose(
            sub_q[1][1]), new_a, new_b
    else:
        return sub_q[0][0], sub_q[0][1], sub_q[1][0], sub_q[1][1], new_a, new_b


def multiply_single_block(a, b, c, transpose_b=False):
    # Transpose if requested
    if transpose_b:
        b = np.transpose(b)

    # Numpy operation
    return c + a * b


############################################
# BLOCK HANDLING FUNCTIONS
############################################

def copy_blocked(a, transpose=False):
    res = []
    for i in range(len(a)):
        res.append([])
        for j in range(len(a[0])):
            res[i].append(np.matrix([0]))
    for i in range(len(a)):
        for j in range(len(a[0])):
            if transpose:
                res[j][i] = a[i][j]
            else:
                res[i][j] = a[i][j]
    return res


def copy_reference(block):
    return block


def split_matrix(a, m_size):
    b_size = len(a) / m_size

    new_mat = [[None for _ in range(m_size)] for _ in range(m_size)]
    for i in range(m_size):
        for j in range(m_size):
            new_mat[i][j] = np.matrix(a[i * b_size:(i + 1) * b_size, j * b_size:(j + 1) * b_size])
    return new_mat


def join_matrix(a):
    res = np.matrix([[]])
    for i in range(0, len(a)):
        current_row = a[i][0]
        for j in range(1, len(a[i])):
            current_row = np.bmat([[current_row, a[i][j]]])
        if i == 0:
            res = current_row
        else:
            res = np.bmat([[res], [current_row]])
    return np.matrix(res)


def check_result(q_res, r_res, input_a):
    is_ok = np.allclose(q_res * r_res, input_a)
    print("Result check status: " + str(is_ok))

    if not is_ok:
        raise Exception("Result does not match expected result")


############################################
# MAIN
############################################

if __name__ == "__main__":
    # Import libraries
    import time

    # Parse arguments
    import sys

    args = sys.argv[1:]
    MSIZE = int(args[0])
    BSIZE = int(args[1])

    # Log arguments if required
    if __debug__:
        print("Running QR application with:")
        print(" - MSIZE = " + str(MSIZE))
        print(" - BSIZE = " + str(BSIZE))

    # Initialize matrix
    if __debug__:
        print("Initializing matrix")
    start_time = time.time()
    A = generate_matrix(MSIZE, BSIZE)
    compss_barrier()

    # Begin computation
    if __debug__:
        print("Performing computation")
    qr_start_time = time.time()
    qr_blocked(A, MSIZE, BSIZE)
    compss_barrier(True)
    end_time = time.time()

    # Log results and time
    if __debug__:
        print("Post-process results")
    total_time = end_time - start_time
    init_time = qr_start_time - start_time
    qr_time = end_time - qr_start_time

    print("RESULTS -----------------")
    print("VERSION AUTOPARALLEL")
    print("MSIZE " + str(MSIZE))
    print("BSIZE " + str(BSIZE))
    print("DEBUG " + str(__debug__))
    print("TOTAL_TIME " + str(total_time))
    print("INIT_TIME " + str(init_time))
    print("QR_TIME " + str(qr_time))
    print("-------------------------")
