summaryrefslogtreecommitdiff
path: root/internal/commands/add_project.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/commands/add_project.go')
-rw-r--r--internal/commands/add_project.go58
1 files changed, 11 insertions, 47 deletions
diff --git a/internal/commands/add_project.go b/internal/commands/add_project.go
index 6c37e2a..1ed42db 100644
--- a/internal/commands/add_project.go
+++ b/internal/commands/add_project.go
@@ -1,13 +1,10 @@
package commands
import (
- "context"
- "database/sql"
"fmt"
- "strconv"
+ "punchcard/internal/actions"
punchctx "punchcard/internal/context"
- "punchcard/internal/queries"
"github.com/spf13/cobra"
)
@@ -34,32 +31,26 @@ Examples:
}
billableRateFloat, _ := cmd.Flags().GetFloat64("hourly-rate")
- billableRate := int64(billableRateFloat * 100) // Convert dollars to cents
+ var billableRate *float64
+ if billableRateFloat > 0 {
+ billableRate = &billableRateFloat
+ }
q := punchctx.GetDB(cmd.Context())
if q == nil {
return fmt.Errorf("database not available in context")
}
- // Find client by ID or name
- client, err := findClient(cmd.Context(), q, clientRef)
+ a := actions.New(q)
+ project, err := a.CreateProject(cmd.Context(), projectName, clientRef, billableRate)
if err != nil {
- return fmt.Errorf("failed to find client: %w", err)
- }
-
- // Create project
- var billableRateParam sql.NullInt64
- if billableRate > 0 {
- billableRateParam = sql.NullInt64{Int64: billableRate, Valid: true}
+ return err
}
- project, err := q.CreateProject(cmd.Context(), queries.CreateProjectParams{
- Name: projectName,
- ClientID: client.ID,
- BillableRate: billableRateParam,
- })
+ // Get client name for output
+ client, err := a.FindClient(cmd.Context(), clientRef)
if err != nil {
- return fmt.Errorf("failed to create project: %w", err)
+ return fmt.Errorf("failed to get client name: %w", err)
}
output := fmt.Sprintf("Created project: %s for client %s (ID: %d)", project.Name, client.Name, project.ID)
@@ -77,30 +68,3 @@ Examples:
return cmd
}
-
-func findClient(ctx context.Context, q *queries.Queries, clientRef string) (queries.Client, error) {
- // Parse clientRef as ID if possible, otherwise use 0
- var idParam int64
- if id, err := strconv.ParseInt(clientRef, 10, 64); err == nil {
- idParam = id
- }
-
- // Search by both ID and name using UNION ALL
- clients, err := q.FindClient(ctx, queries.FindClientParams{
- ID: idParam,
- Name: clientRef,
- })
- if err != nil {
- return queries.Client{}, fmt.Errorf("database error looking up client: %w", err)
- }
-
- // Check results
- switch len(clients) {
- case 0:
- return queries.Client{}, fmt.Errorf("client not found: %s", clientRef)
- case 1:
- return clients[0], nil
- default:
- return queries.Client{}, fmt.Errorf("ambiguous client: %s", clientRef)
- }
-}