Support for group chats

- Save the names of group chats so you can query them
- Ability to send messages to group chats by using the jid
This commit is contained in:
Luke Harries
2025-03-30 16:49:42 +01:00
committed by GitHub
3 changed files with 209 additions and 100 deletions

View File

@@ -8,19 +8,21 @@ import (
"net/http"
"os"
"os/signal"
"reflect"
"strings"
"syscall"
"time"
"github.com/mdp/qrterminal"
_ "github.com/mattn/go-sqlite3"
"github.com/mdp/qrterminal"
"go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/store/sqlstore"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow/types/events"
waLog "go.mau.fi/whatsmeow/util/log"
"google.golang.org/protobuf/proto"
waProto "go.mau.fi/whatsmeow/binary/proto"
)
// Message represents a chat message for our client
@@ -42,13 +44,13 @@ func NewMessageStore() (*MessageStore, error) {
if err := os.MkdirAll("store", 0755); err != nil {
return nil, fmt.Errorf("failed to create store directory: %v", err)
}
// Open SQLite database for messages
db, err := sql.Open("sqlite3", "file:store/messages.db?_foreign_keys=on")
if err != nil {
return nil, fmt.Errorf("failed to open message database: %v", err)
}
// Create tables if they don't exist
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS chats (
@@ -72,7 +74,7 @@ func NewMessageStore() (*MessageStore, error) {
db.Close()
return nil, fmt.Errorf("failed to create tables: %v", err)
}
return &MessageStore{db: db}, nil
}
@@ -96,7 +98,7 @@ func (store *MessageStore) StoreMessage(id, chatJID, sender, content string, tim
if content == "" {
return nil
}
_, err := store.db.Exec(
"INSERT OR REPLACE INTO messages (id, chat_jid, sender, content, timestamp, is_from_me) VALUES (?, ?, ?, ?, ?, ?)",
id, chatJID, sender, content, timestamp, isFromMe,
@@ -114,7 +116,7 @@ func (store *MessageStore) GetMessages(chatJID string, limit int) ([]Message, er
return nil, err
}
defer rows.Close()
var messages []Message
for rows.Next() {
var msg Message
@@ -126,7 +128,7 @@ func (store *MessageStore) GetMessages(chatJID string, limit int) ([]Message, er
msg.Time = timestamp
messages = append(messages, msg)
}
return messages, nil
}
@@ -137,7 +139,7 @@ func (store *MessageStore) GetChats() (map[string]time.Time, error) {
return nil, err
}
defer rows.Close()
chats := make(map[string]time.Time)
for rows.Next() {
var jid string
@@ -148,7 +150,7 @@ func (store *MessageStore) GetChats() (map[string]time.Time, error) {
}
chats[jid] = lastMessageTime
}
return chats, nil
}
@@ -157,14 +159,14 @@ func extractTextContent(msg *waProto.Message) string {
if msg == nil {
return ""
}
// Try to get text content
if text := msg.GetConversation(); text != "" {
return text
} else if extendedText := msg.GetExtendedTextMessage(); extendedText != nil {
return extendedText.GetText()
}
// For now, we're ignoring non-text messages
return ""
}
@@ -177,32 +179,47 @@ type SendMessageResponse struct {
// SendMessageRequest represents the request body for the send message API
type SendMessageRequest struct {
Phone string `json:"phone"`
Message string `json:"message"`
Recipient string `json:"recipient"`
Message string `json:"message"`
}
// Function to send a WhatsApp message
func sendWhatsAppMessage(client *whatsmeow.Client, phone, message string) (bool, string) {
func sendWhatsAppMessage(client *whatsmeow.Client, recipient string, message string) (bool, string) {
if !client.IsConnected() {
return false, "Not connected to WhatsApp"
}
// Create JID for recipient
recipientJID := types.JID{
User: phone,
Server: "s.whatsapp.net", // For personal chats
var recipientJID types.JID
var err error
// Check if recipient is a JID
isJID := strings.Contains(recipient, "@")
if isJID {
// Parse the JID string
recipientJID, err = types.ParseJID(recipient)
if err != nil {
return false, fmt.Sprintf("Error parsing JID: %v", err)
}
} else {
// Create JID from phone number
recipientJID = types.JID{
User: recipient,
Server: "s.whatsapp.net", // For personal chats
}
}
// Send the message
_, err := client.SendMessage(context.Background(), recipientJID, &waProto.Message{
_, err = client.SendMessage(context.Background(), recipientJID, &waProto.Message{
Conversation: proto.String(message),
})
if err != nil {
return false, fmt.Sprintf("Error sending message: %v", err)
}
return true, fmt.Sprintf("Message sent to %s", phone)
return true, fmt.Sprintf("Message sent to %s", recipient)
}
// Start a REST API server to expose the WhatsApp client functionality
@@ -215,42 +232,42 @@ func startRESTServer(client *whatsmeow.Client, port int) {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse the request body
var req SendMessageRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request format", http.StatusBadRequest)
return
}
// Validate request
if req.Phone == "" || req.Message == "" {
http.Error(w, "Phone and message are required", http.StatusBadRequest)
if req.Recipient == "" || req.Message == "" {
http.Error(w, "Recipient and message are required", http.StatusBadRequest)
return
}
// Send the message
success, message := sendWhatsAppMessage(client, req.Phone, req.Message)
success, message := sendWhatsAppMessage(client, req.Recipient, req.Message)
fmt.Println("Message sent", success, message)
// Set response headers
w.Header().Set("Content-Type", "application/json")
// Set appropriate status code
if !success {
w.WriteHeader(http.StatusInternalServerError)
}
// Send response
json.NewEncoder(w).Encode(SendMessageResponse{
Success: success,
Message: message,
})
})
// Start the server
serverAddr := fmt.Sprintf(":%d", port)
fmt.Printf("Starting REST API server on %s...\n", serverAddr)
// Run server in a goroutine so it doesn't block
go func() {
if err := http.ListenAndServe(serverAddr, nil); err != nil {
@@ -266,13 +283,13 @@ func main() {
// Create database connection for storing session data
dbLog := waLog.Stdout("Database", "INFO", true)
// Create directory for database if it doesn't exist
if err := os.MkdirAll("store", 0755); err != nil {
logger.Errorf("Failed to create store directory: %v", err)
return
}
container, err := sqlstore.New("sqlite3", "file:store/whatsapp.db?_foreign_keys=on", dbLog)
if err != nil {
logger.Errorf("Failed to connect to database: %v", err)
@@ -298,7 +315,7 @@ func main() {
logger.Errorf("Failed to create WhatsApp client")
return
}
// Initialize message store
messageStore, err := NewMessageStore()
if err != nil {
@@ -306,29 +323,29 @@ func main() {
return
}
defer messageStore.Close()
// Setup event handling for messages and history sync
client.AddEventHandler(func(evt interface{}) {
switch v := evt.(type) {
case *events.Message:
// Process regular messages
handleMessage(client, messageStore, v, logger)
case *events.HistorySync:
// Process history sync events
handleHistorySync(client, messageStore, v, logger)
case *events.Connected:
logger.Infof("Connected to WhatsApp")
case *events.LoggedOut:
logger.Warnf("Device logged out, please scan QR code to log in again")
}
})
// Create channel to track connection success
connected := make(chan bool, 1)
// Connect to WhatsApp
if client.Store.ID == nil {
// No ID stored, this is a new client, need to pair with phone
@@ -349,7 +366,7 @@ func main() {
break
}
}
// Wait for connection
select {
case <-connected:
@@ -370,31 +387,114 @@ func main() {
// Wait a moment for connection to stabilize
time.Sleep(2 * time.Second)
if !client.IsConnected() {
logger.Errorf("Failed to establish stable connection")
return
}
fmt.Println("\n✓ Connected to WhatsApp! Type 'help' for commands.")
// Start REST API server
startRESTServer(client, 8080)
// Create a channel to keep the main goroutine alive
exitChan := make(chan os.Signal, 1)
signal.Notify(exitChan, syscall.SIGINT, syscall.SIGTERM)
fmt.Println("REST server is running. Press Ctrl+C to disconnect and exit.")
// Wait for termination signal
<-exitChan
fmt.Println("Disconnecting...")
// Disconnect client
client.Disconnect()
}
// GetChatName determines the appropriate name for a chat based on JID and other info
func GetChatName(client *whatsmeow.Client, messageStore *MessageStore, jid types.JID, chatJID string, conversation interface{}, sender string, logger waLog.Logger) string {
// First, check if chat already exists in database with a name
var existingName string
err := messageStore.db.QueryRow("SELECT name FROM chats WHERE jid = ?", chatJID).Scan(&existingName)
if err == nil && existingName != "" {
// Chat exists with a name, use that
logger.Infof("Using existing chat name for %s: %s", chatJID, existingName)
return existingName
}
// Need to determine chat name
var name string
if jid.Server == "g.us" {
// This is a group chat
logger.Infof("Getting name for group: %s", chatJID)
// Use conversation data if provided (from history sync)
if conversation != nil {
// Extract name from conversation if available
// This uses type assertions to handle different possible types
var displayName, convName *string
// Try to extract the fields we care about regardless of the exact type
v := reflect.ValueOf(conversation)
if v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
// Try to find DisplayName field
if displayNameField := v.FieldByName("DisplayName"); displayNameField.IsValid() && displayNameField.Kind() == reflect.Ptr && !displayNameField.IsNil() {
dn := displayNameField.Elem().String()
displayName = &dn
}
// Try to find Name field
if nameField := v.FieldByName("Name"); nameField.IsValid() && nameField.Kind() == reflect.Ptr && !nameField.IsNil() {
n := nameField.Elem().String()
convName = &n
}
}
// Use the name we found
if displayName != nil && *displayName != "" {
name = *displayName
} else if convName != nil && *convName != "" {
name = *convName
}
}
// If we didn't get a name, try group info
if name == "" {
groupInfo, err := client.GetGroupInfo(jid)
if err == nil && groupInfo.Name != "" {
name = groupInfo.Name
} else {
// Fallback name for groups
name = fmt.Sprintf("Group %s", jid.User)
}
}
logger.Infof("Using group name: %s", name)
} else {
// This is an individual contact
logger.Infof("Getting name for contact: %s", chatJID)
// Just use contact info (full name)
contact, err := client.Store.Contacts.GetContact(jid)
if err == nil && contact.FullName != "" {
name = contact.FullName
} else if sender != "" {
// Fallback to sender
name = sender
} else {
// Last fallback to JID
name = jid.User
}
logger.Infof("Using contact name: %s", name)
}
return name
}
// Handle regular incoming messages
func handleMessage(client *whatsmeow.Client, messageStore *MessageStore, msg *events.Message, logger waLog.Logger) {
// Extract text content
@@ -402,29 +502,25 @@ func handleMessage(client *whatsmeow.Client, messageStore *MessageStore, msg *ev
if content == "" {
return // Skip non-text messages
}
// Save message to database
chatJID := msg.Info.Chat.String()
sender := msg.Info.Sender.User
// Get contact name if possible
name := sender
contact, err := client.Store.Contacts.GetContact(msg.Info.Sender)
if err == nil && contact.FullName != "" {
name = contact.FullName
}
// Update chat in database
err = messageStore.StoreChat(chatJID, name, msg.Info.Timestamp)
// Get appropriate chat name (pass nil for conversation since we don't have one for regular messages)
name := GetChatName(client, messageStore, msg.Info.Chat, chatJID, nil, sender, logger)
// Update chat in database with the message timestamp (keeps last message time updated)
err := messageStore.StoreChat(chatJID, name, msg.Info.Timestamp)
if err != nil {
logger.Warnf("Failed to store chat: %v", err)
}
// Store message in database
err = messageStore.StoreMessage(
msg.Info.ID,
chatJID,
sender,
msg.Info.ID,
chatJID,
sender,
content,
msg.Info.Timestamp,
msg.Info.IsFromMe,
@@ -445,30 +541,26 @@ func handleMessage(client *whatsmeow.Client, messageStore *MessageStore, msg *ev
// Handle history sync events
func handleHistorySync(client *whatsmeow.Client, messageStore *MessageStore, historySync *events.HistorySync, logger waLog.Logger) {
fmt.Printf("Received history sync event with %d conversations\n", len(historySync.Data.Conversations))
syncedCount := 0
for _, conversation := range historySync.Data.Conversations {
// Parse JID from the conversation
if conversation.ID == nil {
continue
}
chatJID := *conversation.ID
// Try to parse the JID
jid, err := types.ParseJID(chatJID)
if err != nil {
logger.Warnf("Failed to parse JID %s: %v", chatJID, err)
continue
}
// Get contact name
name := jid.User
contact, err := client.Store.Contacts.GetContact(jid)
if err == nil && contact.FullName != "" {
name = contact.FullName
}
// Get appropriate chat name by passing the history sync conversation directly
name := GetChatName(client, messageStore, jid, chatJID, conversation, "", logger)
// Process messages
messages := conversation.Messages
if len(messages) > 0 {
@@ -477,7 +569,7 @@ func handleHistorySync(client *whatsmeow.Client, messageStore *MessageStore, his
if latestMsg == nil || latestMsg.Message == nil {
continue
}
// Get timestamp from message info
timestamp := time.Time{}
if ts := latestMsg.Message.GetMessageTimestamp(); ts != 0 {
@@ -485,15 +577,15 @@ func handleHistorySync(client *whatsmeow.Client, messageStore *MessageStore, his
} else {
continue
}
messageStore.StoreChat(chatJID, name, timestamp)
// Store messages
for _, msg := range messages {
if msg == nil || msg.Message == nil {
continue
}
// Extract text content
var content string
if msg.Message.Message != nil {
@@ -503,15 +595,15 @@ func handleHistorySync(client *whatsmeow.Client, messageStore *MessageStore, his
content = ext.GetText()
}
}
// Log the message content for debugging
logger.Infof("Message content: %v", content)
// Skip non-text messages
if content == "" {
continue
}
// Determine sender
var sender string
isFromMe := false
@@ -529,13 +621,13 @@ func handleHistorySync(client *whatsmeow.Client, messageStore *MessageStore, his
} else {
sender = jid.User
}
// Store message
msgID := ""
if msg.Message.Key != nil && msg.Message.Key.ID != nil {
msgID = *msg.Message.Key.ID
}
// Get message timestamp
timestamp := time.Time{}
if ts := msg.Message.GetMessageTimestamp(); ts != 0 {
@@ -543,7 +635,7 @@ func handleHistorySync(client *whatsmeow.Client, messageStore *MessageStore, his
} else {
continue
}
err = messageStore.StoreMessage(
msgID,
chatJID,
@@ -562,7 +654,7 @@ func handleHistorySync(client *whatsmeow.Client, messageStore *MessageStore, his
}
}
}
fmt.Printf("History sync complete. Stored %d text messages.\n", syncedCount)
}
@@ -594,7 +686,7 @@ func requestHistorySync(client *whatsmeow.Client) {
Server: "s.whatsapp.net",
User: "status",
}, historyMsg)
if err != nil {
fmt.Printf("Failed to request history sync: %v\n", err)
} else {

View File

@@ -150,17 +150,29 @@ def get_message_context(
return context
@mcp.tool()
def send_message(phone_number: str, message: str) -> Dict[str, Any]:
"""Send a WhatsApp message to the specified phone number.
def send_message(
recipient: str,
message: str
) -> Dict[str, Any]:
"""Send a WhatsApp message to a person or group. For group chats use the JID.
Args:
phone_number: The recipient's phone number, with country code but no + or other symbols
recipient: The recipient - either a phone number with country code but no + or other symbols,
or a JID (e.g., "123456789@s.whatsapp.net" or a group JID like "123456789@g.us")
message: The message text to send
Returns:
A dictionary containing success status and a status message
"""
success, status_message = whatsapp_send_message(phone_number, message)
# Validate input
if not recipient:
return {
"success": False,
"message": "Recipient must be provided"
}
# Call the whatsapp_send_message function with the unified recipient parameter
success, status_message = whatsapp_send_message(recipient, message)
return {
"success": success,
"message": status_message

View File

@@ -662,20 +662,25 @@ def get_direct_chat_by_contact(sender_phone_number: str) -> Optional[Chat]:
if 'conn' in locals():
conn.close()
def send_message(phone_number: str, message: str) -> Tuple[bool, str]:
"""Send a WhatsApp message to the specified phone number.
def send_message(recipient: str, message: str) -> Tuple[bool, str]:
"""Send a WhatsApp message to the specified recipient. For group messages use the JID.
Args:
phone_number (str): The recipient's phone number, with country code but no + or other symbols
message (str): The message text to send
recipient: The recipient - either a phone number with country code but no + or other symbols,
or a JID (e.g., "123456789@s.whatsapp.net" or a group JID like "123456789@g.us").
message: The message text to send
Returns:
Tuple[bool, str]: A tuple containing success status and a status message
"""
try:
# Validate input
if not recipient:
return False, "Recipient must be provided"
url = f"{WHATSAPP_API_BASE_URL}/send"
payload = {
"phone": phone_number,
"recipient": recipient,
"message": message
}