본문 바로가기
Tech Development/Deep Learning (CNN)

Batch Initialization, Overfitting, Dropout, and Optimization

by JK from Korea 2022. 12. 16.

Batch Initialization, Overfitting, Dropout, and Optimization

 

Date : 2022.10.16

 

*The contents of this book is heavily based on Stanford University’s CS231n course.

 

[Batch Normalization]

In the previous post, we explored various methods for weight initialization. The purpose of weight initialization was to evenly spread the activation outputs among all nodes.

 

Batch normalization is a method to spread the activation outputs without relying on weight initialization. The benefits of batch normalization are the following.

 

  1. Improves Learning Rate
  2. Decreases Reliance on Weight Initialization
  3. Prevents Overfitting

 

CNN with batch normalization looks like the following.

 

[Batch Normalization between Layers]
[Batch Normalization Steps]

The following process converts a batch with Average = 0, Variation = 1, which is the ultimate goal of batch normalization. Each batch layer shifts and scales the normalized data.

 

𝛽 is the shift constant and 𝛾 is the scale constant. The initial values for each are 0 and 1. As the learning proceeds the values adjust accordingly.

 

[Computational Graph for Batch Normalization]

 

[Batch Normalization Test]

With MNIST, let’s compare the learning rates.

 

[Batch Normalization Increases the Learning Rate]
[Comparison for Different Weights]

The W above each graph is the std of initial weight inputs. The outcome of multiple scenarios show that the initial weight distribution highly affects the learning curve. However, despite the different weight initializations, you can see that batch normalization still returns stable learning curvatures. You can find the code on my github page.

 

[Overfitting]

Overfitting refers to a network that has adapted to a specific data set. This prevents the network’s performance on unprecedented data. To prevent this we are going to implement weight decay and dropouts.

 

[When does overfitting occur?]

 

  1. Complex Model (multiple layers and variables)
  2. Insufficient Data
  3. Uneven distribution or Outlier in Weight Variable

 

[Weight Decay]

If a weight variable is noticeably greater, it may cause the model to overfit for certain training data since the weight variables adjust value according to different characteristics of the input data.

 

Thus weight decay in simple terms is giving a penalty to weights that are relatively standing out.

 

[Simply Set the Weight Decay Variable]
[Accuracy Comparison]

[Dropout]

Dropout randomly selects and temporarily deletes nodes within the network. Followed the book to create a trainer which automates the training process as a separate object. The outcome for dropout is as following.

 

[Dropout = True]

The difference between Train and Test has decreased.

 

[Finding the Appropriate Hyperparameter]

Hyperparameters are values we need to initialize. It can include but not limited to the number of nodes per layer, batch size, learning rate, weight decay etc. These values don’t have a fixed answer. We must find the optimal value through trial and error. Although intuition and previous studies narrow down the options for us.

 

Before testing hyperparameters, we need to separate the data into three separate groups to avoid overfitting.

 

  1. Validation Data: For Hyperparameter Experimentation
  2. Train Data: Train Weight
  3. Test Data: Tracks Accuracy and Network Functionality

 

The shuffle function mixes the data randomly in a size that we want.

 

[Data Separation]

Hyperparameter optimization takes a long time due to the number of cases. To expedite the process, we can reduce the epoch and data size to a certain degree. The code is on github, and the results are presented below.

 

[Testing]
[Results]

 

It is reasonable to say that the top 5 ~ 6 function better than other trials. From the results we can narrow down the learning rate and weight decay constant to 0.005 ~ 0.009 and 10e-8 ~ 10e-5. From there on we continue to narrow down the hyperparameters until we decide to stop and choose at a reasonably high accuracy.

 

728x90
반응형

댓글