Linear Regression, the Basics
An important part of Artificial Intelligence is about making predictions, based on things an algorithm has learned in the past. From a high level perspective such an algorithm may work as follows:
1. Take in a bunch of labeled data
2. Try to find a pattern that makes sense
3. Make predictions for other data points using this pattern
Linear regression is a method for finding such a pattern. It is part of typical statistics, and works well when the relationship across the data is linear. This can easily be seen when plotting the data points. Intuitively, a straight line can be drawn that matches the trend in this data.
Other relationships between x and y can be better described by different methods such as logistic regression (categorical data) or polynomial regression (more outliers and/or curved relationship).
A practical example where linear regression does its magic is finding a relationship between different properties of houses and their actual market value. Some terminology:
Symbol | Terminology | Definition | Example |
---|---|---|---|
x | Feature | The input variable we know | Size of the house |
y | Target | The output variable we're trying to predict | Price of the house |
m | Number of training examples | Amount of (x,y) pairs available for training | Dataset of known sizes and prices |
(xi, yi) | i-th training example | A particular pair from the data set | (120 m2, $450.000) |
In supervised learning the training data has features and known outputs. In this example, it would be a long list of houses with their sizes and past selling price. If the output (selling price) is not known, then the method is called unsupervised learning. In real-world applications supervised learning is more common.
In the case of housing prices, an easy way to create insight into the data is by creating a scatter plot of house sizes (x) against house prices (y).
Now what are we trying to achieve? The goal of linear regression is to find a simple model, or function, that takes a new x, and outputs a prediction called ŷ.
This model is simply a straight line expressed by the function:
Where x is the slope of the line and b the point where the line intersects the y-axis where x equals zero.
The challenge is to find an optimal line, or, pair of w and b, where the "vertical" distance to the data points is minimised. This means that given a known x, the predicted ŷ (the model) is close to the actual y.
This graph shows a rather good model. It just shows the "costs" of 9 points as an example, but if we were to calculate the total "cost" of all data points, the result is low, or perhaps even minimal. That would mean there are no other models where the total cost is lower.
Although it is possible to just try many values of w and b by hand, we have computers and algorithms now to do this work for us. So how do we know if some pair of (w, b) is any good? How do we know if there exist any better pairs?
The way to measure the effectiveness of a (w, b) pair is done with something called a Cost Function. AI people use the letter J for this function. It calculates the average difference between the expected and real-world output over all training samples.
So: high J is bad, low J is good.
The mathematical formula for the Cost Function J is:
Plotting different values of (w, b) against the Cost Function J creates a bowl shaped 3D-graph.
Remember, to find the optimal model for our data, we need to find the minimum, or the bottom of the bowl. By using an algorithm called Gradient Descent we can start anywhere and work our way towards this bottom.
Basically it works like this:
1. Start on a random point on the 3D bowl
2. Calculate the derivate (slope) of the Cost Function J(w,b)
3. Take a step downwards into the bowl
4. Use these new coordinates as input again for step 1
As the coordinates are slowly moving down into the bowl, the algorithm comes closer to the minimum.
After a while, the slope on the graph of the (w, b) pair approaches zero. In this algorithm, that means the steps also become smaller and smaller, until it converges.
Taking note of the values of the (w, b) pair right here allows us to formulate a model for our scatter plot, thereby finalizing our Linear Regression method.
A couple of considerations:
- The 3D-graph can be a simple bowl, but could also be a very irregular shape. There could be many hills and different valleys. It is not guaranteed the algorithm finds the global minimum. It may get stuck in a local minumum, because it always tries to go down, never up.
- If the step downwards is too small, the algorithm is very slow
- If the step downwards is too big, it may overshoot and fail to converge