# EvoDevo_script.R

# Manuscript:
# Cardoso-Moreira, M., Halbert, J., Valloton, D., Velten, B., Chen, C., Shao, Y., ... & 
# Mazin, P. V. (2019). Gene expression across mammalian organ development. Nature, 571(7766), 505-509.

# Created by Ruby Sharma
# Modified by Sajal Kumar
# Copyright (c) NMSU Song lab

# A script to build co-expression and differential co-expression networks for:
# i) Human Ectoderm vs Mouse Ectoderm
# ii) Human Primitive Streak vs Mouse Primitive Streak
# iii) Human Ectoderm vs Mouse Primitve Streak
# iv) Mouse Ectoderm vs Human Primitve Streak

# Here we begin with normalized matrices for Human and Mouse datasets, containing
# matched cognate genes and outliers removed. 
# The pre-processing steps can be found in the manuscript.

setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
require(stringr)
require(DiffXTablesCoexpnet)

nthreads = 10

# prepare data for co-expression and differential co-expression experiments
# nthreads: number of parallel threads to use.
PrepareData = function(nthreads){
  
  # load human and mouse normalized expression matrices
  # with genes on columns and samples on rows
  load(paste0("../Data/EDHuman.RData"))
  load(paste0("../Data/EDMouse.RData"))
  
  # get gene names
  gnames_human = colnames(human)
  gnames_mouse = colnames(mouse)
  
  # filter genes with low MAD
  rem_human = FilterGenesByMAD(human, nthreads)
  rem_mouse = FilterGenesByMAD(mouse, nthreads)
  
  # remove genes that have either low MAD in mouse or human
  rem_indices = rem_mouse | rem_human
  
  # subset data
  human = human[ , !rem_indices]
  mouse = mouse[ , !rem_indices]
  gnames_human = gnames_human[!rem_indices]
  gnames_mouse = gnames_mouse[!rem_indices]
  
  # tissue names for ectoderm and primitve streak
  ectoderm = c("Brain", "Cerebellum")
  primstreak = c("Heart", "Kidney", "Liver",  "Ovary", "Testis")
  
  # prepare subsets of data for the four experiment
  hprim = human[gsub("\\..*","",rownames(human)) %in% primstreak, ]
  mprim = mouse[gsub("\\..*","",rownames(mouse)) %in% primstreak, ]
  hecto = human[gsub("\\..*","",rownames(human)) %in% ectoderm, ]
  mecto = mouse[gsub("\\..*","",rownames(mouse)) %in% ectoderm, ]
  
  # remove genes with 0 variance in either experiment
  
  hprimvar = apply(hprim, 2, var)
  mprimvar = apply(mprim, 2, var)
  hectovar = apply(hecto, 2, var)
  mectovar = apply(mecto, 2, var)
  
  zero_var_genes = hprimvar == 0 | mprimvar == 0 | hectovar == 0 | mectovar == 0
  hprim = hprim[ , !zero_var_genes]
  mprim = mprim[ , !zero_var_genes]
  hecto = hecto[ , !zero_var_genes]
  mecto = mecto[ , !zero_var_genes]
  gnames_human = gnames_human[!zero_var_genes]
  gnames_mouse = gnames_mouse[!zero_var_genes]
  
  # return the four datasets
  return(list(hprim = hprim, mprim = mprim, hecto = hecto, mecto = mecto, 
              gnames_human = gnames_human, gnames_mouse = gnames_mouse))
}

EvoDevo_script = function(nthreads){
  
  
  # prepare data
  prep_data = PrepareData(nthreads)
  hecto = prep_data$hecto
  mecto = prep_data$mecto
  mprim = prep_data$mprim
  hprim = prep_data$hprim
  
  gnames_human = prep_data$gnames_human
  gnames_mouse = prep_data$gnames_mouse
  
  # number of conditions
  n_conditions = 2
  
  
  # get co-expression networks
  
  # human ecto coexpnet
  human_ecto_coexpnet = GetCoexpnet(hecto, gnames_human, nthreads)
  
  # mouse ecto coexpnet
  mouse_ecto_coexpnet = GetCoexpnet(mecto, gnames_mouse, nthreads)
  
  # human prim coexpnet
  human_prim_coexpnet = GetCoexpnet(hprim, gnames_human, nthreads)
  
  # mouse prim coexpnet
  mouse_prim_coexpnet = GetCoexpnet(mprim, gnames_mouse, nthreads)
  
  
  
  # get differential coexpression networks
  
  # get human-mouse ecto diffcoexpnet
  human_mouse_ecto_diffcoexpnet = Exp1_vs_Exp2(hecto, gnames_human, mecto, nthreads,
                                               list(human_ecto_coexpnet, mouse_ecto_coexpnet))
  
  # get human ecto-prim diffcoexpnet
  human_ecto_prim_diffcoexpnet = Exp1_vs_Exp2(hecto, gnames_human, hprim, nthreads,
                                              list(human_ecto_coexpnet, human_prim_coexpnet))
  
  # get mouse ecto-prim diffcoexpnet
  mouse_ecto_prim_diffcoexpnet = Exp1_vs_Exp2(mecto, gnames_mouse, mprim, nthreads,
                                              list(mouse_ecto_coexpnet, mouse_prim_coexpnet))
  
  # get human-mouse prim diffcoexpnet
  human_mouse_prim_diffcoexpnet = Exp1_vs_Exp2(hprim, gnames_human, mprim, nthreads,
                                               list(human_prim_coexpnet, mouse_prim_coexpnet))
  
  # threshold effect size
  
  # need highest dimension
  max_sm = max(nrow(hecto)+nrow(mecto), nrow(mecto)+nrow(mprim), nrow(mprim)+nrow(hprim), 
               nrow(hprim)+nrow(hecto))
  d = max(2, ceiling((sqrt(max_sm/5)/n_conditions)))
  
  # generate distribution
  sim_shsong_res = SharmaSongEsDist(N=c(100,100), n_conditions = 2, nthreads = nthreads, d = d)
  
  # threshold at 60%
  es_thres = sort(sim_shsong_res[,4])[round(0.6*length(sim_shsong_res[,4]))]
  
  # compile percentages
  Compileresults(human_ecto_prim_diffcoexpnet, mouse_ecto_prim_diffcoexpnet, 
                 human_mouse_ecto_diffcoexpnet, human_mouse_prim_diffcoexpnet,
                 0.05, es_thres)
}


Compileresults = function(hecto_hprim, mecto_mprim, hecto_mecto, hprim_mprim, 
                          pthres, esthres){
  
  cat("Human vs Mouse Ectoderm: ", 
      sum(hecto_mecto$PVALUE < pthres & hecto_mecto$ESTIMATE > esthres)/nrow(hecto_mecto),
      "\n")
  cat("Human vs Mouse Primitve Streak: ", 
      sum(hprim_mprim$PVALUE < pthres & hprim_mprim$ESTIMATE > esthres)/nrow(hprim_mprim),
      "\n")
  cat("Human Ectoderm vs Human Primitve Streak: ", 
      sum(hecto_hprim$PVALUE < pthres & hecto_hprim$ESTIMATE > esthres)/nrow(hecto_hprim),
      "\n")
  cat("Mouse Ectoderm vs Mouse Primitve Streak: ", 
      sum(mecto_mprim$PVALUE < pthres & mecto_mprim$ESTIMATE > esthres)/nrow(mecto_mprim),
      "\n")
  
}

GetCoexpnet = function(data, gnames, nthreads){
  
  # build co-expression networks
  coexpnet = Build_Coexpnet(parent_expr = data, 
                            child_expr = data, 
                            c_names = gnames, 
                            p_names = gnames,
                            method = 'univariate',
                            nthreads = nthreads)
  
  # filter results
  filter = coexpnet$PVALUE < 0.1 & coexpnet$ESTIMATE > 0.8
  coexpnet = coexpnet[filter, ]
  
  gc()
  
  # return
  return(coexpnet)
  
}

Exp1_vs_Exp2 = function(data1, gnames, data2, nthreads, coexpnet_list){
  
  # get interaction indices
  inter_indices = GetInteractionIndices(coexpnet_list)
  
  # prepare data
  exp_matr = rbind(data1, data2)
  conditions = c(rep(1, nrow(data1)), rep(2, nrow(data2)))
  n_conditions = 2
  
  # build differential coexpnet
  diffcoexpnet = Build_DiffCoexpnet(exp_matr = exp_matr,
                                    n_conditions = n_conditions,
                                    conditions = conditions,
                                    indices = inter_indices,
                                    g_names = gnames,
                                    method = 'univariate',
                                    nthreads = nthreads)
  
  # return differential coexpression network
  return(diffcoexpnet)
}