summaryrefslogtreecommitdiff
path: root/stable_diffusion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'stable_diffusion.cpp')
-rw-r--r--stable_diffusion.cpp248
1 files changed, 241 insertions, 7 deletions
diff --git a/stable_diffusion.cpp b/stable_diffusion.cpp
index 0c42f03..2eb230b 100644
--- a/stable_diffusion.cpp
+++ b/stable_diffusion.cpp
@@ -1,15 +1,121 @@
+struct curl_data {
+ char *response;
+ size_t size;
+};
+
+static size_t dumbcurlcallback(void *data, size_t size, size_t nmemb, void *userp)
+ {
+ size_t realsize = size * nmemb;
+ curl_data *mem = (curl_data *)userp;
+
+ Memory_Copy((uint8 *)&(mem->response[mem->size]), (uint8 *)data, realsize);
+ mem->size += realsize;
+ mem->response[mem->size] = 0;
+
+ return realsize;
+ }
+
+static void
+SD_JSONToSource(project_data *File, project_state *State, memory *Memory, void *JSONResponse, int Height, int Width)
+{
+ // We've gotta go from base64 to compressed PNG to the raw format we want.
+ uint8 *Data = (uint8 *)JSONResponse;
+ uint64 i = 0;
+ Assert(Data[2] != 'd');
+ while (Data[i] != '[')
+ i++;
+ i += 2;
+ Assert(Data[i] != '"');
+ uint64 UncompressedSize = Width * Height * 4;
+ uint64 c = 0;
+ while (Data[i+c] != ']')
+ c++;
+ uint64 MaxSize = Width * Height * 4;
+ void *PNGData = Memory_PushScratch(Memory, MaxSize);
+ uint64 PNGSize = 0;
+ base64_decode(&Data[i], c, (uint8 *)PNGData, &PNGSize);
+ int x, y, a;
+ void *RawData = stbi_load_from_memory((stbi_uc *)PNGData, PNGSize, &x, &y, &a, 4);
+ Assert(x == Width && y == Height);
+ Memory_PopScratch(Memory, MaxSize);
+ int SrcIdx = Source_Generate_Blank(File, State, Memory, Width, Height, 4);
+ block_source *Source = (block_source *)Memory_Block_AddressAtIndex(Memory, F_Sources, SrcIdx);
+ // Source->Type = source_type_principal_temp;
+ void *BitmapAddress = Memory_Block_AddressAtIndex(Memory, F_PrincipalBitmaps, Source->Bitmap_Index, 0);
+ Memory_Copy((uint8 *)BitmapAddress, (uint8 *)RawData, MaxSize);
+ stbi_image_free(RawData);
+}
static void
-SD_Txt2Txt(sd_state *SD)
+SD_ParseProgress(project_state *State, char *JSONInfo)
{
- char JSONPayload[1024];
- char CurlCommand[1024];
+ char P[4];
+ char P2[8];
+ int ProgLocation = 0;
+ while (!(JSONInfo[ProgLocation] == 's' && JSONInfo[ProgLocation+1] == 's')) {
+ ProgLocation++;
+ }
+ ProgLocation += 4;
+ int ETALocation = ProgLocation;
+ while (!(JSONInfo[ETALocation] == 'v' && JSONInfo[ETALocation+1] == 'e')) {
+ ETALocation++;
+ }
+ ETALocation += 4;
+ // Assert(JSONInfo[ProgLocation] >= '0' && JSONInfo[ProgLocation] <= '9');
+ // Assert(JSONInfo[ETALocation] >= '0' && JSONInfo[ETALocation] <= '9');
+ Memory_Copy((uint8 *)P, (uint8 *)&JSONInfo[12], 4);
+ Memory_Copy((uint8 *)P2, (uint8 *)&JSONInfo[32], 8);
+ real32 Percent = atof(P);
+ if (Percent > 0.0f) // occasionally returns negative
+ State->SDPercentDone = Percent;
+ real32 Time = atof(P2);
+ if (Time > 0.1f) // occasionally returns zero for some reason
+ State->SDTimeEstimate = Time;
+ // Assert(0);
+}
+
+static void
+JSON_AppendParam_String(char *String, uint64 *i, char *P1, char *P2)
+{
+ uint64 c = *i;
+ String[c++] = '"';
+ int a = 0;
+ while(P1[a] != '\0') {
+ String[c++] = P1[a++];
+ }
+ String[c++] = ':';
+ String[c++] = ' ';
+ String[c++] = '"';
+ a = 0;
+ while(P2[a] != '\0') {
+ String[c++] = P2[a++];
+ if (a > 64)
+ break;
+ }
+ c--;
+ String[c++] = '"';
+ String[c++] = ',';
+ String[c++] = '\n';
+ String[c++] = '\0';
+ *i = c;
+}
+
+static void
+SD_AssembleJSON(sd_state *SD, char *JSONPayload, void *Base64Bitmap = NULL)
+{
+ Arbitrary_Zero((uint8 *)JSONPayload, 1024);
+ // char CurlCommand[1024];
char *Test[] = { "prompt", "negative_prompt", "steps", "width", "height", "cfg_scale" };
void *Test2[6] = { (void *)SD->Prompt, (void *)SD->NegPrompt,
(void *)&SD->Steps, (void *)&SD->Width,
(void *)&SD->Height, (void *)&SD->CFG };
int Type[6] = { 0, 0, 1, 1, 1, 2};
- sprintf(JSONPayload, "%s{\n", JSONPayload);
+ JSONPayload[0] = '{';
+ JSONPayload[1] = '\n';
+ JSONPayload[2] = '\0';
+ uint64 i = 2;
+ if (SD->Mode)
+ JSON_AppendParam_String(JSONPayload, &i, "init_images", (char *)Base64Bitmap);
for (int i = 0; i < 6; i++) {
if (Type[i] == 0) {
sprintf(JSONPayload, "%s\"%s\": \"%s\",\n", JSONPayload, Test[i], (char *)Test2[i]);
@@ -21,8 +127,136 @@ SD_Txt2Txt(sd_state *SD)
Assert(0);
}
}
+ if (SD->Mode)
+ sprintf(JSONPayload, "%s\"%s\": %.2f,\n", JSONPayload, "denoising_strength", SD->DenoisingStrength);
+ sprintf(JSONPayload, "%s%s\n", JSONPayload, "\"sampler_index\": \"DDIM\"");
sprintf(JSONPayload, "%s}\n", JSONPayload);
- sprintf(CurlCommand, "curl -X POST -H 'Content-Type: application/json' -i '%s/sdapi/v1/txt2img' --data '%s'",
- SD->ServerAddress, JSONPayload);
- printf("%s\n", CurlCommand);
+ // sprintf(CurlCommand, "curl -X POST -H 'Content-Type: application/json' -i '%s/sdapi/v1/txt2img' --data '%s'", SD->ServerAddress, JSONPayload);
+ // printf("%s\n", CurlCommand);
};
+
+struct curl_state
+{
+ CURL *curl;
+ CURLM *curlm;
+ curl_slist *list = NULL;
+ curl_data CurlData;
+};
+
+static void
+Curl_Free(curl_state *Handle)
+{
+ curl_multi_remove_handle(Handle->curlm, Handle->curl);
+ curl_easy_cleanup(Handle->curl);
+ curl_multi_cleanup(Handle->curlm);
+}
+
+static bool32
+Curl_Check(curl_state *Handle)
+{
+ int IsActive;
+ CURLMcode mc = curl_multi_perform(Handle->curlm, &IsActive);
+ Assert(!mc);
+ if (!IsActive) {
+ int queue = 0;
+ CURLMsg *msg = curl_multi_info_read(Handle->curlm, &queue);
+ if (msg) {
+ CURL *e = msg->easy_handle;
+ Assert(e == Handle->curl);
+ Assert(msg->msg == CURLMSG_DONE);
+ if (!msg->data.result) {
+ return 1;
+ } else if (msg->data.result == CURLE_COULDNT_CONNECT) {
+ // printf("Active stable-diffusion-webui instance not found at URL.\n");
+ return -1;
+ } else {
+ // printf("curl error: %s!\n", curl_easy_strerror(msg->data.result));
+ return -1;
+ }
+ }
+ }
+ return 0;
+}
+
+static char *APIString[] = {"txt2img", "img2img" };
+
+static void
+Curl_GET_Init(curl_state *C, void *OutputData, char *JSONPayload, char *IP, bool32 API)
+{
+ C->list = curl_slist_append(C->list, "Content-Type: application/json");
+
+ C->curl = curl_easy_init();
+ Assert(C->curl);
+
+ char URL[512];
+ sprintf(URL, "%s/sdapi/v1/%s", IP, APIString[API]);
+ curl_easy_setopt(C->curl, CURLOPT_URL, URL);
+ curl_easy_setopt(C->curl, CURLOPT_HTTPHEADER, C->list);
+ curl_easy_setopt(C->curl, CURLOPT_POSTFIELDS, JSONPayload);
+
+ C->CurlData = { (char *)OutputData, 0 };
+
+ curl_easy_setopt(C->curl, CURLOPT_WRITEFUNCTION, dumbcurlcallback);
+ curl_easy_setopt(C->curl, CURLOPT_WRITEDATA, (void *)&C->CurlData);
+
+ C->curlm = curl_multi_init();
+ Assert(C->curlm);
+ curl_multi_add_handle(C->curlm, C->curl);
+
+ Curl_Check(C);
+}
+
+static void
+Curl_Prog_Init(curl_state *C, void *OutputData)
+{
+ C->curl = curl_easy_init();
+ Assert(C->curl);
+
+ C->CurlData = { (char *)OutputData, 0 };
+
+ curl_easy_setopt(C->curl, CURLOPT_URL, "http://127.0.0.1:7860/sdapi/v1/progress");
+ curl_easy_setopt(C->curl, CURLOPT_WRITEFUNCTION, dumbcurlcallback);
+ curl_easy_setopt(C->curl, CURLOPT_WRITEDATA, (void *)&C->CurlData);
+
+ C->curlm = curl_multi_init();
+ Assert(C->curlm);
+ curl_multi_add_handle(C->curlm, C->curl);
+}
+
+ /*
+ int num;
+ CURLMcode mc = curl_multi_perform(C->curlm, &num);
+ Assert(!mc);
+
+ int queue = 0;
+ CURLMsg *msg = curl_multi_info_read(C->curlm, &queue);
+ if (msg) {
+ CURL *e = msg->easy_handle;
+ if (msg->msg == CURLMSG_DONE) {
+ if (msg->data.result == CURLE_COULDNT_CONNECT) {
+ printf("Active stable-diffusion-webui instance not found at URL.\n");
+ } else {
+ printf("curl error: %s!\n", curl_easy_strerror(msg->data.result));
+ }
+ Assert(0);
+ }
+ }
+ */
+
+
+#if 0
+ res = curl_easy_perform(curl);
+ /* Check for errors */
+ if(res != CURLE_OK)
+ fprintf(stderr, "curl_easy_perform() failed: %s\n",
+ curl_easy_strerror(res));
+
+ /* Perform the request, res will get the return code */
+
+ /* always cleanup */
+ curl_slist_free_all(list);
+ curl_easy_cleanup(curl);
+
+
+ curl_global_cleanup();
+#endif