Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions cherrysrv/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ package main
import (
"fmt"
"sort"
"strings"
"sync"
)

type Channel struct {
clients []*Client // clients in the channel.
Name string // Name of the channel (incl #)
hidden bool
closeOnEmpty bool // only #main should have this as false
sync.RWMutex // for adding/removing client connections
closeOnEmpty bool // only #main should have this as false
sync.RWMutex // for adding/removing client connections
history *ringbuffer // keep the last n messages for repeating
}

func newChannel(name string, hiddenChannel bool) *Channel {
Expand All @@ -21,6 +23,7 @@ func newChannel(name string, hiddenChannel bool) *Channel {
hidden: hiddenChannel,
closeOnEmpty: true,
RWMutex: sync.RWMutex{},
history: newRingBuffer(10),
}
}

Expand Down Expand Up @@ -81,16 +84,12 @@ func (c *Channel) findClient(client *Client) bool {
return false
}

// TODO: Review subtle bug:
// if addClient is blocked because removeClient is working
// and removeClient leaves channel with 0 elements removing it
// from CHANNELS directory completely, addClient will add a client to a
// removed channel.
func (channel *Channel) addClient(newClient *Client) {
channel.Lock()
defer channel.Unlock()

channel.clients = append(channel.clients, newClient)
newClient.Channels = append(newClient.Channels, channel)
}

func (channel *Channel) removeClient(client *Client) bool {
Expand Down Expand Up @@ -125,11 +124,9 @@ func (channel *Channel) removeClient(client *Client) bool {

// len(c.clients) >= 2
// we loop through all the slice, NOT starting in pos=2

for i := 0; i < len; i++ {
if channel.clients[i] == client {
channel.clients[i] = channel.clients[len-1] // TODO: confirm is copying pointer, not content
channel.clients = channel.clients[:len-1]
for i, c := range channel.clients {
if c == client {
channel.clients = append(channel.clients[:i], channel.clients[i+1:]...)

return true
}
Expand All @@ -150,8 +147,10 @@ func (channel *Channel) Say(from *Client, format string, args ...interface{}) {
}

func (c *Channel) write(from *Client, message string) {
trimmed := strings.TrimSpace(message)
c.RLock()
defer c.RUnlock()
c.history.add(&trimmed)

len := len(c.clients)

Expand Down
30 changes: 17 additions & 13 deletions cherrysrv/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ const (

// Client connection storing basic PC data
type Client struct {
conn *net.TCPConn // tcpsocket connection.
Name string // Name of the user.
Status atomic.Int32
conn net.Conn // tcpsocket connection.
Name string // Name of the user.
Status atomic.Int32
Channels []*Channel
}

func (c *Client) String() string {
return c.Name
}

func newClient(conn *net.TCPConn) *Client {
func newClient(conn net.Conn) *Client {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes it much easier to build tests since the entire server doesn't have to be constructed.


client := &Client{
conn: conn,
Expand Down Expand Up @@ -182,7 +183,7 @@ func (clt *Client) UpdateInMain(format string, args ...interface{}) {

broadcast := func(key string, client *Client) bool {

if clt == client { // we don't want to send the message to us
if clt == client || !client.isLogged() { // we don't want to send the message to us or non-logged users
return true
}

Expand All @@ -193,16 +194,19 @@ func (clt *Client) UpdateInMain(format string, args ...interface{}) {
CLIENTS.Range(broadcast)
}

// delete me from all the channels. This is extremely CPU consuming.
// TODO: add a slice of channels that the user has joined.
func (clt *Client) RemoveMeFromAllChannels() {

removeClient := func(key string, channel *Channel) bool {

for _, channel := range clt.Channels {
channel.removeClient(clt)

return true
}

CHANNELS.Range(removeClient)
clt.Channels = nil
}

func (clt *Client) RemoveChannel(channel *Channel) {
for i, c := range clt.Channels {
if c == channel {
clt.Channels = append(clt.Channels[:i], clt.Channels[i+1:]...)
return
}
}
}
189 changes: 189 additions & 0 deletions cherrysrv/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package main

import (
"bufio"
"fmt"
"net"
"testing"
"time"
)

func genClient() (c *Client, out net.Conn, in *bufio.Reader) {
server, out := net.Pipe()

in = bufio.NewReader(out)

c = newClient(server)
go c.clientLoop()

if res, _, err := in.ReadLine(); err == nil {
fmt.Println(string(res))
}

return
}

type multiCase struct {
i *bufio.Reader
c net.Conn
s []string
}

// TestClient is a set of ordered happy path tests
func TestClient(t *testing.T) {
init_logger()
init_commands()
main_channel := NewChannelMain("#main")
CHANNELS.Store(main_channel.Key(), main_channel)

c1, out, in := genClient()

username := "@tester"
chan1 := "#test"
chan2 := "#test2"

clientTests := []struct {
input []byte
expected []string
}{
{[]byte("/who\n"), []string{fmt.Sprintf(">/who>0>%s", c1.Name)}},
{[]byte("/join #test\n"), []string{">/join>0>/join requires you to be logged"}},
{[]byte("/nusers\n"), []string{">/nusers>0>/nusers requires you to be logged"}},
{[]byte("/users\n"), []string{">/users>0>/users requires you to be logged"}},
{[]byte("/login\n"), []string{">/login>0>/login <account>"}},
{[]byte(fmt.Sprintf("/login %s\n", username)), []string{fmt.Sprintf(">/login>0>you're now %s", username)}},
{[]byte("/login @tester2\n"), []string{">/login>0>you're already logged in"}},
{[]byte("/nusers\n"), []string{">/nusers>0>1"}},
{[]byte("/users\n"), []string{">/users>0>@tester"}},
{[]byte("/join\n"), []string{">/join>0>/join <#channel>"}},
{[]byte(fmt.Sprintf("/join %s\n", chan1)), []string{fmt.Sprintf(">/join>0>%s joined %s", username, chan1)}},
{[]byte(fmt.Sprintf("/say %s hello\n", chan1)), []string{fmt.Sprintf(">%s>%s>hello", chan1, username)}},
{[]byte(fmt.Sprintf("/say %s goodbye\n", chan1)), []string{fmt.Sprintf(">%s>%s>goodbye", chan1, username)}},
{[]byte("/nusers #bigapple\n"), []string{">/users #bigapple>0>#bigapple is not a valid channel"}},
{[]byte(fmt.Sprintf("/nusers %s\n", chan1)), []string{fmt.Sprintf(">/users %s>0>1", chan1)}},
{[]byte(fmt.Sprintf("/users %s\n", chan1)), []string{fmt.Sprintf(">/users %s>0>%s", chan1, username)}},
{[]byte(fmt.Sprintf("/history %s\n", chan1)), []string{fmt.Sprintf(">/history>1>%s>%s>hello", chan1, username), fmt.Sprintf(">/history>0>%s>%s>goodbye", chan1, username)}},
{[]byte("/list\n"), []string{fmt.Sprintf(">/list>1>%s", main_channel.Name), fmt.Sprintf(">/list>0>%s", chan1)}},
{[]byte(fmt.Sprintf("/leave %s\n", chan1)), []string{fmt.Sprintf(">%s>%s>left the channel", chan1, username)}},
{[]byte("/list\n"), []string{fmt.Sprintf(">/list>0>%s", main_channel.Name)}},
{[]byte(fmt.Sprintf("/hjoin %s\n", chan2)), []string{fmt.Sprintf(">/hjoin>0>%s hjoined %s", username, chan2)}},
{[]byte("/list\n"), []string{fmt.Sprintf(">/list>0>%s", main_channel.Name)}},
{[]byte(fmt.Sprintf("/leave %s\n", chan2)), []string{fmt.Sprintf(">%s>%s>left the channel", chan2, username)}},
{[]byte("/logoff\n"), []string{fmt.Sprintf(">/logoff>0>Goodbye %s", username)}},
}

for _, test := range clientTests {
out.Write(test.input)

for i, ex := range test.expected {
if res, _, err := in.ReadLine(); err != nil {
t.Errorf("failed read, expected %s", test.expected[i])
} else if string(res) != ex {
t.Errorf("got %s, expected %s", string(res), ex)
}
}

}
}

func fullRead(buff *bufio.Reader, conn net.Conn, c chan []string) {
conn.SetReadDeadline(time.Now().Add(250 * time.Millisecond))
var s []string
for {
if res, _, err := buff.ReadLine(); err != nil {
break
} else {
s = append(s, string(res))
}
}
c <- s
}

// TestMultipleClients
Copy link
Contributor Author

@kdedon kdedon Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test cases tend to deadlock and need to be redesigned.

Edit: Fixed the test. The pipe acts like an unbuffered channel and blocks until the data is read. I modified the test handling to read from all the clients at once in different goroutines and then handle them sequentially once they had been collated. I also modified it to fully read everything before moving on to make it easier to debug/create tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if the code is valuable, I think the your code to solve this issue is too complex. I need some time to think about it.

Could you do a PR for the testing only? This is something I'm happy to apply.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did we leave off on this? @kdedon did you intend to break this into 2 pieces? I hate to have this dangling out here for so long.

func TestMultipleClients(t *testing.T) {
init_logger()
init_commands()
main_channel := NewChannelMain("#main")
CHANNELS.Store(main_channel.Key(), main_channel)

_, out1, in1 := genClient()
_, out2, in2 := genClient()
_, out3, in3 := genClient()

username1 := "@tester1"
username2 := "@tester2"
username3 := "@tester3"

chan1 := "#test"

clientTests := []struct {
o net.Conn
input []byte
expected []multiCase
}{
{out1, []byte(fmt.Sprintf("/login %s\n", username1)),
[]multiCase{{in1, out1, []string{fmt.Sprintf(">/login>0>you're now %s", username1)}}}},

{out2, []byte(fmt.Sprintf("/login %s\n", username1)),
[]multiCase{{in2, out2, []string{fmt.Sprintf(">/login>0>%s is already taken, please select another @name", username1)}}}},

{out2, []byte(fmt.Sprintf("/login %s\n", username2)),
[]multiCase{{in2, out2, []string{fmt.Sprintf(">/login>0>you're now %s", username2)}},
{in1, out1, []string{fmt.Sprintf(">#main>!login>%s has joined the server", username2)}}}},

{out3, []byte(fmt.Sprintf("/login %s\n", username3)),
[]multiCase{{in3, out3, []string{fmt.Sprintf(">/login>0>you're now %s", username3)}},
{in2, out2, []string{fmt.Sprintf(">#main>!login>%s has joined the server", username3)}},
{in1, out1, []string{fmt.Sprintf(">#main>!login>%s has joined the server", username3)}}}},

{out1, []byte(fmt.Sprintf("/join %s\n", chan1)),
[]multiCase{{in1, out1, []string{fmt.Sprintf(">/join>0>%s joined %s", username1, chan1)}}}},

{out2, []byte(fmt.Sprintf("/join %s\n", chan1)),
[]multiCase{{in2, out2, []string{fmt.Sprintf(">%s>%s>joined the channel", chan1, username2)}},
{in1, out1, []string{fmt.Sprintf(">%s>%s>joined the channel", chan1, username2)}}}},

{out3, []byte(fmt.Sprintf("/join %s\n", chan1)),
[]multiCase{{in3, out3, []string{fmt.Sprintf(">%s>%s>joined the channel", chan1, username3)}},
{in2, out2, []string{fmt.Sprintf(">%s>%s>joined the channel", chan1, username3)}},
{in1, out1, []string{fmt.Sprintf(">%s>%s>joined the channel", chan1, username3)}}}},

{out1, []byte(fmt.Sprintf("/say %s hello\n", chan1)),
[]multiCase{{in3, out3, []string{fmt.Sprintf(">%s>%s>hello", chan1, username1)}},
{in2, out2, []string{fmt.Sprintf(">%s>%s>hello", chan1, username1)}},
{in1, out1, []string{fmt.Sprintf(">%s>%s>hello", chan1, username1)}}}},

{out1, []byte("/logoff\n"),
[]multiCase{{in1, out1, []string{fmt.Sprintf(">/logoff>0>Goodbye %s", username1)}},
{in2, out2, []string{fmt.Sprintf(">#main>!logoff>%s is leaving", username1)}},
{in3, out3, []string{fmt.Sprintf(">#main>!logoff>%s is leaving", username1)}}}},
{out2, []byte("/logoff\n"),
[]multiCase{{in2, out2, []string{fmt.Sprintf(">/logoff>0>Goodbye %s", username2)}},
{in3, out3, []string{fmt.Sprintf(">#main>!logoff>%s is leaving", username2)}}}},
{out3, []byte("/logoff\n"),
[]multiCase{{in3, out3, []string{fmt.Sprintf(">/logoff>0>Goodbye %s", username3)}}}},
}

for _, test := range clientTests {
test.o.Write(test.input)

var rets []chan []string
for i, ex := range test.expected {
rets = append(rets, make(chan []string))
go fullRead(ex.i, ex.c, rets[i])
}
for i, ex := range test.expected {
res := <-rets[i]
if len(ex.s) != len(res) {
t.Errorf("got %v, expected %v", res, ex.s)
} else {
for i, s := range ex.s {
if s != res[i] {
t.Errorf("got %s, expected %s", res[i], s)
}
}
}
}

}
}
Loading