# ******************************************
#
# Code for analysing the selected random forest models
#
# Data (labelled dat) is the extracted feature dataset as will be made available on the Rothamsted Research Repository (https://repository.rothamsted.ac.uk/) under a Creative Commons Attribution 4.0 Licence.
#
# *****************************************

library(randomForest)
library(ggplot2)

rm(list=ls())



# Functions ----------
# functions to extract error rates from a RF
getOOB <- function(rf){
  n <- dim(rf$confusion)[1]
  oob <- 100 - (sum(diag(rf$confusion[,1:n])) / sum(rf$confusion[,1:n])) * 100
  return(oob)
}

# true positive
getTP <- function(rf){
  n <- dim(rf$confusion)[1]
  TP <- (diag(rf$confusion[,1:n]))
  return(c(TP))
}

# true negative
getTN <- function(rf){
  n <- dim(rf$confusion)[1]
  TN <- rep(NA, n)
  conf <- rf$confusion[1:n, 1:n]
  for (i in 1:n){
    TN[i] <- sum(conf[-i,-i])
    
  }
  return(TN)
}

# false negative
getFN <- function(rf){
  n <- dim(rf$confusion)[1]
  conf <- rf$confusion[1:n, 1:n]
  FN <- rep(NA, n)
  for (i in 1:n){
    FN[i] <- sum(conf[i,-i])
  }
  return(FN)
}

# false positive
getFP <- function(rf){
  n <- dim(rf$confusion)[1]
  conf <- rf$confusion[1:n, 1:n]
  FP <- rep(NA, n)
  for (i in 1:n){
    FP[i] <- sum(conf[-i,i])
  }
  return(FP)
}



# Data ------------


# read in data from doi
head(dat)

# create sommon name for each species
dat$species_common <- as.character(dat$species)
dat$species_common[dat$species == "Brassicogethes aeneus"] <- "Pollen beetle"
dat$species_common[dat$species == "Psylliodes chrysocephala"] <- "Cabbage stem flea beetle"
dat$species_common[dat$species == "Aphis fabae"] <- "Black bean aphid"
dat$species_common[dat$species == "Drepanosiphum platanoidis"] <- "Sycamore aphid"
dat$species_common[dat$species == "Myzus persicae"] <- "Peach-potato aphid"
dat$species_common[dat$species == "Periphyllus testudinaceus"] <- "Maple aphid"
dat$species_common[dat$species == "Rhopalosiphum padi"] <- "Bird cherry-oat aphid"
dat$species_common[dat$species == "Sitobion avenae"] <- "English grain aphid"
dat$species_common <- factor(dat$species_common)


# select 52 extracted features
model1 <- c("log_maxAmp","log_rangeAmp","log_iqrAmp","log_crestFactor","log_power","log_rms",
            "sqrt_domfreq","bioacousticIndex","bioacousticIndex_2","bioacousticIndex_3","bioacousticIndex_4",
            "log_amplitudeIndex","spectralEntropy","temporalEntropy","acousticEntropy",
            "harmonics.1","harmonics.2","harmonics.3","harmonics.4","harmonics.5", "harmonics.6","harmonics.7","harmonics.8","harmonics.9","harmonics.10",
            "log_gamAmp","log_maxAmp_g", "log_rangeAmp_g", "log_iqrAmp_g", "log_crestFactor_g", "log_power_g", "log_rms_g",
            "log_fundFreq","sqrt_domfreq_g","bioacousticIndex_g","bioacousticIndex_g_2","bioacousticIndex_g_3","bioacousticIndex_g_4",
            "log_amplitudeIndex_g","spectralEntropy_g","temporalEntropy_g", "acousticEntropy_g",
            "harmonics_g.1", "harmonics_g.2","harmonics_g.3", "harmonics_g.4","harmonics_g.5", "harmonics_g.6","harmonics_g.7","harmonics_g.8","harmonics_g.9","harmonics_g.10")

dim(dat[, model1])


# split the data into training and validation
set.seed(6546) # for reproducibility
train <- sample(nrow(dat), 0.7*nrow(dat), replace=FALSE)

dat$set <- "validation"
dat$set[train] <- "training"
# remove the two species with too few observations
dat$set[dat$species %in%c("Myzus persicae", "Rhopalosiphum padi")] <- "null"
table(dat$set)
table(dat$set, dat$species_common)

trainTree <- dat[dat$set=="training",model1]
trainTree <- droplevels(trainTree)
trainSp <- dat[dat$set=="training","species_common"]
trainSp <- droplevels(trainSp)

validTree <- dat[dat$set=="validation",model1]
validTree <- droplevels(validTree)
validSp <- dat[dat$set=="validation","species_common"]
validSp <- droplevels(validSp)






# the random forest model ---------------------

# make it reproducible
set.seed(78)
rfB <- randomForest(trainSp ~ ., trainTree, importance=TRUE, na.action=na.roughfix, mtry=10, ntree = 1000, sampsize=c(75, 75, 120, 50, 120, 300))
# randomForest(formula = trainSp ~ ., data = trainTree, importance = TRUE,      mtry = 10, ntree = 1000, sampsize = c(75, 75, 120, 50, 120,          300), na.action = na.roughfix) 
# Type of random forest: classification
# Number of trees: 1000
# No. of variables tried at each split: 10
# 
# OOB estimate of  error rate: 20.62%

rfB_omit <- randomForest(trainSp ~ ., trainTree, importance=TRUE, na.action=na.omit, mtry=10, ntree = 1000, sampsize=c(60, 60, 96, 40, 96, 240))
# randomForest(formula = trainSp ~ ., data = trainTree, importance = TRUE,      mtry = 10, ntree = 1000, sampsize = c(60, 60, 96, 40, 96,          240), na.action = na.omit) 
# Type of random forest: classification
# Number of trees: 1000
# No. of variables tried at each split: 10
# 
# OOB estimate of  error rate: 17.88%


# classification accuracy ------------
# oob
100 - (sum(diag(rfB$confusion[,1:6])) / sum(rfB$confusion[,1:6])) * 100
100 - (sum(diag(rfB_omit$confusion[,1:6])) / sum(rfB_omit$confusion[,1:6])) * 100


# Checking classification accuracy on validation set
validPred <- predict(rfB, validTree, type="class")
table(validPred, useNA="ifany")
# 244 with NA classification
confval <- table(validSp, validPred)
100 - (sum(diag(confval[,1:6])) / sum(confval[,1:6])) * 100

# imputing missing values as the median of the training set
validTree_med <- validTree
for (i in 1:52){
  ind <- which(is.na(validTree[,i]))
  validTree_med[ind, i] <- median(trainTree[,i], na.rm=TRUE)
}
validPred_impute <- predict(rfB, validTree_med, type="class")
table(validPred_impute, useNA="ifany")
confval_impute <- table(validSp, validPred_impute)
100 - (sum(diag(confval_impute[,1:6])) / sum(confval_impute[,1:6])) * 100

# accuracy when omitting observations with NAs
validPred_omit <- predict(rfB_omit, validTree, type="class")
table(validPred_omit, useNA="ifany")
confval_omit <- table(validSp, validPred_omit)
100 - (sum(diag(confval_omit[,1:6])) / sum(confval_omit[,1:6])) * 100


# Investigating the missclassifications 
# 1. na roughfix
rfB$confusion
confval_impute

# 2. na omit
rfB_omit$confusion
confval_omit


# Error Rates --------------

# class specific error rates
TN <- getTN(rfB)
TP <- getTP(rfB)
FN <- getFN(rfB)
FP <- getFP(rfB)

# get class error rates
classErr <- 1 - TP / (TP + FN)
# get true positive rate
TPR <- TP / (TP + FN)
# get true negative rate
TNR <- TN/(TN + FP)
# weighted accuracy
wAcc <- 0.5* TNR + 0.5* TPR

errRates <- data.frame(Species=rep(levels(trainSp), 4), error=factor(rep(c("TPR", "TNR", "wAcc", "clErr"), each=6)), value=c(TPR,TNR,wAcc,classErr))


ggplot(errRates, aes(x=Species, y=value, colour=error, group=error)) + geom_point() + geom_line() + labs(x="", y="Error rate") + guides(colour=guide_legend(title="Error type")) + theme(axis.text.x=element_text(angle=45, hjust=1), panel.background=element_rect(fill="white",colour="black")) 








# Classification Certainty ---------


# the probability of classifying each observation to a particular class
predValid_prob <- predict(rfB, validTree_med, type="prob")
head(predValid_prob)
outValid <- data.frame(predValid_prob, validTree)
maxProbs <- apply(predValid_prob, 1, max)
hist(maxProbs)
boxplot(maxProbs ~ validSp)

returnProb <- function(ind){
  
  sp <- validSp[ind]
  out <- predValid_prob[ind, sp]
  return(out)
}


predValid <- predict(rfB, validTree_med, type="class")
trueProbs <- sapply(1:dim(validTree_med)[1], returnProb)

correct <- validSp == predValid
df <- data.frame(Probability=trueProbs, Species = validSp, correct = factor(correct), maxProbs)

ggplot(df, aes(x=Species, y=Probability, fill=Species)) + geom_boxplot() + guides(fill="none") + theme(axis.text.x=element_text(angle=45, hjust=1), panel.background=element_rect(fill="white",colour="black")) + labs(x="", y="Class Probability")
ggplot(df, aes(x=Species, y=Probability, fill=correct)) + geom_boxplot()
ggplot(df, aes(x=Species, y=maxProbs, fill=correct)) + geom_boxplot()+ guides(fill=guide_legend(title="Classification")) + theme(axis.text.x=element_text(angle=45, hjust=1), panel.background=element_rect(fill="white",colour="black")) + labs(x="", y="Class Probability")



# Feature Importance -------------------

labs <- row.names(rfB$importance)
labs[labs == "log_maxAmp"] <- "Maximum amplitude"
labs[labs == "log_rangeAmp"] <- "Amplitude range"
labs[labs== "log_iqrAmp"] <- "Amplitude IQR"
labs[labs == "log_power"] <- "Power"
labs[labs == "log_rms"] <- "RMS"
labs[labs == "log_crestFactor"] <- "Crest factor"
labs[labs == "log_amplitudeIndex"] <- "Amplitude index"
labs[labs == "temporalEntropy"] <- "Temporal entropy"
labs[labs == "bioacousticIndex"] <- "Bioacoustics index (1)"
labs[labs == "bioacousticIndex_2"] <- "Bioacoustics index (2)"
labs[labs == "bioacousticIndex_3"] <- "Bioacoustics index (3)"
labs[labs == "bioacousticIndex_4"] <- "Bioacoustics index (4)"
labs[labs == "spectralEntropy"] <- "Spectral entropy"
labs[labs == "acousticEntropy"] <- "Acoustic entropy"
labs[labs == "sqrt_domfreq"] <- "Dominant frequency"
labs[labs == "harmonics.1"] <- "1st harmonic"
labs[labs == "harmonics.2"] <- "2nd harmonic"
labs[labs == "harmonics.3"] <- "3rd harmonic"
labs[labs == "harmonics.4"] <- "4th harmonic"
labs[labs == "harmonics.5"] <- "5th harmonic"
labs[labs == "harmonics.6"] <- "6th harmonic"
labs[labs == "harmonics.7"] <- "7th harmonic"
labs[labs == "harmonics.8"] <- "8th harmonic"
labs[labs == "harmonics.9"] <- "9th harmonic"
labs[labs == "harmonics.10"] <- "10th harmonic"
labs[labs == "log_gamAmp"] <- "GAM amplitude range [g]"
labs[labs == "log_maxAmp_g"] <- "Maximum amplitude [g]"
labs[labs == "log_rangeAmp_g"] <- "Amplitude range [g]"
labs[labs== "log_iqrAmp_g"] <- "Amplitude IQR [g]"
labs[labs == "log_power_g"] <- "Power [g]"
labs[labs == "log_rms_g"] <- "RMS [g]"
labs[labs == "log_crestFactor_g"] <- "Crest factor [g]"
labs[labs == "log_amplitudeIndex_g"] <- "Amplitude index [g]"
labs[labs == "temporalEntropy_g"] <- "Temporal entropy [g]"
labs[labs == "bioacousticIndex_g"] <- "Bioacoustics index (1) [g]"
labs[labs == "bioacousticIndex_g_2"] <- "Bioacoustics index (2) [g]"
labs[labs == "bioacousticIndex_g_3"] <- "Bioacoustics index (3) [g]"
labs[labs == "bioacousticIndex_g_4"] <- "Bioacoustics index (4) [g]"
labs[labs == "spectralEntropy_g"] <- "Spectral entropy [g]"
labs[labs == "acousticEntropy_g"] <- "Acoustic entropy [g]"
labs[labs == "sqrt_domfreq_g"] <- "Dominant frequency [g]"
labs[labs == "log_fundFreq"] <- "Fundamental frequency [g]"
labs[labs == "harmonics_g.1"] <- "1st harmonic [g]"
labs[labs == "harmonics_g.2"] <- "2nd harmonic [g]"
labs[labs == "harmonics_g.3"] <- "3rd harmonic [g]"
labs[labs == "harmonics_g.4"] <- "4th harmonic [g]"
labs[labs == "harmonics_g.5"] <- "5th harmonic [g]"
labs[labs == "harmonics_g.6"] <- "6th harmonic [g]"
labs[labs == "harmonics_g.7"] <- "7th harmonic [g]"
labs[labs == "harmonics_g.8"] <- "8th harmonic [g]"
labs[labs == "harmonics_g.9"] <- "9th harmonic [g]"
labs[labs == "harmonics_g.10"] <- "10th harmonic [g]"


par(mar=c(5,10,4,4), mfrow=c(1, 1))
ind <- order(rfB$importance[,"MeanDecreaseAccuracy"])
barplot(rfB$importance[ind,"MeanDecreaseAccuracy"], horiz=TRUE, xlab="Mean decrease in accuracy", names.arg=labs[ind],las=2, cex.names=0.75, col="tomato")
box()
ind <- order(rfB$importance[,"MeanDecreaseGini"])
barplot(rfB$importance[ind,"MeanDecreaseGini"], horiz=TRUE, xlab="Mean decrease in Gini index", names.arg=labs[ind],las=2, col="skyblue", cex.names=0.75)
box()

par(mar=c(10, 5,2,2))
ind <- rev(order(rfB$importance[,"MeanDecreaseAccuracy"]))
barplot(rfB$importance[ind,"MeanDecreaseAccuracy"], horiz=FALSE, ylab="Mean decrease in accuracy", names.arg=labs[ind],las=2, cex.names=0.9, col="tomato")
box()
ind <- rev(order(rfB$importance[,"MeanDecreaseGini"]))
barplot(rfB$importance[ind,"MeanDecreaseGini"], horiz=FALSE, ylab="Mean decrease in Gini index", names.arg=labs[ind],las=2, col="skyblue", cex.names=0.9)
box()


# only extract the top 10
ind <- rev(order(rfB$importance[,"MeanDecreaseAccuracy"]))[1:10]
barplot(rfB$importance[ind,"MeanDecreaseAccuracy"], horiz=FALSE, ylab="Mean decrease in accuracy", names.arg=labs[ind],las=2, cex.names=0.9, col="tomato")
box()
ind <- rev(order(rfB$importance[,"MeanDecreaseGini"]))[1:10]
barplot(rfB$importance[ind,"MeanDecreaseGini"], horiz=FALSE, ylab="Mean decrease in Gini index", names.arg=labs[ind],las=2, col="skyblue", cex.names=0.9)
box()


rfImportance <- as.data.frame(rfB$importance)
rfImportance$labs <- labs
# reshape to long format
rfImportance_long <- reshape(rfImportance, direction="long", varying=list(1:(dim(rfImportance)[2]-1-2)), timevar="Species",times=colnames(rfImportance)[1:(dim(rfImportance)[2]-1-2)], v.names="Importance", drop=c(7,8))
row.names(rfImportance_long) <- NULL
rfImportance_long$labs <- factor(rfImportance_long$labs, 
                                 levels=c("Maximum amplitude", "Amplitude range", "Amplitude IQR", "Power","RMS","Crest factor","Amplitude index","Temporal entropy","Bioacoustics index (1)", "Bioacoustics index (2)","Bioacoustics index (3)","Bioacoustics index (4)", "Spectral entropy", "Acoustic entropy", 
                                          "Dominant frequency", "1st harmonic", "2nd harmonic", "3rd harmonic", "4th harmonic", "5th harmonic", "6th harmonic", "7th harmonic", "8th harmonic", "9th harmonic", "10th harmonic", 
                                          "GAM amplitude range [g]","Maximum amplitude [g]","Amplitude range [g]","Amplitude IQR [g]","Power [g]", "RMS [g]","Crest factor [g]", "Amplitude index [g]", "Temporal entropy [g]", "Bioacoustics index (1) [g]", "Bioacoustics index (2) [g]", "Bioacoustics index (3) [g]", "Bioacoustics index (4) [g]", "Spectral entropy [g]", "Acoustic entropy [g]", 
                                          "Dominant frequency [g]", "Fundamental frequency [g]", "1st harmonic [g]", "2nd harmonic [g]", "3rd harmonic [g]", "4th harmonic [g]", "5th harmonic [g]", "6th harmonic [g]", "7th harmonic [g]", "8th harmonic [g]", "9th harmonic [g]", "10th harmonic [g]"))
rfImportance_long$Insect <- factor(ifelse(rfImportance_long$Species %in% c("Pollen beetle", "Cabbage stem flea beetle"), "Beetle", "Aphid"))


ggplot(rfImportance_long, aes(x=labs, y=Importance, fill=Species, shape=Insect)) + geom_point(alpha=0.8, colour="black", size=2) + labs(x="") + theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust=0.3),panel.background = element_rect(fill="white",colour="black")) + scale_shape_manual(values=c(21, 24)) +guides(fill=guide_legend("Species")) + theme(legend.text=element_text(face="italic"))

# legend in above is weird,,
ggplot(rfImportance_long, aes(x=labs, y=(Importance), colour=Species, shape=Insect)) + geom_point(size=2) + labs(x="") + theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust=0.3)) + scale_shape_manual(values=c(21, 24)) +guides(colour=guide_legend("Species")) + theme_bw() + theme(legend.text=element_text(face="italic"))



