2016-04-20 14 views
5

Chciałbym sprawdzić wszystkie obserwacje, które osiągnęły pewien węzeł w drzewku decyzyjnym rpart. Na przykład, w następujący kodu:Uzyskiwanie obserwacji w węźle rpart (tj .: CART)

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis) 
fit 

n= 81 

node), split, n, loss, yval, (yprob) 
     * denotes terminal node 

1) root 81 17 absent (0.79.20987654) 
    2) Start>=8.5 62 6 absent (0.90322581 0.09677419) 
    4) Start>=14.5 29 0 absent (1.00000000 0.00000000) * 
    5) Start< 14.5 33 6 absent (0.81818182 0.18181818) 
     10) Age< 55 12 0 absent (1.00000000 0.00000000) * 
     11) Age>=55 21 6 absent (0.71428571 0.28571429) 
     22) Age>=111 14 2 absent (0.85714286 0.14285714) * 
     23) Age< 111 7 3 present (0.42857143 0.57142857) * 
    3) Start< 8.5 19 8 present (0.42105263 0.57894737) * 

ja widok Wszystkie uwagi w węźle (5) (tj .: W 33 obserwacji, dla których Home> = 8,5 & start < 14.5). Oczywiście mogłem ręcznie się do nich dostać. Ale chciałbym mieć funkcję podobną do (powiedzmy) "get_node_date". Dla których mógłbym po prostu uruchomić get_node_date (5) - i uzyskać odpowiednie obserwacje.

Wszelkie sugestie, jak to zrobić?

Odpowiedz

1

Wydaje się, że nie ma takiej funkcji, która umożliwia wydobycie obserwacjami z określonego węzła. Rozwiążę go w następujący sposób: najpierw określ, która zasada/s jest/są używane dla węzła, którego jesteś zainteresowany. Możesz użyć do tego celu path.rpart. Następnie możesz zastosować regułę/zasady jeden po drugim, aby wyodrębnić obserwacje.

Takie podejście jako funkcję:

get_node_date <- function(tree = fit, node = 5){ 
    rule <- path.rpart(tree, node) 
    rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE)) 
    ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all) 
    kyphosis[ind,] 
    } 

dla węzła 5 otrzymasz:

get_node_date() 

node number: 5 
    root 
    Start>=8.5 
    Start< 14.5 
    Kyphosis Age Number Start 
2 absent 158  3 14 
10 present 59  6 12 
11 present 82  5 14 
14 absent 1  4 12 
18 absent 175  5 13 
20 absent 27  4  9 
23 present 96  3 12 
26 absent 9  5 13 
28 absent 100  3 14 
32 absent 125  2 11 
33 absent 130  5 13 
35 absent 140  5 11 
37 absent 1  3  9 
39 absent 20  6  9 
40 present 91  5 12 
42 absent 35  3 13 
46 present 139  3 10 
48 absent 131  5 13 
50 absent 177  2 14 
51 absent 68  5 10 
57 absent 2  3 13 
59 absent 51  7  9 
60 absent 102  3 13 
66 absent 17  4 10 
68 absent 159  4 13 
69 absent 18  4 11 
71 absent 158  5 14 
72 absent 127  4 12 
74 absent 206  4 10 
77 present 157  3 13 
78 absent 26  7 13 
79 absent 120  2 13 
81 absent 36  4 13 
1

rpart zwraca rpart.object element, który zawiera potrzebne informacje:

require(rpart) 
fit2 <- rpart(Kyphosis ~ Age + Start, data = kyphosis) 
fit2 

get_node_date <-function(nodeId,fit) 
{ 
    fit$frame[toString(nodeId),"n"] 
} 


for (i in c(1,2,4,5,10,11,22,23,3)) 
    cat(get_node_date(i,fit2),"\n") 
+1

Nie dostaniesz uwagi poprzez to, ale tylko liczbę abservations które należą do kategorii – DatamineR

+1

masz rację, misread pytanie –

1

Pakiet partykit zapewnia również konserwy rozwiązanie tego problemu. Musisz tylko przekonwertować obiekt rpart na klasę party, aby użyć zunifikowanego interfejsu do obsługi drzew. A następnie możesz użyć funkcji data_party().

Używanie fit z pytaniem i po załadowaniu library("partykit") można najpierw zmusić rpart drzewo do party:

pfit <- as.party(fit) 
plot(pfit) 

full pfit tree

Istnieją tylko dwa małe zagrożenia dla wydobywania danych w sposób chcesz: (1) Model model.frame() z oryginalnego dopasowania zawsze jest rzucany w przymus i musi zostać ponownie podłączony ręcznie. (2) Dla węzłów stosowany jest inny schemat numeracji. Chcesz węzeł 4 (zamiast 5) teraz.

pfit$data <- model.frame(fit) 
data4 <- data_party(pfit, 4) 
dim(data4) 
## [1] 33 5 
head(data4) 
## Kyphosis Age Start (fitted) (response) 
## 2 absent 158 14  7  absent 
## 10 present 59 12  8 present 
## 11 present 82 14  8 present 
## 14 absent 1 12  5  absent 
## 18 absent 175 13  7  absent 
## 20 absent 27  9  5  absent 

Inna droga jest podzbiór poddrzewa począwszy od węzła 4, a następnie biorąc dane z tego:

pfit4 <- pfit[4] 
plot(pfit4) 

subtree of pfit from node 4

Następnie data_party(pfit4) daje takie same jak data4 powyżej. Ponadto pfit4$data podaje dane bez węzła (fitted) i przewidywanego (response).

+0

jeśli użyto 'ptree $ data <- model.frame (eval (tree $ call $ data)) 'zmienne nie używane w formule nie zostaną usunięte – rawr

+0

Prawda ... ale tylko wtedy, gdy' dane' zawiera wszystkie zmienne w 'formule', co niekoniecznie musi się zdarzyć. Z 'model.frame()' otrzymujesz również zmienne transformowane, np. 'Log()', 'Surv()' lub 'factor()' wersje zmiennych, które często są tworzone w locie. –

+0

BTW: Współczynnik 'as.party()' dla obiektów 'rpart' teraz domyślnie przyjmuje wartości danych! Tak więc możesz zrobić "as.party (dopasowanie, dane = TRUE)" (co jest nowym domyślnym) lub 'as.party (dopasowanie, dane = FAŁSZ)' (co odpowiada starszemu zachowaniu). –

1

Jeszcze innym sposobem jest znalezienie wszystkich węzłów końcowych dowolnego węzła i zwrócenie podzbioru danych używanych w wywołaniu.

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis) 

head(subset.rpart(fit, 5)) 
# Kyphosis Age Number Start 
# 2 absent 158  3 14 
# 10 present 59  6 12 
# 11 present 82  5 14 
# 14 absent 1  4 12 
# 18 absent 175  5 13 
# 20 absent 27  4  9 


subset.rpart <- function(tree, node = 1L) { 
    data <- eval(tree$call$data, parent.frame(1L)) 
    wh <- sapply(as.integer(rownames(tree$frame)), parent) 
    wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)])) 
    data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ] 
} 

parent <- function(x) { 
    if (x[1] != 1) 
    c(Recall(if (x %% 2 == 0L) x/2 else (x - 1)/2), x) else x 
}