Greetings! In this session on Decision Trees, we aim to implement a full Decision Tree from scratch in Python. Decision Trees are a type of Supervised Machine Learning in which data is continuously split according to certain parameters.
A Decision Tree has a tree-like structure with each internal node denoting a test on an attribute, each branch representing an outcome of the test, and each terminal node (leaf) holding a class label. Here are the parts of a Decision Tree:
Decisions on attributes depend on how well they help to purify the data.
The tree-building process begins with the full dataset at the root node, iteratively partitioning the data based on chosen attributes. Each child node becomes a new root that can be split further. This recursive process continues until predefined stopping criteria are met.
Standard stopping criteria include:
These criteria ensure the model is consistent, which prevents overfitting.
Now, we'll use Python to build the decision tree. We'll rely on the existing get_best_split
function from the previous lesson to find the optimal split for our data.
Here is how a terminal node is created:
Python1def create_terminal(group): 2 outcomes = [row[-1] for row in group] 3 return max(set(outcomes), key=outcomes.count)
The create_terminal
function determines the most common class value in a group of rows and assigns that value as the final decision for that subset of data.
Let's proceed to the actual tree building:
Python1def build_tree(train, max_depth, min_size): 2 root = get_best_split(train) 3 recurse_split(root, max_depth, min_size, 1) 4 return root
This function begins the tree-building process.
The recurse_split
function is responsible for creating children nodes:
Python1def recurse_split(node, max_depth, min_size, depth): 2 # Split into left and right groups 3 left, right = node['groups'] 4 del(node['groups']) 5 6 # If left or right groups are empty, create a terminal node 7 if not left or not right: 8 node['left'] = node['right'] = create_terminal(left + right) 9 return 10 11 # Check for max depth 12 if depth >= max_depth: 13 node['left'], node['right'] = create_terminal(left), create_terminal(right) 14 return 15 16 # Process the children nodes 17 if len(left) <= min_size: 18 node['left'] = create_terminal(left) 19 else: 20 node['left'] = get_best_split(left) 21 recurse_split(node['left'], max_depth, min_size, depth+1) 22 23 if len(right) <= min_size: 24 node['right'] = create_terminal(right) 25 else: 26 node['right'] = get_best_split(right) 27 recurse_split(node['right'], max_depth, min_size, depth+1)
We can then build and print the decision tree based on the dataset and chosen parameters:
Python1# Sample dataset 2dataset = [ 3 [5, 3, 0], [6, 3, 0], [6, 4, 0], [10, 3, 1], 4 [11, 4, 1], [12, 8, 0], [5, 5, 0], [12, 4, 1] 5] 6 7max_depth = 2 8min_size = 1 9tree = build_tree(dataset, max_depth, min_size) 10 11# Print the tree 12def print_tree(node, depth=0): 13 if isinstance(node, dict): 14 print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value']))) 15 print_tree(node['left'], depth+1) 16 print_tree(node['right'], depth+1) 17 else: 18 print('%s[%s]' % ((depth*' ', node))) 19 20print_tree(tree) 21'''Output: 22[X1 < 10.000] 23 [X1 < 5.000] 24 [0] 25 [0] 26 [X2 < 8.000] 27 [1] 28 [0] 29'''
The print_tree
output shows a decision tree. [X1 < 10.000]
checks if feature X1
is less than 10.000
. Left branch ([X1 < 5.000]
) and its subsequent nodes ([0]
and [0]
) indicate the conditions and predictions if 'Yes'. The right branch ([X2 < 8.000]
) and its nodes ([1]
and [0]
) cover the cases for 'No'. Indentations imply tree depth.
Congratulations! You've now built a decision tree from scratch in Python. Proceed to the practice to reinforce your understanding. Keep progressing and keep implementing! Happy tree building!