How k-Nearest Neighbors works?

K-Nearest Neighbors (k-NN) is one of the classification algorithms in Machine learing. Since, K-NN simply memories or stores the rules in memory, we can also say that it does not learn the mapping function(f) between inputs and labels.

The kNN algorithm can be summarized in following four steps:

1. Compute distance between the test point and each of the training points
2. Sort the distances in descending order
3. Pick 'k' nearest neighbors from the sorted items
4. Apply majority vote of labels for classification or averaging of label values for regression problem.

Before going to implementation, let us build some standard notation so that it is easier to follow the code in the implementation section (next post):

          X_training = given training data
          Y = labels for your given training data
          X_test = new data (For which we need to predict the labels)
          

The whole algorithm can be divided into three major steps:

  1. Finding most nearest neighbors (similar data instances) from your training data for a given test data (X_test)

    Let's say we have total 20 training data, and you find out 4 training instances as nearest neighbors for
    one of your test data
  2. Once you get the nearest neighbors from your training data, collect all the labels of your selected training data

    In our case, we have 4 training instances as nearest neighbors. So, we will collect labels for all
    these 4 training data. 
  3. Finally, predict the label of your test data (X_test) based on majority count.

    In our case, suppose 3 out of 4 training instances have the same label. Then, the majority count
    will assign the label to new data point

The K defines the number of neighbors that the algorithm will collect for making the prediction of your new data.

Example

Suppose, we have a data set of a fruits with features : weight and height and labels as "Apple, Peach, Pear" as follows:

In [2]:
from IPython.display import HTML, display
import tabulate
table = [["Height(cm)","Weight(gm)","Class"],
         [16, 160, 'Pear'],
         [14, 159, 'Pear'],
         [7, 90, 'Apple'],
         [8, 95, 'Apple'],
         [15, 165, 'Pear'],
         [4, 150, 'Peach'],
         [5, 145, 'Peach'],
         [7.5, 100, 'Apple'],
         [6, 144, 'Peach'],
        
        
        ]
display(HTML(tabulate.tabulate(table, tablefmt='html')))
Height(cm) Weight(gm) Class
16 160 Pear
14 159 Pear
7 90 Apple
8 95 Apple
15 165 Pear
4 150 Peach
5 145 Peach
7.5 100 Apple
6 144 Peach

Now, suppose we have a machine to classify fruit based on 'weight' and 'height' information. Let's assume a new entry has following features

     Height : 10 cm
     Weight : 170 gm
     

Which class this new student belongs to?

Now, let us use some distance measure between this new point and all the training data points as follows:

sqrt((10-16)^2 + (170-160)^2) = 11.66

In [24]:
import math
def distance(p1, p2):
    sq_distance = (p1[0] - p2[0])**2 + (p1[1] - p2[1])**2
    return math.ceil(sq_distance**(1/2))
In [33]:
table = [["Height(cm)","Weight(kg)","Class", "Distance"],
         [16, 160, 'Pear', distance([16, 160], [10, 170])],
         [14, 159, 'Pear', distance([14, 159], [168, 65])],
         [7, 90, 'Apple', distance([7, 90], [168, 65])],
         [8, 95, 'Apple', distance([8, 95], [168, 65])],
         [15, 165, 'Pear', distance([15, 165], [168, 65])],
         [4, 150, 'Peach', distance([4, 150], [168, 65])],
         [5, 145, 'Peach', distance([5, 145], [168, 65])],
         [7.5, 100, 'Apple', distance([7.5, 100], [168, 65])],
         [6, 144, 'Peach', distance([[6, 144], [168, 65])],
        
        
        ]
display(HTML(tabulate.tabulate(table, tablefmt='html')))
Height(cm) Weight(kg) Class Distance
169 60 Asian 6
171 59 Asian 7
172 70 European 7
179 69 European 12
170 75 Asian 8
175 80 American 17
176 79 American 17
180 71 European 14
171 76 American 12

From the result table, let's select k=3, neighbors having minimal distance(nearest points)

 a. 169 	60 	Asian 	      6 
 b. 171 	59 	Asian 	      7 
 c. 172 	70 	European    	7 

So all possible labels in our final result is : ['Asian', 'Asian', 'European']

With majority vote, we can classify the new student as 'Asian'

In [ ]:
 

Comments

Comments powered by Disqus