xresnet.jl

FastVision/models/xresnet.jl is a source file in module FastVision

			
			
			
			const
			
			 
			act_fn
			 
			=
			 
			relu
			

			

			
			function
			 
			

	
			convx
			(
			ni
			,
			 
			nf
			
			;
			 
			
			ks
			 
			=
			 
			3
			,
			 
			
			stride
			 
			=
			 
			1
			,
			 
			
			ndim
			 
			=
			 
			2
			)
			
			
    
			

	
			Conv
			(
			
			ntuple
			(
			
			_
			 
			->
			 
			ks
			,
			 
			ndim
			)
			,
			
         
			
			ni
			 
			=>
			 
			nf
			,
			
         
			
			stride
			 
			=
			 
			stride
			,
			
         
			
			pad
			 
			=
			
			 
			ks
			 
			÷
			 
			2
			,
			
         
			
			init
			 
			=
			 
			

	
			Flux
			.
			
			kaiming_normal
			)
			

			end
			

			

			
			function
			 
			

	
			convxlayer
			(
			ni
			,
			 
			nf
			
			;
			 
			
			ks
			 
			=
			 
			3
			,
			 
			
			stride
			 
			=
			 
			1
			,
			 
			
			zero_bn
			 
			=
			 
			false
			,
			 
			
			act
			 
			=
			 
			true
			,
			 
			
			ndim
			 
			=
			 
			2
			)
			
			
    
			
			bn
			 
			=
			 
			

	
			BatchNorm
			(
			nf
			,
			 
			
			act
			 
			?
			 
			act_fn
			 
			:
			 
			identity
			)
			
    
			
			fill!
			(
			
			bn
			.
			
			γ
			,
			 
			
			zero_bn
			 
			?
			 
			0
			 
			:
			 
			1
			)
			
    
			
			return
			 
			

	
			Chain
			(
			

	
			convx
			(
			ni
			,
			 
			nf
			
			;
			 
			
			ks
			 
			=
			 
			ks
			,
			 
			
			stride
			 
			=
			 
			stride
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			 
			bn
			)
			

			end
			

			

			
			
			struct
			 

	
			ResBlock
			
			
    
			
			convs
			::
			Any
			
    
			
			idconv
			::
			Any
			
    
			
			pool
			::
			Any
			

			end
			

			

			
			function
			 
			

	
			ResBlock
			(
			
			expansion
			::
			Int
			,
			 
			
			ni
			::
			Int
			,
			 
			
			nh
			::
			Int
			
			;
			 
			
			
			stride
			::
			Int
			 
			=
			 
			1
			,
			 
			
			ndim
			 
			=
			 
			2
			)
			
			
    
			
			
			nf
			,
			 
			ni
			 
			=
			
			
			 
			nh
			 
			*
			 
			expansion
			,
			
			 
			ni
			 
			*
			 
			expansion
			
    
			
			if
			
			 
			expansion
			 
			==
			 
			1
			
			
        
			
			layers
			 
			=
			 
			
			[
			
            
			

	
			convxlayer
			(
			ni
			,
			 
			nh
			
			;
			 
			
			stride
			 
			=
			 
			stride
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			
            
			

	
			convxlayer
			(
			nh
			,
			 
			nf
			
			;
			 
			
			zero_bn
			 
			=
			 
			true
			,
			 
			
			act
			 
			=
			 
			false
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			
        
			]
			
    
			else
			
			
        
			
			layers
			 
			=
			 
			
			[
			
            
			

	
			convxlayer
			(
			ni
			,
			 
			nh
			
			;
			 
			
			ks
			 
			=
			 
			1
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			
            
			

	
			convxlayer
			(
			nh
			,
			 
			nh
			
			;
			 
			
			stride
			 
			=
			 
			stride
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			
            
			

	
			convxlayer
			(
			nh
			,
			 
			nf
			
			;
			 
			
			ks
			 
			=
			 
			1
			,
			 
			
			zero_bn
			 
			=
			 
			true
			,
			 
			
			act
			 
			=
			 
			false
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			
        
			]
			
    
			end
			

			
    
			
			return
			 
			

	
			ResBlock
			(
			

	
			Chain
			(
			
			layers
			...
			)
			,
			
                    
			
			
			ni
			 
			==
			 
			nf
			 
			?
			 
			identity
			 
			:
			
                    
			

	
			convxlayer
			(
			ni
			,
			 
			nf
			
			;
			 
			
			ks
			 
			=
			 
			1
			,
			 
			
			stride
			 
			=
			 
			1
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			
                    
			
			
			stride
			 
			==
			 
			1
			 
			?
			 
			identity
			 
			:
			 
			

	
			MeanPool
			(
			
			ntuple
			(
			
			_
			 
			->
			 
			2
			,
			 
			ndim
			)
			)
			)
			

			end
			

			
			

	
			Flux
			.
			
			@
			functor
			 

	
			ResBlock
			

			

			
			
			(
			
			r
			::

	
			ResBlock
			)
			(
			x
			)
			 
			=
			 
			
			act_fn
			.
			
			(
			
			
			
			r
			.
			
			convs
			(
			x
			)
			 
			.+
			 
			
			
			r
			.
			
			idconv
			(
			
			
			r
			.
			
			pool
			(
			x
			)
			)
			)
			

			

			
			function
			 
			

	
			make_layer
			(
			expansion
			,
			 
			ni
			,
			 
			nf
			,
			 
			n_blocks
			,
			 
			stride
			
			;
			 
			
			ndim
			 
			=
			 
			2
			)
			
			
    
			
			return
			 
			

	
			Chain
			(
			
			
			[
			
			

	
			ResBlock
			(
			expansion
			,
			
                           
			
			
			i
			 
			==
			 
			1
			 
			?
			 
			ni
			 
			:
			 
			nf
			,
			
                           
			nf
			
			;
			
                           
			
			stride
			 
			=
			 
			(
			
			
			i
			 
			==
			 
			1
			 
			?
			 
			stride
			 
			:
			 
			1
			)
			,
			
                           
			
			ndim
			 
			=
			 
			ndim
			)
			 
			for
			
			 
			i
			 
			in
			
			 
			1
			:
			n_blocks
			]
			...
			)
			

			end
			

			

			
			
			
			"""
			

			    XResNet(expansion, layers; c_in = 3, ndim = 2)

			

			Create an XResNet model backbone following the [implementation in

			fastai](https://github.com/fastai/fastai/blob/master/fastai/vision/models/xresnet.py#L22).

			

			- `c_in::Int = 3`: The number of input channels, e.g. `1` for grayscale images and `3` for

			    RGB images

			- `ndim::Int = 2`: The number of dimensions for the convolutional and pooling layers, e.g.

			    `2` for 2D input images and `3` for 3D volumes.

			"""
			

			
			function
			 
			

	
			XResNet
			(
			expansion
			,
			 
			layers
			
			;
			 
			
			c_in
			 
			=
			 
			3
			,
			 
			
			ndim
			 
			=
			 
			2
			)
			
			
    
			
			nfs
			 
			=
			 
			
			[
			c_in
			,
			 
			
			(
			
			c_in
			 
			+
			 
			1
			)
			 
			*
			 
			8
			,
			 
			64
			,
			 
			64
			]
			
    
			
			stem
			 
			=
			 
			
			[
			
			

	
			convxlayer
			(
			
			nfs
			[
			i
			]
			,
			 
			
			nfs
			[
			
			i
			 
			+
			 
			1
			]
			
			;
			 
			
			stride
			 
			=
			
			
			 
			i
			 
			==
			 
			1
			 
			?
			 
			2
			 
			:
			 
			1
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			
            
			for
			
			 
			i
			 
			in
			
			 
			1
			:
			3
			]
			

			
    
			
			nfs
			 
			=
			 
			
			[
			
			64
			 
			÷
			 
			expansion
			,
			 
			64
			,
			 
			128
			,
			 
			256
			,
			 
			512
			]
			
    
			
			res_layers
			 
			=
			 
			
			[
			
			

	
			make_layer
			(
			expansion
			,
			 
			
			nfs
			[
			i
			]
			,
			 
			
			nfs
			[
			
			i
			 
			+
			 
			1
			]
			,
			 
			l
			,
			 
			
			
			i
			 
			==
			 
			1
			 
			?
			 
			1
			 
			:
			 
			2
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			
                  
			for
			
			
                  
			
			(
			i
			,
			 
			l
			)
			 
			in
			 
			
			enumerate
			(
			layers
			)
			]
			

			
    
			
			return
			 
			

	
			Chain
			(
			
			stem
			...
			,
			 
			

	
			MaxPool
			(
			
			ntuple
			(
			
			_
			 
			->
			 
			3
			,
			 
			ndim
			)
			
			;
			 
			
			pad
			 
			=
			 
			1
			,
			 
			
			stride
			 
			=
			 
			2
			)
			,
			 
			
			res_layers
			...
			)
			

			end
			

			

			
			

	
			xresnet18
			(
			
			;
			 
			
			kwargs
			...
			)
			 
			=
			 
			

	
			XResNet
			(
			1
			,
			 
			
			[
			2
			,
			 
			2
			,
			 
			2
			,
			 
			2
			]
			
			;
			 
			
			kwargs
			...
			)
			

			
			

	
			xresnet50
			(
			
			;
			 
			
			kwargs
			...
			)
			 
			=
			 
			

	
			XResNet
			(
			4
			,
			 
			
			[
			3
			,
			 
			4
			,
			 
			6
			,
			 
			3
			]
			
			;
			 
			
			kwargs
			...
			)
			

			

			
			@
			testset
			 
			
			"
			XResNet [model]
			"
			 
			
			begin
			
    
			
			@
			testset
			 
			
			"
			Basic
			"
			 
			
			begin
			 
			
			@
			test_nowarn
			 
			
			begin
			
        
			
			model
			 
			=
			 
			

	
			xresnet18
			(
			)
			
        
			
			@
			test
			
			 
			
			

	
			Flux
			.
			
			outputsize
			(
			model
			,
			 
			
			(
			128
			,
			 
			128
			,
			 
			3
			,
			 
			1
			)
			)
			 
			==
			 
			
			(
			4
			,
			 
			4
			,
			 
			512
			,
			 
			1
			)
			
    
			end
			 
			end
			

			
    
			
			@
			testset
			 
			
			"
			3D
			"
			 
			
			begin
			 
			
			@
			test_nowarn
			 
			
			begin
			
        
			
			model
			 
			=
			 
			

	
			xresnet18
			(
			
			ndim
			 
			=
			 
			3
			)
			
        
			
			@
			test
			
			 
			
			

	
			Flux
			.
			
			outputsize
			(
			model
			,
			 
			
			(
			128
			,
			 
			128
			,
			 
			128
			,
			 
			3
			,
			 
			1
			)
			)
			 
			==
			 
			
			(
			4
			,
			 
			4
			,
			 
			4
			,
			 
			512
			,
			 
			1
			)
			
    
			end
			 
			end
			

			end