recurrent.jl

Flux/layers/recurrent.jl is a source file in module Flux

			
			
			
			

	
			gate
			(
			h
			,
			 
			n
			)
			 
			=
			
			 
			(
			
			1
			:
			h
			)
			 
			.+
			
			 
			h
			*
			(
			
			n
			-
			1
			)
			

			
			

	
			gate
			(
			
			x
			::
			AbstractVector
			,
			 
			h
			,
			 
			n
			)
			 
			=
			 
			
			@
			view
			 
			
			x
			[
			

	
			gate
			(
			h
			,
			n
			)
			]
			

			
			

	
			gate
			(
			
			x
			::
			AbstractMatrix
			,
			 
			h
			,
			 
			n
			)
			 
			=
			 
			
			view
			(
			x
			,
			 
			

	
			gate
			(
			h
			,
			n
			)
			,
			 
			:
			)

AD-friendly helper for dividing monolithic RNN params into equally sized gates


			
			
			
			
			

	
			multigate
			(
			
			x
			::
			AbstractArray
			,
			 
			h
			,
			 
			
			::
			
			Val
			{
			N
			}
			)
			 
			where
			 
			N
			 
			=
			 
			
			ntuple
			(
			
			n
			 
			->
			 
			

	
			gate
			(
			x
			,
			h
			,
			n
			)
			,
			 
			N
			)
			

			

			
			function
			 
			
			
			ChainRulesCore
			.
			
			rrule
			(
			
			::
			
			typeof
			(

	
			multigate
			)
			,
			 
			
			x
			::
			AbstractArray
			,
			 
			h
			,
			 
			c
			)
			
			
  
			
			function
			 
			
			multigate_pullback
			(
			dy
			)
			
			
    
			
			dx
			 
			=
			 
			
			map!
			(
			zero
			,
			 
			
			similar
			(
			x
			,
			 
			
			float
			(
			
			eltype
			(
			x
			)
			)
			,
			 
			
			axes
			(
			x
			)
			)
			,
			 
			x
			)
			
    
			
			
			foreach
			(
			

	
			multigate
			(
			dx
			,
			 
			h
			,
			 
			c
			)
			,
			 
			
			unthunk
			(
			dy
			)
			)
			 
			do
			
			 
			dxᵢ
			,
			 
			dyᵢ
			
			
      
			
			
			dyᵢ
			 
			isa
			 
			AbstractZero
			 
			&&
			 
			
			return
			
			
      
			
			@
			.
			
			 
			dxᵢ
			 
			+=
			 
			dyᵢ
			
    
			end
			
    
			
			return
			 
			
			(
			
			NoTangent
			(
			)
			,
			 
			dx
			,
			 
			
			NoTangent
			(
			)
			,
			 
			
			NoTangent
			(
			)
			)
			
  
			end
			
  
			
			return
			
			 
			

	
			multigate
			(
			x
			,
			 
			h
			,
			 
			c
			)
			,
			 
			multigate_pullback
			

			end

Type stable and AD-friendly helper for iterating over the last dimension of an array


			
			
			
			function
			 
			
			

	
			eachlastdim
			(
			
			A
			::
			
			AbstractArray
			{
			T
			,
			N
			}
			)
			 
			where
			 
			{
			T
			,
			N
			}
			
			
  
			
			inds_before
			 
			=
			 
			
			ntuple
			(
			
			_
			 
			->
			 
			:
			,
			 
			
			N
			-
			1
			)
			
  
			
			return
			 
			(
			
			
			view
			(
			A
			,
			 
			
			inds_before
			...
			,
			 
			i
			)
			 
			for
			
			 
			i
			 
			in
			 
			
			axes
			(
			A
			,
			 
			N
			)
			)
			

			end

adapted from https://github.com/JuliaDiff/ChainRules.jl/blob/f13e0a45d10bb13f48d6208e9c9d5b4a52b96732/src/rulesets/Base/indexing.jl#L77


			
			
			
			function
			 
			
			
			∇eachlastdim
			(
			dys_raw
			,
			 
			
			x
			::
			
			AbstractArray
			{
			T
			,
			 
			N
			}
			)
			 
			where
			 
			{
			T
			,
			 
			N
			}
			
			
  
			
			dys
			 
			=
			 
			
			unthunk
			(
			dys_raw
			)
			
  
			
			i1
			 
			=
			 
			
			findfirst
			(
			
			dy
			 
			->
			
			 
			dy
			 
			isa
			 
			AbstractArray
			,
			 
			dys
			)
			
  
			
			if
			 
			
			isnothing
			(
			i1
			)
			
			  
			# all slices are Zero!
			
      
			
			return
			 
			
			fill!
			(
			
			similar
			(
			x
			,
			 
			T
			,
			 
			
			axes
			(
			x
			)
			)
			,
			 
			
			zero
			(
			T
			)
			)
			
  
			end
			
  
			# The whole point of this gradient is that we can allocate one `dx` array:
			
  
			
			dx
			 
			=
			 
			
			
			similar
			(
			x
			,
			 
			T
			,
			 
			
			axes
			(
			x
			)
			)
			::
			AbstractArray
			
  
			
			for
			
			 
			i
			 
			in
			 
			
			axes
			(
			x
			,
			 
			N
			)
			
			
      
			
			slice
			 
			=
			 
			
			selectdim
			(
			dx
			,
			 
			N
			,
			 
			i
			)
			
      
			
			if
			
			 
			
			dys
			[
			i
			]
			 
			isa
			 
			AbstractZero
			
			
          
			
			fill!
			(
			slice
			,
			 
			
			zero
			(
			
			eltype
			(
			slice
			)
			)
			)
			
      
			else
			
			
          
			
			copyto!
			(
			slice
			,
			 
			
			dys
			[
			i
			]
			)
			
      
			end
			
  
			end
			
  
			
			return
			 
			
			
			ProjectTo
			(
			x
			)
			(
			dx
			)
			

			end
			

			

			
			function
			 
			
			
			
			ChainRulesCore
			.
			
			rrule
			(
			
			::
			
			typeof
			(

	
			eachlastdim
			)
			,
			 
			
			x
			::
			
			AbstractArray
			{
			T
			,
			N
			}
			)
			 
			where
			 
			{
			T
			,
			N
			}
			
			
  
			
			
			lastdims
			(
			dy
			)
			 
			=
			 
			
			(
			
			NoTangent
			(
			)
			,
			 
			
			∇eachlastdim
			(
			
			unthunk
			(
			dy
			)
			,
			 
			x
			)
			)
			
  
			
			
			collect
			(
			

	
			eachlastdim
			(
			x
			)
			)
			,
			 
			lastdims
			

			end
			

			

			
			

	
			reshape_cell_output
			(
			h
			,
			 
			x
			)
			 
			=
			 
			
			reshape
			(
			h
			,
			 
			:
			,
			 
			
			
			
			size
			(
			x
			)
			[
			
			2
			:
			end
			]
			...
			)

Stateful recurrence


			
			
			
			
			
			"""
			

			    Recur(cell)

			

			`Recur` takes a recurrent cell and makes it stateful, managing the hidden state

			in the background. `cell` should be a model of the form:

			

			    h, y = cell(h, x...)

			

			For example, here's a recurrent network that keeps a running total of its inputs:

			

			# Examples

			```jldoctest

			julia> accum(h, x) = (h + x, x)

			accum (generic function with 1 method)

			

			julia> rnn = Flux.Recur(accum, 0)

			Recur(accum)

			

			julia> rnn(2) 

			2

			

			julia> rnn(3)

			3

			

			julia> rnn.state

			5

			```

			

			Folding over a 3d Array of dimensions `(features, batch, time)` is also supported:

			

			```jldoctest

			julia> accum(h, x) = (h .+ x, x)

			accum (generic function with 1 method)

			

			julia> rnn = Flux.Recur(accum, zeros(Int, 1, 1))

			Recur(accum)

			

			julia> rnn([2])

			1-element Vector{Int64}:

			 2

			

			julia> rnn([3])

			1-element Vector{Int64}:

			 3

			

			julia> rnn.state

			1×1 Matrix{Int64}:

			 5

			

			julia> out = rnn(reshape(1:10, 1, 1, :));  # apply to a sequence of (features, batch, time)

			

			julia> out |> size

			(1, 1, 10)

			

			julia> vec(out)

			10-element Vector{Int64}:

			  1

			  2

			  3

			  4

			  5

			  6

			  7

			  8

			  9

			 10

			

			julia> rnn.state

			1×1 Matrix{Int64}:

			 60

			```

			"""
			

			
			mutable
			
			 
			struct
			 
			

	
			Recur
			{
			T
			,
			S
			}
			
			
  
			
			cell
			::
			T
			
  
			

	
			state
			::
			S
			

			end
			

			

			
			function
			 
			
			(
			
			m
			::

	
			Recur
			)
			(
			x
			)
			
			
  
			
			
			
			m
			.
			

	
			state
			,
			 
			y
			 
			=
			 
			
			
			m
			.
			
			cell
			(
			
			m
			.
			

	
			state
			,
			 
			x
			)
			
  
			
			return
			 
			y
			

			end
			

			

			
			@
			functor
			 

	
			Recur
			

			
			
			trainable
			(
			
			a
			::

	
			Recur
			)
			 
			=
			 
			
			(
			
			;
			 
			
			cell
			 
			=
			 
			
			a
			.
			
			cell
			)
			

			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			m
			::

	
			Recur
			)
			 
			=
			 
			
			print
			(
			io
			,
			 
			
			"
			Recur(
			"
			,
			 
			
			m
			.
			
			cell
			,
			 
			
			"
			)
			"
			)
			

			

			
			
			
			"""
			

			    reset!(rnn)

			

			Reset the hidden state of a recurrent layer back to its original value.

			

			Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:

			

			    rnn.state = hidden(rnn.cell)

			

			# Examples

			```jldoctest

			julia> r = Flux.RNNCell(relu, ones(1,1), zeros(1,1), ones(1,1), zeros(1,1));  # users should use the RNN wrapper struct instead

			

			julia> y = Flux.Recur(r, ones(1,1));

			

			julia> y.state

			1×1 Matrix{Float64}:

			 1.0

			

			julia> y(ones(1,1))  # relu(1*1 + 1)

			1×1 Matrix{Float64}:

			 2.0

			

			julia> y.state

			1×1 Matrix{Float64}:

			 2.0

			

			julia> Flux.reset!(y)

			1×1 Matrix{Float64}:

			 0.0

			

			julia> y.state

			1×1 Matrix{Float64}:

			 0.0

			```

			"""
			

			
			

	
			reset!
			(
			
			m
			::

	
			Recur
			)
			 
			=
			 
			(
			
			
			m
			.
			

	
			state
			 
			=
			 
			
			
			m
			.
			
			cell
			.
			
			state0
			)
			

			
			

	
			reset!
			(
			m
			)
			 
			=
			 
			
			foreach
			(

	
			reset!
			,
			 
			
			
			functor
			(
			m
			)
			[
			1
			]
			)
			

			

			
			

	
			flip
			(
			f
			,
			 
			xs
			)
			 
			=
			 
			
			reverse
			(
			
			[
			
			
			f
			(
			x
			)
			 
			for
			
			 
			x
			 
			in
			 
			
			reverse
			(
			xs
			)
			]
			)
			

			

			
			function
			 
			
			
			(
			
			m
			::

	
			Recur
			)
			(
			
			x
			::
			
			AbstractArray
			{
			T
			,
			 
			3
			}
			)
			 
			where
			 
			T
			
			
  
			
			h
			 
			=
			 
			
			[
			
			
			m
			(
			x_t
			)
			 
			for
			
			 
			x_t
			 
			in
			 
			

	
			eachlastdim
			(
			x
			)
			]
			
  
			
			sze
			 
			=
			 
			
			size
			(
			
			h
			[
			1
			]
			)
			
  
			
			reshape
			(
			
			reduce
			(
			hcat
			,
			 
			h
			)
			,
			 
			
			sze
			[
			1
			]
			,
			 
			
			sze
			[
			2
			]
			,
			 
			
			length
			(
			h
			)
			)
			

			end

Vanilla RNN


			
			
			
			
			struct
			 
			

	
			RNNCell
			{
			F
			,
			I
			,
			H
			,
			V
			,
			S
			}
			
			
  
			
			σ
			::
			F
			
  
			
			Wi
			::
			I
			
  
			
			Wh
			::
			H
			
  
			
			b
			::
			V
			
  
			
			state0
			::
			S
			

			end
			

			

			
			

	
			RNNCell
			(
			
			
			(
			in
			,
			 
			out
			)
			::
			Pair
			,
			 
			
			σ
			=
			tanh
			
			;
			 
			
			init
			=
			

	
			Flux
			.
			

	
			glorot_uniform
			,
			 
			
			initb
			=

	
			zeros32
			,
			 
			
			init_state
			=

	
			zeros32
			)
			 
			=
			
  
			

	
			RNNCell
			(
			σ
			,
			 
			
			init
			(
			out
			,
			 
			in
			)
			,
			 
			
			init
			(
			out
			,
			 
			out
			)
			,
			 
			
			initb
			(
			out
			)
			,
			 
			
			init_state
			(
			out
			,
			1
			)
			)
			

			

			
			function
			 
			
			
			(
			
			m
			::
			

	
			RNNCell
			{
			F
			,
			I
			,
			H
			,
			V
			,
			
			<:
			
			AbstractMatrix
			{
			T
			}
			}
			)
			(
			h
			,
			 
			
			x
			::
			AbstractVecOrMat
			)
			 
			where
			 
			{
			F
			,
			I
			,
			H
			,
			V
			,
			T
			}
			
			
  
			
			
			Wi
			,
			 
			Wh
			,
			 
			b
			 
			=
			
			 
			
			m
			.
			
			Wi
			,
			 
			
			m
			.
			
			Wh
			,
			 
			
			m
			.
			
			b
			
  
			
			_size_check
			(
			m
			,
			 
			x
			,
			 
			
			1
			 
			=>
			 
			
			size
			(
			Wi
			,
			2
			)
			)
			
  
			
			σ
			 
			=
			 
			
			
			NNlib
			.
			
			fast_act
			(
			
			m
			.
			
			σ
			,
			 
			x
			)
			
  
			
			xT
			 
			=
			 
			
			_match_eltype
			(
			m
			,
			 
			T
			,
			 
			x
			)
			
  
			
			h
			 
			=
			 
			
			σ
			.
			
			(
			
			
			
			Wi
			*
			xT
			 
			.+
			
			 
			Wh
			*
			h
			 
			.+
			 
			b
			)
			
  
			
			return
			
			 
			h
			,
			 
			

	
			reshape_cell_output
			(
			h
			,
			 
			x
			)
			

			end
			

			

			
			@
			functor
			 

	
			RNNCell
			

			

			
			function
			 
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			l
			::

	
			RNNCell
			)
			
			
  
			
			print
			(
			io
			,
			 
			
			"
			RNNCell(
			"
			,
			 
			
			size
			(
			
			l
			.
			
			Wi
			,
			 
			2
			)
			,
			 
			
			"
			 => 
			"
			,
			 
			
			size
			(
			
			l
			.
			
			Wi
			,
			 
			1
			)
			)
			
  
			
			
			
			l
			.
			
			σ
			 
			==
			 
			identity
			 
			||
			 
			
			print
			(
			io
			,
			 
			
			"
			, 
			"
			,
			 
			
			l
			.
			
			σ
			)
			
  
			
			print
			(
			io
			,
			 
			
			"
			)
			"
			)
			

			end
			

			

			
			
			
			"""
			

			    RNN(in => out, σ = tanh)

			

			The most basic recurrent layer; essentially acts as a `Dense` layer, but with the

			output fed back into the input each time step.

			

			The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.

			

			This constructor is syntactic sugar for `Recur(RNNCell(a...))`, and so RNNs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.

			

			# Examples

			```jldoctest

			julia> r = RNN(3 => 5)

			Recur(

			  RNNCell(3 => 5, tanh),                # 50 parameters

			)         # Total: 4 trainable arrays, 50 parameters,

			          # plus 1 non-trainable, 5 parameters, summarysize 432 bytes.

			

			julia> r(rand(Float32, 3)) |> size

			(5,)

			

			julia> Flux.reset!(r);

			

			julia> r(rand(Float32, 3, 10)) |> size # batch size of 10

			(5, 10)

			```

			

			!!! warning "Batch size changes"

			  

			    Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the following example:

			

			    ```julia

			    julia> r = RNN(3 => 5)

			    Recur(

			      RNNCell(3 => 5, tanh),                # 50 parameters

			    )         # Total: 4 trainable arrays, 50 parameters,

			              # plus 1 non-trainable, 5 parameters, summarysize 432 bytes.

			

			    julia> r.state |> size

			    (5, 1)

			

			    julia> r(rand(Float32, 3)) |> size

			    (5,)

			

			    julia> r.state |> size

			    (5, 1)

			

			    julia> r(rand(Float32, 3, 10)) |> size # batch size of 10

			    (5, 10)

			

			    julia> r.state |> size # state shape has changed

			    (5, 10)

			

			    julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.

			    (50,)

			    ```

			

			# Note:

			`RNNCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The  `Wi` and `Wh` matrices do not need to be the same type, but if `Wh` is `dxd`, then `Wi` should be of shape `dxN`.

			

			```julia

			julia> using LinearAlgebra

			

			julia> r = Flux.Recur(Flux.RNNCell(tanh, rand(5, 4), Tridiagonal(rand(5, 5)), rand(5), rand(5, 1)))

			

			julia> r(rand(4, 10)) |> size # batch size of 10

			(5, 10)

			```

			"""
			

			
			

	
			RNN
			(
			
			a
			...
			
			;
			 
			
			ka
			...
			)
			 
			=
			 
			

	
			Recur
			(
			

	
			RNNCell
			(
			
			a
			...
			
			;
			 
			
			ka
			...
			)
			)
			

			
			

	
			Recur
			(
			
			m
			::

	
			RNNCell
			)
			 
			=
			 
			

	
			Recur
			(
			m
			,
			 
			
			m
			.
			
			state0
			)

LSTM


			
			
			
			
			struct
			 
			

	
			LSTMCell
			{
			I
			,
			H
			,
			V
			,
			S
			}
			
			
  
			
			Wi
			::
			I
			
  
			
			Wh
			::
			H
			
  
			
			b
			::
			V
			
  
			
			state0
			::
			S
			

			end
			

			

			
			function
			 
			

	
			LSTMCell
			(
			
			
			(
			in
			,
			 
			out
			)
			::
			Pair
			
			;
			
                  
			
			init
			 
			=
			 

	
			glorot_uniform
			,
			
                  
			
			initb
			 
			=
			 

	
			zeros32
			,
			
                  
			
			init_state
			 
			=
			 

	
			zeros32
			)
			
			
  
			
			cell
			 
			=
			 
			

	
			LSTMCell
			(
			
			init
			(
			
			out
			 
			*
			 
			4
			,
			 
			in
			)
			,
			 
			
			init
			(
			
			out
			 
			*
			 
			4
			,
			 
			out
			)
			,
			 
			
			initb
			(
			
			out
			 
			*
			 
			4
			)
			,
			 
			
			(
			
			init_state
			(
			out
			,
			1
			)
			,
			 
			
			init_state
			(
			out
			,
			1
			)
			)
			)
			
  
			
			
			
			cell
			.
			
			b
			[
			

	
			gate
			(
			out
			,
			 
			2
			)
			]
			 
			.=
			 
			1
			
  
			
			return
			 
			cell
			

			end
			

			

			
			function
			 
			
			
			(
			
			m
			::
			

	
			LSTMCell
			{
			I
			,
			H
			,
			V
			,
			
			<:
			
			NTuple
			{
			2
			,
			
			AbstractMatrix
			{
			T
			}
			}
			}
			)
			(
			
			(
			h
			,
			 
			c
			)
			,
			 
			
			x
			::
			AbstractVecOrMat
			)
			 
			where
			 
			{
			I
			,
			H
			,
			V
			,
			T
			}
			
			
  
			
			_size_check
			(
			m
			,
			 
			x
			,
			 
			
			1
			 
			=>
			 
			
			size
			(
			
			m
			.
			
			Wi
			,
			2
			)
			)
			
  
			
			
			b
			,
			 
			o
			 
			=
			
			 
			
			m
			.
			
			b
			,
			 
			
			size
			(
			h
			,
			 
			1
			)
			
  
			
			xT
			 
			=
			 
			
			_match_eltype
			(
			m
			,
			 
			T
			,
			 
			x
			)
			
  
			
			g
			 
			=
			 
			
			muladd
			(
			
			m
			.
			
			Wi
			,
			 
			xT
			,
			 
			
			muladd
			(
			
			m
			.
			
			Wh
			,
			 
			h
			,
			 
			b
			)
			)
			
  
			
			
			input
			,
			 
			forget
			,
			 
			cell
			,
			 
			output
			 
			=
			 
			

	
			multigate
			(
			g
			,
			 
			o
			,
			 
			
			Val
			(
			4
			)
			)
			
  
			
			c′
			 
			=
			 
			
			@
			.
			
			
			 
			
			sigmoid_fast
			(
			forget
			)
			 
			*
			 
			c
			 
			+
			
			 
			
			sigmoid_fast
			(
			input
			)
			 
			*
			 
			
			tanh_fast
			(
			cell
			)
			
  
			
			h′
			 
			=
			 
			
			@
			.
			
			 
			
			sigmoid_fast
			(
			output
			)
			 
			*
			 
			
			tanh_fast
			(
			c′
			)
			
  
			
			return
			
			 
			
			(
			h′
			,
			 
			c′
			)
			,
			 
			

	
			reshape_cell_output
			(
			h′
			,
			 
			x
			)
			

			end
			

			

			
			@
			functor
			 

	
			LSTMCell
			

			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			l
			::

	
			LSTMCell
			)
			 
			=
			
  
			
			print
			(
			io
			,
			 
			
			"
			LSTMCell(
			"
			,
			 
			
			size
			(
			
			l
			.
			
			Wi
			,
			 
			2
			)
			,
			 
			
			"
			 => 
			"
			,
			 
			
			
			size
			(
			
			l
			.
			
			Wi
			,
			 
			1
			)
			÷
			4
			,
			 
			
			"
			)
			"
			)
			

			

			
			
			
			"""
			

			    LSTM(in => out)

			

			[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)

			recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.

			

			The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.

			

			This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.

			

			See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)

			for a good overview of the internals.

			

			# Examples

			```jldoctest

			julia> l = LSTM(3 => 5)

			Recur(

			  LSTMCell(3 => 5),                     # 190 parameters

			)         # Total: 5 trainable arrays, 190 parameters,

			          # plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.

			

			julia> l(rand(Float32, 3)) |> size

			(5,)

			

			julia> Flux.reset!(l);

			

			julia> l(rand(Float32, 3, 10)) |> size # batch size of 10

			(5, 10)

			```

			

			!!! warning "Batch size changes"

			    Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).

			

			# Note:

			  `LSTMCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The  `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref).

			"""
			

			
			

	
			LSTM
			(
			
			a
			...
			
			;
			 
			
			ka
			...
			)
			 
			=
			 
			

	
			Recur
			(
			

	
			LSTMCell
			(
			
			a
			...
			
			;
			 
			
			ka
			...
			)
			)
			

			
			

	
			Recur
			(
			
			m
			::

	
			LSTMCell
			)
			 
			=
			 
			

	
			Recur
			(
			m
			,
			 
			
			m
			.
			
			state0
			)

GRU


			
			
			
			function
			 
			
			_gru_output
			(
			gxs
			,
			 
			ghs
			,
			 
			bs
			)
			
			
  
			
			r
			 
			=
			 
			
			@
			.
			 
			
			sigmoid_fast
			(
			
			
			gxs
			[
			1
			]
			 
			+
			 
			
			ghs
			[
			1
			]
			 
			+
			 
			
			bs
			[
			1
			]
			)
			
  
			
			z
			 
			=
			 
			
			@
			.
			 
			
			sigmoid_fast
			(
			
			
			gxs
			[
			2
			]
			 
			+
			 
			
			ghs
			[
			2
			]
			 
			+
			 
			
			bs
			[
			2
			]
			)
			
  
			
			return
			
			 
			r
			,
			 
			z
			

			end
			

			

			
			
			struct
			 
			

	
			GRUCell
			{
			I
			,
			H
			,
			V
			,
			S
			}
			
			
  
			
			Wi
			::
			I
			
  
			
			Wh
			::
			H
			
  
			
			b
			::
			V
			
  
			
			state0
			::
			S
			

			end
			

			

			
			

	
			GRUCell
			(
			
			
			(
			in
			,
			 
			out
			)
			::
			Pair
			
			;
			 
			
			init
			 
			=
			 

	
			glorot_uniform
			,
			 
			
			initb
			 
			=
			 

	
			zeros32
			,
			 
			
			init_state
			 
			=
			 

	
			zeros32
			)
			 
			=
			
  
			

	
			GRUCell
			(
			
			init
			(
			
			out
			 
			*
			 
			3
			,
			 
			in
			)
			,
			 
			
			init
			(
			
			out
			 
			*
			 
			3
			,
			 
			out
			)
			,
			 
			
			initb
			(
			
			out
			 
			*
			 
			3
			)
			,
			 
			
			init_state
			(
			out
			,
			1
			)
			)
			

			

			
			function
			 
			
			
			(
			
			m
			::
			

	
			GRUCell
			{
			I
			,
			H
			,
			V
			,
			
			<:
			
			AbstractMatrix
			{
			T
			}
			}
			)
			(
			h
			,
			 
			
			x
			::
			AbstractVecOrMat
			)
			 
			where
			 
			{
			I
			,
			H
			,
			V
			,
			T
			}
			
			
  
			
			_size_check
			(
			m
			,
			 
			x
			,
			 
			
			1
			 
			=>
			 
			
			size
			(
			
			m
			.
			
			Wi
			,
			2
			)
			)
			
  
			
			
			Wi
			,
			 
			Wh
			,
			 
			b
			,
			 
			o
			 
			=
			
			 
			
			m
			.
			
			Wi
			,
			 
			
			m
			.
			
			Wh
			,
			 
			
			m
			.
			
			b
			,
			 
			
			size
			(
			h
			,
			 
			1
			)
			
  
			
			xT
			 
			=
			 
			
			_match_eltype
			(
			m
			,
			 
			T
			,
			 
			x
			)
			
  
			
			
			gxs
			,
			 
			ghs
			,
			 
			bs
			 
			=
			
			 
			

	
			multigate
			(
			
			Wi
			*
			xT
			,
			 
			o
			,
			 
			
			Val
			(
			3
			)
			)
			,
			 
			

	
			multigate
			(
			
			Wh
			*
			h
			,
			 
			o
			,
			 
			
			Val
			(
			3
			)
			)
			,
			 
			

	
			multigate
			(
			b
			,
			 
			o
			,
			 
			
			Val
			(
			3
			)
			)
			
  
			
			
			r
			,
			 
			z
			 
			=
			 
			
			_gru_output
			(
			gxs
			,
			 
			ghs
			,
			 
			bs
			)
			
  
			
			 
			=
			 
			
			@
			.
			 
			
			tanh_fast
			(
			
			
			gxs
			[
			3
			]
			 
			+
			
			 
			r
			 
			*
			 
			
			ghs
			[
			3
			]
			 
			+
			 
			
			bs
			[
			3
			]
			)
			
  
			
			h′
			 
			=
			 
			
			@
			.
			
			
			 
			(
			
			1
			 
			-
			 
			z
			)
			 
			*
			 
			 
			+
			
			 
			z
			 
			*
			 
			h
			
  
			
			return
			
			 
			h′
			,
			 
			

	
			reshape_cell_output
			(
			h′
			,
			 
			x
			)
			

			end
			

			

			
			@
			functor
			 

	
			GRUCell
			

			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			l
			::

	
			GRUCell
			)
			 
			=
			
  
			
			print
			(
			io
			,
			 
			
			"
			GRUCell(
			"
			,
			 
			
			size
			(
			
			l
			.
			
			Wi
			,
			 
			2
			)
			,
			 
			
			"
			 => 
			"
			,
			 
			
			
			size
			(
			
			l
			.
			
			Wi
			,
			 
			1
			)
			÷
			3
			,
			 
			
			"
			)
			"
			)
			

			

			
			
			
			"""
			

			    GRU(in => out)

			

			[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an

			RNN but generally exhibits a longer memory span over sequences. This implements

			the variant proposed in v1 of the referenced paper.

			

			The integer arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.

			

			This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.

			

			See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)

			for a good overview of the internals.

			

			# Examples

			```jldoctest

			julia> g = GRU(3 => 5)

			Recur(

			  GRUCell(3 => 5),                      # 140 parameters

			)         # Total: 4 trainable arrays, 140 parameters,

			          # plus 1 non-trainable, 5 parameters, summarysize 792 bytes.

			

			julia> g(rand(Float32, 3)) |> size

			(5,)

			

			julia> Flux.reset!(g);

			

			julia> g(rand(Float32, 3, 10)) |> size # batch size of 10

			(5, 10)

			```

			

			!!! warning "Batch size changes"

			    Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).

			

			# Note:

			  `GRUCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The  `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref).

			"""
			

			
			

	
			GRU
			(
			
			a
			...
			
			;
			 
			
			ka
			...
			)
			 
			=
			 
			

	
			Recur
			(
			

	
			GRUCell
			(
			
			a
			...
			
			;
			 
			
			ka
			...
			)
			)
			

			
			

	
			Recur
			(
			
			m
			::

	
			GRUCell
			)
			 
			=
			 
			

	
			Recur
			(
			m
			,
			 
			
			m
			.
			
			state0
			)

GRU v3


			
			
			
			
			struct
			 
			

	
			GRUv3Cell
			{
			I
			,
			H
			,
			V
			,
			HH
			,
			S
			}
			
			
  
			
			Wi
			::
			I
			
  
			
			Wh
			::
			H
			
  
			
			b
			::
			V
			
  
			
			Wh_h̃
			::
			HH
			
  
			
			state0
			::
			S
			

			end
			

			

			
			

	
			GRUv3Cell
			(
			
			
			(
			in
			,
			 
			out
			)
			::
			Pair
			
			;
			 
			
			init
			 
			=
			 

	
			glorot_uniform
			,
			 
			
			initb
			 
			=
			 

	
			zeros32
			,
			 
			
			init_state
			 
			=
			 

	
			zeros32
			)
			 
			=
			
  
			

	
			GRUv3Cell
			(
			
			init
			(
			
			out
			 
			*
			 
			3
			,
			 
			in
			)
			,
			 
			
			init
			(
			
			out
			 
			*
			 
			2
			,
			 
			out
			)
			,
			 
			
			initb
			(
			
			out
			 
			*
			 
			3
			)
			,
			
            
			
			init
			(
			out
			,
			 
			out
			)
			,
			 
			
			init_state
			(
			out
			,
			1
			)
			)
			

			

			
			function
			 
			
			
			(
			
			m
			::
			

	
			GRUv3Cell
			{
			I
			,
			H
			,
			V
			,
			HH
			,
			
			<:
			
			AbstractMatrix
			{
			T
			}
			}
			)
			(
			h
			,
			 
			
			x
			::
			AbstractVecOrMat
			)
			 
			where
			 
			{
			I
			,
			H
			,
			V
			,
			HH
			,
			T
			}
			
			
  
			
			_size_check
			(
			m
			,
			 
			x
			,
			 
			
			1
			 
			=>
			 
			
			size
			(
			
			m
			.
			
			Wi
			,
			2
			)
			)
			
  
			
			
			Wi
			,
			 
			Wh
			,
			 
			b
			,
			 
			Wh_h̃
			,
			 
			o
			 
			=
			
			 
			
			m
			.
			
			Wi
			,
			 
			
			m
			.
			
			Wh
			,
			 
			
			m
			.
			
			b
			,
			 
			
			m
			.
			
			Wh_h̃
			,
			 
			
			size
			(
			h
			,
			 
			1
			)
			
  
			
			xT
			 
			=
			 
			
			_match_eltype
			(
			m
			,
			 
			T
			,
			 
			x
			)
			
  
			
			
			gxs
			,
			 
			ghs
			,
			 
			bs
			 
			=
			
			 
			

	
			multigate
			(
			
			Wi
			*
			xT
			,
			 
			o
			,
			 
			
			Val
			(
			3
			)
			)
			,
			 
			

	
			multigate
			(
			
			Wh
			*
			h
			,
			 
			o
			,
			 
			
			Val
			(
			2
			)
			)
			,
			 
			

	
			multigate
			(
			b
			,
			 
			o
			,
			 
			
			Val
			(
			3
			)
			)
			
  
			
			
			r
			,
			 
			z
			 
			=
			 
			
			_gru_output
			(
			gxs
			,
			 
			ghs
			,
			 
			bs
			)
			
  
			
			 
			=
			 
			
			tanh_fast
			.
			
			(
			
			
			
			gxs
			[
			3
			]
			 
			.+
			 
			(
			
			Wh_h̃
			 
			*
			 
			(
			
			r
			 
			.*
			 
			h
			)
			)
			 
			.+
			 
			
			bs
			[
			3
			]
			)
			
  
			
			h′
			 
			=
			 
			
			@
			.
			
			
			 
			(
			
			1
			 
			-
			 
			z
			)
			 
			*
			 
			 
			+
			
			 
			z
			 
			*
			 
			h
			
  
			
			return
			
			 
			h′
			,
			 
			

	
			reshape_cell_output
			(
			h′
			,
			 
			x
			)
			

			end
			

			

			
			@
			functor
			 

	
			GRUv3Cell
			

			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			l
			::

	
			GRUv3Cell
			)
			 
			=
			
  
			
			print
			(
			io
			,
			 
			
			"
			GRUv3Cell(
			"
			,
			 
			
			size
			(
			
			l
			.
			
			Wi
			,
			 
			2
			)
			,
			 
			
			"
			 => 
			"
			,
			 
			
			
			size
			(
			
			l
			.
			
			Wi
			,
			 
			1
			)
			÷
			3
			,
			 
			
			"
			)
			"
			)
			

			

			
			
			
			"""
			

			    GRUv3(in => out)

			

			[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an

			RNN but generally exhibits a longer memory span over sequences. This implements

			the variant proposed in v3 of the referenced paper.

			

			The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.

			

			This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.

			

			See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)

			for a good overview of the internals.

			

			# Examples

			```jldoctest

			julia> g = GRUv3(3 => 5)

			Recur(

			  GRUv3Cell(3 => 5),                    # 140 parameters

			)         # Total: 5 trainable arrays, 140 parameters,

			          # plus 1 non-trainable, 5 parameters, summarysize 848 bytes.

			

			julia> g(rand(Float32, 3)) |> size

			(5,)

			

			julia> Flux.reset!(g);

			

			julia> g(rand(Float32, 3, 10)) |> size # batch size of 10

			(5, 10)

			```

			

			!!! warning "Batch size changes"

			    Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).

			

			# Note:

			  `GRUv3Cell`s can be constructed directly by specifying the non-linear function, the `Wi`, `Wh`, and `Wh_h` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The  `Wi`, `Wh`, and `Wh_h` matrices do not need to be the same type. See the example in [`RNN`](@ref).

			"""
			

			
			

	
			GRUv3
			(
			
			a
			...
			
			;
			 
			
			ka
			...
			)
			 
			=
			 
			

	
			Recur
			(
			

	
			GRUv3Cell
			(
			
			a
			...
			
			;
			 
			
			ka
			...
			)
			)
			

			
			

	
			Recur
			(
			
			m
			::

	
			GRUv3Cell
			)
			 
			=
			 
			

	
			Recur
			(
			m
			,
			 
			
			m
			.
			
			state0
			)