1. Business Understanding¶
As the case intro states, the extreme weather conditions can negatively affect the reliability and capacity of the public transportation by potentially damaging the transport infrastructure. This leads to increased costs - not only for infrastructure and transport vehicles maintenance, but also by causing unexpected delays and decreasing the public transport reliability.
If we can predict such events, we would be able to take precautions - by planning costs in advance, by rescheduling in advance or by choosing (or advising for) alternative transportation methods.
As a side note, public transport disruptions do not affect only the government-provided transportation service - they could be used as a reference; as a signal for all drivers and passengers that there are problems on the road.
2. Data Understanding¶
Data collection
The data used in this project consists of the following:
- Dubai Weather data, provided in the case details
- Dubai public bus transportation data, downloaded from Dubai Pulse website. This includes all 640 files available between 01-01-2018 and 16-03-2020. This period is selected so that it fully covers the period of the weather data.
Data processing
The weather data needs to be cleaned of redundant columns. Visualizations need to be made, in order to look for dependencies, cyclic occurrences, relationships between features.
The bus transportation data needs to be downloaded, read and processed, most importantly - identifying the time difference between two consecutive rides of the same bus line. The aggregated information by date will be joined to the weather data and machine learning models will be applied in order to find relationships between weather conditions and transport duration disruptions. It is assumed that the weather conditions will mostly affect the bus ride duration, therefore the median time interval between bus rides will be evaluated. Median is chosen instead of mean value, because it is more reliable metric when dealing with non-normal data with outliers.
Other public transportation methods are out of scope.
Approach limitations
- Aggregated vs. granular data
One of the major limitations of the current project approach is that the public transportation data was aggregated at a daily (instead of an hourly) basis and therefore it lacks granularity.
- Additional data is needed
Looking for additional data could have been helpful, for example any historical infrastructure damages, government spend on road reconstruction by month, public transport rescheduling, etc. Also, weather forecasts for broader region, as well as traffic estimations, could improve the predictins and provide valuable information.
- Transportation type
Another limitation is that the current project is based on bus transportation data only, while other transport types could also be taken into account. It was assumed that the bus transportation is impacted at highest extent by the weather conditions and by other factors as well (e.g. other drivers on the road), therefore it was preferred over the other transportation types.
- Seasonal and hourly specifics have not been taken into account
Seasonality has not been implemented into the modeling part. Information about business holidays and hourly segments (day, night, peak, off-peak) could also improve the model results.
3. Data Preparation¶
3.1. Import libraries¶
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import itertools
from urllib.request import urlopen
from pandas.io.json import json_normalize
from datetime import datetime, timedelta
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (10, 5)
%matplotlib inline
3.2. Load raw data¶
3.2.1. Load weather data¶
response = urlopen("https://datacases.s3.us-east-2.amazonaws.com/datathon-2020/Ernst+and+Young/Dubai+Weather_20180101_20200316.txt")
rawdata = json.loads(response.read().decode('utf-8', 'replace'))
df = json_normalize(rawdata)
#weather = pd.DataFrame([y for x in df['weather'].values.tolist() for y in x])
for i in range(len(df.weather)):
if len(df.weather[i]) > 1:
print(f"{i}, {len(df.weather[i])}")
There are 3 cases with hourly record having more than one weather condition recorded. Only the first one will be kept.
for i in range(len(df.weather)):
df.weather[i] = df.weather[i][0]
condition = json_normalize(df.weather)
df = pd.concat([df, condition], axis=1)
data = df.copy()
data.head(3)
Creating timestamp:
data = data.reset_index()
for i in range(len(data.dt)):
data.loc[data.index==i, 'timestamp'] = datetime.fromtimestamp(data.dt[i])
data = data.drop(['index', 'dt', 'dt_iso', 'weather', 'id', 'icon', 'city_name',
'lat', 'lon', 'timezone', 'description'], axis=1)
There are missing values in both rain columns - they will be infilled with zeros.
data['rain.1h'].fillna(0, inplace=True)
data['rain.3h'].fillna(0, inplace=True)
3.2.2. Data Exploration¶
3.2.3.1 Data types¶
data.info()
data.rename(columns={'main.temp':'temp', 'main.temp_min':'temp_min', 'main.temp_max':'temp_max', 'main.feels_like': 'temp_feel',
'main.pressure':'pressure', 'main.humidity':'humidity', 'clouds.all':'cloudiness',
'wind.speed':'wind_speed', 'wind.deg':'wind_deg', 'rain.1h':'rain_1h', 'rain.3h':'rain_3h',
'main':'condition'}, inplace=True)
data.head()
data.describe()
3.2.3.2. Data visualization¶
f = plt.figure(figsize=(20,25));
gs = f.add_gridspec(4, 2)
ax1 = f.add_subplot(gs[0, 0])
sns.boxplot(data=data, x='condition', y='temp')
plt.title('Temperature Boxplots by Weather Condition')
ax2 = f.add_subplot(gs[0, 1])
sns.boxplot(data=data, x='condition', y='pressure')
plt.title('Atm. Pressure (hPa) by Weather Condition')
ax3 = f.add_subplot(gs[1, 0])
sns.boxplot(data=data, x='condition', y='humidity')
plt.title('Humidity (%) by Weather Condition')
ax4 = f.add_subplot(gs[1, 1])
sns.boxplot(data=data, x='condition', y='cloudiness')
plt.title('Cloudiness (%) by Weather Condition')
ax5 = f.add_subplot(gs[2, 0])
sns.boxplot(data=data, x='condition', y='wind_speed')
plt.title('Wind Speed (m/s) by Weather Condition')
ax6 = f.add_subplot(gs[2, 1])
sns.boxplot(data=data, x='condition', y='wind_deg')
plt.title('Wind Direction (degrees) by Weather Condition')
ax7 = f.add_subplot(gs[3, 0])
sns.boxplot(data=data, x='condition', y='rain_1h')
plt.title('Rain volume in mm (last 1 hour) by Weather Condition')
ax8 = f.add_subplot(gs[3, 1])
sns.boxplot(data=data, x='condition', y='rain_3h')
plt.title('Rain volume in mm (last 3 hours) by Weather Condition')
plt.show();
Several findings:
- The temperature is highest during Dust weather condition.
- While it's Clear, the average temperature is relatively high, but the range is also very high. However, highest temperatures (37+ degrees) were achieved only while it's clear, cloudy or dusty. On the other hand, the lowest temperatures were in foggy and misty weather.
- Lowest pressure is recorded in dusty weather.
- Low humidity while Clear, Cloudy, Dust and Smoke; highest during Fog, Mist and Haze.
- High wind spead during Dust, Thunderstorm and Smoke.
- No rain volume recorded during thunderstorms.
plt.rcParams['figure.figsize'] = (10, 5)
plt.title('Temperature throughout the period');
sns.lineplot(data=data, x='timestamp', y='temp_min', color='b');
sns.lineplot(data=data, x='timestamp', y='temp', color='yellow');
sns.lineplot(data=data, x='timestamp', y='temp_max', color='r');
sns.lineplot(data=data, x='timestamp', y='temp_feel', color='g');
plt.legend(['min','avg','max','feels']);
All temperature features are moving pretty much together, with highest temperatures recorded in July and August and lowest in January and February.
sns.lineplot(data=data, x='timestamp', y='pressure', color='c')
plt.title('Atmospheric pressure');
Generally, the atmospheric pressure varies within tight intervals on monthly basis, but shows a clear seasonal character, being highest in January and lowest in July. Later will be observed a correlation plot between the features that will likely highlight negative correlation between the atmospheric pressure and the temperature.
sns.lineplot(data=data, x='timestamp', y='humidity', color='orange')
plt.title('Humidity');
Although that the volatility of humidity varies greatly, here we can also observe some kind of cyclic behaviour, with highest humidity recorded in January and lowest in May-June.
sns.lineplot(data=data, x='timestamp', y='cloudiness', color='c')
plt.title('Cloudiness');
It's slightly more difficult to look for patterns in the dynamics of cloudiness percentage, but it looks a bit lower in May and June, as well as in some parts of February, August and October.
f=plt.figure(figsize=(20,5))
gs = f.add_gridspec(1,2)
f.add_subplot(gs[0,0])
sns.lineplot(data=data, x='timestamp', y='wind_speed', color='darkcyan')
plt.title('Wind speed');
f.add_subplot(gs[0,1])
sns.lineplot(data=data, x='timestamp', y='wind_deg', color='darkolivegreen')
plt.title('Wind direction (High=N,W; Low=E,S)');
sns.lineplot(data=data, x='timestamp', y='rain_1h', color='y')
sns.lineplot(data=data, x='timestamp', y='rain_3h', color='r')
plt.title('Rain volume (mm) by date')
plt.legend(['last 1 hour', 'last 3 hours']);
Due to Dubai's climate specifics, there is no much evidence of raining. Data for "the last 3 hours" does not exist before the end of 2019 - has either not been observed or not been collected. There is some sort of seasonality, although not so well established:
- in 2018 the highest rain volumes have been recorded in March and October.
- in 2019 rains have been recorded in February, April and May, as well as high quantities in November and December.
- in 2020 the highest figures are in January.
From June to September typically there's no raining.
sns.lineplot(data=data, x='condition', y='rain_1h', color='y');
sns.lineplot(data=data, x='condition', y='rain_3h', color='r');
plt.title('Rain volume by weather condition');
plt.legend(['last 1 hour','last 3 hours']);
corrmatrix = data.iloc[:,:9].corr()
plt.figure(figsize=(10,6))
ax = sns.heatmap(corrmatrix, vmin=-1, vmax=1, annot=True,);
bottom, top = ax.get_ylim()
ax.set_ylim(bottom + 0.5, top - 0.5)
plt.title('Correlation matrix')
plt.show();
Several observations:
- Confirmed is the strong negative correlation between the temperature and the pressure that was mentioned above. Many years ago Gay-Lussac has claimed for the opposite, but it seems that his statement applies only in closed systems with constant gas volume (Source: https://en.wikipedia.org/wiki/Gay-Lussac%27s_law#Pressure-temperature_law)
- There is somewhat strong negative correlation between the temperature and humidity. However, this negative correlation goes somewhat weaker when calculated between the human perception of the temperature ('feels like temperature') and the humidity
- There is some evidence of a negative relationship between the wind speed and humidity, meaning that probably the wind carries away the little water drops in the air
- There is a positive correlatin between the wind speed and wind degrees - meaning that when the wind speed is higher for the higher degree value, and therefore - the winds from North and West more often have higher speed than the ones from South and East (Source: https://snowfence.umn.edu/Components/winddirectionanddegrees.htm)
f = plt.figure(figsize=(20,6))
gs = f.add_gridspec(1,2)
f.add_subplot(gs[0,0])
ax1 = sns.heatmap(data.loc[data.condition=='Clear','temp':'cloudiness'].corr(), vmin=-1, vmax=1, annot=True, cmap='YlGnBu');
bottom, top = ax1.get_ylim()
ax1.set_ylim(bottom + 0.5, top - 0.5)
plt.title('Correlation matrix - Clear condition')
f.add_subplot(gs[0,1])
ax2 = sns.heatmap(data.loc[data.condition=='Thunderstorm','temp':'cloudiness'].corr(), vmin=-1, vmax=1, annot=True, cmap='YlGnBu');
bottom, top = ax2.get_ylim()
ax2.set_ylim(bottom + 0.5, top - 0.5)
plt.title('Correlation matrix - Thunderstorms')
plt.show();
In order to catch what specific changes happen during thunderstorm, we can observe the differences in correlation matrix during clear weather and during thunderstorms. Here are the main differences:
- In clear weather, there is a strong negative correlation between the temperature and atmospheric pressure. During thunderstorms such correlation almost does not exist.
- In clear weather, there is some, although weak, positive correlation between the temperature and the wind speed. During thunderstorms, there is no correlation. However, we can see that the human perception of the temperature does not correlate with the wind speed in clear weather, but correlates negatively with the wind speed during thunderstorms, and so does with the wind degrees - meaning that in low temperatures there wind directions are mainly North and West. This is also associated with higher humidity and higher cloudiness percentage.
f = plt.figure(figsize=(20,6))
gs = f.add_gridspec(1,2)
f.add_subplot(gs[0,0])
sns.lineplot(data=data, x='condition', y='pressure')
plt.title('Atmospheric pressure (mm) by weather condition');
gs = f.add_gridspec(1,2)
f.add_subplot(gs[0,1])
sns.lineplot(data=data, x='condition', y='humidity');
plt.title('Humidity (%) by weather condition');
plt.show();
There is some interesting evidence on the two charts above - that both atmospheric pressure and humidity are high during thunderstorms. They could be explored further, in order to find if a thunderstorm could be predicted by the dynamics of pressure and humidity, and probably wind speed/direction. However, this analysis will not be implemented in the current project.
data.tail()
3.2.3. Load public transportation data¶
The case description does not specify which particular type of public transportation to be analyzed, but due to the tight timelines the best approach would be to focus on the bus transportation - assuming that it should be the most affected type of transportation by weather conditions. Moreover, any bus transportation issues could be a signal for general road transportation ones - this is something that other public transportation types will not be able to hint about.
The Dubai Pulse website API seems to be inaccessible for non-UAE citizens, therefore the only way to acquire the data is by downloading separately each of the 640 CSV files for the bus transportation since 1st January 2018. To reduce the manual effort, Selenium webdriver is going to be set up to download the files one by one - putting a generous 10 seconds wait time between download requests to avoid server overloading. This action took about 5 hours and required a storage size of 60 GB.
from selenium import webdriver
import time
url = 'https://www.dubaipulse.gov.ae/data/allresource/rta-bus/rta_bus_ridership-open?organisation=rta&service=rta-bus&count=1212&result_per_page=1300&as_sfid=AAAAAAVQxFA0BFeFROVV-_FUrwIfaEqwRWpoZA-y-UptqSEqxmERCKYLhWrwqWh3AfDCDdi1moQM5yS3Qjy2NzBMeMFf3DsQYwQOBarG4FRgrDCOeBE9L_Tq7J9m8CMoTDSCXIY%3D&as_fid=ff49229e06fa994326e53390b91e89d1dc5e2954'
driver = webdriver.Firefox(executable_path=r'D:\Library\env\geckodriver.exe')
driver.get(url)
The easiest way to find the position of each 'Download' button is by getting all such buttons by their xpath:
results = driver.find_elements_by_xpath('//div[@class="file-action col-sm-12 col-xs-5"]')
print('Number of results', len(results))
There are 1212 bus transportation files uploaded, but some of them are attributed to 2017. Below is the code that downloads the files - executed in several separate runs, using the i-th parameter.
# the code below is commented out in the submitted version of this file
#i=571
#for result in results[i:]:
#result.click()
#time.sleep(10)
3.2.3. Explore sample bus transportation data - proof of concept¶
The purpose of this section is to set up a proof of concept for a later use - explore one of the bus transportation files, observe the data and find potential ways to extract the information that could be useful for the task. The biggest challenge would be find the most efficient way to process 600+ files, having in mind that some of them contain more than a million records. This makes the looping approach undesirable, because iterating through all the files will probably exceed the Datathon article submission deadline!
csv = pd.read_csv(r"D:\Library\Datathon\2020\Bus\bus_ridership_2018-01-16_00-00-00.csv")
def timestamping(date, time):
datesplit = date.split('-')
timesplit = time.split(':')
timestamp = datetime(int(datesplit[0]),int(datesplit[1]),int(datesplit[2]),int(timesplit[0]),int(timesplit[1]),int(timesplit[2]))
return timestamp
start = datetime.now()
csv['timestamp'] = csv.apply(lambda x: timestamping(x['txn_date'],x['txn_time']), axis=1)
print(f"time: {datetime.now() - start}")
Almost a minute to get the timestamp of the sample file is not great, but definitely better than iterating. The next task is to sort the data and calculate the time difference between consecutive transactions.
csv = csv.sort_values(by=['route_name','end_location','timestamp']).reset_index()
start = datetime.now()
csv['diff_prev'] = csv.timestamp.diff()
print(f"time: {datetime.now() - start}")
csv.head()
The initial approach was to find each transaction that is recorded 30 seconds or more after the previous (or before the following one) and assign a specific (increment) value to each. However, this process did not complete in more than half an hour for just one file and unfortunately the approach was discarded. A quicker one was to identify and get the first occurence of each group and put it into a new data frame.
csv2 = csv.loc[(csv.diff_prev > np.timedelta64(30,'s')) | (csv.diff_prev < np.timedelta64(0,'s')),:]
len(csv)
len(csv2)
From the initial 785k records in the file, about as much as 75% of the data was discarded and not used in the project. Now we can calculate the duration between the rides.
csv2['duration'] = (abs(csv2.timestamp.diff(periods=-1)) / np.timedelta64(1,'s'))
The last record has no duration, therefore needs manual calculation - its timestamp minus the timestamp of the last record from the initial (detailed) data.
csv2.iloc[-1,-1] = (csv.iloc[-1,-2] - csv2.iloc[-1,-3]) / np.timedelta64(1,'s')
csv2.head()
For the purposes of the project I will only use aggregated data from the bus transportation files.
csv2.duration.describe()
Expectedly, the data is non-normal and highly right-skewed. The mean value of duration is three times bigger than the median value, meaning that it is biased and affected by the outliers. The maximum duration (difference between two rides of a single bus line) is almost 24 hours - probably accounting to bus lines riding once per day. The standard deviation is also high - almost 3 hours. The most meaningful piece of the descriptive statistics is the median value, which is approx 17 minutes - this refers to the median time difference between different rides of a single bus line.
sns.distplot(csv2.duration)
plt.title('Ride duration histogram (28-01-2018)');
3.2.4. Read all bus transportation data¶
Now we have finished the proof of concept and we can try applying it on all the files. This is a huge effort, requiring more than 7 hours of processing time.
ride_duration = pd.DataFrame(columns=['date','count', 'mean', 'std', 'min', '25%', '50%', '75%', 'max'])
# this cell was rerun as blank at the end, before the submission, as the output was too long and irrelevant to the article
files = []
src = "D:/Library/Datathon/2020/Bus/"
i=1
with os.scandir(src) as files:
for file in files:
print(f"Processing file: {i}, the time is: {datetime.now()}")
csv = pd.read_csv(src + file.name)
# sort values, create timestamp, calculate time difference (in seconds) between rides of each bus line
csv['timestamp'] = csv.apply(lambda x: timestamping(x['txn_date'],x['txn_time']), axis=1)
csv = csv.sort_values(by=['route_name','end_location','timestamp']).reset_index()
csv['diff_prev'] = csv.timestamp.diff()
# get the first record in each ride, on each stop, for each bus line
csv2 = csv.loc[(csv.diff_prev > np.timedelta64(30,'s')) | (csv.diff_prev < np.timedelta64(0,'s')),:].copy()
# calculate ride durations
csv2['duration'] = (abs(csv2.timestamp.diff(periods=-1)) / np.timedelta64(1,'s'))
# get date and duration descriptive statistics from each provided file
filedate = file.name.split('_')[2]
datetime_file = datetime(int(filedate.split('-')[0]), int(filedate.split('-')[1]), int(filedate.split('-')[2]))
a = []
b = []
b.append(csv2.duration.describe().values)
b = list(itertools.chain.from_iterable(b))
a.append([datetime_file,b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])
# save the date and duration descriptive statistics from each file to ride_duration data frame
ride_duration = pd.concat([ride_duration, pd.DataFrame(a,
columns=['date','count', 'mean', 'std', 'min', '25%', '50%',
'75%', 'max'])])
i+=1
ride_duration_backup = ride_duration.copy()
ride_duration.head()
sns.lineplot(data=ride_duration, x='date', y='50%');
plt.ylabel('median');
plt.title('Median bus ride duration by date');
plt.show();
3.2.4. Combine weather and public transportation data¶
data['date'] = pd.to_datetime(data.timestamp).apply(lambda x: x.date())
ride_duration['date'] = pd.to_datetime(ride_duration.date).apply(lambda x: x.date())
combined_data = data.join(ride_duration[['date','50%','std']].rename(columns={'50%':'median'}).set_index('date'), on='date')
combined_data.info()
combined_data.loc[combined_data['median'].isna()].date.unique()
Unfortunately, bus transportation statistics is indeed not provided for these dates on the Dubai Pulse website. Moreover, the missing dates account for 4007 data records in the weather dataset, which is about 20% of the total. Dropping such a big part of the data is not recommended, therefore the better approach would be to infill the NULLs with the median values for the respective month, and thus, to keep the impact (if any) of the seasonal fluctuations.
combined_data['month'] = combined_data.date.apply(lambda x: str(x)[:7])
monthly_agg = combined_data.groupby('month')['median','std'].median()
combined_data = combined_data.join(monthly_agg.rename(columns={'median':'monthly_median', 'std':'monthly_std'}), on='month')
combined_data.set_index('timestamp', inplace=True)
for i in combined_data.loc[combined_data['median'].isna()].index:
combined_data.loc[combined_data.index == i, 'median'] = combined_data.loc[combined_data.index == i, 'monthly_median']
combined_data.loc[combined_data.index == i, 'std'] = combined_data.loc[combined_data.index == i, 'monthly_std']
combined_data.info()
combined_data.tail(10)
As only the last 3 data records have no value for condition, and the several previous have the value 'Clear', we can fill 'Clear' to the last 3 as well. Several columns, including the description, will not be needed and can be dropped.
combined_data.condition.fillna('Clear', inplace=True)
final_data = combined_data.drop(['date', 'description', 'month', 'monthly_median', 'monthly_std'], axis=1)
final_data.head()
4. Modelling¶
4.1. Pre-processing¶
X = final_data.iloc[:,:-2]
y = final_data.iloc[:,-2]
The 'condition' column is a categorical one, therefore needs to be one-hot encoded.
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
cond_labelenc = LabelEncoder()
X.iloc[:,-1] = cond_labelenc.fit_transform(X.iloc[:,-1])
X.head(2)
from sklearn.compose import ColumnTransformer
ct = ColumnTransformer([('one_hot_encoder', OneHotEncoder(categories='auto'), [11])], remainder='passthrough')
X = ct.fit_transform(X)
X[0]
y = np.array(y)
sns.distplot(y);
The distribution of the dependent variable is not normal, so we can improve the models performance if we transform it by log(1+y).
y = np.log1p(y)
Now we can split the data into training and testing sets.
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
The metric that is going to be used to evaluate the models is the mean absolute error. There will also be a visual representation of the predictions.
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import cross_val_score
4.2. Regression models¶
The following models will be put into action:
- Multiple Linear Regression
- Random Forest Regression
- Support Vector Regression
- Extra Gradient Boosting Regressor
4.2.1. Multiple Linear Regression¶
from sklearn.linear_model import LinearRegression
regressor1 = LinearRegression()
regressor1.fit(X_train, y_train)
y_pred1 = np.expm1(regressor1.predict(X_test))
MAE1 = mean_absolute_error(np.expm1(y_test), y_pred1)
print(f"Mean absolute error: {MAE1}")
4.2.2. Random Forest Regression¶
Here we will also apply grid search to find the best parameters for the model
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV
parameters = {'n_estimators': [2, 3, 5, 10, 15, 20, 50, 100, 200, 500, 750, 1000],
'max_leaf_nodes': [3, 5, 10, 20, 35, 50],
'random_state': [0]}
grid_search = GridSearchCV(estimator = RandomForestRegressor(),
param_grid = parameters,
scoring = 'neg_mean_absolute_error',
cv = 5,
n_jobs = -1)
grid_search = grid_search.fit(X_train, y_train)
print(f"Best MAE: {grid_search.best_score_ * (-1)}")
print(f"Best parameters: {grid_search.best_params_}")
regressor2 = RandomForestRegressor(n_estimators=200, max_leaf_nodes=5, random_state=0)
regressor2.fit(X_train, y_train)
y_pred2 = np.expm1(regressor2.predict(X_test))
MAE2 = mean_absolute_error(np.expm1(y_test), y_pred2)
print(f"Mean absolute error: {MAE2}")
4.2.3. Support Vector Regression¶
from sklearn.svm import SVR
The Support Vector models work best if the data is scaled first.
from sklearn.preprocessing import StandardScaler
sc_X = StandardScaler()
sc_y = StandardScaler()
sc_X_train = sc_X.fit_transform(X_train)
sc_y_train = sc_y.fit_transform(y_train.reshape(-1,1))
sc_X_test = sc_X.fit_transform(X_test)
parameters = {'C': [0.5, 1, 5, 10, 20, 50, 100],
'kernel': ['rbf']}
grid_search = GridSearchCV(estimator = SVR(),
param_grid = parameters,
scoring = 'neg_mean_absolute_error',
cv = 5,
n_jobs = -1)
grid_search = grid_search.fit(sc_X_train, sc_y_train)
print(f"Best MAE: {grid_search.best_score_ * (-1)}")
print(f"Best parameters: {grid_search.best_params_}")
regressor3 = SVR(kernel = 'rbf', C = 20)
regressor3.fit(sc_X_train, sc_y_train)
y_pred3 = regressor3.predict(sc_X_test)
y_pred3 = np.expm1(sc_y.inverse_transform(y_pred3))
MAE3 = mean_absolute_error(np.expm1(y_test), y_pred3)
print(f"Mean absolute error: {MAE3}")
4.2.4. Extra Gradient Boosting (XGBoost) Regressor¶
from xgboost import XGBRegressor
parameters = {'base_score': [0.1, 0.5, 0.7, 1, 2],
'learning_rate': [0.01, 0.03, 0.1, 0.3],
'n_estimators': [20, 50, 100, 150, 200, 300, 400, 1000],
'max_depth': [3, 5, 7, 9]}
grid_search = GridSearchCV(estimator = XGBRegressor(),
param_grid = parameters,
scoring = 'neg_mean_absolute_error',
cv = 5,
n_jobs = -1)
#grid_search = grid_search.fit(X_train, y_train)
#print(f"Best MAE: {grid_search.best_score_ * (-1)}")
#print(f"Best parameters: {grid_search.best_params_}")
regressor4 = XGBRegressor(learning_rate=0.02, n_estimators=300, max_depth=9)
regressor4.fit(X_train, y_train)
y_pred4 = np.expm1(regressor4.predict(X_test))
MAE4 = mean_absolute_error(np.expm1(y_test), y_pred4)
print(f"Mean absolute error: {MAE4}")
5. Evaluation¶
5.1. Models summary¶
models_summary = {'Multiple Linear': MAE1, 'Random Forest': MAE2, 'SVR': MAE3, 'XGB': MAE4}
sns.barplot(x=list(models_summary.keys()), y=list(models_summary.values()), color='r')
plt.title("Mean absolute error by model (the lower, the better)");
The lowest mean absolute error is returned by the SVR model. However, the task is not as trivial as usually - we need not only to find the best fitting model, but also the model that does the best job in predicting the outliers - because they are the ones that are causing the issue described in the business understanding section.
plt.figure(figsize=(20,10))
sns.lineplot(data=y_test, color='r').set(yscale='log');
sns.lineplot(data=np.log1p(y_pred4), color='b').set(yscale='log');
sns.lineplot(data=np.log1p(y_pred3), color='y').set(yscale='log');
sns.lineplot(data=np.log1p(y_pred2), color='brown').set(yscale='log');
sns.lineplot(data=np.log1p(y_pred1), color='cyan').set(yscale='log');
plt.title('Models summary: predictions vs the test set')
plt.legend(['Test set','XGB prediction','SVM prediction','RF prediction','Linear prediciton']);