Skip to content

Commit

Permalink
Add test for copy wait packet
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 21, 2023
1 parent 48acfc4 commit cc0d2f1
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 6 deletions.
31 changes: 30 additions & 1 deletion common/bufio/copy_direct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestCopyWaitTCP(t *testing.T) {
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, LargeDataTest(t, inputConn, &readWaitWrapper{
require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
Conn: outputConn,
readWaiter: readWaiter,
}))
Expand Down Expand Up @@ -46,3 +46,32 @@ func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
r.buffer = buffer
return r.buffer.Read(p)
}

func TestCopyWaitUDP(t *testing.T) {
t.Parallel()
inputConn, outputConn, outputAddr := UDPPipe(t)
readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
require.True(t, created)
require.NotNil(t, readWaiter)
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
PacketConn: outputConn,
readWaiter: readWaiter,
}, outputAddr))
}

type packetReadWaitWrapper struct {
net.PacketConn
readWaiter N.PacketReadWaiter
}

func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer, destination, err := r.readWaiter.WaitReadPacket()
if err != nil {
return
}
n = copy(p, buffer.Bytes())
buffer.Release()
addr = destination.UDPAddr()
return
}
100 changes: 95 additions & 5 deletions common/bufio/net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"io"
"net"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -106,7 +107,7 @@ func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error)
return pingCh, pongCh, test
}

func LargeDataTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
times := 100
chunkSize := int64(64 * 1024)

Expand Down Expand Up @@ -135,7 +136,7 @@ func LargeDataTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error
buf := make([]byte, chunkSize)

for i := 0; i < times; i++ {
_, err := io.ReadFull(inputConn, buf)
_, err := io.ReadFull(outputConn, buf)
if err != nil {
t.Log(err.Error())
return
Expand All @@ -145,7 +146,7 @@ func LargeDataTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error
hashMap[int(buf[0])] = hash[:]
}

sendHash, err := writeRandData(inputConn)
sendHash, err := writeRandData(outputConn)
if err != nil {
t.Log(err.Error())
return
Expand All @@ -158,7 +159,7 @@ func LargeDataTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error
}()

go func() {
sendHash, err := writeRandData(outputConn)
sendHash, err := writeRandData(inputConn)
if err != nil {
t.Log(err.Error())
return
Expand All @@ -168,7 +169,7 @@ func LargeDataTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error
buf := make([]byte, chunkSize)

for i := 0; i < times; i++ {
_, err := io.ReadFull(outputConn, buf)
_, err = io.ReadFull(inputConn, buf)
if err != nil {
t.Log(err.Error())
return
Expand All @@ -185,3 +186,92 @@ func LargeDataTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error
}()
return test(t)
}

func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
rAddr := outputAddr.UDPAddr()
times := 50
chunkSize := 9000
pingCh, pongCh, test := newLargeDataPair()
writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
hashMap := map[int][]byte{}
mux := sync.Mutex{}
for i := 0; i < times; i++ {
buf := make([]byte, chunkSize)
if _, err := rand.Read(buf[1:]); err != nil {
t.Log(err.Error())
continue
}
buf[0] = byte(i)

hash := md5.Sum(buf)
mux.Lock()
hashMap[i] = hash[:]
mux.Unlock()

if _, err := pc.WriteTo(buf, addr); err != nil {
t.Log(err.Error())
}

time.Sleep(10 * time.Millisecond)
}

return hashMap, nil
}
go func() {
var (
lAddr net.Addr
err error
)
hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)

for i := 0; i < times; i++ {
_, lAddr, err = outputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}
hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}
sendHash, err := writeRandData(outputConn, lAddr)
if err != nil {
t.Log(err.Error())
return
}

pingCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()

go func() {
sendHash, err := writeRandData(inputConn, rAddr)
if err != nil {
t.Log(err.Error())
return
}

hashMap := map[int][]byte{}
buf := make([]byte, 64*1024)

for i := 0; i < times; i++ {
_, _, err := inputConn.ReadFrom(buf)
if err != nil {
t.Log(err.Error())
return
}

hash := md5.Sum(buf[:chunkSize])
hashMap[int(buf[0])] = hash[:]
}

pongCh <- hashPair{
sendHash: sendHash,
recvHash: hashMap,
}
}()

return test(t)
}

0 comments on commit cc0d2f1

Please sign in to comment.