summaryrefslogtreecommitdiff
path: root/internal/commands/set.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/commands/set.go')
-rw-r--r--internal/commands/set.go333
1 files changed, 333 insertions, 0 deletions
diff --git a/internal/commands/set.go b/internal/commands/set.go
new file mode 100644
index 0000000..32f3b96
--- /dev/null
+++ b/internal/commands/set.go
@@ -0,0 +1,333 @@
+package commands
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strconv"
+ "strings"
+
+ punchctx "punchcard/internal/context"
+ "punchcard/internal/database"
+ "punchcard/internal/queries"
+
+ "github.com/spf13/cobra"
+)
+
+func NewSetCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "set [key=value ...]",
+ Short: "Set configuration values for clients, projects, or contractor info",
+ Long: `Set configuration values using key=value pairs.
+
+Examples:
+ # Set contractor information (no flags)
+ punch set name="John Doe" label="Software Engineer" email="john@example.com"
+
+ # Set client information
+ punch set -c "Acme Corp" name="Acme Corporation" email="billing@acme.com" hourly-rate=150.00
+
+ # Set project information
+ punch set -p "Website Redesign" name="Website Redesign v2" hourly-rate=180.00
+
+Valid keys:
+ - With no flags (contractor): name, label, email
+ - With -c/--client: name, email, hourly-rate (in dollars)
+ - With -p/--project: name, hourly-rate (in dollars)`,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return runSetCommand(cmd, args)
+ },
+ }
+
+ cmd.Flags().StringP("client", "c", "", "Set values for specified client")
+ cmd.Flags().StringP("project", "p", "", "Set values for specified project")
+ cmd.MarkFlagsMutuallyExclusive("client", "project")
+
+ return cmd
+}
+
+func runSetCommand(cmd *cobra.Command, args []string) error {
+ // Get flag values
+ clientName, _ := cmd.Flags().GetString("client")
+ projectName, _ := cmd.Flags().GetString("project")
+
+ // Parse key=value pairs
+ updates, err := parseKeyValuePairs(args)
+ if err != nil {
+ return err
+ }
+
+ if len(updates) == 0 {
+ return fmt.Errorf("no key=value pairs provided")
+ }
+
+ // Get database connection
+ q := punchctx.GetDB(cmd.Context())
+ if q == nil {
+ var err error
+ q, err = database.GetDB()
+ if err != nil {
+ return fmt.Errorf("failed to connect to database: %w", err)
+ }
+ }
+
+ // Route to appropriate handler
+ if clientName != "" {
+ return setClientValues(q, clientName, updates)
+ } else if projectName != "" {
+ return setProjectValues(q, projectName, updates)
+ } else {
+ return setContractorValues(q, updates)
+ }
+}
+
+func parseKeyValuePairs(args []string) (map[string]string, error) {
+ updates := make(map[string]string)
+
+ for _, arg := range args {
+ parts := strings.SplitN(arg, "=", 2)
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid key=value pair: %s", arg)
+ }
+
+ key := strings.TrimSpace(parts[0])
+ value := strings.TrimSpace(parts[1])
+
+ if key == "" {
+ return nil, fmt.Errorf("empty key in pair: %s", arg)
+ }
+
+ updates[key] = value
+ }
+
+ return updates, nil
+}
+
+func setClientValues(q *queries.Queries, clientName string, updates map[string]string) error {
+ // Validate keys
+ validKeys := map[string]bool{"name": true, "email": true, "hourly-rate": true}
+ for key := range updates {
+ if !validKeys[key] {
+ return fmt.Errorf("invalid key '%s' for client. Valid keys: name, email, hourly-rate", key)
+ }
+ }
+
+ // Find the client
+ client, err := findClient(context.Background(), q, clientName)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return fmt.Errorf("client not found: %s", clientName)
+ }
+ return fmt.Errorf("failed to find client: %w", err)
+ }
+
+ // Prepare update values (start with current values)
+ newName := client.Name
+ newEmail := client.Email
+ newBillableRate := client.BillableRate
+
+ // Apply updates
+ for key, value := range updates {
+ switch key {
+ case "name":
+ newName = value
+ case "email":
+ if value == "" {
+ newEmail = sql.NullString{Valid: false}
+ } else {
+ newEmail = sql.NullString{String: value, Valid: true}
+ }
+ case "hourly-rate":
+ if value == "" {
+ newBillableRate = sql.NullInt64{Valid: false}
+ } else {
+ rateFloat, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return fmt.Errorf("invalid hourly-rate value '%s': must be a number (in dollars)", value)
+ }
+ if rateFloat < 0 {
+ return fmt.Errorf("hourly-rate must be non-negative")
+ }
+ // Convert dollars to cents
+ rateCents := int64(rateFloat * 100)
+ newBillableRate = sql.NullInt64{Int64: rateCents, Valid: true}
+ }
+ }
+ }
+
+ // Update the client
+ updated, err := q.UpdateClient(context.Background(), queries.UpdateClientParams{
+ ID: client.ID,
+ Name: newName,
+ Email: newEmail,
+ BillableRate: newBillableRate,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to update client: %w", err)
+ }
+
+ fmt.Printf("Updated client '%s':\n", clientName)
+ fmt.Printf(" name: %s\n", updated.Name)
+ if updated.Email.Valid {
+ fmt.Printf(" email: %s\n", updated.Email.String)
+ } else {
+ fmt.Printf(" email: (not set)\n")
+ }
+ if updated.BillableRate.Valid {
+ fmt.Printf(" billable_rate: %d cents ($%.2f/hour)\n", updated.BillableRate.Int64, float64(updated.BillableRate.Int64)/100.0)
+ } else {
+ fmt.Printf(" billable_rate: (not set)\n")
+ }
+
+ return nil
+}
+
+func setProjectValues(q *queries.Queries, projectName string, updates map[string]string) error {
+ // Validate keys
+ validKeys := map[string]bool{"name": true, "hourly-rate": true}
+ for key := range updates {
+ if !validKeys[key] {
+ return fmt.Errorf("invalid key '%s' for project. Valid keys: name, hourly-rate", key)
+ }
+ }
+
+ // Find the project
+ project, err := findProject(context.Background(), q, projectName)
+ if err != nil {
+ return fmt.Errorf("failed to find project: %w", err)
+ }
+
+ // Prepare update values (start with current values)
+ newName := project.Name
+ newBillableRate := project.BillableRate
+
+ // Apply updates
+ for key, value := range updates {
+ switch key {
+ case "name":
+ newName = value
+ case "hourly-rate":
+ if value == "" {
+ newBillableRate = sql.NullInt64{Valid: false}
+ } else {
+ rateFloat, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return fmt.Errorf("invalid hourly-rate value '%s': must be a number (in dollars)", value)
+ }
+ if rateFloat < 0 {
+ return fmt.Errorf("hourly-rate must be non-negative")
+ }
+ // Convert dollars to cents
+ rateCents := int64(rateFloat * 100)
+ newBillableRate = sql.NullInt64{Int64: rateCents, Valid: true}
+ }
+ }
+ }
+
+ // Update the project
+ updated, err := q.UpdateProject(context.Background(), queries.UpdateProjectParams{
+ ID: project.ID,
+ Name: newName,
+ BillableRate: newBillableRate,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to update project: %w", err)
+ }
+
+ fmt.Printf("Updated project '%s':\n", projectName)
+ fmt.Printf(" name: %s\n", updated.Name)
+ if updated.BillableRate.Valid {
+ fmt.Printf(" billable_rate: %d cents ($%.2f/hour)\n", updated.BillableRate.Int64, float64(updated.BillableRate.Int64)/100.0)
+ } else {
+ fmt.Printf(" billable_rate: (not set)\n")
+ }
+
+ return nil
+}
+
+func setContractorValues(q *queries.Queries, updates map[string]string) error {
+ // Validate keys
+ validKeys := map[string]bool{"name": true, "label": true, "email": true}
+ for key := range updates {
+ if !validKeys[key] {
+ return fmt.Errorf("invalid key '%s' for contractor. Valid keys: name, label, email", key)
+ }
+ }
+
+ // Try to get existing contractor
+ contractor, err := q.GetContractor(context.Background())
+
+ var newName, newLabel, newEmail string
+
+ if err == sql.ErrNoRows {
+ // No contractor exists, we'll create one
+ // Set default values
+ newName = ""
+ newLabel = ""
+ newEmail = ""
+ } else if err != nil {
+ return fmt.Errorf("failed to get contractor information: %w", err)
+ } else {
+ // Contractor exists, start with current values
+ newName = contractor.Name
+ newLabel = contractor.Label
+ newEmail = contractor.Email
+ }
+
+ // Apply updates
+ for key, value := range updates {
+ switch key {
+ case "name":
+ newName = value
+ case "label":
+ newLabel = value
+ case "email":
+ newEmail = value
+ }
+ }
+
+ // Validate required fields
+ if newName == "" {
+ return fmt.Errorf("contractor name cannot be empty")
+ }
+ if newLabel == "" {
+ return fmt.Errorf("contractor label cannot be empty")
+ }
+ if newEmail == "" {
+ return fmt.Errorf("contractor email cannot be empty")
+ }
+
+ // Create or update contractor
+ if err == sql.ErrNoRows {
+ // Create new contractor
+ created, err := q.CreateContractor(context.Background(), queries.CreateContractorParams{
+ Name: newName,
+ Label: newLabel,
+ Email: newEmail,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create contractor: %w", err)
+ }
+ fmt.Printf("Created contractor:\n")
+ fmt.Printf(" name: %s\n", created.Name)
+ fmt.Printf(" label: %s\n", created.Label)
+ fmt.Printf(" email: %s\n", created.Email)
+ } else {
+ // Update existing contractor
+ updated, err := q.UpdateContractor(context.Background(), queries.UpdateContractorParams{
+ Name: newName,
+ Label: newLabel,
+ Email: newEmail,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to update contractor: %w", err)
+ }
+ fmt.Printf("Updated contractor:\n")
+ fmt.Printf(" name: %s\n", updated.Name)
+ fmt.Printf(" label: %s\n", updated.Label)
+ fmt.Printf(" email: %s\n", updated.Email)
+ }
+
+ return nil
+}
+