2022-05-18 14:10:40 +00:00
package cli
import (
"context"
"fmt"
"net"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"github.com/pion/udp"
"golang.org/x/xerrors"
2022-09-02 23:26:01 +00:00
"github.com/coder/coder/agent"
2023-03-23 22:42:20 +00:00
"github.com/coder/coder/cli/clibase"
2022-05-18 14:10:40 +00:00
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
)
2023-03-23 22:42:20 +00:00
func ( r * RootCmd ) portForward ( ) * clibase . Cmd {
2022-05-18 14:10:40 +00:00
var (
2022-09-13 20:55:56 +00:00
tcpForwards [ ] string // <port>:<port>
udpForwards [ ] string // <port>:<port>
2022-05-18 14:10:40 +00:00
)
2023-03-23 22:42:20 +00:00
client := new ( codersdk . Client )
cmd := & clibase . Cmd {
2022-05-18 14:10:40 +00:00
Use : "port-forward <workspace>" ,
2022-09-19 16:36:18 +00:00
Short : "Forward ports from machine to a workspace" ,
2022-05-18 14:10:40 +00:00
Aliases : [ ] string { "tunnel" } ,
2023-03-23 22:42:20 +00:00
Long : formatExamples (
2022-07-11 16:08:09 +00:00
example {
Description : "Port forward a single TCP port from 1234 in the workspace to port 5678 on your local machine" ,
Command : "coder port-forward <workspace> --tcp 5678:1234" ,
} ,
example {
Description : "Port forward a single UDP port from port 9000 to port 9000 on your local machine" ,
Command : "coder port-forward <workspace> --udp 9000" ,
} ,
example {
Description : "Port forward multiple TCP ports and a UDP port" ,
Command : "coder port-forward <workspace> --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53" ,
} ,
2022-10-03 08:58:43 +00:00
example {
Description : "Port forward multiple ports (TCP or UDP) in condensed syntax" ,
Command : "coder port-forward <workspace> --tcp 8080,9000:3000,9090-9092,10000-10002:10010-10012" ,
} ,
2022-07-11 16:08:09 +00:00
) ,
2023-03-23 22:42:20 +00:00
Middleware : clibase . Chain (
clibase . RequireNArgs ( 1 ) ,
r . InitClient ( client ) ,
) ,
Handler : func ( inv * clibase . Invocation ) error {
ctx , cancel := context . WithCancel ( inv . Context ( ) )
2022-08-02 14:44:59 +00:00
defer cancel ( )
2022-09-13 20:55:56 +00:00
specs , err := parsePortForwards ( tcpForwards , udpForwards )
2022-05-18 14:10:40 +00:00
if err != nil {
return xerrors . Errorf ( "parse port-forward specs: %w" , err )
}
if len ( specs ) == 0 {
2023-03-23 22:42:20 +00:00
err = inv . Command . HelpHandler ( inv )
2022-05-18 14:10:40 +00:00
if err != nil {
return xerrors . Errorf ( "generate help output: %w" , err )
}
return xerrors . New ( "no port-forwards requested" )
}
2023-03-23 22:42:20 +00:00
workspace , workspaceAgent , err := getWorkspaceAndAgent ( ctx , inv , client , codersdk . Me , inv . Args [ 0 ] )
2022-05-18 14:10:40 +00:00
if err != nil {
return err
}
2022-05-19 18:04:44 +00:00
if workspace . LatestBuild . Transition != codersdk . WorkspaceTransitionStart {
2022-05-18 14:10:40 +00:00
return xerrors . New ( "workspace must be in start transition to port-forward" )
}
if workspace . LatestBuild . Job . CompletedAt == nil {
2023-03-23 22:42:20 +00:00
err = cliui . WorkspaceBuild ( ctx , inv . Stderr , client , workspace . LatestBuild . ID )
2022-05-18 14:10:40 +00:00
if err != nil {
return err
}
}
2023-03-23 22:42:20 +00:00
err = cliui . Agent ( ctx , inv . Stderr , cliui . AgentOptions {
2022-05-18 14:10:40 +00:00
WorkspaceName : workspace . Name ,
Fetch : func ( ctx context . Context ) ( codersdk . WorkspaceAgent , error ) {
2022-09-02 23:26:01 +00:00
return client . WorkspaceAgent ( ctx , workspaceAgent . ID )
2022-05-18 14:10:40 +00:00
} ,
} )
if err != nil {
return xerrors . Errorf ( "await agent: %w" , err )
}
2022-10-17 13:43:30 +00:00
conn , err := client . DialWorkspaceAgent ( ctx , workspaceAgent . ID , nil )
2022-05-18 14:10:40 +00:00
if err != nil {
2022-09-02 23:26:01 +00:00
return err
2022-05-18 14:10:40 +00:00
}
defer conn . Close ( )
// Start all listeners.
var (
wg = new ( sync . WaitGroup )
listeners = make ( [ ] net . Listener , len ( specs ) )
closeAllListeners = func ( ) {
for _ , l := range listeners {
if l == nil {
continue
}
_ = l . Close ( )
}
}
)
2022-08-02 14:44:59 +00:00
defer closeAllListeners ( )
2022-05-18 14:10:40 +00:00
for i , spec := range specs {
2023-03-23 22:42:20 +00:00
l , err := listenAndPortForward ( ctx , inv , conn , wg , spec )
2022-05-18 14:10:40 +00:00
if err != nil {
return err
}
listeners [ i ] = l
}
// Wait for the context to be canceled or for a signal and close
// all listeners.
var closeErr error
2022-08-02 14:44:59 +00:00
wg . Add ( 1 )
2022-05-18 14:10:40 +00:00
go func ( ) {
2022-08-02 14:44:59 +00:00
defer wg . Done ( )
2022-05-18 14:10:40 +00:00
sigs := make ( chan os . Signal , 1 )
signal . Notify ( sigs , syscall . SIGINT , syscall . SIGTERM )
select {
case <- ctx . Done ( ) :
closeErr = ctx . Err ( )
case <- sigs :
2023-03-23 22:42:20 +00:00
_ , _ = fmt . Fprintln ( inv . Stderr , "\nReceived signal, closing all listeners and active connections" )
2022-05-18 14:10:40 +00:00
}
cancel ( )
closeAllListeners ( )
} ( )
2022-11-13 17:33:05 +00:00
conn . AwaitReachable ( ctx )
2023-03-23 22:42:20 +00:00
_ , _ = fmt . Fprintln ( inv . Stderr , "Ready!" )
2022-05-18 14:10:40 +00:00
wg . Wait ( )
return closeErr
} ,
}
2023-03-23 22:42:20 +00:00
cmd . Options = clibase . OptionSet {
{
Flag : "tcp" ,
FlagShorthand : "p" ,
Env : "CODER_PORT_FORWARD_TCP" ,
Description : "Forward TCP port(s) from the workspace to the local machine." ,
Value : clibase . StringArrayOf ( & tcpForwards ) ,
} ,
{
Flag : "udp" ,
Env : "CODER_PORT_FORWARD_UDP" ,
Description : "Forward UDP port(s) from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols." ,
Value : clibase . StringArrayOf ( & udpForwards ) ,
} ,
}
2022-05-18 14:10:40 +00:00
return cmd
}
2023-03-23 22:42:20 +00:00
func listenAndPortForward ( ctx context . Context , inv * clibase . Invocation , conn * codersdk . WorkspaceAgentConn , wg * sync . WaitGroup , spec portForwardSpec ) ( net . Listener , error ) {
_ , _ = fmt . Fprintf ( inv . Stderr , "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n" , spec . listenNetwork , spec . listenAddress , spec . dialNetwork , spec . dialAddress )
2022-05-18 14:10:40 +00:00
var (
l net . Listener
err error
)
switch spec . listenNetwork {
case "tcp" :
l , err = net . Listen ( spec . listenNetwork , spec . listenAddress )
case "udp" :
var host , port string
host , port , err = net . SplitHostPort ( spec . listenAddress )
if err != nil {
return nil , xerrors . Errorf ( "split %q: %w" , spec . listenAddress , err )
}
var portInt int
portInt , err = strconv . Atoi ( port )
if err != nil {
return nil , xerrors . Errorf ( "parse port %v from %q as int: %w" , port , spec . listenAddress , err )
}
l , err = udp . Listen ( spec . listenNetwork , & net . UDPAddr {
IP : net . ParseIP ( host ) ,
Port : portInt ,
} )
default :
return nil , xerrors . Errorf ( "unknown listen network %q" , spec . listenNetwork )
}
if err != nil {
return nil , xerrors . Errorf ( "listen '%v://%v': %w" , spec . listenNetwork , spec . listenAddress , err )
}
wg . Add ( 1 )
go func ( spec portForwardSpec ) {
defer wg . Done ( )
for {
netConn , err := l . Accept ( )
if err != nil {
2022-10-17 16:45:29 +00:00
// Silently ignore net.ErrClosed errors.
if xerrors . Is ( err , net . ErrClosed ) {
return
}
2023-03-23 22:42:20 +00:00
_ , _ = fmt . Fprintf ( inv . Stderr , "Error accepting connection from '%v://%v': %v\n" , spec . listenNetwork , spec . listenAddress , err )
_ , _ = fmt . Fprintln ( inv . Stderr , "Killing listener" )
2022-05-18 14:10:40 +00:00
return
}
go func ( netConn net . Conn ) {
defer netConn . Close ( )
remoteConn , err := conn . DialContext ( ctx , spec . dialNetwork , spec . dialAddress )
if err != nil {
2023-03-23 22:42:20 +00:00
_ , _ = fmt . Fprintf ( inv . Stderr , "Failed to dial '%v://%v' in workspace: %s\n" , spec . dialNetwork , spec . dialAddress , err )
2022-05-18 14:10:40 +00:00
return
}
defer remoteConn . Close ( )
2022-09-02 23:26:01 +00:00
agent . Bicopy ( ctx , netConn , remoteConn )
2022-05-18 14:10:40 +00:00
} ( netConn )
}
} ( spec )
return l , nil
}
type portForwardSpec struct {
2022-09-13 20:55:56 +00:00
listenNetwork string // tcp, udp
2022-05-18 14:10:40 +00:00
listenAddress string // <ip>:<port> or path
2022-09-13 20:55:56 +00:00
dialNetwork string // tcp, udp
2022-05-18 14:10:40 +00:00
dialAddress string // <ip>:<port> or path
}
2022-09-13 20:55:56 +00:00
func parsePortForwards ( tcpSpecs , udpSpecs [ ] string ) ( [ ] portForwardSpec , error ) {
2022-05-18 14:10:40 +00:00
specs := [ ] portForwardSpec { }
2022-10-03 08:58:43 +00:00
for _ , specEntry := range tcpSpecs {
for _ , spec := range strings . Split ( specEntry , "," ) {
ports , err := parseSrcDestPorts ( spec )
if err != nil {
return nil , xerrors . Errorf ( "failed to parse TCP port-forward specification %q: %w" , spec , err )
}
2022-05-18 14:10:40 +00:00
2022-10-03 08:58:43 +00:00
for _ , port := range ports {
specs = append ( specs , portForwardSpec {
listenNetwork : "tcp" ,
listenAddress : fmt . Sprintf ( "127.0.0.1:%v" , port . local ) ,
dialNetwork : "tcp" ,
dialAddress : fmt . Sprintf ( "127.0.0.1:%v" , port . remote ) ,
} )
}
}
2022-05-18 14:10:40 +00:00
}
2022-10-03 08:58:43 +00:00
for _ , specEntry := range udpSpecs {
for _ , spec := range strings . Split ( specEntry , "," ) {
ports , err := parseSrcDestPorts ( spec )
if err != nil {
return nil , xerrors . Errorf ( "failed to parse UDP port-forward specification %q: %w" , spec , err )
}
2022-05-18 14:10:40 +00:00
2022-10-03 08:58:43 +00:00
for _ , port := range ports {
specs = append ( specs , portForwardSpec {
listenNetwork : "udp" ,
listenAddress : fmt . Sprintf ( "127.0.0.1:%v" , port . local ) ,
dialNetwork : "udp" ,
dialAddress : fmt . Sprintf ( "127.0.0.1:%v" , port . remote ) ,
} )
}
}
2022-05-18 14:10:40 +00:00
}
// Check for duplicate entries.
locals := map [ string ] struct { } { }
for _ , spec := range specs {
localStr := fmt . Sprintf ( "%v:%v" , spec . listenNetwork , spec . listenAddress )
if _ , ok := locals [ localStr ] ; ok {
return nil , xerrors . Errorf ( "local %v %v is specified twice" , spec . listenNetwork , spec . listenAddress )
}
locals [ localStr ] = struct { } { }
}
return specs , nil
}
func parsePort ( in string ) ( uint16 , error ) {
port , err := strconv . ParseUint ( strings . TrimSpace ( in ) , 10 , 16 )
if err != nil {
return 0 , xerrors . Errorf ( "parse port %q: %w" , in , err )
}
if port == 0 {
return 0 , xerrors . New ( "port cannot be 0" )
}
return uint16 ( port ) , nil
}
2022-10-03 08:58:43 +00:00
type parsedSrcDestPort struct {
local , remote uint16
}
func parseSrcDestPorts ( in string ) ( [ ] parsedSrcDestPort , error ) {
2022-05-18 14:10:40 +00:00
parts := strings . Split ( in , ":" )
if len ( parts ) > 2 {
2022-10-03 08:58:43 +00:00
return nil , xerrors . Errorf ( "invalid port specification %q" , in )
2022-05-18 14:10:40 +00:00
}
if len ( parts ) == 1 {
// Duplicate the single part
parts = append ( parts , parts [ 0 ] )
}
2022-10-03 08:58:43 +00:00
if ! strings . Contains ( parts [ 0 ] , "-" ) {
local , err := parsePort ( parts [ 0 ] )
if err != nil {
return nil , xerrors . Errorf ( "parse local port from %q: %w" , in , err )
}
remote , err := parsePort ( parts [ 1 ] )
if err != nil {
return nil , xerrors . Errorf ( "parse remote port from %q: %w" , in , err )
}
2022-05-18 14:10:40 +00:00
2022-10-03 08:58:43 +00:00
return [ ] parsedSrcDestPort { { local : local , remote : remote } } , nil
}
local , err := parsePortRange ( parts [ 0 ] )
2022-05-18 14:10:40 +00:00
if err != nil {
2022-10-03 08:58:43 +00:00
return nil , xerrors . Errorf ( "parse local port range from %q: %w" , in , err )
2022-05-18 14:10:40 +00:00
}
2022-10-03 08:58:43 +00:00
remote , err := parsePortRange ( parts [ 1 ] )
2022-05-18 14:10:40 +00:00
if err != nil {
2022-10-03 08:58:43 +00:00
return nil , xerrors . Errorf ( "parse remote port range from %q: %w" , in , err )
}
if len ( local ) != len ( remote ) {
return nil , xerrors . Errorf ( "port ranges must be the same length, got %d ports forwarded to %d ports" , len ( local ) , len ( remote ) )
}
var out [ ] parsedSrcDestPort
for i := range local {
out = append ( out , parsedSrcDestPort {
local : local [ i ] ,
remote : remote [ i ] ,
} )
2022-05-18 14:10:40 +00:00
}
2022-10-03 08:58:43 +00:00
return out , nil
}
2022-05-18 14:10:40 +00:00
2022-10-03 08:58:43 +00:00
func parsePortRange ( in string ) ( [ ] uint16 , error ) {
parts := strings . Split ( in , "-" )
if len ( parts ) != 2 {
return nil , xerrors . Errorf ( "invalid port range specification %q" , in )
}
start , err := parsePort ( parts [ 0 ] )
if err != nil {
return nil , xerrors . Errorf ( "parse range start port from %q: %w" , in , err )
}
end , err := parsePort ( parts [ 1 ] )
if err != nil {
return nil , xerrors . Errorf ( "parse range end port from %q: %w" , in , err )
}
if end < start {
return nil , xerrors . Errorf ( "range end port %v is less than start port %v" , end , start )
}
var ports [ ] uint16
for i := start ; i <= end ; i ++ {
ports = append ( ports , i )
}
return ports , nil
2022-05-18 14:10:40 +00:00
}