WEBVTT

00:00.630 --> 00:03.150
Narrator: Hello and welcome to this Python tutorial.

00:03.150 --> 00:05.130
All right, so now in the next function

00:05.130 --> 00:06.510
that we're about to implement,

00:06.510 --> 00:08.790
we will train the deep neural network

00:08.790 --> 00:11.460
that is inside our artificial intelligence.

00:11.460 --> 00:13.980
So basically we're gonna do the whole process

00:13.980 --> 00:16.920
of forward propagation and then back propagation.

00:16.920 --> 00:19.140
So that it is, we're gonna get our output,

00:19.140 --> 00:20.820
we're gonna get the target,

00:20.820 --> 00:22.560
we will compare the output to the target

00:22.560 --> 00:24.510
to compute the loss error.

00:24.510 --> 00:27.180
Then we're gonna back propagate this loss error

00:27.180 --> 00:28.620
into the neural network

00:28.620 --> 00:30.450
and using stochastic gradient descent,

00:30.450 --> 00:31.650
we will update the weight

00:31.650 --> 00:35.160
according to how much they contributed to the loss error.

00:35.160 --> 00:36.480
So let's do all this.

00:36.480 --> 00:39.030
For those of you coming from the deep learning course,

00:39.030 --> 00:40.410
this will be good stuff,

00:40.410 --> 00:42.120
but for the others, don't worry,

00:42.120 --> 00:44.130
I'm going to explain that again.

00:44.130 --> 00:47.670
So we're gonna call this new function learn.

00:47.670 --> 00:51.630
And this learn function is going to take several arguments.

00:51.630 --> 00:53.730
First self, of course,

00:53.730 --> 00:57.510
which will refer to the object of the DQN class.

00:57.510 --> 01:01.020
Then we're gonna take our batch state

01:01.020 --> 01:02.640
for the current state,

01:02.640 --> 01:07.640
then our batch next state

01:07.680 --> 01:11.250
then our batch reward,

01:11.250 --> 01:15.540
and finally our batch action.

01:15.540 --> 01:16.950
So why do we take this?

01:16.950 --> 01:19.860
You probably recognized what is the series.

01:19.860 --> 01:22.050
Well, that's of course a transition.

01:22.050 --> 01:25.440
A transition of the mark of decision process that is

01:25.440 --> 01:27.480
at the base of deep Q learning.

01:27.480 --> 01:30.540
And why do we all take them into some batches?

01:30.540 --> 01:32.160
Well, that's because, you know, remember,

01:32.160 --> 01:34.110
We don't consider the transitions

01:34.110 --> 01:36.690
by a series of the top current state,

01:36.690 --> 01:39.660
next state, current reward, and current action.

01:39.660 --> 01:42.480
We created some sample batches here

01:42.480 --> 01:44.190
thanks to the sample function.

01:44.190 --> 01:46.830
And so now our transitions are in the form,

01:46.830 --> 01:48.510
the first batch for this date,

01:48.510 --> 01:50.550
a second batch for the next date,

01:50.550 --> 01:53.490
a batch for the reward, and a batch for the action.

01:53.490 --> 01:55.650
That's the form of our transitions now

01:55.650 --> 01:58.530
and they're all well aligned with respect to time

01:58.530 --> 02:01.470
thanks to this concatenation that we made here

02:01.470 --> 02:04.140
with respect to the first dimension.

02:04.140 --> 02:08.160
So the point is, now we have these transition of batches.

02:08.160 --> 02:09.630
One batch for each of the states,

02:09.630 --> 02:11.640
next state, reward and action.

02:11.640 --> 02:12.510
And we do all this

02:12.510 --> 02:15.480
because we're using this experience replay trick

02:15.480 --> 02:18.570
so that our deep neural network can learn something.

02:18.570 --> 02:22.200
Remember, if we only had the transitions by themselves,

02:22.200 --> 02:24.240
well it would be some instant learning

02:24.240 --> 02:26.640
or if you want some very short memory learning

02:26.640 --> 02:29.130
and therefore the model wouldn't learn anything.

02:29.130 --> 02:32.820
So we have to take these batches from the memory

02:32.820 --> 02:34.830
which become our transitions

02:34.830 --> 02:37.200
and then eventually we will get the different outputs

02:37.200 --> 02:39.510
for each of the states of the input batch state.

02:39.510 --> 02:41.760
And we will do this for the batch state

02:41.760 --> 02:43.170
and for the batch next state

02:43.170 --> 02:45.660
because we will need both to compute the loss.

02:45.660 --> 02:48.630
I will soon remind the balance equation

02:48.630 --> 02:51.840
that is at the heart of the deep Q learning algorithm.

02:51.840 --> 02:53.550
So now let's go into the function

02:53.550 --> 02:57.150
and let's first get the outputs of the batch state.

02:57.150 --> 02:59.763
So I'm gonna call this first variable outputs.

03:00.660 --> 03:05.370
And then we're gonna take of course our self dot model.

03:05.370 --> 03:08.940
So self dot model,

03:08.940 --> 03:12.110
because we want to get our model outputs

03:12.110 --> 03:14.490
of the input states of the batch state.

03:14.490 --> 03:18.150
And since our model is actually expecting a batch

03:18.150 --> 03:19.260
of input state

03:19.260 --> 03:23.820
well we can totally input batch state right now

03:23.820 --> 03:25.320
for the input of the model,

03:25.320 --> 03:28.230
that's exactly how we initialized this state

03:28.230 --> 03:29.820
that are going into the network

03:29.820 --> 03:33.930
with a torch tensor, with this fake dimension for the batch.

03:33.930 --> 03:35.130
So that's perfect.

03:35.130 --> 03:37.860
We now get the outputs of the model.

03:37.860 --> 03:40.770
But then there is another technical trick.

03:40.770 --> 03:43.830
If we only do self dot model batch state

03:43.830 --> 03:45.870
well we will get the outputs

03:45.870 --> 03:47.820
of all the possible actions,

03:47.820 --> 03:49.500
you know, zero, one, and two.

03:49.500 --> 03:50.900
But that's not what we want.

03:50.900 --> 03:54.720
We are only interested in the actions that were chosen.

03:54.720 --> 03:56.130
The actions that we're decided

03:56.130 --> 03:58.890
by the network to play at each time.

03:58.890 --> 04:01.350
And so to get these action we're interested in,

04:01.350 --> 04:03.000
that is the action displayed,

04:03.000 --> 04:07.200
well we have to use this gather function

04:07.200 --> 04:09.060
in which we input one,

04:09.060 --> 04:12.150
because we only want the action that was chosen

04:12.150 --> 04:15.540
and then we add batch action.

04:15.540 --> 04:18.180
With this, the one and the batch action,

04:18.180 --> 04:22.080
we will gather each time the best action to play

04:22.080 --> 04:24.780
for each of the input states of the batch state.

04:24.780 --> 04:26.940
We only want the action that is played

04:26.940 --> 04:28.320
the action that is chosen.

04:28.320 --> 04:32.130
And we get this with this gather one and batch action.

04:32.130 --> 04:33.420
But then be careful,

04:33.420 --> 04:36.960
the batch state here has this fake dimension

04:36.960 --> 04:38.430
corresponding to the batch,

04:38.430 --> 04:40.170
and batch action doesn't have it.

04:40.170 --> 04:42.870
Batch state has it because we used the unsqueeze here,

04:42.870 --> 04:46.410
but we haven't used any unsqueeze for the actions.

04:46.410 --> 04:48.060
So we have to add it here

04:48.060 --> 04:51.450
so that the batch action has the exact same dimension

04:51.450 --> 04:53.130
as the batch state.

04:53.130 --> 04:58.100
So we're gonna add a dot unsqueeze zero right here.

04:59.130 --> 05:02.220
And actually this is not zero, but one,

05:02.220 --> 05:04.950
because zero corresponds to the fake dimension

05:04.950 --> 05:05.790
of the state

05:05.790 --> 05:06.900
and one will correspond

05:06.900 --> 05:09.360
to the fake dimension of the actions.

05:09.360 --> 05:12.480
And finally, the last thing we need to do here

05:12.480 --> 05:16.500
is we need to kill this fake batch with a squeeze.

05:16.500 --> 05:18.120
And why do we need to do that?

05:18.120 --> 05:20.160
Because now we are out of the neural network,

05:20.160 --> 05:22.920
we have our outputs, but we don't want them in a batch.

05:22.920 --> 05:26.250
We want them in a simple tensor, or a simple vector

05:26.250 --> 05:27.573
a vector of outputs.

05:28.943 --> 05:30.980
The batch is just when we work in the neural network

05:30.980 --> 05:32.940
because the neural network is expecting the format

05:32.940 --> 05:35.010
of tensors into a batch.

05:35.010 --> 05:36.450
But now we have our outputs

05:36.450 --> 05:39.480
and in the next balance equation of the deep Q learning,

05:39.480 --> 05:41.520
we won't need them into a batch.

05:41.520 --> 05:43.200
So I'm killing the batch here.

05:43.200 --> 05:44.970
I'm killing the fake dimension

05:44.970 --> 05:48.150
to get back the simple form of our outputs.

05:48.150 --> 05:51.510
So I'm just adding here dot and then squeeze.

05:51.510 --> 05:54.060
And then, since I want to kill the fake dimension

05:54.060 --> 05:56.220
corresponding to the batch of the action,

05:56.220 --> 05:59.490
well since this fake dimension has index one

05:59.490 --> 06:01.530
I'm adding one here.

06:01.530 --> 06:03.630
All right, and now there we go,

06:03.630 --> 06:05.490
we have our outputs.

06:05.490 --> 06:07.920
Okay, we have a little warning, what is it?

06:07.920 --> 06:11.160
Local variable outputs is assigned but never used.

06:11.160 --> 06:13.890
That's okay, we will use it very quickly.

06:13.890 --> 06:15.600
So that's our outputs.

06:15.600 --> 06:20.490
And now we want to get our next outputs.

06:20.490 --> 06:21.930
So now you might be thinking

06:21.930 --> 06:23.820
why do we need the next outputs?

06:23.820 --> 06:25.680
Well, to understand this, we need to go back

06:25.680 --> 06:28.920
to the deep Q learning algorithm, which is right here.

06:28.920 --> 06:31.830
That's a part of the Lattik handbook.

06:31.830 --> 06:33.840
So that's the whole deep Q learning process.

06:33.840 --> 06:36.750
At the beginning, we're initializing all the Q values.

06:36.750 --> 06:40.140
And then at each time T, well, there we go,

06:40.140 --> 06:42.270
we select the action with Softmax.

06:42.270 --> 06:44.850
That's what we did with the select action function.

06:44.850 --> 06:47.490
Then we append the transition and then,

06:47.490 --> 06:49.590
as you can see, we get the prediction,

06:49.590 --> 06:52.170
we get the target and we compute the loss.

06:52.170 --> 06:54.330
So why do we need the next outputs as well?

06:54.330 --> 06:56.160
That's because of the target.

06:56.160 --> 06:57.180
The target is equal

06:57.180 --> 07:01.470
to gamma times the next output plus the reward

07:01.470 --> 07:03.840
And we will compute the targets right after that.

07:03.840 --> 07:06.150
But since we need the next output for the target,

07:06.150 --> 07:07.800
let's compute this first.

07:07.800 --> 07:10.950
So again, to get the next output, that's very simple.

07:10.950 --> 07:13.170
The next output is going to be the result

07:13.170 --> 07:15.000
of our neural network

07:15.000 --> 07:19.140
when the batched next date is entering it as input.

07:19.140 --> 07:23.370
So very simply, we take our model,

07:23.370 --> 07:25.140
that is our neural network

07:25.140 --> 07:27.900
and then this time the input of the neural network

07:27.900 --> 07:30.693
is going to be the batch next date.

07:31.650 --> 07:33.150
The batch next date.

07:33.150 --> 07:34.080
But now remember,

07:34.080 --> 07:37.140
if we go back to the deep Q learning algorithm,

07:37.140 --> 07:40.500
well you can see that the next output

07:40.500 --> 07:43.800
is the maximum of the Q values for the next state

07:43.800 --> 07:45.690
with respect to all the actions.

07:45.690 --> 07:47.310
So right now, to get the next output

07:47.310 --> 07:50.190
we need to get the maximum of this Q values.

07:50.190 --> 07:54.300
And therefore I'm gonna do here a detach,

07:54.300 --> 07:57.690
you know to detach all the outputs of the model.

07:57.690 --> 08:01.500
Because we have several states in this batch next states.

08:01.500 --> 08:03.750
That's the batch of all the next states

08:03.750 --> 08:05.610
and all the transitions taken

08:05.610 --> 08:07.830
from the random sample of our memory.

08:07.830 --> 08:11.190
So I'm detaching all of them using the detach function.

08:11.190 --> 08:15.300
And then I'm taking the max of all these Q values.

08:15.300 --> 08:19.033
And since we are taking the maximum of these Q values

08:19.033 --> 08:19.866
with respect to the action,

08:19.866 --> 08:20.940
well we have to specify

08:20.940 --> 08:23.070
that it is with respect to the action.

08:23.070 --> 08:26.490
And since the action is represented by the index one

08:26.490 --> 08:29.820
well again, we have to put the index one here

08:29.820 --> 08:33.030
and then we have to specify that we're taking the Q values

08:33.030 --> 08:36.270
of S.T. plus one, that is the next state.

08:36.270 --> 08:39.330
And the next state is represented by the index zero

08:39.330 --> 08:42.480
because the index zero corresponds to the state.

08:42.480 --> 08:45.570
And therefore here we need to add brackets

08:45.570 --> 08:47.760
with the index zero.

08:47.760 --> 08:51.030
That way we get the maximum of the Q values

08:51.030 --> 08:53.970
of the next state, represented by the index zero

08:53.970 --> 08:55.590
according to all the actions

08:55.590 --> 08:58.170
that are represented by the index one.

08:58.170 --> 09:01.560
And now, perfect, we get our next outputs.

09:01.560 --> 09:02.880
These are unused yet,

09:02.880 --> 09:04.320
that's why we had the warning,

09:04.320 --> 09:05.153
but that's fine.

09:05.153 --> 09:08.430
We will use it right now to compute the target.

09:08.430 --> 09:09.840
And speaking of the target

09:09.840 --> 09:12.510
that's the next step of this learn function.

09:12.510 --> 09:13.343
So there we go,

09:13.343 --> 09:15.660
target equals...

09:15.660 --> 09:18.360
Now let's get back to our AI handbook.

09:18.360 --> 09:20.700
The target is equal to the reward

09:20.700 --> 09:23.810
plus gamma times the next output,

09:23.810 --> 09:25.920
that is the maximum of the Q values of the next date

09:25.920 --> 09:27.540
according to the actions.

09:27.540 --> 09:29.340
So there we go, let's compute that.

09:29.340 --> 09:33.390
So that is equal to self dot gamma.

09:33.390 --> 09:36.000
And self dot gamma was initialized here.

09:36.000 --> 09:39.870
It is introduced, so that's a variable of our DQN object.

09:39.870 --> 09:43.503
Self dot gamma times the next output,

09:44.940 --> 09:48.270
as we just said, plus the reward.

09:48.270 --> 09:49.890
That is, the batch reward.

09:49.890 --> 09:51.660
We're working with batches here.

09:51.660 --> 09:55.830
So plus batch reward.

09:55.830 --> 09:59.760
And that's the targets in one sample of the memory.

09:59.760 --> 10:03.870
Gamma multiplied by the next outputs plus the reward.

10:03.870 --> 10:05.100
All right, perfect.

10:05.100 --> 10:08.850
So now we have our outputs, we also have our target

10:08.850 --> 10:11.160
and therefore we can compute the loss.

10:11.160 --> 10:14.460
The loss that is representing the error of the prediction.

10:14.460 --> 10:18.330
So let's call this loss TD loss.

10:18.330 --> 10:21.000
TD is of course for the temporal difference.

10:21.000 --> 10:23.250
That is again at the hearts of Q learning.

10:23.250 --> 10:28.250
And this TD loss is going to be equal to the Huber loss.

10:28.500 --> 10:30.510
That improves much the Q learning.

10:30.510 --> 10:32.370
That's the loss function we will choose

10:32.370 --> 10:34.770
for our artificial intelligence.

10:34.770 --> 10:36.870
For those of you coming from the deep learning course

10:36.870 --> 10:38.640
that's really the loss I recommend

10:38.640 --> 10:40.830
if you want to implement deep Q learning.

10:40.830 --> 10:43.470
And so how are we gonna get this Huber loss?

10:43.470 --> 10:46.020
Well, again, we're gonna take a function

10:46.020 --> 10:49.860
from the functional module, represented by F,

10:49.860 --> 10:52.140
and therefore here I'm going to use

10:52.140 --> 10:55.500
our functional module F dot

10:55.500 --> 10:57.360
and the Huber loss can be obtained

10:57.360 --> 11:01.680
thanks to the function smooth L one loss.

11:01.680 --> 11:03.810
That one. So pressing enter.

11:03.810 --> 11:06.720
And that's really the best loss function I recommend

11:06.720 --> 11:07.890
for deep Q learning.

11:07.890 --> 11:09.753
It really improves DQ learning,

11:10.737 --> 11:12.810
but this is a function, so I'm adding some parenthesis

11:12.810 --> 11:14.820
and now there is nothing more simple.

11:14.820 --> 11:16.650
The arguments we need to input

11:16.650 --> 11:19.380
are the predictions and the targets.

11:19.380 --> 11:22.140
So the predictions is of course our outputs,

11:22.140 --> 11:24.150
because that's the output of the neural network.

11:24.150 --> 11:25.860
You know, the output of the neural network

11:25.860 --> 11:27.690
is what the neural network predicts.

11:27.690 --> 11:29.220
So that's the prediction.

11:29.220 --> 11:32.460
So the first argument here is outputs.

11:32.460 --> 11:36.060
And then the second argument is of course the target,

11:36.060 --> 11:38.160
the thing we're trying to get,

11:38.160 --> 11:39.480
and it's already computed.

11:39.480 --> 11:42.839
Perfect, we can directly input target.

11:42.839 --> 11:45.450
Perfect. And now we have the loss.

11:45.450 --> 11:48.210
Just forgot a little T here.

11:48.210 --> 11:50.880
There we go, now the warning should disappear.

11:50.880 --> 11:52.140
Yes, perfect.

11:52.140 --> 11:53.850
And now that we have the the loss error

11:53.850 --> 11:55.920
we can back propagate this error

11:55.920 --> 11:58.680
back into the network to update the weights

11:58.680 --> 12:00.450
with stochastic gradient descent.

12:00.450 --> 12:03.480
And that's exactly what we're gonna do in the next step.

12:03.480 --> 12:06.180
So of course, now what we have to do, as you might guess,

12:06.180 --> 12:10.170
is take our optimizer,

12:10.170 --> 12:13.770
our optimizer which, again we introduced here.

12:13.770 --> 12:15.000
We initialized it.

12:15.000 --> 12:17.340
And that's an Adam optimizer

12:17.340 --> 12:20.160
which is actually an object of the Adam class.

12:20.160 --> 12:23.790
And it is already fitted with the parameters of our model.

12:23.790 --> 12:27.810
And we already chose a learning rate of 0.1%.

12:27.810 --> 12:30.300
So perfect, our optimizer is ready

12:30.300 --> 12:33.780
but now we need to apply it on the loss error

12:33.780 --> 12:35.970
to perform stochastic gradient descent

12:35.970 --> 12:37.170
and update the weight.

12:37.170 --> 12:38.610
So when working with pytorch

12:38.610 --> 12:42.180
the first thing we need to do is reinitialize it

12:42.180 --> 12:44.160
at each iteration of the loop.

12:44.160 --> 12:47.460
We must reinitialize the optimizer

12:47.460 --> 12:49.350
from one iteration to the other

12:49.350 --> 12:52.020
in the loop of the stochastic gradient descent.

12:52.020 --> 12:55.140
And to reinitialize it at each iteration of the loop

12:55.140 --> 12:57.540
well we're gonna use the following method

12:57.540 --> 13:00.390
which is zero grad, here we go.

13:00.390 --> 13:03.510
Zero grad will re initialize your optimizer

13:03.510 --> 13:05.190
at each iteration of the loop.

13:05.190 --> 13:07.470
Then let's not forget the parenthesis.

13:07.470 --> 13:10.830
Perfect. And now that it is reinitialized,

13:10.830 --> 13:15.180
well we can perform backward propagation with our optimizer.

13:15.180 --> 13:16.500
And how do we do that?

13:16.500 --> 13:18.480
Well we take our TD loss

13:18.480 --> 13:22.110
and we are going to back propagate it back into the network.

13:22.110 --> 13:24.450
And to back propagate it into the network

13:24.450 --> 13:28.170
we need to use the backward function.

13:28.170 --> 13:29.610
And inside this backward function,

13:29.610 --> 13:34.380
I recommend to input retain underscore variables

13:34.380 --> 13:37.230
and set it equal to true.

13:37.230 --> 13:38.220
I recommend to do this

13:38.220 --> 13:41.160
because this will improve back propagation.

13:41.160 --> 13:43.350
The use of written variables equals true

13:43.350 --> 13:46.170
is to free some memory and we need to free the memory

13:46.170 --> 13:49.170
because we're gonna go several times on the loss.

13:49.170 --> 13:52.890
So that will definitely improve the training performance.

13:52.890 --> 13:56.100
And finally, last step of this learn function

13:56.100 --> 13:57.870
is to update the weights according

13:57.870 --> 14:00.540
to the back propagation, that is according to how much

14:00.540 --> 14:02.790
the weights contributed to the error.

14:02.790 --> 14:07.080
And to do this, we take our optimizer again,

14:07.080 --> 14:10.170
which was initialized and reinitialized

14:10.170 --> 14:13.410
and we use the step function.

14:13.410 --> 14:15.810
And simply with this line of code

14:15.810 --> 14:19.470
by using this step function, this will update the weights.

14:19.470 --> 14:21.870
That's this line of code that updates the weights.

14:21.870 --> 14:24.180
This line of code back propagates

14:24.180 --> 14:26.190
the error into the neural network.

14:26.190 --> 14:28.890
And this line of code uses the optimizer

14:28.890 --> 14:30.420
to update the weights.

14:30.420 --> 14:31.470
And there we go.

14:31.470 --> 14:33.993
We have a learning neural network.

14:35.160 --> 14:36.810
All right, so congratulations.

14:36.810 --> 14:40.530
This was probably the most technical and difficult part

14:40.530 --> 14:42.450
of all this deep Q learning model.

14:42.450 --> 14:44.490
I know by Pytorch can be tricky sometimes

14:44.490 --> 14:46.830
with these unsqueeze and squeeze,

14:46.830 --> 14:48.420
but in the end I promise you will get

14:48.420 --> 14:50.970
a very functional neural network and

14:50.970 --> 14:52.500
therefore deep Q learning model

14:52.500 --> 14:56.550
and eventually a great artificial intelligence.

14:56.550 --> 14:58.740
So now let's move on to the next step

14:58.740 --> 15:00.450
of our deep Q learning model,

15:00.450 --> 15:02.520
which will be the update function

15:02.520 --> 15:04.500
that will obviously update

15:04.500 --> 15:07.140
when the AI will discover the new state.

15:07.140 --> 15:09.150
So you know, it will discover the new state

15:09.150 --> 15:11.520
and then it will receive the reward

15:11.520 --> 15:13.650
depending on the action that it displayed

15:13.650 --> 15:15.150
and this new state.

15:15.150 --> 15:17.580
So we'll take care of this with the update function

15:17.580 --> 15:19.680
and we'll do this in the next tutorial.

15:19.680 --> 15:21.543
Until then, enjoy AI.
