summaryrefslogtreecommitdiff
path: root/internal/commands/add_project.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/commands/add_project.go')
-rw-r--r--internal/commands/add_project.go106
1 files changed, 106 insertions, 0 deletions
diff --git a/internal/commands/add_project.go b/internal/commands/add_project.go
new file mode 100644
index 0000000..6c37e2a
--- /dev/null
+++ b/internal/commands/add_project.go
@@ -0,0 +1,106 @@
+package commands
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strconv"
+
+ punchctx "punchcard/internal/context"
+ "punchcard/internal/queries"
+
+ "github.com/spf13/cobra"
+)
+
+func NewAddProjectCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "project <name>",
+ Short: "Add a new project",
+ Long: `Add a new project to the database. Client can be specified by ID or name using the -c/--client flag.
+
+Examples:
+ punch add project "Website Redesign" -c "Acme Corp"
+ punch add project "Mobile App" --client 1`,
+ Args: cobra.ExactArgs(1),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ projectName := args[0]
+
+ clientRef, err := cmd.Flags().GetString("client")
+ if err != nil {
+ return fmt.Errorf("failed to get client flag: %w", err)
+ }
+ if clientRef == "" {
+ return fmt.Errorf("client is required, use -c/--client flag")
+ }
+
+ billableRateFloat, _ := cmd.Flags().GetFloat64("hourly-rate")
+ billableRate := int64(billableRateFloat * 100) // Convert dollars to cents
+
+ q := punchctx.GetDB(cmd.Context())
+ if q == nil {
+ return fmt.Errorf("database not available in context")
+ }
+
+ // Find client by ID or name
+ client, err := findClient(cmd.Context(), q, clientRef)
+ if err != nil {
+ return fmt.Errorf("failed to find client: %w", err)
+ }
+
+ // Create project
+ var billableRateParam sql.NullInt64
+ if billableRate > 0 {
+ billableRateParam = sql.NullInt64{Int64: billableRate, Valid: true}
+ }
+
+ project, err := q.CreateProject(cmd.Context(), queries.CreateProjectParams{
+ Name: projectName,
+ ClientID: client.ID,
+ BillableRate: billableRateParam,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create project: %w", err)
+ }
+
+ output := fmt.Sprintf("Created project: %s for client %s (ID: %d)", project.Name, client.Name, project.ID)
+ cmd.Print(output + "\n")
+
+ return nil
+ },
+ }
+
+ cmd.Flags().StringP("client", "c", "", "Client name or ID (required)")
+ cmd.Flags().Float64P("hourly-rate", "r", 0, "Default hourly billable rate for this project")
+ if err := cmd.MarkFlagRequired("client"); err != nil {
+ panic(fmt.Sprintf("Failed to mark client flag as required: %v", err))
+ }
+
+ return cmd
+}
+
+func findClient(ctx context.Context, q *queries.Queries, clientRef string) (queries.Client, error) {
+ // Parse clientRef as ID if possible, otherwise use 0
+ var idParam int64
+ if id, err := strconv.ParseInt(clientRef, 10, 64); err == nil {
+ idParam = id
+ }
+
+ // Search by both ID and name using UNION ALL
+ clients, err := q.FindClient(ctx, queries.FindClientParams{
+ ID: idParam,
+ Name: clientRef,
+ })
+ if err != nil {
+ return queries.Client{}, fmt.Errorf("database error looking up client: %w", err)
+ }
+
+ // Check results
+ switch len(clients) {
+ case 0:
+ return queries.Client{}, fmt.Errorf("client not found: %s", clientRef)
+ case 1:
+ return clients[0], nil
+ default:
+ return queries.Client{}, fmt.Errorf("ambiguous client: %s", clientRef)
+ }
+}