package commands import ( "context" "database/sql" "fmt" "strconv" "strings" punchctx "git.tjp.lol/punchcard/internal/context" "git.tjp.lol/punchcard/internal/database" "git.tjp.lol/punchcard/internal/queries" "github.com/spf13/cobra" ) func NewSetCmd() *cobra.Command { cmd := &cobra.Command{ Use: "set [key=value ...]", Short: "Set configuration values for clients, projects, or contractor info", Long: `Set configuration values using key=value pairs. Examples: # Set contractor information (no flags) punch set name="John Doe" label="Software Engineer" email="john@example.com" # Set client information punch set -c "Acme Corp" name="Acme Corporation" email="billing@acme.com" hourly-rate=150.00 # Set project information punch set -p "Website Redesign" name="Website Redesign v2" hourly-rate=180.00 Valid keys: - With no flags (contractor): name, label, email - With -c/--client: name, email, hourly-rate (in dollars) - With -p/--project: name, hourly-rate (in dollars)`, RunE: func(cmd *cobra.Command, args []string) error { return runSetCommand(cmd, args) }, } cmd.Flags().StringP("client", "c", "", "Set values for specified client") cmd.Flags().StringP("project", "p", "", "Set values for specified project") cmd.MarkFlagsMutuallyExclusive("client", "project") return cmd } func runSetCommand(cmd *cobra.Command, args []string) error { // Get flag values clientName, _ := cmd.Flags().GetString("client") projectName, _ := cmd.Flags().GetString("project") // Parse key=value pairs updates, err := parseKeyValuePairs(args) if err != nil { return err } if len(updates) == 0 { return fmt.Errorf("no key=value pairs provided") } // Get database connection q := punchctx.GetDB(cmd.Context()) if q == nil { var err error q, err = database.GetDB() if err != nil { return fmt.Errorf("failed to connect to database: %w", err) } } // Route to appropriate handler if clientName != "" { return setClientValues(q, clientName, updates) } else if projectName != "" { return setProjectValues(q, projectName, updates) } else { return setContractorValues(q, updates) } } func parseKeyValuePairs(args []string) (map[string]string, error) { updates := make(map[string]string) for _, arg := range args { parts := strings.SplitN(arg, "=", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid key=value pair: %s", arg) } key := strings.TrimSpace(parts[0]) value := strings.TrimSpace(parts[1]) if key == "" { return nil, fmt.Errorf("empty key in pair: %s", arg) } updates[key] = value } return updates, nil } func setClientValues(q *queries.Queries, clientName string, updates map[string]string) error { // Validate keys validKeys := map[string]bool{"name": true, "email": true, "hourly-rate": true} for key := range updates { if !validKeys[key] { return fmt.Errorf("invalid key '%s' for client. Valid keys: name, email, hourly-rate", key) } } // Find the client client, err := findClient(context.Background(), q, clientName) if err != nil { if err == sql.ErrNoRows { return fmt.Errorf("client not found: %s", clientName) } return fmt.Errorf("failed to find client: %w", err) } // Prepare update values (start with current values) newName := client.Name newEmail := client.Email newBillableRate := client.BillableRate // Apply updates for key, value := range updates { switch key { case "name": newName = value case "email": if value == "" { newEmail = sql.NullString{Valid: false} } else { newEmail = sql.NullString{String: value, Valid: true} } case "hourly-rate": if value == "" { newBillableRate = sql.NullInt64{Valid: false} } else { rateFloat, err := strconv.ParseFloat(value, 64) if err != nil { return fmt.Errorf("invalid hourly-rate value '%s': must be a number (in dollars)", value) } if rateFloat < 0 { return fmt.Errorf("hourly-rate must be non-negative") } // Convert dollars to cents rateCents := int64(rateFloat * 100) newBillableRate = sql.NullInt64{Int64: rateCents, Valid: true} } } } // Update the client updated, err := q.UpdateClient(context.Background(), queries.UpdateClientParams{ ID: client.ID, Name: newName, Email: newEmail, BillableRate: newBillableRate, }) if err != nil { return fmt.Errorf("failed to update client: %w", err) } fmt.Printf("Updated client '%s':\n", clientName) fmt.Printf(" name: %s\n", updated.Name) if updated.Email.Valid { fmt.Printf(" email: %s\n", updated.Email.String) } else { fmt.Printf(" email: (not set)\n") } if updated.BillableRate.Valid { fmt.Printf(" billable_rate: %d cents ($%.2f/hour)\n", updated.BillableRate.Int64, float64(updated.BillableRate.Int64)/100.0) } else { fmt.Printf(" billable_rate: (not set)\n") } return nil } func setProjectValues(q *queries.Queries, projectName string, updates map[string]string) error { // Validate keys validKeys := map[string]bool{"name": true, "hourly-rate": true} for key := range updates { if !validKeys[key] { return fmt.Errorf("invalid key '%s' for project. Valid keys: name, hourly-rate", key) } } // Find the project project, err := findProject(context.Background(), q, projectName) if err != nil { return fmt.Errorf("failed to find project: %w", err) } // Prepare update values (start with current values) newName := project.Name newBillableRate := project.BillableRate // Apply updates for key, value := range updates { switch key { case "name": newName = value case "hourly-rate": if value == "" { newBillableRate = sql.NullInt64{Valid: false} } else { rateFloat, err := strconv.ParseFloat(value, 64) if err != nil { return fmt.Errorf("invalid hourly-rate value '%s': must be a number (in dollars)", value) } if rateFloat < 0 { return fmt.Errorf("hourly-rate must be non-negative") } // Convert dollars to cents rateCents := int64(rateFloat * 100) newBillableRate = sql.NullInt64{Int64: rateCents, Valid: true} } } } // Update the project updated, err := q.UpdateProject(context.Background(), queries.UpdateProjectParams{ ID: project.ID, Name: newName, BillableRate: newBillableRate, }) if err != nil { return fmt.Errorf("failed to update project: %w", err) } fmt.Printf("Updated project '%s':\n", projectName) fmt.Printf(" name: %s\n", updated.Name) if updated.BillableRate.Valid { fmt.Printf(" billable_rate: %d cents ($%.2f/hour)\n", updated.BillableRate.Int64, float64(updated.BillableRate.Int64)/100.0) } else { fmt.Printf(" billable_rate: (not set)\n") } return nil } func setContractorValues(q *queries.Queries, updates map[string]string) error { // Validate keys validKeys := map[string]bool{"name": true, "label": true, "email": true} for key := range updates { if !validKeys[key] { return fmt.Errorf("invalid key '%s' for contractor. Valid keys: name, label, email", key) } } // Try to get existing contractor contractor, err := q.GetContractor(context.Background()) var newName, newLabel, newEmail string if err == sql.ErrNoRows { // No contractor exists, we'll create one // Set default values newName = "" newLabel = "" newEmail = "" } else if err != nil { return fmt.Errorf("failed to get contractor information: %w", err) } else { // Contractor exists, start with current values newName = contractor.Name newLabel = contractor.Label newEmail = contractor.Email } // Apply updates for key, value := range updates { switch key { case "name": newName = value case "label": newLabel = value case "email": newEmail = value } } // Validate required fields if newName == "" { return fmt.Errorf("contractor name cannot be empty") } if newLabel == "" { return fmt.Errorf("contractor label cannot be empty") } if newEmail == "" { return fmt.Errorf("contractor email cannot be empty") } // Create or update contractor if err == sql.ErrNoRows { // Create new contractor created, err := q.CreateContractor(context.Background(), queries.CreateContractorParams{ Name: newName, Label: newLabel, Email: newEmail, }) if err != nil { return fmt.Errorf("failed to create contractor: %w", err) } fmt.Printf("Created contractor:\n") fmt.Printf(" name: %s\n", created.Name) fmt.Printf(" label: %s\n", created.Label) fmt.Printf(" email: %s\n", created.Email) } else { // Update existing contractor updated, err := q.UpdateContractor(context.Background(), queries.UpdateContractorParams{ Name: newName, Label: newLabel, Email: newEmail, }) if err != nil { return fmt.Errorf("failed to update contractor: %w", err) } fmt.Printf("Updated contractor:\n") fmt.Printf(" name: %s\n", updated.Name) fmt.Printf(" label: %s\n", updated.Label) fmt.Printf(" email: %s\n", updated.Email) } return nil }