51 lines
1.2 KiB
Go
51 lines
1.2 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
type OpenAIProvider struct {
|
|
client *openai.Client
|
|
name string
|
|
}
|
|
|
|
// NewOpenAIProvider creates a new provider that uses the official or compatible OpenAI API.
|
|
// It can also handle DeepSeek via a custom BaseURL.
|
|
func NewOpenAIProvider(apiKey, baseURL, name string) *OpenAIProvider {
|
|
config := openai.DefaultConfig(apiKey)
|
|
if baseURL != "" {
|
|
config.BaseURL = baseURL
|
|
}
|
|
return &OpenAIProvider{
|
|
client: openai.NewClientWithConfig(config),
|
|
name: name,
|
|
}
|
|
}
|
|
|
|
func (p *OpenAIProvider) Name() string {
|
|
return p.name
|
|
}
|
|
|
|
func (p *OpenAIProvider) GenerateReply(ctx context.Context, model string, systemPrompt, userPrompt string) (string, error) {
|
|
resp, err := p.client.CreateChatCompletion(
|
|
ctx,
|
|
openai.ChatCompletionRequest{
|
|
Model: model,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{Role: openai.ChatMessageRoleSystem, Content: systemPrompt},
|
|
{Role: openai.ChatMessageRoleUser, Content: userPrompt},
|
|
},
|
|
},
|
|
)
|
|
if err != nil {
|
|
return "", fmt.Errorf("%s api error: %w", p.name, err)
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
return "", fmt.Errorf("%s returned no choices", p.name)
|
|
}
|
|
return resp.Choices[0].Message.Content, nil
|
|
}
|