Greetings! In this session on Decision Trees, we aim to implement a full Decision Tree from scratch in Python. Decision Trees are a type of Supervised Machine Learning in which data is continuously split according to certain parameters.
A Decision Tree has a tree-like structure with each internal node denoting a test on an attribute, each branch representing an outcome of the test, and each terminal node (leaf) holding a class label. Here are the parts of a Decision Tree:
- Root Node: This houses the entire dataset.
- Internal Nodes: These make decisions based on conditions.
- Edges/branches: These connections implement decision rules.
- Leaves: These are terminal nodes for making predictions.
Decisions on attributes depend on how well they help to purify the data.
The tree-building process begins with the full dataset at the root node, iteratively partitioning the data based on chosen attributes. Each child node becomes a new root that can be split further. This recursive process continues until predefined stopping criteria are met.
Standard stopping criteria include:
- Maximum Tree Depth: Limiting the maximum depth of the tree.
- Minimum Node Records: No more partitioning if less than a threshold number of records.
- Node Purity: Stop if all instances at a node belong to the same class.
These criteria ensure the model is consistent, which prevents overfitting.
Now, we'll use Python to build the decision tree. We'll rely on the existing get_best_split
function from the previous lesson to find the optimal split for our data.
Here is how a terminal node is created:
Python1def create_terminal(group): 2 outcomes = [row[-1] for row in group] 3 return max(set(outcomes), key=outcomes.count)
The create_terminal
function determines the most common class value in a group of rows and assigns that value as the final decision for that subset of data.
Let's proceed to the actual tree building:
Python1def build_tree(train, max_depth, min_size): 2 root = get_best_split(train) 3 recurse_split(root, max_depth, min_size, 1) 4 return root
This function begins the tree-building process.
The recurse_split
function is responsible for creating children nodes:
Python1def recurse_split(node, max_depth, min_size, depth): 2 # Split into left and right groups 3 left, right = node['groups'] 4 del(node['groups']) 5 6 # If left or right groups are empty, create a terminal node 7 if not left or not right: 8 node['left'] = node['right'] = create_terminal(left + right) 9 return 10 11 # Check for max depth 12 if depth >= max_depth: 13 node['left'], node['right'] = create_terminal(left), create_terminal(right) 14 return 15 16 # Process the children nodes 17 if len(left) <= min_size: 18 node['left'] = create_terminal(left) 19 else: 20 node['left'] = get_best_split(left) 21 recurse_split(node['left'], max_depth, min_size, depth+1) 22 23 if len(right) <= min_size: 24 node['right'] = create_terminal(right) 25 else: 26 node['right'] = get_best_split(right) 27 recurse_split(node['right'], max_depth, min_size, depth+1)
We can then build and print the decision tree based on the dataset and chosen parameters:
Python1# Sample dataset 2dataset = [ 3 [5, 3, 0], [6, 3, 0], [6, 4, 0], [10, 3, 1], 4 [11, 4, 1], [12, 8, 0], [5, 5, 0], [12, 4, 1] 5] 6 7max_depth = 2 8min_size = 1 9tree = build_tree(dataset, max_depth, min_size) 10 11# Print the tree 12def print_tree(node, depth=0): 13 if isinstance(node, dict): 14 print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value']))) 15 print_tree(node['left'], depth+1) 16 print_tree(node['right'], depth+1) 17 else: 18 print('%s[%s]' % ((depth*' ', node))) 19 20print_tree(tree) 21'''Output: 22[X1 < 10.000] 23 [X1 < 5.000] 24 [0] 25 [0] 26 [X2 < 8.000] 27 [1] 28 [0] 29'''
The print_tree
output shows a decision tree. [X1 < 10.000]
checks if feature X1
is less than 10.000
. Left branch ([X1 < 5.000]
) and its subsequent nodes ([0]
and [0]
) indicate the conditions and predictions if 'Yes'. The right branch ([X2 < 8.000]
) and its nodes ([1]
and [0]
) cover the cases for 'No'. Indentations imply tree depth.
Congratulations! You've now built a decision tree from scratch in Python. Proceed to the practice to reinforce your understanding. Keep progressing and keep implementing! Happy tree building!