# file nnet/knn.q copyright (C) 1994-8 W. N. Venables and B. D. Ripley
#
knn1 <- function(train, test, cl)
{
    train <- as.matrix(train)
    if(is.null(dim(test))) dim(test) <- c(1, length(test))
    test <- as.matrix(test)
    p <- ncol(train)
    ntr <- nrow(train)
    if(length(cl) != ntr) stop("train and class have different lengths")
    nte <- nrow(test)
    if(ncol(test) != p) stop("Dims of test and train differ")
    clf <- as.factor(cl)
    nc <- max(unclass(clf))
    res <- .C("VR_knn1",
              as.integer(ntr),
              as.integer(nte),
              as.integer(p),
              as.double(train),
              as.integer(unclass(clf)),
              as.double(test),
              res = integer(nte),
              integer(nc+1),
              as.integer(nc),
              d = double(nte)
              )$res
    factor(res, levels=seq(along=levels(clf)), labels=levels(clf))
}

knn <- function(train, test, cl, k=1, l=0, prob=FALSE, use.all=TRUE)
{
    train <- as.matrix(train)
    if(is.null(dim(test))) dim(test) <- c(1, length(test))
    test <- as.matrix(test)
    p <- ncol(train)
    ntr <- nrow(train)
    if(length(cl) != ntr) stop("train and class have different lengths")
    if(ntr < k) {
        warning(paste("k =",k,"exceeds number",ntr,"of patterns"))
        k <- ntr
    }
    if (k < 1) stop(paste("k =",k,"must be at least 1"))
    nte <- nrow(test)
    if(ncol(test) != p) stop("Dims of test and train differ")
    clf <- as.factor(cl)
    nc <- max(unclass(clf))
    Z <- .C("VR_knn",
            as.integer(k),
            as.integer(l),
            as.integer(ntr),
            as.integer(nte),
            as.integer(p),
            as.double(train),
            as.integer(unclass(clf)),
            as.double(test),
            res = integer(nte),
            pr = double(nte),
            integer(nc+1),
            as.integer(nc),
            as.integer(FALSE),
            as.integer(use.all)
            )
    res <- factor(Z$res, levels=seq(along=levels(clf)),labels=levels(clf))
    if(prob) attr(res, "prob") <- Z$pr
    res
}

knn.cv <- function(train, cl, k=1, l=0, prob=FALSE, use.all=TRUE)
{
    train <- as.matrix(train)
    p <- ncol(train)
    ntr <- nrow(train)
    if(ntr-1 < k) {
        warning(paste("k =",k,"exceeds number",ntr-1,"of patterns"))
        k <- ntr - 1
    }
    if (k < 1) stop(paste("k =",k,"must be at least 1"))
    clf <- as.factor(cl)
    nc <- max(unclass(clf))
    Z <- .C("VR_knn",
            as.integer(k),
            as.integer(l),
            as.integer(ntr),
            as.integer(ntr),
            as.integer(p),
            as.double(train),
            as.integer(unclass(clf)),
            as.double(train),
            res = integer(ntr),
            pr = double(ntr),
            integer(nc+1),
            as.integer(nc),
            as.integer(TRUE),
            as.integer(use.all)
            )
    res <- factor(Z$res, levels=seq(along=levels(clf)),labels=levels(clf))
    if(prob) attr(res, "prob") <- Z$pr
    res
}
# file nnet/lvq.q copyright (C) 1994-8 W. N. Venables and B. D. Ripley
#
lvqinit <- function(x, cl, size, prior, k=5)
{
    x <- as.matrix(x)
    n <- nrow(x)
    p <- ncol(x)
    if(length(cl) != n) stop("x and cl have different lengths")
    g <- as.factor(cl)
    counts <- tapply(rep(1, length(g)), g, sum)
    prop <- counts/n
    np <- length(prop)
    # allow for supplied prior
    if(missing(prior)) prior <- prop
    else if(any(prior <0)||round(sum(prior), 5) != 1)
        stop("invalid prior")
    if(length(prior) != np) stop("prior is of incorrect length")
    if(missing(size)) size <- min(round(0.4 * np * (np-1+p/2),0), n)
    inside <- knn.cv(x, cl, k) == cl
    selected <- numeric(0)
    for(i in 1:np){
        set <- seq(along=g)[unclass(g)==i & inside]
        if(length(set) > 1)
            set <- sample(set, min(length(set), round(size*prior[i])))
        selected <- c(selected, set)
    }
    list(x = x[selected, , drop=FALSE], cl = cl[selected])
}

olvq1 <- function(x, cl, codebk, niter = 40*nrow(codebk$x), alpha = 0.3)
{
    x <- as.matrix(x)
    n <- nrow(x)
    p <- ncol(x)
    nc <- dim(codebk$x)[1]
    if(length(cl) != n) stop("x and cl have different lengths")
    iters <- sample(n, niter, TRUE)
    z <- .C("VR_olvq",
            as.double(alpha),
            as.integer(n),
            as.integer(p),
            as.double(x),
            as.integer(unclass(cl)),
            as.integer(nc),
            xc = as.double(codebk$x),
            as.integer(codebk$cl),
            as.integer(niter),
            as.integer(iters-1)
            )
    xc <- matrix(z$xc,nc,p)
    dimnames(xc) <- dimnames(codebk$x)
    list(x = xc, cl = codebk$cl)
}

lvq1 <- function(x, cl, codebk, niter = 100*nrow(codebk$x), alpha = 0.03)
{
    x <- as.matrix(x)
    n <- nrow(x)
    p <- ncol(x)
    nc <- dim(codebk$x)[1]
    if(length(cl) != n) stop("x and cl have different lengths")
    iters <- sample(n, niter, TRUE)
    z <- .C("VR_lvq1",
            as.double(alpha),
            as.integer(n),
            as.integer(p),
            as.double(x),
            as.integer(unclass(cl)),
            as.integer(nc),
            xc = as.double(codebk$x),
            as.integer(codebk$cl),
            as.integer(niter),
            as.integer(iters-1)
            )
    xc <- matrix(z$xc,nc,p)
    dimnames(xc) <- dimnames(codebk$x)
    list(x = xc, cl = codebk$cl)
}

lvq2 <- function(x, cl, codebk, niter = 100*nrow(codebk$x), alpha = 0.03,
                 win = 0.3)
{
    x <- as.matrix(x)
    n <- nrow(x)
    p <- ncol(x)
    nc <- dim(codebk$x)[1]
    if(length(cl) != n) stop("x and cl have different lengths")
    iters <- sample(n, niter, TRUE)
    z <- .C("VR_lvq2",
            as.double(alpha),
            as.double(win),
            as.integer(n),
            as.integer(p),
            as.double(x),
            as.integer(unclass(cl)),
            as.integer(nc),
            xc = as.double(codebk$x),
            as.integer(codebk$cl),
            as.integer(niter),
            as.integer(iters-1)
            )
    xc <- matrix(z$xc,nc,p)
    dimnames(xc) <- dimnames(codebk$x)
    list(x = xc, cl = codebk$cl)
}

lvq3 <- function(x, cl, codebk, niter = 100*nrow(codebk$x),
                 alpha = 0.03, win = 0.3, epsilon = 0.1)
{
    x <- as.matrix(x)
    n <- nrow(x)
    p <- ncol(x)
    nc <- dim(codebk$x)[1]
    if(length(cl) != n) stop("x and cl have different lengths")
    iters <- sample(n, niter, TRUE)
    z <- .C("VR_lvq3",
            as.double(alpha),
            as.double(win),
            as.double(epsilon),
            as.integer(n),
            as.integer(p),
            as.double(x),
            as.integer(unclass(cl)),
            as.integer(nc),
            xc = as.double(codebk$x),
            as.integer(codebk$cl),
            as.integer(niter),
            as.integer(iters-1)
            )
    xc <- matrix(z$xc,nc,p)
    dimnames(xc) <- dimnames(codebk$x)
    list(x = xc, cl = codebk$cl)
}

lvqtest <- function(codebk, test) knn1(codebk$x, test, codebk$cl)
# file nnet/multiedit.q copyright (C) 1994-8 W. N. Venables and B. D. Ripley
#
multiedit <- function(x, class, k=1, V=3, I=5, trace=TRUE)
{
    n1 <- length(class)
    class <- codes(class)
    index <- 1:n1
    pass <- lpass <- 0
    repeat{
        if(n1 < 5*V) {
            warning("retained set is now too small to proceed")
            break
        }
        pass <- pass + 1
        sub <- sample(V, length(class), replace=TRUE)
        keep <- logical(length(class))
        for (i in 1:V){
            train <- sub==i
            test <- sub==(1 + i%%V)
            keep[test] <- (knn(x[train, , drop=FALSE], x[test, , drop=FALSE],
                               class[train],k) == class[test])
        }
        x <- x[keep, , drop=FALSE]; class <- class[keep]; index <- index[keep]
        n2 <- length(class)
        if(n2 < n1) lpass <- pass
        if(lpass <= pass - I) break
        n1 <- n2
        if(trace) cat(paste("pass ", pass," size ", n2, "\n"))
    }
    index
}

condense <- function(train, class, store=sample(seq(n), 1), trace=TRUE)
{
    n <- length(class)
    bag <- rep(TRUE, n)
    bag[store] <- FALSE
    repeat {
        if(trace) print(seq(n)[!bag])
        if(sum(bag) == 0) break
        res <- knn1(train[!bag,,drop = FALSE], train[bag,,drop = FALSE],
                    class[!bag])
        add <- res != class[bag]
        if(sum(add) == 0) break
        cand <- (seq(n)[bag])[add]
	if(length(cand) > 1) cand <- sample(cand, 1)
        bag[cand] <- FALSE
    }
    seq(n)[!bag]
}

reduce.nn <- function(train, ind, class)
{
    n <- length(class)
    rest <- seq(n)[-ind]
    # this must be done iteratively, not simultaneously
    for(i in sample(ind)) {
        res <- knn1(train[-c(rest,i),,drop=FALSE],
                    train[c(rest,i),,drop=FALSE],
                    class[-c(rest,i)])
        if(all(res == class[c(rest,i)])) rest <- c(rest,i)
    }
    seq(n)[-rest]
}

.First.lib <- function(lib, pkg)
    library.dynam("class", pkg, lib)
