Gene homology Part 3 - Visualizing Gene Ontology of Conserved Genes

2017-01-05 00:00:00 +0000

Which genes have homologs in many species?

In Part 1 and Part 2 I have already explored gene homology between humans and other species. But there I have only considered how many genes where shared between the species.

In this post I want to have a closer look at genes with homologs in many species. These genes are likely evolutionarily convserved in their biological function. To find out what biological functions these genes have, I am using gene ontology (GO) enrichment analysis.

Because GO categories have a hierarchical parent-children structure, comparing enriched GO terms between groups of genes in a meaningful way isn’t trivial. Simply calculating the overlap between categories will miss the bigger picture when, e.g. two groups are enriched for closely related but not for exactly the same GO categories. Here, I will explore a few options of visualizing GO enrichment with the purpose of comparing enrichments between groups (here: species).

A network of the partial hierarchical structure of enriched GO terms of biological processes encompasses the information of relatedness between GO terms. Custom annotations show GO terms that were common, as well as terms that were unique to species.

Highly conserved biological functions include core regulatory processes that are essential for the cell cycle, energy production and metabolism, nucleoside biogenesis, protein metabolism, etc.



I’m starting with the same table as from this post, called homologs_table_combined.

head(homologs_table_combined)
##   celegans_gene_ensembl cfamiliaris_gene_ensembl
## 1                    NA                       NA
## 2                    NA                   477699
## 3                    NA                   477023
## 4                173979                   490207
## 5                179795                   607852
## 6                177055                       NA
##   dmelanogaster_gene_ensembl drerio_gene_ensembl ggallus_gene_ensembl
## 1                         NA                  NA                   NA
## 2                      44071                  NA                   NA
## 3                         NA              431754               770094
## 4                      38864              406283                   NA
## 5                      36760              794259               395373
## 6                    3772179              541489               421909
##   hsapiens_gene_ensembl mmusculus_gene_ensembl ptroglodytes_gene_ensembl
## 1                    NA                     NA                        NA
## 2                     2                 232345                    465372
## 3                    30                 235674                    460268
## 4                    34                  11364                    469356
## 5                    47                 104112                    454672
## 6                    52                  11431                    458990
##   rnorvegicus_gene_ensembl scerevisiae_gene_ensembl sscrofa_gene_ensembl
## 1                    24152                       NA                   NA
## 2                    24153                       NA               403166
## 3                    24157                   854646            100515577
## 4                    24158                       NA               397104
## 5                    24159                   854310                   NA
## 6                    24161                   856187            100737301

Each row in this table denotes a gene with its Entrez ID and corresponding Entrez IDs for homologs in each of the other 10 species I explored. If a gene doesn’t have homolog the table says “NA”. However, some genes have duplicate entries for a species if there are multiple homologs in different species.

By counting the number of NAs per row, we can identify genes with homologs in all species (sum of NAs = 0) and genes which are specific (sum of NAs = 10).

homologs_na <- rowSums(is.na(homologs_table_combined))

Before I delve deeper into the biology behind these genes, I want to examine the distribution of the NA-counts. To do so, I am plotting a histogram.

But first I’m preparing my custom ggplot2 theme:

library(ggplot2)
my_theme <- function(base_size = 12, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
  theme(
    axis.text = element_text(size = 12),
    axis.title = element_text(size = 14),
    panel.grid.major = element_line(color = "grey"),
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "aliceblue"),
    strip.background = element_rect(fill = "royalblue", color = "grey", size = 1),
    strip.text = element_text(face = "bold", size = 12, color = "white"),
    legend.position = "bottom",
    legend.justification = "top", 
    legend.box = "horizontal",
    legend.box.background = element_rect(colour = "grey50"),
    legend.background = element_blank(),
    panel.border = element_rect(color = "grey", fill = NA, size = 0.5)
  )
}

theme_set(my_theme())
ggplot() + 
  aes(homologs_na) + 
  geom_histogram(binwidth = 1, color = "royalblue", fill = "royalblue", alpha = 0.8) +
  labs(
    x = "Number of NAs per row of gene homology table",
    y = "Count",
    title = "How many homologs do genes have?",
    subtitle = "Showing all bins from 0 to 10",
    caption = "\nEach row of the gene homology table list a gene and all its homologs in 10 other species.
    If a gene doesn't have a homolog, its value in the table is NA. Thus, rows with many NAs refer to a
    gene that is specific to one species, while rows with no NAs show genes with homologs in all species -
    such genes are said to be highly conserved."
  )


Clearly, most genes are specific to a species, they have NAs in all but one column. The rest of the histogram is a bit hard to differentiate with the peak at 10, so let’s look at the same data again with these genes:


ggplot() + 
  aes(subset(homologs_na, homologs_na  < 10)) + 
  geom_histogram(binwidth = 1, color = "royalblue", fill = "royalblue", alpha = 0.8) +
  labs(
    x = "Number of NAs per row of gene homology table",
    y = "Count",
    title = "How many homologs do genes have?",
    subtitle = "Showing bins from 0 to 9",
    caption = "\nEach row of the gene homology table list a gene and all its homologs in 10 other species.
    If a gene doesn't have a homolog, its value in the table is NA. Thus, rows with many NAs refer to a
    gene that is specific to one species, while rows with no NAs show genes with homologs in all species -
    such genes are said to be highly conserved."
  )



Now we can see that most genes have homologs in 9 species (2 NAs). But there are still quite a few genes with homologs in all species.


Which genes are highly conserved and what biological functions do they have?

There are 3461 rows in the original table with no NAs.

genes_homologs_all <- homologs_table_combined[which(rowSums(is.na(homologs_table_combined)) == 0), ]
nrow(genes_homologs_all)
## [1] 3461

Looking at all of these genes by hand wouldn’t be feasible. So, to find out what biological functions these genes have, I am using gene ontology (GO-term) enrichment analysis as implemented in clusterProfiler.

  • Biological processes (BP): Collections of molecular events and functions contributing to a biological process
  • Cellular component (CC): (sub-) cellular structures and locations, macromolecular complexes
  • Molecular functions (MF): Molecular functions like catalytic or binding activity
library(clusterProfiler)
library(DOSE)

for (i in 1:nrow(datasets)) {
  species <- datasets$dataset[i]
  genes <- as.character(unique(genes_homologs_all[, species]))
  universe <- as.character(unique(na.omit(homologs_table_combined[, species])))
  
  cat("\nSpecies", datasets$description[i], "has", length(universe), "unique Entrez IDs, of which", length(genes), "have homologs in all species.\n")
  
  try(go_enrich_BP <- enrichGO(gene = genes, 
                        keytype = "ENTREZID",
                        OrgDb = get(datasets$orgDb[i]),
                        ont = "BP",
                        qvalueCutoff = 0.05,
                        universe = universe,
                        readable = TRUE))
  
  try(go_enrich_MF <- enrichGO(gene = genes, 
                        keytype = "ENTREZID",
                        OrgDb = get(datasets$orgDb[i]),
                        ont = "MF",
                        qvalueCutoff = 0.05,
                        universe = universe,
                        readable = TRUE))
  
  try(go_enrich_CC <- enrichGO(gene = genes, 
                        keytype = "ENTREZID",
                        OrgDb = get(datasets$orgDb[i]),
                        ont = "CC",
                        qvalueCutoff = 0.05,
                        universe = universe,
                        readable = TRUE))
  
  try(assign(paste("go_enrich_BP", species, sep = "_"), go_enrich_BP))
  try(assign(paste("go_enrich_MF", species, sep = "_"), go_enrich_MF))
  try(assign(paste("go_enrich_CC", species, sep = "_"), go_enrich_CC))
}
## 
## Species Rat has 47989 unique Entrez IDs, of which 1796 have homologs in all species.
## 
## Species Yeast has 6350 unique Entrez IDs, of which 1363 have homologs in all species.
## 
## Error in .testForValidCols(x, cols) : 
##   Invalid columns: SYMBOL. Please use the columns method to see a listing of valid arguments.
## Error in .testForValidCols(x, cols) : 
##   Invalid columns: SYMBOL. Please use the columns method to see a listing of valid arguments.
## Error in .testForValidCols(x, cols) : 
##   Invalid columns: SYMBOL. Please use the columns method to see a listing of valid arguments.
## 
## Species C. elegans has 46727 unique Entrez IDs, of which 1758 have homologs in all species.
## 
## Species Chimpanzee has 40540 unique Entrez IDs, of which 1704 have homologs in all species.
## 
## Species Pig has 54622 unique Entrez IDs, of which 1708 have homologs in all species.
## 
## Species Human has 60136 unique Entrez IDs, of which 1757 have homologs in all species.
## 
## Species Chicken has 34668 unique Entrez IDs, of which 1684 have homologs in all species.
## 
## Species Zebrafish has 47900 unique Entrez IDs, of which 1910 have homologs in all species.
## 
## Species Fruitfly has 25038 unique Entrez IDs, of which 1805 have homologs in all species.
## 
## Species Mouse has 68603 unique Entrez IDs, of which 1761 have homologs in all species.
## 
## Species Dog has 31376 unique Entrez IDs, of which 1716 have homologs in all species.

I’m not sure why yeast throws an error but since all other species worked I will ignore yeast genes for now…


Visualizing enriched GO categories

A heatmap of gene enrichment ratios shows enriched GO categories and their enrichment strength and allows us to capture comparable information for all species.

To generate this heatmap, I am using the output tables from clusterProfiler, which tell us the enrichment p-value (before and after correction for multiple testing, aka False Discovery Rate or FDR), the proportion of genes belonging to each GO category in the background versus all background genes and in our gene list compared to all genes in our list. To keep the graph reasonably small, I’m subsetting only the top 10 enriched genes with lowest adjusted p-values for biological processes (BP), molecular functions (MF) and cellular components (CC).

# removing yeast
datasets_2 <- datasets[-grep("scerevisiae_gene_ensembl", datasets$dataset), ]

library(dplyr)

cutoff = 10

for (i in 1:nrow(datasets_2)) {
  species <- datasets_2$dataset[i]
  
  go_enrich_BP_df <- as.data.frame(get(paste("go_enrich_BP", species, sep = "_"))) %>%
    mutate(GR = as.numeric(Count) / as.numeric(gsub(".*/", "", GeneRatio)),
           Category = "BP")
  
  plot_BP <- go_enrich_BP_df[1:cutoff, ]
  plot_BP$Description <- factor(plot_BP$Description, levels = plot_BP$Description[order(plot_BP$GR)])
    
  go_enrich_MF_df <- as.data.frame(get(paste("go_enrich_MF", species, sep = "_"))) %>%
    mutate(GR = as.numeric(Count) / as.numeric(gsub(".*/", "", GeneRatio)),
           Category = "MF")
  
  plot_MF <- go_enrich_MF_df[1:cutoff, ]
  plot_MF$Description <- factor(plot_MF$Description, levels = plot_MF$Description[order(plot_MF$GR)])

  go_enrich_CC_df <- as.data.frame(get(paste("go_enrich_CC", species, sep = "_"))) %>%
    mutate(GR = as.numeric(Count) / as.numeric(gsub(".*/", "", GeneRatio)),
           Category = "CC")
  
  plot_CC <- go_enrich_CC_df[1:cutoff, ]
  plot_CC$Description <- factor(plot_CC$Description, levels = plot_CC$Description[order(plot_CC$GR)])
  
  plot <- rbind(plot_BP, plot_MF, plot_CC)
  plot$Species <- paste(datasets_2$description[i])

  if (i == 1) {
    plot_df <- plot
  } else {
    plot_df <- rbind(plot_df, plot)
  }
}

length(unique(plot_df$ID))
## [1] 123
ggplot(plot_df, aes(x = Species, y = Description, fill = GR)) + 
  geom_tile(width = 1, height = 1) +
  scale_fill_gradient2(low = "white", mid = "blue", high = "red", space = "Lab", name = "Gene Ratio") +
  facet_grid(Category ~ ., scales = "free_y") +
  scale_x_discrete(position = "top") +
  labs(
      title = paste("Top", cutoff, "enriched GO terms"),
      x = "",
      y = ""
    )



Word clouds

Another way to capture information about GO terms is to use the GO description texts as basis for wordclouds. I’m using the general method described here. I’m also included a list of words to remove, like process, activity, etc. These words occur often and on their own don’t tell us much about the GO term.

library(tm)
library(SnowballC)
library(wordcloud)
library(RColorBrewer)

wordcloud_function <- function(data = data,
                               removewords = c("process", "activity", "positive", "negative", "response", "regulation"),
                               min.freq = 4,
                               max.words=Inf,
                               random.order=TRUE){
  input <- Corpus(VectorSource(data))
  
  input <- tm_map(input, content_transformer(tolower))
  input <- tm_map(input, content_transformer(removePunctuation))
  input <- tm_map(input, removeNumbers)
  input <- tm_map(input, stripWhitespace)
  
  toSpace <- content_transformer(function(x , pattern ) gsub(pattern, " ", x))
  input <- tm_map(input, toSpace, "/")
  input <- tm_map(input, toSpace, "@")
  input <- tm_map(input, toSpace, "\\|")
  
  input <- tm_map(input, function(x) removeWords(x, stopwords("english")))
  
  # specify your stopwords as a character vector
  input <- tm_map(input, removeWords, removewords)
  
  tdm <- TermDocumentMatrix(input)
  m <- as.matrix(tdm)
  v <- sort(rowSums(m),decreasing = TRUE)
  d <- data.frame(word = names(v),freq = v)
  
  set.seed(1234)
  wordcloud(words = d$word, freq = d$freq, min.freq = min.freq, scale = c(8,.2),
            max.words = max.words, random.order = random.order, rot.per = 0.15,
            colors = brewer.pal(8, "Dark2"))
}
layout(matrix(c(1:80), nrow = 4, byrow = FALSE), heights = c(0.1, 1))

for (i in 1:nrow(datasets_2)) {
  species <- datasets_2$dataset[i]
  df_BP <- as.data.frame(get(paste("go_enrich_BP", species, sep = "_")))
  df_MF <- as.data.frame(get(paste("go_enrich_MF", species, sep = "_")))
  df_CC <- as.data.frame(get(paste("go_enrich_CC", species, sep = "_")))
  
  par(mar = rep(0, 4))
  plot.new()
  text(x = 0.5, y = 0.5, paste0(datasets_2$description[i]), cex = 2)
  wordcloud_function(data = df_BP$Description)
  wordcloud_function(data = df_MF$Description)
  wordcloud_function(data = df_CC$Description)
}


Which GO terms are enriched in all species?

I also want to know which specific GO categories are enriched in all species. Even though we might miss related terms, this will give us a first idea about the biological functions that are highly conserved between species.

I’m creating a list of enriched GO categories of all species and ask for common elements using the list.common function of the rlist package.

# create empty list to populate with a loop
go_list_BP <- lapply(datasets_2$dataset, function(x) NULL)
names(go_list_BP) <- paste(datasets_2$dataset)

for (species in datasets_2$dataset) {
  df_BP <- as.data.frame(get(paste("go_enrich_BP", species, sep = "_")))
  go_list_BP[[species]] <- df_BP$ID
}

# Now I know which GO IDs are common but I want to know the description as well.
# Because they are in all go_enrich tables, I'm chosen one to subset

library(rlist)
common_gos_BP <- as.data.frame(go_enrich_BP_cfamiliaris_gene_ensembl) %>%
  filter(ID %in% list.common(go_list_BP)) %>%
  select(ID, Description) %>%
  arrange(Description)


go_list_MF <- lapply(datasets_2$dataset, function(x) NULL)
names(go_list_MF) <- paste(datasets_2$dataset)

for (species in datasets_2$dataset) {
  df_MF <- as.data.frame(get(paste("go_enrich_MF", species, sep = "_")))
  go_list_MF[[species]] <- df_MF$ID
}

library(rlist)
common_gos_MF <- as.data.frame(go_enrich_MF_cfamiliaris_gene_ensembl) %>%
  filter(ID %in% list.common(go_list_MF)) %>%
  select(ID, Description) %>%
  arrange(Description)


go_list_CC <- lapply(datasets_2$dataset, function(x) NULL)
names(go_list_CC) <- paste(datasets_2$dataset)

for (species in datasets_2$dataset) {
  df_CC <- as.data.frame(get(paste("go_enrich_CC", species, sep = "_")))
  go_list_CC[[species]] <- df_CC$ID
}

library(rlist)
common_gos_CC <- as.data.frame(go_enrich_CC_cfamiliaris_gene_ensembl) %>%
  filter(ID %in% list.common(go_list_CC)) %>%
  select(ID, Description) %>%
  arrange(Description)

To visualize these common GO categories, I’m going back to the clusterProfiler output tables. This time I’m not using a heatmap, but a point plot to show enrichment statistics for all species.

for (i in 1:nrow(datasets_2)) {
  species <- datasets_2$dataset[i]
  
  go_enrich_BP_df <- as.data.frame(get(paste("go_enrich_BP", species, sep = "_"))) %>%
    filter(ID %in% list.common(go_list_BP)) %>%
    mutate(GR = as.numeric(Count) / as.numeric(gsub(".*/", "", GeneRatio)),
           Category = "BP")
  
  go_enrich_BP_df$Description <- factor(go_enrich_BP_df$Description, levels = go_enrich_BP_df$Description[order(go_enrich_BP_df$GR)])
    
  go_enrich_MF_df <- as.data.frame(get(paste("go_enrich_MF", species, sep = "_"))) %>%
    filter(ID %in% list.common(go_list_MF)) %>%
    mutate(GR = as.numeric(Count) / as.numeric(gsub(".*/", "", GeneRatio)),
           Category = "MF")
  
  go_enrich_MF_df$Description <- factor(go_enrich_MF_df$Description, levels = go_enrich_MF_df$Description[order(go_enrich_MF_df$GR)])

  go_enrich_CC_df <- as.data.frame(get(paste("go_enrich_CC", species, sep = "_"))) %>%
    filter(ID %in% list.common(go_list_CC)) %>%
    mutate(GR = as.numeric(Count) / as.numeric(gsub(".*/", "", GeneRatio)),
           Category = "CC")
  
  go_enrich_CC_df$Description <- factor(go_enrich_CC_df$Description, levels = go_enrich_CC_df$Description[order(go_enrich_CC_df$GR)])
  
  plot <- rbind(go_enrich_BP_df, go_enrich_MF_df, go_enrich_CC_df)
  plot$Species <- paste(datasets_2$description[i])

  if (i == 1) {
    plot_df <- plot
  } else {
    plot_df <- rbind(plot_df, plot)
  }
}

ggplot(plot_df, aes(x = GR, y = Description, size = Count, color = p.adjust)) +
    geom_point() +
    labs(
      title = "Common enriched GO terms",
      x = "Gene Ratio",
      y = ""
    ) +
    facet_grid(Category ~ Species, scales = "free_y")


Gene Ratio is shown on the x-axis, point size shows the gene counts and point color shows the adjusted p-value.


GO hierachical structure

The package RamiGO can be used to plot all parents of the 10 common GO categories:

library(RamiGO)
library(RColorBrewer)
set3 <- brewer.pal(10, "Set3")
goIDs <- list.common(go_list_BP)
getAmigoTree(goIDs, set3, filename = "common_gos_BP", picType = "png", modeType = "amigo", saveResult = TRUE)

set3 <- brewer.pal(8, "Set3")
goIDs <- list.common(go_list_MF)
getAmigoTree(goIDs, set3, filename = "common_gos_MF", picType = "png", modeType = "amigo", saveResult = TRUE)

goIDs <- list.common(go_list_CC)
getAmigoTree(goIDs, set3, filename = "common_gos_CC", picType = "png", modeType = "amigo", saveResult = TRUE)
  • Common GOs BP

  • Common GOs MF

  • Common GOs CC


GO network

To visualize clusters of species-specific and common enriched GO categories within their hierachical structure, I’m creating a network of the top 15 enriched biological process categories of each species with two levels of parent terms.

for (i in 1:nrow(datasets_2)) {
  species <- datasets_2$dataset[i]
  df_BP <- as.data.frame(get(paste("go_enrich_BP", species, sep = "_")))
  df_BP <- df_BP[1:15, 1, drop = FALSE]
  df_BP$species <- datasets_2$description[i]
  
  if (i == 1) {
    df_BP_all <- df_BP
  } else {
    df_BP_all <- rbind(df_BP_all, df_BP)
  }
}

length(unique(df_BP_all$ID))
## [1] 69

Of these top GO categories most are not unique to individual species.

In the network I want to show GO categories that are species-specific and those that are common among all species. To create the annotation data with this information I’m creating a dataframe with the names of the species along with their specific GO term and combine this with the common GO terms from above.

library(tidyr)
df_BP_all$value <- 1
df_BP_all <- spread(df_BP_all, species, value)

go_sums <- data.frame(ID = df_BP_all$ID,
                      sum = rowSums(!is.na(df_BP_all[, -1])))

# GO in all species
go_anno <- data.frame(ID = go_sums$ID)
go_anno$group <- ifelse(go_sums$sum == 10, "common_all", NA)

x <- subset(go_sums, sum == 1)
xx <- subset(df_BP_all, ID %in% x$ID)

for (i in 2:11) {
  xxx <- xx[, i]
  xxx[which(!is.na(xxx))] <- colnames(xx[, i, drop = FALSE])
  unique <- as.data.frame(cbind(xx[, 1], xxx)) %>%
    subset(!is.na(xxx))
  
  if (i == 2) {
    unique_all <- unique
  } else {
    unique_all <- rbind(unique_all, unique)
  }
}

# Common GOs
unique_common_all <- rbind(unique_all, data.frame(V1 = list.common(go_list_BP), xxx = "common"))

To identify the parent GO terms, I’m using the GO.db package. I’m adding two levels of parental structure for each GO term.

library(GO.db)

# GO parents list
ontology <- as.list(GOBPPARENTS)

goChilds <- unique(df_BP_all$ID)
goChilds <- goChilds[goChilds %in% names(ontology)]

adjList <- matrix(nrow = 0, ncol = 3)
colnames(adjList) <- c("parent", "rel", "child")

# Find parents
goChilds <- goChilds[which(goChilds != "all")]
        
goFamilyList <- lapply(ontology[goChilds], function(x) data.frame(parent = x, rel = names(x)))
adjList <- rbind(adjList, as.matrix(cbind(do.call(rbind, goFamilyList), child = rep(names(goFamilyList), sapply(goFamilyList, nrow)))))

# Next parents
goParents <- unique(as.character(adjList[which(!adjList[,"parent"] %in% adjList[,"child"]), "parent"]))    
goFamilyList_2 <- lapply(ontology[goParents], function(x) data.frame(parent = x, rel = names(x)))
adjList <- rbind(adjList, as.matrix(cbind(do.call(rbind, goFamilyList_2), child = rep(names(goFamilyList_2), sapply(goFamilyList_2, nrow)))))

adjList <- adjList[which(adjList[,1] != "all"), , drop = FALSE]
adjList <- adjList[, c("child","parent", "rel"), drop = FALSE]

The graph is then prodcued with igraph as a directed network. Colors show species-specific, common and remaining enriched GO categories. Parents are shown in light blue.

library(igraph)

goGraph <- graph.edgelist(adjList[, c("parent", "child"), drop = FALSE], directed = TRUE)
goGraph <- set.edge.attribute(goGraph, "type", value = adjList[,"rel"])

goTerms <- sapply(V(goGraph)$name, function(x)
        {
            term <- GOTERM[[x]]@Term
            
                term <- gsub("( process)$", " pr.", term)
                term <- gsub(" development ", " dev. ", term)
                term <- gsub("( development)$", " dev.", term, perl = TRUE)
                term <- gsub("Development ", "Dev. ", term)
            
            # split in two lines
            spaceLoc <- gregexpr(" ", term, fixed = TRUE)[[1]]
            if ((spaceLoc[1] != -1) && (length(spaceLoc) > 1))
            {
                spaceLoc <- spaceLoc[ceiling(length(spaceLoc)/2)]
                term <- paste(substr(term, 1, spaceLoc - 1), substr(term, spaceLoc + 1, nchar(term)), sep = "\n")
            }

            term
})

goTerms_df <- as.data.frame(goTerms)

goTerms_df$color <- ifelse(rownames(goTerms_df) %in% unique_common_all[which(unique_common_all$xxx == "C. elegans"), 1], "green",
                           ifelse(rownames(goTerms_df) %in% unique_common_all[which(unique_common_all$xxx == "Chicken"), 1], "yellowgreen",
                                  ifelse(rownames(goTerms_df) %in% unique_common_all[which(unique_common_all$xxx == "Chimpanzee"), 1], "violetred",
                                         ifelse(rownames(goTerms_df) %in% unique_common_all[which(unique_common_all$xxx == "Dog"), 1], "turquoise",
                                                ifelse(rownames(goTerms_df) %in% unique_common_all[which(unique_common_all$xxx == "Fruitfly"), 1], "violet", 
                                                       ifelse(rownames(goTerms_df) %in% unique_common_all[which(unique_common_all$xxx == "Human"), 1], "thistle", 
                                                              ifelse(rownames(goTerms_df) %in% unique_common_all[which(unique_common_all$xxx == "Mouse"), 1], "tan1", 
                                                                     ifelse(rownames(goTerms_df) %in% unique_common_all[which(unique_common_all$xxx == "Pig"), 1], "steelblue", 
                                                                            ifelse(rownames(goTerms_df) %in% unique_common_all[which(unique_common_all$xxx == "Rat"), 1], "orangered", 
                                                                                   ifelse(rownames(goTerms_df) %in% 
                                                                                   unique_common_all[which(unique_common_all$xxx == "Zebrafish"), 1], "olivedrab", 
                                                                                          ifelse(rownames(goTerms_df) %in% 
                                                                                          unique_common_all[which(unique_common_all$xxx == "common"), 1], "cyan", 
                                                                                          ifelse(rownames(goTerms_df) %in% goChilds, "red", "skyblue2"))))))))))))

V(goGraph)$color <- adjustcolor(goTerms_df$color, alpha.f = 0.8)
V(goGraph)$label <- goTerms

Ecolor <- ifelse(E(goGraph)$type == "is_a", "slategrey",
                           ifelse(E(goGraph)$type == "part_of", "springgreen4",
                           ifelse(E(goGraph)$type == "regulates", "tomato",
                           ifelse(E(goGraph)$type == "negatively_regulates", "orange", "greenyellow"))))

E(goGraph)$color <- Ecolor
plot(goGraph,
     vertex.label.color = "black",
     vertex.label.cex = 0.8,
     vertex.size = 2,
     margin = 0)

labels <- c("C. elegans", "Chicken", "Chimpanzee", "Dog", "Fruitfly", "Human", "Mouse", "Pig", "Rat", "Zebrafish", "common", "other enriched GO", "parent")
colors <- c("green", "yellowgreen", "violetred", "turquoise", "violet", "thistle", "tan1", "steelblue", "orangered", "olivedrab", "cyan", "red", "skyblue2")
legend("topleft", labels, pch = 19,
       col = colors, pt.cex = 2, cex = 2, bty = "n", ncol = 2)

For the output network graph, see summary at the beginning of this post.



sessionInfo()
## R version 3.3.2 (2016-10-31)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 7 x64 (build 7601) Service Pack 1
## 
## locale:
## [1] LC_COLLATE=English_United States.1252 
## [2] LC_CTYPE=English_United States.1252   
## [3] LC_MONETARY=English_United States.1252
## [4] LC_NUMERIC=C                          
## [5] LC_TIME=English_United States.1252    
## 
## attached base packages:
## [1] parallel  stats4    stats     graphics  grDevices utils     datasets 
## [8] methods   base     
## 
## other attached packages:
##  [1] igraph_1.0.1          GO.db_3.4.0           tidyr_0.6.0          
##  [4] RamiGO_1.20.0         gsubfn_0.6-6          proto_1.0.0          
##  [7] rlist_0.4.6.1         wordcloud_2.5         RColorBrewer_1.1-2   
## [10] SnowballC_0.5.1       tm_0.6-2              NLP_0.1-9            
## [13] dplyr_0.5.0           clusterProfiler_3.2.9 DOSE_3.0.9           
## [16] ggplot2_2.2.1         org.Cf.eg.db_3.4.0    org.Mm.eg.db_3.4.0   
## [19] org.Dm.eg.db_3.4.0    org.Dr.eg.db_3.4.0    org.Gg.eg.db_3.4.0   
## [22] org.Hs.eg.db_3.4.0    org.Ss.eg.db_3.4.0    org.Pt.eg.db_3.4.0   
## [25] org.Ce.eg.db_3.4.0    org.Sc.sgd.db_3.4.0   org.Rn.eg.db_3.4.0   
## [28] AnnotationDbi_1.36.0  IRanges_2.8.1         S4Vectors_0.12.1     
## [31] Biobase_2.34.0        BiocGenerics_0.20.0   biomaRt_2.30.0       
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.8        png_0.1-7          assertthat_0.1    
##  [4] rprojroot_1.1      digest_0.6.11      slam_0.1-40       
##  [7] R6_2.2.0           plyr_1.8.4         backports_1.0.4   
## [10] RSQLite_1.1-1      evaluate_0.10      lazyeval_0.2.0    
## [13] data.table_1.10.0  qvalue_2.6.0       rmarkdown_1.3     
## [16] labeling_0.3       splines_3.3.2      BiocParallel_1.8.1
## [19] RCytoscape_1.24.1  stringr_1.1.0      RCurl_1.95-4.8    
## [22] munsell_0.4.3      fgsea_1.0.2        htmltools_0.3.5   
## [25] tcltk_3.3.2        tibble_1.2         gridExtra_2.2.1   
## [28] codetools_0.2-15   XML_3.98-1.5       bitops_1.0-6      
## [31] grid_3.3.2         gtable_0.2.0       DBI_0.5-1         
## [34] magrittr_1.5       scales_0.4.1       graph_1.52.0      
## [37] stringi_1.1.2      GOSemSim_2.0.3     reshape2_1.4.2    
## [40] DO.db_2.9          fastmatch_1.0-4    tools_3.3.2       
## [43] yaml_2.1.14        colorspace_1.3-2   memoise_1.0.0     
## [46] knitr_1.15.1       XMLRPC_0.3-0

How to map your Google location history with R

2016-12-30 00:00:00 +0000

It’s no secret that Google Big Brothers most of us. But at least they allow us to access quite a lot of the data they have collected on us. Among this is the Google location history.

If you want to see a few ways how to quickly and easily visualize your location history with R, stay tuned…

The Google location history can be downloaded from your Google account under https://takeout.google.com/settings/takeout. Make sure you only tick “location history” for download, otherwise it will take super long to get all your Google data.

The data Google provides you for download is a .json file and can be loaded with the jsonlite package. Loading this file into R might take a few minutes because it can be quite big, depending on how many location points Google had saved about you.

library(jsonlite)
system.time(x <- fromJSON("Standortverlauf.json"))
##    user  system elapsed 
##  112.52    1.12  132.59


The date and time column is in the POSIX milliseconds format, so I converted it to human readable POSIX.

Similarly, longitude and latitude are saved in E7 format and were converted to GPS coordinates.

# extracting the locations dataframe
loc = x$locations

# converting time column from posix milliseconds into a readable time scale
loc$time = as.POSIXct(as.numeric(x$locations$timestampMs)/1000, origin = "1970-01-01")

# converting longitude and latitude from E7 to GPS coordinates
loc$lat = loc$latitudeE7 / 1e7
loc$lon = loc$longitudeE7 / 1e7


This is how the data looks like now:

head(loc)
##     timestampMs latitudeE7 longitudeE7 accuracy                 activitys
## 1 1482393378938  519601402    76004708       29                      NULL
## 2 1482393333953  519601402    76004708       29                      NULL
## 3 1482393033893  519603616    76002628       20 1482393165600, still, 100
## 4 1482392814435  519603684    76001572       20 1482392817678, still, 100
## 5 1482392734911  519603684    76001572       20                      NULL
## 6 1482392433788  519603684    76001572       20                      NULL
##   velocity heading altitude                time      lat      lon
## 1       NA      NA       NA 2016-12-22 08:56:18 51.96014 7.600471
## 2       NA      NA       NA 2016-12-22 08:55:33 51.96014 7.600471
## 3       NA      NA       NA 2016-12-22 08:50:33 51.96036 7.600263
## 4       NA      NA       NA 2016-12-22 08:46:54 51.96037 7.600157
## 5       NA      NA       NA 2016-12-22 08:45:34 51.96037 7.600157
## 6       NA      NA       NA 2016-12-22 08:40:33 51.96037 7.600157

We have the original and converted time, latitude and longitude columns, plus accuracy, activities, velocity, heading and altitude. Accuracy gives the error distance around the point in metres. Activities are saved as a list of data frames and will be explored further down. Velocity, heading and altitude were not recorded for earlier data points.


Data stats

Before I get to actually plotting maps, I want to explore a few basic statistics of the data.


How many data points did Google record over what period of time?

# how many rows are in the data frame?
nrow(loc)
## [1] 600897
min(loc$time)
## [1] "2013-09-06 19:33:41 CEST"
max(loc$time)
## [1] "2016-12-22 08:56:18 CET"


And how are they distributed over days, months and years?

# calculate the number of data points per day, month and year
library(lubridate)
library(zoo)

loc$date <- as.Date(loc$time, '%Y/%m/%d')
loc$year <- year(loc$date)
loc$month_year <- as.yearmon(loc$date)

points_p_day <- data.frame(table(loc$date), group = "day")
points_p_month <- data.frame(table(loc$month_year), group = "month")
points_p_year <- data.frame(table(loc$year), group = "year")

How many days were recorded?

nrow(points_p_day)
## [1] 1057

How many months?

nrow(points_p_month)
## [1] 39

And how many years?

nrow(points_p_year)
## [1] 4
# set up plotting theme
library(ggplot2)
library(ggmap)

my_theme <- function(base_size = 12, base_family = "sans"){
  theme_grey(base_size = base_size, base_family = base_family) +
  theme(
    axis.text = element_text(size = 12),
    axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1),
    axis.title = element_text(size = 14),
    panel.grid.major = element_line(color = "grey"),
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "aliceblue"),
    strip.background = element_rect(fill = "lightgrey", color = "grey", size = 1),
    strip.text = element_text(face = "bold", size = 12, color = "navy"),
    legend.position = "right",
    legend.background = element_blank(),
    panel.margin = unit(.5, "lines"),
    panel.border = element_rect(color = "grey", fill = NA, size = 0.5)
  )
}
points <- rbind(points_p_day[, -1], points_p_month[, -1], points_p_year[, -1])

ggplot(points, aes(x = group, y = Freq)) + 
  geom_point(position = position_jitter(width = 0.2), alpha = 0.3) + 
  geom_boxplot(aes(color = group), size = 1, outlier.colour = NA) + 
  facet_grid(group ~ ., scales = "free") + my_theme() +
  theme(
    legend.position = "none",
    strip.placement = "outside",
    strip.background = element_blank(),
    strip.text = element_blank(),
    axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5)
  ) +
  labs(
    x = "",
    y = "Number of data points",
    title = "How many data points did Google collect about me?",
    subtitle = "Number of data points per day, month and year",
    caption = "\nGoogle collected between 0 and 1500 data points per day
    (median ~500), between 0 and 40,000 per month (median ~15,000) and 
    between 80,000 and 220,000 per year (median ~140,000)."
  )


How accurate is the data?

Accuracy is given in meters, i.e. the smaller the better.

accuracy <- data.frame(accuracy = loc$accuracy, group = ifelse(loc$accuracy < 800, "high", ifelse(loc$accuracy < 5000, "middle", "low")))

accuracy$group <- factor(accuracy$group, levels = c("high", "middle", "low"))

ggplot(accuracy, aes(x = accuracy, fill = group)) + 
  geom_histogram() + 
  facet_grid(group ~ ., scales="free") + 
  my_theme() +
  theme(
    legend.position = "none",
    strip.placement = "outside",
    strip.background = element_blank(),
    axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5)
  ) +
  labs(
    x = "Accuracy in metres",
    y = "Count",
    title = "How accurate is the location data?",
    subtitle = "Histogram of accuracy of location points",
    caption = "\nMost data points are pretty accurate, 
but there are still many data points with a high inaccuracy.
    These were probably from areas with bad satellite reception."
  )


Plotting data points on maps

Finally, we are actually going to plot some maps!

The first map is a simple point plot of all locations recorded around Germany.

You first specify the map area and the zoom factor (the furthest away is 1); the bigger the zoom, the closer to the center of the specified location. Location can be given as longitude/latitude pair or as city or country name.

On this map we can plot different types of plot with the regular ggplot2 syntax. For example, a point plot.

germany <- get_map(location = 'Germany', zoom = 5)

ggmap(germany) + geom_point(data = loc, aes(x = lon, y = lat), alpha = 0.5, color = "red") + 
  theme(legend.position = "right") + 
  labs(
    x = "Longitude", 
    y = "Latitude", 
    title = "Location history data points in Europe",
    caption = "\nA simple point plot shows recorded positions.")


The second map shows a 2D bin plot of accuracy measured for all data points recorded in my home town Münster.

munster <- get_map(location = 'Munster', zoom = 12)

options(stringsAsFactors = T)
ggmap(munster) + 
  stat_summary_2d(geom = "tile", bins = 100, data = loc, aes(x = lon, y = lat, z = accuracy), alpha = 0.5) + 
  scale_fill_gradient(low = "blue", high = "red", guide = guide_legend(title = "Accuracy")) +
  labs(
    x = "Longitude", 
    y = "Latitude", 
    title = "Location history data points around Münster",
    subtitle = "Color scale shows accuracy (low: blue, high: red)",
    caption = "\nThis bin plot shows recorded positions 
    and their accuracy in and around Münster")


We can also plot the velocity of each data point:

loc_2 <- loc[which(!is.na(loc$velocity)), ]

munster <- get_map(location = 'Munster', zoom = 10)

ggmap(munster) + geom_point(data = loc_2, aes(x = lon, y = lat, color = velocity), alpha = 0.3) + 
  theme(legend.position = "right") + 
  labs(x = "Longitude", y = "Latitude", 
       title = "Location history data points in Münster",
       subtitle = "Color scale shows velocity measured for location",
       caption = "\nA point plot where points are colored according 
       to velocity nicely reflects that I moved generally 
       slower in the city center than on the autobahn") +
  scale_colour_gradient(low = "blue", high = "red", guide = guide_legend(title = "Velocity"))


What distance did I travel?

To obtain the distance I moved, I am calculating the distance between data points. Because it takes a long time to calculate, I am only looking at data from last year.

loc3 <- with(loc, subset(loc, loc$time > as.POSIXct('2016-01-01 0:00:01')))
loc3 <- with(loc, subset(loc3, loc$time < as.POSIXct('2016-12-22 23:59:59')))

# Shifting vectors for latitude and longitude to include end position
shift.vec <- function(vec, shift){
  if (length(vec) <= abs(shift)){
    rep(NA ,length(vec))
  } else {
    if (shift >= 0) {
      c(rep(NA, shift), vec[1:(length(vec) - shift)]) }
    else {
      c(vec[(abs(shift) + 1):length(vec)], rep(NA, abs(shift)))
    }
  }
}

loc3$lat.p1 <- shift.vec(loc3$lat, -1)
loc3$lon.p1 <- shift.vec(loc3$lon, -1)

# Calculating distances between points (in metres) with the function pointDistance from the 'raster' package.
library(raster)
loc3$dist.to.prev <- apply(loc3, 1, FUN = function(row) {
  pointDistance(c(as.numeric(as.character(row["lat.p1"])),
                  as.numeric(as.character(row["lon.p1"]))),
                c(as.numeric(as.character(row["lat"])), as.numeric(as.character(row["lon"]))),
                lonlat = T) # Parameter 'lonlat' has to be TRUE!
})
# distance in km
round(sum(as.numeric(as.character(loc3$dist.to.prev)), na.rm = TRUE)*0.001, digits = 2)
## [1] 54466.08
distance_p_month <- aggregate(loc3$dist.to.prev, by = list(month_year = as.factor(loc3$month_year)), FUN = sum)
distance_p_month$x <- distance_p_month$x*0.001
ggplot(distance_p_month[-1, ], aes(x = month_year, y = x,  fill = month_year)) + 
  geom_bar(stat = "identity")  + 
  guides(fill = FALSE) +
  my_theme() +
  labs(
    x = "",
    y = "Distance in km",
    title = "Distance traveled per month in 2016",
    caption = "This barplot shows the sum of distances between recorded 
    positions for 2016. In September we went to the US and Canada."
  )


Activities

Google also guesses my activity based on distance travelled per time.

Here again, it would take too long to look at activity from all data points.

activities <- loc3$activitys

list.condition <- sapply(activities, function(x) !is.null(x[[1]]))
activities  <- activities[list.condition]

df <- do.call("rbind", activities)
main_activity <- sapply(df$activities, function(x) x[[1]][1][[1]][1])

activities_2 <- data.frame(main_activity = main_activity, 
                           time = as.POSIXct(as.numeric(df$timestampMs)/1000, origin = "1970-01-01"))

head(activities_2)
##   main_activity                time
## 1         still 2016-12-22 08:52:45
## 2         still 2016-12-22 08:46:57
## 3         still 2016-12-22 08:33:24
## 4         still 2016-12-22 08:21:31
## 5         still 2016-12-22 08:15:32
## 6         still 2016-12-22 08:10:25
ggplot(activities_2, aes(x = main_activity, group = main_activity, fill = main_activity)) + 
  geom_bar()  + 
  guides(fill = FALSE) +
  my_theme() +
  labs(
    x = "",
    y = "Count",
    title = "Main activities in 2016",
    caption = "Associated activity for recorded positions in 2016. 
    Because Google records activity probabilities for each position, 
    only the activity with highest likelihood were chosen for each position."
  )



sessionInfo()
## R version 3.3.2 (2016-10-31)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 7 x64 (build 7601) Service Pack 1
## 
## locale:
## [1] LC_COLLATE=English_United States.1252 
## [2] LC_CTYPE=English_United States.1252   
## [3] LC_MONETARY=English_United States.1252
## [4] LC_NUMERIC=C                          
## [5] LC_TIME=English_United States.1252    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] raster_2.5-8    sp_1.2-3        ggmap_2.6.2     ggplot2_2.2.0  
## [5] zoo_1.7-13      lubridate_1.6.0 jsonlite_1.1   
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.8       plyr_1.8.4        bitops_1.0-6     
##  [4] tools_3.3.2       digest_0.6.10     evaluate_0.10    
##  [7] tibble_1.2        gtable_0.2.0      lattice_0.20-34  
## [10] png_0.1-7         DBI_0.5-1         mapproj_1.2-4    
## [13] yaml_2.1.14       proto_1.0.0       stringr_1.1.0    
## [16] dplyr_0.5.0       knitr_1.15.1      RgoogleMaps_1.4.1
## [19] maps_3.1.1        rprojroot_1.1     grid_3.3.2       
## [22] R6_2.2.0          jpeg_0.1-8        rmarkdown_1.2    
## [25] reshape2_1.4.2    magrittr_1.5      codetools_0.2-15 
## [28] backports_1.0.4   scales_0.4.1      htmltools_0.3.5  
## [31] assertthat_0.1    colorspace_1.3-2  geosphere_1.5-5  
## [34] labeling_0.3      stringi_1.1.2     lazyeval_0.2.0   
## [37] munsell_0.4.3     rjson_0.2.15

Animating Plots of Beer Ingredients and Sin Taxes over Time

2016-12-22 00:00:00 +0000

With the upcoming holidays, I thought it fitting to finally explore the ttbbeer package. It contains data on beer ingredients used in US breweries from 2006 to 2015 and on the (sin) tax rates for beer, champagne, distilled spirits, wine and various tobacco items since 1862.

I will be exploring these datasets and create animated plots with the animation package and the gganimate extension for ggplot2.


Merry R-Mas and Cheers! :-)

(For full code see below…)



Beer materials since 2006

This dataset includes the monthly aggregates of ingredients as reported by the Beer Monthly Statistical Releases, the Brewer’s Report of Operations and the Quarterly Brewer’s Report of Operations from the US Alcohol and Tobacco Tax and Trade Bureau (TTB).

The following nine ingredients are listed:

  • Malt and malt products
  • Corn and corn products
  • Rice and rice products
  • Barley and barley products
  • Wheat and wheat products
  • Sugar and syrups
  • Dry hops
  • Hops extracts
  • Other ingredients, like flavors, etc.

All ingredient amounts are given in pounds.

I’m not only interested in the total amounts of ingredients that were used but also in the change compared to the previous month. This, I’m calculating as the percent of change compared to the previous entry, starting with 0 in January 2006.

The first thing I wanted to know was how many ingredients were used in total each month. This probably gives a rough estimate of how much beer was produced that month.


The seasonal variation in ingredient use shows that more beer is produced in late spring/ early summer than in winter. Given that it takes about one to four months from production to the finished product, it fits my intuition that people drink more beer when it’s warm outside rather than in the cold of winter.

Now lets look at the ingredients separately. This plot shows the monthly sum of pounds for each ingredient:


We can see the following trends:

  • The seasonal trend in overall beer production is reflected in the individual ingredients
  • Barley and wheat use has been steadily increasing since 2006
  • Corn and corn products has spiked between 2008 and 2011 and slightly dropped since
  • Dry hops has two very prominent spikes in 2014 and 2015, we will look at these in more detail in the following figures
  • The use of hops extracts, malt and rice has slighty decreased since 2006
  • Sugar and syrup use has dropped of between 2006 and 2010 and has since been relatively stable, with the exception of two months in 2014 and 2015
  • Other additives, like flavors, have increased sharply since 2009; in the last two years their use shows extremely high variability between summer (many added) and winter (fewer added); this might be due to the popularity of craft beers and light summer variants.

We can see that barley use spikes every third month, except for a few outlier months. Dy hops is interesting as there are two very prominent spikes in December 2014 and in December 2015. In the same months we see a drop for corn, hops extracts, malt, other ingredients, sugar and syrups and wheat.

The following plot shows the percent of change compared to the previous month for the monthly sums of ingredients:


Here we can again see the seasonality of ingredient use, while we can also see that the overall trend shows no significant change from zero.

The two outlier months, December 2014 and 2015 are quite conspicuous in these plots as well. This might be due to specific trends or there were shortages of supply for fresh hops and malt during these months (even though we don’t see a drop in overall ingredient amounts).

The first animated plot shows the percentages of ingredients that were used per month. Each frame shows one year.


The percentage plots show that there are no major differences in the proportions of ingredients, again with the exception of December 2014 and 2015.



Sin taxes since 1862

Tax rates for beer, champagne, distilled spirits, wine (separated into three alcohol percentage classes: up to 14%, 14 - 21% and 21 24%) and various tobacco items are included in separate datasets. The tax rates are given in US Dollar per

  • Barrel of beer (31 gallons)
  • Wine gallon
  • Proof barrel of distilled spirits
  • 1000 units of small and large cigarettes and cigars
  • 1 lb. of pipe, chewing and roll-your-own tobacco and snuff

Before I can visualize the data, I need to to do some tidying and preprocessing:

  • I’m adding a column with the tax group
  • to keep it consistent, I’m converting all tobacco item names to upper case letters
  • to dates in the tobacco set were probably wrong, because they messed up the plots, so I corrected them

For each tax group I also want to know the percent of change in tax rate compared to the previous time point.

The first plot shows a time line of tax rate changes since 1862. Major historical events are marked.


The strongest increase in tax rates can be seen for champagne in the 1950s and for wine (14%) following the end of the Cold War. At the end of the civil war, there is a sharp drop in the tax rates for distilled spirits. Spirit tax increased little between 1874 and the beginning of prohibition. During prohibition we can also see an increase of tax rates for champagne and wine (21-24%). While tax rates remained stable during the Cold War, they increased again in 1991 for all drinks, except for spirits.

The same plot as above but with absolute tax rates is visualized in the animated plot below:



The plot below catches the same information in a static plot:


All tax rates have increased strongly since 1862. Beer and wine show a strong jump in 1991 and at the beginning of prohibition. Between the end of prohibition and World War II, beer and wine taxes had dropped again, albeit slightly. Since the 1940s, all alcohol taxes have risen steadily. For tobacoo there is only sporadic information.



R code

library(ttbbeer)
data("beermaterials")
library(dplyr)

# order by year from least to most recent
beermaterials_change <- arrange(beermaterials, Year)

# calculate percent change for each ingredient column
for (i in 3:ncol(beermaterials_change)){
  
  value <- beermaterials_change[[colnames(beermaterials_change)[i]]]
  beermaterials_change[, i] <- c(0, 
                       100 * round((value[-1] - value[1:(length(value)-1)]) /
                                     value[1:(length(value)-1)], 2))
  
}

head(beermaterials_change)
## # A tibble: 6 × 11
##      Month  Year Malt_and_malt_products `Corn_and _corn_products`
##      <chr> <int>                  <dbl>                     <dbl>
## 1  January  2006                      0                         0
## 2 February  2006                     -8                        -1
## 3    March  2006                     11                         7
## 4    April  2006                      2                        -6
## 5      May  2006                      5                        10
## 6     June  2006                     -2                       -15
## # ... with 7 more variables: Rice_and_rice_products <dbl>,
## #   Barley_and_barley_products <dbl>, Wheat_and_wheat_products <dbl>,
## #   Sugar_and_syrups <dbl>, Hops_dry <dbl>, Hops_extracts <dbl>,
## #   Other <dbl>
#calcluate sum of ingredients per month
beermaterials$sum <- rowSums(beermaterials[, -c(1:2)])

For plotting I’m converting the tables from wide to long format:

library(tidyr)

beermaterials_gather <- beermaterials %>%
  gather(Ingredient, Amount, Malt_and_malt_products:Other)
beermaterials_gather$Date <- paste("01", beermaterials_gather$Month, beermaterials_gather$Year, sep = "-") %>%
  as.Date(format = "%d-%B-%Y")
beermaterials_gather$Ingredient <- gsub("_", " ", beermaterials_gather$Ingredient)
beermaterials_change_gather <- beermaterials_change %>%
  gather(Ingredient, Amount, Malt_and_malt_products:Other)
beermaterials_change_gather$Date <- paste("01", beermaterials_change_gather$Month, beermaterials_change_gather$Year, sep = "-") %>%
  as.Date(format = "%d-%B-%Y")
beermaterials_change_gather$Ingredient <- gsub("_", " ", beermaterials_change_gather$Ingredient)

And I’m preparing my custom ggplot2 theme:

library(ggplot2)
my_theme <- function(base_size = 12, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
  theme(
    axis.text = element_text(size = 12),
    axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1),
    axis.title = element_text(size = 14),
    panel.grid.major = element_line(color = "grey"),
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "aliceblue"),
    strip.background = element_rect(fill = "royalblue", color = "grey", size = 1),
    strip.text = element_text(face = "bold", size = 12, color = "white"),
    legend.position = "bottom",
    legend.justification = "top", 
    legend.box = "horizontal",
    legend.box.background = element_rect(colour = "grey50"),
    legend.background = element_blank(),
    panel.border = element_rect(color = "grey", fill = NA, size = 0.5)
  )
}

theme_set(my_theme())
ggplot(beermaterials_gather, aes(x = Date, y = sum)) +
  geom_point(size = 1.5, alpha = 0.6, color = "navy") +
  geom_smooth(alpha = 0.3, color = "black", size = 0.5, linetype = "dashed") +
  geom_line(color = "navy") +
  guides(color = FALSE) +
  labs(
    x = "",
    y = "Amount in pounds",
    title = "Sum of Beer Ingredients by Month and Year",
    subtitle = "Sum of monthly aggregates of all ingredients from 2006 to 2015"
  ) +
  scale_x_date(date_breaks = "6 month", date_labels = "%B %Y")
ggplot(beermaterials_gather, aes(x = Date, y = Amount, color = Ingredient)) +
  geom_point(size = 1.5, alpha = 0.6) +
  geom_smooth(alpha = 0.3, color = "black", size = 0.5, linetype = "dashed") +
  geom_line() +
  facet_wrap(~Ingredient, ncol = 3, scales = "free_y") +
  guides(color = FALSE) +
  labs(
    x = "",
    y = "Amount in pounds",
    title = "Beer Ingredients by Month and Year",
    subtitle = "Monthly aggregates of ingredients from 2006 to 2015"
  ) +
  scale_x_date(date_breaks = "1 year", date_labels = "%Y")
ggplot(beermaterials_gather, aes(x = Date, y = Amount, color = Ingredient)) +
  geom_point(size = 1.5, alpha = 0.6) +
  guides(color = FALSE) +
  geom_line() +
  labs(
    x = "",
    y = "Amount in pounds",
    title = "Beer Ingredients by Month and Year",
    subtitle = "from January 2006 to December 2015"
  ) +
  scale_x_date(date_breaks = "1 month", date_labels = "%B %Y") +
  facet_wrap(~Ingredient, ncol = 1, scales = "free_y")

To explore the seasonal changes in more detail, I’m widening the resolution of the x-axis.

ggplot(beermaterials_change_gather, aes(x = Date, y = Amount, color = Ingredient)) +
  geom_smooth(alpha = 0.3, color = "black", size = 0.5, linetype = "dashed") +
  geom_point(size = 1.5, alpha = 0.6) +
  geom_line() +
  facet_wrap(~Ingredient, ncol = 3, scales = "free_y") +
  guides(color = FALSE) +
  labs(
    x = "",
    y = "Percent change",
    title = "Beer Ingredients by Month and Year",
    subtitle = "Percent change of monthly aggregates of ingredients compared to previous month from 2006 to 2015"
  ) +
  scale_x_date(date_breaks = "1 year", date_labels = "%Y")
beermaterials_percent <- cbind(beermaterials[, 1:2], prop.table(as.matrix(beermaterials[, -c(1:2)]), margin = 1) * 100)

beermaterials_percent_gather <- beermaterials_percent %>%
  gather(Ingredient, Amount, Malt_and_malt_products:Other)
beermaterials_percent_gather$Ingredient <- gsub("_", " ", beermaterials_percent_gather$Ingredient)

f = unique(beermaterials_percent_gather$Month)
beermaterials_percent_gather <- within(beermaterials_percent_gather, Month <- factor(Month, levels = f))
library(animation)
ani.options(ani.width=800, ani.height=500)

saveGIF({
  for (year in rev(unique(beermaterials_percent_gather$Year))){
  
  pie <- ggplot(subset(beermaterials_percent_gather, Year == paste(year)), aes(x = "", y = Amount, fill = Ingredient, frame = Year)) + 
  geom_bar(width = 1, stat = "identity") + theme_minimal() +
  coord_polar("y", start = 0) +
  labs(
    title = "Percentage of Beer Ingredients by Year and Month",
    subtitle = paste(year)) +
  theme(
  axis.title.x = element_blank(),
  axis.title.y = element_blank(),
  panel.border = element_blank(),
  panel.grid = element_blank(),
  axis.ticks = element_blank(),
  plot.title = element_text(size = 14, face = "bold"),
  legend.title = element_blank(),
  legend.position = "bottom",
  legend.text = element_text(size = 8)
  ) + 
  facet_wrap( ~ Month, ncol = 6) +
    guides(fill = guide_legend(ncol = 9, byrow = FALSE))

  print(pie)
  }
  
}, movie.name = "beer_ingredients.gif")
ggplot(beermaterials_percent_gather, aes(x = Month, y = Amount, color = Ingredient, fill = Ingredient)) + 
  geom_bar(width = 0.8, stat = "identity", alpha = 0.7) + 
  my_theme() +
  guides(color = FALSE) +
  facet_wrap( ~ Year, ncol = 5) +
  labs(
    x = "",
    y = "Percent",
    title = "Percentage of Beer Ingredients by Year and Month",
    subtitle = "from 2006 to 2015",
    fill = ""
  )

To see the information from the animated figure at one glance, the next plots shows the percentages of ingredients per month and year as a barplot:


beertax$tax <- "BEER"
champagnetax$tax <- "CHAMPAGNE"
spirittax$tax <- "SPIRIT"
tobaccotax$ITEM <- toupper(tobaccotax$ITEM)
tobaccotax$tax <- paste("TOBACCO", tobaccotax$ITEM, sep = " ")
tobaccotax <- tobaccotax[, -1]
tobaccotax[18, 2] <- "1977-01-31"
tobaccotax[18, 1] <- "1977-01-01"
winetax14$tax <- "WINE 14%"
winetax1421$tax <- "WINE 14-21%"
winetax2124$tax <- "WINE 21-24%"
taxes_list <- c("beertax", "champagnetax", "spirittax", "tobaccotax", "winetax14", "winetax1421", "winetax2124")

for (tax in taxes_list){
  
  tax_df <- get(tax)
  tax_df <- tax_df[match(sort(tax_df$FROM), tax_df$FROM), ]
  
  if (tax != "tobaccotax"){
    tax_df$change <- c(0, 
                       100 * round((tax_df$RATE[-1] - tax_df$RATE[1:(length(tax_df$RATE)-1)]) /
                                     tax_df$RATE[1:(length(tax_df$RATE)-1)], 2))

  } else {
    tax_df$change <- "NA"
  }
  
  if (tax == "beertax"){
    tax_df_2 <- tax_df
  } else {
    tax_df_2 <- rbind(tax_df_2, tax_df)
  }
}

tax_2 <- tax_df_2
tax_2$ID <- paste(tax_2$tax, rownames(tax_2), sep = "_")
head(tax_2)
##         FROM         TO RATE  tax change     ID
## 4 1862-09-01 1863-03-03  1.0 BEER      0 BEER_4
## 5 1863-03-03 1864-03-31  0.6 BEER    -40 BEER_5
## 6 1864-04-01 1898-06-13  1.0 BEER     67 BEER_6
## 7 1898-06-14 1901-06-30  2.0 BEER    100 BEER_7
## 8 1901-07-01 1902-06-30  1.6 BEER    -20 BEER_8
## 9 1902-07-01 1914-10-22  1.0 BEER    -38 BEER_9

Now I can gather the dataframe for plotting.

library(tidyr)
tax_gather <- tax_2 %>%
  gather(dat_column, Date, FROM:TO)
tax_gather$group <- gsub(" .*", "", tax_gather$tax)
d <- data.frame(xmin = as.Date(c("1861-04-12", "1914-07-28", "1920-01-16", "1939-11-01", "1947-03-12")), 
                ymin = -Inf, 
                xmax = as.Date(c("1865-04-09", "1918-11-11", "1933-12-05", "1945-05-08", "1991-12-26")), 
                ymax = Inf, 
                description = c("Civial War", "WW1", "Prohibition", "WW2", "Cold War"),
                pos_y = as.numeric(c(rep(c("1700", "1500"), 2), "1700")))
d$pos_x <- as.Date(NA)

for (i in 1:nrow(d)){
  d[i, "pos_x"] <- as.Date(as.character(median(c(d$xmin[i], d$xmax[i]))))
}

tax_gather$group2 <- ifelse(tax_gather$group == "TOBACCO", "TOBACCO", "ALCOHOL")

ggplot() +
  geom_rect(data = d, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax), alpha=0.2, fill = "navy") +
  geom_label(data = d, aes(pos_x, pos_y, label = description), check_overlap = TRUE) +
  geom_point(data = subset(tax_gather, group2 == "ALCOHOL"), aes(x = Date, y = as.numeric(change), color = tax), size = 1.5, alpha = 0.6) +
  geom_line(data = subset(tax_gather, group2 == "ALCOHOL"), aes(x = Date, y = as.numeric(change), color = tax), size = 1) +
  facet_wrap(~group2, ncol = 1, scales = "free_y") +
  labs(
    x = "",
    y = "Tax rate change in %",
    color = "",
    title = "Alcohol tax rate changes during history",
    subtitle = "Change in percent compared to previous timepoint from 1862 to 2016"
  ) +
  guides(color = guide_legend(ncol = 6, byrow = FALSE)) +
  scale_x_date(date_breaks = "10 year", date_labels = "%Y")
d$pos_y <- as.numeric(c(rep(c("20", "18"), 2), "20"))

library(gganimate)

p <- ggplot() +
  geom_rect(data = d, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax), alpha=0.2, fill = "navy") +
  geom_label(data = d, aes(pos_x, pos_y, label = description), check_overlap = TRUE) +
  geom_point(data = tax_gather, aes(x = Date, y = RATE, color = tax, frame = tax), size = 2, alpha = 0.6) +
  geom_line(data = tax_gather, aes(x = Date, y = RATE, color = tax, frame = tax), size = 1.5) +
  labs(
    x = "",
    y = "Tax rate",
    color = "",
    title = "Tax rates of ",
    subtitle = "from 1862 to 2016"
  ) +
  guides(color = guide_legend(ncol = 6, byrow = FALSE)) +
  scale_x_date(date_breaks = "10 year", date_labels = "%Y")

ani.options(ani.width=1000, ani.height=600)
gganimate(p, filename = "taxes_by_tax.mp4", cumulative = TRUE)
gganimate(p, filename = "taxes_by_tax.gif", cumulative = TRUE)
ggplot() +
  geom_rect(data = d, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax), alpha=0.2, fill = "navy") +
  geom_smooth(data = tax_gather, aes(x = Date, y = RATE, color = tax), alpha = 0.3, color = "black", size = 0.5, linetype = "dashed") +
  geom_point(data = tax_gather, aes(x = Date, y = RATE, color = tax), size = 1.5, alpha = 0.6) +
  geom_line(data = tax_gather, aes(x = Date, y = RATE, color = tax), size = 1) +
  facet_wrap(~group, ncol = 2, scales = "free_y") +
  labs(
    x = "",
    y = "Tax rate",
    color = "",
    title = "Alcohol and tobacco tax rates",
    subtitle = "from 1862 to 2016"
  ) +
 guides(color = guide_legend(ncol = 6, byrow = FALSE)) +
 scale_x_date(date_breaks = "20 year", date_labels = "%Y")

## R version 3.3.2 (2016-10-31)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 7 x64 (build 7601) Service Pack 1
## 
## locale:
## [1] LC_COLLATE=English_United States.1252 
## [2] LC_CTYPE=English_United States.1252   
## [3] LC_MONETARY=English_United States.1252
## [4] LC_NUMERIC=C                          
## [5] LC_TIME=English_United States.1252    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] ggplot2_2.2.0 tidyr_0.6.0   dplyr_0.5.0   ttbbeer_1.1.0
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.8      knitr_1.15.1     magrittr_1.5     munsell_0.4.3   
##  [5] lattice_0.20-34  colorspace_1.3-2 R6_2.2.0         stringr_1.1.0   
##  [9] plyr_1.8.4       tools_3.3.2      grid_3.3.2       nlme_3.1-128    
## [13] gtable_0.2.0     mgcv_1.8-16      DBI_0.5-1        htmltools_0.3.5 
## [17] yaml_2.1.14      lazyeval_0.2.0   assertthat_0.1   rprojroot_1.1   
## [21] digest_0.6.10    tibble_1.2       Matrix_1.2-7.1   evaluate_0.10   
## [25] rmarkdown_1.2    labeling_0.3     stringi_1.1.2    scales_0.4.1    
## [29] backports_1.0.4

How to build a Shiny app for disease- & trait-associated locations of the human genome

2016-12-18 00:00:00 +0000

This app is based on the gwascat R package and its ebicat38 database and shows trait-associated SNP locations of the human genome. You can visualize and compare the genomic locations of up to 8 traits simultaneously.

The National Human Genome Research Institute (NHGRI) catalog of Genome-Wide Association Studies (GWAS) is a curated resource of single-nucleotide polymorphism (SNP)-trait associations. The database contains more than 100,000 SNPs and all SNP-trait associations with a p-value <1 × 10^−5.


You can access the app at

https://shiring.shinyapps.io/gwas_shiny_app/

Loading the data might take a few seconds. Patience you must have, my young padawan… ;-)


Alternatively, if you are using R, you can load the app via Github with shiny:

library(shiny)
runGitHub("ShirinG/GWAS_Shiny_App") 



If you want to know how I built this app, contine reading.


Data preparation

Initially, I wanted to load all data directly from R packages. And while this worked in principal, it made the app load super slowly.

So, I decided to prepare the data beforehand and save the datatables as tab-delimited text files and load them in the app.


Genome information

First, I’m preparing the genome information: for each chromosome I want to know its length:

library(AnnotationDbi)
library(org.Hs.eg.db)

library(EnsDb.Hsapiens.v79)
edb <- EnsDb.Hsapiens.v79

keys <- keys(edb, keytype="SEQNAME")
chromosome_length <- select(edb, keys = keys, columns = c("SEQLENGTH", "SEQNAME"), keytype = "SEQNAME")
chromosome_length <- chromosome_length[grep("^[0-9]+$|^X$|^Y$|^MT$", chromosome_length$SEQNAME), ]

write.table(chromosome_length, "chromosome_length.txt", row.names = FALSE, col.names = TRUE, sep = "\t")


GWAS SNP data

I am saving the GWAS data as a text file as well; this datatable will be used for plotting the SNP locations.

I am also saving a table with the alphabetically sorted traits as input for the drop-down menu.

library(gwascat)
data(ebicat38)

gwas38 <- as.data.frame(ebicat38)
gwas38$DISEASE.TRAIT <- gsub("&beta;", "beta", gwas38$DISEASE.TRAIT)

# PVALUE_MLOG: -log(p-value)
gwas38[is.infinite(gwas38$PVALUE_MLOG), "PVALUE_MLOG"] <- max(gwas38[is.finite(gwas38$PVALUE_MLOG), "PVALUE_MLOG"]) + 10
summary(gwas38[is.finite(gwas38$PVALUE_MLOG), "PVALUE_MLOG"])
summary(gwas38$PVALUE_MLOG)

# OR or BETA*: Reported odds ratio or beta-coefficient associated with strongest SNP risk allele. Note that if an OR <1 is reported this is inverted, along with the reported allele, so that all ORs included in the Catalog are >1. Appropriate unit and increase/decrease are included for beta coefficients.
summary(gwas38$OR.or.BETA)

write.table(gwas38, "gwas38.txt", row.names = FALSE, col.names = TRUE, sep = "\t")

gwas38_traits <- as.data.frame(table(gwas38$DISEASE.TRAIT))
colnames(gwas38_traits) <- c("Trait", "Frequency")

write.table(gwas38_traits, "gwas38_traits.txt", row.names = FALSE, col.names = TRUE, sep = "\t")


The Shiny App

I built my Shiny app with the traditional two-file system. This means that I have a “ui.R” file containing the layout and a “server.R” file, which contains the R code.


ui.R

From top to bottom I chose the following settings:

  • loading the table with all traits
  • adding a “choose below” option before the traits to have no trait chosen by default
  • “united” theme for layout
  • two slider bars, one for the p-value threshold and one for the odds ratio/ beta coefficient threshold
  • drop-down menus for all traits, maximal eight can be chosen to be plotted simultaneously
  • main panel with explanatory text, main plot and output tables
library(shiny)
library(shinythemes)

gwas38_traits <- read.table("gwas38_traits.txt", header = TRUE, sep = "\t")

diseases <- c("choose below", as.character(gwas38_traits$Trait))

shinyUI(fluidPage(theme = shinytheme("united"),
                  titlePanel("GWAS disease- & trait-associated SNP locations of the human genome"),

                  sidebarLayout(
                    sidebarPanel(

                      sliderInput("pvalmlog",
                                  "-log(p-value):",
                                  min = -13,
                                  max = 332,
                                  value = -13),

                      sliderInput("orbeta",
                                  "odds ratio/ beta-coefficient :",
                                  min = 0,
                                  max = 4426,
                                  value = 0),

                      selectInput("variable1", "First trait:",
                                  choices = diseases[-1]),

                      selectInput("variable2", "Second trait:",
                                  choices = diseases),

                      selectInput("variable3", "Third trait:",
                                  choices = diseases),

                      selectInput("variable4", "Fourth trait:",
                                  choices = diseases),

                      selectInput("variable5", "Fifth trait:",
                                  choices = diseases),

                      selectInput("variable6", "Sixth trait:",
                                  choices = diseases),

                      selectInput("variable7", "Seventh trait:",
                                  choices = diseases),

                      selectInput("variable8", "Eighth trait:",
                                  choices = diseases)
                    ),
                    mainPanel(
                      br(),
                      p("The National Human Genome Research Institute (NHGRI) catalog of Genome-Wide Association Studies (GWAS) is a curated resource of single-nucleotide polymorphisms (SNP)-trait associations. The database contains more than 100,000 SNPs and all SNP-trait associations with a p-value <1 × 10^−5."),
                      p("This app is based on the 'gwascat' R package and its 'ebicat38' database and shows trait-associated SNP locations of the human genome."),
                      p("For more info on how I built this app check out", a("my blog.", href = "https://shiring.github.io/")),
                      br(),
                      h4("How to use this app:"),
                      div("Out of 1320 available traits or diseases you can choose up to 8 on the left-side panel und see their chromosomal locations below. The traits are sorted alphabetically. You can also start typing in the drop-down panel and traits matching your query will be suggested.", style = "color:blue"),
                      br(),
                      div("With the two sliders on the left-side panel you can select SNPs above a p-value threshold (-log of association p-value) and/or above an odds ratio/ beta-coefficient threshold. The higher the -log of the p-value the more significant the association of the SNP with the trait. Beware that some SNPs have no odds ratio value and will be shown regardless of the threshold.", style = "color:blue"),
                      br(),
                      div("The table directly below the plot shows the number of SNPs for each selected trait (without subsetting when p-value or odds ratio are changed). The second table below the first shows detailed information for each SNP of the chosen traits. This table shows only SNPs which are plotted (it subsets according to p-value and odds ratio thresholds).", style = "color:blue"),
                      br(),
                      div("Loading might take a few seconds...", style = "color:red"),
                      br(),
                      plotOutput("plot"),
                      tableOutput('table'),
                      tableOutput('table2'),
                      br(),
                      p("GWAS catalog:", a("http://www.ebi.ac.uk/gwas/", href = "http://www.ebi.ac.uk/gwas/")),
                      p(a("Welter, Danielle et al. “The NHGRI GWAS Catalog, a Curated Resource of SNP-Trait Associations.” Nucleic Acids Research 42.Database issue (2014): D1001–D1006. PMC. Web. 17 Dec. 2016.", href = "https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3965119/"))
                    )
                  )
))


server.R

This file contains the R code to produce the plot and output tables.

  • The chromosome data is loaded for the chromosome length barplot. In order to have the correct order, I am setting the chromosome factor order manually
  • renderPlot() contains the code for the plot
  • from the eight possible input traits, all unique traits are plotted (removing “choose below” first), unique traits were set here, so as not to plot the same trait on top of themselves if the same trait is chosen in two drop-down menus
  • the first output table shows the number of SNPs for the chosen traits
  • the second output table shows SNP information for all SNPs of the chosen trait(s) within the chosen thresholds
library(shiny)

chr_data <- read.table("chromosome_length.txt", header = TRUE, sep = "\t")

chr_data$SEQNAME <- as.factor(chr_data$SEQNAME)
f = c("1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "X", "Y", "MT")
chr_data <- within(chr_data, SEQNAME <- factor(SEQNAME, levels = f))

library(ggplot2)

gwas38 <- read.table("gwas38.txt", header = TRUE, sep = "\t")
gwas38_traits <- read.table("gwas38_traits.txt", header = TRUE, sep = "\t")

# Define server logic required to plot variables
shinyServer(function(input, output) {

  # Generate a plot of the requested variables
  output$plot <- renderPlot({

    plot_snps <- function(trait){

      if (any(trait == "choose below")){
        trait <- trait[-which(trait == "choose below")]
      } else {
        trait <- unique(trait)
      }

        for (i in 1:length(unique(trait))){

          trait_pre <- unique(trait)[i]
          snps_data_pre <- gwas38[which(gwas38$DISEASE.TRAIT == paste(trait_pre)), ]
          snps_data_pre <- data.frame(Chr = snps_data_pre$seqnames,
                                      Start = snps_data_pre$CHR_POS,
                                      SNPid = snps_data_pre$SNPS,
                                      Trait = rep(paste(trait_pre), nrow(snps_data_pre)),
                                      PVALUE_MLOG = snps_data_pre$PVALUE_MLOG,
                                      OR.or.BETA = snps_data_pre$OR.or.BETA)

          snps_data_pre <- subset(snps_data_pre, PVALUE_MLOG > input$pvalmlog)
          snps_data_pre <- subset(snps_data_pre, OR.or.BETA > input$orbeta | is.na(OR.or.BETA))

          if (i == 1){

            snps_data <- snps_data_pre

          } else {

            snps_data <- rbind(snps_data, snps_data_pre)

          }
        }

      snps_data <- within(snps_data, Chr <- factor(Chr, levels = f))

      p <- ggplot(data = snps_data, aes(x = Chr, y = as.numeric(Start))) +
        geom_bar(data = chr_data, aes(x = SEQNAME, y = as.numeric(SEQLENGTH)), stat = "identity", fill = "grey90", color = "black") +
        theme(
          axis.text = element_text(size = 14),
          axis.title = element_text(size = 14),
          panel.grid.major = element_blank(),
          panel.grid.minor = element_blank(),
          panel.background = element_rect(fill = "white"),
          legend.position = "bottom"
        ) +
        labs(x = "Chromosome", y = "Position")

      p + geom_segment(data = snps_data, aes(x = as.numeric(as.character(Chr)) - 0.45, xend = as.numeric(as.character(Chr)) + 0.45,
                                           y = Start, yend = Start, colour = Trait), size = 2, alpha = 0.5) +
        scale_colour_brewer(palette = "Set1") +
        guides(colour = guide_legend(ncol = 3, byrow = FALSE))

      }

    plot_snps(trait = c(input$variable1, input$variable2, input$variable3, input$variable4, input$variable5, input$variable6, input$variable7, input$variable8, input$variable9))
  })

  output$table <- renderTable({
    table <- gwas38_traits[which(gwas38_traits$Trait %in% c(input$variable1, input$variable2, input$variable3, input$variable4, input$variable5, input$variable6, input$variable7, input$variable8)), ]

    })

  output$table2 <- renderTable({
    table <- gwas38[which(gwas38$DISEASE.TRAIT %in% c(input$variable1, input$variable2, input$variable3, input$variable4, input$variable5, input$variable6, input$variable7, input$variable8)), c(1:5, 8, 10, 11, 13, 14, 20, 27, 32, 33, 34, 36)]

    table <- subset(table, PVALUE_MLOG > input$pvalmlog)
    table <- subset(table, OR.or.BETA > input$orbeta | is.na(OR.or.BETA))

    table[order(table$seqnames), ]
    })

})


Deploying to shinyapps.io

Finally, I am deploying my finished app to shinyapps.io with the rsconnect package. You will need to register with shinyapps.io before you can host your Shiny app there and register rsconnect with the token you received.

library(rsconnect)
rsconnect::deployApp('~/Documents/Github/blog_posts_prep/gwas/shiny/GWAS_Shiny_App')

This will open your deployed app right aways. You can now share the link to your app with the world! :-)



If you are interested in human genomics…

… you might also like these posts:



## R version 3.3.2 (2016-10-31)
## Platform: x86_64-apple-darwin13.4.0 (64-bit)
## Running under: macOS Sierra 10.12.1
## 
## locale:
## [1] de_DE.UTF-8/de_DE.UTF-8/de_DE.UTF-8/C/de_DE.UTF-8/de_DE.UTF-8
## 
## attached base packages:
## [1] parallel  stats4    stats     graphics  grDevices utils     datasets 
## [8] methods   base     
## 
## other attached packages:
##  [1] rsconnect_0.6                          
##  [2] gwascat_2.6.0                          
##  [3] Homo.sapiens_1.3.1                     
##  [4] TxDb.Hsapiens.UCSC.hg19.knownGene_3.2.2
##  [5] GO.db_3.4.0                            
##  [6] OrganismDbi_1.16.0                     
##  [7] EnsDb.Hsapiens.v79_2.1.0               
##  [8] ensembldb_1.6.2                        
##  [9] GenomicFeatures_1.26.0                 
## [10] GenomicRanges_1.26.1                   
## [11] GenomeInfoDb_1.10.1                    
## [12] org.Hs.eg.db_3.4.0                     
## [13] AnnotationDbi_1.36.0                   
## [14] IRanges_2.8.1                          
## [15] S4Vectors_0.12.1                       
## [16] Biobase_2.34.0                         
## [17] BiocGenerics_0.20.0                    
## 
## loaded via a namespace (and not attached):
##   [1] colorspace_1.3-1              rprojroot_1.1                
##   [3] biovizBase_1.22.0             htmlTable_1.7                
##   [5] XVector_0.14.0                base64enc_0.1-3              
##   [7] dichromat_2.0-0               base64_2.0                   
##   [9] interactiveDisplayBase_1.12.0 codetools_0.2-15             
##  [11] splines_3.3.2                 snpStats_1.24.0              
##  [13] ggbio_1.22.0                  fail_1.3                     
##  [15] doParallel_1.0.10             knitr_1.15.1                 
##  [17] Formula_1.2-1                 gQTLBase_1.6.0               
##  [19] Rsamtools_1.26.1              cluster_2.0.5                
##  [21] graph_1.52.0                  shiny_0.14.2                 
##  [23] httr_1.2.1                    backports_1.0.4              
##  [25] assertthat_0.1                Matrix_1.2-7.1               
##  [27] lazyeval_0.2.0                limma_3.30.6                 
##  [29] acepack_1.4.1                 htmltools_0.3.5              
##  [31] tools_3.3.2                   gtable_0.2.0                 
##  [33] reshape2_1.4.2                dplyr_0.5.0                  
##  [35] fastmatch_1.0-4               Rcpp_0.12.8                  
##  [37] Biostrings_2.42.1             nlme_3.1-128                 
##  [39] rtracklayer_1.34.1            iterators_1.0.8              
##  [41] ffbase_0.12.3                 stringr_1.1.0                
##  [43] mime_0.5                      XML_3.98-1.5                 
##  [45] AnnotationHub_2.6.4           zlibbioc_1.20.0              
##  [47] scales_0.4.1                  BSgenome_1.42.0              
##  [49] VariantAnnotation_1.20.2      BiocInstaller_1.24.0         
##  [51] SummarizedExperiment_1.4.0    RBGL_1.50.0                  
##  [53] RColorBrewer_1.1-2            BBmisc_1.10                  
##  [55] yaml_2.1.14                   memoise_1.0.0                
##  [57] gridExtra_2.2.1               ggplot2_2.2.0                
##  [59] biomaRt_2.30.0                rpart_4.1-10                 
##  [61] reshape_0.8.6                 latticeExtra_0.6-28          
##  [63] stringi_1.1.2                 RSQLite_1.1-1                
##  [65] foreach_1.4.3                 checkmate_1.8.2              
##  [67] BiocParallel_1.8.1            BatchJobs_1.6                
##  [69] GenomicFiles_1.10.3           bitops_1.0-6                 
##  [71] matrixStats_0.51.0            evaluate_0.10                
##  [73] lattice_0.20-34               GenomicAlignments_1.10.0     
##  [75] bit_1.1-12                    GGally_1.3.0                 
##  [77] plyr_1.8.4                    magrittr_1.5                 
##  [79] sendmailR_1.2-1               R6_2.2.0                     
##  [81] Hmisc_4.0-1                   DBI_0.5-1                    
##  [83] foreign_0.8-67                mgcv_1.8-16                  
##  [85] survival_2.40-1               RCurl_1.95-4.8               
##  [87] nnet_7.3-12                   tibble_1.2                   
##  [89] rmarkdown_1.2                 grid_3.3.2                   
##  [91] data.table_1.10.0             gQTLstats_1.6.0              
##  [93] digest_0.6.10                 xtable_1.8-2                 
##  [95] ff_2.2-13                     httpuv_1.3.3                 
##  [97] brew_1.0-6                    openssl_0.9.5                
##  [99] munsell_0.4.3                 beeswarm_0.2.3               
## [101] Gviz_1.18.1

Gene homology Part 2 - creating directed networks with igraph

2016-12-14 00:00:00 +0000

In my last post I created a gene homology network for human genes. In this post I want to extend the network to include edges for other species.

First, I am loading biomaRt and extract a list of all available datasets.

library(biomaRt)
ensembl = useMart("ensembl")
datasets <- listDatasets(ensembl)

For this network I don’t want to include all species with available information in Biomart because it would get way too messy. So, I am only picking the 11 species below: Human, fruitfly, mouse, C. elegans, dog, zebrafish, chicken, chimp, rat, yeast and pig. For all of these species, I am going to use their respective org.db database to extract gene information.

In order to loop over all species’ org.dbs for installation and loading, I create a new column with the name of the library.

datasets$orgDb <- NA

datasets[grep("hsapiens", datasets$dataset), "orgDb"] <- "org.Hs.eg.db"
datasets[grep("dmel", datasets$dataset), "orgDb"] <- "org.Dm.eg.db"
datasets[grep("mmus", datasets$dataset), "orgDb"] <- "org.Mm.eg.db"
datasets[grep("celegans", datasets$dataset), "orgDb"] <- "org.Ce.eg.db"
datasets[grep("cfam", datasets$dataset), "orgDb"] <- "org.Cf.eg.db"
datasets[grep("drerio", datasets$dataset), "orgDb"] <- "org.Dr.eg.db"
datasets[grep("ggallus", datasets$dataset), "orgDb"] <- "org.Gg.eg.db"
datasets[grep("ptrog", datasets$dataset), "orgDb"] <- "org.Pt.eg.db"
datasets[grep("rnor", datasets$dataset), "orgDb"] <- "org.Rn.eg.db"
datasets[grep("scer", datasets$dataset), "orgDb"] <- "org.Sc.sgd.db"
datasets[grep("sscrofa", datasets$dataset), "orgDb"] <- "org.Ss.eg.db"

datasets <- datasets[!is.na(datasets$orgDb), ]
datasets
##                       dataset                              description
## 9    rnorvegicus_gene_ensembl                     Rat genes (Rnor_6.0)
## 13   scerevisiae_gene_ensembl Saccharomyces cerevisiae genes (R64-1-1)
## 14      celegans_gene_ensembl  Caenorhabditis elegans genes (WBcel235)
## 22  ptroglodytes_gene_ensembl            Chimpanzee genes (CHIMP2.1.4)
## 26       sscrofa_gene_ensembl                  Pig genes (Sscrofa10.2)
## 32      hsapiens_gene_ensembl                  Human genes (GRCh38.p7)
## 36       ggallus_gene_ensembl        Chicken genes (Gallus_gallus-5.0)
## 41        drerio_gene_ensembl                 Zebrafish genes (GRCz10)
## 53 dmelanogaster_gene_ensembl                   Fruitfly genes (BDGP6)
## 61     mmusculus_gene_ensembl                  Mouse genes (GRCm38.p5)
## 69   cfamiliaris_gene_ensembl                    Dog genes (CanFam3.1)
##              version         orgDb
## 9           Rnor_6.0  org.Rn.eg.db
## 13           R64-1-1 org.Sc.sgd.db
## 14          WBcel235  org.Ce.eg.db
## 22        CHIMP2.1.4  org.Pt.eg.db
## 26       Sscrofa10.2  org.Ss.eg.db
## 32         GRCh38.p7  org.Hs.eg.db
## 36 Gallus_gallus-5.0  org.Gg.eg.db
## 41            GRCz10  org.Dr.eg.db
## 53             BDGP6  org.Dm.eg.db
## 61         GRCm38.p5  org.Mm.eg.db
## 69         CanFam3.1  org.Cf.eg.db

Now I can load the Ensembl mart for each species.

for (i in 1:nrow(datasets)) {
  ensembl <- datasets[i, 1]
  assign(paste0(ensembl), useMart("ensembl", dataset = paste0(ensembl)))
}

specieslist <- datasets$dataset

And install all org.db libraries, if necessary.

library(AnnotationDbi)

load_orgDb <- function(orgDb){
  if(!orgDb %in% installed.packages()[,"Package"]){
    source("https://bioconductor.org/biocLite.R")
    biocLite(orgDb)
  }
}

sapply(datasets$orgDb, load_orgDb, simplify = TRUE, USE.NAMES = TRUE)

Once they are all installed, I can load the libraries.

lapply(datasets$orgDb, require, character.only = TRUE)

Because I want to exctract and compare information on homologus genes for all possible combinations of species, they need to have a common identifier. To find them, I first produce a list with all available keytypes and then ask for common elements using rlist’s list.common() function.

keytypes_list <- lapply(datasets$orgDb, function(x) NULL)
names(keytypes_list) <- paste(datasets$orgDb)

for (orgDb in datasets$orgDb){
  keytypes_list[[orgDb]] <- keytypes(get(orgDb))
}

library(rlist)
list.common(keytypes_list)
##  [1] "ENTREZID"    "ENZYME"      "EVIDENCE"    "EVIDENCEALL" "GENENAME"   
##  [6] "GO"          "GOALL"       "ONTOLOGY"    "ONTOLOGYALL" "PATH"       
## [11] "PMID"        "REFSEQ"      "UNIPROT"

Entrez IDs are available for all species, so these are the keys I am using as gene identifiers. I am first creating a table of homologous genes for each species with all other species separately by looping over the datasets table.

for (i in 1:nrow(datasets)){
  orgDbs <- datasets$orgDb[i]
  values <- keys(get(orgDbs), keytype = "ENTREZID")

  ds <- datasets$dataset[i]
  mart <- useMart("ensembl", dataset = paste(ds))
  print(mart)

  if (!is.na(listFilters(mart)$name[grep("^entrezgene$", listFilters(mart)$name)])){
    if (!is.na(listAttributes(mart)$name[grep("^entrezgene$", listAttributes(mart)$name)])){
    print("TRUE")
    for (species in specieslist) {
      print(species)
      if (species != ds){
        assign(paste("homologs", orgDbs, species, sep = "_"), getLDS(attributes = c("entrezgene"),
                                                                     filters = "entrezgene",
                                                                     values = values,
                                                                     mart = mart,
                                                                     attributesL = c("entrezgene"),
                                                                     martL = get(species)))
      }
    }
  }
  }
}

Now I sort and combine these tables into one big table and remove duplicate entries.

library(dplyr)

for (i in 1:nrow(datasets)){
  orgDbs <- datasets$orgDb[i]
  values <- data.frame(GeneID = keys(get(orgDbs), keytype = "ENTREZID"))
  values$GeneID <- as.character(values$GeneID)
  ds <- datasets$dataset[i]

  for (j in 1:length(specieslist)){
    species <- specieslist[j]

    if (j == 1){
      homologs_table <- values
    }

    if (species != ds){
      homologs_species <- get(paste("homologs", orgDbs, species, sep = "_"))
      homologs_species$EntrezGene.ID <- as.character(homologs_species$EntrezGene.ID)

      homologs <- left_join(values, homologs_species, by = c("GeneID" = "EntrezGene.ID"))
      homologs <- homologs[!duplicated(homologs$GeneID), ]
      colnames(homologs)[2] <- paste(species)

      homologs_table <- left_join(homologs_table, homologs, by = "GeneID")
  }
  }

  colnames(homologs_table)[1] <- paste(ds)

  assign(paste("homologs_table", ds, sep = "_"), homologs_table)
}


for (i in 1:nrow(datasets)){
  ds <- datasets$dataset[i]

  if (i == 1){
    homologs_table_combined <- get(paste("homologs_table", ds, sep = "_"))
    homologs_table_combined <- homologs_table_combined[, order(colnames(homologs_table_combined))]
  } else {
    homologs_table_species <- get(paste("homologs_table", ds, sep = "_"))
    homologs_table_species <- homologs_table_species[, order(colnames(homologs_table_species))]

    homologs_table_combined <- rbind(homologs_table_combined, homologs_table_species)
  }
}

homologs_table_combined <- homologs_table_combined[!duplicated(homologs_table_combined), ]

Now each row in the table shows one gene and its homologs in all of the 11 species. Some genes have multiple homologs in other species (e.g. gene 44071 of the fruitfly has two homologs in zebrafish).

head(homologs_table_combined)
##   celegans_gene_ensembl cfamiliaris_gene_ensembl
## 1                    NA                       NA
## 2                    NA                   477699
## 3                    NA                   477023
## 4                173979                   490207
## 5                179795                   607852
## 6                177055                       NA
##   dmelanogaster_gene_ensembl drerio_gene_ensembl ggallus_gene_ensembl
## 1                         NA                  NA                   NA
## 2                      44071                  NA                   NA
## 3                         NA              431754               770094
## 4                      38864              406283                   NA
## 5                      36760              794259               395373
## 6                    3772179              541489               421909
##   hsapiens_gene_ensembl mmusculus_gene_ensembl ptroglodytes_gene_ensembl
## 1                    NA                     NA                        NA
## 2                     2                 232345                    465372
## 3                    30                 235674                    460268
## 4                    34                  11364                    469356
## 5                    47                 104112                    454672
## 6                    52                  11431                    458990
##   rnorvegicus_gene_ensembl scerevisiae_gene_ensembl sscrofa_gene_ensembl
## 1                    24152                       NA                   NA
## 2                    24153                       NA               403166
## 3                    24157                   854646            100515577
## 4                    24158                       NA               397104
## 5                    24159                   854310                   NA
## 6                    24161                   856187            100737301

In order to create the cooccurrence matrix, I am converting gene names to “1” and NAs to “0” and multiply the matrix with its transposed self. This is now the basis for igraph’s graph_from_adjacency_matrix() function with which I’m creating a directed network. Before plotting the network, I am removing multiple entries and loops.

homologs_table_combined_matrix <- as.matrix(ifelse(is.na(homologs_table_combined), 0, 1))

co_occurrence <- t(as.matrix(homologs_table_combined_matrix)) %*% as.matrix(homologs_table_combined_matrix)

library(igraph)
g <- graph_from_adjacency_matrix(co_occurrence,
                                 weighted = TRUE,
                                 diag = FALSE,
                                 mode = "directed")

g <- simplify(g, remove.multiple = TRUE, remove.loops = TRUE)

I am also preparing a node annotation table with colors for each species group and the count of unique gene identifiers for species.

datasets[, 2] <- gsub("(.*)( genes (.*))", "\\1", datasets[, 2])
datasets$description[grep("Saccharomyces", datasets$description)] <- "Yeast"
datasets$description[grep("elegans", datasets$description)] <- "C. elegans"

datasets$group <- ifelse(datasets$description == "Yeast", "fungus",
                         ifelse(datasets$description == "C. elegans", "roundworm",
                                ifelse(datasets$description == "Chicken", "bird",
                                       ifelse(datasets$dataset == "Zebrafish", "fish",
                                              ifelse(datasets$description == "Fruitfly", "insect", "mammal")))))

list <- as.character(unique(datasets$dataset))

library(RColorBrewer)
set3 <- brewer.pal(length(list), "Set3")

datasets$col <- NA
for (i in 1:length(list)){
  datasets[datasets$dataset == list[i], "col"] <- set3[i]
}

no_genes <- as.data.frame(apply(homologs_table_combined, 2, function(x) length(unique(x))))
no_genes$dataset <- rownames(no_genes)
colnames(no_genes)[1] <- "no_genes"

library(dplyr)
datasets <- left_join(datasets, no_genes, by = "dataset")

datasets <- datasets[order(datasets$dataset), ]

I also want to have the proportion of genes that have homologs in each of the respective other species as edge attributes. I am preparing an extra table for this of all possible combinations of nodes.

for (i in 1:ncol(homologs_table_combined)){
  input <- unique(na.omit(homologs_table_combined[, i]))
  homologs_table_combined_subset <- homologs_table_combined[!is.na(homologs_table_combined[, i]), ]

  for (j in 1:ncol(homologs_table_combined)){
    if (i != j){
      output <- unique(na.omit(homologs_table_combined_subset[, j]))
      
      if (i == 1 & j == 2){
        edge_table <- data.frame(source = colnames(homologs_table_combined)[i], 
                                 target = colnames(homologs_table_combined)[j], 
                                 weight = length(output)/length(input))
      } else {
        table <- data.frame(source = colnames(homologs_table_combined)[i], 
                            target = colnames(homologs_table_combined)[j], 
                            weight = length(output)/length(input))
        edge_table <- rbind(edge_table, table)
      }
    }
  }
}

list <- as.character(unique(edge_table$source))

edge_table$col <- NA
for (i in 1:length(list)){
  edge_table[edge_table$source == list[i], "col"] <- set3[i]
}

And finally I can produce the plot:

V(g)$color <- datasets$col
V(g)$label <- datasets$description
V(g)$size <- datasets$no_genes/2000
E(g)$arrow.size <- 3
E(g)$width <- edge_table$weight*25

plot(g,
     vertex.label.font = 1,
     vertex.shape = "sphere",
     vertex.label.cex = 3,
     vertex.label.color = "black",
     vertex.frame.color = NA,
     edge.curved = rep(0.1, ecount(g)),
     edge.color = edge_table$col,
     layout = layout_in_circle)

Node size shows the total number of genes of each species and edge width shows the proportion of these genes with homologs in the respective other species. Edge and node colors show the 11 species and their outgoing edges.

Weirdly, the mouse seems to have more gene entries in the org.db library than human. If anyone knows why that is, please let me know!



## R version 3.3.2 (2016-10-31)
## Platform: x86_64-apple-darwin13.4.0 (64-bit)
## Running under: macOS Sierra 10.12.1
## 
## locale:
## [1] de_DE.UTF-8/de_DE.UTF-8/de_DE.UTF-8/C/de_DE.UTF-8/de_DE.UTF-8
## 
## attached base packages:
## [1] parallel  stats4    stats     graphics  grDevices utils     datasets 
## [8] methods   base     
## 
## other attached packages:
##  [1] dplyr_0.5.0          RColorBrewer_1.1-2   igraph_1.0.1        
##  [4] rlist_0.4.6.1        org.Cf.eg.db_3.4.0   org.Mm.eg.db_3.4.0  
##  [7] org.Dm.eg.db_3.4.0   org.Dr.eg.db_3.4.0   org.Gg.eg.db_3.4.0  
## [10] org.Hs.eg.db_3.4.0   org.Ss.eg.db_3.4.0   org.Pt.eg.db_3.4.0  
## [13] org.Ce.eg.db_3.4.0   org.Sc.sgd.db_3.4.0  org.Rn.eg.db_3.4.0  
## [16] AnnotationDbi_1.36.0 IRanges_2.8.1        S4Vectors_0.12.1    
## [19] Biobase_2.34.0       BiocGenerics_0.20.0  biomaRt_2.30.0      
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.8       bitops_1.0-6      tools_3.3.2      
##  [4] digest_0.6.10     tibble_1.2        RSQLite_1.1-1    
##  [7] evaluate_0.10     memoise_1.0.0     DBI_0.5-1        
## [10] yaml_2.1.14       stringr_1.1.0     knitr_1.15.1     
## [13] rprojroot_1.1     data.table_1.10.0 R6_2.2.0         
## [16] XML_3.98-1.5      rmarkdown_1.2     magrittr_1.5     
## [19] backports_1.0.4   htmltools_0.3.5   assertthat_0.1   
## [22] stringi_1.1.2     RCurl_1.95-4.8

Creating a network of human gene homology with R and D3

2016-12-11 00:00:00 +0000

Edited on 20 December 2016


This week I want to explore how many of our human genes have homologs in other species. I will use this question to show how to visualize dendrograms and (D3-) networks.

Here, I am looking at gene homologs for all genes of the human genome. In Part 2 I am creating a full network between a subset of the species from this post.


Accessing species information from Biomart

I am using the biomaRt package to access all datasets available in the Biomart Ensembl database. The listDatasets() function shows that currently there is data for 69 species (including humans). This dataframe gives us the name of each dataset in Biomart Ensembl, the respective common species name and version number.

library(biomaRt)
ensembl = useMart("ensembl")
datasets <- listDatasets(ensembl)

human = useMart("ensembl", dataset = "hsapiens_gene_ensembl")

I am then looping over each species to call useMart() and assign an object to each species’ Ensembl database.

for (i in 1:nrow(datasets)) {
  ensembl <- datasets[i, 1]
  assign(paste0(ensembl), useMart("ensembl", dataset = paste0(ensembl)))
}

specieslist <- datasets$dataset


Identifying human gene homologs

Protein coding genes

The majority of human genes are protein coding genes. This means that their DNA sequence will be translated into a protein with specific cellular functions. For an overview of all gene biotypes in the human genome, have a look at this previous post.

To identify all protein coding genes in the human genome, I am using AnnotationDbi and the EnsDb.Hsapiens.v79 database.

library(AnnotationDbi)
library(EnsDb.Hsapiens.v79)

# get all Ensembl gene IDs
human_EnsDb <- keys(EnsDb.Hsapiens.v79, keytype = "GENEID")

# get the biotype of each gene ID
human_gene_EnsDb <- ensembldb::select(EnsDb.Hsapiens.v79, keys = human_EnsDb, columns = "GENEBIOTYPE", keytype = "GENEID")

# and keep only protein coding genes
human_prot_coding_genes <- human_gene_EnsDb[which(human_gene_EnsDb$GENEBIOTYPE == "protein_coding"), ]

To get a dataframe of human genes and their homologs in other species, I am looping over all Biomart Ensembl databases using biomaRt’s getLDS() function. getLDS() links two datasets and is equivalent to homology mapping with Ensembl. I am also including the human database so that I can see how many genes from EnsDb.Hsapiens.v79 are found via Biomart.

for (species in specieslist) {
  print(species)
  assign(paste0("homologs_human_", species), getLDS(attributes = c("ensembl_gene_id", "chromosome_name"), 
       filters = "ensembl_gene_id", 
       values = human_prot_coding_genes$GENEID, 
       mart = human, 
       attributesL = c("ensembl_gene_id", "chromosome_name"), 
       martL = get(species)))
}

Based on EnsDb.Hsapiens.v79 protein coding genes, I now have 69 individual datasets with homologous genes found in other species. I now want to combine these 69 datasets into one cooccurrence matrix. To do this, I am first joining the homologous gene subsets back to the original dataframe with all protein coding genes. Then, I convert existing homologous gene names to 1 and NAs to 0. These, I then merge into one table. This final table now contains the information whether each gene has a homolog in each of the other species (“1”) or not (“0”).

library(dplyr)

for (i in 1:length(specieslist)){
  species <- specieslist[i]
  homologs_human <- left_join(human_prot_coding_genes, get(paste0("homologs_human_", species)), by = c("GENEID" = "Gene.ID"))
  homologs_human[, paste0(species)] <- ifelse(is.na(homologs_human$Gene.ID.1), 0, 1)
  homologs_human <- homologs_human[, c(1, 6)]
  homologs_human <- homologs_human[!duplicated(homologs_human$GENEID), ]
  
  if (i == 1){
    homologs_human_table <- homologs_human
  } else {
    homologs_human_table <- left_join(homologs_human_table, homologs_human, by = "GENEID")
  }
}

I only want to keep the 21,925 human genes found via Biomart (77 genes were not found).

homologs_human_table <- homologs_human_table[which(homologs_human_table[, grep("sapiens", colnames(homologs_human_table))] == 1), ]

The final steps to creating a cooccurrence matrix are

  • removing the column with Ensembl gene IDs (because we have the Homo sapiens column from Biomart),
  • multiplying with the transposed matrix and
  • setting all rows and columns which don’t show cooccurrences with humans to 0 (because I only looked at homology from the human perspective).
gene_matrix <- homologs_human_table[, -1]

co_occurrence <- t(as.matrix(gene_matrix)) %*% as.matrix(gene_matrix)
co_occurrence[-grep("sapiens", rownames(co_occurrence)), -grep("sapiens", colnames(co_occurrence))] <- 0

The first information I want to extract from this data is how many and what proportion of human genes had a homolog in each of the other species.

genes <- data.frame(organism = colnames(gene_matrix),
                    number_genes = colSums(gene_matrix),
                    proportion_human_genes = colSums(gene_matrix)/nrow(homologs_human_table))

I also want to know which species are similar regarding the gene homology they share with humans. This, I will visualize with a hierarchical clustering dendrogram using the dendextend and circlize packages.

Before I can produce plots however, I want to create annotation attributes for each species:

  • their common name (because most people won’t know off the top of their heads what species T. nigroviridis is) and
  • a grouping attribute, e.g. mammal, fish, bird, etc.

In order to have the correct order of species, I create the annotation dataframe based off the dendrogram.

Most of the common names I can extract from the list of Biomart Ensembl datasets but a few I have to change manually.

library(dendextend)
library(circlize)

# create a dendrogram
h <- hclust(dist(scale(as.matrix(t(gene_matrix))), method = "manhattan"))
dend <- as.dendrogram(h)

library(dplyr)
labels <- as.data.frame(dend %>% labels) %>%
  left_join(datasets, by = c("dend %>% labels" = "dataset")) %>%
  left_join(genes, by = c("dend %>% labels" = "organism"))

labels[, 2] <- gsub("(.*)( genes (.*))", "\\1", labels[, 2])
labels$group <- c(rep("mammal", 2), rep("fish", 8), "amphibia", rep("fish", 4), rep("bird", 5), rep("reptile", 2), rep("mammal", 41), "fungus", "lamprey", rep("seasquirt", 2), "nematode", "insect")

labels$description[grep("hedgehog", labels$description)] <- "Hedgehog Tenrec"
labels$description[grep("Saccharomyces", labels$description)] <- "Yeast"
labels$description[grep("savignyi", labels$description)] <- "C. savignyi"
labels$description[grep("intestinalis", labels$description)] <- "C. intestinalis"
labels$description[grep("elegans", labels$description)] <- "C. elegans"
labels$description[grep("turtle", labels$description)] <- "Turtle"
labels$description[grep("Vervet", labels$description)] <- "Vervet"

The first plot shows the proportion of human genes with a homolog in each of the other species.

library(ggplot2)

my_theme <- function(base_size = 12, base_family = "sans"){
  theme_grey(base_size = base_size, base_family = base_family) +
  theme(
    axis.text = element_text(size = 12),
    axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1),
    axis.title = element_text(size = 14),
    panel.grid.major = element_line(color = "grey"),
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "aliceblue"),
    strip.background = element_rect(fill = "lightgrey", color = "grey", size = 1),
    strip.text = element_text(face = "bold", size = 12, color = "navy"),
    legend.position = "bottom",
    legend.background = element_blank(),
    panel.margin = unit(.5, "lines"),
    panel.border = element_rect(color = "grey", fill = NA, size = 0.5)
  )
}

labels_p <- labels[-1, ]

f = as.factor(labels_p[order(labels_p$proportion_human_genes, decreasing = TRUE), "description"])
labels_p <- within(labels_p, description <- factor(description, levels = f))

ggplot(labels_p, aes(x = description, y = proportion_human_genes, fill = group)) +
  geom_bar(stat = "identity") +
  my_theme() +
  labs(
    x = "",
    y = "Proportion of homology",
    fill = "",
    title = "Human gene homology",
    subtitle = "Proportion of human protein coding genes with homologs in 68 other species"
  )

As expected, the highest proportions of homologous genes are found between humans and most other mammals: between 81 and 70% of our genes have a homolog in most mammals. The lowest proportion of homologous genes are found in yeast, but even such a simply organism shares more than 20% of its genes with us. Think about that the next time you enjoy a delicious freshly baked bread or drink a nice cold glass of beer…

However, based on these Biomart Ensembl databases, more of our genes have a homolog in mice than in primates. I would have expected primates to come out on top, but what we see here might be confounded by the different amounts of information we have on species’ genomes and the accuracy of this information: there has been more research done on the mouse than e.g. on the alpaca. We therefore have more and more accurate information on the mouse genome and have identified more homologous genes than in other species.


For plotting the dendrograms, I am adding a column with color specifications for each group to the annotation table.

labels$col <- c(rep("aquamarine4", 2), rep("darkorange3", 8), "blue4", rep("darkorange3", 4), rep("darkred", 5), rep("darkmagenta", 2), rep("aquamarine4", 41), "darkslateblue", "deepskyblue1", rep("deeppink3", 2), "forestgreen", "brown3")

This dendrogram shows the hierarchical clustering of human gene homology with 68 other species. The species are grouped by color (see legend).

labels(dend) <- labels$description
dend %>% 
  set("labels_col", labels$col) %>% 
  set("labels_cex", 2) %>% 
  circlize_dendrogram(labels_track_height = NA, dend_track_height = 0.3) 

legend("topleft", unique(labels$group), pch = 19,
       col = unique(labels$col), pt.cex = 3, cex = 3, bty = "n", ncol = 1)

The most distant group from all other species include yeast, fruit fly, lamprey, C. elegans and the two seasquirt species, followed by tarsier, alpaca, sloth and shrew. The next group includes the fish, reptiles, amphibiens, birds and the platypus. Interestingly, the platypus the Xenopus are more similar to fish than to birds and reptiles.

And finally, we have the remaining mammals: here, the primates cluster nicely together.


In order to show the homology between human genes with the other species, we can also plot a network. The 2D-network is created from the cooccurrence matrix with node size and edge width representing the proportion of homology and colors depicting the species’ group.

# plot the network graph
library(igraph)
g <- graph_from_adjacency_matrix(co_occurrence,
                         weighted = TRUE,
                         diag = FALSE,
                         mode = "undirected")

g <- simplify(g, remove.multiple = F, remove.loops = T, edge.attr.comb = c(weight = "sum", type = "ignore"))

labels_2 <- labels[match(rownames(co_occurrence), labels[, 1]), ]

V(g)$color <- labels_2$col
V(g)$label <- labels_2$description
V(g)$size <- labels_2$proportion_human_genes*25
E(g)$arrow.size <- 0.2
E(g)$edge.color <- "gray80"
E(g)$width <- labels_2$proportion_human_genes*10
plot(g,
     vertex.label.font = 1,
     vertex.shape = "sphere",
     vertex.label.cex = 1,
     vertex.label.color = "white",
     vertex.frame.color = NA)

legend("topleft", unique(labels_2$group), pch = 19,
       col = unique(labels_2$col), pt.cex = 2, cex = 1.5, bty = "n", ncol = 1)

Another nice way to visualize networks is to create a D3 JavaScript network graph with the networkD3 package. It shows the same information as the 2D-network above but here you can interact with the nodes, which makes the species names easier to read.

library(networkD3)

net_d3 <- igraph_to_networkD3(g, group = labels_2$group)
net_d3$nodes <- merge(net_d3$nodes, genes, by.x = "name", by.y = "row.names")
net_d3$nodes$proportion_human_genes <- net_d3$nodes$proportion_human_genes
net_d3$nodes <- merge(net_d3$nodes, labels_2[, 1:2], by.x = "name", by.y = "dend %>% labels")
net_d3$nodes[, 1] <- net_d3$nodes$description

net_d3$nodes <- net_d3$nodes[match(labels_2[, 1], net_d3$nodes$organism), ]

# Create force directed network plot
forceNetwork(Links = net_d3$links, 
             Nodes = net_d3$nodes,
             Nodesize = "proportion_human_genes",
             radiusCalculation = JS(" Math.sqrt(d.nodesize)*15"),
             linkDistance = 100,
             zoom = TRUE, 
             opacity = 0.9,
             Source = 'source', 
             Target = 'target',
             linkWidth = networkD3::JS("function(d) { return d.value/5; }"),
             bounded = FALSE, 
             colourScale = JS("d3.scale.category10()"),
             NodeID = 'name', 
             Group = 'group', 
             fontSize = 10, 
             charge = -200,
             opacityNoHover = 0.7)

Click to open the interactive network and use the mouse to zoom and turn the plot:

networkD3



## R version 3.3.2 (2016-10-31)
## Platform: x86_64-apple-darwin13.4.0 (64-bit)
## Running under: macOS Sierra 10.12
## 
## locale:
## [1] de_DE.UTF-8/de_DE.UTF-8/de_DE.UTF-8/C/de_DE.UTF-8/de_DE.UTF-8
## 
## attached base packages:
## [1] parallel  stats4    stats     graphics  grDevices utils     datasets 
## [8] methods   base     
## 
## other attached packages:
##  [1] networkD3_0.2.13         igraph_1.0.1            
##  [3] ggplot2_2.2.0            dplyr_0.5.0             
##  [5] circlize_0.3.9           dendextend_1.3.0        
##  [7] EnsDb.Hsapiens.v79_1.1.0 ensembldb_1.6.0         
##  [9] GenomicFeatures_1.26.0   GenomicRanges_1.26.0    
## [11] GenomeInfoDb_1.10.0      AnnotationDbi_1.36.0    
## [13] IRanges_2.8.0            S4Vectors_0.12.0        
## [15] Biobase_2.34.0           BiocGenerics_0.20.0     
## [17] biomaRt_2.30.0          
## 
## loaded via a namespace (and not attached):
##  [1] httr_1.2.1                    jsonlite_1.1                 
##  [3] AnnotationHub_2.6.0           shiny_0.14.2                 
##  [5] assertthat_0.1                interactiveDisplayBase_1.12.0
##  [7] Rsamtools_1.26.0              yaml_2.1.14                  
##  [9] robustbase_0.92-6             RSQLite_1.0.0                
## [11] lattice_0.20-34               digest_0.6.10                
## [13] XVector_0.14.0                colorspace_1.3-1             
## [15] htmltools_0.3.5               httpuv_1.3.3                 
## [17] Matrix_1.2-7.1                plyr_1.8.4                   
## [19] XML_3.98-1.5                  zlibbioc_1.20.0              
## [21] xtable_1.8-2                  mvtnorm_1.0-5                
## [23] scales_0.4.1                  whisker_0.3-2                
## [25] BiocParallel_1.8.0            tibble_1.2                   
## [27] SummarizedExperiment_1.4.0    nnet_7.3-12                  
## [29] lazyeval_0.2.0                magrittr_1.5                 
## [31] mime_0.5                      mclust_5.2                   
## [33] evaluate_0.10                 MASS_7.3-45                  
## [35] class_7.3-14                  BiocInstaller_1.24.0         
## [37] tools_3.3.2                   GlobalOptions_0.0.10         
## [39] trimcluster_0.1-2             stringr_1.1.0                
## [41] kernlab_0.9-25                munsell_0.4.3                
## [43] cluster_2.0.5                 fpc_2.1-10                   
## [45] Biostrings_2.42.0             grid_3.3.2                   
## [47] RCurl_1.95-4.8                htmlwidgets_0.8              
## [49] bitops_1.0-6                  labeling_0.3                 
## [51] rmarkdown_1.1                 gtable_0.2.0                 
## [53] flexmix_2.3-13                DBI_0.5-1                    
## [55] R6_2.2.0                      GenomicAlignments_1.10.0     
## [57] knitr_1.15                    prabclus_2.2-6               
## [59] rtracklayer_1.34.0            shape_1.4.2                  
## [61] modeltools_0.2-21             stringi_1.1.2                
## [63] Rcpp_0.12.8                   DEoptimR_1.0-8               
## [65] diptest_0.75-7

How to set up your own R blog with Github pages and Jekyll Bootstrap

2016-12-04 00:00:00 +0000

This post is in reply to a request: How did I set up this R blog?

I have wanted to have my own R blog for a while before I actually went ahead and realised this page. I had seen all the cool things people do with R by following R-bloggers and reading their newsletter every day!

While I am using R every day at work, the thematic scope there is of course very bioinformatics-centric (with a little bit of Machine Learning and lots of visualization and statistics), I wanted to have an incentive to regularly explore other types of analyses and other types of data that I don’t normally work with. I have also very often benefited from other people’s published code in that it gave me ideas for my own work and I hope that sharing my own analyses will inspire others as much as I often am by what can be be done with data.

Moreover, since I’ve started my R blog, I’ve very much enjoyed interacting more with other R-bloggers and R users and getting valuable feedback. I definitely feel more part of the community now that I’m an active contributor!


In this post, I will describe how I go about publishing R code and its output to my blog. I will not go into detail on how to customize your page but only give a quick introduction to setting up a blog with Github pages and JekyllBoostrap. I will focus on how to write content with RMarkdown and RStudio.


Setting up Github pages

I decided to host my blog on Github pages because I was already used to working with Github (it’s where I store my exprAnalysis R package).

Setting up your own Github page is free and very easy! All you need is a Github account. On Github you need to create a new repository called “username.github.io”” (substitute username for your Github user name, in my case: ShirinG.github.io).

I set up the rest via the terminal interface:

  • go to the folder where your repository should live (in this example called “Test”)
  • clone the repository and change to its directory
  • pull from Github to make sure everything is synchronized and up to date between your computer and Github
cd Home/Documents/Test
git clone https://github.com/username/username.github.io
cd username.github.io
git pull origin master


JekyllBoostrap

Github pages also supports Jekyll, an easy to use static site generator perfect for blogging. I set up my blog with JekyllBoostrap, a full blog scaffold for Jekyll based blogs. Follow their Quick Start Guide for instructions on how to install it.

It also gives detailed instructions on how to set up comments, analytics and permalinks.


Customizing your page

Most of the heavy-lifting has already been done for you with JekyllBootstrap but of course there are many options for customizing and adapting your blog’s layout to your liking.

The main file for customizing your page is Jekyll’s “*_config.yml*” file (check back with JekyllBoostrap for details). Here you need to fill it all the relevant information about your page, like title, author, etc.

If you know a little bit about JavaScript, CSS and HTML, you can also directly modify the style sheets in the directory “*_includes*”.

You can change your page’s theme by following this tutorial.

In your page repository’s main folder you have a markdown called “index”. This is your main page and is already set up to show a list of all your blog post.

You can have more pages in the navigation bar if you want, like an “About me”, etc. You simply need to create a new markdown with the respective content in your page repository’s main folder.

You will also automatically get .xml files with your blog’s RSS and Atom feeds.


Preparing R blog posts with RMarkdown

I am generally working in RStudio, so this is where I’m preparing all my blog posts. I’m writing my code in RMarkdown and produce html output files until I deem a blog post finished.

This is how an example Rmarkdown document could look like.

Your Rmarkdown will need a header with title, author, date and output information:

---
title: "This is a test post for my R blog"
author: "Dr. Shirin Glander"
date: '`r Sys.Date()`'
output: html_document
---

Below the header you can write text, insert pictures, create tables and much more. Code chunks will go in between brackets.

```{r, echo=TRUE}
summary(cars)
```

As always, there is much more information on how to work with RMarkdown on the webs.

The final code chunk will look something like this in the output:

library(ggplot2)
ggplot(mpg, aes(displ, hwy)) +
  geom_point(aes(color = class)) +
  geom_smooth(se = FALSE, method = "loess") +
  labs(
    title = "Fuel efficiency generally decreases with engine size",
    subtitle = "Two seaters (sports cars) are an exception because of their light weight",
    caption = "Data from fueleconomy.gov"
  )


Once I’m happy with my finished blog post, I’m changing the output format in the YAML header to create a markdown file (I’m using markdown_github) and knit again.

output: 
  md_document:
    variant: markdown_github

This will produce a markdown document (ending in .md) of your blog post. This markdown is now almost ready to be published but first you need to add a header in the markdown document to the top of your finished markdown for Jekyll to recognize. An example markdown header looks like this:

---
layout: post
title: "This is a test post for my R blog"
date: 2016-12-01
categories: rblogging
tags: test ggplot2
---

It needs to contain the layout definition for post, the title and date. Categories and tags are optional but I like using them to categorize the posts I have written and make them easier to find again.

Now this file is ready to go on your blog.


Publishing a blog post

I’m working in a different directory to produce my posts and test out things (in this example called “blog_prep”), so once I have my markdown file ready (in this example called “example_post.md”), I am going back to the terminal window and copy it to my blog repository’s folder “_post” and rename it so that it has the following format: Date-name.md (of course, you could also manually copy, paste and rename the file).

cp ../blog_prep/example_post.md _posts
mv _posts/example_post.md _posts/2016-11-20-example_post.md

If you also have figures or graphs in your blog post, there will be a folder generated along with your markdown with the name “example_post_files”. You need to copy this folder to your blog repository as well. If you gave your post a category, it will look for the files in a folder with the same name as your category with subfolders for the date. With the example above, which was categorized as “rblogging”, I will need to have/ make a folder with the same name and create subdirectories for the date on which my post is published:

mkdir -p rblogging/2016/12/01
cp ../blog_prep/example_post_files rblogging/2016/12/01


Testing your blog page locally

Before I push my updated blog to Github where it will go live, I’m testing the output locally with Jekyll. In order to be able to run Jekyll locally, you need to follow these excellent guidelines. Once you have Jekyll up and running, all you need to do is to be in your page repository and run jekyll serve.

bundle exec jekyll serve

This will build your page from the contents of your repository and you can preview whether everything looks the way you want it to.


If my new blog post looks okay and I don’t want to change anything any more, I can finally stage, commit and push the updated blog to Github.

git add *
git commit -m "name"
git push origin master


Alternatively, if you’d rather use RStudio for pushing changes to Github, you can also connect your blog repository to RStudio and stage, commit and push from there.


And voilà, your new R blog post is out there for the world to see!

Extreme Gradient Boosting and Preprocessing in Machine Learning - Addendum to predicting flu outcome with R

2016-12-02 00:00:00 +0000

In last week’s post I explored whether machine learning models can be applied to predict flu deaths from the 2013 outbreak of influenza A H7N9 in China. There, I compared random forests, elastic-net regularized generalized linear models, k-nearest neighbors, penalized discriminant analysis, stabilized linear discriminant analysis, nearest shrunken centroids, single C5.0 tree and partial least squares.


Extreme gradient boosting

Extreme gradient boosting (XGBoost) is a faster and improved implementation of gradient boosting for supervised learning and has recently been very successfully applied in Kaggle competitions. Because I’ve heard XGBoost’s praise being sung everywhere lately, I wanted to get my feet wet with it too. So this week I want to compare the prediction success of gradient boosting with the same dataset. Additionally, I want to test the influence of different preprocessing methods on the outcome.


“XGBoost uses a more regularized model formalization to control over-fitting, which gives it better performance.” Tianqi Chen, developer of xgboost

XGBoost is a tree ensemble model, which means the sum of predictions from a set of classification and regression trees (CART). In that, XGBoost is similar to Random Forests but it uses a different approach to model training.


Starting with the same test and training data (partitioned into validation test and validation train subsets) from last week’s post, I am training extreme gradient boosting models as implemented in the xgboost and caret packages with different preprocessing settings.

Out of the different implementations and variations of gradient boosting algorithms, caret performed best on PCA-preprocessed data in the validation set. These paramteres were then used to predict the outcome in the test set and compare it to last week’s predictions.

Compared to last week, there is much less uncertainty in the predictions from XGBoost. Overall, I would say that this algorithm is superior to the others I have used before.


xgboost

Extreme gradient boosting is implemented in the xgboost package.

# install the stable/pre-compiled version from CRAN

install.packages('xgboost')

# or install from weekly updated drat repo

install.packages("drat", repos="https://cran.rstudio.com")
drat:::addRepo("dmlc")
install.packages("xgboost", repos="http://dmlc.ml/drat/", type="source")

XGBoost supports only numbers, so the outcome classes have to be converted into integers and both training and test data have to be in numeric matrix format.

library(xgboost)
matrix_train <- apply(val_train_X, 2, function(x) as.numeric(as.character(x)))
outcome_death_train <- ifelse(val_train_data$outcome == "Death", 1, 0)

matrix_test <- apply(val_test_X, 2, function(x) as.numeric(as.character(x)))
outcome_death_test <- ifelse(val_test_data$outcome == "Death", 1, 0)

xgb_train_matrix <- xgb.DMatrix(data = as.matrix(matrix_train), label = outcome_death_train)
xgb_test_matrix <- xgb.DMatrix(data = as.matrix(matrix_test), label = outcome_death_test)

watchlist <- list(train = xgb_train_matrix, test = xgb_test_matrix)
label <- getinfo(xgb_test_matrix, "label")

I am using cross validation to evaluate the error rate.

param <- list("objective" = "binary:logistic")

xgb.cv(param = param, 
       data = xgb_train_matrix, 
       nfold = 3,
       label = getinfo(xgb_train_matrix, "label"),
       nrounds = 5)
## [1]  train-error:0.133950+0.022131   test-error:0.341131+0.129919 
## [2]  train-error:0.124941+0.033432   test-error:0.358674+0.105191 
## [3]  train-error:0.115932+0.045501   test-error:0.340156+0.092598 
## [4]  train-error:0.124941+0.033432   test-error:0.340156+0.092598 
## [5]  train-error:0.089616+0.051301   test-error:0.394737+0.098465


Training with gbtree

gbtree is the default booster for xgb.train.

bst_1 <- xgb.train(data = xgb_train_matrix, 
                   label = getinfo(xgb_train_matrix, "label"),
                   max.depth = 2, 
                   eta = 1, 
                   nthread = 4, 
                   nround = 50, # number of trees used for model building
                   watchlist = watchlist, 
                   objective = "binary:logistic")
## [1]  train-error:0.250000    test-error:0.347826 
## [2]  train-error:0.178571    test-error:0.304348 
## [3]  train-error:0.142857    test-error:0.347826 
## [4]  train-error:0.089286    test-error:0.260870 
## [5]  train-error:0.053571    test-error:0.347826 
## [6]  train-error:0.053571    test-error:0.391304 
## [7]  train-error:0.035714    test-error:0.347826 
## [8]  train-error:0.000000    test-error:0.391304 
## [9]  train-error:0.017857    test-error:0.391304 
## [10] train-error:0.000000    test-error:0.391304 
## [11] train-error:0.000000    test-error:0.391304 
## [12] train-error:0.000000    test-error:0.347826 
## [13] train-error:0.000000    test-error:0.347826 
## [14] train-error:0.000000    test-error:0.347826 
## [15] train-error:0.000000    test-error:0.347826 
## [16] train-error:0.000000    test-error:0.347826 
## [17] train-error:0.000000    test-error:0.347826 
## [18] train-error:0.000000    test-error:0.347826 
## [19] train-error:0.000000    test-error:0.347826 
## [20] train-error:0.000000    test-error:0.347826 
## [21] train-error:0.000000    test-error:0.347826 
## [22] train-error:0.000000    test-error:0.347826 
## [23] train-error:0.000000    test-error:0.347826 
## [24] train-error:0.000000    test-error:0.347826 
## [25] train-error:0.000000    test-error:0.347826 
## [26] train-error:0.000000    test-error:0.347826 
## [27] train-error:0.000000    test-error:0.347826 
## [28] train-error:0.000000    test-error:0.347826 
## [29] train-error:0.000000    test-error:0.347826 
## [30] train-error:0.000000    test-error:0.347826 
## [31] train-error:0.000000    test-error:0.347826 
## [32] train-error:0.000000    test-error:0.347826 
## [33] train-error:0.000000    test-error:0.347826 
## [34] train-error:0.000000    test-error:0.347826 
## [35] train-error:0.000000    test-error:0.347826 
## [36] train-error:0.000000    test-error:0.347826 
## [37] train-error:0.000000    test-error:0.347826 
## [38] train-error:0.000000    test-error:0.347826 
## [39] train-error:0.000000    test-error:0.347826 
## [40] train-error:0.000000    test-error:0.347826 
## [41] train-error:0.000000    test-error:0.347826 
## [42] train-error:0.000000    test-error:0.347826 
## [43] train-error:0.000000    test-error:0.347826 
## [44] train-error:0.000000    test-error:0.347826 
## [45] train-error:0.000000    test-error:0.347826 
## [46] train-error:0.000000    test-error:0.347826 
## [47] train-error:0.000000    test-error:0.347826 
## [48] train-error:0.000000    test-error:0.347826 
## [49] train-error:0.000000    test-error:0.347826 
## [50] train-error:0.000000    test-error:0.347826

Each feature is grouped by importance with k-means clustering. Gain is the improvement in accuracy that the addition of a feature brings to the branches it is on.

features = colnames(matrix_train)
importance_matrix_1 <- xgb.importance(features, model = bst_1)
print(importance_matrix_1)
##                   Feature       Gain      Cover  Frequency
## 1:                    age 0.54677689 0.45726557 0.43333333
## 2:  days_onset_to_outcome 0.14753698 0.11686537 0.11111111
## 3:            early_onset 0.06724531 0.07365416 0.07777778
## 4:          early_outcome 0.06679880 0.08869426 0.07777778
## 5:      province_Shanghai 0.05325295 0.04317030 0.04444444
## 6: days_onset_to_hospital 0.03883631 0.11300882 0.13333333
## 7:         province_other 0.03502971 0.03977657 0.04444444
## 8:               gender_f 0.02610959 0.02484170 0.01111111
## 9:               hospital 0.01841345 0.04272324 0.06666667
xgb.ggplot.importance(importance_matrix_1) +
  theme_minimal()

pred_1 <- predict(bst_1, xgb_test_matrix)

result_1 <- data.frame(case_ID = rownames(val_test_data),
                       outcome = val_test_data$outcome, 
                       label = label, 
                       prediction_p_death = round(pred_1, digits = 2),
                       prediction = as.integer(pred_1 > 0.5),
                       prediction_eval = ifelse(as.integer(pred_1 > 0.5) != label, "wrong", "correct"))
result_1
##     case_ID outcome label prediction_p_death prediction prediction_eval
## 1  case_123   Death     1               0.01          0           wrong
## 2  case_127 Recover     0               0.00          0         correct
## 3  case_128 Recover     0               0.96          1           wrong
## 4   case_14 Recover     0               0.76          1           wrong
## 5   case_19   Death     1               0.95          1         correct
## 6    case_2   Death     1               0.85          1         correct
## 7   case_20 Recover     0               0.98          1           wrong
## 8   case_21 Recover     0               0.06          0         correct
## 9   case_34   Death     1               1.00          1         correct
## 10  case_37 Recover     0               0.00          0         correct
## 11   case_5 Recover     0               0.03          0         correct
## 12  case_51 Recover     0               0.27          0         correct
## 13  case_55 Recover     0               0.11          0         correct
## 14   case_6   Death     1               0.26          0           wrong
## 15  case_61   Death     1               1.00          1         correct
## 16  case_65 Recover     0               0.00          0         correct
## 17  case_74 Recover     0               0.92          1           wrong
## 18  case_78   Death     1               0.03          0           wrong
## 19  case_79 Recover     0               0.00          0         correct
## 20   case_8   Death     1               0.06          0           wrong
## 21  case_87   Death     1               0.64          1         correct
## 22  case_91 Recover     0               0.00          0         correct
## 23  case_94 Recover     0               0.01          0         correct
err <- as.numeric(sum(as.integer(pred_1 > 0.5) != label))/length(label)
print(paste("test-error =", round(err, digits = 2)))
## [1] "test-error = 0.35"


Training with gblinear

bst_2 <- xgb.train(data = xgb_train_matrix, 
                   booster = "gblinear", 
                   label = getinfo(xgb_train_matrix, "label"),
                   max.depth = 2, 
                   eta = 1, 
                   nthread = 4, 
                   nround = 50, # number of trees used for model building
                   watchlist = watchlist, 
                   objective = "binary:logistic")
## [1]  train-error:0.339286    test-error:0.434783 
## [2]  train-error:0.303571    test-error:0.391304 
## [3]  train-error:0.285714    test-error:0.347826 
## [4]  train-error:0.303571    test-error:0.347826 
## [5]  train-error:0.285714    test-error:0.347826 
## [6]  train-error:0.285714    test-error:0.391304 
## [7]  train-error:0.285714    test-error:0.391304 
## [8]  train-error:0.267857    test-error:0.391304 
## [9]  train-error:0.250000    test-error:0.391304 
## [10] train-error:0.250000    test-error:0.478261 
## [11] train-error:0.250000    test-error:0.434783 
## [12] train-error:0.232143    test-error:0.434783 
## [13] train-error:0.232143    test-error:0.434783 
## [14] train-error:0.232143    test-error:0.434783 
## [15] train-error:0.232143    test-error:0.434783 
## [16] train-error:0.232143    test-error:0.434783 
## [17] train-error:0.232143    test-error:0.434783 
## [18] train-error:0.232143    test-error:0.434783 
## [19] train-error:0.232143    test-error:0.434783 
## [20] train-error:0.232143    test-error:0.434783 
## [21] train-error:0.214286    test-error:0.434783 
## [22] train-error:0.214286    test-error:0.434783 
## [23] train-error:0.214286    test-error:0.434783 
## [24] train-error:0.214286    test-error:0.434783 
## [25] train-error:0.214286    test-error:0.434783 
## [26] train-error:0.214286    test-error:0.434783 
## [27] train-error:0.214286    test-error:0.434783 
## [28] train-error:0.214286    test-error:0.434783 
## [29] train-error:0.214286    test-error:0.434783 
## [30] train-error:0.232143    test-error:0.434783 
## [31] train-error:0.214286    test-error:0.434783 
## [32] train-error:0.214286    test-error:0.434783 
## [33] train-error:0.214286    test-error:0.434783 
## [34] train-error:0.214286    test-error:0.434783 
## [35] train-error:0.214286    test-error:0.434783 
## [36] train-error:0.214286    test-error:0.434783 
## [37] train-error:0.214286    test-error:0.434783 
## [38] train-error:0.214286    test-error:0.434783 
## [39] train-error:0.214286    test-error:0.434783 
## [40] train-error:0.214286    test-error:0.434783 
## [41] train-error:0.214286    test-error:0.434783 
## [42] train-error:0.214286    test-error:0.434783 
## [43] train-error:0.214286    test-error:0.434783 
## [44] train-error:0.214286    test-error:0.434783 
## [45] train-error:0.214286    test-error:0.434783 
## [46] train-error:0.214286    test-error:0.434783 
## [47] train-error:0.214286    test-error:0.434783 
## [48] train-error:0.214286    test-error:0.434783 
## [49] train-error:0.214286    test-error:0.434783 
## [50] train-error:0.214286    test-error:0.434783

Each feature is grouped by importance with k-means clustering. Gain is the improvement in accuracy that the addition of a feature brings to the branches it is on.

features = colnames(matrix_train)
importance_matrix_2 <- xgb.importance(features, model = bst_2)
print(importance_matrix_2)
##                    Feature     Weight
##  1:                    age  0.0439012
##  2:               hospital -0.3131160
##  3:               gender_f  0.8067830
##  4:       province_Jiangsu -0.8807380
##  5:      province_Shanghai -0.4065330
##  6:      province_Zhejiang -0.1284590
##  7:         province_other  0.4510330
##  8:  days_onset_to_outcome  0.0209952
##  9: days_onset_to_hospital -0.0856504
## 10:            early_onset  0.6452190
## 11:          early_outcome  2.5134300
xgb.ggplot.importance(importance_matrix_2) +
  theme_minimal()

pred_2 <- predict(bst_2, xgb_test_matrix)

result_2 <- data.frame(case_ID = rownames(val_test_data),
                       outcome = val_test_data$outcome, 
                       label = label, 
                       prediction_p_death = round(pred_2, digits = 2),
                       prediction = as.integer(pred_2 > 0.5),
                       prediction_eval = ifelse(as.integer(pred_2 > 0.5) != label, "wrong", "correct"))
result_2
##     case_ID outcome label prediction_p_death prediction prediction_eval
## 1  case_123   Death     1               0.10          0           wrong
## 2  case_127 Recover     0               0.03          0         correct
## 3  case_128 Recover     0               0.13          0         correct
## 4   case_14 Recover     0               0.15          0         correct
## 5   case_19   Death     1               0.28          0           wrong
## 6    case_2   Death     1               0.28          0           wrong
## 7   case_20 Recover     0               0.80          1           wrong
## 8   case_21 Recover     0               0.79          1           wrong
## 9   case_34   Death     1               0.74          1         correct
## 10  case_37 Recover     0               0.06          0         correct
## 11   case_5 Recover     0               0.14          0         correct
## 12  case_51 Recover     0               0.77          1           wrong
## 13  case_55 Recover     0               0.23          0         correct
## 14   case_6   Death     1               0.47          0           wrong
## 15  case_61   Death     1               0.52          1         correct
## 16  case_65 Recover     0               0.03          0         correct
## 17  case_74 Recover     0               0.70          1           wrong
## 18  case_78   Death     1               0.22          0           wrong
## 19  case_79 Recover     0               0.06          0         correct
## 20   case_8   Death     1               0.36          0           wrong
## 21  case_87   Death     1               0.65          1         correct
## 22  case_91 Recover     0               0.05          0         correct
## 23  case_94 Recover     0               0.07          0         correct
err <- as.numeric(sum(as.integer(pred_2 > 0.5) != label))/length(label)
print(paste("test-error =", round(err, digits = 2)))
## [1] "test-error = 0.43"


caret

Extreme gradient boosting is also implemented in the caret package. Caret also provides options for preprocessing, of which I will compare a few.


No preprocessing

library(caret)

set.seed(27)
model_xgb_null <-train(outcome ~ .,
                 data=val_train_data,
                 method="xgbTree",
                 preProcess = NULL,
                 trControl = trainControl(method = "repeatedcv", number = 5, repeats = 10, verboseIter = FALSE))
confusionMatrix(predict(model_xgb_null, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       5       4
##    Recover     4      10
##                                           
##                Accuracy : 0.6522          
##                  95% CI : (0.4273, 0.8362)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.4216          
##                                           
##                   Kappa : 0.2698          
##  Mcnemar's Test P-Value : 1.0000          
##                                           
##             Sensitivity : 0.5556          
##             Specificity : 0.7143          
##          Pos Pred Value : 0.5556          
##          Neg Pred Value : 0.7143          
##              Prevalence : 0.3913          
##          Detection Rate : 0.2174          
##    Detection Prevalence : 0.3913          
##       Balanced Accuracy : 0.6349          
##                                           
##        'Positive' Class : Death           
## 


Scaling and centering

With this method the column variables are centered (subtracting the column mean from each value in a column) and standardized (dividing by the column standard deviation).

set.seed(27)
model_xgb_sc_cen <-train(outcome ~ .,
                 data=val_train_data,
                 method="xgbTree",
                 preProcess = c("scale", "center"),
                 trControl = trainControl(method = "repeatedcv", number = 5, repeats = 10, verboseIter = FALSE))
confusionMatrix(predict(model_xgb_sc_cen, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       5       4
##    Recover     4      10
##                                           
##                Accuracy : 0.6522          
##                  95% CI : (0.4273, 0.8362)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.4216          
##                                           
##                   Kappa : 0.2698          
##  Mcnemar's Test P-Value : 1.0000          
##                                           
##             Sensitivity : 0.5556          
##             Specificity : 0.7143          
##          Pos Pred Value : 0.5556          
##          Neg Pred Value : 0.7143          
##              Prevalence : 0.3913          
##          Detection Rate : 0.2174          
##    Detection Prevalence : 0.3913          
##       Balanced Accuracy : 0.6349          
##                                           
##        'Positive' Class : Death           
## 

Scaling and centering did not improve the predictions.

pred_3 <- predict(model_xgb_sc_cen, val_test_data[, -1])
pred_3b <- round(predict(model_xgb_sc_cen, val_test_data[, -1], type="prob"), digits = 2)

result_3 <- data.frame(case_ID = rownames(val_test_data),
                       outcome = val_test_data$outcome, 
                       label = label, 
                       prediction = pred_3,
                       pred_3b)
result_3$prediction_eval <- ifelse(result_3$prediction != result_3$outcome, "wrong", "correct")
result_3
##     case_ID outcome label prediction Death Recover prediction_eval
## 1  case_123   Death     1    Recover  0.04    0.96           wrong
## 2  case_127 Recover     0    Recover  0.04    0.96         correct
## 3  case_128 Recover     0      Death  0.81    0.19           wrong
## 4   case_14 Recover     0      Death  0.70    0.30           wrong
## 5   case_19   Death     1      Death  0.67    0.33         correct
## 6    case_2   Death     1      Death  0.66    0.34         correct
## 7   case_20 Recover     0      Death  0.76    0.24           wrong
## 8   case_21 Recover     0    Recover  0.17    0.83         correct
## 9   case_34   Death     1      Death  0.94    0.06         correct
## 10  case_37 Recover     0    Recover  0.02    0.98         correct
## 11   case_5 Recover     0    Recover  0.09    0.91         correct
## 12  case_51 Recover     0    Recover  0.41    0.59         correct
## 13  case_55 Recover     0    Recover  0.12    0.88         correct
## 14   case_6   Death     1    Recover  0.30    0.70           wrong
## 15  case_61   Death     1      Death  0.99    0.01         correct
## 16  case_65 Recover     0    Recover  0.01    0.99         correct
## 17  case_74 Recover     0      Death  0.85    0.15           wrong
## 18  case_78   Death     1    Recover  0.21    0.79           wrong
## 19  case_79 Recover     0    Recover  0.01    0.99         correct
## 20   case_8   Death     1    Recover  0.20    0.80           wrong
## 21  case_87   Death     1      Death  0.65    0.35         correct
## 22  case_91 Recover     0    Recover  0.01    0.99         correct
## 23  case_94 Recover     0    Recover  0.07    0.93         correct


Box-Cox transformation

The Box-Cox power transformation is used to normalize data.

set.seed(27)
model_xgb_BoxCox <-train(outcome ~ .,
                 data=val_train_data,
                 method="xgbTree",
                 preProcess = "BoxCox",
                 trControl = trainControl(method = "repeatedcv", number = 5, repeats = 10, verboseIter = FALSE))
confusionMatrix(predict(model_xgb_BoxCox, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       5       4
##    Recover     4      10
##                                           
##                Accuracy : 0.6522          
##                  95% CI : (0.4273, 0.8362)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.4216          
##                                           
##                   Kappa : 0.2698          
##  Mcnemar's Test P-Value : 1.0000          
##                                           
##             Sensitivity : 0.5556          
##             Specificity : 0.7143          
##          Pos Pred Value : 0.5556          
##          Neg Pred Value : 0.7143          
##              Prevalence : 0.3913          
##          Detection Rate : 0.2174          
##    Detection Prevalence : 0.3913          
##       Balanced Accuracy : 0.6349          
##                                           
##        'Positive' Class : Death           
## 

Box-Cox transformation did not improve the predictions either.

pred_4 <- predict(model_xgb_BoxCox, val_test_data[, -1])
pred_4b <- round(predict(model_xgb_BoxCox, val_test_data[, -1], type="prob"), digits = 2)

result_4 <- data.frame(case_ID = rownames(val_test_data),
                       outcome = val_test_data$outcome, 
                       label = label, 
                       prediction = pred_4,
                       pred_4b)
result_4$prediction_eval <- ifelse(result_4$prediction != result_4$outcome, "wrong", "correct")
result_4
##     case_ID outcome label prediction Death Recover prediction_eval
## 1  case_123   Death     1    Recover  0.01    0.99           wrong
## 2  case_127 Recover     0    Recover  0.01    0.99         correct
## 3  case_128 Recover     0      Death  0.88    0.12           wrong
## 4   case_14 Recover     0      Death  0.81    0.19           wrong
## 5   case_19   Death     1      Death  0.88    0.12         correct
## 6    case_2   Death     1      Death  0.79    0.21         correct
## 7   case_20 Recover     0      Death  0.84    0.16           wrong
## 8   case_21 Recover     0    Recover  0.07    0.93         correct
## 9   case_34   Death     1      Death  0.97    0.03         correct
## 10  case_37 Recover     0    Recover  0.01    0.99         correct
## 11   case_5 Recover     0    Recover  0.06    0.94         correct
## 12  case_51 Recover     0    Recover  0.27    0.73         correct
## 13  case_55 Recover     0    Recover  0.09    0.91         correct
## 14   case_6   Death     1    Recover  0.30    0.70           wrong
## 15  case_61   Death     1      Death  1.00    0.00         correct
## 16  case_65 Recover     0    Recover  0.00    1.00         correct
## 17  case_74 Recover     0      Death  0.90    0.10           wrong
## 18  case_78   Death     1    Recover  0.11    0.89           wrong
## 19  case_79 Recover     0    Recover  0.00    1.00         correct
## 20   case_8   Death     1    Recover  0.13    0.87           wrong
## 21  case_87   Death     1      Death  0.61    0.39         correct
## 22  case_91 Recover     0    Recover  0.00    1.00         correct
## 23  case_94 Recover     0    Recover  0.04    0.96         correct


Principal Component Analysis (PCA)

PCA is used for dimensionality reduction. When applied as a preprocessing method the number of features are reduced by using the eigenvectors of the covariance matrix.

set.seed(27)
model_xgb_pca <-train(outcome ~ .,
                 data=val_train_data,
                 method="xgbTree",
                 preProcess = "pca",
                 trControl = trainControl(method = "repeatedcv", number = 5, repeats = 10, verboseIter = FALSE))
confusionMatrix(predict(model_xgb_pca, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       4       2
##    Recover     5      12
##                                           
##                Accuracy : 0.6957          
##                  95% CI : (0.4708, 0.8679)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.2644          
##                                           
##                   Kappa : 0.3207          
##  Mcnemar's Test P-Value : 0.4497          
##                                           
##             Sensitivity : 0.4444          
##             Specificity : 0.8571          
##          Pos Pred Value : 0.6667          
##          Neg Pred Value : 0.7059          
##              Prevalence : 0.3913          
##          Detection Rate : 0.1739          
##    Detection Prevalence : 0.2609          
##       Balanced Accuracy : 0.6508          
##                                           
##        'Positive' Class : Death           
## 

PCA did improve the predictions slightly.

pred_5 <- predict(model_xgb_pca, val_test_data[, -1])
pred_5b <- round(predict(model_xgb_pca, val_test_data[, -1], type="prob"), digits = 2)

result_5 <- data.frame(case_ID = rownames(val_test_data),
                       outcome = val_test_data$outcome, 
                       label = label, 
                       prediction = pred_5,
                       pred_5b)
result_5$prediction_eval <- ifelse(result_5$prediction != result_5$outcome, "wrong", "correct")
result_5
##     case_ID outcome label prediction Death Recover prediction_eval
## 1  case_123   Death     1    Recover  0.02    0.98           wrong
## 2  case_127 Recover     0    Recover  0.29    0.71         correct
## 3  case_128 Recover     0    Recover  0.01    0.99         correct
## 4   case_14 Recover     0    Recover  0.44    0.56         correct
## 5   case_19   Death     1    Recover  0.07    0.93           wrong
## 6    case_2   Death     1    Recover  0.26    0.74           wrong
## 7   case_20 Recover     0    Recover  0.16    0.84         correct
## 8   case_21 Recover     0      Death  0.57    0.43           wrong
## 9   case_34   Death     1      Death  0.93    0.07         correct
## 10  case_37 Recover     0    Recover  0.03    0.97         correct
## 11   case_5 Recover     0    Recover  0.11    0.89         correct
## 12  case_51 Recover     0    Recover  0.19    0.81         correct
## 13  case_55 Recover     0    Recover  0.16    0.84         correct
## 14   case_6   Death     1      Death  0.55    0.45         correct
## 15  case_61   Death     1    Recover  0.22    0.78           wrong
## 16  case_65 Recover     0    Recover  0.08    0.92         correct
## 17  case_74 Recover     0      Death  0.67    0.33           wrong
## 18  case_78   Death     1      Death  0.87    0.13         correct
## 19  case_79 Recover     0    Recover  0.16    0.84         correct
## 20   case_8   Death     1    Recover  0.22    0.78           wrong
## 21  case_87   Death     1      Death  0.66    0.34         correct
## 22  case_91 Recover     0    Recover  0.01    0.99         correct
## 23  case_94 Recover     0    Recover  0.14    0.86         correct


Median imputation

set.seed(27)
model_xgb_medianImpute <-train(outcome ~ .,
                 data=val_train_data,
                 method="xgbTree",
                 preProcess = "medianImpute",
                 trControl = trainControl(method = "repeatedcv", number = 5, repeats = 10, verboseIter = FALSE))
confusionMatrix(predict(model_xgb_medianImpute, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       5       4
##    Recover     4      10
##                                           
##                Accuracy : 0.6522          
##                  95% CI : (0.4273, 0.8362)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.4216          
##                                           
##                   Kappa : 0.2698          
##  Mcnemar's Test P-Value : 1.0000          
##                                           
##             Sensitivity : 0.5556          
##             Specificity : 0.7143          
##          Pos Pred Value : 0.5556          
##          Neg Pred Value : 0.7143          
##              Prevalence : 0.3913          
##          Detection Rate : 0.2174          
##    Detection Prevalence : 0.3913          
##       Balanced Accuracy : 0.6349          
##                                           
##        'Positive' Class : Death           
## 

Median imputation did not improve the predictions either.

pred_6 <- predict(model_xgb_medianImpute, val_test_data[, -1])
pred_6b <- round(predict(model_xgb_medianImpute, val_test_data[, -1], type="prob"), digits = 2)

result_6 <- data.frame(case_ID = rownames(val_test_data),
                       outcome = val_test_data$outcome, 
                       label = label, 
                       prediction = pred_6,
                       pred_6b)
result_6$prediction_eval <- ifelse(result_6$prediction != result_6$outcome, "wrong", "correct")
result_6
##     case_ID outcome label prediction Death Recover prediction_eval
## 1  case_123   Death     1    Recover  0.04    0.96           wrong
## 2  case_127 Recover     0    Recover  0.04    0.96         correct
## 3  case_128 Recover     0      Death  0.81    0.19           wrong
## 4   case_14 Recover     0      Death  0.70    0.30           wrong
## 5   case_19   Death     1      Death  0.67    0.33         correct
## 6    case_2   Death     1      Death  0.66    0.34         correct
## 7   case_20 Recover     0      Death  0.76    0.24           wrong
## 8   case_21 Recover     0    Recover  0.17    0.83         correct
## 9   case_34   Death     1      Death  0.94    0.06         correct
## 10  case_37 Recover     0    Recover  0.02    0.98         correct
## 11   case_5 Recover     0    Recover  0.09    0.91         correct
## 12  case_51 Recover     0    Recover  0.41    0.59         correct
## 13  case_55 Recover     0    Recover  0.12    0.88         correct
## 14   case_6   Death     1    Recover  0.30    0.70           wrong
## 15  case_61   Death     1      Death  0.99    0.01         correct
## 16  case_65 Recover     0    Recover  0.01    0.99         correct
## 17  case_74 Recover     0      Death  0.85    0.15           wrong
## 18  case_78   Death     1    Recover  0.21    0.79           wrong
## 19  case_79 Recover     0    Recover  0.01    0.99         correct
## 20   case_8   Death     1    Recover  0.20    0.80           wrong
## 21  case_87   Death     1      Death  0.65    0.35         correct
## 22  case_91 Recover     0    Recover  0.01    0.99         correct
## 23  case_94 Recover     0    Recover  0.07    0.93         correct


Comparison of extreme gradient boosting models

Combining results

library(dplyr)
result <- left_join(result_1[, c(1, 2, 6)], result_2[, c(1, 6)], by = "case_ID")
result <- left_join(result, result_3[, c(1, 7)], by = "case_ID")
result <- left_join(result, result_4[, c(1, 7)], by = "case_ID")
result <- left_join(result, result_5[, c(1, 7)], by = "case_ID")
result <- left_join(result, result_6[, c(1, 7)], by = "case_ID")
colnames(result)[-c(1:2)] <- c("pred_xgboost_gbtree", "pred_xgboost_gblinear", "model_xgb_sc_cen", "model_xgb_BoxCox", "pred_xgbTree_pca", "model_xgb_medianImpute")


What’s the rate of correctly predicted cases in the validation data?

round(sum(result$pred_xgboost_gbtree == "correct")/nrow(result), digits = 2)
## [1] 0.65
round(sum(result$pred_xgboost_gblinear == "correct")/nrow(result), digits = 2)
## [1] 0.57
round(sum(result$model_xgb_sc_cen == "correct")/nrow(result), digits = 2)
## [1] 0.65
round(sum(result$model_xgb_BoxCox == "correct")/nrow(result), digits = 2)
## [1] 0.65
round(sum(result$pred_xgbTree_pca == "correct")/nrow(result), digits = 2)
## [1] 0.7
round(sum(result$model_xgb_medianImpute == "correct")/nrow(result), digits = 2)
## [1] 0.65

PCA preprocessing combined with caret’s implementation of extreme gradient boosting had the highest number of accurately predicted cases in the validation data set.


Predicting unknown output

Because PCA preprocessing and caret’s implementation of extreme gradient boosting had the highest prediction accuracy, this model will be used to predict the unknown cases from the original dataset.

Uncertainty of prediction is accounted for with a cutoff below 0.7.

set.seed(27)
model_xgb_pca <-train(outcome ~ .,
                 data = train_data,
                 method = "xgbTree",
                 preProcess = "pca",
                 trControl = trainControl(method = "repeatedcv", number = 5, repeats = 10, verboseIter = FALSE))
pred <- predict(model_xgb_pca, test_data)
predb <- round(predict(model_xgb_pca, test_data, type="prob"), digits = 2)

result <- data.frame(case_ID = rownames(test_data),
                       prediction = pred,
                       predb)
result$predicted_outcome <- ifelse(result$Death > 0.7, "Death",
                                   ifelse(result$Recover > 0.7, "Recover", "uncertain"))
result
##     case_ID prediction Death Recover predicted_outcome
## 1  case_100    Recover  0.08    0.92           Recover
## 2  case_101    Recover  0.11    0.89           Recover
## 3  case_102    Recover  0.04    0.96           Recover
## 4  case_103    Recover  0.01    0.99           Recover
## 5  case_104    Recover  0.19    0.81           Recover
## 6  case_105    Recover  0.01    0.99           Recover
## 7  case_108      Death  0.99    0.01             Death
## 8  case_109    Recover  0.30    0.70         uncertain
## 9  case_110    Recover  0.16    0.84           Recover
## 10 case_112      Death  0.54    0.46         uncertain
## 11 case_113    Recover  0.02    0.98           Recover
## 12 case_114    Recover  0.30    0.70         uncertain
## 13 case_115    Recover  0.02    0.98           Recover
## 14 case_118      Death  0.65    0.35         uncertain
## 15 case_120      Death  0.89    0.11             Death
## 16 case_122    Recover  0.00    1.00           Recover
## 17 case_126    Recover  0.04    0.96           Recover
## 18 case_130    Recover  0.42    0.58         uncertain
## 19 case_132    Recover  0.47    0.53         uncertain
## 20 case_136      Death  0.86    0.14             Death
## 21  case_15      Death  0.99    0.01             Death
## 22  case_16      Death  0.91    0.09             Death
## 23  case_22      Death  0.80    0.20             Death
## 24  case_28      Death  0.91    0.09             Death
## 25  case_31      Death  1.00    0.00             Death
## 26  case_32      Death  0.93    0.07             Death
## 27  case_38      Death  0.87    0.13             Death
## 28  case_39      Death  0.88    0.12             Death
## 29   case_4      Death  0.99    0.01             Death
## 30  case_40      Death  1.00    0.00             Death
## 31  case_41      Death  0.95    0.05             Death
## 32  case_42    Recover  0.49    0.51         uncertain
## 33  case_47      Death  0.91    0.09             Death
## 34  case_48      Death  0.95    0.05             Death
## 35  case_52      Death  0.62    0.38         uncertain
## 36  case_54    Recover  0.43    0.57         uncertain
## 37  case_56      Death  0.81    0.19             Death
## 38  case_62      Death  0.58    0.42         uncertain
## 39  case_63    Recover  0.44    0.56         uncertain
## 40  case_66    Recover  0.18    0.82           Recover
## 41  case_67    Recover  0.02    0.98           Recover
## 42  case_68    Recover  0.27    0.73           Recover
## 43  case_69    Recover  0.35    0.65         uncertain
## 44  case_70    Recover  0.02    0.98           Recover
## 45  case_71    Recover  0.17    0.83           Recover
## 46  case_80    Recover  0.05    0.95           Recover
## 47  case_84    Recover  0.02    0.98           Recover
## 48  case_85      Death  0.99    0.01             Death
## 49  case_86    Recover  0.02    0.98           Recover
## 50  case_88    Recover  0.02    0.98           Recover
## 51   case_9      Death  0.98    0.02             Death
## 52  case_90    Recover  0.08    0.92           Recover
## 53  case_92    Recover  0.04    0.96           Recover
## 54  case_93    Recover  0.43    0.57         uncertain
## 55  case_95    Recover  0.11    0.89           Recover
## 56  case_96      Death  0.68    0.32         uncertain
## 57  case_99    Recover  0.04    0.96           Recover


Comparison with predicted outcome from last week’s analyses


Plotting predicted outcome

Results from last week’s analysis come from the following part of the analysis:

results <- data.frame(randomForest = predict(model_rf, newdata = test_data, type="prob"),
glmnet = predict(model_glmnet, newdata = test_data, type="prob"),
kknn = predict(model_kknn, newdata = test_data, type="prob"),
pda = predict(model_pda, newdata = test_data, type="prob"),
slda = predict(model_slda, newdata = test_data, type="prob"),
pam = predict(model_pam, newdata = test_data, type="prob"),
C5.0Tree = predict(model_C5.0Tree, newdata = test_data, type="prob"),
pls = predict(model_pls, newdata = test_data, type="prob"))

results$sum_Death <- rowSums(results[, grep("Death", colnames(results))])
results$sum_Recover <- rowSums(results[, grep("Recover", colnames(results))])
results$log2_ratio <- log2(results$sum_Recover/results$sum_Death)
results$predicted_outcome <- ifelse(results$log2_ratio > 1.5, "Recover", ifelse(results$log2_ratio < -1.5, "Death", "uncertain"))
results[, -c(1:16)]

results <- results_last_week

Combining the table with predicted outcomes from this and last week with the original data.

result <- merge(result[, c(1, 5)], results_last_week[, ncol(results_last_week), drop = FALSE], by.x = "case_ID", by.y = "row.names")
colnames(result)[2:3] <- c("predicted_outcome_xgboost", "predicted_outcome_last_week")

results_combined <- merge(result, fluH7N9.china.2013[which(fluH7N9.china.2013$case.ID %in% result$case_ID), ], 
                          by.x = "case_ID", by.y = "case.ID")
results_combined <- results_combined[, -c(6,7)]

For plotting with ggplot2, the dataframe needs to be gathered.

library(tidyr)
results_combined_gather <- results_combined %>%
  gather(group_dates, date, date.of.onset:date.of.hospitalisation)

results_combined_gather$group_dates <- factor(results_combined_gather$group_dates, levels = c("date.of.onset", "date.of.hospitalisation"))

results_combined_gather$group_dates <- mapvalues(results_combined_gather$group_dates, from = c("date.of.onset", "date.of.hospitalisation"), 
                                             to = c("Date of onset", "Date of hospitalisation"))

results_combined_gather$gender <- mapvalues(results_combined_gather$gender, from = c("f", "m"), 
                                             to = c("Female", "Male"))
levels(results_combined_gather$gender) <- c(levels(results_combined_gather$gender), "unknown")
results_combined_gather$gender[is.na(results_combined_gather$gender)] <- "unknown"

results_combined_gather <- results_combined_gather %>%
  gather(group_pred, prediction, predicted_outcome_xgboost:predicted_outcome_last_week)

results_combined_gather$group_pred <- mapvalues(results_combined_gather$group_pred, from = c("predicted_outcome_xgboost", "predicted_outcome_last_week"), 
                                             to = c("Predicted outcome from XGBoost", "Predicted outcome from last week"))

Setting a custom theme for plotting:

library(ggplot2)
my_theme <- function(base_size = 12, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
  theme(
    axis.text = element_text(size = 12),
    axis.text.x = element_text(angle = 45, vjust = 0.5, hjust = 0.5),
    axis.title = element_text(size = 14),
    panel.grid.major = element_line(color = "grey"),
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "aliceblue"),
    strip.background = element_rect(fill = "lightgrey", color = "grey", size = 1),
    strip.text = element_text(face = "bold", size = 12, color = "black"),
    legend.position = "bottom",
    legend.justification = "top", 
    legend.box = "horizontal",
    legend.box.background = element_rect(colour = "grey50"),
    legend.background = element_blank(),
    panel.border = element_rect(color = "grey", fill = NA, size = 0.5),
    panel.spacing = unit(1, "lines")
  )
}
results_combined_gather$province <- mapvalues(results_combined_gather$province, 
                                                from = c("Anhui", "Beijing", "Fujian", "Guangdong", "Hebei", "Henan", "Hunan", "Jiangxi", "Shandong", "Taiwan"), 
                                                to = rep("Other", 10))

levels(results_combined_gather$gender) <- c(levels(results_combined_gather$gender), "unknown")
results_combined_gather$gender[is.na(results_combined_gather$gender)] <- "unknown"

results_combined_gather$province <- factor(results_combined_gather$province, levels = c("Jiangsu",  "Shanghai", "Zhejiang", "Other"))
ggplot(data = subset(results_combined_gather, group_dates == "Date of onset"), aes(x = date, y = as.numeric(age), fill = prediction)) +
  stat_density2d(aes(alpha = ..level..), geom = "polygon") +
  geom_jitter(aes(color = prediction, shape = gender), size = 2) +
  geom_rug(aes(color = prediction)) +
  labs(
    fill = "Predicted outcome",
    color = "Predicted outcome",
    alpha = "Level",
    shape = "Gender",
    x = "Date of onset in 2013",
    y = "Age",
    title = "2013 Influenza A H7N9 cases in China",
    subtitle = "Predicted outcome of cases with unknown outcome",
    caption = ""
  ) +
  facet_grid(group_pred ~ province) +
  my_theme() +
  scale_shape_manual(values = c(15, 16, 17)) +
  scale_color_brewer(palette="Set1", na.value = "grey50") +
  scale_fill_brewer(palette="Set1")


There is much less uncertainty in the XGBoost data, even tough I used slightly different methods for classifying uncertainty: In last week’s analysis I based uncertainty on the ratio of combined prediction values from all analyses, this week uncertainty is based on the prediction value from one analysis.


For a comparison of feature selection methods see here.


If you are interested in more machine learning posts, check out the category listing for machine_learning.


sessionInfo()
## R version 3.3.2 (2016-10-31)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 7 x64 (build 7601) Service Pack 1
## 
## locale:
## [1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252    LC_MONETARY=English_United States.1252 LC_NUMERIC=C                           LC_TIME=English_United States.1252    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] tidyr_0.6.0     plyr_1.8.4      xgboost_0.6-0   caret_6.0-73    ggplot2_2.2.0   lattice_0.20-34 mice_2.25       Rcpp_0.12.8     dplyr_0.5.0     outbreaks_1.0.0
## 
## loaded via a namespace (and not attached):
##  [1] Ckmeans.1d.dp_3.4.6-4 RColorBrewer_1.1-2    compiler_3.3.2        nloptr_1.0.4          class_7.3-14          iterators_1.0.8       tools_3.3.2           rpart_4.1-10          digest_0.6.10         lme4_1.1-12           evaluate_0.10         tibble_1.2            gtable_0.2.0          nlme_3.1-128          mgcv_1.8-16           Matrix_1.2-7.1        foreach_1.4.3         DBI_0.5-1             parallel_3.3.2        yaml_2.1.14           SparseM_1.74          e1071_1.6-7           stringr_1.1.0         knitr_1.15            MatrixModels_0.4-1    stats4_3.3.2          grid_3.3.2            nnet_7.3-12           data.table_1.9.6      R6_2.2.0              survival_2.40-1       rmarkdown_1.1         minqa_1.2.4           reshape2_1.4.2        car_2.1-3             magrittr_1.5          scales_0.4.1          codetools_0.2-15      ModelMetrics_1.1.0    htmltools_0.3.5       MASS_7.3-45           splines_3.3.2         assertthat_0.1        pbkrtest_0.4-6        colorspace_1.3-0      labeling_0.3          quantreg_5.29         stringi_1.1.2         lazyeval_0.2.0        munsell_0.4.3         chron_2.3-47

Can we predict flu deaths with Machine Learning and R?

2016-11-27 00:00:00 +0000

Edited on 26 December 2016


Among the many R packages, there is the outbreaks package. It contains datasets on epidemics, on of which is from the 2013 outbreak of influenza A H7N9 in China, as analysed by Kucharski et al. (2014):

A. Kucharski, H. Mills, A. Pinsent, C. Fraser, M. Van Kerkhove, C. A. Donnelly, and S. Riley. 2014. Distinguishing between reservoir exposure and human-to-human transmission for emerging pathogens using case onset data. PLOS Currents Outbreaks. Mar 7, edition 1. doi: 10.1371/currents.outbreaks.e1473d9bfc99d080ca242139a06c455f.

A. Kucharski, H. Mills, A. Pinsent, C. Fraser, M. Van Kerkhove, C. A. Donnelly, and S. Riley. 2014. Data from: Distinguishing between reservoir exposure and human-to-human transmission for emerging pathogens using case onset data. Dryad Digital Repository. http://dx.doi.org/10.5061/dryad.2g43n.

I will be using their data as an example to test whether we can use Machine Learning algorithms for predicting disease outcome.

To do so, I selected and extracted features from the raw data, including age, days between onset and outcome, gender, whether the patients were hospitalised, etc. Missing values were imputed and different model algorithms were used to predict outcome (death or recovery). The prediction accuracy, sensitivity and specificity. The thus prepared dataset was devided into training and testing subsets. The test subset contained all cases with an unknown outcome. Before I applied the models to the test data, I further split the training data into validation subsets.

The tested modeling algorithms were similarly successful at predicting the outcomes of the validation data. To decide on final classifications, I compared predictions from all models and defined the outcome “Death” or “Recovery” as a function of all models, whereas classifications with a low prediction probability were flagged as “uncertain”. Accounting for this uncertainty led to a 100% correct classification of the validation test set.

The training cases with unknown outcome were then classified based on the same algorithms. From 57 unknown cases, 14 were classified as “Recovery”, 10 as “Death” and 33 as uncertain.


In a Part 2 I am looking at how extreme gradient boosting performs on this dataset.


The data

The dataset contains case ID, date of onset, date of hospitalisation, date of outcome, gender, age, province and of course the outcome: Death or Recovery. I can already see that there are a couple of missing values in the data, which I will deal with later.

# install and load package
if (!require("outbreaks")) install.packages("outbreaks")
library(outbreaks)
fluH7N9.china.2013_backup <- fluH7N9.china.2013 # back up original dataset in case something goes awry along the way

# convert ? to NAs
fluH7N9.china.2013$age[which(fluH7N9.china.2013$age == "?")] <- NA

# create a new column with case ID
fluH7N9.china.2013$case.ID <- paste("case", fluH7N9.china.2013$case.ID, sep = "_")
head(fluH7N9.china.2013)
##   case.ID date.of.onset date.of.hospitalisation date.of.outcome outcome gender age province
## 1  case_1    2013-02-19                    <NA>      2013-03-04   Death      m  87 Shanghai
## 2  case_2    2013-02-27              2013-03-03      2013-03-10   Death      m  27 Shanghai
## 3  case_3    2013-03-09              2013-03-19      2013-04-09   Death      f  35    Anhui
## 4  case_4    2013-03-19              2013-03-27            <NA>    <NA>      f  45  Jiangsu
## 5  case_5    2013-03-19              2013-03-30      2013-05-15 Recover      f  48  Jiangsu
## 6  case_6    2013-03-21              2013-03-28      2013-04-26   Death      f  32  Jiangsu


Before I start preparing the data for Machine Learning, I want to get an idea of the distribution of the data points and their different variables by plotting.

Most provinces have only a handful of cases, so I am combining them into the category “other” and keep only Jiangsu, Shanghai and Zhejian and separate provinces.

# gather for plotting with ggplot2
library(tidyr)
fluH7N9.china.2013_gather <- fluH7N9.china.2013 %>%
  gather(Group, Date, date.of.onset:date.of.outcome)

# rearrange group order
fluH7N9.china.2013_gather$Group <- factor(fluH7N9.china.2013_gather$Group, levels = c("date.of.onset", "date.of.hospitalisation", "date.of.outcome"))

# rename groups
library(plyr)
fluH7N9.china.2013_gather$Group <- mapvalues(fluH7N9.china.2013_gather$Group, from = c("date.of.onset", "date.of.hospitalisation", "date.of.outcome"), 
          to = c("Date of onset", "Date of hospitalisation", "Date of outcome"))

# renaming provinces
fluH7N9.china.2013_gather$province <- mapvalues(fluH7N9.china.2013_gather$province, 
                                                from = c("Anhui", "Beijing", "Fujian", "Guangdong", "Hebei", "Henan", "Hunan", "Jiangxi", "Shandong", "Taiwan"), 
                                                to = rep("Other", 10))

# add a level for unknown gender
levels(fluH7N9.china.2013_gather$gender) <- c(levels(fluH7N9.china.2013_gather$gender), "unknown")
fluH7N9.china.2013_gather$gender[is.na(fluH7N9.china.2013_gather$gender)] <- "unknown"

# rearrange province order so that Other is the last
fluH7N9.china.2013_gather$province <- factor(fluH7N9.china.2013_gather$province, levels = c("Jiangsu",  "Shanghai", "Zhejiang", "Other"))

# convert age to numeric
fluH7N9.china.2013_gather$age <- as.numeric(as.character(fluH7N9.china.2013_gather$age))
# preparing my ggplot2 theme of choice

library(ggplot2)
my_theme <- function(base_size = 12, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
  theme(
    axis.text = element_text(size = 12),
    axis.text.x = element_text(angle = 45, vjust = 0.5, hjust = 0.5),
    axis.title = element_text(size = 14),
    panel.grid.major = element_line(color = "grey"),
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "aliceblue"),
    strip.background = element_rect(fill = "lightgrey", color = "grey", size = 1),
    strip.text = element_text(face = "bold", size = 12, color = "black"),
    legend.position = "bottom",
    legend.justification = "top", 
    legend.box = "horizontal",
    legend.box.background = element_rect(colour = "grey50"),
    legend.background = element_blank(),
    panel.border = element_rect(color = "grey", fill = NA, size = 0.5)
  )
}
# plotting raw data

ggplot(data = fluH7N9.china.2013_gather, aes(x = Date, y = age, fill = outcome)) +
  stat_density2d(aes(alpha = ..level..), geom = "polygon") +
  geom_jitter(aes(color = outcome, shape = gender), size = 1.5) +
  geom_rug(aes(color = outcome)) +
  labs(
    fill = "Outcome",
    color = "Outcome",
    alpha = "Level",
    shape = "Gender",
    x = "Date in 2013",
    y = "Age",
    title = "2013 Influenza A H7N9 cases in China",
    subtitle = "Dataset from 'outbreaks' package (Kucharski et al. 2014)",
    caption = ""
  ) +
  facet_grid(Group ~ province) +
  my_theme() +
  scale_shape_manual(values = c(15, 16, 17)) +
  scale_color_brewer(palette="Set1", na.value = "grey50") +
  scale_fill_brewer(palette="Set1")

This plot shows the dates of onset, hospitalisation and outcome (if known) of each data point. Outcome is marked by color and age shown on the y-axis. Gender is marked by point shape.

The density distribution of date by age for the cases seems to indicate that older people died more frequently in the Jiangsu and Zhejiang province than in Shanghai and in other provinces.

When we look at the distribution of points along the time axis, it suggests that their might be a positive correlation between the likelihood of death and an early onset or early outcome.


I also want to know how many cases there are for each gender and province and compare the genders’ age distribution.

fluH7N9.china.2013_gather_2 <- fluH7N9.china.2013_gather[, -4] %>%
  gather(group_2, value, gender:province)

fluH7N9.china.2013_gather_2$value <- mapvalues(fluH7N9.china.2013_gather_2$value, from = c("m", "f", "unknown", "Other"), 
          to = c("Male", "Female", "Unknown gender", "Other province"))

fluH7N9.china.2013_gather_2$value <- factor(fluH7N9.china.2013_gather_2$value, 
                                            levels = c("Female", "Male", "Unknown gender", "Jiangsu", "Shanghai", "Zhejiang", "Other province"))

p1 <- ggplot(data = fluH7N9.china.2013_gather_2, aes(x = value, fill = outcome, color = outcome)) +
  geom_bar(position = "dodge", alpha = 0.7, size = 1) +
  my_theme() +
  scale_fill_brewer(palette="Set1", na.value = "grey50") +
  scale_color_brewer(palette="Set1", na.value = "grey50") +
  labs(
    color = "",
    fill = "",
    x = "",
    y = "Count",
    title = "2013 Influenza A H7N9 cases in China",
    subtitle = "Gender and Province numbers of flu cases",
    caption = ""
  )

p2 <- ggplot(data = fluH7N9.china.2013_gather, aes(x = age, fill = outcome, color = outcome)) +
  geom_density(alpha = 0.3, size = 1) +
  geom_rug() +
  scale_color_brewer(palette="Set1", na.value = "grey50") +
  scale_fill_brewer(palette="Set1", na.value = "grey50") +
  my_theme() +
  labs(
    color = "",
    fill = "",
    x = "Age",
    y = "Density",
    title = "",
    subtitle = "Age distribution of flu cases",
    caption = ""
  )

library(gridExtra)
library(grid)

grid.arrange(p1, p2, ncol = 2)

In the dataset, there are more male than female cases and correspondingly, we see more deaths, recoveries and unknown outcomes in men than in women. This is potentially a problem later on for modeling because the inherent likelihoods for outcome are not directly comparable between the sexes.

Most unknown outcomes were recorded in Zhejiang. Similarly to gender, we don’t have an equal distribution of data points across provinces either.

When we look at the age distribution it is obvious that people who died tended to be slightly older than those who recovered. The density curve of unknown outcomes is more similar to that of death than of recovery, suggesting that among these people there might have been more deaths than recoveries.


And lastly, I want to plot how many days passed between onset, hospitalisation and outcome for each case.

ggplot(data = fluH7N9.china.2013_gather, aes(x = Date, y = age, color = outcome)) +
  geom_point(aes(shape = gender), size = 1.5, alpha = 0.6) +
  geom_path(aes(group = case.ID)) +
  facet_wrap( ~ province, ncol = 2) +
  my_theme() +
  scale_shape_manual(values = c(15, 16, 17)) +
  scale_color_brewer(palette="Set1", na.value = "grey50") +
  scale_fill_brewer(palette="Set1") +
  labs(
    color = "Outcome",
    shape = "Gender",
    x = "Date in 2013",
    y = "Age",
    title = "2013 Influenza A H7N9 cases in China",
    subtitle = "Dataset from 'outbreaks' package (Kucharski et al. 2014)",
    caption = "\nTime from onset of flu to outcome."
  )

This plot shows that there are many missing values in the dates, so it is hard to draw a general conclusion.


Features

In Machine Learning-speak features are the variables used for model training. Using the right features dramatically influences the accuracy of the model.

Because we don’t have many features, I am keeping age as it is, but I am also generating new features:

  • from the date information I am calculating the days between onset and outcome and between onset and hospitalisation
  • I am converting gender into numeric values with 1 for female and 0 for male
  • similarly, I am converting provinces to binary classifiers (yes == 1, no == 0) for Shanghai, Zhejiang, Jiangsu and other provinces
  • the same binary classification is given for whether a case was hospitalised, and whether they had an early onset or early outcome (earlier than the median date)
# preparing the data frame for modeling
# 
library(dplyr)

dataset <- fluH7N9.china.2013 %>%
  mutate(hospital = as.factor(ifelse(is.na(date.of.hospitalisation), 0, 1)),
         gender_f = as.factor(ifelse(gender == "f", 1, 0)),
         province_Jiangsu = as.factor(ifelse(province == "Jiangsu", 1, 0)),
         province_Shanghai = as.factor(ifelse(province == "Shanghai", 1, 0)),
         province_Zhejiang = as.factor(ifelse(province == "Zhejiang", 1, 0)),
         province_other = as.factor(ifelse(province == "Zhejiang" | province == "Jiangsu" | province == "Shanghai", 0, 1)),
         days_onset_to_outcome = as.numeric(as.character(gsub(" days", "",
                                      as.Date(as.character(date.of.outcome), format = "%Y-%m-%d") - 
                                        as.Date(as.character(date.of.onset), format = "%Y-%m-%d")))),
         days_onset_to_hospital = as.numeric(as.character(gsub(" days", "",
                                      as.Date(as.character(date.of.hospitalisation), format = "%Y-%m-%d") - 
                                        as.Date(as.character(date.of.onset), format = "%Y-%m-%d")))),
         age = as.numeric(as.character(age)),
         early_onset = as.factor(ifelse(date.of.onset < summary(fluH7N9.china.2013$date.of.onset)[[3]], 1, 0)),
         early_outcome = as.factor(ifelse(date.of.outcome < summary(fluH7N9.china.2013$date.of.outcome)[[3]], 1, 0))) %>%
  subset(select = -c(2:4, 6, 8))
rownames(dataset) <- dataset$case.ID
dataset <- dataset[, -1]
head(dataset)
##        outcome age hospital gender_f province_Jiangsu province_Shanghai province_Zhejiang province_other days_onset_to_outcome days_onset_to_hospital early_onset early_outcome
## case_1   Death  87        0        0                0                 1                 0              0                    13                     NA           1             1
## case_2   Death  27        1        0                0                 1                 0              0                    11                      4           1             1
## case_3   Death  35        1        1                0                 0                 0              1                    31                     10           1             1
## case_4    <NA>  45        1        1                1                 0                 0              0                    NA                      8           1          <NA>
## case_5 Recover  48        1        1                1                 0                 0              0                    57                     11           1             0
## case_6   Death  32        1        1                1                 0                 0              0                    36                      7           1             1


Imputing missing values

When looking at the dataset I created for modeling, it is obvious that we have quite a few missing values.

The missing values from the outcome column are what I want to predict but for the rest I would either have to remove the entire row from the data or impute the missing information. I decided to try the latter with the mice package.

# impute missing data

library(mice)

dataset_impute <- mice(dataset[, -1],  print = FALSE)
dataset_impute
## Multiply imputed dataset
## Call:
## mice(data = dataset[, -1], printFlag = FALSE)
## Number of multiple imputations:  5
## Missing cells per column:
##                    age               hospital               gender_f       province_Jiangsu      province_Shanghai      province_Zhejiang         province_other  days_onset_to_outcome days_onset_to_hospital            early_onset          early_outcome 
##                      2                      0                      2                      0                      0                      0                      0                     67                     74                     10                     65 
## Imputation methods:
##                    age               hospital               gender_f       province_Jiangsu      province_Shanghai      province_Zhejiang         province_other  days_onset_to_outcome days_onset_to_hospital            early_onset          early_outcome 
##                  "pmm"                     ""               "logreg"                     ""                     ""                     ""                     ""                  "pmm"                  "pmm"               "logreg"               "logreg" 
## VisitSequence:
##                    age               gender_f  days_onset_to_outcome days_onset_to_hospital            early_onset          early_outcome 
##                      1                      3                      8                      9                     10                     11 
## PredictorMatrix:
##                        age hospital gender_f province_Jiangsu province_Shanghai province_Zhejiang province_other days_onset_to_outcome days_onset_to_hospital early_onset early_outcome
## age                      0        1        1                1                 1                 1              1                     1                      1           1             1
## hospital                 0        0        0                0                 0                 0              0                     0                      0           0             0
## gender_f                 1        1        0                1                 1                 1              1                     1                      1           1             1
## province_Jiangsu         0        0        0                0                 0                 0              0                     0                      0           0             0
## province_Shanghai        0        0        0                0                 0                 0              0                     0                      0           0             0
## province_Zhejiang        0        0        0                0                 0                 0              0                     0                      0           0             0
## province_other           0        0        0                0                 0                 0              0                     0                      0           0             0
## days_onset_to_outcome    1        1        1                1                 1                 1              1                     0                      1           1             1
## days_onset_to_hospital   1        1        1                1                 1                 1              1                     1                      0           1             1
## early_onset              1        1        1                1                 1                 1              1                     1                      1           0             1
## early_outcome            1        1        1                1                 1                 1              1                     1                      1           1             0
## Random generator seed value:  NA
# recombine imputed data frame with the outcome column

dataset_complete <- merge(dataset[, 1, drop = FALSE], mice::complete(dataset_impute, 1), by = "row.names", all = TRUE)
rownames(dataset_complete) <- dataset_complete$Row.names
dataset_complete <- dataset_complete[, -1]


Test, train and validation datasets

For building the model, I am separating the imputed data frame into training and test data. Test data are the 57 cases with unknown outcome.

summary(dataset$outcome)
##   Death Recover    NA's 
##      32      47      57


The training data will be further devided for validation of the models: 70% of the training data will be kept for model building and the remaining 30% will be used for model testing.

I am using the caret package for modeling.

train_index <- which(is.na(dataset_complete$outcome))
train_data <- dataset_complete[-train_index, ]
test_data  <- dataset_complete[train_index, -1]

library(caret)
set.seed(27)
val_index <- createDataPartition(train_data$outcome, p = 0.7, list=FALSE)
val_train_data <- train_data[val_index, ]
val_test_data  <- train_data[-val_index, ]
val_train_X <- val_train_data[,-1]
val_test_X <- val_test_data[,-1]


Decision trees

To get an idea about how each feature contributes to the prediction of the outcome, I first built a decision tree based on the training data using rpart and rattle.

library(rpart)
library(rattle)
library(rpart.plot)
library(RColorBrewer)

set.seed(27)
fit <- rpart(outcome ~ .,
                    data = train_data,
                    method = "class",
                    control = rpart.control(xval = 10, minbucket = 2, cp = 0), parms = list(split = "information"))

fancyRpartPlot(fit)

This randomly generated decision tree shows that cases with an early outcome were most likely to die when they were 68 or older, when they also had an early onset and when they were sick for fewer than 13 days. If a person was not among the first cases and was younger than 52, they had a good chance of recovering, but if they were 82 or older, they were more likely to die from the flu.


Feature Importance

Not all of the features I created will be equally important to the model. The decision tree already gave me an idea of which features might be most important but I also want to estimate feature importance using a Random Forest approach with repeated cross validation.

# prepare training scheme
control <- trainControl(method = "repeatedcv", number = 10, repeats = 10)

# train the model
set.seed(27)
model <- train(outcome ~ ., data = train_data, method = "rf", preProcess = NULL, trControl = control)

# estimate variable importance
importance <- varImp(model, scale=TRUE)
# prepare for plotting
importance_df_1 <- importance$importance
importance_df_1$group <- rownames(importance_df_1)

importance_df_1$group <- mapvalues(importance_df_1$group, 
                                           from = c("age", "hospital2", "gender_f2", "province_Jiangsu2", "province_Shanghai2", "province_Zhejiang2",
                                                    "province_other2", "days_onset_to_outcome", "days_onset_to_hospital", "early_onset2", "early_outcome2" ), 
                                           to = c("Age", "Hospital", "Female", "Jiangsu", "Shanghai", "Zhejiang",
                                                    "Other province", "Days onset to outcome", "Days onset to hospital", "Early onset", "Early outcome" ))
f = importance_df_1[order(importance_df_1$Overall, decreasing = FALSE), "group"]

importance_df_2 <- importance_df_1
importance_df_2$Overall <- 0

importance_df <- rbind(importance_df_1, importance_df_2)

# setting factor levels
importance_df <- within(importance_df, group <- factor(group, levels = f))
importance_df_1 <- within(importance_df_1, group <- factor(group, levels = f))

ggplot() +
  geom_point(data = importance_df_1, aes(x = Overall, y = group, color = group), size = 2) +
  geom_path(data = importance_df, aes(x = Overall, y = group, color = group, group = group), size = 1) +
  scale_color_manual(values = rep(brewer.pal(1, "Set1")[1], 11)) +
  my_theme() +
  theme(legend.position = "none",
        axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5)) +
  labs(
    x = "Importance",
    y = "",
    title = "2013 Influenza A H7N9 cases in China",
    subtitle = "Scaled feature importance",
    caption = "\nDetermined with Random Forest and
    repeated cross validation (10 repeats, 10 times)"
  )

This tells me that age is the most important determining factor for predicting disease outcome, followed by days between onset an outcome, early outcome and days between onset and hospitalisation.


Feature Plot

Before I start actually building models, I want to check whether the distribution of feature values is comparable between training, validation and test datasets.

# prepare for plotting
dataset_complete_gather <- dataset_complete %>%
  mutate(set = ifelse(rownames(dataset_complete) %in% rownames(test_data), "Test Data", 
                               ifelse(rownames(dataset_complete) %in% rownames(val_train_data), "Validation Train Data",
                                      ifelse(rownames(dataset_complete) %in% rownames(val_test_data), "Validation Test Data", "NA"))),
         case_ID = rownames(.)) %>%
  gather(group, value, age:early_outcome)

dataset_complete_gather$group <- mapvalues(dataset_complete_gather$group, 
                                           from = c("age", "hospital", "gender_f", "province_Jiangsu", "province_Shanghai", "province_Zhejiang",
                                                    "province_other", "days_onset_to_outcome", "days_onset_to_hospital", "early_onset", "early_outcome" ), 
                                           to = c("Age", "Hospital", "Female", "Jiangsu", "Shanghai", "Zhejiang",
                                                    "Other province", "Days onset to outcome", "Days onset to hospital", "Early onset", "Early outcome" ))
ggplot(data = dataset_complete_gather, aes(x = as.numeric(value), fill = outcome, color = outcome)) +
  geom_density(alpha = 0.2) +
  geom_rug() +
  scale_color_brewer(palette="Set1", na.value = "grey50") +
  scale_fill_brewer(palette="Set1", na.value = "grey50") +
  my_theme() +
  facet_wrap(set ~ group, ncol = 11, scales = "free") +
  labs(
    x = "Value",
    y = "Density",
    title = "2013 Influenza A H7N9 cases in China",
    subtitle = "Features for classifying outcome",
    caption = "\nDensity distribution of all features used for classification of flu outcome."
  )
ggplot(subset(dataset_complete_gather, group == "Age" | group == "Days onset to hospital" | group == "Days onset to outcome"), 
       aes(x=outcome, y=as.numeric(value), fill=set)) + geom_boxplot() +
  my_theme() +
  scale_fill_brewer(palette="Set1", type = "div ") +
  facet_wrap( ~ group, ncol = 3, scales = "free") +
  labs(
    fill = "",
    x = "Outcome",
    y = "Value",
    title = "2013 Influenza A H7N9 cases in China",
    subtitle = "Features for classifying outcome",
    caption = "\nBoxplot of the features age, days from onset to hospitalisation and days from onset to outcome."
  )

Luckily, the distributions looks reasonably similar between the validation and test data, except for a few outliers.



Comparing Machine Learning algorithms

Before I try to predict the outcome of the unknown cases, I am testing the models’ accuracy with the validation datasets on a couple of algorithms. I have chosen only a few more well known algorithms, but caret implements many more.

I have chose to not do any preprocessing because I was worried that the different data distributions with continuous variables (e.g. age) and binary variables (i.e. 0, 1 classification of e.g. hospitalisation) would lead to problems.


Random Forest

Random Forests predictions are based on the generation of multiple classification trees.

This model classified 14 out of 23 cases correctly.

set.seed(27)
model_rf <- caret::train(outcome ~ .,
                             data = val_train_data,
                             method = "rf",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_rf
## Random Forest 
## 
## 56 samples
## 11 predictors
##  2 classes: 'Death', 'Recover' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.7801905  0.5338408
##    6    0.7650952  0.4926366
##   11    0.7527619  0.4638073
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 2.
confusionMatrix(predict(model_rf, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       4       4
##    Recover     5      10
##                                           
##                Accuracy : 0.6087          
##                  95% CI : (0.3854, 0.8029)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.5901          
##                                           
##                   Kappa : 0.1619          
##  Mcnemar's Test P-Value : 1.0000          
##                                           
##             Sensitivity : 0.4444          
##             Specificity : 0.7143          
##          Pos Pred Value : 0.5000          
##          Neg Pred Value : 0.6667          
##              Prevalence : 0.3913          
##          Detection Rate : 0.1739          
##    Detection Prevalence : 0.3478          
##       Balanced Accuracy : 0.5794          
##                                           
##        'Positive' Class : Death           
## 


Elastic-Net Regularized Generalized Linear Models

Lasso or elastic net regularization for generalized linear model regression are based on linear regression models and is useful when we have feature correlation in our model.

This model classified 13 out of 23 cases correctly.

set.seed(27)
model_glmnet <- caret::train(outcome ~ .,
                             data = val_train_data,
                             method = "glmnet",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_glmnet
## glmnet 
## 
## 56 samples
## 11 predictors
##  2 classes: 'Death', 'Recover' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ... 
## Resampling results across tuning parameters:
## 
##   alpha  lambda        Accuracy   Kappa    
##   0.10   0.0005154913  0.7211905  0.4218952
##   0.10   0.0051549131  0.7189524  0.4189835
##   0.10   0.0515491312  0.7469524  0.4704687
##   0.55   0.0005154913  0.7211905  0.4218952
##   0.55   0.0051549131  0.7236190  0.4280528
##   0.55   0.0515491312  0.7531905  0.4801031
##   1.00   0.0005154913  0.7228571  0.4245618
##   1.00   0.0051549131  0.7278571  0.4361809
##   1.00   0.0515491312  0.7678571  0.5094194
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final values used for the model were alpha = 1 and lambda = 0.05154913.
confusionMatrix(predict(model_glmnet, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       3       4
##    Recover     6      10
##                                           
##                Accuracy : 0.5652          
##                  95% CI : (0.3449, 0.7681)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.7418          
##                                           
##                   Kappa : 0.0496          
##  Mcnemar's Test P-Value : 0.7518          
##                                           
##             Sensitivity : 0.3333          
##             Specificity : 0.7143          
##          Pos Pred Value : 0.4286          
##          Neg Pred Value : 0.6250          
##              Prevalence : 0.3913          
##          Detection Rate : 0.1304          
##    Detection Prevalence : 0.3043          
##       Balanced Accuracy : 0.5238          
##                                           
##        'Positive' Class : Death           
## 


k-Nearest Neighbors

k-nearest neighbors predicts based on point distances with predefined constants.

This model classified 14 out of 23 cases correctly.

set.seed(27)
model_kknn <- caret::train(outcome ~ .,
                             data = val_train_data,
                             method = "kknn",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_kknn
## k-Nearest Neighbors 
## 
## 56 samples
## 11 predictors
##  2 classes: 'Death', 'Recover' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ... 
## Resampling results across tuning parameters:
## 
##   kmax  Accuracy   Kappa    
##   5     0.6531905  0.2670442
##   7     0.7120476  0.4031836
##   9     0.7106190  0.4004564
## 
## Tuning parameter 'distance' was held constant at a value of 2
## Tuning parameter 'kernel' was held constant at a value of optimal
## Accuracy was used to select the optimal model using  the largest value.
## The final values used for the model were kmax = 7, distance = 2 and kernel = optimal.
confusionMatrix(predict(model_kknn, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       3       3
##    Recover     6      11
##                                           
##                Accuracy : 0.6087          
##                  95% CI : (0.3854, 0.8029)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.5901          
##                                           
##                   Kappa : 0.1266          
##  Mcnemar's Test P-Value : 0.5050          
##                                           
##             Sensitivity : 0.3333          
##             Specificity : 0.7857          
##          Pos Pred Value : 0.5000          
##          Neg Pred Value : 0.6471          
##              Prevalence : 0.3913          
##          Detection Rate : 0.1304          
##    Detection Prevalence : 0.2609          
##       Balanced Accuracy : 0.5595          
##                                           
##        'Positive' Class : Death           
## 


Penalized Discriminant Analysis

Penalized Discriminant Analysis is the penalized linear discriminant analysis and is also useful for when we have highly correlated features.

This model classified 14 out of 23 cases correctly.

set.seed(27)
model_pda <- caret::train(outcome ~ .,
                             data = val_train_data,
                             method = "pda",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pda
## Penalized Discriminant Analysis 
## 
## 56 samples
## 11 predictors
##  2 classes: 'Death', 'Recover' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ... 
## Resampling results across tuning parameters:
## 
##   lambda  Accuracy   Kappa    
##   0e+00         NaN        NaN
##   1e-04   0.7255238  0.4373766
##   1e-01   0.7235238  0.4342554
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was lambda = 1e-04.
confusionMatrix(predict(model_pda, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       4       4
##    Recover     5      10
##                                           
##                Accuracy : 0.6087          
##                  95% CI : (0.3854, 0.8029)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.5901          
##                                           
##                   Kappa : 0.1619          
##  Mcnemar's Test P-Value : 1.0000          
##                                           
##             Sensitivity : 0.4444          
##             Specificity : 0.7143          
##          Pos Pred Value : 0.5000          
##          Neg Pred Value : 0.6667          
##              Prevalence : 0.3913          
##          Detection Rate : 0.1739          
##    Detection Prevalence : 0.3478          
##       Balanced Accuracy : 0.5794          
##                                           
##        'Positive' Class : Death           
## 


Stabilized Linear Discriminant Analysis

Stabilized Linear Discriminant Analysis is designed for high-dimensional data and correlated co-variables.

This model classified 15 out of 23 cases correctly.

set.seed(27)
model_slda <- caret::train(outcome ~ .,
                             data = val_train_data,
                             method = "slda",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_slda
## Stabilized Linear Discriminant Analysis 
## 
## 56 samples
## 11 predictors
##  2 classes: 'Death', 'Recover' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ... 
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.6886667  0.3512234
## 
## 
confusionMatrix(predict(model_slda, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       3       2
##    Recover     6      12
##                                           
##                Accuracy : 0.6522          
##                  95% CI : (0.4273, 0.8362)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.4216          
##                                           
##                   Kappa : 0.2069          
##  Mcnemar's Test P-Value : 0.2888          
##                                           
##             Sensitivity : 0.3333          
##             Specificity : 0.8571          
##          Pos Pred Value : 0.6000          
##          Neg Pred Value : 0.6667          
##              Prevalence : 0.3913          
##          Detection Rate : 0.1304          
##    Detection Prevalence : 0.2174          
##       Balanced Accuracy : 0.5952          
##                                           
##        'Positive' Class : Death           
## 


Nearest Shrunken Centroids

Nearest Shrunken Centroids computes a standardized centroid for each class and shrinks each centroid toward the overall centroid for all classes.

This model classified 15 out of 23 cases correctly.

set.seed(27)
model_pam <- caret::train(outcome ~ .,
                             data = val_train_data,
                             method = "pam",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pam
## Nearest Shrunken Centroids 
## 
## 56 samples
## 11 predictors
##  2 classes: 'Death', 'Recover' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ... 
## Resampling results across tuning parameters:
## 
##   threshold  Accuracy   Kappa    
##   0.1200215  0.7283333  0.4418904
##   1.7403111  0.6455714  0.2319674
##   3.3606007  0.5904762  0.0000000
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was threshold = 0.1200215.
confusionMatrix(predict(model_pam, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       4       3
##    Recover     5      11
##                                           
##                Accuracy : 0.6522          
##                  95% CI : (0.4273, 0.8362)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.4216          
##                                           
##                   Kappa : 0.2397          
##  Mcnemar's Test P-Value : 0.7237          
##                                           
##             Sensitivity : 0.4444          
##             Specificity : 0.7857          
##          Pos Pred Value : 0.5714          
##          Neg Pred Value : 0.6875          
##              Prevalence : 0.3913          
##          Detection Rate : 0.1739          
##    Detection Prevalence : 0.3043          
##       Balanced Accuracy : 0.6151          
##                                           
##        'Positive' Class : Death           
## 


Single C5.0 Tree

C5.0 is another tree-based modeling algorithm.

This model classified 15 out of 23 cases correctly.

set.seed(27)
model_C5.0Tree <- caret::train(outcome ~ .,
                             data = val_train_data,
                             method = "C5.0Tree",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_C5.0Tree
## Single C5.0 Tree 
## 
## 56 samples
## 11 predictors
##  2 classes: 'Death', 'Recover' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ... 
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.7103333  0.4047334
## 
## 
confusionMatrix(predict(model_C5.0Tree, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       5       4
##    Recover     4      10
##                                           
##                Accuracy : 0.6522          
##                  95% CI : (0.4273, 0.8362)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.4216          
##                                           
##                   Kappa : 0.2698          
##  Mcnemar's Test P-Value : 1.0000          
##                                           
##             Sensitivity : 0.5556          
##             Specificity : 0.7143          
##          Pos Pred Value : 0.5556          
##          Neg Pred Value : 0.7143          
##              Prevalence : 0.3913          
##          Detection Rate : 0.2174          
##    Detection Prevalence : 0.3913          
##       Balanced Accuracy : 0.6349          
##                                           
##        'Positive' Class : Death           
## 


Partial Least Squares

Partial least squares regression combined principal component analysis and multiple regression and is useful for modeling with correlated features.

This model classified 15 out of 23 cases correctly.

set.seed(27)
model_pls <- caret::train(outcome ~ .,
                             data = val_train_data,
                             method = "pls",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pls
## Partial Least Squares 
## 
## 56 samples
## 11 predictors
##  2 classes: 'Death', 'Recover' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ... 
## Resampling results across tuning parameters:
## 
##   ncomp  Accuracy   Kappa    
##   1      0.6198571  0.2112982
##   2      0.6376190  0.2436222
##   3      0.6773810  0.3305780
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was ncomp = 3.
confusionMatrix(predict(model_pls, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Recover
##    Death       3       2
##    Recover     6      12
##                                           
##                Accuracy : 0.6522          
##                  95% CI : (0.4273, 0.8362)
##     No Information Rate : 0.6087          
##     P-Value [Acc > NIR] : 0.4216          
##                                           
##                   Kappa : 0.2069          
##  Mcnemar's Test P-Value : 0.2888          
##                                           
##             Sensitivity : 0.3333          
##             Specificity : 0.8571          
##          Pos Pred Value : 0.6000          
##          Neg Pred Value : 0.6667          
##              Prevalence : 0.3913          
##          Detection Rate : 0.1304          
##    Detection Prevalence : 0.2174          
##       Balanced Accuracy : 0.5952          
##                                           
##        'Positive' Class : Death           
## 


Comparing accuracy of models

All models were similarly accurate.

# Create a list of models
models <- list(rf = model_rf, glmnet = model_glmnet, kknn = model_kknn, pda = model_pda, slda = model_slda,
               pam = model_pam, C5.0Tree = model_C5.0Tree, pls = model_pls)

# Resample the models
resample_results <- resamples(models)

# Generate a summary
summary(resample_results, metric = c("Kappa", "Accuracy"))
## 
## Call:
## summary.resamples(object = resample_results, metric = c("Kappa", "Accuracy"))
## 
## Models: rf, glmnet, kknn, pda, slda, pam, C5.0Tree, pls 
## Number of resamples: 100 
## 
## Kappa 
##             Min. 1st Qu. Median   Mean 3rd Qu. Max. NA's
## rf       -0.3636  0.3214 0.5714 0.5338  0.6800    1    0
## glmnet   -0.3636  0.2768 0.5455 0.5094  0.6667    1    0
## kknn     -0.3636  0.1667 0.5035 0.4032  0.6282    1    0
## pda      -0.6667  0.2431 0.5455 0.4374  0.6667    1    0
## slda     -0.5217  0.0000 0.4308 0.3512  0.6667    1    0
## pam      -0.5000  0.2292 0.5455 0.4419  0.6667    1    0
## C5.0Tree -0.4286  0.1667 0.4167 0.4047  0.6154    1    0
## pls      -0.6667  0.0000 0.3333 0.3306  0.6282    1    0
## 
## Accuracy 
##            Min. 1st Qu. Median   Mean 3rd Qu. Max. NA's
## rf       0.4000  0.6667 0.8000 0.7802  0.8393    1    0
## glmnet   0.4000  0.6500 0.8000 0.7679  0.8333    1    0
## kknn     0.3333  0.6000 0.7571 0.7120  0.8333    1    0
## pda      0.2000  0.6000 0.8000 0.7255  0.8333    1    0
## slda     0.2000  0.6000 0.6905 0.6887  0.8333    1    0
## pam      0.2000  0.6000 0.8000 0.7283  0.8333    1    0
## C5.0Tree 0.2000  0.6000 0.7143 0.7103  0.8333    1    0
## pls      0.1429  0.5929 0.6667 0.6774  0.8333    1    0
bwplot(resample_results , metric = c("Kappa","Accuracy"))


Combined results of predicting validation test samples

To compare the predictions from all models, I summed up the prediction probabilities for Death and Recovery from all models and calculated the log2 of the ratio between the summed probabilities for Recovery by the summed probabilities for Death. All cases with a log2 ratio bigger than 1.5 were defined as Recover, cases with a log2 ratio below -1.5 were defined as Death, and the remaining cases were defined as uncertain.

results <- data.frame(randomForest = predict(model_rf, newdata = val_test_data[, -1], type="prob"),
                      glmnet = predict(model_glmnet, newdata = val_test_data[, -1], type="prob"),
                      kknn = predict(model_kknn, newdata = val_test_data[, -1], type="prob"),
                      pda = predict(model_pda, newdata = val_test_data[, -1], type="prob"),
                      slda = predict(model_slda, newdata = val_test_data[, -1], type="prob"),
                      pam = predict(model_pam, newdata = val_test_data[, -1], type="prob"),
                      C5.0Tree = predict(model_C5.0Tree, newdata = val_test_data[, -1], type="prob"),
                      pls = predict(model_pls, newdata = val_test_data[, -1], type="prob"))

results$sum_Death <- rowSums(results[, grep("Death", colnames(results))])
results$sum_Recover <- rowSums(results[, grep("Recover", colnames(results))])
results$log2_ratio <- log2(results$sum_Recover/results$sum_Death)
results$true_outcome <- val_test_data$outcome
results$pred_outcome <- ifelse(results$log2_ratio > 1.5, "Recover", ifelse(results$log2_ratio < -1.5, "Death", "uncertain"))
results$prediction <- ifelse(results$pred_outcome == results$true_outcome, "CORRECT", 
                             ifelse(results$pred_outcome == "uncertain", "uncertain", "wrong"))
results[, -c(1:16)]
##          sum_Death sum_Recover  log2_ratio true_outcome pred_outcome prediction
## case_123 2.2236374    5.776363  1.37723972        Death    uncertain  uncertain
## case_127 0.7522674    7.247733  3.26821230      Recover      Recover    CORRECT
## case_128 2.4448101    5.555190  1.18411381      Recover    uncertain  uncertain
## case_14  2.9135620    5.086438  0.80387172      Recover    uncertain  uncertain
## case_19  2.2113377    5.788662  1.38831062        Death    uncertain  uncertain
## case_2   3.6899508    4.310049  0.22410277        Death    uncertain  uncertain
## case_20  4.5621189    3.437881 -0.40818442      Recover    uncertain  uncertain
## case_21  4.9960238    3.003976 -0.73390698      Recover    uncertain  uncertain
## case_34  6.1430233    1.856977 -1.72599312        Death        Death    CORRECT
## case_37  0.8717595    7.128241  3.03154401      Recover      Recover    CORRECT
## case_5   1.7574945    6.242505  1.82860496      Recover      Recover    CORRECT
## case_51  3.8917736    4.108226  0.07808791      Recover    uncertain  uncertain
## case_55  3.1548368    4.845163  0.61897986      Recover    uncertain  uncertain
## case_6   3.5163388    4.483661  0.35060317        Death    uncertain  uncertain
## case_61  5.0803241    2.919676 -0.79911232        Death    uncertain  uncertain
## case_65  1.1282313    6.871769  2.60661863      Recover      Recover    CORRECT
## case_74  5.3438427    2.656157 -1.00853699      Recover    uncertain  uncertain
## case_78  3.2183947    4.781605  0.57115378        Death    uncertain  uncertain
## case_79  1.9521617    6.047838  1.63134700      Recover      Recover    CORRECT
## case_8   3.6106045    4.389395  0.28178184        Death    uncertain  uncertain
## case_87  4.9712879    3.028712 -0.71491522        Death    uncertain  uncertain
## case_91  1.8648837    6.135116  1.71800508      Recover      Recover    CORRECT
## case_94  2.1198087    5.880191  1.47192901      Recover    uncertain  uncertain

All predictions based on all models were either correct or uncertain.


Predicting unknown outcomes

The above models will now be used to predict the outcome of cases with unknown fate.

set.seed(27)
model_rf <- caret::train(outcome ~ .,
                             data = train_data,
                             method = "rf",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_glmnet <- caret::train(outcome ~ .,
                             data = train_data,
                             method = "glmnet",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_kknn <- caret::train(outcome ~ .,
                             data = train_data,
                             method = "kknn",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pda <- caret::train(outcome ~ .,
                             data = train_data,
                             method = "pda",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_slda <- caret::train(outcome ~ .,
                             data = train_data,
                             method = "slda",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pam <- caret::train(outcome ~ .,
                             data = train_data,
                             method = "pam",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_C5.0Tree <- caret::train(outcome ~ .,
                             data = train_data,
                             method = "C5.0Tree",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pls <- caret::train(outcome ~ .,
                             data = train_data,
                             method = "pls",
                             preProcess = NULL,
                             trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))

models <- list(rf = model_rf, glmnet = model_glmnet, kknn = model_kknn, pda = model_pda, slda = model_slda,
               pam = model_pam, C5.0Tree = model_C5.0Tree, pls = model_pls)

# Resample the models
resample_results <- resamples(models)

bwplot(resample_results , metric = c("Kappa","Accuracy"))

Here again, the accuracy is similar in all models.


The final results are calculate as described above.

results <- data.frame(randomForest = predict(model_rf, newdata = test_data, type="prob"),
                      glmnet = predict(model_glmnet, newdata = test_data, type="prob"),
                      kknn = predict(model_kknn, newdata = test_data, type="prob"),
                      pda = predict(model_pda, newdata = test_data, type="prob"),
                      slda = predict(model_slda, newdata = test_data, type="prob"),
                      pam = predict(model_pam, newdata = test_data, type="prob"),
                      C5.0Tree = predict(model_C5.0Tree, newdata = test_data, type="prob"),
                      pls = predict(model_pls, newdata = test_data, type="prob"))

results$sum_Death <- rowSums(results[, grep("Death", colnames(results))])
results$sum_Recover <- rowSums(results[, grep("Recover", colnames(results))])
results$log2_ratio <- log2(results$sum_Recover/results$sum_Death)
results$predicted_outcome <- ifelse(results$log2_ratio > 1.5, "Recover", ifelse(results$log2_ratio < -1.5, "Death", "uncertain"))
results[, -c(1:16)]
##          sum_Death sum_Recover  log2_ratio predicted_outcome
## case_100 3.6707808    4.329219  0.23801986         uncertain
## case_101 6.3194220    1.680578 -1.91083509             Death
## case_102 6.4960471    1.503953 -2.11080274             Death
## case_103 2.3877223    5.612278  1.23295134         uncertain
## case_104 4.1254604    3.874540 -0.09053024         uncertain
## case_105 2.1082161    5.891784  1.48268174         uncertain
## case_108 6.5941436    1.405856 -2.22973606             Death
## case_109 2.2780779    5.721922  1.32868278         uncertain
## case_110 1.7853361    6.214664  1.79948072           Recover
## case_112 4.5102439    3.489756 -0.37007925         uncertain
## case_113 4.7047554    3.295245 -0.51373420         uncertain
## case_114 1.6397105    6.360290  1.95565132           Recover
## case_115 1.2351519    6.764848  2.45336902           Recover
## case_118 1.4675571    6.532443  2.15420601           Recover
## case_120 1.9322149    6.067785  1.65091447           Recover
## case_122 1.1165786    6.883421  2.62404109           Recover
## case_126 5.4933927    2.506607 -1.13196145         uncertain
## case_130 4.3035732    3.696427 -0.21940367         uncertain
## case_132 5.0166184    2.983382 -0.74976671         uncertain
## case_136 2.6824566    5.317543  0.98720505         uncertain
## case_15  5.8545370    2.145463 -1.44826607         uncertain
## case_16  6.5063794    1.493621 -2.12304122             Death
## case_22  4.2929861    3.707014 -0.21172398         uncertain
## case_28  6.6129447    1.387055 -2.25326754             Death
## case_31  6.2625800    1.737420 -1.84981057             Death
## case_32  6.0986316    1.901368 -1.68144746             Death
## case_38  1.7581540    6.241846  1.82791134           Recover
## case_39  2.4765188    5.523481  1.15726428         uncertain
## case_4   5.1309910    2.869009 -0.83868503         uncertain
## case_40  6.6069297    1.393070 -2.24571191             Death
## case_41  5.5499144    2.450086 -1.17963336         uncertain
## case_42  2.0709160    5.929084  1.51754019           Recover
## case_47  6.1988812    1.801119 -1.78311450             Death
## case_48  5.6500772    2.349923 -1.26565724         uncertain
## case_52  5.9096543    2.090346 -1.49933214         uncertain
## case_54  2.6763066    5.323693  0.99218409         uncertain
## case_56  2.6826414    5.317359  0.98705557         uncertain
## case_62  6.0722552    1.927745 -1.65531833             Death
## case_63  3.9541381    4.045862  0.03308379         uncertain
## case_66  5.8320010    2.167999 -1.42762690         uncertain
## case_67  2.1732059    5.826794  1.42287745         uncertain
## case_68  4.3464174    3.653583 -0.25051491         uncertain
## case_69  1.8902845    6.109715  1.69250181           Recover
## case_70  4.5986242    3.401376 -0.43508393         uncertain
## case_71  2.7966938    5.203306  0.89570626         uncertain
## case_80  2.7270986    5.272901  0.95123017         uncertain
## case_84  1.9485253    6.051475  1.63490411           Recover
## case_85  3.0076953    4.992305  0.73104755         uncertain
## case_86  2.0417802    5.958220  1.54505376           Recover
## case_88  0.7285419    7.271458  3.31916083           Recover
## case_9   2.2188620    5.781138  1.38153353         uncertain
## case_90  1.3973262    6.602674  2.24038155           Recover
## case_92  5.0994993    2.900501 -0.81405366         uncertain
## case_93  4.3979048    3.602095 -0.28798006         uncertain
## case_95  3.5561792    4.443821  0.32147260         uncertain
## case_96  1.3359082    6.664092  2.31858744           Recover
## case_99  5.0776686    2.922331 -0.79704644         uncertain

From 57 cases, 14 were defined as Recover, 10 as Death and 33 as uncertain.

results_combined <- merge(results[, -c(1:16)], fluH7N9.china.2013[which(fluH7N9.china.2013$case.ID %in% rownames(results)), ], 
                          by.x = "row.names", by.y = "case.ID")
results_combined <- results_combined[, -c(2, 3, 8, 9)]
results_combined_gather <- results_combined %>%
  gather(group_dates, date, date.of.onset:date.of.hospitalisation)

results_combined_gather$group_dates <- factor(results_combined_gather$group_dates, levels = c("date.of.onset", "date.of.hospitalisation"))

results_combined_gather$group_dates <- mapvalues(results_combined_gather$group_dates, from = c("date.of.onset", "date.of.hospitalisation"), 
                                             to = c("Date of onset", "Date of hospitalisation"))

results_combined_gather$gender <- mapvalues(results_combined_gather$gender, from = c("f", "m"), 
                                             to = c("Female", "Male"))
levels(results_combined_gather$gender) <- c(levels(results_combined_gather$gender), "unknown")
results_combined_gather$gender[is.na(results_combined_gather$gender)] <- "unknown"
results_combined_gather$age <- as.numeric(as.character(results_combined_gather$age))
ggplot(data = results_combined_gather, aes(x = date, y = log2_ratio, color = predicted_outcome)) +
  geom_jitter(aes(size = age), alpha = 0.3) +
  geom_rug() +
  facet_grid(gender ~ group_dates) +
  labs(
    color = "Predicted outcome",
    size = "Age",
    x = "Date in 2013",
    y = "log2 ratio of prediction Recover vs Death",
    title = "2013 Influenza A H7N9 cases in China",
    subtitle = "Predicted outcome",
    caption = ""
  ) +
  my_theme() +
  scale_color_brewer(palette="Set1") +
  scale_fill_brewer(palette="Set1")

The comparison of date of onset, data of hospitalisation, gender and age with predicted outcome shows that predicted deaths were associated with older age than predicted Recoveries. Date of onset does not show an obvious bias in either direction.


Conclusions

This dataset posed a couple of difficulties to begin with, like unequal distribution of data points across variables and missing data. This makes the modeling inherently prone to flaws. However, real life data isn’t perfect either, so I went ahead and tested the modeling success anyway.

By accounting for uncertain classification with low predictions probability, the validation data could be classified accurately. However, for a more accurate model, these few cases don’t give enough information to reliably predict the outcome. More cases, more information (i.e. more features) and fewer missing data would improve the modeling outcome.

Also, this example is only applicable for this specific case of flu. In order to be able to draw more general conclusions about flu outcome, other cases and additional information, for example on medical parameters like preexisting medical conditions, disase parameters, demographic information, etc. would be necessary.

All in all, this dataset served as a nice example of the possibilities (and pitfalls) of machine learning applications and showcases a basic workflow for building prediction models with R.


For a comparison of feature selection methods see here.

If you see any mistakes or have tips and tricks for improvement, please don’t hesitate to let me know! Thanks. :-)


If you are interested in more machine learning posts, check out the category listing for machine_learning.


sessionInfo()
## R version 3.3.2 (2016-10-31)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 7 x64 (build 7601) Service Pack 1
## 
## locale:
## [1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252    LC_MONETARY=English_United States.1252 LC_NUMERIC=C                           LC_TIME=English_United States.1252    
## 
## attached base packages:
## [1] grid      stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] pls_2.5-0           C50_0.1.0-24        pamr_1.55           survival_2.40-1     cluster_2.0.5       ipred_0.9-5         mda_0.4-9           class_7.3-14        kknn_1.3.1          glmnet_2.0-5        foreach_1.4.3       Matrix_1.2-7.1      randomForest_4.6-12 RColorBrewer_1.1-2  rpart.plot_2.1.0    rattle_4.1.0        rpart_4.1-10        caret_6.0-73        lattice_0.20-34     mice_2.25           Rcpp_0.12.7         dplyr_0.5.0         gridExtra_2.2.1     ggplot2_2.2.0       plyr_1.8.4          tidyr_0.6.0         outbreaks_1.0.0    
## 
## loaded via a namespace (and not attached):
##  [1] RGtk2_2.20.31      assertthat_0.1     digest_0.6.10      R6_2.2.0           MatrixModels_0.4-1 stats4_3.3.2       evaluate_0.10      e1071_1.6-7        lazyeval_0.2.0     minqa_1.2.4        SparseM_1.74       car_2.1-3          nloptr_1.0.4       partykit_1.1-1     rmarkdown_1.1      labeling_0.3       splines_3.3.2      lme4_1.1-12        stringr_1.1.0      igraph_1.0.1       munsell_0.4.3      compiler_3.3.2     mgcv_1.8-16        htmltools_0.3.5    nnet_7.3-12        tibble_1.2         prodlim_1.5.7      codetools_0.2-15   MASS_7.3-45        ModelMetrics_1.1.0 nlme_3.1-128       gtable_0.2.0       DBI_0.5-1          magrittr_1.5       scales_0.4.1       stringi_1.1.2      reshape2_1.4.2     Formula_1.2-1      lava_1.4.5         iterators_1.0.8    tools_3.3.2        parallel_3.3.2     pbkrtest_0.4-6     yaml_2.1.14        colorspace_1.3-0   knitr_1.15         quantreg_5.29

Analysing the Gilmore Girls' coffee addiction with R

2016-11-20 00:00:00 +0000

Last week’s post showed how to create a Gilmore Girls character network.

In this week’s short post, I want to explore the Gilmore Girls’ famous coffee addiction by analysing the same episode transcripts that were also used last week.

I am also showcasing how to use the recently updated ggplot2 2.2.0.

The transcripts were prepared as described in last week’s post. Full R code can be found below.


How often do they talk about coffee?

The first question I want to explore is how often Lorelai, Rory, Luke, Sookie, Lane and Paris talked about coffee. For this, I subsetted the lines of these characters to where they talk about coffee (i.e. lines that include the word “coffee” at least once). I then created a plot showing how often they talk about coffee per episode and season.



The differences in coffee mention between characters are of course somewhat biased by differences in the total number of lines of each character in the respective episodes, as well by the length of the episodes.

This led me to wonder what proportion of each episode was occupied by which character.





R code

# set names
names <- c("LORELAI", "RORY", "LUKE", "SOOKIE", "LANE", "PARIS")

# extract lines for each character in names
for (name in names){
  assign(paste("lines", name, sep = "_"), transcripts_2[grep(paste0(name), transcripts_2$character), ])
}

# extract lines not from characters in names
lines_OTHER <- transcripts_2[!grepl("LORELAI|RORY|LUKE|SOOKIE|LANE|PARIS", transcripts_2$character), ]

# update names to include "other"
names <- c(names, "OTHER")
# which lines mention coffee.
for (name in names){
  df <- get(paste("lines", name, sep = "_"))
  assign(paste("coffee", name, sep = "_"), df[grep("coffee", df$dialogue), ])
}
# calculate number of lines in total and with coffee mentions per episode
library(dplyr)

for (name in names){
  lines_p_ep <- get(paste("lines", name, sep = "_"))
  lines_p_ep <- as.data.frame(table(lines_p_ep$episode))
  
  df <- get(paste("coffee", name, sep = "_"))
  df_2 <- as.data.frame(table(df$episode))
  
  # calculate coffee lines per episode
  coffee_lines_p_ep <- full_join(lines_p_ep, df_2, by = "Var1")
  coffee_lines_p_ep$ratio_p_ep <- coffee_lines_p_ep$Freq.y/coffee_lines_p_ep$Freq.x
  
  coffee_lines_p_ep$Season <- gsub("_.*", "", coffee_lines_p_ep$Var1)
  coffee_lines_p_ep$character <- paste(name)
  
  assign(paste("coffee_lines_p_ep", name, sep = "_"), coffee_lines_p_ep)
}

coffee_df <- rbind(coffee_lines_p_ep_LORELAI, coffee_lines_p_ep_RORY, coffee_lines_p_ep_LUKE, coffee_lines_p_ep_SOOKIE, 
                   coffee_lines_p_ep_LANE, coffee_lines_p_ep_PARIS)
library(ggplot2)

my_theme <- function(base_size = 12, base_family = "sans"){
  theme_grey(base_size = base_size, base_family = base_family) +
  theme(
    axis.text = element_text(size = 12),
    axis.text.x = element_text(angle = 45, vjust = 0.5, hjust = 0.5),
    axis.title = element_text(size = 14),
    panel.grid.major = element_line(color = "grey"),
    panel.grid.minor = element_blank(),
    panel.background = element_rect(fill = "aliceblue"),
    strip.background = element_rect(fill = "lightgrey", color = "grey", size = 1),
    strip.text = element_text(face = "bold", size = 12, color = "black"),
    legend.position = "bottom",
    legend.justification = "top", 
    legend.box = "horizontal",
    legend.box.background = element_rect(colour = "grey50"),
    legend.background = element_blank(),
    panel.border = element_rect(color = "grey", fill = NA, size = 0.5)
  )
}
library(tidyr)

coffee_df_gather <- coffee_df %>% 
  gather(group, value, Freq.x:ratio_p_ep)

coffee_df_gather[which(coffee_df_gather$group == "Freq.x"), "group"] <- "Lines per episode"
coffee_df_gather[which(coffee_df_gather$group == "Freq.y"), "group"] <- "Coffee lines per episode"
coffee_df_gather[which(coffee_df_gather$group == "ratio_p_ep"), "group"] <- "Ratio of coffee lines per episode"

coffee_df_gather$character <- factor(coffee_df_gather$character, levels = names)
coffee_df_gather$group <- factor(coffee_df_gather$group, levels = c("Lines per episode", "Coffee lines per episode", "Ratio of coffee lines per episode"))

ggplot(data = coffee_df_gather, aes(x = character, y = value, color = Season, fill = Season)) +
  geom_boxplot(alpha = 0.5) +
  labs(
    x = "",
    y = "Number or ratio of lines",
    title = "Lorelai & Luke are the coffee queen and king",
    subtitle = "Coffee mentions per episode and season",
    caption = "\nThese boxplots show the number of total lines per episode and season, lines with coffee mentions per epsiode and season,
     as well as the ratio between the two for the main characters Lorelai, Rory, Luke, Sookie, Lane and Paris. Lorelai consistently had
     the most lines per episode, followed closely by Rory and Luke. Sookie, Lane and Paris had roughly the same numbers of lines per episode,
     although their variance is higher. The same trend is reflected by the number of lines with coffee mentions, except that Rory and Luke seem
     to talk similarly often about coffee. Interestingly, Lorelai's (verbal) coffee obsession seems to have decreased slightly over the seasons.
     The ratio of coffee lines divided by the total number of lines for each episode reflects that even though Lorelai, Rory and Luke talk a lot about coffee,
     they also talk a lot in general, so their ratio of coffee mentions is only slightly higher than Sookie's, Lane's and Paris'. 
     While the latter's mean ratio is lower. they have a much higher variance with a few episodes where they seem to have talked relatively much about coffee.
     During the first seasons, Luke is the character with the highest ratio of coffee vs general talk but this is decreasing over time and becomes more
     similar to Lorelai and Rory's coffee talk ratio."
  ) +
  my_theme() +
  facet_wrap(~ group, ncol = 3, scales = "free") +
  guides(fill = guide_legend(nrow = 1, byrow = TRUE)) +
  scale_fill_brewer(palette="Dark2") +
  scale_color_brewer(palette="Dark2")
# calculate number of lines in total
for (name in names){
  lines_p_ep <- get(paste("lines", name, sep = "_"))
  lines_p_ep <- as.data.frame(table(lines_p_ep$episode))

  lines_p_ep$character <- paste(name)
  
  assign(paste("lines_p_ep", name, sep = "_"), lines_p_ep)
}

# combine and convert to wide format to make calculating proportions easier
lines_df_wide <- spread(rbind(lines_p_ep_LORELAI, lines_p_ep_RORY, lines_p_ep_LUKE, lines_p_ep_SOOKIE, 
                              lines_p_ep_LANE, lines_p_ep_PARIS, lines_p_ep_OTHER), Var1, Freq)
rownames(lines_df_wide) <- lines_df_wide[, 1]
lines_df_wide <- as.data.frame(t(lines_df_wide[, -1]))

# calculate proportions and percentage
prop_lines_df_wide <- sweep(lines_df_wide, 1, rowSums(lines_df_wide), FUN = "/")
percent_lines_df_wide <- prop_lines_df_wide * 100
percent_lines_df_wide$Episode <- rownames(percent_lines_df_wide)
percent_lines_df_wide$Season <- gsub("_.*", "", percent_lines_df_wide$Episode)

# gather for plotting
percent_lines_df_gather <- percent_lines_df_wide %>% 
  gather(Character, Percent, LANE:SOOKIE)

percent_lines_df_gather$Character <- factor(percent_lines_df_gather$Character, levels = names)
percent_lines_df_gather <- merge(percent_lines_df_gather, subset(lines_LORELAI, select = c(episode, episode_running_nr)), by.x = "Episode", by.y = "episode")
percent_lines_df_gather$Episode <- factor(percent_lines_df_gather$Episode, levels = percent_lines_df_gather[order(percent_lines_df_gather$episode_running_nr, decreasing = TRUE), "Episode"])
percent_lines_df_gather <- percent_lines_df_gather[!duplicated(percent_lines_df_gather), ]
ggplot(data = percent_lines_df_gather, aes(x = Character, y = Percent, fill = Season)) +
  geom_boxplot(alpha = 0.5) +
  labs(
    x = "",
    y = "Percent",
    title = "Who's the biggest talker? (a)",
    subtitle = "Percentage of lines per episode spoken by main characters",
    caption = "\nThe boxplot shows the percentages of lines spoken by the main characters per episode and season.
    In the first three seasons the majority of lines are clearly spoken by Lorelai and Rory, followed with some distance by Luke.
    During the later seasons, the gap between Lorelai, Rory and the rest is not as pronounced any more."
  ) +
  my_theme() +
  facet_wrap(~ Season, ncol = 4, scales = "free") +
  guides(fill = guide_legend(nrow = 1, byrow = TRUE)) +
  scale_fill_brewer(palette = "Dark2") +
  scale_color_brewer(palette = "Dark2")
# calculate number of lines in total
for (name in names){
  lines_p_season <- get(paste("lines", name, sep = "_"))
  lines_p_season <- as.data.frame(table(lines_p_season$season))

  lines_p_season$character <- paste(name)
  
  assign(paste("lines_p_season", name, sep = "_"), lines_p_season)
}

# combine and convert to wide format to make calculating proportions easier
lines_df_wide <- spread(rbind(lines_p_season_LORELAI, lines_p_season_RORY, lines_p_season_LUKE, lines_p_season_SOOKIE, 
                              lines_p_season_LANE, lines_p_season_PARIS, lines_p_season_OTHER), Var1, Freq)
rownames(lines_df_wide) <- lines_df_wide[, 1]
lines_df_wide <- as.data.frame(t(lines_df_wide[, -1]))

# calculate proportions and percentage
prop_lines_df_wide <- sweep(lines_df_wide, 1, rowSums(lines_df_wide), FUN = "/")
percent_lines_df_wide <- prop_lines_df_wide * 100
percent_lines_df_wide$Season <- rownames(percent_lines_df_wide)

# gather for plotting
percent_lines_df_gather <- percent_lines_df_wide %>%
  gather(Character, Percent, LANE:SOOKIE)

percent_lines_df_gather$Character <- factor(percent_lines_df_gather$Character, levels = names)
# plot
bp <- ggplot(percent_lines_df_gather, aes(x = "", y = Percent, fill = Character)) + 
  geom_bar(width = 1, stat = "identity") + theme_minimal() +
  scale_fill_brewer(palette = "Spectral")

pie <- bp + coord_polar("y", start = 0) +
  theme(
  axis.title.x = element_blank(),
  axis.title.y = element_blank(),
  panel.border = element_blank(),
  panel.grid = element_blank(),
  axis.ticks = element_blank(),
  plot.title = element_text(size = 14, face = "bold"),
  legend.title = element_blank(),
  legend.position = "bottom",
  legend.text = element_text(size = 8),
  legend.justification = "top", 
    legend.box = "horizontal",
    legend.box.background = element_rect(colour = "grey50"),
    legend.background = element_blank(),
  ) + guides(fill = guide_legend(nrow = 1, byrow = TRUE))

pie + facet_wrap( ~ Season, ncol = 4) +
  labs(
    title = "Who's the biggest talker? (b)",
    subtitle = "Percent of lines per season spoken by main characters")



sessionInfo()
## R version 3.3.2 (2016-10-31)
## Platform: x86_64-apple-darwin13.4.0 (64-bit)
## Running under: macOS Sierra 10.12
## 
## locale:
## [1] de_DE.UTF-8/de_DE.UTF-8/de_DE.UTF-8/C/de_DE.UTF-8/de_DE.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] ggplot2_2.2.0         dplyr_0.5.0           splitstackshape_1.4.2
## [4] data.table_1.9.6      tidyr_0.6.0          
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.8        knitr_1.15         magrittr_1.5      
##  [4] munsell_0.4.3      colorspace_1.3-1   R6_2.2.0          
##  [7] stringr_1.1.0      plyr_1.8.4         tools_3.3.2       
## [10] grid_3.3.2         gtable_0.2.0       DBI_0.5-1         
## [13] htmltools_0.3.5    lazyeval_0.2.0     yaml_2.1.14       
## [16] assertthat_0.1     digest_0.6.10      tibble_1.2        
## [19] RColorBrewer_1.1-2 evaluate_0.10      rmarkdown_1.1     
## [22] labeling_0.3       stringi_1.1.2      scales_0.4.1      
## [25] chron_2.3-47