Skip to content

Commit

Permalink
Merge pull request #6023 from SeldonIO/v2
Browse files Browse the repository at this point in the history
fix(ci): Changes from v2 for release 2.8.5 (3)
  • Loading branch information
sakoush authored Nov 1, 2024
2 parents c3ec9ed + 2c79e77 commit e3691bb
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 14 deletions.
23 changes: 22 additions & 1 deletion scheduler/data-flow/src/main/kotlin/io/seldon/dataflow/Cli.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package io.seldon.dataflow

import com.natpryce.konfig.CommandLineOption
import com.natpryce.konfig.Configuration
import com.natpryce.konfig.ConfigurationMap
import com.natpryce.konfig.ConfigurationProperties
import com.natpryce.konfig.EnvironmentVariables
import com.natpryce.konfig.Key
Expand All @@ -25,6 +26,8 @@ import io.klogging.Level
import io.klogging.noCoLogger
import io.seldon.dataflow.kafka.security.KafkaSaslMechanisms
import io.seldon.dataflow.kafka.security.KafkaSecurityProtocols
import java.net.InetAddress
import java.util.UUID

object Cli {
private const val ENV_VAR_PREFIX = "SELDON_"
Expand All @@ -34,6 +37,7 @@ object Cli {
val logLevelApplication = Key("log.level.app", enumType(*Level.values()))
val logLevelKafka = Key("log.level.kafka", enumType(*Level.values()))
val namespace = Key("pod.namespace", stringType)
val dataflowReplicaId = Key("dataflow.replica.id", stringType)

// Seldon components
val upstreamHost = Key("upstream.host", stringType)
Expand Down Expand Up @@ -75,6 +79,7 @@ object Cli {
logLevelApplication,
logLevelKafka,
namespace,
dataflowReplicaId,
upstreamHost,
upstreamPort,
kafkaBootstrapServers,
Expand Down Expand Up @@ -105,10 +110,26 @@ object Cli {

fun configWith(rawArgs: Array<String>): Configuration {
val fromProperties = ConfigurationProperties.fromResource("local.properties")
val fromSystem = getSystemConfig()
val fromEnv = EnvironmentVariables(prefix = ENV_VAR_PREFIX)
val fromArgs = parseArguments(rawArgs)

return fromArgs overriding fromEnv overriding fromProperties
return fromArgs overriding fromEnv overriding fromSystem overriding fromProperties
}

private fun getSystemConfig(): Configuration {
val dataflowIdPair = this.dataflowReplicaId to getNewDataflowId()
return ConfigurationMap(dataflowIdPair)
}

fun getNewDataflowId(assignRandomUuid: Boolean = false): String {
if (!assignRandomUuid) {
try {
return InetAddress.getLocalHost().hostName
} catch (_: Exception) {
}
}
return "seldon-dataflow-engine-" + UUID.randomUUID().toString()
}

private fun parseArguments(rawArgs: Array<String>): Configuration {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ object Main {
describeRetries = config[Cli.topicDescribeRetries],
describeRetryDelayMillis = config[Cli.topicDescribeRetryDelayMillis],
)
val subscriberId = config[Cli.dataflowReplicaId]

val subscriber =
PipelineSubscriber(
"seldon-dataflow-engine",
subscriberId,
kafkaProperties,
kafkaAdminProperties,
kafkaStreamsParams,
Expand Down
1 change: 1 addition & 0 deletions scheduler/data-flow/src/main/resources/local.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
log.level.app=INFO
log.level.kafka=WARN
dataflow.replica.id=seldon-dataflow-engine
kafka.bootstrap.servers=localhost:9092
kafka.consumer.prefix=
kafka.security.protocol=PLAINTEXT
Expand Down
31 changes: 31 additions & 0 deletions scheduler/data-flow/src/test/kotlin/io/seldon/dataflow/CliTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@ import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.Arguments.arguments
import org.junit.jupiter.params.provider.MethodSource
import strikt.api.expectCatching
import strikt.api.expectThat
import strikt.assertions.hasLength
import strikt.assertions.isEqualTo
import strikt.assertions.isNotEqualTo
import strikt.assertions.isSuccess
import strikt.assertions.startsWith
import java.util.UUID
import java.util.stream.Stream
import kotlin.test.Test

internal class CliTest {
@DisplayName("Passing auth mechanism via cli argument")
Expand All @@ -36,6 +42,31 @@ internal class CliTest {
.isEqualTo(expectedMechanism)
}

@Test
fun `should handle dataflow replica id`() {
val cliDefault = Cli.configWith(arrayOf<String>())
val testReplicaId = "dataflow-id-1"
val cli = Cli.configWith(arrayOf("--dataflow-replica-id", testReplicaId))

expectThat(cliDefault[Cli.dataflowReplicaId]) {
isNotEqualTo("seldon-dataflow-engine")
}
expectThat(cli[Cli.dataflowReplicaId]) {
isEqualTo(testReplicaId)
}

// test random Uuid (v4)
val expectedReplicaIdPrefix = "seldon-dataflow-engine-"
val uuidStringLength = 36
val randomReplicaUuid = Cli.getNewDataflowId(true)
expectThat(randomReplicaUuid) {
startsWith(expectedReplicaIdPrefix)
hasLength(expectedReplicaIdPrefix.length + uuidStringLength)
}
expectCatching { UUID.fromString(randomReplicaUuid.removePrefix(expectedReplicaIdPrefix)) }
.isSuccess()
}

companion object {
@JvmStatic
private fun saslMechanisms(): Stream<Arguments> {
Expand Down
14 changes: 12 additions & 2 deletions scheduler/pkg/agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ type Server struct {
certificateStore *seldontls.CertificateStore
waiter *modelRelocatedWaiter // waiter for when we want to drain a particular server replica
autoscalingServiceEnabled bool
agentMutex sync.Map // to force a serial order per agent (serverName, replicaIdx)
}

type SchedulerAgent interface {
Expand All @@ -138,6 +139,7 @@ func NewAgentServer(
scheduler: scheduler,
waiter: newModelRelocatedWaiter(),
autoscalingServiceEnabled: autoscalingServiceEnabled,
agentMutex: sync.Map{},
}

hub.RegisterModelEventHandler(
Expand Down Expand Up @@ -383,12 +385,20 @@ func (s *Server) ModelScalingTrigger(stream pb.AgentService_ModelScalingTriggerS

func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentService_SubscribeServer) error {
logger := s.logger.WithField("func", "Subscribe")
key := ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}

// this is forcing a serial order per agent (serverName, replicaIdx)
// in general this will make sure that a given agent disconnects fully before another agent is allowed to connect
mu, _ := s.agentMutex.LoadOrStore(key, &sync.Mutex{})
mu.(*sync.Mutex).Lock()
defer mu.(*sync.Mutex).Unlock()

logger.Infof("Received subscribe request from %s:%d", request.ServerName, request.ReplicaIdx)

fin := make(chan bool)

s.mutex.Lock()
s.agents[ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx}] = &AgentSubscriber{
s.agents[key] = &AgentSubscriber{
finished: fin,
stream: stream,
}
Expand All @@ -414,7 +424,7 @@ func (s *Server) Subscribe(request *pb.AgentSubscribeRequest, stream pb.AgentSer
case <-ctx.Done():
logger.Infof("Client replica %s:%d has disconnected", request.ServerName, request.ReplicaIdx)
s.mutex.Lock()
delete(s.agents, ServerKey{serverName: request.ServerName, replicaIdx: request.ReplicaIdx})
delete(s.agents, key)
s.mutex.Unlock()
s.removeServerReplicaImpl(request.GetServerName(), int(request.GetReplicaIdx())) // this is non-blocking beyond rescheduling models on removed server
return nil
Expand Down
142 changes: 140 additions & 2 deletions scheduler/pkg/agent/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,40 @@ the Change License after the Change Date as each is defined in accordance with t
package agent

import (
"context"
"fmt"
"sync"
"testing"
"time"

. "github.com/onsi/gomega"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

"github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pbs "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator"
testing_utils "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/scheduler"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/store"
)

type mockScheduler struct {
}

var _ scheduler.Scheduler = (*mockScheduler)(nil)

func (s mockScheduler) Schedule(_ string) error {
return nil
}

func (s mockScheduler) ScheduleFailedModels() ([]string, error) {
return nil, nil
}

type mockStore struct {
models map[string]*store.ModelSnapshot
}
Expand Down Expand Up @@ -91,15 +110,15 @@ func (m *mockStore) UpdateModelState(modelKey string, version uint32, serverKey
}

func (m *mockStore) AddServerReplica(request *pb.AgentSubscribeRequest) error {
panic("implement me")
return nil
}

func (m *mockStore) ServerNotify(request *pbs.ServerNotify) error {
panic("implement me")
}

func (m *mockStore) RemoveServerReplica(serverName string, replicaIdx int) ([]string, error) {
panic("implement me")
return nil, nil
}

func (m *mockStore) DrainServerReplica(serverName string, replicaIdx int) ([]string, error) {
Expand Down Expand Up @@ -943,3 +962,122 @@ func TestAutoscalingEnabled(t *testing.T) {
}

}

func TestSubscribe(t *testing.T) {
log.SetLevel(log.DebugLevel)
g := NewGomegaWithT(t)

type ag struct {
id uint32
doClose bool
}
type test struct {
name string
agents []ag
expectedAgentsCount int
expectedAgentsCountAfterClose int
}
tests := []test{
{
name: "simple",
agents: []ag{
{1, true}, {2, true},
},
expectedAgentsCount: 2,
expectedAgentsCountAfterClose: 0,
},
{
name: "simple - no close",
agents: []ag{
{1, true}, {2, false},
},
expectedAgentsCount: 2,
expectedAgentsCountAfterClose: 1,
},
{
name: "duplicates",
agents: []ag{
{1, true}, {1, false},
},
expectedAgentsCount: 1,
expectedAgentsCountAfterClose: 1,
},
{
name: "duplicates with all close",
agents: []ag{
{1, true}, {1, true}, {1, true},
},
expectedAgentsCount: 1,
expectedAgentsCountAfterClose: 0,
},
}

getStream := func(id uint32, context context.Context, port int) *grpc.ClientConn {
conn, _ := grpc.NewClient(fmt.Sprintf(":%d", port), grpc.WithTransportCredentials(insecure.NewCredentials()))
grpcClient := agent.NewAgentServiceClient(conn)
_, _ = grpcClient.Subscribe(
context,
&agent.AgentSubscribeRequest{
ServerName: "dummy",
ReplicaIdx: id,
ReplicaConfig: &agent.ReplicaConfig{},
Shared: true,
AvailableMemoryBytes: 0,
},
)
return conn
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
logger := log.New()
eventHub, err := coordinator.NewEventHub(logger)
g.Expect(err).To(BeNil())
server := NewAgentServer(logger, &mockStore{}, mockScheduler{}, eventHub, false)
port, err := testing_utils.GetFreePortForTest()
if err != nil {
t.Fatal(err)
}
err = server.startServer(uint(port), false)
if err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)

mu := sync.Mutex{}
streams := make([]*grpc.ClientConn, 0)
for _, a := range test.agents {
go func(id uint32) {
conn := getStream(id, context.Background(), port)
mu.Lock()
streams = append(streams, conn)
mu.Unlock()
}(a.id)
}

maxCount := 10
count := 0
for len(server.agents) != test.expectedAgentsCount && count < maxCount {
time.Sleep(100 * time.Millisecond)
count++
}
g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCount))

for idx, s := range streams {
go func(idx int, s *grpc.ClientConn) {
if test.agents[idx].doClose {
s.Close()
}
}(idx, s)
}

count = 0
for len(server.agents) != test.expectedAgentsCountAfterClose && count < maxCount {
time.Sleep(100 * time.Millisecond)
count++
}
g.Expect(len(server.agents)).To(Equal(test.expectedAgentsCountAfterClose))

server.StopAgentStreams()
})
}
}
Loading

0 comments on commit e3691bb

Please sign in to comment.