#include "matlib.h"
#include <assert.h>
#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <stdlib.h>

// Scalar type

struct scalar {
  double value;
};

struct scalar* read_scalar(const char* file) {
  // Open file for reading, in binary mode.
  FILE *f = fopen(file, "rb");

  // Fail quickly if we could not open the file.
  if (f == NULL) {
    return NULL;
  }

  // Make room for reading the four-byte magic number.
  char magic[4];

  // Read the magic number into 'magic'; jumping to the end if it goes
  // wrong.
  if (fread(magic, 4, 1, f) != 1) {
    return NULL;
  }

  // Check if the magic number is as expected.  Here we exploit that
  // the magic number corresponds to an ASCII string.
  if (memcmp(magic,"SCLR",4) != 0) {
    return NULL;
  }

  // Read the scalar.  It's 1 element of size sizeof(double).
  double v;
  if (fread(&v, sizeof(double), 1, f) != 1) {
    return NULL;
  }

  // Allocate the scalar struct and write 'double' we read into it.
  struct scalar* ret = malloc(sizeof(struct scalar));
  ret->value = v;

  // Even if we get here due to a failure, we should still remember to
  // close the file.
  fclose(f);
  return ret;
}

int write_scalar(const char* file, struct scalar* s) {
  // Open file for writing, in binary mode.
  FILE* f = fopen(file, "wb");
  int ret = 0;

  // If we could not open it, return immediately.
  if (f == NULL) {
    return 1;
  }

  // Write the magic number to the file.  We exploit that the magic
  // number corresponds to a four-byte ASCII String.
  fprintf(f, "SCLR");
  // Then write the actual value.  Check that it goes well.
  if (fwrite(&s->value, sizeof(double), 1, f) != 1) {
    ret = 1;
  }

  // Remember to close the file.
  fclose(f);
  return ret;
}

void free_scalar(struct scalar* s) {
  // Free the struct allocated in read_scalar().
  free(s);
}

double scalar_value(struct scalar* s) {
  // Self-explanatory, I hope.
  return s->value;
}

// Vector type

struct vector {
  int n;
  double *elements;
};

struct vector* read_vector(const char* file) {
  // Open file for reading, in binary mode.
  FILE *f = fopen(file, "rb");

  // Fail quickly if we could not open the file.
  if (f == NULL) {
    return NULL;
  }

  // Make room for reading the four-byte magic number.
  char magic[4];

  // Read the magic number into 'magic'
  // wrong.
  if (fread(magic, 4, 1, f) != 1) {
    return NULL;
  }

  // Check if the magic number is as expected.  Here we exploit that
  // the magic number corresponds to an ASCII string.
  if (memcmp(magic,"VDNS",4) != 0) {
    return NULL;
  }

  int N;
  if (fread(&N, sizeof(int), 1, f) != 1) {
    return NULL;
  }

  double *elems = malloc(N * sizeof(double));
  if (fread(elems, N*sizeof(double), 1, f) != 1) {
    return NULL;
  }

  struct vector* vec = malloc(sizeof(struct vector));
  vec->n = N;
  vec->elements = elems;
  fclose(f);
  return vec;



}



int write_vector(const char* file, struct vector* v) {
  // Open file for writing, in binary mode.
  FILE* f = fopen(file, "wb");
  int ret = 0;

  // If we could not open it, return immediately.
  if (f == NULL) {
    return 1;
  }

  // Write the magic number to the file.  We exploit that the magic
  // number corresponds to a four-byte ASCII String.
  fprintf(f, "VDNS");
  // Write n
  if (fwrite(&v->n, sizeof(int), 1, f) != 1) {
    ret = 1;
  }
  //write all elements
  if (fwrite(v->elements, v->n*sizeof(double), 1, f) != 1) {
    ret = 1;
  }

  fclose(f);
  return ret;
}

void free_vector(struct vector* v) {
  free(v->elements);
  free(v);
}

int vector_n(struct vector* v) {
  return v->n;
}

double vector_idx(struct vector* v, int i) {
  return (v->elements)[i];
}

// Dense matrices

struct matrix_dense {
  int n;
  int m;
  double* elements;
};

struct matrix_dense* read_matrix_dense(const char* file) {
  // Open file for reading, in binary mode.
  FILE *f = fopen(file, "rb");

  // Fail quickly if we could not open the file.
  if (f == NULL) {
    return NULL;
  }

  // Make room for reading the four-byte magic number.
  char magic[4];

  // Read the magic number into 'magic'
  if (fread(magic, 4, 1, f) != 1) {
    return NULL;
  }

  // Check if the magic number is as expected.  Here we exploit that
  // the magic number corresponds to an ASCII string.
  if (memcmp(magic,"MDNS",4) != 0) {
    return NULL;
  }

  int N;
  if (fread(&N, sizeof(int), 1, f) != 1) {
    return NULL;
  }

  int M;
  if (fread(&M, sizeof(int), 1, f) != 1) {
    return NULL;
  }


  double *elems = malloc(M * N * sizeof(double));
  if (fread(elems, N*M*sizeof(double), 1, f) != 1) {
    return NULL;
  }

  struct matrix_dense* mat_d = malloc(sizeof(struct matrix_dense));
  mat_d->n = N;
  mat_d->m = M;
  mat_d->elements = elems;

  fclose(f);
  return mat_d;

}


//TODO
int write_matrix_dense(const char* file, struct matrix_dense* X) {
  // Open file for writing, in binary mode.
  FILE* f = fopen(file, "wb");
  int ret = 0;

  // If we could not open it, return immediately.
  if (f == NULL) {
    return 1;
  }

  // Write the magic number to the file.  We exploit that the magic
  // number corresponds to a four-byte ASCII String.
  fprintf(f, "MDNS");

  // Write n
  if (fwrite(&X->n, sizeof(int), 1, f) != 1) {
    ret = 1;
  }
  // Write m
  if (fwrite(&X->m, sizeof(int), 1, f) != 1) {
    ret = 1;
  }
  //write all elements
  if (fwrite(X->elements, X->n*X->m*sizeof(double), 1, f) != 1) {
    ret = 1;
  }

  fclose(f);
  return ret;

}

void free_matrix_dense(struct matrix_dense* X) {
  free(X->elements);
  free(X);
}

int matrix_dense_n(struct matrix_dense* X) {
  return X->n;
}

int matrix_dense_m(struct matrix_dense* X) {
  return X->m;
}

double matrix_dense_idx(struct matrix_dense* X, int i, int j) {
  //Using row major index formula: N*i + j
  return X->elements[(X->m)*i + j];
}

// Sparse matrices in CSR format

struct matrix_csr {
  int n;
  int m;
  int nnz; // Number of nonzero entries.
  double *v; // Nonzero values, of length nnz.
  int *v_col; // Column of each value, of length nnz.
  int *row_start; // For each row number, offset in v where that row
                  // starts.
};

struct matrix_csr* read_matrix_csr(const char* file) {
  // Open file for reading, in binary mode.
  FILE *f = fopen(file, "rb");

  // Fail quickly if we could not open the file.
  if (f == NULL) {
    return NULL;
  }

  // Make room for reading the four-byte magic number.
  char magic[4];

  // Read the magic number into 'magic'; jumping to the end if it goes
  // wrong.
  if (fread(magic, 4, 1, f) != 1) {
    return NULL;
  }

  // Check if the magic number is as expected.  Here we exploit that
  // the magic number corresponds to an ASCII string.
  if (memcmp(magic,"MCSR",4) != 0) {
    return NULL;
  }

  int N;
  if (fread(&N, sizeof(int), 1, f) != 1) {
    return NULL;
  }
  int M;
  if (fread(&M, sizeof(int), 1, f) != 1) {
    return NULL;
  }
  int NNZ;
  if (fread(&NNZ, sizeof(int), 1, f) != 1) {
    return NULL;
  }

  double *V = malloc(NNZ * sizeof(double));
  if (fread(V, NNZ*sizeof(double), 1, f) != 1) {
    return NULL;
  }

  int *V_COL = malloc(NNZ * sizeof(int));
  if (fread(V_COL, NNZ*sizeof(int), 1, f) != 1) {
    return NULL;
  }

  int *ROW_START = malloc(N*sizeof(int));
  if (fread(ROW_START, N*sizeof(int), 1, f) != 1) {
    return NULL;
  }

  struct matrix_csr *MC = malloc(sizeof(struct matrix_csr));
  MC->n = N;
  MC->m = M;
  MC->nnz = NNZ;
  MC->v = V;
  MC->v_col = V_COL;
  MC->row_start = ROW_START;

  fclose(f);
  return MC;






}

//TODO
int write_matrix_csr(const char* file, struct matrix_csr* X) {
  // Open file for writing, in binary mode.
  FILE* f = fopen(file, "wb");
  int ret = 0;

  // If we could not open it, return immediately.
  if (f == NULL) {
    return 1;
  }

  // Write the magic number to the file.  We exploit that the magic
  // number corresponds to a four-byte ASCII String.
  fprintf(f, "MDNS");

  // Write n
  if (fwrite(&X->n, sizeof(int), 1, f) != 1) {
    ret = 1;
  }
  // Write m
  if (fwrite(&X->m, sizeof(int), 1, f) != 1) {
    ret = 1;
  }
  // Write nnz
  if (fwrite(&X->nnz, sizeof(int), 1, f) != 1) {
    ret = 1;
  }
  // Write v
  if (fwrite(X->v, X->nnz*sizeof(double), 1, f) != 1) {
    ret = 1;
  }
  // Write v_col
  if (fwrite(X->v_col, X->nnz*sizeof(int), 1, f) != 1) {
    ret = 1;
  }
  // Write row_start
  if (fwrite(X->row_start, X->n*sizeof(int), 1, f) != 1) {
    ret = 1;
  }

  return ret;
}


void free_matrix_csr(struct matrix_csr* A) {
  free(A->v);
  free(A->v_col);
  free(A->row_start);
  free(A);
}

int matrix_csr_n(struct matrix_csr* X) {
  return X->n;
}

int matrix_csr_m(struct matrix_csr* X) {
  return X->m;
}

double matrix_csr_idx(struct matrix_csr* A, int i, int j) {
  if (i == A->n - 1){
    for(int k = A->row_start[i]; k < A->nnz; k++){
      //If any element 
        if (A->v_col[k] == j) {
            return A->v[k];
        }
    }
  }
  //Look between row_start[i] and row_start[i+1], all values are there.
  for (int k = A->row_start[i]; k < A->row_start[i + 1]; k++) {
        //If any element 
        if (A->v_col[k] == j) {
            return A->v[k];
        }
        
    }
    //If no match found, just return 0.
    return 0;
}

// Operations


struct scalar* mul_vv(struct vector* x, struct vector* y) {
  //Check dimensions
  if(x->n != y->n){
    return NULL;
  }
  double res = 0;
#pragma omp parallel for reduction(+: res)
  for (int i = 0; i < x->n; i++){
    res += x->elements[i] * y->elements[i];
  }

  //Make result scalar
  struct scalar* res_scalar = malloc(sizeof(struct scalar));
  res_scalar->value = res;
  return res_scalar;
}

struct vector* mul_mv(struct matrix_dense* A, struct vector* x) {
  //Check dimensions
  if(x->n != A->m){
    return NULL;
  }
//Initialize a result matrix with size n filled with 0
  double *res = malloc(A->n*sizeof(double));
  for(int i = 0; i < A->n; i++){
    res[i] = 0;
  }

//Multiplying the matrix with the vector
#pragma omp parallel for collapse(2) reduction(+:res[:A->n])
  for(int i = 0; i < A->n; i++){
    for(int j = 0; j < A->m; j++){
      res[i] += A->elements[A->m*i + j] * x->elements[j];
    }
  }
  //Making the result vector
  struct vector *v_res = malloc(sizeof(struct vector));
  v_res->elements = res;
  v_res->n = A->n;
  

  return v_res;
}

struct vector* mul_mTv(struct matrix_dense* A, struct vector* x) {
  if(x->n != A->n){
    return NULL;
  }
  //Initialize a result array with size m filled with 0
  double *res = malloc(A->m*sizeof(double));
#pragma omp parallel for
  for(int i = 0; i < A->m; i++){
    res[i] = 0;
  }
 
  //Traversing through columns, but instead of doing directly I do one row at a time to improve spatial locality.
#pragma omp parallel for collapse(2) reduction(+:res[:A->m])
  for(int i = 0; i < A->n; i++){
    for (int j = 0; j < A->m; j++){
      res[j] += A->elements[A->m*i + j] * x->elements[i];
    }
  }

  //Making the result vector
  struct vector *v_res = malloc(sizeof(struct vector));
  v_res->elements = res;
  v_res->n = A->m;
  

  return v_res;

}

struct vector* mul_spmv(struct matrix_csr* A, struct vector* x) {
  //Check dimensions
  if(x->n != A->m){
    return NULL;
  }
  //Make result array
  double *y = malloc(sizeof(double) * A->n);

#pragma omp parallel for
  for (int i = 0; i < A->n; i++){
    int start = A->row_start[i];
    int end;
    if (i+1 < A->n){
      end = A->row_start[i+1];
    } else{
      end = A->nnz;
    }
    y[i] = 0;

    for (int l = start; l < end; l++){
      y[i] += A->v[l] * x->elements[A->v_col[l]];
    }
  }

  //Making the result vector
  struct vector *v_y = malloc(sizeof(struct vector));
  v_y->elements = y;
  v_y->n = A->n;

  return v_y;
}


struct vector* mul_spmTv(struct matrix_csr* A, struct vector* x) {
  //Check dimensions
  if(x->n != A->n){
    return NULL;
  }
  //Make result array and fill with 0's
  double *y = malloc(sizeof(double) * A->m);
#pragma omp parallel for
  for (int j = 0; j < A->m; j++){
    y[j] = 0;
  }

//To run multiple threads, such that they dont get into race conditions.
#pragma omp parallel for reduction(+:y[:A->m])
  for(int i = 0; i < A->n; i++){
    int start = A->row_start[i];
    int end;
    if (i + 1 < A->n){
      end = A->row_start[i+1];
    } else{
      end = A->nnz;
    }
    for (int l = start; l < end; l++){
      y[A->v_col[l]] += A->v[l] * x->elements[i];
    }
  }

  //Making the result vector
  struct vector *v_y = malloc(sizeof(struct vector));
  v_y->elements = y;
  v_y->n = A->m;

  return v_y;
}
