metahunt/Assets/barScene/AI Toolbox/Runtime/ChatGpt.cs
2024-02-11 13:53:18 +05:30

454 lines
18 KiB
C#

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using UnityEngine.Networking;
// ReSharper disable InconsistentNaming
namespace AiToolbox {
public enum Role {
User,
AI,
}
public struct Message {
public string text;
public Role role;
public Message(string text, Role role) {
this.text = text;
this.role = role;
}
}
public enum ChatGptErrorCodes {
MaxTokensExceeded = 0,
ThrottleExceeded = 1,
RemoteConfigConnectionFailure = 2,
RemoteConfigKeyNotFound = 3,
Unknown = 4,
}
internal sealed class RequestRecord {
private Action cancelCallback;
public void SetCancelCallback(Action callback) {
cancelCallback = callback;
}
public void Cancel() {
cancelCallback?.Invoke();
}
}
public static class ChatGpt {
private static readonly List<RequestRecord> _requestRecords = new List<RequestRecord>();
/// <summary>
/// Send a request to ChatGPT.
/// </summary>
/// <param name="prompt">The text of the request, e.g. "Generate a character description".</param>
/// <param name="parameters">Settings of the request.</param>
/// <param name="completeCallback">The function to be called on successful completion. ChatGPT response is provided
/// as a parameter.</param>
/// <param name="failureCallback">The function to be called on failure. Error code and message are provided as
/// parameters.</param>
/// <param name="updateCallback">The function to be called when a new response chunk is generated. ChatGPT response
/// data is provided as a parameter.</param>
/// <returns>A function that can be called to cancel the request.</returns>
// ReSharper disable once MemberCanBePrivate.Global
public static Action Request(string prompt, Parameters parameters, Action<string> completeCallback,
Action<long, string> failureCallback, Action<string> updateCallback = null) {
return Request(new List<Message> { new Message { role = Role.User, text = prompt } }, parameters,
completeCallback, failureCallback, updateCallback);
}
/// <summary>
/// Send a request to ChatGPT.
/// </summary>
/// <param name="messages">Sequence of messages to send to ChatGPT. The order of messages should be the same as the
/// chronological order of messages in the conversation, i.e. the first message should be the oldest one. The roles
/// of the messages should switch between User and AI.</param>
/// <param name="parameters">Settings of the request.</param>
/// <param name="completeCallback">The function to be called on successful completion. ChatGPT response is provided
/// as a parameter.</param>
/// <param name="failureCallback">The function to be called on failure. Error code and message are provided as
/// parameters.</param>
/// <param name="updateCallback">The function to be called when a new response chunk is generated. ChatGPT response
/// data is provided as a parameter.</param>
/// <returns>A function that can be called to cancel the request.</returns>
// ReSharper disable once MemberCanBePrivate.Global
public static Action Request(IEnumerable<Message> messages, Parameters parameters, Action<string> completeCallback,
Action<long, string> failureCallback, Action<string> updateCallback = null) {
Debug.Assert(parameters != null, "Parameters cannot be null.");
Debug.Assert(!string.IsNullOrEmpty(parameters!.apiKey), "API key cannot be null or empty.");
Debug.Assert(messages != null, "Messages cannot be null.");
if (updateCallback == null) {
return QuickRequest(messages, parameters, completeCallback, failureCallback);
}
// Throttle.
if (parameters.throttle > 0) {
var requestCount = _requestRecords.Count;
if (requestCount >= parameters.throttle) {
failureCallback?.Invoke((long)ChatGptErrorCodes.ThrottleExceeded,
$"Too many requests. Maximum allowed: {parameters.throttle}.");
return () => { };
}
}
var requestRecord = new RequestRecord();
var enumerator = Stream(messages, parameters, updateCallback, completeCallback, failureCallback, requestRecord);
var cancelCallback = new Action(() => {
if (enumerator != null) {
ChatGptContainer.Instance.StopCoroutine(enumerator);
}
_requestRecords.Remove(requestRecord);
});
requestRecord.SetCancelCallback(cancelCallback);
_requestRecords.Add(requestRecord);
ChatGptContainer.Instance.StartCoroutine(enumerator);
return cancelCallback;
}
// ReSharper disable once MemberCanBePrivate.Global
/// <summary>
/// Cancel all pending requests.
/// </summary>
public static void CancelAllRequests() {
while (_requestRecords.Count > 0) {
_requestRecords[0].Cancel();
}
_requestRecords.Clear();
}
private static Action QuickRequest(IEnumerable<Message> messages, Parameters parameters,
Action<string> completeCallback, Action<long, string> failureCallback) {
if (parameters.apiKeyEncryption != ApiKeyEncryption.RemoteConfig) {
return QuickRequestBlocking(messages, parameters, completeCallback, failureCallback);
}
var enumerator = QuickRequestCoroutine(messages, parameters, completeCallback, failureCallback);
ChatGptContainer.Instance.StartCoroutine(enumerator);
void CancelCallback() {
ChatGptContainer.Instance.StopCoroutine(enumerator);
}
return CancelCallback;
}
private static IEnumerator QuickRequestCoroutine(IEnumerable<Message> messages, Parameters parameters,
Action<string> completeCallback,
Action<long, string> failureCallback) {
if (parameters.apiKeyEncryption == ApiKeyEncryption.RemoteConfig) {
yield return GetRemoteConfig(parameters, failureCallback);
}
QuickRequestBlocking(messages, parameters, completeCallback, failureCallback);
}
private static Action QuickRequestBlocking(IEnumerable<Message> messages, Parameters parameters,
Action<string> completeCallback, Action<long, string> failureCallback) {
Debug.Assert(parameters != null, "Parameters cannot be null.");
Debug.Assert(!string.IsNullOrEmpty(parameters!.apiKey), "API key cannot be null or empty.");
Debug.Assert(messages != null, "Messages cannot be null.");
// Throttle.
if (parameters.throttle > 0) {
var requestCount = _requestRecords.Count;
if (requestCount >= parameters.throttle) {
failureCallback?.Invoke((long)ChatGptErrorCodes.ThrottleExceeded,
$"Too many requests. Maximum allowed: {parameters.throttle}.");
return () => { };
}
}
var requestObject = new RequestMessage {
model = GetModelName(parameters.model),
temperature = parameters.temperature,
stream = false,
messages = ConvertMessages(messages, parameters.role),
};
var requestRecord = new RequestRecord();
var requestJson = JsonUtility.ToJson(requestObject);
var request = GetWebRequest(requestJson, parameters, failureCallback, requestRecord);
var cancelCallback = new Action(() => {
try {
request?.Abort();
request?.Dispose();
_requestRecords.Remove(requestRecord);
}
catch (Exception) {
// If the request is aborted, accessing the error property will throw an exception.
}
});
requestRecord.SetCancelCallback(cancelCallback);
_requestRecords.Add(requestRecord);
request.SendWebRequest().completed += _ => {
_requestRecords.Remove(requestRecord);
Application.quitting -= cancelCallback;
bool isErrorResponse;
try {
isErrorResponse = !string.IsNullOrEmpty(request.error);
}
catch (Exception) {
// If the request is aborted, accessing the error property will throw an exception.
return;
}
if (isErrorResponse) {
failureCallback?.Invoke(request.responseCode, request.error);
return;
}
var response = JsonUtility.FromJson<ResponseMessage>(request.downloadHandler.text);
if (response.choices.Length == 0) {
failureCallback?.Invoke((long)ChatGptErrorCodes.Unknown,
"No response choices returned from the server.");
return;
}
var responseMessage = response.choices[0].message.content;
completeCallback?.Invoke(responseMessage);
request.Dispose();
};
Application.quitting += cancelCallback;
return cancelCallback;
}
private static IEnumerator Stream(IEnumerable<Message> messages, Parameters parameters,
Action<string> updateCallback, Action<string> completeCallback,
Action<long, string> failureCallback, RequestRecord requestRecord) {
var requestObject = new RequestMessage {
model = GetModelName(parameters.model),
temperature = parameters.temperature,
stream = true,
messages = ConvertMessages(messages, parameters.role),
};
if (parameters.apiKeyEncryption == ApiKeyEncryption.RemoteConfig) {
yield return GetRemoteConfig(parameters, failureCallback);
}
var requestJson = JsonUtility.ToJson(requestObject);
using var request = GetWebRequest(requestJson, parameters, failureCallback, requestRecord);
var webRequest = request.SendWebRequest();
int textLength = 0;
string completeText = "";
while (!webRequest.isDone) {
if (request.downloadHandler.text.Length > textLength) {
if (!string.IsNullOrEmpty(request.error)) {
failureCallback(request.responseCode, request.error);
_requestRecords.Remove(requestRecord);
yield break;
}
var text = request.downloadHandler.text;
var newText = text.Substring(textLength);
textLength = text.Length;
while (newText.Contains("data: ")) {
var startTrimmed =
newText.Substring(newText.IndexOf("data: ", StringComparison.Ordinal) + "data: ".Length);
var dataEndPosition = startTrimmed.IndexOf("data: ", StringComparison.Ordinal);
var dataJson = dataEndPosition == -1 ? startTrimmed : startTrimmed.Substring(0, dataEndPosition);
newText = dataEndPosition == -1 ? "" : startTrimmed.Substring(dataEndPosition);
if (dataJson.Contains("[DONE]")) {
break;
}
try {
var data = JsonUtility.FromJson<ResponseMessage>(dataJson);
if (data.choices == null || data.choices.Length == 0) {
failureCallback((long)ChatGptErrorCodes.Unknown,
"No response choices returned from the server.");
_requestRecords.Remove(requestRecord);
yield break;
}
if (data.choices[0].finish_reason == "length") {
failureCallback((long)ChatGptErrorCodes.MaxTokensExceeded, completeText);
_requestRecords.Remove(requestRecord);
yield break;
}
var delta = data.choices[0].delta.content;
completeText += delta;
updateCallback?.Invoke(delta);
}
catch (Exception e) {
failureCallback((long)ChatGptErrorCodes.Unknown, e.Message);
_requestRecords.Remove(requestRecord);
yield break;
}
}
}
yield return null;
}
if (!string.IsNullOrEmpty(request.error)) {
failureCallback?.Invoke(request.responseCode, request.error);
_requestRecords.Remove(requestRecord);
yield break;
}
if (!string.IsNullOrEmpty(completeText)) {
completeCallback?.Invoke(completeText);
_requestRecords.Remove(requestRecord);
}
}
private static IEnumerator GetRemoteConfig(Parameters parameters, Action<long, string> failureCallback) {
var apiKeySet = false;
var task = RemoteKeyService.GetOpenAiKey(parameters.apiKeyRemoteConfigKey, s => {
parameters.apiKeyEncryption = ApiKeyEncryption.None;
parameters.apiKey = s;
apiKeySet = true;
}, (errorCode, error) => {
failureCallback?.Invoke(errorCode, error);
apiKeySet = true;
});
yield return new WaitUntil(() => task.IsCompleted && apiKeySet);
if (task.IsFaulted) {
failureCallback?.Invoke((long)ChatGptErrorCodes.RemoteConfigConnectionFailure,
"Failed to retrieve API key from remote config.");
}
}
private static UnityWebRequest GetWebRequest(string requestJson, Parameters parameters,
Action<long, string> failureCallback, RequestRecord requestRecord) {
var baseUrl = "https://api.openai.com/v1/chat/completions";
#if UNITY_2022_2_OR_NEWER
var request = UnityWebRequest.Post(baseUrl, requestJson, "application/json");
#else
var request = new UnityWebRequest(baseUrl, "POST");
request.uploadHandler = new UploadHandlerRaw(System.Text.Encoding.UTF8.GetBytes(requestJson));
request.downloadHandler = new DownloadHandlerBuffer();
request.SetRequestHeader("Content-Type", "application/json");
#endif
request.timeout = parameters.timeout;
try {
var apiKey = parameters.apiKey;
var isEncrypted = parameters.apiKeyEncryption == ApiKeyEncryption.LocallyEncrypted;
if (isEncrypted) {
apiKey = Key.B(apiKey, parameters.apiKeyEncryptionPassword);
}
request.SetRequestHeader("Authorization", "Bearer " + apiKey);
}
catch (Exception e) {
failureCallback?.Invoke((long)ChatGptErrorCodes.Unknown, e.Message);
_requestRecords.Remove(requestRecord);
}
return request;
}
private static string GetModelName(Model model) {
return model switch {
Model.Gpt35Turbo => "gpt-3.5-turbo",
Model.Gpt4 => "gpt-4",
_ => throw new ArgumentOutOfRangeException(nameof(model), model, null)
};
}
private static RoleContentMessage[] ConvertMessages(IEnumerable<Message> messages, string role) {
var systemMessageOffset = string.IsNullOrEmpty(role) ? 0 : 1;
var inputArray = messages as Message[] ?? messages.ToArray();
var requestMessages = new RoleContentMessage[inputArray.Length + systemMessageOffset];
if (systemMessageOffset > 0) {
requestMessages[0] = new RoleContentMessage { role = "system", content = role };
}
for (var i = systemMessageOffset; i < requestMessages.Length; i++) {
var message = inputArray[i - systemMessageOffset];
requestMessages[i] = new RoleContentMessage {
role = message.role == Role.User ? "user" : "assistant", content = message.text
};
}
return requestMessages;
}
private class ChatGptContainer : MonoBehaviour {
private static ChatGptContainer _instance;
internal static ChatGptContainer Instance {
get {
if (_instance == null) {
var container = new GameObject("ChatGptContainer");
DontDestroyOnLoad(container);
container.hideFlags = HideFlags.HideInHierarchy;
_instance = container.AddComponent<ChatGptContainer>();
}
return _instance;
}
}
private void OnApplicationQuit() {
CancelAllRequests();
}
}
#pragma warning disable 0649
// ReSharper disable NotAccessedField.Local
[Serializable]
private struct RequestMessage {
public string model;
public RoleContentMessage[] messages;
public float temperature;
public bool stream;
// Omitted fields: int n, string stop, int max_tokens,
// float presence_penalty, float frequency_penalty;
}
[Serializable]
private struct RoleContentMessage {
public string role;
public string content;
}
[Serializable]
private struct ResponseMessage {
public string id;
public string created;
public ResponseChoice[] choices;
public string model;
public Usage usage;
}
[Serializable]
private struct ResponseChoice {
public int index;
public RoleContentMessage delta;
public RoleContentMessage message;
public string finish_reason;
}
[Serializable]
private struct Usage {
public int completion_tokens;
public int prompt_tokens;
public int total_tokens;
}
}
}