unet.jl

FastVision/models/unet.jl is a source file in module FastVision

			
			
			
			
			
			"""
			

			    UNetDynamic(backbone, inputsize, k_out[; kwargs...])

			

			Create a U-Net model from convolutional `backbone` architecture. After every

			downsampling layer (i.e. pooling or strided convolution), a skip connection and

			an upsampling block are inserted, resulting in a convolutional network with

			the same spatial output dimensions as its input. Outputs an array with `k_out`

			channels.

			

			## Keyword arguments

			

			- `fdownscale = 0`: Number of upsampling steps to leave out. By default there will be one

			    upsampling step for every downsampling step in `backbone`. Hence if the input spatial

			    size is `(h, w)`, the output size will be `(h/2^fdownscale, w/2^fdownscale)`, i.e.

			    to get outputs at half the resolution, set `fdownscale = 1`.

			- `kwargs...`: Other keyword arguments are passed through to `upsample`.

			

			## Examples

			

			```julia

			using FastAI, Metalhead

			

			backbone = Metalhead.ResNet50(pretrain=true).layers[1][1:end-1]

			unet = UNetDynamic(backbone, (256, 256, 3, 1); k_out = 10)

			Flux.outputsize(unet, (256, 256, 3, 1)) == (256, 256, 10, 1)

			

			unet = UNetDynamic(backbone, (256, 256, 3, 1); fdownscalk_out = 10)

			Flux.outputsize(unet, (256, 256, 3, 1)) == (256, 256, 10, 1)

			```

			"""
			

			
			function
			 
			

	
			UNetDynamic
			(
			backbone
			,
			
                     
			inputsize
			,
			
                     
			
			k_out
			::
			Int
			
			;
			
                     
			
			final
			 
			=
			 

	
			UNetFinalBlock
			,
			
                     
			
			fdownscale
			 
			=
			 
			0
			,
			
                     
			
			kwargs
			...
			)
			
			
    
			
			backbonelayers
			 
			=
			 
			
			collect
			(
			

	
			iterlayers
			(
			backbone
			)
			)
			
    
			
			unet
			 
			=
			 
			

	
			unetlayers
			(
			backbonelayers
			,
			
                      
			inputsize
			
			;
			
                      
			
			m_middle
			 
			=
			 

	
			UNetMiddleBlock
			,
			
                      
			
			skip_upscale
			 
			=
			 
			fdownscale
			,
			
        
			
			kwargs
			...
			)
			
    
			
			outsz
			 
			=
			 
			
			

	
			Flux
			.
			
			outputsize
			(
			unet
			,
			 
			inputsize
			)
			
    
			
			return
			 
			

	
			Chain
			(
			unet
			,
			 
			
			final
			(
			
			outsz
			[
			
			end
			 
			-
			 
			1
			]
			,
			 
			k_out
			,
			 
			
			
			length
			(
			outsz
			)
			 
			-
			 
			2
			)
			)
			

			end
			

			

			
			function
			 
			

	
			catchannels
			(
			x1
			,
			 
			x2
			)
			
			
    
			
			
			
			ndims
			(
			x1
			)
			 
			==
			 
			
			ndims
			(
			x2
			)
			 
			||
			 
			
			error
			(
			
			"
			Expected inputs with same number of dimensions!
			"
			)
			
    
			
			cat
			(
			x1
			,
			 
			x2
			
			;
			 
			
			dims
			 
			=
			
			 
			
			ndims
			(
			x1
			)
			 
			-
			 
			1
			)
			

			end
			

			

			
			function
			 
			

	
			unetlayers
			(
			layers
			,
			
    
			sz
			
			;
			
                    
			
			k_out
			 
			=
			 
			nothing
			,
			
                    
			
			skip_upscale
			 
			=
			 
			0
			,
			
                    
			
			m_middle
			 
			=
			 
			
			_
			 
			->
			 
			
			(
			identity
			,
			)
			)
			
			
    
			
			
			isempty
			(
			layers
			)
			 
			&&
			 
			
			return
			 
			
			m_middle
			(
			
			sz
			[
			
			end
			 
			-
			 
			1
			]
			)
			

			
    
			
			
			layer
			,
			 
			layers
			 
			=
			
			 
			
			layers
			[
			1
			]
			,
			 
			
			layers
			[
			
			2
			:
			end
			]
			
    
			
			outsz
			 
			=
			 
			
			

	
			Flux
			.
			
			outputsize
			(
			layer
			,
			 
			sz
			)
			
    
			
			does_downscale
			 
			=
			
			
			 
			
			sz
			[
			1
			]
			 
			÷
			 
			2
			 
			==
			 
			
			outsz
			[
			1
			]
			

			
    
			
			if
			 
			
			!
			does_downscale
			
			
        
			# If `layer` does not scale down the spatial dimensions, append
			
        
			# it to a Chain
			
        
			
			return
			 
			

	
			Chain
			(
			layer
			,
			 
			
			

	
			unetlayers
			(
			layers
			,
			 
			outsz
			
			;
			 
			k_out
			,
			 
			skip_upscale
			)
			...
			)
			
    
			
			elseif
			
			 
			does_downscale
			 
			&&
			
			 
			skip_upscale
			 
			>
			 
			0
			
			
        
			# If `layer` does scale down the spatial dimensions, but we don't
			
        
			# to upsample this one, recurse with modified arguments
			
        
			
			return
			 
			

	
			Chain
			(
			layer
			,
			
                     
			
			

	
			unetlayers
			(
			layers
			,
			 
			outsz
			
			;
			 
			
			skip_upscale
			 
			=
			
			 
			skip_upscale
			 
			-
			 
			1
			,
			 
			k_out
			)
			...
			)
			
    
			else
			
			
        
			# `layer` scales down the spatial dimensions and we add an upsampling block
			
        
			# and a skip connection that scales the dimensions back up
			
        
			
			childunet
			 
			=
			 
			

	
			Chain
			(
			
			

	
			unetlayers
			(
			layers
			,
			 
			outsz
			
			;
			 
			skip_upscale
			)
			...
			)
			
        
			
			outsz
			 
			=
			 
			
			

	
			Flux
			.
			
			outputsize
			(
			childunet
			,
			 
			outsz
			)
			

			
        
			
			k_in
			 
			=
			 
			
			sz
			[
			
			end
			 
			-
			 
			1
			]
			
        
			
			k_mid
			 
			=
			 
			
			outsz
			[
			
			end
			 
			-
			 
			1
			]
			
        
			
			k_out
			 
			=
			
			 
			
			isnothing
			(
			k_out
			)
			 
			?
			 
			k_in
			 
			:
			 
			k_out
			
        
			
			return
			 
			

	
			UNetBlock
			(
			

	
			Chain
			(
			layer
			,
			 
			childunet
			)
			,
			
                         
			k_in
			,
			  
			# Input channels to upsampling layer
			
                         
			k_mid
			,
			
                         
			k_out
			,
			
                         
			
			
			length
			(
			outsz
			)
			 
			-
			 
			2
			)
			
    
			end
			

			end
			

			

			
			

	
			iterlayers
			(
			
			m
			::

	
			Chain
			)
			 
			=
			 
			
			
			Iterators
			.
			
			flatten
			(
			
			

	
			iterlayers
			(
			l
			)
			 
			for
			
			 
			l
			 
			in
			 
			
			m
			.
			
			layers
			)
			

			
			

	
			iterlayers
			(
			m
			)
			 
			=
			 
			
			(
			m
			,
			)
			

			

			
			
			
			"""
			

			    UNetBlock(m, k_in)

			

			Given convolutional module `m` that halves the spatial dimensions

			and outputs `k_in` filters, create a module that upsamples the

			spatial dimensions and then aggregates features via  a skip connection.

			"""
			

			
			function
			 
			

	
			UNetBlock
			(
			m_child
			,
			 
			k_in
			,
			 
			k_mid
			,
			 
			
			k_out
			 
			=
			
			 
			2
			
			k_in
			,
			 
			
			ndim
			 
			=
			 
			2
			)
			
			
    
			
			return
			 
			

	
			Chain
			(
			
			upsample
			 
			=
			 
			

	
			SkipConnection
			(
			

	
			Chain
			(
			
			child
			 
			=
			 
			m_child
			,
			                              
			# Downsampling and processing
			
                                                 
			
			upsample
			 
			=
			 
			

	
			PixelShuffleICNR
			(
			k_mid
			,
			 
			k_mid
			,
			 
			ndim
			)
			)
			,
			
                                           
			

	
			Parallel
			(

	
			catchannels
			,
			 
			identity
			,
			 
			

	
			BatchNorm
			(
			k_in
			)
			)
			)
			,
			
                 
			
			act
			 
			=
			 
			
			xs
			 
			->
			 
			
			relu
			.
			
			(
			xs
			)
			,
			
                 
			
			combine
			 
			=
			 
			

	
			UNetCombineLayer
			(
			
			k_in
			 
			+
			 
			k_mid
			,
			 
			k_out
			,
			 
			ndim
			)
			)
			

			end
			

			

			
			function
			 
			

	
			PixelShuffleICNR
			(
			k_in
			,
			 
			k_out
			,
			 
			ndim
			
			;
			 
			
			r
			 
			=
			 
			2
			)
			
			
    
			
			return
			 
			

	
			Chain
			(
			

	
			convxlayer
			(
			k_in
			,
			 
			
			k_out
			 
			*
			 
			(
			
			r
			^
			ndim
			)
			,
			 
			
			ks
			 
			=
			 
			1
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			 
			
			

	
			Flux
			.
			

	
			PixelShuffle
			(
			r
			)
			)
			

			end
			

			

			
			function
			 
			

	
			UNetCombineLayer
			(
			k_in
			,
			 
			k_out
			,
			 
			ndim
			)
			
			
    
			
			return
			 
			

	
			Chain
			(
			

	
			convxlayer
			(
			k_in
			,
			 
			k_out
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			 
			

	
			convxlayer
			(
			k_out
			,
			 
			k_out
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			)
			

			end
			

			

			
			function
			 
			

	
			UNetMiddleBlock
			(
			k
			,
			 
			ndim
			)
			
			
    
			
			return
			 
			

	
			Chain
			(
			

	
			convxlayer
			(
			k
			,
			 
			
			2
			
			k
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			 
			

	
			convxlayer
			(
			
			2
			
			k
			,
			 
			k
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			)
			

			end
			

			

			
			function
			 
			

	
			UNetFinalBlock
			(
			k_in
			,
			 
			k_out
			,
			 
			ndim
			)
			
			
    
			
			return
			 
			

	
			Chain
			(
			

	
			ResBlock
			(
			1
			,
			 
			k_in
			,
			 
			k_in
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			,
			 
			

	
			convxlayer
			(
			k_in
			,
			 
			k_out
			,
			 
			
			ks
			 
			=
			 
			1
			,
			 
			
			ndim
			 
			=
			 
			ndim
			)
			)
			

			end
			

			

			
			
			
			"""
			

			    upsample_block_small(insize, k_out)

			

			An upsampling block that increases the spatial dimensions of the input by 2

			using pixel-shuffle upsampling.

			"""
			

			
			function
			 
			

	
			upsample_block_small
			(
			insize
			,
			 
			k_out
			
			;
			 
			
			ks
			 
			=
			 
			3
			,
			 
			
			kwargs
			...
			)
			
			
    
			
			return
			 
			

	
			Chain
			(
			
			

	
			Flux
			.
			

	
			PixelShuffle
			(
			2
			)
			,
			 
			

	
			convxlayer
			(
			
			
			insize
			[
			
			end
			 
			-
			 
			1
			]
			 
			÷
			 
			4
			,
			 
			k_out
			
			;
			 
			
			kwargs
			...
			)
			)
			

			end
			

			

			
			function
			 
			

	
			conv_final
			(
			insize
			,
			 
			k_out
			
			;
			 
			
			ks
			 
			=
			 
			1
			,
			 
			
			kwargs
			...
			)
			
			
    
			
			return
			 
			

	
			convxlayer
			(
			
			insize
			[
			
			end
			 
			-
			 
			1
			]
			,
			 
			k_out
			
			;
			 
			
			ks
			 
			=
			 
			ks
			,
			 
			
			kwargs
			...
			)
			

			end
			

			

			
			@
			testset
			 
			
			"
			UNetDynamic [model]
			"
			 
			
			begin
			 
			
			@
			test_nowarn
			 
			
			begin
			
    
			
			model
			 
			=
			 
			

	
			UNetDynamic
			(
			
			
			Models
			.
			

	
			xresnet18
			(
			)
			,
			 
			
			(
			128
			,
			 
			128
			,
			 
			3
			,
			 
			1
			)
			,
			 
			4
			)
			
    
			
			@
			test
			
			 
			
			

	
			Flux
			.
			
			outputsize
			(
			model
			,
			 
			
			(
			128
			,
			 
			128
			,
			 
			3
			,
			 
			1
			)
			)
			 
			==
			 
			
			(
			128
			,
			 
			128
			,
			 
			4
			,
			 
			1
			)
			

			
    
			
			model
			 
			=
			 
			

	
			UNetDynamic
			(
			
			
			Models
			.
			

	
			xresnet18
			(
			)
			,
			 
			
			(
			128
			,
			 
			128
			,
			 
			3
			,
			 
			1
			)
			,
			 
			4
			,
			 
			
			fdownscale
			 
			=
			 
			1
			)
			
    
			
			@
			test
			
			 
			
			

	
			Flux
			.
			
			outputsize
			(
			model
			,
			 
			
			(
			128
			,
			 
			128
			,
			 
			3
			,
			 
			1
			)
			)
			 
			==
			 
			
			(
			64
			,
			 
			64
			,
			 
			4
			,
			 
			1
			)
			

			
    
			
			model
			 
			=
			 
			

	
			UNetDynamic
			(
			
			
			Models
			.
			

	
			xresnet18
			(
			
			ndim
			 
			=
			 
			3
			)
			,
			 
			
			(
			128
			,
			 
			128
			,
			 
			128
			,
			 
			3
			,
			 
			1
			)
			,
			 
			4
			)
			
    
			
			@
			test
			
			 
			
			

	
			Flux
			.
			
			outputsize
			(
			model
			,
			 
			
			(
			128
			,
			 
			128
			,
			 
			128
			,
			 
			3
			,
			 
			1
			)
			)
			 
			==
			 
			
			(
			128
			,
			 
			128
			,
			 
			128
			,
			 
			4
			,
			 
			1
			)
			

			end
			 
			end