Provision default workspaces and harden daemon pairing

This commit is contained in:
Jiayuan Zhang 2026-03-24 15:19:27 +08:00
parent 2e5e24f194
commit 4c6eb81789
11 changed files with 492 additions and 5 deletions

View file

@ -1,6 +1,7 @@
package handler
import (
"context"
"encoding/json"
"net/http"
"strings"
@ -42,6 +43,111 @@ type LoginResponse struct {
User UserResponse `json:"user"`
}
func defaultWorkspaceName(user db.User) string {
name := strings.TrimSpace(user.Name)
if name == "" {
email := strings.TrimSpace(user.Email)
if at := strings.Index(email, "@"); at > 0 {
name = email[:at]
}
}
if name == "" {
name = "Personal"
}
return name + "'s Workspace"
}
func slugifyWorkspacePart(value string) string {
value = strings.ToLower(strings.TrimSpace(value))
var b strings.Builder
lastWasDash := false
for _, r := range value {
switch {
case r >= 'a' && r <= 'z', r >= '0' && r <= '9':
b.WriteRune(r)
lastWasDash = false
case b.Len() > 0 && !lastWasDash:
b.WriteByte('-')
lastWasDash = true
}
}
return strings.Trim(b.String(), "-")
}
func defaultWorkspaceSlug(user db.User) string {
candidates := []string{
slugifyWorkspacePart(user.Name),
slugifyWorkspacePart(strings.Split(strings.TrimSpace(user.Email), "@")[0]),
"workspace",
}
base := "workspace"
for _, candidate := range candidates {
if candidate != "" {
base = candidate
break
}
}
userID := uuidToString(user.ID)
if len(userID) >= 8 {
return base + "-" + userID[:8]
}
return base
}
func (h *Handler) ensureUserWorkspace(ctx context.Context, user db.User) error {
workspaces, err := h.Queries.ListWorkspaces(ctx, user.ID)
if err != nil {
return err
}
if len(workspaces) > 0 {
return nil
}
tx, err := h.TxStarter.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
qtx := h.Queries.WithTx(tx)
workspaces, err = qtx.ListWorkspaces(ctx, user.ID)
if err != nil {
return err
}
if len(workspaces) > 0 {
return nil
}
workspace, err := qtx.CreateWorkspace(ctx, db.CreateWorkspaceParams{
Name: defaultWorkspaceName(user),
Slug: defaultWorkspaceSlug(user),
Description: pgtype.Text{},
})
if err != nil {
if isUniqueViolation(err) {
workspaces, lookupErr := h.Queries.ListWorkspaces(ctx, user.ID)
if lookupErr == nil && len(workspaces) > 0 {
return nil
}
}
return err
}
if _, err := qtx.CreateMember(ctx, db.CreateMemberParams{
WorkspaceID: workspace.ID,
UserID: user.ID,
Role: "owner",
}); err != nil {
return err
}
return tx.Commit(ctx)
}
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -89,6 +195,11 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
}
}
if err := h.ensureUserWorkspace(r.Context(), user); err != nil {
writeError(w, http.StatusInternalServerError, "failed to provision workspace")
return
}
// Generate JWT
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": uuidToString(user.ID),

View file

@ -33,6 +33,10 @@ func (h *Handler) DaemonRegister(w http.ResponseWriter, r *http.Request) {
return
}
req.WorkspaceID = strings.TrimSpace(req.WorkspaceID)
req.DaemonID = strings.TrimSpace(req.DaemonID)
req.DeviceName = strings.TrimSpace(req.DeviceName)
if req.DaemonID == "" {
writeError(w, http.StatusBadRequest, "daemon_id is required")
return
@ -45,6 +49,10 @@ func (h *Handler) DaemonRegister(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "at least one runtime is required")
return
}
if _, err := h.Queries.GetWorkspace(r.Context(), parseUUID(req.WorkspaceID)); err != nil {
writeError(w, http.StatusNotFound, "workspace not found")
return
}
resp := make([]AgentRuntimeResponse, 0, len(req.Runtimes))
for _, runtime := range req.Runtimes {

View file

@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/go-chi/chi/v5"
@ -17,6 +18,7 @@ import (
)
var testHandler *Handler
var testPool *pgxpool.Pool
var testUserID string
var testWorkspaceID string
@ -43,6 +45,7 @@ func TestMain(m *testing.M) {
hub := realtime.NewHub()
go hub.Run()
testHandler = New(queries, pool, hub)
testPool = pool
testUserID, testWorkspaceID, err = setupHandlerTestFixture(ctx, pool)
if err != nil {
@ -371,3 +374,73 @@ func TestAuthLogin(t *testing.T) {
t.Fatalf("Login: expected email 'test-handler@multica.ai', got '%s'", resp.User.Email)
}
}
func TestAuthLoginCreatesWorkspaceForNewUser(t *testing.T) {
const email = "new-handler-login@multica.ai"
ctx := context.Background()
t.Cleanup(func() {
user, err := testHandler.Queries.GetUserByEmail(ctx, email)
if err == nil {
workspaces, listErr := testHandler.Queries.ListWorkspaces(ctx, user.ID)
if listErr == nil {
for _, workspace := range workspaces {
_ = testHandler.Queries.DeleteWorkspace(ctx, workspace.ID)
}
}
}
_, _ = testPool.Exec(ctx, `DELETE FROM "user" WHERE email = $1`, email)
})
_, _ = testPool.Exec(ctx, `DELETE FROM "user" WHERE email = $1`, email)
w := httptest.NewRecorder()
body := map[string]string{"email": email, "name": "Workspace Owner"}
var buf bytes.Buffer
json.NewEncoder(&buf).Encode(body)
req := httptest.NewRequest("POST", "/auth/login", &buf)
req.Header.Set("Content-Type", "application/json")
testHandler.Login(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Login: expected 200, got %d: %s", w.Code, w.Body.String())
}
user, err := testHandler.Queries.GetUserByEmail(ctx, email)
if err != nil {
t.Fatalf("GetUserByEmail: %v", err)
}
workspaces, err := testHandler.Queries.ListWorkspaces(ctx, user.ID)
if err != nil {
t.Fatalf("ListWorkspaces: %v", err)
}
if len(workspaces) != 1 {
t.Fatalf("ListWorkspaces: expected 1 workspace, got %d", len(workspaces))
}
if !strings.Contains(workspaces[0].Name, "Workspace") {
t.Fatalf("expected auto-created workspace name, got %q", workspaces[0].Name)
}
if workspaces[0].Slug == "" {
t.Fatal("expected auto-created workspace slug")
}
}
func TestDaemonRegisterMissingWorkspaceReturns404(t *testing.T) {
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/daemon/register", bytes.NewBufferString(`{
"workspace_id":"00000000-0000-0000-0000-000000000001",
"daemon_id":"local-daemon",
"device_name":"test-machine",
"runtimes":[{"name":"Local Codex","type":"codex","version":"1.0.0","status":"online"}]
}`))
req.Header.Set("Content-Type", "application/json")
testHandler.DaemonRegister(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("DaemonRegister: expected 404, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "workspace not found") {
t.Fatalf("DaemonRegister: expected workspace not found error, got %s", w.Body.String())
}
}