
##################################### twostage.snmm.R ###################################################
#
#   daniel almirall (dalmiral [arroba] umich [punto] edu)
#   posted december, 13, 2010 at the university of michigan
#
#########################################################################################################
#
#   this R code accompanies the following journal article:
#
#   Almirall, D., Tenhave, T., and Murphy, S.A. (2009).
#   "Structural Nested Mean Models for Assessing Time-varying Effect Moderation"
#   Biometrics, Volume 66, Issue 1, pp131-139.
#
#   overview: This R code implements a simple version of the two-stage regression with residuals estimator 
#       (Almirall, et al. (2009)) of Robins' Structural Nested Mean Model (SNMM). It returns an estimate of the 
#       asymptotic variance-covariance matrix (and therefore standard errors) for the estimated SNMM parameters 
#       (stage-2), taking into account sampling error in the estimation of the stage-1 parameters. The SNMM, 
#       as described in Almirall, et al. (2009), is useful when scientists are interested in understanding 
#       causal effect moderation in settings in which the treatment (or primary exposure) of interest is 
#       time-varying and so are covariates hypothesized to moderate its effects. 
#
#   disclaimer: This R code is provided with no guarantees. It is an alpha version, written originally for 
#       research purposes only. Therefore, it may not be user-friendly and may be difficult to use. If you 
#       find bugs/errors in this code, please email Danny (dalmiral [arroba] umich [punto] edu): I will fix the 
#       code and then acknowledge you on this website. Coming soon: We are working on a user-friendly, 
#       distributable, implementation of the 2-stage regression estimator that we will post on CRAN.
#
#   usage: twostage.snmm( modgivpast, coefnuisfunc, intermeff, response, dat, 
#           verbose=T, full.return=T, nmods.vec )
#
#   arguments: values inside the less-than/greater-than brackets <...> is to be supplied by user. note the
#               peculiar use of the tilde ~ throughout. the arguments to the function are:
#
#       modgivpast  =   a list of lists specifying formulas for the linear models for the "mod"erators 
#                       "giv"en the "past". the residuals from these models are used in the nuisance functions 
#                       (error terms) in the snmm. moderators at baseline must be listed first, followed by 
#                       moderators at time 2, and so on. the user must specify if moderator is continuous (continuous=1)
#                       or binary (continuous=0). if continuous=0, the stage 1 regression is a logistic regression; 
#                       otherwise, ordinary least squares regression is used in stage 1.
#
#                       list( 
#                           list( <moderator1> ~ <formula1>, <continuous1> ), 
#                           list( <moderator2> ~ <formula2>, <continuous2> ), 
#                           ...
#                       )
#
#       coefnuisfunc=   a list of models for the "coef"ficient of "nuis"ance "func"tions.
#                       each model in modgivpast is used to create a residual. the corresponding residual 
#                       is then multiplied by the terms specified here to make up the corresponding model for the
#                       error term (nuisance function). the order of the terms matters here; they should be in 
#                       the order corresponding to the moderators specified in modgivpast. this is how the 
#                       function will know which residual to multiply with which set of terms.
#
#                       list(
#                           ~<formula1>,
#                           ~<formula2>,
#                           ...
#                       )
#
#       intermeff   =   a list of lists specifying the primary treatment (or exposure) of interest and
#                       the models for the "interm"ediate causal "eff"ects.
#                       there should be as many inner lists as there are time-varying treatments.
#                       that is, there should be as many lists as there are intermediate causal effect functions
#                       
#                       list(
#                           list(~<treatment1>, ~<formula1>),
#                           list(~<treatment2>, ~<formula2>),
#                           ...
#                       )
#
#       response    =   name of the end of study outcome
#
#                       ~<outcome>
#
#       dat         =   name of the data set containing the outcome, the time-varying moderators, and the 
#                       treatments
#
#       verbose     =   set to T (TRUE) if you want to print a lot of output
#
#       full.return =   set to T (TRUE) if you want a lot of output
#
#       nmods.vec   =   a vector (of dimension equal to the number of time points in the snmm) specifying the 
#                       number of moderators per time point. the length of the vector will tell the function 
#                       how many time points there are. the length of intermeff should be equal to the length 
#                       of nmods.vec. the sum of nmods.vec should equal the number of lists inside
#                       modgivpast and coefnuisfunc.
#
#########################################################################################################




## load some libraries for this to work
library(Matrix)
library(sandwich)

## 2-stage Least Squares Estimation of the SNMM

twostage.snmm <- function(modgivpast, coefnuisfunc, intermeff, response, dat, verbose=T, full.return=T, nmods.vec)
{

# must be careful not to include in this function a data set with e. variables
# provide a warning if this is the case
if ( length(grep("e.",names( dat ), fixed=TRUE)) > 0  ) { warning("'e.' variables were found in dat--CAREFUL: these variables will be used in the second regression INSTEAD of the estimated ones")}

n <- nrow(dat)
nt <- length(intermeff)
nmods <- length(modgivpast)
if ( !(nmods==sum(nmods.vec))  ) { stop("nmods should equal sum(nmods.vec)")}
e.names <- unlist(lapply(modgivpast, FUN=function(x) paste("e." ,as.character(x[[1]][[2]]),sep="" ) ))

########################################
############# Estimates ################
########################################

## get stage1 lm/glm objects
stg1.objs <- lapply(modgivpast, FUN=function(x){
    if (x[[2]]==1) {
        lm.obj <- lm(x[[1]], data=dat )
        return( list(lm.obj,1) )
    } else if (x[[2]]==0) {
        glm.obj <- glm(x[[1]], data=dat, family="binomial")
        return( list(glm.obj,0) )
    } else stop("Second argument in modgivpast must be either 1 (continuous moderator) or 0")
    })

## get stage1 residuals from the stage1 objects
stg1.res <- lapply(stg1.objs, FUN=function(x){
    if (x[[2]]==1) {
        res <- x[[1]]$residuals
        return(res)
    } else if (x[[2]]==0) {
        res <- x[[1]]$model[[1]] - x[[1]]$fitted.values
        return(res)
    } else stop("Second argument in modgivpast must be either 1 (continuous moderator) or 0")
    })


## get stage1 design matrices only
## for binary moderators, multiply by expit*(1-expit) term
zs <- lapply(stg1.objs, FUN=function(x) {
    mf <- model.frame(x[[1]])
    mt <- attr(mf,"terms")
    z <- model.matrix(mt,mf)    # design matrix from the stg1 objects
    if (x[[2]]==1){
        return( list(z,x[[2]],NULL) )
    } else if (x[[2]]==0) {
        .p <- x[[1]]$fitted.values
        expit.term <- .p * ( 1 - .p )
        return( list(z,x[[2]],expit.term) )
    } else stop("Second argument in modgivpast must be either 1 (continuous moderator) or 0")
    })


## bind the residuals into the dataframe and give them appropriate names
## these residuals make up the nuisance functions
names(stg1.res) <- e.names
stg1.res <- as.data.frame(stg1.res)
dat <- as.data.frame(cbind(dat,stg1.res))

## construct intermediate effect functions
intermeff.updated <- lapply(intermeff, FUN=function(x) {my.update.formula( x[[2]] , eval(substitute( ~ ddd + . : ddd, list(ddd = x[[1]][[2]]))) )} )

## construct coefficient of nuisance functions
coefnuisfunc.updated <- vector("list",nmods)
for (i in 1:nmods){
    tmp <- as.formula( paste("~",e.names[i],sep="") )
    if (coefnuisfunc[[i]]==~1) {
        coefnuisfunc.updated[[i]] <- update( coefnuisfunc[[i]] , eval(substitute( ~ ddd, list(ddd = tmp[[2]]))) )
    }
    else {
        coefnuisfunc.updated[[i]] <- update( coefnuisfunc[[i]] , eval(substitute( ~ ddd + . : ddd, list(ddd = tmp[[2]]))) )
    }
}

## build final snmm formula; begin picking up the intermediate effect functions
snmm.formula <- eval(substitute( ddd ~ 1, list(ddd = response[[2]]) ))
for(i in 1:nt){     
    snmm.formula <- my.update.formula( snmm.formula, eval(substitute( ~ . + ddd, list(ddd = intermeff.updated[[i]][[2]]))) )
}

## pick up the (non-augmented) x matrix; the one corresponding only to the intermediate effects
## also obtain names for the intermediate effect functions; useful to isolate beta later
mf <- model.frame( terms(snmm.formula, keep.order=T) , data=dat)
mt <- attr(mf, "terms")
x <- model.matrix(mt,mf)
x.names <- attr(x,"dimnames")[[2]]
np.x <- dim(x)[2]   # number of intermediate effect parameters

## continue building final snmm formula; pick up the coefficient of nuisance functions
for(i in 1:nmods){  ## Pick up the updated coefficient of nuisance functionals
    snmm.formula <- my.update.formula( snmm.formula, eval(substitute( ~ . + ddd, list(ddd = coefnuisfunc.updated[[i]][[2]]))) )
}

## get stage2/final model object
stg2.obj <- lm( terms(snmm.formula, keep.order=T) , data=dat)

## obtain estimates
theta.hat <- stg2.obj$coeff ## vector corresponding to entire second stage estimation
beta.hat <- theta.hat[x.names]  ## vector corresponding only to the intermediate effect functions
eta.hat  <- lapply(e.names, FUN=function(x){    ## list corresponding to the coefficient of nuisance functions
        theta.hat[grep(x, names(theta.hat) )]
    })


##############################################
############# Standard Errors ################
##############################################

## get stage1 D.gamma.inv for standard errors
D.gamma.inv <- bdiag(lapply(zs, FUN=function(x){
    z <- x[[1]]
    n <- nrow(z)
    if (x[[2]]==1){
        ztz <- (t(z) %*% z) / n  ## take the average
        ztz.inv <- solve(ztz)  ## invert matrix
        return( ztz.inv )
    } else if (x[[2]]==0) {
        c <- diag(x[[3]])
        cztz <- ( t(c %*% z) %*% z) / n
        cztz.inv <- solve(cztz)
        return( cztz.inv )
    }
    }))
    

## get coefficient of nuisance function design matrix for the standard errors
qs <- lapply(coefnuisfunc, FUN=function(x) {
        mf <- model.frame(x, data=dat)
        mt <- attr(mf,"terms")
        q <- model.matrix(mt,mf)  ## design matrix
        return(q)
    })


## get stage1 estimating equations for standard errors
stg1.esteqs <- as.matrix(as.data.frame(lapply(stg1.objs, FUN=function(x){
        stg1.esteq <- estfun( x[[1]] )
        return(stg1.esteq)
    })))


## get qetas for the standard errors
qetas <- vector("list",nmods)
for (i in 1:nmods){
    qetas[[i]] <- qs[[i]] %*% eta.hat[[i]]
}


## get stage2 estimating equation for standard errors
stg2.esteq <- estfun( stg2.obj )

## get stage2 residuals for standard errors
stg2.res <- stg2.obj$residuals

## get stage2 D.theta.inv for standard errors 
## pick up the augmented x matrix along the way
mf <- model.frame( terms(snmm.formula, keep.order=T) , data=dat)
mt <- attr(mf, "terms")
xaug <- model.matrix(mt, mf)
n <- nrow(xaug)
D.theta.inv <- solve( ( t(xaug) %*% xaug ) / n )
np.theta <- dim(xaug)[2]
np.nuis <- np.theta - np.x  # number of nuisance parameters

## get average.neg.stg2.res.qtzs for the standard errors
avg.neg.stg2.res.qtzs <- vector("list",nmods)
for (i in 1:nmods){
    if (zs[[i]][[2]]==1) {
        avg.neg.stg2.res.qtzs[[i]] <- ( t(qs[[i]]) %*% (stg2.res * zs[[i]][[1]]) ) / n
    }
    else if (zs[[i]][[2]]==0) {
        avg.neg.stg2.res.qtzs[[i]] <- ( t(qs[[i]]) %*% (zs[[i]][[3]] * stg2.res * zs[[i]][[1]]) ) / n
    }
}


## turn each of the matrices in the list avg.neg.stg2.res.qtzs into an np.theta-long matrix
## with zeroe matrices in the appropriate places throughout
## this is the first matrix in the sum that makes up the various matrices in D.gamma.theta; see notes
## at first, i thought this matrix was all zero, but now I don't; I am trying anew
pos <- np.theta - np.nuis
for (i in 1:nmods){
    tmp1 <- avg.neg.stg2.res.qtzs[[i]]
    dim.tmp1 <- dim(tmp1)
    tmp2 <- matrix(0.0, nrow=np.theta, ncol=dim.tmp1[2])
    tmp2[ (pos+1):(pos+dim.tmp1[1]) ,] <- tmp1
    avg.neg.stg2.res.qtzs[[i]] <- tmp2
    pos <- pos + dim.tmp1[1]
}

## get avg.qeta.xaugtzs for standard errors
avg.qeta.xaugtzs <- vector("list",nmods)
for (i in 1:nmods){
    if (zs[[i]][[2]]==1) {
        avg.qeta.xaugtzs[[i]] <- ( t(xaug) %*% ( as.vector(qetas[[i]]) * zs[[i]][[1]]) ) / n
    }
    else if (zs[[i]][[2]]==0) {
        avg.qeta.xaugtzs[[i]] <- ( t(xaug) %*% ( as.vector(qetas[[i]]) * zs[[i]][[3]] * zs[[i]][[1]]) ) / n
    }
}

## construct D.gamma.theta; this is the non-invertible matrix used in the standard errors
D.gamma.theta.list <- vector("list",nmods)
for (i in 1:nmods){
    D.gamma.theta.list[[i]] <- avg.neg.stg2.res.qtzs[[i]] + avg.qeta.xaugtzs[[i]]
}
D.gamma.theta <- as.matrix(as.data.frame(D.gamma.theta.list))



## construct several averages needed for the standard errors
Pn.stg2eeq.stg2eeqt <- ( t(stg2.esteq) %*% stg2.esteq ) / n
Pn.stg2eeq.stg1eeqt <- ( t(stg2.esteq) %*% stg1.esteqs ) / n
Pn.stg1eeq.stg1eeqt <- ( t(stg1.esteqs) %*% stg1.esteqs ) / n


D.gamma.theta.D.gamma.inv <- D.gamma.theta %*% D.gamma.inv 
C <- Pn.stg2eeq.stg1eeqt %*% t(D.gamma.theta.D.gamma.inv)
Pn.VVt <- Pn.stg2eeq.stg2eeqt - C - t(C) + D.gamma.theta.D.gamma.inv %*% Pn.stg1eeq.stg1eeqt %*% t(D.gamma.theta.D.gamma.inv)

## final asymptotic variance-covariance matrix
VCov.theta <- ( D.theta.inv %*% Pn.VVt %*% D.theta.inv ) / n

##########################################################################
############# Consolidate Results and Print Model Summary ################
##########################################################################

se.theta <- sqrt(diag(VCov.theta))
zval <- theta.hat / se.theta

ans <- cbind(theta.hat, se.theta, zval, 2 * pnorm(abs(zval), lower.tail=F) )
dimnames(ans) <- list(names(theta.hat), c("Estimate", "ASE", "z value", "Pr(>|z|)"))

## if verbose=T, print the results of the 2stage estimator of the snmm in a "pretty" way
if (verbose) {
    cat("\n*********************************************************\n")
    cat("  2-Stage Estimator of the Structural Nested Mean Model  \n")
    cat("*********************************************************\n\n")
    cat("Outcome......................................",as.character(response)[[2]],"\n")
    cat("Sample size..................................",nrow(x),"\n")
    cat("Number of Time-points........................",nt,"\n")
    cat("Number of Stage-2 Parameters.................",np.theta,"\n")   
    cat("Number of Stage-2 Nuisance Parameters........",np.nuis,"\n")   
    cat("Number of Causal Parameters, plus Intercept..",np.x,"\n\n")   
    cat("---------------------------------------------------------\n")
    cat("Marginal mean response under no treatment\n")
    cat("-----------------------------------------\n")
    tmpans <- ans[1,]
    tmpans <- t(as.matrix(tmpans))
    rownames(tmpans) <- "(SNMM Intercept)    "    
    print(round(tmpans,4))
    cat("\n")
    intermeffnms <- vector(length=nt,mode="list")
    indintermeff <- 1.
    for (tt in 1:nt){
        cat("---------------------------------------------------------\n")
        cat("Intermediate causal effect function at time",tt,"\n")
        cat("----------------------------------------------\n")
        cat("Treatment variable:",intermeff[[tt]][[1]][[2]],"\n")
        mf <- model.frame( terms(intermeff[[tt]][[2]],keep.order=T),data=dat )
        mt <- attr(mf,"terms")
        mt.labels <- attr(mt,"term.labels")
        lenintermeff <- length(mt.labels)+1
        intname <- paste("(Causal ",tt," Intercept)",sep="")
        if(length(mt.labels)==0){
            tmpans <- ans[ (indintermeff + 1):(indintermeff+lenintermeff), ]
            tmpans <- t(as.matrix(tmpans))
            rownames(tmpans) <- intname
        } else {
            tmpans <- ans[ (indintermeff + 1):(indintermeff+lenintermeff), ]
            rownames(tmpans) <- c(intname,mt.labels)
        }
        indintermeff <- indintermeff + lenintermeff
        print(round(tmpans,4))
        cat("\n")
    }
    cntr <- 1.       # counter for the de-meaned moderator
    for (tt in 1:nt){
        cat("---------------------------------------------------------\n")
        cat("Non-causal nuisance terms at time",tt,"\n")
        cat("------------------------------------\n")
        for (jj in 1:nmods.vec[tt]){
            cat("De-meaned moderator:",e.names[cntr],"\n")
            mf <- model.frame( terms( coefnuisfunc[[cntr]] , keep.order=T),data=dat )
            mt <- attr(mf,"terms")
            mt.labels <- attr(mt,"term.labels")
            lenintermeff <- length(mt.labels)+1
            intname <- paste("(Nuis (",tt,",",jj,") Intcpt)",sep="")
            if(length(mt.labels)==0){
                tmpans <- ans[ (indintermeff + 1):(indintermeff+lenintermeff), ]
                tmpans <- t(as.matrix(tmpans))
                rownames(tmpans) <- intname
            } else {
                tmpans <- ans[ (indintermeff + 1):(indintermeff+lenintermeff), ]
                rownames(tmpans) <- c(intname,mt.labels)
            }
            indintermeff <- indintermeff + lenintermeff
            print(round(tmpans,4))
            cntr <- cntr + 1  #update counter
            cat("\n")
        }
    }
}

######################################
############# Returns ################
######################################

if (full.return) {
    ret <- list()
    ret$ans <- ans
    ret$beta.names <- x.names
    ret$vcov.theta <- VCov.theta
    ret$stage1 <- stg1.objs
    ret$stage2 <- stg2.obj
    return(ret)
} else {
    ret <- list()
    ret$ans <- ans
    ret$beta.names <- x.names
    ret$vcov.theta <- VCov.theta
    return(ret)
}

} ## end twostage function



############################### i use this function above ###########
my.update.formula <- function(old, new, keep.order=T, ...)
{
    env <- environment(as.formula(old))
    tmp <- .Internal(update.formula(as.formula(old), as.formula(new)))
    out <- formula(terms.formula(tmp, simplify = TRUE, keep.order=keep.order))
    environment(out) <- env
    return(out)
}
