1.3 Minibatch Gradient Descent
Sec. 1.1의 Linear regression에서는 3개의 데이터만 사용했다. 반면 Sec. 1.2의 Multivariable linear regression에서는 5개의 데이터를 사용했다. 복잡한 모델을 학습하기 위해서는 많은 양의 데이터가 필요하다. 10만, 100만, 1000만 단위의 데이터를 사용할수록 점점 정확한 모델을 얻을 수 있다.
그러나 이 많은 양의 데이터를 한 번에 학습시킬 수가 없다. 각 데이터마다 cost function을 계산하는 과정에서 많은 시간이 소요된다.
일부분의 데이터만을 가져와 학습하는 건 어떨까?
1.3.1 Minibatch Gradient Descent
아이디어는 다음과 같다. 전체 데이터를 여러 균일한 크기의 작은 데이터(minibatch)들로 나누어서 학습시키자!
각각의 minibatch에 대한 cost를 계산하여 학습시킨다. 모든 데이터를 사용하지 않기 위해 한 번의 계산마다 업데이트 주기가 빨라진다. 대신 잘못된 방향으로 업데이트를 할 가능성이 있다.
다음은 데이터를 minibatch로 쪼개는 데 사용되는 torch.utils.data
패키지이다.
- import Dataset : torch.utils.data.Dataset 클래스 상속
- __len__() : 이 dataset의 총 데이터 수를 반환
- __getitem__() : index를 받았을 때 이 위치에 있는 입출력 데이터 반환
- import DataLoader : torch.utils.data.DataLoader 클래스 상속
- batch_size : 각 minibatch의 크기. 보통 2의 제곱수로 설정한다.
- shuffle : Epoch마다 dataset을 섞어 데이터가 학습되는 순서를 바꾼다.
이제 minibatch를 사용해 데이터를 학습해보자.
- enumerate(dataloader) : minibatch의 index와 data를 받아 온다
- len(dataloader) : 한 epoch 당 minibatch의 개수
---
- Result :
Epoch 100/1000 Batch 1/3 hypothesis: tensor([4284.1865, 1546.3556]), W: tensor([30.0853, 35.0039, 59.9021]), b: 1.9362, Cost: 1.249685
Epoch 100/1000 Batch 2/3 hypothesis: tensor([3186.9971, 1987.5029]), W: tensor([30.0738, 34.9962, 59.8934]), b: 1.9358, Cost: 5.126473
Epoch 100/1000 Batch 3/3 hypothesis: 4405.82177734375, W: tensor([30.0664, 34.9934, 59.8866]), b: 1.9356, Cost: 0.675318
Epoch 200/1000 Batch 1/3 hypothesis: tensor([3185.6082, 1545.4385]), W: tensor([29.9730, 34.9332, 59.9984]), b: 1.9279, Cost: 0.281057
Epoch 200/1000 Batch 2/3 hypothesis: tensor([4284.7207, 4404.5107]), W: tensor([29.9759, 34.9347, 60.0017]), b: 1.9280, Cost: 0.158690
Epoch 200/1000 Batch 3/3 hypothesis: 1984.9007568359375, W: tensor([29.9762, 34.9352, 60.0019]), b: 1.9280, Cost: 0.009849
Epoch 300/1000 Batch 1/3 hypothesis: tensor([1545.3967, 3185.3601]), W: tensor([29.9675, 34.9334, 59.9988]), b: 1.9235, Cost: 0.143535
Epoch 300/1000 Batch 2/3 hypothesis: tensor([4404.2798, 1984.6947]), W: tensor([29.9713, 34.9354, 60.0020]), b: 1.9236, Cost: 0.305958
Epoch 300/1000 Batch 3/3 hypothesis: 4284.9013671875, W: tensor([29.9717, 34.9359, 60.0030]), b: 1.9236, Cost: 0.009728
Epoch 400/1000 Batch 1/3 hypothesis: tensor([4285.0542, 1545.4668]), W: tensor([29.9732, 34.9336, 60.0039]), b: 1.9192, Cost: 0.110418
Epoch 400/1000 Batch 2/3 hypothesis: tensor([1984.8416, 3185.6094]), W: tensor([29.9712, 34.9336, 60.0022]), b: 1.9191, Cost: 0.198222
Epoch 400/1000 Batch 3/3 hypothesis: 4404.58447265625, W: tensor([29.9750, 34.9350, 60.0056]), b: 1.9192, Cost: 0.172663
Epoch 500/1000 Batch 1/3 hypothesis: tensor([1545.4055, 1984.8169]), W: tensor([29.9734, 34.9331, 60.0016]), b: 1.9148, Cost: 0.098986
Epoch 500/1000 Batch 2/3 hypothesis: tensor([4284.8662, 4404.6455]), W: tensor([29.9753, 34.9340, 60.0037]), b: 1.9148, Cost: 0.071782
Epoch 500/1000 Batch 3/3 hypothesis: 3185.677001953125, W: tensor([29.9704, 34.9331, 59.9995]), b: 1.9147, Cost: 0.458332
Epoch 600/1000 Batch 1/3 hypothesis: tensor([1545.4696, 4285.0161]), W: tensor([29.9740, 34.9349, 60.0025]), b: 1.9104, Cost: 0.110394
Epoch 600/1000 Batch 2/3 hypothesis: tensor([4404.7349, 1984.8646]), W: tensor([29.9754, 34.9357, 60.0037]), b: 1.9104, Cost: 0.044312
Epoch 600/1000 Batch 3/3 hypothesis: 3185.68994140625, W: tensor([29.9704, 34.9347, 59.9995]), b: 1.9103, Cost: 0.476019
Epoch 700/1000 Batch 1/3 hypothesis: tensor([1984.9572, 1545.5155]), W: tensor([29.9746, 34.9362, 60.0044]), b: 1.9060, Cost: 0.133790
Epoch 700/1000 Batch 2/3 hypothesis: tensor([4404.8613, 3185.6833]), W: tensor([29.9728, 34.9360, 60.0029]), b: 1.9060, Cost: 0.243098
Epoch 700/1000 Batch 3/3 hypothesis: 4284.9697265625, W: tensor([29.9729, 34.9361, 60.0032]), b: 1.9060, Cost: 0.000916
Epoch 800/1000 Batch 1/3 hypothesis: tensor([4404.4023, 1984.7732]), W: tensor([29.9724, 34.9376, 60.0019]), b: 1.9019, Cost: 0.204317
Epoch 800/1000 Batch 2/3 hypothesis: tensor([4284.9487, 1545.4744]), W: tensor([29.9723, 34.9367, 60.0016]), b: 1.9018, Cost: 0.113825
Epoch 800/1000 Batch 3/3 hypothesis: 3185.511474609375, W: tensor([29.9686, 34.9360, 59.9984]), b: 1.9017, Cost: 0.261606
Epoch 900/1000 Batch 1/3 hypothesis: tensor([1984.8707, 4284.9219]), W: tensor([29.9721, 34.9380, 60.0023]), b: 1.8976, Cost: 0.011407
Epoch 900/1000 Batch 2/3 hypothesis: tensor([1545.4807, 3185.5295]), W: tensor([29.9699, 34.9366, 60.0002]), b: 1.8975, Cost: 0.255749
Epoch 900/1000 Batch 3/3 hypothesis: 4404.4716796875, W: tensor([29.9747, 34.9384, 60.0045]), b: 1.8976, Cost: 0.279122
Epoch 1000/1000 Batch 1/3 hypothesis: tensor([1545.5229, 4285.1377]), W: tensor([29.9742, 34.9369, 60.0034]), b: 1.8933, Cost: 0.146218
Epoch 1000/1000 Batch 2/3 hypothesis: tensor([3185.6277, 4404.7983]), W: tensor([29.9728, 34.9368, 60.0023]), b: 1.8932, Cost: 0.217328
Epoch 1000/1000 Batch 3/3 hypothesis: 1984.8743896484375, W: tensor([29.9732, 34.9374, 60.0026]), b: 1.8933, Cost: 0.015778